From 5df105d623c59b8f609a665ea3840de7edc1deab Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 17 Dec 2025 12:28:34 +0000 Subject: [PATCH] test: add tests for can_use_tool callback features (issue #159) - Add test for PermissionResultAllow with updatedPermissions support - Add test for PermissionResultDeny with interrupt flag (deny and stop) - Add test for PermissionResultDeny without interrupt (deny and continue) - Export PermissionRuleValue type for creating permission updates These tests verify the can_use_tool callback correctly handles: 1. The behavior/updatedInput field names (not allow/input) 2. The updatedPermissions field for "Always Allow" functionality 3. The interrupt flag for stopping vs continuing after denial Fixes #159 --- src/claude_agent_sdk/__init__.py | 2 + tests/test_tool_callbacks.py | 152 +++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+) diff --git a/src/claude_agent_sdk/__init__.py b/src/claude_agent_sdk/__init__.py index 4898bc0..e53a5a3 100644 --- a/src/claude_agent_sdk/__init__.py +++ b/src/claude_agent_sdk/__init__.py @@ -34,6 +34,7 @@ from .types import ( PermissionResult, PermissionResultAllow, PermissionResultDeny, + PermissionRuleValue, PermissionUpdate, PostToolUseHookInput, PreCompactHookInput, @@ -327,6 +328,7 @@ __all__ = [ "PermissionResult", "PermissionResultAllow", "PermissionResultDeny", + "PermissionRuleValue", "PermissionUpdate", # Hook support "HookCallback", diff --git a/tests/test_tool_callbacks.py b/tests/test_tool_callbacks.py index 8ace3c8..b78aa06 100644 --- a/tests/test_tool_callbacks.py +++ b/tests/test_tool_callbacks.py @@ -12,6 +12,8 @@ from claude_agent_sdk import ( HookMatcher, PermissionResultAllow, PermissionResultDeny, + PermissionRuleValue, + PermissionUpdate, ToolPermissionContext, ) from claude_agent_sdk._internal.query import Query @@ -207,6 +209,156 @@ class TestToolPermissionCallbacks: assert '"subtype": "error"' in response assert "Callback error" in response + @pytest.mark.asyncio + async def test_permission_callback_with_updated_permissions(self): + """Test callback that returns allow with updated permissions (Always Allow).""" + + async def allow_with_permissions_callback( + tool_name: str, input_data: dict, context: ToolPermissionContext + ) -> PermissionResultAllow: + # Return allow with permission updates for "Always Allow" functionality + return PermissionResultAllow( + updated_permissions=[ + PermissionUpdate( + type="addRules", + behavior="allow", + rules=[ + PermissionRuleValue(tool_name="Bash", rule_content=None) + ], + destination="session", + ) + ] + ) + + transport = MockTransport() + query = Query( + transport=transport, + is_streaming_mode=True, + can_use_tool=allow_with_permissions_callback, + hooks=None, + ) + + request = { + "type": "control_request", + "request_id": "test-4", + "request": { + "subtype": "can_use_tool", + "tool_name": "Bash", + "input": {"command": "ls -la"}, + "permission_suggestions": [], + }, + } + + await query._handle_control_request(request) + + # Check response includes updatedPermissions + assert len(transport.written_messages) == 1 + response = transport.written_messages[0] + response_data = json.loads(response) + + # Get the nested response data + result = response_data["response"]["response"] + + assert result.get("behavior") == "allow" + assert "updatedPermissions" in result + assert len(result["updatedPermissions"]) == 1 + assert result["updatedPermissions"][0]["type"] == "addRules" + assert result["updatedPermissions"][0]["behavior"] == "allow" + assert result["updatedPermissions"][0]["destination"] == "session" + + @pytest.mark.asyncio + async def test_permission_callback_deny_with_interrupt(self): + """Test callback that denies with interrupt flag to stop execution.""" + + async def deny_with_interrupt_callback( + tool_name: str, input_data: dict, context: ToolPermissionContext + ) -> PermissionResultDeny: + # Deny and interrupt - stop the agent completely + return PermissionResultDeny( + message="Critical security violation - stopping agent", + interrupt=True, + ) + + transport = MockTransport() + query = Query( + transport=transport, + is_streaming_mode=True, + can_use_tool=deny_with_interrupt_callback, + hooks=None, + ) + + request = { + "type": "control_request", + "request_id": "test-5-interrupt", + "request": { + "subtype": "can_use_tool", + "tool_name": "DangerousTool", + "input": {"command": "rm -rf /"}, + "permission_suggestions": [], + }, + } + + await query._handle_control_request(request) + + # Check response includes interrupt flag + assert len(transport.written_messages) == 1 + response = transport.written_messages[0] + response_data = json.loads(response) + + # Get the nested response data + result = response_data["response"]["response"] + + assert result.get("behavior") == "deny" + assert result.get("message") == "Critical security violation - stopping agent" + assert result.get("interrupt") is True + + @pytest.mark.asyncio + async def test_permission_callback_deny_without_interrupt(self): + """Test callback that denies without interrupt (deny and continue).""" + + async def deny_without_interrupt_callback( + tool_name: str, input_data: dict, context: ToolPermissionContext + ) -> PermissionResultDeny: + # Deny but don't interrupt - let the agent try a different approach + return PermissionResultDeny( + message="Tool not allowed, try a different approach", + interrupt=False, + ) + + transport = MockTransport() + query = Query( + transport=transport, + is_streaming_mode=True, + can_use_tool=deny_without_interrupt_callback, + hooks=None, + ) + + request = { + "type": "control_request", + "request_id": "test-6-no-interrupt", + "request": { + "subtype": "can_use_tool", + "tool_name": "SomeTool", + "input": {}, + "permission_suggestions": [], + }, + } + + await query._handle_control_request(request) + + # Check response does NOT include interrupt flag when False + assert len(transport.written_messages) == 1 + response = transport.written_messages[0] + response_data = json.loads(response) + + # Get the nested response data + result = response_data["response"]["response"] + + assert result.get("behavior") == "deny" + assert result.get("message") == "Tool not allowed, try a different approach" + # interrupt should not be present when False + assert "interrupt" not in result + class TestHookCallbacks: """Test hook callback functionality."""