diff --git a/pyproject.toml b/pyproject.toml index 0967e8a..d2c4f6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,9 @@ dependencies = [ ] [project.optional-dependencies] +mcp = [ + "mcp>=0.1.0", +] dev = [ "pytest>=7.0.0", "pytest-asyncio>=0.20.0", @@ -37,6 +40,7 @@ dev = [ "pytest-cov>=4.0.0", "mypy>=1.0.0", "ruff>=0.1.0", + "mcp>=0.1.0", # Include MCP for testing ] [project.urls] diff --git a/src/claude_code_sdk/_internal/query.py b/src/claude_code_sdk/_internal/query.py index 9d6139a..479a77f 100644 --- a/src/claude_code_sdk/_internal/query.py +++ b/src/claude_code_sdk/_internal/query.py @@ -19,6 +19,11 @@ from .transport import Transport if TYPE_CHECKING: from mcp.server import Server as McpServer + from mcp.types import ( + CallToolRequest, + CallToolRequestParams, + ListToolsRequest, + ) logger = logging.getLogger(__name__) @@ -314,31 +319,31 @@ class Query: params = message.get("params", {}) try: - # For now, we'll use a simpler approach without MCP SDK types - # This avoids import issues in CI where mcp package isn't installed - - # Route based on method string directly + # Import MCP types at runtime + from mcp.types import ( + CallToolRequest, + CallToolRequestParams, + ListToolsRequest, + ) + + # Route based on method using MCP SDK's request types if method == "tools/list": - # Try to get the handler - it should handle the raw request - handler = getattr(server, 'handle_list_tools', None) - if handler: - tools = await handler() - tools_data = [] - for tool in tools: - tool_dict = { - "name": tool.name if hasattr(tool, 'name') else str(tool), - "description": getattr(tool, 'description', ''), - } - if hasattr(tool, 'inputSchema') and tool.inputSchema: - schema = tool.inputSchema - if hasattr(schema, 'model_dump'): - tool_dict["inputSchema"] = schema.model_dump() - else: - tool_dict["inputSchema"] = {} - else: - tool_dict["inputSchema"] = {} - tools_data.append(tool_dict) + # Create the proper request type + request = ListToolsRequest(method=method) + # Get the handler and call it + handler = server.request_handlers.get(ListToolsRequest) + if handler: + result = await handler(request) + # result is a ServerResult with nested ListToolsResult + tools_data = [ + { + "name": tool.name, + "description": tool.description, + "inputSchema": tool.inputSchema.model_dump() if tool.inputSchema else {} + } + for tool in result.root.tools + ] return { "jsonrpc": "2.0", "id": message.get("id"), @@ -346,29 +351,30 @@ class Query: } elif method == "tools/call": - # Try to get the handler - handler = getattr(server, 'handle_call_tool', None) - if handler: - result = await handler( - params.get("name"), - params.get("arguments", {}) + # Create the proper request type + request = CallToolRequest( + method=method, + params=CallToolRequestParams( + name=params.get("name"), + arguments=params.get("arguments", {}) ) + ) - # Format the response + # Get the handler and call it + handler = server.request_handlers.get(CallToolRequest) + if handler: + result = await handler(request) + # result is a ServerResult with nested CallToolResult + # Convert to JSONRPC response format content = [] - if hasattr(result, 'content'): - for item in result.content: - if hasattr(item, 'text'): - content.append({"type": "text", "text": item.text}) - elif hasattr(item, 'data') and hasattr(item, 'mimeType'): - content.append({"type": "image", "data": item.data, "mimeType": item.mimeType}) - elif isinstance(result, str): - content.append({"type": "text", "text": result}) - elif isinstance(result, dict): - content.append({"type": "text", "text": str(result)}) + for item in result.root.content: + if hasattr(item, 'text'): + content.append({"type": "text", "text": item.text}) + elif hasattr(item, 'data') and hasattr(item, 'mimeType'): + content.append({"type": "image", "data": item.data, "mimeType": item.mimeType}) response_data = {"content": content} - if hasattr(result, 'is_error') and result.is_error: + if hasattr(result.root, 'is_error') and result.root.is_error: response_data["is_error"] = True return {