diff --git a/e2e-tests/test_hooks.py b/e2e-tests/test_hooks.py index a01b65c..fda60e9 100644 --- a/e2e-tests/test_hooks.py +++ b/e2e-tests/test_hooks.py @@ -6,6 +6,7 @@ from claude_agent_sdk import ( ClaudeAgentOptions, ClaudeSDKClient, HookContext, + HookInput, HookJSONOutput, HookMatcher, ) @@ -18,7 +19,7 @@ async def test_hook_with_permission_decision_and_reason(): hook_invocations = [] async def test_hook( - input_data: dict, tool_use_id: str | None, context: HookContext + input_data: HookInput, tool_use_id: str | None, context: HookContext ) -> HookJSONOutput: """Hook that uses permissionDecision and reason fields.""" tool_name = input_data.get("tool_name", "") @@ -73,7 +74,7 @@ async def test_hook_with_continue_and_stop_reason(): hook_invocations = [] async def post_tool_hook( - input_data: dict, tool_use_id: str | None, context: HookContext + input_data: HookInput, tool_use_id: str | None, context: HookContext ) -> HookJSONOutput: """PostToolUse hook that stops execution with stopReason.""" tool_name = input_data.get("tool_name", "") @@ -114,7 +115,7 @@ async def test_hook_with_additional_context(): hook_invocations = [] async def context_hook( - input_data: dict, tool_use_id: str | None, context: HookContext + input_data: HookInput, tool_use_id: str | None, context: HookContext ) -> HookJSONOutput: """Hook that provides additional context.""" hook_invocations.append("context_added") diff --git a/examples/hooks.py b/examples/hooks.py index e533ac7..a8001d4 100644 --- a/examples/hooks.py +++ b/examples/hooks.py @@ -19,6 +19,7 @@ from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient from claude_agent_sdk.types import ( AssistantMessage, HookContext, + HookInput, HookJSONOutput, HookMatcher, Message, @@ -43,7 +44,7 @@ def display_message(msg: Message) -> None: ##### Hook callback functions async def check_bash_command( - input_data: dict[str, Any], tool_use_id: str | None, context: HookContext + input_data: HookInput, tool_use_id: str | None, context: HookContext ) -> HookJSONOutput: """Prevent certain bash commands from being executed.""" tool_name = input_data["tool_name"] @@ -70,7 +71,7 @@ async def check_bash_command( async def add_custom_instructions( - input_data: dict[str, Any], tool_use_id: str | None, context: HookContext + input_data: HookInput, tool_use_id: str | None, context: HookContext ) -> HookJSONOutput: """Add custom instructions when a session starts.""" return { @@ -82,7 +83,7 @@ async def add_custom_instructions( async def review_tool_output( - input_data: dict[str, Any], tool_use_id: str | None, context: HookContext + input_data: HookInput, tool_use_id: str | None, context: HookContext ) -> HookJSONOutput: """Review tool output and provide additional context or warnings.""" tool_response = input_data.get("tool_response", "") @@ -102,7 +103,7 @@ async def review_tool_output( async def strict_approval_hook( - input_data: dict[str, Any], tool_use_id: str | None, context: HookContext + input_data: HookInput, tool_use_id: str | None, context: HookContext ) -> HookJSONOutput: """Demonstrates using permissionDecision to control tool execution.""" tool_name = input_data.get("tool_name") @@ -135,7 +136,7 @@ async def strict_approval_hook( async def stop_on_error_hook( - input_data: dict[str, Any], tool_use_id: str | None, context: HookContext + input_data: HookInput, tool_use_id: str | None, context: HookContext ) -> HookJSONOutput: """Demonstrates using continue=False to stop execution on certain conditions.""" tool_response = input_data.get("tool_response", "") diff --git a/src/claude_agent_sdk/__init__.py b/src/claude_agent_sdk/__init__.py index 45eccec..6c44747 100644 --- a/src/claude_agent_sdk/__init__.py +++ b/src/claude_agent_sdk/__init__.py @@ -18,11 +18,13 @@ from .query import query from .types import ( AgentDefinition, AssistantMessage, + BaseHookInput, CanUseTool, ClaudeAgentOptions, ContentBlock, HookCallback, HookContext, + HookInput, HookJSONOutput, HookMatcher, McpSdkServerConfig, @@ -33,8 +35,13 @@ from .types import ( PermissionResultAllow, PermissionResultDeny, PermissionUpdate, + PostToolUseHookInput, + PreCompactHookInput, + PreToolUseHookInput, ResultMessage, SettingSource, + StopHookInput, + SubagentStopHookInput, SystemMessage, TextBlock, ThinkingBlock, @@ -42,6 +49,7 @@ from .types import ( ToolResultBlock, ToolUseBlock, UserMessage, + UserPromptSubmitHookInput, ) # MCP Server Support @@ -307,8 +315,17 @@ __all__ = [ "PermissionResultAllow", "PermissionResultDeny", "PermissionUpdate", + # Hook support "HookCallback", "HookContext", + "HookInput", + "BaseHookInput", + "PreToolUseHookInput", + "PostToolUseHookInput", + "UserPromptSubmitHookInput", + "StopHookInput", + "SubagentStopHookInput", + "PreCompactHookInput", "HookJSONOutput", "HookMatcher", # Agent support diff --git a/src/claude_agent_sdk/types.py b/src/claude_agent_sdk/types.py index 3095dfd..82a57ad 100644 --- a/src/claude_agent_sdk/types.py +++ b/src/claude_agent_sdk/types.py @@ -157,6 +157,73 @@ HookEvent = ( ) +# Hook input types - strongly typed for each hook event +class BaseHookInput(TypedDict): + """Base hook input fields present across many hook events.""" + + session_id: str + transcript_path: str + cwd: str + permission_mode: NotRequired[str] + + +class PreToolUseHookInput(BaseHookInput): + """Input data for PreToolUse hook events.""" + + hook_event_name: Literal["PreToolUse"] + tool_name: str + tool_input: dict[str, Any] + + +class PostToolUseHookInput(BaseHookInput): + """Input data for PostToolUse hook events.""" + + hook_event_name: Literal["PostToolUse"] + tool_name: str + tool_input: dict[str, Any] + tool_response: Any + + +class UserPromptSubmitHookInput(BaseHookInput): + """Input data for UserPromptSubmit hook events.""" + + hook_event_name: Literal["UserPromptSubmit"] + prompt: str + + +class StopHookInput(BaseHookInput): + """Input data for Stop hook events.""" + + hook_event_name: Literal["Stop"] + stop_hook_active: bool + + +class SubagentStopHookInput(BaseHookInput): + """Input data for SubagentStop hook events.""" + + hook_event_name: Literal["SubagentStop"] + stop_hook_active: bool + + +class PreCompactHookInput(BaseHookInput): + """Input data for PreCompact hook events.""" + + hook_event_name: Literal["PreCompact"] + trigger: Literal["manual", "auto"] + custom_instructions: str | None + + +# Union type for all hook inputs +HookInput = ( + PreToolUseHookInput + | PostToolUseHookInput + | UserPromptSubmitHookInput + | StopHookInput + | SubagentStopHookInput + | PreCompactHookInput +) + + # Hook-specific output types class PreToolUseHookSpecificOutput(TypedDict): """Hook-specific output for PreToolUse events.""" @@ -265,21 +332,22 @@ class SyncHookJSONOutput(TypedDict): HookJSONOutput = AsyncHookJSONOutput | SyncHookJSONOutput -@dataclass -class HookContext: - """Context information for hook callbacks.""" +class HookContext(TypedDict): + """Context information for hook callbacks. - signal: Any | None = None # Future: abort signal support + Fields: + signal: Reserved for future abort signal support. Currently always None. + """ + + signal: Any | None # Future: abort signal support HookCallback = Callable[ # HookCallback input parameters: - # - input - # See https://docs.anthropic.com/en/docs/claude-code/hooks#hook-input for - # the type of 'input', the first value. - # - tool_use_id - # - context - [dict[str, Any], str | None, HookContext], + # - input: Strongly-typed hook input with discriminated unions based on hook_event_name + # - tool_use_id: Optional tool use identifier + # - context: Hook context with abort signal support (currently placeholder) + [HookInput, str | None, HookContext], Awaitable[HookJSONOutput], ] diff --git a/tests/test_tool_callbacks.py b/tests/test_tool_callbacks.py index 4987ede..8ace3c8 100644 --- a/tests/test_tool_callbacks.py +++ b/tests/test_tool_callbacks.py @@ -7,6 +7,7 @@ import pytest from claude_agent_sdk import ( ClaudeAgentOptions, HookContext, + HookInput, HookJSONOutput, HookMatcher, PermissionResultAllow, @@ -216,7 +217,7 @@ class TestHookCallbacks: hook_calls = [] async def test_hook( - input_data: dict, tool_use_id: str | None, context: HookContext + input_data: HookInput, tool_use_id: str | None, context: HookContext ) -> dict: hook_calls.append({"input": input_data, "tool_use_id": tool_use_id}) return {"processed": True} @@ -266,7 +267,7 @@ class TestHookCallbacks: # Test all SyncHookJSONOutput fields together async def comprehensive_hook( - input_data: dict, tool_use_id: str | None, context: HookContext + input_data: HookInput, tool_use_id: str | None, context: HookContext ) -> HookJSONOutput: return { # Control fields @@ -349,7 +350,7 @@ class TestHookCallbacks: """Test AsyncHookJSONOutput type with proper async fields.""" async def async_hook( - input_data: dict, tool_use_id: str | None, context: HookContext + input_data: HookInput, tool_use_id: str | None, context: HookContext ) -> HookJSONOutput: # Test that async hooks properly use async_ and asyncTimeout fields return { @@ -399,7 +400,7 @@ class TestHookCallbacks: """Test that Python-safe field names (async_, continue_) are converted to CLI format (async, continue).""" async def conversion_test_hook( - input_data: dict, tool_use_id: str | None, context: HookContext + input_data: HookInput, tool_use_id: str | None, context: HookContext ) -> HookJSONOutput: # Return both async_ and continue_ to test conversion return { @@ -468,7 +469,7 @@ class TestClaudeAgentOptionsIntegration: return PermissionResultAllow() async def my_hook( - input_data: dict, tool_use_id: str | None, context: HookContext + input_data: HookInput, tool_use_id: str | None, context: HookContext ) -> dict: return {}