diff --git a/.gitignore b/.gitignore index c50630d..7e6d2df 100644 --- a/.gitignore +++ b/.gitignore @@ -33,6 +33,7 @@ env/ *.swp *.swo *~ +**/.DS_Store # Testing .tox/ @@ -46,4 +47,4 @@ htmlcov/ .mypy_cache/ .dmypy.json dmypy.json -.pyre/ \ No newline at end of file +.pyre/ diff --git a/examples/streaming_mode.py b/examples/streaming_mode.py index ad177c9..333b488 100755 --- a/examples/streaming_mode.py +++ b/examples/streaming_mode.py @@ -340,6 +340,85 @@ async def example_bash_command(): print("\n") +async def example_control_protocol(): + """Demonstrate server info and interrupt capabilities.""" + print("=== Control Protocol Example ===") + print("Shows server info retrieval and interrupt capability\n") + + async with ClaudeSDKClient() as client: + # 1. Get server initialization info + print("1. Getting server info...") + server_info = await client.get_server_info() + + if server_info: + print("✓ Server info retrieved successfully!") + print(f" - Available commands: {len(server_info.get('commands', []))}") + print(f" - Output style: {server_info.get('output_style', 'unknown')}") + + # Show available output styles if present + styles = server_info.get('available_output_styles', []) + if styles: + print(f" - Available output styles: {', '.join(styles)}") + + # Show a few example commands + commands = server_info.get('commands', [])[:5] + if commands: + print(" - Example commands:") + for cmd in commands: + if isinstance(cmd, dict): + print(f" • {cmd.get('name', 'unknown')}") + else: + print("✗ No server info available (may not be in streaming mode)") + + print("\n2. Testing interrupt capability...") + + # Start a long-running task + print("User: Count from 1 to 20 slowly") + await client.query("Count from 1 to 20 slowly, pausing between each number") + + # Start consuming messages in background to enable interrupt + messages = [] + async def consume(): + async for msg in client.receive_response(): + messages.append(msg) + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + # Print first 50 chars to show progress + print(f"Claude: {block.text[:50]}...") + break + if isinstance(msg, ResultMessage): + break + + consume_task = asyncio.create_task(consume()) + + # Wait a moment then interrupt + await asyncio.sleep(2) + print("\n[Sending interrupt after 2 seconds...]") + + try: + await client.interrupt() + print("✓ Interrupt sent successfully") + except Exception as e: + print(f"✗ Interrupt failed: {e}") + + # Wait for task to complete + with contextlib.suppress(asyncio.CancelledError): + await consume_task + + # Send new query after interrupt + print("\nUser: Just say 'Hello!'") + await client.query("Just say 'Hello!'") + + async for msg in client.receive_response(): + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + print("\n") + + async def example_error_handling(): """Demonstrate proper error handling.""" print("=== Error Handling Example ===") @@ -350,8 +429,8 @@ async def example_error_handling(): await client.connect() # Send a message that will take time to process - print("User: Run a bash sleep command for 60 seconds") - await client.query("Run a bash sleep command for 60 seconds") + print("User: Run a bash sleep command for 60 seconds not in the background") + await client.query("Run a bash sleep command for 60 seconds not in the background") # Try to receive response with a short timeout try: @@ -397,6 +476,7 @@ async def main(): "with_options": example_with_options, "async_iterable_prompt": example_async_iterable_prompt, "bash_command": example_bash_command, + "control_protocol": example_control_protocol, "error_handling": example_error_handling, } diff --git a/src/claude_code_sdk/_internal/client.py b/src/claude_code_sdk/_internal/client.py index 15d8e7d..6b5331f 100644 --- a/src/claude_code_sdk/_internal/client.py +++ b/src/claude_code_sdk/_internal/client.py @@ -8,6 +8,7 @@ from ..types import ( Message, ) from .message_parser import parse_message +from .query import Query from .transport import Transport from .transport.subprocess_cli import SubprocessCLITransport @@ -24,21 +25,44 @@ class InternalClient: options: ClaudeCodeOptions, transport: Transport | None = None, ) -> AsyncIterator[Message]: - """Process a query through transport.""" + """Process a query through transport and Query.""" - # Use provided transport or choose one based on configuration + # Use provided transport or create subprocess transport if transport is not None: chosen_transport = transport else: - chosen_transport = SubprocessCLITransport( - prompt=prompt, options=options, close_stdin_after_prompt=True - ) + chosen_transport = SubprocessCLITransport(prompt=prompt, options=options) + + # Connect transport + await chosen_transport.connect() + + # Create Query to handle control protocol + is_streaming = not isinstance(prompt, str) + query = Query( + transport=chosen_transport, + is_streaming_mode=is_streaming, + can_use_tool=None, # TODO: Add support for can_use_tool callback + hooks=None, # TODO: Add support for hooks + ) try: - await chosen_transport.connect() + # Start reading messages + await query.start() - async for data in chosen_transport.receive_messages(): + # Initialize if streaming + if is_streaming: + await query.initialize() + + # Stream input if it's an AsyncIterable + if isinstance(prompt, AsyncIterable) and query._tg: + # Start streaming in background + # Create a task that will run in the background + query._tg.start_soon(query.stream_input, prompt) + # For string prompts, the prompt is already passed via CLI args + + # Yield parsed messages + async for data in query.receive_messages(): yield parse_message(data) finally: - await chosen_transport.disconnect() + await query.close() diff --git a/src/claude_code_sdk/_internal/query.py b/src/claude_code_sdk/_internal/query.py new file mode 100644 index 0000000..cf0b6d8 --- /dev/null +++ b/src/claude_code_sdk/_internal/query.py @@ -0,0 +1,332 @@ +"""Query class for handling bidirectional control protocol.""" + +import json +import logging +import os +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable +from contextlib import suppress +from typing import Any + +import anyio + +from ..types import ( + SDKControlPermissionRequest, + SDKControlRequest, + SDKControlResponse, + SDKHookCallbackRequest, +) +from .transport import Transport + +logger = logging.getLogger(__name__) + + +class Query: + """Handles bidirectional control protocol on top of Transport. + + This class manages: + - Control request/response routing + - Hook callbacks + - Tool permission callbacks + - Message streaming + - Initialization handshake + """ + + def __init__( + self, + transport: Transport, + is_streaming_mode: bool, + can_use_tool: Callable[ + [str, dict[str, Any], dict[str, Any]], Awaitable[dict[str, Any]] + ] + | None = None, + hooks: dict[str, list[dict[str, Any]]] | None = None, + ): + """Initialize Query with transport and callbacks. + + Args: + transport: Low-level transport for I/O + is_streaming_mode: Whether using streaming (bidirectional) mode + can_use_tool: Optional callback for tool permission requests + hooks: Optional hook configurations + """ + self.transport = transport + self.is_streaming_mode = is_streaming_mode + self.can_use_tool = can_use_tool + self.hooks = hooks or {} + + # Control protocol state + self.pending_control_responses: dict[str, anyio.Event] = {} + self.pending_control_results: dict[str, dict[str, Any] | Exception] = {} + self.hook_callbacks: dict[str, Callable[..., Any]] = {} + self.next_callback_id = 0 + self._request_counter = 0 + + # Message stream + self._message_send, self._message_receive = anyio.create_memory_object_stream[ + dict[str, Any] + ](max_buffer_size=100) + self._tg: anyio.abc.TaskGroup | None = None + self._initialized = False + self._closed = False + self._initialization_result: dict[str, Any] | None = None + + async def initialize(self) -> dict[str, Any] | None: + """Initialize control protocol if in streaming mode. + + Returns: + Initialize response with supported commands, or None if not streaming + """ + if not self.is_streaming_mode: + return None + + # Build hooks configuration for initialization + hooks_config: dict[str, Any] = {} + if self.hooks: + for event, matchers in self.hooks.items(): + if matchers: + hooks_config[event] = [] + for matcher in matchers: + callback_ids = [] + for callback in matcher.get("hooks", []): + callback_id = f"hook_{self.next_callback_id}" + self.next_callback_id += 1 + self.hook_callbacks[callback_id] = callback + callback_ids.append(callback_id) + hooks_config[event].append( + { + "matcher": matcher.get("matcher"), + "hookCallbackIds": callback_ids, + } + ) + + # Send initialize request + request = { + "subtype": "initialize", + "hooks": hooks_config if hooks_config else None, + } + + response = await self._send_control_request(request) + self._initialized = True + self._initialization_result = response # Store for later access + return response + + async def start(self) -> None: + """Start reading messages from transport.""" + if self._tg is None: + self._tg = anyio.create_task_group() + await self._tg.__aenter__() + self._tg.start_soon(self._read_messages) + + async def _read_messages(self) -> None: + """Read messages from transport and route them.""" + try: + async for message in self.transport.read_messages(): + if self._closed: + break + + msg_type = message.get("type") + + # Route control messages + if msg_type == "control_response": + response = message.get("response", {}) + request_id = response.get("request_id") + if request_id in self.pending_control_responses: + event = self.pending_control_responses[request_id] + if response.get("subtype") == "error": + self.pending_control_results[request_id] = Exception( + response.get("error", "Unknown error") + ) + else: + self.pending_control_results[request_id] = response + event.set() + continue + + elif msg_type == "control_request": + # Handle incoming control requests from CLI + # Cast message to SDKControlRequest for type safety + request: SDKControlRequest = message # type: ignore[assignment] + if self._tg: + self._tg.start_soon(self._handle_control_request, request) + continue + + elif msg_type == "control_cancel_request": + # Handle cancel requests + # TODO: Implement cancellation support + continue + + # Regular SDK messages go to the stream + await self._message_send.send(message) + + except anyio.get_cancelled_exc_class(): + # Task was cancelled - this is expected behavior + logger.debug("Read task cancelled") + raise # Re-raise to properly handle cancellation + except Exception as e: + logger.error(f"Fatal error in message reader: {e}") + # Put error in stream so iterators can handle it + await self._message_send.send({"type": "error", "error": str(e)}) + finally: + # Always signal end of stream + await self._message_send.send({"type": "end"}) + + async def _handle_control_request(self, request: SDKControlRequest) -> None: + """Handle incoming control request from CLI.""" + request_id = request["request_id"] + request_data = request["request"] + subtype = request_data["subtype"] + + try: + response_data = {} + + if subtype == "can_use_tool": + permission_request: SDKControlPermissionRequest = request_data # type: ignore[assignment] + # Handle tool permission request + if not self.can_use_tool: + raise Exception("canUseTool callback is not provided") + + response_data = await self.can_use_tool( + permission_request["tool_name"], + permission_request["input"], + { + "signal": None, # TODO: Add abort signal support + "suggestions": permission_request.get("permission_suggestions"), + }, + ) + + elif subtype == "hook_callback": + hook_callback_request: SDKHookCallbackRequest = request_data # type: ignore[assignment] + # Handle hook callback + callback_id = hook_callback_request["callback_id"] + callback = self.hook_callbacks.get(callback_id) + if not callback: + raise Exception(f"No hook callback found for ID: {callback_id}") + + response_data = await callback( + request_data.get("input"), + request_data.get("tool_use_id"), + {"signal": None}, # TODO: Add abort signal support + ) + + else: + raise Exception(f"Unsupported control request subtype: {subtype}") + + # Send success response + success_response: SDKControlResponse = { + "type": "control_response", + "response": { + "subtype": "success", + "request_id": request_id, + "response": response_data, + }, + } + await self.transport.write(json.dumps(success_response) + "\n") + + except Exception as e: + # Send error response + error_response: SDKControlResponse = { + "type": "control_response", + "response": { + "subtype": "error", + "request_id": request_id, + "error": str(e), + }, + } + await self.transport.write(json.dumps(error_response) + "\n") + + async def _send_control_request(self, request: dict[str, Any]) -> dict[str, Any]: + """Send control request to CLI and wait for response.""" + if not self.is_streaming_mode: + raise Exception("Control requests require streaming mode") + + # Generate unique request ID + self._request_counter += 1 + request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}" + + # Create event for response + event = anyio.Event() + self.pending_control_responses[request_id] = event + + # Build and send request + control_request = { + "type": "control_request", + "request_id": request_id, + "request": request, + } + + await self.transport.write(json.dumps(control_request) + "\n") + + # Wait for response + try: + with anyio.fail_after(60.0): + await event.wait() + + result = self.pending_control_results.pop(request_id) + self.pending_control_responses.pop(request_id, None) + + if isinstance(result, Exception): + raise result + + response_data = result.get("response", {}) + return response_data if isinstance(response_data, dict) else {} + except TimeoutError as e: + self.pending_control_responses.pop(request_id, None) + self.pending_control_results.pop(request_id, None) + raise Exception(f"Control request timeout: {request.get('subtype')}") from e + + async def interrupt(self) -> None: + """Send interrupt control request.""" + await self._send_control_request({"subtype": "interrupt"}) + + async def set_permission_mode(self, mode: str) -> None: + """Change permission mode.""" + await self._send_control_request( + { + "subtype": "set_permission_mode", + "mode": mode, + } + ) + + async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None: + """Stream input messages to transport.""" + try: + async for message in stream: + if self._closed: + break + await self.transport.write(json.dumps(message) + "\n") + # After all messages sent, end input + await self.transport.end_input() + except Exception as e: + logger.debug(f"Error streaming input: {e}") + + async def receive_messages(self) -> AsyncIterator[dict[str, Any]]: + """Receive SDK messages (not control messages).""" + async with self._message_receive: + async for message in self._message_receive: + # Check for special messages + if message.get("type") == "end": + break + elif message.get("type") == "error": + raise Exception(message.get("error", "Unknown error")) + + yield message + + async def close(self) -> None: + """Close the query and transport.""" + self._closed = True + if self._tg: + self._tg.cancel_scope.cancel() + # Wait for task group to complete cancellation + with suppress(anyio.get_cancelled_exc_class()): + await self._tg.__aexit__(None, None, None) + await self.transport.close() + + # Make Query an async iterator + def __aiter__(self) -> AsyncIterator[dict[str, Any]]: + """Return async iterator for messages.""" + return self.receive_messages() + + async def __anext__(self) -> dict[str, Any]: + """Get next message.""" + async for message in self.receive_messages(): + return message + raise StopAsyncIteration diff --git a/src/claude_code_sdk/_internal/transport/__init__.py b/src/claude_code_sdk/_internal/transport/__init__.py index 09a10f8..6dedef6 100644 --- a/src/claude_code_sdk/_internal/transport/__init__.py +++ b/src/claude_code_sdk/_internal/transport/__init__.py @@ -12,33 +12,56 @@ class Transport(ABC): (e.g., remote Claude Code connections). The Claude Code team may change or or remove this abstract class in any future release. Custom implementations must be updated to match interface changes. + + This is a low-level transport interface that handles raw I/O with the Claude + process or service. The Query class builds on top of this to implement the + control protocol and message routing. """ @abstractmethod async def connect(self) -> None: - """Initialize connection.""" + """Connect the transport and prepare for communication. + + For subprocess transports, this starts the process. + For network transports, this establishes the connection. + """ pass @abstractmethod - async def disconnect(self) -> None: - """Close connection.""" + async def write(self, data: str) -> None: + """Write raw data to the transport. + + Args: + data: Raw string data to write (typically JSON + newline) + """ pass @abstractmethod - async def send_request( - self, messages: list[dict[str, Any]], options: dict[str, Any] - ) -> None: - """Send request to Claude.""" + def read_messages(self) -> AsyncIterator[dict[str, Any]]: + """Read and parse messages from the transport. + + Yields: + Parsed JSON messages from the transport + """ pass @abstractmethod - def receive_messages(self) -> AsyncIterator[dict[str, Any]]: - """Receive messages from Claude.""" + async def close(self) -> None: + """Close the transport connection and clean up resources.""" pass @abstractmethod - def is_connected(self) -> bool: - """Check if transport is connected.""" + def is_ready(self) -> bool: + """Check if transport is ready for communication. + + Returns: + True if transport is ready to send/receive messages + """ + pass + + @abstractmethod + async def end_input(self) -> None: + """End the input stream (close stdin for process transports).""" pass diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 619b68e..85aec10 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -7,6 +7,7 @@ import shutil import tempfile from collections import deque from collections.abc import AsyncIterable, AsyncIterator +from contextlib import suppress from pathlib import Path from subprocess import PIPE from typing import Any @@ -33,7 +34,6 @@ class SubprocessCLITransport(Transport): prompt: str | AsyncIterable[dict[str, Any]], options: ClaudeCodeOptions, cli_path: str | Path | None = None, - close_stdin_after_prompt: bool = False, ): self._prompt = prompt self._is_streaming = not isinstance(prompt, str) @@ -44,11 +44,8 @@ class SubprocessCLITransport(Transport): self._stdout_stream: TextReceiveStream | None = None self._stderr_stream: TextReceiveStream | None = None self._stdin_stream: TextSendStream | None = None - self._pending_control_responses: dict[str, dict[str, Any]] = {} - self._request_counter = 0 - self._close_stdin_after_prompt = close_stdin_after_prompt - self._task_group: anyio.abc.TaskGroup | None = None self._stderr_file: Any = None # tempfile.NamedTemporaryFile + self._ready = False def _find_cli(self) -> str: """Find Claude Code CLI binary.""" @@ -174,7 +171,6 @@ class SubprocessCLITransport(Transport): mode="w+", prefix="claude_stderr_", suffix=".log", delete=False ) - # Enable stdin pipe for both modes (but we'll close it for string mode) # Merge environment variables: system -> user -> SDK required process_env = { **os.environ, @@ -197,19 +193,14 @@ class SubprocessCLITransport(Transport): if self._process.stdout: self._stdout_stream = TextReceiveStream(self._process.stdout) - # Handle stdin based on mode - if self._is_streaming: - # Streaming mode: keep stdin open and start streaming task - if self._process.stdin: - self._stdin_stream = TextSendStream(self._process.stdin) - # Start streaming messages to stdin in background - self._task_group = anyio.create_task_group() - await self._task_group.__aenter__() - self._task_group.start_soon(self._stream_to_stdin) - else: - # String mode: close stdin immediately (backward compatible) - if self._process.stdin: - await self._process.stdin.aclose() + # Setup stdin for streaming mode + if self._is_streaming and self._process.stdin: + self._stdin_stream = TextSendStream(self._process.stdin) + elif not self._is_streaming and self._process.stdin: + # String mode: close stdin immediately + await self._process.stdin.aclose() + + self._ready = True except FileNotFoundError as e: # Check if the error comes from the working directory or the CLI @@ -221,27 +212,31 @@ class SubprocessCLITransport(Transport): except Exception as e: raise CLIConnectionError(f"Failed to start Claude Code: {e}") from e - async def disconnect(self) -> None: - """Terminate subprocess.""" + async def close(self) -> None: + """Close the transport and clean up resources.""" + self._ready = False + if not self._process: return - # Cancel task group if it exists - if self._task_group: - self._task_group.cancel_scope.cancel() - await self._task_group.__aexit__(None, None, None) - self._task_group = None + # Close stdin first if it's still open + if self._stdin_stream: + with suppress(Exception): + await self._stdin_stream.aclose() + self._stdin_stream = None + if self._process.stdin: + with suppress(Exception): + await self._process.stdin.aclose() + + # Terminate and wait for process if self._process.returncode is None: - try: + with suppress(ProcessLookupError): self._process.terminate() - with anyio.fail_after(5.0): + # Wait for process to finish with timeout + with suppress(Exception): + # Just try to wait, but don't block if it fails await self._process.wait() - except TimeoutError: - self._process.kill() - await self._process.wait() - except ProcessLookupError: - pass # Clean up temp file if self._stderr_file: @@ -257,57 +252,35 @@ class SubprocessCLITransport(Transport): self._stderr_stream = None self._stdin_stream = None - async def send_request(self, messages: list[Any], options: dict[str, Any]) -> None: - """Send additional messages in streaming mode.""" - if not self._is_streaming: - raise CLIConnectionError("send_request only works in streaming mode") - + async def write(self, data: str) -> None: + """Write raw data to the transport.""" if not self._stdin_stream: - raise CLIConnectionError("stdin not available - stream may have ended") + raise CLIConnectionError("Cannot write: stdin not available") - # Send each message as a user message - for message in messages: - # Ensure message has required structure - if not isinstance(message, dict): - message = { - "type": "user", - "message": {"role": "user", "content": str(message)}, - "parent_tool_use_id": None, - "session_id": options.get("session_id", "default"), - } + await self._stdin_stream.send(data) - await self._stdin_stream.send(json.dumps(message) + "\n") - - async def _stream_to_stdin(self) -> None: - """Stream messages to stdin for streaming mode.""" - if not self._stdin_stream or not isinstance(self._prompt, AsyncIterable): - return - - try: - async for message in self._prompt: - if not self._stdin_stream: - break - await self._stdin_stream.send(json.dumps(message) + "\n") - - # Close stdin after prompt if requested (e.g., for query() one-shot mode) - if self._close_stdin_after_prompt and self._stdin_stream: + async def end_input(self) -> None: + """End the input stream (close stdin).""" + if self._stdin_stream: + with suppress(Exception): await self._stdin_stream.aclose() - self._stdin_stream = None - # Otherwise keep stdin open for send_request (ClaudeSDKClient interactive mode) - except Exception as e: - logger.debug(f"Error streaming to stdin: {e}") - if self._stdin_stream: - await self._stdin_stream.aclose() - self._stdin_stream = None + self._stdin_stream = None + if self._process and self._process.stdin: + with suppress(Exception): + await self._process.stdin.aclose() - async def receive_messages(self) -> AsyncIterator[dict[str, Any]]: - """Receive messages from CLI.""" + def read_messages(self) -> AsyncIterator[dict[str, Any]]: + """Read and parse messages from the transport.""" + return self._read_messages_impl() + + async def _read_messages_impl(self) -> AsyncIterator[dict[str, Any]]: + """Internal implementation of read_messages.""" if not self._process or not self._stdout_stream: raise CLIConnectionError("Not connected") json_buffer = "" - # Process stdout messages first + # Process stdout messages try: async for line in self._stdout_stream: line_str = line.strip() @@ -336,20 +309,7 @@ class SubprocessCLITransport(Transport): try: data = json.loads(json_buffer) json_buffer = "" - - # Handle control responses separately - if data.get("type") == "control_response": - response = data.get("response", {}) - request_id = response.get("request_id") - if request_id: - # Store the response for the pending request - self._pending_control_responses[request_id] = response - continue - - try: - yield data - except GeneratorExit: - return + yield data except json.JSONDecodeError: # We are speculatively decoding the buffer until we get # a full JSON object. If there is an actual issue, we @@ -359,7 +319,7 @@ class SubprocessCLITransport(Transport): except anyio.ClosedResourceError: pass except GeneratorExit: - # Client disconnected - still need to clean up + # Client disconnected pass # Read stderr from temp file (keep only last N lines for memory efficiency) @@ -402,48 +362,12 @@ class SubprocessCLITransport(Transport): # Log stderr for debugging but don't fail on non-zero exit logger.debug(f"Process stderr: {stderr_output}") - def is_connected(self) -> bool: - """Check if subprocess is running.""" - return self._process is not None and self._process.returncode is None + def is_ready(self) -> bool: + """Check if transport is ready for communication.""" + return ( + self._ready + and self._process is not None + and self._process.returncode is None + ) - async def interrupt(self) -> None: - """Send interrupt control request (only works in streaming mode).""" - if not self._is_streaming: - raise CLIConnectionError( - "Interrupt requires streaming mode (AsyncIterable prompt)" - ) - - if not self._stdin_stream: - raise CLIConnectionError("Not connected or stdin not available") - - await self._send_control_request({"subtype": "interrupt"}) - - async def _send_control_request(self, request: dict[str, Any]) -> dict[str, Any]: - """Send a control request and wait for response.""" - if not self._stdin_stream: - raise CLIConnectionError("Stdin not available") - - # Generate unique request ID - self._request_counter += 1 - request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}" - - # Build control request - control_request = { - "type": "control_request", - "request_id": request_id, - "request": request, - } - - # Send request - await self._stdin_stream.send(json.dumps(control_request) + "\n") - - # Wait for response - while request_id not in self._pending_control_responses: - await anyio.sleep(0.1) - - response = self._pending_control_responses.pop(request_id) - - if response.get("subtype") == "error": - raise CLIConnectionError(f"Control request failed: {response.get('error')}") - - return response + # Remove interrupt and control request methods - these now belong in Query class diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py index cf668fd..3cfeb42 100644 --- a/src/claude_code_sdk/client.py +++ b/src/claude_code_sdk/client.py @@ -1,5 +1,6 @@ """Claude SDK Client for interacting with Claude Code.""" +import json import os from collections.abc import AsyncIterable, AsyncIterator from typing import Any @@ -96,12 +97,15 @@ class ClaudeSDKClient: options = ClaudeCodeOptions() self.options = options self._transport: Any | None = None + self._query: Any | None = None os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py-client" async def connect( self, prompt: str | AsyncIterable[dict[str, Any]] | None = None ) -> None: """Connect to Claude with a prompt or message stream.""" + + from ._internal.query import Query from ._internal.transport.subprocess_cli import SubprocessCLITransport # Auto-connect with empty async iterable if no prompt is provided @@ -112,20 +116,38 @@ class ClaudeSDKClient: return yield {} # type: ignore[unreachable] + actual_prompt = _empty_stream() if prompt is None else prompt + self._transport = SubprocessCLITransport( - prompt=_empty_stream() if prompt is None else prompt, + prompt=actual_prompt, options=self.options, ) await self._transport.connect() + # Create Query to handle control protocol + self._query = Query( + transport=self._transport, + is_streaming_mode=True, # ClaudeSDKClient always uses streaming mode + can_use_tool=None, # TODO: Add support for can_use_tool callback + hooks=None, # TODO: Add support for hooks + ) + + # Start reading messages and initialize + await self._query.start() + await self._query.initialize() + + # If we have an initial prompt stream, start streaming it + if prompt is not None and isinstance(prompt, AsyncIterable) and self._query._tg: + self._query._tg.start_soon(self._query.stream_input, prompt) + async def receive_messages(self) -> AsyncIterator[Message]: """Receive all messages from Claude.""" - if not self._transport: + if not self._query: raise CLIConnectionError("Not connected. Call connect() first.") from ._internal.message_parser import parse_message - async for data in self._transport.receive_messages(): + async for data in self._query.receive_messages(): yield parse_message(data) async def query( @@ -138,7 +160,7 @@ class ClaudeSDKClient: prompt: Either a string message or an async iterable of message dictionaries session_id: Session identifier for the conversation """ - if not self._transport: + if not self._query or not self._transport: raise CLIConnectionError("Not connected. Call connect() first.") # Handle string prompts @@ -149,24 +171,45 @@ class ClaudeSDKClient: "parent_tool_use_id": None, "session_id": session_id, } - await self._transport.send_request([message], {"session_id": session_id}) + await self._transport.write(json.dumps(message) + "\n") else: - # Handle AsyncIterable prompts - messages = [] + # Handle AsyncIterable prompts - stream them async for msg in prompt: # Ensure session_id is set on each message if "session_id" not in msg: msg["session_id"] = session_id - messages.append(msg) - - if messages: - await self._transport.send_request(messages, {"session_id": session_id}) + await self._transport.write(json.dumps(msg) + "\n") async def interrupt(self) -> None: """Send interrupt signal (only works with streaming mode).""" - if not self._transport: + if not self._query: raise CLIConnectionError("Not connected. Call connect() first.") - await self._transport.interrupt() + await self._query.interrupt() + + async def get_server_info(self) -> dict[str, Any] | None: + """Get server initialization info including available commands and output styles. + + Returns initialization information from the Claude Code server including: + - Available commands (slash commands, system commands, etc.) + - Current and available output styles + - Server capabilities + + Returns: + Dictionary with server info, or None if not in streaming mode + + Example: + ```python + async with ClaudeSDKClient() as client: + info = await client.get_server_info() + if info: + print(f"Commands available: {len(info.get('commands', []))}") + print(f"Output style: {info.get('output_style', 'default')}") + ``` + """ + if not self._query: + raise CLIConnectionError("Not connected. Call connect() first.") + # Return the initialization result that was already obtained during connect + return getattr(self._query, "_initialization_result", None) async def receive_response(self) -> AsyncIterator[Message]: """ @@ -211,9 +254,10 @@ class ClaudeSDKClient: async def disconnect(self) -> None: """Disconnect from Claude.""" - if self._transport: - await self._transport.disconnect() - self._transport = None + if self._query: + await self._query.close() + self._query = None + self._transport = None async def __aenter__(self) -> "ClaudeSDKClient": """Enter async context - automatically connects with empty stream for interactive use.""" diff --git a/src/claude_code_sdk/types.py b/src/claude_code_sdk/types.py index 2c52907..1f61ec6 100644 --- a/src/claude_code_sdk/types.py +++ b/src/claude_code_sdk/types.py @@ -141,3 +141,72 @@ class ClaudeCodeOptions: extra_args: dict[str, str | None] = field( default_factory=dict ) # Pass arbitrary CLI flags + + +# SDK Control Protocol +class SDKControlInterruptRequest(TypedDict): + subtype: Literal["interrupt"] + + +class SDKControlPermissionRequest(TypedDict): + subtype: Literal["can_use_tool"] + tool_name: str + input: dict[str, Any] + # TODO: Add PermissionUpdate type here + permission_suggestions: list[Any] | None + blocked_path: str | None + + +class SDKControlInitializeRequest(TypedDict): + subtype: Literal["initialize"] + # TODO: Use HookEvent names as the key. + hooks: dict[str, Any] | None + + +class SDKControlSetPermissionModeRequest(TypedDict): + subtype: Literal["set_permission_mode"] + # TODO: Add PermissionMode + mode: str + + +class SDKHookCallbackRequest(TypedDict): + subtype: Literal["hook_callback"] + callback_id: str + input: Any + tool_use_id: str | None + + +class SDKControlMcpMessageRequest(TypedDict): + subtype: Literal["mcp_message"] + server_name: str + message: Any + + +class SDKControlRequest(TypedDict): + type: Literal["control_request"] + request_id: str + request: ( + SDKControlInterruptRequest + | SDKControlPermissionRequest + | SDKControlInitializeRequest + | SDKControlSetPermissionModeRequest + | SDKHookCallbackRequest + | SDKControlMcpMessageRequest + ) + + +class ControlResponse(TypedDict): + subtype: Literal["success"] + request_id: str + response: dict[str, Any] | None + + +class ControlErrorResponse(TypedDict): + subtype: Literal["error"] + request_id: str + error: str + + +class SDKControlResponse(TypedDict): + type: Literal["control_response"] + response: ControlResponse | ControlErrorResponse diff --git a/tests/test_client.py b/tests/test_client.py index 8156010..df1d087 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,6 +1,6 @@ """Tests for Claude SDK client functionality.""" -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch import anyio @@ -102,9 +102,12 @@ class TestQueryFunction: "total_cost_usd": 0.001, } - mock_transport.receive_messages = mock_receive + mock_transport.read_messages = mock_receive mock_transport.connect = AsyncMock() - mock_transport.disconnect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.end_input = AsyncMock() + mock_transport.write = AsyncMock() + mock_transport.is_ready = Mock(return_value=True) options = ClaudeCodeOptions(cwd="/custom/path") messages = [] diff --git a/tests/test_integration.py b/tests/test_integration.py index aa6d12e..c3e4feb 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -3,7 +3,7 @@ These tests verify end-to-end functionality with mocked CLI responses. """ -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch import anyio import pytest @@ -52,9 +52,12 @@ class TestIntegration: "total_cost_usd": 0.001, } - mock_transport.receive_messages = mock_receive + mock_transport.read_messages = mock_receive mock_transport.connect = AsyncMock() - mock_transport.disconnect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.end_input = AsyncMock() + mock_transport.write = AsyncMock() + mock_transport.is_ready = Mock(return_value=True) # Run query messages = [] @@ -118,9 +121,12 @@ class TestIntegration: "total_cost_usd": 0.002, } - mock_transport.receive_messages = mock_receive + mock_transport.read_messages = mock_receive mock_transport.connect = AsyncMock() - mock_transport.disconnect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.end_input = AsyncMock() + mock_transport.write = AsyncMock() + mock_transport.is_ready = Mock(return_value=True) # Run query with tools enabled messages = [] @@ -185,9 +191,12 @@ class TestIntegration: }, } - mock_transport.receive_messages = mock_receive + mock_transport.read_messages = mock_receive mock_transport.connect = AsyncMock() - mock_transport.disconnect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.end_input = AsyncMock() + mock_transport.write = AsyncMock() + mock_transport.is_ready = Mock(return_value=True) # Run query with continuation messages = [] diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index a9c2bb3..821ff96 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -1,10 +1,11 @@ """Tests for ClaudeSDKClient streaming functionality and query() with async iterables.""" import asyncio +import json import sys import tempfile from pathlib import Path -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch import anyio import pytest @@ -22,6 +23,90 @@ from claude_code_sdk import ( from claude_code_sdk._internal.transport.subprocess_cli import SubprocessCLITransport +def create_mock_transport(with_init_response=True): + """Create a properly configured mock transport. + + Args: + with_init_response: If True, automatically respond to initialization request + """ + mock_transport = AsyncMock() + mock_transport.connect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.end_input = AsyncMock() + mock_transport.write = AsyncMock() + mock_transport.is_ready = Mock(return_value=True) + + # Track written messages to simulate control protocol responses + written_messages = [] + + async def mock_write(data): + written_messages.append(data) + + mock_transport.write.side_effect = mock_write + + # Default read_messages to handle control protocol + async def control_protocol_generator(): + # Wait for initialization request if needed + if with_init_response: + # Wait a bit for the write to happen + await asyncio.sleep(0.01) + + # Check if initialization was requested + for msg_str in written_messages: + try: + msg = json.loads(msg_str.strip()) + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") == "initialize" + ): + # Send initialization response + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + "commands": [], + "output_style": "default", + }, + } + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + + # Keep checking for other control requests (like interrupt) + last_check = len(written_messages) + timeout_counter = 0 + while timeout_counter < 100: # Avoid infinite loop + await asyncio.sleep(0.01) + timeout_counter += 1 + + # Check for new messages + for msg_str in written_messages[last_check:]: + try: + msg = json.loads(msg_str.strip()) + if msg.get("type") == "control_request": + subtype = msg.get("request", {}).get("subtype") + if subtype == "interrupt": + # Send interrupt response + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + }, + } + return # End after interrupt + except (json.JSONDecodeError, KeyError, AttributeError): + pass + last_check = len(written_messages) + + # Then end the stream + return + + mock_transport.read_messages = control_protocol_generator + return mock_transport + + class TestClaudeSDKClientStreaming: """Test ClaudeSDKClient streaming functionality.""" @@ -32,7 +117,7 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport async with ClaudeSDKClient() as client: @@ -41,7 +126,7 @@ class TestClaudeSDKClientStreaming: assert client._transport is mock_transport # Verify disconnect was called on exit - mock_transport.disconnect.assert_called_once() + mock_transport.close.assert_called_once() anyio.run(_test) @@ -52,7 +137,7 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport client = ClaudeSDKClient() @@ -64,7 +149,7 @@ class TestClaudeSDKClientStreaming: await client.disconnect() # Verify disconnect was called - mock_transport.disconnect.assert_called_once() + mock_transport.close.assert_called_once() assert client._transport is None anyio.run(_test) @@ -76,7 +161,7 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport client = ClaudeSDKClient() @@ -95,7 +180,7 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport async def message_stream(): @@ -123,20 +208,30 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport async with ClaudeSDKClient() as client: await client.query("Test message") - # Verify send_request was called with correct format - mock_transport.send_request.assert_called_once() - call_args = mock_transport.send_request.call_args - messages, options = call_args[0] - assert len(messages) == 1 - assert messages[0]["type"] == "user" - assert messages[0]["message"]["content"] == "Test message" - assert options["session_id"] == "default" + # Verify write was called with correct format + # Should have at least 2 writes: init request and user message + assert mock_transport.write.call_count >= 2 + + # Find the user message in the write calls + user_msg_found = False + for call in mock_transport.write.call_args_list: + data = call[0][0] + try: + msg = json.loads(data.strip()) + if msg.get("type") == "user": + assert msg["message"]["content"] == "Test message" + assert msg["session_id"] == "default" + user_msg_found = True + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + assert user_msg_found, "User message not found in write calls" anyio.run(_test) @@ -147,16 +242,25 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport async with ClaudeSDKClient() as client: await client.query("Test", session_id="custom-session") - call_args = mock_transport.send_request.call_args - messages, options = call_args[0] - assert messages[0]["session_id"] == "custom-session" - assert options["session_id"] == "custom-session" + # Find the user message with custom session ID + session_found = False + for call in mock_transport.write.call_args_list: + data = call[0][0] + try: + msg = json.loads(data.strip()) + if msg.get("type") == "user": + assert msg["session_id"] == "custom-session" + session_found = True + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + assert session_found, "User message with custom session not found" anyio.run(_test) @@ -177,11 +281,37 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport - # Mock the message stream + # Mock the message stream with control protocol support async def mock_receive(): + # First handle initialization + await asyncio.sleep(0.01) + written = mock_transport.write.call_args_list + for call in written: + data = call[0][0] + try: + msg = json.loads(data.strip()) + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") + == "initialize" + ): + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + "commands": [], + "output_style": "default", + }, + } + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + + # Then yield the actual messages yield { "type": "assistant", "message": { @@ -195,7 +325,7 @@ class TestClaudeSDKClientStreaming: "message": {"role": "user", "content": "Hi there"}, } - mock_transport.receive_messages = mock_receive + mock_transport.read_messages = mock_receive async with ClaudeSDKClient() as client: messages = [] @@ -220,11 +350,37 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport - # Mock the message stream + # Mock the message stream with control protocol support async def mock_receive(): + # First handle initialization + await asyncio.sleep(0.01) + written = mock_transport.write.call_args_list + for call in written: + data = call[0][0] + try: + msg = json.loads(data.strip()) + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") + == "initialize" + ): + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + "commands": [], + "output_style": "default", + }, + } + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + + # Then yield the actual messages yield { "type": "assistant", "message": { @@ -255,7 +411,7 @@ class TestClaudeSDKClientStreaming: "model": "claude-opus-4-1-20250805", } - mock_transport.receive_messages = mock_receive + mock_transport.read_messages = mock_receive async with ClaudeSDKClient() as client: messages = [] @@ -276,12 +432,28 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport async with ClaudeSDKClient() as client: + # Interrupt is now handled via control protocol await client.interrupt() - mock_transport.interrupt.assert_called_once() + # Check that a control request was sent via write + write_calls = mock_transport.write.call_args_list + interrupt_found = False + for call in write_calls: + data = call[0][0] + try: + msg = json.loads(data.strip()) + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") == "interrupt" + ): + interrupt_found = True + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + assert interrupt_found, "Interrupt control request not found" anyio.run(_test) @@ -308,7 +480,7 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport client = ClaudeSDKClient(options=options) @@ -327,11 +499,38 @@ class TestClaudeSDKClientStreaming: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport - # Mock receive to wait then yield messages + # Mock receive to wait then yield messages with control protocol support async def mock_receive(): + # First handle initialization + await asyncio.sleep(0.01) + written = mock_transport.write.call_args_list + for call in written: + if call: + data = call[0][0] + try: + msg = json.loads(data.strip()) + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") + == "initialize" + ): + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + "commands": [], + "output_style": "default", + }, + } + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + + # Then yield the actual messages await asyncio.sleep(0.1) yield { "type": "assistant", @@ -353,7 +552,7 @@ class TestClaudeSDKClientStreaming: "total_cost_usd": 0.001, } - mock_transport.receive_messages = mock_receive + mock_transport.read_messages = mock_receive async with ClaudeSDKClient() as client: # Helper to get next message @@ -397,9 +596,35 @@ while True: line = sys.stdin.readline() if not line: break - stdin_messages.append(line.strip()) -# Verify we got 2 messages + try: + msg = json.loads(line.strip()) + # Handle control requests + if msg.get("type") == "control_request": + request_id = msg.get("request_id") + request = msg.get("request", {}) + + # Send control response for initialize + if request.get("subtype") == "initialize": + response = { + "type": "control_response", + "response": { + "subtype": "success", + "request_id": request_id, + "response": { + "commands": [], + "output_style": "default" + } + } + } + print(json.dumps(response)) + sys.stdout.flush() + else: + stdin_messages.append(line.strip()) + except: + stdin_messages.append(line.strip()) + +# Verify we got 2 user messages assert len(stdin_messages) == 2 assert '"First"' in stdin_messages[0] assert '"Second"' in stdin_messages[1] @@ -476,8 +701,11 @@ class TestClaudeSDKClientEdgeCases: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() - mock_transport_class.return_value = mock_transport + # Create a new mock transport for each call + mock_transport_class.side_effect = [ + create_mock_transport(), + create_mock_transport(), + ] client = ClaudeSDKClient() await client.connect() @@ -506,7 +734,7 @@ class TestClaudeSDKClientEdgeCases: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport with pytest.raises(ValueError): @@ -514,7 +742,7 @@ class TestClaudeSDKClientEdgeCases: raise ValueError("Test error") # Disconnect should still be called - mock_transport.disconnect.assert_called_once() + mock_transport.close.assert_called_once() anyio.run(_test) @@ -525,11 +753,38 @@ class TestClaudeSDKClientEdgeCases: with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" ) as mock_transport_class: - mock_transport = AsyncMock() + mock_transport = create_mock_transport() mock_transport_class.return_value = mock_transport - # Mock the message stream + # Mock the message stream with control protocol support async def mock_receive(): + # First handle initialization + await asyncio.sleep(0.01) + written = mock_transport.write.call_args_list + for call in written: + if call: + data = call[0][0] + try: + msg = json.loads(data.strip()) + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") + == "initialize" + ): + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + "commands": [], + "output_style": "default", + }, + } + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + + # Then yield the actual messages yield { "type": "assistant", "message": { @@ -557,7 +812,7 @@ class TestClaudeSDKClientEdgeCases: "total_cost_usd": 0.001, } - mock_transport.receive_messages = mock_receive + mock_transport.read_messages = mock_receive async with ClaudeSDKClient() as client: # Test list comprehension pattern from docstring diff --git a/tests/test_subprocess_buffering.py b/tests/test_subprocess_buffering.py index 426d42e..05584e1 100644 --- a/tests/test_subprocess_buffering.py +++ b/tests/test_subprocess_buffering.py @@ -63,7 +63,7 @@ class TestSubprocessBuffering: transport._stderr_stream = MockTextReceiveStream([]) # type: ignore[assignment] messages: list[Any] = [] - async for msg in transport.receive_messages(): + async for msg in transport.read_messages(): messages.append(msg) assert len(messages) == 2 @@ -97,7 +97,7 @@ class TestSubprocessBuffering: transport._stderr_stream = MockTextReceiveStream([]) messages: list[Any] = [] - async for msg in transport.receive_messages(): + async for msg in transport.read_messages(): messages.append(msg) assert len(messages) == 2 @@ -127,7 +127,7 @@ class TestSubprocessBuffering: transport._stderr_stream = MockTextReceiveStream([]) messages: list[Any] = [] - async for msg in transport.receive_messages(): + async for msg in transport.read_messages(): messages.append(msg) assert len(messages) == 2 @@ -173,7 +173,7 @@ class TestSubprocessBuffering: transport._stderr_stream = MockTextReceiveStream([]) messages: list[Any] = [] - async for msg in transport.receive_messages(): + async for msg in transport.read_messages(): messages.append(msg) assert len(messages) == 1 @@ -221,7 +221,7 @@ class TestSubprocessBuffering: transport._stderr_stream = MockTextReceiveStream([]) messages: list[Any] = [] - async for msg in transport.receive_messages(): + async for msg in transport.read_messages(): messages.append(msg) assert len(messages) == 1 @@ -252,7 +252,7 @@ class TestSubprocessBuffering: with pytest.raises(Exception) as exc_info: messages: list[Any] = [] - async for msg in transport.receive_messages(): + async for msg in transport.read_messages(): messages.append(msg) assert isinstance(exc_info.value, CLIJSONDecodeError) @@ -293,7 +293,7 @@ class TestSubprocessBuffering: transport._stderr_stream = MockTextReceiveStream([]) messages: list[Any] = [] - async for msg in transport.receive_messages(): + async for msg in transport.read_messages(): messages.append(msg) assert len(messages) == 3 diff --git a/tests/test_transport.py b/tests/test_transport.py index aa6a8e9..50e093c 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -112,8 +112,8 @@ class TestSubprocessCLITransport: assert "--resume" in cmd assert "session-123" in cmd - def test_connect_disconnect(self): - """Test connect and disconnect lifecycle.""" + def test_connect_close(self): + """Test connect and close lifecycle.""" async def _test(): with patch("anyio.open_process") as mock_exec: @@ -139,22 +139,22 @@ class TestSubprocessCLITransport: await transport.connect() assert transport._process is not None - assert transport.is_connected() + assert transport.is_ready() - await transport.disconnect() + await transport.close() mock_process.terminate.assert_called_once() anyio.run(_test) - def test_receive_messages(self): - """Test parsing messages from CLI output.""" - # This test is simplified to just test the parsing logic + def test_read_messages(self): + """Test reading messages from CLI output.""" + # This test is simplified to just test the transport creation # The full async stream handling is tested in integration tests transport = SubprocessCLITransport( prompt="test", options=ClaudeCodeOptions(), cli_path="/usr/bin/claude" ) - # The actual message parsing is done by the client, not the transport + # The transport now just provides raw message reading via read_messages() # So we just verify the transport can be created and basic structure is correct assert transport._prompt == "test" assert transport._cli_path == "/usr/bin/claude"