From baee8bae42a688995b1f4d40bf5696c1a59859e2 Mon Sep 17 00:00:00 2001 From: Ashwin Bhat Date: Sun, 31 Aug 2025 21:59:47 -0700 Subject: [PATCH 1/3] Reorganize imports in client.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Move json import to top-level imports - Add asyncio and contextlib.suppress imports for future use - Remove inline json import from query method 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- src/claude_code_sdk/client.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py index 216b601..a78189f 100644 --- a/src/claude_code_sdk/client.py +++ b/src/claude_code_sdk/client.py @@ -1,7 +1,10 @@ """Claude SDK Client for interacting with Claude Code.""" +import asyncio +import json import os from collections.abc import AsyncIterable, AsyncIterator +from contextlib import suppress from typing import Any from ._errors import CLIConnectionError @@ -164,8 +167,6 @@ class ClaudeSDKClient: if not self._query or not self._transport: raise CLIConnectionError("Not connected. Call connect() first.") - import json - # Handle string prompts if isinstance(prompt, str): message = { From 1adacaffe816b7ca0b33cac946defbfa9f31c149 Mon Sep 17 00:00:00 2001 From: Ashwin Bhat Date: Sun, 31 Aug 2025 22:04:21 -0700 Subject: [PATCH 2/3] Revert "Reorganize imports in client.py" This reverts commit baee8bae42a688995b1f4d40bf5696c1a59859e2. --- src/claude_code_sdk/client.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py index a78189f..216b601 100644 --- a/src/claude_code_sdk/client.py +++ b/src/claude_code_sdk/client.py @@ -1,10 +1,7 @@ """Claude SDK Client for interacting with Claude Code.""" -import asyncio -import json import os from collections.abc import AsyncIterable, AsyncIterator -from contextlib import suppress from typing import Any from ._errors import CLIConnectionError @@ -167,6 +164,8 @@ class ClaudeSDKClient: if not self._query or not self._transport: raise CLIConnectionError("Not connected. Call connect() first.") + import json + # Handle string prompts if isinstance(prompt, str): message = { From 105da74f6805f3d1204045c85acf31d236d08f1c Mon Sep 17 00:00:00 2001 From: Kashyap Murali Date: Mon, 1 Sep 2025 03:00:14 -0700 Subject: [PATCH 3/3] clean up sdk refactor --- src/claude_code_sdk/_internal/client.py | 5 +- src/claude_code_sdk/_internal/query.py | 18 +- .../_internal/transport/__init__.py | 4 +- .../_internal/transport/subprocess_cli.py | 33 ++-- src/claude_code_sdk/client.py | 8 +- tests/test_client.py | 9 +- tests/test_integration.py | 23 ++- tests/test_streaming_client.py | 173 ++++++++++++++---- 8 files changed, 201 insertions(+), 72 deletions(-) diff --git a/src/claude_code_sdk/_internal/client.py b/src/claude_code_sdk/_internal/client.py index ccfc1e8..373695d 100644 --- a/src/claude_code_sdk/_internal/client.py +++ b/src/claude_code_sdk/_internal/client.py @@ -1,5 +1,6 @@ """Internal client implementation.""" +import asyncio from collections.abc import AsyncIterable, AsyncIterator from typing import Any @@ -56,8 +57,6 @@ class InternalClient: # Stream input if it's an AsyncIterable if isinstance(prompt, AsyncIterable): # Start streaming in background - import asyncio - asyncio.create_task(query.stream_input(prompt)) # For string prompts, the prompt is already passed via CLI args @@ -66,4 +65,4 @@ class InternalClient: yield parse_message(data) finally: - query.close() + await query.close() diff --git a/src/claude_code_sdk/_internal/query.py b/src/claude_code_sdk/_internal/query.py index 815522f..f9df256 100644 --- a/src/claude_code_sdk/_internal/query.py +++ b/src/claude_code_sdk/_internal/query.py @@ -5,6 +5,7 @@ import json import logging import os from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable +from contextlib import suppress from typing import Any from .transport import Transport @@ -140,12 +141,16 @@ class Query: # Regular SDK messages go to the queue await self._message_queue.put(message) + except asyncio.CancelledError: + # Task was cancelled - this is expected behavior + logger.debug("Read task cancelled") + raise # Re-raise to properly handle cancellation except Exception as e: - logger.debug(f"Error reading messages: {e}") + logger.error(f"Fatal error in message reader: {e}") # Put error in queue so iterators can handle it await self._message_queue.put({"type": "error", "error": str(e)}) finally: - # Signal end of stream + # Always signal end of stream await self._message_queue.put({"type": "end"}) async def _handle_control_request(self, request: dict[str, Any]) -> None: @@ -262,7 +267,7 @@ class Query: break await self.transport.write(json.dumps(message) + "\n") # After all messages sent, end input - self.transport.end_input() + await self.transport.end_input() except Exception as e: logger.debug(f"Error streaming input: {e}") @@ -279,12 +284,15 @@ class Query: yield message - def close(self) -> None: + async def close(self) -> None: """Close the query and transport.""" self._closed = True if self._read_task and not self._read_task.done(): self._read_task.cancel() - self.transport.close() + # Wait for task to complete cancellation + with suppress(asyncio.CancelledError): + await self._read_task + await self.transport.close() # Make Query an async iterator def __aiter__(self) -> AsyncIterator[dict[str, Any]]: diff --git a/src/claude_code_sdk/_internal/transport/__init__.py b/src/claude_code_sdk/_internal/transport/__init__.py index 9773996..6dedef6 100644 --- a/src/claude_code_sdk/_internal/transport/__init__.py +++ b/src/claude_code_sdk/_internal/transport/__init__.py @@ -46,7 +46,7 @@ class Transport(ABC): pass @abstractmethod - def close(self) -> None: + async def close(self) -> None: """Close the transport connection and clean up resources.""" pass @@ -60,7 +60,7 @@ class Transport(ABC): pass @abstractmethod - def end_input(self) -> None: + async def end_input(self) -> None: """End the input stream (close stdin for process transports).""" pass diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 2048d60..b7be26a 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -7,6 +7,7 @@ import shutil import tempfile from collections import deque from collections.abc import AsyncIterable, AsyncIterator +from contextlib import suppress from pathlib import Path from subprocess import PIPE from typing import Any @@ -208,20 +209,31 @@ class SubprocessCLITransport(Transport): except Exception as e: raise CLIConnectionError(f"Failed to start Claude Code: {e}") from e - def close(self) -> None: + async def close(self) -> None: """Close the transport and clean up resources.""" self._ready = False if not self._process: return - if self._process.returncode is None: - from contextlib import suppress + # Close stdin first if it's still open + if self._stdin_stream: + with suppress(Exception): + await self._stdin_stream.aclose() + self._stdin_stream = None + if self._process.stdin: + with suppress(Exception): + await self._process.stdin.aclose() + + # Terminate and wait for process + if self._process.returncode is None: with suppress(ProcessLookupError): self._process.terminate() - # Note: We can't use async wait here since close() is sync - # The process will be cleaned up by the OS + # Wait for process to finish with timeout + with suppress(Exception): + # Just try to wait, but don't block if it fails + await self._process.wait() # Clean up temp file if self._stderr_file: @@ -244,18 +256,15 @@ class SubprocessCLITransport(Transport): await self._stdin_stream.send(data) - def end_input(self) -> None: + async def end_input(self) -> None: """End the input stream (close stdin).""" if self._stdin_stream: - # Note: We can't use async aclose here since end_input() is sync - # Just mark it as None and let cleanup happen later + with suppress(Exception): + await self._stdin_stream.aclose() self._stdin_stream = None if self._process and self._process.stdin: - from contextlib import suppress - with suppress(Exception): - # Mark stdin as closed - actual close will happen during cleanup - pass + await self._process.stdin.aclose() def read_messages(self) -> AsyncIterator[dict[str, Any]]: """Read and parse messages from the transport.""" diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py index 216b601..2875d70 100644 --- a/src/claude_code_sdk/client.py +++ b/src/claude_code_sdk/client.py @@ -1,5 +1,7 @@ """Claude SDK Client for interacting with Claude Code.""" +import asyncio +import json import os from collections.abc import AsyncIterable, AsyncIterator from typing import Any @@ -137,8 +139,6 @@ class ClaudeSDKClient: # If we have an initial prompt stream, start streaming it if prompt is not None and isinstance(prompt, AsyncIterable): - import asyncio - asyncio.create_task(self._query.stream_input(prompt)) async def receive_messages(self) -> AsyncIterator[Message]: @@ -164,8 +164,6 @@ class ClaudeSDKClient: if not self._query or not self._transport: raise CLIConnectionError("Not connected. Call connect() first.") - import json - # Handle string prompts if isinstance(prompt, str): message = { @@ -258,7 +256,7 @@ class ClaudeSDKClient: async def disconnect(self) -> None: """Disconnect from Claude.""" if self._query: - self._query.close() + await self._query.close() self._query = None self._transport = None diff --git a/tests/test_client.py b/tests/test_client.py index 8156010..df1d087 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,6 +1,6 @@ """Tests for Claude SDK client functionality.""" -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch import anyio @@ -102,9 +102,12 @@ class TestQueryFunction: "total_cost_usd": 0.001, } - mock_transport.receive_messages = mock_receive + mock_transport.read_messages = mock_receive mock_transport.connect = AsyncMock() - mock_transport.disconnect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.end_input = AsyncMock() + mock_transport.write = AsyncMock() + mock_transport.is_ready = Mock(return_value=True) options = ClaudeCodeOptions(cwd="/custom/path") messages = [] diff --git a/tests/test_integration.py b/tests/test_integration.py index aa6d12e..c3e4feb 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -3,7 +3,7 @@ These tests verify end-to-end functionality with mocked CLI responses. """ -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch import anyio import pytest @@ -52,9 +52,12 @@ class TestIntegration: "total_cost_usd": 0.001, } - mock_transport.receive_messages = mock_receive + mock_transport.read_messages = mock_receive mock_transport.connect = AsyncMock() - mock_transport.disconnect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.end_input = AsyncMock() + mock_transport.write = AsyncMock() + mock_transport.is_ready = Mock(return_value=True) # Run query messages = [] @@ -118,9 +121,12 @@ class TestIntegration: "total_cost_usd": 0.002, } - mock_transport.receive_messages = mock_receive + mock_transport.read_messages = mock_receive mock_transport.connect = AsyncMock() - mock_transport.disconnect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.end_input = AsyncMock() + mock_transport.write = AsyncMock() + mock_transport.is_ready = Mock(return_value=True) # Run query with tools enabled messages = [] @@ -185,9 +191,12 @@ class TestIntegration: }, } - mock_transport.receive_messages = mock_receive + mock_transport.read_messages = mock_receive mock_transport.connect = AsyncMock() - mock_transport.disconnect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.end_input = AsyncMock() + mock_transport.write = AsyncMock() + mock_transport.is_ready = Mock(return_value=True) # Run query with continuation messages = [] diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index a9c2bb3..a590acd 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -1,10 +1,11 @@ """Tests for ClaudeSDKClient streaming functionality and query() with async iterables.""" import asyncio +import json import sys import tempfile from pathlib import Path -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch import anyio import pytest @@ -22,6 +23,63 @@ from claude_code_sdk import ( from claude_code_sdk._internal.transport.subprocess_cli import SubprocessCLITransport +def create_mock_transport(with_init_response=True): + """Create a properly configured mock transport. + + Args: + with_init_response: If True, automatically respond to initialization request + """ + mock_transport = AsyncMock() + mock_transport.connect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.end_input = AsyncMock() + mock_transport.write = AsyncMock() + mock_transport.is_ready = Mock(return_value=True) + + # Track written messages to simulate control protocol responses + written_messages = [] + + async def mock_write(data): + written_messages.append(data) + + mock_transport.write.side_effect = mock_write + + # Default read_messages to handle control protocol + async def control_protocol_generator(): + # Wait for initialization request if needed + if with_init_response: + # Wait a bit for the write to happen + await asyncio.sleep(0.01) + + # Check if initialization was requested + for msg_str in written_messages: + try: + msg = json.loads(msg_str.strip()) + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") == "initialize" + ): + # Send initialization response + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + "commands": [], + "output_style": "default", + }, + } + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + + # Then end the stream + return + + mock_transport.read_messages = control_protocol_generator + return mock_transport + + class TestClaudeSDKClientStreaming: """Test ClaudeSDKClient streaming functionality.""" @@ -32,7 +90,7 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport async with ClaudeSDKClient() as client: @@ -41,7 +99,7 @@ class TestClaudeSDKClientStreaming: assert client._transport is mock_transport # Verify disconnect was called on exit - mock_transport.disconnect.assert_called_once() + mock_transport.close.assert_called_once() anyio.run(_test) @@ -52,7 +110,7 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport client = ClaudeSDKClient() @@ -64,7 +122,7 @@ class TestClaudeSDKClientStreaming: await client.disconnect() # Verify disconnect was called - mock_transport.disconnect.assert_called_once() + mock_transport.close.assert_called_once() assert client._transport is None anyio.run(_test) @@ -76,7 +134,7 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport client = ClaudeSDKClient() @@ -95,7 +153,7 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport async def message_stream(): @@ -123,20 +181,30 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport async with ClaudeSDKClient() as client: await client.query("Test message") - # Verify send_request was called with correct format - mock_transport.send_request.assert_called_once() - call_args = mock_transport.send_request.call_args - messages, options = call_args[0] - assert len(messages) == 1 - assert messages[0]["type"] == "user" - assert messages[0]["message"]["content"] == "Test message" - assert options["session_id"] == "default" + # Verify write was called with correct format + # Should have at least 2 writes: init request and user message + assert mock_transport.write.call_count >= 2 + + # Find the user message in the write calls + user_msg_found = False + for call in mock_transport.write.call_args_list: + data = call[0][0] + try: + msg = json.loads(data.strip()) + if msg.get("type") == "user": + assert msg["message"]["content"] == "Test message" + assert msg["session_id"] == "default" + user_msg_found = True + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + assert user_msg_found, "User message not found in write calls" anyio.run(_test) @@ -147,16 +215,25 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport async with ClaudeSDKClient() as client: await client.query("Test", session_id="custom-session") - call_args = mock_transport.send_request.call_args - messages, options = call_args[0] - assert messages[0]["session_id"] == "custom-session" - assert options["session_id"] == "custom-session" + # Find the user message with custom session ID + session_found = False + for call in mock_transport.write.call_args_list: + data = call[0][0] + try: + msg = json.loads(data.strip()) + if msg.get("type") == "user": + assert msg["session_id"] == "custom-session" + session_found = True + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + assert session_found, "User message with custom session not found" anyio.run(_test) @@ -177,11 +254,37 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport - # Mock the message stream + # Mock the message stream with control protocol support async def mock_receive(): + # First handle initialization + await asyncio.sleep(0.01) + written = mock_transport.write.call_args_list + for call in written: + data = call[0][0] + try: + msg = json.loads(data.strip()) + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") + == "initialize" + ): + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + "commands": [], + "output_style": "default", + }, + } + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + + # Then yield the actual messages yield { "type": "assistant", "message": { @@ -195,7 +298,7 @@ class TestClaudeSDKClientStreaming: "message": {"role": "user", "content": "Hi there"}, } - mock_transport.receive_messages = mock_receive + mock_transport.read_messages = mock_receive async with ClaudeSDKClient() as client: messages = [] @@ -220,7 +323,7 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport # Mock the message stream @@ -255,7 +358,7 @@ class TestClaudeSDKClientStreaming: "model": "claude-opus-4-1-20250805", } - mock_transport.receive_messages = mock_receive + mock_transport.read_messages = mock_receive async with ClaudeSDKClient() as client: messages = [] @@ -276,7 +379,7 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport async with ClaudeSDKClient() as client: @@ -308,7 +411,7 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport client = ClaudeSDKClient(options=options) @@ -327,7 +430,7 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport # Mock receive to wait then yield messages @@ -353,7 +456,7 @@ class TestClaudeSDKClientStreaming: "total_cost_usd": 0.001, } - mock_transport.receive_messages = mock_receive + mock_transport.read_messages = mock_receive async with ClaudeSDKClient() as client: # Helper to get next message @@ -476,7 +579,7 @@ class TestClaudeSDKClientEdgeCases: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport client = ClaudeSDKClient() @@ -506,7 +609,7 @@ class TestClaudeSDKClientEdgeCases: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport with pytest.raises(ValueError): @@ -514,7 +617,7 @@ class TestClaudeSDKClientEdgeCases: raise ValueError("Test error") # Disconnect should still be called - mock_transport.disconnect.assert_called_once() + mock_transport.close.assert_called_once() anyio.run(_test) @@ -525,7 +628,7 @@ class TestClaudeSDKClientEdgeCases: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport # Mock the message stream @@ -557,7 +660,7 @@ class TestClaudeSDKClientEdgeCases: "total_cost_usd": 0.001, } - mock_transport.receive_messages = mock_receive + mock_transport.read_messages = mock_receive async with ClaudeSDKClient() as client: # Test list comprehension pattern from docstring