From 2d67166cae749d2d416dd110ab440c07f114b02e Mon Sep 17 00:00:00 2001 From: Carlos Cuevas <6290853+CarlosCuevas@users.noreply.github.com> Date: Thu, 4 Dec 2025 17:27:01 -0500 Subject: [PATCH] fix: add write lock to prevent concurrent transport writes (#391) ## TL;DR Adds a write lock to `SubprocessCLITransport` to prevent concurrent writes from parallel subagents. --- ## Overview When multiple subagents run in parallel and invoke MCP tools, the CLI sends concurrent `control_request` messages. Each handler tries to write a response back to the subprocess stdin at the same time. Trio's `TextSendStream` isn't thread-safe for concurrent access, so this causes `BusyResourceError`. This PR adds an `anyio.Lock` around all write operations (`write()`, `end_input()`, and the stdin-closing part of `close()`). The lock serializes concurrent writes so they happen one at a time. The `_ready` flag is now set inside the lock during `close()` to prevent a TOCTOU race where `write()` checks `_ready`, then `close()` sets it and closes the stream before `write()` actually sends data. --- ## Call Flow ```mermaid flowchart TD A["write()
subprocess_cli.py:505"] --> B["acquire _write_lock
subprocess_cli.py:507"] B --> C["check _ready & stream
subprocess_cli.py:509"] C --> D["_stdin_stream.send()
subprocess_cli.py:523"] E["close()
subprocess_cli.py:458"] --> F["acquire _write_lock
subprocess_cli.py:478"] F --> G["set _ready = False
subprocess_cli.py:479"] G --> H["close _stdin_stream
subprocess_cli.py:481"] I["end_input()
subprocess_cli.py:531"] --> J["acquire _write_lock
subprocess_cli.py:533"] J --> K["close _stdin_stream
subprocess_cli.py:535"] ``` --- .../_internal/transport/subprocess_cli.py | 70 +++++---- tests/test_transport.py | 133 ++++++++++++++++++ 2 files changed, 167 insertions(+), 36 deletions(-) diff --git a/src/claude_agent_sdk/_internal/transport/subprocess_cli.py b/src/claude_agent_sdk/_internal/transport/subprocess_cli.py index 26bd2ec..c7c7420 100644 --- a/src/claude_agent_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_agent_sdk/_internal/transport/subprocess_cli.py @@ -65,6 +65,7 @@ class SubprocessCLITransport(Transport): else _DEFAULT_MAX_BUFFER_SIZE ) self._temp_files: list[str] = [] # Track temporary files for cleanup + self._write_lock: anyio.Lock = anyio.Lock() def _find_cli(self) -> str: """Find Claude Code CLI binary.""" @@ -471,8 +472,6 @@ class SubprocessCLITransport(Transport): async def close(self) -> None: """Close the transport and clean up resources.""" - self._ready = False - # Clean up temporary files first (before early return) for temp_file in self._temp_files: with suppress(Exception): @@ -480,6 +479,7 @@ class SubprocessCLITransport(Transport): self._temp_files.clear() if not self._process: + self._ready = False return # Close stderr task group if active @@ -489,21 +489,19 @@ class SubprocessCLITransport(Transport): await self._stderr_task_group.__aexit__(None, None, None) self._stderr_task_group = None - # Close streams - if self._stdin_stream: - with suppress(Exception): - await self._stdin_stream.aclose() - self._stdin_stream = None + # Close stdin stream (acquire lock to prevent race with concurrent writes) + async with self._write_lock: + self._ready = False # Set inside lock to prevent TOCTOU with write() + if self._stdin_stream: + with suppress(Exception): + await self._stdin_stream.aclose() + self._stdin_stream = None if self._stderr_stream: with suppress(Exception): await self._stderr_stream.aclose() self._stderr_stream = None - if self._process.stdin: - with suppress(Exception): - await self._process.stdin.aclose() - # Terminate and wait for process if self._process.returncode is None: with suppress(ProcessLookupError): @@ -521,37 +519,37 @@ class SubprocessCLITransport(Transport): async def write(self, data: str) -> None: """Write raw data to the transport.""" - # Check if ready (like TypeScript) - if not self._ready or not self._stdin_stream: - raise CLIConnectionError("ProcessTransport is not ready for writing") + async with self._write_lock: + # All checks inside lock to prevent TOCTOU races with close()/end_input() + if not self._ready or not self._stdin_stream: + raise CLIConnectionError("ProcessTransport is not ready for writing") - # Check if process is still alive (like TypeScript) - if self._process and self._process.returncode is not None: - raise CLIConnectionError( - f"Cannot write to terminated process (exit code: {self._process.returncode})" - ) + if self._process and self._process.returncode is not None: + raise CLIConnectionError( + f"Cannot write to terminated process (exit code: {self._process.returncode})" + ) - # Check for exit errors (like TypeScript) - if self._exit_error: - raise CLIConnectionError( - f"Cannot write to process that exited with error: {self._exit_error}" - ) from self._exit_error + if self._exit_error: + raise CLIConnectionError( + f"Cannot write to process that exited with error: {self._exit_error}" + ) from self._exit_error - try: - await self._stdin_stream.send(data) - except Exception as e: - self._ready = False # Mark as not ready (like TypeScript) - self._exit_error = CLIConnectionError( - f"Failed to write to process stdin: {e}" - ) - raise self._exit_error from e + try: + await self._stdin_stream.send(data) + except Exception as e: + self._ready = False + self._exit_error = CLIConnectionError( + f"Failed to write to process stdin: {e}" + ) + raise self._exit_error from e async def end_input(self) -> None: """End the input stream (close stdin).""" - if self._stdin_stream: - with suppress(Exception): - await self._stdin_stream.aclose() - self._stdin_stream = None + async with self._write_lock: + if self._stdin_stream: + with suppress(Exception): + await self._stdin_stream.aclose() + self._stdin_stream = None def read_messages(self) -> AsyncIterator[dict[str, Any]]: """Read and parse messages from the transport.""" diff --git a/tests/test_transport.py b/tests/test_transport.py index c634fc2..fe9b6b2 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -693,3 +693,136 @@ class TestSubprocessCLITransport: cmd = transport._build_command() assert "--tools" not in cmd + + def test_concurrent_writes_are_serialized(self): + """Test that concurrent write() calls are serialized by the lock. + + When parallel subagents invoke MCP tools, they trigger concurrent write() + calls. Without the _write_lock, trio raises BusyResourceError. + + Uses a real subprocess with the same stream setup as production: + process.stdin -> TextSendStream + """ + + async def _test(): + import sys + from subprocess import PIPE + + from anyio.streams.text import TextSendStream + + # Create a real subprocess that consumes stdin (cross-platform) + process = await anyio.open_process( + [sys.executable, "-c", "import sys; sys.stdin.read()"], + stdin=PIPE, + stdout=PIPE, + stderr=PIPE, + ) + + try: + transport = SubprocessCLITransport( + prompt="test", + options=ClaudeAgentOptions(cli_path="/usr/bin/claude"), + ) + + # Same setup as production: TextSendStream wrapping process.stdin + transport._ready = True + transport._process = MagicMock(returncode=None) + transport._stdin_stream = TextSendStream(process.stdin) + + # Spawn concurrent writes - the lock should serialize them + num_writes = 10 + errors: list[Exception] = [] + + async def do_write(i: int): + try: + await transport.write(f'{{"msg": {i}}}\n') + except Exception as e: + errors.append(e) + + async with anyio.create_task_group() as tg: + for i in range(num_writes): + tg.start_soon(do_write, i) + + # All writes should succeed - the lock serializes them + assert len(errors) == 0, f"Got errors: {errors}" + finally: + process.terminate() + await process.wait() + + anyio.run(_test, backend="trio") + + def test_concurrent_writes_fail_without_lock(self): + """Verify that without the lock, concurrent writes cause BusyResourceError. + + Uses a real subprocess with the same stream setup as production. + """ + + async def _test(): + import sys + from contextlib import asynccontextmanager + from subprocess import PIPE + + from anyio.streams.text import TextSendStream + + # Create a real subprocess that consumes stdin (cross-platform) + process = await anyio.open_process( + [sys.executable, "-c", "import sys; sys.stdin.read()"], + stdin=PIPE, + stdout=PIPE, + stderr=PIPE, + ) + + try: + transport = SubprocessCLITransport( + prompt="test", + options=ClaudeAgentOptions(cli_path="/usr/bin/claude"), + ) + + # Same setup as production + transport._ready = True + transport._process = MagicMock(returncode=None) + transport._stdin_stream = TextSendStream(process.stdin) + + # Replace lock with no-op to trigger the race condition + class NoOpLock: + @asynccontextmanager + async def __call__(self): + yield + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + transport._write_lock = NoOpLock() + + # Spawn concurrent writes - should fail without lock + num_writes = 10 + errors: list[Exception] = [] + + async def do_write(i: int): + try: + await transport.write(f'{{"msg": {i}}}\n') + except Exception as e: + errors.append(e) + + async with anyio.create_task_group() as tg: + for i in range(num_writes): + tg.start_soon(do_write, i) + + # Should have gotten errors due to concurrent access + assert len(errors) > 0, ( + "Expected errors from concurrent access, but got none" + ) + + # Check that at least one error mentions the concurrent access + error_strs = [str(e) for e in errors] + assert any("another task" in s for s in error_strs), ( + f"Expected 'another task' error, got: {error_strs}" + ) + finally: + process.terminate() + await process.wait() + + anyio.run(_test, backend="trio")