mirror of
https://github.com/anthropics/claude-code-sdk-python.git
synced 2025-07-07 14:45:00 +00:00
Merge pull request #60 from anthropics/fix-fastapi-streaming-issue-4
Fix FastAPI SSE streaming compatibility
This commit is contained in:
commit
e4de22eba4
2 changed files with 90 additions and 53 deletions
|
@ -1,6 +1,7 @@
|
||||||
"""Subprocess transport implementation using Claude Code CLI."""
|
"""Subprocess transport implementation using Claude Code CLI."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
@ -17,6 +18,8 @@ from ..._errors import CLIJSONDecodeError as SDKJSONDecodeError
|
||||||
from ...types import ClaudeCodeOptions
|
from ...types import ClaudeCodeOptions
|
||||||
from . import Transport
|
from . import Transport
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_MAX_BUFFER_SIZE = 1024 * 1024 # 1MB buffer limit
|
_MAX_BUFFER_SIZE = 1024 * 1024 # 1MB buffer limit
|
||||||
|
|
||||||
|
|
||||||
|
@ -170,69 +173,104 @@ class SubprocessCLITransport(Transport):
|
||||||
if not self._process or not self._stdout_stream:
|
if not self._process or not self._stdout_stream:
|
||||||
raise CLIConnectionError("Not connected")
|
raise CLIConnectionError("Not connected")
|
||||||
|
|
||||||
stderr_lines = []
|
# Safety constants
|
||||||
|
max_stderr_size = 10 * 1024 * 1024 # 10MB
|
||||||
|
stderr_timeout = 30.0 # 30 seconds
|
||||||
|
|
||||||
async def read_stderr() -> None:
|
json_buffer = ""
|
||||||
"""Read stderr in background."""
|
|
||||||
if self._stderr_stream:
|
|
||||||
try:
|
|
||||||
async for line in self._stderr_stream:
|
|
||||||
stderr_lines.append(line.strip())
|
|
||||||
except anyio.ClosedResourceError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async with anyio.create_task_group() as tg:
|
# Process stdout messages first
|
||||||
tg.start_soon(read_stderr)
|
try:
|
||||||
|
async for line in self._stdout_stream:
|
||||||
|
line_str = line.strip()
|
||||||
|
if not line_str:
|
||||||
|
continue
|
||||||
|
|
||||||
json_buffer = ""
|
json_lines = line_str.split("\n")
|
||||||
|
|
||||||
try:
|
for json_line in json_lines:
|
||||||
async for line in self._stdout_stream:
|
json_line = json_line.strip()
|
||||||
line_str = line.strip()
|
if not json_line:
|
||||||
if not line_str:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
json_lines = line_str.split("\n")
|
# Keep accumulating partial JSON until we can parse it
|
||||||
|
json_buffer += json_line
|
||||||
|
|
||||||
for json_line in json_lines:
|
if len(json_buffer) > _MAX_BUFFER_SIZE:
|
||||||
json_line = json_line.strip()
|
json_buffer = ""
|
||||||
if not json_line:
|
raise SDKJSONDecodeError(
|
||||||
continue
|
f"JSON message exceeded maximum buffer size of {_MAX_BUFFER_SIZE} bytes",
|
||||||
|
ValueError(
|
||||||
# Keep accumulating partial JSON until we can parse it
|
f"Buffer size {len(json_buffer)} exceeds limit {_MAX_BUFFER_SIZE}"
|
||||||
json_buffer += json_line
|
),
|
||||||
|
)
|
||||||
if len(json_buffer) > _MAX_BUFFER_SIZE:
|
|
||||||
json_buffer = ""
|
|
||||||
raise SDKJSONDecodeError(
|
|
||||||
f"JSON message exceeded maximum buffer size of {_MAX_BUFFER_SIZE} bytes",
|
|
||||||
ValueError(
|
|
||||||
f"Buffer size {len(json_buffer)} exceeds limit {_MAX_BUFFER_SIZE}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(json_buffer)
|
||||||
|
json_buffer = ""
|
||||||
try:
|
try:
|
||||||
data = json.loads(json_buffer)
|
yield data
|
||||||
json_buffer = ""
|
except GeneratorExit:
|
||||||
try:
|
return
|
||||||
yield data
|
except json.JSONDecodeError:
|
||||||
except GeneratorExit:
|
continue
|
||||||
return
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
except anyio.ClosedResourceError:
|
||||||
|
pass
|
||||||
|
except GeneratorExit:
|
||||||
|
# Client disconnected - still need to clean up
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Process stderr with safety limits
|
||||||
|
stderr_lines = []
|
||||||
|
stderr_size = 0
|
||||||
|
|
||||||
|
if self._stderr_stream:
|
||||||
|
try:
|
||||||
|
# Use timeout to prevent hanging
|
||||||
|
with anyio.fail_after(stderr_timeout):
|
||||||
|
async for line in self._stderr_stream:
|
||||||
|
line_text = line.strip()
|
||||||
|
line_size = len(line_text)
|
||||||
|
|
||||||
|
# Enforce memory limit
|
||||||
|
if stderr_size + line_size > max_stderr_size:
|
||||||
|
stderr_lines.append(
|
||||||
|
f"[stderr truncated after {stderr_size} bytes]"
|
||||||
|
)
|
||||||
|
# Drain rest of stream without storing
|
||||||
|
async for _ in self._stderr_stream:
|
||||||
|
pass
|
||||||
|
break
|
||||||
|
|
||||||
|
stderr_lines.append(line_text)
|
||||||
|
stderr_size += line_size
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
stderr_lines.append(
|
||||||
|
f"[stderr collection timed out after {stderr_timeout}s]"
|
||||||
|
)
|
||||||
except anyio.ClosedResourceError:
|
except anyio.ClosedResourceError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
await self._process.wait()
|
# Check process completion and handle errors
|
||||||
if self._process.returncode is not None and self._process.returncode != 0:
|
try:
|
||||||
stderr_output = "\n".join(stderr_lines)
|
returncode = await self._process.wait()
|
||||||
if stderr_output and "error" in stderr_output.lower():
|
except Exception:
|
||||||
raise ProcessError(
|
returncode = -1
|
||||||
"CLI process failed",
|
|
||||||
exit_code=self._process.returncode,
|
stderr_output = "\n".join(stderr_lines) if stderr_lines else ""
|
||||||
stderr=stderr_output,
|
|
||||||
)
|
# Use exit code for error detection, not string matching
|
||||||
|
if returncode is not None and returncode != 0:
|
||||||
|
raise ProcessError(
|
||||||
|
f"Command failed with exit code {returncode}",
|
||||||
|
exit_code=returncode,
|
||||||
|
stderr=stderr_output,
|
||||||
|
)
|
||||||
|
elif stderr_output:
|
||||||
|
# Log stderr for debugging but don't fail on non-zero exit
|
||||||
|
logger.debug(f"Process stderr: {stderr_output}")
|
||||||
|
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
"""Check if subprocess is running."""
|
"""Check if subprocess is running."""
|
||||||
|
|
|
@ -255,9 +255,8 @@ class TestSubprocessBuffering:
|
||||||
async for msg in transport.receive_messages():
|
async for msg in transport.receive_messages():
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
|
|
||||||
assert len(exc_info.value.exceptions) == 1
|
assert isinstance(exc_info.value, CLIJSONDecodeError)
|
||||||
assert isinstance(exc_info.value.exceptions[0], CLIJSONDecodeError)
|
assert "exceeded maximum buffer size" in str(exc_info.value)
|
||||||
assert "exceeded maximum buffer size" in str(exc_info.value.exceptions[0])
|
|
||||||
|
|
||||||
anyio.run(_test)
|
anyio.run(_test)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue