From 92a2d42d36c121e0cf534c8734f95d0802e6877b Mon Sep 17 00:00:00 2001 From: Kashyap Murali Date: Mon, 1 Sep 2025 03:14:44 -0700 Subject: [PATCH] Apply ruff formatting and merge fixes from base branch - Applied ruff formatting to all files - Integrated async/sync fixes from dickson/control branch - Fixed import conflicts between TYPE_CHECKING and contextlib - All SDK MCP tests passing --- src/claude_code_sdk/__init__.py | 89 ++++++++++--------- src/claude_code_sdk/_internal/query.py | 34 +++---- .../_internal/transport/subprocess_cli.py | 2 +- src/claude_code_sdk/types.py | 4 +- tests/test_sdk_mcp_integration.py | 59 ++++-------- 5 files changed, 85 insertions(+), 103 deletions(-) diff --git a/src/claude_code_sdk/__init__.py b/src/claude_code_sdk/__init__.py index a26640a..0a9aee5 100644 --- a/src/claude_code_sdk/__init__.py +++ b/src/claude_code_sdk/__init__.py @@ -2,7 +2,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass -from typing import Any, Generic, TypeVar, Union +from typing import Any, Generic, TypeVar from ._errors import ( ClaudeSDKError, @@ -33,11 +33,13 @@ from .types import ( # MCP Server Support -T = TypeVar('T') +T = TypeVar("T") + @dataclass class SdkMcpTool(Generic[T]): """Definition for an SDK MCP tool.""" + name: str description: str input_schema: type[T] | dict[str, Any] @@ -45,16 +47,14 @@ class SdkMcpTool(Generic[T]): def tool( - name: str, - description: str, - input_schema: type | dict[str, Any] + name: str, description: str, input_schema: type | dict[str, Any] ) -> Callable[[Callable[[Any], Awaitable[dict[str, Any]]]], SdkMcpTool]: """Decorator for defining MCP tools with type safety. - + Creates a tool that can be used with SDK MCP servers. The tool runs in-process within your Python application, providing better performance than external MCP servers. - + Args: name: Unique identifier for the tool. This is what Claude will use to reference the tool in function calls. @@ -65,55 +65,60 @@ def tool( - A dictionary mapping parameter names to types (e.g., {"text": str}) - A TypedDict class for more complex schemas - A JSON Schema dictionary for full validation - + Returns: A decorator function that wraps the tool implementation and returns an SdkMcpTool instance ready for use with create_sdk_mcp_server(). - + Example: Basic tool with simple schema: >>> @tool("greet", "Greet a user", {"name": str}) ... async def greet(args): ... return {"content": [{"type": "text", "text": f"Hello, {args['name']}!"}]} - + Tool with multiple parameters: >>> @tool("add", "Add two numbers", {"a": float, "b": float}) ... async def add_numbers(args): ... result = args["a"] + args["b"] ... return {"content": [{"type": "text", "text": f"Result: {result}"}]} - + Tool with error handling: >>> @tool("divide", "Divide two numbers", {"a": float, "b": float}) ... async def divide(args): ... if args["b"] == 0: ... return {"content": [{"type": "text", "text": "Error: Division by zero"}], "is_error": True} ... return {"content": [{"type": "text", "text": f"Result: {args['a'] / args['b']}"}]} - + Notes: - The tool function must be async (defined with async def) - The function receives a single dict argument with the input parameters - The function should return a dict with a "content" key containing the response - Errors can be indicated by including "is_error": True in the response """ + def decorator(handler: Callable[[Any], Awaitable[dict[str, Any]]]) -> SdkMcpTool: - return SdkMcpTool(name=name, description=description, input_schema=input_schema, handler=handler) + return SdkMcpTool( + name=name, + description=description, + input_schema=input_schema, + handler=handler, + ) + return decorator def create_sdk_mcp_server( - name: str, - version: str = "1.0.0", - tools: list[SdkMcpTool] | None = None + name: str, version: str = "1.0.0", tools: list[SdkMcpTool] | None = None ) -> McpSdkServerConfig: """Create an in-process MCP server that runs within your Python application. - + Unlike external MCP servers that run as separate processes, SDK MCP servers run directly in your application's process. This provides: - Better performance (no IPC overhead) - Simpler deployment (single process) - Easier debugging (same process) - Direct access to your application's state - + Args: name: Unique identifier for the server. This name is used to reference the server in the mcp_servers configuration. @@ -122,54 +127,54 @@ def create_sdk_mcp_server( tools: List of SdkMcpTool instances created with the @tool decorator. These are the functions that Claude can call through this server. If None or empty, the server will have no tools (rarely useful). - + Returns: McpSdkServerConfig: A configuration object that can be passed to ClaudeCodeOptions.mcp_servers. This config contains the server instance and metadata needed for the SDK to route tool calls. - + Example: Simple calculator server: >>> @tool("add", "Add numbers", {"a": float, "b": float}) ... async def add(args): ... return {"content": [{"type": "text", "text": f"Sum: {args['a'] + args['b']}"}]} - >>> + >>> >>> @tool("multiply", "Multiply numbers", {"a": float, "b": float}) ... async def multiply(args): ... return {"content": [{"type": "text", "text": f"Product: {args['a'] * args['b']}"}]} - >>> + >>> >>> calculator = create_sdk_mcp_server( ... name="calculator", ... version="2.0.0", ... tools=[add, multiply] ... ) - >>> + >>> >>> # Use with Claude >>> options = ClaudeCodeOptions( ... mcp_servers={"calc": calculator}, ... allowed_tools=["add", "multiply"] ... ) - + Server with application state access: >>> class DataStore: ... def __init__(self): ... self.items = [] - ... + ... >>> store = DataStore() - >>> + >>> >>> @tool("add_item", "Add item to store", {"item": str}) ... async def add_item(args): ... store.items.append(args["item"]) ... return {"content": [{"type": "text", "text": f"Added: {args['item']}"}]} - >>> + >>> >>> server = create_sdk_mcp_server("store", tools=[add_item]) - + Notes: - The server runs in the same process as your Python application - Tools have direct access to your application's variables and state - No subprocess or IPC overhead for tool calls - Server lifecycle is managed automatically by the SDK - + See Also: - tool(): Decorator for creating tool functions - ClaudeCodeOptions: Configuration for using servers with query() @@ -194,7 +199,10 @@ def create_sdk_mcp_server( # Convert input_schema to JSON Schema format if isinstance(tool_def.input_schema, dict): # Check if it's already a JSON schema - if "type" in tool_def.input_schema and "properties" in tool_def.input_schema: + if ( + "type" in tool_def.input_schema + and "properties" in tool_def.input_schema + ): schema = tool_def.input_schema else: # Simple dict mapping names to types - convert to JSON schema @@ -213,17 +221,19 @@ def create_sdk_mcp_server( schema = { "type": "object", "properties": properties, - "required": list(properties.keys()) + "required": list(properties.keys()), } else: # For TypedDict or other types, create basic schema schema = {"type": "object", "properties": {}} - tool_list.append(Tool( - name=tool_def.name, - description=tool_def.description, - inputSchema=schema - )) + tool_list.append( + Tool( + name=tool_def.name, + description=tool_def.description, + inputSchema=schema, + ) + ) return tool_list # Register call_tool handler to execute tools @@ -250,11 +260,8 @@ def create_sdk_mcp_server( return content # Return SDK server configuration - return McpSdkServerConfig( - type="sdk", - name=name, - instance=server - ) + return McpSdkServerConfig(type="sdk", name=name, instance=server) + __version__ = "0.0.20" diff --git a/src/claude_code_sdk/_internal/query.py b/src/claude_code_sdk/_internal/query.py index 083e374..0ab2740 100644 --- a/src/claude_code_sdk/_internal/query.py +++ b/src/claude_code_sdk/_internal/query.py @@ -199,11 +199,13 @@ class Query: # Handle SDK MCP request server_name = request_data.get("server_name") mcp_message = request_data.get("message") - + if not server_name or not mcp_message: raise Exception("Missing server_name or message for MCP request") - - response_data = await self._handle_sdk_mcp_request(server_name, mcp_message) + + response_data = await self._handle_sdk_mcp_request( + server_name, mcp_message + ) else: raise Exception(f"Unsupported control request subtype: {subtype}") @@ -278,8 +280,8 @@ class Query: "id": message.get("id"), "error": { "code": -32601, - "message": f"Server '{server_name}' not found" - } + "message": f"Server '{server_name}' not found", + }, } server = self.sdk_mcp_servers[server_name] @@ -296,37 +298,29 @@ class Query: return { "jsonrpc": "2.0", "id": message.get("id"), - "result": {"tools": [t.model_dump() for t in tools]} + "result": {"tools": [t.model_dump() for t in tools]}, } elif method == "tools/call": # Get the call_tool handler and call it handler = server.request_handlers.get("tools/call") if handler: - result = await handler(params.get("name"), params.get("arguments", {})) - return { - "jsonrpc": "2.0", - "id": message.get("id"), - "result": result - } + result = await handler( + params.get("name"), params.get("arguments", {}) + ) + return {"jsonrpc": "2.0", "id": message.get("id"), "result": result} # Method not found return { "jsonrpc": "2.0", "id": message.get("id"), - "error": { - "code": -32601, - "message": f"Method '{method}' not found" - } + "error": {"code": -32601, "message": f"Method '{method}' not found"}, } except Exception as e: return { "jsonrpc": "2.0", "id": message.get("id"), - "error": { - "code": -32603, - "message": str(e) - } + "error": {"code": -32603, "message": str(e)}, } async def interrupt(self) -> None: diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 7478ac4..79d8ff6 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -134,7 +134,7 @@ class SubprocessCLITransport(Transport): for name, config in self._options.mcp_servers.items() if not (isinstance(config, dict) and config.get("type") == "sdk") } - + # Only pass external servers to CLI if external_servers: cmd.extend( diff --git a/src/claude_code_sdk/types.py b/src/claude_code_sdk/types.py index f42336e..b4fc413 100644 --- a/src/claude_code_sdk/types.py +++ b/src/claude_code_sdk/types.py @@ -47,7 +47,9 @@ class McpSdkServerConfig(TypedDict): instance: "McpServer" -McpServerConfig = McpStdioServerConfig | McpSSEServerConfig | McpHttpServerConfig | McpSdkServerConfig +McpServerConfig = ( + McpStdioServerConfig | McpSSEServerConfig | McpHttpServerConfig | McpSdkServerConfig +) # Content block types diff --git a/tests/test_sdk_mcp_integration.py b/tests/test_sdk_mcp_integration.py index 44d3a2e..e991f38 100644 --- a/tests/test_sdk_mcp_integration.py +++ b/tests/test_sdk_mcp_integration.py @@ -25,26 +25,16 @@ async def test_sdk_mcp_server_handlers(): @tool("greet_user", "Greets a user by name", {"name": str}) async def greet_user(args: dict[str, Any]) -> dict[str, Any]: tool_executions.append({"name": "greet_user", "args": args}) - return { - "content": [ - {"type": "text", "text": f"Hello, {args['name']}!"} - ] - } + return {"content": [{"type": "text", "text": f"Hello, {args['name']}!"}]} @tool("add_numbers", "Adds two numbers", {"a": float, "b": float}) async def add_numbers(args: dict[str, Any]) -> dict[str, Any]: tool_executions.append({"name": "add_numbers", "args": args}) result = args["a"] + args["b"] - return { - "content": [ - {"type": "text", "text": f"The sum is {result}"} - ] - } + return {"content": [{"type": "text", "text": f"The sum is {result}"}]} server_config = create_sdk_mcp_server( - name="test-sdk-server", - version="1.0.0", - tools=[greet_user, add_numbers] + name="test-sdk-server", version="1.0.0", tools=[greet_user, add_numbers] ) # Verify server configuration @@ -79,9 +69,10 @@ async def test_sdk_mcp_server_handlers(): # Call greet_user - CallToolRequest wraps the call from mcp.types import CallToolRequestParams + greet_request = CallToolRequest( method="tools/call", - params=CallToolRequestParams(name="greet_user", arguments={"name": "Alice"}) + params=CallToolRequestParams(name="greet_user", arguments={"name": "Alice"}), ) result = await call_handler(greet_request) # Response is ServerResult with nested CallToolResult @@ -93,7 +84,7 @@ async def test_sdk_mcp_server_handlers(): # Call add_numbers add_request = CallToolRequest( method="tools/call", - params=CallToolRequestParams(name="add_numbers", arguments={"a": 5, "b": 3}) + params=CallToolRequestParams(name="add_numbers", arguments={"a": 5, "b": 3}), ) result = await call_handler(add_request) assert "8" in result.root.content[0].text @@ -106,6 +97,7 @@ async def test_sdk_mcp_server_handlers(): @pytest.mark.asyncio async def test_tool_creation(): """Test that tools can be created with proper schemas.""" + @tool("echo", "Echo input", {"input": str}) async def echo_tool(args: dict[str, Any]) -> dict[str, Any]: return {"output": args["input"]} @@ -124,6 +116,7 @@ async def test_tool_creation(): @pytest.mark.asyncio async def test_error_handling(): """Test that tool errors are properly handled.""" + @tool("fail", "Always fails", {}) async def fail_tool(args: dict[str, Any]) -> dict[str, Any]: raise ValueError("Expected error") @@ -133,20 +126,18 @@ async def test_error_handling(): await fail_tool.handler({}) # Test error handling through the server - server_config = create_sdk_mcp_server( - name="error-test", - tools=[fail_tool] - ) + server_config = create_sdk_mcp_server(name="error-test", tools=[fail_tool]) server = server_config["instance"] from mcp.types import CallToolRequest + call_handler = server.request_handlers[CallToolRequest] # The handler should return an error result, not raise from mcp.types import CallToolRequestParams + fail_request = CallToolRequest( - method="tools/call", - params=CallToolRequestParams(name="fail", arguments={}) + method="tools/call", params=CallToolRequestParams(name="fail", arguments={}) ) result = await call_handler(fail_request) # MCP SDK catches exceptions and returns error results @@ -157,28 +148,19 @@ async def test_error_handling(): @pytest.mark.asyncio async def test_mixed_servers(): """Test that SDK and external MCP servers can work together.""" + # Create an SDK server @tool("sdk_tool", "SDK tool", {}) async def sdk_tool(args: dict[str, Any]) -> dict[str, Any]: return {"result": "from SDK"} - sdk_server = create_sdk_mcp_server( - name="sdk-server", - tools=[sdk_tool] - ) + sdk_server = create_sdk_mcp_server(name="sdk-server", tools=[sdk_tool]) # Create configuration with both SDK and external servers - external_server = { - "type": "stdio", - "command": "echo", - "args": ["test"] - } + external_server = {"type": "stdio", "command": "echo", "args": ["test"]} options = ClaudeCodeOptions( - mcp_servers={ - "sdk": sdk_server, - "external": external_server - } + mcp_servers={"sdk": sdk_server, "external": external_server} ) # Verify both server types are in the configuration @@ -191,11 +173,7 @@ async def test_mixed_servers(): @pytest.mark.asyncio async def test_server_creation(): """Test that SDK MCP servers are created correctly.""" - server = create_sdk_mcp_server( - name="test-server", - version="2.0.0", - tools=[] - ) + server = create_sdk_mcp_server(name="test-server", version="2.0.0", tools=[]) # Verify server configuration assert server["type"] == "sdk" @@ -210,5 +188,6 @@ async def test_server_creation(): # With no tools, no handlers are registered if tools is empty from mcp.types import ListToolsRequest + # When no tools are provided, the handlers are not registered - assert ListToolsRequest not in instance.request_handlers \ No newline at end of file + assert ListToolsRequest not in instance.request_handlers