diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 288b2e7..f4fbc58 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -1,6 +1,7 @@ """Subprocess transport implementation using Claude Code CLI.""" import json +import logging import os import shutil from collections.abc import AsyncIterator @@ -17,6 +18,8 @@ from ..._errors import CLIJSONDecodeError as SDKJSONDecodeError from ...types import ClaudeCodeOptions from . import Transport +logger = logging.getLogger(__name__) + _MAX_BUFFER_SIZE = 1024 * 1024 # 1MB buffer limit @@ -170,69 +173,104 @@ class SubprocessCLITransport(Transport): if not self._process or not self._stdout_stream: raise CLIConnectionError("Not connected") - stderr_lines = [] + # Safety constants + max_stderr_size = 10 * 1024 * 1024 # 10MB + stderr_timeout = 30.0 # 30 seconds - async def read_stderr() -> None: - """Read stderr in background.""" - if self._stderr_stream: - try: - async for line in self._stderr_stream: - stderr_lines.append(line.strip()) - except anyio.ClosedResourceError: - pass + json_buffer = "" - async with anyio.create_task_group() as tg: - tg.start_soon(read_stderr) + # Process stdout messages first + try: + async for line in self._stdout_stream: + line_str = line.strip() + if not line_str: + continue - json_buffer = "" + json_lines = line_str.split("\n") - try: - async for line in self._stdout_stream: - line_str = line.strip() - if not line_str: + for json_line in json_lines: + json_line = json_line.strip() + if not json_line: continue - json_lines = line_str.split("\n") + # Keep accumulating partial JSON until we can parse it + json_buffer += json_line - for json_line in json_lines: - json_line = json_line.strip() - if not json_line: - continue - - # Keep accumulating partial JSON until we can parse it - json_buffer += json_line - - if len(json_buffer) > _MAX_BUFFER_SIZE: - json_buffer = "" - raise SDKJSONDecodeError( - f"JSON message exceeded maximum buffer size of {_MAX_BUFFER_SIZE} bytes", - ValueError( - f"Buffer size {len(json_buffer)} exceeds limit {_MAX_BUFFER_SIZE}" - ), - ) + if len(json_buffer) > _MAX_BUFFER_SIZE: + json_buffer = "" + raise SDKJSONDecodeError( + f"JSON message exceeded maximum buffer size of {_MAX_BUFFER_SIZE} bytes", + ValueError( + f"Buffer size {len(json_buffer)} exceeds limit {_MAX_BUFFER_SIZE}" + ), + ) + try: + data = json.loads(json_buffer) + json_buffer = "" try: - data = json.loads(json_buffer) - json_buffer = "" - try: - yield data - except GeneratorExit: - return - except json.JSONDecodeError: - continue + yield data + except GeneratorExit: + return + except json.JSONDecodeError: + continue + except anyio.ClosedResourceError: + pass + except GeneratorExit: + # Client disconnected - still need to clean up + pass + + # Process stderr with safety limits + stderr_lines = [] + stderr_size = 0 + + if self._stderr_stream: + try: + # Use timeout to prevent hanging + with anyio.fail_after(stderr_timeout): + async for line in self._stderr_stream: + line_text = line.strip() + line_size = len(line_text) + + # Enforce memory limit + if stderr_size + line_size > max_stderr_size: + stderr_lines.append( + f"[stderr truncated after {stderr_size} bytes]" + ) + # Drain rest of stream without storing + async for _ in self._stderr_stream: + pass + break + + stderr_lines.append(line_text) + stderr_size += line_size + + except TimeoutError: + stderr_lines.append( + f"[stderr collection timed out after {stderr_timeout}s]" + ) except anyio.ClosedResourceError: pass - await self._process.wait() - if self._process.returncode is not None and self._process.returncode != 0: - stderr_output = "\n".join(stderr_lines) - if stderr_output and "error" in stderr_output.lower(): - raise ProcessError( - "CLI process failed", - exit_code=self._process.returncode, - stderr=stderr_output, - ) + # Check process completion and handle errors + try: + returncode = await self._process.wait() + except Exception: + returncode = -1 + + stderr_output = "\n".join(stderr_lines) if stderr_lines else "" + + # Use exit code for error detection, not string matching + if returncode is not None and returncode != 0: + raise ProcessError( + f"Command failed with exit code {returncode}", + exit_code=returncode, + stderr=stderr_output, + ) + elif stderr_output: + # Log stderr for debugging but don't fail on non-zero exit + logger.debug(f"Process stderr: {stderr_output}") def is_connected(self) -> bool: """Check if subprocess is running.""" diff --git a/tests/test_subprocess_buffering.py b/tests/test_subprocess_buffering.py index 3f9b0be..426d42e 100644 --- a/tests/test_subprocess_buffering.py +++ b/tests/test_subprocess_buffering.py @@ -255,9 +255,8 @@ class TestSubprocessBuffering: async for msg in transport.receive_messages(): messages.append(msg) - assert len(exc_info.value.exceptions) == 1 - assert isinstance(exc_info.value.exceptions[0], CLIJSONDecodeError) - assert "exceeded maximum buffer size" in str(exc_info.value.exceptions[0]) + assert isinstance(exc_info.value, CLIJSONDecodeError) + assert "exceeded maximum buffer size" in str(exc_info.value) anyio.run(_test)