From a50b134ecf73fdd0c4455d96c29c1598c25fe850 Mon Sep 17 00:00:00 2001 From: Kashyap Murali Date: Mon, 1 Sep 2025 03:24:02 -0700 Subject: [PATCH 1/2] Fix failing tests for control protocol implementation - Update test mocks to handle control protocol initialization - Fix interrupt test to check for control requests via write() instead of interrupt() - Add control protocol support to mock_receive functions in streaming tests - Ensure all tests properly handle init handshake before sending messages - Fix formatting issues (trailing whitespace, blank lines) All tests now properly support the new Query/Transport architecture with bidirectional control protocol communication. --- tests/test_streaming_client.py | 73 +++++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 2 deletions(-) diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index a590acd..5df1039 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -73,6 +73,33 @@ def create_mock_transport(with_init_response=True): except (json.JSONDecodeError, KeyError, AttributeError): pass + # Keep checking for other control requests (like interrupt) + last_check = len(written_messages) + timeout_counter = 0 + while timeout_counter < 100: # Avoid infinite loop + await asyncio.sleep(0.01) + timeout_counter += 1 + + # Check for new messages + for msg_str in written_messages[last_check:]: + try: + msg = json.loads(msg_str.strip()) + if msg.get("type") == "control_request": + subtype = msg.get("request", {}).get("subtype") + if subtype == "interrupt": + # Send interrupt response + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + }, + } + return # End after interrupt + except (json.JSONDecodeError, KeyError, AttributeError): + pass + last_check = len(written_messages) + # Then end the stream return @@ -326,8 +353,34 @@ class TestClaudeSDKClientStreaming: mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport - # Mock the message stream + # Mock the message stream with control protocol support async def mock_receive(): + # First handle initialization + await asyncio.sleep(0.01) + written = mock_transport.write.call_args_list + for call in written: + data = call[0][0] + try: + msg = json.loads(data.strip()) + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") + == "initialize" + ): + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + "commands": [], + "output_style": "default", + }, + } + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + + # Then yield the actual messages yield { "type": "assistant", "message": { @@ -383,8 +436,24 @@ class TestClaudeSDKClientStreaming: mock_transport_class.return_value = mock_transport async with ClaudeSDKClient() as client: + # Interrupt is now handled via control protocol await client.interrupt() - mock_transport.interrupt.assert_called_once() + # Check that a control request was sent via write + write_calls = mock_transport.write.call_args_list + interrupt_found = False + for call in write_calls: + data = call[0][0] + try: + msg = json.loads(data.strip()) + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") == "interrupt" + ): + interrupt_found = True + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + assert interrupt_found, "Interrupt control request not found" anyio.run(_test) From 866bdf79290e056edc9f1502ca514023dc250f23 Mon Sep 17 00:00:00 2001 From: Kashyap Murali Date: Mon, 1 Sep 2025 03:45:59 -0700 Subject: [PATCH 2/2] Fix remaining test failures for control protocol - Update test_transport.py: disconnect() -> close(), is_connected() -> is_ready() - Update test_subprocess_buffering.py: receive_messages() -> read_messages() - Fix test_streaming_client.py edge cases: - Add control protocol init to test_double_connect (use side_effect for multiple mocks) - Add control protocol init to test_receive_response_list_comprehension - Add control protocol init to test_concurrent_send_receive - All transport method calls now use read_messages() instead of receive_messages() Test results: 89 of 90 tests passing (1 subprocess test still needs work) --- tests/test_streaming_client.py | 57 +++++++++++++++++++++++++++--- tests/test_subprocess_buffering.py | 14 ++++---- tests/test_transport.py | 16 ++++----- 3 files changed, 68 insertions(+), 19 deletions(-) diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index 5df1039..214869a 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -502,8 +502,31 @@ class TestClaudeSDKClientStreaming: mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport - # Mock receive to wait then yield messages + # Mock receive to wait then yield messages with control protocol support async def mock_receive(): + # First handle initialization + await asyncio.sleep(0.01) + written = mock_transport.write.call_args_list + for call in written: + if call: + data = call[0][0] + try: + msg = json.loads(data.strip()) + if msg.get("type") == "control_request" and msg.get("request", {}).get("subtype") == "initialize": + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + "commands": [], + "output_style": "default" + } + } + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + + # Then yield the actual messages await asyncio.sleep(0.1) yield { "type": "assistant", @@ -648,8 +671,11 @@ class TestClaudeSDKClientEdgeCases: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = create_mock_transport() - mock_transport_class.return_value = mock_transport + # Create a new mock transport for each call + mock_transport_class.side_effect = [ + create_mock_transport(), + create_mock_transport() + ] client = ClaudeSDKClient() await client.connect() @@ -700,8 +726,31 @@ class TestClaudeSDKClientEdgeCases: mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport - # Mock the message stream + # Mock the message stream with control protocol support async def mock_receive(): + # First handle initialization + await asyncio.sleep(0.01) + written = mock_transport.write.call_args_list + for call in written: + if call: + data = call[0][0] + try: + msg = json.loads(data.strip()) + if msg.get("type") == "control_request" and msg.get("request", {}).get("subtype") == "initialize": + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + "commands": [], + "output_style": "default" + } + } + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + + # Then yield the actual messages yield { "type": "assistant", "message": { diff --git a/tests/test_subprocess_buffering.py b/tests/test_subprocess_buffering.py index 426d42e..05584e1 100644 --- a/tests/test_subprocess_buffering.py +++ b/tests/test_subprocess_buffering.py @@ -63,7 +63,7 @@ class TestSubprocessBuffering: transport._stderr_stream = MockTextReceiveStream([]) # type: ignore[assignment] messages: list[Any] = [] - async for msg in transport.receive_messages(): + async for msg in transport.read_messages(): messages.append(msg) assert len(messages) == 2 @@ -97,7 +97,7 @@ class TestSubprocessBuffering: transport._stderr_stream = MockTextReceiveStream([]) messages: list[Any] = [] - async for msg in transport.receive_messages(): + async for msg in transport.read_messages(): messages.append(msg) assert len(messages) == 2 @@ -127,7 +127,7 @@ class TestSubprocessBuffering: transport._stderr_stream = MockTextReceiveStream([]) messages: list[Any] = [] - async for msg in transport.receive_messages(): + async for msg in transport.read_messages(): messages.append(msg) assert len(messages) == 2 @@ -173,7 +173,7 @@ class TestSubprocessBuffering: transport._stderr_stream = MockTextReceiveStream([]) messages: list[Any] = [] - async for msg in transport.receive_messages(): + async for msg in transport.read_messages(): messages.append(msg) assert len(messages) == 1 @@ -221,7 +221,7 @@ class TestSubprocessBuffering: transport._stderr_stream = MockTextReceiveStream([]) messages: list[Any] = [] - async for msg in transport.receive_messages(): + async for msg in transport.read_messages(): messages.append(msg) assert len(messages) == 1 @@ -252,7 +252,7 @@ class TestSubprocessBuffering: with pytest.raises(Exception) as exc_info: messages: list[Any] = [] - async for msg in transport.receive_messages(): + async for msg in transport.read_messages(): messages.append(msg) assert isinstance(exc_info.value, CLIJSONDecodeError) @@ -293,7 +293,7 @@ class TestSubprocessBuffering: transport._stderr_stream = MockTextReceiveStream([]) messages: list[Any] = [] - async for msg in transport.receive_messages(): + async for msg in transport.read_messages(): messages.append(msg) assert len(messages) == 3 diff --git a/tests/test_transport.py b/tests/test_transport.py index aa6a8e9..50e093c 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -112,8 +112,8 @@ class TestSubprocessCLITransport: assert "--resume" in cmd assert "session-123" in cmd - def test_connect_disconnect(self): - """Test connect and disconnect lifecycle.""" + def test_connect_close(self): + """Test connect and close lifecycle.""" async def _test(): with patch("anyio.open_process") as mock_exec: @@ -139,22 +139,22 @@ class TestSubprocessCLITransport: await transport.connect() assert transport._process is not None - assert transport.is_connected() + assert transport.is_ready() - await transport.disconnect() + await transport.close() mock_process.terminate.assert_called_once() anyio.run(_test) - def test_receive_messages(self): - """Test parsing messages from CLI output.""" - # This test is simplified to just test the parsing logic + def test_read_messages(self): + """Test reading messages from CLI output.""" + # This test is simplified to just test the transport creation # The full async stream handling is tested in integration tests transport = SubprocessCLITransport( prompt="test", options=ClaudeCodeOptions(), cli_path="/usr/bin/claude" ) - # The actual message parsing is done by the client, not the transport + # The transport now just provides raw message reading via read_messages() # So we just verify the transport can be created and basic structure is correct assert transport._prompt == "test" assert transport._cli_path == "/usr/bin/claude"