opencode/internal/llm/provider/gemini.go
mineo 87237b6462
feat: support VertexAI provider (#153)
* support: vertexai

fix

fix

set default for vertexai

added comment

fix

fix

* create schema

* fix README.md

* fix order

* added pupularity

* set tools if tools is exists

restore commentout

* fix comment

* set summarizer model
2025-05-15 13:35:06 -05:00

555 lines
14 KiB
Go

package provider
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
"time"
"github.com/google/uuid"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/llm/tools"
"github.com/sst/opencode/internal/message"
"github.com/sst/opencode/internal/status"
"google.golang.org/genai"
"log/slog"
)
type geminiOptions struct {
disableCache bool
}
type GeminiOption func(*geminiOptions)
type geminiClient struct {
providerOptions providerClientOptions
options geminiOptions
client *genai.Client
}
type GeminiClient ProviderClient
func newGeminiClient(opts providerClientOptions) GeminiClient {
geminiOpts := geminiOptions{}
for _, o := range opts.geminiOptions {
o(&geminiOpts)
}
client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
if err != nil {
slog.Error("Failed to create Gemini client", "error", err)
return nil
}
return &geminiClient{
providerOptions: opts,
options: geminiOpts,
client: client,
}
}
func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
var history []*genai.Content
for _, msg := range messages {
switch msg.Role {
case message.User:
var parts []*genai.Part
parts = append(parts, &genai.Part{Text: msg.Content().String()})
for _, binaryContent := range msg.BinaryContent() {
imageFormat := strings.Split(binaryContent.MIMEType, "/")
parts = append(parts, &genai.Part{InlineData: &genai.Blob{
MIMEType: imageFormat[1],
Data: binaryContent.Data,
}})
}
history = append(history, &genai.Content{
Parts: parts,
Role: "user",
})
case message.Assistant:
content := &genai.Content{
Role: "model",
Parts: []*genai.Part{},
}
if msg.Content().String() != "" {
content.Parts = append(content.Parts, &genai.Part{Text: msg.Content().String()})
}
if len(msg.ToolCalls()) > 0 {
for _, call := range msg.ToolCalls() {
args, _ := parseJsonToMap(call.Input)
content.Parts = append(content.Parts, &genai.Part{
FunctionCall: &genai.FunctionCall{
Name: call.Name,
Args: args,
},
})
}
}
history = append(history, content)
case message.Tool:
for _, result := range msg.ToolResults() {
response := map[string]interface{}{"result": result.Content}
parsed, err := parseJsonToMap(result.Content)
if err == nil {
response = parsed
}
var toolCall message.ToolCall
for _, m := range messages {
if m.Role == message.Assistant {
for _, call := range m.ToolCalls() {
if call.ID == result.ToolCallID {
toolCall = call
break
}
}
}
}
history = append(history, &genai.Content{
Parts: []*genai.Part{
{
FunctionResponse: &genai.FunctionResponse{
Name: toolCall.Name,
Response: response,
},
},
},
Role: "function",
})
}
}
}
return history
}
func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
geminiTool := &genai.Tool{}
geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
for _, tool := range tools {
info := tool.Info()
declaration := &genai.FunctionDeclaration{
Name: info.Name,
Description: info.Description,
Parameters: &genai.Schema{
Type: genai.TypeObject,
Properties: convertSchemaProperties(info.Parameters),
Required: info.Required,
},
}
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
}
return []*genai.Tool{geminiTool}
}
func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
switch {
case reason == genai.FinishReasonStop:
return message.FinishReasonEndTurn
case reason == genai.FinishReasonMaxTokens:
return message.FinishReasonMaxTokens
default:
return message.FinishReasonUnknown
}
}
func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
// Convert messages
geminiMessages := g.convertMessages(messages)
cfg := config.Get()
if cfg.Debug {
jsonData, _ := json.Marshal(geminiMessages)
slog.Debug("Prepared messages", "messages", string(jsonData))
}
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
config := &genai.GenerateContentConfig{
MaxOutputTokens: int32(g.providerOptions.maxTokens),
SystemInstruction: &genai.Content{
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
},
}
if len(tools) > 0 {
config.Tools = g.convertTools(tools)
}
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
attempts := 0
for {
attempts++
var toolCalls []message.ToolCall
var lastMsgParts []genai.Part
for _, part := range lastMsg.Parts {
lastMsgParts = append(lastMsgParts, *part)
}
resp, err := chat.SendMessage(ctx, lastMsgParts...)
// If there is an error we are going to see if we can retry the call
if err != nil {
retry, after, retryErr := g.shouldRetry(attempts, err)
duration := time.Duration(after) * time.Millisecond
if retryErr != nil {
return nil, retryErr
}
if retry {
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), status.WithDuration(duration))
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(duration):
continue
}
}
return nil, retryErr
}
content := ""
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
switch {
case part.Text != "":
content = string(part.Text)
case part.FunctionCall != nil:
id := "call_" + uuid.New().String()
args, _ := json.Marshal(part.FunctionCall.Args)
toolCalls = append(toolCalls, message.ToolCall{
ID: id,
Name: part.FunctionCall.Name,
Input: string(args),
Type: "function",
Finished: true,
})
}
}
}
finishReason := message.FinishReasonEndTurn
if len(resp.Candidates) > 0 {
finishReason = g.finishReason(resp.Candidates[0].FinishReason)
}
if len(toolCalls) > 0 {
finishReason = message.FinishReasonToolUse
}
return &ProviderResponse{
Content: content,
ToolCalls: toolCalls,
Usage: g.usage(resp),
FinishReason: finishReason,
}, nil
}
}
func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
// Convert messages
geminiMessages := g.convertMessages(messages)
cfg := config.Get()
if cfg.Debug {
jsonData, _ := json.Marshal(geminiMessages)
slog.Debug("Prepared messages", "messages", string(jsonData))
}
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
config := &genai.GenerateContentConfig{
MaxOutputTokens: int32(g.providerOptions.maxTokens),
SystemInstruction: &genai.Content{
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
},
}
if len(tools) > 0 {
config.Tools = g.convertTools(tools)
}
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
attempts := 0
eventChan := make(chan ProviderEvent)
go func() {
defer close(eventChan)
for {
attempts++
currentContent := ""
toolCalls := []message.ToolCall{}
var finalResp *genai.GenerateContentResponse
eventChan <- ProviderEvent{Type: EventContentStart}
var lastMsgParts []genai.Part
for _, part := range lastMsg.Parts {
lastMsgParts = append(lastMsgParts, *part)
}
for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
if err != nil {
retry, after, retryErr := g.shouldRetry(attempts, err)
duration := time.Duration(after) * time.Millisecond
if retryErr != nil {
eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
return
}
if retry {
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), status.WithDuration(duration))
select {
case <-ctx.Done():
if ctx.Err() != nil {
eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
}
return
case <-time.After(duration):
break
}
} else {
eventChan <- ProviderEvent{Type: EventError, Error: err}
return
}
}
finalResp = resp
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
switch {
case part.Text != "":
delta := string(part.Text)
if delta != "" {
eventChan <- ProviderEvent{
Type: EventContentDelta,
Content: delta,
}
currentContent += delta
}
case part.FunctionCall != nil:
id := "call_" + uuid.New().String()
args, _ := json.Marshal(part.FunctionCall.Args)
newCall := message.ToolCall{
ID: id,
Name: part.FunctionCall.Name,
Input: string(args),
Type: "function",
Finished: true,
}
isNew := true
for _, existing := range toolCalls {
if existing.Name == newCall.Name && existing.Input == newCall.Input {
isNew = false
break
}
}
if isNew {
toolCalls = append(toolCalls, newCall)
}
}
}
}
}
eventChan <- ProviderEvent{Type: EventContentStop}
if finalResp != nil {
finishReason := message.FinishReasonEndTurn
if len(finalResp.Candidates) > 0 {
finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
}
if len(toolCalls) > 0 {
finishReason = message.FinishReasonToolUse
}
eventChan <- ProviderEvent{
Type: EventComplete,
Response: &ProviderResponse{
Content: currentContent,
ToolCalls: toolCalls,
Usage: g.usage(finalResp),
FinishReason: finishReason,
},
}
return
}
}
}()
return eventChan
}
func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
// Check if error is a rate limit error
if attempts > maxRetries {
return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
}
// Gemini doesn't have a standard error type we can check against
// So we'll check the error message for rate limit indicators
if errors.Is(err, io.EOF) {
return false, 0, err
}
errMsg := err.Error()
isRateLimit := false
// Check for common rate limit error messages
if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
isRateLimit = true
}
if !isRateLimit {
return false, 0, err
}
// Calculate backoff with jitter
backoffMs := 2000 * (1 << (attempts - 1))
jitterMs := int(float64(backoffMs) * 0.2)
retryMs := backoffMs + jitterMs
return true, int64(retryMs), nil
}
func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
var toolCalls []message.ToolCall
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
if part.FunctionCall != nil {
id := "call_" + uuid.New().String()
args, _ := json.Marshal(part.FunctionCall.Args)
toolCalls = append(toolCalls, message.ToolCall{
ID: id,
Name: part.FunctionCall.Name,
Input: string(args),
Type: "function",
})
}
}
}
return toolCalls
}
func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
if resp == nil || resp.UsageMetadata == nil {
return TokenUsage{}
}
return TokenUsage{
InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
CacheCreationTokens: 0, // Not directly provided by Gemini
CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
}
}
func WithGeminiDisableCache() GeminiOption {
return func(options *geminiOptions) {
options.disableCache = true
}
}
// Helper functions
func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
var result map[string]interface{}
err := json.Unmarshal([]byte(jsonStr), &result)
return result, err
}
func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
properties := make(map[string]*genai.Schema)
for name, param := range parameters {
properties[name] = convertToSchema(param)
}
return properties
}
func convertToSchema(param interface{}) *genai.Schema {
schema := &genai.Schema{Type: genai.TypeString}
paramMap, ok := param.(map[string]interface{})
if !ok {
return schema
}
if desc, ok := paramMap["description"].(string); ok {
schema.Description = desc
}
typeVal, hasType := paramMap["type"]
if !hasType {
return schema
}
typeStr, ok := typeVal.(string)
if !ok {
return schema
}
schema.Type = mapJSONTypeToGenAI(typeStr)
switch typeStr {
case "array":
schema.Items = processArrayItems(paramMap)
case "object":
if props, ok := paramMap["properties"].(map[string]interface{}); ok {
schema.Properties = convertSchemaProperties(props)
}
}
return schema
}
func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
items, ok := paramMap["items"].(map[string]interface{})
if !ok {
return nil
}
return convertToSchema(items)
}
func mapJSONTypeToGenAI(jsonType string) genai.Type {
switch jsonType {
case "string":
return genai.TypeString
case "number":
return genai.TypeNumber
case "integer":
return genai.TypeInteger
case "boolean":
return genai.TypeBoolean
case "array":
return genai.TypeArray
case "object":
return genai.TypeObject
default:
return genai.TypeString // Default to string for unknown types
}
}
func contains(s string, substrs ...string) bool {
for _, substr := range substrs {
if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
return true
}
}
return false
}