opencode/internal/llm/agent/mcp-tools.go
2025-05-13 10:02:39 -05:00

198 lines
5.3 KiB
Go

package agent
import (
"context"
"encoding/json"
"fmt"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/llm/tools"
"github.com/sst/opencode/internal/permission"
"github.com/sst/opencode/internal/version"
"log/slog"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
)
type mcpTool struct {
mcpName string
tool mcp.Tool
mcpConfig config.MCPServer
permissions permission.Service
}
type MCPClient interface {
Initialize(
ctx context.Context,
request mcp.InitializeRequest,
) (*mcp.InitializeResult, error)
ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
Close() error
}
func (b *mcpTool) Info() tools.ToolInfo {
return tools.ToolInfo{
Name: fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
Description: b.tool.Description,
Parameters: b.tool.InputSchema.Properties,
Required: b.tool.InputSchema.Required,
}
}
func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
defer c.Close()
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
Name: "OpenCode",
Version: version.Version,
}
_, err := c.Initialize(ctx, initRequest)
if err != nil {
return tools.NewTextErrorResponse(err.Error()), nil
}
toolRequest := mcp.CallToolRequest{}
toolRequest.Params.Name = toolName
var args map[string]any
if err = json.Unmarshal([]byte(input), &args); err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
toolRequest.Params.Arguments = args
result, err := c.CallTool(ctx, toolRequest)
if err != nil {
return tools.NewTextErrorResponse(err.Error()), nil
}
output := ""
for _, v := range result.Content {
if v, ok := v.(mcp.TextContent); ok {
output = v.Text
} else {
output = fmt.Sprintf("%v", v)
}
}
return tools.NewTextResponse(output), nil
}
func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
sessionID, messageID := tools.GetContextValues(ctx)
if sessionID == "" || messageID == "" {
return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
}
permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
p := b.permissions.Request(
ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: config.WorkingDirectory(),
ToolName: b.Info().Name,
Action: "execute",
Description: permissionDescription,
Params: params.Input,
},
)
if !p {
return tools.NewTextErrorResponse("permission denied"), nil
}
switch b.mcpConfig.Type {
case config.MCPStdio:
c, err := client.NewStdioMCPClient(
b.mcpConfig.Command,
b.mcpConfig.Env,
b.mcpConfig.Args...,
)
if err != nil {
return tools.NewTextErrorResponse(err.Error()), nil
}
return runTool(ctx, c, b.tool.Name, params.Input)
case config.MCPSse:
c, err := client.NewSSEMCPClient(
b.mcpConfig.URL,
client.WithHeaders(b.mcpConfig.Headers),
)
if err != nil {
return tools.NewTextErrorResponse(err.Error()), nil
}
return runTool(ctx, c, b.tool.Name, params.Input)
}
return tools.NewTextErrorResponse("invalid mcp type"), nil
}
func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPServer) tools.BaseTool {
return &mcpTool{
mcpName: name,
tool: tool,
mcpConfig: mcpConfig,
permissions: permissions,
}
}
var mcpTools []tools.BaseTool
func getTools(ctx context.Context, name string, m config.MCPServer, permissions permission.Service, c MCPClient) []tools.BaseTool {
var stdioTools []tools.BaseTool
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
Name: "OpenCode",
Version: version.Version,
}
_, err := c.Initialize(ctx, initRequest)
if err != nil {
slog.Error("error initializing mcp client", "error", err)
return stdioTools
}
toolsRequest := mcp.ListToolsRequest{}
tools, err := c.ListTools(ctx, toolsRequest)
if err != nil {
slog.Error("error listing tools", "error", err)
return stdioTools
}
for _, t := range tools.Tools {
stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m))
}
defer c.Close()
return stdioTools
}
func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool {
if len(mcpTools) > 0 {
return mcpTools
}
for name, m := range config.Get().MCPServers {
switch m.Type {
case config.MCPStdio:
c, err := client.NewStdioMCPClient(
m.Command,
m.Env,
m.Args...,
)
if err != nil {
slog.Error("error creating mcp client", "error", err)
continue
}
mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
case config.MCPSse:
c, err := client.NewSSEMCPClient(
m.URL,
client.WithHeaders(m.Headers),
)
if err != nil {
slog.Error("error creating mcp client", "error", err)
continue
}
mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
}
}
return mcpTools
}