package provider import ( "context" "fmt" "github.com/sst/opencode/internal/llm/models" "github.com/sst/opencode/internal/llm/tools" "github.com/sst/opencode/internal/message" "log/slog" ) type EventType string const maxRetries = 8 const ( EventContentStart EventType = "content_start" EventToolUseStart EventType = "tool_use_start" EventToolUseDelta EventType = "tool_use_delta" EventToolUseStop EventType = "tool_use_stop" EventContentDelta EventType = "content_delta" EventThinkingDelta EventType = "thinking_delta" EventContentStop EventType = "content_stop" EventComplete EventType = "complete" EventError EventType = "error" EventWarning EventType = "warning" ) type TokenUsage struct { InputTokens int64 OutputTokens int64 CacheCreationTokens int64 CacheReadTokens int64 } type ProviderResponse struct { Content string ToolCalls []message.ToolCall Usage TokenUsage FinishReason message.FinishReason } type ProviderEvent struct { Type EventType Content string Thinking string Response *ProviderResponse ToolCall *message.ToolCall Error error } type Provider interface { SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent Model() models.Model MaxTokens() int64 } type providerClientOptions struct { apiKey string model models.Model maxTokens int64 systemMessage string anthropicOptions []AnthropicOption openaiOptions []OpenAIOption geminiOptions []GeminiOption bedrockOptions []BedrockOption } type ProviderClientOption func(*providerClientOptions) type ProviderClient interface { send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent } type baseProvider[C ProviderClient] struct { options providerClientOptions client C } func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) { clientOptions := providerClientOptions{} for _, o := range opts { o(&clientOptions) } switch providerName { case models.ProviderAnthropic: return &baseProvider[AnthropicClient]{ options: clientOptions, client: newAnthropicClient(clientOptions), }, nil case models.ProviderOpenAI: return &baseProvider[OpenAIClient]{ options: clientOptions, client: newOpenAIClient(clientOptions), }, nil case models.ProviderGemini: return &baseProvider[GeminiClient]{ options: clientOptions, client: newGeminiClient(clientOptions), }, nil case models.ProviderBedrock: return &baseProvider[BedrockClient]{ options: clientOptions, client: newBedrockClient(clientOptions), }, nil case models.ProviderGROQ: clientOptions.openaiOptions = append(clientOptions.openaiOptions, WithOpenAIBaseURL("https://api.groq.com/openai/v1"), ) return &baseProvider[OpenAIClient]{ options: clientOptions, client: newOpenAIClient(clientOptions), }, nil case models.ProviderAzure: return &baseProvider[AzureClient]{ options: clientOptions, client: newAzureClient(clientOptions), }, nil case models.ProviderVertexAI: return &baseProvider[VertexAIClient]{ options: clientOptions, client: newVertexAIClient(clientOptions), }, nil case models.ProviderOpenRouter: clientOptions.openaiOptions = append(clientOptions.openaiOptions, WithOpenAIBaseURL("https://openrouter.ai/api/v1"), WithOpenAIExtraHeaders(map[string]string{ "HTTP-Referer": "opencode.ai", "X-Title": "OpenCode", }), ) return &baseProvider[OpenAIClient]{ options: clientOptions, client: newOpenAIClient(clientOptions), }, nil case models.ProviderXAI: clientOptions.openaiOptions = append(clientOptions.openaiOptions, WithOpenAIBaseURL("https://api.x.ai/v1"), ) return &baseProvider[OpenAIClient]{ options: clientOptions, client: newOpenAIClient(clientOptions), }, nil case models.ProviderMock: // TODO: implement mock client for test panic("not implemented") } return nil, fmt.Errorf("provider not supported: %s", providerName) } func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) { for _, msg := range messages { // The message has no content if len(msg.Parts) == 0 { continue } cleaned = append(cleaned, msg) } return } func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { messages = p.cleanMessages(messages) response, err := p.client.send(ctx, messages, tools) if err == nil && response != nil { slog.Debug("API request token usage", "model", p.options.model.Name, "input_tokens", response.Usage.InputTokens, "output_tokens", response.Usage.OutputTokens, "cache_creation_tokens", response.Usage.CacheCreationTokens, "cache_read_tokens", response.Usage.CacheReadTokens, "total_tokens", response.Usage.InputTokens+response.Usage.OutputTokens) } return response, err } func (p *baseProvider[C]) Model() models.Model { return p.options.model } func (p *baseProvider[C]) MaxTokens() int64 { return p.options.maxTokens } func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { messages = p.cleanMessages(messages) eventChan := p.client.stream(ctx, messages, tools) // Create a new channel to intercept events wrappedChan := make(chan ProviderEvent) go func() { defer close(wrappedChan) for event := range eventChan { // Pass the event through wrappedChan <- event // Log token usage when we get the complete event if event.Type == EventComplete && event.Response != nil { slog.Debug("API streaming request token usage", "model", p.options.model.Name, "input_tokens", event.Response.Usage.InputTokens, "output_tokens", event.Response.Usage.OutputTokens, "cache_creation_tokens", event.Response.Usage.CacheCreationTokens, "cache_read_tokens", event.Response.Usage.CacheReadTokens, "total_tokens", event.Response.Usage.InputTokens+event.Response.Usage.OutputTokens) } } }() return wrappedChan } func WithAPIKey(apiKey string) ProviderClientOption { return func(options *providerClientOptions) { options.apiKey = apiKey } } func WithModel(model models.Model) ProviderClientOption { return func(options *providerClientOptions) { options.model = model } } func WithMaxTokens(maxTokens int64) ProviderClientOption { return func(options *providerClientOptions) { options.maxTokens = maxTokens } } func WithSystemMessage(systemMessage string) ProviderClientOption { return func(options *providerClientOptions) { options.systemMessage = systemMessage } } func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption { return func(options *providerClientOptions) { options.anthropicOptions = anthropicOptions } } func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption { return func(options *providerClientOptions) { options.openaiOptions = openaiOptions } } func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption { return func(options *providerClientOptions) { options.geminiOptions = geminiOptions } } func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption { return func(options *providerClientOptions) { options.bedrockOptions = bedrockOptions } }