From a65e593ab4f35e1a647832ba36be2c696e1f5165 Mon Sep 17 00:00:00 2001 From: adamdottv <2363879+adamdottv@users.noreply.github.com> Date: Thu, 15 May 2025 12:44:16 -0500 Subject: [PATCH] feat: batch tool --- internal/llm/agent/tools.go | 64 ++++--- internal/llm/tools/batch.go | 191 ++++++++++++++++++++ internal/llm/tools/batch_test.go | 224 ++++++++++++++++++++++++ internal/tui/components/chat/message.go | 40 +++++ 4 files changed, 498 insertions(+), 21 deletions(-) create mode 100644 internal/llm/tools/batch.go create mode 100644 internal/llm/tools/batch_test.go diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go index dba437bd..157b5bf5 100644 --- a/internal/llm/agent/tools.go +++ b/internal/llm/agent/tools.go @@ -21,30 +21,41 @@ func PrimaryAgentTools( ctx := context.Background() mcpTools := GetMcpTools(ctx, permissions) - return append( - []tools.BaseTool{ - tools.NewBashTool(permissions), - tools.NewEditTool(lspClients, permissions, history), - tools.NewFetchTool(permissions), - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewViewTool(lspClients), - tools.NewPatchTool(lspClients, permissions, history), - tools.NewWriteTool(lspClients, permissions, history), - tools.NewDiagnosticsTool(lspClients), - tools.NewDefinitionTool(lspClients), - tools.NewReferencesTool(lspClients), - tools.NewDocSymbolsTool(lspClients), - tools.NewWorkspaceSymbolsTool(lspClients), - tools.NewCodeActionTool(lspClients), - NewAgentTool(sessions, messages, lspClients), - }, mcpTools..., - ) + // Create the list of tools + toolsList := []tools.BaseTool{ + tools.NewBashTool(permissions), + tools.NewEditTool(lspClients, permissions, history), + tools.NewFetchTool(permissions), + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewViewTool(lspClients), + tools.NewPatchTool(lspClients, permissions, history), + tools.NewWriteTool(lspClients, permissions, history), + tools.NewDiagnosticsTool(lspClients), + tools.NewDefinitionTool(lspClients), + tools.NewReferencesTool(lspClients), + tools.NewDocSymbolsTool(lspClients), + tools.NewWorkspaceSymbolsTool(lspClients), + tools.NewCodeActionTool(lspClients), + NewAgentTool(sessions, messages, lspClients), + } + + // Create a map of tools for the batch tool + toolsMap := make(map[string]tools.BaseTool) + for _, tool := range toolsList { + toolsMap[tool.Info().Name] = tool + } + + // Add the batch tool with access to all other tools + toolsList = append(toolsList, tools.NewBatchTool(toolsMap)) + + return append(toolsList, mcpTools...) } func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool { - return []tools.BaseTool{ + // Create the list of tools + toolsList := []tools.BaseTool{ tools.NewGlobTool(), tools.NewGrepTool(), tools.NewLsTool(), @@ -54,4 +65,15 @@ func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool { tools.NewDocSymbolsTool(lspClients), tools.NewWorkspaceSymbolsTool(lspClients), } + + // Create a map of tools for the batch tool + toolsMap := make(map[string]tools.BaseTool) + for _, tool := range toolsList { + toolsMap[tool.Info().Name] = tool + } + + // Add the batch tool with access to all other tools + toolsList = append(toolsList, tools.NewBatchTool(toolsMap)) + + return toolsList } diff --git a/internal/llm/tools/batch.go b/internal/llm/tools/batch.go new file mode 100644 index 00000000..55101a50 --- /dev/null +++ b/internal/llm/tools/batch.go @@ -0,0 +1,191 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" +) + +type BatchToolCall struct { + Name string `json:"name"` + Input json.RawMessage `json:"input"` +} + +type BatchParams struct { + Calls []BatchToolCall `json:"calls"` +} + +type BatchToolResult struct { + ToolName string `json:"tool_name"` + ToolInput json.RawMessage `json:"tool_input"` + Result json.RawMessage `json:"result"` + Error string `json:"error,omitempty"` + // Added for better formatting and separation between results + Separator string `json:"separator,omitempty"` +} + +type BatchResult struct { + Results []BatchToolResult `json:"results"` +} + +type batchTool struct { + tools map[string]BaseTool +} + +const ( + BatchToolName = "batch" + BatchToolDescription = `Executes multiple tool calls in parallel and returns their results. + +WHEN TO USE THIS TOOL: +- Use when you need to run multiple independent tool calls at once +- Helpful for improving performance by parallelizing operations +- Great for gathering information from multiple sources simultaneously + +HOW TO USE: +- Provide an array of tool calls, each with a name and input +- Each tool call will be executed in parallel +- Results are returned in the same order as the input calls + +FEATURES: +- Runs tool calls concurrently for better performance +- Returns both results and errors for each call +- Maintains the order of results to match input calls + +LIMITATIONS: +- All tools must be available in the current context +- Complex error handling may be required for some use cases +- Not suitable for tool calls that depend on each other's results + +TIPS: +- Use for independent operations like multiple file reads or searches +- Great for batch operations like searching multiple directories +- Combine with other tools for more complex workflows` +) + +func NewBatchTool(tools map[string]BaseTool) BaseTool { + return &batchTool{ + tools: tools, + } +} + +func (b *batchTool) Info() ToolInfo { + return ToolInfo{ + Name: BatchToolName, + Description: BatchToolDescription, + Parameters: map[string]any{ + "calls": map[string]any{ + "type": "array", + "description": "Array of tool calls to execute in parallel", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + "description": "Name of the tool to call", + }, + "input": map[string]any{ + "type": "object", + "description": "Input parameters for the tool", + }, + }, + "required": []string{"name", "input"}, + }, + }, + }, + Required: []string{"calls"}, + } +} + +func (b *batchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { + var params BatchParams + if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { + return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil + } + + if len(params.Calls) == 0 { + return NewTextErrorResponse("no tool calls provided"), nil + } + + var wg sync.WaitGroup + results := make([]BatchToolResult, len(params.Calls)) + + for i, toolCall := range params.Calls { + wg.Add(1) + go func(index int, tc BatchToolCall) { + defer wg.Done() + + // Create separator for better visual distinction between results + separator := "" + if index > 0 { + separator = fmt.Sprintf("\n%s\n", strings.Repeat("=", 80)) + } + + result := BatchToolResult{ + ToolName: tc.Name, + ToolInput: tc.Input, + Separator: separator, + } + + tool, ok := b.tools[tc.Name] + if !ok { + result.Error = fmt.Sprintf("tool not found: %s", tc.Name) + results[index] = result + return + } + + // Create a proper ToolCall object + callObj := ToolCall{ + ID: fmt.Sprintf("batch-%d", index), + Name: tc.Name, + Input: string(tc.Input), + } + + response, err := tool.Run(ctx, callObj) + if err != nil { + result.Error = fmt.Sprintf("error executing tool %s: %s", tc.Name, err) + results[index] = result + return + } + + // Standardize metadata format if present + if response.Metadata != "" { + var metadata map[string]interface{} + if err := json.Unmarshal([]byte(response.Metadata), &metadata); err == nil { + // Add tool name to metadata for better context + metadata["tool"] = tc.Name + + // Re-marshal with consistent formatting + if metadataBytes, err := json.MarshalIndent(metadata, "", " "); err == nil { + response.Metadata = string(metadataBytes) + } + } + } + + // Convert the response to JSON + responseJSON, err := json.Marshal(response) + if err != nil { + result.Error = fmt.Sprintf("error marshaling response: %s", err) + results[index] = result + return + } + + result.Result = responseJSON + results[index] = result + }(i, toolCall) + } + + wg.Wait() + + batchResult := BatchResult{ + Results: results, + } + + resultJSON, err := json.Marshal(batchResult) + if err != nil { + return NewTextErrorResponse(fmt.Sprintf("error marshaling batch result: %s", err)), nil + } + + return NewTextResponse(string(resultJSON)), nil +} \ No newline at end of file diff --git a/internal/llm/tools/batch_test.go b/internal/llm/tools/batch_test.go new file mode 100644 index 00000000..1d5f0564 --- /dev/null +++ b/internal/llm/tools/batch_test.go @@ -0,0 +1,224 @@ +package tools + +import ( + "context" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +// MockTool is a simple tool implementation for testing +type MockTool struct { + name string + description string + response ToolResponse + err error +} + +func (m *MockTool) Info() ToolInfo { + return ToolInfo{ + Name: m.name, + Description: m.description, + Parameters: map[string]any{}, + Required: []string{}, + } +} + +func (m *MockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { + return m.response, m.err +} + +func TestBatchTool(t *testing.T) { + t.Parallel() + + t.Run("successful batch execution", func(t *testing.T) { + t.Parallel() + + // Create mock tools + mockTools := map[string]BaseTool{ + "tool1": &MockTool{ + name: "tool1", + description: "Mock Tool 1", + response: NewTextResponse("Tool 1 Response"), + err: nil, + }, + "tool2": &MockTool{ + name: "tool2", + description: "Mock Tool 2", + response: NewTextResponse("Tool 2 Response"), + err: nil, + }, + } + + // Create batch tool + batchTool := NewBatchTool(mockTools) + + // Create batch call + input := `{ + "calls": [ + { + "name": "tool1", + "input": {} + }, + { + "name": "tool2", + "input": {} + } + ] + }` + + call := ToolCall{ + ID: "test-batch", + Name: "batch", + Input: input, + } + + // Execute batch + response, err := batchTool.Run(context.Background(), call) + + // Verify results + assert.NoError(t, err) + assert.Equal(t, ToolResponseTypeText, response.Type) + assert.False(t, response.IsError) + + // Parse the response + var batchResult BatchResult + err = json.Unmarshal([]byte(response.Content), &batchResult) + assert.NoError(t, err) + + // Verify batch results + assert.Len(t, batchResult.Results, 2) + assert.Empty(t, batchResult.Results[0].Error) + assert.Empty(t, batchResult.Results[1].Error) + assert.Empty(t, batchResult.Results[0].Separator) + assert.NotEmpty(t, batchResult.Results[1].Separator) + + // Verify individual results + var result1 ToolResponse + err = json.Unmarshal(batchResult.Results[0].Result, &result1) + assert.NoError(t, err) + assert.Equal(t, "Tool 1 Response", result1.Content) + + var result2 ToolResponse + err = json.Unmarshal(batchResult.Results[1].Result, &result2) + assert.NoError(t, err) + assert.Equal(t, "Tool 2 Response", result2.Content) + }) + + t.Run("tool not found", func(t *testing.T) { + t.Parallel() + + // Create mock tools + mockTools := map[string]BaseTool{ + "tool1": &MockTool{ + name: "tool1", + description: "Mock Tool 1", + response: NewTextResponse("Tool 1 Response"), + err: nil, + }, + } + + // Create batch tool + batchTool := NewBatchTool(mockTools) + + // Create batch call with non-existent tool + input := `{ + "calls": [ + { + "name": "tool1", + "input": {} + }, + { + "name": "nonexistent", + "input": {} + } + ] + }` + + call := ToolCall{ + ID: "test-batch", + Name: "batch", + Input: input, + } + + // Execute batch + response, err := batchTool.Run(context.Background(), call) + + // Verify results + assert.NoError(t, err) + assert.Equal(t, ToolResponseTypeText, response.Type) + assert.False(t, response.IsError) + + // Parse the response + var batchResult BatchResult + err = json.Unmarshal([]byte(response.Content), &batchResult) + assert.NoError(t, err) + + // Verify batch results + assert.Len(t, batchResult.Results, 2) + assert.Empty(t, batchResult.Results[0].Error) + assert.Contains(t, batchResult.Results[1].Error, "tool not found: nonexistent") + }) + + t.Run("empty calls", func(t *testing.T) { + t.Parallel() + + // Create batch tool with empty tools map + batchTool := NewBatchTool(map[string]BaseTool{}) + + // Create batch call with empty calls + input := `{ + "calls": [] + }` + + call := ToolCall{ + ID: "test-batch", + Name: "batch", + Input: input, + } + + // Execute batch + response, err := batchTool.Run(context.Background(), call) + + // Verify results + assert.NoError(t, err) + assert.Equal(t, ToolResponseTypeText, response.Type) + assert.True(t, response.IsError) + assert.Contains(t, response.Content, "no tool calls provided") + }) + + t.Run("invalid input", func(t *testing.T) { + t.Parallel() + + // Create batch tool with empty tools map + batchTool := NewBatchTool(map[string]BaseTool{}) + + // Create batch call with invalid JSON + input := `{ + "calls": [ + { + "name": "tool1", + "input": { + "invalid": json + } + } + ] + }` + + call := ToolCall{ + ID: "test-batch", + Name: "batch", + Input: input, + } + + // Execute batch + response, err := batchTool.Run(context.Background(), call) + + // Verify results + assert.NoError(t, err) + assert.Equal(t, ToolResponseTypeText, response.Type) + assert.True(t, response.IsError) + assert.Contains(t, response.Content, "error parsing parameters") + }) +} \ No newline at end of file diff --git a/internal/tui/components/chat/message.go b/internal/tui/components/chat/message.go index 58c0aed4..f887337a 100644 --- a/internal/tui/components/chat/message.go +++ b/internal/tui/components/chat/message.go @@ -266,6 +266,8 @@ func toolName(name string) string { return "Write" case tools.PatchToolName: return "Patch" + case tools.BatchToolName: + return "Batch" } return name } @@ -292,6 +294,8 @@ func getToolAction(name string) string { return "Preparing write..." case tools.PatchToolName: return "Preparing patch..." + case tools.BatchToolName: + return "Running batch operations..." } return "Working..." } @@ -443,6 +447,10 @@ func renderToolParams(paramWidth int, toolCall message.ToolCall) string { json.Unmarshal([]byte(toolCall.Input), ¶ms) filePath := removeWorkingDirPrefix(params.FilePath) return renderParams(paramWidth, filePath) + case tools.BatchToolName: + var params tools.BatchParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + return renderParams(paramWidth, fmt.Sprintf("%d parallel calls", len(params.Calls))) default: input := strings.ReplaceAll(toolCall.Input, "\n", " ") params = renderParams(paramWidth, input) @@ -540,6 +548,38 @@ func renderToolResponse(toolCall message.ToolCall, response message.ToolResult, toMarkdown(resultContent, true, width), t.Background(), ) + case tools.BatchToolName: + var batchResult tools.BatchResult + if err := json.Unmarshal([]byte(resultContent), &batchResult); err != nil { + return baseStyle.Width(width).Foreground(t.Error()).Render(fmt.Sprintf("Error parsing batch result: %s", err)) + } + + var toolCalls []string + for i, result := range batchResult.Results { + toolName := toolName(result.ToolName) + + // Format the tool input as a string + inputStr := string(result.ToolInput) + + // Format the result + var resultStr string + if result.Error != "" { + resultStr = fmt.Sprintf("Error: %s", result.Error) + } else { + var toolResponse tools.ToolResponse + if err := json.Unmarshal(result.Result, &toolResponse); err != nil { + resultStr = "Error parsing tool response" + } else { + resultStr = truncateHeight(toolResponse.Content, 3) + } + } + + // Format the tool call + toolCall := fmt.Sprintf("%d. %s: %s\n %s", i+1, toolName, inputStr, resultStr) + toolCalls = append(toolCalls, toolCall) + } + + return baseStyle.Width(width).Foreground(t.TextMuted()).Render(strings.Join(toolCalls, "\n\n")) default: resultContent = fmt.Sprintf("```text\n%s\n```", resultContent) return styles.ForceReplaceBackgroundWithLipgloss(