Merge branch 'dickson/control' into feat/sdk-mcp-server-support

This commit is contained in:
Kashyap Murali 2025-09-01 18:33:33 -07:00
commit 44f0d05fb7
No known key found for this signature in database
5 changed files with 184 additions and 82 deletions

View file

@ -1,6 +1,5 @@
"""Internal client implementation."""
import asyncio
from collections.abc import AsyncIterable, AsyncIterator
from typing import Any
@ -63,9 +62,10 @@ class InternalClient:
await query.initialize()
# Stream input if it's an AsyncIterable
if isinstance(prompt, AsyncIterable):
if isinstance(prompt, AsyncIterable) and query._tg:
# Start streaming in background
asyncio.create_task(query.stream_input(prompt))
# Create a task that will run in the background
query._tg.start_soon(query.stream_input, prompt)
# For string prompts, the prompt is already passed via CLI args
# Yield parsed messages

View file

@ -1,6 +1,5 @@
"""Query class for handling bidirectional control protocol."""
import asyncio
import json
import logging
import os
@ -8,6 +7,14 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
from contextlib import suppress
from typing import TYPE_CHECKING, Any
import anyio
from ..types import (
SDKControlPermissionRequest,
SDKControlRequest,
SDKControlResponse,
SDKHookCallbackRequest,
)
from .transport import Transport
if TYPE_CHECKING:
@ -54,14 +61,17 @@ class Query:
self.sdk_mcp_servers = sdk_mcp_servers or {}
# Control protocol state
self.pending_control_responses: dict[str, asyncio.Future[dict[str, Any]]] = {}
self.pending_control_responses: dict[str, anyio.Event] = {}
self.pending_control_results: dict[str, dict[str, Any] | Exception] = {}
self.hook_callbacks: dict[str, Callable[..., Any]] = {}
self.next_callback_id = 0
self._request_counter = 0
# Message stream
self._message_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
self._read_task: asyncio.Task[None] | None = None
self._message_send, self._message_receive = anyio.create_memory_object_stream[
dict[str, Any]
](max_buffer_size=100)
self._tg: anyio.abc.TaskGroup | None = None
self._initialized = False
self._closed = False
self._initialization_result: dict[str, Any] | None = None
@ -108,8 +118,10 @@ class Query:
async def start(self) -> None:
"""Start reading messages from transport."""
if self._read_task is None:
self._read_task = asyncio.create_task(self._read_messages())
if self._tg is None:
self._tg = anyio.create_task_group()
await self._tg.__aenter__()
self._tg.start_soon(self._read_messages)
async def _read_messages(self) -> None:
"""Read messages from transport and route them."""
@ -125,18 +137,22 @@ class Query:
response = message.get("response", {})
request_id = response.get("request_id")
if request_id in self.pending_control_responses:
future = self.pending_control_responses.pop(request_id)
event = self.pending_control_responses[request_id]
if response.get("subtype") == "error":
future.set_exception(
Exception(response.get("error", "Unknown error"))
self.pending_control_results[request_id] = Exception(
response.get("error", "Unknown error")
)
else:
future.set_result(response)
self.pending_control_results[request_id] = response
event.set()
continue
elif msg_type == "control_request":
# Handle incoming control requests from CLI
asyncio.create_task(self._handle_control_request(message))
# Cast message to SDKControlRequest for type safety
request: SDKControlRequest = message # type: ignore[assignment]
if self._tg:
self._tg.start_soon(self._handle_control_request, request)
continue
elif msg_type == "control_cancel_request":
@ -144,47 +160,49 @@ class Query:
# TODO: Implement cancellation support
continue
# Regular SDK messages go to the queue
await self._message_queue.put(message)
# Regular SDK messages go to the stream
await self._message_send.send(message)
except asyncio.CancelledError:
except anyio.get_cancelled_exc_class():
# Task was cancelled - this is expected behavior
logger.debug("Read task cancelled")
raise # Re-raise to properly handle cancellation
except Exception as e:
logger.error(f"Fatal error in message reader: {e}")
# Put error in queue so iterators can handle it
await self._message_queue.put({"type": "error", "error": str(e)})
# Put error in stream so iterators can handle it
await self._message_send.send({"type": "error", "error": str(e)})
finally:
# Always signal end of stream
await self._message_queue.put({"type": "end"})
await self._message_send.send({"type": "end"})
async def _handle_control_request(self, request: dict[str, Any]) -> None:
async def _handle_control_request(self, request: SDKControlRequest) -> None:
"""Handle incoming control request from CLI."""
request_id = request.get("request_id")
request_data = request.get("request", {})
subtype = request_data.get("subtype")
request_id = request["request_id"]
request_data = request["request"]
subtype = request_data["subtype"]
try:
response_data = {}
if subtype == "can_use_tool":
permission_request: SDKControlPermissionRequest = request_data # type: ignore[assignment]
# Handle tool permission request
if not self.can_use_tool:
raise Exception("canUseTool callback is not provided")
response_data = await self.can_use_tool(
request_data.get("tool_name"),
request_data.get("input"),
permission_request["tool_name"],
permission_request["input"],
{
"signal": None, # TODO: Add abort signal support
"suggestions": request_data.get("permission_suggestions"),
"suggestions": permission_request.get("permission_suggestions"),
},
)
elif subtype == "hook_callback":
hook_callback_request: SDKHookCallbackRequest = request_data # type: ignore[assignment]
# Handle hook callback
callback_id = request_data.get("callback_id")
callback_id = hook_callback_request["callback_id"]
callback = self.hook_callbacks.get(callback_id)
if not callback:
raise Exception(f"No hook callback found for ID: {callback_id}")
@ -211,7 +229,7 @@ class Query:
raise Exception(f"Unsupported control request subtype: {subtype}")
# Send success response
response = {
success_response: SDKControlResponse = {
"type": "control_response",
"response": {
"subtype": "success",
@ -219,11 +237,11 @@ class Query:
"response": response_data,
},
}
await self.transport.write(json.dumps(response) + "\n")
await self.transport.write(json.dumps(success_response) + "\n")
except Exception as e:
# Send error response
response = {
error_response: SDKControlResponse = {
"type": "control_response",
"response": {
"subtype": "error",
@ -231,7 +249,7 @@ class Query:
"error": str(e),
},
}
await self.transport.write(json.dumps(response) + "\n")
await self.transport.write(json.dumps(error_response) + "\n")
async def _send_control_request(self, request: dict[str, Any]) -> dict[str, Any]:
"""Send control request to CLI and wait for response."""
@ -242,9 +260,9 @@ class Query:
self._request_counter += 1
request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}"
# Create future for response
future: asyncio.Future[dict[str, Any]] = asyncio.Future()
self.pending_control_responses[request_id] = future
# Create event for response
event = anyio.Event()
self.pending_control_responses[request_id] = event
# Build and send request
control_request = {
@ -257,11 +275,20 @@ class Query:
# Wait for response
try:
response = await asyncio.wait_for(future, timeout=60.0)
result = response.get("response", {})
return result if isinstance(result, dict) else {}
except asyncio.TimeoutError as e:
with anyio.fail_after(60.0):
await event.wait()
result = self.pending_control_results.pop(request_id)
self.pending_control_responses.pop(request_id, None)
if isinstance(result, Exception):
raise result
response_data = result.get("response", {})
return response_data if isinstance(response_data, dict) else {}
except TimeoutError as e:
self.pending_control_responses.pop(request_id, None)
self.pending_control_results.pop(request_id, None)
raise Exception(f"Control request timeout: {request.get('subtype')}") from e
async def _handle_sdk_mcp_request(self, server_name: str, message: dict) -> dict:
@ -350,25 +377,24 @@ class Query:
async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
"""Receive SDK messages (not control messages)."""
while not self._closed:
message = await self._message_queue.get()
async with self._message_receive:
async for message in self._message_receive:
# Check for special messages
if message.get("type") == "end":
break
elif message.get("type") == "error":
raise Exception(message.get("error", "Unknown error"))
# Check for special messages
if message.get("type") == "end":
break
elif message.get("type") == "error":
raise Exception(message.get("error", "Unknown error"))
yield message
yield message
async def close(self) -> None:
"""Close the query and transport."""
self._closed = True
if self._read_task and not self._read_task.done():
self._read_task.cancel()
# Wait for task to complete cancellation
with suppress(asyncio.CancelledError):
await self._read_task
if self._tg:
self._tg.cancel_scope.cancel()
# Wait for task group to complete cancellation
with suppress(anyio.get_cancelled_exc_class()):
await self._tg.__aexit__(None, None, None)
await self.transport.close()
# Make Query an async iterator

View file

@ -1,6 +1,5 @@
"""Claude SDK Client for interacting with Claude Code."""
import asyncio
import json
import os
from collections.abc import AsyncIterable, AsyncIterator
@ -138,8 +137,8 @@ class ClaudeSDKClient:
await self._query.initialize()
# If we have an initial prompt stream, start streaming it
if prompt is not None and isinstance(prompt, AsyncIterable):
asyncio.create_task(self._query.stream_input(prompt))
if prompt is not None and isinstance(prompt, AsyncIterable) and self._query._tg:
self._query._tg.start_soon(self._query.stream_input, prompt)
async def receive_messages(self) -> AsyncIterator[Message]:
"""Receive all messages from Claude."""

View file

@ -156,27 +156,70 @@ class ClaudeCodeOptions:
) # Pass arbitrary CLI flags
# Control protocol types for initialization
@dataclass
class InitializationMessage:
"""Initialization message from the CLI."""
type: Literal["control_response"]
response: dict[str, Any]
# SDK Control Protocol
class SDKControlInterruptRequest(TypedDict):
subtype: Literal["interrupt"]
@dataclass
class ControlRequest:
"""Control request message."""
class SDKControlPermissionRequest(TypedDict):
subtype: Literal["can_use_tool"]
tool_name: str
input: dict[str, Any]
# TODO: Add PermissionUpdate type here
permission_suggestions: list[Any] | None
blocked_path: str | None
class SDKControlInitializeRequest(TypedDict):
subtype: Literal["initialize"]
# TODO: Use HookEvent names as the key.
hooks: dict[str, Any] | None
class SDKControlSetPermissionModeRequest(TypedDict):
subtype: Literal["set_permission_mode"]
# TODO: Add PermissionMode
mode: str
class SDKHookCallbackRequest(TypedDict):
subtype: Literal["hook_callback"]
callback_id: str
input: Any
tool_use_id: str | None
class SDKControlMcpMessageRequest(TypedDict):
subtype: Literal["mcp_message"]
server_name: str
message: Any
class SDKControlRequest(TypedDict):
type: Literal["control_request"]
request_id: str
request: dict[str, Any]
request: (
SDKControlInterruptRequest
| SDKControlPermissionRequest
| SDKControlInitializeRequest
| SDKControlSetPermissionModeRequest
| SDKHookCallbackRequest
| SDKControlMcpMessageRequest
)
@dataclass
class ControlResponse:
"""Control response message."""
class ControlResponse(TypedDict):
subtype: Literal["success"]
request_id: str
response: dict[str, Any] | None
class ControlErrorResponse(TypedDict):
subtype: Literal["error"]
request_id: str
error: str
class SDKControlResponse(TypedDict):
type: Literal["control_response"]
response: dict[str, Any]
response: ControlResponse | ControlErrorResponse

View file

@ -512,20 +512,24 @@ class TestClaudeSDKClientStreaming:
data = call[0][0]
try:
msg = json.loads(data.strip())
if msg.get("type") == "control_request" and msg.get("request", {}).get("subtype") == "initialize":
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype")
== "initialize"
):
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
"output_style": "default"
}
"output_style": "default",
},
}
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
# Then yield the actual messages
await asyncio.sleep(0.1)
yield {
@ -592,9 +596,35 @@ while True:
line = sys.stdin.readline()
if not line:
break
stdin_messages.append(line.strip())
# Verify we got 2 messages
try:
msg = json.loads(line.strip())
# Handle control requests
if msg.get("type") == "control_request":
request_id = msg.get("request_id")
request = msg.get("request", {})
# Send control response for initialize
if request.get("subtype") == "initialize":
response = {
"type": "control_response",
"response": {
"subtype": "success",
"request_id": request_id,
"response": {
"commands": [],
"output_style": "default"
}
}
}
print(json.dumps(response))
sys.stdout.flush()
else:
stdin_messages.append(line.strip())
except:
stdin_messages.append(line.strip())
# Verify we got 2 user messages
assert len(stdin_messages) == 2
assert '"First"' in stdin_messages[0]
assert '"Second"' in stdin_messages[1]
@ -674,7 +704,7 @@ class TestClaudeSDKClientEdgeCases:
# Create a new mock transport for each call
mock_transport_class.side_effect = [
create_mock_transport(),
create_mock_transport()
create_mock_transport(),
]
client = ClaudeSDKClient()
@ -736,20 +766,24 @@ class TestClaudeSDKClientEdgeCases:
data = call[0][0]
try:
msg = json.loads(data.strip())
if msg.get("type") == "control_request" and msg.get("request", {}).get("subtype") == "initialize":
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype")
== "initialize"
):
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
"output_style": "default"
}
"output_style": "default",
},
}
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
# Then yield the actual messages
yield {
"type": "assistant",