mirror of
https://github.com/anthropics/claude-code-sdk-python.git
synced 2025-12-23 09:19:52 +00:00
Merge remote-tracking branch 'origin/dickson/control' into feat/sdk-mcp-server-support
This commit is contained in:
commit
ef997feb49
3 changed files with 139 additions and 21 deletions
|
|
@ -73,6 +73,33 @@ def create_mock_transport(with_init_response=True):
|
|||
except (json.JSONDecodeError, KeyError, AttributeError):
|
||||
pass
|
||||
|
||||
# Keep checking for other control requests (like interrupt)
|
||||
last_check = len(written_messages)
|
||||
timeout_counter = 0
|
||||
while timeout_counter < 100: # Avoid infinite loop
|
||||
await asyncio.sleep(0.01)
|
||||
timeout_counter += 1
|
||||
|
||||
# Check for new messages
|
||||
for msg_str in written_messages[last_check:]:
|
||||
try:
|
||||
msg = json.loads(msg_str.strip())
|
||||
if msg.get("type") == "control_request":
|
||||
subtype = msg.get("request", {}).get("subtype")
|
||||
if subtype == "interrupt":
|
||||
# Send interrupt response
|
||||
yield {
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"request_id": msg.get("request_id"),
|
||||
"subtype": "success",
|
||||
},
|
||||
}
|
||||
return # End after interrupt
|
||||
except (json.JSONDecodeError, KeyError, AttributeError):
|
||||
pass
|
||||
last_check = len(written_messages)
|
||||
|
||||
# Then end the stream
|
||||
return
|
||||
|
||||
|
|
@ -326,8 +353,34 @@ class TestClaudeSDKClientStreaming:
|
|||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
# Mock the message stream
|
||||
# Mock the message stream with control protocol support
|
||||
async def mock_receive():
|
||||
# First handle initialization
|
||||
await asyncio.sleep(0.01)
|
||||
written = mock_transport.write.call_args_list
|
||||
for call in written:
|
||||
data = call[0][0]
|
||||
try:
|
||||
msg = json.loads(data.strip())
|
||||
if (
|
||||
msg.get("type") == "control_request"
|
||||
and msg.get("request", {}).get("subtype")
|
||||
== "initialize"
|
||||
):
|
||||
yield {
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"request_id": msg.get("request_id"),
|
||||
"subtype": "success",
|
||||
"commands": [],
|
||||
"output_style": "default",
|
||||
},
|
||||
}
|
||||
break
|
||||
except (json.JSONDecodeError, KeyError, AttributeError):
|
||||
pass
|
||||
|
||||
# Then yield the actual messages
|
||||
yield {
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
|
|
@ -383,8 +436,24 @@ class TestClaudeSDKClientStreaming:
|
|||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
# Interrupt is now handled via control protocol
|
||||
await client.interrupt()
|
||||
mock_transport.interrupt.assert_called_once()
|
||||
# Check that a control request was sent via write
|
||||
write_calls = mock_transport.write.call_args_list
|
||||
interrupt_found = False
|
||||
for call in write_calls:
|
||||
data = call[0][0]
|
||||
try:
|
||||
msg = json.loads(data.strip())
|
||||
if (
|
||||
msg.get("type") == "control_request"
|
||||
and msg.get("request", {}).get("subtype") == "interrupt"
|
||||
):
|
||||
interrupt_found = True
|
||||
break
|
||||
except (json.JSONDecodeError, KeyError, AttributeError):
|
||||
pass
|
||||
assert interrupt_found, "Interrupt control request not found"
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
|
|
@ -433,8 +502,31 @@ class TestClaudeSDKClientStreaming:
|
|||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
# Mock receive to wait then yield messages
|
||||
# Mock receive to wait then yield messages with control protocol support
|
||||
async def mock_receive():
|
||||
# First handle initialization
|
||||
await asyncio.sleep(0.01)
|
||||
written = mock_transport.write.call_args_list
|
||||
for call in written:
|
||||
if call:
|
||||
data = call[0][0]
|
||||
try:
|
||||
msg = json.loads(data.strip())
|
||||
if msg.get("type") == "control_request" and msg.get("request", {}).get("subtype") == "initialize":
|
||||
yield {
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"request_id": msg.get("request_id"),
|
||||
"subtype": "success",
|
||||
"commands": [],
|
||||
"output_style": "default"
|
||||
}
|
||||
}
|
||||
break
|
||||
except (json.JSONDecodeError, KeyError, AttributeError):
|
||||
pass
|
||||
|
||||
# Then yield the actual messages
|
||||
await asyncio.sleep(0.1)
|
||||
yield {
|
||||
"type": "assistant",
|
||||
|
|
@ -579,8 +671,11 @@ class TestClaudeSDKClientEdgeCases:
|
|||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
# Create a new mock transport for each call
|
||||
mock_transport_class.side_effect = [
|
||||
create_mock_transport(),
|
||||
create_mock_transport()
|
||||
]
|
||||
|
||||
client = ClaudeSDKClient()
|
||||
await client.connect()
|
||||
|
|
@ -631,8 +726,31 @@ class TestClaudeSDKClientEdgeCases:
|
|||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
# Mock the message stream
|
||||
# Mock the message stream with control protocol support
|
||||
async def mock_receive():
|
||||
# First handle initialization
|
||||
await asyncio.sleep(0.01)
|
||||
written = mock_transport.write.call_args_list
|
||||
for call in written:
|
||||
if call:
|
||||
data = call[0][0]
|
||||
try:
|
||||
msg = json.loads(data.strip())
|
||||
if msg.get("type") == "control_request" and msg.get("request", {}).get("subtype") == "initialize":
|
||||
yield {
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"request_id": msg.get("request_id"),
|
||||
"subtype": "success",
|
||||
"commands": [],
|
||||
"output_style": "default"
|
||||
}
|
||||
}
|
||||
break
|
||||
except (json.JSONDecodeError, KeyError, AttributeError):
|
||||
pass
|
||||
|
||||
# Then yield the actual messages
|
||||
yield {
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ class TestSubprocessBuffering:
|
|||
transport._stderr_stream = MockTextReceiveStream([]) # type: ignore[assignment]
|
||||
|
||||
messages: list[Any] = []
|
||||
async for msg in transport.receive_messages():
|
||||
async for msg in transport.read_messages():
|
||||
messages.append(msg)
|
||||
|
||||
assert len(messages) == 2
|
||||
|
|
@ -97,7 +97,7 @@ class TestSubprocessBuffering:
|
|||
transport._stderr_stream = MockTextReceiveStream([])
|
||||
|
||||
messages: list[Any] = []
|
||||
async for msg in transport.receive_messages():
|
||||
async for msg in transport.read_messages():
|
||||
messages.append(msg)
|
||||
|
||||
assert len(messages) == 2
|
||||
|
|
@ -127,7 +127,7 @@ class TestSubprocessBuffering:
|
|||
transport._stderr_stream = MockTextReceiveStream([])
|
||||
|
||||
messages: list[Any] = []
|
||||
async for msg in transport.receive_messages():
|
||||
async for msg in transport.read_messages():
|
||||
messages.append(msg)
|
||||
|
||||
assert len(messages) == 2
|
||||
|
|
@ -173,7 +173,7 @@ class TestSubprocessBuffering:
|
|||
transport._stderr_stream = MockTextReceiveStream([])
|
||||
|
||||
messages: list[Any] = []
|
||||
async for msg in transport.receive_messages():
|
||||
async for msg in transport.read_messages():
|
||||
messages.append(msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
|
|
@ -221,7 +221,7 @@ class TestSubprocessBuffering:
|
|||
transport._stderr_stream = MockTextReceiveStream([])
|
||||
|
||||
messages: list[Any] = []
|
||||
async for msg in transport.receive_messages():
|
||||
async for msg in transport.read_messages():
|
||||
messages.append(msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
|
|
@ -252,7 +252,7 @@ class TestSubprocessBuffering:
|
|||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
messages: list[Any] = []
|
||||
async for msg in transport.receive_messages():
|
||||
async for msg in transport.read_messages():
|
||||
messages.append(msg)
|
||||
|
||||
assert isinstance(exc_info.value, CLIJSONDecodeError)
|
||||
|
|
@ -293,7 +293,7 @@ class TestSubprocessBuffering:
|
|||
transport._stderr_stream = MockTextReceiveStream([])
|
||||
|
||||
messages: list[Any] = []
|
||||
async for msg in transport.receive_messages():
|
||||
async for msg in transport.read_messages():
|
||||
messages.append(msg)
|
||||
|
||||
assert len(messages) == 3
|
||||
|
|
|
|||
|
|
@ -112,8 +112,8 @@ class TestSubprocessCLITransport:
|
|||
assert "--resume" in cmd
|
||||
assert "session-123" in cmd
|
||||
|
||||
def test_connect_disconnect(self):
|
||||
"""Test connect and disconnect lifecycle."""
|
||||
def test_connect_close(self):
|
||||
"""Test connect and close lifecycle."""
|
||||
|
||||
async def _test():
|
||||
with patch("anyio.open_process") as mock_exec:
|
||||
|
|
@ -139,22 +139,22 @@ class TestSubprocessCLITransport:
|
|||
|
||||
await transport.connect()
|
||||
assert transport._process is not None
|
||||
assert transport.is_connected()
|
||||
assert transport.is_ready()
|
||||
|
||||
await transport.disconnect()
|
||||
await transport.close()
|
||||
mock_process.terminate.assert_called_once()
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_receive_messages(self):
|
||||
"""Test parsing messages from CLI output."""
|
||||
# This test is simplified to just test the parsing logic
|
||||
def test_read_messages(self):
|
||||
"""Test reading messages from CLI output."""
|
||||
# This test is simplified to just test the transport creation
|
||||
# The full async stream handling is tested in integration tests
|
||||
transport = SubprocessCLITransport(
|
||||
prompt="test", options=ClaudeCodeOptions(), cli_path="/usr/bin/claude"
|
||||
)
|
||||
|
||||
# The actual message parsing is done by the client, not the transport
|
||||
# The transport now just provides raw message reading via read_messages()
|
||||
# So we just verify the transport can be created and basic structure is correct
|
||||
assert transport._prompt == "test"
|
||||
assert transport._cli_path == "/usr/bin/claude"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue