mirror of
https://github.com/anthropics/claude-code-sdk-python.git
synced 2025-12-23 09:19:52 +00:00
fix: improve concurrent write test
- Move json import to top of file - Simplify test by directly patching _stdin_stream.send - Remove unnecessary call to original_send 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
a914f05721
commit
aa437fd298
1 changed files with 21 additions and 36 deletions
|
|
@ -1,5 +1,6 @@
|
|||
"""Tests for Claude SDK transport layer."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
|
@ -317,7 +318,6 @@ class TestSubprocessCLITransport:
|
|||
|
||||
def test_build_command_with_mcp_servers(self):
|
||||
"""Test building CLI command with mcp_servers option."""
|
||||
import json
|
||||
|
||||
mcp_servers = {
|
||||
"test-server": {
|
||||
|
|
@ -503,7 +503,6 @@ class TestSubprocessCLITransport:
|
|||
|
||||
def test_concurrent_writes_are_serialized(self):
|
||||
"""Test that concurrent writes to stdin are serialized by the lock."""
|
||||
import json
|
||||
|
||||
async def _test():
|
||||
with patch("anyio.open_process") as mock_exec:
|
||||
|
|
@ -520,27 +519,11 @@ class TestSubprocessCLITransport:
|
|||
mock_process = MagicMock()
|
||||
mock_process.returncode = None
|
||||
mock_process.stdout = MagicMock()
|
||||
mock_process.stdin = MagicMock()
|
||||
mock_process.stdin.aclose = AsyncMock()
|
||||
|
||||
# Track write calls to verify serialization
|
||||
write_events: list[tuple[str, str]] = [] # (event_type, data_preview)
|
||||
|
||||
async def mock_send(data: str):
|
||||
"""Mock send that tracks start/end of each write."""
|
||||
preview = data[:20]
|
||||
write_events.append(("start", preview))
|
||||
# Small delay to increase chance of interleaving if lock is broken
|
||||
await anyio.sleep(0.001)
|
||||
write_events.append(("end", preview))
|
||||
|
||||
mock_stdin = MagicMock()
|
||||
mock_stdin.send = mock_send
|
||||
mock_stdin.aclose = AsyncMock()
|
||||
mock_process.stdin = mock_stdin
|
||||
|
||||
# Return version process first, then main process
|
||||
mock_exec.side_effect = [mock_version_process, mock_process]
|
||||
|
||||
# Create transport in streaming mode (required for stdin writes)
|
||||
async def dummy_stream():
|
||||
yield {"type": "user"}
|
||||
|
||||
|
|
@ -552,30 +535,32 @@ class TestSubprocessCLITransport:
|
|||
await transport.connect()
|
||||
assert transport.is_ready()
|
||||
|
||||
# Spawn multiple concurrent writes
|
||||
# Track write events to verify serialization
|
||||
write_events: list[tuple[str, int]] = []
|
||||
|
||||
async def tracked_send(data: str):
|
||||
msg_id = json.loads(data.strip())["id"]
|
||||
write_events.append(("start", msg_id))
|
||||
await anyio.sleep(0.001) # Yield to expose interleaving
|
||||
write_events.append(("end", msg_id))
|
||||
|
||||
transport._stdin_stream.send = tracked_send
|
||||
|
||||
# Spawn concurrent writes
|
||||
num_writes = 5
|
||||
messages = [
|
||||
json.dumps({"id": i, "data": "x" * 50}) + "\n"
|
||||
for i in range(num_writes)
|
||||
]
|
||||
messages = [json.dumps({"id": i}) + "\n" for i in range(num_writes)]
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for msg in messages:
|
||||
tg.start_soon(transport.write, msg)
|
||||
|
||||
# Verify serialization: events should be strictly alternating start/end pairs
|
||||
# If writes interleaved, we'd see patterns like: start, start, end, end
|
||||
# Verify: each start must be immediately followed by its end
|
||||
# If interleaved, we'd see: start(0), start(1), end(0), end(1)
|
||||
assert len(write_events) == num_writes * 2
|
||||
|
||||
for i in range(0, len(write_events), 2):
|
||||
assert write_events[i][0] == "start", f"Expected start at index {i}"
|
||||
assert write_events[i + 1][0] == "end", (
|
||||
f"Expected end at index {i + 1}"
|
||||
)
|
||||
# Each start should be followed by its corresponding end
|
||||
assert write_events[i][1] == write_events[i + 1][1], (
|
||||
f"Mismatched start/end at index {i}"
|
||||
)
|
||||
assert write_events[i][0] == "start"
|
||||
assert write_events[i + 1][0] == "end"
|
||||
assert write_events[i][1] == write_events[i + 1][1]
|
||||
|
||||
await transport.close()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue