Apply ruff formatting and merge fixes from base branch

- Applied ruff formatting to all files
- Integrated async/sync fixes from dickson/control branch
- Fixed import conflicts between TYPE_CHECKING and contextlib
- All SDK MCP tests passing
This commit is contained in:
Kashyap Murali 2025-09-01 03:14:44 -07:00
parent 76f6ed1d9c
commit 92a2d42d36
No known key found for this signature in database
5 changed files with 85 additions and 103 deletions

View file

@ -2,7 +2,7 @@
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any, Generic, TypeVar, Union
from typing import Any, Generic, TypeVar
from ._errors import (
ClaudeSDKError,
@ -33,11 +33,13 @@ from .types import (
# MCP Server Support
T = TypeVar('T')
T = TypeVar("T")
@dataclass
class SdkMcpTool(Generic[T]):
"""Definition for an SDK MCP tool."""
name: str
description: str
input_schema: type[T] | dict[str, Any]
@ -45,16 +47,14 @@ class SdkMcpTool(Generic[T]):
def tool(
name: str,
description: str,
input_schema: type | dict[str, Any]
name: str, description: str, input_schema: type | dict[str, Any]
) -> Callable[[Callable[[Any], Awaitable[dict[str, Any]]]], SdkMcpTool]:
"""Decorator for defining MCP tools with type safety.
Creates a tool that can be used with SDK MCP servers. The tool runs
in-process within your Python application, providing better performance
than external MCP servers.
Args:
name: Unique identifier for the tool. This is what Claude will use
to reference the tool in function calls.
@ -65,55 +65,60 @@ def tool(
- A dictionary mapping parameter names to types (e.g., {"text": str})
- A TypedDict class for more complex schemas
- A JSON Schema dictionary for full validation
Returns:
A decorator function that wraps the tool implementation and returns
an SdkMcpTool instance ready for use with create_sdk_mcp_server().
Example:
Basic tool with simple schema:
>>> @tool("greet", "Greet a user", {"name": str})
... async def greet(args):
... return {"content": [{"type": "text", "text": f"Hello, {args['name']}!"}]}
Tool with multiple parameters:
>>> @tool("add", "Add two numbers", {"a": float, "b": float})
... async def add_numbers(args):
... result = args["a"] + args["b"]
... return {"content": [{"type": "text", "text": f"Result: {result}"}]}
Tool with error handling:
>>> @tool("divide", "Divide two numbers", {"a": float, "b": float})
... async def divide(args):
... if args["b"] == 0:
... return {"content": [{"type": "text", "text": "Error: Division by zero"}], "is_error": True}
... return {"content": [{"type": "text", "text": f"Result: {args['a'] / args['b']}"}]}
Notes:
- The tool function must be async (defined with async def)
- The function receives a single dict argument with the input parameters
- The function should return a dict with a "content" key containing the response
- Errors can be indicated by including "is_error": True in the response
"""
def decorator(handler: Callable[[Any], Awaitable[dict[str, Any]]]) -> SdkMcpTool:
return SdkMcpTool(name=name, description=description, input_schema=input_schema, handler=handler)
return SdkMcpTool(
name=name,
description=description,
input_schema=input_schema,
handler=handler,
)
return decorator
def create_sdk_mcp_server(
name: str,
version: str = "1.0.0",
tools: list[SdkMcpTool] | None = None
name: str, version: str = "1.0.0", tools: list[SdkMcpTool] | None = None
) -> McpSdkServerConfig:
"""Create an in-process MCP server that runs within your Python application.
Unlike external MCP servers that run as separate processes, SDK MCP servers
run directly in your application's process. This provides:
- Better performance (no IPC overhead)
- Simpler deployment (single process)
- Easier debugging (same process)
- Direct access to your application's state
Args:
name: Unique identifier for the server. This name is used to reference
the server in the mcp_servers configuration.
@ -122,54 +127,54 @@ def create_sdk_mcp_server(
tools: List of SdkMcpTool instances created with the @tool decorator.
These are the functions that Claude can call through this server.
If None or empty, the server will have no tools (rarely useful).
Returns:
McpSdkServerConfig: A configuration object that can be passed to
ClaudeCodeOptions.mcp_servers. This config contains the server
instance and metadata needed for the SDK to route tool calls.
Example:
Simple calculator server:
>>> @tool("add", "Add numbers", {"a": float, "b": float})
... async def add(args):
... return {"content": [{"type": "text", "text": f"Sum: {args['a'] + args['b']}"}]}
>>>
>>>
>>> @tool("multiply", "Multiply numbers", {"a": float, "b": float})
... async def multiply(args):
... return {"content": [{"type": "text", "text": f"Product: {args['a'] * args['b']}"}]}
>>>
>>>
>>> calculator = create_sdk_mcp_server(
... name="calculator",
... version="2.0.0",
... tools=[add, multiply]
... )
>>>
>>>
>>> # Use with Claude
>>> options = ClaudeCodeOptions(
... mcp_servers={"calc": calculator},
... allowed_tools=["add", "multiply"]
... )
Server with application state access:
>>> class DataStore:
... def __init__(self):
... self.items = []
...
...
>>> store = DataStore()
>>>
>>>
>>> @tool("add_item", "Add item to store", {"item": str})
... async def add_item(args):
... store.items.append(args["item"])
... return {"content": [{"type": "text", "text": f"Added: {args['item']}"}]}
>>>
>>>
>>> server = create_sdk_mcp_server("store", tools=[add_item])
Notes:
- The server runs in the same process as your Python application
- Tools have direct access to your application's variables and state
- No subprocess or IPC overhead for tool calls
- Server lifecycle is managed automatically by the SDK
See Also:
- tool(): Decorator for creating tool functions
- ClaudeCodeOptions: Configuration for using servers with query()
@ -194,7 +199,10 @@ def create_sdk_mcp_server(
# Convert input_schema to JSON Schema format
if isinstance(tool_def.input_schema, dict):
# Check if it's already a JSON schema
if "type" in tool_def.input_schema and "properties" in tool_def.input_schema:
if (
"type" in tool_def.input_schema
and "properties" in tool_def.input_schema
):
schema = tool_def.input_schema
else:
# Simple dict mapping names to types - convert to JSON schema
@ -213,17 +221,19 @@ def create_sdk_mcp_server(
schema = {
"type": "object",
"properties": properties,
"required": list(properties.keys())
"required": list(properties.keys()),
}
else:
# For TypedDict or other types, create basic schema
schema = {"type": "object", "properties": {}}
tool_list.append(Tool(
name=tool_def.name,
description=tool_def.description,
inputSchema=schema
))
tool_list.append(
Tool(
name=tool_def.name,
description=tool_def.description,
inputSchema=schema,
)
)
return tool_list
# Register call_tool handler to execute tools
@ -250,11 +260,8 @@ def create_sdk_mcp_server(
return content
# Return SDK server configuration
return McpSdkServerConfig(
type="sdk",
name=name,
instance=server
)
return McpSdkServerConfig(type="sdk", name=name, instance=server)
__version__ = "0.0.20"

View file

@ -199,11 +199,13 @@ class Query:
# Handle SDK MCP request
server_name = request_data.get("server_name")
mcp_message = request_data.get("message")
if not server_name or not mcp_message:
raise Exception("Missing server_name or message for MCP request")
response_data = await self._handle_sdk_mcp_request(server_name, mcp_message)
response_data = await self._handle_sdk_mcp_request(
server_name, mcp_message
)
else:
raise Exception(f"Unsupported control request subtype: {subtype}")
@ -278,8 +280,8 @@ class Query:
"id": message.get("id"),
"error": {
"code": -32601,
"message": f"Server '{server_name}' not found"
}
"message": f"Server '{server_name}' not found",
},
}
server = self.sdk_mcp_servers[server_name]
@ -296,37 +298,29 @@ class Query:
return {
"jsonrpc": "2.0",
"id": message.get("id"),
"result": {"tools": [t.model_dump() for t in tools]}
"result": {"tools": [t.model_dump() for t in tools]},
}
elif method == "tools/call":
# Get the call_tool handler and call it
handler = server.request_handlers.get("tools/call")
if handler:
result = await handler(params.get("name"), params.get("arguments", {}))
return {
"jsonrpc": "2.0",
"id": message.get("id"),
"result": result
}
result = await handler(
params.get("name"), params.get("arguments", {})
)
return {"jsonrpc": "2.0", "id": message.get("id"), "result": result}
# Method not found
return {
"jsonrpc": "2.0",
"id": message.get("id"),
"error": {
"code": -32601,
"message": f"Method '{method}' not found"
}
"error": {"code": -32601, "message": f"Method '{method}' not found"},
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": message.get("id"),
"error": {
"code": -32603,
"message": str(e)
}
"error": {"code": -32603, "message": str(e)},
}
async def interrupt(self) -> None:

View file

@ -134,7 +134,7 @@ class SubprocessCLITransport(Transport):
for name, config in self._options.mcp_servers.items()
if not (isinstance(config, dict) and config.get("type") == "sdk")
}
# Only pass external servers to CLI
if external_servers:
cmd.extend(

View file

@ -47,7 +47,9 @@ class McpSdkServerConfig(TypedDict):
instance: "McpServer"
McpServerConfig = McpStdioServerConfig | McpSSEServerConfig | McpHttpServerConfig | McpSdkServerConfig
McpServerConfig = (
McpStdioServerConfig | McpSSEServerConfig | McpHttpServerConfig | McpSdkServerConfig
)
# Content block types

View file

@ -25,26 +25,16 @@ async def test_sdk_mcp_server_handlers():
@tool("greet_user", "Greets a user by name", {"name": str})
async def greet_user(args: dict[str, Any]) -> dict[str, Any]:
tool_executions.append({"name": "greet_user", "args": args})
return {
"content": [
{"type": "text", "text": f"Hello, {args['name']}!"}
]
}
return {"content": [{"type": "text", "text": f"Hello, {args['name']}!"}]}
@tool("add_numbers", "Adds two numbers", {"a": float, "b": float})
async def add_numbers(args: dict[str, Any]) -> dict[str, Any]:
tool_executions.append({"name": "add_numbers", "args": args})
result = args["a"] + args["b"]
return {
"content": [
{"type": "text", "text": f"The sum is {result}"}
]
}
return {"content": [{"type": "text", "text": f"The sum is {result}"}]}
server_config = create_sdk_mcp_server(
name="test-sdk-server",
version="1.0.0",
tools=[greet_user, add_numbers]
name="test-sdk-server", version="1.0.0", tools=[greet_user, add_numbers]
)
# Verify server configuration
@ -79,9 +69,10 @@ async def test_sdk_mcp_server_handlers():
# Call greet_user - CallToolRequest wraps the call
from mcp.types import CallToolRequestParams
greet_request = CallToolRequest(
method="tools/call",
params=CallToolRequestParams(name="greet_user", arguments={"name": "Alice"})
params=CallToolRequestParams(name="greet_user", arguments={"name": "Alice"}),
)
result = await call_handler(greet_request)
# Response is ServerResult with nested CallToolResult
@ -93,7 +84,7 @@ async def test_sdk_mcp_server_handlers():
# Call add_numbers
add_request = CallToolRequest(
method="tools/call",
params=CallToolRequestParams(name="add_numbers", arguments={"a": 5, "b": 3})
params=CallToolRequestParams(name="add_numbers", arguments={"a": 5, "b": 3}),
)
result = await call_handler(add_request)
assert "8" in result.root.content[0].text
@ -106,6 +97,7 @@ async def test_sdk_mcp_server_handlers():
@pytest.mark.asyncio
async def test_tool_creation():
"""Test that tools can be created with proper schemas."""
@tool("echo", "Echo input", {"input": str})
async def echo_tool(args: dict[str, Any]) -> dict[str, Any]:
return {"output": args["input"]}
@ -124,6 +116,7 @@ async def test_tool_creation():
@pytest.mark.asyncio
async def test_error_handling():
"""Test that tool errors are properly handled."""
@tool("fail", "Always fails", {})
async def fail_tool(args: dict[str, Any]) -> dict[str, Any]:
raise ValueError("Expected error")
@ -133,20 +126,18 @@ async def test_error_handling():
await fail_tool.handler({})
# Test error handling through the server
server_config = create_sdk_mcp_server(
name="error-test",
tools=[fail_tool]
)
server_config = create_sdk_mcp_server(name="error-test", tools=[fail_tool])
server = server_config["instance"]
from mcp.types import CallToolRequest
call_handler = server.request_handlers[CallToolRequest]
# The handler should return an error result, not raise
from mcp.types import CallToolRequestParams
fail_request = CallToolRequest(
method="tools/call",
params=CallToolRequestParams(name="fail", arguments={})
method="tools/call", params=CallToolRequestParams(name="fail", arguments={})
)
result = await call_handler(fail_request)
# MCP SDK catches exceptions and returns error results
@ -157,28 +148,19 @@ async def test_error_handling():
@pytest.mark.asyncio
async def test_mixed_servers():
"""Test that SDK and external MCP servers can work together."""
# Create an SDK server
@tool("sdk_tool", "SDK tool", {})
async def sdk_tool(args: dict[str, Any]) -> dict[str, Any]:
return {"result": "from SDK"}
sdk_server = create_sdk_mcp_server(
name="sdk-server",
tools=[sdk_tool]
)
sdk_server = create_sdk_mcp_server(name="sdk-server", tools=[sdk_tool])
# Create configuration with both SDK and external servers
external_server = {
"type": "stdio",
"command": "echo",
"args": ["test"]
}
external_server = {"type": "stdio", "command": "echo", "args": ["test"]}
options = ClaudeCodeOptions(
mcp_servers={
"sdk": sdk_server,
"external": external_server
}
mcp_servers={"sdk": sdk_server, "external": external_server}
)
# Verify both server types are in the configuration
@ -191,11 +173,7 @@ async def test_mixed_servers():
@pytest.mark.asyncio
async def test_server_creation():
"""Test that SDK MCP servers are created correctly."""
server = create_sdk_mcp_server(
name="test-server",
version="2.0.0",
tools=[]
)
server = create_sdk_mcp_server(name="test-server", version="2.0.0", tools=[])
# Verify server configuration
assert server["type"] == "sdk"
@ -210,5 +188,6 @@ async def test_server_creation():
# With no tools, no handlers are registered if tools is empty
from mcp.types import ListToolsRequest
# When no tools are provided, the handlers are not registered
assert ListToolsRequest not in instance.request_handlers
assert ListToolsRequest not in instance.request_handlers