mirror of
https://github.com/sst/opencode.git
synced 2025-08-10 08:18:02 +00:00
279 lines
6.9 KiB
Go
279 lines
6.9 KiB
Go
package provider
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
|
|
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
|
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
|
"github.com/kujtimiihoxha/termai/internal/message"
|
|
"github.com/openai/openai-go"
|
|
"github.com/openai/openai-go/option"
|
|
)
|
|
|
|
type openaiProvider struct {
|
|
client openai.Client
|
|
model models.Model
|
|
maxTokens int64
|
|
baseURL string
|
|
apiKey string
|
|
systemMessage string
|
|
}
|
|
|
|
type OpenAIOption func(*openaiProvider)
|
|
|
|
func NewOpenAIProvider(opts ...OpenAIOption) (Provider, error) {
|
|
provider := &openaiProvider{
|
|
maxTokens: 5000,
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(provider)
|
|
}
|
|
|
|
clientOpts := []option.RequestOption{
|
|
option.WithAPIKey(provider.apiKey),
|
|
}
|
|
if provider.baseURL != "" {
|
|
clientOpts = append(clientOpts, option.WithBaseURL(provider.baseURL))
|
|
}
|
|
|
|
provider.client = openai.NewClient(clientOpts...)
|
|
if provider.systemMessage == "" {
|
|
return nil, errors.New("system message is required")
|
|
}
|
|
|
|
return provider, nil
|
|
}
|
|
|
|
func WithOpenAISystemMessage(message string) OpenAIOption {
|
|
return func(p *openaiProvider) {
|
|
p.systemMessage = message
|
|
}
|
|
}
|
|
|
|
func WithOpenAIMaxTokens(maxTokens int64) OpenAIOption {
|
|
return func(p *openaiProvider) {
|
|
p.maxTokens = maxTokens
|
|
}
|
|
}
|
|
|
|
func WithOpenAIModel(model models.Model) OpenAIOption {
|
|
return func(p *openaiProvider) {
|
|
p.model = model
|
|
}
|
|
}
|
|
|
|
func WithOpenAIBaseURL(baseURL string) OpenAIOption {
|
|
return func(p *openaiProvider) {
|
|
p.baseURL = baseURL
|
|
}
|
|
}
|
|
|
|
func WithOpenAIKey(apiKey string) OpenAIOption {
|
|
return func(p *openaiProvider) {
|
|
p.apiKey = apiKey
|
|
}
|
|
}
|
|
|
|
func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []openai.ChatCompletionMessageParamUnion {
|
|
var chatMessages []openai.ChatCompletionMessageParamUnion
|
|
|
|
chatMessages = append(chatMessages, openai.SystemMessage(p.systemMessage))
|
|
|
|
for _, msg := range messages {
|
|
switch msg.Role {
|
|
case message.User:
|
|
chatMessages = append(chatMessages, openai.UserMessage(msg.Content().String()))
|
|
|
|
case message.Assistant:
|
|
assistantMsg := openai.ChatCompletionAssistantMessageParam{
|
|
Role: "assistant",
|
|
}
|
|
|
|
if msg.Content().String() != "" {
|
|
assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
|
|
OfString: openai.String(msg.Content().String()),
|
|
}
|
|
}
|
|
|
|
if len(msg.ToolCalls()) > 0 {
|
|
assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
|
|
for i, call := range msg.ToolCalls() {
|
|
assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
|
|
ID: call.ID,
|
|
Type: "function",
|
|
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
|
Name: call.Name,
|
|
Arguments: call.Input,
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
chatMessages = append(chatMessages, openai.ChatCompletionMessageParamUnion{
|
|
OfAssistant: &assistantMsg,
|
|
})
|
|
|
|
case message.Tool:
|
|
for _, result := range msg.ToolResults() {
|
|
chatMessages = append(chatMessages,
|
|
openai.ToolMessage(result.Content, result.ToolCallID),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
return chatMessages
|
|
}
|
|
|
|
func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
|
|
openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
|
|
|
|
for i, tool := range tools {
|
|
info := tool.Info()
|
|
openaiTools[i] = openai.ChatCompletionToolParam{
|
|
Function: openai.FunctionDefinitionParam{
|
|
Name: info.Name,
|
|
Description: openai.String(info.Description),
|
|
Parameters: openai.FunctionParameters{
|
|
"type": "object",
|
|
"properties": info.Parameters,
|
|
"required": info.Required,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
return openaiTools
|
|
}
|
|
|
|
func (p *openaiProvider) extractTokenUsage(usage openai.CompletionUsage) TokenUsage {
|
|
cachedTokens := int64(0)
|
|
|
|
cachedTokens = usage.PromptTokensDetails.CachedTokens
|
|
inputTokens := usage.PromptTokens - cachedTokens
|
|
|
|
return TokenUsage{
|
|
InputTokens: inputTokens,
|
|
OutputTokens: usage.CompletionTokens,
|
|
CacheCreationTokens: 0, // OpenAI doesn't provide this directly
|
|
CacheReadTokens: cachedTokens,
|
|
}
|
|
}
|
|
|
|
func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
|
chatMessages := p.convertToOpenAIMessages(messages)
|
|
openaiTools := p.convertToOpenAITools(tools)
|
|
|
|
params := openai.ChatCompletionNewParams{
|
|
Model: openai.ChatModel(p.model.APIModel),
|
|
Messages: chatMessages,
|
|
MaxTokens: openai.Int(p.maxTokens),
|
|
Tools: openaiTools,
|
|
}
|
|
|
|
response, err := p.client.Chat.Completions.New(ctx, params)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
content := ""
|
|
if response.Choices[0].Message.Content != "" {
|
|
content = response.Choices[0].Message.Content
|
|
}
|
|
|
|
var toolCalls []message.ToolCall
|
|
if len(response.Choices[0].Message.ToolCalls) > 0 {
|
|
toolCalls = make([]message.ToolCall, len(response.Choices[0].Message.ToolCalls))
|
|
for i, call := range response.Choices[0].Message.ToolCalls {
|
|
toolCalls[i] = message.ToolCall{
|
|
ID: call.ID,
|
|
Name: call.Function.Name,
|
|
Input: call.Function.Arguments,
|
|
Type: "function",
|
|
}
|
|
}
|
|
}
|
|
|
|
tokenUsage := p.extractTokenUsage(response.Usage)
|
|
|
|
return &ProviderResponse{
|
|
Content: content,
|
|
ToolCalls: toolCalls,
|
|
Usage: tokenUsage,
|
|
}, nil
|
|
}
|
|
|
|
func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
|
|
chatMessages := p.convertToOpenAIMessages(messages)
|
|
openaiTools := p.convertToOpenAITools(tools)
|
|
|
|
params := openai.ChatCompletionNewParams{
|
|
Model: openai.ChatModel(p.model.APIModel),
|
|
Messages: chatMessages,
|
|
MaxTokens: openai.Int(p.maxTokens),
|
|
Tools: openaiTools,
|
|
StreamOptions: openai.ChatCompletionStreamOptionsParam{
|
|
IncludeUsage: openai.Bool(true),
|
|
},
|
|
}
|
|
|
|
stream := p.client.Chat.Completions.NewStreaming(ctx, params)
|
|
|
|
eventChan := make(chan ProviderEvent)
|
|
|
|
toolCalls := make([]message.ToolCall, 0)
|
|
go func() {
|
|
defer close(eventChan)
|
|
|
|
acc := openai.ChatCompletionAccumulator{}
|
|
currentContent := ""
|
|
|
|
for stream.Next() {
|
|
chunk := stream.Current()
|
|
acc.AddChunk(chunk)
|
|
|
|
if tool, ok := acc.JustFinishedToolCall(); ok {
|
|
toolCalls = append(toolCalls, message.ToolCall{
|
|
ID: tool.Id,
|
|
Name: tool.Name,
|
|
Input: tool.Arguments,
|
|
Type: "function",
|
|
})
|
|
}
|
|
|
|
for _, choice := range chunk.Choices {
|
|
if choice.Delta.Content != "" {
|
|
eventChan <- ProviderEvent{
|
|
Type: EventContentDelta,
|
|
Content: choice.Delta.Content,
|
|
}
|
|
currentContent += choice.Delta.Content
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := stream.Err(); err != nil {
|
|
eventChan <- ProviderEvent{
|
|
Type: EventError,
|
|
Error: err,
|
|
}
|
|
return
|
|
}
|
|
|
|
tokenUsage := p.extractTokenUsage(acc.Usage)
|
|
|
|
eventChan <- ProviderEvent{
|
|
Type: EventComplete,
|
|
Response: &ProviderResponse{
|
|
Content: currentContent,
|
|
ToolCalls: toolCalls,
|
|
Usage: tokenUsage,
|
|
},
|
|
}
|
|
}()
|
|
|
|
return eventChan, nil
|
|
}
|
|
|