Merge pull request #60 from anthropics/fix-fastapi-streaming-issue-4

Fix FastAPI SSE streaming compatibility
This commit is contained in:
Lina Tawfik 2025-07-01 20:40:12 -07:00 committed by GitHub
commit e4de22eba4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 90 additions and 53 deletions

View file

@ -1,6 +1,7 @@
"""Subprocess transport implementation using Claude Code CLI.""" """Subprocess transport implementation using Claude Code CLI."""
import json import json
import logging
import os import os
import shutil import shutil
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
@ -17,6 +18,8 @@ from ..._errors import CLIJSONDecodeError as SDKJSONDecodeError
from ...types import ClaudeCodeOptions from ...types import ClaudeCodeOptions
from . import Transport from . import Transport
logger = logging.getLogger(__name__)
_MAX_BUFFER_SIZE = 1024 * 1024 # 1MB buffer limit _MAX_BUFFER_SIZE = 1024 * 1024 # 1MB buffer limit
@ -170,69 +173,104 @@ class SubprocessCLITransport(Transport):
if not self._process or not self._stdout_stream: if not self._process or not self._stdout_stream:
raise CLIConnectionError("Not connected") raise CLIConnectionError("Not connected")
stderr_lines = [] # Safety constants
max_stderr_size = 10 * 1024 * 1024 # 10MB
stderr_timeout = 30.0 # 30 seconds
async def read_stderr() -> None: json_buffer = ""
"""Read stderr in background."""
if self._stderr_stream:
try:
async for line in self._stderr_stream:
stderr_lines.append(line.strip())
except anyio.ClosedResourceError:
pass
async with anyio.create_task_group() as tg: # Process stdout messages first
tg.start_soon(read_stderr) try:
async for line in self._stdout_stream:
line_str = line.strip()
if not line_str:
continue
json_buffer = "" json_lines = line_str.split("\n")
try: for json_line in json_lines:
async for line in self._stdout_stream: json_line = json_line.strip()
line_str = line.strip() if not json_line:
if not line_str:
continue continue
json_lines = line_str.split("\n") # Keep accumulating partial JSON until we can parse it
json_buffer += json_line
for json_line in json_lines: if len(json_buffer) > _MAX_BUFFER_SIZE:
json_line = json_line.strip() json_buffer = ""
if not json_line: raise SDKJSONDecodeError(
continue f"JSON message exceeded maximum buffer size of {_MAX_BUFFER_SIZE} bytes",
ValueError(
# Keep accumulating partial JSON until we can parse it f"Buffer size {len(json_buffer)} exceeds limit {_MAX_BUFFER_SIZE}"
json_buffer += json_line ),
)
if len(json_buffer) > _MAX_BUFFER_SIZE:
json_buffer = ""
raise SDKJSONDecodeError(
f"JSON message exceeded maximum buffer size of {_MAX_BUFFER_SIZE} bytes",
ValueError(
f"Buffer size {len(json_buffer)} exceeds limit {_MAX_BUFFER_SIZE}"
),
)
try:
data = json.loads(json_buffer)
json_buffer = ""
try: try:
data = json.loads(json_buffer) yield data
json_buffer = "" except GeneratorExit:
try: return
yield data except json.JSONDecodeError:
except GeneratorExit: continue
return
except json.JSONDecodeError:
continue
except anyio.ClosedResourceError:
pass
except GeneratorExit:
# Client disconnected - still need to clean up
pass
# Process stderr with safety limits
stderr_lines = []
stderr_size = 0
if self._stderr_stream:
try:
# Use timeout to prevent hanging
with anyio.fail_after(stderr_timeout):
async for line in self._stderr_stream:
line_text = line.strip()
line_size = len(line_text)
# Enforce memory limit
if stderr_size + line_size > max_stderr_size:
stderr_lines.append(
f"[stderr truncated after {stderr_size} bytes]"
)
# Drain rest of stream without storing
async for _ in self._stderr_stream:
pass
break
stderr_lines.append(line_text)
stderr_size += line_size
except TimeoutError:
stderr_lines.append(
f"[stderr collection timed out after {stderr_timeout}s]"
)
except anyio.ClosedResourceError: except anyio.ClosedResourceError:
pass pass
await self._process.wait() # Check process completion and handle errors
if self._process.returncode is not None and self._process.returncode != 0: try:
stderr_output = "\n".join(stderr_lines) returncode = await self._process.wait()
if stderr_output and "error" in stderr_output.lower(): except Exception:
raise ProcessError( returncode = -1
"CLI process failed",
exit_code=self._process.returncode, stderr_output = "\n".join(stderr_lines) if stderr_lines else ""
stderr=stderr_output,
) # Use exit code for error detection, not string matching
if returncode is not None and returncode != 0:
raise ProcessError(
f"Command failed with exit code {returncode}",
exit_code=returncode,
stderr=stderr_output,
)
elif stderr_output:
# Log stderr for debugging but don't fail on non-zero exit
logger.debug(f"Process stderr: {stderr_output}")
def is_connected(self) -> bool: def is_connected(self) -> bool:
"""Check if subprocess is running.""" """Check if subprocess is running."""

View file

@ -255,9 +255,8 @@ class TestSubprocessBuffering:
async for msg in transport.receive_messages(): async for msg in transport.receive_messages():
messages.append(msg) messages.append(msg)
assert len(exc_info.value.exceptions) == 1 assert isinstance(exc_info.value, CLIJSONDecodeError)
assert isinstance(exc_info.value.exceptions[0], CLIJSONDecodeError) assert "exceeded maximum buffer size" in str(exc_info.value)
assert "exceeded maximum buffer size" in str(exc_info.value.exceptions[0])
anyio.run(_test) anyio.run(_test)