mirror of
https://github.com/anthropics/claude-code-sdk-python.git
synced 2025-12-23 09:19:52 +00:00
feat: Add tool permission and hook callbacks support (#143)
## Summary Adds comprehensive support for tool permission callbacks and hook callbacks to the Python SDK, enabling fine-grained control over tool execution and custom event handling. ## Key Changes - **Tool Permission Callbacks**: Control which tools Claude can use and modify their inputs - type with async support - with suggestions from CLI - for structured responses - **Hook Callbacks**: React to events in the Claude workflow - type for event handlers - for conditional hook execution - Support for tool_use_start, tool_use_end events - **Integration**: Full plumbing through ClaudeCodeOptions → Client → Query - **Examples**: Comprehensive example showing permission control patterns - **Tests**: Coverage for all callback scenarios ## Implementation Details - Callbacks are registered during initialization phase - Control protocol handles can_use_tool and hook_callback requests - Backwards compatible with dict returns for tool permissions - Proper error handling and type safety throughout Builds on top of #139's control protocol implementation. --------- Co-authored-by: Dickson Tsai <dickson@anthropic.com>
This commit is contained in:
parent
9ef57859af
commit
68f0d7aa7d
7 changed files with 649 additions and 11 deletions
158
examples/tool_permission_callback.py
Normal file
158
examples/tool_permission_callback.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Example: Tool Permission Callbacks.
|
||||
|
||||
This example demonstrates how to use tool permission callbacks to control
|
||||
which tools Claude can use and modify their inputs.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from claude_code_sdk import (
|
||||
AssistantMessage,
|
||||
ClaudeCodeOptions,
|
||||
ClaudeSDKClient,
|
||||
PermissionResultAllow,
|
||||
PermissionResultDeny,
|
||||
ResultMessage,
|
||||
TextBlock,
|
||||
ToolPermissionContext,
|
||||
)
|
||||
|
||||
# Track tool usage for demonstration
|
||||
tool_usage_log = []
|
||||
|
||||
|
||||
async def my_permission_callback(
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
context: ToolPermissionContext
|
||||
) -> PermissionResultAllow | PermissionResultDeny:
|
||||
"""Control tool permissions based on tool type and input."""
|
||||
|
||||
# Log the tool request
|
||||
tool_usage_log.append({
|
||||
"tool": tool_name,
|
||||
"input": input_data,
|
||||
"suggestions": context.suggestions
|
||||
})
|
||||
|
||||
print(f"\n🔧 Tool Permission Request: {tool_name}")
|
||||
print(f" Input: {json.dumps(input_data, indent=2)}")
|
||||
|
||||
# Always allow read operations
|
||||
if tool_name in ["Read", "Glob", "Grep"]:
|
||||
print(f" ✅ Automatically allowing {tool_name} (read-only operation)")
|
||||
return PermissionResultAllow()
|
||||
|
||||
# Deny write operations to system directories
|
||||
if tool_name in ["Write", "Edit", "MultiEdit"]:
|
||||
file_path = input_data.get("file_path", "")
|
||||
if file_path.startswith("/etc/") or file_path.startswith("/usr/"):
|
||||
print(f" ❌ Denying write to system directory: {file_path}")
|
||||
return PermissionResultDeny(
|
||||
message=f"Cannot write to system directory: {file_path}"
|
||||
)
|
||||
|
||||
# Redirect writes to a safe directory
|
||||
if not file_path.startswith("/tmp/") and not file_path.startswith("./"):
|
||||
safe_path = f"./safe_output/{file_path.split('/')[-1]}"
|
||||
print(f" ⚠️ Redirecting write from {file_path} to {safe_path}")
|
||||
modified_input = input_data.copy()
|
||||
modified_input["file_path"] = safe_path
|
||||
return PermissionResultAllow(
|
||||
updatedInput=modified_input
|
||||
)
|
||||
|
||||
# Check dangerous bash commands
|
||||
if tool_name == "Bash":
|
||||
command = input_data.get("command", "")
|
||||
dangerous_commands = ["rm -rf", "sudo", "chmod 777", "dd if=", "mkfs"]
|
||||
|
||||
for dangerous in dangerous_commands:
|
||||
if dangerous in command:
|
||||
print(f" ❌ Denying dangerous command: {command}")
|
||||
return PermissionResultDeny(
|
||||
message=f"Dangerous command pattern detected: {dangerous}"
|
||||
)
|
||||
|
||||
# Allow but log the command
|
||||
print(f" ✅ Allowing bash command: {command}")
|
||||
return PermissionResultAllow()
|
||||
|
||||
# For all other tools, ask the user
|
||||
print(f" ❓ Unknown tool: {tool_name}")
|
||||
print(f" Input: {json.dumps(input_data, indent=6)}")
|
||||
user_input = input(" Allow this tool? (y/N): ").strip().lower()
|
||||
|
||||
if user_input in ("y", "yes"):
|
||||
return PermissionResultAllow()
|
||||
else:
|
||||
return PermissionResultDeny(
|
||||
message="User denied permission"
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run example with tool permission callbacks."""
|
||||
|
||||
print("=" * 60)
|
||||
print("Tool Permission Callback Example")
|
||||
print("=" * 60)
|
||||
print("\nThis example demonstrates how to:")
|
||||
print("1. Allow/deny tools based on type")
|
||||
print("2. Modify tool inputs for safety")
|
||||
print("3. Log tool usage")
|
||||
print("4. Prompt for unknown tools")
|
||||
print("=" * 60)
|
||||
|
||||
# Configure options with our callback
|
||||
options = ClaudeCodeOptions(
|
||||
can_use_tool=my_permission_callback,
|
||||
# Use default permission mode to ensure callbacks are invoked
|
||||
permission_mode="default",
|
||||
cwd="." # Set working directory
|
||||
)
|
||||
|
||||
# Create client and send a query that will use multiple tools
|
||||
async with ClaudeSDKClient(options) as client:
|
||||
print("\n📝 Sending query to Claude...")
|
||||
await client.query(
|
||||
"Please do the following:\n"
|
||||
"1. List the files in the current directory\n"
|
||||
"2. Create a simple Python hello world script at hello.py\n"
|
||||
"3. Run the script to test it"
|
||||
)
|
||||
|
||||
print("\n📨 Receiving response...")
|
||||
message_count = 0
|
||||
|
||||
async for message in client.receive_response():
|
||||
message_count += 1
|
||||
|
||||
if isinstance(message, AssistantMessage):
|
||||
# Print Claude's text responses
|
||||
for block in message.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"\n💬 Claude: {block.text}")
|
||||
|
||||
elif isinstance(message, ResultMessage):
|
||||
print("\n✅ Task completed!")
|
||||
print(f" Duration: {message.duration_ms}ms")
|
||||
if message.total_cost_usd:
|
||||
print(f" Cost: ${message.total_cost_usd:.4f}")
|
||||
print(f" Messages processed: {message_count}")
|
||||
|
||||
# Print tool usage summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Tool Usage Summary")
|
||||
print("=" * 60)
|
||||
for i, usage in enumerate(tool_usage_log, 1):
|
||||
print(f"\n{i}. Tool: {usage['tool']}")
|
||||
print(f" Input: {json.dumps(usage['input'], indent=6)}")
|
||||
if usage['suggestions']:
|
||||
print(f" Suggestions: {usage['suggestions']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -16,16 +16,25 @@ from .client import ClaudeSDKClient
|
|||
from .query import query
|
||||
from .types import (
|
||||
AssistantMessage,
|
||||
CanUseTool,
|
||||
ClaudeCodeOptions,
|
||||
ContentBlock,
|
||||
HookCallback,
|
||||
HookContext,
|
||||
HookMatcher,
|
||||
McpSdkServerConfig,
|
||||
McpServerConfig,
|
||||
Message,
|
||||
PermissionMode,
|
||||
PermissionResult,
|
||||
PermissionResultAllow,
|
||||
PermissionResultDeny,
|
||||
PermissionUpdate,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolPermissionContext,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
|
|
@ -286,6 +295,16 @@ __all__ = [
|
|||
"ToolUseBlock",
|
||||
"ToolResultBlock",
|
||||
"ContentBlock",
|
||||
# Tool callbacks
|
||||
"CanUseTool",
|
||||
"ToolPermissionContext",
|
||||
"PermissionResult",
|
||||
"PermissionResultAllow",
|
||||
"PermissionResultDeny",
|
||||
"PermissionUpdate",
|
||||
"HookCallback",
|
||||
"HookContext",
|
||||
"HookMatcher",
|
||||
# MCP Server Support
|
||||
"create_sdk_mcp_server",
|
||||
"tool",
|
||||
|
|
|
|||
|
|
@ -19,6 +19,22 @@ class InternalClient:
|
|||
def __init__(self) -> None:
|
||||
"""Initialize the internal client."""
|
||||
|
||||
def _convert_hooks_to_internal_format(
|
||||
self, hooks: dict[str, list]
|
||||
) -> dict[str, list[dict[str, Any]]]:
|
||||
"""Convert HookMatcher format to internal Query format."""
|
||||
internal_hooks = {}
|
||||
for event, matchers in hooks.items():
|
||||
internal_hooks[event] = []
|
||||
for matcher in matchers:
|
||||
# Convert HookMatcher to internal dict format
|
||||
internal_matcher = {
|
||||
"matcher": matcher.matcher if hasattr(matcher, 'matcher') else None,
|
||||
"hooks": matcher.hooks if hasattr(matcher, 'hooks') else []
|
||||
}
|
||||
internal_hooks[event].append(internal_matcher)
|
||||
return internal_hooks
|
||||
|
||||
async def process_query(
|
||||
self,
|
||||
prompt: str | AsyncIterable[dict[str, Any]],
|
||||
|
|
@ -48,8 +64,8 @@ class InternalClient:
|
|||
query = Query(
|
||||
transport=chosen_transport,
|
||||
is_streaming_mode=is_streaming,
|
||||
can_use_tool=None, # TODO: Add support for can_use_tool callback
|
||||
hooks=None, # TODO: Add support for hooks
|
||||
can_use_tool=options.can_use_tool,
|
||||
hooks=self._convert_hooks_to_internal_format(options.hooks) if options.hooks else None,
|
||||
sdk_mcp_servers=sdk_mcp_servers,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,10 +15,14 @@ from mcp.types import (
|
|||
)
|
||||
|
||||
from ..types import (
|
||||
PermissionResult,
|
||||
PermissionResultAllow,
|
||||
PermissionResultDeny,
|
||||
SDKControlPermissionRequest,
|
||||
SDKControlRequest,
|
||||
SDKControlResponse,
|
||||
SDKHookCallbackRequest,
|
||||
ToolPermissionContext,
|
||||
)
|
||||
from .transport import Transport
|
||||
|
||||
|
|
@ -195,15 +199,34 @@ class Query:
|
|||
if not self.can_use_tool:
|
||||
raise Exception("canUseTool callback is not provided")
|
||||
|
||||
response_data = await self.can_use_tool(
|
||||
context = ToolPermissionContext(
|
||||
signal=None, # TODO: Add abort signal support
|
||||
suggestions=permission_request.get("permission_suggestions", [])
|
||||
)
|
||||
|
||||
response = await self.can_use_tool(
|
||||
permission_request["tool_name"],
|
||||
permission_request["input"],
|
||||
{
|
||||
"signal": None, # TODO: Add abort signal support
|
||||
"suggestions": permission_request.get("permission_suggestions"),
|
||||
},
|
||||
context
|
||||
)
|
||||
|
||||
# Convert PermissionResult to expected dict format
|
||||
if isinstance(response, PermissionResultAllow):
|
||||
response_data = {
|
||||
"allow": True
|
||||
}
|
||||
if response.updatedInput is not None:
|
||||
response_data["input"] = response.updatedInput
|
||||
# TODO: Handle updatedPermissions when control protocol supports it
|
||||
elif isinstance(response, PermissionResultDeny):
|
||||
response_data = {
|
||||
"allow": False,
|
||||
"reason": response.message
|
||||
}
|
||||
# TODO: Handle interrupt flag when control protocol supports it
|
||||
else:
|
||||
raise TypeError(f"Tool permission callback must return PermissionResult (PermissionResultAllow or PermissionResultDeny), got {type(response)}")
|
||||
|
||||
elif subtype == "hook_callback":
|
||||
hook_callback_request: SDKHookCallbackRequest = request_data # type: ignore[assignment]
|
||||
# Handle hook callback
|
||||
|
|
|
|||
|
|
@ -100,6 +100,22 @@ class ClaudeSDKClient:
|
|||
self._query: Any | None = None
|
||||
os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py-client"
|
||||
|
||||
def _convert_hooks_to_internal_format(
|
||||
self, hooks: dict[str, list]
|
||||
) -> dict[str, list[dict[str, Any]]]:
|
||||
"""Convert HookMatcher format to internal Query format."""
|
||||
internal_hooks = {}
|
||||
for event, matchers in hooks.items():
|
||||
internal_hooks[event] = []
|
||||
for matcher in matchers:
|
||||
# Convert HookMatcher to internal dict format
|
||||
internal_matcher = {
|
||||
"matcher": matcher.matcher if hasattr(matcher, 'matcher') else None,
|
||||
"hooks": matcher.hooks if hasattr(matcher, 'hooks') else []
|
||||
}
|
||||
internal_hooks[event].append(internal_matcher)
|
||||
return internal_hooks
|
||||
|
||||
async def connect(
|
||||
self, prompt: str | AsyncIterable[dict[str, Any]] | None = None
|
||||
) -> None:
|
||||
|
|
@ -135,8 +151,8 @@ class ClaudeSDKClient:
|
|||
self._query = Query(
|
||||
transport=self._transport,
|
||||
is_streaming_mode=True, # ClaudeSDKClient always uses streaming mode
|
||||
can_use_tool=None, # TODO: Add support for can_use_tool callback
|
||||
hooks=None, # TODO: Add support for hooks
|
||||
can_use_tool=self.options.can_use_tool,
|
||||
hooks=self._convert_hooks_to_internal_format(self.options.hooks) if self.options.hooks else None,
|
||||
sdk_mcp_servers=sdk_mcp_servers,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,14 @@
|
|||
"""Type definitions for Claude SDK."""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypedDict
|
||||
|
||||
from typing_extensions import NotRequired # For Python < 3.11 compatibility
|
||||
try:
|
||||
from typing import NotRequired # Python 3.11+
|
||||
except ImportError:
|
||||
from typing_extensions import NotRequired # For Python < 3.11 compatibility
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.server import Server as McpServer
|
||||
|
|
@ -13,6 +17,87 @@ if TYPE_CHECKING:
|
|||
PermissionMode = Literal["default", "acceptEdits", "plan", "bypassPermissions"]
|
||||
|
||||
|
||||
# Permission Update types (matching TypeScript SDK)
|
||||
PermissionUpdateDestination = Literal[
|
||||
"userSettings",
|
||||
"projectSettings",
|
||||
"localSettings",
|
||||
"session"
|
||||
]
|
||||
|
||||
PermissionBehavior = Literal["allow", "deny", "ask"]
|
||||
|
||||
@dataclass
|
||||
class PermissionRuleValue:
|
||||
"""Permission rule value."""
|
||||
toolName: str
|
||||
ruleContent: str | None = None
|
||||
|
||||
@dataclass
|
||||
class PermissionUpdate:
|
||||
"""Permission update configuration."""
|
||||
type: Literal["addRules", "replaceRules", "removeRules", "setMode", "addDirectories", "removeDirectories"]
|
||||
rules: list[PermissionRuleValue] | None = None
|
||||
behavior: PermissionBehavior | None = None
|
||||
mode: PermissionMode | None = None
|
||||
directories: list[str] | None = None
|
||||
destination: PermissionUpdateDestination | None = None
|
||||
|
||||
# Tool callback types
|
||||
@dataclass
|
||||
class ToolPermissionContext:
|
||||
"""Context information for tool permission callbacks."""
|
||||
|
||||
signal: Any | None = None # Future: abort signal support
|
||||
suggestions: list[PermissionUpdate] = field(default_factory=list) # Permission suggestions from CLI
|
||||
|
||||
|
||||
# Match TypeScript's PermissionResult structure
|
||||
@dataclass
|
||||
class PermissionResultAllow:
|
||||
"""Allow permission result."""
|
||||
behavior: Literal["allow"] = "allow"
|
||||
updatedInput: dict[str, Any] | None = None
|
||||
updatedPermissions: list[PermissionUpdate] | None = None
|
||||
|
||||
@dataclass
|
||||
class PermissionResultDeny:
|
||||
"""Deny permission result."""
|
||||
behavior: Literal["deny"] = "deny"
|
||||
message: str = ""
|
||||
interrupt: bool = False
|
||||
|
||||
PermissionResult = PermissionResultAllow | PermissionResultDeny
|
||||
|
||||
CanUseTool = Callable[
|
||||
[str, dict[str, Any], ToolPermissionContext],
|
||||
Awaitable[PermissionResult]
|
||||
]
|
||||
|
||||
|
||||
# Hook callback types
|
||||
@dataclass
|
||||
class HookContext:
|
||||
"""Context information for hook callbacks."""
|
||||
|
||||
signal: Any | None = None # Future: abort signal support
|
||||
|
||||
|
||||
HookCallback = Callable[
|
||||
[dict[str, Any], str | None, HookContext], # input, tool_use_id, context
|
||||
Awaitable[dict[str, Any]] # response data
|
||||
]
|
||||
|
||||
|
||||
# Hook matcher configuration
|
||||
@dataclass
|
||||
class HookMatcher:
|
||||
"""Hook matcher configuration."""
|
||||
|
||||
matcher: dict[str, Any] | None = None # Matcher criteria
|
||||
hooks: list[HookCallback] = field(default_factory=list) # Callbacks to invoke
|
||||
|
||||
|
||||
# MCP Server config
|
||||
class McpStdioServerConfig(TypedDict):
|
||||
"""MCP stdio server configuration."""
|
||||
|
|
@ -155,6 +240,12 @@ class ClaudeCodeOptions:
|
|||
default_factory=dict
|
||||
) # Pass arbitrary CLI flags
|
||||
|
||||
# Tool permission callback
|
||||
can_use_tool: CanUseTool | None = None
|
||||
|
||||
# Hook configurations
|
||||
hooks: dict[str, list[HookMatcher]] | None = None
|
||||
|
||||
|
||||
# SDK Control Protocol
|
||||
class SDKControlInterruptRequest(TypedDict):
|
||||
|
|
@ -169,7 +260,6 @@ class SDKControlPermissionRequest(TypedDict):
|
|||
permission_suggestions: list[Any] | None
|
||||
blocked_path: str | None
|
||||
|
||||
|
||||
class SDKControlInitializeRequest(TypedDict):
|
||||
subtype: Literal["initialize"]
|
||||
# TODO: Use HookEvent names as the key.
|
||||
|
|
|
|||
316
tests/test_tool_callbacks.py
Normal file
316
tests/test_tool_callbacks.py
Normal file
|
|
@ -0,0 +1,316 @@
|
|||
"""Tests for tool permission callbacks and hook callbacks."""
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_code_sdk import (
|
||||
ClaudeCodeOptions,
|
||||
HookContext,
|
||||
HookMatcher,
|
||||
PermissionResultAllow,
|
||||
PermissionResultDeny,
|
||||
ToolPermissionContext,
|
||||
)
|
||||
from claude_code_sdk._internal.query import Query
|
||||
from claude_code_sdk._internal.transport import Transport
|
||||
|
||||
|
||||
class MockTransport(Transport):
|
||||
"""Mock transport for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.written_messages = []
|
||||
self.messages_to_read = []
|
||||
self._connected = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
self._connected = True
|
||||
|
||||
async def close(self) -> None:
|
||||
self._connected = False
|
||||
|
||||
async def write(self, data: str) -> None:
|
||||
self.written_messages.append(data)
|
||||
|
||||
async def end_input(self) -> None:
|
||||
pass
|
||||
|
||||
def read_messages(self):
|
||||
async def _read():
|
||||
for msg in self.messages_to_read:
|
||||
yield msg
|
||||
return _read()
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
return self._connected
|
||||
|
||||
|
||||
class TestToolPermissionCallbacks:
|
||||
"""Test tool permission callback functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_allow(self):
|
||||
"""Test callback that allows tool execution."""
|
||||
callback_invoked = False
|
||||
|
||||
async def allow_callback(
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
context: ToolPermissionContext
|
||||
) -> PermissionResultAllow:
|
||||
nonlocal callback_invoked
|
||||
callback_invoked = True
|
||||
assert tool_name == "TestTool"
|
||||
assert input_data == {"param": "value"}
|
||||
return PermissionResultAllow()
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=allow_callback,
|
||||
hooks=None
|
||||
)
|
||||
|
||||
# Simulate control request
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-1",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "TestTool",
|
||||
"input": {"param": "value"},
|
||||
"permission_suggestions": []
|
||||
}
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check callback was invoked
|
||||
assert callback_invoked
|
||||
|
||||
# Check response was sent
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
assert '"allow": true' in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_deny(self):
|
||||
"""Test callback that denies tool execution."""
|
||||
async def deny_callback(
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
context: ToolPermissionContext
|
||||
) -> PermissionResultDeny:
|
||||
return PermissionResultDeny(
|
||||
message="Security policy violation"
|
||||
)
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=deny_callback,
|
||||
hooks=None
|
||||
)
|
||||
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-2",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "DangerousTool",
|
||||
"input": {"command": "rm -rf /"},
|
||||
"permission_suggestions": ["deny"]
|
||||
}
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check response
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
assert '"allow": false' in response
|
||||
assert '"reason": "Security policy violation"' in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_input_modification(self):
|
||||
"""Test callback that modifies tool input."""
|
||||
async def modify_callback(
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
context: ToolPermissionContext
|
||||
) -> PermissionResultAllow:
|
||||
# Modify the input to add safety flag
|
||||
modified_input = input_data.copy()
|
||||
modified_input["safe_mode"] = True
|
||||
return PermissionResultAllow(
|
||||
updatedInput=modified_input
|
||||
)
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=modify_callback,
|
||||
hooks=None
|
||||
)
|
||||
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-3",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "WriteTool",
|
||||
"input": {"file_path": "/etc/passwd"},
|
||||
"permission_suggestions": []
|
||||
}
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check response includes modified input
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
assert '"allow": true' in response
|
||||
assert '"safe_mode": true' in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_exception_handling(self):
|
||||
"""Test that callback exceptions are properly handled."""
|
||||
async def error_callback(
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
context: ToolPermissionContext
|
||||
) -> PermissionResultAllow:
|
||||
raise ValueError("Callback error")
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=error_callback,
|
||||
hooks=None
|
||||
)
|
||||
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-5",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "TestTool",
|
||||
"input": {},
|
||||
"permission_suggestions": []
|
||||
}
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check error response was sent
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
assert '"subtype": "error"' in response
|
||||
assert "Callback error" in response
|
||||
|
||||
|
||||
class TestHookCallbacks:
|
||||
"""Test hook callback functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_execution(self):
|
||||
"""Test that hooks are called at appropriate times."""
|
||||
hook_calls = []
|
||||
|
||||
async def test_hook(
|
||||
input_data: dict,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext
|
||||
) -> dict:
|
||||
hook_calls.append({
|
||||
"input": input_data,
|
||||
"tool_use_id": tool_use_id
|
||||
})
|
||||
return {"processed": True}
|
||||
|
||||
transport = MockTransport()
|
||||
|
||||
# Create hooks configuration
|
||||
hooks = {
|
||||
"tool_use_start": [
|
||||
{
|
||||
"matcher": {"tool": "TestTool"},
|
||||
"hooks": [test_hook]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=None,
|
||||
hooks=hooks
|
||||
)
|
||||
|
||||
# Manually register the hook callback to avoid needing the full initialize flow
|
||||
callback_id = "test_hook_0"
|
||||
query.hook_callbacks[callback_id] = test_hook
|
||||
|
||||
# Simulate hook callback request
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-hook-1",
|
||||
"request": {
|
||||
"subtype": "hook_callback",
|
||||
"callback_id": callback_id,
|
||||
"input": {"test": "data"},
|
||||
"tool_use_id": "tool-123"
|
||||
}
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check hook was called
|
||||
assert len(hook_calls) == 1
|
||||
assert hook_calls[0]["input"] == {"test": "data"}
|
||||
assert hook_calls[0]["tool_use_id"] == "tool-123"
|
||||
|
||||
# Check response
|
||||
assert len(transport.written_messages) > 0
|
||||
last_response = transport.written_messages[-1]
|
||||
assert '"processed": true' in last_response
|
||||
|
||||
|
||||
class TestClaudeCodeOptionsIntegration:
|
||||
"""Test that callbacks work through ClaudeCodeOptions."""
|
||||
|
||||
def test_options_with_callbacks(self):
|
||||
"""Test creating options with callbacks."""
|
||||
async def my_callback(
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
context: ToolPermissionContext
|
||||
) -> PermissionResultAllow:
|
||||
return PermissionResultAllow()
|
||||
|
||||
async def my_hook(
|
||||
input_data: dict,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
options = ClaudeCodeOptions(
|
||||
can_use_tool=my_callback,
|
||||
hooks={
|
||||
"tool_use_start": [
|
||||
HookMatcher(
|
||||
matcher={"tool": "Bash"},
|
||||
hooks=[my_hook]
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
assert options.can_use_tool == my_callback
|
||||
assert "tool_use_start" in options.hooks
|
||||
assert len(options.hooks["tool_use_start"]) == 1
|
||||
assert options.hooks["tool_use_start"][0].hooks[0] == my_hook
|
||||
Loading…
Add table
Add a link
Reference in a new issue