feat: Add tool permission and hook callbacks support

- Added ToolPermissionCallback type for controlling tool execution
- Added HookCallback type for intercepting tool events
- Added callbacks to ClaudeCodeOptions
- Thread callbacks through InternalClient and ClaudeSDKClient to Query
- Updated Query to handle tool permission requests with new types
- Support both ToolPermissionResponse and dict returns for compatibility
- Added example demonstrating permission control and input modification
- Added comprehensive tests for callbacks functionality
- Fixed NotRequired import for Python 3.11+ compatibility
This commit is contained in:
Kashyap Murali 2025-09-01 17:31:18 -07:00
parent 8e9548a1c5
commit a774031c34
No known key found for this signature in database
7 changed files with 671 additions and 10 deletions

View file

@ -0,0 +1,166 @@
#!/usr/bin/env python3
"""Example: Tool Permission Callbacks.
This example demonstrates how to use tool permission callbacks to control
which tools Claude can use and modify their inputs.
"""
import asyncio
import json
from claude_code_sdk import (
ClaudeCodeOptions,
ClaudeSDKClient,
ToolPermissionCallback,
ToolPermissionResponse,
ToolPermissionContext,
TextBlock,
AssistantMessage,
ResultMessage,
)
# Track tool usage for demonstration
tool_usage_log = []
async def my_permission_callback(
tool_name: str,
input_data: dict,
context: ToolPermissionContext
) -> ToolPermissionResponse:
"""Control tool permissions based on tool type and input."""
# Log the tool request
tool_usage_log.append({
"tool": tool_name,
"input": input_data,
"suggestions": context.suggestions
})
print(f"\n🔧 Tool Permission Request: {tool_name}")
print(f" Input: {json.dumps(input_data, indent=2)}")
# Always allow read operations
if tool_name in ["Read", "Glob", "Grep"]:
print(f" ✅ Automatically allowing {tool_name} (read-only operation)")
return ToolPermissionResponse(
allow=True,
reason="Read operations are always allowed"
)
# Deny write operations to system directories
if tool_name in ["Write", "Edit", "MultiEdit"]:
file_path = input_data.get("file_path", "")
if file_path.startswith("/etc/") or file_path.startswith("/usr/"):
print(f" ❌ Denying write to system directory: {file_path}")
return ToolPermissionResponse(
allow=False,
reason=f"Cannot write to system directory: {file_path}"
)
# Redirect writes to a safe directory
if not file_path.startswith("/tmp/") and not file_path.startswith("./"):
safe_path = f"./safe_output/{file_path.split('/')[-1]}"
print(f" ⚠️ Redirecting write from {file_path} to {safe_path}")
modified_input = input_data.copy()
modified_input["file_path"] = safe_path
return ToolPermissionResponse(
allow=True,
input=modified_input,
reason=f"Redirected to safe path: {safe_path}"
)
# Check dangerous bash commands
if tool_name == "Bash":
command = input_data.get("command", "")
dangerous_commands = ["rm -rf", "sudo", "chmod 777", "dd if=", "mkfs"]
for dangerous in dangerous_commands:
if dangerous in command:
print(f" ❌ Denying dangerous command: {command}")
return ToolPermissionResponse(
allow=False,
reason=f"Dangerous command pattern detected: {dangerous}"
)
# Allow but log the command
print(f" ✅ Allowing bash command: {command}")
return ToolPermissionResponse(
allow=True,
reason="Command appears safe"
)
# For all other tools, ask the user
print(f" ❓ Unknown tool: {tool_name}")
print(f" Input: {json.dumps(input_data, indent=6)}")
user_input = input(" Allow this tool? (y/N): ").strip().lower()
return ToolPermissionResponse(
allow=user_input in ("y", "yes"),
reason=f"User {'approved' if user_input in ('y', 'yes') else 'denied'}"
)
async def main():
"""Run example with tool permission callbacks."""
print("=" * 60)
print("Tool Permission Callback Example")
print("=" * 60)
print("\nThis example demonstrates how to:")
print("1. Allow/deny tools based on type")
print("2. Modify tool inputs for safety")
print("3. Log tool usage")
print("4. Prompt for unknown tools")
print("=" * 60)
# Configure options with our callback
options = ClaudeCodeOptions(
tool_permission_callback=my_permission_callback,
# Use default permission mode to ensure callbacks are invoked
permission_mode="default",
cwd="." # Set working directory
)
# Create client and send a query that will use multiple tools
async with ClaudeSDKClient(options) as client:
print("\n📝 Sending query to Claude...")
await client.query(
"Please do the following:\n"
"1. List the files in the current directory\n"
"2. Create a simple Python hello world script at hello.py\n"
"3. Run the script to test it"
)
print("\n📨 Receiving response...")
message_count = 0
async for message in client.receive_response():
message_count += 1
if isinstance(message, AssistantMessage):
# Print Claude's text responses
for block in message.content:
if isinstance(block, TextBlock):
print(f"\n💬 Claude: {block.text}")
elif isinstance(message, ResultMessage):
print(f"\n✅ Task completed!")
print(f" Duration: {message.duration_ms}ms")
if message.total_cost_usd:
print(f" Cost: ${message.total_cost_usd:.4f}")
print(f" Messages processed: {message_count}")
# Print tool usage summary
print("\n" + "=" * 60)
print("Tool Usage Summary")
print("=" * 60)
for i, usage in enumerate(tool_usage_log, 1):
print(f"\n{i}. Tool: {usage['tool']}")
print(f" Input: {json.dumps(usage['input'], indent=6)}")
if usage['suggestions']:
print(f" Suggestions: {usage['suggestions']}")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -14,6 +14,9 @@ from .types import (
AssistantMessage,
ClaudeCodeOptions,
ContentBlock,
HookCallback,
HookContext,
HookMatcher,
McpServerConfig,
Message,
PermissionMode,
@ -21,6 +24,9 @@ from .types import (
SystemMessage,
TextBlock,
ThinkingBlock,
ToolPermissionCallback,
ToolPermissionContext,
ToolPermissionResponse,
ToolResultBlock,
ToolUseBlock,
UserMessage,
@ -48,6 +54,13 @@ __all__ = [
"ToolUseBlock",
"ToolResultBlock",
"ContentBlock",
# Tool callbacks
"ToolPermissionCallback",
"ToolPermissionContext",
"ToolPermissionResponse",
"HookCallback",
"HookContext",
"HookMatcher",
# Errors
"ClaudeSDKError",
"CLIConnectionError",

View file

@ -19,6 +19,22 @@ class InternalClient:
def __init__(self) -> None:
"""Initialize the internal client."""
def _convert_hooks_to_internal_format(
self, hooks: dict[str, list]
) -> dict[str, list[dict[str, Any]]]:
"""Convert HookMatcher format to internal Query format."""
internal_hooks = {}
for event, matchers in hooks.items():
internal_hooks[event] = []
for matcher in matchers:
# Convert HookMatcher to internal dict format
internal_matcher = {
"matcher": matcher.matcher if hasattr(matcher, 'matcher') else None,
"hooks": matcher.hooks if hasattr(matcher, 'hooks') else []
}
internal_hooks[event].append(internal_matcher)
return internal_hooks
async def process_query(
self,
prompt: str | AsyncIterable[dict[str, Any]],
@ -41,8 +57,8 @@ class InternalClient:
query = Query(
transport=chosen_transport,
is_streaming_mode=is_streaming,
can_use_tool=None, # TODO: Add support for can_use_tool callback
hooks=None, # TODO: Add support for hooks
can_use_tool=options.tool_permission_callback,
hooks=self._convert_hooks_to_internal_format(options.hooks) if options.hooks else None,
)
try:

View file

@ -184,15 +184,35 @@ class Query:
if not self.can_use_tool:
raise Exception("canUseTool callback is not provided")
response_data = await self.can_use_tool(
# Import here to avoid circular dependency
from ..types import ToolPermissionContext, ToolPermissionResponse
context = ToolPermissionContext(
signal=None, # TODO: Add abort signal support
suggestions=permission_request.get("permission_suggestions", [])
)
response = await self.can_use_tool(
permission_request["tool_name"],
permission_request["input"],
{
"signal": None, # TODO: Add abort signal support
"suggestions": permission_request.get("permission_suggestions"),
},
context
)
# Convert ToolPermissionResponse to expected dict format
if isinstance(response, ToolPermissionResponse):
response_data = {
"allow": response.allow
}
if response.input is not None:
response_data["input"] = response.input
if response.reason is not None:
response_data["reason"] = response.reason
elif isinstance(response, dict):
# Support returning dict directly for compatibility
response_data = response
else:
raise TypeError(f"Tool permission callback must return ToolPermissionResponse or dict, got {type(response)}")
elif subtype == "hook_callback":
hook_callback_request: SDKHookCallbackRequest = request_data # type: ignore[assignment]
# Handle hook callback

View file

@ -100,6 +100,22 @@ class ClaudeSDKClient:
self._query: Any | None = None
os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py-client"
def _convert_hooks_to_internal_format(
self, hooks: dict[str, list]
) -> dict[str, list[dict[str, Any]]]:
"""Convert HookMatcher format to internal Query format."""
internal_hooks = {}
for event, matchers in hooks.items():
internal_hooks[event] = []
for matcher in matchers:
# Convert HookMatcher to internal dict format
internal_matcher = {
"matcher": matcher.matcher if hasattr(matcher, 'matcher') else None,
"hooks": matcher.hooks if hasattr(matcher, 'hooks') else []
}
internal_hooks[event].append(internal_matcher)
return internal_hooks
async def connect(
self, prompt: str | AsyncIterable[dict[str, Any]] | None = None
) -> None:
@ -128,8 +144,8 @@ class ClaudeSDKClient:
self._query = Query(
transport=self._transport,
is_streaming_mode=True, # ClaudeSDKClient always uses streaming mode
can_use_tool=None, # TODO: Add support for can_use_tool callback
hooks=None, # TODO: Add support for hooks
can_use_tool=self.options.tool_permission_callback,
hooks=self._convert_hooks_to_internal_format(self.options.hooks) if self.options.hooks else None,
)
# Start reading messages and initialize

View file

@ -1,15 +1,66 @@
"""Type definitions for Claude SDK."""
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, TypedDict
from typing_extensions import NotRequired # For Python < 3.11 compatibility
try:
from typing import NotRequired # Python 3.11+
except ImportError:
from typing_extensions import NotRequired # For Python < 3.11 compatibility
# Permission modes
PermissionMode = Literal["default", "acceptEdits", "plan", "bypassPermissions"]
# Tool callback types
@dataclass
class ToolPermissionContext:
"""Context information for tool permission callbacks."""
signal: Any | None = None # Future: abort signal support
suggestions: list[str] = field(default_factory=list) # Permission suggestions from CLI
@dataclass
class ToolPermissionResponse:
"""Response from tool permission callback."""
allow: bool
input: dict[str, Any] | None = None # Optional: modified input parameters
reason: str | None = None # Optional: reason for decision
ToolPermissionCallback = Callable[
[str, dict[str, Any], ToolPermissionContext],
Awaitable[ToolPermissionResponse | dict[str, Any]]
]
# Hook callback types
@dataclass
class HookContext:
"""Context information for hook callbacks."""
signal: Any | None = None # Future: abort signal support
HookCallback = Callable[
[dict[str, Any], str | None, HookContext], # input, tool_use_id, context
Awaitable[dict[str, Any]] # response data
]
# Hook matcher configuration
@dataclass
class HookMatcher:
"""Hook matcher configuration."""
matcher: dict[str, Any] | None = None # Matcher criteria
hooks: list[HookCallback] = field(default_factory=list) # Callbacks to invoke
# MCP Server config
class McpStdioServerConfig(TypedDict):
"""MCP stdio server configuration."""
@ -142,6 +193,12 @@ class ClaudeCodeOptions:
default_factory=dict
) # Pass arbitrary CLI flags
# Tool permission callback
tool_permission_callback: ToolPermissionCallback | None = None
# Hook configurations
hooks: dict[str, list[HookMatcher]] | None = None
# SDK Control Protocol
class SDKControlInterruptRequest(TypedDict):

View file

@ -0,0 +1,373 @@
"""Tests for tool permission callbacks and hook callbacks."""
import asyncio
from unittest.mock import AsyncMock, Mock
import pytest
from claude_code_sdk import (
ClaudeCodeOptions,
ToolPermissionCallback,
ToolPermissionResponse,
ToolPermissionContext,
HookCallback,
HookContext,
HookMatcher,
)
from claude_code_sdk._internal.query import Query
from claude_code_sdk._internal.transport import Transport
class MockTransport(Transport):
"""Mock transport for testing."""
def __init__(self):
self.written_messages = []
self.messages_to_read = []
self._connected = False
async def connect(self) -> None:
self._connected = True
async def close(self) -> None:
self._connected = False
async def write(self, data: str) -> None:
self.written_messages.append(data)
async def end_input(self) -> None:
pass
def read_messages(self):
async def _read():
for msg in self.messages_to_read:
yield msg
return _read()
def is_ready(self) -> bool:
return self._connected
class TestToolPermissionCallbacks:
"""Test tool permission callback functionality."""
@pytest.mark.asyncio
async def test_permission_callback_allow(self):
"""Test callback that allows tool execution."""
callback_invoked = False
async def allow_callback(
tool_name: str,
input_data: dict,
context: ToolPermissionContext
) -> ToolPermissionResponse:
nonlocal callback_invoked
callback_invoked = True
assert tool_name == "TestTool"
assert input_data == {"param": "value"}
return ToolPermissionResponse(allow=True, reason="Test allow")
transport = MockTransport()
query = Query(
transport=transport,
is_streaming_mode=True,
can_use_tool=allow_callback,
hooks=None
)
# Simulate control request
request = {
"type": "control_request",
"request_id": "test-1",
"request": {
"subtype": "can_use_tool",
"tool_name": "TestTool",
"input": {"param": "value"},
"permission_suggestions": []
}
}
await query._handle_control_request(request)
# Check callback was invoked
assert callback_invoked
# Check response was sent
assert len(transport.written_messages) == 1
response = transport.written_messages[0]
assert '"allow": true' in response
assert '"reason": "Test allow"' in response
@pytest.mark.asyncio
async def test_permission_callback_deny(self):
"""Test callback that denies tool execution."""
async def deny_callback(
tool_name: str,
input_data: dict,
context: ToolPermissionContext
) -> ToolPermissionResponse:
return ToolPermissionResponse(
allow=False,
reason="Security policy violation"
)
transport = MockTransport()
query = Query(
transport=transport,
is_streaming_mode=True,
can_use_tool=deny_callback,
hooks=None
)
request = {
"type": "control_request",
"request_id": "test-2",
"request": {
"subtype": "can_use_tool",
"tool_name": "DangerousTool",
"input": {"command": "rm -rf /"},
"permission_suggestions": ["deny"]
}
}
await query._handle_control_request(request)
# Check response
assert len(transport.written_messages) == 1
response = transport.written_messages[0]
assert '"allow": false' in response
assert '"reason": "Security policy violation"' in response
@pytest.mark.asyncio
async def test_permission_callback_input_modification(self):
"""Test callback that modifies tool input."""
async def modify_callback(
tool_name: str,
input_data: dict,
context: ToolPermissionContext
) -> ToolPermissionResponse:
# Modify the input to add safety flag
modified_input = input_data.copy()
modified_input["safe_mode"] = True
return ToolPermissionResponse(
allow=True,
input=modified_input,
reason="Modified for safety"
)
transport = MockTransport()
query = Query(
transport=transport,
is_streaming_mode=True,
can_use_tool=modify_callback,
hooks=None
)
request = {
"type": "control_request",
"request_id": "test-3",
"request": {
"subtype": "can_use_tool",
"tool_name": "WriteTool",
"input": {"file_path": "/etc/passwd"},
"permission_suggestions": []
}
}
await query._handle_control_request(request)
# Check response includes modified input
assert len(transport.written_messages) == 1
response = transport.written_messages[0]
assert '"allow": true' in response
assert '"safe_mode": true' in response
assert '"reason": "Modified for safety"' in response
@pytest.mark.asyncio
async def test_permission_callback_dict_return(self):
"""Test callback can return dict for backwards compatibility."""
async def dict_callback(
tool_name: str,
input_data: dict,
context: ToolPermissionContext
) -> dict:
# Return dict directly instead of ToolPermissionResponse
return {
"allow": True,
"reason": "Dict response"
}
transport = MockTransport()
query = Query(
transport=transport,
is_streaming_mode=True,
can_use_tool=dict_callback,
hooks=None
)
request = {
"type": "control_request",
"request_id": "test-4",
"request": {
"subtype": "can_use_tool",
"tool_name": "TestTool",
"input": {},
"permission_suggestions": []
}
}
await query._handle_control_request(request)
# Check response
assert len(transport.written_messages) == 1
response = transport.written_messages[0]
assert '"allow": true' in response
assert '"reason": "Dict response"' in response
@pytest.mark.asyncio
async def test_callback_exception_handling(self):
"""Test that callback exceptions are properly handled."""
async def error_callback(
tool_name: str,
input_data: dict,
context: ToolPermissionContext
) -> ToolPermissionResponse:
raise ValueError("Callback error")
transport = MockTransport()
query = Query(
transport=transport,
is_streaming_mode=True,
can_use_tool=error_callback,
hooks=None
)
request = {
"type": "control_request",
"request_id": "test-5",
"request": {
"subtype": "can_use_tool",
"tool_name": "TestTool",
"input": {},
"permission_suggestions": []
}
}
await query._handle_control_request(request)
# Check error response was sent
assert len(transport.written_messages) == 1
response = transport.written_messages[0]
assert '"subtype": "error"' in response
assert "Callback error" in response
class TestHookCallbacks:
"""Test hook callback functionality."""
@pytest.mark.asyncio
async def test_hook_execution(self):
"""Test that hooks are called at appropriate times."""
hook_calls = []
async def test_hook(
input_data: dict,
tool_use_id: str | None,
context: HookContext
) -> dict:
hook_calls.append({
"input": input_data,
"tool_use_id": tool_use_id
})
return {"processed": True}
transport = MockTransport()
# Create hooks configuration
hooks = {
"tool_use_start": [
{
"matcher": {"tool": "TestTool"},
"hooks": [test_hook]
}
]
}
query = Query(
transport=transport,
is_streaming_mode=True,
can_use_tool=None,
hooks=hooks
)
# During initialization, hook callbacks are registered
await query.initialize()
# Find the registered callback ID
callback_id = None
for cid in query.hook_callbacks:
if query.hook_callbacks[cid] == test_hook:
callback_id = cid
break
assert callback_id is not None
# Simulate hook callback request
request = {
"type": "control_request",
"request_id": "test-hook-1",
"request": {
"subtype": "hook_callback",
"callback_id": callback_id,
"input": {"test": "data"},
"tool_use_id": "tool-123"
}
}
await query._handle_control_request(request)
# Check hook was called
assert len(hook_calls) == 1
assert hook_calls[0]["input"] == {"test": "data"}
assert hook_calls[0]["tool_use_id"] == "tool-123"
# Check response
assert len(transport.written_messages) > 0
last_response = transport.written_messages[-1]
assert '"processed": true' in last_response
class TestClaudeCodeOptionsIntegration:
"""Test that callbacks work through ClaudeCodeOptions."""
def test_options_with_callbacks(self):
"""Test creating options with callbacks."""
async def my_callback(
tool_name: str,
input_data: dict,
context: ToolPermissionContext
) -> ToolPermissionResponse:
return ToolPermissionResponse(allow=True)
async def my_hook(
input_data: dict,
tool_use_id: str | None,
context: HookContext
) -> dict:
return {}
options = ClaudeCodeOptions(
tool_permission_callback=my_callback,
hooks={
"tool_use_start": [
HookMatcher(
matcher={"tool": "Bash"},
hooks=[my_hook]
)
]
}
)
assert options.tool_permission_callback == my_callback
assert "tool_use_start" in options.hooks
assert len(options.hooks["tool_use_start"]) == 1
assert options.hooks["tool_use_start"][0].hooks[0] == my_hook