From 05e932d2d19fc217f5701e1cf1ec45c76aaacd62 Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Tue, 2 Sep 2025 06:35:00 +0900 Subject: [PATCH] Use anyio, not asyncio, in src --- src/claude_code_sdk/_internal/client.py | 6 +- src/claude_code_sdk/_internal/query.py | 92 +++++++++++++++---------- src/claude_code_sdk/client.py | 5 +- src/claude_code_sdk/types.py | 15 ++-- tests/test_streaming_client.py | 52 +++++++++++--- 5 files changed, 110 insertions(+), 60 deletions(-) diff --git a/src/claude_code_sdk/_internal/client.py b/src/claude_code_sdk/_internal/client.py index 373695d..6b5331f 100644 --- a/src/claude_code_sdk/_internal/client.py +++ b/src/claude_code_sdk/_internal/client.py @@ -1,6 +1,5 @@ """Internal client implementation.""" -import asyncio from collections.abc import AsyncIterable, AsyncIterator from typing import Any @@ -55,9 +54,10 @@ class InternalClient: await query.initialize() # Stream input if it's an AsyncIterable - if isinstance(prompt, AsyncIterable): + if isinstance(prompt, AsyncIterable) and query._tg: # Start streaming in background - asyncio.create_task(query.stream_input(prompt)) + # Create a task that will run in the background + query._tg.start_soon(query.stream_input, prompt) # For string prompts, the prompt is already passed via CLI args # Yield parsed messages diff --git a/src/claude_code_sdk/_internal/query.py b/src/claude_code_sdk/_internal/query.py index ce645b2..cf0b6d8 100644 --- a/src/claude_code_sdk/_internal/query.py +++ b/src/claude_code_sdk/_internal/query.py @@ -1,6 +1,5 @@ """Query class for handling bidirectional control protocol.""" -import asyncio import json import logging import os @@ -8,6 +7,8 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable from contextlib import suppress from typing import Any +import anyio + from ..types import ( SDKControlPermissionRequest, SDKControlRequest, @@ -54,14 +55,17 @@ class Query: self.hooks = hooks or {} # Control protocol state - self.pending_control_responses: dict[str, asyncio.Future[dict[str, Any]]] = {} + self.pending_control_responses: dict[str, anyio.Event] = {} + self.pending_control_results: dict[str, dict[str, Any] | Exception] = {} self.hook_callbacks: dict[str, Callable[..., Any]] = {} self.next_callback_id = 0 self._request_counter = 0 # Message stream - self._message_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() - self._read_task: asyncio.Task[None] | None = None + self._message_send, self._message_receive = anyio.create_memory_object_stream[ + dict[str, Any] + ](max_buffer_size=100) + self._tg: anyio.abc.TaskGroup | None = None self._initialized = False self._closed = False self._initialization_result: dict[str, Any] | None = None @@ -108,8 +112,10 @@ class Query: async def start(self) -> None: """Start reading messages from transport.""" - if self._read_task is None: - self._read_task = asyncio.create_task(self._read_messages()) + if self._tg is None: + self._tg = anyio.create_task_group() + await self._tg.__aenter__() + self._tg.start_soon(self._read_messages) async def _read_messages(self) -> None: """Read messages from transport and route them.""" @@ -125,20 +131,22 @@ class Query: response = message.get("response", {}) request_id = response.get("request_id") if request_id in self.pending_control_responses: - future = self.pending_control_responses.pop(request_id) + event = self.pending_control_responses[request_id] if response.get("subtype") == "error": - future.set_exception( - Exception(response.get("error", "Unknown error")) + self.pending_control_results[request_id] = Exception( + response.get("error", "Unknown error") ) else: - future.set_result(response) + self.pending_control_results[request_id] = response + event.set() continue elif msg_type == "control_request": # Handle incoming control requests from CLI # Cast message to SDKControlRequest for type safety request: SDKControlRequest = message # type: ignore[assignment] - asyncio.create_task(self._handle_control_request(request)) + if self._tg: + self._tg.start_soon(self._handle_control_request, request) continue elif msg_type == "control_cancel_request": @@ -146,20 +154,20 @@ class Query: # TODO: Implement cancellation support continue - # Regular SDK messages go to the queue - await self._message_queue.put(message) + # Regular SDK messages go to the stream + await self._message_send.send(message) - except asyncio.CancelledError: + except anyio.get_cancelled_exc_class(): # Task was cancelled - this is expected behavior logger.debug("Read task cancelled") raise # Re-raise to properly handle cancellation except Exception as e: logger.error(f"Fatal error in message reader: {e}") - # Put error in queue so iterators can handle it - await self._message_queue.put({"type": "error", "error": str(e)}) + # Put error in stream so iterators can handle it + await self._message_send.send({"type": "error", "error": str(e)}) finally: # Always signal end of stream - await self._message_queue.put({"type": "end"}) + await self._message_send.send({"type": "end"}) async def _handle_control_request(self, request: SDKControlRequest) -> None: """Handle incoming control request from CLI.""" @@ -234,9 +242,9 @@ class Query: self._request_counter += 1 request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}" - # Create future for response - future: asyncio.Future[dict[str, Any]] = asyncio.Future() - self.pending_control_responses[request_id] = future + # Create event for response + event = anyio.Event() + self.pending_control_responses[request_id] = event # Build and send request control_request = { @@ -249,11 +257,20 @@ class Query: # Wait for response try: - response = await asyncio.wait_for(future, timeout=60.0) - result = response.get("response", {}) - return result if isinstance(result, dict) else {} - except asyncio.TimeoutError as e: + with anyio.fail_after(60.0): + await event.wait() + + result = self.pending_control_results.pop(request_id) self.pending_control_responses.pop(request_id, None) + + if isinstance(result, Exception): + raise result + + response_data = result.get("response", {}) + return response_data if isinstance(response_data, dict) else {} + except TimeoutError as e: + self.pending_control_responses.pop(request_id, None) + self.pending_control_results.pop(request_id, None) raise Exception(f"Control request timeout: {request.get('subtype')}") from e async def interrupt(self) -> None: @@ -283,25 +300,24 @@ class Query: async def receive_messages(self) -> AsyncIterator[dict[str, Any]]: """Receive SDK messages (not control messages).""" - while not self._closed: - message = await self._message_queue.get() + async with self._message_receive: + async for message in self._message_receive: + # Check for special messages + if message.get("type") == "end": + break + elif message.get("type") == "error": + raise Exception(message.get("error", "Unknown error")) - # Check for special messages - if message.get("type") == "end": - break - elif message.get("type") == "error": - raise Exception(message.get("error", "Unknown error")) - - yield message + yield message async def close(self) -> None: """Close the query and transport.""" self._closed = True - if self._read_task and not self._read_task.done(): - self._read_task.cancel() - # Wait for task to complete cancellation - with suppress(asyncio.CancelledError): - await self._read_task + if self._tg: + self._tg.cancel_scope.cancel() + # Wait for task group to complete cancellation + with suppress(anyio.get_cancelled_exc_class()): + await self._tg.__aexit__(None, None, None) await self.transport.close() # Make Query an async iterator diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py index 2875d70..3cfeb42 100644 --- a/src/claude_code_sdk/client.py +++ b/src/claude_code_sdk/client.py @@ -1,6 +1,5 @@ """Claude SDK Client for interacting with Claude Code.""" -import asyncio import json import os from collections.abc import AsyncIterable, AsyncIterator @@ -138,8 +137,8 @@ class ClaudeSDKClient: await self._query.initialize() # If we have an initial prompt stream, start streaming it - if prompt is not None and isinstance(prompt, AsyncIterable): - asyncio.create_task(self._query.stream_input(prompt)) + if prompt is not None and isinstance(prompt, AsyncIterable) and self._query._tg: + self._query._tg.start_soon(self._query.stream_input, prompt) async def receive_messages(self) -> AsyncIterator[Message]: """Receive all messages from Claude.""" diff --git a/src/claude_code_sdk/types.py b/src/claude_code_sdk/types.py index 1faf59f..1f61ec6 100644 --- a/src/claude_code_sdk/types.py +++ b/src/claude_code_sdk/types.py @@ -194,18 +194,19 @@ class SDKControlRequest(TypedDict): | SDKControlMcpMessageRequest ) + class ControlResponse(TypedDict): - subtype: Literal['success'] - request_id: str - response: dict[str, Any] | None + subtype: Literal["success"] + request_id: str + response: dict[str, Any] | None class ControlErrorResponse(TypedDict): - subtype: Literal['error'] - request_id: str - error: str + subtype: Literal["error"] + request_id: str + error: str class SDKControlResponse(TypedDict): - type: Literal['control_response'] + type: Literal["control_response"] response: ControlResponse | ControlErrorResponse diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index 15b5250..821ff96 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -512,15 +512,19 @@ class TestClaudeSDKClientStreaming: data = call[0][0] try: msg = json.loads(data.strip()) - if msg.get("type") == "control_request" and msg.get("request", {}).get("subtype") == "initialize": + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") + == "initialize" + ): yield { "type": "control_response", "response": { "request_id": msg.get("request_id"), "subtype": "success", "commands": [], - "output_style": "default" - } + "output_style": "default", + }, } break except (json.JSONDecodeError, KeyError, AttributeError): @@ -592,9 +596,35 @@ while True: line = sys.stdin.readline() if not line: break - stdin_messages.append(line.strip()) -# Verify we got 2 messages + try: + msg = json.loads(line.strip()) + # Handle control requests + if msg.get("type") == "control_request": + request_id = msg.get("request_id") + request = msg.get("request", {}) + + # Send control response for initialize + if request.get("subtype") == "initialize": + response = { + "type": "control_response", + "response": { + "subtype": "success", + "request_id": request_id, + "response": { + "commands": [], + "output_style": "default" + } + } + } + print(json.dumps(response)) + sys.stdout.flush() + else: + stdin_messages.append(line.strip()) + except: + stdin_messages.append(line.strip()) + +# Verify we got 2 user messages assert len(stdin_messages) == 2 assert '"First"' in stdin_messages[0] assert '"Second"' in stdin_messages[1] @@ -674,7 +704,7 @@ class TestClaudeSDKClientEdgeCases: # Create a new mock transport for each call mock_transport_class.side_effect = [ create_mock_transport(), - create_mock_transport() + create_mock_transport(), ] client = ClaudeSDKClient() @@ -736,15 +766,19 @@ class TestClaudeSDKClientEdgeCases: data = call[0][0] try: msg = json.loads(data.strip()) - if msg.get("type") == "control_request" and msg.get("request", {}).get("subtype") == "initialize": + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") + == "initialize" + ): yield { "type": "control_response", "response": { "request_id": msg.get("request_id"), "subtype": "success", "commands": [], - "output_style": "default" - } + "output_style": "default", + }, } break except (json.JSONDecodeError, KeyError, AttributeError):