diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index a590acd..5df1039 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -73,6 +73,33 @@ def create_mock_transport(with_init_response=True): except (json.JSONDecodeError, KeyError, AttributeError): pass + # Keep checking for other control requests (like interrupt) + last_check = len(written_messages) + timeout_counter = 0 + while timeout_counter < 100: # Avoid infinite loop + await asyncio.sleep(0.01) + timeout_counter += 1 + + # Check for new messages + for msg_str in written_messages[last_check:]: + try: + msg = json.loads(msg_str.strip()) + if msg.get("type") == "control_request": + subtype = msg.get("request", {}).get("subtype") + if subtype == "interrupt": + # Send interrupt response + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + }, + } + return # End after interrupt + except (json.JSONDecodeError, KeyError, AttributeError): + pass + last_check = len(written_messages) + # Then end the stream return @@ -326,8 +353,34 @@ class TestClaudeSDKClientStreaming: mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport - # Mock the message stream + # Mock the message stream with control protocol support async def mock_receive(): + # First handle initialization + await asyncio.sleep(0.01) + written = mock_transport.write.call_args_list + for call in written: + data = call[0][0] + try: + msg = json.loads(data.strip()) + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") + == "initialize" + ): + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + "commands": [], + "output_style": "default", + }, + } + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + + # Then yield the actual messages yield { "type": "assistant", "message": { @@ -383,8 +436,24 @@ class TestClaudeSDKClientStreaming: mock_transport_class.return_value = mock_transport async with ClaudeSDKClient() as client: + # Interrupt is now handled via control protocol await client.interrupt() - mock_transport.interrupt.assert_called_once() + # Check that a control request was sent via write + write_calls = mock_transport.write.call_args_list + interrupt_found = False + for call in write_calls: + data = call[0][0] + try: + msg = json.loads(data.strip()) + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") == "interrupt" + ): + interrupt_found = True + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + assert interrupt_found, "Interrupt control request not found" anyio.run(_test)