fix: add write lock to prevent concurrent transport writes (#391)

## TL;DR
Adds a write lock to `SubprocessCLITransport` to prevent concurrent
writes from parallel subagents.

---

## Overview
When multiple subagents run in parallel and invoke MCP tools, the CLI
sends concurrent `control_request` messages. Each handler tries to write
a response back to the subprocess stdin at the same time. Trio's
`TextSendStream` isn't thread-safe for concurrent access, so this causes
`BusyResourceError`.

This PR adds an `anyio.Lock` around all write operations (`write()`,
`end_input()`, and the stdin-closing part of `close()`). The lock
serializes concurrent writes so they happen one at a time. The `_ready`
flag is now set inside the lock during `close()` to prevent a TOCTOU
race where `write()` checks `_ready`, then `close()` sets it and closes
the stream before `write()` actually sends data.

---

## Call Flow
```mermaid
flowchart TD
    A["write()<br/>subprocess_cli.py:505"] --> B["acquire _write_lock<br/>subprocess_cli.py:507"]
    B --> C["check _ready & stream<br/>subprocess_cli.py:509"]
    C --> D["_stdin_stream.send()<br/>subprocess_cli.py:523"]
    
    E["close()<br/>subprocess_cli.py:458"] --> F["acquire _write_lock<br/>subprocess_cli.py:478"]
    F --> G["set _ready = False<br/>subprocess_cli.py:479"]
    G --> H["close _stdin_stream<br/>subprocess_cli.py:481"]
    
    I["end_input()<br/>subprocess_cli.py:531"] --> J["acquire _write_lock<br/>subprocess_cli.py:533"]
    J --> K["close _stdin_stream<br/>subprocess_cli.py:535"]
```
This commit is contained in:
Carlos Cuevas 2025-12-04 17:27:01 -05:00 committed by GitHub
parent 00332f32dc
commit 2d67166cae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 167 additions and 36 deletions

View file

@ -65,6 +65,7 @@ class SubprocessCLITransport(Transport):
else _DEFAULT_MAX_BUFFER_SIZE
)
self._temp_files: list[str] = [] # Track temporary files for cleanup
self._write_lock: anyio.Lock = anyio.Lock()
def _find_cli(self) -> str:
"""Find Claude Code CLI binary."""
@ -471,8 +472,6 @@ class SubprocessCLITransport(Transport):
async def close(self) -> None:
"""Close the transport and clean up resources."""
self._ready = False
# Clean up temporary files first (before early return)
for temp_file in self._temp_files:
with suppress(Exception):
@ -480,6 +479,7 @@ class SubprocessCLITransport(Transport):
self._temp_files.clear()
if not self._process:
self._ready = False
return
# Close stderr task group if active
@ -489,21 +489,19 @@ class SubprocessCLITransport(Transport):
await self._stderr_task_group.__aexit__(None, None, None)
self._stderr_task_group = None
# Close streams
if self._stdin_stream:
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None
# Close stdin stream (acquire lock to prevent race with concurrent writes)
async with self._write_lock:
self._ready = False # Set inside lock to prevent TOCTOU with write()
if self._stdin_stream:
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None
if self._stderr_stream:
with suppress(Exception):
await self._stderr_stream.aclose()
self._stderr_stream = None
if self._process.stdin:
with suppress(Exception):
await self._process.stdin.aclose()
# Terminate and wait for process
if self._process.returncode is None:
with suppress(ProcessLookupError):
@ -521,37 +519,37 @@ class SubprocessCLITransport(Transport):
async def write(self, data: str) -> None:
"""Write raw data to the transport."""
# Check if ready (like TypeScript)
if not self._ready or not self._stdin_stream:
raise CLIConnectionError("ProcessTransport is not ready for writing")
async with self._write_lock:
# All checks inside lock to prevent TOCTOU races with close()/end_input()
if not self._ready or not self._stdin_stream:
raise CLIConnectionError("ProcessTransport is not ready for writing")
# Check if process is still alive (like TypeScript)
if self._process and self._process.returncode is not None:
raise CLIConnectionError(
f"Cannot write to terminated process (exit code: {self._process.returncode})"
)
if self._process and self._process.returncode is not None:
raise CLIConnectionError(
f"Cannot write to terminated process (exit code: {self._process.returncode})"
)
# Check for exit errors (like TypeScript)
if self._exit_error:
raise CLIConnectionError(
f"Cannot write to process that exited with error: {self._exit_error}"
) from self._exit_error
if self._exit_error:
raise CLIConnectionError(
f"Cannot write to process that exited with error: {self._exit_error}"
) from self._exit_error
try:
await self._stdin_stream.send(data)
except Exception as e:
self._ready = False # Mark as not ready (like TypeScript)
self._exit_error = CLIConnectionError(
f"Failed to write to process stdin: {e}"
)
raise self._exit_error from e
try:
await self._stdin_stream.send(data)
except Exception as e:
self._ready = False
self._exit_error = CLIConnectionError(
f"Failed to write to process stdin: {e}"
)
raise self._exit_error from e
async def end_input(self) -> None:
"""End the input stream (close stdin)."""
if self._stdin_stream:
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None
async with self._write_lock:
if self._stdin_stream:
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None
def read_messages(self) -> AsyncIterator[dict[str, Any]]:
"""Read and parse messages from the transport."""

View file

@ -693,3 +693,136 @@ class TestSubprocessCLITransport:
cmd = transport._build_command()
assert "--tools" not in cmd
def test_concurrent_writes_are_serialized(self):
"""Test that concurrent write() calls are serialized by the lock.
When parallel subagents invoke MCP tools, they trigger concurrent write()
calls. Without the _write_lock, trio raises BusyResourceError.
Uses a real subprocess with the same stream setup as production:
process.stdin -> TextSendStream
"""
async def _test():
import sys
from subprocess import PIPE
from anyio.streams.text import TextSendStream
# Create a real subprocess that consumes stdin (cross-platform)
process = await anyio.open_process(
[sys.executable, "-c", "import sys; sys.stdin.read()"],
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
)
try:
transport = SubprocessCLITransport(
prompt="test",
options=ClaudeAgentOptions(cli_path="/usr/bin/claude"),
)
# Same setup as production: TextSendStream wrapping process.stdin
transport._ready = True
transport._process = MagicMock(returncode=None)
transport._stdin_stream = TextSendStream(process.stdin)
# Spawn concurrent writes - the lock should serialize them
num_writes = 10
errors: list[Exception] = []
async def do_write(i: int):
try:
await transport.write(f'{{"msg": {i}}}\n')
except Exception as e:
errors.append(e)
async with anyio.create_task_group() as tg:
for i in range(num_writes):
tg.start_soon(do_write, i)
# All writes should succeed - the lock serializes them
assert len(errors) == 0, f"Got errors: {errors}"
finally:
process.terminate()
await process.wait()
anyio.run(_test, backend="trio")
def test_concurrent_writes_fail_without_lock(self):
"""Verify that without the lock, concurrent writes cause BusyResourceError.
Uses a real subprocess with the same stream setup as production.
"""
async def _test():
import sys
from contextlib import asynccontextmanager
from subprocess import PIPE
from anyio.streams.text import TextSendStream
# Create a real subprocess that consumes stdin (cross-platform)
process = await anyio.open_process(
[sys.executable, "-c", "import sys; sys.stdin.read()"],
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
)
try:
transport = SubprocessCLITransport(
prompt="test",
options=ClaudeAgentOptions(cli_path="/usr/bin/claude"),
)
# Same setup as production
transport._ready = True
transport._process = MagicMock(returncode=None)
transport._stdin_stream = TextSendStream(process.stdin)
# Replace lock with no-op to trigger the race condition
class NoOpLock:
@asynccontextmanager
async def __call__(self):
yield
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
transport._write_lock = NoOpLock()
# Spawn concurrent writes - should fail without lock
num_writes = 10
errors: list[Exception] = []
async def do_write(i: int):
try:
await transport.write(f'{{"msg": {i}}}\n')
except Exception as e:
errors.append(e)
async with anyio.create_task_group() as tg:
for i in range(num_writes):
tg.start_soon(do_write, i)
# Should have gotten errors due to concurrent access
assert len(errors) > 0, (
"Expected errors from concurrent access, but got none"
)
# Check that at least one error mentions the concurrent access
error_strs = [str(e) for e in errors]
assert any("another task" in s for s in error_strs), (
f"Expected 'another task' error, got: {error_strs}"
)
finally:
process.terminate()
await process.wait()
anyio.run(_test, backend="trio")