This commit is contained in:
Dickson Tsai 2025-08-29 19:40:24 -07:00
parent d40261c8d5
commit b3718cde80
No known key found for this signature in database
6 changed files with 80 additions and 30 deletions

View file

@ -31,9 +31,7 @@ class InternalClient:
if transport is not None:
chosen_transport = transport
else:
chosen_transport = SubprocessCLITransport(
prompt=prompt, options=options
)
chosen_transport = SubprocessCLITransport(prompt=prompt, options=options)
# Connect transport
await chosen_transport.connect()
@ -59,6 +57,7 @@ class InternalClient:
if isinstance(prompt, AsyncIterable):
# Start streaming in background
import asyncio
asyncio.create_task(query.stream_input(prompt))
# For string prompts, the prompt is already passed via CLI args

View file

@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
class Query:
"""Handles bidirectional control protocol on top of Transport.
This class manages:
- Control request/response routing
- Hook callbacks
@ -27,11 +27,14 @@ class Query:
self,
transport: Transport,
is_streaming_mode: bool,
can_use_tool: Callable[[str, dict[str, Any], dict[str, Any]], Awaitable[dict[str, Any]]] | None = None,
can_use_tool: Callable[
[str, dict[str, Any], dict[str, Any]], Awaitable[dict[str, Any]]
]
| None = None,
hooks: dict[str, list[dict[str, Any]]] | None = None,
):
"""Initialize Query with transport and callbacks.
Args:
transport: Low-level transport for I/O
is_streaming_mode: Whether using streaming (bidirectional) mode
@ -58,7 +61,7 @@ class Query:
async def initialize(self) -> dict[str, Any] | None:
"""Initialize control protocol if in streaming mode.
Returns:
Initialize response with supported commands, or None if not streaming
"""
@ -78,10 +81,12 @@ class Query:
self.next_callback_id += 1
self.hook_callbacks[callback_id] = callback
callback_ids.append(callback_id)
hooks_config[event].append({
"matcher": matcher.get("matcher"),
"hookCallbackIds": callback_ids,
})
hooks_config[event].append(
{
"matcher": matcher.get("matcher"),
"hookCallbackIds": callback_ids,
}
)
# Send initialize request
request = {
@ -115,7 +120,9 @@ class Query:
if request_id in self.pending_control_responses:
future = self.pending_control_responses.pop(request_id)
if response.get("subtype") == "error":
future.set_exception(Exception(response.get("error", "Unknown error")))
future.set_exception(
Exception(response.get("error", "Unknown error"))
)
else:
future.set_result(response)
continue
@ -161,7 +168,7 @@ class Query:
{
"signal": None, # TODO: Add abort signal support
"suggestions": request_data.get("permission_suggestions"),
}
},
)
elif subtype == "hook_callback":
@ -174,7 +181,7 @@ class Query:
response_data = await callback(
request_data.get("input"),
request_data.get("tool_use_id"),
{"signal": None} # TODO: Add abort signal support
{"signal": None}, # TODO: Add abort signal support
)
else:
@ -187,7 +194,7 @@ class Query:
"subtype": "success",
"request_id": request_id,
"response": response_data,
}
},
}
await self.transport.write(json.dumps(response) + "\n")
@ -199,7 +206,7 @@ class Query:
"subtype": "error",
"request_id": request_id,
"error": str(e),
}
},
}
await self.transport.write(json.dumps(response) + "\n")
@ -240,10 +247,12 @@ class Query:
async def set_permission_mode(self, mode: str) -> None:
"""Change permission mode."""
await self._send_control_request({
"subtype": "set_permission_mode",
"mode": mode,
})
await self._send_control_request(
{
"subtype": "set_permission_mode",
"mode": mode,
}
)
async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None:
"""Stream input messages to transport."""

View file

@ -12,16 +12,25 @@ class Transport(ABC):
(e.g., remote Claude Code connections). The Claude Code team may change or
or remove this abstract class in any future release. Custom implementations
must be updated to match interface changes.
This is a low-level transport interface that handles raw I/O with the Claude
process or service. The Query class builds on top of this to implement the
control protocol and message routing.
"""
@abstractmethod
async def connect(self) -> None:
"""Connect the transport and prepare for communication.
For subprocess transports, this starts the process.
For network transports, this establishes the connection.
"""
pass
@abstractmethod
async def write(self, data: str) -> None:
"""Write raw data to the transport.
Args:
data: Raw string data to write (typically JSON + newline)
"""
@ -30,7 +39,7 @@ class Transport(ABC):
@abstractmethod
def read_messages(self) -> AsyncIterator[dict[str, Any]]:
"""Read and parse messages from the transport.
Yields:
Parsed JSON messages from the transport
"""
@ -44,7 +53,7 @@ class Transport(ABC):
@abstractmethod
def is_ready(self) -> bool:
"""Check if transport is ready for communication.
Returns:
True if transport is ready to send/receive messages
"""

View file

@ -217,6 +217,7 @@ class SubprocessCLITransport(Transport):
if self._process.returncode is None:
from contextlib import suppress
with suppress(ProcessLookupError):
self._process.terminate()
# Note: We can't use async wait here since close() is sync
@ -251,6 +252,7 @@ class SubprocessCLITransport(Transport):
self._stdin_stream = None
if self._process and self._process.stdin:
from contextlib import suppress
with suppress(Exception):
# Mark stdin as closed - actual close will happen during cleanup
pass
@ -350,6 +352,10 @@ class SubprocessCLITransport(Transport):
def is_ready(self) -> bool:
"""Check if transport is ready for communication."""
return self._ready and self._process is not None and self._process.returncode is None
return (
self._ready
and self._process is not None
and self._process.returncode is None
)
# Remove interrupt and control request methods - these now belong in Query class

View file

@ -138,6 +138,7 @@ class ClaudeSDKClient:
# If we have an initial prompt stream, start streaming it
if prompt is not None and isinstance(prompt, AsyncIterable):
import asyncio
asyncio.create_task(self._query.stream_input(prompt))
async def receive_messages(self) -> AsyncIterator[Message]:
@ -187,18 +188,18 @@ class ClaudeSDKClient:
if not self._query:
raise CLIConnectionError("Not connected. Call connect() first.")
await self._query.interrupt()
async def get_server_info(self) -> dict[str, Any] | None:
"""Get server initialization info including available commands and output styles.
Returns initialization information from the Claude Code server including:
- Available commands (slash commands, system commands, etc.)
- Current and available output styles
- Server capabilities
Returns:
Dictionary with server info, or None if not in streaming mode
Example:
```python
async with ClaudeSDKClient() as client:
@ -211,7 +212,7 @@ class ClaudeSDKClient:
if not self._query:
raise CLIConnectionError("Not connected. Call connect() first.")
# Return the initialization result that was already obtained during connect
return getattr(self._query, '_initialization_result', None)
return getattr(self._query, "_initialization_result", None)
async def receive_response(self) -> AsyncIterator[Message]:
"""

View file

@ -141,3 +141,29 @@ class ClaudeCodeOptions:
extra_args: dict[str, str | None] = field(
default_factory=dict
) # Pass arbitrary CLI flags
# Control protocol types for initialization
@dataclass
class InitializationMessage:
"""Initialization message from the CLI."""
type: Literal["control_response"]
response: dict[str, Any]
@dataclass
class ControlRequest:
"""Control request message."""
type: Literal["control_request"]
request_id: str
request: dict[str, Any]
@dataclass
class ControlResponse:
"""Control response message."""
type: Literal["control_response"]
response: dict[str, Any]