diff --git a/examples/tool_permission_callback.py b/examples/tool_permission_callback.py new file mode 100644 index 0000000..ccff319 --- /dev/null +++ b/examples/tool_permission_callback.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +"""Example: Tool Permission Callbacks. + +This example demonstrates how to use tool permission callbacks to control +which tools Claude can use and modify their inputs. +""" + +import asyncio +import json + +from claude_code_sdk import ( + AssistantMessage, + ClaudeCodeOptions, + ClaudeSDKClient, + PermissionResultAllow, + PermissionResultDeny, + ResultMessage, + TextBlock, + ToolPermissionContext, +) + +# Track tool usage for demonstration +tool_usage_log = [] + + +async def my_permission_callback( + tool_name: str, + input_data: dict, + context: ToolPermissionContext +) -> PermissionResultAllow | PermissionResultDeny: + """Control tool permissions based on tool type and input.""" + + # Log the tool request + tool_usage_log.append({ + "tool": tool_name, + "input": input_data, + "suggestions": context.suggestions + }) + + print(f"\nšŸ”§ Tool Permission Request: {tool_name}") + print(f" Input: {json.dumps(input_data, indent=2)}") + + # Always allow read operations + if tool_name in ["Read", "Glob", "Grep"]: + print(f" āœ… Automatically allowing {tool_name} (read-only operation)") + return PermissionResultAllow() + + # Deny write operations to system directories + if tool_name in ["Write", "Edit", "MultiEdit"]: + file_path = input_data.get("file_path", "") + if file_path.startswith("/etc/") or file_path.startswith("/usr/"): + print(f" āŒ Denying write to system directory: {file_path}") + return PermissionResultDeny( + message=f"Cannot write to system directory: {file_path}" + ) + + # Redirect writes to a safe directory + if not file_path.startswith("/tmp/") and not file_path.startswith("./"): + safe_path = f"./safe_output/{file_path.split('/')[-1]}" + print(f" āš ļø Redirecting write from {file_path} to {safe_path}") + modified_input = input_data.copy() + modified_input["file_path"] = safe_path + return PermissionResultAllow( + updatedInput=modified_input + ) + + # Check dangerous bash commands + if tool_name == "Bash": + command = input_data.get("command", "") + dangerous_commands = ["rm -rf", "sudo", "chmod 777", "dd if=", "mkfs"] + + for dangerous in dangerous_commands: + if dangerous in command: + print(f" āŒ Denying dangerous command: {command}") + return PermissionResultDeny( + message=f"Dangerous command pattern detected: {dangerous}" + ) + + # Allow but log the command + print(f" āœ… Allowing bash command: {command}") + return PermissionResultAllow() + + # For all other tools, ask the user + print(f" ā“ Unknown tool: {tool_name}") + print(f" Input: {json.dumps(input_data, indent=6)}") + user_input = input(" Allow this tool? (y/N): ").strip().lower() + + if user_input in ("y", "yes"): + return PermissionResultAllow() + else: + return PermissionResultDeny( + message="User denied permission" + ) + + +async def main(): + """Run example with tool permission callbacks.""" + + print("=" * 60) + print("Tool Permission Callback Example") + print("=" * 60) + print("\nThis example demonstrates how to:") + print("1. Allow/deny tools based on type") + print("2. Modify tool inputs for safety") + print("3. Log tool usage") + print("4. Prompt for unknown tools") + print("=" * 60) + + # Configure options with our callback + options = ClaudeCodeOptions( + can_use_tool=my_permission_callback, + # Use default permission mode to ensure callbacks are invoked + permission_mode="default", + cwd="." # Set working directory + ) + + # Create client and send a query that will use multiple tools + async with ClaudeSDKClient(options) as client: + print("\nšŸ“ Sending query to Claude...") + await client.query( + "Please do the following:\n" + "1. List the files in the current directory\n" + "2. Create a simple Python hello world script at hello.py\n" + "3. Run the script to test it" + ) + + print("\nšŸ“Ø Receiving response...") + message_count = 0 + + async for message in client.receive_response(): + message_count += 1 + + if isinstance(message, AssistantMessage): + # Print Claude's text responses + for block in message.content: + if isinstance(block, TextBlock): + print(f"\nšŸ’¬ Claude: {block.text}") + + elif isinstance(message, ResultMessage): + print("\nāœ… Task completed!") + print(f" Duration: {message.duration_ms}ms") + if message.total_cost_usd: + print(f" Cost: ${message.total_cost_usd:.4f}") + print(f" Messages processed: {message_count}") + + # Print tool usage summary + print("\n" + "=" * 60) + print("Tool Usage Summary") + print("=" * 60) + for i, usage in enumerate(tool_usage_log, 1): + print(f"\n{i}. Tool: {usage['tool']}") + print(f" Input: {json.dumps(usage['input'], indent=6)}") + if usage['suggestions']: + print(f" Suggestions: {usage['suggestions']}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/claude_code_sdk/__init__.py b/src/claude_code_sdk/__init__.py index 0a9aee5..f3b91b6 100644 --- a/src/claude_code_sdk/__init__.py +++ b/src/claude_code_sdk/__init__.py @@ -16,16 +16,25 @@ from .client import ClaudeSDKClient from .query import query from .types import ( AssistantMessage, + CanUseTool, ClaudeCodeOptions, ContentBlock, + HookCallback, + HookContext, + HookMatcher, McpSdkServerConfig, McpServerConfig, Message, PermissionMode, + PermissionResult, + PermissionResultAllow, + PermissionResultDeny, + PermissionUpdate, ResultMessage, SystemMessage, TextBlock, ThinkingBlock, + ToolPermissionContext, ToolResultBlock, ToolUseBlock, UserMessage, @@ -286,6 +295,16 @@ __all__ = [ "ToolUseBlock", "ToolResultBlock", "ContentBlock", + # Tool callbacks + "CanUseTool", + "ToolPermissionContext", + "PermissionResult", + "PermissionResultAllow", + "PermissionResultDeny", + "PermissionUpdate", + "HookCallback", + "HookContext", + "HookMatcher", # MCP Server Support "create_sdk_mcp_server", "tool", diff --git a/src/claude_code_sdk/_internal/client.py b/src/claude_code_sdk/_internal/client.py index 3e99511..d38de1b 100644 --- a/src/claude_code_sdk/_internal/client.py +++ b/src/claude_code_sdk/_internal/client.py @@ -19,6 +19,22 @@ class InternalClient: def __init__(self) -> None: """Initialize the internal client.""" + def _convert_hooks_to_internal_format( + self, hooks: dict[str, list] + ) -> dict[str, list[dict[str, Any]]]: + """Convert HookMatcher format to internal Query format.""" + internal_hooks = {} + for event, matchers in hooks.items(): + internal_hooks[event] = [] + for matcher in matchers: + # Convert HookMatcher to internal dict format + internal_matcher = { + "matcher": matcher.matcher if hasattr(matcher, 'matcher') else None, + "hooks": matcher.hooks if hasattr(matcher, 'hooks') else [] + } + internal_hooks[event].append(internal_matcher) + return internal_hooks + async def process_query( self, prompt: str | AsyncIterable[dict[str, Any]], @@ -48,8 +64,8 @@ class InternalClient: query = Query( transport=chosen_transport, is_streaming_mode=is_streaming, - can_use_tool=None, # TODO: Add support for can_use_tool callback - hooks=None, # TODO: Add support for hooks + can_use_tool=options.can_use_tool, + hooks=self._convert_hooks_to_internal_format(options.hooks) if options.hooks else None, sdk_mcp_servers=sdk_mcp_servers, ) diff --git a/src/claude_code_sdk/_internal/query.py b/src/claude_code_sdk/_internal/query.py index cb9056c..9cbb409 100644 --- a/src/claude_code_sdk/_internal/query.py +++ b/src/claude_code_sdk/_internal/query.py @@ -15,10 +15,14 @@ from mcp.types import ( ) from ..types import ( + PermissionResult, + PermissionResultAllow, + PermissionResultDeny, SDKControlPermissionRequest, SDKControlRequest, SDKControlResponse, SDKHookCallbackRequest, + ToolPermissionContext, ) from .transport import Transport @@ -195,15 +199,34 @@ class Query: if not self.can_use_tool: raise Exception("canUseTool callback is not provided") - response_data = await self.can_use_tool( + context = ToolPermissionContext( + signal=None, # TODO: Add abort signal support + suggestions=permission_request.get("permission_suggestions", []) + ) + + response = await self.can_use_tool( permission_request["tool_name"], permission_request["input"], - { - "signal": None, # TODO: Add abort signal support - "suggestions": permission_request.get("permission_suggestions"), - }, + context ) + # Convert PermissionResult to expected dict format + if isinstance(response, PermissionResultAllow): + response_data = { + "allow": True + } + if response.updatedInput is not None: + response_data["input"] = response.updatedInput + # TODO: Handle updatedPermissions when control protocol supports it + elif isinstance(response, PermissionResultDeny): + response_data = { + "allow": False, + "reason": response.message + } + # TODO: Handle interrupt flag when control protocol supports it + else: + raise TypeError(f"Tool permission callback must return PermissionResult (PermissionResultAllow or PermissionResultDeny), got {type(response)}") + elif subtype == "hook_callback": hook_callback_request: SDKHookCallbackRequest = request_data # type: ignore[assignment] # Handle hook callback diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py index 13ff4d2..9f65e18 100644 --- a/src/claude_code_sdk/client.py +++ b/src/claude_code_sdk/client.py @@ -100,6 +100,22 @@ class ClaudeSDKClient: self._query: Any | None = None os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py-client" + def _convert_hooks_to_internal_format( + self, hooks: dict[str, list] + ) -> dict[str, list[dict[str, Any]]]: + """Convert HookMatcher format to internal Query format.""" + internal_hooks = {} + for event, matchers in hooks.items(): + internal_hooks[event] = [] + for matcher in matchers: + # Convert HookMatcher to internal dict format + internal_matcher = { + "matcher": matcher.matcher if hasattr(matcher, 'matcher') else None, + "hooks": matcher.hooks if hasattr(matcher, 'hooks') else [] + } + internal_hooks[event].append(internal_matcher) + return internal_hooks + async def connect( self, prompt: str | AsyncIterable[dict[str, Any]] | None = None ) -> None: @@ -135,8 +151,8 @@ class ClaudeSDKClient: self._query = Query( transport=self._transport, is_streaming_mode=True, # ClaudeSDKClient always uses streaming mode - can_use_tool=None, # TODO: Add support for can_use_tool callback - hooks=None, # TODO: Add support for hooks + can_use_tool=self.options.can_use_tool, + hooks=self._convert_hooks_to_internal_format(self.options.hooks) if self.options.hooks else None, sdk_mcp_servers=sdk_mcp_servers, ) diff --git a/src/claude_code_sdk/types.py b/src/claude_code_sdk/types.py index 22966dd..ab7d8ad 100644 --- a/src/claude_code_sdk/types.py +++ b/src/claude_code_sdk/types.py @@ -1,10 +1,14 @@ """Type definitions for Claude SDK.""" +from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, TypedDict -from typing_extensions import NotRequired # For Python < 3.11 compatibility +try: + from typing import NotRequired # Python 3.11+ +except ImportError: + from typing_extensions import NotRequired # For Python < 3.11 compatibility if TYPE_CHECKING: from mcp.server import Server as McpServer @@ -13,6 +17,87 @@ if TYPE_CHECKING: PermissionMode = Literal["default", "acceptEdits", "plan", "bypassPermissions"] +# Permission Update types (matching TypeScript SDK) +PermissionUpdateDestination = Literal[ + "userSettings", + "projectSettings", + "localSettings", + "session" +] + +PermissionBehavior = Literal["allow", "deny", "ask"] + +@dataclass +class PermissionRuleValue: + """Permission rule value.""" + toolName: str + ruleContent: str | None = None + +@dataclass +class PermissionUpdate: + """Permission update configuration.""" + type: Literal["addRules", "replaceRules", "removeRules", "setMode", "addDirectories", "removeDirectories"] + rules: list[PermissionRuleValue] | None = None + behavior: PermissionBehavior | None = None + mode: PermissionMode | None = None + directories: list[str] | None = None + destination: PermissionUpdateDestination | None = None + +# Tool callback types +@dataclass +class ToolPermissionContext: + """Context information for tool permission callbacks.""" + + signal: Any | None = None # Future: abort signal support + suggestions: list[PermissionUpdate] = field(default_factory=list) # Permission suggestions from CLI + + +# Match TypeScript's PermissionResult structure +@dataclass +class PermissionResultAllow: + """Allow permission result.""" + behavior: Literal["allow"] = "allow" + updatedInput: dict[str, Any] | None = None + updatedPermissions: list[PermissionUpdate] | None = None + +@dataclass +class PermissionResultDeny: + """Deny permission result.""" + behavior: Literal["deny"] = "deny" + message: str = "" + interrupt: bool = False + +PermissionResult = PermissionResultAllow | PermissionResultDeny + +CanUseTool = Callable[ + [str, dict[str, Any], ToolPermissionContext], + Awaitable[PermissionResult] +] + + +# Hook callback types +@dataclass +class HookContext: + """Context information for hook callbacks.""" + + signal: Any | None = None # Future: abort signal support + + +HookCallback = Callable[ + [dict[str, Any], str | None, HookContext], # input, tool_use_id, context + Awaitable[dict[str, Any]] # response data +] + + +# Hook matcher configuration +@dataclass +class HookMatcher: + """Hook matcher configuration.""" + + matcher: dict[str, Any] | None = None # Matcher criteria + hooks: list[HookCallback] = field(default_factory=list) # Callbacks to invoke + + # MCP Server config class McpStdioServerConfig(TypedDict): """MCP stdio server configuration.""" @@ -155,6 +240,12 @@ class ClaudeCodeOptions: default_factory=dict ) # Pass arbitrary CLI flags + # Tool permission callback + can_use_tool: CanUseTool | None = None + + # Hook configurations + hooks: dict[str, list[HookMatcher]] | None = None + # SDK Control Protocol class SDKControlInterruptRequest(TypedDict): @@ -169,7 +260,6 @@ class SDKControlPermissionRequest(TypedDict): permission_suggestions: list[Any] | None blocked_path: str | None - class SDKControlInitializeRequest(TypedDict): subtype: Literal["initialize"] # TODO: Use HookEvent names as the key. diff --git a/tests/test_tool_callbacks.py b/tests/test_tool_callbacks.py new file mode 100644 index 0000000..26663c4 --- /dev/null +++ b/tests/test_tool_callbacks.py @@ -0,0 +1,316 @@ +"""Tests for tool permission callbacks and hook callbacks.""" + +import pytest + +from claude_code_sdk import ( + ClaudeCodeOptions, + HookContext, + HookMatcher, + PermissionResultAllow, + PermissionResultDeny, + ToolPermissionContext, +) +from claude_code_sdk._internal.query import Query +from claude_code_sdk._internal.transport import Transport + + +class MockTransport(Transport): + """Mock transport for testing.""" + + def __init__(self): + self.written_messages = [] + self.messages_to_read = [] + self._connected = False + + async def connect(self) -> None: + self._connected = True + + async def close(self) -> None: + self._connected = False + + async def write(self, data: str) -> None: + self.written_messages.append(data) + + async def end_input(self) -> None: + pass + + def read_messages(self): + async def _read(): + for msg in self.messages_to_read: + yield msg + return _read() + + def is_ready(self) -> bool: + return self._connected + + +class TestToolPermissionCallbacks: + """Test tool permission callback functionality.""" + + @pytest.mark.asyncio + async def test_permission_callback_allow(self): + """Test callback that allows tool execution.""" + callback_invoked = False + + async def allow_callback( + tool_name: str, + input_data: dict, + context: ToolPermissionContext + ) -> PermissionResultAllow: + nonlocal callback_invoked + callback_invoked = True + assert tool_name == "TestTool" + assert input_data == {"param": "value"} + return PermissionResultAllow() + + transport = MockTransport() + query = Query( + transport=transport, + is_streaming_mode=True, + can_use_tool=allow_callback, + hooks=None + ) + + # Simulate control request + request = { + "type": "control_request", + "request_id": "test-1", + "request": { + "subtype": "can_use_tool", + "tool_name": "TestTool", + "input": {"param": "value"}, + "permission_suggestions": [] + } + } + + await query._handle_control_request(request) + + # Check callback was invoked + assert callback_invoked + + # Check response was sent + assert len(transport.written_messages) == 1 + response = transport.written_messages[0] + assert '"allow": true' in response + + @pytest.mark.asyncio + async def test_permission_callback_deny(self): + """Test callback that denies tool execution.""" + async def deny_callback( + tool_name: str, + input_data: dict, + context: ToolPermissionContext + ) -> PermissionResultDeny: + return PermissionResultDeny( + message="Security policy violation" + ) + + transport = MockTransport() + query = Query( + transport=transport, + is_streaming_mode=True, + can_use_tool=deny_callback, + hooks=None + ) + + request = { + "type": "control_request", + "request_id": "test-2", + "request": { + "subtype": "can_use_tool", + "tool_name": "DangerousTool", + "input": {"command": "rm -rf /"}, + "permission_suggestions": ["deny"] + } + } + + await query._handle_control_request(request) + + # Check response + assert len(transport.written_messages) == 1 + response = transport.written_messages[0] + assert '"allow": false' in response + assert '"reason": "Security policy violation"' in response + + @pytest.mark.asyncio + async def test_permission_callback_input_modification(self): + """Test callback that modifies tool input.""" + async def modify_callback( + tool_name: str, + input_data: dict, + context: ToolPermissionContext + ) -> PermissionResultAllow: + # Modify the input to add safety flag + modified_input = input_data.copy() + modified_input["safe_mode"] = True + return PermissionResultAllow( + updatedInput=modified_input + ) + + transport = MockTransport() + query = Query( + transport=transport, + is_streaming_mode=True, + can_use_tool=modify_callback, + hooks=None + ) + + request = { + "type": "control_request", + "request_id": "test-3", + "request": { + "subtype": "can_use_tool", + "tool_name": "WriteTool", + "input": {"file_path": "/etc/passwd"}, + "permission_suggestions": [] + } + } + + await query._handle_control_request(request) + + # Check response includes modified input + assert len(transport.written_messages) == 1 + response = transport.written_messages[0] + assert '"allow": true' in response + assert '"safe_mode": true' in response + + @pytest.mark.asyncio + async def test_callback_exception_handling(self): + """Test that callback exceptions are properly handled.""" + async def error_callback( + tool_name: str, + input_data: dict, + context: ToolPermissionContext + ) -> PermissionResultAllow: + raise ValueError("Callback error") + + transport = MockTransport() + query = Query( + transport=transport, + is_streaming_mode=True, + can_use_tool=error_callback, + hooks=None + ) + + request = { + "type": "control_request", + "request_id": "test-5", + "request": { + "subtype": "can_use_tool", + "tool_name": "TestTool", + "input": {}, + "permission_suggestions": [] + } + } + + await query._handle_control_request(request) + + # Check error response was sent + assert len(transport.written_messages) == 1 + response = transport.written_messages[0] + assert '"subtype": "error"' in response + assert "Callback error" in response + + +class TestHookCallbacks: + """Test hook callback functionality.""" + + @pytest.mark.asyncio + async def test_hook_execution(self): + """Test that hooks are called at appropriate times.""" + hook_calls = [] + + async def test_hook( + input_data: dict, + tool_use_id: str | None, + context: HookContext + ) -> dict: + hook_calls.append({ + "input": input_data, + "tool_use_id": tool_use_id + }) + return {"processed": True} + + transport = MockTransport() + + # Create hooks configuration + hooks = { + "tool_use_start": [ + { + "matcher": {"tool": "TestTool"}, + "hooks": [test_hook] + } + ] + } + + query = Query( + transport=transport, + is_streaming_mode=True, + can_use_tool=None, + hooks=hooks + ) + + # Manually register the hook callback to avoid needing the full initialize flow + callback_id = "test_hook_0" + query.hook_callbacks[callback_id] = test_hook + + # Simulate hook callback request + request = { + "type": "control_request", + "request_id": "test-hook-1", + "request": { + "subtype": "hook_callback", + "callback_id": callback_id, + "input": {"test": "data"}, + "tool_use_id": "tool-123" + } + } + + await query._handle_control_request(request) + + # Check hook was called + assert len(hook_calls) == 1 + assert hook_calls[0]["input"] == {"test": "data"} + assert hook_calls[0]["tool_use_id"] == "tool-123" + + # Check response + assert len(transport.written_messages) > 0 + last_response = transport.written_messages[-1] + assert '"processed": true' in last_response + + +class TestClaudeCodeOptionsIntegration: + """Test that callbacks work through ClaudeCodeOptions.""" + + def test_options_with_callbacks(self): + """Test creating options with callbacks.""" + async def my_callback( + tool_name: str, + input_data: dict, + context: ToolPermissionContext + ) -> PermissionResultAllow: + return PermissionResultAllow() + + async def my_hook( + input_data: dict, + tool_use_id: str | None, + context: HookContext + ) -> dict: + return {} + + options = ClaudeCodeOptions( + can_use_tool=my_callback, + hooks={ + "tool_use_start": [ + HookMatcher( + matcher={"tool": "Bash"}, + hooks=[my_hook] + ) + ] + } + ) + + assert options.can_use_tool == my_callback + assert "tool_use_start" in options.hooks + assert len(options.hooks["tool_use_start"]) == 1 + assert options.hooks["tool_use_start"][0].hooks[0] == my_hook