From 839300404f89cf74a71abab2f9ad577d97cb8f1f Mon Sep 17 00:00:00 2001 From: Ashwin Bhat Date: Mon, 8 Sep 2025 08:51:40 -0700 Subject: [PATCH] Add custom tool callbacks and e2e tests (#157) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary This PR adds support for custom tool callbacks and comprehensive e2e testing for MCP calculator functionality. ## Key Features Added - **Custom tool permission callbacks** - Allow dynamic tool permission control via `can_use_tool` callback - **E2E test suite** - Real Claude API tests validating MCP tool execution end-to-end - **Fixed MCP calculator example** - Now properly uses `allowed_tools` for permission management ## Changes ### Custom Callbacks - Added `ToolPermissionContext` and `PermissionResult` types for tool permission handling - Implemented `can_use_tool` callback support in SDK client - Added comprehensive tests in `tests/test_tool_callbacks.py` ### E2E Testing Infrastructure - Created `e2e-tests/` directory with pytest-based test suite - `test_mcp_calculator.py` - Tests all calculator operations with real API calls - `conftest.py` - Pytest config with mandatory API key validation - GitHub Actions workflow for automated e2e testing on main branch - Comprehensive documentation in `e2e-tests/README.md` ### Bug Fixes - Fixed MCP calculator example to use `allowed_tools` instead of incorrect `permission_mode` - Resolved tool permission issues preventing MCP tools from executing ## Testing E2E tests require `ANTHROPIC_API_KEY` environment variable and will fail without it. Run locally: ```bash export ANTHROPIC_API_KEY=your-key python -m pytest e2e-tests/ -v -m e2e ``` Run unit tests including callback tests: ```bash python -m pytest tests/test_tool_callbacks.py -v ``` ๐Ÿค– Generated with [Claude Code](https://claude.ai/code) --------- Co-authored-by: Claude Co-authored-by: Kashyap Murali --- .github/workflows/test-e2e.yml | 45 --- .github/workflows/test.yml | 117 +++++-- README.md | 84 +++++ e2e-tests/README.md | 102 +++++++ e2e-tests/conftest.py | 30 ++ e2e-tests/test_sdk_mcp_tools.py | 165 ++++++++++ e2e-tests/test_tool_permissions.py | 43 +++ examples/mcp_calculator.py | 222 ++++++++++++++ examples/tool_permission_callback.py | 158 ++++++++++ src/claude_code_sdk/__init__.py | 16 +- src/claude_code_sdk/_internal/client.py | 39 ++- src/claude_code_sdk/_internal/query.py | 33 +- .../_internal/transport/subprocess_cli.py | 24 +- src/claude_code_sdk/client.py | 24 +- tests/test_tool_callbacks.py | 289 ++++++++++++++++++ 15 files changed, 1301 insertions(+), 90 deletions(-) delete mode 100644 .github/workflows/test-e2e.yml create mode 100644 e2e-tests/README.md create mode 100644 e2e-tests/conftest.py create mode 100644 e2e-tests/test_sdk_mcp_tools.py create mode 100644 e2e-tests/test_tool_permissions.py create mode 100644 examples/mcp_calculator.py create mode 100644 examples/tool_permission_callback.py create mode 100644 tests/test_tool_callbacks.py diff --git a/.github/workflows/test-e2e.yml b/.github/workflows/test-e2e.yml deleted file mode 100644 index c96cde6..0000000 --- a/.github/workflows/test-e2e.yml +++ /dev/null @@ -1,45 +0,0 @@ -name: Claude Code E2E Test - -on: - pull_request: - push: - branches: - - 'main' - -jobs: - integration-test: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.10", "3.11", "3.12", "3.13"] - - steps: - - uses: actions/checkout@v4 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - - name: Install Claude Code - run: | - curl -fsSL https://claude.ai/install.sh | bash - echo "$HOME/.local/bin" >> $GITHUB_PATH - - - name: Verify Claude Code installation - run: claude -v - - - name: Install Python dependencies - run: | - python -m pip install --upgrade pip - pip install -e . - - - name: Run quickstart example - run: python examples/quick_start.py - env: - ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} - - - name: Run streaming mode examples - run: timeout 120 python examples/streaming_mode.py all - env: - ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index aca2b3d..9a3e24a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -4,34 +4,103 @@ on: pull_request: push: branches: - - 'main' + - "main" jobs: test: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.10', '3.11', '3.12', '3.13'] - + python-version: ["3.10", "3.11", "3.12", "3.13"] + steps: - - uses: actions/checkout@v4 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -e ".[dev]" - - - name: Run tests - run: | - python -m pytest tests/ -v --cov=claude_code_sdk --cov-report=xml - - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v4 - with: - file: ./coverage.xml - fail_ci_if_error: false \ No newline at end of file + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Run tests + run: | + python -m pytest tests/ -v --cov=claude_code_sdk --cov-report=xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + file: ./coverage.xml + fail_ci_if_error: false + + test-e2e: + runs-on: ubuntu-latest + needs: test # Run after unit tests pass + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12", "3.13"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Claude Code + run: | + curl -fsSL https://claude.ai/install.sh | bash + echo "$HOME/.local/bin" >> $GITHUB_PATH + + - name: Verify Claude Code installation + run: claude -v + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Run end-to-end tests with real API + env: + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + run: | + python -m pytest e2e-tests/ -v -m e2e + + test-examples: + runs-on: ubuntu-latest + needs: test-e2e # Run after e2e tests + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12", "3.13"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Claude Code + run: | + curl -fsSL https://claude.ai/install.sh | bash + echo "$HOME/.local/bin" >> $GITHUB_PATH + + - name: Verify Claude Code installation + run: claude -v + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e . + + - name: Run example scripts + env: + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + run: | + python examples/quick_start.py + timeout 120 python examples/streaming_mode.py all diff --git a/README.md b/README.md index 214b953..bdcd0d6 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,90 @@ options = ClaudeCodeOptions( ) ``` +### SDK MCP Servers (In-Process) + +The SDK now supports in-process MCP servers that run directly within your Python application, eliminating the need for separate processes. + +#### Creating a Simple Tool + +```python +from claude_code_sdk import tool, create_sdk_mcp_server + +# Define a tool using the @tool decorator +@tool("greet", "Greet a user", {"name": str}) +async def greet_user(args): + return { + "content": [ + {"type": "text", "text": f"Hello, {args['name']}!"} + ] + } + +# Create an SDK MCP server +server = create_sdk_mcp_server( + name="my-tools", + version="1.0.0", + tools=[greet_user] +) + +# Use it with Claude +options = ClaudeCodeOptions( + mcp_servers={"tools": server} +) + +async for message in query(prompt="Greet Alice", options=options): + print(message) +``` + +#### Benefits Over External MCP Servers + +- **No subprocess management** - Runs in the same process as your application +- **Better performance** - No IPC overhead for tool calls +- **Simpler deployment** - Single Python process instead of multiple +- **Easier debugging** - All code runs in the same process +- **Type safety** - Direct Python function calls with type hints + +#### Migration from External Servers + +```python +# BEFORE: External MCP server (separate process) +options = ClaudeCodeOptions( + mcp_servers={ + "calculator": { + "type": "stdio", + "command": "python", + "args": ["-m", "calculator_server"] + } + } +) + +# AFTER: SDK MCP server (in-process) +from my_tools import add, subtract # Your tool functions + +calculator = create_sdk_mcp_server( + name="calculator", + tools=[add, subtract] +) + +options = ClaudeCodeOptions( + mcp_servers={"calculator": calculator} +) +``` + +#### Mixed Server Support + +You can use both SDK and external MCP servers together: + +```python +options = ClaudeCodeOptions( + mcp_servers={ + "internal": sdk_server, # In-process SDK server + "external": { # External subprocess server + "type": "stdio", + "command": "external-server" + } + } +) +``` ## API Reference diff --git a/e2e-tests/README.md b/e2e-tests/README.md new file mode 100644 index 0000000..6dfe374 --- /dev/null +++ b/e2e-tests/README.md @@ -0,0 +1,102 @@ +# End-to-End Tests for Claude Code SDK + +This directory contains end-to-end tests that run against the actual Claude API to verify real-world functionality. + +## Requirements + +### API Key (REQUIRED) + +These tests require a valid Anthropic API key. The tests will **fail** if `ANTHROPIC_API_KEY` is not set. + +Set your API key before running tests: + +```bash +export ANTHROPIC_API_KEY="your-api-key-here" +``` + +### Dependencies + +Install the development dependencies: + +```bash +pip install -e ".[dev]" +``` + +## Running the Tests + +### Run all e2e tests: + +```bash +python -m pytest e2e-tests/ -v +``` + +### Run with e2e marker only: + +```bash +python -m pytest e2e-tests/ -v -m e2e +``` + +### Run a specific test: + +```bash +python -m pytest e2e-tests/test_mcp_calculator.py::test_basic_addition -v +``` + +## Cost Considerations + +โš ๏ธ **Important**: These tests make actual API calls to Claude, which incur costs based on your Anthropic pricing plan. + +- Each test typically uses 1-3 API calls +- Tests use simple prompts to minimize token usage +- The complete test suite should cost less than $0.10 to run + +## Test Coverage + +### MCP Calculator Tests (`test_mcp_calculator.py`) + +Tests the MCP (Model Context Protocol) integration with calculator tools: + +- **test_basic_addition**: Verifies the add tool executes correctly +- **test_division**: Tests division with decimal results +- **test_square_root**: Validates square root calculations +- **test_power**: Tests exponentiation +- **test_multi_step_calculation**: Verifies multiple tools can be used in sequence +- **test_tool_permissions_enforced**: Ensures permission system works correctly + +Each test validates: +1. Tools are actually called (ToolUseBlock present in response) +2. Correct tool inputs are provided +3. Expected results are returned +4. Permission system is enforced + +## CI/CD Integration + +These tests run automatically on: +- Pushes to `main` branch (via GitHub Actions) +- Manual workflow dispatch + +The workflow uses `ANTHROPIC_API_KEY` from GitHub Secrets. + +## Troubleshooting + +### "ANTHROPIC_API_KEY environment variable is required" error +- Set your API key: `export ANTHROPIC_API_KEY=sk-ant-...` +- The tests will not skip - they require the key to run + +### Tests timing out +- Check your API key is valid and has quota available +- Ensure network connectivity to api.anthropic.com + +### Permission denied errors +- Verify the `allowed_tools` parameter includes the necessary MCP tools +- Check that tool names match the expected format (e.g., `mcp__calc__add`) + +## Adding New E2E Tests + +When adding new e2e tests: + +1. Mark tests with `@pytest.mark.e2e` decorator +2. Use the `api_key` fixture to ensure API key is available +3. Keep prompts simple to minimize costs +4. Verify actual tool execution, not just mocked responses +5. Document any special setup requirements in this README \ No newline at end of file diff --git a/e2e-tests/conftest.py b/e2e-tests/conftest.py new file mode 100644 index 0000000..392c213 --- /dev/null +++ b/e2e-tests/conftest.py @@ -0,0 +1,30 @@ +"""Pytest configuration for e2e tests.""" + +import os + +import pytest + + +@pytest.fixture(scope="session") +def api_key(): + """Ensure ANTHROPIC_API_KEY is set for e2e tests.""" + key = os.environ.get("ANTHROPIC_API_KEY") + if not key: + pytest.fail( + "ANTHROPIC_API_KEY environment variable is required for e2e tests. " + "Set it before running: export ANTHROPIC_API_KEY=your-key-here" + ) + return key + + +@pytest.fixture(scope="session") +def event_loop_policy(): + """Use the default event loop policy for all async tests.""" + import asyncio + + return asyncio.get_event_loop_policy() + + +def pytest_configure(config): + """Add e2e marker.""" + config.addinivalue_line("markers", "e2e: marks tests as e2e tests requiring API key") \ No newline at end of file diff --git a/e2e-tests/test_sdk_mcp_tools.py b/e2e-tests/test_sdk_mcp_tools.py new file mode 100644 index 0000000..8c78b60 --- /dev/null +++ b/e2e-tests/test_sdk_mcp_tools.py @@ -0,0 +1,165 @@ +"""End-to-end tests for SDK MCP (inline) tools with real Claude API calls. + +These tests verify that SDK-created MCP tools work correctly through the full stack, +focusing on tool execution mechanics rather than specific tool functionality. +""" + +from typing import Any + +import pytest + +from claude_code_sdk import ( + AssistantMessage, + ClaudeCodeOptions, + ClaudeSDKClient, + UserMessage, + create_sdk_mcp_server, + tool, +) +from claude_code_sdk.types import ToolResultBlock, ToolUseBlock + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_sdk_mcp_tool_execution(): + """Test that SDK MCP tools can be called and executed with allowed_tools.""" + executions = [] + + @tool("echo", "Echo back the input text", {"text": str}) + async def echo_tool(args: dict[str, Any]) -> dict[str, Any]: + """Echo back whatever text is provided.""" + executions.append("echo") + return {"content": [{"type": "text", "text": f"Echo: {args['text']}"}]} + + server = create_sdk_mcp_server( + name="test", + version="1.0.0", + tools=[echo_tool], + ) + + options = ClaudeCodeOptions( + mcp_servers={"test": server}, + allowed_tools=["mcp__test__echo"], + ) + + async with ClaudeSDKClient(options=options) as client: + await client.query("Call the mcp__test__echo tool with any text") + + async for message in client.receive_response(): + pass # Just consume messages + + # Check if the actual Python function was called + assert "echo" in executions, "Echo tool function was not executed" + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_sdk_mcp_permission_enforcement(): + """Test that disallowed_tools prevents SDK MCP tool execution.""" + executions = [] + + @tool("echo", "Echo back the input text", {"text": str}) + async def echo_tool(args: dict[str, Any]) -> dict[str, Any]: + """Echo back whatever text is provided.""" + executions.append("echo") + return {"content": [{"type": "text", "text": f"Echo: {args['text']}"}]} + + @tool("greet", "Greet a person by name", {"name": str}) + async def greet_tool(args: dict[str, Any]) -> dict[str, Any]: + """Greet someone by name.""" + executions.append("greet") + return {"content": [{"type": "text", "text": f"Hello, {args['name']}!"}]} + + server = create_sdk_mcp_server( + name="test", + version="1.0.0", + tools=[echo_tool, greet_tool], + ) + + options = ClaudeCodeOptions( + mcp_servers={"test": server}, + disallowed_tools=["mcp__test__echo"], # Block echo tool + allowed_tools=["mcp__test__greet"], # But allow greet + ) + + async with ClaudeSDKClient(options=options) as client: + await client.query("Use the echo tool to echo 'test' and use greet tool to greet 'Alice'") + + async for message in client.receive_response(): + pass # Just consume messages + + # Check actual function executions + assert "echo" not in executions, "Disallowed echo tool was executed" + assert "greet" in executions, "Allowed greet tool was not executed" + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_sdk_mcp_multiple_tools(): + """Test that multiple SDK MCP tools can be called in sequence.""" + executions = [] + + @tool("echo", "Echo back the input text", {"text": str}) + async def echo_tool(args: dict[str, Any]) -> dict[str, Any]: + """Echo back whatever text is provided.""" + executions.append("echo") + return {"content": [{"type": "text", "text": f"Echo: {args['text']}"}]} + + @tool("greet", "Greet a person by name", {"name": str}) + async def greet_tool(args: dict[str, Any]) -> dict[str, Any]: + """Greet someone by name.""" + executions.append("greet") + return {"content": [{"type": "text", "text": f"Hello, {args['name']}!"}]} + + server = create_sdk_mcp_server( + name="multi", + version="1.0.0", + tools=[echo_tool, greet_tool], + ) + + options = ClaudeCodeOptions( + mcp_servers={"multi": server}, + allowed_tools=["mcp__multi__echo", "mcp__multi__greet"], + ) + + async with ClaudeSDKClient(options=options) as client: + await client.query("Call mcp__multi__echo with text='test' and mcp__multi__greet with name='Bob'") + + async for message in client.receive_response(): + pass # Just consume messages + + # Both tools should have been executed + assert "echo" in executions, "Echo tool was not executed" + assert "greet" in executions, "Greet tool was not executed" + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_sdk_mcp_without_permissions(): + """Test SDK MCP tool behavior without explicit allowed_tools.""" + executions = [] + + @tool("echo", "Echo back the input text", {"text": str}) + async def echo_tool(args: dict[str, Any]) -> dict[str, Any]: + """Echo back whatever text is provided.""" + executions.append("echo") + return {"content": [{"type": "text", "text": f"Echo: {args['text']}"}]} + + server = create_sdk_mcp_server( + name="noperm", + version="1.0.0", + tools=[echo_tool], + ) + + # No allowed_tools specified + options = ClaudeCodeOptions( + mcp_servers={"noperm": server}, + ) + + async with ClaudeSDKClient(options=options) as client: + await client.query("Call the mcp__noperm__echo tool") + + async for message in client.receive_response(): + pass # Just consume messages + + assert "echo" not in executions, "SDK MCP tool was executed" \ No newline at end of file diff --git a/e2e-tests/test_tool_permissions.py b/e2e-tests/test_tool_permissions.py new file mode 100644 index 0000000..c61794e --- /dev/null +++ b/e2e-tests/test_tool_permissions.py @@ -0,0 +1,43 @@ +"""End-to-end tests for tool permission callbacks with real Claude API calls.""" + +import pytest + +from claude_code_sdk import ( + ClaudeCodeOptions, + ClaudeSDKClient, + PermissionResultAllow, + PermissionResultDeny, + ToolPermissionContext, +) + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_permission_callback_gets_called(): + """Test that can_use_tool callback gets invoked.""" + callback_invocations = [] + + async def permission_callback( + tool_name: str, + input_data: dict, + context: ToolPermissionContext + ) -> PermissionResultAllow | PermissionResultDeny: + """Track callback invocation.""" + print(f"Permission callback called for: {tool_name}, input: {input_data}") + callback_invocations.append(tool_name) + return PermissionResultAllow() + + options = ClaudeCodeOptions( + can_use_tool=permission_callback, + ) + + async with ClaudeSDKClient(options=options) as client: + await client.query("Write 'hello world' to /tmp/test.txt") + + async for message in client.receive_response(): + print(f"Got message: {message}") + pass # Just consume messages + + print(f'Callback invocations: {callback_invocations}') + # Verify callback was invoked + assert "Write" in callback_invocations, f"can_use_tool callback should have been invoked for Write tool, got: {callback_invocations}" \ No newline at end of file diff --git a/examples/mcp_calculator.py b/examples/mcp_calculator.py new file mode 100644 index 0000000..2b12bbc --- /dev/null +++ b/examples/mcp_calculator.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +"""Example: Calculator MCP Server. + +This example demonstrates how to create an in-process MCP server with +calculator tools using the Claude Code Python SDK. + +Unlike external MCP servers that require separate processes, this server +runs directly within your Python application, providing better performance +and simpler deployment. +""" + +import asyncio +from typing import Any + +from claude_code_sdk import ( + ClaudeCodeOptions, + create_sdk_mcp_server, + tool, +) + +# Define calculator tools using the @tool decorator + +@tool("add", "Add two numbers", {"a": float, "b": float}) +async def add_numbers(args: dict[str, Any]) -> dict[str, Any]: + """Add two numbers together.""" + result = args["a"] + args["b"] + return { + "content": [ + { + "type": "text", + "text": f"{args['a']} + {args['b']} = {result}" + } + ] + } + + +@tool("subtract", "Subtract one number from another", {"a": float, "b": float}) +async def subtract_numbers(args: dict[str, Any]) -> dict[str, Any]: + """Subtract b from a.""" + result = args["a"] - args["b"] + return { + "content": [ + { + "type": "text", + "text": f"{args['a']} - {args['b']} = {result}" + } + ] + } + + +@tool("multiply", "Multiply two numbers", {"a": float, "b": float}) +async def multiply_numbers(args: dict[str, Any]) -> dict[str, Any]: + """Multiply two numbers.""" + result = args["a"] * args["b"] + return { + "content": [ + { + "type": "text", + "text": f"{args['a']} ร— {args['b']} = {result}" + } + ] + } + + +@tool("divide", "Divide one number by another", {"a": float, "b": float}) +async def divide_numbers(args: dict[str, Any]) -> dict[str, Any]: + """Divide a by b.""" + if args["b"] == 0: + return { + "content": [ + { + "type": "text", + "text": "Error: Division by zero is not allowed" + } + ], + "is_error": True + } + + result = args["a"] / args["b"] + return { + "content": [ + { + "type": "text", + "text": f"{args['a']} รท {args['b']} = {result}" + } + ] + } + + +@tool("sqrt", "Calculate square root", {"n": float}) +async def square_root(args: dict[str, Any]) -> dict[str, Any]: + """Calculate the square root of a number.""" + n = args["n"] + if n < 0: + return { + "content": [ + { + "type": "text", + "text": f"Error: Cannot calculate square root of negative number {n}" + } + ], + "is_error": True + } + + import math + result = math.sqrt(n) + return { + "content": [ + { + "type": "text", + "text": f"โˆš{n} = {result}" + } + ] + } + + +@tool("power", "Raise a number to a power", {"base": float, "exponent": float}) +async def power(args: dict[str, Any]) -> dict[str, Any]: + """Raise base to the exponent power.""" + result = args["base"] ** args["exponent"] + return { + "content": [ + { + "type": "text", + "text": f"{args['base']}^{args['exponent']} = {result}" + } + ] + } + + +def display_message(msg): + """Display message content in a clean format.""" + from claude_code_sdk import ( + AssistantMessage, + ResultMessage, + SystemMessage, + TextBlock, + ToolResultBlock, + ToolUseBlock, + UserMessage, + ) + + if isinstance(msg, UserMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"User: {block.text}") + elif isinstance(block, ToolResultBlock): + print(f"Tool Result: {block.content[:100] if block.content else 'None'}...") + elif isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + elif isinstance(block, ToolUseBlock): + print(f"Using tool: {block.name}") + # Show tool inputs for calculator + if block.input: + print(f" Input: {block.input}") + elif isinstance(msg, SystemMessage): + # Ignore system messages + pass + elif isinstance(msg, ResultMessage): + print("Result ended") + if msg.total_cost_usd: + print(f"Cost: ${msg.total_cost_usd:.6f}") + + +async def main(): + """Run example calculations using the SDK MCP server with streaming client.""" + from claude_code_sdk import ClaudeSDKClient + + # Create the calculator server with all tools + calculator = create_sdk_mcp_server( + name="calculator", + version="2.0.0", + tools=[ + add_numbers, + subtract_numbers, + multiply_numbers, + divide_numbers, + square_root, + power + ] + ) + + # Configure Claude to use the calculator server with allowed tools + # Pre-approve all calculator MCP tools so they can be used without permission prompts + options = ClaudeCodeOptions( + mcp_servers={"calc": calculator}, + allowed_tools=[ + "mcp__calc__add", + "mcp__calc__subtract", + "mcp__calc__multiply", + "mcp__calc__divide", + "mcp__calc__sqrt", + "mcp__calc__power" + ] + ) + + # Example prompts to demonstrate calculator usage + prompts = [ + "List your tools", + "Calculate 15 + 27", + "What is 100 divided by 7?", + "Calculate the square root of 144", + "What is 2 raised to the power of 8?", + "Calculate (12 + 8) * 3 - 10" # Complex calculation + ] + + for prompt in prompts: + print(f"\n{'='*50}") + print(f"Prompt: {prompt}") + print(f"{'='*50}") + + async with ClaudeSDKClient(options=options) as client: + await client.query(prompt) + + async for message in client.receive_response(): + display_message(message) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/tool_permission_callback.py b/examples/tool_permission_callback.py new file mode 100644 index 0000000..8efd879 --- /dev/null +++ b/examples/tool_permission_callback.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +"""Example: Tool Permission Callbacks. + +This example demonstrates how to use tool permission callbacks to control +which tools Claude can use and modify their inputs. +""" + +import asyncio +import json + +from claude_code_sdk import ( + AssistantMessage, + ClaudeCodeOptions, + ClaudeSDKClient, + PermissionResultAllow, + PermissionResultDeny, + ResultMessage, + TextBlock, + ToolPermissionContext, +) + +# Track tool usage for demonstration +tool_usage_log = [] + + +async def my_permission_callback( + tool_name: str, + input_data: dict, + context: ToolPermissionContext +) -> PermissionResultAllow | PermissionResultDeny: + """Control tool permissions based on tool type and input.""" + + # Log the tool request + tool_usage_log.append({ + "tool": tool_name, + "input": input_data, + "suggestions": context.suggestions + }) + + print(f"\n๐Ÿ”ง Tool Permission Request: {tool_name}") + print(f" Input: {json.dumps(input_data, indent=2)}") + + # Always allow read operations + if tool_name in ["Read", "Glob", "Grep"]: + print(f" โœ… Automatically allowing {tool_name} (read-only operation)") + return PermissionResultAllow() + + # Deny write operations to system directories + if tool_name in ["Write", "Edit", "MultiEdit"]: + file_path = input_data.get("file_path", "") + if file_path.startswith("/etc/") or file_path.startswith("/usr/"): + print(f" โŒ Denying write to system directory: {file_path}") + return PermissionResultDeny( + message=f"Cannot write to system directory: {file_path}" + ) + + # Redirect writes to a safe directory + if not file_path.startswith("/tmp/") and not file_path.startswith("./"): + safe_path = f"./safe_output/{file_path.split('/')[-1]}" + print(f" โš ๏ธ Redirecting write from {file_path} to {safe_path}") + modified_input = input_data.copy() + modified_input["file_path"] = safe_path + return PermissionResultAllow( + updated_input=modified_input + ) + + # Check dangerous bash commands + if tool_name == "Bash": + command = input_data.get("command", "") + dangerous_commands = ["rm -rf", "sudo", "chmod 777", "dd if=", "mkfs"] + + for dangerous in dangerous_commands: + if dangerous in command: + print(f" โŒ Denying dangerous command: {command}") + return PermissionResultDeny( + message=f"Dangerous command pattern detected: {dangerous}" + ) + + # Allow but log the command + print(f" โœ… Allowing bash command: {command}") + return PermissionResultAllow() + + # For all other tools, ask the user + print(f" โ“ Unknown tool: {tool_name}") + print(f" Input: {json.dumps(input_data, indent=6)}") + user_input = input(" Allow this tool? (y/N): ").strip().lower() + + if user_input in ("y", "yes"): + return PermissionResultAllow() + else: + return PermissionResultDeny( + message="User denied permission" + ) + + +async def main(): + """Run example with tool permission callbacks.""" + + print("=" * 60) + print("Tool Permission Callback Example") + print("=" * 60) + print("\nThis example demonstrates how to:") + print("1. Allow/deny tools based on type") + print("2. Modify tool inputs for safety") + print("3. Log tool usage") + print("4. Prompt for unknown tools") + print("=" * 60) + + # Configure options with our callback + options = ClaudeCodeOptions( + can_use_tool=my_permission_callback, + # Use default permission mode to ensure callbacks are invoked + permission_mode="default", + cwd="." # Set working directory + ) + + # Create client and send a query that will use multiple tools + async with ClaudeSDKClient(options) as client: + print("\n๐Ÿ“ Sending query to Claude...") + await client.query( + "Please do the following:\n" + "1. List the files in the current directory\n" + "2. Create a simple Python hello world script at hello.py\n" + "3. Run the script to test it" + ) + + print("\n๐Ÿ“จ Receiving response...") + message_count = 0 + + async for message in client.receive_response(): + message_count += 1 + + if isinstance(message, AssistantMessage): + # Print Claude's text responses + for block in message.content: + if isinstance(block, TextBlock): + print(f"\n๐Ÿ’ฌ Claude: {block.text}") + + elif isinstance(message, ResultMessage): + print("\nโœ… Task completed!") + print(f" Duration: {message.duration_ms}ms") + if message.total_cost_usd: + print(f" Cost: ${message.total_cost_usd:.4f}") + print(f" Messages processed: {message_count}") + + # Print tool usage summary + print("\n" + "=" * 60) + print("Tool Usage Summary") + print("=" * 60) + for i, usage in enumerate(tool_usage_log, 1): + print(f"\n{i}. Tool: {usage['tool']}") + print(f" Input: {json.dumps(usage['input'], indent=6)}") + if usage['suggestions']: + print(f" Suggestions: {usage['suggestions']}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/claude_code_sdk/__init__.py b/src/claude_code_sdk/__init__.py index aae45bb..3d93de3 100644 --- a/src/claude_code_sdk/__init__.py +++ b/src/claude_code_sdk/__init__.py @@ -16,8 +16,12 @@ from .client import ClaudeSDKClient from .query import query from .types import ( AssistantMessage, + CanUseTool, ClaudeCodeOptions, ContentBlock, + HookCallback, + HookContext, + HookMatcher, McpSdkServerConfig, McpServerConfig, Message, @@ -30,6 +34,7 @@ from .types import ( SystemMessage, TextBlock, ThinkingBlock, + ToolPermissionContext, ToolResultBlock, ToolUseBlock, UserMessage, @@ -292,11 +297,20 @@ __all__ = [ "ToolUseBlock", "ToolResultBlock", "ContentBlock", - # Permission results (keep these as they may be used by internal callbacks) + # Tool callbacks + "CanUseTool", + "ToolPermissionContext", "PermissionResult", "PermissionResultAllow", "PermissionResultDeny", "PermissionUpdate", + "HookCallback", + "HookContext", + "HookMatcher", + # MCP Server Support + "create_sdk_mcp_server", + "tool", + "SdkMcpTool", # Errors "ClaudeSDKError", "CLIConnectionError", diff --git a/src/claude_code_sdk/_internal/client.py b/src/claude_code_sdk/_internal/client.py index 1d05eb0..037b0bc 100644 --- a/src/claude_code_sdk/_internal/client.py +++ b/src/claude_code_sdk/_internal/client.py @@ -1,6 +1,7 @@ """Internal client implementation.""" from collections.abc import AsyncIterable, AsyncIterator +from dataclasses import replace from typing import Any from ..types import ( @@ -43,19 +44,43 @@ class InternalClient: ) -> AsyncIterator[Message]: """Process a query through transport and Query.""" + # Validate and configure permission settings (matching TypeScript SDK logic) + configured_options = options + if options.can_use_tool: + # canUseTool callback requires streaming mode (AsyncIterable prompt) + if isinstance(prompt, str): + raise ValueError( + "can_use_tool callback requires streaming mode. " + "Please provide prompt as an AsyncIterable instead of a string." + ) + + # canUseTool and permission_prompt_tool_name are mutually exclusive + if options.permission_prompt_tool_name: + raise ValueError( + "can_use_tool callback cannot be used with permission_prompt_tool_name. " + "Please use one or the other." + ) + + # Automatically set permission_prompt_tool_name to "stdio" for control protocol + configured_options = replace(options, permission_prompt_tool_name="stdio") + # Use provided transport or create subprocess transport if transport is not None: chosen_transport = transport else: - chosen_transport = SubprocessCLITransport(prompt=prompt, options=options) + chosen_transport = SubprocessCLITransport( + prompt=prompt, options=configured_options + ) # Connect transport await chosen_transport.connect() - # Extract SDK MCP servers from options + # Extract SDK MCP servers from configured options sdk_mcp_servers = {} - if options.mcp_servers and isinstance(options.mcp_servers, dict): - for name, config in options.mcp_servers.items(): + if configured_options.mcp_servers and isinstance( + configured_options.mcp_servers, dict + ): + for name, config in configured_options.mcp_servers.items(): if isinstance(config, dict) and config.get("type") == "sdk": sdk_mcp_servers[name] = config["instance"] # type: ignore[typeddict-item] @@ -64,9 +89,9 @@ class InternalClient: query = Query( transport=chosen_transport, is_streaming_mode=is_streaming, - can_use_tool=options.can_use_tool, - hooks=self._convert_hooks_to_internal_format(options.hooks) - if options.hooks + can_use_tool=configured_options.can_use_tool, + hooks=self._convert_hooks_to_internal_format(configured_options.hooks) + if configured_options.hooks else None, sdk_mcp_servers=sdk_mcp_servers, ) diff --git a/src/claude_code_sdk/_internal/query.py b/src/claude_code_sdk/_internal/query.py index ea8045a..d83951e 100644 --- a/src/claude_code_sdk/_internal/query.py +++ b/src/claude_code_sdk/_internal/query.py @@ -250,9 +250,11 @@ class Query: # Type narrowing - we've verified these are not None above assert isinstance(server_name, str) assert isinstance(mcp_message, dict) - response_data = await self._handle_sdk_mcp_request( + mcp_response = await self._handle_sdk_mcp_request( server_name, mcp_message ) + # Wrap the MCP response as expected by the control protocol + response_data = {"mcp_response": mcp_response} else: raise Exception(f"Unsupported control request subtype: {subtype}") @@ -357,7 +359,24 @@ class Query: # # This forces us to manually route methods. When Python MCP adds Transport # support, we can refactor to match the TypeScript approach. - if method == "tools/list": + if method == "initialize": + # Handle MCP initialization - hardcoded for tools only, no listChanged + return { + "jsonrpc": "2.0", + "id": message.get("id"), + "result": { + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {} # Tools capability without listChanged + }, + "serverInfo": { + "name": server.name, + "version": server.version or "1.0.0", + }, + }, + } + + elif method == "tools/list": request = ListToolsRequest(method=method) handler = server.request_handlers.get(ListToolsRequest) if handler: @@ -367,7 +386,11 @@ class Query: { "name": tool.name, "description": tool.description, - "inputSchema": tool.inputSchema.model_dump() # type: ignore[union-attr] + "inputSchema": ( + tool.inputSchema.model_dump() + if hasattr(tool.inputSchema, "model_dump") + else tool.inputSchema + ) if tool.inputSchema else {}, } @@ -413,6 +436,10 @@ class Query: "result": response_data, } + elif method == "notifications/initialized": + # Handle initialized notification - just acknowledge it + return {"jsonrpc": "2.0", "result": {}} + # Add more methods here as MCP SDK adds them (resources, prompts, etc.) # This is the limitation Ashwin pointed out - we have to manually update diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 4802e9a..8e990b1 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -125,19 +125,25 @@ class SubprocessCLITransport(Transport): if self._options.mcp_servers: if isinstance(self._options.mcp_servers, dict): - # Filter out SDK servers - they're handled in-process - external_servers = { - name: config - for name, config in self._options.mcp_servers.items() - if not (isinstance(config, dict) and config.get("type") == "sdk") - } + # Process all servers, stripping instance field from SDK servers + servers_for_cli: dict[str, Any] = {} + for name, config in self._options.mcp_servers.items(): + if isinstance(config, dict) and config.get("type") == "sdk": + # For SDK servers, pass everything except the instance field + sdk_config: dict[str, object] = { + k: v for k, v in config.items() if k != "instance" + } + servers_for_cli[name] = sdk_config + else: + # For external servers, pass as-is + servers_for_cli[name] = config - # Only pass external servers to CLI - if external_servers: + # Pass all servers to CLI + if servers_for_cli: cmd.extend( [ "--mcp-config", - json.dumps({"mcpServers": external_servers}), + json.dumps({"mcpServers": servers_for_cli}), ] ) else: diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py index 9b1a087..48666cd 100644 --- a/src/claude_code_sdk/client.py +++ b/src/claude_code_sdk/client.py @@ -3,6 +3,7 @@ import json import os from collections.abc import AsyncIterable, AsyncIterator +from dataclasses import replace from typing import Any from ._errors import CLIConnectionError @@ -134,9 +135,30 @@ class ClaudeSDKClient: actual_prompt = _empty_stream() if prompt is None else prompt + # Validate and configure permission settings (matching TypeScript SDK logic) + if self.options.can_use_tool: + # canUseTool callback requires streaming mode (AsyncIterable prompt) + if isinstance(prompt, str): + raise ValueError( + "can_use_tool callback requires streaming mode. " + "Please provide prompt as an AsyncIterable instead of a string." + ) + + # canUseTool and permission_prompt_tool_name are mutually exclusive + if self.options.permission_prompt_tool_name: + raise ValueError( + "can_use_tool callback cannot be used with permission_prompt_tool_name. " + "Please use one or the other." + ) + + # Automatically set permission_prompt_tool_name to "stdio" for control protocol + options = replace(self.options, permission_prompt_tool_name="stdio") + else: + options = self.options + self._transport = SubprocessCLITransport( prompt=actual_prompt, - options=self.options, + options=options, ) await self._transport.connect() diff --git a/tests/test_tool_callbacks.py b/tests/test_tool_callbacks.py new file mode 100644 index 0000000..769de13 --- /dev/null +++ b/tests/test_tool_callbacks.py @@ -0,0 +1,289 @@ +"""Tests for tool permission callbacks and hook callbacks.""" + +import pytest + +from claude_code_sdk import ( + ClaudeCodeOptions, + HookContext, + HookMatcher, + PermissionResultAllow, + PermissionResultDeny, + ToolPermissionContext, +) +from claude_code_sdk._internal.query import Query +from claude_code_sdk._internal.transport import Transport + + +class MockTransport(Transport): + """Mock transport for testing.""" + + def __init__(self): + self.written_messages = [] + self.messages_to_read = [] + self._connected = False + + async def connect(self) -> None: + self._connected = True + + async def close(self) -> None: + self._connected = False + + async def write(self, data: str) -> None: + self.written_messages.append(data) + + async def end_input(self) -> None: + pass + + def read_messages(self): + async def _read(): + for msg in self.messages_to_read: + yield msg + + return _read() + + def is_ready(self) -> bool: + return self._connected + + +class TestToolPermissionCallbacks: + """Test tool permission callback functionality.""" + + @pytest.mark.asyncio + async def test_permission_callback_allow(self): + """Test callback that allows tool execution.""" + callback_invoked = False + + async def allow_callback( + tool_name: str, input_data: dict, context: ToolPermissionContext + ) -> PermissionResultAllow: + nonlocal callback_invoked + callback_invoked = True + assert tool_name == "TestTool" + assert input_data == {"param": "value"} + return PermissionResultAllow() + + transport = MockTransport() + query = Query( + transport=transport, + is_streaming_mode=True, + can_use_tool=allow_callback, + hooks=None, + ) + + # Simulate control request + request = { + "type": "control_request", + "request_id": "test-1", + "request": { + "subtype": "can_use_tool", + "tool_name": "TestTool", + "input": {"param": "value"}, + "permission_suggestions": [], + }, + } + + await query._handle_control_request(request) + + # Check callback was invoked + assert callback_invoked + + # Check response was sent + assert len(transport.written_messages) == 1 + response = transport.written_messages[0] + assert '"allow": true' in response + + @pytest.mark.asyncio + async def test_permission_callback_deny(self): + """Test callback that denies tool execution.""" + + async def deny_callback( + tool_name: str, input_data: dict, context: ToolPermissionContext + ) -> PermissionResultDeny: + return PermissionResultDeny(message="Security policy violation") + + transport = MockTransport() + query = Query( + transport=transport, + is_streaming_mode=True, + can_use_tool=deny_callback, + hooks=None, + ) + + request = { + "type": "control_request", + "request_id": "test-2", + "request": { + "subtype": "can_use_tool", + "tool_name": "DangerousTool", + "input": {"command": "rm -rf /"}, + "permission_suggestions": ["deny"], + }, + } + + await query._handle_control_request(request) + + # Check response + assert len(transport.written_messages) == 1 + response = transport.written_messages[0] + assert '"allow": false' in response + assert '"reason": "Security policy violation"' in response + + @pytest.mark.asyncio + async def test_permission_callback_input_modification(self): + """Test callback that modifies tool input.""" + + async def modify_callback( + tool_name: str, input_data: dict, context: ToolPermissionContext + ) -> PermissionResultAllow: + # Modify the input to add safety flag + modified_input = input_data.copy() + modified_input["safe_mode"] = True + return PermissionResultAllow(updated_input=modified_input) + + transport = MockTransport() + query = Query( + transport=transport, + is_streaming_mode=True, + can_use_tool=modify_callback, + hooks=None, + ) + + request = { + "type": "control_request", + "request_id": "test-3", + "request": { + "subtype": "can_use_tool", + "tool_name": "WriteTool", + "input": {"file_path": "/etc/passwd"}, + "permission_suggestions": [], + }, + } + + await query._handle_control_request(request) + + # Check response includes modified input + assert len(transport.written_messages) == 1 + response = transport.written_messages[0] + assert '"allow": true' in response + assert '"safe_mode": true' in response + + @pytest.mark.asyncio + async def test_callback_exception_handling(self): + """Test that callback exceptions are properly handled.""" + + async def error_callback( + tool_name: str, input_data: dict, context: ToolPermissionContext + ) -> PermissionResultAllow: + raise ValueError("Callback error") + + transport = MockTransport() + query = Query( + transport=transport, + is_streaming_mode=True, + can_use_tool=error_callback, + hooks=None, + ) + + request = { + "type": "control_request", + "request_id": "test-5", + "request": { + "subtype": "can_use_tool", + "tool_name": "TestTool", + "input": {}, + "permission_suggestions": [], + }, + } + + await query._handle_control_request(request) + + # Check error response was sent + assert len(transport.written_messages) == 1 + response = transport.written_messages[0] + assert '"subtype": "error"' in response + assert "Callback error" in response + + +class TestHookCallbacks: + """Test hook callback functionality.""" + + @pytest.mark.asyncio + async def test_hook_execution(self): + """Test that hooks are called at appropriate times.""" + hook_calls = [] + + async def test_hook( + input_data: dict, tool_use_id: str | None, context: HookContext + ) -> dict: + hook_calls.append({"input": input_data, "tool_use_id": tool_use_id}) + return {"processed": True} + + transport = MockTransport() + + # Create hooks configuration + hooks = { + "tool_use_start": [{"matcher": {"tool": "TestTool"}, "hooks": [test_hook]}] + } + + query = Query( + transport=transport, is_streaming_mode=True, can_use_tool=None, hooks=hooks + ) + + # Manually register the hook callback to avoid needing the full initialize flow + callback_id = "test_hook_0" + query.hook_callbacks[callback_id] = test_hook + + # Simulate hook callback request + request = { + "type": "control_request", + "request_id": "test-hook-1", + "request": { + "subtype": "hook_callback", + "callback_id": callback_id, + "input": {"test": "data"}, + "tool_use_id": "tool-123", + }, + } + + await query._handle_control_request(request) + + # Check hook was called + assert len(hook_calls) == 1 + assert hook_calls[0]["input"] == {"test": "data"} + assert hook_calls[0]["tool_use_id"] == "tool-123" + + # Check response + assert len(transport.written_messages) > 0 + last_response = transport.written_messages[-1] + assert '"processed": true' in last_response + + +class TestClaudeCodeOptionsIntegration: + """Test that callbacks work through ClaudeCodeOptions.""" + + def test_options_with_callbacks(self): + """Test creating options with callbacks.""" + + async def my_callback( + tool_name: str, input_data: dict, context: ToolPermissionContext + ) -> PermissionResultAllow: + return PermissionResultAllow() + + async def my_hook( + input_data: dict, tool_use_id: str | None, context: HookContext + ) -> dict: + return {} + + options = ClaudeCodeOptions( + can_use_tool=my_callback, + hooks={ + "tool_use_start": [ + HookMatcher(matcher={"tool": "Bash"}, hooks=[my_hook]) + ] + }, + ) + + assert options.can_use_tool == my_callback + assert "tool_use_start" in options.hooks + assert len(options.hooks["tool_use_start"]) == 1 + assert options.hooks["tool_use_start"][0].hooks[0] == my_hook