diff --git a/src/claude_code_sdk/_internal/client.py b/src/claude_code_sdk/_internal/client.py index 44d8622..ccfc1e8 100644 --- a/src/claude_code_sdk/_internal/client.py +++ b/src/claude_code_sdk/_internal/client.py @@ -31,9 +31,7 @@ class InternalClient: if transport is not None: chosen_transport = transport else: - chosen_transport = SubprocessCLITransport( - prompt=prompt, options=options - ) + chosen_transport = SubprocessCLITransport(prompt=prompt, options=options) # Connect transport await chosen_transport.connect() @@ -59,6 +57,7 @@ class InternalClient: if isinstance(prompt, AsyncIterable): # Start streaming in background import asyncio + asyncio.create_task(query.stream_input(prompt)) # For string prompts, the prompt is already passed via CLI args diff --git a/src/claude_code_sdk/_internal/query.py b/src/claude_code_sdk/_internal/query.py index 04edc16..815522f 100644 --- a/src/claude_code_sdk/_internal/query.py +++ b/src/claude_code_sdk/_internal/query.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) class Query: """Handles bidirectional control protocol on top of Transport. - + This class manages: - Control request/response routing - Hook callbacks @@ -27,11 +27,14 @@ class Query: self, transport: Transport, is_streaming_mode: bool, - can_use_tool: Callable[[str, dict[str, Any], dict[str, Any]], Awaitable[dict[str, Any]]] | None = None, + can_use_tool: Callable[ + [str, dict[str, Any], dict[str, Any]], Awaitable[dict[str, Any]] + ] + | None = None, hooks: dict[str, list[dict[str, Any]]] | None = None, ): """Initialize Query with transport and callbacks. - + Args: transport: Low-level transport for I/O is_streaming_mode: Whether using streaming (bidirectional) mode @@ -58,7 +61,7 @@ class Query: async def initialize(self) -> dict[str, Any] | None: """Initialize control protocol if in streaming mode. - + Returns: Initialize response with supported commands, or None if not streaming """ @@ -78,10 +81,12 @@ class Query: self.next_callback_id += 1 self.hook_callbacks[callback_id] = callback callback_ids.append(callback_id) - hooks_config[event].append({ - "matcher": matcher.get("matcher"), - "hookCallbackIds": callback_ids, - }) + hooks_config[event].append( + { + "matcher": matcher.get("matcher"), + "hookCallbackIds": callback_ids, + } + ) # Send initialize request request = { @@ -115,7 +120,9 @@ class Query: if request_id in self.pending_control_responses: future = self.pending_control_responses.pop(request_id) if response.get("subtype") == "error": - future.set_exception(Exception(response.get("error", "Unknown error"))) + future.set_exception( + Exception(response.get("error", "Unknown error")) + ) else: future.set_result(response) continue @@ -161,7 +168,7 @@ class Query: { "signal": None, # TODO: Add abort signal support "suggestions": request_data.get("permission_suggestions"), - } + }, ) elif subtype == "hook_callback": @@ -174,7 +181,7 @@ class Query: response_data = await callback( request_data.get("input"), request_data.get("tool_use_id"), - {"signal": None} # TODO: Add abort signal support + {"signal": None}, # TODO: Add abort signal support ) else: @@ -187,7 +194,7 @@ class Query: "subtype": "success", "request_id": request_id, "response": response_data, - } + }, } await self.transport.write(json.dumps(response) + "\n") @@ -199,7 +206,7 @@ class Query: "subtype": "error", "request_id": request_id, "error": str(e), - } + }, } await self.transport.write(json.dumps(response) + "\n") @@ -240,10 +247,12 @@ class Query: async def set_permission_mode(self, mode: str) -> None: """Change permission mode.""" - await self._send_control_request({ - "subtype": "set_permission_mode", - "mode": mode, - }) + await self._send_control_request( + { + "subtype": "set_permission_mode", + "mode": mode, + } + ) async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None: """Stream input messages to transport.""" diff --git a/src/claude_code_sdk/_internal/transport/__init__.py b/src/claude_code_sdk/_internal/transport/__init__.py index 5da45e1..9773996 100644 --- a/src/claude_code_sdk/_internal/transport/__init__.py +++ b/src/claude_code_sdk/_internal/transport/__init__.py @@ -12,16 +12,25 @@ class Transport(ABC): (e.g., remote Claude Code connections). The Claude Code team may change or or remove this abstract class in any future release. Custom implementations must be updated to match interface changes. - + This is a low-level transport interface that handles raw I/O with the Claude process or service. The Query class builds on top of this to implement the control protocol and message routing. """ + @abstractmethod + async def connect(self) -> None: + """Connect the transport and prepare for communication. + + For subprocess transports, this starts the process. + For network transports, this establishes the connection. + """ + pass + @abstractmethod async def write(self, data: str) -> None: """Write raw data to the transport. - + Args: data: Raw string data to write (typically JSON + newline) """ @@ -30,7 +39,7 @@ class Transport(ABC): @abstractmethod def read_messages(self) -> AsyncIterator[dict[str, Any]]: """Read and parse messages from the transport. - + Yields: Parsed JSON messages from the transport """ @@ -44,7 +53,7 @@ class Transport(ABC): @abstractmethod def is_ready(self) -> bool: """Check if transport is ready for communication. - + Returns: True if transport is ready to send/receive messages """ diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 9867788..2048d60 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -217,6 +217,7 @@ class SubprocessCLITransport(Transport): if self._process.returncode is None: from contextlib import suppress + with suppress(ProcessLookupError): self._process.terminate() # Note: We can't use async wait here since close() is sync @@ -251,6 +252,7 @@ class SubprocessCLITransport(Transport): self._stdin_stream = None if self._process and self._process.stdin: from contextlib import suppress + with suppress(Exception): # Mark stdin as closed - actual close will happen during cleanup pass @@ -350,6 +352,10 @@ class SubprocessCLITransport(Transport): def is_ready(self) -> bool: """Check if transport is ready for communication.""" - return self._ready and self._process is not None and self._process.returncode is None + return ( + self._ready + and self._process is not None + and self._process.returncode is None + ) # Remove interrupt and control request methods - these now belong in Query class diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py index 3e5bd0e..216b601 100644 --- a/src/claude_code_sdk/client.py +++ b/src/claude_code_sdk/client.py @@ -138,6 +138,7 @@ class ClaudeSDKClient: # If we have an initial prompt stream, start streaming it if prompt is not None and isinstance(prompt, AsyncIterable): import asyncio + asyncio.create_task(self._query.stream_input(prompt)) async def receive_messages(self) -> AsyncIterator[Message]: @@ -187,18 +188,18 @@ class ClaudeSDKClient: if not self._query: raise CLIConnectionError("Not connected. Call connect() first.") await self._query.interrupt() - + async def get_server_info(self) -> dict[str, Any] | None: """Get server initialization info including available commands and output styles. - + Returns initialization information from the Claude Code server including: - Available commands (slash commands, system commands, etc.) - Current and available output styles - Server capabilities - + Returns: Dictionary with server info, or None if not in streaming mode - + Example: ```python async with ClaudeSDKClient() as client: @@ -211,7 +212,7 @@ class ClaudeSDKClient: if not self._query: raise CLIConnectionError("Not connected. Call connect() first.") # Return the initialization result that was already obtained during connect - return getattr(self._query, '_initialization_result', None) + return getattr(self._query, "_initialization_result", None) async def receive_response(self) -> AsyncIterator[Message]: """ diff --git a/src/claude_code_sdk/types.py b/src/claude_code_sdk/types.py index 2c52907..2f0d0e8 100644 --- a/src/claude_code_sdk/types.py +++ b/src/claude_code_sdk/types.py @@ -141,3 +141,29 @@ class ClaudeCodeOptions: extra_args: dict[str, str | None] = field( default_factory=dict ) # Pass arbitrary CLI flags + + +# Control protocol types for initialization +@dataclass +class InitializationMessage: + """Initialization message from the CLI.""" + + type: Literal["control_response"] + response: dict[str, Any] + + +@dataclass +class ControlRequest: + """Control request message.""" + + type: Literal["control_request"] + request_id: str + request: dict[str, Any] + + +@dataclass +class ControlResponse: + """Control response message.""" + + type: Literal["control_response"] + response: dict[str, Any]