opencode/internal/llm/provider/provider.go
mineo 87237b6462
feat: support VertexAI provider (#153)
* support: vertexai

fix

fix

set default for vertexai

added comment

fix

fix

* create schema

* fix README.md

* fix order

* added pupularity

* set tools if tools is exists

restore commentout

* fix comment

* set summarizer model
2025-05-15 13:35:06 -05:00

269 lines
7.5 KiB
Go

package provider
import (
"context"
"fmt"
"github.com/sst/opencode/internal/llm/models"
"github.com/sst/opencode/internal/llm/tools"
"github.com/sst/opencode/internal/message"
"log/slog"
)
type EventType string
const maxRetries = 8
const (
EventContentStart EventType = "content_start"
EventToolUseStart EventType = "tool_use_start"
EventToolUseDelta EventType = "tool_use_delta"
EventToolUseStop EventType = "tool_use_stop"
EventContentDelta EventType = "content_delta"
EventThinkingDelta EventType = "thinking_delta"
EventContentStop EventType = "content_stop"
EventComplete EventType = "complete"
EventError EventType = "error"
EventWarning EventType = "warning"
)
type TokenUsage struct {
InputTokens int64
OutputTokens int64
CacheCreationTokens int64
CacheReadTokens int64
}
type ProviderResponse struct {
Content string
ToolCalls []message.ToolCall
Usage TokenUsage
FinishReason message.FinishReason
}
type ProviderEvent struct {
Type EventType
Content string
Thinking string
Response *ProviderResponse
ToolCall *message.ToolCall
Error error
}
type Provider interface {
SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
Model() models.Model
MaxTokens() int64
}
type providerClientOptions struct {
apiKey string
model models.Model
maxTokens int64
systemMessage string
anthropicOptions []AnthropicOption
openaiOptions []OpenAIOption
geminiOptions []GeminiOption
bedrockOptions []BedrockOption
}
type ProviderClientOption func(*providerClientOptions)
type ProviderClient interface {
send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
}
type baseProvider[C ProviderClient] struct {
options providerClientOptions
client C
}
func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) {
clientOptions := providerClientOptions{}
for _, o := range opts {
o(&clientOptions)
}
switch providerName {
case models.ProviderAnthropic:
return &baseProvider[AnthropicClient]{
options: clientOptions,
client: newAnthropicClient(clientOptions),
}, nil
case models.ProviderOpenAI:
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
case models.ProviderGemini:
return &baseProvider[GeminiClient]{
options: clientOptions,
client: newGeminiClient(clientOptions),
}, nil
case models.ProviderBedrock:
return &baseProvider[BedrockClient]{
options: clientOptions,
client: newBedrockClient(clientOptions),
}, nil
case models.ProviderGROQ:
clientOptions.openaiOptions = append(clientOptions.openaiOptions,
WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
)
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
case models.ProviderAzure:
return &baseProvider[AzureClient]{
options: clientOptions,
client: newAzureClient(clientOptions),
}, nil
case models.ProviderVertexAI:
return &baseProvider[VertexAIClient]{
options: clientOptions,
client: newVertexAIClient(clientOptions),
}, nil
case models.ProviderOpenRouter:
clientOptions.openaiOptions = append(clientOptions.openaiOptions,
WithOpenAIBaseURL("https://openrouter.ai/api/v1"),
WithOpenAIExtraHeaders(map[string]string{
"HTTP-Referer": "opencode.ai",
"X-Title": "OpenCode",
}),
)
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
case models.ProviderXAI:
clientOptions.openaiOptions = append(clientOptions.openaiOptions,
WithOpenAIBaseURL("https://api.x.ai/v1"),
)
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
case models.ProviderMock:
// TODO: implement mock client for test
panic("not implemented")
}
return nil, fmt.Errorf("provider not supported: %s", providerName)
}
func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
for _, msg := range messages {
// The message has no content
if len(msg.Parts) == 0 {
continue
}
cleaned = append(cleaned, msg)
}
return
}
func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
messages = p.cleanMessages(messages)
response, err := p.client.send(ctx, messages, tools)
if err == nil && response != nil {
slog.Debug("API request token usage",
"model", p.options.model.Name,
"input_tokens", response.Usage.InputTokens,
"output_tokens", response.Usage.OutputTokens,
"cache_creation_tokens", response.Usage.CacheCreationTokens,
"cache_read_tokens", response.Usage.CacheReadTokens,
"total_tokens", response.Usage.InputTokens+response.Usage.OutputTokens)
}
return response, err
}
func (p *baseProvider[C]) Model() models.Model {
return p.options.model
}
func (p *baseProvider[C]) MaxTokens() int64 {
return p.options.maxTokens
}
func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
messages = p.cleanMessages(messages)
eventChan := p.client.stream(ctx, messages, tools)
// Create a new channel to intercept events
wrappedChan := make(chan ProviderEvent)
go func() {
defer close(wrappedChan)
for event := range eventChan {
// Pass the event through
wrappedChan <- event
// Log token usage when we get the complete event
if event.Type == EventComplete && event.Response != nil {
slog.Debug("API streaming request token usage",
"model", p.options.model.Name,
"input_tokens", event.Response.Usage.InputTokens,
"output_tokens", event.Response.Usage.OutputTokens,
"cache_creation_tokens", event.Response.Usage.CacheCreationTokens,
"cache_read_tokens", event.Response.Usage.CacheReadTokens,
"total_tokens", event.Response.Usage.InputTokens+event.Response.Usage.OutputTokens)
}
}
}()
return wrappedChan
}
func WithAPIKey(apiKey string) ProviderClientOption {
return func(options *providerClientOptions) {
options.apiKey = apiKey
}
}
func WithModel(model models.Model) ProviderClientOption {
return func(options *providerClientOptions) {
options.model = model
}
}
func WithMaxTokens(maxTokens int64) ProviderClientOption {
return func(options *providerClientOptions) {
options.maxTokens = maxTokens
}
}
func WithSystemMessage(systemMessage string) ProviderClientOption {
return func(options *providerClientOptions) {
options.systemMessage = systemMessage
}
}
func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
return func(options *providerClientOptions) {
options.anthropicOptions = anthropicOptions
}
}
func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
return func(options *providerClientOptions) {
options.openaiOptions = openaiOptions
}
}
func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
return func(options *providerClientOptions) {
options.geminiOptions = geminiOptions
}
}
func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
return func(options *providerClientOptions) {
options.bedrockOptions = bedrockOptions
}
}