fix: add write lock to prevent concurrent transport writes

When parallel subagents invoke MCP tools, the CLI sends multiple
concurrent control_request messages. Without synchronization, handlers
race to write responses back, causing trio.BusyResourceError.

This adds an anyio.Lock to serialize writes to stdin, and moves the
_ready flag inside the lock to prevent TOCTOU races with close().

🏠 Remote-Dev: homespace
This commit is contained in:
Carlos Cuevas 2025-11-25 02:15:08 +00:00
parent 00332f32dc
commit a876feece1
No known key found for this signature in database
2 changed files with 158 additions and 36 deletions

View file

@ -65,6 +65,7 @@ class SubprocessCLITransport(Transport):
else _DEFAULT_MAX_BUFFER_SIZE
)
self._temp_files: list[str] = [] # Track temporary files for cleanup
self._write_lock: anyio.Lock = anyio.Lock()
def _find_cli(self) -> str:
"""Find Claude Code CLI binary."""
@ -471,8 +472,6 @@ class SubprocessCLITransport(Transport):
async def close(self) -> None:
"""Close the transport and clean up resources."""
self._ready = False
# Clean up temporary files first (before early return)
for temp_file in self._temp_files:
with suppress(Exception):
@ -480,6 +479,7 @@ class SubprocessCLITransport(Transport):
self._temp_files.clear()
if not self._process:
self._ready = False
return
# Close stderr task group if active
@ -489,21 +489,19 @@ class SubprocessCLITransport(Transport):
await self._stderr_task_group.__aexit__(None, None, None)
self._stderr_task_group = None
# Close streams
if self._stdin_stream:
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None
# Close stdin stream (acquire lock to prevent race with concurrent writes)
async with self._write_lock:
self._ready = False # Set inside lock to prevent TOCTOU with write()
if self._stdin_stream:
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None
if self._stderr_stream:
with suppress(Exception):
await self._stderr_stream.aclose()
self._stderr_stream = None
if self._process.stdin:
with suppress(Exception):
await self._process.stdin.aclose()
# Terminate and wait for process
if self._process.returncode is None:
with suppress(ProcessLookupError):
@ -521,37 +519,37 @@ class SubprocessCLITransport(Transport):
async def write(self, data: str) -> None:
"""Write raw data to the transport."""
# Check if ready (like TypeScript)
if not self._ready or not self._stdin_stream:
raise CLIConnectionError("ProcessTransport is not ready for writing")
async with self._write_lock:
# All checks inside lock to prevent TOCTOU races with close()/end_input()
if not self._ready or not self._stdin_stream:
raise CLIConnectionError("ProcessTransport is not ready for writing")
# Check if process is still alive (like TypeScript)
if self._process and self._process.returncode is not None:
raise CLIConnectionError(
f"Cannot write to terminated process (exit code: {self._process.returncode})"
)
if self._process and self._process.returncode is not None:
raise CLIConnectionError(
f"Cannot write to terminated process (exit code: {self._process.returncode})"
)
# Check for exit errors (like TypeScript)
if self._exit_error:
raise CLIConnectionError(
f"Cannot write to process that exited with error: {self._exit_error}"
) from self._exit_error
if self._exit_error:
raise CLIConnectionError(
f"Cannot write to process that exited with error: {self._exit_error}"
) from self._exit_error
try:
await self._stdin_stream.send(data)
except Exception as e:
self._ready = False # Mark as not ready (like TypeScript)
self._exit_error = CLIConnectionError(
f"Failed to write to process stdin: {e}"
)
raise self._exit_error from e
try:
await self._stdin_stream.send(data)
except Exception as e:
self._ready = False
self._exit_error = CLIConnectionError(
f"Failed to write to process stdin: {e}"
)
raise self._exit_error from e
async def end_input(self) -> None:
"""End the input stream (close stdin)."""
if self._stdin_stream:
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None
async with self._write_lock:
if self._stdin_stream:
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None
def read_messages(self) -> AsyncIterator[dict[str, Any]]:
"""Read and parse messages from the transport."""

View file

@ -693,3 +693,127 @@ class TestSubprocessCLITransport:
cmd = transport._build_command()
assert "--tools" not in cmd
def test_concurrent_writes_are_serialized(self):
"""Test that concurrent write() calls are serialized by the lock.
When parallel subagents invoke MCP tools, they trigger concurrent write()
calls. Without the _write_lock, trio raises BusyResourceError.
Uses the exact same stream chain as production:
FdStream -> SendStreamWrapper -> TextSendStream
"""
async def _test():
from anyio._backends._trio import SendStreamWrapper
from anyio.streams.text import TextSendStream
from trio.lowlevel import FdStream
transport = SubprocessCLITransport(
prompt="test",
options=ClaudeAgentOptions(cli_path="/usr/bin/claude"),
)
# Create a pipe - FdStream is the same type used for process stdin
read_fd, write_fd = os.pipe()
# Exact same wrapping as production: FdStream -> SendStreamWrapper -> TextSendStream
fd_stream = FdStream(write_fd)
transport._ready = True
transport._process = MagicMock(returncode=None)
transport._stdin_stream = TextSendStream(SendStreamWrapper(fd_stream))
# Spawn concurrent writes - the lock should serialize them
num_writes = 10
errors: list[Exception] = []
async def do_write(i: int):
try:
await transport.write(f'{{"msg": {i}}}\n')
except Exception as e:
errors.append(e)
try:
async with anyio.create_task_group() as tg:
for i in range(num_writes):
tg.start_soon(do_write, i)
# All writes should succeed - the lock serializes them
assert len(errors) == 0, f"Got errors: {errors}"
finally:
os.close(read_fd)
await fd_stream.aclose()
anyio.run(_test, backend="trio")
def test_concurrent_writes_fail_without_lock(self):
"""Verify that without the lock, concurrent writes cause BusyResourceError.
Uses the exact same stream chain as production to prove the lock is necessary.
"""
async def _test():
from contextlib import asynccontextmanager
from anyio._backends._trio import SendStreamWrapper
from anyio.streams.text import TextSendStream
from trio.lowlevel import FdStream
transport = SubprocessCLITransport(
prompt="test",
options=ClaudeAgentOptions(cli_path="/usr/bin/claude"),
)
# Create a pipe - FdStream is the same type used for process stdin
read_fd, write_fd = os.pipe()
# Exact same wrapping as production
fd_stream = FdStream(write_fd)
transport._ready = True
transport._process = MagicMock(returncode=None)
transport._stdin_stream = TextSendStream(SendStreamWrapper(fd_stream))
# Replace lock with no-op to trigger the race condition
class NoOpLock:
@asynccontextmanager
async def __call__(self):
yield
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
transport._write_lock = NoOpLock()
# Spawn concurrent writes - should fail without lock
num_writes = 10
errors: list[Exception] = []
async def do_write(i: int):
try:
await transport.write(f'{{"msg": {i}}}\n')
except Exception as e:
errors.append(e)
try:
async with anyio.create_task_group() as tg:
for i in range(num_writes):
tg.start_soon(do_write, i)
# Should have gotten errors due to concurrent access
assert len(errors) > 0, (
"Expected errors from concurrent access, but got none"
)
# Check that at least one error mentions the concurrent access
error_strs = [str(e) for e in errors]
assert any("another task" in s for s in error_strs), (
f"Expected 'another task' error, got: {error_strs}"
)
finally:
os.close(read_fd)
await fd_stream.aclose()
anyio.run(_test, backend="trio")