mirror of
https://github.com/sst/opencode.git
synced 2025-08-08 07:18:03 +00:00
feat: batch tool
This commit is contained in:
parent
5d9058eb74
commit
a65e593ab4
4 changed files with 498 additions and 21 deletions
|
@ -21,30 +21,41 @@ func PrimaryAgentTools(
|
|||
ctx := context.Background()
|
||||
mcpTools := GetMcpTools(ctx, permissions)
|
||||
|
||||
return append(
|
||||
[]tools.BaseTool{
|
||||
tools.NewBashTool(permissions),
|
||||
tools.NewEditTool(lspClients, permissions, history),
|
||||
tools.NewFetchTool(permissions),
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
tools.NewViewTool(lspClients),
|
||||
tools.NewPatchTool(lspClients, permissions, history),
|
||||
tools.NewWriteTool(lspClients, permissions, history),
|
||||
tools.NewDiagnosticsTool(lspClients),
|
||||
tools.NewDefinitionTool(lspClients),
|
||||
tools.NewReferencesTool(lspClients),
|
||||
tools.NewDocSymbolsTool(lspClients),
|
||||
tools.NewWorkspaceSymbolsTool(lspClients),
|
||||
tools.NewCodeActionTool(lspClients),
|
||||
NewAgentTool(sessions, messages, lspClients),
|
||||
}, mcpTools...,
|
||||
)
|
||||
// Create the list of tools
|
||||
toolsList := []tools.BaseTool{
|
||||
tools.NewBashTool(permissions),
|
||||
tools.NewEditTool(lspClients, permissions, history),
|
||||
tools.NewFetchTool(permissions),
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
tools.NewViewTool(lspClients),
|
||||
tools.NewPatchTool(lspClients, permissions, history),
|
||||
tools.NewWriteTool(lspClients, permissions, history),
|
||||
tools.NewDiagnosticsTool(lspClients),
|
||||
tools.NewDefinitionTool(lspClients),
|
||||
tools.NewReferencesTool(lspClients),
|
||||
tools.NewDocSymbolsTool(lspClients),
|
||||
tools.NewWorkspaceSymbolsTool(lspClients),
|
||||
tools.NewCodeActionTool(lspClients),
|
||||
NewAgentTool(sessions, messages, lspClients),
|
||||
}
|
||||
|
||||
// Create a map of tools for the batch tool
|
||||
toolsMap := make(map[string]tools.BaseTool)
|
||||
for _, tool := range toolsList {
|
||||
toolsMap[tool.Info().Name] = tool
|
||||
}
|
||||
|
||||
// Add the batch tool with access to all other tools
|
||||
toolsList = append(toolsList, tools.NewBatchTool(toolsMap))
|
||||
|
||||
return append(toolsList, mcpTools...)
|
||||
}
|
||||
|
||||
func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
|
||||
return []tools.BaseTool{
|
||||
// Create the list of tools
|
||||
toolsList := []tools.BaseTool{
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
|
@ -54,4 +65,15 @@ func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
|
|||
tools.NewDocSymbolsTool(lspClients),
|
||||
tools.NewWorkspaceSymbolsTool(lspClients),
|
||||
}
|
||||
|
||||
// Create a map of tools for the batch tool
|
||||
toolsMap := make(map[string]tools.BaseTool)
|
||||
for _, tool := range toolsList {
|
||||
toolsMap[tool.Info().Name] = tool
|
||||
}
|
||||
|
||||
// Add the batch tool with access to all other tools
|
||||
toolsList = append(toolsList, tools.NewBatchTool(toolsMap))
|
||||
|
||||
return toolsList
|
||||
}
|
||||
|
|
191
internal/llm/tools/batch.go
Normal file
191
internal/llm/tools/batch.go
Normal file
|
@ -0,0 +1,191 @@
|
|||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type BatchToolCall struct {
|
||||
Name string `json:"name"`
|
||||
Input json.RawMessage `json:"input"`
|
||||
}
|
||||
|
||||
type BatchParams struct {
|
||||
Calls []BatchToolCall `json:"calls"`
|
||||
}
|
||||
|
||||
type BatchToolResult struct {
|
||||
ToolName string `json:"tool_name"`
|
||||
ToolInput json.RawMessage `json:"tool_input"`
|
||||
Result json.RawMessage `json:"result"`
|
||||
Error string `json:"error,omitempty"`
|
||||
// Added for better formatting and separation between results
|
||||
Separator string `json:"separator,omitempty"`
|
||||
}
|
||||
|
||||
type BatchResult struct {
|
||||
Results []BatchToolResult `json:"results"`
|
||||
}
|
||||
|
||||
type batchTool struct {
|
||||
tools map[string]BaseTool
|
||||
}
|
||||
|
||||
const (
|
||||
BatchToolName = "batch"
|
||||
BatchToolDescription = `Executes multiple tool calls in parallel and returns their results.
|
||||
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to run multiple independent tool calls at once
|
||||
- Helpful for improving performance by parallelizing operations
|
||||
- Great for gathering information from multiple sources simultaneously
|
||||
|
||||
HOW TO USE:
|
||||
- Provide an array of tool calls, each with a name and input
|
||||
- Each tool call will be executed in parallel
|
||||
- Results are returned in the same order as the input calls
|
||||
|
||||
FEATURES:
|
||||
- Runs tool calls concurrently for better performance
|
||||
- Returns both results and errors for each call
|
||||
- Maintains the order of results to match input calls
|
||||
|
||||
LIMITATIONS:
|
||||
- All tools must be available in the current context
|
||||
- Complex error handling may be required for some use cases
|
||||
- Not suitable for tool calls that depend on each other's results
|
||||
|
||||
TIPS:
|
||||
- Use for independent operations like multiple file reads or searches
|
||||
- Great for batch operations like searching multiple directories
|
||||
- Combine with other tools for more complex workflows`
|
||||
)
|
||||
|
||||
func NewBatchTool(tools map[string]BaseTool) BaseTool {
|
||||
return &batchTool{
|
||||
tools: tools,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *batchTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: BatchToolName,
|
||||
Description: BatchToolDescription,
|
||||
Parameters: map[string]any{
|
||||
"calls": map[string]any{
|
||||
"type": "array",
|
||||
"description": "Array of tool calls to execute in parallel",
|
||||
"items": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Name of the tool to call",
|
||||
},
|
||||
"input": map[string]any{
|
||||
"type": "object",
|
||||
"description": "Input parameters for the tool",
|
||||
},
|
||||
},
|
||||
"required": []string{"name", "input"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Required: []string{"calls"},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *batchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params BatchParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
if len(params.Calls) == 0 {
|
||||
return NewTextErrorResponse("no tool calls provided"), nil
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
results := make([]BatchToolResult, len(params.Calls))
|
||||
|
||||
for i, toolCall := range params.Calls {
|
||||
wg.Add(1)
|
||||
go func(index int, tc BatchToolCall) {
|
||||
defer wg.Done()
|
||||
|
||||
// Create separator for better visual distinction between results
|
||||
separator := ""
|
||||
if index > 0 {
|
||||
separator = fmt.Sprintf("\n%s\n", strings.Repeat("=", 80))
|
||||
}
|
||||
|
||||
result := BatchToolResult{
|
||||
ToolName: tc.Name,
|
||||
ToolInput: tc.Input,
|
||||
Separator: separator,
|
||||
}
|
||||
|
||||
tool, ok := b.tools[tc.Name]
|
||||
if !ok {
|
||||
result.Error = fmt.Sprintf("tool not found: %s", tc.Name)
|
||||
results[index] = result
|
||||
return
|
||||
}
|
||||
|
||||
// Create a proper ToolCall object
|
||||
callObj := ToolCall{
|
||||
ID: fmt.Sprintf("batch-%d", index),
|
||||
Name: tc.Name,
|
||||
Input: string(tc.Input),
|
||||
}
|
||||
|
||||
response, err := tool.Run(ctx, callObj)
|
||||
if err != nil {
|
||||
result.Error = fmt.Sprintf("error executing tool %s: %s", tc.Name, err)
|
||||
results[index] = result
|
||||
return
|
||||
}
|
||||
|
||||
// Standardize metadata format if present
|
||||
if response.Metadata != "" {
|
||||
var metadata map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(response.Metadata), &metadata); err == nil {
|
||||
// Add tool name to metadata for better context
|
||||
metadata["tool"] = tc.Name
|
||||
|
||||
// Re-marshal with consistent formatting
|
||||
if metadataBytes, err := json.MarshalIndent(metadata, "", " "); err == nil {
|
||||
response.Metadata = string(metadataBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert the response to JSON
|
||||
responseJSON, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
result.Error = fmt.Sprintf("error marshaling response: %s", err)
|
||||
results[index] = result
|
||||
return
|
||||
}
|
||||
|
||||
result.Result = responseJSON
|
||||
results[index] = result
|
||||
}(i, toolCall)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
batchResult := BatchResult{
|
||||
Results: results,
|
||||
}
|
||||
|
||||
resultJSON, err := json.Marshal(batchResult)
|
||||
if err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error marshaling batch result: %s", err)), nil
|
||||
}
|
||||
|
||||
return NewTextResponse(string(resultJSON)), nil
|
||||
}
|
224
internal/llm/tools/batch_test.go
Normal file
224
internal/llm/tools/batch_test.go
Normal file
|
@ -0,0 +1,224 @@
|
|||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// MockTool is a simple tool implementation for testing
|
||||
type MockTool struct {
|
||||
name string
|
||||
description string
|
||||
response ToolResponse
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *MockTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: m.name,
|
||||
Description: m.description,
|
||||
Parameters: map[string]any{},
|
||||
Required: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
return m.response, m.err
|
||||
}
|
||||
|
||||
func TestBatchTool(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("successful batch execution", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create mock tools
|
||||
mockTools := map[string]BaseTool{
|
||||
"tool1": &MockTool{
|
||||
name: "tool1",
|
||||
description: "Mock Tool 1",
|
||||
response: NewTextResponse("Tool 1 Response"),
|
||||
err: nil,
|
||||
},
|
||||
"tool2": &MockTool{
|
||||
name: "tool2",
|
||||
description: "Mock Tool 2",
|
||||
response: NewTextResponse("Tool 2 Response"),
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
// Create batch tool
|
||||
batchTool := NewBatchTool(mockTools)
|
||||
|
||||
// Create batch call
|
||||
input := `{
|
||||
"calls": [
|
||||
{
|
||||
"name": "tool1",
|
||||
"input": {}
|
||||
},
|
||||
{
|
||||
"name": "tool2",
|
||||
"input": {}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
call := ToolCall{
|
||||
ID: "test-batch",
|
||||
Name: "batch",
|
||||
Input: input,
|
||||
}
|
||||
|
||||
// Execute batch
|
||||
response, err := batchTool.Run(context.Background(), call)
|
||||
|
||||
// Verify results
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, ToolResponseTypeText, response.Type)
|
||||
assert.False(t, response.IsError)
|
||||
|
||||
// Parse the response
|
||||
var batchResult BatchResult
|
||||
err = json.Unmarshal([]byte(response.Content), &batchResult)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify batch results
|
||||
assert.Len(t, batchResult.Results, 2)
|
||||
assert.Empty(t, batchResult.Results[0].Error)
|
||||
assert.Empty(t, batchResult.Results[1].Error)
|
||||
assert.Empty(t, batchResult.Results[0].Separator)
|
||||
assert.NotEmpty(t, batchResult.Results[1].Separator)
|
||||
|
||||
// Verify individual results
|
||||
var result1 ToolResponse
|
||||
err = json.Unmarshal(batchResult.Results[0].Result, &result1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Tool 1 Response", result1.Content)
|
||||
|
||||
var result2 ToolResponse
|
||||
err = json.Unmarshal(batchResult.Results[1].Result, &result2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Tool 2 Response", result2.Content)
|
||||
})
|
||||
|
||||
t.Run("tool not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create mock tools
|
||||
mockTools := map[string]BaseTool{
|
||||
"tool1": &MockTool{
|
||||
name: "tool1",
|
||||
description: "Mock Tool 1",
|
||||
response: NewTextResponse("Tool 1 Response"),
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
// Create batch tool
|
||||
batchTool := NewBatchTool(mockTools)
|
||||
|
||||
// Create batch call with non-existent tool
|
||||
input := `{
|
||||
"calls": [
|
||||
{
|
||||
"name": "tool1",
|
||||
"input": {}
|
||||
},
|
||||
{
|
||||
"name": "nonexistent",
|
||||
"input": {}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
call := ToolCall{
|
||||
ID: "test-batch",
|
||||
Name: "batch",
|
||||
Input: input,
|
||||
}
|
||||
|
||||
// Execute batch
|
||||
response, err := batchTool.Run(context.Background(), call)
|
||||
|
||||
// Verify results
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, ToolResponseTypeText, response.Type)
|
||||
assert.False(t, response.IsError)
|
||||
|
||||
// Parse the response
|
||||
var batchResult BatchResult
|
||||
err = json.Unmarshal([]byte(response.Content), &batchResult)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify batch results
|
||||
assert.Len(t, batchResult.Results, 2)
|
||||
assert.Empty(t, batchResult.Results[0].Error)
|
||||
assert.Contains(t, batchResult.Results[1].Error, "tool not found: nonexistent")
|
||||
})
|
||||
|
||||
t.Run("empty calls", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create batch tool with empty tools map
|
||||
batchTool := NewBatchTool(map[string]BaseTool{})
|
||||
|
||||
// Create batch call with empty calls
|
||||
input := `{
|
||||
"calls": []
|
||||
}`
|
||||
|
||||
call := ToolCall{
|
||||
ID: "test-batch",
|
||||
Name: "batch",
|
||||
Input: input,
|
||||
}
|
||||
|
||||
// Execute batch
|
||||
response, err := batchTool.Run(context.Background(), call)
|
||||
|
||||
// Verify results
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, ToolResponseTypeText, response.Type)
|
||||
assert.True(t, response.IsError)
|
||||
assert.Contains(t, response.Content, "no tool calls provided")
|
||||
})
|
||||
|
||||
t.Run("invalid input", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create batch tool with empty tools map
|
||||
batchTool := NewBatchTool(map[string]BaseTool{})
|
||||
|
||||
// Create batch call with invalid JSON
|
||||
input := `{
|
||||
"calls": [
|
||||
{
|
||||
"name": "tool1",
|
||||
"input": {
|
||||
"invalid": json
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
call := ToolCall{
|
||||
ID: "test-batch",
|
||||
Name: "batch",
|
||||
Input: input,
|
||||
}
|
||||
|
||||
// Execute batch
|
||||
response, err := batchTool.Run(context.Background(), call)
|
||||
|
||||
// Verify results
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, ToolResponseTypeText, response.Type)
|
||||
assert.True(t, response.IsError)
|
||||
assert.Contains(t, response.Content, "error parsing parameters")
|
||||
})
|
||||
}
|
|
@ -266,6 +266,8 @@ func toolName(name string) string {
|
|||
return "Write"
|
||||
case tools.PatchToolName:
|
||||
return "Patch"
|
||||
case tools.BatchToolName:
|
||||
return "Batch"
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
@ -292,6 +294,8 @@ func getToolAction(name string) string {
|
|||
return "Preparing write..."
|
||||
case tools.PatchToolName:
|
||||
return "Preparing patch..."
|
||||
case tools.BatchToolName:
|
||||
return "Running batch operations..."
|
||||
}
|
||||
return "Working..."
|
||||
}
|
||||
|
@ -443,6 +447,10 @@ func renderToolParams(paramWidth int, toolCall message.ToolCall) string {
|
|||
json.Unmarshal([]byte(toolCall.Input), ¶ms)
|
||||
filePath := removeWorkingDirPrefix(params.FilePath)
|
||||
return renderParams(paramWidth, filePath)
|
||||
case tools.BatchToolName:
|
||||
var params tools.BatchParams
|
||||
json.Unmarshal([]byte(toolCall.Input), ¶ms)
|
||||
return renderParams(paramWidth, fmt.Sprintf("%d parallel calls", len(params.Calls)))
|
||||
default:
|
||||
input := strings.ReplaceAll(toolCall.Input, "\n", " ")
|
||||
params = renderParams(paramWidth, input)
|
||||
|
@ -540,6 +548,38 @@ func renderToolResponse(toolCall message.ToolCall, response message.ToolResult,
|
|||
toMarkdown(resultContent, true, width),
|
||||
t.Background(),
|
||||
)
|
||||
case tools.BatchToolName:
|
||||
var batchResult tools.BatchResult
|
||||
if err := json.Unmarshal([]byte(resultContent), &batchResult); err != nil {
|
||||
return baseStyle.Width(width).Foreground(t.Error()).Render(fmt.Sprintf("Error parsing batch result: %s", err))
|
||||
}
|
||||
|
||||
var toolCalls []string
|
||||
for i, result := range batchResult.Results {
|
||||
toolName := toolName(result.ToolName)
|
||||
|
||||
// Format the tool input as a string
|
||||
inputStr := string(result.ToolInput)
|
||||
|
||||
// Format the result
|
||||
var resultStr string
|
||||
if result.Error != "" {
|
||||
resultStr = fmt.Sprintf("Error: %s", result.Error)
|
||||
} else {
|
||||
var toolResponse tools.ToolResponse
|
||||
if err := json.Unmarshal(result.Result, &toolResponse); err != nil {
|
||||
resultStr = "Error parsing tool response"
|
||||
} else {
|
||||
resultStr = truncateHeight(toolResponse.Content, 3)
|
||||
}
|
||||
}
|
||||
|
||||
// Format the tool call
|
||||
toolCall := fmt.Sprintf("%d. %s: %s\n %s", i+1, toolName, inputStr, resultStr)
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
|
||||
return baseStyle.Width(width).Foreground(t.TextMuted()).Render(strings.Join(toolCalls, "\n\n"))
|
||||
default:
|
||||
resultContent = fmt.Sprintf("```text\n%s\n```", resultContent)
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue