From 95775ac657cac5290d5020dc73ef0aca438f657d Mon Sep 17 00:00:00 2001 From: Kashyap Murali Date: Sun, 31 Aug 2025 20:46:28 -0700 Subject: [PATCH] feat: Add SDK MCP server support on top of control protocol - Rebased SDK MCP implementation onto PR #139's control protocol refactoring - Moved SDK MCP handling from Transport to Query class for proper layering - Transport now only filters SDK servers from CLI config - Query class handles SDK MCP control requests via _handle_sdk_mcp_request - Added tool decorator and create_sdk_mcp_server API functions - Added McpSdkServerConfig type definition - Updated documentation with SDK MCP examples - Added integration tests for SDK MCP functionality This implementation properly layers SDK MCP on top of the bidirectional control protocol from PR #139, ensuring clean separation between transport (I/O) and query (protocol) layers. --- README.md | 85 +++++++ examples/mcp_calculator.py | 181 ++++++++++++++ src/claude_code_sdk/__init__.py | 235 ++++++++++++++++++ src/claude_code_sdk/_internal/client.py | 8 + src/claude_code_sdk/_internal/query.py | 85 ++++++- .../_internal/transport/subprocess_cli.py | 22 +- src/claude_code_sdk/types.py | 15 +- tests/test_sdk_mcp_integration.py | 214 ++++++++++++++++ 8 files changed, 835 insertions(+), 10 deletions(-) create mode 100644 examples/mcp_calculator.py create mode 100644 tests/test_sdk_mcp_integration.py diff --git a/README.md b/README.md index fd91924..bdcd0d6 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,91 @@ options = ClaudeCodeOptions( ) ``` +### SDK MCP Servers (In-Process) + +The SDK now supports in-process MCP servers that run directly within your Python application, eliminating the need for separate processes. + +#### Creating a Simple Tool + +```python +from claude_code_sdk import tool, create_sdk_mcp_server + +# Define a tool using the @tool decorator +@tool("greet", "Greet a user", {"name": str}) +async def greet_user(args): + return { + "content": [ + {"type": "text", "text": f"Hello, {args['name']}!"} + ] + } + +# Create an SDK MCP server +server = create_sdk_mcp_server( + name="my-tools", + version="1.0.0", + tools=[greet_user] +) + +# Use it with Claude +options = ClaudeCodeOptions( + mcp_servers={"tools": server} +) + +async for message in query(prompt="Greet Alice", options=options): + print(message) +``` + +#### Benefits Over External MCP Servers + +- **No subprocess management** - Runs in the same process as your application +- **Better performance** - No IPC overhead for tool calls +- **Simpler deployment** - Single Python process instead of multiple +- **Easier debugging** - All code runs in the same process +- **Type safety** - Direct Python function calls with type hints + +#### Migration from External Servers + +```python +# BEFORE: External MCP server (separate process) +options = ClaudeCodeOptions( + mcp_servers={ + "calculator": { + "type": "stdio", + "command": "python", + "args": ["-m", "calculator_server"] + } + } +) + +# AFTER: SDK MCP server (in-process) +from my_tools import add, subtract # Your tool functions + +calculator = create_sdk_mcp_server( + name="calculator", + tools=[add, subtract] +) + +options = ClaudeCodeOptions( + mcp_servers={"calculator": calculator} +) +``` + +#### Mixed Server Support + +You can use both SDK and external MCP servers together: + +```python +options = ClaudeCodeOptions( + mcp_servers={ + "internal": sdk_server, # In-process SDK server + "external": { # External subprocess server + "type": "stdio", + "command": "external-server" + } + } +) +``` + ## API Reference ### `query(prompt, options=None)` diff --git a/examples/mcp_calculator.py b/examples/mcp_calculator.py new file mode 100644 index 0000000..4c516b9 --- /dev/null +++ b/examples/mcp_calculator.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +"""Example: Calculator MCP Server. + +This example demonstrates how to create an in-process MCP server with +calculator tools using the Claude Code Python SDK. + +Unlike external MCP servers that require separate processes, this server +runs directly within your Python application, providing better performance +and simpler deployment. +""" + +import asyncio +from typing import Any + +from claude_code_sdk import ( + ClaudeCodeOptions, + create_sdk_mcp_server, + query, + tool, +) + +# Define calculator tools using the @tool decorator + +@tool("add", "Add two numbers", {"a": float, "b": float}) +async def add_numbers(args: dict[str, Any]) -> dict[str, Any]: + """Add two numbers together.""" + result = args["a"] + args["b"] + return { + "content": [ + { + "type": "text", + "text": f"{args['a']} + {args['b']} = {result}" + } + ] + } + + +@tool("subtract", "Subtract one number from another", {"a": float, "b": float}) +async def subtract_numbers(args: dict[str, Any]) -> dict[str, Any]: + """Subtract b from a.""" + result = args["a"] - args["b"] + return { + "content": [ + { + "type": "text", + "text": f"{args['a']} - {args['b']} = {result}" + } + ] + } + + +@tool("multiply", "Multiply two numbers", {"a": float, "b": float}) +async def multiply_numbers(args: dict[str, Any]) -> dict[str, Any]: + """Multiply two numbers.""" + result = args["a"] * args["b"] + return { + "content": [ + { + "type": "text", + "text": f"{args['a']} × {args['b']} = {result}" + } + ] + } + + +@tool("divide", "Divide one number by another", {"a": float, "b": float}) +async def divide_numbers(args: dict[str, Any]) -> dict[str, Any]: + """Divide a by b.""" + if args["b"] == 0: + return { + "content": [ + { + "type": "text", + "text": "Error: Division by zero is not allowed" + } + ], + "is_error": True + } + + result = args["a"] / args["b"] + return { + "content": [ + { + "type": "text", + "text": f"{args['a']} ÷ {args['b']} = {result}" + } + ] + } + + +@tool("sqrt", "Calculate square root", {"n": float}) +async def square_root(args: dict[str, Any]) -> dict[str, Any]: + """Calculate the square root of a number.""" + n = args["n"] + if n < 0: + return { + "content": [ + { + "type": "text", + "text": f"Error: Cannot calculate square root of negative number {n}" + } + ], + "is_error": True + } + + import math + result = math.sqrt(n) + return { + "content": [ + { + "type": "text", + "text": f"√{n} = {result}" + } + ] + } + + +@tool("power", "Raise a number to a power", {"base": float, "exponent": float}) +async def power(args: dict[str, Any]) -> dict[str, Any]: + """Raise base to the exponent power.""" + result = args["base"] ** args["exponent"] + return { + "content": [ + { + "type": "text", + "text": f"{args['base']}^{args['exponent']} = {result}" + } + ] + } + + +async def main(): + """Run example calculations using the SDK MCP server.""" + + # Create the calculator server with all tools + calculator = create_sdk_mcp_server( + name="calculator", + version="2.0.0", + tools=[ + add_numbers, + subtract_numbers, + multiply_numbers, + divide_numbers, + square_root, + power + ] + ) + + # Configure Claude to use the calculator server + options = ClaudeCodeOptions( + mcp_servers={"calc": calculator}, + # Allow Claude to use calculator tools without permission prompts + permission_mode="bypassPermissions" + ) + + # Example prompts to demonstrate calculator usage + prompts = [ + "Calculate 15 + 27", + "What is 100 divided by 7?", + "Calculate the square root of 144", + "What is 2 raised to the power of 8?", + "Calculate (12 + 8) * 3 - 10" # Complex calculation + ] + + for prompt in prompts: + print(f"\n{'='*50}") + print(f"Prompt: {prompt}") + print(f"{'='*50}") + + async for message in query(prompt=prompt, options=options): + # Print the message content + if hasattr(message, 'content'): + for content_block in message.content: + if hasattr(content_block, 'text'): + print(f"Claude: {content_block.text}") + elif hasattr(content_block, 'name'): + print(f"Using tool: {content_block.name}") + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/src/claude_code_sdk/__init__.py b/src/claude_code_sdk/__init__.py index f2b9bdb..a26640a 100644 --- a/src/claude_code_sdk/__init__.py +++ b/src/claude_code_sdk/__init__.py @@ -1,5 +1,9 @@ """Claude SDK for Python.""" +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any, Generic, TypeVar, Union + from ._errors import ( ClaudeSDKError, CLIConnectionError, @@ -14,6 +18,7 @@ from .types import ( AssistantMessage, ClaudeCodeOptions, ContentBlock, + McpSdkServerConfig, McpServerConfig, Message, PermissionMode, @@ -26,6 +31,231 @@ from .types import ( UserMessage, ) +# MCP Server Support + +T = TypeVar('T') + +@dataclass +class SdkMcpTool(Generic[T]): + """Definition for an SDK MCP tool.""" + name: str + description: str + input_schema: type[T] | dict[str, Any] + handler: Callable[[T], Awaitable[dict[str, Any]]] + + +def tool( + name: str, + description: str, + input_schema: type | dict[str, Any] +) -> Callable[[Callable[[Any], Awaitable[dict[str, Any]]]], SdkMcpTool]: + """Decorator for defining MCP tools with type safety. + + Creates a tool that can be used with SDK MCP servers. The tool runs + in-process within your Python application, providing better performance + than external MCP servers. + + Args: + name: Unique identifier for the tool. This is what Claude will use + to reference the tool in function calls. + description: Human-readable description of what the tool does. + This helps Claude understand when to use the tool. + input_schema: Schema defining the tool's input parameters. + Can be either: + - A dictionary mapping parameter names to types (e.g., {"text": str}) + - A TypedDict class for more complex schemas + - A JSON Schema dictionary for full validation + + Returns: + A decorator function that wraps the tool implementation and returns + an SdkMcpTool instance ready for use with create_sdk_mcp_server(). + + Example: + Basic tool with simple schema: + >>> @tool("greet", "Greet a user", {"name": str}) + ... async def greet(args): + ... return {"content": [{"type": "text", "text": f"Hello, {args['name']}!"}]} + + Tool with multiple parameters: + >>> @tool("add", "Add two numbers", {"a": float, "b": float}) + ... async def add_numbers(args): + ... result = args["a"] + args["b"] + ... return {"content": [{"type": "text", "text": f"Result: {result}"}]} + + Tool with error handling: + >>> @tool("divide", "Divide two numbers", {"a": float, "b": float}) + ... async def divide(args): + ... if args["b"] == 0: + ... return {"content": [{"type": "text", "text": "Error: Division by zero"}], "is_error": True} + ... return {"content": [{"type": "text", "text": f"Result: {args['a'] / args['b']}"}]} + + Notes: + - The tool function must be async (defined with async def) + - The function receives a single dict argument with the input parameters + - The function should return a dict with a "content" key containing the response + - Errors can be indicated by including "is_error": True in the response + """ + def decorator(handler: Callable[[Any], Awaitable[dict[str, Any]]]) -> SdkMcpTool: + return SdkMcpTool(name=name, description=description, input_schema=input_schema, handler=handler) + return decorator + + +def create_sdk_mcp_server( + name: str, + version: str = "1.0.0", + tools: list[SdkMcpTool] | None = None +) -> McpSdkServerConfig: + """Create an in-process MCP server that runs within your Python application. + + Unlike external MCP servers that run as separate processes, SDK MCP servers + run directly in your application's process. This provides: + - Better performance (no IPC overhead) + - Simpler deployment (single process) + - Easier debugging (same process) + - Direct access to your application's state + + Args: + name: Unique identifier for the server. This name is used to reference + the server in the mcp_servers configuration. + version: Server version string. Defaults to "1.0.0". This is for + informational purposes and doesn't affect functionality. + tools: List of SdkMcpTool instances created with the @tool decorator. + These are the functions that Claude can call through this server. + If None or empty, the server will have no tools (rarely useful). + + Returns: + McpSdkServerConfig: A configuration object that can be passed to + ClaudeCodeOptions.mcp_servers. This config contains the server + instance and metadata needed for the SDK to route tool calls. + + Example: + Simple calculator server: + >>> @tool("add", "Add numbers", {"a": float, "b": float}) + ... async def add(args): + ... return {"content": [{"type": "text", "text": f"Sum: {args['a'] + args['b']}"}]} + >>> + >>> @tool("multiply", "Multiply numbers", {"a": float, "b": float}) + ... async def multiply(args): + ... return {"content": [{"type": "text", "text": f"Product: {args['a'] * args['b']}"}]} + >>> + >>> calculator = create_sdk_mcp_server( + ... name="calculator", + ... version="2.0.0", + ... tools=[add, multiply] + ... ) + >>> + >>> # Use with Claude + >>> options = ClaudeCodeOptions( + ... mcp_servers={"calc": calculator}, + ... allowed_tools=["add", "multiply"] + ... ) + + Server with application state access: + >>> class DataStore: + ... def __init__(self): + ... self.items = [] + ... + >>> store = DataStore() + >>> + >>> @tool("add_item", "Add item to store", {"item": str}) + ... async def add_item(args): + ... store.items.append(args["item"]) + ... return {"content": [{"type": "text", "text": f"Added: {args['item']}"}]} + >>> + >>> server = create_sdk_mcp_server("store", tools=[add_item]) + + Notes: + - The server runs in the same process as your Python application + - Tools have direct access to your application's variables and state + - No subprocess or IPC overhead for tool calls + - Server lifecycle is managed automatically by the SDK + + See Also: + - tool(): Decorator for creating tool functions + - ClaudeCodeOptions: Configuration for using servers with query() + """ + from mcp.server import Server + from mcp.types import TextContent, Tool + + # Create MCP server instance + server = Server(name, version=version) + + # Register tools if provided + if tools: + # Store tools for access in handlers + tool_map = {tool_def.name: tool_def for tool_def in tools} + + # Register list_tools handler to expose available tools + @server.list_tools() + async def list_tools() -> list[Tool]: + """Return the list of available tools.""" + tool_list = [] + for tool_def in tools: + # Convert input_schema to JSON Schema format + if isinstance(tool_def.input_schema, dict): + # Check if it's already a JSON schema + if "type" in tool_def.input_schema and "properties" in tool_def.input_schema: + schema = tool_def.input_schema + else: + # Simple dict mapping names to types - convert to JSON schema + properties = {} + for param_name, param_type in tool_def.input_schema.items(): + if param_type is str: + properties[param_name] = {"type": "string"} + elif param_type is int: + properties[param_name] = {"type": "integer"} + elif param_type is float: + properties[param_name] = {"type": "number"} + elif param_type is bool: + properties[param_name] = {"type": "boolean"} + else: + properties[param_name] = {"type": "string"} # Default + schema = { + "type": "object", + "properties": properties, + "required": list(properties.keys()) + } + else: + # For TypedDict or other types, create basic schema + schema = {"type": "object", "properties": {}} + + tool_list.append(Tool( + name=tool_def.name, + description=tool_def.description, + inputSchema=schema + )) + return tool_list + + # Register call_tool handler to execute tools + @server.call_tool() + async def call_tool(name: str, arguments: dict) -> Any: + """Execute a tool by name with given arguments.""" + if name not in tool_map: + raise ValueError(f"Tool '{name}' not found") + + tool_def = tool_map[name] + # Call the tool's handler with arguments + result = await tool_def.handler(arguments) + + # Convert result to MCP format + # The decorator expects us to return the content, not a CallToolResult + # It will wrap our return value in CallToolResult + content = [] + if "content" in result: + for item in result["content"]: + if item.get("type") == "text": + content.append(TextContent(type="text", text=item["text"])) + + # Return just the content list - the decorator wraps it + return content + + # Return SDK server configuration + return McpSdkServerConfig( + type="sdk", + name=name, + instance=server + ) + __version__ = "0.0.20" __all__ = [ @@ -37,6 +267,7 @@ __all__ = [ # Types "PermissionMode", "McpServerConfig", + "McpSdkServerConfig", "UserMessage", "AssistantMessage", "SystemMessage", @@ -48,6 +279,10 @@ __all__ = [ "ToolUseBlock", "ToolResultBlock", "ContentBlock", + # MCP Server Support + "create_sdk_mcp_server", + "tool", + "SdkMcpTool", # Errors "ClaudeSDKError", "CLIConnectionError", diff --git a/src/claude_code_sdk/_internal/client.py b/src/claude_code_sdk/_internal/client.py index ccfc1e8..1b27a33 100644 --- a/src/claude_code_sdk/_internal/client.py +++ b/src/claude_code_sdk/_internal/client.py @@ -36,6 +36,13 @@ class InternalClient: # Connect transport await chosen_transport.connect() + # Extract SDK MCP servers from options + sdk_mcp_servers = {} + if options.mcp_servers and isinstance(options.mcp_servers, dict): + for name, config in options.mcp_servers.items(): + if isinstance(config, dict) and config.get("type") == "sdk": + sdk_mcp_servers[name] = config["instance"] + # Create Query to handle control protocol is_streaming = not isinstance(prompt, str) query = Query( @@ -43,6 +50,7 @@ class InternalClient: is_streaming_mode=is_streaming, can_use_tool=None, # TODO: Add support for can_use_tool callback hooks=None, # TODO: Add support for hooks + sdk_mcp_servers=sdk_mcp_servers, ) try: diff --git a/src/claude_code_sdk/_internal/query.py b/src/claude_code_sdk/_internal/query.py index 815522f..b477f27 100644 --- a/src/claude_code_sdk/_internal/query.py +++ b/src/claude_code_sdk/_internal/query.py @@ -5,10 +5,13 @@ import json import logging import os from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable -from typing import Any +from typing import TYPE_CHECKING, Any from .transport import Transport +if TYPE_CHECKING: + from mcp.server import Server as McpServer + logger = logging.getLogger(__name__) @@ -32,6 +35,7 @@ class Query: ] | None = None, hooks: dict[str, list[dict[str, Any]]] | None = None, + sdk_mcp_servers: dict[str, "McpServer"] | None = None, ): """Initialize Query with transport and callbacks. @@ -40,11 +44,13 @@ class Query: is_streaming_mode: Whether using streaming (bidirectional) mode can_use_tool: Optional callback for tool permission requests hooks: Optional hook configurations + sdk_mcp_servers: Optional SDK MCP server instances """ self.transport = transport self.is_streaming_mode = is_streaming_mode self.can_use_tool = can_use_tool self.hooks = hooks or {} + self.sdk_mcp_servers = sdk_mcp_servers or {} # Control protocol state self.pending_control_responses: dict[str, asyncio.Future[dict[str, Any]]] = {} @@ -184,6 +190,16 @@ class Query: {"signal": None}, # TODO: Add abort signal support ) + elif subtype == "mcp_request": + # Handle SDK MCP request + server_name = request_data.get("server_name") + mcp_message = request_data.get("message") + + if not server_name or not mcp_message: + raise Exception("Missing server_name or message for MCP request") + + response_data = await self._handle_sdk_mcp_request(server_name, mcp_message) + else: raise Exception(f"Unsupported control request subtype: {subtype}") @@ -241,6 +257,73 @@ class Query: self.pending_control_responses.pop(request_id, None) raise Exception(f"Control request timeout: {request.get('subtype')}") from e + async def _handle_sdk_mcp_request(self, server_name: str, message: dict) -> dict: + """Handle an MCP request for an SDK server. + + Args: + server_name: Name of the SDK MCP server + message: The JSONRPC message + + Returns: + The response message + """ + if server_name not in self.sdk_mcp_servers: + return { + "jsonrpc": "2.0", + "id": message.get("id"), + "error": { + "code": -32601, + "message": f"Server '{server_name}' not found" + } + } + + server = self.sdk_mcp_servers[server_name] + method = message.get("method") + params = message.get("params", {}) + + try: + # Route to appropriate handler based on method + if method == "tools/list": + # Get the list_tools handler and call it + handler = server.request_handlers.get("tools/list") + if handler: + tools = await handler() + return { + "jsonrpc": "2.0", + "id": message.get("id"), + "result": {"tools": [t.model_dump() for t in tools]} + } + elif method == "tools/call": + # Get the call_tool handler and call it + handler = server.request_handlers.get("tools/call") + if handler: + result = await handler(params.get("name"), params.get("arguments", {})) + return { + "jsonrpc": "2.0", + "id": message.get("id"), + "result": result + } + + # Method not found + return { + "jsonrpc": "2.0", + "id": message.get("id"), + "error": { + "code": -32601, + "message": f"Method '{method}' not found" + } + } + + except Exception as e: + return { + "jsonrpc": "2.0", + "id": message.get("id"), + "error": { + "code": -32603, + "message": str(e) + } + } + async def interrupt(self) -> None: """Send interrupt control request.""" await self._send_control_request({"subtype": "interrupt"}) diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 2048d60..53f0e27 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -127,13 +127,21 @@ class SubprocessCLITransport(Transport): if self._options.mcp_servers: if isinstance(self._options.mcp_servers, dict): - # Dict format: serialize to JSON - cmd.extend( - [ - "--mcp-config", - json.dumps({"mcpServers": self._options.mcp_servers}), - ] - ) + # Filter out SDK servers - they're handled in-process + external_servers = { + name: config + for name, config in self._options.mcp_servers.items() + if not (isinstance(config, dict) and config.get("type") == "sdk") + } + + # Only pass external servers to CLI + if external_servers: + cmd.extend( + [ + "--mcp-config", + json.dumps({"mcpServers": external_servers}), + ] + ) else: # String or Path format: pass directly as file path or JSON string cmd.extend(["--mcp-config", str(self._options.mcp_servers)]) diff --git a/src/claude_code_sdk/types.py b/src/claude_code_sdk/types.py index 2f0d0e8..f42336e 100644 --- a/src/claude_code_sdk/types.py +++ b/src/claude_code_sdk/types.py @@ -2,10 +2,13 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Literal, TypedDict +from typing import TYPE_CHECKING, Any, Literal, TypedDict from typing_extensions import NotRequired # For Python < 3.11 compatibility +if TYPE_CHECKING: + from mcp.server import Server as McpServer + # Permission modes PermissionMode = Literal["default", "acceptEdits", "plan", "bypassPermissions"] @@ -36,7 +39,15 @@ class McpHttpServerConfig(TypedDict): headers: NotRequired[dict[str, str]] -McpServerConfig = McpStdioServerConfig | McpSSEServerConfig | McpHttpServerConfig +class McpSdkServerConfig(TypedDict): + """SDK MCP server configuration.""" + + type: Literal["sdk"] + name: str + instance: "McpServer" + + +McpServerConfig = McpStdioServerConfig | McpSSEServerConfig | McpHttpServerConfig | McpSdkServerConfig # Content block types diff --git a/tests/test_sdk_mcp_integration.py b/tests/test_sdk_mcp_integration.py new file mode 100644 index 0000000..44d3a2e --- /dev/null +++ b/tests/test_sdk_mcp_integration.py @@ -0,0 +1,214 @@ +"""Integration tests for SDK MCP server support. + +This test file verifies that SDK MCP servers work correctly through the full stack, +matching the TypeScript SDK test/sdk.test.ts pattern. +""" + +from typing import Any + +import pytest + +from claude_code_sdk import ( + ClaudeCodeOptions, + create_sdk_mcp_server, + tool, +) + + +@pytest.mark.asyncio +async def test_sdk_mcp_server_handlers(): + """Test that SDK MCP server handlers are properly registered.""" + # Track tool executions + tool_executions: list[dict[str, Any]] = [] + + # Create SDK MCP server with multiple tools + @tool("greet_user", "Greets a user by name", {"name": str}) + async def greet_user(args: dict[str, Any]) -> dict[str, Any]: + tool_executions.append({"name": "greet_user", "args": args}) + return { + "content": [ + {"type": "text", "text": f"Hello, {args['name']}!"} + ] + } + + @tool("add_numbers", "Adds two numbers", {"a": float, "b": float}) + async def add_numbers(args: dict[str, Any]) -> dict[str, Any]: + tool_executions.append({"name": "add_numbers", "args": args}) + result = args["a"] + args["b"] + return { + "content": [ + {"type": "text", "text": f"The sum is {result}"} + ] + } + + server_config = create_sdk_mcp_server( + name="test-sdk-server", + version="1.0.0", + tools=[greet_user, add_numbers] + ) + + # Verify server configuration + assert server_config["type"] == "sdk" + assert server_config["name"] == "test-sdk-server" + assert "instance" in server_config + + # Get the server instance + server = server_config["instance"] + + # Import the request types to check handlers + from mcp.types import CallToolRequest, ListToolsRequest + + # Verify handlers are registered + assert ListToolsRequest in server.request_handlers + assert CallToolRequest in server.request_handlers + + # Test list_tools handler - the decorator wraps our function + list_handler = server.request_handlers[ListToolsRequest] + request = ListToolsRequest(method="tools/list") + response = await list_handler(request) + # Response is ServerResult with nested ListToolsResult + assert len(response.root.tools) == 2 + + # Check tool definitions + tool_names = [t.name for t in response.root.tools] + assert "greet_user" in tool_names + assert "add_numbers" in tool_names + + # Test call_tool handler + call_handler = server.request_handlers[CallToolRequest] + + # Call greet_user - CallToolRequest wraps the call + from mcp.types import CallToolRequestParams + greet_request = CallToolRequest( + method="tools/call", + params=CallToolRequestParams(name="greet_user", arguments={"name": "Alice"}) + ) + result = await call_handler(greet_request) + # Response is ServerResult with nested CallToolResult + assert result.root.content[0].text == "Hello, Alice!" + assert len(tool_executions) == 1 + assert tool_executions[0]["name"] == "greet_user" + assert tool_executions[0]["args"]["name"] == "Alice" + + # Call add_numbers + add_request = CallToolRequest( + method="tools/call", + params=CallToolRequestParams(name="add_numbers", arguments={"a": 5, "b": 3}) + ) + result = await call_handler(add_request) + assert "8" in result.root.content[0].text + assert len(tool_executions) == 2 + assert tool_executions[1]["name"] == "add_numbers" + assert tool_executions[1]["args"]["a"] == 5 + assert tool_executions[1]["args"]["b"] == 3 + + +@pytest.mark.asyncio +async def test_tool_creation(): + """Test that tools can be created with proper schemas.""" + @tool("echo", "Echo input", {"input": str}) + async def echo_tool(args: dict[str, Any]) -> dict[str, Any]: + return {"output": args["input"]} + + # Verify tool was created + assert echo_tool.name == "echo" + assert echo_tool.description == "Echo input" + assert echo_tool.input_schema == {"input": str} + assert callable(echo_tool.handler) + + # Test the handler works + result = await echo_tool.handler({"input": "test"}) + assert result == {"output": "test"} + + +@pytest.mark.asyncio +async def test_error_handling(): + """Test that tool errors are properly handled.""" + @tool("fail", "Always fails", {}) + async def fail_tool(args: dict[str, Any]) -> dict[str, Any]: + raise ValueError("Expected error") + + # Verify the tool raises an error when called directly + with pytest.raises(ValueError, match="Expected error"): + await fail_tool.handler({}) + + # Test error handling through the server + server_config = create_sdk_mcp_server( + name="error-test", + tools=[fail_tool] + ) + + server = server_config["instance"] + from mcp.types import CallToolRequest + call_handler = server.request_handlers[CallToolRequest] + + # The handler should return an error result, not raise + from mcp.types import CallToolRequestParams + fail_request = CallToolRequest( + method="tools/call", + params=CallToolRequestParams(name="fail", arguments={}) + ) + result = await call_handler(fail_request) + # MCP SDK catches exceptions and returns error results + assert result.root.isError + assert "Expected error" in str(result.root.content[0].text) + + +@pytest.mark.asyncio +async def test_mixed_servers(): + """Test that SDK and external MCP servers can work together.""" + # Create an SDK server + @tool("sdk_tool", "SDK tool", {}) + async def sdk_tool(args: dict[str, Any]) -> dict[str, Any]: + return {"result": "from SDK"} + + sdk_server = create_sdk_mcp_server( + name="sdk-server", + tools=[sdk_tool] + ) + + # Create configuration with both SDK and external servers + external_server = { + "type": "stdio", + "command": "echo", + "args": ["test"] + } + + options = ClaudeCodeOptions( + mcp_servers={ + "sdk": sdk_server, + "external": external_server + } + ) + + # Verify both server types are in the configuration + assert "sdk" in options.mcp_servers + assert "external" in options.mcp_servers + assert options.mcp_servers["sdk"]["type"] == "sdk" + assert options.mcp_servers["external"]["type"] == "stdio" + + +@pytest.mark.asyncio +async def test_server_creation(): + """Test that SDK MCP servers are created correctly.""" + server = create_sdk_mcp_server( + name="test-server", + version="2.0.0", + tools=[] + ) + + # Verify server configuration + assert server["type"] == "sdk" + assert server["name"] == "test-server" + assert "instance" in server + assert server["instance"] is not None + + # Verify the server instance has the right attributes + instance = server["instance"] + assert instance.name == "test-server" + assert instance.version == "2.0.0" + + # With no tools, no handlers are registered if tools is empty + from mcp.types import ListToolsRequest + # When no tools are provided, the handlers are not registered + assert ListToolsRequest not in instance.request_handlers \ No newline at end of file