From 4c67548d8715aa8f7f51a0f1c86b63876ebe6e80 Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Thu, 17 Jul 2025 13:44:31 -0700 Subject: [PATCH 01/19] Initial implementation of bidi streaming --- .../_internal/transport/subprocess_cli.py | 111 +++++++++++++++++- 1 file changed, 105 insertions(+), 6 deletions(-) diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index c283f42..29b3967 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -4,14 +4,14 @@ import json import logging import os import shutil -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, AsyncIterable from pathlib import Path from subprocess import PIPE -from typing import Any +from typing import Any, Union import anyio from anyio.abc import Process -from anyio.streams.text import TextReceiveStream +from anyio.streams.text import TextReceiveStream, TextSendStream from ..._errors import CLIConnectionError, CLINotFoundError, ProcessError from ..._errors import CLIJSONDecodeError as SDKJSONDecodeError @@ -28,17 +28,21 @@ class SubprocessCLITransport(Transport): def __init__( self, - prompt: str, + prompt: Union[str, AsyncIterable[dict[str, Any]]], options: ClaudeCodeOptions, cli_path: str | Path | None = None, ): self._prompt = prompt + self._is_streaming = not isinstance(prompt, str) self._options = options self._cli_path = str(cli_path) if cli_path else self._find_cli() self._cwd = str(options.cwd) if options.cwd else None self._process: Process | None = None self._stdout_stream: TextReceiveStream | None = None self._stderr_stream: TextReceiveStream | None = None + self._stdin_stream: TextSendStream | None = None + self._pending_control_responses: dict[str, Any] = {} + self._request_counter = 0 def _find_cli(self) -> str: """Find Claude Code CLI binary.""" @@ -116,7 +120,14 @@ class SubprocessCLITransport(Transport): ["--mcp-config", json.dumps({"mcpServers": self._options.mcp_servers})] ) - cmd.extend(["--print", self._prompt]) + # Add prompt handling based on mode + if self._is_streaming: + # Streaming mode: use --input-format stream-json + cmd.extend(["--input-format", "stream-json"]) + else: + # String mode: use --print with the prompt + cmd.extend(["--print", self._prompt]) + return cmd async def connect(self) -> None: @@ -126,9 +137,10 @@ class SubprocessCLITransport(Transport): cmd = self._build_command() try: + # Enable stdin pipe for both modes (but we'll close it for string mode) self._process = await anyio.open_process( cmd, - stdin=None, + stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=self._cwd, @@ -139,6 +151,18 @@ class SubprocessCLITransport(Transport): self._stdout_stream = TextReceiveStream(self._process.stdout) if self._process.stderr: self._stderr_stream = TextReceiveStream(self._process.stderr) + + # Handle stdin based on mode + if self._is_streaming: + # Streaming mode: keep stdin open and start streaming task + if self._process.stdin: + self._stdin_stream = TextSendStream(self._process.stdin) + # Start streaming messages to stdin + anyio.start_soon(self._stream_to_stdin) + else: + # String mode: close stdin immediately (backward compatible) + if self._process.stdin: + await self._process.stdin.aclose() except FileNotFoundError as e: # Check if the error comes from the working directory or the CLI @@ -169,10 +193,32 @@ class SubprocessCLITransport(Transport): self._process = None self._stdout_stream = None self._stderr_stream = None + self._stdin_stream = None async def send_request(self, messages: list[Any], options: dict[str, Any]) -> None: """Not used for CLI transport - args passed via command line.""" + async def _stream_to_stdin(self) -> None: + """Stream messages to stdin for streaming mode.""" + if not self._stdin_stream or not isinstance(self._prompt, AsyncIterable): + return + + try: + async for message in self._prompt: + if not self._stdin_stream: + break + await self._stdin_stream.send(json.dumps(message) + "\n") + + # Close stdin when done + if self._stdin_stream: + await self._stdin_stream.aclose() + self._stdin_stream = None + except Exception as e: + logger.debug(f"Error streaming to stdin: {e}") + if self._stdin_stream: + await self._stdin_stream.aclose() + self._stdin_stream = None + async def receive_messages(self) -> AsyncIterator[dict[str, Any]]: """Receive messages from CLI.""" if not self._process or not self._stdout_stream: @@ -213,6 +259,15 @@ class SubprocessCLITransport(Transport): try: data = json.loads(json_buffer) json_buffer = "" + + # Handle control responses separately + if data.get("type") == "control_response": + request_id = data.get("response", {}).get("request_id") + if request_id and request_id in self._pending_control_responses: + # Store the response for the pending request + self._pending_control_responses[request_id] = data.get("response", {}) + continue + try: yield data except GeneratorExit: @@ -280,3 +335,47 @@ class SubprocessCLITransport(Transport): def is_connected(self) -> bool: """Check if subprocess is running.""" return self._process is not None and self._process.returncode is None + + async def interrupt(self) -> None: + """Send interrupt control request (only works in streaming mode).""" + if not self._is_streaming: + raise CLIConnectionError("Interrupt requires streaming mode (AsyncIterable prompt)") + + if not self._stdin_stream: + raise CLIConnectionError("Not connected or stdin not available") + + await self._send_control_request({"subtype": "interrupt"}) + + async def _send_control_request(self, request: dict[str, Any]) -> dict[str, Any]: + """Send a control request and wait for response.""" + if not self._stdin_stream: + raise CLIConnectionError("Stdin not available") + + # Generate unique request ID + self._request_counter += 1 + request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}" + + # Build control request + control_request = { + "type": "control_request", + "request_id": request_id, + "request": request + } + + # Send request + await self._stdin_stream.send(json.dumps(control_request) + "\n") + + # Wait for response with timeout + try: + with anyio.fail_after(30.0): # 30 second timeout + while request_id not in self._pending_control_responses: + await anyio.sleep(0.1) + + response = self._pending_control_responses.pop(request_id) + + if response.get("subtype") == "error": + raise CLIConnectionError(f"Control request failed: {response.get('error')}") + + return response + except TimeoutError: + raise CLIConnectionError("Control request timed out") from None From f4cff21590ebb6290fc63fe328d1e501ff05d377 Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Fri, 18 Jul 2025 00:16:18 -0700 Subject: [PATCH 02/19] Finalize streaming impl --- .../_internal/transport/subprocess_cli.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 29b3967..b17f85e 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -157,8 +157,9 @@ class SubprocessCLITransport(Transport): # Streaming mode: keep stdin open and start streaming task if self._process.stdin: self._stdin_stream = TextSendStream(self._process.stdin) - # Start streaming messages to stdin - anyio.start_soon(self._stream_to_stdin) + # Start streaming messages to stdin in background + import asyncio + asyncio.create_task(self._stream_to_stdin()) else: # String mode: close stdin immediately (backward compatible) if self._process.stdin: @@ -209,10 +210,8 @@ class SubprocessCLITransport(Transport): break await self._stdin_stream.send(json.dumps(message) + "\n") - # Close stdin when done - if self._stdin_stream: - await self._stdin_stream.aclose() - self._stdin_stream = None + # Signal EOF but keep the stream open for control messages + # This matches the TypeScript implementation which calls stdin.end() except Exception as e: logger.debug(f"Error streaming to stdin: {e}") if self._stdin_stream: From 6dd12b0df8cddcafa22a265dfdc0fdf43474bb67 Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 10:43:23 -0700 Subject: [PATCH 03/19] Implement proper client and bidi streaming --- examples/streaming_mode_example.py | 192 ++++++++++++++++ src/claude_code_sdk/__init__.py | 57 +---- src/claude_code_sdk/_internal/client.py | 74 +------ .../_internal/message_parser.py | 77 +++++++ .../_internal/transport/subprocess_cli.py | 66 ++++-- src/claude_code_sdk/client.py | 208 ++++++++++++++++++ src/claude_code_sdk/query.py | 99 +++++++++ 7 files changed, 627 insertions(+), 146 deletions(-) create mode 100644 examples/streaming_mode_example.py create mode 100644 src/claude_code_sdk/_internal/message_parser.py create mode 100644 src/claude_code_sdk/client.py create mode 100644 src/claude_code_sdk/query.py diff --git a/examples/streaming_mode_example.py b/examples/streaming_mode_example.py new file mode 100644 index 0000000..ed782a9 --- /dev/null +++ b/examples/streaming_mode_example.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +"""Example demonstrating streaming mode with bidirectional communication.""" + +import asyncio +from collections.abc import AsyncIterator + +from claude_code_sdk import ClaudeCodeOptions, ClaudeSDKClient, query + + +async def create_message_stream() -> AsyncIterator[dict]: + """Create an async stream of user messages.""" + # Example messages to send + messages = [ + { + "type": "user", + "message": { + "role": "user", + "content": "Hello! Please tell me a bit about Python async programming.", + }, + "parent_tool_use_id": None, + "session_id": "example-session-1", + }, + # Add a delay to simulate interactive conversation + None, # We'll use this as a signal to delay + { + "type": "user", + "message": { + "role": "user", + "content": "Can you give me a simple code example?", + }, + "parent_tool_use_id": None, + "session_id": "example-session-1", + }, + ] + + for msg in messages: + if msg is None: + await asyncio.sleep(2) # Simulate user thinking time + continue + yield msg + + +async def example_string_mode(): + """Example using traditional string mode (backward compatible).""" + print("=== String Mode Example ===") + + # Option 1: Using query function + async for message in query( + prompt="What is 2+2? Please give a brief answer.", options=ClaudeCodeOptions() + ): + print(f"Received: {type(message).__name__}") + if hasattr(message, "content"): + print(f" Content: {message.content}") + + print("Completed\n") + + +async def example_streaming_mode(): + """Example using new streaming mode with async iterable.""" + print("=== Streaming Mode Example ===") + + options = ClaudeCodeOptions() + + # Create message stream + message_stream = create_message_stream() + + # Use query with async iterable + message_count = 0 + async for message in query(prompt=message_stream, options=options): + message_count += 1 + msg_type = type(message).__name__ + + print(f"\nMessage #{message_count} ({msg_type}):") + + if hasattr(message, "content"): + content = message.content + if isinstance(content, list): + for block in content: + if hasattr(block, "text"): + print(f" {block.text}") + else: + print(f" {content}") + elif hasattr(message, "subtype"): + print(f" Subtype: {message.subtype}") + + print("\nCompleted") + + +async def example_with_context_manager(): + """Example using context manager for cleaner code.""" + print("=== Context Manager Example ===") + + # Simple one-shot query with automatic cleanup + async with ClaudeSDKClient() as client: + await client.send_message("What is the meaning of life?") + async for message in client.receive_messages(): + if hasattr(message, "content"): + print(f"Response: {message.content}") + + print("\nCompleted with automatic cleanup\n") + + +async def example_with_interrupt(): + """Example demonstrating interrupt functionality.""" + print("=== Streaming Mode with Interrupt Example ===") + + options = ClaudeCodeOptions() + client = ClaudeSDKClient(options=options) + + async def interruptible_stream(): + """Stream that we'll interrupt.""" + yield { + "type": "user", + "message": { + "role": "user", + "content": "Count to 1000 slowly, saying each number.", + }, + "parent_tool_use_id": None, + "session_id": "interrupt-example", + } + # Keep the stream open by waiting indefinitely + # This prevents stdin from being closed + await asyncio.Event().wait() + + try: + await client.connect(interruptible_stream()) + print("Connected - will interrupt after 3 seconds") + + # Create tasks for receiving and interrupting + async def receive_and_interrupt(): + # Start a background task to continuously receive messages + async def receive_messages(): + async for message in client.receive_messages(): + msg_type = type(message).__name__ + print(f"Received: {msg_type}") + + if hasattr(message, "content") and isinstance( + message.content, list + ): + for block in message.content: + if hasattr(block, "text"): + print(f" {block.text[:50]}...") # First 50 chars + + # Start receiving in background + receive_task = asyncio.create_task(receive_messages()) + + # Wait 3 seconds then interrupt + await asyncio.sleep(3) + print("\nSending interrupt signal...") + + try: + await client.interrupt() + print("Interrupt sent successfully") + except Exception as e: + print(f"Interrupt error: {e}") + + # Give some time to see any final messages + await asyncio.sleep(2) + + # Cancel the receive task + receive_task.cancel() + try: + await receive_task + except asyncio.CancelledError: + pass + + await receive_and_interrupt() + + except Exception as e: + print(f"Error: {e}") + finally: + await client.disconnect() + print("\nDisconnected") + + +async def main(): + """Run all examples.""" + # Run string mode example + await example_string_mode() + + # Run streaming mode example + await example_streaming_mode() + + # Run context manager example + await example_with_context_manager() + + # Run interrupt example + await example_with_interrupt() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/claude_code_sdk/__init__.py b/src/claude_code_sdk/__init__.py index b8a1152..dc84df1 100644 --- a/src/claude_code_sdk/__init__.py +++ b/src/claude_code_sdk/__init__.py @@ -1,7 +1,5 @@ """Claude SDK for Python.""" -import os -from collections.abc import AsyncIterator from ._errors import ( ClaudeSDKError, @@ -10,7 +8,8 @@ from ._errors import ( CLINotFoundError, ProcessError, ) -from ._internal.client import InternalClient +from .client import ClaudeSDKClient +from .query import query from .types import ( AssistantMessage, ClaudeCodeOptions, @@ -29,8 +28,9 @@ from .types import ( __version__ = "0.0.14" __all__ = [ - # Main function + # Main exports "query", + "ClaudeSDKClient", # Types "PermissionMode", "McpServerConfig", @@ -51,52 +51,3 @@ __all__ = [ "ProcessError", "CLIJSONDecodeError", ] - - -async def query( - *, prompt: str, options: ClaudeCodeOptions | None = None -) -> AsyncIterator[Message]: - """ - Query Claude Code. - - Python SDK for interacting with Claude Code. - - Args: - prompt: The prompt to send to Claude - options: Optional configuration (defaults to ClaudeCodeOptions() if None). - Set options.permission_mode to control tool execution: - - 'default': CLI prompts for dangerous tools - - 'acceptEdits': Auto-accept file edits - - 'bypassPermissions': Allow all tools (use with caution) - Set options.cwd for working directory. - - Yields: - Messages from the conversation - - - Example: - ```python - # Simple usage - async for message in query(prompt="Hello"): - print(message) - - # With options - async for message in query( - prompt="Hello", - options=ClaudeCodeOptions( - system_prompt="You are helpful", - cwd="/home/user" - ) - ): - print(message) - ``` - """ - if options is None: - options = ClaudeCodeOptions() - - os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py" - - client = InternalClient() - - async for message in client.process_query(prompt=prompt, options=options): - yield message diff --git a/src/claude_code_sdk/_internal/client.py b/src/claude_code_sdk/_internal/client.py index ef1070d..fb4eeb8 100644 --- a/src/claude_code_sdk/_internal/client.py +++ b/src/claude_code_sdk/_internal/client.py @@ -1,20 +1,10 @@ """Internal client implementation.""" -from collections.abc import AsyncIterator +from collections.abc import AsyncIterable, AsyncIterator from typing import Any -from ..types import ( - AssistantMessage, - ClaudeCodeOptions, - ContentBlock, - Message, - ResultMessage, - SystemMessage, - TextBlock, - ToolResultBlock, - ToolUseBlock, - UserMessage, -) +from ..types import ClaudeCodeOptions, Message +from .message_parser import parse_message from .transport.subprocess_cli import SubprocessCLITransport @@ -25,7 +15,7 @@ class InternalClient: """Initialize the internal client.""" async def process_query( - self, prompt: str, options: ClaudeCodeOptions + self, prompt: str | AsyncIterable[dict[str, Any]], options: ClaudeCodeOptions ) -> AsyncIterator[Message]: """Process a query through transport.""" @@ -35,63 +25,9 @@ class InternalClient: await transport.connect() async for data in transport.receive_messages(): - message = self._parse_message(data) + message = parse_message(data) if message: yield message finally: await transport.disconnect() - - def _parse_message(self, data: dict[str, Any]) -> Message | None: - """Parse message from CLI output, trusting the structure.""" - - match data["type"]: - case "user": - return UserMessage(content=data["message"]["content"]) - - case "assistant": - content_blocks: list[ContentBlock] = [] - for block in data["message"]["content"]: - match block["type"]: - case "text": - content_blocks.append(TextBlock(text=block["text"])) - case "tool_use": - content_blocks.append( - ToolUseBlock( - id=block["id"], - name=block["name"], - input=block["input"], - ) - ) - case "tool_result": - content_blocks.append( - ToolResultBlock( - tool_use_id=block["tool_use_id"], - content=block.get("content"), - is_error=block.get("is_error"), - ) - ) - - return AssistantMessage(content=content_blocks) - - case "system": - return SystemMessage( - subtype=data["subtype"], - data=data, - ) - - case "result": - return ResultMessage( - subtype=data["subtype"], - duration_ms=data["duration_ms"], - duration_api_ms=data["duration_api_ms"], - is_error=data["is_error"], - num_turns=data["num_turns"], - session_id=data["session_id"], - total_cost_usd=data.get("total_cost_usd"), - usage=data.get("usage"), - result=data.get("result"), - ) - - case _: - return None diff --git a/src/claude_code_sdk/_internal/message_parser.py b/src/claude_code_sdk/_internal/message_parser.py new file mode 100644 index 0000000..a2b88d2 --- /dev/null +++ b/src/claude_code_sdk/_internal/message_parser.py @@ -0,0 +1,77 @@ +"""Message parser for Claude Code SDK responses.""" + +from typing import Any + +from ..types import ( + AssistantMessage, + ContentBlock, + Message, + ResultMessage, + SystemMessage, + TextBlock, + ToolResultBlock, + ToolUseBlock, + UserMessage, +) + + +def parse_message(data: dict[str, Any]) -> Message | None: + """ + Parse message from CLI output into typed Message objects. + + Args: + data: Raw message dictionary from CLI output + + Returns: + Parsed Message object or None if type is unrecognized + """ + match data["type"]: + case "user": + return UserMessage(content=data["message"]["content"]) + + case "assistant": + content_blocks: list[ContentBlock] = [] + for block in data["message"]["content"]: + match block["type"]: + case "text": + content_blocks.append(TextBlock(text=block["text"])) + case "tool_use": + content_blocks.append( + ToolUseBlock( + id=block["id"], + name=block["name"], + input=block["input"], + ) + ) + case "tool_result": + content_blocks.append( + ToolResultBlock( + tool_use_id=block["tool_use_id"], + content=block.get("content"), + is_error=block.get("is_error"), + ) + ) + + return AssistantMessage(content=content_blocks) + + case "system": + return SystemMessage( + subtype=data["subtype"], + data=data, + ) + + case "result": + return ResultMessage( + subtype=data["subtype"], + duration_ms=data["duration_ms"], + duration_api_ms=data["duration_api_ms"], + is_error=data["is_error"], + num_turns=data["num_turns"], + session_id=data["session_id"], + total_cost_usd=data.get("total_cost_usd"), + usage=data.get("usage"), + result=data.get("result"), + ) + + case _: + return None diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index b17f85e..22f3000 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -4,10 +4,10 @@ import json import logging import os import shutil -from collections.abc import AsyncIterator, AsyncIterable +from collections.abc import AsyncIterable, AsyncIterator from pathlib import Path from subprocess import PIPE -from typing import Any, Union +from typing import Any import anyio from anyio.abc import Process @@ -28,7 +28,7 @@ class SubprocessCLITransport(Transport): def __init__( self, - prompt: Union[str, AsyncIterable[dict[str, Any]]], + prompt: str | AsyncIterable[dict[str, Any]], options: ClaudeCodeOptions, cli_path: str | Path | None = None, ): @@ -126,8 +126,8 @@ class SubprocessCLITransport(Transport): cmd.extend(["--input-format", "stream-json"]) else: # String mode: use --print with the prompt - cmd.extend(["--print", self._prompt]) - + cmd.extend(["--print", str(self._prompt)]) + return cmd async def connect(self) -> None: @@ -151,7 +151,7 @@ class SubprocessCLITransport(Transport): self._stdout_stream = TextReceiveStream(self._process.stdout) if self._process.stderr: self._stderr_stream = TextReceiveStream(self._process.stderr) - + # Handle stdin based on mode if self._is_streaming: # Streaming mode: keep stdin open and start streaming task @@ -197,21 +197,39 @@ class SubprocessCLITransport(Transport): self._stdin_stream = None async def send_request(self, messages: list[Any], options: dict[str, Any]) -> None: - """Not used for CLI transport - args passed via command line.""" + """Send additional messages in streaming mode.""" + if not self._is_streaming: + raise CLIConnectionError("send_request only works in streaming mode") + + if not self._stdin_stream: + raise CLIConnectionError("stdin not available - stream may have ended") + + # Send each message as a user message + for message in messages: + # Ensure message has required structure + if not isinstance(message, dict): + message = { + "type": "user", + "message": {"role": "user", "content": str(message)}, + "parent_tool_use_id": None, + "session_id": options.get("session_id", "default") + } + + await self._stdin_stream.send(json.dumps(message) + "\n") async def _stream_to_stdin(self) -> None: """Stream messages to stdin for streaming mode.""" if not self._stdin_stream or not isinstance(self._prompt, AsyncIterable): return - + try: async for message in self._prompt: if not self._stdin_stream: break await self._stdin_stream.send(json.dumps(message) + "\n") - - # Signal EOF but keep the stream open for control messages - # This matches the TypeScript implementation which calls stdin.end() + + # Don't close stdin - keep it open for send_request + # Users can explicitly call disconnect() when done except Exception as e: logger.debug(f"Error streaming to stdin: {e}") if self._stdin_stream: @@ -258,7 +276,7 @@ class SubprocessCLITransport(Transport): try: data = json.loads(json_buffer) json_buffer = "" - + # Handle control responses separately if data.get("type") == "control_response": request_id = data.get("response", {}).get("request_id") @@ -266,7 +284,7 @@ class SubprocessCLITransport(Transport): # Store the response for the pending request self._pending_control_responses[request_id] = data.get("response", {}) continue - + try: yield data except GeneratorExit: @@ -334,47 +352,47 @@ class SubprocessCLITransport(Transport): def is_connected(self) -> bool: """Check if subprocess is running.""" return self._process is not None and self._process.returncode is None - + async def interrupt(self) -> None: """Send interrupt control request (only works in streaming mode).""" if not self._is_streaming: raise CLIConnectionError("Interrupt requires streaming mode (AsyncIterable prompt)") - + if not self._stdin_stream: raise CLIConnectionError("Not connected or stdin not available") - + await self._send_control_request({"subtype": "interrupt"}) - + async def _send_control_request(self, request: dict[str, Any]) -> dict[str, Any]: """Send a control request and wait for response.""" if not self._stdin_stream: raise CLIConnectionError("Stdin not available") - + # Generate unique request ID self._request_counter += 1 request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}" - + # Build control request control_request = { "type": "control_request", "request_id": request_id, "request": request } - + # Send request await self._stdin_stream.send(json.dumps(control_request) + "\n") - + # Wait for response with timeout try: with anyio.fail_after(30.0): # 30 second timeout while request_id not in self._pending_control_responses: await anyio.sleep(0.1) - + response = self._pending_control_responses.pop(request_id) - + if response.get("subtype") == "error": raise CLIConnectionError(f"Control request failed: {response.get('error')}") - + return response except TimeoutError: raise CLIConnectionError("Control request timed out") from None diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py new file mode 100644 index 0000000..a75eece --- /dev/null +++ b/src/claude_code_sdk/client.py @@ -0,0 +1,208 @@ +"""Claude SDK Client for interacting with Claude Code.""" + +import os +from collections.abc import AsyncIterable, AsyncIterator + +from ._errors import CLIConnectionError +from .types import ClaudeCodeOptions, Message, ResultMessage + + +class ClaudeSDKClient: + """ + Client for bidirectional, interactive conversations with Claude Code. + + This client provides full control over the conversation flow with support + for streaming, interrupts, and dynamic message sending. For simple one-shot + queries, consider using the query() function instead. + + Key features: + - **Bidirectional**: Send and receive messages at any time + - **Stateful**: Maintains conversation context across messages + - **Interactive**: Send follow-ups based on responses + - **Control flow**: Support for interrupts and session management + + When to use ClaudeSDKClient: + - Building chat interfaces or conversational UIs + - Interactive debugging or exploration sessions + - Multi-turn conversations with context + - When you need to react to Claude's responses + - Real-time applications with user input + - When you need interrupt capabilities + + When to use query() instead: + - Simple one-off questions + - Batch processing of prompts + - Fire-and-forget automation scripts + - When all inputs are known upfront + - Stateless operations + + Example - Interactive conversation: + ```python + # Automatically connects with empty stream for interactive use + async with ClaudeSDKClient() as client: + # Send initial message + await client.send_message("Let's solve a math problem step by step") + + # Receive and process response + async for message in client.receive_messages(): + if "ready" in str(message.content).lower(): + break + + # Send follow-up based on response + await client.send_message("What's 15% of 80?") + + # Continue conversation... + # Automatically disconnects + ``` + + Example - With interrupt: + ```python + async with ClaudeSDKClient() as client: + # Start a long task + await client.send_message("Count to 1000") + + # Interrupt after 2 seconds + await asyncio.sleep(2) + await client.interrupt() + + # Send new instruction + await client.send_message("Never mind, what's 2+2?") + ``` + + Example - Manual connection: + ```python + client = ClaudeSDKClient() + + # Connect with initial message stream + async def message_stream(): + yield {"type": "user", "message": {"role": "user", "content": "Hello"}} + + await client.connect(message_stream()) + + # Send additional messages dynamically + await client.send_message("What's the weather?") + + async for message in client.receive_messages(): + print(message) + + await client.disconnect() + ``` + """ + + def __init__(self, options: ClaudeCodeOptions | None = None): + """Initialize Claude SDK client.""" + if options is None: + options = ClaudeCodeOptions() + self.options = options + self._transport = None + os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py-client" + + async def connect(self, prompt: str | AsyncIterable[dict] | None = None) -> None: + """Connect to Claude with a prompt or message stream.""" + from ._internal.transport.subprocess_cli import SubprocessCLITransport + + # Auto-connect with empty async iterable if no prompt is provided + async def _empty_stream(): + # Never yields, but indicates that this function is an iterator and + # keeps the connection open. + if False: + yield + + self._transport = SubprocessCLITransport( + prompt=_empty_stream() if prompt is None else prompt, + options=self.options, + ) + await self._transport.connect() + + async def receive_messages(self) -> AsyncIterator[Message]: + """Receive all messages from Claude.""" + if not self._transport: + raise CLIConnectionError("Not connected. Call connect() first.") + + from ._internal.message_parser import parse_message + + async for data in self._transport.receive_messages(): + message = parse_message(data) + if message: + yield message + + async def send_message(self, content: str, session_id: str = "default") -> None: + """Send a new message in streaming mode.""" + if not self._transport: + raise CLIConnectionError("Not connected. Call connect() first.") + + message = { + "type": "user", + "message": {"role": "user", "content": content}, + "parent_tool_use_id": None, + "session_id": session_id, + } + + await self._transport.send_request([message], {"session_id": session_id}) + + async def interrupt(self) -> None: + """Send interrupt signal (only works with streaming mode).""" + if not self._transport: + raise CLIConnectionError("Not connected. Call connect() first.") + await self._transport.interrupt() + + async def receive_response(self) -> tuple[list[Message], ResultMessage | None]: + """ + Receive a complete response from Claude, collecting all messages until ResultMessage. + + Compared to receive_messages(), this is a convenience method that + handles the common pattern of receiving messages until Claude completes + its response. It collects all messages and returns them along with the + final ResultMessage. + + Returns: + tuple: A tuple of (messages, result) where: + - messages: List of all messages received (UserMessage, AssistantMessage, SystemMessage) + - result: The final ResultMessage if received, None if stream ended without result + + Example: + ```python + async with ClaudeSDKClient() as client: + # First turn + await client.send_message("What's the capital of France?") + messages, result = await client.receive_response() + + # Extract assistant's response + for msg in messages: + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + # Second turn + await client.send_message("What's the population?") + messages, result = await client.receive_response() + # ... process response + ``` + """ + from .types import ResultMessage + + messages = [] + async for message in self.receive_messages(): + messages.append(message) + if isinstance(message, ResultMessage): + return messages, message + + # Stream ended without ResultMessage + return messages, None + + async def disconnect(self) -> None: + """Disconnect from Claude.""" + if self._transport: + await self._transport.disconnect() + self._transport = None + + async def __aenter__(self): + """Enter async context - automatically connects with empty stream for interactive use.""" + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Exit async context - always disconnects.""" + await self.disconnect() + return False diff --git a/src/claude_code_sdk/query.py b/src/claude_code_sdk/query.py new file mode 100644 index 0000000..4bd7e96 --- /dev/null +++ b/src/claude_code_sdk/query.py @@ -0,0 +1,99 @@ +"""Query function for one-shot interactions with Claude Code.""" + +import os +from collections.abc import AsyncIterable, AsyncIterator + +from ._internal.client import InternalClient +from .types import ClaudeCodeOptions, Message + + +async def query( + *, prompt: str | AsyncIterable[dict], options: ClaudeCodeOptions | None = None +) -> AsyncIterator[Message]: + """ + Query Claude Code for one-shot or unidirectional streaming interactions. + + This function is ideal for simple, stateless queries where you don't need + bidirectional communication or conversation management. For interactive, + stateful conversations, use ClaudeSDKClient instead. + + Key differences from ClaudeSDKClient: + - **Unidirectional**: Send all messages upfront, receive all responses + - **Stateless**: Each query is independent, no conversation state + - **Simple**: Fire-and-forget style, no connection management + - **No interrupts**: Cannot interrupt or send follow-up messages + + When to use query(): + - Simple one-off questions ("What is 2+2?") + - Batch processing of independent prompts + - Code generation or analysis tasks + - Automated scripts and CI/CD pipelines + - When you know all inputs upfront + + When to use ClaudeSDKClient: + - Interactive conversations with follow-ups + - Chat applications or REPL-like interfaces + - When you need to send messages based on responses + - When you need interrupt capabilities + - Long-running sessions with state + + Args: + prompt: The prompt to send to Claude. Can be a string for single-shot queries + or an AsyncIterable[dict] for streaming mode with continuous interaction. + In streaming mode, each dict should have the structure: + { + "type": "user", + "message": {"role": "user", "content": "..."}, + "parent_tool_use_id": None, + "session_id": "..." + } + options: Optional configuration (defaults to ClaudeCodeOptions() if None). + Set options.permission_mode to control tool execution: + - 'default': CLI prompts for dangerous tools + - 'acceptEdits': Auto-accept file edits + - 'bypassPermissions': Allow all tools (use with caution) + Set options.cwd for working directory. + + Yields: + Messages from the conversation + + Example - Simple query: + ```python + # One-off question + async for message in query(prompt="What is the capital of France?"): + print(message) + ``` + + Example - With options: + ```python + # Code generation with specific settings + async for message in query( + prompt="Create a Python web server", + options=ClaudeCodeOptions( + system_prompt="You are an expert Python developer", + cwd="/home/user/project" + ) + ): + print(message) + ``` + + Example - Streaming mode (still unidirectional): + ```python + async def prompts(): + yield {"type": "user", "message": {"role": "user", "content": "Hello"}} + yield {"type": "user", "message": {"role": "user", "content": "How are you?"}} + + # All prompts are sent, then all responses received + async for message in query(prompt=prompts()): + print(message) + ``` + """ + if options is None: + options = ClaudeCodeOptions() + + os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py" + + client = InternalClient() + + async for message in client.process_query(prompt=prompt, options=options): + yield message From 361c7e0be39ba51e016c75f16afd4e53041dbbe1 Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 13:23:53 -0700 Subject: [PATCH 04/19] Working examples --- examples/streaming_mode.py | 328 ++++++++++++++++++ examples/streaming_mode_example.py | 192 ---------- examples/streaming_mode_ipython.py | 153 ++++++++ .../_internal/transport/subprocess_cli.py | 7 +- 4 files changed, 485 insertions(+), 195 deletions(-) create mode 100644 examples/streaming_mode.py delete mode 100644 examples/streaming_mode_example.py create mode 100644 examples/streaming_mode_ipython.py diff --git a/examples/streaming_mode.py b/examples/streaming_mode.py new file mode 100644 index 0000000..cfa4455 --- /dev/null +++ b/examples/streaming_mode.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 +""" +Comprehensive examples of using ClaudeSDKClient for streaming mode. + +This file demonstrates various patterns for building applications with +the ClaudeSDKClient streaming interface. +""" + +import asyncio +import contextlib + +from claude_code_sdk import ( + AssistantMessage, + ClaudeCodeOptions, + ClaudeSDKClient, + ResultMessage, + TextBlock, +) + + +async def example_basic_streaming(): + """Basic streaming with context manager.""" + print("=== Basic Streaming Example ===") + + async with ClaudeSDKClient() as client: + # Send a message + await client.send_message("What is 2+2?") + + # Receive complete response using the helper method + messages, result = await client.receive_response() + + # Extract text from assistant's response + for msg in messages: + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + # Print cost if available + if result and result.total_cost_usd: + print(f"Cost: ${result.total_cost_usd:.4f}") + + print("Session ended\n") + + +async def example_multi_turn_conversation(): + """Multi-turn conversation using receive_response helper.""" + print("=== Multi-Turn Conversation Example ===") + + async with ClaudeSDKClient() as client: + # First turn + print("User: What's the capital of France?") + await client.send_message("What's the capital of France?") + + messages, _ = await client.receive_response() + + # Extract and print response + for msg in messages: + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + # Second turn - follow-up + print("\nUser: What's the population of that city?") + await client.send_message("What's the population of that city?") + + messages, _ = await client.receive_response() + + for msg in messages: + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + for msg in messages: + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + print("\nConversation ended\n") + + +async def example_concurrent_responses(): + """Handle responses while sending new messages.""" + print("=== Concurrent Send/Receive Example ===") + + async with ClaudeSDKClient() as client: + # Background task to continuously receive messages + async def receive_messages(): + async for message in client.receive_messages(): + if isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + # Start receiving in background + receive_task = asyncio.create_task(receive_messages()) + + # Send multiple messages with delays + questions = [ + "What is 2 + 2?", + "What is the square root of 144?", + "What is 15% of 80?", + ] + + for question in questions: + print(f"\nUser: {question}") + await client.send_message(question) + await asyncio.sleep(3) # Wait between messages + + # Give time for final responses + await asyncio.sleep(2) + + # Clean up + receive_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await receive_task + + print("\nSession ended\n") + + +async def example_with_interrupt(): + """Demonstrate interrupt capability.""" + print("=== Interrupt Example ===") + print("IMPORTANT: Interrupts require active message consumption.") + + async with ClaudeSDKClient() as client: + # Start a long-running task + print("\nUser: Count from 1 to 100 slowly") + await client.send_message( + "Count from 1 to 100 slowly, with a brief pause between each number" + ) + + # Create a background task to consume messages + messages_received = [] + interrupt_sent = False + + async def consume_messages(): + """Consume messages in the background to enable interrupt processing.""" + async for message in client.receive_messages(): + messages_received.append(message) + if isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + # Print first few numbers + print(f"Claude: {block.text[:50]}...") + + # Stop when we get a result after interrupt + if isinstance(message, ResultMessage) and interrupt_sent: + break + + # Start consuming messages in the background + consume_task = asyncio.create_task(consume_messages()) + + # Wait 2 seconds then send interrupt + await asyncio.sleep(2) + print("\n[After 2 seconds, sending interrupt...]") + interrupt_sent = True + await client.interrupt() + + # Wait for the consume task to finish processing the interrupt + await consume_task + + # Send new instruction after interrupt + print("\nUser: Never mind, just tell me a quick joke") + await client.send_message("Never mind, just tell me a quick joke") + + # Get the joke + messages, result = await client.receive_response() + + for msg in messages: + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + print("\nSession ended\n") + + +async def example_manual_message_handling(): + """Manually handle message stream for custom logic.""" + print("=== Manual Message Handling Example ===") + + async with ClaudeSDKClient() as client: + await client.send_message( + "List 5 programming languages and their main use cases" + ) + + # Manually process messages with custom logic + languages_found = [] + + async for message in client.receive_messages(): + if isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + text = block.text + # Custom logic: extract language names + for lang in [ + "Python", + "JavaScript", + "Java", + "C++", + "Go", + "Rust", + "Ruby", + ]: + if lang in text and lang not in languages_found: + languages_found.append(lang) + print(f"Found language: {lang}") + + elif isinstance(message, ResultMessage): + print(f"\nTotal languages mentioned: {len(languages_found)}") + break + + print("\nSession ended\n") + + +async def example_with_options(): + """Use ClaudeCodeOptions to configure the client.""" + print("=== Custom Options Example ===") + + # Configure options + options = ClaudeCodeOptions( + allowed_tools=["Read", "Write"], # Allow file operations + max_thinking_tokens=10000, + system_prompt="You are a helpful coding assistant.", + ) + + async with ClaudeSDKClient(options=options) as client: + await client.send_message( + "Create a simple hello.txt file with a greeting message" + ) + + messages, result = await client.receive_response() + + tool_uses = [] + for msg in messages: + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + elif hasattr(block, "name"): # ToolUseBlock + tool_uses.append(getattr(block, "name", "")) + + if tool_uses: + print(f"\nTools used: {', '.join(tool_uses)}") + + print("\nSession ended\n") + + +async def example_error_handling(): + """Demonstrate proper error handling.""" + print("=== Error Handling Example ===") + + client = ClaudeSDKClient() + + try: + # Connect with custom stream + async def message_stream(): + yield { + "type": "user", + "message": {"role": "user", "content": "Hello"}, + "parent_tool_use_id": None, + "session_id": "error-demo", + } + + await client.connect(message_stream()) + + # Create a background task to consume messages (required for interrupt to work) + consume_task = None + + async def consume_messages(): + """Background message consumer.""" + async for msg in client.receive_messages(): + if isinstance(msg, AssistantMessage): + print("Received response from Claude") + + # Receive messages with timeout + try: + # Start consuming messages in background + consume_task = asyncio.create_task(consume_messages()) + + # Wait for response with timeout + await asyncio.wait_for(consume_task, timeout=30.0) + + except asyncio.TimeoutError: + print("Response timeout - sending interrupt") + # Note: interrupt requires active message consumption + # Since we're already consuming in the background task, interrupt will work + await client.interrupt() + + # Cancel the consume task + if consume_task: + consume_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await consume_task + + except Exception as e: + print(f"Error: {e}") + + finally: + # Always disconnect + await client.disconnect() + + print("\nSession ended\n") + + +async def main(): + """Run all examples.""" + examples = [ + example_basic_streaming, + example_multi_turn_conversation, + example_concurrent_responses, + example_with_interrupt, + example_manual_message_handling, + example_with_options, + example_error_handling, + ] + + for example in examples: + await example() + print("-" * 50 + "\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/streaming_mode_example.py b/examples/streaming_mode_example.py deleted file mode 100644 index ed782a9..0000000 --- a/examples/streaming_mode_example.py +++ /dev/null @@ -1,192 +0,0 @@ -#!/usr/bin/env python3 -"""Example demonstrating streaming mode with bidirectional communication.""" - -import asyncio -from collections.abc import AsyncIterator - -from claude_code_sdk import ClaudeCodeOptions, ClaudeSDKClient, query - - -async def create_message_stream() -> AsyncIterator[dict]: - """Create an async stream of user messages.""" - # Example messages to send - messages = [ - { - "type": "user", - "message": { - "role": "user", - "content": "Hello! Please tell me a bit about Python async programming.", - }, - "parent_tool_use_id": None, - "session_id": "example-session-1", - }, - # Add a delay to simulate interactive conversation - None, # We'll use this as a signal to delay - { - "type": "user", - "message": { - "role": "user", - "content": "Can you give me a simple code example?", - }, - "parent_tool_use_id": None, - "session_id": "example-session-1", - }, - ] - - for msg in messages: - if msg is None: - await asyncio.sleep(2) # Simulate user thinking time - continue - yield msg - - -async def example_string_mode(): - """Example using traditional string mode (backward compatible).""" - print("=== String Mode Example ===") - - # Option 1: Using query function - async for message in query( - prompt="What is 2+2? Please give a brief answer.", options=ClaudeCodeOptions() - ): - print(f"Received: {type(message).__name__}") - if hasattr(message, "content"): - print(f" Content: {message.content}") - - print("Completed\n") - - -async def example_streaming_mode(): - """Example using new streaming mode with async iterable.""" - print("=== Streaming Mode Example ===") - - options = ClaudeCodeOptions() - - # Create message stream - message_stream = create_message_stream() - - # Use query with async iterable - message_count = 0 - async for message in query(prompt=message_stream, options=options): - message_count += 1 - msg_type = type(message).__name__ - - print(f"\nMessage #{message_count} ({msg_type}):") - - if hasattr(message, "content"): - content = message.content - if isinstance(content, list): - for block in content: - if hasattr(block, "text"): - print(f" {block.text}") - else: - print(f" {content}") - elif hasattr(message, "subtype"): - print(f" Subtype: {message.subtype}") - - print("\nCompleted") - - -async def example_with_context_manager(): - """Example using context manager for cleaner code.""" - print("=== Context Manager Example ===") - - # Simple one-shot query with automatic cleanup - async with ClaudeSDKClient() as client: - await client.send_message("What is the meaning of life?") - async for message in client.receive_messages(): - if hasattr(message, "content"): - print(f"Response: {message.content}") - - print("\nCompleted with automatic cleanup\n") - - -async def example_with_interrupt(): - """Example demonstrating interrupt functionality.""" - print("=== Streaming Mode with Interrupt Example ===") - - options = ClaudeCodeOptions() - client = ClaudeSDKClient(options=options) - - async def interruptible_stream(): - """Stream that we'll interrupt.""" - yield { - "type": "user", - "message": { - "role": "user", - "content": "Count to 1000 slowly, saying each number.", - }, - "parent_tool_use_id": None, - "session_id": "interrupt-example", - } - # Keep the stream open by waiting indefinitely - # This prevents stdin from being closed - await asyncio.Event().wait() - - try: - await client.connect(interruptible_stream()) - print("Connected - will interrupt after 3 seconds") - - # Create tasks for receiving and interrupting - async def receive_and_interrupt(): - # Start a background task to continuously receive messages - async def receive_messages(): - async for message in client.receive_messages(): - msg_type = type(message).__name__ - print(f"Received: {msg_type}") - - if hasattr(message, "content") and isinstance( - message.content, list - ): - for block in message.content: - if hasattr(block, "text"): - print(f" {block.text[:50]}...") # First 50 chars - - # Start receiving in background - receive_task = asyncio.create_task(receive_messages()) - - # Wait 3 seconds then interrupt - await asyncio.sleep(3) - print("\nSending interrupt signal...") - - try: - await client.interrupt() - print("Interrupt sent successfully") - except Exception as e: - print(f"Interrupt error: {e}") - - # Give some time to see any final messages - await asyncio.sleep(2) - - # Cancel the receive task - receive_task.cancel() - try: - await receive_task - except asyncio.CancelledError: - pass - - await receive_and_interrupt() - - except Exception as e: - print(f"Error: {e}") - finally: - await client.disconnect() - print("\nDisconnected") - - -async def main(): - """Run all examples.""" - # Run string mode example - await example_string_mode() - - # Run streaming mode example - await example_streaming_mode() - - # Run context manager example - await example_with_context_manager() - - # Run interrupt example - await example_with_interrupt() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/streaming_mode_ipython.py b/examples/streaming_mode_ipython.py new file mode 100644 index 0000000..fc2f7cf --- /dev/null +++ b/examples/streaming_mode_ipython.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +IPython-friendly code snippets for ClaudeSDKClient streaming mode. + +These examples are designed to be copy-pasted directly into IPython. +Each example is self-contained and can be run independently. +""" + +# ============================================================================ +# BASIC STREAMING +# ============================================================================ + +from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock + +async with ClaudeSDKClient() as client: + await client.send_message("What is 2+2?") + messages, result = await client.receive_response() + + for msg in messages: + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + +# ============================================================================ +# STREAMING WITH REAL-TIME DISPLAY +# ============================================================================ + +import asyncio +from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock + +async with ClaudeSDKClient() as client: + async def receive_response(): + messages, _ = await client.receive_response() + for msg in messages: + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + await client.send_message("Tell me a short joke") + await receive_response() + await client.send_message("Now tell me a fun fact") + await receive_response() + + +# ============================================================================ +# PERSISTENT CLIENT FOR MULTIPLE QUESTIONS +# ============================================================================ + +from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock + +# Create client +client = ClaudeSDKClient() +await client.connect() + + +# Helper to get response +async def get_response(): + messages, result = await client.receive_response() + for msg in messages: + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + +# Use it multiple times +await client.send_message("What's 2+2?") +await get_response() + +await client.send_message("What's 10*10?") +await get_response() + +# Don't forget to disconnect when done +await client.disconnect() + + +# ============================================================================ +# WITH INTERRUPT CAPABILITY +# ============================================================================ +# IMPORTANT: Interrupts require active message consumption. You must be +# consuming messages from the client for the interrupt to be processed. + +import asyncio +from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock, ResultMessage + +async with ClaudeSDKClient() as client: + print("\n--- Sending initial message ---\n") + + # Send a long-running task + await client.send_message("Count from 1 to 100 slowly using bash sleep") + + # Create a background task to consume messages + messages_received = [] + interrupt_sent = False + + async def consume_messages(): + async for msg in client.receive_messages(): + messages_received.append(msg) + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + # Check if we got a result after interrupt + if isinstance(msg, ResultMessage) and interrupt_sent: + break + + # Start consuming messages in the background + consume_task = asyncio.create_task(consume_messages()) + + # Wait a bit then send interrupt + await asyncio.sleep(10) + print("\n--- Sending interrupt ---\n") + interrupt_sent = True + await client.interrupt() + + # Wait for the consume task to finish + await consume_task + + # Send a new message after interrupt + print("\n--- After interrupt, sending new message ---\n") + await client.send_message("Just say 'Hello! I was interrupted.'") + messages, result = await client.receive_response() + + for msg in messages: + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + +# ============================================================================ +# ERROR HANDLING PATTERN +# ============================================================================ + +from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock + +try: + async with ClaudeSDKClient() as client: + await client.send_message("Run a bash sleep command for 60 seconds") + + # Timeout after 30 seconds + messages, result = await asyncio.wait_for( + client.receive_response(), timeout=20.0 + ) + +except asyncio.TimeoutError: + print("Request timed out") +except Exception as e: + print(f"Error: {e}") \ No newline at end of file diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 22f3000..caff236 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -279,10 +279,11 @@ class SubprocessCLITransport(Transport): # Handle control responses separately if data.get("type") == "control_response": - request_id = data.get("response", {}).get("request_id") - if request_id and request_id in self._pending_control_responses: + response = data.get("response", {}) + request_id = response.get("request_id") + if request_id: # Store the response for the pending request - self._pending_control_responses[request_id] = data.get("response", {}) + self._pending_control_responses[request_id] = response continue try: From 6c25bf7d37cfa976f4620160ecd60f5dbf004d1d Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 13:44:53 -0700 Subject: [PATCH 05/19] Fix examples --- examples/streaming_mode.py | 106 ++++++++++------------------- examples/streaming_mode_ipython.py | 67 ++++++++++++------ src/claude_code_sdk/client.py | 45 ++++++------ 3 files changed, 103 insertions(+), 115 deletions(-) diff --git a/examples/streaming_mode.py b/examples/streaming_mode.py index cfa4455..239024d 100644 --- a/examples/streaming_mode.py +++ b/examples/streaming_mode.py @@ -13,6 +13,7 @@ from claude_code_sdk import ( AssistantMessage, ClaudeCodeOptions, ClaudeSDKClient, + CLIConnectionError, ResultMessage, TextBlock, ) @@ -27,18 +28,13 @@ async def example_basic_streaming(): await client.send_message("What is 2+2?") # Receive complete response using the helper method - messages, result = await client.receive_response() - - # Extract text from assistant's response - for msg in messages: + async for msg in client.receive_response(): if isinstance(msg, AssistantMessage): for block in msg.content: if isinstance(block, TextBlock): print(f"Claude: {block.text}") - - # Print cost if available - if result and result.total_cost_usd: - print(f"Cost: ${result.total_cost_usd:.4f}") + elif isinstance(msg, ResultMessage) and msg.total_cost_usd: + print(f"Cost: ${msg.total_cost_usd:.4f}") print("Session ended\n") @@ -52,32 +48,22 @@ async def example_multi_turn_conversation(): print("User: What's the capital of France?") await client.send_message("What's the capital of France?") - messages, _ = await client.receive_response() - # Extract and print response - for msg in messages: - if isinstance(msg, AssistantMessage): - for block in msg.content: - if isinstance(block, TextBlock): - print(f"Claude: {block.text}") + async for msg in client.receive_response(): + content_blocks = getattr(msg, 'content', []) + for block in content_blocks: + if isinstance(block, TextBlock): + print(f"{block.text}") # Second turn - follow-up print("\nUser: What's the population of that city?") await client.send_message("What's the population of that city?") - messages, _ = await client.receive_response() - - for msg in messages: - if isinstance(msg, AssistantMessage): - for block in msg.content: - if isinstance(block, TextBlock): - print(f"Claude: {block.text}") - - for msg in messages: - if isinstance(msg, AssistantMessage): - for block in msg.content: - if isinstance(block, TextBlock): - print(f"Claude: {block.text}") + async for msg in client.receive_response(): + content_blocks = getattr(msg, 'content', []) + for block in content_blocks: + if isinstance(block, TextBlock): + print(f"{block.text}") print("\nConversation ended\n") @@ -102,7 +88,7 @@ async def example_concurrent_responses(): questions = [ "What is 2 + 2?", "What is the square root of 144?", - "What is 15% of 80?", + "What is 10% of 80?", ] for question in questions: @@ -168,9 +154,7 @@ async def example_with_interrupt(): await client.send_message("Never mind, just tell me a quick joke") # Get the joke - messages, result = await client.receive_response() - - for msg in messages: + async for msg in client.receive_response(): if isinstance(msg, AssistantMessage): for block in msg.content: if isinstance(block, TextBlock): @@ -233,10 +217,8 @@ async def example_with_options(): "Create a simple hello.txt file with a greeting message" ) - messages, result = await client.receive_response() - tool_uses = [] - for msg in messages: + async for msg in client.receive_response(): if isinstance(msg, AssistantMessage): for block in msg.content: if isinstance(block, TextBlock): @@ -257,48 +239,34 @@ async def example_error_handling(): client = ClaudeSDKClient() try: - # Connect with custom stream - async def message_stream(): - yield { - "type": "user", - "message": {"role": "user", "content": "Hello"}, - "parent_tool_use_id": None, - "session_id": "error-demo", - } + await client.connect() - await client.connect(message_stream()) + # Send a message that will take time to process + await client.send_message("Run a bash sleep command for 60 seconds") - # Create a background task to consume messages (required for interrupt to work) - consume_task = None - - async def consume_messages(): - """Background message consumer.""" - async for msg in client.receive_messages(): - if isinstance(msg, AssistantMessage): - print("Received response from Claude") - - # Receive messages with timeout + # Try to receive response with a short timeout try: - # Start consuming messages in background - consume_task = asyncio.create_task(consume_messages()) - - # Wait for response with timeout - await asyncio.wait_for(consume_task, timeout=30.0) + messages = [] + async with asyncio.timeout(10.0): + async for msg in client.receive_response(): + messages.append(msg) + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text[:50]}...") + elif isinstance(msg, ResultMessage): + print("Received complete response") + break except asyncio.TimeoutError: - print("Response timeout - sending interrupt") - # Note: interrupt requires active message consumption - # Since we're already consuming in the background task, interrupt will work - await client.interrupt() + print("\nResponse timeout after 10 seconds - demonstrating graceful handling") + print(f"Received {len(messages)} messages before timeout") - # Cancel the consume task - if consume_task: - consume_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await consume_task + except CLIConnectionError as e: + print(f"Connection error: {e}") except Exception as e: - print(f"Error: {e}") + print(f"Unexpected error: {e}") finally: # Always disconnect diff --git a/examples/streaming_mode_ipython.py b/examples/streaming_mode_ipython.py index fc2f7cf..6b2b554 100644 --- a/examples/streaming_mode_ipython.py +++ b/examples/streaming_mode_ipython.py @@ -10,17 +10,18 @@ Each example is self-contained and can be run independently. # BASIC STREAMING # ============================================================================ -from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock +from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock, ResultMessage async with ClaudeSDKClient() as client: await client.send_message("What is 2+2?") - messages, result = await client.receive_response() - for msg in messages: + async for msg in client.receive_response(): if isinstance(msg, AssistantMessage): for block in msg.content: if isinstance(block, TextBlock): print(f"Claude: {block.text}") + elif isinstance(msg, ResultMessage) and msg.total_cost_usd: + print(f"Cost: ${msg.total_cost_usd:.4f}") # ============================================================================ @@ -31,18 +32,17 @@ import asyncio from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock async with ClaudeSDKClient() as client: - async def receive_response(): - messages, _ = await client.receive_response() - for msg in messages: + async def send_and_receive(prompt): + await client.send_message(prompt) + async for msg in client.receive_response(): if isinstance(msg, AssistantMessage): for block in msg.content: if isinstance(block, TextBlock): print(f"Claude: {block.text}") - await client.send_message("Tell me a short joke") - await receive_response() - await client.send_message("Now tell me a fun fact") - await receive_response() + await send_and_receive("Tell me a short joke") + print("\n---\n") + await send_and_receive("Now tell me a fun fact") # ============================================================================ @@ -58,8 +58,7 @@ await client.connect() # Helper to get response async def get_response(): - messages, result = await client.receive_response() - for msg in messages: + async for msg in client.receive_response(): if isinstance(msg, AssistantMessage): for block in msg.content: if isinstance(block, TextBlock): @@ -123,9 +122,8 @@ async with ClaudeSDKClient() as client: # Send a new message after interrupt print("\n--- After interrupt, sending new message ---\n") await client.send_message("Just say 'Hello! I was interrupted.'") - messages, result = await client.receive_response() - for msg in messages: + async for msg in client.receive_response(): if isinstance(msg, AssistantMessage): for block in msg.content: if isinstance(block, TextBlock): @@ -142,12 +140,41 @@ try: async with ClaudeSDKClient() as client: await client.send_message("Run a bash sleep command for 60 seconds") - # Timeout after 30 seconds - messages, result = await asyncio.wait_for( - client.receive_response(), timeout=20.0 - ) + # Timeout after 20 seconds + messages = [] + async with asyncio.timeout(20.0): + async for msg in client.receive_response(): + messages.append(msg) + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") except asyncio.TimeoutError: - print("Request timed out") + print("Request timed out after 20 seconds") except Exception as e: - print(f"Error: {e}") \ No newline at end of file + print(f"Error: {e}") + + +# ============================================================================ +# COLLECTING ALL MESSAGES INTO A LIST +# ============================================================================ + +from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock, ResultMessage + +async with ClaudeSDKClient() as client: + await client.send_message("What are the primary colors?") + + # Collect all messages into a list + messages = [msg async for msg in client.receive_response()] + + # Process them afterwards + for msg in messages: + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + elif isinstance(msg, ResultMessage): + print(f"Total messages: {len(messages)}") + if msg.total_cost_usd: + print(f"Cost: ${msg.total_cost_usd:.4f}") diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py index a75eece..db7c494 100644 --- a/src/claude_code_sdk/client.py +++ b/src/claude_code_sdk/client.py @@ -146,50 +146,43 @@ class ClaudeSDKClient: raise CLIConnectionError("Not connected. Call connect() first.") await self._transport.interrupt() - async def receive_response(self) -> tuple[list[Message], ResultMessage | None]: + async def receive_response(self) -> AsyncIterator[Message]: """ - Receive a complete response from Claude, collecting all messages until ResultMessage. + Receive messages from Claude until a ResultMessage is received. - Compared to receive_messages(), this is a convenience method that - handles the common pattern of receiving messages until Claude completes - its response. It collects all messages and returns them along with the - final ResultMessage. + This is an async iterator that yields all messages including the final ResultMessage. + It's a convenience method over receive_messages() that automatically stops iteration + after receiving a ResultMessage. - Returns: - tuple: A tuple of (messages, result) where: - - messages: List of all messages received (UserMessage, AssistantMessage, SystemMessage) - - result: The final ResultMessage if received, None if stream ended without result + Yields: + Message: Each message received (UserMessage, AssistantMessage, SystemMessage, ResultMessage) Example: ```python async with ClaudeSDKClient() as client: - # First turn + # Send message and process response await client.send_message("What's the capital of France?") - messages, result = await client.receive_response() - # Extract assistant's response - for msg in messages: + async for msg in client.receive_response(): if isinstance(msg, AssistantMessage): for block in msg.content: if isinstance(block, TextBlock): print(f"Claude: {block.text}") + elif isinstance(msg, ResultMessage): + print(f"Cost: ${msg.total_cost_usd:.4f}") + ``` - # Second turn - await client.send_message("What's the population?") - messages, result = await client.receive_response() - # ... process response + Note: + The iterator will automatically stop after yielding a ResultMessage. + If you need to collect all messages into a list, use: + ```python + messages = [msg async for msg in client.receive_response()] ``` """ - from .types import ResultMessage - - messages = [] async for message in self.receive_messages(): - messages.append(message) + yield message if isinstance(message, ResultMessage): - return messages, message - - # Stream ended without ResultMessage - return messages, None + return async def disconnect(self) -> None: """Disconnect from Claude.""" From 489677d614d3a86a8f057b9ce83d0aca10c5050b Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 13:57:52 -0700 Subject: [PATCH 06/19] Add tests --- .../_internal/transport/subprocess_cli.py | 6 +- tests/test_streaming_client.py | 674 ++++++++++++++++++ tests/test_transport.py | 6 + 3 files changed, 683 insertions(+), 3 deletions(-) create mode 100644 tests/test_streaming_client.py diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index caff236..5632c21 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -200,10 +200,10 @@ class SubprocessCLITransport(Transport): """Send additional messages in streaming mode.""" if not self._is_streaming: raise CLIConnectionError("send_request only works in streaming mode") - + if not self._stdin_stream: raise CLIConnectionError("stdin not available - stream may have ended") - + # Send each message as a user message for message in messages: # Ensure message has required structure @@ -214,7 +214,7 @@ class SubprocessCLITransport(Transport): "parent_tool_use_id": None, "session_id": options.get("session_id", "default") } - + await self._stdin_stream.send(json.dumps(message) + "\n") async def _stream_to_stdin(self) -> None: diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py new file mode 100644 index 0000000..4f62545 --- /dev/null +++ b/tests/test_streaming_client.py @@ -0,0 +1,674 @@ +"""Tests for ClaudeSDKClient streaming functionality and query() with async iterables.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import anyio +import pytest + +from claude_code_sdk import ( + AssistantMessage, + ClaudeCodeOptions, + ClaudeSDKClient, + CLIConnectionError, + ResultMessage, + SystemMessage, + TextBlock, + UserMessage, + query, +) + + +class TestClaudeSDKClientStreaming: + """Test ClaudeSDKClient streaming functionality.""" + + def test_auto_connect_with_context_manager(self): + """Test automatic connection when using context manager.""" + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + async with ClaudeSDKClient() as client: + # Verify connect was called + mock_transport.connect.assert_called_once() + assert client._transport is mock_transport + + # Verify disconnect was called on exit + mock_transport.disconnect.assert_called_once() + + anyio.run(_test) + + def test_manual_connect_disconnect(self): + """Test manual connect and disconnect.""" + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + client = ClaudeSDKClient() + await client.connect() + + # Verify connect was called + mock_transport.connect.assert_called_once() + assert client._transport is mock_transport + + await client.disconnect() + # Verify disconnect was called + mock_transport.disconnect.assert_called_once() + assert client._transport is None + + anyio.run(_test) + + def test_connect_with_string_prompt(self): + """Test connecting with a string prompt.""" + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + client = ClaudeSDKClient() + await client.connect("Hello Claude") + + # Verify transport was created with string prompt + call_kwargs = mock_transport_class.call_args.kwargs + assert call_kwargs["prompt"] == "Hello Claude" + + anyio.run(_test) + + def test_connect_with_async_iterable(self): + """Test connecting with an async iterable.""" + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + async def message_stream(): + yield {"type": "user", "message": {"role": "user", "content": "Hi"}} + yield {"type": "user", "message": {"role": "user", "content": "Bye"}} + + client = ClaudeSDKClient() + stream = message_stream() + await client.connect(stream) + + # Verify transport was created with async iterable + call_kwargs = mock_transport_class.call_args.kwargs + # Should be the same async iterator + assert call_kwargs["prompt"] is stream + + anyio.run(_test) + + def test_send_message(self): + """Test sending a message.""" + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + async with ClaudeSDKClient() as client: + await client.send_message("Test message") + + # Verify send_request was called with correct format + mock_transport.send_request.assert_called_once() + call_args = mock_transport.send_request.call_args + messages, options = call_args[0] + assert len(messages) == 1 + assert messages[0]["type"] == "user" + assert messages[0]["message"]["content"] == "Test message" + assert options["session_id"] == "default" + + anyio.run(_test) + + def test_send_message_with_session_id(self): + """Test sending a message with custom session ID.""" + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + async with ClaudeSDKClient() as client: + await client.send_message("Test", session_id="custom-session") + + call_args = mock_transport.send_request.call_args + messages, options = call_args[0] + assert messages[0]["session_id"] == "custom-session" + assert options["session_id"] == "custom-session" + + anyio.run(_test) + + def test_send_message_not_connected(self): + """Test sending message when not connected raises error.""" + async def _test(): + client = ClaudeSDKClient() + with pytest.raises(CLIConnectionError, match="Not connected"): + await client.send_message("Test") + + anyio.run(_test) + + def test_receive_messages(self): + """Test receiving messages.""" + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + # Mock the message stream + async def mock_receive(): + yield { + "type": "assistant", + "message": { + "role": "assistant", + "content": [{"type": "text", "text": "Hello!"}], + }, + } + yield { + "type": "user", + "message": {"role": "user", "content": "Hi there"}, + } + + mock_transport.receive_messages = mock_receive + + async with ClaudeSDKClient() as client: + messages = [] + async for msg in client.receive_messages(): + messages.append(msg) + if len(messages) == 2: + break + + assert len(messages) == 2 + assert isinstance(messages[0], AssistantMessage) + assert isinstance(messages[0].content[0], TextBlock) + assert messages[0].content[0].text == "Hello!" + assert isinstance(messages[1], UserMessage) + assert messages[1].content == "Hi there" + + anyio.run(_test) + + def test_receive_response(self): + """Test receive_response stops at ResultMessage.""" + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + # Mock the message stream + async def mock_receive(): + yield { + "type": "assistant", + "message": { + "role": "assistant", + "content": [{"type": "text", "text": "Answer"}], + }, + } + yield { + "type": "result", + "subtype": "success", + "duration_ms": 1000, + "duration_api_ms": 800, + "is_error": False, + "num_turns": 1, + "session_id": "test", + "total_cost_usd": 0.001, + } + # This should not be yielded + yield { + "type": "assistant", + "message": { + "role": "assistant", + "content": [{"type": "text", "text": "Should not see this"}], + }, + } + + mock_transport.receive_messages = mock_receive + + async with ClaudeSDKClient() as client: + messages = [] + async for msg in client.receive_response(): + messages.append(msg) + + # Should only get 2 messages (assistant + result) + assert len(messages) == 2 + assert isinstance(messages[0], AssistantMessage) + assert isinstance(messages[1], ResultMessage) + + anyio.run(_test) + + def test_interrupt(self): + """Test interrupt functionality.""" + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + async with ClaudeSDKClient() as client: + await client.interrupt() + mock_transport.interrupt.assert_called_once() + + anyio.run(_test) + + def test_interrupt_not_connected(self): + """Test interrupt when not connected raises error.""" + async def _test(): + client = ClaudeSDKClient() + with pytest.raises(CLIConnectionError, match="Not connected"): + await client.interrupt() + + anyio.run(_test) + + def test_client_with_options(self): + """Test client initialization with options.""" + async def _test(): + options = ClaudeCodeOptions( + cwd="/custom/path", + allowed_tools=["Read", "Write"], + system_prompt="Be helpful", + ) + + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + client = ClaudeSDKClient(options=options) + await client.connect() + + # Verify options were passed to transport + call_kwargs = mock_transport_class.call_args.kwargs + assert call_kwargs["options"] is options + + anyio.run(_test) + + def test_concurrent_send_receive(self): + """Test concurrent sending and receiving messages.""" + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + # Mock receive to wait then yield messages + async def mock_receive(): + await asyncio.sleep(0.1) + yield { + "type": "assistant", + "message": { + "role": "assistant", + "content": [{"type": "text", "text": "Response 1"}], + }, + } + await asyncio.sleep(0.1) + yield { + "type": "result", + "subtype": "success", + "duration_ms": 1000, + "duration_api_ms": 800, + "is_error": False, + "num_turns": 1, + "session_id": "test", + "total_cost_usd": 0.001, + } + + mock_transport.receive_messages = mock_receive + + async with ClaudeSDKClient() as client: + # Helper to get next message + async def get_next_message(): + return await client.receive_response().__anext__() + + # Start receiving in background + receive_task = asyncio.create_task(get_next_message()) + + # Send message while receiving + await client.send_message("Question 1") + + # Wait for first message + first_msg = await receive_task + assert isinstance(first_msg, AssistantMessage) + + anyio.run(_test) + + +class TestQueryWithAsyncIterable: + """Test query() function with async iterable inputs.""" + + def test_query_with_async_iterable(self): + """Test query with async iterable of messages.""" + async def _test(): + async def message_stream(): + yield {"type": "user", "message": {"role": "user", "content": "First"}} + yield {"type": "user", "message": {"role": "user", "content": "Second"}} + + with patch( + "claude_code_sdk.query.InternalClient" + ) as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # Mock the async generator response + async def mock_process(): + yield AssistantMessage( + content=[TextBlock(text="Response to both messages")] + ) + yield ResultMessage( + subtype="success", + duration_ms=1000, + duration_api_ms=800, + is_error=False, + num_turns=2, + session_id="test", + total_cost_usd=0.002, + ) + + mock_client.process_query.return_value = mock_process() + + # Run query with async iterable + messages = [] + async for msg in query(prompt=message_stream()): + messages.append(msg) + + assert len(messages) == 2 + assert isinstance(messages[0], AssistantMessage) + assert isinstance(messages[1], ResultMessage) + + # Verify process_query was called with async iterable + call_kwargs = mock_client.process_query.call_args.kwargs + # The prompt should be an async generator + assert hasattr(call_kwargs["prompt"], "__aiter__") + + anyio.run(_test) + + def test_query_async_iterable_with_options(self): + """Test query with async iterable and custom options.""" + async def _test(): + async def complex_stream(): + yield { + "type": "user", + "message": {"role": "user", "content": "Setup"}, + "parent_tool_use_id": None, + "session_id": "session-1", + } + yield { + "type": "user", + "message": {"role": "user", "content": "Execute"}, + "parent_tool_use_id": None, + "session_id": "session-1", + } + + options = ClaudeCodeOptions( + cwd="/workspace", + permission_mode="acceptEdits", + max_turns=10, + ) + + with patch( + "claude_code_sdk.query.InternalClient" + ) as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # Mock response + async def mock_process(): + yield AssistantMessage(content=[TextBlock(text="Done")]) + + mock_client.process_query.return_value = mock_process() + + # Run query + messages = [] + async for msg in query(prompt=complex_stream(), options=options): + messages.append(msg) + + # Verify options were passed + call_kwargs = mock_client.process_query.call_args.kwargs + assert call_kwargs["options"] is options + + anyio.run(_test) + + def test_query_empty_async_iterable(self): + """Test query with empty async iterable.""" + async def _test(): + async def empty_stream(): + # Never yields anything + if False: + yield + + with patch( + "claude_code_sdk.query.InternalClient" + ) as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # Mock response + async def mock_process(): + yield SystemMessage( + subtype="info", + data={"message": "No input provided"} + ) + + mock_client.process_query.return_value = mock_process() + + # Run query with empty stream + messages = [] + async for msg in query(prompt=empty_stream()): + messages.append(msg) + + assert len(messages) == 1 + assert isinstance(messages[0], SystemMessage) + + anyio.run(_test) + + def test_query_async_iterable_with_delay(self): + """Test query with async iterable that has delays between yields.""" + async def _test(): + async def delayed_stream(): + yield {"type": "user", "message": {"role": "user", "content": "Start"}} + await asyncio.sleep(0.1) + yield {"type": "user", "message": {"role": "user", "content": "Middle"}} + await asyncio.sleep(0.1) + yield {"type": "user", "message": {"role": "user", "content": "End"}} + + with patch( + "claude_code_sdk.query.InternalClient" + ) as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # Track if the stream was consumed + stream_consumed = False + + # Mock process_query to consume the input stream + async def mock_process_query(prompt, options): + nonlocal stream_consumed + # Consume the async iterable to trigger delays + items = [] + async for item in prompt: + items.append(item) + stream_consumed = True + # Then yield response + yield AssistantMessage( + content=[TextBlock(text="Processing all messages")] + ) + + mock_client.process_query = mock_process_query + + # Time the execution + import time + start_time = time.time() + messages = [] + async for msg in query(prompt=delayed_stream()): + messages.append(msg) + elapsed = time.time() - start_time + + # Should have taken at least 0.2 seconds due to delays + assert elapsed >= 0.2 + assert len(messages) == 1 + assert stream_consumed + + anyio.run(_test) + + def test_query_async_iterable_exception_handling(self): + """Test query handles exceptions in async iterable.""" + async def _test(): + async def failing_stream(): + yield {"type": "user", "message": {"role": "user", "content": "First"}} + raise ValueError("Stream error") + + with patch( + "claude_code_sdk.query.InternalClient" + ) as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # The internal client should receive the failing stream + # and handle the error appropriately + async def mock_process(): + # Simulate processing until error + yield AssistantMessage(content=[TextBlock(text="Error occurred")]) + + mock_client.process_query.return_value = mock_process() + + # Query should handle the error gracefully + messages = [] + async for msg in query(prompt=failing_stream()): + messages.append(msg) + + assert len(messages) == 1 + + anyio.run(_test) + + +class TestClaudeSDKClientEdgeCases: + """Test edge cases and error scenarios.""" + + def test_receive_messages_not_connected(self): + """Test receiving messages when not connected.""" + async def _test(): + client = ClaudeSDKClient() + with pytest.raises(CLIConnectionError, match="Not connected"): + async for _ in client.receive_messages(): + pass + + anyio.run(_test) + + def test_receive_response_not_connected(self): + """Test receive_response when not connected.""" + async def _test(): + client = ClaudeSDKClient() + with pytest.raises(CLIConnectionError, match="Not connected"): + async for _ in client.receive_response(): + pass + + anyio.run(_test) + + def test_double_connect(self): + """Test connecting twice.""" + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + client = ClaudeSDKClient() + await client.connect() + # Second connect should create new transport + await client.connect() + + # Should have been called twice + assert mock_transport_class.call_count == 2 + + anyio.run(_test) + + def test_disconnect_without_connect(self): + """Test disconnecting without connecting first.""" + async def _test(): + client = ClaudeSDKClient() + # Should not raise error + await client.disconnect() + + anyio.run(_test) + + def test_context_manager_with_exception(self): + """Test context manager cleans up on exception.""" + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + with pytest.raises(ValueError): + async with ClaudeSDKClient(): + raise ValueError("Test error") + + # Disconnect should still be called + mock_transport.disconnect.assert_called_once() + + anyio.run(_test) + + def test_receive_response_list_comprehension(self): + """Test collecting messages with list comprehension as shown in examples.""" + async def _test(): + with patch( + "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport_class.return_value = mock_transport + + # Mock the message stream + async def mock_receive(): + yield { + "type": "assistant", + "message": { + "role": "assistant", + "content": [{"type": "text", "text": "Hello"}], + }, + } + yield { + "type": "assistant", + "message": { + "role": "assistant", + "content": [{"type": "text", "text": "World"}], + }, + } + yield { + "type": "result", + "subtype": "success", + "duration_ms": 1000, + "duration_api_ms": 800, + "is_error": False, + "num_turns": 1, + "session_id": "test", + "total_cost_usd": 0.001, + } + + mock_transport.receive_messages = mock_receive + + async with ClaudeSDKClient() as client: + # Test list comprehension pattern from docstring + messages = [msg async for msg in client.receive_response()] + + assert len(messages) == 3 + assert all(isinstance(msg, AssistantMessage | ResultMessage) for msg in messages) + assert isinstance(messages[-1], ResultMessage) + + anyio.run(_test) diff --git a/tests/test_transport.py b/tests/test_transport.py index c8d8e51..aa9e432 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -103,6 +103,12 @@ class TestSubprocessCLITransport: mock_process.wait = AsyncMock() mock_process.stdout = MagicMock() mock_process.stderr = MagicMock() + + # Mock stdin with aclose method + mock_stdin = MagicMock() + mock_stdin.aclose = AsyncMock() + mock_process.stdin = mock_stdin + mock_exec.return_value = mock_process transport = SubprocessCLITransport( From eeb0be9955f9f70a34c9d423352693eddf49f5a4 Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 15:01:43 -0700 Subject: [PATCH 07/19] Close stdin for query() --- src/claude_code_sdk/_internal/client.py | 6 ++- .../_internal/transport/subprocess_cli.py | 9 +++- tests/test_streaming_client.py | 54 +++++++++++++++---- 3 files changed, 55 insertions(+), 14 deletions(-) diff --git a/src/claude_code_sdk/_internal/client.py b/src/claude_code_sdk/_internal/client.py index fb4eeb8..d40540f 100644 --- a/src/claude_code_sdk/_internal/client.py +++ b/src/claude_code_sdk/_internal/client.py @@ -19,7 +19,11 @@ class InternalClient: ) -> AsyncIterator[Message]: """Process a query through transport.""" - transport = SubprocessCLITransport(prompt=prompt, options=options) + transport = SubprocessCLITransport( + prompt=prompt, + options=options, + close_stdin_after_prompt=True + ) try: await transport.connect() diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 5632c21..701b686 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -31,6 +31,7 @@ class SubprocessCLITransport(Transport): prompt: str | AsyncIterable[dict[str, Any]], options: ClaudeCodeOptions, cli_path: str | Path | None = None, + close_stdin_after_prompt: bool = False, ): self._prompt = prompt self._is_streaming = not isinstance(prompt, str) @@ -43,6 +44,7 @@ class SubprocessCLITransport(Transport): self._stdin_stream: TextSendStream | None = None self._pending_control_responses: dict[str, Any] = {} self._request_counter = 0 + self._close_stdin_after_prompt = close_stdin_after_prompt def _find_cli(self) -> str: """Find Claude Code CLI binary.""" @@ -228,8 +230,11 @@ class SubprocessCLITransport(Transport): break await self._stdin_stream.send(json.dumps(message) + "\n") - # Don't close stdin - keep it open for send_request - # Users can explicitly call disconnect() when done + # Close stdin after prompt if requested (e.g., for query() one-shot mode) + if self._close_stdin_after_prompt and self._stdin_stream: + await self._stdin_stream.aclose() + self._stdin_stream = None + # Otherwise keep stdin open for send_request (ClaudeSDKClient interactive mode) except Exception as e: logger.debug(f"Error streaming to stdin: {e}") if self._stdin_stream: diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index 4f62545..ed83bcd 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -24,6 +24,7 @@ class TestClaudeSDKClientStreaming: def test_auto_connect_with_context_manager(self): """Test automatic connection when using context manager.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -43,6 +44,7 @@ class TestClaudeSDKClientStreaming: def test_manual_connect_disconnect(self): """Test manual connect and disconnect.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -66,6 +68,7 @@ class TestClaudeSDKClientStreaming: def test_connect_with_string_prompt(self): """Test connecting with a string prompt.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -84,6 +87,7 @@ class TestClaudeSDKClientStreaming: def test_connect_with_async_iterable(self): """Test connecting with an async iterable.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -93,7 +97,10 @@ class TestClaudeSDKClientStreaming: async def message_stream(): yield {"type": "user", "message": {"role": "user", "content": "Hi"}} - yield {"type": "user", "message": {"role": "user", "content": "Bye"}} + yield { + "type": "user", + "message": {"role": "user", "content": "Bye"}, + } client = ClaudeSDKClient() stream = message_stream() @@ -108,6 +115,7 @@ class TestClaudeSDKClientStreaming: def test_send_message(self): """Test sending a message.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -131,6 +139,7 @@ class TestClaudeSDKClientStreaming: def test_send_message_with_session_id(self): """Test sending a message with custom session ID.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -150,6 +159,7 @@ class TestClaudeSDKClientStreaming: def test_send_message_not_connected(self): """Test sending message when not connected raises error.""" + async def _test(): client = ClaudeSDKClient() with pytest.raises(CLIConnectionError, match="Not connected"): @@ -159,6 +169,7 @@ class TestClaudeSDKClientStreaming: def test_receive_messages(self): """Test receiving messages.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -200,6 +211,7 @@ class TestClaudeSDKClientStreaming: def test_receive_response(self): """Test receive_response stops at ResultMessage.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -231,7 +243,9 @@ class TestClaudeSDKClientStreaming: "type": "assistant", "message": { "role": "assistant", - "content": [{"type": "text", "text": "Should not see this"}], + "content": [ + {"type": "text", "text": "Should not see this"} + ], }, } @@ -251,6 +265,7 @@ class TestClaudeSDKClientStreaming: def test_interrupt(self): """Test interrupt functionality.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -266,6 +281,7 @@ class TestClaudeSDKClientStreaming: def test_interrupt_not_connected(self): """Test interrupt when not connected raises error.""" + async def _test(): client = ClaudeSDKClient() with pytest.raises(CLIConnectionError, match="Not connected"): @@ -275,6 +291,7 @@ class TestClaudeSDKClientStreaming: def test_client_with_options(self): """Test client initialization with options.""" + async def _test(): options = ClaudeCodeOptions( cwd="/custom/path", @@ -299,6 +316,7 @@ class TestClaudeSDKClientStreaming: def test_concurrent_send_receive(self): """Test concurrent sending and receiving messages.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -334,7 +352,7 @@ class TestClaudeSDKClientStreaming: # Helper to get next message async def get_next_message(): return await client.receive_response().__anext__() - + # Start receiving in background receive_task = asyncio.create_task(get_next_message()) @@ -353,13 +371,14 @@ class TestQueryWithAsyncIterable: def test_query_with_async_iterable(self): """Test query with async iterable of messages.""" + async def _test(): async def message_stream(): yield {"type": "user", "message": {"role": "user", "content": "First"}} yield {"type": "user", "message": {"role": "user", "content": "Second"}} with patch( - "claude_code_sdk.query.InternalClient" + "claude_code_sdk._internal.client.InternalClient" ) as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client @@ -399,6 +418,7 @@ class TestQueryWithAsyncIterable: def test_query_async_iterable_with_options(self): """Test query with async iterable and custom options.""" + async def _test(): async def complex_stream(): yield { @@ -421,7 +441,7 @@ class TestQueryWithAsyncIterable: ) with patch( - "claude_code_sdk.query.InternalClient" + "claude_code_sdk._internal.client.InternalClient" ) as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client @@ -445,6 +465,7 @@ class TestQueryWithAsyncIterable: def test_query_empty_async_iterable(self): """Test query with empty async iterable.""" + async def _test(): async def empty_stream(): # Never yields anything @@ -452,7 +473,7 @@ class TestQueryWithAsyncIterable: yield with patch( - "claude_code_sdk.query.InternalClient" + "claude_code_sdk._internal.client.InternalClient" ) as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client @@ -460,8 +481,7 @@ class TestQueryWithAsyncIterable: # Mock response async def mock_process(): yield SystemMessage( - subtype="info", - data={"message": "No input provided"} + subtype="info", data={"message": "No input provided"} ) mock_client.process_query.return_value = mock_process() @@ -478,6 +498,7 @@ class TestQueryWithAsyncIterable: def test_query_async_iterable_with_delay(self): """Test query with async iterable that has delays between yields.""" + async def _test(): async def delayed_stream(): yield {"type": "user", "message": {"role": "user", "content": "Start"}} @@ -487,7 +508,7 @@ class TestQueryWithAsyncIterable: yield {"type": "user", "message": {"role": "user", "content": "End"}} with patch( - "claude_code_sdk.query.InternalClient" + "claude_code_sdk._internal.client.InternalClient" ) as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client @@ -512,6 +533,7 @@ class TestQueryWithAsyncIterable: # Time the execution import time + start_time = time.time() messages = [] async for msg in query(prompt=delayed_stream()): @@ -527,13 +549,14 @@ class TestQueryWithAsyncIterable: def test_query_async_iterable_exception_handling(self): """Test query handles exceptions in async iterable.""" + async def _test(): async def failing_stream(): yield {"type": "user", "message": {"role": "user", "content": "First"}} raise ValueError("Stream error") with patch( - "claude_code_sdk.query.InternalClient" + "claude_code_sdk._internal.client.InternalClient" ) as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client @@ -561,6 +584,7 @@ class TestClaudeSDKClientEdgeCases: def test_receive_messages_not_connected(self): """Test receiving messages when not connected.""" + async def _test(): client = ClaudeSDKClient() with pytest.raises(CLIConnectionError, match="Not connected"): @@ -571,6 +595,7 @@ class TestClaudeSDKClientEdgeCases: def test_receive_response_not_connected(self): """Test receive_response when not connected.""" + async def _test(): client = ClaudeSDKClient() with pytest.raises(CLIConnectionError, match="Not connected"): @@ -581,6 +606,7 @@ class TestClaudeSDKClientEdgeCases: def test_double_connect(self): """Test connecting twice.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -600,6 +626,7 @@ class TestClaudeSDKClientEdgeCases: def test_disconnect_without_connect(self): """Test disconnecting without connecting first.""" + async def _test(): client = ClaudeSDKClient() # Should not raise error @@ -609,6 +636,7 @@ class TestClaudeSDKClientEdgeCases: def test_context_manager_with_exception(self): """Test context manager cleans up on exception.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -627,6 +655,7 @@ class TestClaudeSDKClientEdgeCases: def test_receive_response_list_comprehension(self): """Test collecting messages with list comprehension as shown in examples.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -668,7 +697,10 @@ class TestClaudeSDKClientEdgeCases: messages = [msg async for msg in client.receive_response()] assert len(messages) == 3 - assert all(isinstance(msg, AssistantMessage | ResultMessage) for msg in messages) + assert all( + isinstance(msg, AssistantMessage | ResultMessage) + for msg in messages + ) assert isinstance(messages[-1], ResultMessage) anyio.run(_test) From 712948c2e7e6dbb312478e2ecde8ca60fcfeeee7 Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 15:20:02 -0700 Subject: [PATCH 08/19] Fix test --- tests/test_streaming_client.py | 299 +++++++++++---------------------- 1 file changed, 102 insertions(+), 197 deletions(-) diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index ed83bcd..9dc131d 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -1,7 +1,11 @@ """Tests for ClaudeSDKClient streaming functionality and query() with async iterables.""" import asyncio -from unittest.mock import AsyncMock, MagicMock, patch +import sys +import tempfile +import textwrap +from pathlib import Path +from unittest.mock import AsyncMock, patch import anyio import pytest @@ -12,11 +16,11 @@ from claude_code_sdk import ( ClaudeSDKClient, CLIConnectionError, ResultMessage, - SystemMessage, TextBlock, UserMessage, query, ) +from claude_code_sdk._internal.transport.subprocess_cli import SubprocessCLITransport class TestClaudeSDKClientStreaming: @@ -369,6 +373,76 @@ class TestClaudeSDKClientStreaming: class TestQueryWithAsyncIterable: """Test query() function with async iterable inputs.""" + def _create_test_script( + self, expected_messages=None, response=None, should_error=False + ): + """Create a test script that validates CLI args and stdin messages. + + Args: + expected_messages: List of expected message content strings, or None to skip validation + response: Custom response to output, defaults to a success result + should_error: If True, script will exit with error after reading stdin + + Returns: + Path to the test script + """ + if response is None: + response = '{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}' + + script_content = textwrap.dedent(""" + #!/usr/bin/env python3 + import sys + import json + import time + + # Check command line args + args = sys.argv[1:] + assert "--output-format" in args + assert "stream-json" in args + + # Read stdin messages + stdin_messages = [] + stdin_closed = False + try: + while True: + line = sys.stdin.readline() + if not line: + stdin_closed = True + break + stdin_messages.append(line.strip()) + except: + stdin_closed = True + """, + ) + + if expected_messages is not None: + script_content += textwrap.dedent(f""" + # Verify we got the expected messages + assert len(stdin_messages) == {len(expected_messages)} + """, + ) + for i, msg in enumerate(expected_messages): + script_content += f'''assert '"{msg}"' in stdin_messages[{i}]\n''' + + if should_error: + script_content += textwrap.dedent(""" + sys.exit(1) + """, + ) + else: + script_content += textwrap.dedent(f""" + # Output response + print('{response}') + """, + ) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + test_script = f.name + f.write(script_content) + + Path(test_script).chmod(0o755) + return test_script + def test_query_with_async_iterable(self): """Test query with async iterable of messages.""" @@ -377,204 +451,35 @@ class TestQueryWithAsyncIterable: yield {"type": "user", "message": {"role": "user", "content": "First"}} yield {"type": "user", "message": {"role": "user", "content": "Second"}} - with patch( - "claude_code_sdk._internal.client.InternalClient" - ) as mock_client_class: - mock_client = MagicMock() - mock_client_class.return_value = mock_client - - # Mock the async generator response - async def mock_process(): - yield AssistantMessage( - content=[TextBlock(text="Response to both messages")] - ) - yield ResultMessage( - subtype="success", - duration_ms=1000, - duration_api_ms=800, - is_error=False, - num_turns=2, - session_id="test", - total_cost_usd=0.002, - ) - - mock_client.process_query.return_value = mock_process() - - # Run query with async iterable - messages = [] - async for msg in query(prompt=message_stream()): - messages.append(msg) - - assert len(messages) == 2 - assert isinstance(messages[0], AssistantMessage) - assert isinstance(messages[1], ResultMessage) - - # Verify process_query was called with async iterable - call_kwargs = mock_client.process_query.call_args.kwargs - # The prompt should be an async generator - assert hasattr(call_kwargs["prompt"], "__aiter__") - - anyio.run(_test) - - def test_query_async_iterable_with_options(self): - """Test query with async iterable and custom options.""" - - async def _test(): - async def complex_stream(): - yield { - "type": "user", - "message": {"role": "user", "content": "Setup"}, - "parent_tool_use_id": None, - "session_id": "session-1", - } - yield { - "type": "user", - "message": {"role": "user", "content": "Execute"}, - "parent_tool_use_id": None, - "session_id": "session-1", - } - - options = ClaudeCodeOptions( - cwd="/workspace", - permission_mode="acceptEdits", - max_turns=10, + test_script = self._create_test_script( + expected_messages=["First", "Second"] ) - with patch( - "claude_code_sdk._internal.client.InternalClient" - ) as mock_client_class: - mock_client = MagicMock() - mock_client_class.return_value = mock_client + try: + # Mock _build_command to return our test script + with patch.object( + SubprocessCLITransport, + "_build_command", + return_value=[ + sys.executable, + test_script, + "--output-format", + "stream-json", + "--verbose", + ], + ): + # Run query with async iterable + messages = [] + async for msg in query(prompt=message_stream()): + messages.append(msg) - # Mock response - async def mock_process(): - yield AssistantMessage(content=[TextBlock(text="Done")]) - - mock_client.process_query.return_value = mock_process() - - # Run query - messages = [] - async for msg in query(prompt=complex_stream(), options=options): - messages.append(msg) - - # Verify options were passed - call_kwargs = mock_client.process_query.call_args.kwargs - assert call_kwargs["options"] is options - - anyio.run(_test) - - def test_query_empty_async_iterable(self): - """Test query with empty async iterable.""" - - async def _test(): - async def empty_stream(): - # Never yields anything - if False: - yield - - with patch( - "claude_code_sdk._internal.client.InternalClient" - ) as mock_client_class: - mock_client = MagicMock() - mock_client_class.return_value = mock_client - - # Mock response - async def mock_process(): - yield SystemMessage( - subtype="info", data={"message": "No input provided"} - ) - - mock_client.process_query.return_value = mock_process() - - # Run query with empty stream - messages = [] - async for msg in query(prompt=empty_stream()): - messages.append(msg) - - assert len(messages) == 1 - assert isinstance(messages[0], SystemMessage) - - anyio.run(_test) - - def test_query_async_iterable_with_delay(self): - """Test query with async iterable that has delays between yields.""" - - async def _test(): - async def delayed_stream(): - yield {"type": "user", "message": {"role": "user", "content": "Start"}} - await asyncio.sleep(0.1) - yield {"type": "user", "message": {"role": "user", "content": "Middle"}} - await asyncio.sleep(0.1) - yield {"type": "user", "message": {"role": "user", "content": "End"}} - - with patch( - "claude_code_sdk._internal.client.InternalClient" - ) as mock_client_class: - mock_client = MagicMock() - mock_client_class.return_value = mock_client - - # Track if the stream was consumed - stream_consumed = False - - # Mock process_query to consume the input stream - async def mock_process_query(prompt, options): - nonlocal stream_consumed - # Consume the async iterable to trigger delays - items = [] - async for item in prompt: - items.append(item) - stream_consumed = True - # Then yield response - yield AssistantMessage( - content=[TextBlock(text="Processing all messages")] - ) - - mock_client.process_query = mock_process_query - - # Time the execution - import time - - start_time = time.time() - messages = [] - async for msg in query(prompt=delayed_stream()): - messages.append(msg) - elapsed = time.time() - start_time - - # Should have taken at least 0.2 seconds due to delays - assert elapsed >= 0.2 - assert len(messages) == 1 - assert stream_consumed - - anyio.run(_test) - - def test_query_async_iterable_exception_handling(self): - """Test query handles exceptions in async iterable.""" - - async def _test(): - async def failing_stream(): - yield {"type": "user", "message": {"role": "user", "content": "First"}} - raise ValueError("Stream error") - - with patch( - "claude_code_sdk._internal.client.InternalClient" - ) as mock_client_class: - mock_client = MagicMock() - mock_client_class.return_value = mock_client - - # The internal client should receive the failing stream - # and handle the error appropriately - async def mock_process(): - # Simulate processing until error - yield AssistantMessage(content=[TextBlock(text="Error occurred")]) - - mock_client.process_query.return_value = mock_process() - - # Query should handle the error gracefully - messages = [] - async for msg in query(prompt=failing_stream()): - messages.append(msg) - - assert len(messages) == 1 + # Should get the result message + assert len(messages) == 1 + assert isinstance(messages[0], ResultMessage) + assert messages[0].subtype == "success" + finally: + # Clean up + Path(test_script).unlink() anyio.run(_test) From c95c077b9b65c44e71982f744d816ad5453a6b0d Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 15:21:02 -0700 Subject: [PATCH 09/19] Ruff --- src/claude_code_sdk/__init__.py | 1 - src/claude_code_sdk/_internal/client.py | 4 +--- .../_internal/transport/subprocess_cli.py | 13 +++++++++---- tests/test_streaming_client.py | 12 ++++++++---- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/claude_code_sdk/__init__.py b/src/claude_code_sdk/__init__.py index dc84df1..1439937 100644 --- a/src/claude_code_sdk/__init__.py +++ b/src/claude_code_sdk/__init__.py @@ -1,6 +1,5 @@ """Claude SDK for Python.""" - from ._errors import ( ClaudeSDKError, CLIConnectionError, diff --git a/src/claude_code_sdk/_internal/client.py b/src/claude_code_sdk/_internal/client.py index d40540f..c1afa9e 100644 --- a/src/claude_code_sdk/_internal/client.py +++ b/src/claude_code_sdk/_internal/client.py @@ -20,9 +20,7 @@ class InternalClient: """Process a query through transport.""" transport = SubprocessCLITransport( - prompt=prompt, - options=options, - close_stdin_after_prompt=True + prompt=prompt, options=options, close_stdin_after_prompt=True ) try: diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 701b686..92b0743 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -161,6 +161,7 @@ class SubprocessCLITransport(Transport): self._stdin_stream = TextSendStream(self._process.stdin) # Start streaming messages to stdin in background import asyncio + asyncio.create_task(self._stream_to_stdin()) else: # String mode: close stdin immediately (backward compatible) @@ -214,7 +215,7 @@ class SubprocessCLITransport(Transport): "type": "user", "message": {"role": "user", "content": str(message)}, "parent_tool_use_id": None, - "session_id": options.get("session_id", "default") + "session_id": options.get("session_id", "default"), } await self._stdin_stream.send(json.dumps(message) + "\n") @@ -362,7 +363,9 @@ class SubprocessCLITransport(Transport): async def interrupt(self) -> None: """Send interrupt control request (only works in streaming mode).""" if not self._is_streaming: - raise CLIConnectionError("Interrupt requires streaming mode (AsyncIterable prompt)") + raise CLIConnectionError( + "Interrupt requires streaming mode (AsyncIterable prompt)" + ) if not self._stdin_stream: raise CLIConnectionError("Not connected or stdin not available") @@ -382,7 +385,7 @@ class SubprocessCLITransport(Transport): control_request = { "type": "control_request", "request_id": request_id, - "request": request + "request": request, } # Send request @@ -397,7 +400,9 @@ class SubprocessCLITransport(Transport): response = self._pending_control_responses.pop(request_id) if response.get("subtype") == "error": - raise CLIConnectionError(f"Control request failed: {response.get('error')}") + raise CLIConnectionError( + f"Control request failed: {response.get('error')}" + ) return response except TimeoutError: diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index 9dc131d..cf1c6a5 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -389,7 +389,8 @@ class TestQueryWithAsyncIterable: if response is None: response = '{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}' - script_content = textwrap.dedent(""" + script_content = textwrap.dedent( + """ #!/usr/bin/env python3 import sys import json @@ -416,7 +417,8 @@ class TestQueryWithAsyncIterable: ) if expected_messages is not None: - script_content += textwrap.dedent(f""" + script_content += textwrap.dedent( + f""" # Verify we got the expected messages assert len(stdin_messages) == {len(expected_messages)} """, @@ -425,12 +427,14 @@ class TestQueryWithAsyncIterable: script_content += f'''assert '"{msg}"' in stdin_messages[{i}]\n''' if should_error: - script_content += textwrap.dedent(""" + script_content += textwrap.dedent( + """ sys.exit(1) """, ) else: - script_content += textwrap.dedent(f""" + script_content += textwrap.dedent( + f""" # Output response print('{response}') """, From e65c2f417aabd82b2d405a8152592b95ed7e8e8b Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 15:25:34 -0700 Subject: [PATCH 10/19] Fix test --- tests/test_streaming_client.py | 146 ++++++++++++--------------------- 1 file changed, 51 insertions(+), 95 deletions(-) diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index cf1c6a5..c196c56 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -3,7 +3,6 @@ import asyncio import sys import tempfile -import textwrap from pathlib import Path from unittest.mock import AsyncMock, patch @@ -373,80 +372,6 @@ class TestClaudeSDKClientStreaming: class TestQueryWithAsyncIterable: """Test query() function with async iterable inputs.""" - def _create_test_script( - self, expected_messages=None, response=None, should_error=False - ): - """Create a test script that validates CLI args and stdin messages. - - Args: - expected_messages: List of expected message content strings, or None to skip validation - response: Custom response to output, defaults to a success result - should_error: If True, script will exit with error after reading stdin - - Returns: - Path to the test script - """ - if response is None: - response = '{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}' - - script_content = textwrap.dedent( - """ - #!/usr/bin/env python3 - import sys - import json - import time - - # Check command line args - args = sys.argv[1:] - assert "--output-format" in args - assert "stream-json" in args - - # Read stdin messages - stdin_messages = [] - stdin_closed = False - try: - while True: - line = sys.stdin.readline() - if not line: - stdin_closed = True - break - stdin_messages.append(line.strip()) - except: - stdin_closed = True - """, - ) - - if expected_messages is not None: - script_content += textwrap.dedent( - f""" - # Verify we got the expected messages - assert len(stdin_messages) == {len(expected_messages)} - """, - ) - for i, msg in enumerate(expected_messages): - script_content += f'''assert '"{msg}"' in stdin_messages[{i}]\n''' - - if should_error: - script_content += textwrap.dedent( - """ - sys.exit(1) - """, - ) - else: - script_content += textwrap.dedent( - f""" - # Output response - print('{response}') - """, - ) - - with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: - test_script = f.name - f.write(script_content) - - Path(test_script).chmod(0o755) - return test_script - def test_query_with_async_iterable(self): """Test query with async iterable of messages.""" @@ -455,32 +380,63 @@ class TestQueryWithAsyncIterable: yield {"type": "user", "message": {"role": "user", "content": "First"}} yield {"type": "user", "message": {"role": "user", "content": "Second"}} - test_script = self._create_test_script( - expected_messages=["First", "Second"] - ) + # Create a simple test script that validates stdin and outputs a result + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + test_script = f.name + f.write("""#!/usr/bin/env python3 +import sys +import json + +# Read stdin messages +stdin_messages = [] +while True: + line = sys.stdin.readline() + if not line: + break + stdin_messages.append(line.strip()) + +# Verify we got 2 messages +assert len(stdin_messages) == 2 +assert '"First"' in stdin_messages[0] +assert '"Second"' in stdin_messages[1] + +# Output a valid result +print('{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}') +""") + + Path(test_script).chmod(0o755) try: - # Mock _build_command to return our test script + # Mock _find_cli to return python executing our test script with patch.object( SubprocessCLITransport, - "_build_command", - return_value=[ - sys.executable, - test_script, - "--output-format", - "stream-json", - "--verbose", - ], + "_find_cli", + return_value=sys.executable ): - # Run query with async iterable - messages = [] - async for msg in query(prompt=message_stream()): - messages.append(msg) + # Mock _build_command to add our test script as first argument + original_build_command = SubprocessCLITransport._build_command + + def mock_build_command(self): + # Get original command + cmd = original_build_command(self) + # Replace the CLI path with python + script + cmd[0] = test_script + return cmd + + with patch.object( + SubprocessCLITransport, + "_build_command", + mock_build_command + ): + # Run query with async iterable + messages = [] + async for msg in query(prompt=message_stream()): + messages.append(msg) - # Should get the result message - assert len(messages) == 1 - assert isinstance(messages[0], ResultMessage) - assert messages[0].subtype == "success" + # Should get the result message + assert len(messages) == 1 + assert isinstance(messages[0], ResultMessage) + assert messages[0].subtype == "success" finally: # Clean up Path(test_script).unlink() From a813a4d66593a7c592eec8071cbe2f449454b694 Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 15:26:30 -0700 Subject: [PATCH 11/19] Fix lint --- tests/test_streaming_client.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index c196c56..49a5291 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -403,30 +403,26 @@ assert '"Second"' in stdin_messages[1] # Output a valid result print('{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}') """) - + Path(test_script).chmod(0o755) try: # Mock _find_cli to return python executing our test script with patch.object( - SubprocessCLITransport, - "_find_cli", - return_value=sys.executable + SubprocessCLITransport, "_find_cli", return_value=sys.executable ): # Mock _build_command to add our test script as first argument original_build_command = SubprocessCLITransport._build_command - + def mock_build_command(self): # Get original command cmd = original_build_command(self) # Replace the CLI path with python + script cmd[0] = test_script return cmd - + with patch.object( - SubprocessCLITransport, - "_build_command", - mock_build_command + SubprocessCLITransport, "_build_command", mock_build_command ): # Run query with async iterable messages = [] From 952085283952d7ddf2faa959c3e6f7877d6f0e2d Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 18:47:07 -0700 Subject: [PATCH 12/19] Fix types --- .../_internal/transport/subprocess_cli.py | 2 +- src/claude_code_sdk/client.py | 16 +++++++++------- src/claude_code_sdk/query.py | 3 ++- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 92b0743..4121d91 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -42,7 +42,7 @@ class SubprocessCLITransport(Transport): self._stdout_stream: TextReceiveStream | None = None self._stderr_stream: TextReceiveStream | None = None self._stdin_stream: TextSendStream | None = None - self._pending_control_responses: dict[str, Any] = {} + self._pending_control_responses: dict[str, dict[str, Any]] = {} self._request_counter = 0 self._close_stdin_after_prompt = close_stdin_after_prompt diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py index db7c494..213519e 100644 --- a/src/claude_code_sdk/client.py +++ b/src/claude_code_sdk/client.py @@ -2,6 +2,7 @@ import os from collections.abc import AsyncIterable, AsyncIterator +from typing import Any from ._errors import CLIConnectionError from .types import ClaudeCodeOptions, Message, ResultMessage @@ -94,19 +95,20 @@ class ClaudeSDKClient: if options is None: options = ClaudeCodeOptions() self.options = options - self._transport = None + self._transport: Any | None = None os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py-client" - async def connect(self, prompt: str | AsyncIterable[dict] | None = None) -> None: + async def connect(self, prompt: str | AsyncIterable[dict[str, Any]] | None = None) -> None: """Connect to Claude with a prompt or message stream.""" from ._internal.transport.subprocess_cli import SubprocessCLITransport # Auto-connect with empty async iterable if no prompt is provided - async def _empty_stream(): + async def _empty_stream() -> AsyncIterator[dict[str, Any]]: # Never yields, but indicates that this function is an iterator and # keeps the connection open. - if False: - yield + # This yield is never reached but makes this an async generator + return + yield {} # type: ignore[unreachable] self._transport = SubprocessCLITransport( prompt=_empty_stream() if prompt is None else prompt, @@ -190,12 +192,12 @@ class ClaudeSDKClient: await self._transport.disconnect() self._transport = None - async def __aenter__(self): + async def __aenter__(self) -> "ClaudeSDKClient": """Enter async context - automatically connects with empty stream for interactive use.""" await self.connect() return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: """Exit async context - always disconnects.""" await self.disconnect() return False diff --git a/src/claude_code_sdk/query.py b/src/claude_code_sdk/query.py index 4bd7e96..3732762 100644 --- a/src/claude_code_sdk/query.py +++ b/src/claude_code_sdk/query.py @@ -2,13 +2,14 @@ import os from collections.abc import AsyncIterable, AsyncIterator +from typing import Any from ._internal.client import InternalClient from .types import ClaudeCodeOptions, Message async def query( - *, prompt: str | AsyncIterable[dict], options: ClaudeCodeOptions | None = None + *, prompt: str | AsyncIterable[dict[str, Any]], options: ClaudeCodeOptions | None = None ) -> AsyncIterator[Message]: """ Query Claude Code for one-shot or unidirectional streaming interactions. From 739a5723f9db6de61071b37e22122e1b759cc977 Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 19:12:07 -0700 Subject: [PATCH 13/19] PR feedback --- .../_internal/message_parser.py | 106 +++++++++++------- .../_internal/transport/subprocess_cli.py | 7 +- src/claude_code_sdk/client.py | 59 ++++++---- 3 files changed, 114 insertions(+), 58 deletions(-) diff --git a/src/claude_code_sdk/_internal/message_parser.py b/src/claude_code_sdk/_internal/message_parser.py index a2b88d2..c5f4fc0 100644 --- a/src/claude_code_sdk/_internal/message_parser.py +++ b/src/claude_code_sdk/_internal/message_parser.py @@ -1,5 +1,6 @@ """Message parser for Claude Code SDK responses.""" +import logging from typing import Any from ..types import ( @@ -14,6 +15,8 @@ from ..types import ( UserMessage, ) +logger = logging.getLogger(__name__) + def parse_message(data: dict[str, Any]) -> Message | None: """ @@ -23,55 +26,82 @@ def parse_message(data: dict[str, Any]) -> Message | None: data: Raw message dictionary from CLI output Returns: - Parsed Message object or None if type is unrecognized + Parsed Message object or None if type is unrecognized or parsing fails """ - match data["type"]: + try: + message_type = data.get("type") + if not message_type: + logger.warning("Message missing 'type' field: %s", data) + return None + + except AttributeError: + logger.error("Invalid message data type (expected dict): %s", type(data)) + return None + + match message_type: case "user": - return UserMessage(content=data["message"]["content"]) + try: + return UserMessage(content=data["message"]["content"]) + except KeyError as e: + logger.error("Missing required field in user message: %s", e) + return None case "assistant": - content_blocks: list[ContentBlock] = [] - for block in data["message"]["content"]: - match block["type"]: - case "text": - content_blocks.append(TextBlock(text=block["text"])) - case "tool_use": - content_blocks.append( - ToolUseBlock( - id=block["id"], - name=block["name"], - input=block["input"], + try: + content_blocks: list[ContentBlock] = [] + for block in data["message"]["content"]: + match block["type"]: + case "text": + content_blocks.append(TextBlock(text=block["text"])) + case "tool_use": + content_blocks.append( + ToolUseBlock( + id=block["id"], + name=block["name"], + input=block["input"], + ) ) - ) - case "tool_result": - content_blocks.append( - ToolResultBlock( - tool_use_id=block["tool_use_id"], - content=block.get("content"), - is_error=block.get("is_error"), + case "tool_result": + content_blocks.append( + ToolResultBlock( + tool_use_id=block["tool_use_id"], + content=block.get("content"), + is_error=block.get("is_error"), + ) ) - ) - return AssistantMessage(content=content_blocks) + return AssistantMessage(content=content_blocks) + except KeyError as e: + logger.error("Missing required field in assistant message: %s", e) + return None case "system": - return SystemMessage( - subtype=data["subtype"], - data=data, - ) + try: + return SystemMessage( + subtype=data["subtype"], + data=data, + ) + except KeyError as e: + logger.error("Missing required field in system message: %s", e) + return None case "result": - return ResultMessage( - subtype=data["subtype"], - duration_ms=data["duration_ms"], - duration_api_ms=data["duration_api_ms"], - is_error=data["is_error"], - num_turns=data["num_turns"], - session_id=data["session_id"], - total_cost_usd=data.get("total_cost_usd"), - usage=data.get("usage"), - result=data.get("result"), - ) + try: + return ResultMessage( + subtype=data["subtype"], + duration_ms=data["duration_ms"], + duration_api_ms=data["duration_api_ms"], + is_error=data["is_error"], + num_turns=data["num_turns"], + session_id=data["session_id"], + total_cost_usd=data.get("total_cost_usd"), + usage=data.get("usage"), + result=data.get("result"), + ) + except KeyError as e: + logger.error("Missing required field in result message: %s", e) + return None case _: + logger.debug("Unknown message type: %s", message_type) return None diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 4121d91..6a22eec 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -296,7 +296,12 @@ class SubprocessCLITransport(Transport): yield data except GeneratorExit: return - except json.JSONDecodeError: + except json.JSONDecodeError as e: + logger.warning( + f"Failed to parse JSON from CLI output: {e}. Buffer content: {json_buffer[:200]}..." + ) + # Clear buffer to avoid repeated parse attempts on malformed data + json_buffer = "" continue except anyio.ClosedResourceError: diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py index 213519e..dbf9aa6 100644 --- a/src/claude_code_sdk/client.py +++ b/src/claude_code_sdk/client.py @@ -128,19 +128,37 @@ class ClaudeSDKClient: if message: yield message - async def send_message(self, content: str, session_id: str = "default") -> None: - """Send a new message in streaming mode.""" + async def send_message(self, prompt: str | AsyncIterable[dict[str, Any]], session_id: str = "default") -> None: + """ + Send a new message in streaming mode. + + Args: + prompt: Either a string message or an async iterable of message dictionaries + session_id: Session identifier for the conversation + """ if not self._transport: raise CLIConnectionError("Not connected. Call connect() first.") - message = { - "type": "user", - "message": {"role": "user", "content": content}, - "parent_tool_use_id": None, - "session_id": session_id, - } + # Handle string prompts + if isinstance(prompt, str): + message = { + "type": "user", + "message": {"role": "user", "content": prompt}, + "parent_tool_use_id": None, + "session_id": session_id, + } + await self._transport.send_request([message], {"session_id": session_id}) + else: + # Handle AsyncIterable prompts + messages = [] + async for msg in prompt: + # Ensure session_id is set on each message + if "session_id" not in msg: + msg["session_id"] = session_id + messages.append(msg) - await self._transport.send_request([message], {"session_id": session_id}) + if messages: + await self._transport.send_request(messages, {"session_id": session_id}) async def interrupt(self) -> None: """Send interrupt signal (only works with streaming mode).""" @@ -150,11 +168,17 @@ class ClaudeSDKClient: async def receive_response(self) -> AsyncIterator[Message]: """ - Receive messages from Claude until a ResultMessage is received. + Receive messages from Claude until and including a ResultMessage. - This is an async iterator that yields all messages including the final ResultMessage. - It's a convenience method over receive_messages() that automatically stops iteration - after receiving a ResultMessage. + This async iterator yields all messages in sequence and automatically terminates + after yielding a ResultMessage (which indicates the response is complete). + It's a convenience method over receive_messages() for single-response workflows. + + **Stopping Behavior:** + - Yields each message as it's received + - Terminates immediately after yielding a ResultMessage + - The ResultMessage IS included in the yielded messages + - If no ResultMessage is received, the iterator continues indefinitely Yields: Message: Each message received (UserMessage, AssistantMessage, SystemMessage, ResultMessage) @@ -162,7 +186,6 @@ class ClaudeSDKClient: Example: ```python async with ClaudeSDKClient() as client: - # Send message and process response await client.send_message("What's the capital of France?") async for msg in client.receive_response(): @@ -172,14 +195,12 @@ class ClaudeSDKClient: print(f"Claude: {block.text}") elif isinstance(msg, ResultMessage): print(f"Cost: ${msg.total_cost_usd:.4f}") + # Iterator will terminate after this message ``` Note: - The iterator will automatically stop after yielding a ResultMessage. - If you need to collect all messages into a list, use: - ```python - messages = [msg async for msg in client.receive_response()] - ``` + To collect all messages: `messages = [msg async for msg in client.receive_response()]` + The final message in the list will always be a ResultMessage. """ async for message in self.receive_messages(): yield message From 3b56577b2f53a7516def5dbf194e50b67b3fe077 Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 19:12:15 -0700 Subject: [PATCH 14/19] CLAUDE.md --- CLAUDE.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..69f23fb --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,27 @@ +# Workflow + +```bash +# Lint and style +# Check for issues and fix automatically +python -m ruff check src/ test/ --fix +python -m ruff format src/ test/ + +# Typecheck (only done for src/) +python -m mypy src/ + +# Run all tests +python -m pytest tests/ + +# Run specific test file +python -m pytest tests/test_client.py +``` + +# Codebase Structure + +- `src/claude_code_sdk/` - Main package + - `client.py` - ClaudeSDKClient for interactive sessions + - `query.py` - One-shot query function + - `types.py` - Type definitions + - `_internal/` - Internal implementation details + - `transport/subprocess_cli.py` - CLI subprocess management + - `message_parser.py` - Message parsing logic From b57e05afa526042eee8b9b7ccca3939d1729ac89 Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 19:57:17 -0700 Subject: [PATCH 15/19] Improve examples --- examples/streaming_mode.py | 226 +++++++++++++++++++++-------- examples/streaming_mode_ipython.py | 72 +++++++-- src/claude_code_sdk/client.py | 16 +- tests/test_streaming_client.py | 12 +- 4 files changed, 236 insertions(+), 90 deletions(-) mode change 100644 => 100755 examples/streaming_mode.py diff --git a/examples/streaming_mode.py b/examples/streaming_mode.py old mode 100644 new mode 100755 index 239024d..73eb410 --- a/examples/streaming_mode.py +++ b/examples/streaming_mode.py @@ -4,10 +4,20 @@ Comprehensive examples of using ClaudeSDKClient for streaming mode. This file demonstrates various patterns for building applications with the ClaudeSDKClient streaming interface. + +The queries are intentionally simplistic. In reality, a query can be a more +complex task that Claude SDK uses its agentic capabilities and tools (e.g. run +bash commands, edit files, search the web, fetch web content) to accomplish. + +Usage: +./examples/streaming_mode.py - List the examples +./examples/streaming_mode.py all - Run all examples +./examples/streaming_mode.py basic_streaming - Run a specific example """ import asyncio import contextlib +import sys from claude_code_sdk import ( AssistantMessage, @@ -15,28 +25,48 @@ from claude_code_sdk import ( ClaudeSDKClient, CLIConnectionError, ResultMessage, + SystemMessage, TextBlock, + UserMessage, ) +def display_message(msg): + """Standardized message display function. + + - UserMessage: "User: " + - AssistantMessage: "Claude: " + - SystemMessage: ignored + - ResultMessage: "Result ended" + cost if available + """ + if isinstance(msg, UserMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"User: {block.text}") + elif isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + elif isinstance(msg, SystemMessage): + # Ignore system messages + pass + elif isinstance(msg, ResultMessage): + print("Result ended") + + async def example_basic_streaming(): """Basic streaming with context manager.""" print("=== Basic Streaming Example ===") async with ClaudeSDKClient() as client: - # Send a message - await client.send_message("What is 2+2?") + print("User: What is 2+2?") + await client.query("What is 2+2?") # Receive complete response using the helper method async for msg in client.receive_response(): - if isinstance(msg, AssistantMessage): - for block in msg.content: - if isinstance(block, TextBlock): - print(f"Claude: {block.text}") - elif isinstance(msg, ResultMessage) and msg.total_cost_usd: - print(f"Cost: ${msg.total_cost_usd:.4f}") + display_message(msg) - print("Session ended\n") + print("\n") async def example_multi_turn_conversation(): @@ -46,26 +76,20 @@ async def example_multi_turn_conversation(): async with ClaudeSDKClient() as client: # First turn print("User: What's the capital of France?") - await client.send_message("What's the capital of France?") + await client.query("What's the capital of France?") # Extract and print response async for msg in client.receive_response(): - content_blocks = getattr(msg, 'content', []) - for block in content_blocks: - if isinstance(block, TextBlock): - print(f"{block.text}") + display_message(msg) # Second turn - follow-up print("\nUser: What's the population of that city?") - await client.send_message("What's the population of that city?") + await client.query("What's the population of that city?") async for msg in client.receive_response(): - content_blocks = getattr(msg, 'content', []) - for block in content_blocks: - if isinstance(block, TextBlock): - print(f"{block.text}") + display_message(msg) - print("\nConversation ended\n") + print("\n") async def example_concurrent_responses(): @@ -76,10 +100,7 @@ async def example_concurrent_responses(): # Background task to continuously receive messages async def receive_messages(): async for message in client.receive_messages(): - if isinstance(message, AssistantMessage): - for block in message.content: - if isinstance(block, TextBlock): - print(f"Claude: {block.text}") + display_message(message) # Start receiving in background receive_task = asyncio.create_task(receive_messages()) @@ -93,7 +114,7 @@ async def example_concurrent_responses(): for question in questions: print(f"\nUser: {question}") - await client.send_message(question) + await client.query(question) await asyncio.sleep(3) # Wait between messages # Give time for final responses @@ -104,7 +125,7 @@ async def example_concurrent_responses(): with contextlib.suppress(asyncio.CancelledError): await receive_task - print("\nSession ended\n") + print("\n") async def example_with_interrupt(): @@ -115,7 +136,7 @@ async def example_with_interrupt(): async with ClaudeSDKClient() as client: # Start a long-running task print("\nUser: Count from 1 to 100 slowly") - await client.send_message( + await client.query( "Count from 1 to 100 slowly, with a brief pause between each number" ) @@ -132,10 +153,10 @@ async def example_with_interrupt(): if isinstance(block, TextBlock): # Print first few numbers print(f"Claude: {block.text[:50]}...") - - # Stop when we get a result after interrupt - if isinstance(message, ResultMessage) and interrupt_sent: - break + elif isinstance(message, ResultMessage): + display_message(message) + if interrupt_sent: + break # Start consuming messages in the background consume_task = asyncio.create_task(consume_messages()) @@ -151,16 +172,13 @@ async def example_with_interrupt(): # Send new instruction after interrupt print("\nUser: Never mind, just tell me a quick joke") - await client.send_message("Never mind, just tell me a quick joke") + await client.query("Never mind, just tell me a quick joke") # Get the joke async for msg in client.receive_response(): - if isinstance(msg, AssistantMessage): - for block in msg.content: - if isinstance(block, TextBlock): - print(f"Claude: {block.text}") + display_message(msg) - print("\nSession ended\n") + print("\n") async def example_manual_message_handling(): @@ -168,7 +186,7 @@ async def example_manual_message_handling(): print("=== Manual Message Handling Example ===") async with ClaudeSDKClient() as client: - await client.send_message( + await client.query( "List 5 programming languages and their main use cases" ) @@ -180,6 +198,7 @@ async def example_manual_message_handling(): for block in message.content: if isinstance(block, TextBlock): text = block.text + print(f"Claude: {text}") # Custom logic: extract language names for lang in [ "Python", @@ -193,12 +212,12 @@ async def example_manual_message_handling(): if lang in text and lang not in languages_found: languages_found.append(lang) print(f"Found language: {lang}") - elif isinstance(message, ResultMessage): - print(f"\nTotal languages mentioned: {len(languages_found)}") + display_message(message) + print(f"Total languages mentioned: {len(languages_found)}") break - print("\nSession ended\n") + print("\n") async def example_with_options(): @@ -213,23 +232,75 @@ async def example_with_options(): ) async with ClaudeSDKClient(options=options) as client: - await client.send_message( + print("User: Create a simple hello.txt file with a greeting message") + await client.query( "Create a simple hello.txt file with a greeting message" ) tool_uses = [] async for msg in client.receive_response(): if isinstance(msg, AssistantMessage): + display_message(msg) for block in msg.content: - if isinstance(block, TextBlock): - print(f"Claude: {block.text}") - elif hasattr(block, "name"): # ToolUseBlock + if hasattr(block, "name") and not isinstance( + block, TextBlock + ): # ToolUseBlock tool_uses.append(getattr(block, "name", "")) + else: + display_message(msg) if tool_uses: - print(f"\nTools used: {', '.join(tool_uses)}") + print(f"Tools used: {', '.join(tool_uses)}") - print("\nSession ended\n") + print("\n") + + +async def example_async_iterable_prompt(): + """Demonstrate send_message with async iterable.""" + print("=== Async Iterable Prompt Example ===") + + async def create_message_stream(): + """Generate a stream of messages.""" + print("User: Hello! I have multiple questions.") + yield { + "type": "user", + "message": {"role": "user", "content": "Hello! I have multiple questions."}, + "parent_tool_use_id": None, + "session_id": "qa-session", + } + + print("User: First, what's the capital of Japan?") + yield { + "type": "user", + "message": { + "role": "user", + "content": "First, what's the capital of Japan?", + }, + "parent_tool_use_id": None, + "session_id": "qa-session", + } + + print("User: Second, what's 15% of 200?") + yield { + "type": "user", + "message": {"role": "user", "content": "Second, what's 15% of 200?"}, + "parent_tool_use_id": None, + "session_id": "qa-session", + } + + async with ClaudeSDKClient() as client: + # Send async iterable of messages + await client.query(create_message_stream()) + + # Receive the three responses + async for msg in client.receive_response(): + display_message(msg) + async for msg in client.receive_response(): + display_message(msg) + async for msg in client.receive_response(): + display_message(msg) + + print("\n") async def example_error_handling(): @@ -242,7 +313,8 @@ async def example_error_handling(): await client.connect() # Send a message that will take time to process - await client.send_message("Run a bash sleep command for 60 seconds") + print("User: Run a bash sleep command for 60 seconds") + await client.query("Run a bash sleep command for 60 seconds") # Try to receive response with a short timeout try: @@ -255,11 +327,13 @@ async def example_error_handling(): if isinstance(block, TextBlock): print(f"Claude: {block.text[:50]}...") elif isinstance(msg, ResultMessage): - print("Received complete response") + display_message(msg) break except asyncio.TimeoutError: - print("\nResponse timeout after 10 seconds - demonstrating graceful handling") + print( + "\nResponse timeout after 10 seconds - demonstrating graceful handling" + ) print(f"Received {len(messages)} messages before timeout") except CLIConnectionError as e: @@ -272,24 +346,48 @@ async def example_error_handling(): # Always disconnect await client.disconnect() - print("\nSession ended\n") + print("\n") async def main(): - """Run all examples.""" - examples = [ - example_basic_streaming, - example_multi_turn_conversation, - example_concurrent_responses, - example_with_interrupt, - example_manual_message_handling, - example_with_options, - example_error_handling, - ] + """Run all examples or a specific example based on command line argument.""" + examples = { + "basic_streaming": example_basic_streaming, + "multi_turn_conversation": example_multi_turn_conversation, + "concurrent_responses": example_concurrent_responses, + "with_interrupt": example_with_interrupt, + "manual_message_handling": example_manual_message_handling, + "with_options": example_with_options, + "async_iterable_prompt": example_async_iterable_prompt, + "error_handling": example_error_handling, + } - for example in examples: - await example() - print("-" * 50 + "\n") + if len(sys.argv) < 2: + # List available examples + print("Usage: python streaming_mode.py ") + print("\nAvailable examples:") + print(" all - Run all examples") + for name in examples: + print(f" {name}") + sys.exit(0) + + example_name = sys.argv[1] + + if example_name == "all": + # Run all examples + for example in examples.values(): + await example() + print("-" * 50 + "\n") + elif example_name in examples: + # Run specific example + await examples[example_name]() + else: + print(f"Error: Unknown example '{example_name}'") + print("\nAvailable examples:") + print(" all - Run all examples") + for name in examples: + print(f" {name}") + sys.exit(1) if __name__ == "__main__": diff --git a/examples/streaming_mode_ipython.py b/examples/streaming_mode_ipython.py index 6b2b554..7265afa 100644 --- a/examples/streaming_mode_ipython.py +++ b/examples/streaming_mode_ipython.py @@ -4,6 +4,10 @@ IPython-friendly code snippets for ClaudeSDKClient streaming mode. These examples are designed to be copy-pasted directly into IPython. Each example is self-contained and can be run independently. + +The queries are intentionally simplistic. In reality, a query can be a more +complex task that Claude SDK uses its agentic capabilities and tools (e.g. run +bash commands, edit files, search the web, fetch web content) to accomplish. """ # ============================================================================ @@ -13,15 +17,14 @@ Each example is self-contained and can be run independently. from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock, ResultMessage async with ClaudeSDKClient() as client: - await client.send_message("What is 2+2?") + print("User: What is 2+2?") + await client.query("What is 2+2?") async for msg in client.receive_response(): if isinstance(msg, AssistantMessage): for block in msg.content: if isinstance(block, TextBlock): print(f"Claude: {block.text}") - elif isinstance(msg, ResultMessage) and msg.total_cost_usd: - print(f"Cost: ${msg.total_cost_usd:.4f}") # ============================================================================ @@ -33,7 +36,8 @@ from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock async with ClaudeSDKClient() as client: async def send_and_receive(prompt): - await client.send_message(prompt) + print(f"User: {prompt}") + await client.query(prompt) async for msg in client.receive_response(): if isinstance(msg, AssistantMessage): for block in msg.content: @@ -66,10 +70,12 @@ async def get_response(): # Use it multiple times -await client.send_message("What's 2+2?") +print("User: What's 2+2?") +await client.query("What's 2+2?") await get_response() -await client.send_message("What's 10*10?") +print("User: What's 10*10?") +await client.query("What's 10*10?") await get_response() # Don't forget to disconnect when done @@ -89,7 +95,8 @@ async with ClaudeSDKClient() as client: print("\n--- Sending initial message ---\n") # Send a long-running task - await client.send_message("Count from 1 to 100 slowly using bash sleep") + print("User: Count from 1 to 100, run bash sleep for 1 second in between") + await client.query("Count from 1 to 100, run bash sleep for 1 second in between") # Create a background task to consume messages messages_received = [] @@ -121,7 +128,7 @@ async with ClaudeSDKClient() as client: # Send a new message after interrupt print("\n--- After interrupt, sending new message ---\n") - await client.send_message("Just say 'Hello! I was interrupted.'") + await client.query("Just say 'Hello! I was interrupted.'") async for msg in client.receive_response(): if isinstance(msg, AssistantMessage): @@ -138,7 +145,8 @@ from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock try: async with ClaudeSDKClient() as client: - await client.send_message("Run a bash sleep command for 60 seconds") + print("User: Run a bash sleep command for 60 seconds") + await client.query("Run a bash sleep command for 60 seconds") # Timeout after 20 seconds messages = [] @@ -156,6 +164,47 @@ except Exception as e: print(f"Error: {e}") +# ============================================================================ +# SENDING ASYNC ITERABLE OF MESSAGES +# ============================================================================ + +from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock + +async def message_generator(): + """Generate multiple messages as an async iterable.""" + print("User: I have two math questions.") + yield { + "type": "user", + "message": {"role": "user", "content": "I have two math questions."}, + "parent_tool_use_id": None, + "session_id": "math-session" + } + print("User: What is 25 * 4?") + yield { + "type": "user", + "message": {"role": "user", "content": "What is 25 * 4?"}, + "parent_tool_use_id": None, + "session_id": "math-session" + } + print("User: What is 100 / 5?") + yield { + "type": "user", + "message": {"role": "user", "content": "What is 100 / 5?"}, + "parent_tool_use_id": None, + "session_id": "math-session" + } + +async with ClaudeSDKClient() as client: + # Send async iterable instead of string + await client.query(message_generator()) + + async for msg in client.receive_response(): + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + # ============================================================================ # COLLECTING ALL MESSAGES INTO A LIST # ============================================================================ @@ -163,7 +212,8 @@ except Exception as e: from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock, ResultMessage async with ClaudeSDKClient() as client: - await client.send_message("What are the primary colors?") + print("User: What are the primary colors?") + await client.query("What are the primary colors?") # Collect all messages into a list messages = [msg async for msg in client.receive_response()] @@ -176,5 +226,3 @@ async with ClaudeSDKClient() as client: print(f"Claude: {block.text}") elif isinstance(msg, ResultMessage): print(f"Total messages: {len(messages)}") - if msg.total_cost_usd: - print(f"Cost: ${msg.total_cost_usd:.4f}") diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py index dbf9aa6..3c24cb1 100644 --- a/src/claude_code_sdk/client.py +++ b/src/claude_code_sdk/client.py @@ -42,7 +42,7 @@ class ClaudeSDKClient: # Automatically connects with empty stream for interactive use async with ClaudeSDKClient() as client: # Send initial message - await client.send_message("Let's solve a math problem step by step") + await client.query("Let's solve a math problem step by step") # Receive and process response async for message in client.receive_messages(): @@ -50,7 +50,7 @@ class ClaudeSDKClient: break # Send follow-up based on response - await client.send_message("What's 15% of 80?") + await client.query("What's 15% of 80?") # Continue conversation... # Automatically disconnects @@ -60,14 +60,14 @@ class ClaudeSDKClient: ```python async with ClaudeSDKClient() as client: # Start a long task - await client.send_message("Count to 1000") + await client.query("Count to 1000") # Interrupt after 2 seconds await asyncio.sleep(2) await client.interrupt() # Send new instruction - await client.send_message("Never mind, what's 2+2?") + await client.query("Never mind, what's 2+2?") ``` Example - Manual connection: @@ -81,7 +81,7 @@ class ClaudeSDKClient: await client.connect(message_stream()) # Send additional messages dynamically - await client.send_message("What's the weather?") + await client.query("What's the weather?") async for message in client.receive_messages(): print(message) @@ -128,9 +128,9 @@ class ClaudeSDKClient: if message: yield message - async def send_message(self, prompt: str | AsyncIterable[dict[str, Any]], session_id: str = "default") -> None: + async def query(self, prompt: str | AsyncIterable[dict[str, Any]], session_id: str = "default") -> None: """ - Send a new message in streaming mode. + Send a new request in streaming mode. Args: prompt: Either a string message or an async iterable of message dictionaries @@ -186,7 +186,7 @@ class ClaudeSDKClient: Example: ```python async with ClaudeSDKClient() as client: - await client.send_message("What's the capital of France?") + await client.query("What's the capital of France?") async for msg in client.receive_response(): if isinstance(msg, AssistantMessage): diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index 49a5291..884d7c4 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -116,8 +116,8 @@ class TestClaudeSDKClientStreaming: anyio.run(_test) - def test_send_message(self): - """Test sending a message.""" + def test_query(self): + """Test sending a query.""" async def _test(): with patch( @@ -127,7 +127,7 @@ class TestClaudeSDKClientStreaming: mock_transport_class.return_value = mock_transport async with ClaudeSDKClient() as client: - await client.send_message("Test message") + await client.query("Test message") # Verify send_request was called with correct format mock_transport.send_request.assert_called_once() @@ -151,7 +151,7 @@ class TestClaudeSDKClientStreaming: mock_transport_class.return_value = mock_transport async with ClaudeSDKClient() as client: - await client.send_message("Test", session_id="custom-session") + await client.query("Test", session_id="custom-session") call_args = mock_transport.send_request.call_args messages, options = call_args[0] @@ -166,7 +166,7 @@ class TestClaudeSDKClientStreaming: async def _test(): client = ClaudeSDKClient() with pytest.raises(CLIConnectionError, match="Not connected"): - await client.send_message("Test") + await client.query("Test") anyio.run(_test) @@ -360,7 +360,7 @@ class TestClaudeSDKClientStreaming: receive_task = asyncio.create_task(get_next_message()) # Send message while receiving - await client.send_message("Question 1") + await client.query("Question 1") # Wait for first message first_msg = await receive_task From 8e652d7d87009bb6a7824595d3af55a3bf0d231b Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 20:04:58 -0700 Subject: [PATCH 16/19] Fix lint and test --- CLAUDE.md | 4 ++-- .../_internal/transport/subprocess_cli.py | 9 +++------ src/claude_code_sdk/client.py | 8 ++++++-- src/claude_code_sdk/query.py | 4 +++- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 69f23fb..fb9ed47 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -3,8 +3,8 @@ ```bash # Lint and style # Check for issues and fix automatically -python -m ruff check src/ test/ --fix -python -m ruff format src/ test/ +python -m ruff check src/ tests/ --fix +python -m ruff format src/ tests/ # Typecheck (only done for src/) python -m mypy src/ diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 6a22eec..94c42fa 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -296,12 +296,9 @@ class SubprocessCLITransport(Transport): yield data except GeneratorExit: return - except json.JSONDecodeError as e: - logger.warning( - f"Failed to parse JSON from CLI output: {e}. Buffer content: {json_buffer[:200]}..." - ) - # Clear buffer to avoid repeated parse attempts on malformed data - json_buffer = "" + except json.JSONDecodeError: + # Don't clear buffer - we might be in the middle of a split JSON message + # The buffer will be cleared when we successfully parse or hit size limit continue except anyio.ClosedResourceError: diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py index 3c24cb1..a4c81ed 100644 --- a/src/claude_code_sdk/client.py +++ b/src/claude_code_sdk/client.py @@ -98,7 +98,9 @@ class ClaudeSDKClient: self._transport: Any | None = None os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py-client" - async def connect(self, prompt: str | AsyncIterable[dict[str, Any]] | None = None) -> None: + async def connect( + self, prompt: str | AsyncIterable[dict[str, Any]] | None = None + ) -> None: """Connect to Claude with a prompt or message stream.""" from ._internal.transport.subprocess_cli import SubprocessCLITransport @@ -128,7 +130,9 @@ class ClaudeSDKClient: if message: yield message - async def query(self, prompt: str | AsyncIterable[dict[str, Any]], session_id: str = "default") -> None: + async def query( + self, prompt: str | AsyncIterable[dict[str, Any]], session_id: str = "default" + ) -> None: """ Send a new request in streaming mode. diff --git a/src/claude_code_sdk/query.py b/src/claude_code_sdk/query.py index 3732762..ad77a1b 100644 --- a/src/claude_code_sdk/query.py +++ b/src/claude_code_sdk/query.py @@ -9,7 +9,9 @@ from .types import ClaudeCodeOptions, Message async def query( - *, prompt: str | AsyncIterable[dict[str, Any]], options: ClaudeCodeOptions | None = None + *, + prompt: str | AsyncIterable[dict[str, Any]], + options: ClaudeCodeOptions | None = None, ) -> AsyncIterator[Message]: """ Query Claude Code for one-shot or unidirectional streaming interactions. From 3e7da418cea8d27c2b30a58118787feeda3a45c5 Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 20:16:45 -0700 Subject: [PATCH 17/19] Fix json error handling --- src/claude_code_sdk/_errors.py | 8 ++ src/claude_code_sdk/_internal/client.py | 4 +- .../_internal/message_parser.py | 49 ++++---- .../_internal/transport/subprocess_cli.py | 5 +- src/claude_code_sdk/client.py | 4 +- tests/test_message_parser.py | 118 ++++++++++++++++++ 6 files changed, 159 insertions(+), 29 deletions(-) create mode 100644 tests/test_message_parser.py diff --git a/src/claude_code_sdk/_errors.py b/src/claude_code_sdk/_errors.py index e832757..8f3e759 100644 --- a/src/claude_code_sdk/_errors.py +++ b/src/claude_code_sdk/_errors.py @@ -44,3 +44,11 @@ class CLIJSONDecodeError(ClaudeSDKError): self.line = line self.original_error = original_error super().__init__(f"Failed to decode JSON: {line[:100]}...") + + +class MessageParseError(ClaudeSDKError): + """Raised when unable to parse a message from CLI output.""" + + def __init__(self, message: str, data: dict | None = None): + self.data = data + super().__init__(message) diff --git a/src/claude_code_sdk/_internal/client.py b/src/claude_code_sdk/_internal/client.py index c1afa9e..715dab5 100644 --- a/src/claude_code_sdk/_internal/client.py +++ b/src/claude_code_sdk/_internal/client.py @@ -27,9 +27,7 @@ class InternalClient: await transport.connect() async for data in transport.receive_messages(): - message = parse_message(data) - if message: - yield message + yield parse_message(data) finally: await transport.disconnect() diff --git a/src/claude_code_sdk/_internal/message_parser.py b/src/claude_code_sdk/_internal/message_parser.py index c5f4fc0..858e24f 100644 --- a/src/claude_code_sdk/_internal/message_parser.py +++ b/src/claude_code_sdk/_internal/message_parser.py @@ -3,6 +3,7 @@ import logging from typing import Any +from .._errors import MessageParseError from ..types import ( AssistantMessage, ContentBlock, @@ -18,7 +19,7 @@ from ..types import ( logger = logging.getLogger(__name__) -def parse_message(data: dict[str, Any]) -> Message | None: +def parse_message(data: dict[str, Any]) -> Message: """ Parse message from CLI output into typed Message objects. @@ -26,25 +27,29 @@ def parse_message(data: dict[str, Any]) -> Message | None: data: Raw message dictionary from CLI output Returns: - Parsed Message object or None if type is unrecognized or parsing fails - """ - try: - message_type = data.get("type") - if not message_type: - logger.warning("Message missing 'type' field: %s", data) - return None + Parsed Message object - except AttributeError: - logger.error("Invalid message data type (expected dict): %s", type(data)) - return None + Raises: + MessageParseError: If parsing fails or message type is unrecognized + """ + if not isinstance(data, dict): + raise MessageParseError( + f"Invalid message data type (expected dict, got {type(data).__name__})", + data, + ) + + message_type = data.get("type") + if not message_type: + raise MessageParseError("Message missing 'type' field", data) match message_type: case "user": try: return UserMessage(content=data["message"]["content"]) except KeyError as e: - logger.error("Missing required field in user message: %s", e) - return None + raise MessageParseError( + f"Missing required field in user message: {e}", data + ) from e case "assistant": try: @@ -72,8 +77,9 @@ def parse_message(data: dict[str, Any]) -> Message | None: return AssistantMessage(content=content_blocks) except KeyError as e: - logger.error("Missing required field in assistant message: %s", e) - return None + raise MessageParseError( + f"Missing required field in assistant message: {e}", data + ) from e case "system": try: @@ -82,8 +88,9 @@ def parse_message(data: dict[str, Any]) -> Message | None: data=data, ) except KeyError as e: - logger.error("Missing required field in system message: %s", e) - return None + raise MessageParseError( + f"Missing required field in system message: {e}", data + ) from e case "result": try: @@ -99,9 +106,9 @@ def parse_message(data: dict[str, Any]) -> Message | None: result=data.get("result"), ) except KeyError as e: - logger.error("Missing required field in result message: %s", e) - return None + raise MessageParseError( + f"Missing required field in result message: {e}", data + ) from e case _: - logger.debug("Unknown message type: %s", message_type) - return None + raise MessageParseError(f"Unknown message type: {message_type}", data) diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 94c42fa..b39f903 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -297,8 +297,9 @@ class SubprocessCLITransport(Transport): except GeneratorExit: return except json.JSONDecodeError: - # Don't clear buffer - we might be in the middle of a split JSON message - # The buffer will be cleared when we successfully parse or hit size limit + # We are speculatively decoding the buffer until we get + # a full JSON object. If there is an actual issue, we + # raise an error after _MAX_BUFFER_SIZE. continue except anyio.ClosedResourceError: diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py index a4c81ed..8e86ba7 100644 --- a/src/claude_code_sdk/client.py +++ b/src/claude_code_sdk/client.py @@ -126,9 +126,7 @@ class ClaudeSDKClient: from ._internal.message_parser import parse_message async for data in self._transport.receive_messages(): - message = parse_message(data) - if message: - yield message + yield parse_message(data) async def query( self, prompt: str | AsyncIterable[dict[str, Any]], session_id: str = "default" diff --git a/tests/test_message_parser.py b/tests/test_message_parser.py new file mode 100644 index 0000000..47bd521 --- /dev/null +++ b/tests/test_message_parser.py @@ -0,0 +1,118 @@ +"""Tests for message parser error handling.""" + +import pytest + +from claude_code_sdk._errors import MessageParseError +from claude_code_sdk._internal.message_parser import parse_message +from claude_code_sdk.types import ( + AssistantMessage, + ResultMessage, + SystemMessage, + TextBlock, + ToolUseBlock, + UserMessage, +) + + +class TestMessageParser: + """Test message parsing with the new exception behavior.""" + + def test_parse_valid_user_message(self): + """Test parsing a valid user message.""" + data = {"type": "user", "message": {"content": [{"type": "text", "text": "Hello"}]}} + message = parse_message(data) + assert isinstance(message, UserMessage) + + def test_parse_valid_assistant_message(self): + """Test parsing a valid assistant message.""" + data = { + "type": "assistant", + "message": { + "content": [ + {"type": "text", "text": "Hello"}, + { + "type": "tool_use", + "id": "tool_123", + "name": "Read", + "input": {"file_path": "/test.txt"}, + }, + ] + }, + } + message = parse_message(data) + assert isinstance(message, AssistantMessage) + assert len(message.content) == 2 + assert isinstance(message.content[0], TextBlock) + assert isinstance(message.content[1], ToolUseBlock) + + def test_parse_valid_system_message(self): + """Test parsing a valid system message.""" + data = {"type": "system", "subtype": "start"} + message = parse_message(data) + assert isinstance(message, SystemMessage) + assert message.subtype == "start" + + def test_parse_valid_result_message(self): + """Test parsing a valid result message.""" + data = { + "type": "result", + "subtype": "success", + "duration_ms": 1000, + "duration_api_ms": 500, + "is_error": False, + "num_turns": 2, + "session_id": "session_123", + } + message = parse_message(data) + assert isinstance(message, ResultMessage) + assert message.subtype == "success" + + def test_parse_invalid_data_type(self): + """Test that non-dict data raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message("not a dict") # type: ignore + assert "Invalid message data type" in str(exc_info.value) + assert "expected dict, got str" in str(exc_info.value) + + def test_parse_missing_type_field(self): + """Test that missing 'type' field raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"message": {"content": []}}) + assert "Message missing 'type' field" in str(exc_info.value) + + def test_parse_unknown_message_type(self): + """Test that unknown message type raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"type": "unknown_type"}) + assert "Unknown message type: unknown_type" in str(exc_info.value) + + def test_parse_user_message_missing_fields(self): + """Test that user message with missing fields raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"type": "user"}) + assert "Missing required field in user message" in str(exc_info.value) + + def test_parse_assistant_message_missing_fields(self): + """Test that assistant message with missing fields raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"type": "assistant"}) + assert "Missing required field in assistant message" in str(exc_info.value) + + def test_parse_system_message_missing_fields(self): + """Test that system message with missing fields raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"type": "system"}) + assert "Missing required field in system message" in str(exc_info.value) + + def test_parse_result_message_missing_fields(self): + """Test that result message with missing fields raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"type": "result", "subtype": "success"}) + assert "Missing required field in result message" in str(exc_info.value) + + def test_message_parse_error_contains_data(self): + """Test that MessageParseError contains the original data.""" + data = {"type": "unknown", "some": "data"} + with pytest.raises(MessageParseError) as exc_info: + parse_message(data) + assert exc_info.value.data == data From 5325dea9fd1694582f8f8454acf037233016e85d Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 20:19:40 -0700 Subject: [PATCH 18/19] Lint --- tests/test_message_parser.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_message_parser.py b/tests/test_message_parser.py index 47bd521..0eb4354 100644 --- a/tests/test_message_parser.py +++ b/tests/test_message_parser.py @@ -19,7 +19,10 @@ class TestMessageParser: def test_parse_valid_user_message(self): """Test parsing a valid user message.""" - data = {"type": "user", "message": {"content": [{"type": "text", "text": "Hello"}]}} + data = { + "type": "user", + "message": {"content": [{"type": "text", "text": "Hello"}]}, + } message = parse_message(data) assert isinstance(message, UserMessage) From e852710d8cc4d8f7628a93cd4c999a9e1dccd71d Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 20:43:07 -0700 Subject: [PATCH 19/19] Remove hardcoded timeout for control messages to match Typescript SDK --- src/claude_code_sdk/_errors.py | 4 +++- .../_internal/transport/subprocess_cli.py | 20 +++++++------------ 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/claude_code_sdk/_errors.py b/src/claude_code_sdk/_errors.py index 8f3e759..c86bf23 100644 --- a/src/claude_code_sdk/_errors.py +++ b/src/claude_code_sdk/_errors.py @@ -1,5 +1,7 @@ """Error types for Claude SDK.""" +from typing import Any + class ClaudeSDKError(Exception): """Base exception for all Claude SDK errors.""" @@ -49,6 +51,6 @@ class CLIJSONDecodeError(ClaudeSDKError): class MessageParseError(ClaudeSDKError): """Raised when unable to parse a message from CLI output.""" - def __init__(self, message: str, data: dict | None = None): + def __init__(self, message: str, data: dict[str, Any] | None = None): self.data = data super().__init__(message) diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index b39f903..34b7034 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -394,19 +394,13 @@ class SubprocessCLITransport(Transport): # Send request await self._stdin_stream.send(json.dumps(control_request) + "\n") - # Wait for response with timeout - try: - with anyio.fail_after(30.0): # 30 second timeout - while request_id not in self._pending_control_responses: - await anyio.sleep(0.1) + # Wait for response + while request_id not in self._pending_control_responses: + await anyio.sleep(0.1) - response = self._pending_control_responses.pop(request_id) + response = self._pending_control_responses.pop(request_id) - if response.get("subtype") == "error": - raise CLIConnectionError( - f"Control request failed: {response.get('error')}" - ) + if response.get("subtype") == "error": + raise CLIConnectionError(f"Control request failed: {response.get('error')}") - return response - except TimeoutError: - raise CLIConnectionError("Control request timed out") from None + return response