mirror of
https://github.com/anthropics/claude-code-sdk-python.git
synced 2025-12-23 09:19:52 +00:00
test: add tests for can_use_tool callback features (issue #159)
- Add test for PermissionResultAllow with updatedPermissions support - Add test for PermissionResultDeny with interrupt flag (deny and stop) - Add test for PermissionResultDeny without interrupt (deny and continue) - Export PermissionRuleValue type for creating permission updates These tests verify the can_use_tool callback correctly handles: 1. The behavior/updatedInput field names (not allow/input) 2. The updatedPermissions field for "Always Allow" functionality 3. The interrupt flag for stopping vs continuing after denial Fixes #159
This commit is contained in:
parent
27575ae2ca
commit
5df105d623
2 changed files with 154 additions and 0 deletions
|
|
@ -34,6 +34,7 @@ from .types import (
|
|||
PermissionResult,
|
||||
PermissionResultAllow,
|
||||
PermissionResultDeny,
|
||||
PermissionRuleValue,
|
||||
PermissionUpdate,
|
||||
PostToolUseHookInput,
|
||||
PreCompactHookInput,
|
||||
|
|
@ -327,6 +328,7 @@ __all__ = [
|
|||
"PermissionResult",
|
||||
"PermissionResultAllow",
|
||||
"PermissionResultDeny",
|
||||
"PermissionRuleValue",
|
||||
"PermissionUpdate",
|
||||
# Hook support
|
||||
"HookCallback",
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ from claude_agent_sdk import (
|
|||
HookMatcher,
|
||||
PermissionResultAllow,
|
||||
PermissionResultDeny,
|
||||
PermissionRuleValue,
|
||||
PermissionUpdate,
|
||||
ToolPermissionContext,
|
||||
)
|
||||
from claude_agent_sdk._internal.query import Query
|
||||
|
|
@ -207,6 +209,156 @@ class TestToolPermissionCallbacks:
|
|||
assert '"subtype": "error"' in response
|
||||
assert "Callback error" in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_with_updated_permissions(self):
|
||||
"""Test callback that returns allow with updated permissions (Always Allow)."""
|
||||
|
||||
async def allow_with_permissions_callback(
|
||||
tool_name: str, input_data: dict, context: ToolPermissionContext
|
||||
) -> PermissionResultAllow:
|
||||
# Return allow with permission updates for "Always Allow" functionality
|
||||
return PermissionResultAllow(
|
||||
updated_permissions=[
|
||||
PermissionUpdate(
|
||||
type="addRules",
|
||||
behavior="allow",
|
||||
rules=[
|
||||
PermissionRuleValue(tool_name="Bash", rule_content=None)
|
||||
],
|
||||
destination="session",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=allow_with_permissions_callback,
|
||||
hooks=None,
|
||||
)
|
||||
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-4",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "Bash",
|
||||
"input": {"command": "ls -la"},
|
||||
"permission_suggestions": [],
|
||||
},
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check response includes updatedPermissions
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
response_data = json.loads(response)
|
||||
|
||||
# Get the nested response data
|
||||
result = response_data["response"]["response"]
|
||||
|
||||
assert result.get("behavior") == "allow"
|
||||
assert "updatedPermissions" in result
|
||||
assert len(result["updatedPermissions"]) == 1
|
||||
assert result["updatedPermissions"][0]["type"] == "addRules"
|
||||
assert result["updatedPermissions"][0]["behavior"] == "allow"
|
||||
assert result["updatedPermissions"][0]["destination"] == "session"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_deny_with_interrupt(self):
|
||||
"""Test callback that denies with interrupt flag to stop execution."""
|
||||
|
||||
async def deny_with_interrupt_callback(
|
||||
tool_name: str, input_data: dict, context: ToolPermissionContext
|
||||
) -> PermissionResultDeny:
|
||||
# Deny and interrupt - stop the agent completely
|
||||
return PermissionResultDeny(
|
||||
message="Critical security violation - stopping agent",
|
||||
interrupt=True,
|
||||
)
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=deny_with_interrupt_callback,
|
||||
hooks=None,
|
||||
)
|
||||
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-5-interrupt",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "DangerousTool",
|
||||
"input": {"command": "rm -rf /"},
|
||||
"permission_suggestions": [],
|
||||
},
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check response includes interrupt flag
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
response_data = json.loads(response)
|
||||
|
||||
# Get the nested response data
|
||||
result = response_data["response"]["response"]
|
||||
|
||||
assert result.get("behavior") == "deny"
|
||||
assert result.get("message") == "Critical security violation - stopping agent"
|
||||
assert result.get("interrupt") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_deny_without_interrupt(self):
|
||||
"""Test callback that denies without interrupt (deny and continue)."""
|
||||
|
||||
async def deny_without_interrupt_callback(
|
||||
tool_name: str, input_data: dict, context: ToolPermissionContext
|
||||
) -> PermissionResultDeny:
|
||||
# Deny but don't interrupt - let the agent try a different approach
|
||||
return PermissionResultDeny(
|
||||
message="Tool not allowed, try a different approach",
|
||||
interrupt=False,
|
||||
)
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=deny_without_interrupt_callback,
|
||||
hooks=None,
|
||||
)
|
||||
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-6-no-interrupt",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "SomeTool",
|
||||
"input": {},
|
||||
"permission_suggestions": [],
|
||||
},
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check response does NOT include interrupt flag when False
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
response_data = json.loads(response)
|
||||
|
||||
# Get the nested response data
|
||||
result = response_data["response"]["response"]
|
||||
|
||||
assert result.get("behavior") == "deny"
|
||||
assert result.get("message") == "Tool not allowed, try a different approach"
|
||||
# interrupt should not be present when False
|
||||
assert "interrupt" not in result
|
||||
|
||||
|
||||
class TestHookCallbacks:
|
||||
"""Test hook callback functionality."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue