mirror of
https://github.com/anthropics/claude-code-sdk-python.git
synced 2025-12-23 09:19:52 +00:00
mypy
This commit is contained in:
parent
d40261c8d5
commit
b3718cde80
6 changed files with 80 additions and 30 deletions
|
|
@ -31,9 +31,7 @@ class InternalClient:
|
|||
if transport is not None:
|
||||
chosen_transport = transport
|
||||
else:
|
||||
chosen_transport = SubprocessCLITransport(
|
||||
prompt=prompt, options=options
|
||||
)
|
||||
chosen_transport = SubprocessCLITransport(prompt=prompt, options=options)
|
||||
|
||||
# Connect transport
|
||||
await chosen_transport.connect()
|
||||
|
|
@ -59,6 +57,7 @@ class InternalClient:
|
|||
if isinstance(prompt, AsyncIterable):
|
||||
# Start streaming in background
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(query.stream_input(prompt))
|
||||
# For string prompts, the prompt is already passed via CLI args
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class Query:
|
||||
"""Handles bidirectional control protocol on top of Transport.
|
||||
|
||||
|
||||
This class manages:
|
||||
- Control request/response routing
|
||||
- Hook callbacks
|
||||
|
|
@ -27,11 +27,14 @@ class Query:
|
|||
self,
|
||||
transport: Transport,
|
||||
is_streaming_mode: bool,
|
||||
can_use_tool: Callable[[str, dict[str, Any], dict[str, Any]], Awaitable[dict[str, Any]]] | None = None,
|
||||
can_use_tool: Callable[
|
||||
[str, dict[str, Any], dict[str, Any]], Awaitable[dict[str, Any]]
|
||||
]
|
||||
| None = None,
|
||||
hooks: dict[str, list[dict[str, Any]]] | None = None,
|
||||
):
|
||||
"""Initialize Query with transport and callbacks.
|
||||
|
||||
|
||||
Args:
|
||||
transport: Low-level transport for I/O
|
||||
is_streaming_mode: Whether using streaming (bidirectional) mode
|
||||
|
|
@ -58,7 +61,7 @@ class Query:
|
|||
|
||||
async def initialize(self) -> dict[str, Any] | None:
|
||||
"""Initialize control protocol if in streaming mode.
|
||||
|
||||
|
||||
Returns:
|
||||
Initialize response with supported commands, or None if not streaming
|
||||
"""
|
||||
|
|
@ -78,10 +81,12 @@ class Query:
|
|||
self.next_callback_id += 1
|
||||
self.hook_callbacks[callback_id] = callback
|
||||
callback_ids.append(callback_id)
|
||||
hooks_config[event].append({
|
||||
"matcher": matcher.get("matcher"),
|
||||
"hookCallbackIds": callback_ids,
|
||||
})
|
||||
hooks_config[event].append(
|
||||
{
|
||||
"matcher": matcher.get("matcher"),
|
||||
"hookCallbackIds": callback_ids,
|
||||
}
|
||||
)
|
||||
|
||||
# Send initialize request
|
||||
request = {
|
||||
|
|
@ -115,7 +120,9 @@ class Query:
|
|||
if request_id in self.pending_control_responses:
|
||||
future = self.pending_control_responses.pop(request_id)
|
||||
if response.get("subtype") == "error":
|
||||
future.set_exception(Exception(response.get("error", "Unknown error")))
|
||||
future.set_exception(
|
||||
Exception(response.get("error", "Unknown error"))
|
||||
)
|
||||
else:
|
||||
future.set_result(response)
|
||||
continue
|
||||
|
|
@ -161,7 +168,7 @@ class Query:
|
|||
{
|
||||
"signal": None, # TODO: Add abort signal support
|
||||
"suggestions": request_data.get("permission_suggestions"),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
elif subtype == "hook_callback":
|
||||
|
|
@ -174,7 +181,7 @@ class Query:
|
|||
response_data = await callback(
|
||||
request_data.get("input"),
|
||||
request_data.get("tool_use_id"),
|
||||
{"signal": None} # TODO: Add abort signal support
|
||||
{"signal": None}, # TODO: Add abort signal support
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
@ -187,7 +194,7 @@ class Query:
|
|||
"subtype": "success",
|
||||
"request_id": request_id,
|
||||
"response": response_data,
|
||||
}
|
||||
},
|
||||
}
|
||||
await self.transport.write(json.dumps(response) + "\n")
|
||||
|
||||
|
|
@ -199,7 +206,7 @@ class Query:
|
|||
"subtype": "error",
|
||||
"request_id": request_id,
|
||||
"error": str(e),
|
||||
}
|
||||
},
|
||||
}
|
||||
await self.transport.write(json.dumps(response) + "\n")
|
||||
|
||||
|
|
@ -240,10 +247,12 @@ class Query:
|
|||
|
||||
async def set_permission_mode(self, mode: str) -> None:
|
||||
"""Change permission mode."""
|
||||
await self._send_control_request({
|
||||
"subtype": "set_permission_mode",
|
||||
"mode": mode,
|
||||
})
|
||||
await self._send_control_request(
|
||||
{
|
||||
"subtype": "set_permission_mode",
|
||||
"mode": mode,
|
||||
}
|
||||
)
|
||||
|
||||
async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None:
|
||||
"""Stream input messages to transport."""
|
||||
|
|
|
|||
|
|
@ -12,16 +12,25 @@ class Transport(ABC):
|
|||
(e.g., remote Claude Code connections). The Claude Code team may change or
|
||||
or remove this abstract class in any future release. Custom implementations
|
||||
must be updated to match interface changes.
|
||||
|
||||
|
||||
This is a low-level transport interface that handles raw I/O with the Claude
|
||||
process or service. The Query class builds on top of this to implement the
|
||||
control protocol and message routing.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
"""Connect the transport and prepare for communication.
|
||||
|
||||
For subprocess transports, this starts the process.
|
||||
For network transports, this establishes the connection.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def write(self, data: str) -> None:
|
||||
"""Write raw data to the transport.
|
||||
|
||||
|
||||
Args:
|
||||
data: Raw string data to write (typically JSON + newline)
|
||||
"""
|
||||
|
|
@ -30,7 +39,7 @@ class Transport(ABC):
|
|||
@abstractmethod
|
||||
def read_messages(self) -> AsyncIterator[dict[str, Any]]:
|
||||
"""Read and parse messages from the transport.
|
||||
|
||||
|
||||
Yields:
|
||||
Parsed JSON messages from the transport
|
||||
"""
|
||||
|
|
@ -44,7 +53,7 @@ class Transport(ABC):
|
|||
@abstractmethod
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if transport is ready for communication.
|
||||
|
||||
|
||||
Returns:
|
||||
True if transport is ready to send/receive messages
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -217,6 +217,7 @@ class SubprocessCLITransport(Transport):
|
|||
|
||||
if self._process.returncode is None:
|
||||
from contextlib import suppress
|
||||
|
||||
with suppress(ProcessLookupError):
|
||||
self._process.terminate()
|
||||
# Note: We can't use async wait here since close() is sync
|
||||
|
|
@ -251,6 +252,7 @@ class SubprocessCLITransport(Transport):
|
|||
self._stdin_stream = None
|
||||
if self._process and self._process.stdin:
|
||||
from contextlib import suppress
|
||||
|
||||
with suppress(Exception):
|
||||
# Mark stdin as closed - actual close will happen during cleanup
|
||||
pass
|
||||
|
|
@ -350,6 +352,10 @@ class SubprocessCLITransport(Transport):
|
|||
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if transport is ready for communication."""
|
||||
return self._ready and self._process is not None and self._process.returncode is None
|
||||
return (
|
||||
self._ready
|
||||
and self._process is not None
|
||||
and self._process.returncode is None
|
||||
)
|
||||
|
||||
# Remove interrupt and control request methods - these now belong in Query class
|
||||
|
|
|
|||
|
|
@ -138,6 +138,7 @@ class ClaudeSDKClient:
|
|||
# If we have an initial prompt stream, start streaming it
|
||||
if prompt is not None and isinstance(prompt, AsyncIterable):
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(self._query.stream_input(prompt))
|
||||
|
||||
async def receive_messages(self) -> AsyncIterator[Message]:
|
||||
|
|
@ -187,18 +188,18 @@ class ClaudeSDKClient:
|
|||
if not self._query:
|
||||
raise CLIConnectionError("Not connected. Call connect() first.")
|
||||
await self._query.interrupt()
|
||||
|
||||
|
||||
async def get_server_info(self) -> dict[str, Any] | None:
|
||||
"""Get server initialization info including available commands and output styles.
|
||||
|
||||
|
||||
Returns initialization information from the Claude Code server including:
|
||||
- Available commands (slash commands, system commands, etc.)
|
||||
- Current and available output styles
|
||||
- Server capabilities
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with server info, or None if not in streaming mode
|
||||
|
||||
|
||||
Example:
|
||||
```python
|
||||
async with ClaudeSDKClient() as client:
|
||||
|
|
@ -211,7 +212,7 @@ class ClaudeSDKClient:
|
|||
if not self._query:
|
||||
raise CLIConnectionError("Not connected. Call connect() first.")
|
||||
# Return the initialization result that was already obtained during connect
|
||||
return getattr(self._query, '_initialization_result', None)
|
||||
return getattr(self._query, "_initialization_result", None)
|
||||
|
||||
async def receive_response(self) -> AsyncIterator[Message]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -141,3 +141,29 @@ class ClaudeCodeOptions:
|
|||
extra_args: dict[str, str | None] = field(
|
||||
default_factory=dict
|
||||
) # Pass arbitrary CLI flags
|
||||
|
||||
|
||||
# Control protocol types for initialization
|
||||
@dataclass
|
||||
class InitializationMessage:
|
||||
"""Initialization message from the CLI."""
|
||||
|
||||
type: Literal["control_response"]
|
||||
response: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlRequest:
|
||||
"""Control request message."""
|
||||
|
||||
type: Literal["control_request"]
|
||||
request_id: str
|
||||
request: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlResponse:
|
||||
"""Control response message."""
|
||||
|
||||
type: Literal["control_response"]
|
||||
response: dict[str, Any]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue