refactor: Use MCP SDK's built-in request handling for SDK MCP servers

- Addresses Ashwin's review comment about letting MCP SDK do the routing
- Directly uses MCP Server's request_handlers instead of manual routing
- Creates proper MCP request types (ListToolsRequest, CallToolRequest)
- Leverages existing MCP SDK infrastructure for better maintainability
- No need for custom transport layer - uses handlers directly
- Future MCP methods will be automatically supported without code changes
This commit is contained in:
Kashyap Murali 2025-09-02 15:14:00 -07:00
parent 44f0d05fb7
commit 6f1513e7d2
No known key found for this signature in database

View file

@ -292,7 +292,7 @@ class Query:
raise Exception(f"Control request timeout: {request.get('subtype')}") from e
async def _handle_sdk_mcp_request(self, server_name: str, message: dict) -> dict:
"""Handle an MCP request for an SDK server.
"""Handle an MCP request for an SDK server using the MCP SDK's built-in routing.
Args:
server_name: Name of the SDK MCP server
@ -316,27 +316,71 @@ class Query:
params = message.get("params", {})
try:
# Route to appropriate handler based on method
# Import MCP types dynamically to avoid dependency issues
from mcp.types import (
CallToolRequest,
CallToolRequestParams,
ListToolsRequest,
)
# Route based on method using MCP SDK's request types
if method == "tools/list":
# Get the list_tools handler and call it
handler = server.request_handlers.get("tools/list")
# Create the proper request type
request = ListToolsRequest(method=method)
# Get the handler and call it
handler = server.request_handlers.get(ListToolsRequest)
if handler:
tools = await handler()
result = await handler(request)
# result is a ServerResult with nested ListToolsResult
tools_data = [
{
"name": tool.name,
"description": tool.description,
"inputSchema": tool.inputSchema.model_dump() if tool.inputSchema else {}
}
for tool in result.root.tools
]
return {
"jsonrpc": "2.0",
"id": message.get("id"),
"result": {"tools": [t.model_dump() for t in tools]},
"result": {"tools": tools_data}
}
elif method == "tools/call":
# Get the call_tool handler and call it
handler = server.request_handlers.get("tools/call")
if handler:
result = await handler(
params.get("name"), params.get("arguments", {})
)
return {"jsonrpc": "2.0", "id": message.get("id"), "result": result}
# Method not found
elif method == "tools/call":
# Create the proper request type
request = CallToolRequest(
method=method,
params=CallToolRequestParams(
name=params.get("name"),
arguments=params.get("arguments", {})
)
)
# Get the handler and call it
handler = server.request_handlers.get(CallToolRequest)
if handler:
result = await handler(request)
# result is a ServerResult with nested CallToolResult
# Convert to JSONRPC response format
content = []
for item in result.root.content:
if hasattr(item, 'text'):
content.append({"type": "text", "text": item.text})
elif hasattr(item, 'data') and hasattr(item, 'mimeType'):
content.append({"type": "image", "data": item.data, "mimeType": item.mimeType})
response_data = {"content": content}
if hasattr(result.root, 'is_error') and result.root.is_error:
response_data["is_error"] = True
return {
"jsonrpc": "2.0",
"id": message.get("id"),
"result": response_data
}
# Method not found or no handler
return {
"jsonrpc": "2.0",
"id": message.get("id"),