test: add tests for can_use_tool callback features (issue #159)

- Add test for PermissionResultAllow with updatedPermissions support
- Add test for PermissionResultDeny with interrupt flag (deny and stop)
- Add test for PermissionResultDeny without interrupt (deny and continue)
- Export PermissionRuleValue type for creating permission updates

These tests verify the can_use_tool callback correctly handles:
1. The behavior/updatedInput field names (not allow/input)
2. The updatedPermissions field for "Always Allow" functionality
3. The interrupt flag for stopping vs continuing after denial

Fixes #159
This commit is contained in:
Claude 2025-12-17 12:28:34 +00:00
parent 27575ae2ca
commit 5df105d623
No known key found for this signature in database
2 changed files with 154 additions and 0 deletions

View file

@ -34,6 +34,7 @@ from .types import (
PermissionResult,
PermissionResultAllow,
PermissionResultDeny,
PermissionRuleValue,
PermissionUpdate,
PostToolUseHookInput,
PreCompactHookInput,
@ -327,6 +328,7 @@ __all__ = [
"PermissionResult",
"PermissionResultAllow",
"PermissionResultDeny",
"PermissionRuleValue",
"PermissionUpdate",
# Hook support
"HookCallback",

View file

@ -12,6 +12,8 @@ from claude_agent_sdk import (
HookMatcher,
PermissionResultAllow,
PermissionResultDeny,
PermissionRuleValue,
PermissionUpdate,
ToolPermissionContext,
)
from claude_agent_sdk._internal.query import Query
@ -207,6 +209,156 @@ class TestToolPermissionCallbacks:
assert '"subtype": "error"' in response
assert "Callback error" in response
@pytest.mark.asyncio
async def test_permission_callback_with_updated_permissions(self):
"""Test callback that returns allow with updated permissions (Always Allow)."""
async def allow_with_permissions_callback(
tool_name: str, input_data: dict, context: ToolPermissionContext
) -> PermissionResultAllow:
# Return allow with permission updates for "Always Allow" functionality
return PermissionResultAllow(
updated_permissions=[
PermissionUpdate(
type="addRules",
behavior="allow",
rules=[
PermissionRuleValue(tool_name="Bash", rule_content=None)
],
destination="session",
)
]
)
transport = MockTransport()
query = Query(
transport=transport,
is_streaming_mode=True,
can_use_tool=allow_with_permissions_callback,
hooks=None,
)
request = {
"type": "control_request",
"request_id": "test-4",
"request": {
"subtype": "can_use_tool",
"tool_name": "Bash",
"input": {"command": "ls -la"},
"permission_suggestions": [],
},
}
await query._handle_control_request(request)
# Check response includes updatedPermissions
assert len(transport.written_messages) == 1
response = transport.written_messages[0]
response_data = json.loads(response)
# Get the nested response data
result = response_data["response"]["response"]
assert result.get("behavior") == "allow"
assert "updatedPermissions" in result
assert len(result["updatedPermissions"]) == 1
assert result["updatedPermissions"][0]["type"] == "addRules"
assert result["updatedPermissions"][0]["behavior"] == "allow"
assert result["updatedPermissions"][0]["destination"] == "session"
@pytest.mark.asyncio
async def test_permission_callback_deny_with_interrupt(self):
"""Test callback that denies with interrupt flag to stop execution."""
async def deny_with_interrupt_callback(
tool_name: str, input_data: dict, context: ToolPermissionContext
) -> PermissionResultDeny:
# Deny and interrupt - stop the agent completely
return PermissionResultDeny(
message="Critical security violation - stopping agent",
interrupt=True,
)
transport = MockTransport()
query = Query(
transport=transport,
is_streaming_mode=True,
can_use_tool=deny_with_interrupt_callback,
hooks=None,
)
request = {
"type": "control_request",
"request_id": "test-5-interrupt",
"request": {
"subtype": "can_use_tool",
"tool_name": "DangerousTool",
"input": {"command": "rm -rf /"},
"permission_suggestions": [],
},
}
await query._handle_control_request(request)
# Check response includes interrupt flag
assert len(transport.written_messages) == 1
response = transport.written_messages[0]
response_data = json.loads(response)
# Get the nested response data
result = response_data["response"]["response"]
assert result.get("behavior") == "deny"
assert result.get("message") == "Critical security violation - stopping agent"
assert result.get("interrupt") is True
@pytest.mark.asyncio
async def test_permission_callback_deny_without_interrupt(self):
"""Test callback that denies without interrupt (deny and continue)."""
async def deny_without_interrupt_callback(
tool_name: str, input_data: dict, context: ToolPermissionContext
) -> PermissionResultDeny:
# Deny but don't interrupt - let the agent try a different approach
return PermissionResultDeny(
message="Tool not allowed, try a different approach",
interrupt=False,
)
transport = MockTransport()
query = Query(
transport=transport,
is_streaming_mode=True,
can_use_tool=deny_without_interrupt_callback,
hooks=None,
)
request = {
"type": "control_request",
"request_id": "test-6-no-interrupt",
"request": {
"subtype": "can_use_tool",
"tool_name": "SomeTool",
"input": {},
"permission_suggestions": [],
},
}
await query._handle_control_request(request)
# Check response does NOT include interrupt flag when False
assert len(transport.written_messages) == 1
response = transport.written_messages[0]
response_data = json.loads(response)
# Get the nested response data
result = response_data["response"]["response"]
assert result.get("behavior") == "deny"
assert result.get("message") == "Tool not allowed, try a different approach"
# interrupt should not be present when False
assert "interrupt" not in result
class TestHookCallbacks:
"""Test hook callback functionality."""