mirror of
https://github.com/anthropics/claude-code-sdk-python.git
synced 2025-12-23 09:19:52 +00:00
Address review comments
- Move ToolPermissionContext/Response imports to top of query.py (no circular dependency) - Remove dict compatibility, require ToolPermissionResponse only - Update type hints to enforce cleaner API - Remove test for dict return compatibility - Fix all whitespace and import issues - Add newline at end of test file
This commit is contained in:
parent
f8bf1a4cbb
commit
9674b3bdfd
3 changed files with 71 additions and 128 deletions
|
|
@ -14,6 +14,8 @@ from ..types import (
|
|||
SDKControlRequest,
|
||||
SDKControlResponse,
|
||||
SDKHookCallbackRequest,
|
||||
ToolPermissionContext,
|
||||
ToolPermissionResponse,
|
||||
)
|
||||
from .transport import Transport
|
||||
|
||||
|
|
@ -184,9 +186,6 @@ class Query:
|
|||
if not self.can_use_tool:
|
||||
raise Exception("canUseTool callback is not provided")
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from ..types import ToolPermissionContext, ToolPermissionResponse
|
||||
|
||||
context = ToolPermissionContext(
|
||||
signal=None, # TODO: Add abort signal support
|
||||
suggestions=permission_request.get("permission_suggestions", [])
|
||||
|
|
@ -199,19 +198,16 @@ class Query:
|
|||
)
|
||||
|
||||
# Convert ToolPermissionResponse to expected dict format
|
||||
if isinstance(response, ToolPermissionResponse):
|
||||
response_data = {
|
||||
"allow": response.allow
|
||||
}
|
||||
if response.input is not None:
|
||||
response_data["input"] = response.input
|
||||
if response.reason is not None:
|
||||
response_data["reason"] = response.reason
|
||||
elif isinstance(response, dict):
|
||||
# Support returning dict directly for compatibility
|
||||
response_data = response
|
||||
else:
|
||||
raise TypeError(f"Tool permission callback must return ToolPermissionResponse or dict, got {type(response)}")
|
||||
if not isinstance(response, ToolPermissionResponse):
|
||||
raise TypeError(f"Tool permission callback must return ToolPermissionResponse, got {type(response)}")
|
||||
|
||||
response_data = {
|
||||
"allow": response.allow
|
||||
}
|
||||
if response.input is not None:
|
||||
response_data["input"] = response.input
|
||||
if response.reason is not None:
|
||||
response_data["reason"] = response.reason
|
||||
|
||||
elif subtype == "hook_callback":
|
||||
hook_callback_request: SDKHookCallbackRequest = request_data # type: ignore[assignment]
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class ToolPermissionResponse:
|
|||
|
||||
ToolPermissionCallback = Callable[
|
||||
[str, dict[str, Any], ToolPermissionContext],
|
||||
Awaitable[ToolPermissionResponse | dict[str, Any]]
|
||||
Awaitable[ToolPermissionResponse]
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +1,13 @@
|
|||
"""Tests for tool permission callbacks and hook callbacks."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
import pytest
|
||||
|
||||
from claude_code_sdk import (
|
||||
ClaudeCodeOptions,
|
||||
ToolPermissionCallback,
|
||||
ToolPermissionResponse,
|
||||
ToolPermissionContext,
|
||||
HookCallback,
|
||||
HookContext,
|
||||
HookMatcher,
|
||||
ToolPermissionContext,
|
||||
ToolPermissionResponse,
|
||||
)
|
||||
from claude_code_sdk._internal.query import Query
|
||||
from claude_code_sdk._internal.transport import Transport
|
||||
|
|
@ -19,45 +15,45 @@ from claude_code_sdk._internal.transport import Transport
|
|||
|
||||
class MockTransport(Transport):
|
||||
"""Mock transport for testing."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.written_messages = []
|
||||
self.messages_to_read = []
|
||||
self._connected = False
|
||||
|
||||
|
||||
async def connect(self) -> None:
|
||||
self._connected = True
|
||||
|
||||
|
||||
async def close(self) -> None:
|
||||
self._connected = False
|
||||
|
||||
|
||||
async def write(self, data: str) -> None:
|
||||
self.written_messages.append(data)
|
||||
|
||||
|
||||
async def end_input(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def read_messages(self):
|
||||
async def _read():
|
||||
for msg in self.messages_to_read:
|
||||
yield msg
|
||||
return _read()
|
||||
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
return self._connected
|
||||
|
||||
|
||||
class TestToolPermissionCallbacks:
|
||||
"""Test tool permission callback functionality."""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_allow(self):
|
||||
"""Test callback that allows tool execution."""
|
||||
callback_invoked = False
|
||||
|
||||
|
||||
async def allow_callback(
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
context: ToolPermissionContext
|
||||
) -> ToolPermissionResponse:
|
||||
nonlocal callback_invoked
|
||||
|
|
@ -65,7 +61,7 @@ class TestToolPermissionCallbacks:
|
|||
assert tool_name == "TestTool"
|
||||
assert input_data == {"param": "value"}
|
||||
return ToolPermissionResponse(allow=True, reason="Test allow")
|
||||
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
|
|
@ -73,7 +69,7 @@ class TestToolPermissionCallbacks:
|
|||
can_use_tool=allow_callback,
|
||||
hooks=None
|
||||
)
|
||||
|
||||
|
||||
# Simulate control request
|
||||
request = {
|
||||
"type": "control_request",
|
||||
|
|
@ -85,31 +81,31 @@ class TestToolPermissionCallbacks:
|
|||
"permission_suggestions": []
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
|
||||
# Check callback was invoked
|
||||
assert callback_invoked
|
||||
|
||||
|
||||
# Check response was sent
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
assert '"allow": true' in response
|
||||
assert '"reason": "Test allow"' in response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_deny(self):
|
||||
"""Test callback that denies tool execution."""
|
||||
async def deny_callback(
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
context: ToolPermissionContext
|
||||
) -> ToolPermissionResponse:
|
||||
return ToolPermissionResponse(
|
||||
allow=False,
|
||||
allow=False,
|
||||
reason="Security policy violation"
|
||||
)
|
||||
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
|
|
@ -117,7 +113,7 @@ class TestToolPermissionCallbacks:
|
|||
can_use_tool=deny_callback,
|
||||
hooks=None
|
||||
)
|
||||
|
||||
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-2",
|
||||
|
|
@ -128,21 +124,21 @@ class TestToolPermissionCallbacks:
|
|||
"permission_suggestions": ["deny"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
|
||||
# Check response
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
assert '"allow": false' in response
|
||||
assert '"reason": "Security policy violation"' in response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_input_modification(self):
|
||||
"""Test callback that modifies tool input."""
|
||||
async def modify_callback(
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
context: ToolPermissionContext
|
||||
) -> ToolPermissionResponse:
|
||||
# Modify the input to add safety flag
|
||||
|
|
@ -153,7 +149,7 @@ class TestToolPermissionCallbacks:
|
|||
input=modified_input,
|
||||
reason="Modified for safety"
|
||||
)
|
||||
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
|
|
@ -161,7 +157,7 @@ class TestToolPermissionCallbacks:
|
|||
can_use_tool=modify_callback,
|
||||
hooks=None
|
||||
)
|
||||
|
||||
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-3",
|
||||
|
|
@ -172,67 +168,26 @@ class TestToolPermissionCallbacks:
|
|||
"permission_suggestions": []
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
|
||||
# Check response includes modified input
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
assert '"allow": true' in response
|
||||
assert '"safe_mode": true' in response
|
||||
assert '"reason": "Modified for safety"' in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_dict_return(self):
|
||||
"""Test callback can return dict for backwards compatibility."""
|
||||
async def dict_callback(
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
context: ToolPermissionContext
|
||||
) -> dict:
|
||||
# Return dict directly instead of ToolPermissionResponse
|
||||
return {
|
||||
"allow": True,
|
||||
"reason": "Dict response"
|
||||
}
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=dict_callback,
|
||||
hooks=None
|
||||
)
|
||||
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-4",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "TestTool",
|
||||
"input": {},
|
||||
"permission_suggestions": []
|
||||
}
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check response
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
assert '"allow": true' in response
|
||||
assert '"reason": "Dict response"' in response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_exception_handling(self):
|
||||
"""Test that callback exceptions are properly handled."""
|
||||
async def error_callback(
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
context: ToolPermissionContext
|
||||
) -> ToolPermissionResponse:
|
||||
raise ValueError("Callback error")
|
||||
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
|
|
@ -240,7 +195,7 @@ class TestToolPermissionCallbacks:
|
|||
can_use_tool=error_callback,
|
||||
hooks=None
|
||||
)
|
||||
|
||||
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-5",
|
||||
|
|
@ -251,9 +206,9 @@ class TestToolPermissionCallbacks:
|
|||
"permission_suggestions": []
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
|
||||
# Check error response was sent
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
|
|
@ -263,12 +218,12 @@ class TestToolPermissionCallbacks:
|
|||
|
||||
class TestHookCallbacks:
|
||||
"""Test hook callback functionality."""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_execution(self):
|
||||
"""Test that hooks are called at appropriate times."""
|
||||
hook_calls = []
|
||||
|
||||
|
||||
async def test_hook(
|
||||
input_data: dict,
|
||||
tool_use_id: str | None,
|
||||
|
|
@ -279,9 +234,9 @@ class TestHookCallbacks:
|
|||
"tool_use_id": tool_use_id
|
||||
})
|
||||
return {"processed": True}
|
||||
|
||||
|
||||
transport = MockTransport()
|
||||
|
||||
|
||||
# Create hooks configuration
|
||||
hooks = {
|
||||
"tool_use_start": [
|
||||
|
|
@ -291,26 +246,18 @@ class TestHookCallbacks:
|
|||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=None,
|
||||
hooks=hooks
|
||||
)
|
||||
|
||||
# During initialization, hook callbacks are registered
|
||||
await query.initialize()
|
||||
|
||||
# Find the registered callback ID
|
||||
callback_id = None
|
||||
for cid in query.hook_callbacks:
|
||||
if query.hook_callbacks[cid] == test_hook:
|
||||
callback_id = cid
|
||||
break
|
||||
|
||||
assert callback_id is not None
|
||||
|
||||
|
||||
# Manually register the hook callback to avoid needing the full initialize flow
|
||||
callback_id = "test_hook_0"
|
||||
query.hook_callbacks[callback_id] = test_hook
|
||||
|
||||
# Simulate hook callback request
|
||||
request = {
|
||||
"type": "control_request",
|
||||
|
|
@ -322,14 +269,14 @@ class TestHookCallbacks:
|
|||
"tool_use_id": "tool-123"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
|
||||
# Check hook was called
|
||||
assert len(hook_calls) == 1
|
||||
assert hook_calls[0]["input"] == {"test": "data"}
|
||||
assert hook_calls[0]["tool_use_id"] == "tool-123"
|
||||
|
||||
|
||||
# Check response
|
||||
assert len(transport.written_messages) > 0
|
||||
last_response = transport.written_messages[-1]
|
||||
|
|
@ -338,7 +285,7 @@ class TestHookCallbacks:
|
|||
|
||||
class TestClaudeCodeOptionsIntegration:
|
||||
"""Test that callbacks work through ClaudeCodeOptions."""
|
||||
|
||||
|
||||
def test_options_with_callbacks(self):
|
||||
"""Test creating options with callbacks."""
|
||||
async def my_callback(
|
||||
|
|
@ -347,14 +294,14 @@ class TestClaudeCodeOptionsIntegration:
|
|||
context: ToolPermissionContext
|
||||
) -> ToolPermissionResponse:
|
||||
return ToolPermissionResponse(allow=True)
|
||||
|
||||
|
||||
async def my_hook(
|
||||
input_data: dict,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
options = ClaudeCodeOptions(
|
||||
tool_permission_callback=my_callback,
|
||||
hooks={
|
||||
|
|
@ -366,8 +313,8 @@ class TestClaudeCodeOptionsIntegration:
|
|||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
assert options.tool_permission_callback == my_callback
|
||||
assert "tool_use_start" in options.hooks
|
||||
assert len(options.hooks["tool_use_start"]) == 1
|
||||
assert options.hooks["tool_use_start"][0].hooks[0] == my_hook
|
||||
assert options.hooks["tool_use_start"][0].hooks[0] == my_hook
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue