feat: batch tool

This commit is contained in:
adamdottv 2025-05-15 12:44:16 -05:00
parent 5d9058eb74
commit a65e593ab4
No known key found for this signature in database
GPG key ID: 9CB48779AF150E75
4 changed files with 498 additions and 21 deletions

View file

@ -21,30 +21,41 @@ func PrimaryAgentTools(
ctx := context.Background()
mcpTools := GetMcpTools(ctx, permissions)
return append(
[]tools.BaseTool{
tools.NewBashTool(permissions),
tools.NewEditTool(lspClients, permissions, history),
tools.NewFetchTool(permissions),
tools.NewGlobTool(),
tools.NewGrepTool(),
tools.NewLsTool(),
tools.NewViewTool(lspClients),
tools.NewPatchTool(lspClients, permissions, history),
tools.NewWriteTool(lspClients, permissions, history),
tools.NewDiagnosticsTool(lspClients),
tools.NewDefinitionTool(lspClients),
tools.NewReferencesTool(lspClients),
tools.NewDocSymbolsTool(lspClients),
tools.NewWorkspaceSymbolsTool(lspClients),
tools.NewCodeActionTool(lspClients),
NewAgentTool(sessions, messages, lspClients),
}, mcpTools...,
)
// Create the list of tools
toolsList := []tools.BaseTool{
tools.NewBashTool(permissions),
tools.NewEditTool(lspClients, permissions, history),
tools.NewFetchTool(permissions),
tools.NewGlobTool(),
tools.NewGrepTool(),
tools.NewLsTool(),
tools.NewViewTool(lspClients),
tools.NewPatchTool(lspClients, permissions, history),
tools.NewWriteTool(lspClients, permissions, history),
tools.NewDiagnosticsTool(lspClients),
tools.NewDefinitionTool(lspClients),
tools.NewReferencesTool(lspClients),
tools.NewDocSymbolsTool(lspClients),
tools.NewWorkspaceSymbolsTool(lspClients),
tools.NewCodeActionTool(lspClients),
NewAgentTool(sessions, messages, lspClients),
}
// Create a map of tools for the batch tool
toolsMap := make(map[string]tools.BaseTool)
for _, tool := range toolsList {
toolsMap[tool.Info().Name] = tool
}
// Add the batch tool with access to all other tools
toolsList = append(toolsList, tools.NewBatchTool(toolsMap))
return append(toolsList, mcpTools...)
}
func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
return []tools.BaseTool{
// Create the list of tools
toolsList := []tools.BaseTool{
tools.NewGlobTool(),
tools.NewGrepTool(),
tools.NewLsTool(),
@ -54,4 +65,15 @@ func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
tools.NewDocSymbolsTool(lspClients),
tools.NewWorkspaceSymbolsTool(lspClients),
}
// Create a map of tools for the batch tool
toolsMap := make(map[string]tools.BaseTool)
for _, tool := range toolsList {
toolsMap[tool.Info().Name] = tool
}
// Add the batch tool with access to all other tools
toolsList = append(toolsList, tools.NewBatchTool(toolsMap))
return toolsList
}

191
internal/llm/tools/batch.go Normal file
View file

@ -0,0 +1,191 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"strings"
"sync"
)
type BatchToolCall struct {
Name string `json:"name"`
Input json.RawMessage `json:"input"`
}
type BatchParams struct {
Calls []BatchToolCall `json:"calls"`
}
type BatchToolResult struct {
ToolName string `json:"tool_name"`
ToolInput json.RawMessage `json:"tool_input"`
Result json.RawMessage `json:"result"`
Error string `json:"error,omitempty"`
// Added for better formatting and separation between results
Separator string `json:"separator,omitempty"`
}
type BatchResult struct {
Results []BatchToolResult `json:"results"`
}
type batchTool struct {
tools map[string]BaseTool
}
const (
BatchToolName = "batch"
BatchToolDescription = `Executes multiple tool calls in parallel and returns their results.
WHEN TO USE THIS TOOL:
- Use when you need to run multiple independent tool calls at once
- Helpful for improving performance by parallelizing operations
- Great for gathering information from multiple sources simultaneously
HOW TO USE:
- Provide an array of tool calls, each with a name and input
- Each tool call will be executed in parallel
- Results are returned in the same order as the input calls
FEATURES:
- Runs tool calls concurrently for better performance
- Returns both results and errors for each call
- Maintains the order of results to match input calls
LIMITATIONS:
- All tools must be available in the current context
- Complex error handling may be required for some use cases
- Not suitable for tool calls that depend on each other's results
TIPS:
- Use for independent operations like multiple file reads or searches
- Great for batch operations like searching multiple directories
- Combine with other tools for more complex workflows`
)
func NewBatchTool(tools map[string]BaseTool) BaseTool {
return &batchTool{
tools: tools,
}
}
func (b *batchTool) Info() ToolInfo {
return ToolInfo{
Name: BatchToolName,
Description: BatchToolDescription,
Parameters: map[string]any{
"calls": map[string]any{
"type": "array",
"description": "Array of tool calls to execute in parallel",
"items": map[string]any{
"type": "object",
"properties": map[string]any{
"name": map[string]any{
"type": "string",
"description": "Name of the tool to call",
},
"input": map[string]any{
"type": "object",
"description": "Input parameters for the tool",
},
},
"required": []string{"name", "input"},
},
},
},
Required: []string{"calls"},
}
}
func (b *batchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
var params BatchParams
if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
if len(params.Calls) == 0 {
return NewTextErrorResponse("no tool calls provided"), nil
}
var wg sync.WaitGroup
results := make([]BatchToolResult, len(params.Calls))
for i, toolCall := range params.Calls {
wg.Add(1)
go func(index int, tc BatchToolCall) {
defer wg.Done()
// Create separator for better visual distinction between results
separator := ""
if index > 0 {
separator = fmt.Sprintf("\n%s\n", strings.Repeat("=", 80))
}
result := BatchToolResult{
ToolName: tc.Name,
ToolInput: tc.Input,
Separator: separator,
}
tool, ok := b.tools[tc.Name]
if !ok {
result.Error = fmt.Sprintf("tool not found: %s", tc.Name)
results[index] = result
return
}
// Create a proper ToolCall object
callObj := ToolCall{
ID: fmt.Sprintf("batch-%d", index),
Name: tc.Name,
Input: string(tc.Input),
}
response, err := tool.Run(ctx, callObj)
if err != nil {
result.Error = fmt.Sprintf("error executing tool %s: %s", tc.Name, err)
results[index] = result
return
}
// Standardize metadata format if present
if response.Metadata != "" {
var metadata map[string]interface{}
if err := json.Unmarshal([]byte(response.Metadata), &metadata); err == nil {
// Add tool name to metadata for better context
metadata["tool"] = tc.Name
// Re-marshal with consistent formatting
if metadataBytes, err := json.MarshalIndent(metadata, "", " "); err == nil {
response.Metadata = string(metadataBytes)
}
}
}
// Convert the response to JSON
responseJSON, err := json.Marshal(response)
if err != nil {
result.Error = fmt.Sprintf("error marshaling response: %s", err)
results[index] = result
return
}
result.Result = responseJSON
results[index] = result
}(i, toolCall)
}
wg.Wait()
batchResult := BatchResult{
Results: results,
}
resultJSON, err := json.Marshal(batchResult)
if err != nil {
return NewTextErrorResponse(fmt.Sprintf("error marshaling batch result: %s", err)), nil
}
return NewTextResponse(string(resultJSON)), nil
}

View file

@ -0,0 +1,224 @@
package tools
import (
"context"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
)
// MockTool is a simple tool implementation for testing
type MockTool struct {
name string
description string
response ToolResponse
err error
}
func (m *MockTool) Info() ToolInfo {
return ToolInfo{
Name: m.name,
Description: m.description,
Parameters: map[string]any{},
Required: []string{},
}
}
func (m *MockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
return m.response, m.err
}
func TestBatchTool(t *testing.T) {
t.Parallel()
t.Run("successful batch execution", func(t *testing.T) {
t.Parallel()
// Create mock tools
mockTools := map[string]BaseTool{
"tool1": &MockTool{
name: "tool1",
description: "Mock Tool 1",
response: NewTextResponse("Tool 1 Response"),
err: nil,
},
"tool2": &MockTool{
name: "tool2",
description: "Mock Tool 2",
response: NewTextResponse("Tool 2 Response"),
err: nil,
},
}
// Create batch tool
batchTool := NewBatchTool(mockTools)
// Create batch call
input := `{
"calls": [
{
"name": "tool1",
"input": {}
},
{
"name": "tool2",
"input": {}
}
]
}`
call := ToolCall{
ID: "test-batch",
Name: "batch",
Input: input,
}
// Execute batch
response, err := batchTool.Run(context.Background(), call)
// Verify results
assert.NoError(t, err)
assert.Equal(t, ToolResponseTypeText, response.Type)
assert.False(t, response.IsError)
// Parse the response
var batchResult BatchResult
err = json.Unmarshal([]byte(response.Content), &batchResult)
assert.NoError(t, err)
// Verify batch results
assert.Len(t, batchResult.Results, 2)
assert.Empty(t, batchResult.Results[0].Error)
assert.Empty(t, batchResult.Results[1].Error)
assert.Empty(t, batchResult.Results[0].Separator)
assert.NotEmpty(t, batchResult.Results[1].Separator)
// Verify individual results
var result1 ToolResponse
err = json.Unmarshal(batchResult.Results[0].Result, &result1)
assert.NoError(t, err)
assert.Equal(t, "Tool 1 Response", result1.Content)
var result2 ToolResponse
err = json.Unmarshal(batchResult.Results[1].Result, &result2)
assert.NoError(t, err)
assert.Equal(t, "Tool 2 Response", result2.Content)
})
t.Run("tool not found", func(t *testing.T) {
t.Parallel()
// Create mock tools
mockTools := map[string]BaseTool{
"tool1": &MockTool{
name: "tool1",
description: "Mock Tool 1",
response: NewTextResponse("Tool 1 Response"),
err: nil,
},
}
// Create batch tool
batchTool := NewBatchTool(mockTools)
// Create batch call with non-existent tool
input := `{
"calls": [
{
"name": "tool1",
"input": {}
},
{
"name": "nonexistent",
"input": {}
}
]
}`
call := ToolCall{
ID: "test-batch",
Name: "batch",
Input: input,
}
// Execute batch
response, err := batchTool.Run(context.Background(), call)
// Verify results
assert.NoError(t, err)
assert.Equal(t, ToolResponseTypeText, response.Type)
assert.False(t, response.IsError)
// Parse the response
var batchResult BatchResult
err = json.Unmarshal([]byte(response.Content), &batchResult)
assert.NoError(t, err)
// Verify batch results
assert.Len(t, batchResult.Results, 2)
assert.Empty(t, batchResult.Results[0].Error)
assert.Contains(t, batchResult.Results[1].Error, "tool not found: nonexistent")
})
t.Run("empty calls", func(t *testing.T) {
t.Parallel()
// Create batch tool with empty tools map
batchTool := NewBatchTool(map[string]BaseTool{})
// Create batch call with empty calls
input := `{
"calls": []
}`
call := ToolCall{
ID: "test-batch",
Name: "batch",
Input: input,
}
// Execute batch
response, err := batchTool.Run(context.Background(), call)
// Verify results
assert.NoError(t, err)
assert.Equal(t, ToolResponseTypeText, response.Type)
assert.True(t, response.IsError)
assert.Contains(t, response.Content, "no tool calls provided")
})
t.Run("invalid input", func(t *testing.T) {
t.Parallel()
// Create batch tool with empty tools map
batchTool := NewBatchTool(map[string]BaseTool{})
// Create batch call with invalid JSON
input := `{
"calls": [
{
"name": "tool1",
"input": {
"invalid": json
}
}
]
}`
call := ToolCall{
ID: "test-batch",
Name: "batch",
Input: input,
}
// Execute batch
response, err := batchTool.Run(context.Background(), call)
// Verify results
assert.NoError(t, err)
assert.Equal(t, ToolResponseTypeText, response.Type)
assert.True(t, response.IsError)
assert.Contains(t, response.Content, "error parsing parameters")
})
}

View file

@ -266,6 +266,8 @@ func toolName(name string) string {
return "Write"
case tools.PatchToolName:
return "Patch"
case tools.BatchToolName:
return "Batch"
}
return name
}
@ -292,6 +294,8 @@ func getToolAction(name string) string {
return "Preparing write..."
case tools.PatchToolName:
return "Preparing patch..."
case tools.BatchToolName:
return "Running batch operations..."
}
return "Working..."
}
@ -443,6 +447,10 @@ func renderToolParams(paramWidth int, toolCall message.ToolCall) string {
json.Unmarshal([]byte(toolCall.Input), &params)
filePath := removeWorkingDirPrefix(params.FilePath)
return renderParams(paramWidth, filePath)
case tools.BatchToolName:
var params tools.BatchParams
json.Unmarshal([]byte(toolCall.Input), &params)
return renderParams(paramWidth, fmt.Sprintf("%d parallel calls", len(params.Calls)))
default:
input := strings.ReplaceAll(toolCall.Input, "\n", " ")
params = renderParams(paramWidth, input)
@ -540,6 +548,38 @@ func renderToolResponse(toolCall message.ToolCall, response message.ToolResult,
toMarkdown(resultContent, true, width),
t.Background(),
)
case tools.BatchToolName:
var batchResult tools.BatchResult
if err := json.Unmarshal([]byte(resultContent), &batchResult); err != nil {
return baseStyle.Width(width).Foreground(t.Error()).Render(fmt.Sprintf("Error parsing batch result: %s", err))
}
var toolCalls []string
for i, result := range batchResult.Results {
toolName := toolName(result.ToolName)
// Format the tool input as a string
inputStr := string(result.ToolInput)
// Format the result
var resultStr string
if result.Error != "" {
resultStr = fmt.Sprintf("Error: %s", result.Error)
} else {
var toolResponse tools.ToolResponse
if err := json.Unmarshal(result.Result, &toolResponse); err != nil {
resultStr = "Error parsing tool response"
} else {
resultStr = truncateHeight(toolResponse.Content, 3)
}
}
// Format the tool call
toolCall := fmt.Sprintf("%d. %s: %s\n %s", i+1, toolName, inputStr, resultStr)
toolCalls = append(toolCalls, toolCall)
}
return baseStyle.Width(width).Foreground(t.TextMuted()).Render(strings.Join(toolCalls, "\n\n"))
default:
resultContent = fmt.Sprintf("```text\n%s\n```", resultContent)
return styles.ForceReplaceBackgroundWithLipgloss(