diff --git a/src/claude_code_sdk/_internal/query.py b/src/claude_code_sdk/_internal/query.py index c29a751..0a3aae4 100644 --- a/src/claude_code_sdk/_internal/query.py +++ b/src/claude_code_sdk/_internal/query.py @@ -292,7 +292,7 @@ class Query: raise Exception(f"Control request timeout: {request.get('subtype')}") from e async def _handle_sdk_mcp_request(self, server_name: str, message: dict) -> dict: - """Handle an MCP request for an SDK server. + """Handle an MCP request for an SDK server using the MCP SDK's built-in routing. Args: server_name: Name of the SDK MCP server @@ -316,27 +316,71 @@ class Query: params = message.get("params", {}) try: - # Route to appropriate handler based on method + # Import MCP types dynamically to avoid dependency issues + from mcp.types import ( + CallToolRequest, + CallToolRequestParams, + ListToolsRequest, + ) + + # Route based on method using MCP SDK's request types if method == "tools/list": - # Get the list_tools handler and call it - handler = server.request_handlers.get("tools/list") + # Create the proper request type + request = ListToolsRequest(method=method) + + # Get the handler and call it + handler = server.request_handlers.get(ListToolsRequest) if handler: - tools = await handler() + result = await handler(request) + # result is a ServerResult with nested ListToolsResult + tools_data = [ + { + "name": tool.name, + "description": tool.description, + "inputSchema": tool.inputSchema.model_dump() if tool.inputSchema else {} + } + for tool in result.root.tools + ] return { "jsonrpc": "2.0", "id": message.get("id"), - "result": {"tools": [t.model_dump() for t in tools]}, + "result": {"tools": tools_data} } - elif method == "tools/call": - # Get the call_tool handler and call it - handler = server.request_handlers.get("tools/call") - if handler: - result = await handler( - params.get("name"), params.get("arguments", {}) - ) - return {"jsonrpc": "2.0", "id": message.get("id"), "result": result} - # Method not found + elif method == "tools/call": + # Create the proper request type + request = CallToolRequest( + method=method, + params=CallToolRequestParams( + name=params.get("name"), + arguments=params.get("arguments", {}) + ) + ) + + # Get the handler and call it + handler = server.request_handlers.get(CallToolRequest) + if handler: + result = await handler(request) + # result is a ServerResult with nested CallToolResult + # Convert to JSONRPC response format + content = [] + for item in result.root.content: + if hasattr(item, 'text'): + content.append({"type": "text", "text": item.text}) + elif hasattr(item, 'data') and hasattr(item, 'mimeType'): + content.append({"type": "image", "data": item.data, "mimeType": item.mimeType}) + + response_data = {"content": content} + if hasattr(result.root, 'is_error') and result.root.is_error: + response_data["is_error"] = True + + return { + "jsonrpc": "2.0", + "id": message.get("id"), + "result": response_data + } + + # Method not found or no handler return { "jsonrpc": "2.0", "id": message.get("id"),