opencode/internal/llm/provider/gemini.go
2025-04-09 19:07:39 +02:00

401 lines
9.6 KiB
Go

package provider
import (
"context"
"encoding/json"
"errors"
"github.com/google/generative-ai-go/genai"
"github.com/google/uuid"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)
type geminiProvider struct {
client *genai.Client
model models.Model
maxTokens int32
apiKey string
systemMessage string
}
type GeminiOption func(*geminiProvider)
func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) {
provider := &geminiProvider{
maxTokens: 5000,
}
for _, opt := range opts {
opt(provider)
}
if provider.systemMessage == "" {
return nil, errors.New("system message is required")
}
client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey))
if err != nil {
return nil, err
}
provider.client = client
return provider, nil
}
func WithGeminiSystemMessage(message string) GeminiOption {
return func(p *geminiProvider) {
p.systemMessage = message
}
}
func WithGeminiMaxTokens(maxTokens int32) GeminiOption {
return func(p *geminiProvider) {
p.maxTokens = maxTokens
}
}
func WithGeminiModel(model models.Model) GeminiOption {
return func(p *geminiProvider) {
p.model = model
}
}
func WithGeminiKey(apiKey string) GeminiOption {
return func(p *geminiProvider) {
p.apiKey = apiKey
}
}
func (p *geminiProvider) Close() {
if p.client != nil {
p.client.Close()
}
}
func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
var history []*genai.Content
for _, msg := range messages {
switch msg.Role {
case message.User:
history = append(history, &genai.Content{
Parts: []genai.Part{genai.Text(msg.Content().String())},
Role: "user",
})
case message.Assistant:
content := &genai.Content{
Role: "model",
Parts: []genai.Part{},
}
if msg.Content().String() != "" {
content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
}
if len(msg.ToolCalls()) > 0 {
for _, call := range msg.ToolCalls() {
args, _ := parseJsonToMap(call.Input)
content.Parts = append(content.Parts, genai.FunctionCall{
Name: call.Name,
Args: args,
})
}
}
history = append(history, content)
case message.Tool:
for _, result := range msg.ToolResults() {
response := map[string]interface{}{"result": result.Content}
parsed, err := parseJsonToMap(result.Content)
if err == nil {
response = parsed
}
var toolCall message.ToolCall
for _, msg := range messages {
if msg.Role == message.Assistant {
for _, call := range msg.ToolCalls() {
if call.ID == result.ToolCallID {
toolCall = call
break
}
}
}
}
history = append(history, &genai.Content{
Parts: []genai.Part{genai.FunctionResponse{
Name: toolCall.Name,
Response: response,
}},
Role: "function",
})
}
}
}
return history
}
func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
if resp == nil || resp.UsageMetadata == nil {
return TokenUsage{}
}
return TokenUsage{
InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
CacheCreationTokens: 0, // Not directly provided by Gemini
CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
}
}
func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
model := p.client.GenerativeModel(p.model.APIModel)
model.SetMaxOutputTokens(p.maxTokens)
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
if len(tools) > 0 {
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
for _, declaration := range declarations {
model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
}
}
chat := model.StartChat()
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
lastUserMsg := messages[len(messages)-1]
resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String()))
if err != nil {
return nil, err
}
var content string
var toolCalls []message.ToolCall
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
switch p := part.(type) {
case genai.Text:
content = string(p)
case genai.FunctionCall:
id := "call_" + uuid.New().String()
args, _ := json.Marshal(p.Args)
toolCalls = append(toolCalls, message.ToolCall{
ID: id,
Name: p.Name,
Input: string(args),
Type: "function",
})
}
}
}
tokenUsage := p.extractTokenUsage(resp)
return &ProviderResponse{
Content: content,
ToolCalls: toolCalls,
Usage: tokenUsage,
}, nil
}
func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
model := p.client.GenerativeModel(p.model.APIModel)
model.SetMaxOutputTokens(p.maxTokens)
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
if len(tools) > 0 {
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
for _, declaration := range declarations {
model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
}
}
chat := model.StartChat()
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
lastUserMsg := messages[len(messages)-1]
iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String()))
eventChan := make(chan ProviderEvent)
go func() {
defer close(eventChan)
var finalResp *genai.GenerateContentResponse
currentContent := ""
toolCalls := []message.ToolCall{}
for {
resp, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
eventChan <- ProviderEvent{
Type: EventError,
Error: err,
}
return
}
finalResp = resp
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
switch p := part.(type) {
case genai.Text:
newText := string(p)
eventChan <- ProviderEvent{
Type: EventContentDelta,
Content: newText,
}
currentContent += newText
case genai.FunctionCall:
id := "call_" + uuid.New().String()
args, _ := json.Marshal(p.Args)
newCall := message.ToolCall{
ID: id,
Name: p.Name,
Input: string(args),
Type: "function",
}
isNew := true
for _, existing := range toolCalls {
if existing.Name == newCall.Name && existing.Input == newCall.Input {
isNew = false
break
}
}
if isNew {
toolCalls = append(toolCalls, newCall)
}
}
}
}
}
tokenUsage := p.extractTokenUsage(finalResp)
eventChan <- ProviderEvent{
Type: EventComplete,
Response: &ProviderResponse{
Content: currentContent,
ToolCalls: toolCalls,
Usage: tokenUsage,
FinishReason: string(finalResp.Candidates[0].FinishReason.String()),
},
}
}()
return eventChan, nil
}
func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
declarations := make([]*genai.FunctionDeclaration, len(tools))
for i, tool := range tools {
info := tool.Info()
declarations[i] = &genai.FunctionDeclaration{
Name: info.Name,
Description: info.Description,
Parameters: &genai.Schema{
Type: genai.TypeObject,
Properties: convertSchemaProperties(info.Parameters),
Required: info.Required,
},
}
}
return declarations
}
func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
properties := make(map[string]*genai.Schema)
for name, param := range parameters {
properties[name] = convertToSchema(param)
}
return properties
}
func convertToSchema(param interface{}) *genai.Schema {
schema := &genai.Schema{Type: genai.TypeString}
paramMap, ok := param.(map[string]interface{})
if !ok {
return schema
}
if desc, ok := paramMap["description"].(string); ok {
schema.Description = desc
}
typeVal, hasType := paramMap["type"]
if !hasType {
return schema
}
typeStr, ok := typeVal.(string)
if !ok {
return schema
}
schema.Type = mapJSONTypeToGenAI(typeStr)
switch typeStr {
case "array":
schema.Items = processArrayItems(paramMap)
case "object":
if props, ok := paramMap["properties"].(map[string]interface{}); ok {
schema.Properties = convertSchemaProperties(props)
}
}
return schema
}
func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
items, ok := paramMap["items"].(map[string]interface{})
if !ok {
return nil
}
return convertToSchema(items)
}
func mapJSONTypeToGenAI(jsonType string) genai.Type {
switch jsonType {
case "string":
return genai.TypeString
case "number":
return genai.TypeNumber
case "integer":
return genai.TypeInteger
case "boolean":
return genai.TypeBoolean
case "array":
return genai.TypeArray
case "object":
return genai.TypeObject
default:
return genai.TypeString // Default to string for unknown types
}
}
func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
var result map[string]interface{}
err := json.Unmarshal([]byte(jsonStr), &result)
return result, err
}