mirror of
https://github.com/anthropics/claude-code-sdk-python.git
synced 2025-12-23 09:19:52 +00:00
Add custom tool callbacks and e2e tests (#157)
## Summary This PR adds support for custom tool callbacks and comprehensive e2e testing for MCP calculator functionality. ## Key Features Added - **Custom tool permission callbacks** - Allow dynamic tool permission control via `can_use_tool` callback - **E2E test suite** - Real Claude API tests validating MCP tool execution end-to-end - **Fixed MCP calculator example** - Now properly uses `allowed_tools` for permission management ## Changes ### Custom Callbacks - Added `ToolPermissionContext` and `PermissionResult` types for tool permission handling - Implemented `can_use_tool` callback support in SDK client - Added comprehensive tests in `tests/test_tool_callbacks.py` ### E2E Testing Infrastructure - Created `e2e-tests/` directory with pytest-based test suite - `test_mcp_calculator.py` - Tests all calculator operations with real API calls - `conftest.py` - Pytest config with mandatory API key validation - GitHub Actions workflow for automated e2e testing on main branch - Comprehensive documentation in `e2e-tests/README.md` ### Bug Fixes - Fixed MCP calculator example to use `allowed_tools` instead of incorrect `permission_mode` - Resolved tool permission issues preventing MCP tools from executing ## Testing E2E tests require `ANTHROPIC_API_KEY` environment variable and will fail without it. Run locally: ```bash export ANTHROPIC_API_KEY=your-key python -m pytest e2e-tests/ -v -m e2e ``` Run unit tests including callback tests: ```bash python -m pytest tests/test_tool_callbacks.py -v ``` 🤖 Generated with [Claude Code](https://claude.ai/code) --------- Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Kashyap Murali <kashyap@anthropic.com>
This commit is contained in:
parent
d3190f12d3
commit
839300404f
15 changed files with 1301 additions and 90 deletions
45
.github/workflows/test-e2e.yml
vendored
45
.github/workflows/test-e2e.yml
vendored
|
|
@ -1,45 +0,0 @@
|
|||
name: Claude Code E2E Test
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
|
||||
jobs:
|
||||
integration-test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10", "3.11", "3.12", "3.13"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install Claude Code
|
||||
run: |
|
||||
curl -fsSL https://claude.ai/install.sh | bash
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Verify Claude Code installation
|
||||
run: claude -v
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -e .
|
||||
|
||||
- name: Run quickstart example
|
||||
run: python examples/quick_start.py
|
||||
env:
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
|
||||
- name: Run streaming mode examples
|
||||
run: timeout 120 python examples/streaming_mode.py all
|
||||
env:
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
117
.github/workflows/test.yml
vendored
117
.github/workflows/test.yml
vendored
|
|
@ -4,34 +4,103 @@ on:
|
|||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
- "main"
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.10', '3.11', '3.12', '3.13']
|
||||
|
||||
python-version: ["3.10", "3.11", "3.12", "3.13"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -e ".[dev]"
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
python -m pytest tests/ -v --cov=claude_code_sdk --cov-report=xml
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
fail_ci_if_error: false
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -e ".[dev]"
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
python -m pytest tests/ -v --cov=claude_code_sdk --cov-report=xml
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
fail_ci_if_error: false
|
||||
|
||||
test-e2e:
|
||||
runs-on: ubuntu-latest
|
||||
needs: test # Run after unit tests pass
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10", "3.11", "3.12", "3.13"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install Claude Code
|
||||
run: |
|
||||
curl -fsSL https://claude.ai/install.sh | bash
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Verify Claude Code installation
|
||||
run: claude -v
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -e ".[dev]"
|
||||
|
||||
- name: Run end-to-end tests with real API
|
||||
env:
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
run: |
|
||||
python -m pytest e2e-tests/ -v -m e2e
|
||||
|
||||
test-examples:
|
||||
runs-on: ubuntu-latest
|
||||
needs: test-e2e # Run after e2e tests
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10", "3.11", "3.12", "3.13"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install Claude Code
|
||||
run: |
|
||||
curl -fsSL https://claude.ai/install.sh | bash
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Verify Claude Code installation
|
||||
run: claude -v
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -e .
|
||||
|
||||
- name: Run example scripts
|
||||
env:
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
run: |
|
||||
python examples/quick_start.py
|
||||
timeout 120 python examples/streaming_mode.py all
|
||||
|
|
|
|||
84
README.md
84
README.md
|
|
@ -76,6 +76,90 @@ options = ClaudeCodeOptions(
|
|||
)
|
||||
```
|
||||
|
||||
### SDK MCP Servers (In-Process)
|
||||
|
||||
The SDK now supports in-process MCP servers that run directly within your Python application, eliminating the need for separate processes.
|
||||
|
||||
#### Creating a Simple Tool
|
||||
|
||||
```python
|
||||
from claude_code_sdk import tool, create_sdk_mcp_server
|
||||
|
||||
# Define a tool using the @tool decorator
|
||||
@tool("greet", "Greet a user", {"name": str})
|
||||
async def greet_user(args):
|
||||
return {
|
||||
"content": [
|
||||
{"type": "text", "text": f"Hello, {args['name']}!"}
|
||||
]
|
||||
}
|
||||
|
||||
# Create an SDK MCP server
|
||||
server = create_sdk_mcp_server(
|
||||
name="my-tools",
|
||||
version="1.0.0",
|
||||
tools=[greet_user]
|
||||
)
|
||||
|
||||
# Use it with Claude
|
||||
options = ClaudeCodeOptions(
|
||||
mcp_servers={"tools": server}
|
||||
)
|
||||
|
||||
async for message in query(prompt="Greet Alice", options=options):
|
||||
print(message)
|
||||
```
|
||||
|
||||
#### Benefits Over External MCP Servers
|
||||
|
||||
- **No subprocess management** - Runs in the same process as your application
|
||||
- **Better performance** - No IPC overhead for tool calls
|
||||
- **Simpler deployment** - Single Python process instead of multiple
|
||||
- **Easier debugging** - All code runs in the same process
|
||||
- **Type safety** - Direct Python function calls with type hints
|
||||
|
||||
#### Migration from External Servers
|
||||
|
||||
```python
|
||||
# BEFORE: External MCP server (separate process)
|
||||
options = ClaudeCodeOptions(
|
||||
mcp_servers={
|
||||
"calculator": {
|
||||
"type": "stdio",
|
||||
"command": "python",
|
||||
"args": ["-m", "calculator_server"]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# AFTER: SDK MCP server (in-process)
|
||||
from my_tools import add, subtract # Your tool functions
|
||||
|
||||
calculator = create_sdk_mcp_server(
|
||||
name="calculator",
|
||||
tools=[add, subtract]
|
||||
)
|
||||
|
||||
options = ClaudeCodeOptions(
|
||||
mcp_servers={"calculator": calculator}
|
||||
)
|
||||
```
|
||||
|
||||
#### Mixed Server Support
|
||||
|
||||
You can use both SDK and external MCP servers together:
|
||||
|
||||
```python
|
||||
options = ClaudeCodeOptions(
|
||||
mcp_servers={
|
||||
"internal": sdk_server, # In-process SDK server
|
||||
"external": { # External subprocess server
|
||||
"type": "stdio",
|
||||
"command": "external-server"
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
|
|
|
|||
102
e2e-tests/README.md
Normal file
102
e2e-tests/README.md
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
# End-to-End Tests for Claude Code SDK
|
||||
|
||||
This directory contains end-to-end tests that run against the actual Claude API to verify real-world functionality.
|
||||
|
||||
## Requirements
|
||||
|
||||
### API Key (REQUIRED)
|
||||
|
||||
These tests require a valid Anthropic API key. The tests will **fail** if `ANTHROPIC_API_KEY` is not set.
|
||||
|
||||
Set your API key before running tests:
|
||||
|
||||
```bash
|
||||
export ANTHROPIC_API_KEY="your-api-key-here"
|
||||
```
|
||||
|
||||
### Dependencies
|
||||
|
||||
Install the development dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
## Running the Tests
|
||||
|
||||
### Run all e2e tests:
|
||||
|
||||
```bash
|
||||
python -m pytest e2e-tests/ -v
|
||||
```
|
||||
|
||||
### Run with e2e marker only:
|
||||
|
||||
```bash
|
||||
python -m pytest e2e-tests/ -v -m e2e
|
||||
```
|
||||
|
||||
### Run a specific test:
|
||||
|
||||
```bash
|
||||
python -m pytest e2e-tests/test_mcp_calculator.py::test_basic_addition -v
|
||||
```
|
||||
|
||||
## Cost Considerations
|
||||
|
||||
⚠️ **Important**: These tests make actual API calls to Claude, which incur costs based on your Anthropic pricing plan.
|
||||
|
||||
- Each test typically uses 1-3 API calls
|
||||
- Tests use simple prompts to minimize token usage
|
||||
- The complete test suite should cost less than $0.10 to run
|
||||
|
||||
## Test Coverage
|
||||
|
||||
### MCP Calculator Tests (`test_mcp_calculator.py`)
|
||||
|
||||
Tests the MCP (Model Context Protocol) integration with calculator tools:
|
||||
|
||||
- **test_basic_addition**: Verifies the add tool executes correctly
|
||||
- **test_division**: Tests division with decimal results
|
||||
- **test_square_root**: Validates square root calculations
|
||||
- **test_power**: Tests exponentiation
|
||||
- **test_multi_step_calculation**: Verifies multiple tools can be used in sequence
|
||||
- **test_tool_permissions_enforced**: Ensures permission system works correctly
|
||||
|
||||
Each test validates:
|
||||
1. Tools are actually called (ToolUseBlock present in response)
|
||||
2. Correct tool inputs are provided
|
||||
3. Expected results are returned
|
||||
4. Permission system is enforced
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
These tests run automatically on:
|
||||
- Pushes to `main` branch (via GitHub Actions)
|
||||
- Manual workflow dispatch
|
||||
|
||||
The workflow uses `ANTHROPIC_API_KEY` from GitHub Secrets.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "ANTHROPIC_API_KEY environment variable is required" error
|
||||
- Set your API key: `export ANTHROPIC_API_KEY=sk-ant-...`
|
||||
- The tests will not skip - they require the key to run
|
||||
|
||||
### Tests timing out
|
||||
- Check your API key is valid and has quota available
|
||||
- Ensure network connectivity to api.anthropic.com
|
||||
|
||||
### Permission denied errors
|
||||
- Verify the `allowed_tools` parameter includes the necessary MCP tools
|
||||
- Check that tool names match the expected format (e.g., `mcp__calc__add`)
|
||||
|
||||
## Adding New E2E Tests
|
||||
|
||||
When adding new e2e tests:
|
||||
|
||||
1. Mark tests with `@pytest.mark.e2e` decorator
|
||||
2. Use the `api_key` fixture to ensure API key is available
|
||||
3. Keep prompts simple to minimize costs
|
||||
4. Verify actual tool execution, not just mocked responses
|
||||
5. Document any special setup requirements in this README
|
||||
30
e2e-tests/conftest.py
Normal file
30
e2e-tests/conftest.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
"""Pytest configuration for e2e tests."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api_key():
|
||||
"""Ensure ANTHROPIC_API_KEY is set for e2e tests."""
|
||||
key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
if not key:
|
||||
pytest.fail(
|
||||
"ANTHROPIC_API_KEY environment variable is required for e2e tests. "
|
||||
"Set it before running: export ANTHROPIC_API_KEY=your-key-here"
|
||||
)
|
||||
return key
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop_policy():
|
||||
"""Use the default event loop policy for all async tests."""
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop_policy()
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Add e2e marker."""
|
||||
config.addinivalue_line("markers", "e2e: marks tests as e2e tests requiring API key")
|
||||
165
e2e-tests/test_sdk_mcp_tools.py
Normal file
165
e2e-tests/test_sdk_mcp_tools.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
"""End-to-end tests for SDK MCP (inline) tools with real Claude API calls.
|
||||
|
||||
These tests verify that SDK-created MCP tools work correctly through the full stack,
|
||||
focusing on tool execution mechanics rather than specific tool functionality.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_code_sdk import (
|
||||
AssistantMessage,
|
||||
ClaudeCodeOptions,
|
||||
ClaudeSDKClient,
|
||||
UserMessage,
|
||||
create_sdk_mcp_server,
|
||||
tool,
|
||||
)
|
||||
from claude_code_sdk.types import ToolResultBlock, ToolUseBlock
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@pytest.mark.asyncio
|
||||
async def test_sdk_mcp_tool_execution():
|
||||
"""Test that SDK MCP tools can be called and executed with allowed_tools."""
|
||||
executions = []
|
||||
|
||||
@tool("echo", "Echo back the input text", {"text": str})
|
||||
async def echo_tool(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Echo back whatever text is provided."""
|
||||
executions.append("echo")
|
||||
return {"content": [{"type": "text", "text": f"Echo: {args['text']}"}]}
|
||||
|
||||
server = create_sdk_mcp_server(
|
||||
name="test",
|
||||
version="1.0.0",
|
||||
tools=[echo_tool],
|
||||
)
|
||||
|
||||
options = ClaudeCodeOptions(
|
||||
mcp_servers={"test": server},
|
||||
allowed_tools=["mcp__test__echo"],
|
||||
)
|
||||
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
await client.query("Call the mcp__test__echo tool with any text")
|
||||
|
||||
async for message in client.receive_response():
|
||||
pass # Just consume messages
|
||||
|
||||
# Check if the actual Python function was called
|
||||
assert "echo" in executions, "Echo tool function was not executed"
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@pytest.mark.asyncio
|
||||
async def test_sdk_mcp_permission_enforcement():
|
||||
"""Test that disallowed_tools prevents SDK MCP tool execution."""
|
||||
executions = []
|
||||
|
||||
@tool("echo", "Echo back the input text", {"text": str})
|
||||
async def echo_tool(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Echo back whatever text is provided."""
|
||||
executions.append("echo")
|
||||
return {"content": [{"type": "text", "text": f"Echo: {args['text']}"}]}
|
||||
|
||||
@tool("greet", "Greet a person by name", {"name": str})
|
||||
async def greet_tool(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Greet someone by name."""
|
||||
executions.append("greet")
|
||||
return {"content": [{"type": "text", "text": f"Hello, {args['name']}!"}]}
|
||||
|
||||
server = create_sdk_mcp_server(
|
||||
name="test",
|
||||
version="1.0.0",
|
||||
tools=[echo_tool, greet_tool],
|
||||
)
|
||||
|
||||
options = ClaudeCodeOptions(
|
||||
mcp_servers={"test": server},
|
||||
disallowed_tools=["mcp__test__echo"], # Block echo tool
|
||||
allowed_tools=["mcp__test__greet"], # But allow greet
|
||||
)
|
||||
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
await client.query("Use the echo tool to echo 'test' and use greet tool to greet 'Alice'")
|
||||
|
||||
async for message in client.receive_response():
|
||||
pass # Just consume messages
|
||||
|
||||
# Check actual function executions
|
||||
assert "echo" not in executions, "Disallowed echo tool was executed"
|
||||
assert "greet" in executions, "Allowed greet tool was not executed"
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@pytest.mark.asyncio
|
||||
async def test_sdk_mcp_multiple_tools():
|
||||
"""Test that multiple SDK MCP tools can be called in sequence."""
|
||||
executions = []
|
||||
|
||||
@tool("echo", "Echo back the input text", {"text": str})
|
||||
async def echo_tool(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Echo back whatever text is provided."""
|
||||
executions.append("echo")
|
||||
return {"content": [{"type": "text", "text": f"Echo: {args['text']}"}]}
|
||||
|
||||
@tool("greet", "Greet a person by name", {"name": str})
|
||||
async def greet_tool(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Greet someone by name."""
|
||||
executions.append("greet")
|
||||
return {"content": [{"type": "text", "text": f"Hello, {args['name']}!"}]}
|
||||
|
||||
server = create_sdk_mcp_server(
|
||||
name="multi",
|
||||
version="1.0.0",
|
||||
tools=[echo_tool, greet_tool],
|
||||
)
|
||||
|
||||
options = ClaudeCodeOptions(
|
||||
mcp_servers={"multi": server},
|
||||
allowed_tools=["mcp__multi__echo", "mcp__multi__greet"],
|
||||
)
|
||||
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
await client.query("Call mcp__multi__echo with text='test' and mcp__multi__greet with name='Bob'")
|
||||
|
||||
async for message in client.receive_response():
|
||||
pass # Just consume messages
|
||||
|
||||
# Both tools should have been executed
|
||||
assert "echo" in executions, "Echo tool was not executed"
|
||||
assert "greet" in executions, "Greet tool was not executed"
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@pytest.mark.asyncio
|
||||
async def test_sdk_mcp_without_permissions():
|
||||
"""Test SDK MCP tool behavior without explicit allowed_tools."""
|
||||
executions = []
|
||||
|
||||
@tool("echo", "Echo back the input text", {"text": str})
|
||||
async def echo_tool(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Echo back whatever text is provided."""
|
||||
executions.append("echo")
|
||||
return {"content": [{"type": "text", "text": f"Echo: {args['text']}"}]}
|
||||
|
||||
server = create_sdk_mcp_server(
|
||||
name="noperm",
|
||||
version="1.0.0",
|
||||
tools=[echo_tool],
|
||||
)
|
||||
|
||||
# No allowed_tools specified
|
||||
options = ClaudeCodeOptions(
|
||||
mcp_servers={"noperm": server},
|
||||
)
|
||||
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
await client.query("Call the mcp__noperm__echo tool")
|
||||
|
||||
async for message in client.receive_response():
|
||||
pass # Just consume messages
|
||||
|
||||
assert "echo" not in executions, "SDK MCP tool was executed"
|
||||
43
e2e-tests/test_tool_permissions.py
Normal file
43
e2e-tests/test_tool_permissions.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
"""End-to-end tests for tool permission callbacks with real Claude API calls."""
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_code_sdk import (
|
||||
ClaudeCodeOptions,
|
||||
ClaudeSDKClient,
|
||||
PermissionResultAllow,
|
||||
PermissionResultDeny,
|
||||
ToolPermissionContext,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_gets_called():
|
||||
"""Test that can_use_tool callback gets invoked."""
|
||||
callback_invocations = []
|
||||
|
||||
async def permission_callback(
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
context: ToolPermissionContext
|
||||
) -> PermissionResultAllow | PermissionResultDeny:
|
||||
"""Track callback invocation."""
|
||||
print(f"Permission callback called for: {tool_name}, input: {input_data}")
|
||||
callback_invocations.append(tool_name)
|
||||
return PermissionResultAllow()
|
||||
|
||||
options = ClaudeCodeOptions(
|
||||
can_use_tool=permission_callback,
|
||||
)
|
||||
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
await client.query("Write 'hello world' to /tmp/test.txt")
|
||||
|
||||
async for message in client.receive_response():
|
||||
print(f"Got message: {message}")
|
||||
pass # Just consume messages
|
||||
|
||||
print(f'Callback invocations: {callback_invocations}')
|
||||
# Verify callback was invoked
|
||||
assert "Write" in callback_invocations, f"can_use_tool callback should have been invoked for Write tool, got: {callback_invocations}"
|
||||
222
examples/mcp_calculator.py
Normal file
222
examples/mcp_calculator.py
Normal file
|
|
@ -0,0 +1,222 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Example: Calculator MCP Server.
|
||||
|
||||
This example demonstrates how to create an in-process MCP server with
|
||||
calculator tools using the Claude Code Python SDK.
|
||||
|
||||
Unlike external MCP servers that require separate processes, this server
|
||||
runs directly within your Python application, providing better performance
|
||||
and simpler deployment.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from claude_code_sdk import (
|
||||
ClaudeCodeOptions,
|
||||
create_sdk_mcp_server,
|
||||
tool,
|
||||
)
|
||||
|
||||
# Define calculator tools using the @tool decorator
|
||||
|
||||
@tool("add", "Add two numbers", {"a": float, "b": float})
|
||||
async def add_numbers(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Add two numbers together."""
|
||||
result = args["a"] + args["b"]
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{args['a']} + {args['b']} = {result}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@tool("subtract", "Subtract one number from another", {"a": float, "b": float})
|
||||
async def subtract_numbers(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Subtract b from a."""
|
||||
result = args["a"] - args["b"]
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{args['a']} - {args['b']} = {result}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@tool("multiply", "Multiply two numbers", {"a": float, "b": float})
|
||||
async def multiply_numbers(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Multiply two numbers."""
|
||||
result = args["a"] * args["b"]
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{args['a']} × {args['b']} = {result}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@tool("divide", "Divide one number by another", {"a": float, "b": float})
|
||||
async def divide_numbers(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Divide a by b."""
|
||||
if args["b"] == 0:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Error: Division by zero is not allowed"
|
||||
}
|
||||
],
|
||||
"is_error": True
|
||||
}
|
||||
|
||||
result = args["a"] / args["b"]
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{args['a']} ÷ {args['b']} = {result}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@tool("sqrt", "Calculate square root", {"n": float})
|
||||
async def square_root(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Calculate the square root of a number."""
|
||||
n = args["n"]
|
||||
if n < 0:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Error: Cannot calculate square root of negative number {n}"
|
||||
}
|
||||
],
|
||||
"is_error": True
|
||||
}
|
||||
|
||||
import math
|
||||
result = math.sqrt(n)
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"√{n} = {result}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@tool("power", "Raise a number to a power", {"base": float, "exponent": float})
|
||||
async def power(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Raise base to the exponent power."""
|
||||
result = args["base"] ** args["exponent"]
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{args['base']}^{args['exponent']} = {result}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def display_message(msg):
|
||||
"""Display message content in a clean format."""
|
||||
from claude_code_sdk import (
|
||||
AssistantMessage,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
if isinstance(msg, UserMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"User: {block.text}")
|
||||
elif isinstance(block, ToolResultBlock):
|
||||
print(f"Tool Result: {block.content[:100] if block.content else 'None'}...")
|
||||
elif isinstance(msg, AssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"Claude: {block.text}")
|
||||
elif isinstance(block, ToolUseBlock):
|
||||
print(f"Using tool: {block.name}")
|
||||
# Show tool inputs for calculator
|
||||
if block.input:
|
||||
print(f" Input: {block.input}")
|
||||
elif isinstance(msg, SystemMessage):
|
||||
# Ignore system messages
|
||||
pass
|
||||
elif isinstance(msg, ResultMessage):
|
||||
print("Result ended")
|
||||
if msg.total_cost_usd:
|
||||
print(f"Cost: ${msg.total_cost_usd:.6f}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run example calculations using the SDK MCP server with streaming client."""
|
||||
from claude_code_sdk import ClaudeSDKClient
|
||||
|
||||
# Create the calculator server with all tools
|
||||
calculator = create_sdk_mcp_server(
|
||||
name="calculator",
|
||||
version="2.0.0",
|
||||
tools=[
|
||||
add_numbers,
|
||||
subtract_numbers,
|
||||
multiply_numbers,
|
||||
divide_numbers,
|
||||
square_root,
|
||||
power
|
||||
]
|
||||
)
|
||||
|
||||
# Configure Claude to use the calculator server with allowed tools
|
||||
# Pre-approve all calculator MCP tools so they can be used without permission prompts
|
||||
options = ClaudeCodeOptions(
|
||||
mcp_servers={"calc": calculator},
|
||||
allowed_tools=[
|
||||
"mcp__calc__add",
|
||||
"mcp__calc__subtract",
|
||||
"mcp__calc__multiply",
|
||||
"mcp__calc__divide",
|
||||
"mcp__calc__sqrt",
|
||||
"mcp__calc__power"
|
||||
]
|
||||
)
|
||||
|
||||
# Example prompts to demonstrate calculator usage
|
||||
prompts = [
|
||||
"List your tools",
|
||||
"Calculate 15 + 27",
|
||||
"What is 100 divided by 7?",
|
||||
"Calculate the square root of 144",
|
||||
"What is 2 raised to the power of 8?",
|
||||
"Calculate (12 + 8) * 3 - 10" # Complex calculation
|
||||
]
|
||||
|
||||
for prompt in prompts:
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Prompt: {prompt}")
|
||||
print(f"{'='*50}")
|
||||
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
await client.query(prompt)
|
||||
|
||||
async for message in client.receive_response():
|
||||
display_message(message)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
158
examples/tool_permission_callback.py
Normal file
158
examples/tool_permission_callback.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Example: Tool Permission Callbacks.
|
||||
|
||||
This example demonstrates how to use tool permission callbacks to control
|
||||
which tools Claude can use and modify their inputs.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from claude_code_sdk import (
|
||||
AssistantMessage,
|
||||
ClaudeCodeOptions,
|
||||
ClaudeSDKClient,
|
||||
PermissionResultAllow,
|
||||
PermissionResultDeny,
|
||||
ResultMessage,
|
||||
TextBlock,
|
||||
ToolPermissionContext,
|
||||
)
|
||||
|
||||
# Track tool usage for demonstration
|
||||
tool_usage_log = []
|
||||
|
||||
|
||||
async def my_permission_callback(
|
||||
tool_name: str,
|
||||
input_data: dict,
|
||||
context: ToolPermissionContext
|
||||
) -> PermissionResultAllow | PermissionResultDeny:
|
||||
"""Control tool permissions based on tool type and input."""
|
||||
|
||||
# Log the tool request
|
||||
tool_usage_log.append({
|
||||
"tool": tool_name,
|
||||
"input": input_data,
|
||||
"suggestions": context.suggestions
|
||||
})
|
||||
|
||||
print(f"\n🔧 Tool Permission Request: {tool_name}")
|
||||
print(f" Input: {json.dumps(input_data, indent=2)}")
|
||||
|
||||
# Always allow read operations
|
||||
if tool_name in ["Read", "Glob", "Grep"]:
|
||||
print(f" ✅ Automatically allowing {tool_name} (read-only operation)")
|
||||
return PermissionResultAllow()
|
||||
|
||||
# Deny write operations to system directories
|
||||
if tool_name in ["Write", "Edit", "MultiEdit"]:
|
||||
file_path = input_data.get("file_path", "")
|
||||
if file_path.startswith("/etc/") or file_path.startswith("/usr/"):
|
||||
print(f" ❌ Denying write to system directory: {file_path}")
|
||||
return PermissionResultDeny(
|
||||
message=f"Cannot write to system directory: {file_path}"
|
||||
)
|
||||
|
||||
# Redirect writes to a safe directory
|
||||
if not file_path.startswith("/tmp/") and not file_path.startswith("./"):
|
||||
safe_path = f"./safe_output/{file_path.split('/')[-1]}"
|
||||
print(f" ⚠️ Redirecting write from {file_path} to {safe_path}")
|
||||
modified_input = input_data.copy()
|
||||
modified_input["file_path"] = safe_path
|
||||
return PermissionResultAllow(
|
||||
updated_input=modified_input
|
||||
)
|
||||
|
||||
# Check dangerous bash commands
|
||||
if tool_name == "Bash":
|
||||
command = input_data.get("command", "")
|
||||
dangerous_commands = ["rm -rf", "sudo", "chmod 777", "dd if=", "mkfs"]
|
||||
|
||||
for dangerous in dangerous_commands:
|
||||
if dangerous in command:
|
||||
print(f" ❌ Denying dangerous command: {command}")
|
||||
return PermissionResultDeny(
|
||||
message=f"Dangerous command pattern detected: {dangerous}"
|
||||
)
|
||||
|
||||
# Allow but log the command
|
||||
print(f" ✅ Allowing bash command: {command}")
|
||||
return PermissionResultAllow()
|
||||
|
||||
# For all other tools, ask the user
|
||||
print(f" ❓ Unknown tool: {tool_name}")
|
||||
print(f" Input: {json.dumps(input_data, indent=6)}")
|
||||
user_input = input(" Allow this tool? (y/N): ").strip().lower()
|
||||
|
||||
if user_input in ("y", "yes"):
|
||||
return PermissionResultAllow()
|
||||
else:
|
||||
return PermissionResultDeny(
|
||||
message="User denied permission"
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run example with tool permission callbacks."""
|
||||
|
||||
print("=" * 60)
|
||||
print("Tool Permission Callback Example")
|
||||
print("=" * 60)
|
||||
print("\nThis example demonstrates how to:")
|
||||
print("1. Allow/deny tools based on type")
|
||||
print("2. Modify tool inputs for safety")
|
||||
print("3. Log tool usage")
|
||||
print("4. Prompt for unknown tools")
|
||||
print("=" * 60)
|
||||
|
||||
# Configure options with our callback
|
||||
options = ClaudeCodeOptions(
|
||||
can_use_tool=my_permission_callback,
|
||||
# Use default permission mode to ensure callbacks are invoked
|
||||
permission_mode="default",
|
||||
cwd="." # Set working directory
|
||||
)
|
||||
|
||||
# Create client and send a query that will use multiple tools
|
||||
async with ClaudeSDKClient(options) as client:
|
||||
print("\n📝 Sending query to Claude...")
|
||||
await client.query(
|
||||
"Please do the following:\n"
|
||||
"1. List the files in the current directory\n"
|
||||
"2. Create a simple Python hello world script at hello.py\n"
|
||||
"3. Run the script to test it"
|
||||
)
|
||||
|
||||
print("\n📨 Receiving response...")
|
||||
message_count = 0
|
||||
|
||||
async for message in client.receive_response():
|
||||
message_count += 1
|
||||
|
||||
if isinstance(message, AssistantMessage):
|
||||
# Print Claude's text responses
|
||||
for block in message.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"\n💬 Claude: {block.text}")
|
||||
|
||||
elif isinstance(message, ResultMessage):
|
||||
print("\n✅ Task completed!")
|
||||
print(f" Duration: {message.duration_ms}ms")
|
||||
if message.total_cost_usd:
|
||||
print(f" Cost: ${message.total_cost_usd:.4f}")
|
||||
print(f" Messages processed: {message_count}")
|
||||
|
||||
# Print tool usage summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Tool Usage Summary")
|
||||
print("=" * 60)
|
||||
for i, usage in enumerate(tool_usage_log, 1):
|
||||
print(f"\n{i}. Tool: {usage['tool']}")
|
||||
print(f" Input: {json.dumps(usage['input'], indent=6)}")
|
||||
if usage['suggestions']:
|
||||
print(f" Suggestions: {usage['suggestions']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -16,8 +16,12 @@ from .client import ClaudeSDKClient
|
|||
from .query import query
|
||||
from .types import (
|
||||
AssistantMessage,
|
||||
CanUseTool,
|
||||
ClaudeCodeOptions,
|
||||
ContentBlock,
|
||||
HookCallback,
|
||||
HookContext,
|
||||
HookMatcher,
|
||||
McpSdkServerConfig,
|
||||
McpServerConfig,
|
||||
Message,
|
||||
|
|
@ -30,6 +34,7 @@ from .types import (
|
|||
SystemMessage,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolPermissionContext,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
|
|
@ -292,11 +297,20 @@ __all__ = [
|
|||
"ToolUseBlock",
|
||||
"ToolResultBlock",
|
||||
"ContentBlock",
|
||||
# Permission results (keep these as they may be used by internal callbacks)
|
||||
# Tool callbacks
|
||||
"CanUseTool",
|
||||
"ToolPermissionContext",
|
||||
"PermissionResult",
|
||||
"PermissionResultAllow",
|
||||
"PermissionResultDeny",
|
||||
"PermissionUpdate",
|
||||
"HookCallback",
|
||||
"HookContext",
|
||||
"HookMatcher",
|
||||
# MCP Server Support
|
||||
"create_sdk_mcp_server",
|
||||
"tool",
|
||||
"SdkMcpTool",
|
||||
# Errors
|
||||
"ClaudeSDKError",
|
||||
"CLIConnectionError",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Internal client implementation."""
|
||||
|
||||
from collections.abc import AsyncIterable, AsyncIterator
|
||||
from dataclasses import replace
|
||||
from typing import Any
|
||||
|
||||
from ..types import (
|
||||
|
|
@ -43,19 +44,43 @@ class InternalClient:
|
|||
) -> AsyncIterator[Message]:
|
||||
"""Process a query through transport and Query."""
|
||||
|
||||
# Validate and configure permission settings (matching TypeScript SDK logic)
|
||||
configured_options = options
|
||||
if options.can_use_tool:
|
||||
# canUseTool callback requires streaming mode (AsyncIterable prompt)
|
||||
if isinstance(prompt, str):
|
||||
raise ValueError(
|
||||
"can_use_tool callback requires streaming mode. "
|
||||
"Please provide prompt as an AsyncIterable instead of a string."
|
||||
)
|
||||
|
||||
# canUseTool and permission_prompt_tool_name are mutually exclusive
|
||||
if options.permission_prompt_tool_name:
|
||||
raise ValueError(
|
||||
"can_use_tool callback cannot be used with permission_prompt_tool_name. "
|
||||
"Please use one or the other."
|
||||
)
|
||||
|
||||
# Automatically set permission_prompt_tool_name to "stdio" for control protocol
|
||||
configured_options = replace(options, permission_prompt_tool_name="stdio")
|
||||
|
||||
# Use provided transport or create subprocess transport
|
||||
if transport is not None:
|
||||
chosen_transport = transport
|
||||
else:
|
||||
chosen_transport = SubprocessCLITransport(prompt=prompt, options=options)
|
||||
chosen_transport = SubprocessCLITransport(
|
||||
prompt=prompt, options=configured_options
|
||||
)
|
||||
|
||||
# Connect transport
|
||||
await chosen_transport.connect()
|
||||
|
||||
# Extract SDK MCP servers from options
|
||||
# Extract SDK MCP servers from configured options
|
||||
sdk_mcp_servers = {}
|
||||
if options.mcp_servers and isinstance(options.mcp_servers, dict):
|
||||
for name, config in options.mcp_servers.items():
|
||||
if configured_options.mcp_servers and isinstance(
|
||||
configured_options.mcp_servers, dict
|
||||
):
|
||||
for name, config in configured_options.mcp_servers.items():
|
||||
if isinstance(config, dict) and config.get("type") == "sdk":
|
||||
sdk_mcp_servers[name] = config["instance"] # type: ignore[typeddict-item]
|
||||
|
||||
|
|
@ -64,9 +89,9 @@ class InternalClient:
|
|||
query = Query(
|
||||
transport=chosen_transport,
|
||||
is_streaming_mode=is_streaming,
|
||||
can_use_tool=options.can_use_tool,
|
||||
hooks=self._convert_hooks_to_internal_format(options.hooks)
|
||||
if options.hooks
|
||||
can_use_tool=configured_options.can_use_tool,
|
||||
hooks=self._convert_hooks_to_internal_format(configured_options.hooks)
|
||||
if configured_options.hooks
|
||||
else None,
|
||||
sdk_mcp_servers=sdk_mcp_servers,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -250,9 +250,11 @@ class Query:
|
|||
# Type narrowing - we've verified these are not None above
|
||||
assert isinstance(server_name, str)
|
||||
assert isinstance(mcp_message, dict)
|
||||
response_data = await self._handle_sdk_mcp_request(
|
||||
mcp_response = await self._handle_sdk_mcp_request(
|
||||
server_name, mcp_message
|
||||
)
|
||||
# Wrap the MCP response as expected by the control protocol
|
||||
response_data = {"mcp_response": mcp_response}
|
||||
|
||||
else:
|
||||
raise Exception(f"Unsupported control request subtype: {subtype}")
|
||||
|
|
@ -357,7 +359,24 @@ class Query:
|
|||
#
|
||||
# This forces us to manually route methods. When Python MCP adds Transport
|
||||
# support, we can refactor to match the TypeScript approach.
|
||||
if method == "tools/list":
|
||||
if method == "initialize":
|
||||
# Handle MCP initialization - hardcoded for tools only, no listChanged
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message.get("id"),
|
||||
"result": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {
|
||||
"tools": {} # Tools capability without listChanged
|
||||
},
|
||||
"serverInfo": {
|
||||
"name": server.name,
|
||||
"version": server.version or "1.0.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
elif method == "tools/list":
|
||||
request = ListToolsRequest(method=method)
|
||||
handler = server.request_handlers.get(ListToolsRequest)
|
||||
if handler:
|
||||
|
|
@ -367,7 +386,11 @@ class Query:
|
|||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.inputSchema.model_dump() # type: ignore[union-attr]
|
||||
"inputSchema": (
|
||||
tool.inputSchema.model_dump()
|
||||
if hasattr(tool.inputSchema, "model_dump")
|
||||
else tool.inputSchema
|
||||
)
|
||||
if tool.inputSchema
|
||||
else {},
|
||||
}
|
||||
|
|
@ -413,6 +436,10 @@ class Query:
|
|||
"result": response_data,
|
||||
}
|
||||
|
||||
elif method == "notifications/initialized":
|
||||
# Handle initialized notification - just acknowledge it
|
||||
return {"jsonrpc": "2.0", "result": {}}
|
||||
|
||||
# Add more methods here as MCP SDK adds them (resources, prompts, etc.)
|
||||
# This is the limitation Ashwin pointed out - we have to manually update
|
||||
|
||||
|
|
|
|||
|
|
@ -125,19 +125,25 @@ class SubprocessCLITransport(Transport):
|
|||
|
||||
if self._options.mcp_servers:
|
||||
if isinstance(self._options.mcp_servers, dict):
|
||||
# Filter out SDK servers - they're handled in-process
|
||||
external_servers = {
|
||||
name: config
|
||||
for name, config in self._options.mcp_servers.items()
|
||||
if not (isinstance(config, dict) and config.get("type") == "sdk")
|
||||
}
|
||||
# Process all servers, stripping instance field from SDK servers
|
||||
servers_for_cli: dict[str, Any] = {}
|
||||
for name, config in self._options.mcp_servers.items():
|
||||
if isinstance(config, dict) and config.get("type") == "sdk":
|
||||
# For SDK servers, pass everything except the instance field
|
||||
sdk_config: dict[str, object] = {
|
||||
k: v for k, v in config.items() if k != "instance"
|
||||
}
|
||||
servers_for_cli[name] = sdk_config
|
||||
else:
|
||||
# For external servers, pass as-is
|
||||
servers_for_cli[name] = config
|
||||
|
||||
# Only pass external servers to CLI
|
||||
if external_servers:
|
||||
# Pass all servers to CLI
|
||||
if servers_for_cli:
|
||||
cmd.extend(
|
||||
[
|
||||
"--mcp-config",
|
||||
json.dumps({"mcpServers": external_servers}),
|
||||
json.dumps({"mcpServers": servers_for_cli}),
|
||||
]
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import json
|
||||
import os
|
||||
from collections.abc import AsyncIterable, AsyncIterator
|
||||
from dataclasses import replace
|
||||
from typing import Any
|
||||
|
||||
from ._errors import CLIConnectionError
|
||||
|
|
@ -134,9 +135,30 @@ class ClaudeSDKClient:
|
|||
|
||||
actual_prompt = _empty_stream() if prompt is None else prompt
|
||||
|
||||
# Validate and configure permission settings (matching TypeScript SDK logic)
|
||||
if self.options.can_use_tool:
|
||||
# canUseTool callback requires streaming mode (AsyncIterable prompt)
|
||||
if isinstance(prompt, str):
|
||||
raise ValueError(
|
||||
"can_use_tool callback requires streaming mode. "
|
||||
"Please provide prompt as an AsyncIterable instead of a string."
|
||||
)
|
||||
|
||||
# canUseTool and permission_prompt_tool_name are mutually exclusive
|
||||
if self.options.permission_prompt_tool_name:
|
||||
raise ValueError(
|
||||
"can_use_tool callback cannot be used with permission_prompt_tool_name. "
|
||||
"Please use one or the other."
|
||||
)
|
||||
|
||||
# Automatically set permission_prompt_tool_name to "stdio" for control protocol
|
||||
options = replace(self.options, permission_prompt_tool_name="stdio")
|
||||
else:
|
||||
options = self.options
|
||||
|
||||
self._transport = SubprocessCLITransport(
|
||||
prompt=actual_prompt,
|
||||
options=self.options,
|
||||
options=options,
|
||||
)
|
||||
await self._transport.connect()
|
||||
|
||||
|
|
|
|||
289
tests/test_tool_callbacks.py
Normal file
289
tests/test_tool_callbacks.py
Normal file
|
|
@ -0,0 +1,289 @@
|
|||
"""Tests for tool permission callbacks and hook callbacks."""
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_code_sdk import (
|
||||
ClaudeCodeOptions,
|
||||
HookContext,
|
||||
HookMatcher,
|
||||
PermissionResultAllow,
|
||||
PermissionResultDeny,
|
||||
ToolPermissionContext,
|
||||
)
|
||||
from claude_code_sdk._internal.query import Query
|
||||
from claude_code_sdk._internal.transport import Transport
|
||||
|
||||
|
||||
class MockTransport(Transport):
|
||||
"""Mock transport for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.written_messages = []
|
||||
self.messages_to_read = []
|
||||
self._connected = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
self._connected = True
|
||||
|
||||
async def close(self) -> None:
|
||||
self._connected = False
|
||||
|
||||
async def write(self, data: str) -> None:
|
||||
self.written_messages.append(data)
|
||||
|
||||
async def end_input(self) -> None:
|
||||
pass
|
||||
|
||||
def read_messages(self):
|
||||
async def _read():
|
||||
for msg in self.messages_to_read:
|
||||
yield msg
|
||||
|
||||
return _read()
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
return self._connected
|
||||
|
||||
|
||||
class TestToolPermissionCallbacks:
|
||||
"""Test tool permission callback functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_allow(self):
|
||||
"""Test callback that allows tool execution."""
|
||||
callback_invoked = False
|
||||
|
||||
async def allow_callback(
|
||||
tool_name: str, input_data: dict, context: ToolPermissionContext
|
||||
) -> PermissionResultAllow:
|
||||
nonlocal callback_invoked
|
||||
callback_invoked = True
|
||||
assert tool_name == "TestTool"
|
||||
assert input_data == {"param": "value"}
|
||||
return PermissionResultAllow()
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=allow_callback,
|
||||
hooks=None,
|
||||
)
|
||||
|
||||
# Simulate control request
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-1",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "TestTool",
|
||||
"input": {"param": "value"},
|
||||
"permission_suggestions": [],
|
||||
},
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check callback was invoked
|
||||
assert callback_invoked
|
||||
|
||||
# Check response was sent
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
assert '"allow": true' in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_deny(self):
|
||||
"""Test callback that denies tool execution."""
|
||||
|
||||
async def deny_callback(
|
||||
tool_name: str, input_data: dict, context: ToolPermissionContext
|
||||
) -> PermissionResultDeny:
|
||||
return PermissionResultDeny(message="Security policy violation")
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=deny_callback,
|
||||
hooks=None,
|
||||
)
|
||||
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-2",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "DangerousTool",
|
||||
"input": {"command": "rm -rf /"},
|
||||
"permission_suggestions": ["deny"],
|
||||
},
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check response
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
assert '"allow": false' in response
|
||||
assert '"reason": "Security policy violation"' in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_input_modification(self):
|
||||
"""Test callback that modifies tool input."""
|
||||
|
||||
async def modify_callback(
|
||||
tool_name: str, input_data: dict, context: ToolPermissionContext
|
||||
) -> PermissionResultAllow:
|
||||
# Modify the input to add safety flag
|
||||
modified_input = input_data.copy()
|
||||
modified_input["safe_mode"] = True
|
||||
return PermissionResultAllow(updated_input=modified_input)
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=modify_callback,
|
||||
hooks=None,
|
||||
)
|
||||
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-3",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "WriteTool",
|
||||
"input": {"file_path": "/etc/passwd"},
|
||||
"permission_suggestions": [],
|
||||
},
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check response includes modified input
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
assert '"allow": true' in response
|
||||
assert '"safe_mode": true' in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_exception_handling(self):
|
||||
"""Test that callback exceptions are properly handled."""
|
||||
|
||||
async def error_callback(
|
||||
tool_name: str, input_data: dict, context: ToolPermissionContext
|
||||
) -> PermissionResultAllow:
|
||||
raise ValueError("Callback error")
|
||||
|
||||
transport = MockTransport()
|
||||
query = Query(
|
||||
transport=transport,
|
||||
is_streaming_mode=True,
|
||||
can_use_tool=error_callback,
|
||||
hooks=None,
|
||||
)
|
||||
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-5",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "TestTool",
|
||||
"input": {},
|
||||
"permission_suggestions": [],
|
||||
},
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check error response was sent
|
||||
assert len(transport.written_messages) == 1
|
||||
response = transport.written_messages[0]
|
||||
assert '"subtype": "error"' in response
|
||||
assert "Callback error" in response
|
||||
|
||||
|
||||
class TestHookCallbacks:
|
||||
"""Test hook callback functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_execution(self):
|
||||
"""Test that hooks are called at appropriate times."""
|
||||
hook_calls = []
|
||||
|
||||
async def test_hook(
|
||||
input_data: dict, tool_use_id: str | None, context: HookContext
|
||||
) -> dict:
|
||||
hook_calls.append({"input": input_data, "tool_use_id": tool_use_id})
|
||||
return {"processed": True}
|
||||
|
||||
transport = MockTransport()
|
||||
|
||||
# Create hooks configuration
|
||||
hooks = {
|
||||
"tool_use_start": [{"matcher": {"tool": "TestTool"}, "hooks": [test_hook]}]
|
||||
}
|
||||
|
||||
query = Query(
|
||||
transport=transport, is_streaming_mode=True, can_use_tool=None, hooks=hooks
|
||||
)
|
||||
|
||||
# Manually register the hook callback to avoid needing the full initialize flow
|
||||
callback_id = "test_hook_0"
|
||||
query.hook_callbacks[callback_id] = test_hook
|
||||
|
||||
# Simulate hook callback request
|
||||
request = {
|
||||
"type": "control_request",
|
||||
"request_id": "test-hook-1",
|
||||
"request": {
|
||||
"subtype": "hook_callback",
|
||||
"callback_id": callback_id,
|
||||
"input": {"test": "data"},
|
||||
"tool_use_id": "tool-123",
|
||||
},
|
||||
}
|
||||
|
||||
await query._handle_control_request(request)
|
||||
|
||||
# Check hook was called
|
||||
assert len(hook_calls) == 1
|
||||
assert hook_calls[0]["input"] == {"test": "data"}
|
||||
assert hook_calls[0]["tool_use_id"] == "tool-123"
|
||||
|
||||
# Check response
|
||||
assert len(transport.written_messages) > 0
|
||||
last_response = transport.written_messages[-1]
|
||||
assert '"processed": true' in last_response
|
||||
|
||||
|
||||
class TestClaudeCodeOptionsIntegration:
|
||||
"""Test that callbacks work through ClaudeCodeOptions."""
|
||||
|
||||
def test_options_with_callbacks(self):
|
||||
"""Test creating options with callbacks."""
|
||||
|
||||
async def my_callback(
|
||||
tool_name: str, input_data: dict, context: ToolPermissionContext
|
||||
) -> PermissionResultAllow:
|
||||
return PermissionResultAllow()
|
||||
|
||||
async def my_hook(
|
||||
input_data: dict, tool_use_id: str | None, context: HookContext
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
options = ClaudeCodeOptions(
|
||||
can_use_tool=my_callback,
|
||||
hooks={
|
||||
"tool_use_start": [
|
||||
HookMatcher(matcher={"tool": "Bash"}, hooks=[my_hook])
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
assert options.can_use_tool == my_callback
|
||||
assert "tool_use_start" in options.hooks
|
||||
assert len(options.hooks["tool_use_start"]) == 1
|
||||
assert options.hooks["tool_use_start"][0].hooks[0] == my_hook
|
||||
Loading…
Add table
Add a link
Reference in a new issue