mirror of
https://github.com/anthropics/claude-code-sdk-python.git
synced 2025-12-23 09:19:52 +00:00
fix: add write lock to prevent concurrent transport writes
When parallel subagents invoke MCP tools, the CLI sends multiple
concurrent control_request messages. Without synchronization, handlers
race to write responses back, causing trio.BusyResourceError.
This adds an anyio.Lock to serialize writes to stdin, and moves the
_ready flag inside the lock to prevent TOCTOU races with close().
🏠 Remote-Dev: homespace
This commit is contained in:
parent
00332f32dc
commit
a876feece1
2 changed files with 158 additions and 36 deletions
|
|
@ -65,6 +65,7 @@ class SubprocessCLITransport(Transport):
|
|||
else _DEFAULT_MAX_BUFFER_SIZE
|
||||
)
|
||||
self._temp_files: list[str] = [] # Track temporary files for cleanup
|
||||
self._write_lock: anyio.Lock = anyio.Lock()
|
||||
|
||||
def _find_cli(self) -> str:
|
||||
"""Find Claude Code CLI binary."""
|
||||
|
|
@ -471,8 +472,6 @@ class SubprocessCLITransport(Transport):
|
|||
|
||||
async def close(self) -> None:
|
||||
"""Close the transport and clean up resources."""
|
||||
self._ready = False
|
||||
|
||||
# Clean up temporary files first (before early return)
|
||||
for temp_file in self._temp_files:
|
||||
with suppress(Exception):
|
||||
|
|
@ -480,6 +479,7 @@ class SubprocessCLITransport(Transport):
|
|||
self._temp_files.clear()
|
||||
|
||||
if not self._process:
|
||||
self._ready = False
|
||||
return
|
||||
|
||||
# Close stderr task group if active
|
||||
|
|
@ -489,21 +489,19 @@ class SubprocessCLITransport(Transport):
|
|||
await self._stderr_task_group.__aexit__(None, None, None)
|
||||
self._stderr_task_group = None
|
||||
|
||||
# Close streams
|
||||
if self._stdin_stream:
|
||||
with suppress(Exception):
|
||||
await self._stdin_stream.aclose()
|
||||
self._stdin_stream = None
|
||||
# Close stdin stream (acquire lock to prevent race with concurrent writes)
|
||||
async with self._write_lock:
|
||||
self._ready = False # Set inside lock to prevent TOCTOU with write()
|
||||
if self._stdin_stream:
|
||||
with suppress(Exception):
|
||||
await self._stdin_stream.aclose()
|
||||
self._stdin_stream = None
|
||||
|
||||
if self._stderr_stream:
|
||||
with suppress(Exception):
|
||||
await self._stderr_stream.aclose()
|
||||
self._stderr_stream = None
|
||||
|
||||
if self._process.stdin:
|
||||
with suppress(Exception):
|
||||
await self._process.stdin.aclose()
|
||||
|
||||
# Terminate and wait for process
|
||||
if self._process.returncode is None:
|
||||
with suppress(ProcessLookupError):
|
||||
|
|
@ -521,37 +519,37 @@ class SubprocessCLITransport(Transport):
|
|||
|
||||
async def write(self, data: str) -> None:
|
||||
"""Write raw data to the transport."""
|
||||
# Check if ready (like TypeScript)
|
||||
if not self._ready or not self._stdin_stream:
|
||||
raise CLIConnectionError("ProcessTransport is not ready for writing")
|
||||
async with self._write_lock:
|
||||
# All checks inside lock to prevent TOCTOU races with close()/end_input()
|
||||
if not self._ready or not self._stdin_stream:
|
||||
raise CLIConnectionError("ProcessTransport is not ready for writing")
|
||||
|
||||
# Check if process is still alive (like TypeScript)
|
||||
if self._process and self._process.returncode is not None:
|
||||
raise CLIConnectionError(
|
||||
f"Cannot write to terminated process (exit code: {self._process.returncode})"
|
||||
)
|
||||
if self._process and self._process.returncode is not None:
|
||||
raise CLIConnectionError(
|
||||
f"Cannot write to terminated process (exit code: {self._process.returncode})"
|
||||
)
|
||||
|
||||
# Check for exit errors (like TypeScript)
|
||||
if self._exit_error:
|
||||
raise CLIConnectionError(
|
||||
f"Cannot write to process that exited with error: {self._exit_error}"
|
||||
) from self._exit_error
|
||||
if self._exit_error:
|
||||
raise CLIConnectionError(
|
||||
f"Cannot write to process that exited with error: {self._exit_error}"
|
||||
) from self._exit_error
|
||||
|
||||
try:
|
||||
await self._stdin_stream.send(data)
|
||||
except Exception as e:
|
||||
self._ready = False # Mark as not ready (like TypeScript)
|
||||
self._exit_error = CLIConnectionError(
|
||||
f"Failed to write to process stdin: {e}"
|
||||
)
|
||||
raise self._exit_error from e
|
||||
try:
|
||||
await self._stdin_stream.send(data)
|
||||
except Exception as e:
|
||||
self._ready = False
|
||||
self._exit_error = CLIConnectionError(
|
||||
f"Failed to write to process stdin: {e}"
|
||||
)
|
||||
raise self._exit_error from e
|
||||
|
||||
async def end_input(self) -> None:
|
||||
"""End the input stream (close stdin)."""
|
||||
if self._stdin_stream:
|
||||
with suppress(Exception):
|
||||
await self._stdin_stream.aclose()
|
||||
self._stdin_stream = None
|
||||
async with self._write_lock:
|
||||
if self._stdin_stream:
|
||||
with suppress(Exception):
|
||||
await self._stdin_stream.aclose()
|
||||
self._stdin_stream = None
|
||||
|
||||
def read_messages(self) -> AsyncIterator[dict[str, Any]]:
|
||||
"""Read and parse messages from the transport."""
|
||||
|
|
|
|||
|
|
@ -693,3 +693,127 @@ class TestSubprocessCLITransport:
|
|||
|
||||
cmd = transport._build_command()
|
||||
assert "--tools" not in cmd
|
||||
|
||||
def test_concurrent_writes_are_serialized(self):
|
||||
"""Test that concurrent write() calls are serialized by the lock.
|
||||
|
||||
When parallel subagents invoke MCP tools, they trigger concurrent write()
|
||||
calls. Without the _write_lock, trio raises BusyResourceError.
|
||||
|
||||
Uses the exact same stream chain as production:
|
||||
FdStream -> SendStreamWrapper -> TextSendStream
|
||||
"""
|
||||
|
||||
async def _test():
|
||||
from anyio._backends._trio import SendStreamWrapper
|
||||
from anyio.streams.text import TextSendStream
|
||||
from trio.lowlevel import FdStream
|
||||
|
||||
transport = SubprocessCLITransport(
|
||||
prompt="test",
|
||||
options=ClaudeAgentOptions(cli_path="/usr/bin/claude"),
|
||||
)
|
||||
|
||||
# Create a pipe - FdStream is the same type used for process stdin
|
||||
read_fd, write_fd = os.pipe()
|
||||
|
||||
# Exact same wrapping as production: FdStream -> SendStreamWrapper -> TextSendStream
|
||||
fd_stream = FdStream(write_fd)
|
||||
transport._ready = True
|
||||
transport._process = MagicMock(returncode=None)
|
||||
transport._stdin_stream = TextSendStream(SendStreamWrapper(fd_stream))
|
||||
|
||||
# Spawn concurrent writes - the lock should serialize them
|
||||
num_writes = 10
|
||||
errors: list[Exception] = []
|
||||
|
||||
async def do_write(i: int):
|
||||
try:
|
||||
await transport.write(f'{{"msg": {i}}}\n')
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
try:
|
||||
async with anyio.create_task_group() as tg:
|
||||
for i in range(num_writes):
|
||||
tg.start_soon(do_write, i)
|
||||
|
||||
# All writes should succeed - the lock serializes them
|
||||
assert len(errors) == 0, f"Got errors: {errors}"
|
||||
finally:
|
||||
os.close(read_fd)
|
||||
await fd_stream.aclose()
|
||||
|
||||
anyio.run(_test, backend="trio")
|
||||
|
||||
def test_concurrent_writes_fail_without_lock(self):
|
||||
"""Verify that without the lock, concurrent writes cause BusyResourceError.
|
||||
|
||||
Uses the exact same stream chain as production to prove the lock is necessary.
|
||||
"""
|
||||
|
||||
async def _test():
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from anyio._backends._trio import SendStreamWrapper
|
||||
from anyio.streams.text import TextSendStream
|
||||
from trio.lowlevel import FdStream
|
||||
|
||||
transport = SubprocessCLITransport(
|
||||
prompt="test",
|
||||
options=ClaudeAgentOptions(cli_path="/usr/bin/claude"),
|
||||
)
|
||||
|
||||
# Create a pipe - FdStream is the same type used for process stdin
|
||||
read_fd, write_fd = os.pipe()
|
||||
|
||||
# Exact same wrapping as production
|
||||
fd_stream = FdStream(write_fd)
|
||||
transport._ready = True
|
||||
transport._process = MagicMock(returncode=None)
|
||||
transport._stdin_stream = TextSendStream(SendStreamWrapper(fd_stream))
|
||||
|
||||
# Replace lock with no-op to trigger the race condition
|
||||
class NoOpLock:
|
||||
@asynccontextmanager
|
||||
async def __call__(self):
|
||||
yield
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
transport._write_lock = NoOpLock()
|
||||
|
||||
# Spawn concurrent writes - should fail without lock
|
||||
num_writes = 10
|
||||
errors: list[Exception] = []
|
||||
|
||||
async def do_write(i: int):
|
||||
try:
|
||||
await transport.write(f'{{"msg": {i}}}\n')
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
try:
|
||||
async with anyio.create_task_group() as tg:
|
||||
for i in range(num_writes):
|
||||
tg.start_soon(do_write, i)
|
||||
|
||||
# Should have gotten errors due to concurrent access
|
||||
assert len(errors) > 0, (
|
||||
"Expected errors from concurrent access, but got none"
|
||||
)
|
||||
|
||||
# Check that at least one error mentions the concurrent access
|
||||
error_strs = [str(e) for e in errors]
|
||||
assert any("another task" in s for s in error_strs), (
|
||||
f"Expected 'another task' error, got: {error_strs}"
|
||||
)
|
||||
finally:
|
||||
os.close(read_fd)
|
||||
await fd_stream.aclose()
|
||||
|
||||
anyio.run(_test, backend="trio")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue