diff --git a/tests/test_transport.py b/tests/test_transport.py index c5baaea..950123b 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -1,5 +1,6 @@ """Tests for Claude SDK transport layer.""" +import json import os import uuid from unittest.mock import AsyncMock, MagicMock, patch @@ -317,7 +318,6 @@ class TestSubprocessCLITransport: def test_build_command_with_mcp_servers(self): """Test building CLI command with mcp_servers option.""" - import json mcp_servers = { "test-server": { @@ -503,7 +503,6 @@ class TestSubprocessCLITransport: def test_concurrent_writes_are_serialized(self): """Test that concurrent writes to stdin are serialized by the lock.""" - import json async def _test(): with patch("anyio.open_process") as mock_exec: @@ -520,27 +519,11 @@ class TestSubprocessCLITransport: mock_process = MagicMock() mock_process.returncode = None mock_process.stdout = MagicMock() + mock_process.stdin = MagicMock() + mock_process.stdin.aclose = AsyncMock() - # Track write calls to verify serialization - write_events: list[tuple[str, str]] = [] # (event_type, data_preview) - - async def mock_send(data: str): - """Mock send that tracks start/end of each write.""" - preview = data[:20] - write_events.append(("start", preview)) - # Small delay to increase chance of interleaving if lock is broken - await anyio.sleep(0.001) - write_events.append(("end", preview)) - - mock_stdin = MagicMock() - mock_stdin.send = mock_send - mock_stdin.aclose = AsyncMock() - mock_process.stdin = mock_stdin - - # Return version process first, then main process mock_exec.side_effect = [mock_version_process, mock_process] - # Create transport in streaming mode (required for stdin writes) async def dummy_stream(): yield {"type": "user"} @@ -552,30 +535,32 @@ class TestSubprocessCLITransport: await transport.connect() assert transport.is_ready() - # Spawn multiple concurrent writes + # Track write events to verify serialization + write_events: list[tuple[str, int]] = [] + + async def tracked_send(data: str): + msg_id = json.loads(data.strip())["id"] + write_events.append(("start", msg_id)) + await anyio.sleep(0.001) # Yield to expose interleaving + write_events.append(("end", msg_id)) + + transport._stdin_stream.send = tracked_send + + # Spawn concurrent writes num_writes = 5 - messages = [ - json.dumps({"id": i, "data": "x" * 50}) + "\n" - for i in range(num_writes) - ] + messages = [json.dumps({"id": i}) + "\n" for i in range(num_writes)] async with anyio.create_task_group() as tg: for msg in messages: tg.start_soon(transport.write, msg) - # Verify serialization: events should be strictly alternating start/end pairs - # If writes interleaved, we'd see patterns like: start, start, end, end + # Verify: each start must be immediately followed by its end + # If interleaved, we'd see: start(0), start(1), end(0), end(1) assert len(write_events) == num_writes * 2 - for i in range(0, len(write_events), 2): - assert write_events[i][0] == "start", f"Expected start at index {i}" - assert write_events[i + 1][0] == "end", ( - f"Expected end at index {i + 1}" - ) - # Each start should be followed by its corresponding end - assert write_events[i][1] == write_events[i + 1][1], ( - f"Mismatched start/end at index {i}" - ) + assert write_events[i][0] == "start" + assert write_events[i + 1][0] == "end" + assert write_events[i][1] == write_events[i + 1][1] await transport.close()