Merge remote-tracking branch 'origin/dickson/control' into feat/sdk-mcp-server-support

This commit is contained in:
Kashyap Murali 2025-09-01 03:48:33 -07:00
commit ef997feb49
No known key found for this signature in database
3 changed files with 139 additions and 21 deletions

View file

@ -73,6 +73,33 @@ def create_mock_transport(with_init_response=True):
except (json.JSONDecodeError, KeyError, AttributeError):
pass
# Keep checking for other control requests (like interrupt)
last_check = len(written_messages)
timeout_counter = 0
while timeout_counter < 100: # Avoid infinite loop
await asyncio.sleep(0.01)
timeout_counter += 1
# Check for new messages
for msg_str in written_messages[last_check:]:
try:
msg = json.loads(msg_str.strip())
if msg.get("type") == "control_request":
subtype = msg.get("request", {}).get("subtype")
if subtype == "interrupt":
# Send interrupt response
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
},
}
return # End after interrupt
except (json.JSONDecodeError, KeyError, AttributeError):
pass
last_check = len(written_messages)
# Then end the stream
return
@ -326,8 +353,34 @@ class TestClaudeSDKClientStreaming:
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
# Mock the message stream
# Mock the message stream with control protocol support
async def mock_receive():
# First handle initialization
await asyncio.sleep(0.01)
written = mock_transport.write.call_args_list
for call in written:
data = call[0][0]
try:
msg = json.loads(data.strip())
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype")
== "initialize"
):
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
"output_style": "default",
},
}
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
# Then yield the actual messages
yield {
"type": "assistant",
"message": {
@ -383,8 +436,24 @@ class TestClaudeSDKClientStreaming:
mock_transport_class.return_value = mock_transport
async with ClaudeSDKClient() as client:
# Interrupt is now handled via control protocol
await client.interrupt()
mock_transport.interrupt.assert_called_once()
# Check that a control request was sent via write
write_calls = mock_transport.write.call_args_list
interrupt_found = False
for call in write_calls:
data = call[0][0]
try:
msg = json.loads(data.strip())
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype") == "interrupt"
):
interrupt_found = True
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
assert interrupt_found, "Interrupt control request not found"
anyio.run(_test)
@ -433,8 +502,31 @@ class TestClaudeSDKClientStreaming:
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
# Mock receive to wait then yield messages
# Mock receive to wait then yield messages with control protocol support
async def mock_receive():
# First handle initialization
await asyncio.sleep(0.01)
written = mock_transport.write.call_args_list
for call in written:
if call:
data = call[0][0]
try:
msg = json.loads(data.strip())
if msg.get("type") == "control_request" and msg.get("request", {}).get("subtype") == "initialize":
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
"output_style": "default"
}
}
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
# Then yield the actual messages
await asyncio.sleep(0.1)
yield {
"type": "assistant",
@ -579,8 +671,11 @@ class TestClaudeSDKClientEdgeCases:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
# Create a new mock transport for each call
mock_transport_class.side_effect = [
create_mock_transport(),
create_mock_transport()
]
client = ClaudeSDKClient()
await client.connect()
@ -631,8 +726,31 @@ class TestClaudeSDKClientEdgeCases:
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
# Mock the message stream
# Mock the message stream with control protocol support
async def mock_receive():
# First handle initialization
await asyncio.sleep(0.01)
written = mock_transport.write.call_args_list
for call in written:
if call:
data = call[0][0]
try:
msg = json.loads(data.strip())
if msg.get("type") == "control_request" and msg.get("request", {}).get("subtype") == "initialize":
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
"output_style": "default"
}
}
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
# Then yield the actual messages
yield {
"type": "assistant",
"message": {

View file

@ -63,7 +63,7 @@ class TestSubprocessBuffering:
transport._stderr_stream = MockTextReceiveStream([]) # type: ignore[assignment]
messages: list[Any] = []
async for msg in transport.receive_messages():
async for msg in transport.read_messages():
messages.append(msg)
assert len(messages) == 2
@ -97,7 +97,7 @@ class TestSubprocessBuffering:
transport._stderr_stream = MockTextReceiveStream([])
messages: list[Any] = []
async for msg in transport.receive_messages():
async for msg in transport.read_messages():
messages.append(msg)
assert len(messages) == 2
@ -127,7 +127,7 @@ class TestSubprocessBuffering:
transport._stderr_stream = MockTextReceiveStream([])
messages: list[Any] = []
async for msg in transport.receive_messages():
async for msg in transport.read_messages():
messages.append(msg)
assert len(messages) == 2
@ -173,7 +173,7 @@ class TestSubprocessBuffering:
transport._stderr_stream = MockTextReceiveStream([])
messages: list[Any] = []
async for msg in transport.receive_messages():
async for msg in transport.read_messages():
messages.append(msg)
assert len(messages) == 1
@ -221,7 +221,7 @@ class TestSubprocessBuffering:
transport._stderr_stream = MockTextReceiveStream([])
messages: list[Any] = []
async for msg in transport.receive_messages():
async for msg in transport.read_messages():
messages.append(msg)
assert len(messages) == 1
@ -252,7 +252,7 @@ class TestSubprocessBuffering:
with pytest.raises(Exception) as exc_info:
messages: list[Any] = []
async for msg in transport.receive_messages():
async for msg in transport.read_messages():
messages.append(msg)
assert isinstance(exc_info.value, CLIJSONDecodeError)
@ -293,7 +293,7 @@ class TestSubprocessBuffering:
transport._stderr_stream = MockTextReceiveStream([])
messages: list[Any] = []
async for msg in transport.receive_messages():
async for msg in transport.read_messages():
messages.append(msg)
assert len(messages) == 3

View file

@ -112,8 +112,8 @@ class TestSubprocessCLITransport:
assert "--resume" in cmd
assert "session-123" in cmd
def test_connect_disconnect(self):
"""Test connect and disconnect lifecycle."""
def test_connect_close(self):
"""Test connect and close lifecycle."""
async def _test():
with patch("anyio.open_process") as mock_exec:
@ -139,22 +139,22 @@ class TestSubprocessCLITransport:
await transport.connect()
assert transport._process is not None
assert transport.is_connected()
assert transport.is_ready()
await transport.disconnect()
await transport.close()
mock_process.terminate.assert_called_once()
anyio.run(_test)
def test_receive_messages(self):
"""Test parsing messages from CLI output."""
# This test is simplified to just test the parsing logic
def test_read_messages(self):
"""Test reading messages from CLI output."""
# This test is simplified to just test the transport creation
# The full async stream handling is tested in integration tests
transport = SubprocessCLITransport(
prompt="test", options=ClaudeCodeOptions(), cli_path="/usr/bin/claude"
)
# The actual message parsing is done by the client, not the transport
# The transport now just provides raw message reading via read_messages()
# So we just verify the transport can be created and basic structure is correct
assert transport._prompt == "test"
assert transport._cli_path == "/usr/bin/claude"