claude-code-sdk-python/tests/test_tool_callbacks.py
Claude 5df105d623
test: add tests for can_use_tool callback features (issue #159)
- Add test for PermissionResultAllow with updatedPermissions support
- Add test for PermissionResultDeny with interrupt flag (deny and stop)
- Add test for PermissionResultDeny without interrupt (deny and continue)
- Export PermissionRuleValue type for creating permission updates

These tests verify the can_use_tool callback correctly handles:
1. The behavior/updatedInput field names (not allow/input)
2. The updatedPermissions field for "Always Allow" functionality
3. The interrupt flag for stopping vs continuing after denial

Fixes #159
2025-12-17 12:28:34 +00:00

640 lines
22 KiB
Python

"""Tests for tool permission callbacks and hook callbacks."""
import json
import pytest
from claude_agent_sdk import (
ClaudeAgentOptions,
HookContext,
HookInput,
HookJSONOutput,
HookMatcher,
PermissionResultAllow,
PermissionResultDeny,
PermissionRuleValue,
PermissionUpdate,
ToolPermissionContext,
)
from claude_agent_sdk._internal.query import Query
from claude_agent_sdk._internal.transport import Transport
class MockTransport(Transport):
"""Mock transport for testing."""
def __init__(self):
self.written_messages = []
self.messages_to_read = []
self._connected = False
async def connect(self) -> None:
self._connected = True
async def close(self) -> None:
self._connected = False
async def write(self, data: str) -> None:
self.written_messages.append(data)
async def end_input(self) -> None:
pass
def read_messages(self):
async def _read():
for msg in self.messages_to_read:
yield msg
return _read()
def is_ready(self) -> bool:
return self._connected
class TestToolPermissionCallbacks:
"""Test tool permission callback functionality."""
@pytest.mark.asyncio
async def test_permission_callback_allow(self):
"""Test callback that allows tool execution."""
callback_invoked = False
async def allow_callback(
tool_name: str, input_data: dict, context: ToolPermissionContext
) -> PermissionResultAllow:
nonlocal callback_invoked
callback_invoked = True
assert tool_name == "TestTool"
assert input_data == {"param": "value"}
return PermissionResultAllow()
transport = MockTransport()
query = Query(
transport=transport,
is_streaming_mode=True,
can_use_tool=allow_callback,
hooks=None,
)
# Simulate control request
request = {
"type": "control_request",
"request_id": "test-1",
"request": {
"subtype": "can_use_tool",
"tool_name": "TestTool",
"input": {"param": "value"},
"permission_suggestions": [],
},
}
await query._handle_control_request(request)
# Check callback was invoked
assert callback_invoked
# Check response was sent
assert len(transport.written_messages) == 1
response = transport.written_messages[0]
assert '"behavior": "allow"' in response
@pytest.mark.asyncio
async def test_permission_callback_deny(self):
"""Test callback that denies tool execution."""
async def deny_callback(
tool_name: str, input_data: dict, context: ToolPermissionContext
) -> PermissionResultDeny:
return PermissionResultDeny(message="Security policy violation")
transport = MockTransport()
query = Query(
transport=transport,
is_streaming_mode=True,
can_use_tool=deny_callback,
hooks=None,
)
request = {
"type": "control_request",
"request_id": "test-2",
"request": {
"subtype": "can_use_tool",
"tool_name": "DangerousTool",
"input": {"command": "rm -rf /"},
"permission_suggestions": ["deny"],
},
}
await query._handle_control_request(request)
# Check response
assert len(transport.written_messages) == 1
response = transport.written_messages[0]
assert '"behavior": "deny"' in response
assert '"message": "Security policy violation"' in response
@pytest.mark.asyncio
async def test_permission_callback_input_modification(self):
"""Test callback that modifies tool input."""
async def modify_callback(
tool_name: str, input_data: dict, context: ToolPermissionContext
) -> PermissionResultAllow:
# Modify the input to add safety flag
modified_input = input_data.copy()
modified_input["safe_mode"] = True
return PermissionResultAllow(updated_input=modified_input)
transport = MockTransport()
query = Query(
transport=transport,
is_streaming_mode=True,
can_use_tool=modify_callback,
hooks=None,
)
request = {
"type": "control_request",
"request_id": "test-3",
"request": {
"subtype": "can_use_tool",
"tool_name": "WriteTool",
"input": {"file_path": "/etc/passwd"},
"permission_suggestions": [],
},
}
await query._handle_control_request(request)
# Check response includes modified input
assert len(transport.written_messages) == 1
response = transport.written_messages[0]
assert '"behavior": "allow"' in response
assert '"safe_mode": true' in response
@pytest.mark.asyncio
async def test_callback_exception_handling(self):
"""Test that callback exceptions are properly handled."""
async def error_callback(
tool_name: str, input_data: dict, context: ToolPermissionContext
) -> PermissionResultAllow:
raise ValueError("Callback error")
transport = MockTransport()
query = Query(
transport=transport,
is_streaming_mode=True,
can_use_tool=error_callback,
hooks=None,
)
request = {
"type": "control_request",
"request_id": "test-5",
"request": {
"subtype": "can_use_tool",
"tool_name": "TestTool",
"input": {},
"permission_suggestions": [],
},
}
await query._handle_control_request(request)
# Check error response was sent
assert len(transport.written_messages) == 1
response = transport.written_messages[0]
assert '"subtype": "error"' in response
assert "Callback error" in response
@pytest.mark.asyncio
async def test_permission_callback_with_updated_permissions(self):
"""Test callback that returns allow with updated permissions (Always Allow)."""
async def allow_with_permissions_callback(
tool_name: str, input_data: dict, context: ToolPermissionContext
) -> PermissionResultAllow:
# Return allow with permission updates for "Always Allow" functionality
return PermissionResultAllow(
updated_permissions=[
PermissionUpdate(
type="addRules",
behavior="allow",
rules=[
PermissionRuleValue(tool_name="Bash", rule_content=None)
],
destination="session",
)
]
)
transport = MockTransport()
query = Query(
transport=transport,
is_streaming_mode=True,
can_use_tool=allow_with_permissions_callback,
hooks=None,
)
request = {
"type": "control_request",
"request_id": "test-4",
"request": {
"subtype": "can_use_tool",
"tool_name": "Bash",
"input": {"command": "ls -la"},
"permission_suggestions": [],
},
}
await query._handle_control_request(request)
# Check response includes updatedPermissions
assert len(transport.written_messages) == 1
response = transport.written_messages[0]
response_data = json.loads(response)
# Get the nested response data
result = response_data["response"]["response"]
assert result.get("behavior") == "allow"
assert "updatedPermissions" in result
assert len(result["updatedPermissions"]) == 1
assert result["updatedPermissions"][0]["type"] == "addRules"
assert result["updatedPermissions"][0]["behavior"] == "allow"
assert result["updatedPermissions"][0]["destination"] == "session"
@pytest.mark.asyncio
async def test_permission_callback_deny_with_interrupt(self):
"""Test callback that denies with interrupt flag to stop execution."""
async def deny_with_interrupt_callback(
tool_name: str, input_data: dict, context: ToolPermissionContext
) -> PermissionResultDeny:
# Deny and interrupt - stop the agent completely
return PermissionResultDeny(
message="Critical security violation - stopping agent",
interrupt=True,
)
transport = MockTransport()
query = Query(
transport=transport,
is_streaming_mode=True,
can_use_tool=deny_with_interrupt_callback,
hooks=None,
)
request = {
"type": "control_request",
"request_id": "test-5-interrupt",
"request": {
"subtype": "can_use_tool",
"tool_name": "DangerousTool",
"input": {"command": "rm -rf /"},
"permission_suggestions": [],
},
}
await query._handle_control_request(request)
# Check response includes interrupt flag
assert len(transport.written_messages) == 1
response = transport.written_messages[0]
response_data = json.loads(response)
# Get the nested response data
result = response_data["response"]["response"]
assert result.get("behavior") == "deny"
assert result.get("message") == "Critical security violation - stopping agent"
assert result.get("interrupt") is True
@pytest.mark.asyncio
async def test_permission_callback_deny_without_interrupt(self):
"""Test callback that denies without interrupt (deny and continue)."""
async def deny_without_interrupt_callback(
tool_name: str, input_data: dict, context: ToolPermissionContext
) -> PermissionResultDeny:
# Deny but don't interrupt - let the agent try a different approach
return PermissionResultDeny(
message="Tool not allowed, try a different approach",
interrupt=False,
)
transport = MockTransport()
query = Query(
transport=transport,
is_streaming_mode=True,
can_use_tool=deny_without_interrupt_callback,
hooks=None,
)
request = {
"type": "control_request",
"request_id": "test-6-no-interrupt",
"request": {
"subtype": "can_use_tool",
"tool_name": "SomeTool",
"input": {},
"permission_suggestions": [],
},
}
await query._handle_control_request(request)
# Check response does NOT include interrupt flag when False
assert len(transport.written_messages) == 1
response = transport.written_messages[0]
response_data = json.loads(response)
# Get the nested response data
result = response_data["response"]["response"]
assert result.get("behavior") == "deny"
assert result.get("message") == "Tool not allowed, try a different approach"
# interrupt should not be present when False
assert "interrupt" not in result
class TestHookCallbacks:
"""Test hook callback functionality."""
@pytest.mark.asyncio
async def test_hook_execution(self):
"""Test that hooks are called at appropriate times."""
hook_calls = []
async def test_hook(
input_data: HookInput, tool_use_id: str | None, context: HookContext
) -> dict:
hook_calls.append({"input": input_data, "tool_use_id": tool_use_id})
return {"processed": True}
transport = MockTransport()
# Create hooks configuration
hooks = {
"tool_use_start": [{"matcher": {"tool": "TestTool"}, "hooks": [test_hook]}]
}
query = Query(
transport=transport, is_streaming_mode=True, can_use_tool=None, hooks=hooks
)
# Manually register the hook callback to avoid needing the full initialize flow
callback_id = "test_hook_0"
query.hook_callbacks[callback_id] = test_hook
# Simulate hook callback request
request = {
"type": "control_request",
"request_id": "test-hook-1",
"request": {
"subtype": "hook_callback",
"callback_id": callback_id,
"input": {"test": "data"},
"tool_use_id": "tool-123",
},
}
await query._handle_control_request(request)
# Check hook was called
assert len(hook_calls) == 1
assert hook_calls[0]["input"] == {"test": "data"}
assert hook_calls[0]["tool_use_id"] == "tool-123"
# Check response
assert len(transport.written_messages) > 0
last_response = transport.written_messages[-1]
assert '"processed": true' in last_response
@pytest.mark.asyncio
async def test_hook_output_fields(self):
"""Test that all SyncHookJSONOutput fields are properly handled."""
# Test all SyncHookJSONOutput fields together
async def comprehensive_hook(
input_data: HookInput, tool_use_id: str | None, context: HookContext
) -> HookJSONOutput:
return {
# Control fields
"continue_": True,
"suppressOutput": False,
"stopReason": "Test stop reason",
# Decision fields
"decision": "block",
"systemMessage": "Test system message",
"reason": "Test reason for blocking",
# Hook-specific output with all PreToolUse fields
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": "Security policy violation",
"updatedInput": {"modified": "input"},
},
}
transport = MockTransport()
hooks = {
"PreToolUse": [
{"matcher": {"tool": "TestTool"}, "hooks": [comprehensive_hook]}
]
}
query = Query(
transport=transport, is_streaming_mode=True, can_use_tool=None, hooks=hooks
)
callback_id = "test_comprehensive_hook"
query.hook_callbacks[callback_id] = comprehensive_hook
request = {
"type": "control_request",
"request_id": "test-comprehensive",
"request": {
"subtype": "hook_callback",
"callback_id": callback_id,
"input": {"test": "data"},
"tool_use_id": "tool-456",
},
}
await query._handle_control_request(request)
# Check response contains all the fields
assert len(transport.written_messages) > 0
last_response = transport.written_messages[-1]
# Parse the JSON response
response_data = json.loads(last_response)
# The hook result is nested at response.response
result = response_data["response"]["response"]
# Verify control fields are present and converted to CLI format
assert result.get("continue") is True, (
"continue_ should be converted to continue"
)
assert "continue_" not in result, "continue_ should not appear in CLI output"
assert result.get("suppressOutput") is False
assert result.get("stopReason") == "Test stop reason"
# Verify decision fields are present
assert result.get("decision") == "block"
assert result.get("reason") == "Test reason for blocking"
assert result.get("systemMessage") == "Test system message"
# Verify hook-specific output is present
hook_output = result.get("hookSpecificOutput", {})
assert hook_output.get("hookEventName") == "PreToolUse"
assert hook_output.get("permissionDecision") == "deny"
assert (
hook_output.get("permissionDecisionReason") == "Security policy violation"
)
assert "updatedInput" in hook_output
@pytest.mark.asyncio
async def test_async_hook_output(self):
"""Test AsyncHookJSONOutput type with proper async fields."""
async def async_hook(
input_data: HookInput, tool_use_id: str | None, context: HookContext
) -> HookJSONOutput:
# Test that async hooks properly use async_ and asyncTimeout fields
return {
"async_": True,
"asyncTimeout": 5000,
}
transport = MockTransport()
hooks = {"PreToolUse": [{"matcher": None, "hooks": [async_hook]}]}
query = Query(
transport=transport, is_streaming_mode=True, can_use_tool=None, hooks=hooks
)
callback_id = "test_async_hook"
query.hook_callbacks[callback_id] = async_hook
request = {
"type": "control_request",
"request_id": "test-async",
"request": {
"subtype": "hook_callback",
"callback_id": callback_id,
"input": {"test": "async_data"},
"tool_use_id": None,
},
}
await query._handle_control_request(request)
# Check response contains async fields
assert len(transport.written_messages) > 0
last_response = transport.written_messages[-1]
# Parse the JSON response
response_data = json.loads(last_response)
# The hook result is nested at response.response
result = response_data["response"]["response"]
# The SDK should convert async_ to "async" for CLI compatibility
assert result.get("async") is True, "async_ should be converted to async"
assert "async_" not in result, "async_ should not appear in CLI output"
assert result.get("asyncTimeout") == 5000
@pytest.mark.asyncio
async def test_field_name_conversion(self):
"""Test that Python-safe field names (async_, continue_) are converted to CLI format (async, continue)."""
async def conversion_test_hook(
input_data: HookInput, tool_use_id: str | None, context: HookContext
) -> HookJSONOutput:
# Return both async_ and continue_ to test conversion
return {
"async_": True,
"asyncTimeout": 10000,
"continue_": False,
"stopReason": "Testing field conversion",
"systemMessage": "Fields should be converted",
}
transport = MockTransport()
hooks = {"PreToolUse": [{"matcher": None, "hooks": [conversion_test_hook]}]}
query = Query(
transport=transport, is_streaming_mode=True, can_use_tool=None, hooks=hooks
)
callback_id = "test_conversion"
query.hook_callbacks[callback_id] = conversion_test_hook
request = {
"type": "control_request",
"request_id": "test-conversion",
"request": {
"subtype": "hook_callback",
"callback_id": callback_id,
"input": {"test": "data"},
"tool_use_id": None,
},
}
await query._handle_control_request(request)
# Check response has converted field names
assert len(transport.written_messages) > 0
last_response = transport.written_messages[-1]
response_data = json.loads(last_response)
result = response_data["response"]["response"]
# Verify async_ was converted to async
assert result.get("async") is True, "async_ should be converted to async"
assert "async_" not in result, "async_ should not appear in output"
# Verify continue_ was converted to continue
assert result.get("continue") is False, (
"continue_ should be converted to continue"
)
assert "continue_" not in result, "continue_ should not appear in output"
# Verify other fields are unchanged
assert result.get("asyncTimeout") == 10000
assert result.get("stopReason") == "Testing field conversion"
assert result.get("systemMessage") == "Fields should be converted"
class TestClaudeAgentOptionsIntegration:
"""Test that callbacks work through ClaudeAgentOptions."""
def test_options_with_callbacks(self):
"""Test creating options with callbacks."""
async def my_callback(
tool_name: str, input_data: dict, context: ToolPermissionContext
) -> PermissionResultAllow:
return PermissionResultAllow()
async def my_hook(
input_data: HookInput, tool_use_id: str | None, context: HookContext
) -> dict:
return {}
options = ClaudeAgentOptions(
can_use_tool=my_callback,
hooks={
"tool_use_start": [
HookMatcher(matcher={"tool": "Bash"}, hooks=[my_hook])
]
},
)
assert options.can_use_tool == my_callback
assert "tool_use_start" in options.hooks
assert len(options.hooks["tool_use_start"]) == 1
assert options.hooks["tool_use_start"][0].hooks[0] == my_hook