mirror of
https://github.com/anthropics/claude-code-sdk-python.git
synced 2025-12-23 09:19:52 +00:00
fix: add write lock to prevent concurrent transport writes (#391)
## TL;DR
Adds a write lock to `SubprocessCLITransport` to prevent concurrent
writes from parallel subagents.
---
## Overview
When multiple subagents run in parallel and invoke MCP tools, the CLI
sends concurrent `control_request` messages. Each handler tries to write
a response back to the subprocess stdin at the same time. Trio's
`TextSendStream` isn't thread-safe for concurrent access, so this causes
`BusyResourceError`.
This PR adds an `anyio.Lock` around all write operations (`write()`,
`end_input()`, and the stdin-closing part of `close()`). The lock
serializes concurrent writes so they happen one at a time. The `_ready`
flag is now set inside the lock during `close()` to prevent a TOCTOU
race where `write()` checks `_ready`, then `close()` sets it and closes
the stream before `write()` actually sends data.
---
## Call Flow
```mermaid
flowchart TD
A["write()<br/>subprocess_cli.py:505"] --> B["acquire _write_lock<br/>subprocess_cli.py:507"]
B --> C["check _ready & stream<br/>subprocess_cli.py:509"]
C --> D["_stdin_stream.send()<br/>subprocess_cli.py:523"]
E["close()<br/>subprocess_cli.py:458"] --> F["acquire _write_lock<br/>subprocess_cli.py:478"]
F --> G["set _ready = False<br/>subprocess_cli.py:479"]
G --> H["close _stdin_stream<br/>subprocess_cli.py:481"]
I["end_input()<br/>subprocess_cli.py:531"] --> J["acquire _write_lock<br/>subprocess_cli.py:533"]
J --> K["close _stdin_stream<br/>subprocess_cli.py:535"]
```
This commit is contained in:
parent
00332f32dc
commit
2d67166cae
2 changed files with 167 additions and 36 deletions
|
|
@ -65,6 +65,7 @@ class SubprocessCLITransport(Transport):
|
|||
else _DEFAULT_MAX_BUFFER_SIZE
|
||||
)
|
||||
self._temp_files: list[str] = [] # Track temporary files for cleanup
|
||||
self._write_lock: anyio.Lock = anyio.Lock()
|
||||
|
||||
def _find_cli(self) -> str:
|
||||
"""Find Claude Code CLI binary."""
|
||||
|
|
@ -471,8 +472,6 @@ class SubprocessCLITransport(Transport):
|
|||
|
||||
async def close(self) -> None:
|
||||
"""Close the transport and clean up resources."""
|
||||
self._ready = False
|
||||
|
||||
# Clean up temporary files first (before early return)
|
||||
for temp_file in self._temp_files:
|
||||
with suppress(Exception):
|
||||
|
|
@ -480,6 +479,7 @@ class SubprocessCLITransport(Transport):
|
|||
self._temp_files.clear()
|
||||
|
||||
if not self._process:
|
||||
self._ready = False
|
||||
return
|
||||
|
||||
# Close stderr task group if active
|
||||
|
|
@ -489,21 +489,19 @@ class SubprocessCLITransport(Transport):
|
|||
await self._stderr_task_group.__aexit__(None, None, None)
|
||||
self._stderr_task_group = None
|
||||
|
||||
# Close streams
|
||||
if self._stdin_stream:
|
||||
with suppress(Exception):
|
||||
await self._stdin_stream.aclose()
|
||||
self._stdin_stream = None
|
||||
# Close stdin stream (acquire lock to prevent race with concurrent writes)
|
||||
async with self._write_lock:
|
||||
self._ready = False # Set inside lock to prevent TOCTOU with write()
|
||||
if self._stdin_stream:
|
||||
with suppress(Exception):
|
||||
await self._stdin_stream.aclose()
|
||||
self._stdin_stream = None
|
||||
|
||||
if self._stderr_stream:
|
||||
with suppress(Exception):
|
||||
await self._stderr_stream.aclose()
|
||||
self._stderr_stream = None
|
||||
|
||||
if self._process.stdin:
|
||||
with suppress(Exception):
|
||||
await self._process.stdin.aclose()
|
||||
|
||||
# Terminate and wait for process
|
||||
if self._process.returncode is None:
|
||||
with suppress(ProcessLookupError):
|
||||
|
|
@ -521,37 +519,37 @@ class SubprocessCLITransport(Transport):
|
|||
|
||||
async def write(self, data: str) -> None:
|
||||
"""Write raw data to the transport."""
|
||||
# Check if ready (like TypeScript)
|
||||
if not self._ready or not self._stdin_stream:
|
||||
raise CLIConnectionError("ProcessTransport is not ready for writing")
|
||||
async with self._write_lock:
|
||||
# All checks inside lock to prevent TOCTOU races with close()/end_input()
|
||||
if not self._ready or not self._stdin_stream:
|
||||
raise CLIConnectionError("ProcessTransport is not ready for writing")
|
||||
|
||||
# Check if process is still alive (like TypeScript)
|
||||
if self._process and self._process.returncode is not None:
|
||||
raise CLIConnectionError(
|
||||
f"Cannot write to terminated process (exit code: {self._process.returncode})"
|
||||
)
|
||||
if self._process and self._process.returncode is not None:
|
||||
raise CLIConnectionError(
|
||||
f"Cannot write to terminated process (exit code: {self._process.returncode})"
|
||||
)
|
||||
|
||||
# Check for exit errors (like TypeScript)
|
||||
if self._exit_error:
|
||||
raise CLIConnectionError(
|
||||
f"Cannot write to process that exited with error: {self._exit_error}"
|
||||
) from self._exit_error
|
||||
if self._exit_error:
|
||||
raise CLIConnectionError(
|
||||
f"Cannot write to process that exited with error: {self._exit_error}"
|
||||
) from self._exit_error
|
||||
|
||||
try:
|
||||
await self._stdin_stream.send(data)
|
||||
except Exception as e:
|
||||
self._ready = False # Mark as not ready (like TypeScript)
|
||||
self._exit_error = CLIConnectionError(
|
||||
f"Failed to write to process stdin: {e}"
|
||||
)
|
||||
raise self._exit_error from e
|
||||
try:
|
||||
await self._stdin_stream.send(data)
|
||||
except Exception as e:
|
||||
self._ready = False
|
||||
self._exit_error = CLIConnectionError(
|
||||
f"Failed to write to process stdin: {e}"
|
||||
)
|
||||
raise self._exit_error from e
|
||||
|
||||
async def end_input(self) -> None:
|
||||
"""End the input stream (close stdin)."""
|
||||
if self._stdin_stream:
|
||||
with suppress(Exception):
|
||||
await self._stdin_stream.aclose()
|
||||
self._stdin_stream = None
|
||||
async with self._write_lock:
|
||||
if self._stdin_stream:
|
||||
with suppress(Exception):
|
||||
await self._stdin_stream.aclose()
|
||||
self._stdin_stream = None
|
||||
|
||||
def read_messages(self) -> AsyncIterator[dict[str, Any]]:
|
||||
"""Read and parse messages from the transport."""
|
||||
|
|
|
|||
|
|
@ -693,3 +693,136 @@ class TestSubprocessCLITransport:
|
|||
|
||||
cmd = transport._build_command()
|
||||
assert "--tools" not in cmd
|
||||
|
||||
def test_concurrent_writes_are_serialized(self):
|
||||
"""Test that concurrent write() calls are serialized by the lock.
|
||||
|
||||
When parallel subagents invoke MCP tools, they trigger concurrent write()
|
||||
calls. Without the _write_lock, trio raises BusyResourceError.
|
||||
|
||||
Uses a real subprocess with the same stream setup as production:
|
||||
process.stdin -> TextSendStream
|
||||
"""
|
||||
|
||||
async def _test():
|
||||
import sys
|
||||
from subprocess import PIPE
|
||||
|
||||
from anyio.streams.text import TextSendStream
|
||||
|
||||
# Create a real subprocess that consumes stdin (cross-platform)
|
||||
process = await anyio.open_process(
|
||||
[sys.executable, "-c", "import sys; sys.stdin.read()"],
|
||||
stdin=PIPE,
|
||||
stdout=PIPE,
|
||||
stderr=PIPE,
|
||||
)
|
||||
|
||||
try:
|
||||
transport = SubprocessCLITransport(
|
||||
prompt="test",
|
||||
options=ClaudeAgentOptions(cli_path="/usr/bin/claude"),
|
||||
)
|
||||
|
||||
# Same setup as production: TextSendStream wrapping process.stdin
|
||||
transport._ready = True
|
||||
transport._process = MagicMock(returncode=None)
|
||||
transport._stdin_stream = TextSendStream(process.stdin)
|
||||
|
||||
# Spawn concurrent writes - the lock should serialize them
|
||||
num_writes = 10
|
||||
errors: list[Exception] = []
|
||||
|
||||
async def do_write(i: int):
|
||||
try:
|
||||
await transport.write(f'{{"msg": {i}}}\n')
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for i in range(num_writes):
|
||||
tg.start_soon(do_write, i)
|
||||
|
||||
# All writes should succeed - the lock serializes them
|
||||
assert len(errors) == 0, f"Got errors: {errors}"
|
||||
finally:
|
||||
process.terminate()
|
||||
await process.wait()
|
||||
|
||||
anyio.run(_test, backend="trio")
|
||||
|
||||
def test_concurrent_writes_fail_without_lock(self):
|
||||
"""Verify that without the lock, concurrent writes cause BusyResourceError.
|
||||
|
||||
Uses a real subprocess with the same stream setup as production.
|
||||
"""
|
||||
|
||||
async def _test():
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from subprocess import PIPE
|
||||
|
||||
from anyio.streams.text import TextSendStream
|
||||
|
||||
# Create a real subprocess that consumes stdin (cross-platform)
|
||||
process = await anyio.open_process(
|
||||
[sys.executable, "-c", "import sys; sys.stdin.read()"],
|
||||
stdin=PIPE,
|
||||
stdout=PIPE,
|
||||
stderr=PIPE,
|
||||
)
|
||||
|
||||
try:
|
||||
transport = SubprocessCLITransport(
|
||||
prompt="test",
|
||||
options=ClaudeAgentOptions(cli_path="/usr/bin/claude"),
|
||||
)
|
||||
|
||||
# Same setup as production
|
||||
transport._ready = True
|
||||
transport._process = MagicMock(returncode=None)
|
||||
transport._stdin_stream = TextSendStream(process.stdin)
|
||||
|
||||
# Replace lock with no-op to trigger the race condition
|
||||
class NoOpLock:
|
||||
@asynccontextmanager
|
||||
async def __call__(self):
|
||||
yield
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
transport._write_lock = NoOpLock()
|
||||
|
||||
# Spawn concurrent writes - should fail without lock
|
||||
num_writes = 10
|
||||
errors: list[Exception] = []
|
||||
|
||||
async def do_write(i: int):
|
||||
try:
|
||||
await transport.write(f'{{"msg": {i}}}\n')
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for i in range(num_writes):
|
||||
tg.start_soon(do_write, i)
|
||||
|
||||
# Should have gotten errors due to concurrent access
|
||||
assert len(errors) > 0, (
|
||||
"Expected errors from concurrent access, but got none"
|
||||
)
|
||||
|
||||
# Check that at least one error mentions the concurrent access
|
||||
error_strs = [str(e) for e in errors]
|
||||
assert any("another task" in s for s in error_strs), (
|
||||
f"Expected 'another task' error, got: {error_strs}"
|
||||
)
|
||||
finally:
|
||||
process.terminate()
|
||||
await process.wait()
|
||||
|
||||
anyio.run(_test, backend="trio")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue