mirror of
https://github.com/anthropics/claude-code-sdk-python.git
synced 2025-12-23 09:19:52 +00:00
Remove unstable public APIs from SDK (#151)
Hide hooks, tool permission callbacks, and SDK MCP server APIs from public interface while keeping implementation code intact. These features are not yet stable and should not be documented or exposed to users. Changes: - Remove hook-related exports (HookCallback, HookContext, HookMatcher) from __all__ - Remove tool permission exports (CanUseTool, ToolPermissionContext) from __all__ - Remove SDK MCP exports (tool, create_sdk_mcp_server, SdkMcpTool) from __all__ - Delete examples using unstable APIs (tool_permission_callback.py, mcp_calculator.py) - Remove SDK MCP server documentation from README 🤖 Generated with [Claude Code](https://claude.ai/code) --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
99d13717d5
commit
e4feaf2e57
5 changed files with 1 additions and 727 deletions
84
README.md
84
README.md
|
|
@ -76,90 +76,6 @@ options = ClaudeCodeOptions(
|
|||
)
|
||||
```
|
||||
|
||||
### SDK MCP Servers (In-Process)
|
||||
|
||||
The SDK now supports in-process MCP servers that run directly within your Python application, eliminating the need for separate processes.
|
||||
|
||||
#### Creating a Simple Tool
|
||||
|
||||
```python
|
||||
from claude_code_sdk import tool, create_sdk_mcp_server
|
||||
|
||||
# Define a tool using the @tool decorator
|
||||
@tool("greet", "Greet a user", {"name": str})
|
||||
async def greet_user(args):
|
||||
return {
|
||||
"content": [
|
||||
{"type": "text", "text": f"Hello, {args['name']}!"}
|
||||
]
|
||||
}
|
||||
|
||||
# Create an SDK MCP server
|
||||
server = create_sdk_mcp_server(
|
||||
name="my-tools",
|
||||
version="1.0.0",
|
||||
tools=[greet_user]
|
||||
)
|
||||
|
||||
# Use it with Claude
|
||||
options = ClaudeCodeOptions(
|
||||
mcp_servers={"tools": server}
|
||||
)
|
||||
|
||||
async for message in query(prompt="Greet Alice", options=options):
|
||||
print(message)
|
||||
```
|
||||
|
||||
#### Benefits Over External MCP Servers
|
||||
|
||||
- **No subprocess management** - Runs in the same process as your application
|
||||
- **Better performance** - No IPC overhead for tool calls
|
||||
- **Simpler deployment** - Single Python process instead of multiple
|
||||
- **Easier debugging** - All code runs in the same process
|
||||
- **Type safety** - Direct Python function calls with type hints
|
||||
|
||||
#### Migration from External Servers
|
||||
|
||||
```python
|
||||
# BEFORE: External MCP server (separate process)
|
||||
options = ClaudeCodeOptions(
|
||||
mcp_servers={
|
||||
"calculator": {
|
||||
"type": "stdio",
|
||||
"command": "python",
|
||||
"args": ["-m", "calculator_server"]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# AFTER: SDK MCP server (in-process)
|
||||
from my_tools import add, subtract # Your tool functions
|
||||
|
||||
calculator = create_sdk_mcp_server(
|
||||
name="calculator",
|
||||
tools=[add, subtract]
|
||||
)
|
||||
|
||||
options = ClaudeCodeOptions(
|
||||
mcp_servers={"calculator": calculator}
|
||||
)
|
||||
```
|
||||
|
||||
#### Mixed Server Support
|
||||
|
||||
You can use both SDK and external MCP servers together:
|
||||
|
||||
```python
|
||||
options = ClaudeCodeOptions(
|
||||
mcp_servers={
|
||||
"internal": sdk_server, # In-process SDK server
|
||||
"external": { # External subprocess server
|
||||
"type": "stdio",
|
||||
"command": "external-server"
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
|
|
|
|||
|
|
@ -1,181 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Example: Calculator MCP Server.
|
||||
|
||||
This example demonstrates how to create an in-process MCP server with
|
||||
calculator tools using the Claude Code Python SDK.
|
||||
|
||||
Unlike external MCP servers that require separate processes, this server
|
||||
runs directly within your Python application, providing better performance
|
||||
and simpler deployment.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from claude_code_sdk import (
|
||||
ClaudeCodeOptions,
|
||||
create_sdk_mcp_server,
|
||||
query,
|
||||
tool,
|
||||
)
|
||||
|
||||
# Define calculator tools using the @tool decorator
|
||||
|
||||
@tool("add", "Add two numbers", {"a": float, "b": float})
|
||||
async def add_numbers(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Add two numbers together."""
|
||||
result = args["a"] + args["b"]
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{args['a']} + {args['b']} = {result}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@tool("subtract", "Subtract one number from another", {"a": float, "b": float})
|
||||
async def subtract_numbers(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Subtract b from a."""
|
||||
result = args["a"] - args["b"]
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{args['a']} - {args['b']} = {result}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@tool("multiply", "Multiply two numbers", {"a": float, "b": float})
|
||||
async def multiply_numbers(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Multiply two numbers."""
|
||||
result = args["a"] * args["b"]
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{args['a']} × {args['b']} = {result}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@tool("divide", "Divide one number by another", {"a": float, "b": float})
|
||||
async def divide_numbers(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Divide a by b."""
|
||||
if args["b"] == 0:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Error: Division by zero is not allowed"
|
||||
}
|
||||
],
|
||||
"is_error": True
|
||||
}
|
||||
|
||||
result = args["a"] / args["b"]
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{args['a']} ÷ {args['b']} = {result}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@tool("sqrt", "Calculate square root", {"n": float})
|
||||
async def square_root(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Calculate the square root of a number."""
|
||||
n = args["n"]
|
||||
if n < 0:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Error: Cannot calculate square root of negative number {n}"
|
||||
}
|
||||
],
|
||||
"is_error": True
|
||||
}
|
||||
|
||||
import math
|
||||
result = math.sqrt(n)
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"√{n} = {result}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@tool("power", "Raise a number to a power", {"base": float, "exponent": float})
|
||||
async def power(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Raise base to the exponent power."""
|
||||
result = args["base"] ** args["exponent"]
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{args['base']}^{args['exponent']} = {result}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run example calculations using the SDK MCP server."""
|
||||
|
||||
# Create the calculator server with all tools
|
||||
calculator = create_sdk_mcp_server(
|
||||
name="calculator",
|
||||
version="2.0.0",
|
||||
tools=[
|
||||
add_numbers,
|
||||
subtract_numbers,
|
||||
multiply_numbers,
|
||||
divide_numbers,
|
||||
square_root,
|
||||
power
|
||||
]
|
||||
)
|
||||
|
||||
# Configure Claude to use the calculator server
|
||||
options = ClaudeCodeOptions(
|
||||
mcp_servers={"calc": calculator},
|
||||
# Allow Claude to use calculator tools without permission prompts
|
||||
permission_mode="bypassPermissions"
|
||||
)
|
||||
|
||||
# Example prompts to demonstrate calculator usage
|
||||
prompts = [
|
||||
"Calculate 15 + 27",
|
||||
"What is 100 divided by 7?",
|
||||
"Calculate the square root of 144",
|
||||
"What is 2 raised to the power of 8?",
|
||||
"Calculate (12 + 8) * 3 - 10" # Complex calculation
|
||||
]
|
||||
|
||||
for prompt in prompts:
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Prompt: {prompt}")
|
||||
print(f"{'='*50}")
|
||||
|
||||
async for message in query(prompt=prompt, options=options):
|
||||
# Print the message content
|
||||
if hasattr(message, 'content'):
|
||||
for content_block in message.content:
|
||||
if hasattr(content_block, 'text'):
|
||||
print(f"Claude: {content_block.text}")
|
||||
elif hasattr(content_block, 'name'):
|
||||
print(f"Using tool: {content_block.name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -1,158 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Example: Tool Permission Callbacks.
|
||||
|
||||
This example demonstrates how to use tool permission callbacks to control
|
||||
which tools Claude can use and modify their inputs.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from claude_code_sdk import (
|
||||
AssistantMessage,
|
||||
ClaudeCodeOptions,
|
||||
ClaudeSDKClient,
|
||||
PermissionResultAllow,
|
||||
PermissionResultDeny,
|
||||
ResultMessage,
|
||||
TextBlock,
|
||||
ToolPermissionContext,
|
||||
)
|
||||
|
||||
# Track tool usage for demonstration
|
||||
tool_usage_log = []
|
||||
|
||||
|
||||
async def my_permission_callback(
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
context: ToolPermissionContext
|
||||
) -> PermissionResultAllow | PermissionResultDeny:
|
||||
"""Control tool permissions based on tool type and input."""
|
||||
|
||||
# Log the tool request
|
||||
tool_usage_log.append({
|
||||
"tool": tool_name,
|
||||
"input": input_data,
|
||||
"suggestions": context.suggestions
|
||||
})
|
||||
|
||||
print(f"\n🔧 Tool Permission Request: {tool_name}")
|
||||
print(f" Input: {json.dumps(input_data, indent=2)}")
|
||||
|
||||
# Always allow read operations
|
||||
if tool_name in ["Read", "Glob", "Grep"]:
|
||||
print(f" ✅ Automatically allowing {tool_name} (read-only operation)")
|
||||
return PermissionResultAllow()
|
||||
|
||||
# Deny write operations to system directories
|
||||
if tool_name in ["Write", "Edit", "MultiEdit"]:
|
||||
file_path = input_data.get("file_path", "")
|
||||
if file_path.startswith("/etc/") or file_path.startswith("/usr/"):
|
||||
print(f" ❌ Denying write to system directory: {file_path}")
|
||||
return PermissionResultDeny(
|
||||
message=f"Cannot write to system directory: {file_path}"
|
||||
)
|
||||
|
||||
# Redirect writes to a safe directory
|
||||
if not file_path.startswith("/tmp/") and not file_path.startswith("./"):
|
||||
safe_path = f"./safe_output/{file_path.split('/')[-1]}"
|
||||
print(f" ⚠️ Redirecting write from {file_path} to {safe_path}")
|
||||
modified_input = input_data.copy()
|
||||
modified_input["file_path"] = safe_path
|
||||
return PermissionResultAllow(
|
||||
updated_input=modified_input
|
||||
)
|
||||
|
||||
# Check dangerous bash commands
|
||||
if tool_name == "Bash":
|
||||
command = input_data.get("command", "")
|
||||
dangerous_commands = ["rm -rf", "sudo", "chmod 777", "dd if=", "mkfs"]
|
||||
|
||||
for dangerous in dangerous_commands:
|
||||
if dangerous in command:
|
||||
print(f" ❌ Denying dangerous command: {command}")
|
||||
return PermissionResultDeny(
|
||||
message=f"Dangerous command pattern detected: {dangerous}"
|
||||
)
|
||||
|
||||
# Allow but log the command
|
||||
print(f" ✅ Allowing bash command: {command}")
|
||||
return PermissionResultAllow()
|
||||
|
||||
# For all other tools, ask the user
|
||||
print(f" ❓ Unknown tool: {tool_name}")
|
||||
print(f" Input: {json.dumps(input_data, indent=6)}")
|
||||
user_input = input(" Allow this tool? (y/N): ").strip().lower()
|
||||
|
||||
if user_input in ("y", "yes"):
|
||||
return PermissionResultAllow()
|
||||
else:
|
||||
return PermissionResultDeny(
|
||||
message="User denied permission"
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run example with tool permission callbacks."""
|
||||
|
||||
print("=" * 60)
|
||||
print("Tool Permission Callback Example")
|
||||
print("=" * 60)
|
||||
print("\nThis example demonstrates how to:")
|
||||
print("1. Allow/deny tools based on type")
|
||||
print("2. Modify tool inputs for safety")
|
||||
print("3. Log tool usage")
|
||||
print("4. Prompt for unknown tools")
|
||||
print("=" * 60)
|
||||
|
||||
# Configure options with our callback
|
||||
options = ClaudeCodeOptions(
|
||||
can_use_tool=my_permission_callback,
|
||||
# Use default permission mode to ensure callbacks are invoked
|
||||
permission_mode="default",
|
||||
cwd="." # Set working directory
|
||||
)
|
||||
|
||||
# Create client and send a query that will use multiple tools
|
||||
async with ClaudeSDKClient(options) as client:
|
||||
print("\n📝 Sending query to Claude...")
|
||||
await client.query(
|
||||
"Please do the following:\n"
|
||||
"1. List the files in the current directory\n"
|
||||
"2. Create a simple Python hello world script at hello.py\n"
|
||||
"3. Run the script to test it"
|
||||
)
|
||||
|
||||
print("\n📨 Receiving response...")
|
||||
message_count = 0
|
||||
|
||||
async for message in client.receive_response():
|
||||
message_count += 1
|
||||
|
||||
if isinstance(message, AssistantMessage):
|
||||
# Print Claude's text responses
|
||||
for block in message.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"\n💬 Claude: {block.text}")
|
||||
|
||||
elif isinstance(message, ResultMessage):
|
||||
print("\n✅ Task completed!")
|
||||
print(f" Duration: {message.duration_ms}ms")
|
||||
if message.total_cost_usd:
|
||||
print(f" Cost: ${message.total_cost_usd:.4f}")
|
||||
print(f" Messages processed: {message_count}")
|
||||
|
||||
# Print tool usage summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Tool Usage Summary")
|
||||
print("=" * 60)
|
||||
for i, usage in enumerate(tool_usage_log, 1):
|
||||
print(f"\n{i}. Tool: {usage['tool']}")
|
||||
print(f" Input: {json.dumps(usage['input'], indent=6)}")
|
||||
if usage['suggestions']:
|
||||
print(f" Suggestions: {usage['suggestions']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -16,12 +16,8 @@ from .client import ClaudeSDKClient
|
|||
from .query import query
|
||||
from .types import (
|
||||
AssistantMessage,
|
||||
CanUseTool,
|
||||
ClaudeCodeOptions,
|
||||
ContentBlock,
|
||||
HookCallback,
|
||||
HookContext,
|
||||
HookMatcher,
|
||||
McpSdkServerConfig,
|
||||
McpServerConfig,
|
||||
Message,
|
||||
|
|
@ -34,7 +30,6 @@ from .types import (
|
|||
SystemMessage,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolPermissionContext,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
|
|
@ -297,20 +292,11 @@ __all__ = [
|
|||
"ToolUseBlock",
|
||||
"ToolResultBlock",
|
||||
"ContentBlock",
|
||||
# Tool callbacks
|
||||
"CanUseTool",
|
||||
"ToolPermissionContext",
|
||||
# Permission results (keep these as they may be used by internal callbacks)
|
||||
"PermissionResult",
|
||||
"PermissionResultAllow",
|
||||
"PermissionResultDeny",
|
||||
"PermissionUpdate",
|
||||
"HookCallback",
|
||||
"HookContext",
|
||||
"HookMatcher",
|
||||
# MCP Server Support
|
||||
"create_sdk_mcp_server",
|
||||
"tool",
|
||||
"SdkMcpTool",
|
||||
# Errors
|
||||
"ClaudeSDKError",
|
||||
"CLIConnectionError",
|
||||
|
|
|
|||
|
|
@ -1,289 +0,0 @@
|
|||
"""Tests for tool permission callbacks and hook callbacks."""
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_code_sdk import (
|
||||
ClaudeCodeOptions,
|
||||
HookContext,
|
||||
HookMatcher,
|
||||
PermissionResultAllow,
|
||||
PermissionResultDeny,
|
||||
ToolPermissionContext,
|
||||
)
|
||||
from claude_code_sdk._internal.query import Query
|
||||
from claude_code_sdk._internal.transport import Transport
|
||||
|
||||
|
||||
class MockTransport(Transport):
|
||||
"""Mock transport for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.written_messages = []
|
||||
self.messages_to_read = []
|
||||
self._connected = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
self._connected = True
|
||||
|
||||
async def close(self) -> None:
|
||||
self._connected = False
|
||||
|
||||
async def write(self, data: str) -> None:
|
||||
self.written_messages.append(data)
|
||||
|
||||
async def end_input(self) -> None:
|
||||
pass
|
||||
|
||||
def read_messages(self):
|
||||
async def _read():
|
||||
for msg in self.messages_to_read:
|
||||
yield msg
|
||||
|
||||
return _read()
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
return self._connected
|
||||
|
||||
|
||||
class TestToolPermissionCallbacks:
|
||||
"""Test tool permission callback functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_allow(self):
|
||||
"""Test callback that allows tool execution."""
|
||||
callback_invoked = False
|
||||
|
||||
async def allow_callback(
|
||||
tool_name: str, input_data: dict, context: ToolPermissionContext
|
||||
) -> PermissionResultAllow:
|
||||
nonlocal callback_invoked
|
||||
callback_invoked = True
|
||||
assert tool_name == "TestTool"
|
||||
assert input_data == {"param": "value"}
|
||||
return PermissionResultAllow()
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=allow_callback,
|
||||
hooks=None,
|
||||
)
|
||||
|
||||
# Simulate control request
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-1",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "TestTool",
|
||||
"input": {"param": "value"},
|
||||
"permission_suggestions": [],
|
||||
},
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check callback was invoked
|
||||
assert callback_invoked
|
||||
|
||||
# Check response was sent
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
assert '"allow": true' in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_deny(self):
|
||||
"""Test callback that denies tool execution."""
|
||||
|
||||
async def deny_callback(
|
||||
tool_name: str, input_data: dict, context: ToolPermissionContext
|
||||
) -> PermissionResultDeny:
|
||||
return PermissionResultDeny(message="Security policy violation")
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=deny_callback,
|
||||
hooks=None,
|
||||
)
|
||||
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-2",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "DangerousTool",
|
||||
"input": {"command": "rm -rf /"},
|
||||
"permission_suggestions": ["deny"],
|
||||
},
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check response
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
assert '"allow": false' in response
|
||||
assert '"reason": "Security policy violation"' in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_input_modification(self):
|
||||
"""Test callback that modifies tool input."""
|
||||
|
||||
async def modify_callback(
|
||||
tool_name: str, input_data: dict, context: ToolPermissionContext
|
||||
) -> PermissionResultAllow:
|
||||
# Modify the input to add safety flag
|
||||
modified_input = input_data.copy()
|
||||
modified_input["safe_mode"] = True
|
||||
return PermissionResultAllow(updated_input=modified_input)
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=modify_callback,
|
||||
hooks=None,
|
||||
)
|
||||
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-3",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "WriteTool",
|
||||
"input": {"file_path": "/etc/passwd"},
|
||||
"permission_suggestions": [],
|
||||
},
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check response includes modified input
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
assert '"allow": true' in response
|
||||
assert '"safe_mode": true' in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_exception_handling(self):
|
||||
"""Test that callback exceptions are properly handled."""
|
||||
|
||||
async def error_callback(
|
||||
tool_name: str, input_data: dict, context: ToolPermissionContext
|
||||
) -> PermissionResultAllow:
|
||||
raise ValueError("Callback error")
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=error_callback,
|
||||
hooks=None,
|
||||
)
|
||||
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-5",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "TestTool",
|
||||
"input": {},
|
||||
"permission_suggestions": [],
|
||||
},
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check error response was sent
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
assert '"subtype": "error"' in response
|
||||
assert "Callback error" in response
|
||||
|
||||
|
||||
class TestHookCallbacks:
|
||||
"""Test hook callback functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_execution(self):
|
||||
"""Test that hooks are called at appropriate times."""
|
||||
hook_calls = []
|
||||
|
||||
async def test_hook(
|
||||
input_data: dict, tool_use_id: str | None, context: HookContext
|
||||
) -> dict:
|
||||
hook_calls.append({"input": input_data, "tool_use_id": tool_use_id})
|
||||
return {"processed": True}
|
||||
|
||||
transport = MockTransport()
|
||||
|
||||
# Create hooks configuration
|
||||
hooks = {
|
||||
"tool_use_start": [{"matcher": {"tool": "TestTool"}, "hooks": [test_hook]}]
|
||||
}
|
||||
|
||||
query = Query(
|
||||
transport=transport, is_streaming_mode=True, can_use_tool=None, hooks=hooks
|
||||
)
|
||||
|
||||
# Manually register the hook callback to avoid needing the full initialize flow
|
||||
callback_id = "test_hook_0"
|
||||
query.hook_callbacks[callback_id] = test_hook
|
||||
|
||||
# Simulate hook callback request
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-hook-1",
|
||||
"request": {
|
||||
"subtype": "hook_callback",
|
||||
"callback_id": callback_id,
|
||||
"input": {"test": "data"},
|
||||
"tool_use_id": "tool-123",
|
||||
},
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check hook was called
|
||||
assert len(hook_calls) == 1
|
||||
assert hook_calls[0]["input"] == {"test": "data"}
|
||||
assert hook_calls[0]["tool_use_id"] == "tool-123"
|
||||
|
||||
# Check response
|
||||
assert len(transport.written_messages) > 0
|
||||
last_response = transport.written_messages[-1]
|
||||
assert '"processed": true' in last_response
|
||||
|
||||
|
||||
class TestClaudeCodeOptionsIntegration:
|
||||
"""Test that callbacks work through ClaudeCodeOptions."""
|
||||
|
||||
def test_options_with_callbacks(self):
|
||||
"""Test creating options with callbacks."""
|
||||
|
||||
async def my_callback(
|
||||
tool_name: str, input_data: dict, context: ToolPermissionContext
|
||||
) -> PermissionResultAllow:
|
||||
return PermissionResultAllow()
|
||||
|
||||
async def my_hook(
|
||||
input_data: dict, tool_use_id: str | None, context: HookContext
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
options = ClaudeCodeOptions(
|
||||
can_use_tool=my_callback,
|
||||
hooks={
|
||||
"tool_use_start": [
|
||||
HookMatcher(matcher={"tool": "Bash"}, hooks=[my_hook])
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
assert options.can_use_tool == my_callback
|
||||
assert "tool_use_start" in options.hooks
|
||||
assert len(options.hooks["tool_use_start"]) == 1
|
||||
assert options.hooks["tool_use_start"][0].hooks[0] == my_hook
|
||||
Loading…
Add table
Add a link
Reference in a new issue