mirror of
https://github.com/sst/opencode.git
synced 2025-08-07 14:58:07 +00:00

* support: vertexai fix fix set default for vertexai added comment fix fix * create schema * fix README.md * fix order * added pupularity * set tools if tools is exists restore commentout * fix comment * set summarizer model
555 lines
14 KiB
Go
555 lines
14 KiB
Go
package provider
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/sst/opencode/internal/config"
|
|
"github.com/sst/opencode/internal/llm/tools"
|
|
"github.com/sst/opencode/internal/message"
|
|
"github.com/sst/opencode/internal/status"
|
|
"google.golang.org/genai"
|
|
"log/slog"
|
|
)
|
|
|
|
type geminiOptions struct {
|
|
disableCache bool
|
|
}
|
|
|
|
type GeminiOption func(*geminiOptions)
|
|
|
|
type geminiClient struct {
|
|
providerOptions providerClientOptions
|
|
options geminiOptions
|
|
client *genai.Client
|
|
}
|
|
|
|
type GeminiClient ProviderClient
|
|
|
|
func newGeminiClient(opts providerClientOptions) GeminiClient {
|
|
geminiOpts := geminiOptions{}
|
|
for _, o := range opts.geminiOptions {
|
|
o(&geminiOpts)
|
|
}
|
|
|
|
client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
|
|
if err != nil {
|
|
slog.Error("Failed to create Gemini client", "error", err)
|
|
return nil
|
|
}
|
|
|
|
return &geminiClient{
|
|
providerOptions: opts,
|
|
options: geminiOpts,
|
|
client: client,
|
|
}
|
|
}
|
|
|
|
func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
|
|
var history []*genai.Content
|
|
for _, msg := range messages {
|
|
switch msg.Role {
|
|
case message.User:
|
|
var parts []*genai.Part
|
|
parts = append(parts, &genai.Part{Text: msg.Content().String()})
|
|
for _, binaryContent := range msg.BinaryContent() {
|
|
imageFormat := strings.Split(binaryContent.MIMEType, "/")
|
|
parts = append(parts, &genai.Part{InlineData: &genai.Blob{
|
|
MIMEType: imageFormat[1],
|
|
Data: binaryContent.Data,
|
|
}})
|
|
}
|
|
history = append(history, &genai.Content{
|
|
Parts: parts,
|
|
Role: "user",
|
|
})
|
|
case message.Assistant:
|
|
content := &genai.Content{
|
|
Role: "model",
|
|
Parts: []*genai.Part{},
|
|
}
|
|
|
|
if msg.Content().String() != "" {
|
|
content.Parts = append(content.Parts, &genai.Part{Text: msg.Content().String()})
|
|
}
|
|
|
|
if len(msg.ToolCalls()) > 0 {
|
|
for _, call := range msg.ToolCalls() {
|
|
args, _ := parseJsonToMap(call.Input)
|
|
content.Parts = append(content.Parts, &genai.Part{
|
|
FunctionCall: &genai.FunctionCall{
|
|
Name: call.Name,
|
|
Args: args,
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
history = append(history, content)
|
|
|
|
case message.Tool:
|
|
for _, result := range msg.ToolResults() {
|
|
response := map[string]interface{}{"result": result.Content}
|
|
parsed, err := parseJsonToMap(result.Content)
|
|
if err == nil {
|
|
response = parsed
|
|
}
|
|
|
|
var toolCall message.ToolCall
|
|
for _, m := range messages {
|
|
if m.Role == message.Assistant {
|
|
for _, call := range m.ToolCalls() {
|
|
if call.ID == result.ToolCallID {
|
|
toolCall = call
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
history = append(history, &genai.Content{
|
|
Parts: []*genai.Part{
|
|
{
|
|
FunctionResponse: &genai.FunctionResponse{
|
|
Name: toolCall.Name,
|
|
Response: response,
|
|
},
|
|
},
|
|
},
|
|
Role: "function",
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
return history
|
|
}
|
|
|
|
func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
|
|
geminiTool := &genai.Tool{}
|
|
geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
|
|
|
|
for _, tool := range tools {
|
|
info := tool.Info()
|
|
declaration := &genai.FunctionDeclaration{
|
|
Name: info.Name,
|
|
Description: info.Description,
|
|
Parameters: &genai.Schema{
|
|
Type: genai.TypeObject,
|
|
Properties: convertSchemaProperties(info.Parameters),
|
|
Required: info.Required,
|
|
},
|
|
}
|
|
|
|
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
|
|
}
|
|
|
|
return []*genai.Tool{geminiTool}
|
|
}
|
|
|
|
func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
|
|
switch {
|
|
case reason == genai.FinishReasonStop:
|
|
return message.FinishReasonEndTurn
|
|
case reason == genai.FinishReasonMaxTokens:
|
|
return message.FinishReasonMaxTokens
|
|
default:
|
|
return message.FinishReasonUnknown
|
|
}
|
|
}
|
|
|
|
func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
|
// Convert messages
|
|
geminiMessages := g.convertMessages(messages)
|
|
|
|
cfg := config.Get()
|
|
if cfg.Debug {
|
|
jsonData, _ := json.Marshal(geminiMessages)
|
|
slog.Debug("Prepared messages", "messages", string(jsonData))
|
|
}
|
|
|
|
history := geminiMessages[:len(geminiMessages)-1] // All but last message
|
|
lastMsg := geminiMessages[len(geminiMessages)-1]
|
|
config := &genai.GenerateContentConfig{
|
|
MaxOutputTokens: int32(g.providerOptions.maxTokens),
|
|
SystemInstruction: &genai.Content{
|
|
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
|
|
},
|
|
}
|
|
if len(tools) > 0 {
|
|
config.Tools = g.convertTools(tools)
|
|
}
|
|
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
|
|
|
|
attempts := 0
|
|
for {
|
|
attempts++
|
|
var toolCalls []message.ToolCall
|
|
|
|
var lastMsgParts []genai.Part
|
|
for _, part := range lastMsg.Parts {
|
|
lastMsgParts = append(lastMsgParts, *part)
|
|
}
|
|
resp, err := chat.SendMessage(ctx, lastMsgParts...)
|
|
// If there is an error we are going to see if we can retry the call
|
|
if err != nil {
|
|
retry, after, retryErr := g.shouldRetry(attempts, err)
|
|
duration := time.Duration(after) * time.Millisecond
|
|
if retryErr != nil {
|
|
return nil, retryErr
|
|
}
|
|
if retry {
|
|
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), status.WithDuration(duration))
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-time.After(duration):
|
|
continue
|
|
}
|
|
}
|
|
return nil, retryErr
|
|
}
|
|
|
|
content := ""
|
|
|
|
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
|
for _, part := range resp.Candidates[0].Content.Parts {
|
|
switch {
|
|
case part.Text != "":
|
|
content = string(part.Text)
|
|
case part.FunctionCall != nil:
|
|
id := "call_" + uuid.New().String()
|
|
args, _ := json.Marshal(part.FunctionCall.Args)
|
|
toolCalls = append(toolCalls, message.ToolCall{
|
|
ID: id,
|
|
Name: part.FunctionCall.Name,
|
|
Input: string(args),
|
|
Type: "function",
|
|
Finished: true,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
finishReason := message.FinishReasonEndTurn
|
|
if len(resp.Candidates) > 0 {
|
|
finishReason = g.finishReason(resp.Candidates[0].FinishReason)
|
|
}
|
|
if len(toolCalls) > 0 {
|
|
finishReason = message.FinishReasonToolUse
|
|
}
|
|
|
|
return &ProviderResponse{
|
|
Content: content,
|
|
ToolCalls: toolCalls,
|
|
Usage: g.usage(resp),
|
|
FinishReason: finishReason,
|
|
}, nil
|
|
}
|
|
}
|
|
|
|
func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
|
// Convert messages
|
|
geminiMessages := g.convertMessages(messages)
|
|
|
|
cfg := config.Get()
|
|
if cfg.Debug {
|
|
jsonData, _ := json.Marshal(geminiMessages)
|
|
slog.Debug("Prepared messages", "messages", string(jsonData))
|
|
}
|
|
|
|
history := geminiMessages[:len(geminiMessages)-1] // All but last message
|
|
lastMsg := geminiMessages[len(geminiMessages)-1]
|
|
config := &genai.GenerateContentConfig{
|
|
MaxOutputTokens: int32(g.providerOptions.maxTokens),
|
|
SystemInstruction: &genai.Content{
|
|
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
|
|
},
|
|
}
|
|
if len(tools) > 0 {
|
|
config.Tools = g.convertTools(tools)
|
|
}
|
|
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
|
|
|
|
attempts := 0
|
|
eventChan := make(chan ProviderEvent)
|
|
|
|
go func() {
|
|
defer close(eventChan)
|
|
|
|
for {
|
|
attempts++
|
|
|
|
currentContent := ""
|
|
toolCalls := []message.ToolCall{}
|
|
var finalResp *genai.GenerateContentResponse
|
|
|
|
eventChan <- ProviderEvent{Type: EventContentStart}
|
|
|
|
var lastMsgParts []genai.Part
|
|
|
|
for _, part := range lastMsg.Parts {
|
|
lastMsgParts = append(lastMsgParts, *part)
|
|
}
|
|
for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
|
|
if err != nil {
|
|
retry, after, retryErr := g.shouldRetry(attempts, err)
|
|
duration := time.Duration(after) * time.Millisecond
|
|
if retryErr != nil {
|
|
eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
|
|
return
|
|
}
|
|
if retry {
|
|
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), status.WithDuration(duration))
|
|
select {
|
|
case <-ctx.Done():
|
|
if ctx.Err() != nil {
|
|
eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
|
|
}
|
|
|
|
return
|
|
case <-time.After(duration):
|
|
break
|
|
}
|
|
} else {
|
|
eventChan <- ProviderEvent{Type: EventError, Error: err}
|
|
return
|
|
}
|
|
}
|
|
|
|
finalResp = resp
|
|
|
|
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
|
for _, part := range resp.Candidates[0].Content.Parts {
|
|
switch {
|
|
case part.Text != "":
|
|
delta := string(part.Text)
|
|
if delta != "" {
|
|
eventChan <- ProviderEvent{
|
|
Type: EventContentDelta,
|
|
Content: delta,
|
|
}
|
|
currentContent += delta
|
|
}
|
|
case part.FunctionCall != nil:
|
|
id := "call_" + uuid.New().String()
|
|
args, _ := json.Marshal(part.FunctionCall.Args)
|
|
newCall := message.ToolCall{
|
|
ID: id,
|
|
Name: part.FunctionCall.Name,
|
|
Input: string(args),
|
|
Type: "function",
|
|
Finished: true,
|
|
}
|
|
|
|
isNew := true
|
|
for _, existing := range toolCalls {
|
|
if existing.Name == newCall.Name && existing.Input == newCall.Input {
|
|
isNew = false
|
|
break
|
|
}
|
|
}
|
|
|
|
if isNew {
|
|
toolCalls = append(toolCalls, newCall)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
eventChan <- ProviderEvent{Type: EventContentStop}
|
|
|
|
if finalResp != nil {
|
|
|
|
finishReason := message.FinishReasonEndTurn
|
|
if len(finalResp.Candidates) > 0 {
|
|
finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
|
|
}
|
|
if len(toolCalls) > 0 {
|
|
finishReason = message.FinishReasonToolUse
|
|
}
|
|
eventChan <- ProviderEvent{
|
|
Type: EventComplete,
|
|
Response: &ProviderResponse{
|
|
Content: currentContent,
|
|
ToolCalls: toolCalls,
|
|
Usage: g.usage(finalResp),
|
|
FinishReason: finishReason,
|
|
},
|
|
}
|
|
return
|
|
}
|
|
|
|
}
|
|
}()
|
|
|
|
return eventChan
|
|
}
|
|
|
|
func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
|
|
// Check if error is a rate limit error
|
|
if attempts > maxRetries {
|
|
return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
|
|
}
|
|
|
|
// Gemini doesn't have a standard error type we can check against
|
|
// So we'll check the error message for rate limit indicators
|
|
if errors.Is(err, io.EOF) {
|
|
return false, 0, err
|
|
}
|
|
|
|
errMsg := err.Error()
|
|
isRateLimit := false
|
|
|
|
// Check for common rate limit error messages
|
|
if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
|
|
isRateLimit = true
|
|
}
|
|
|
|
if !isRateLimit {
|
|
return false, 0, err
|
|
}
|
|
|
|
// Calculate backoff with jitter
|
|
backoffMs := 2000 * (1 << (attempts - 1))
|
|
jitterMs := int(float64(backoffMs) * 0.2)
|
|
retryMs := backoffMs + jitterMs
|
|
|
|
return true, int64(retryMs), nil
|
|
}
|
|
|
|
func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
|
|
var toolCalls []message.ToolCall
|
|
|
|
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
|
for _, part := range resp.Candidates[0].Content.Parts {
|
|
if part.FunctionCall != nil {
|
|
id := "call_" + uuid.New().String()
|
|
args, _ := json.Marshal(part.FunctionCall.Args)
|
|
toolCalls = append(toolCalls, message.ToolCall{
|
|
ID: id,
|
|
Name: part.FunctionCall.Name,
|
|
Input: string(args),
|
|
Type: "function",
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
return toolCalls
|
|
}
|
|
|
|
func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
|
|
if resp == nil || resp.UsageMetadata == nil {
|
|
return TokenUsage{}
|
|
}
|
|
|
|
return TokenUsage{
|
|
InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
|
|
OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
|
|
CacheCreationTokens: 0, // Not directly provided by Gemini
|
|
CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
|
|
}
|
|
}
|
|
|
|
func WithGeminiDisableCache() GeminiOption {
|
|
return func(options *geminiOptions) {
|
|
options.disableCache = true
|
|
}
|
|
}
|
|
|
|
// Helper functions
|
|
func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
|
|
var result map[string]interface{}
|
|
err := json.Unmarshal([]byte(jsonStr), &result)
|
|
return result, err
|
|
}
|
|
|
|
func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
|
|
properties := make(map[string]*genai.Schema)
|
|
|
|
for name, param := range parameters {
|
|
properties[name] = convertToSchema(param)
|
|
}
|
|
|
|
return properties
|
|
}
|
|
|
|
func convertToSchema(param interface{}) *genai.Schema {
|
|
schema := &genai.Schema{Type: genai.TypeString}
|
|
|
|
paramMap, ok := param.(map[string]interface{})
|
|
if !ok {
|
|
return schema
|
|
}
|
|
|
|
if desc, ok := paramMap["description"].(string); ok {
|
|
schema.Description = desc
|
|
}
|
|
|
|
typeVal, hasType := paramMap["type"]
|
|
if !hasType {
|
|
return schema
|
|
}
|
|
|
|
typeStr, ok := typeVal.(string)
|
|
if !ok {
|
|
return schema
|
|
}
|
|
|
|
schema.Type = mapJSONTypeToGenAI(typeStr)
|
|
|
|
switch typeStr {
|
|
case "array":
|
|
schema.Items = processArrayItems(paramMap)
|
|
case "object":
|
|
if props, ok := paramMap["properties"].(map[string]interface{}); ok {
|
|
schema.Properties = convertSchemaProperties(props)
|
|
}
|
|
}
|
|
|
|
return schema
|
|
}
|
|
|
|
func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
|
|
items, ok := paramMap["items"].(map[string]interface{})
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
return convertToSchema(items)
|
|
}
|
|
|
|
func mapJSONTypeToGenAI(jsonType string) genai.Type {
|
|
switch jsonType {
|
|
case "string":
|
|
return genai.TypeString
|
|
case "number":
|
|
return genai.TypeNumber
|
|
case "integer":
|
|
return genai.TypeInteger
|
|
case "boolean":
|
|
return genai.TypeBoolean
|
|
case "array":
|
|
return genai.TypeArray
|
|
case "object":
|
|
return genai.TypeObject
|
|
default:
|
|
return genai.TypeString // Default to string for unknown types
|
|
}
|
|
}
|
|
|
|
func contains(s string, substrs ...string) bool {
|
|
for _, substr := range substrs {
|
|
if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|