fix: improve concurrent write test

- Move json import to top of file
- Simplify test by directly patching _stdin_stream.send
- Remove unnecessary call to original_send

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Ashwin Bhat 2025-11-20 10:14:14 -08:00
parent a914f05721
commit aa437fd298
No known key found for this signature in database

View file

@ -1,5 +1,6 @@
"""Tests for Claude SDK transport layer."""
import json
import os
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
@ -317,7 +318,6 @@ class TestSubprocessCLITransport:
def test_build_command_with_mcp_servers(self):
"""Test building CLI command with mcp_servers option."""
import json
mcp_servers = {
"test-server": {
@ -503,7 +503,6 @@ class TestSubprocessCLITransport:
def test_concurrent_writes_are_serialized(self):
"""Test that concurrent writes to stdin are serialized by the lock."""
import json
async def _test():
with patch("anyio.open_process") as mock_exec:
@ -520,27 +519,11 @@ class TestSubprocessCLITransport:
mock_process = MagicMock()
mock_process.returncode = None
mock_process.stdout = MagicMock()
mock_process.stdin = MagicMock()
mock_process.stdin.aclose = AsyncMock()
# Track write calls to verify serialization
write_events: list[tuple[str, str]] = [] # (event_type, data_preview)
async def mock_send(data: str):
"""Mock send that tracks start/end of each write."""
preview = data[:20]
write_events.append(("start", preview))
# Small delay to increase chance of interleaving if lock is broken
await anyio.sleep(0.001)
write_events.append(("end", preview))
mock_stdin = MagicMock()
mock_stdin.send = mock_send
mock_stdin.aclose = AsyncMock()
mock_process.stdin = mock_stdin
# Return version process first, then main process
mock_exec.side_effect = [mock_version_process, mock_process]
# Create transport in streaming mode (required for stdin writes)
async def dummy_stream():
yield {"type": "user"}
@ -552,30 +535,32 @@ class TestSubprocessCLITransport:
await transport.connect()
assert transport.is_ready()
# Spawn multiple concurrent writes
# Track write events to verify serialization
write_events: list[tuple[str, int]] = []
async def tracked_send(data: str):
msg_id = json.loads(data.strip())["id"]
write_events.append(("start", msg_id))
await anyio.sleep(0.001) # Yield to expose interleaving
write_events.append(("end", msg_id))
transport._stdin_stream.send = tracked_send
# Spawn concurrent writes
num_writes = 5
messages = [
json.dumps({"id": i, "data": "x" * 50}) + "\n"
for i in range(num_writes)
]
messages = [json.dumps({"id": i}) + "\n" for i in range(num_writes)]
async with anyio.create_task_group() as tg:
for msg in messages:
tg.start_soon(transport.write, msg)
# Verify serialization: events should be strictly alternating start/end pairs
# If writes interleaved, we'd see patterns like: start, start, end, end
# Verify: each start must be immediately followed by its end
# If interleaved, we'd see: start(0), start(1), end(0), end(1)
assert len(write_events) == num_writes * 2
for i in range(0, len(write_events), 2):
assert write_events[i][0] == "start", f"Expected start at index {i}"
assert write_events[i + 1][0] == "end", (
f"Expected end at index {i + 1}"
)
# Each start should be followed by its corresponding end
assert write_events[i][1] == write_events[i + 1][1], (
f"Mismatched start/end at index {i}"
)
assert write_events[i][0] == "start"
assert write_events[i + 1][0] == "end"
assert write_events[i][1] == write_events[i + 1][1]
await transport.close()