diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 288b2e7..c6da949 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -170,57 +170,53 @@ class SubprocessCLITransport(Transport): if not self._process or not self._stdout_stream: raise CLIConnectionError("Not connected") - stderr_lines = [] + json_buffer = "" + + # Process stdout messages first + try: + async for line in self._stdout_stream: + line_str = line.strip() + if not line_str: + continue - async def read_stderr() -> None: - """Read stderr in background.""" - if self._stderr_stream: - try: - async for line in self._stderr_stream: - stderr_lines.append(line.strip()) - except anyio.ClosedResourceError: - pass + json_lines = line_str.split("\n") - async with anyio.create_task_group() as tg: - tg.start_soon(read_stderr) - - json_buffer = "" - - try: - async for line in self._stdout_stream: - line_str = line.strip() - if not line_str: + for json_line in json_lines: + json_line = json_line.strip() + if not json_line: continue - json_lines = line_str.split("\n") + # Keep accumulating partial JSON until we can parse it + json_buffer += json_line - for json_line in json_lines: - json_line = json_line.strip() - if not json_line: - continue - - # Keep accumulating partial JSON until we can parse it - json_buffer += json_line - - if len(json_buffer) > _MAX_BUFFER_SIZE: - json_buffer = "" - raise SDKJSONDecodeError( - f"JSON message exceeded maximum buffer size of {_MAX_BUFFER_SIZE} bytes", - ValueError( - f"Buffer size {len(json_buffer)} exceeds limit {_MAX_BUFFER_SIZE}" - ), - ) + if len(json_buffer) > _MAX_BUFFER_SIZE: + json_buffer = "" + raise SDKJSONDecodeError( + f"JSON message exceeded maximum buffer size of {_MAX_BUFFER_SIZE} bytes", + ValueError( + f"Buffer size {len(json_buffer)} exceeds limit {_MAX_BUFFER_SIZE}" + ), + ) + try: + data = json.loads(json_buffer) + json_buffer = "" try: - data = json.loads(json_buffer) - json_buffer = "" - try: - yield data - except GeneratorExit: - return - except json.JSONDecodeError: - continue + yield data + except GeneratorExit: + return + except json.JSONDecodeError: + continue + except anyio.ClosedResourceError: + pass + + # Read stderr after stdout completes (no concurrent task group) + stderr_lines = [] + if self._stderr_stream: + try: + async for line in self._stderr_stream: + stderr_lines.append(line.strip()) except anyio.ClosedResourceError: pass diff --git a/tests/test_fastapi_streaming_compatibility.py b/tests/test_fastapi_streaming_compatibility.py new file mode 100644 index 0000000..47a678c --- /dev/null +++ b/tests/test_fastapi_streaming_compatibility.py @@ -0,0 +1,25 @@ +"""Test FastAPI streaming compatibility (issue #4 fix).""" + +import inspect + +from claude_code_sdk._internal.transport.subprocess_cli import SubprocessCLITransport + + +def test_no_task_groups_in_receive_messages(): + """Verify receive_messages doesn't use task groups (fixes FastAPI issue #4).""" + # Get the source code of receive_messages + source = inspect.getsource(SubprocessCLITransport.receive_messages) + + # The fix: ensure no task group or task creation + assert "create_task_group" not in source, ( + "receive_messages must not use create_task_group to avoid " + "RuntimeError with FastAPI streaming" + ) + assert "asyncio.create_task" not in source, ( + "receive_messages must not create tasks to maintain " + "compatibility with FastAPI's generator handling" + ) + + # Verify stderr is still being read (sequential approach) + assert "_stderr_stream" in source, "Should still read stderr" + assert "stderr_lines" in source, "Should collect stderr output" \ No newline at end of file diff --git a/tests/test_subprocess_buffering.py b/tests/test_subprocess_buffering.py index 3f9b0be..426d42e 100644 --- a/tests/test_subprocess_buffering.py +++ b/tests/test_subprocess_buffering.py @@ -255,9 +255,8 @@ class TestSubprocessBuffering: async for msg in transport.receive_messages(): messages.append(msg) - assert len(exc_info.value.exceptions) == 1 - assert isinstance(exc_info.value.exceptions[0], CLIJSONDecodeError) - assert "exceeded maximum buffer size" in str(exc_info.value.exceptions[0]) + assert isinstance(exc_info.value, CLIJSONDecodeError) + assert "exceeded maximum buffer size" in str(exc_info.value) anyio.run(_test)