initial working agent

This commit is contained in:
Kujtim Hoxha 2025-03-24 11:47:39 +01:00
parent e7258e38ae
commit 005b8ac167
6 changed files with 201 additions and 22 deletions

View file

@ -109,8 +109,6 @@ func setupSubscriptions(app *app.App) (chan tea.Msg, func()) {
}
}
// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
err := rootCmd.Execute()
if err != nil {
@ -131,13 +129,14 @@ func loadConfig() {
// LLM
viper.SetDefault("models.big", string(models.DefaultBigModel))
viper.SetDefault("models.little", string(models.DefaultLittleModel))
viper.SetDefault("models.small", string(models.DefaultLittleModel))
viper.SetDefault("providers.openai.key", os.Getenv("OPENAI_API_KEY"))
viper.SetDefault("providers.anthropic.key", os.Getenv("ANTHROPIC_API_KEY"))
viper.SetDefault("providers.groq.key", os.Getenv("GROQ_API_KEY"))
viper.SetDefault("providers.common.max_tokens", 4000)
viper.SetDefault("agents.default", "coder")
//
viper.ReadInConfig()
workdir, err := os.Getwd()

View file

@ -0,0 +1,31 @@
package agent
import (
"context"
"github.com/cloudwego/eino/schema"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/spf13/viper"
)
func GenerateTitle(ctx context.Context, content string) (string, error) {
model, err := models.GetModel(ctx, models.ModelID(viper.GetString("models.small")))
if err != nil {
return "", err
}
out, err := model.Generate(
ctx,
[]*schema.Message{
schema.SystemMessage(`- you will generate a short title based on the first message a user begins a conversation with
- ensure it is not more than 80 characters long
- the title should be a summary of the user's message
- do not use quotes or colons
- the entire text you return will be used as the title`),
schema.UserMessage(content),
},
)
if err != nil {
return "", err
}
return out.Content, nil
}

View file

@ -11,6 +11,7 @@ import (
"github.com/cloudwego/eino/schema"
"github.com/google/uuid"
"github.com/kujtimiihoxha/termai/internal/llm/agent"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/message"
"github.com/kujtimiihoxha/termai/internal/pubsub"
@ -88,7 +89,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
}
log.Printf("Request: %s", content)
agent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
currentAgent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
ID: id,
@ -110,6 +111,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
for _, m := range history {
messages = append(messages, &m.MessageData)
}
builder := callbacks.NewHandlerBuilder()
builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
i, ok := input.(*eModel.CallbackInput)
@ -140,7 +142,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
return ctx
})
out, err := agent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
out, err := currentAgent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
ID: id,
@ -153,6 +155,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
return
}
usage := out.ResponseMeta.Usage
s.messages.Create(sessionID, *out)
if usage != nil {
log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
session, err := s.sessions.Get(sessionID)
@ -170,6 +173,29 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
session.PromptTokens += int64(usage.PromptTokens)
session.CompletionTokens += int64(usage.CompletionTokens)
// TODO: calculate cost
model := models.SupportedModels[models.ModelID(viper.GetString("models.big"))]
session.Cost += float64(usage.PromptTokens)*(model.CostPer1MIn/1_000_000) +
float64(usage.CompletionTokens)*(model.CostPer1MOut/1_000_000)
var newTitle string
if len(history) == 1 {
// first message generate the title
newTitle, err = agent.GenerateTitle(s.ctx, content)
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
ID: id,
Type: AgentMessageTypeError,
AgentID: RootAgent,
MessageID: "",
SessionID: sessionID,
Content: err.Error(),
})
return
}
}
if newTitle != "" {
session.Title = newTitle
}
_, err = s.sessions.Save(session)
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
@ -183,7 +209,6 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
return
}
}
s.messages.Create(sessionID, *out)
}
func (s *service) SendRequest(sessionID string, content string) {

View file

@ -3,6 +3,7 @@ package models
import (
"context"
"errors"
"log"
"github.com/cloudwego/eino-ext/components/model/claude"
"github.com/cloudwego/eino-ext/components/model/openai"
@ -16,10 +17,12 @@ type (
)
type Model struct {
ID ModelID `json:"id"`
Name string `json:"name"`
Provider ModelProvider `json:"provider"`
APIModel string `json:"api_model"` // Actual value used when calling the API
ID ModelID `json:"id"`
Name string `json:"name"`
Provider ModelProvider `json:"provider"`
APIModel string `json:"api_model"`
CostPer1MIn float64 `json:"cost_per_1m_in"`
CostPer1MOut float64 `json:"cost_per_1m_out"`
}
const (
@ -52,6 +55,9 @@ const (
// Meta
Llama3 ModelID = "llama-3"
Llama270B ModelID = "llama-2-70b"
// GROQ
GroqLlama3SpecDec ModelID = "groq-llama-3-spec-dec"
GroqQwen32BCoder ModelID = "qwen-2.5-coder-32b"
)
const (
@ -61,6 +67,7 @@ const (
ProviderXAI ModelProvider = "xai"
ProviderDeepSeek ModelProvider = "deepseek"
ProviderMeta ModelProvider = "meta"
ProviderGroq ModelProvider = "groq"
)
var SupportedModels = map[ModelID]Model{
@ -72,10 +79,12 @@ var SupportedModels = map[ModelID]Model{
APIModel: "gpt-4o",
},
GPT4oMini: {
ID: GPT4oMini,
Name: "GPT-4o Mini",
Provider: ProviderOpenAI,
APIModel: "gpt-4o-mini",
ID: GPT4oMini,
Name: "GPT-4o Mini",
Provider: ProviderOpenAI,
APIModel: "gpt-4o-mini",
CostPer1MIn: 0.150,
CostPer1MOut: 0.600,
},
GPT45: {
ID: GPT45,
@ -172,10 +181,25 @@ var SupportedModels = map[ModelID]Model{
Provider: ProviderMeta,
APIModel: "llama-2-70b",
},
// GROQ
GroqLlama3SpecDec: {
ID: GroqLlama3SpecDec,
Name: "GROQ LLaMA 3 SpecDec",
Provider: ProviderGroq,
APIModel: "llama-3.3-70b-specdec",
},
GroqQwen32BCoder: {
ID: GroqQwen32BCoder,
Name: "GROQ Qwen 2.5 Coder 32B",
Provider: ProviderGroq,
APIModel: "qwen-2.5-coder-32b",
},
}
func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
provider := SupportedModels[model].Provider
log.Printf("Provider: %s", provider)
maxTokens := viper.GetInt("providers.common.max_tokens")
switch provider {
case ProviderOpenAI:
@ -191,6 +215,14 @@ func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
MaxTokens: maxTokens,
})
case ProviderGroq:
return openai.NewChatModel(ctx, &openai.ChatModelConfig{
BaseURL: "https://api.groq.com/openai/v1",
APIKey: viper.GetString("providers.groq.key"),
Model: string(SupportedModels[model].APIModel),
MaxTokens: &maxTokens,
})
}
return nil, errors.New("unsupported provider")
}

View file

@ -1,22 +1,33 @@
package repl
import (
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/kujtimiihoxha/termai/internal/app"
"github.com/kujtimiihoxha/termai/internal/message"
"github.com/kujtimiihoxha/termai/internal/pubsub"
"github.com/kujtimiihoxha/termai/internal/session"
"github.com/kujtimiihoxha/termai/internal/tui/layout"
)
type MessagesCmp interface {
tea.Model
layout.Focusable
layout.Bordered
layout.Sizeable
layout.Bindings
}
type messagesCmp struct {
app *app.App
messages []message.Message
session session.Session
}
func (m *messagesCmp) Init() tea.Cmd {
return nil
viewport viewport.Model
width int
height int
focused bool
}
func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
@ -25,6 +36,12 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if msg.Type == pubsub.CreatedEvent {
m.messages = append(m.messages, msg.Payload)
}
case pubsub.Event[session.Session]:
if msg.Type == pubsub.UpdatedEvent {
if m.session.ID == msg.Payload.ID {
m.session = msg.Payload
}
}
case SelectedSessionMsg:
m.session, _ = m.app.Sessions.Get(msg.SessionID)
m.messages, _ = m.app.Messages.List(m.session.ID)
@ -40,7 +57,55 @@ func (i *messagesCmp) View() string {
return lipgloss.JoinVertical(lipgloss.Top, stringMessages...)
}
func NewMessagesCmp(app *app.App) tea.Model {
// BindingKeys implements MessagesCmp.
func (m *messagesCmp) BindingKeys() []key.Binding {
return []key.Binding{}
}
// Blur implements MessagesCmp.
func (m *messagesCmp) Blur() tea.Cmd {
m.focused = false
return nil
}
// BorderText implements MessagesCmp.
func (m *messagesCmp) BorderText() map[layout.BorderPosition]string {
title := m.session.Title
if len(title) > 20 {
title = title[:20] + "..."
}
return map[layout.BorderPosition]string{
layout.TopLeftBorder: title,
}
}
// Focus implements MessagesCmp.
func (m *messagesCmp) Focus() tea.Cmd {
m.focused = true
return nil
}
// GetSize implements MessagesCmp.
func (m *messagesCmp) GetSize() (int, int) {
return m.width, m.height
}
// IsFocused implements MessagesCmp.
func (m *messagesCmp) IsFocused() bool {
return m.focused
}
// SetSize implements MessagesCmp.
func (m *messagesCmp) SetSize(width int, height int) {
m.width = width
m.height = height
}
func (m *messagesCmp) Init() tea.Cmd {
return nil
}
func NewMessagesCmp(app *app.App) MessagesCmp {
return &messagesCmp{
app: app,
messages: []message.Message{},

View file

@ -2,6 +2,7 @@ package repl
import (
"fmt"
"strings"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/list"
@ -82,7 +83,7 @@ func (i *sessionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
items[i] = listItem{
id: s.ID,
title: s.Title,
desc: fmt.Sprintf("Tokens: %d, Cost: %.2f", s.PromptTokens+s.CompletionTokens, s.Cost),
desc: formatTokensAndCost(s.PromptTokens+s.CompletionTokens, s.Cost),
}
}
return i, i.list.SetItems(items)
@ -94,7 +95,7 @@ func (i *sessionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
s := item.(listItem)
if s.id == msg.Payload.ID {
s.title = msg.Payload.Title
s.desc = fmt.Sprintf("Tokens: %d, Cost: %.2f", msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost)
s.desc = formatTokensAndCost(msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost)
items[idx] = s
break
}
@ -169,6 +170,32 @@ func (i *sessionsCmp) BindingKeys() []key.Binding {
return append(layout.KeyMapToSlice(i.list.KeyMap), sessionKeyMapValue.Select)
}
func formatTokensAndCost(tokens int64, cost float64) string {
// Format tokens in human-readable format (e.g., 110K, 1.2M)
var formattedTokens string
switch {
case tokens >= 1_000_000:
formattedTokens = fmt.Sprintf("%.1fM", float64(tokens)/1_000_000)
case tokens >= 1_000:
formattedTokens = fmt.Sprintf("%.1fK", float64(tokens)/1_000)
default:
formattedTokens = fmt.Sprintf("%d", tokens)
}
// Remove .0 suffix if present
if strings.HasSuffix(formattedTokens, ".0K") {
formattedTokens = strings.Replace(formattedTokens, ".0K", "K", 1)
}
if strings.HasSuffix(formattedTokens, ".0M") {
formattedTokens = strings.Replace(formattedTokens, ".0M", "M", 1)
}
// Format cost with $ symbol and 2 decimal places
formattedCost := fmt.Sprintf("$%.2f", cost)
return fmt.Sprintf("Tokens: %s, Cost: %s", formattedTokens, formattedCost)
}
func NewSessionsCmp(app *app.App) SessionsCmp {
listDelegate := list.NewDefaultDelegate()
defaultItemStyle := list.NewDefaultItemStyles()