mirror of
https://github.com/sst/opencode.git
synced 2025-08-04 13:30:52 +00:00
initial working agent
This commit is contained in:
parent
e7258e38ae
commit
005b8ac167
6 changed files with 201 additions and 22 deletions
|
@ -109,8 +109,6 @@ func setupSubscriptions(app *app.App) (chan tea.Msg, func()) {
|
|||
}
|
||||
}
|
||||
|
||||
// Execute adds all child commands to the root command and sets flags appropriately.
|
||||
// This is called by main.main(). It only needs to happen once to the rootCmd.
|
||||
func Execute() {
|
||||
err := rootCmd.Execute()
|
||||
if err != nil {
|
||||
|
@ -131,13 +129,14 @@ func loadConfig() {
|
|||
|
||||
// LLM
|
||||
viper.SetDefault("models.big", string(models.DefaultBigModel))
|
||||
viper.SetDefault("models.little", string(models.DefaultLittleModel))
|
||||
viper.SetDefault("models.small", string(models.DefaultLittleModel))
|
||||
viper.SetDefault("providers.openai.key", os.Getenv("OPENAI_API_KEY"))
|
||||
viper.SetDefault("providers.anthropic.key", os.Getenv("ANTHROPIC_API_KEY"))
|
||||
viper.SetDefault("providers.groq.key", os.Getenv("GROQ_API_KEY"))
|
||||
viper.SetDefault("providers.common.max_tokens", 4000)
|
||||
|
||||
viper.SetDefault("agents.default", "coder")
|
||||
//
|
||||
|
||||
viper.ReadInConfig()
|
||||
|
||||
workdir, err := os.Getwd()
|
||||
|
|
31
internal/llm/agent/title.go
Normal file
31
internal/llm/agent/title.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func GenerateTitle(ctx context.Context, content string) (string, error) {
|
||||
model, err := models.GetModel(ctx, models.ModelID(viper.GetString("models.small")))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
out, err := model.Generate(
|
||||
ctx,
|
||||
[]*schema.Message{
|
||||
schema.SystemMessage(`- you will generate a short title based on the first message a user begins a conversation with
|
||||
- ensure it is not more than 80 characters long
|
||||
- the title should be a summary of the user's message
|
||||
- do not use quotes or colons
|
||||
- the entire text you return will be used as the title`),
|
||||
schema.UserMessage(content),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return out.Content, nil
|
||||
}
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/google/uuid"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/agent"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/logging"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
"github.com/kujtimiihoxha/termai/internal/pubsub"
|
||||
|
@ -88,7 +89,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
|||
}
|
||||
|
||||
log.Printf("Request: %s", content)
|
||||
agent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
|
||||
currentAgent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
|
||||
if err != nil {
|
||||
s.Publish(AgentErrorEvent, AgentEvent{
|
||||
ID: id,
|
||||
|
@ -110,6 +111,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
|||
for _, m := range history {
|
||||
messages = append(messages, &m.MessageData)
|
||||
}
|
||||
|
||||
builder := callbacks.NewHandlerBuilder()
|
||||
builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
|
||||
i, ok := input.(*eModel.CallbackInput)
|
||||
|
@ -140,7 +142,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
|||
return ctx
|
||||
})
|
||||
|
||||
out, err := agent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
|
||||
out, err := currentAgent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
|
||||
if err != nil {
|
||||
s.Publish(AgentErrorEvent, AgentEvent{
|
||||
ID: id,
|
||||
|
@ -153,6 +155,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
|||
return
|
||||
}
|
||||
usage := out.ResponseMeta.Usage
|
||||
s.messages.Create(sessionID, *out)
|
||||
if usage != nil {
|
||||
log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
|
||||
session, err := s.sessions.Get(sessionID)
|
||||
|
@ -170,6 +173,29 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
|||
session.PromptTokens += int64(usage.PromptTokens)
|
||||
session.CompletionTokens += int64(usage.CompletionTokens)
|
||||
// TODO: calculate cost
|
||||
model := models.SupportedModels[models.ModelID(viper.GetString("models.big"))]
|
||||
session.Cost += float64(usage.PromptTokens)*(model.CostPer1MIn/1_000_000) +
|
||||
float64(usage.CompletionTokens)*(model.CostPer1MOut/1_000_000)
|
||||
var newTitle string
|
||||
if len(history) == 1 {
|
||||
// first message generate the title
|
||||
newTitle, err = agent.GenerateTitle(s.ctx, content)
|
||||
if err != nil {
|
||||
s.Publish(AgentErrorEvent, AgentEvent{
|
||||
ID: id,
|
||||
Type: AgentMessageTypeError,
|
||||
AgentID: RootAgent,
|
||||
MessageID: "",
|
||||
SessionID: sessionID,
|
||||
Content: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if newTitle != "" {
|
||||
session.Title = newTitle
|
||||
}
|
||||
|
||||
_, err = s.sessions.Save(session)
|
||||
if err != nil {
|
||||
s.Publish(AgentErrorEvent, AgentEvent{
|
||||
|
@ -183,7 +209,6 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
|
|||
return
|
||||
}
|
||||
}
|
||||
s.messages.Create(sessionID, *out)
|
||||
}
|
||||
|
||||
func (s *service) SendRequest(sessionID string, content string) {
|
||||
|
|
|
@ -3,6 +3,7 @@ package models
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/claude"
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
|
@ -16,10 +17,12 @@ type (
|
|||
)
|
||||
|
||||
type Model struct {
|
||||
ID ModelID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider ModelProvider `json:"provider"`
|
||||
APIModel string `json:"api_model"` // Actual value used when calling the API
|
||||
ID ModelID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider ModelProvider `json:"provider"`
|
||||
APIModel string `json:"api_model"`
|
||||
CostPer1MIn float64 `json:"cost_per_1m_in"`
|
||||
CostPer1MOut float64 `json:"cost_per_1m_out"`
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -52,6 +55,9 @@ const (
|
|||
// Meta
|
||||
Llama3 ModelID = "llama-3"
|
||||
Llama270B ModelID = "llama-2-70b"
|
||||
// GROQ
|
||||
GroqLlama3SpecDec ModelID = "groq-llama-3-spec-dec"
|
||||
GroqQwen32BCoder ModelID = "qwen-2.5-coder-32b"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -61,6 +67,7 @@ const (
|
|||
ProviderXAI ModelProvider = "xai"
|
||||
ProviderDeepSeek ModelProvider = "deepseek"
|
||||
ProviderMeta ModelProvider = "meta"
|
||||
ProviderGroq ModelProvider = "groq"
|
||||
)
|
||||
|
||||
var SupportedModels = map[ModelID]Model{
|
||||
|
@ -72,10 +79,12 @@ var SupportedModels = map[ModelID]Model{
|
|||
APIModel: "gpt-4o",
|
||||
},
|
||||
GPT4oMini: {
|
||||
ID: GPT4oMini,
|
||||
Name: "GPT-4o Mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4o-mini",
|
||||
ID: GPT4oMini,
|
||||
Name: "GPT-4o Mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4o-mini",
|
||||
CostPer1MIn: 0.150,
|
||||
CostPer1MOut: 0.600,
|
||||
},
|
||||
GPT45: {
|
||||
ID: GPT45,
|
||||
|
@ -172,10 +181,25 @@ var SupportedModels = map[ModelID]Model{
|
|||
Provider: ProviderMeta,
|
||||
APIModel: "llama-2-70b",
|
||||
},
|
||||
|
||||
// GROQ
|
||||
GroqLlama3SpecDec: {
|
||||
ID: GroqLlama3SpecDec,
|
||||
Name: "GROQ LLaMA 3 SpecDec",
|
||||
Provider: ProviderGroq,
|
||||
APIModel: "llama-3.3-70b-specdec",
|
||||
},
|
||||
GroqQwen32BCoder: {
|
||||
ID: GroqQwen32BCoder,
|
||||
Name: "GROQ Qwen 2.5 Coder 32B",
|
||||
Provider: ProviderGroq,
|
||||
APIModel: "qwen-2.5-coder-32b",
|
||||
},
|
||||
}
|
||||
|
||||
func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
|
||||
provider := SupportedModels[model].Provider
|
||||
log.Printf("Provider: %s", provider)
|
||||
maxTokens := viper.GetInt("providers.common.max_tokens")
|
||||
switch provider {
|
||||
case ProviderOpenAI:
|
||||
|
@ -191,6 +215,14 @@ func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
|
|||
MaxTokens: maxTokens,
|
||||
})
|
||||
|
||||
case ProviderGroq:
|
||||
return openai.NewChatModel(ctx, &openai.ChatModelConfig{
|
||||
BaseURL: "https://api.groq.com/openai/v1",
|
||||
APIKey: viper.GetString("providers.groq.key"),
|
||||
Model: string(SupportedModels[model].APIModel),
|
||||
MaxTokens: &maxTokens,
|
||||
})
|
||||
|
||||
}
|
||||
return nil, errors.New("unsupported provider")
|
||||
}
|
||||
|
|
|
@ -1,22 +1,33 @@
|
|||
package repl
|
||||
|
||||
import (
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/kujtimiihoxha/termai/internal/app"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
"github.com/kujtimiihoxha/termai/internal/pubsub"
|
||||
"github.com/kujtimiihoxha/termai/internal/session"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/layout"
|
||||
)
|
||||
|
||||
type MessagesCmp interface {
|
||||
tea.Model
|
||||
layout.Focusable
|
||||
layout.Bordered
|
||||
layout.Sizeable
|
||||
layout.Bindings
|
||||
}
|
||||
|
||||
type messagesCmp struct {
|
||||
app *app.App
|
||||
messages []message.Message
|
||||
session session.Session
|
||||
}
|
||||
|
||||
func (m *messagesCmp) Init() tea.Cmd {
|
||||
return nil
|
||||
viewport viewport.Model
|
||||
width int
|
||||
height int
|
||||
focused bool
|
||||
}
|
||||
|
||||
func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
@ -25,6 +36,12 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
if msg.Type == pubsub.CreatedEvent {
|
||||
m.messages = append(m.messages, msg.Payload)
|
||||
}
|
||||
case pubsub.Event[session.Session]:
|
||||
if msg.Type == pubsub.UpdatedEvent {
|
||||
if m.session.ID == msg.Payload.ID {
|
||||
m.session = msg.Payload
|
||||
}
|
||||
}
|
||||
case SelectedSessionMsg:
|
||||
m.session, _ = m.app.Sessions.Get(msg.SessionID)
|
||||
m.messages, _ = m.app.Messages.List(m.session.ID)
|
||||
|
@ -40,7 +57,55 @@ func (i *messagesCmp) View() string {
|
|||
return lipgloss.JoinVertical(lipgloss.Top, stringMessages...)
|
||||
}
|
||||
|
||||
func NewMessagesCmp(app *app.App) tea.Model {
|
||||
// BindingKeys implements MessagesCmp.
|
||||
func (m *messagesCmp) BindingKeys() []key.Binding {
|
||||
return []key.Binding{}
|
||||
}
|
||||
|
||||
// Blur implements MessagesCmp.
|
||||
func (m *messagesCmp) Blur() tea.Cmd {
|
||||
m.focused = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// BorderText implements MessagesCmp.
|
||||
func (m *messagesCmp) BorderText() map[layout.BorderPosition]string {
|
||||
title := m.session.Title
|
||||
if len(title) > 20 {
|
||||
title = title[:20] + "..."
|
||||
}
|
||||
return map[layout.BorderPosition]string{
|
||||
layout.TopLeftBorder: title,
|
||||
}
|
||||
}
|
||||
|
||||
// Focus implements MessagesCmp.
|
||||
func (m *messagesCmp) Focus() tea.Cmd {
|
||||
m.focused = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSize implements MessagesCmp.
|
||||
func (m *messagesCmp) GetSize() (int, int) {
|
||||
return m.width, m.height
|
||||
}
|
||||
|
||||
// IsFocused implements MessagesCmp.
|
||||
func (m *messagesCmp) IsFocused() bool {
|
||||
return m.focused
|
||||
}
|
||||
|
||||
// SetSize implements MessagesCmp.
|
||||
func (m *messagesCmp) SetSize(width int, height int) {
|
||||
m.width = width
|
||||
m.height = height
|
||||
}
|
||||
|
||||
func (m *messagesCmp) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewMessagesCmp(app *app.App) MessagesCmp {
|
||||
return &messagesCmp{
|
||||
app: app,
|
||||
messages: []message.Message{},
|
||||
|
|
|
@ -2,6 +2,7 @@ package repl
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
"github.com/charmbracelet/bubbles/list"
|
||||
|
@ -82,7 +83,7 @@ func (i *sessionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
items[i] = listItem{
|
||||
id: s.ID,
|
||||
title: s.Title,
|
||||
desc: fmt.Sprintf("Tokens: %d, Cost: %.2f", s.PromptTokens+s.CompletionTokens, s.Cost),
|
||||
desc: formatTokensAndCost(s.PromptTokens+s.CompletionTokens, s.Cost),
|
||||
}
|
||||
}
|
||||
return i, i.list.SetItems(items)
|
||||
|
@ -94,7 +95,7 @@ func (i *sessionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
s := item.(listItem)
|
||||
if s.id == msg.Payload.ID {
|
||||
s.title = msg.Payload.Title
|
||||
s.desc = fmt.Sprintf("Tokens: %d, Cost: %.2f", msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost)
|
||||
s.desc = formatTokensAndCost(msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost)
|
||||
items[idx] = s
|
||||
break
|
||||
}
|
||||
|
@ -169,6 +170,32 @@ func (i *sessionsCmp) BindingKeys() []key.Binding {
|
|||
return append(layout.KeyMapToSlice(i.list.KeyMap), sessionKeyMapValue.Select)
|
||||
}
|
||||
|
||||
func formatTokensAndCost(tokens int64, cost float64) string {
|
||||
// Format tokens in human-readable format (e.g., 110K, 1.2M)
|
||||
var formattedTokens string
|
||||
switch {
|
||||
case tokens >= 1_000_000:
|
||||
formattedTokens = fmt.Sprintf("%.1fM", float64(tokens)/1_000_000)
|
||||
case tokens >= 1_000:
|
||||
formattedTokens = fmt.Sprintf("%.1fK", float64(tokens)/1_000)
|
||||
default:
|
||||
formattedTokens = fmt.Sprintf("%d", tokens)
|
||||
}
|
||||
|
||||
// Remove .0 suffix if present
|
||||
if strings.HasSuffix(formattedTokens, ".0K") {
|
||||
formattedTokens = strings.Replace(formattedTokens, ".0K", "K", 1)
|
||||
}
|
||||
if strings.HasSuffix(formattedTokens, ".0M") {
|
||||
formattedTokens = strings.Replace(formattedTokens, ".0M", "M", 1)
|
||||
}
|
||||
|
||||
// Format cost with $ symbol and 2 decimal places
|
||||
formattedCost := fmt.Sprintf("$%.2f", cost)
|
||||
|
||||
return fmt.Sprintf("Tokens: %s, Cost: %s", formattedTokens, formattedCost)
|
||||
}
|
||||
|
||||
func NewSessionsCmp(app *app.App) SessionsCmp {
|
||||
listDelegate := list.NewDefaultDelegate()
|
||||
defaultItemStyle := list.NewDefaultItemStyles()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue