diff --git a/src/claude_code_sdk/_internal/client.py b/src/claude_code_sdk/_internal/client.py index fe2b78e..3e99511 100644 --- a/src/claude_code_sdk/_internal/client.py +++ b/src/claude_code_sdk/_internal/client.py @@ -1,6 +1,5 @@ """Internal client implementation.""" -import asyncio from collections.abc import AsyncIterable, AsyncIterator from typing import Any @@ -63,9 +62,10 @@ class InternalClient: await query.initialize() # Stream input if it's an AsyncIterable - if isinstance(prompt, AsyncIterable): + if isinstance(prompt, AsyncIterable) and query._tg: # Start streaming in background - asyncio.create_task(query.stream_input(prompt)) + # Create a task that will run in the background + query._tg.start_soon(query.stream_input, prompt) # For string prompts, the prompt is already passed via CLI args # Yield parsed messages diff --git a/src/claude_code_sdk/_internal/query.py b/src/claude_code_sdk/_internal/query.py index 0ab2740..c29a751 100644 --- a/src/claude_code_sdk/_internal/query.py +++ b/src/claude_code_sdk/_internal/query.py @@ -1,6 +1,5 @@ """Query class for handling bidirectional control protocol.""" -import asyncio import json import logging import os @@ -8,6 +7,14 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable from contextlib import suppress from typing import TYPE_CHECKING, Any +import anyio + +from ..types import ( + SDKControlPermissionRequest, + SDKControlRequest, + SDKControlResponse, + SDKHookCallbackRequest, +) from .transport import Transport if TYPE_CHECKING: @@ -54,14 +61,17 @@ class Query: self.sdk_mcp_servers = sdk_mcp_servers or {} # Control protocol state - self.pending_control_responses: dict[str, asyncio.Future[dict[str, Any]]] = {} + self.pending_control_responses: dict[str, anyio.Event] = {} + self.pending_control_results: dict[str, dict[str, Any] | Exception] = {} self.hook_callbacks: dict[str, Callable[..., Any]] = {} self.next_callback_id = 0 self._request_counter = 0 # Message stream - self._message_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() - self._read_task: asyncio.Task[None] | None = None + self._message_send, self._message_receive = anyio.create_memory_object_stream[ + dict[str, Any] + ](max_buffer_size=100) + self._tg: anyio.abc.TaskGroup | None = None self._initialized = False self._closed = False self._initialization_result: dict[str, Any] | None = None @@ -108,8 +118,10 @@ class Query: async def start(self) -> None: """Start reading messages from transport.""" - if self._read_task is None: - self._read_task = asyncio.create_task(self._read_messages()) + if self._tg is None: + self._tg = anyio.create_task_group() + await self._tg.__aenter__() + self._tg.start_soon(self._read_messages) async def _read_messages(self) -> None: """Read messages from transport and route them.""" @@ -125,18 +137,22 @@ class Query: response = message.get("response", {}) request_id = response.get("request_id") if request_id in self.pending_control_responses: - future = self.pending_control_responses.pop(request_id) + event = self.pending_control_responses[request_id] if response.get("subtype") == "error": - future.set_exception( - Exception(response.get("error", "Unknown error")) + self.pending_control_results[request_id] = Exception( + response.get("error", "Unknown error") ) else: - future.set_result(response) + self.pending_control_results[request_id] = response + event.set() continue elif msg_type == "control_request": # Handle incoming control requests from CLI - asyncio.create_task(self._handle_control_request(message)) + # Cast message to SDKControlRequest for type safety + request: SDKControlRequest = message # type: ignore[assignment] + if self._tg: + self._tg.start_soon(self._handle_control_request, request) continue elif msg_type == "control_cancel_request": @@ -144,47 +160,49 @@ class Query: # TODO: Implement cancellation support continue - # Regular SDK messages go to the queue - await self._message_queue.put(message) + # Regular SDK messages go to the stream + await self._message_send.send(message) - except asyncio.CancelledError: + except anyio.get_cancelled_exc_class(): # Task was cancelled - this is expected behavior logger.debug("Read task cancelled") raise # Re-raise to properly handle cancellation except Exception as e: logger.error(f"Fatal error in message reader: {e}") - # Put error in queue so iterators can handle it - await self._message_queue.put({"type": "error", "error": str(e)}) + # Put error in stream so iterators can handle it + await self._message_send.send({"type": "error", "error": str(e)}) finally: # Always signal end of stream - await self._message_queue.put({"type": "end"}) + await self._message_send.send({"type": "end"}) - async def _handle_control_request(self, request: dict[str, Any]) -> None: + async def _handle_control_request(self, request: SDKControlRequest) -> None: """Handle incoming control request from CLI.""" - request_id = request.get("request_id") - request_data = request.get("request", {}) - subtype = request_data.get("subtype") + request_id = request["request_id"] + request_data = request["request"] + subtype = request_data["subtype"] try: response_data = {} if subtype == "can_use_tool": + permission_request: SDKControlPermissionRequest = request_data # type: ignore[assignment] # Handle tool permission request if not self.can_use_tool: raise Exception("canUseTool callback is not provided") response_data = await self.can_use_tool( - request_data.get("tool_name"), - request_data.get("input"), + permission_request["tool_name"], + permission_request["input"], { "signal": None, # TODO: Add abort signal support - "suggestions": request_data.get("permission_suggestions"), + "suggestions": permission_request.get("permission_suggestions"), }, ) elif subtype == "hook_callback": + hook_callback_request: SDKHookCallbackRequest = request_data # type: ignore[assignment] # Handle hook callback - callback_id = request_data.get("callback_id") + callback_id = hook_callback_request["callback_id"] callback = self.hook_callbacks.get(callback_id) if not callback: raise Exception(f"No hook callback found for ID: {callback_id}") @@ -211,7 +229,7 @@ class Query: raise Exception(f"Unsupported control request subtype: {subtype}") # Send success response - response = { + success_response: SDKControlResponse = { "type": "control_response", "response": { "subtype": "success", @@ -219,11 +237,11 @@ class Query: "response": response_data, }, } - await self.transport.write(json.dumps(response) + "\n") + await self.transport.write(json.dumps(success_response) + "\n") except Exception as e: # Send error response - response = { + error_response: SDKControlResponse = { "type": "control_response", "response": { "subtype": "error", @@ -231,7 +249,7 @@ class Query: "error": str(e), }, } - await self.transport.write(json.dumps(response) + "\n") + await self.transport.write(json.dumps(error_response) + "\n") async def _send_control_request(self, request: dict[str, Any]) -> dict[str, Any]: """Send control request to CLI and wait for response.""" @@ -242,9 +260,9 @@ class Query: self._request_counter += 1 request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}" - # Create future for response - future: asyncio.Future[dict[str, Any]] = asyncio.Future() - self.pending_control_responses[request_id] = future + # Create event for response + event = anyio.Event() + self.pending_control_responses[request_id] = event # Build and send request control_request = { @@ -257,11 +275,20 @@ class Query: # Wait for response try: - response = await asyncio.wait_for(future, timeout=60.0) - result = response.get("response", {}) - return result if isinstance(result, dict) else {} - except asyncio.TimeoutError as e: + with anyio.fail_after(60.0): + await event.wait() + + result = self.pending_control_results.pop(request_id) self.pending_control_responses.pop(request_id, None) + + if isinstance(result, Exception): + raise result + + response_data = result.get("response", {}) + return response_data if isinstance(response_data, dict) else {} + except TimeoutError as e: + self.pending_control_responses.pop(request_id, None) + self.pending_control_results.pop(request_id, None) raise Exception(f"Control request timeout: {request.get('subtype')}") from e async def _handle_sdk_mcp_request(self, server_name: str, message: dict) -> dict: @@ -350,25 +377,24 @@ class Query: async def receive_messages(self) -> AsyncIterator[dict[str, Any]]: """Receive SDK messages (not control messages).""" - while not self._closed: - message = await self._message_queue.get() + async with self._message_receive: + async for message in self._message_receive: + # Check for special messages + if message.get("type") == "end": + break + elif message.get("type") == "error": + raise Exception(message.get("error", "Unknown error")) - # Check for special messages - if message.get("type") == "end": - break - elif message.get("type") == "error": - raise Exception(message.get("error", "Unknown error")) - - yield message + yield message async def close(self) -> None: """Close the query and transport.""" self._closed = True - if self._read_task and not self._read_task.done(): - self._read_task.cancel() - # Wait for task to complete cancellation - with suppress(asyncio.CancelledError): - await self._read_task + if self._tg: + self._tg.cancel_scope.cancel() + # Wait for task group to complete cancellation + with suppress(anyio.get_cancelled_exc_class()): + await self._tg.__aexit__(None, None, None) await self.transport.close() # Make Query an async iterator diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py index 2875d70..3cfeb42 100644 --- a/src/claude_code_sdk/client.py +++ b/src/claude_code_sdk/client.py @@ -1,6 +1,5 @@ """Claude SDK Client for interacting with Claude Code.""" -import asyncio import json import os from collections.abc import AsyncIterable, AsyncIterator @@ -138,8 +137,8 @@ class ClaudeSDKClient: await self._query.initialize() # If we have an initial prompt stream, start streaming it - if prompt is not None and isinstance(prompt, AsyncIterable): - asyncio.create_task(self._query.stream_input(prompt)) + if prompt is not None and isinstance(prompt, AsyncIterable) and self._query._tg: + self._query._tg.start_soon(self._query.stream_input, prompt) async def receive_messages(self) -> AsyncIterator[Message]: """Receive all messages from Claude.""" diff --git a/src/claude_code_sdk/types.py b/src/claude_code_sdk/types.py index b4fc413..22966dd 100644 --- a/src/claude_code_sdk/types.py +++ b/src/claude_code_sdk/types.py @@ -156,27 +156,70 @@ class ClaudeCodeOptions: ) # Pass arbitrary CLI flags -# Control protocol types for initialization -@dataclass -class InitializationMessage: - """Initialization message from the CLI.""" - - type: Literal["control_response"] - response: dict[str, Any] +# SDK Control Protocol +class SDKControlInterruptRequest(TypedDict): + subtype: Literal["interrupt"] -@dataclass -class ControlRequest: - """Control request message.""" +class SDKControlPermissionRequest(TypedDict): + subtype: Literal["can_use_tool"] + tool_name: str + input: dict[str, Any] + # TODO: Add PermissionUpdate type here + permission_suggestions: list[Any] | None + blocked_path: str | None + +class SDKControlInitializeRequest(TypedDict): + subtype: Literal["initialize"] + # TODO: Use HookEvent names as the key. + hooks: dict[str, Any] | None + + +class SDKControlSetPermissionModeRequest(TypedDict): + subtype: Literal["set_permission_mode"] + # TODO: Add PermissionMode + mode: str + + +class SDKHookCallbackRequest(TypedDict): + subtype: Literal["hook_callback"] + callback_id: str + input: Any + tool_use_id: str | None + + +class SDKControlMcpMessageRequest(TypedDict): + subtype: Literal["mcp_message"] + server_name: str + message: Any + + +class SDKControlRequest(TypedDict): type: Literal["control_request"] request_id: str - request: dict[str, Any] + request: ( + SDKControlInterruptRequest + | SDKControlPermissionRequest + | SDKControlInitializeRequest + | SDKControlSetPermissionModeRequest + | SDKHookCallbackRequest + | SDKControlMcpMessageRequest + ) -@dataclass -class ControlResponse: - """Control response message.""" +class ControlResponse(TypedDict): + subtype: Literal["success"] + request_id: str + response: dict[str, Any] | None + +class ControlErrorResponse(TypedDict): + subtype: Literal["error"] + request_id: str + error: str + + +class SDKControlResponse(TypedDict): type: Literal["control_response"] - response: dict[str, Any] + response: ControlResponse | ControlErrorResponse diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index 214869a..821ff96 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -512,20 +512,24 @@ class TestClaudeSDKClientStreaming: data = call[0][0] try: msg = json.loads(data.strip()) - if msg.get("type") == "control_request" and msg.get("request", {}).get("subtype") == "initialize": + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") + == "initialize" + ): yield { "type": "control_response", "response": { "request_id": msg.get("request_id"), "subtype": "success", "commands": [], - "output_style": "default" - } + "output_style": "default", + }, } break except (json.JSONDecodeError, KeyError, AttributeError): pass - + # Then yield the actual messages await asyncio.sleep(0.1) yield { @@ -592,9 +596,35 @@ while True: line = sys.stdin.readline() if not line: break - stdin_messages.append(line.strip()) -# Verify we got 2 messages + try: + msg = json.loads(line.strip()) + # Handle control requests + if msg.get("type") == "control_request": + request_id = msg.get("request_id") + request = msg.get("request", {}) + + # Send control response for initialize + if request.get("subtype") == "initialize": + response = { + "type": "control_response", + "response": { + "subtype": "success", + "request_id": request_id, + "response": { + "commands": [], + "output_style": "default" + } + } + } + print(json.dumps(response)) + sys.stdout.flush() + else: + stdin_messages.append(line.strip()) + except: + stdin_messages.append(line.strip()) + +# Verify we got 2 user messages assert len(stdin_messages) == 2 assert '"First"' in stdin_messages[0] assert '"Second"' in stdin_messages[1] @@ -674,7 +704,7 @@ class TestClaudeSDKClientEdgeCases: # Create a new mock transport for each call mock_transport_class.side_effect = [ create_mock_transport(), - create_mock_transport() + create_mock_transport(), ] client = ClaudeSDKClient() @@ -736,20 +766,24 @@ class TestClaudeSDKClientEdgeCases: data = call[0][0] try: msg = json.loads(data.strip()) - if msg.get("type") == "control_request" and msg.get("request", {}).get("subtype") == "initialize": + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") + == "initialize" + ): yield { "type": "control_response", "response": { "request_id": msg.get("request_id"), "subtype": "success", "commands": [], - "output_style": "default" - } + "output_style": "default", + }, } break except (json.JSONDecodeError, KeyError, AttributeError): pass - + # Then yield the actual messages yield { "type": "assistant",