mirror of
https://github.com/sst/opencode.git
synced 2025-07-08 00:25:00 +00:00
rework llm
This commit is contained in:
parent
904061c243
commit
afd9ad0560
61 changed files with 5882 additions and 2074 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -42,3 +42,5 @@ debug.log
|
|||
.env.local
|
||||
|
||||
.termai
|
||||
|
||||
.termai.json
|
||||
|
|
101
README.md
101
README.md
|
@ -1,13 +1,100 @@
|
|||
# TermAI
|
||||
|
||||
**⚠️ WORK IN PROGRESS ⚠️**
|
||||
A powerful terminal-based AI assistant for developers, providing intelligent coding assistance directly in your terminal.
|
||||
|
||||
This project is currently under active development.
|
||||
[](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy)
|
||||
|
||||
## Current Progress
|
||||
## Overview
|
||||
|
||||
- Initial CLI setup
|
||||
- Basic functionality implementation
|
||||
- Working on core features
|
||||
TermAI is a Go-based CLI application that brings AI assistance to your terminal. It provides a TUI (Terminal User Interface) for interacting with various AI models to help with coding tasks, debugging, and more.
|
||||
|
||||
More details coming soon.
|
||||
## Features
|
||||
|
||||
- **Interactive TUI**: Built with [Bubble Tea](https://github.com/charmbracelet/bubbletea) for a smooth terminal experience
|
||||
- **Multiple AI Providers**: Support for OpenAI, Anthropic Claude, and Google Gemini models
|
||||
- **Session Management**: Save and manage multiple conversation sessions
|
||||
- **Tool Integration**: AI can execute commands, search files, and modify code
|
||||
- **Vim-like Editor**: Integrated editor with Vim keybindings for text input
|
||||
- **Persistent Storage**: SQLite database for storing conversations and sessions
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Coming soon
|
||||
go install github.com/kujtimiihoxha/termai@latest
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
TermAI looks for configuration in the following locations:
|
||||
|
||||
- `$HOME/.termai.json`
|
||||
- `$XDG_CONFIG_HOME/termai/.termai.json`
|
||||
- `./.termai.json` (local directory)
|
||||
|
||||
You can also use environment variables:
|
||||
|
||||
- `ANTHROPIC_API_KEY`: For Claude models
|
||||
- `OPENAI_API_KEY`: For OpenAI models
|
||||
- `GEMINI_API_KEY`: For Google Gemini models
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Start TermAI
|
||||
termai
|
||||
|
||||
# Start with debug logging
|
||||
termai -d
|
||||
```
|
||||
|
||||
### Keyboard Shortcuts
|
||||
|
||||
- `L`: View logs
|
||||
- `N`: Create new session
|
||||
- `Esc`: Close current view
|
||||
- `Backspace`: Go back
|
||||
- `Ctrl+C` or `q`: Quit
|
||||
- `?`: Toggle help
|
||||
|
||||
## Architecture
|
||||
|
||||
TermAI is built with a modular architecture:
|
||||
|
||||
- **cmd**: Command-line interface using Cobra
|
||||
- **internal/app**: Core application services
|
||||
- **internal/config**: Configuration management
|
||||
- **internal/db**: Database operations and migrations
|
||||
- **internal/llm**: LLM providers and tools integration
|
||||
- **internal/tui**: Terminal UI components and layouts
|
||||
- **internal/logging**: Logging infrastructure
|
||||
- **internal/message**: Message handling
|
||||
- **internal/session**: Session management
|
||||
|
||||
## Development
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Go 1.23.5 or higher
|
||||
|
||||
### Building from Source
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/kujtimiihoxha/termai.git
|
||||
cd termai
|
||||
|
||||
# Build
|
||||
go build -o termai
|
||||
|
||||
# Run
|
||||
./termai
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
[License information coming soon]
|
||||
|
||||
## Contributing
|
||||
|
||||
[Contribution guidelines coming soon]
|
||||
|
|
56
cmd/root.go
56
cmd/root.go
|
@ -7,12 +7,12 @@ import (
|
|||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/kujtimiihoxha/termai/internal/app"
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/db"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/agent"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui"
|
||||
zone "github.com/lrstanley/bubblezone"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
|
@ -25,11 +25,10 @@ var rootCmd = &cobra.Command{
|
|||
return nil
|
||||
}
|
||||
debug, _ := cmd.Flags().GetBool("debug")
|
||||
viper.Set("debug", debug)
|
||||
if debug {
|
||||
viper.Set("log.level", "debug")
|
||||
err := config.Load(debug)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := db.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -49,6 +48,8 @@ var rootCmd = &cobra.Command{
|
|||
defer unsub()
|
||||
|
||||
go func() {
|
||||
// Set this up once
|
||||
agent.GetMcpTools(ctx)
|
||||
for msg := range ch {
|
||||
tui.Send(msg)
|
||||
}
|
||||
|
@ -95,16 +96,6 @@ func setupSubscriptions(app *app.App) (chan tea.Msg, func()) {
|
|||
wg.Done()
|
||||
}()
|
||||
}
|
||||
{
|
||||
sub := app.LLM.Subscribe(ctx)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
for ev := range sub {
|
||||
ch <- ev
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
{
|
||||
sub := app.Permissions.Subscribe(ctx)
|
||||
wg.Add(1)
|
||||
|
@ -129,40 +120,7 @@ func Execute() {
|
|||
}
|
||||
}
|
||||
|
||||
func loadConfig() {
|
||||
viper.SetConfigName(".termai")
|
||||
viper.SetConfigType("yaml")
|
||||
viper.AddConfigPath("$HOME")
|
||||
viper.AddConfigPath("$XDG_CONFIG_HOME/termai")
|
||||
viper.AddConfigPath(".")
|
||||
viper.SetEnvPrefix("TERMAI")
|
||||
// SET DEFAULTS
|
||||
viper.SetDefault("log.level", "info")
|
||||
viper.SetDefault("data.dir", ".termai")
|
||||
|
||||
// LLM
|
||||
viper.SetDefault("models.big", string(models.DefaultBigModel))
|
||||
viper.SetDefault("models.small", string(models.DefaultLittleModel))
|
||||
|
||||
viper.SetDefault("providers.openai.key", os.Getenv("OPENAI_API_KEY"))
|
||||
viper.SetDefault("providers.anthropic.key", os.Getenv("ANTHROPIC_API_KEY"))
|
||||
viper.SetDefault("providers.groq.key", os.Getenv("GROQ_API_KEY"))
|
||||
viper.SetDefault("providers.common.max_tokens", 4000)
|
||||
|
||||
viper.SetDefault("agents.default", "coder")
|
||||
|
||||
viper.ReadInConfig()
|
||||
|
||||
workdir, err := os.Getwd()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
viper.Set("wd", workdir)
|
||||
}
|
||||
|
||||
func init() {
|
||||
loadConfig()
|
||||
|
||||
rootCmd.Flags().BoolP("help", "h", false, "Help")
|
||||
rootCmd.Flags().BoolP("debug", "d", false, "Help")
|
||||
}
|
||||
|
|
86
go.mod
86
go.mod
|
@ -3,6 +3,7 @@ module github.com/kujtimiihoxha/termai
|
|||
go 1.23.5
|
||||
|
||||
require (
|
||||
github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2
|
||||
github.com/bmatcuk/doublestar/v4 v4.8.1
|
||||
github.com/catppuccin/go v0.3.0
|
||||
github.com/charmbracelet/bubbles v0.20.0
|
||||
|
@ -10,92 +11,70 @@ require (
|
|||
github.com/charmbracelet/glamour v0.9.1
|
||||
github.com/charmbracelet/huh v0.6.0
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/cloudwego/eino v0.3.17
|
||||
github.com/cloudwego/eino-ext/components/model/claude v0.0.0-20250320062631-616205c32186
|
||||
github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250320062631-616205c32186
|
||||
github.com/go-logfmt/logfmt v0.6.0
|
||||
github.com/golang-migrate/migrate/v4 v4.18.2
|
||||
github.com/google/generative-ai-go v0.19.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/kujtimiihoxha/vimtea v0.0.3-0.20250317175717-9d8ba9c69840
|
||||
github.com/kujtimiihoxha/vimtea v0.0.3-0.20250329221256-a250e98498f9
|
||||
github.com/lrstanley/bubblezone v0.0.0-20250315020633-c249a3fe1231
|
||||
github.com/mark3labs/mcp-go v0.17.0
|
||||
github.com/mattn/go-runewidth v0.0.16
|
||||
github.com/mattn/go-sqlite3 v1.14.24
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6
|
||||
github.com/muesli/reflow v0.3.0
|
||||
github.com/muesli/termenv v0.16.0
|
||||
github.com/openai/openai-go v0.1.0-beta.2
|
||||
github.com/sergi/go-diff v1.3.1
|
||||
github.com/spf13/cobra v1.9.1
|
||||
github.com/spf13/viper v1.20.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1
|
||||
google.golang.org/api v0.215.0
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.116.0 // indirect
|
||||
cloud.google.com/go/ai v0.8.0 // indirect
|
||||
cloud.google.com/go/auth v0.13.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.6.0 // indirect
|
||||
cloud.google.com/go/longrunning v0.5.7 // indirect
|
||||
github.com/alecthomas/chroma/v2 v2.15.0 // indirect
|
||||
github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.8 // indirect
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.33.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.29.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.54 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.28 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.28 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.9 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.24.11 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.10 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.33.9 // indirect
|
||||
github.com/aws/smithy-go v1.22.1 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/bytedance/sonic v1.12.2 // indirect
|
||||
github.com/bytedance/sonic/loader v0.2.0 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
github.com/charmbracelet/x/ansi v0.8.0 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.0.0-20250305023926-469de0301955 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.4 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fsnotify/fsnotify v1.8.0 // indirect
|
||||
github.com/getkin/kin-openapi v0.118.0 // indirect
|
||||
github.com/go-openapi/jsonpointer v0.19.5 // indirect
|
||||
github.com/go-openapi/swag v0.19.5 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
|
||||
github.com/goph/emperror v0.17.2 // indirect
|
||||
github.com/google/s2a-go v0.1.8 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/invopop/yaml v0.1.0 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.0.9 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/nikolalohinski/gonja v1.5.3 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
||||
github.com/perimeterx/marshmallow v1.1.4 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/sagikazarmark/locafero v0.7.0 // indirect
|
||||
github.com/sahilm/fuzzy v0.1.1 // indirect
|
||||
github.com/sashabaranov/go-openai v1.32.5 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
github.com/spf13/afero v1.12.0 // indirect
|
||||
github.com/spf13/cast v1.7.1 // indirect
|
||||
|
@ -105,19 +84,32 @@ require (
|
|||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/yargevad/filepathx v1.0.0 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/yuin/goldmark v1.7.8 // indirect
|
||||
github.com/yuin/goldmark-emoji v1.0.5 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
|
||||
go.opentelemetry.io/otel v1.29.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.29.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.29.0 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/arch v0.11.0 // indirect
|
||||
golang.org/x/net v0.33.0 // indirect
|
||||
golang.design/x/clipboard v0.7.0 // indirect
|
||||
golang.org/x/crypto v0.33.0 // indirect
|
||||
golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 // indirect
|
||||
golang.org/x/image v0.14.0 // indirect
|
||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a // indirect
|
||||
golang.org/x/net v0.34.0 // indirect
|
||||
golang.org/x/oauth2 v0.25.0 // indirect
|
||||
golang.org/x/sync v0.12.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/term v0.30.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
golang.org/x/time v0.8.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 // indirect
|
||||
google.golang.org/grpc v1.67.3 // indirect
|
||||
google.golang.org/protobuf v1.36.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
|
242
go.sum
242
go.sum
|
@ -1,66 +1,37 @@
|
|||
cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE=
|
||||
cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U=
|
||||
cloud.google.com/go/ai v0.8.0 h1:rXUEz8Wp2OlrM8r1bfmpF2+VKqc1VJpafE3HgzRnD/w=
|
||||
cloud.google.com/go/ai v0.8.0/go.mod h1:t3Dfk4cM61sytiggo2UyGsDVW3RF1qGZaUKDrZFyqkE=
|
||||
cloud.google.com/go/auth v0.13.0 h1:8Fu8TZy167JkW8Tj3q7dIkr2v4cndv41ouecJx0PAHs=
|
||||
cloud.google.com/go/auth v0.13.0/go.mod h1:COOjD9gwfKNKz+IIduatIhYJQIc0mG3H102r/EMxX6Q=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.6 h1:V6a6XDu2lTwPZWOawrAa9HUK+DB2zfJyTuciBG5hFkU=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.6/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8=
|
||||
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
|
||||
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
|
||||
cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU=
|
||||
cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng=
|
||||
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
|
||||
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
|
||||
github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o=
|
||||
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
|
||||
github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
|
||||
github.com/alecthomas/chroma/v2 v2.15.0 h1:LxXTQHFoYrstG2nnV9y2X5O94sOBzf0CIUpSTbpxvMc=
|
||||
github.com/alecthomas/chroma/v2 v2.15.0/go.mod h1:gUhVLrPDXPtp/f+L1jo9xepo9gL4eLwRuGAunSZMkio=
|
||||
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
|
||||
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.8 h1:ss/c/eeyILgoK2sMsTJdcdLdhY3wZSt//+nanM41B9w=
|
||||
github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.8/go.mod h1:GJxtdOs9K4neo8Gg65CjJ7jNautmldGli5/OFNabOoo=
|
||||
github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2 h1:h7qxtumNjKPWFv1QM/HJy60MteeW23iKeEtBoY7bYZk=
|
||||
github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2/go.mod h1:AapDW22irxK2PSumZiQXYUFvsdQgkwIWlpESweWZI/c=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aws/aws-sdk-go-v2 v1.33.0 h1:Evgm4DI9imD81V0WwD+TN4DCwjUMdc94TrduMLbgZJs=
|
||||
github.com/aws/aws-sdk-go-v2 v1.33.0/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 h1:tW1/Rkad38LA15X4UQtjXZXNKsCgkshC3EbmcUmghTg=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3/go.mod h1:UbnqO+zjqk3uIt9yCACHJ9IVNhyhOCnYk8yA19SAWrM=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.29.1 h1:JZhGawAyZ/EuJeBtbQYnaoftczcb2drR2Iq36Wgz4sQ=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.29.1/go.mod h1:7bR2YD5euaxBhzt2y/oDkt3uNRb6tjFp98GlTFueRwk=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.54 h1:4UmqeOqJPvdvASZWrKlhzpRahAulBfyTJQUaYy4+hEI=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.54/go.mod h1:RTdfo0P0hbbTxIhmQrOsC/PquBZGabEPnCaxxKRPSnI=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.24 h1:5grmdTdMsovn9kPZPI23Hhvp0ZyNm5cRO+IZFIYiAfw=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.24/go.mod h1:zqi7TVKTswH3Ozq28PkmBmgzG1tona7mo9G2IJg4Cis=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.28 h1:igORFSiH3bfq4lxKFkTSYDhJEUCYo6C8VKiWJjYwQuQ=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.28/go.mod h1:3So8EA/aAYm36L7XIvCVwLa0s5N0P7o2b1oqnx/2R4g=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.28 h1:1mOW9zAUMhTSrMDssEHS/ajx8JcAj/IcftzcmNlmVLI=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.28/go.mod h1:kGlXVIWDfvt2Ox5zEaNglmq0hXPHgQFNMix33Tw22jA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 h1:iXtILhvDxB6kPvEXgsDhGaZCSC6LQET5ZHSdJozeI0Y=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1/go.mod h1:9nu0fVANtYiAePIBh2/pFUSwtJ402hLnp854CNoDOeE=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.9 h1:TQmKDyETFGiXVhZfQ/I0cCFziqqX58pi4tKJGYGFSz0=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.9/go.mod h1:HVLPK2iHQBUx7HfZeOQSEu3v2ubZaAY2YPbAm5/WUyY=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.24.11 h1:kuIyu4fTT38Kj7YCC7ouNbVZSSpqkZ+LzIfhCr6Dg+I=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.24.11/go.mod h1:Ro744S4fKiCCuZECXgOi760TiYylUM8ZBf6OGiZzJtY=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.10 h1:l+dgv/64iVlQ3WsBbnn+JSbkj01jIi+SM0wYsj3y/hY=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.10/go.mod h1:Fzsj6lZEb8AkTE5S68OhcbBqeWPsR8RnGuKPr8Todl8=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.33.9 h1:BRVDbewN6VZcwr+FBOszDKvYeXY1kJ+GGMCcpghlw0U=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.33.9/go.mod h1:f6vjfZER1M17Fokn0IzssOTMT2N8ZSq+7jnNF0tArvw=
|
||||
github.com/aws/smithy-go v1.22.1 h1:/HPHZQ0g7f4eUeK6HKglFz8uwVfZKgoI25rb/J+dnro=
|
||||
github.com/aws/smithy-go v1.22.1/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8=
|
||||
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA=
|
||||
github.com/bmatcuk/doublestar/v4 v4.8.1 h1:54Bopc5c2cAvhLRAzqOGCYHYyhcDHsFF4wWIR5wKP38=
|
||||
github.com/bmatcuk/doublestar/v4 v4.8.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc=
|
||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
|
||||
github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8=
|
||||
github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE=
|
||||
github.com/bytedance/mockey v1.2.13 h1:jokWZAm/pUEbD939Rhznz615MKUCZNuvCFQlJ2+ntoo=
|
||||
github.com/bytedance/mockey v1.2.13/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY=
|
||||
github.com/bytedance/sonic v1.12.2 h1:oaMFuRTpMHYLpCntGca65YWt5ny+wAceDERTkT2L9lg=
|
||||
github.com/bytedance/sonic v1.12.2/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM=
|
||||
github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY=
|
||||
github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc=
|
||||
github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4=
|
||||
github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE=
|
||||
github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU=
|
||||
github.com/charmbracelet/bubbletea v1.3.4 h1:kCg7B+jSCFPLYRA52SDZjr51kG/fMUEoPoZrkaDHyoI=
|
||||
|
@ -83,18 +54,6 @@ github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 h1:qko
|
|||
github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0/go.mod h1:pBhA0ybfXv6hDjQUZ7hk1lVxBiUbupdw5R31yPUViVQ=
|
||||
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||
github.com/cloudwego/eino v0.3.17 h1:cRQUCLU6897cautWe1u3les1H3OILasUIhnHzxr2QcE=
|
||||
github.com/cloudwego/eino v0.3.17/go.mod h1:+kmJimGEcKuSI6OKhet7kBedkm1WUZS3H1QRazxgWUo=
|
||||
github.com/cloudwego/eino-ext/components/model/claude v0.0.0-20250320062631-616205c32186 h1:GGneAI4dIuQaxXfgsaaytmFuIi57cFTQ0R7b8PDUOxg=
|
||||
github.com/cloudwego/eino-ext/components/model/claude v0.0.0-20250320062631-616205c32186/go.mod h1:ABAc+C6D9zs6KZZKfuHUhsCi+YWFogbWlqsMkQ9wJYY=
|
||||
github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250320062631-616205c32186 h1:4f7KLAI2/177oZZ2iJO2bpzeJU9KPg3TduVTR0ulzQk=
|
||||
github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250320062631-616205c32186/go.mod h1:vs6irdysecD6QxymjsCTzAgN2JceOTPYPjTXfJCNcgQ=
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.0.0-20250305023926-469de0301955 h1:fgvkmTqAalDfjdy3b6Ur2mh/KEwB9L2uvqS4MFgTOqc=
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.0.0-20250305023926-469de0301955/go.mod h1:6CThw1XQx/ASXNt31yuvp0X4Yp4GprknQuIvP9VKDpw=
|
||||
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
|
||||
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
|
@ -105,42 +64,37 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
|
|||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
|
||||
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/getkin/kin-openapi v0.118.0 h1:z43njxPmJ7TaPpMSCQb7PN0dEYno4tyBPQcrFdHoLuM=
|
||||
github.com/getkin/kin-openapi v0.118.0/go.mod h1:l5e9PaFUo9fyLJCPGQeXI2ML8c3P8BHOEV2VaAVf/pc=
|
||||
github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ=
|
||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI=
|
||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98=
|
||||
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
|
||||
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
|
||||
github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY=
|
||||
github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
|
||||
github.com/go-openapi/swag v0.19.5 h1:lTz6Ys4CmqqCQmZPBlbQENR1/GucA2bzYTE12Pw4tFY=
|
||||
github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
|
||||
github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM=
|
||||
github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
|
||||
github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8=
|
||||
github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg=
|
||||
github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM=
|
||||
github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18=
|
||||
github.com/goph/emperror v0.17.2/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic=
|
||||
github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g=
|
||||
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA=
|
||||
github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q=
|
||||
github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA=
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
|
||||
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
|
@ -148,22 +102,8 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l
|
|||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/invopop/yaml v0.1.0 h1:YW3WGUoJEXYfzWBjn00zIlrw7brGVD0fUKRYDPAPhrc=
|
||||
github.com/invopop/yaml v0.1.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q=
|
||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
|
||||
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
|
||||
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
|
@ -171,8 +111,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
|||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kujtimiihoxha/vimtea v0.0.3-0.20250317175717-9d8ba9c69840 h1:AORwYXTzap8hg0zmTA5RWB/0fxv9F19dF42dCY0IsRc=
|
||||
github.com/kujtimiihoxha/vimtea v0.0.3-0.20250317175717-9d8ba9c69840/go.mod h1:VyCD1xYnYem+OHp9nzGNx8x7rCwaeB+2VSyOUgX8Zyc=
|
||||
github.com/kujtimiihoxha/vimtea v0.0.3-0.20250329221256-a250e98498f9 h1:xYfCLI8KUwmXDFp1pOpNX+XsQczQw9VbEuju1pQF5/A=
|
||||
github.com/kujtimiihoxha/vimtea v0.0.3-0.20250329221256-a250e98498f9/go.mod h1:Ye+kIkTmPO5xuqCQ+PPHDTGIViRRoSpSIlcYgma8YlA=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
|
@ -181,12 +121,8 @@ github.com/lrstanley/bubblezone v0.0.0-20250315020633-c249a3fe1231 h1:9rjt7AfnrX
|
|||
github.com/lrstanley/bubblezone v0.0.0-20250315020633-c249a3fe1231/go.mod h1:S5etECMx+sZnW0Gm100Ma9J1PgVCTgNyFaqGu2b08b4=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
|
||||
github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
|
||||
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE=
|
||||
github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
|
||||
github.com/mark3labs/mcp-go v0.17.0 h1:5Ps6T7qXr7De/2QTqs9h6BKeZ/qdeUeGrgM5lPzi930=
|
||||
github.com/mark3labs/mcp-go v0.17.0/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
|
@ -196,19 +132,10 @@ github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6T
|
|||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
|
||||
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4=
|
||||
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw=
|
||||
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
|
@ -217,18 +144,10 @@ github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
|||
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c=
|
||||
github.com/nikolalohinski/gonja v1.5.3/go.mod h1:RmjwxNiXAEqcq1HeK5SSMmqFJvKOfTfXhkJv6YBtPa4=
|
||||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
github.com/openai/openai-go v0.1.0-beta.2 h1:Ra5nCFkbEl9w+UJwAciC4kqnIBUCcJazhmMA0/YN894=
|
||||
github.com/openai/openai-go v0.1.0-beta.2/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
|
||||
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
|
||||
github.com/perimeterx/marshmallow v1.1.4 h1:pZLDH9RjlLGGorbXhcaQLhfuV0pFMNfPO55FuFkxqLw=
|
||||
github.com/perimeterx/marshmallow v1.1.4/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
|
||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
|
@ -237,25 +156,13 @@ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
|||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
|
||||
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||
github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtISeXco0L5PKQ=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
|
||||
github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
|
||||
github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA=
|
||||
github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
|
||||
github.com/sashabaranov/go-openai v1.32.5 h1:/eNVa8KzlE7mJdKPZDj6886MUzZQjoVHyn0sLvIt5qA=
|
||||
github.com/sashabaranov/go-openai v1.32.5/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8=
|
||||
github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I=
|
||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI=
|
||||
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg=
|
||||
github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY=
|
||||
github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec=
|
||||
github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY=
|
||||
github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60=
|
||||
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
||||
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
||||
github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
|
||||
|
@ -269,16 +176,8 @@ github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An
|
|||
github.com/spf13/viper v1.20.0 h1:zrxIyR3RQIOsarIrgL8+sAvALXul9jeEPa06Y0Ph6vY=
|
||||
github.com/spf13/viper v1.20.0/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
|
@ -293,66 +192,71 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
|||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo=
|
||||
github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=
|
||||
github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0=
|
||||
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
|
||||
github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg=
|
||||
github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc=
|
||||
github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
|
||||
github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic=
|
||||
github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
|
||||
github.com/yuin/goldmark-emoji v1.0.5 h1:EMVWyCGPlXJfUXBXpuMu+ii3TIaxbVBnEX9uaDC4cIk=
|
||||
github.com/yuin/goldmark-emoji v1.0.5/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 h1:r6I7RJCN86bpD/FQwedZ0vSixDpwuWREjW9oRMsmqDc=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8=
|
||||
go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw=
|
||||
go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8=
|
||||
go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc=
|
||||
go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8=
|
||||
go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4=
|
||||
go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ=
|
||||
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
|
||||
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
|
||||
golang.org/x/arch v0.11.0 h1:KXV8WWKCXm6tRpLirl2szsO5j/oOODwZf4hATmGVNs4=
|
||||
golang.org/x/arch v0.11.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
|
||||
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
|
||||
golang.design/x/clipboard v0.7.0 h1:4Je8M/ys9AJumVnl8m+rZnIvstSnYj1fvzqYrU3TXvo=
|
||||
golang.design/x/clipboard v0.7.0/go.mod h1:PQIvqYO9GP29yINEfsEn5zSQKAz3UgXmZKzDA6dnq2E=
|
||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw=
|
||||
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 h1:bFYqOIMdeiCEdzPJkLiOoMDzW/v3tjW4AA/RmUZYsL8=
|
||||
golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394/go.mod h1:ygj7T6vSGhhm/9yTpOQQNvuAUFziTH7RUiH74EoE2C8=
|
||||
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
|
||||
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a h1:sYbmY3FwUWCBTodZL1S3JUuOvaW6kM2o+clDzzDNBWg=
|
||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a/go.mod h1:Ede7gF0KGoHlj822RtphAHK1jLdrcuRBZg0sF1Q+SPc=
|
||||
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
|
||||
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
|
||||
golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70=
|
||||
golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
|
||||
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg=
|
||||
golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
google.golang.org/api v0.215.0 h1:jdYF4qnyczlEz2ReWIsosNLDuzXyvFHJtI5gcr0J7t0=
|
||||
google.golang.org/api v0.215.0/go.mod h1:fta3CVtuJYOEdugLNWm6WodzOS8KdFckABwN4I40hzY=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576 h1:CkkIfIt50+lT6NHAVoRYEyAvQGFM7xEwXUUywFvEb3Q=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576/go.mod h1:1R3kvZ1dtP3+4p4d3G8uJ8rFk/fWlScl38vanWACI08=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 h1:TqExAhdPaB60Ux47Cn0oLV07rGnxZzIsaRhQaqS666A=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8/go.mod h1:lcTa1sDdWEIHMWlITnIczmw5w60CF9ffkb8Z+DVmmjA=
|
||||
google.golang.org/grpc v1.67.3 h1:OgPcDAFKHnH8X3O4WcO4XUc8GRDeKsKReqbQtiCj7N8=
|
||||
google.golang.org/grpc v1.67.3/go.mod h1:YGaHCc6Oap+FzBJTZLBzkGSYt/cvGPFTPxkn7QfSU8s=
|
||||
google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk=
|
||||
google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
|
||||
|
|
|
@ -4,13 +4,12 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/db"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm"
|
||||
"github.com/kujtimiihoxha/termai/internal/logging"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
"github.com/kujtimiihoxha/termai/internal/session"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type App struct {
|
||||
|
@ -19,7 +18,6 @@ type App struct {
|
|||
Sessions session.Service
|
||||
Messages message.Service
|
||||
Permissions permission.Service
|
||||
LLM llm.Service
|
||||
|
||||
Logger logging.Interface
|
||||
}
|
||||
|
@ -27,18 +25,17 @@ type App struct {
|
|||
func New(ctx context.Context, conn *sql.DB) *App {
|
||||
q := db.New(conn)
|
||||
log := logging.NewLogger(logging.Options{
|
||||
Level: viper.GetString("log.level"),
|
||||
Level: config.Get().Log.Level,
|
||||
})
|
||||
sessions := session.NewService(ctx, q)
|
||||
messages := message.NewService(ctx, q)
|
||||
llm := llm.NewService(ctx, log, sessions, messages)
|
||||
|
||||
return &App{
|
||||
Context: ctx,
|
||||
Sessions: sessions,
|
||||
Messages: messages,
|
||||
Permissions: permission.Default,
|
||||
LLM: llm,
|
||||
Logger: log,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
180
internal/config/config.go
Normal file
180
internal/config/config.go
Normal file
|
@ -0,0 +1,180 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type MCPType string
|
||||
|
||||
const (
|
||||
MCPStdio MCPType = "stdio"
|
||||
MCPSse MCPType = "sse"
|
||||
)
|
||||
|
||||
type MCPServer struct {
|
||||
Command string `json:"command"`
|
||||
Env []string `json:"env"`
|
||||
Args []string `json:"args"`
|
||||
Type MCPType `json:"type"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
// TODO: add permissions configuration
|
||||
// TODO: add the ability to specify the tools to import
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
Coder models.ModelID `json:"coder"`
|
||||
CoderMaxTokens int64 `json:"coderMaxTokens"`
|
||||
|
||||
Task models.ModelID `json:"task"`
|
||||
TaskMaxTokens int64 `json:"taskMaxTokens"`
|
||||
// TODO: Maybe support multiple models for different purposes
|
||||
}
|
||||
|
||||
type Provider struct {
|
||||
APIKey string `json:"apiKey"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type Data struct {
|
||||
Directory string `json:"directory"`
|
||||
}
|
||||
|
||||
type Log struct {
|
||||
Level string `json:"level"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Data *Data `json:"data,omitempty"`
|
||||
Log *Log `json:"log,omitempty"`
|
||||
MCPServers map[string]MCPServer `json:"mcpServers,omitempty"`
|
||||
Providers map[models.ModelProvider]Provider `json:"providers,omitempty"`
|
||||
|
||||
Model *Model `json:"model,omitempty"`
|
||||
}
|
||||
|
||||
var cfg *Config
|
||||
|
||||
const (
|
||||
defaultDataDirectory = ".termai"
|
||||
defaultLogLevel = "info"
|
||||
defaultMaxTokens = int64(5000)
|
||||
termai = "termai"
|
||||
)
|
||||
|
||||
func Load(debug bool) error {
|
||||
if cfg != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
viper.SetConfigName(fmt.Sprintf(".%s", termai))
|
||||
viper.SetConfigType("json")
|
||||
viper.AddConfigPath("$HOME")
|
||||
viper.AddConfigPath(fmt.Sprintf("$XDG_CONFIG_HOME/%s", termai))
|
||||
viper.SetEnvPrefix(strings.ToUpper(termai))
|
||||
|
||||
// Add defaults
|
||||
viper.SetDefault("data.directory", defaultDataDirectory)
|
||||
if debug {
|
||||
viper.Set("log.level", "debug")
|
||||
} else {
|
||||
viper.SetDefault("log.level", defaultLogLevel)
|
||||
}
|
||||
|
||||
defaultModelSet := false
|
||||
if os.Getenv("ANTHROPIC_API_KEY") != "" {
|
||||
viper.SetDefault("providers.anthropic.apiKey", os.Getenv("ANTHROPIC_API_KEY"))
|
||||
viper.SetDefault("providers.anthropic.enabled", true)
|
||||
viper.SetDefault("model.coder", models.Claude37Sonnet)
|
||||
viper.SetDefault("model.task", models.Claude37Sonnet)
|
||||
defaultModelSet = true
|
||||
}
|
||||
if os.Getenv("OPENAI_API_KEY") != "" {
|
||||
viper.SetDefault("providers.openai.apiKey", os.Getenv("OPENAI_API_KEY"))
|
||||
viper.SetDefault("providers.openai.enabled", true)
|
||||
if !defaultModelSet {
|
||||
viper.SetDefault("model.coder", models.GPT4o)
|
||||
viper.SetDefault("model.task", models.GPT4o)
|
||||
defaultModelSet = true
|
||||
}
|
||||
}
|
||||
if os.Getenv("GEMINI_API_KEY") != "" {
|
||||
viper.SetDefault("providers.gemini.apiKey", os.Getenv("GEMINI_API_KEY"))
|
||||
viper.SetDefault("providers.gemini.enabled", true)
|
||||
if !defaultModelSet {
|
||||
viper.SetDefault("model.coder", models.GRMINI20Flash)
|
||||
viper.SetDefault("model.task", models.GRMINI20Flash)
|
||||
defaultModelSet = true
|
||||
}
|
||||
}
|
||||
if os.Getenv("GROQ_API_KEY") != "" {
|
||||
viper.SetDefault("providers.groq.apiKey", os.Getenv("GROQ_API_KEY"))
|
||||
viper.SetDefault("providers.groq.enabled", true)
|
||||
if !defaultModelSet {
|
||||
viper.SetDefault("model.coder", models.QWENQwq)
|
||||
viper.SetDefault("model.task", models.QWENQwq)
|
||||
defaultModelSet = true
|
||||
}
|
||||
}
|
||||
// TODO: add more providers
|
||||
cfg = &Config{}
|
||||
|
||||
err := viper.ReadInConfig()
|
||||
if err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
return err
|
||||
}
|
||||
}
|
||||
local := viper.New()
|
||||
local.SetConfigName(fmt.Sprintf(".%s", termai))
|
||||
local.SetConfigType("json")
|
||||
local.AddConfigPath(".")
|
||||
// load local config, this will override the global config
|
||||
if err = local.ReadInConfig(); err == nil {
|
||||
viper.MergeConfigMap(local.AllSettings())
|
||||
}
|
||||
viper.Unmarshal(cfg)
|
||||
|
||||
if cfg.Model != nil && cfg.Model.CoderMaxTokens <= 0 {
|
||||
cfg.Model.CoderMaxTokens = defaultMaxTokens
|
||||
}
|
||||
if cfg.Model != nil && cfg.Model.TaskMaxTokens <= 0 {
|
||||
cfg.Model.TaskMaxTokens = defaultMaxTokens
|
||||
}
|
||||
|
||||
for _, v := range cfg.MCPServers {
|
||||
if v.Type == "" {
|
||||
v.Type = MCPStdio
|
||||
}
|
||||
}
|
||||
|
||||
workdir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
viper.Set("wd", workdir)
|
||||
return nil
|
||||
}
|
||||
|
||||
func Get() *Config {
|
||||
if cfg == nil {
|
||||
err := Load(false)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func WorkingDirectory() string {
|
||||
return viper.GetString("wd")
|
||||
}
|
||||
|
||||
func Write() error {
|
||||
return viper.WriteConfig()
|
||||
}
|
465
internal/config/config_test.go
Normal file
465
internal/config/config_test.go
Normal file
|
@ -0,0 +1,465 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
setupTest(t)
|
||||
|
||||
t.Run("loads configuration successfully", func(t *testing.T) {
|
||||
homeDir := t.TempDir()
|
||||
t.Setenv("HOME", homeDir)
|
||||
configPath := filepath.Join(homeDir, ".termai.json")
|
||||
|
||||
configContent := `{
|
||||
"data": {
|
||||
"directory": "custom-dir"
|
||||
},
|
||||
"log": {
|
||||
"level": "debug"
|
||||
},
|
||||
"mcpServers": {
|
||||
"test-server": {
|
||||
"command": "test-command",
|
||||
"env": ["TEST_ENV=value"],
|
||||
"args": ["--arg1", "--arg2"],
|
||||
"type": "stdio",
|
||||
"url": "",
|
||||
"headers": {}
|
||||
},
|
||||
"sse-server": {
|
||||
"command": "",
|
||||
"env": [],
|
||||
"args": [],
|
||||
"type": "sse",
|
||||
"url": "https://api.example.com/events",
|
||||
"headers": {
|
||||
"Authorization": "Bearer token123",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"anthropic": {
|
||||
"apiKey": "test-api-key",
|
||||
"enabled": true
|
||||
}
|
||||
},
|
||||
"model": {
|
||||
"coder": "claude-3-haiku",
|
||||
"task": "claude-3-haiku"
|
||||
}
|
||||
}`
|
||||
err := os.WriteFile(configPath, []byte(configContent), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg = nil
|
||||
viper.Reset()
|
||||
|
||||
err = Load(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := Get()
|
||||
assert.NotNil(t, config)
|
||||
assert.Equal(t, "custom-dir", config.Data.Directory)
|
||||
assert.Equal(t, "debug", config.Log.Level)
|
||||
|
||||
assert.Contains(t, config.MCPServers, "test-server")
|
||||
stdioServer := config.MCPServers["test-server"]
|
||||
assert.Equal(t, "test-command", stdioServer.Command)
|
||||
assert.Equal(t, []string{"TEST_ENV=value"}, stdioServer.Env)
|
||||
assert.Equal(t, []string{"--arg1", "--arg2"}, stdioServer.Args)
|
||||
assert.Equal(t, MCPStdio, stdioServer.Type)
|
||||
assert.Equal(t, "", stdioServer.URL)
|
||||
assert.Empty(t, stdioServer.Headers)
|
||||
|
||||
assert.Contains(t, config.MCPServers, "sse-server")
|
||||
sseServer := config.MCPServers["sse-server"]
|
||||
assert.Equal(t, "", sseServer.Command)
|
||||
assert.Empty(t, sseServer.Env)
|
||||
assert.Empty(t, sseServer.Args)
|
||||
assert.Equal(t, MCPSse, sseServer.Type)
|
||||
assert.Equal(t, "https://api.example.com/events", sseServer.URL)
|
||||
assert.Equal(t, map[string]string{
|
||||
"authorization": "Bearer token123",
|
||||
"content-type": "application/json",
|
||||
}, sseServer.Headers)
|
||||
|
||||
assert.Contains(t, config.Providers, models.ModelProvider("anthropic"))
|
||||
provider := config.Providers[models.ModelProvider("anthropic")]
|
||||
assert.Equal(t, "test-api-key", provider.APIKey)
|
||||
assert.True(t, provider.Enabled)
|
||||
|
||||
assert.NotNil(t, config.Model)
|
||||
assert.Equal(t, models.Claude3Haiku, config.Model.Coder)
|
||||
assert.Equal(t, models.Claude3Haiku, config.Model.Task)
|
||||
assert.Equal(t, defaultMaxTokens, config.Model.CoderMaxTokens)
|
||||
})
|
||||
|
||||
t.Run("loads configuration with environment variables", func(t *testing.T) {
|
||||
homeDir := t.TempDir()
|
||||
t.Setenv("HOME", homeDir)
|
||||
configPath := filepath.Join(homeDir, ".termai.json")
|
||||
err := os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Setenv("ANTHROPIC_API_KEY", "env-anthropic-key")
|
||||
t.Setenv("OPENAI_API_KEY", "env-openai-key")
|
||||
t.Setenv("GEMINI_API_KEY", "env-gemini-key")
|
||||
|
||||
cfg = nil
|
||||
viper.Reset()
|
||||
|
||||
err = Load(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := Get()
|
||||
assert.NotNil(t, config)
|
||||
|
||||
assert.Equal(t, defaultDataDirectory, config.Data.Directory)
|
||||
assert.Equal(t, defaultLogLevel, config.Log.Level)
|
||||
|
||||
assert.Contains(t, config.Providers, models.ModelProvider("anthropic"))
|
||||
assert.Equal(t, "env-anthropic-key", config.Providers[models.ModelProvider("anthropic")].APIKey)
|
||||
assert.True(t, config.Providers[models.ModelProvider("anthropic")].Enabled)
|
||||
|
||||
assert.Contains(t, config.Providers, models.ModelProvider("openai"))
|
||||
assert.Equal(t, "env-openai-key", config.Providers[models.ModelProvider("openai")].APIKey)
|
||||
assert.True(t, config.Providers[models.ModelProvider("openai")].Enabled)
|
||||
|
||||
assert.Contains(t, config.Providers, models.ModelProvider("gemini"))
|
||||
assert.Equal(t, "env-gemini-key", config.Providers[models.ModelProvider("gemini")].APIKey)
|
||||
assert.True(t, config.Providers[models.ModelProvider("gemini")].Enabled)
|
||||
|
||||
assert.Equal(t, models.Claude37Sonnet, config.Model.Coder)
|
||||
})
|
||||
|
||||
t.Run("local config overrides global config", func(t *testing.T) {
|
||||
homeDir := t.TempDir()
|
||||
t.Setenv("HOME", homeDir)
|
||||
globalConfigPath := filepath.Join(homeDir, ".termai.json")
|
||||
globalConfig := `{
|
||||
"data": {
|
||||
"directory": "global-dir"
|
||||
},
|
||||
"log": {
|
||||
"level": "info"
|
||||
}
|
||||
}`
|
||||
err := os.WriteFile(globalConfigPath, []byte(globalConfig), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
workDir := t.TempDir()
|
||||
origDir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
defer os.Chdir(origDir)
|
||||
err = os.Chdir(workDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
localConfigPath := filepath.Join(workDir, ".termai.json")
|
||||
localConfig := `{
|
||||
"data": {
|
||||
"directory": "local-dir"
|
||||
},
|
||||
"log": {
|
||||
"level": "debug"
|
||||
}
|
||||
}`
|
||||
err = os.WriteFile(localConfigPath, []byte(localConfig), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg = nil
|
||||
viper.Reset()
|
||||
|
||||
err = Load(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := Get()
|
||||
assert.NotNil(t, config)
|
||||
|
||||
assert.Equal(t, "local-dir", config.Data.Directory)
|
||||
assert.Equal(t, "debug", config.Log.Level)
|
||||
})
|
||||
|
||||
t.Run("missing config file should not return error", func(t *testing.T) {
|
||||
emptyDir := t.TempDir()
|
||||
t.Setenv("HOME", emptyDir)
|
||||
|
||||
cfg = nil
|
||||
viper.Reset()
|
||||
|
||||
err := Load(false)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("model priority and fallbacks", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
anthropicKey string
|
||||
openaiKey string
|
||||
geminiKey string
|
||||
expectedModel models.ModelID
|
||||
explicitModel models.ModelID
|
||||
useExplicitModel bool
|
||||
}{
|
||||
{
|
||||
name: "anthropic has priority",
|
||||
anthropicKey: "test-key",
|
||||
openaiKey: "test-key",
|
||||
geminiKey: "test-key",
|
||||
expectedModel: models.Claude37Sonnet,
|
||||
},
|
||||
{
|
||||
name: "fallback to openai when no anthropic",
|
||||
anthropicKey: "",
|
||||
openaiKey: "test-key",
|
||||
geminiKey: "test-key",
|
||||
expectedModel: models.GPT4o,
|
||||
},
|
||||
{
|
||||
name: "fallback to gemini when no others",
|
||||
anthropicKey: "",
|
||||
openaiKey: "",
|
||||
geminiKey: "test-key",
|
||||
expectedModel: models.GRMINI20Flash,
|
||||
},
|
||||
{
|
||||
name: "explicit model overrides defaults",
|
||||
anthropicKey: "test-key",
|
||||
openaiKey: "test-key",
|
||||
geminiKey: "test-key",
|
||||
explicitModel: models.GPT4o,
|
||||
useExplicitModel: true,
|
||||
expectedModel: models.GPT4o,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
homeDir := t.TempDir()
|
||||
t.Setenv("HOME", homeDir)
|
||||
configPath := filepath.Join(homeDir, ".termai.json")
|
||||
|
||||
configContent := "{}"
|
||||
if tc.useExplicitModel {
|
||||
configContent = fmt.Sprintf(`{"model":{"coder":"%s"}}`, tc.explicitModel)
|
||||
}
|
||||
|
||||
err := os.WriteFile(configPath, []byte(configContent), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tc.anthropicKey != "" {
|
||||
t.Setenv("ANTHROPIC_API_KEY", tc.anthropicKey)
|
||||
} else {
|
||||
t.Setenv("ANTHROPIC_API_KEY", "")
|
||||
}
|
||||
|
||||
if tc.openaiKey != "" {
|
||||
t.Setenv("OPENAI_API_KEY", tc.openaiKey)
|
||||
} else {
|
||||
t.Setenv("OPENAI_API_KEY", "")
|
||||
}
|
||||
|
||||
if tc.geminiKey != "" {
|
||||
t.Setenv("GEMINI_API_KEY", tc.geminiKey)
|
||||
} else {
|
||||
t.Setenv("GEMINI_API_KEY", "")
|
||||
}
|
||||
|
||||
cfg = nil
|
||||
viper.Reset()
|
||||
|
||||
err = Load(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := Get()
|
||||
assert.NotNil(t, config)
|
||||
assert.Equal(t, tc.expectedModel, config.Model.Coder)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
t.Run("get returns same config instance", func(t *testing.T) {
|
||||
setupTest(t)
|
||||
homeDir := t.TempDir()
|
||||
t.Setenv("HOME", homeDir)
|
||||
configPath := filepath.Join(homeDir, ".termai.json")
|
||||
err := os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg = nil
|
||||
viper.Reset()
|
||||
|
||||
config1 := Get()
|
||||
require.NotNil(t, config1)
|
||||
|
||||
config2 := Get()
|
||||
require.NotNil(t, config2)
|
||||
|
||||
assert.Same(t, config1, config2)
|
||||
})
|
||||
|
||||
t.Run("get loads config if not loaded", func(t *testing.T) {
|
||||
setupTest(t)
|
||||
homeDir := t.TempDir()
|
||||
t.Setenv("HOME", homeDir)
|
||||
configPath := filepath.Join(homeDir, ".termai.json")
|
||||
configContent := `{"data":{"directory":"test-dir"}}`
|
||||
err := os.WriteFile(configPath, []byte(configContent), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg = nil
|
||||
viper.Reset()
|
||||
|
||||
config := Get()
|
||||
require.NotNil(t, config)
|
||||
assert.Equal(t, "test-dir", config.Data.Directory)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkingDirectory(t *testing.T) {
|
||||
t.Run("returns current working directory", func(t *testing.T) {
|
||||
setupTest(t)
|
||||
homeDir := t.TempDir()
|
||||
t.Setenv("HOME", homeDir)
|
||||
configPath := filepath.Join(homeDir, ".termai.json")
|
||||
err := os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg = nil
|
||||
viper.Reset()
|
||||
|
||||
err = Load(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
wd := WorkingDirectory()
|
||||
expectedWd, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedWd, wd)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWrite(t *testing.T) {
|
||||
t.Run("writes config to file", func(t *testing.T) {
|
||||
setupTest(t)
|
||||
homeDir := t.TempDir()
|
||||
t.Setenv("HOME", homeDir)
|
||||
configPath := filepath.Join(homeDir, ".termai.json")
|
||||
err := os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg = nil
|
||||
viper.Reset()
|
||||
|
||||
err = Load(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
viper.Set("data.directory", "modified-dir")
|
||||
|
||||
err = Write()
|
||||
require.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(configPath)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(content), "modified-dir")
|
||||
})
|
||||
}
|
||||
|
||||
func TestMCPType(t *testing.T) {
|
||||
t.Run("MCPType constants", func(t *testing.T) {
|
||||
assert.Equal(t, MCPType("stdio"), MCPStdio)
|
||||
assert.Equal(t, MCPType("sse"), MCPSse)
|
||||
})
|
||||
|
||||
t.Run("MCPType JSON unmarshaling", func(t *testing.T) {
|
||||
homeDir := t.TempDir()
|
||||
t.Setenv("HOME", homeDir)
|
||||
configPath := filepath.Join(homeDir, ".termai.json")
|
||||
|
||||
configContent := `{
|
||||
"mcpServers": {
|
||||
"stdio-server": {
|
||||
"type": "stdio"
|
||||
},
|
||||
"sse-server": {
|
||||
"type": "sse"
|
||||
},
|
||||
"invalid-server": {
|
||||
"type": "invalid"
|
||||
}
|
||||
}
|
||||
}`
|
||||
err := os.WriteFile(configPath, []byte(configContent), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg = nil
|
||||
viper.Reset()
|
||||
|
||||
err = Load(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := Get()
|
||||
assert.NotNil(t, config)
|
||||
|
||||
assert.Equal(t, MCPStdio, config.MCPServers["stdio-server"].Type)
|
||||
assert.Equal(t, MCPSse, config.MCPServers["sse-server"].Type)
|
||||
assert.Equal(t, MCPType("invalid"), config.MCPServers["invalid-server"].Type)
|
||||
})
|
||||
|
||||
t.Run("default MCPType", func(t *testing.T) {
|
||||
homeDir := t.TempDir()
|
||||
t.Setenv("HOME", homeDir)
|
||||
configPath := filepath.Join(homeDir, ".termai.json")
|
||||
|
||||
configContent := `{
|
||||
"mcpServers": {
|
||||
"test-server": {
|
||||
"command": "test-command"
|
||||
}
|
||||
}
|
||||
}`
|
||||
err := os.WriteFile(configPath, []byte(configContent), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg = nil
|
||||
viper.Reset()
|
||||
|
||||
err = Load(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := Get()
|
||||
assert.NotNil(t, config)
|
||||
|
||||
assert.Equal(t, MCPType(""), config.MCPServers["test-server"].Type)
|
||||
})
|
||||
}
|
||||
|
||||
func setupTest(t *testing.T) {
|
||||
origHome := os.Getenv("HOME")
|
||||
origXdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
|
||||
origAnthropicKey := os.Getenv("ANTHROPIC_API_KEY")
|
||||
origOpenAIKey := os.Getenv("OPENAI_API_KEY")
|
||||
origGeminiKey := os.Getenv("GEMINI_API_KEY")
|
||||
|
||||
t.Cleanup(func() {
|
||||
t.Setenv("HOME", origHome)
|
||||
t.Setenv("XDG_CONFIG_HOME", origXdgConfigHome)
|
||||
t.Setenv("ANTHROPIC_API_KEY", origAnthropicKey)
|
||||
t.Setenv("OPENAI_API_KEY", origOpenAIKey)
|
||||
t.Setenv("GEMINI_API_KEY", origGeminiKey)
|
||||
|
||||
cfg = nil
|
||||
viper.Reset()
|
||||
})
|
||||
}
|
|
@ -12,14 +12,14 @@ import (
|
|||
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/logging"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var log = logging.Get()
|
||||
|
||||
func Connect() (*sql.DB, error) {
|
||||
dataDir := viper.GetString("data.dir")
|
||||
dataDir := config.Get().Data.Directory
|
||||
if dataDir == "" {
|
||||
return nil, fmt.Errorf("data.dir is not set")
|
||||
}
|
||||
|
|
|
@ -51,6 +51,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
|
|||
if q.listSessionsStmt, err = db.PrepareContext(ctx, listSessions); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListSessions: %w", err)
|
||||
}
|
||||
if q.updateMessageStmt, err = db.PrepareContext(ctx, updateMessage); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query UpdateMessage: %w", err)
|
||||
}
|
||||
if q.updateSessionStmt, err = db.PrepareContext(ctx, updateSession); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query UpdateSession: %w", err)
|
||||
}
|
||||
|
@ -104,6 +107,11 @@ func (q *Queries) Close() error {
|
|||
err = fmt.Errorf("error closing listSessionsStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.updateMessageStmt != nil {
|
||||
if cerr := q.updateMessageStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing updateMessageStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.updateSessionStmt != nil {
|
||||
if cerr := q.updateSessionStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing updateSessionStmt: %w", cerr)
|
||||
|
@ -157,6 +165,7 @@ type Queries struct {
|
|||
getSessionByIDStmt *sql.Stmt
|
||||
listMessagesBySessionStmt *sql.Stmt
|
||||
listSessionsStmt *sql.Stmt
|
||||
updateMessageStmt *sql.Stmt
|
||||
updateSessionStmt *sql.Stmt
|
||||
}
|
||||
|
||||
|
@ -173,6 +182,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
|
|||
getSessionByIDStmt: q.getSessionByIDStmt,
|
||||
listMessagesBySessionStmt: q.listMessagesBySessionStmt,
|
||||
listSessionsStmt: q.listSessionsStmt,
|
||||
updateMessageStmt: q.updateMessageStmt,
|
||||
updateSessionStmt: q.updateSessionStmt,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,34 +7,56 @@ package db
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
const createMessage = `-- name: CreateMessage :one
|
||||
INSERT INTO messages (
|
||||
id,
|
||||
session_id,
|
||||
message_data,
|
||||
role,
|
||||
finished,
|
||||
content,
|
||||
tool_calls,
|
||||
tool_results,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
|
||||
?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
|
||||
)
|
||||
RETURNING id, session_id, message_data, created_at, updated_at
|
||||
RETURNING id, session_id, role, content, thinking, finished, tool_calls, tool_results, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateMessageParams struct {
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
MessageData string `json:"message_data"`
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Role string `json:"role"`
|
||||
Finished bool `json:"finished"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls sql.NullString `json:"tool_calls"`
|
||||
ToolResults sql.NullString `json:"tool_results"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) {
|
||||
row := q.queryRow(ctx, q.createMessageStmt, createMessage, arg.ID, arg.SessionID, arg.MessageData)
|
||||
row := q.queryRow(ctx, q.createMessageStmt, createMessage,
|
||||
arg.ID,
|
||||
arg.SessionID,
|
||||
arg.Role,
|
||||
arg.Finished,
|
||||
arg.Content,
|
||||
arg.ToolCalls,
|
||||
arg.ToolResults,
|
||||
)
|
||||
var i Message
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.MessageData,
|
||||
&i.Role,
|
||||
&i.Content,
|
||||
&i.Thinking,
|
||||
&i.Finished,
|
||||
&i.ToolCalls,
|
||||
&i.ToolResults,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
|
@ -62,7 +84,7 @@ func (q *Queries) DeleteSessionMessages(ctx context.Context, sessionID string) e
|
|||
}
|
||||
|
||||
const getMessage = `-- name: GetMessage :one
|
||||
SELECT id, session_id, message_data, created_at, updated_at
|
||||
SELECT id, session_id, role, content, thinking, finished, tool_calls, tool_results, created_at, updated_at
|
||||
FROM messages
|
||||
WHERE id = ? LIMIT 1
|
||||
`
|
||||
|
@ -73,7 +95,12 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) {
|
|||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.MessageData,
|
||||
&i.Role,
|
||||
&i.Content,
|
||||
&i.Thinking,
|
||||
&i.Finished,
|
||||
&i.ToolCalls,
|
||||
&i.ToolResults,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
|
@ -81,7 +108,7 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) {
|
|||
}
|
||||
|
||||
const listMessagesBySession = `-- name: ListMessagesBySession :many
|
||||
SELECT id, session_id, message_data, created_at, updated_at
|
||||
SELECT id, session_id, role, content, thinking, finished, tool_calls, tool_results, created_at, updated_at
|
||||
FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY created_at ASC
|
||||
|
@ -99,7 +126,12 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) (
|
|||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.MessageData,
|
||||
&i.Role,
|
||||
&i.Content,
|
||||
&i.Thinking,
|
||||
&i.Finished,
|
||||
&i.ToolCalls,
|
||||
&i.ToolResults,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
|
@ -115,3 +147,36 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) (
|
|||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateMessage = `-- name: UpdateMessage :exec
|
||||
UPDATE messages
|
||||
SET
|
||||
content = ?,
|
||||
thinking = ?,
|
||||
tool_calls = ?,
|
||||
tool_results = ?,
|
||||
finished = ?,
|
||||
updated_at = strftime('%s', 'now')
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
type UpdateMessageParams struct {
|
||||
Content string `json:"content"`
|
||||
Thinking string `json:"thinking"`
|
||||
ToolCalls sql.NullString `json:"tool_calls"`
|
||||
ToolResults sql.NullString `json:"tool_results"`
|
||||
Finished bool `json:"finished"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateMessage(ctx context.Context, arg UpdateMessageParams) error {
|
||||
_, err := q.exec(ctx, q.updateMessageStmt, updateMessage,
|
||||
arg.Content,
|
||||
arg.Thinking,
|
||||
arg.ToolCalls,
|
||||
arg.ToolResults,
|
||||
arg.Finished,
|
||||
arg.ID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
-- Sessions
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
parent_session_id TEXT,
|
||||
title TEXT NOT NULL,
|
||||
message_count INTEGER NOT NULL DEFAULT 0 CHECK (message_count >= 0),
|
||||
prompt_tokens INTEGER NOT NULL DEFAULT 0 CHECK (prompt_tokens >= 0),
|
||||
|
@ -21,7 +22,12 @@ END;
|
|||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
message_data TEXT NOT NULL, -- JSON string of message content
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
thinking Text NOT NULL DEFAULT '',
|
||||
finished BOOLEAN NOT NULL DEFAULT 0,
|
||||
tool_calls TEXT,
|
||||
tool_results TEXT,
|
||||
created_at INTEGER NOT NULL, -- Unix timestamp in milliseconds
|
||||
updated_at INTEGER NOT NULL, -- Unix timestamp in milliseconds
|
||||
FOREIGN KEY (session_id) REFERENCES sessions (id) ON DELETE CASCADE
|
||||
|
|
|
@ -4,21 +4,31 @@
|
|||
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
MessageData string `json:"message_data"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Thinking string `json:"thinking"`
|
||||
Finished bool `json:"finished"`
|
||||
ToolCalls sql.NullString `json:"tool_calls"`
|
||||
ToolResults sql.NullString `json:"tool_results"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
MessageCount int64 `json:"message_count"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
ID string `json:"id"`
|
||||
ParentSessionID sql.NullString `json:"parent_session_id"`
|
||||
Title string `json:"title"`
|
||||
MessageCount int64 `json:"message_count"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ type Querier interface {
|
|||
GetSessionByID(ctx context.Context, id string) (Session, error)
|
||||
ListMessagesBySession(ctx context.Context, sessionID string) ([]Message, error)
|
||||
ListSessions(ctx context.Context) ([]Session, error)
|
||||
UpdateMessage(ctx context.Context, arg UpdateMessageParams) error
|
||||
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
|
||||
}
|
||||
|
||||
|
|
|
@ -7,11 +7,13 @@ package db
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
const createSession = `-- name: CreateSession :one
|
||||
INSERT INTO sessions (
|
||||
id,
|
||||
parent_session_id,
|
||||
title,
|
||||
message_count,
|
||||
prompt_tokens,
|
||||
|
@ -26,23 +28,26 @@ INSERT INTO sessions (
|
|||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
strftime('%s', 'now'),
|
||||
strftime('%s', 'now')
|
||||
) RETURNING id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
|
||||
) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
|
||||
`
|
||||
|
||||
type CreateSessionParams struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
MessageCount int64 `json:"message_count"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
ID string `json:"id"`
|
||||
ParentSessionID sql.NullString `json:"parent_session_id"`
|
||||
Title string `json:"title"`
|
||||
MessageCount int64 `json:"message_count"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
|
||||
row := q.queryRow(ctx, q.createSessionStmt, createSession,
|
||||
arg.ID,
|
||||
arg.ParentSessionID,
|
||||
arg.Title,
|
||||
arg.MessageCount,
|
||||
arg.PromptTokens,
|
||||
|
@ -52,6 +57,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
|
|||
var i Session
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ParentSessionID,
|
||||
&i.Title,
|
||||
&i.MessageCount,
|
||||
&i.PromptTokens,
|
||||
|
@ -74,7 +80,7 @@ func (q *Queries) DeleteSession(ctx context.Context, id string) error {
|
|||
}
|
||||
|
||||
const getSessionByID = `-- name: GetSessionByID :one
|
||||
SELECT id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
|
||||
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
|
||||
FROM sessions
|
||||
WHERE id = ? LIMIT 1
|
||||
`
|
||||
|
@ -84,6 +90,7 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error
|
|||
var i Session
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ParentSessionID,
|
||||
&i.Title,
|
||||
&i.MessageCount,
|
||||
&i.PromptTokens,
|
||||
|
@ -96,8 +103,9 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error
|
|||
}
|
||||
|
||||
const listSessions = `-- name: ListSessions :many
|
||||
SELECT id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
|
||||
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
|
||||
FROM sessions
|
||||
WHERE parent_session_id is NULL
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
|
@ -112,6 +120,7 @@ func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) {
|
|||
var i Session
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.ParentSessionID,
|
||||
&i.Title,
|
||||
&i.MessageCount,
|
||||
&i.PromptTokens,
|
||||
|
@ -141,7 +150,7 @@ SET
|
|||
completion_tokens = ?,
|
||||
cost = ?
|
||||
WHERE id = ?
|
||||
RETURNING id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
|
||||
RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
|
||||
`
|
||||
|
||||
type UpdateSessionParams struct {
|
||||
|
@ -163,6 +172,7 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
|
|||
var i Session
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ParentSessionID,
|
||||
&i.Title,
|
||||
&i.MessageCount,
|
||||
&i.PromptTokens,
|
||||
|
|
|
@ -13,14 +13,29 @@ ORDER BY created_at ASC;
|
|||
INSERT INTO messages (
|
||||
id,
|
||||
session_id,
|
||||
message_data,
|
||||
role,
|
||||
finished,
|
||||
content,
|
||||
tool_calls,
|
||||
tool_results,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
|
||||
?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateMessage :exec
|
||||
UPDATE messages
|
||||
SET
|
||||
content = ?,
|
||||
thinking = ?,
|
||||
tool_calls = ?,
|
||||
tool_results = ?,
|
||||
finished = ?,
|
||||
updated_at = strftime('%s', 'now')
|
||||
WHERE id = ?;
|
||||
|
||||
-- name: DeleteMessage :exec
|
||||
DELETE FROM messages
|
||||
WHERE id = ?;
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
-- name: CreateSession :one
|
||||
INSERT INTO sessions (
|
||||
id,
|
||||
parent_session_id,
|
||||
title,
|
||||
message_count,
|
||||
prompt_tokens,
|
||||
|
@ -15,6 +16,7 @@ INSERT INTO sessions (
|
|||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
strftime('%s', 'now'),
|
||||
strftime('%s', 'now')
|
||||
) RETURNING *;
|
||||
|
@ -27,6 +29,7 @@ WHERE id = ? LIMIT 1;
|
|||
-- name: ListSessions :many
|
||||
SELECT *
|
||||
FROM sessions
|
||||
WHERE parent_session_id is NULL
|
||||
ORDER BY created_at DESC;
|
||||
|
||||
-- name: UpdateSession :one
|
||||
|
|
102
internal/llm/agent/agent-tool.go
Normal file
102
internal/llm/agent/agent-tool.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/app"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
)
|
||||
|
||||
type agentTool struct {
|
||||
parentSessionID string
|
||||
app *app.App
|
||||
}
|
||||
|
||||
const (
|
||||
AgentToolName = "agent"
|
||||
)
|
||||
|
||||
type AgentParams struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
func (b *agentTool) Info() tools.ToolInfo {
|
||||
return tools.ToolInfo{
|
||||
Name: AgentToolName,
|
||||
Description: "Launch a new agent that has access to the following tools: GlobTool, GrepTool, LS, View. When you are searching for a keyword or file and are not confident that you will find the right match on the first try, use the Agent tool to perform the search for you. For example:\n\n- If you are searching for a keyword like \"config\" or \"logger\", or for questions like \"which file does X?\", the Agent tool is strongly recommended\n- If you want to read a specific file path, use the View or GlobTool tool instead of the Agent tool, to find the match more quickly\n- If you are searching for a specific class definition like \"class Foo\", use the GlobTool tool instead, to find the match more quickly\n\nUsage notes:\n1. Launch multiple agents concurrently whenever possible, to maximize performance; to do that, use a single message with multiple tool uses\n2. When the agent is done, it will return a single message back to you. The result returned by the agent is not visible to the user. To show the user the result, you should send a text message back to the user with a concise summary of the result.\n3. Each agent invocation is stateless. You will not be able to send additional messages to the agent, nor will the agent be able to communicate with you outside of its final report. Therefore, your prompt should contain a highly detailed task description for the agent to perform autonomously and you should specify exactly what information the agent should return back to you in its final and only message to you.\n4. The agent's outputs should generally be trusted\n5. IMPORTANT: The agent can not use Bash, Replace, Edit, so can not modify files. If you want to use these tools, use them directly instead of going through the agent.",
|
||||
Parameters: map[string]any{
|
||||
"prompt": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The task for the agent to perform",
|
||||
},
|
||||
},
|
||||
Required: []string{"prompt"},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
|
||||
var params AgentParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
if params.Prompt == "" {
|
||||
return tools.NewTextErrorResponse("prompt is required"), nil
|
||||
}
|
||||
|
||||
agent, err := NewTaskAgent(b.app)
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(fmt.Sprintf("error creating agent: %s", err)), nil
|
||||
}
|
||||
|
||||
session, err := b.app.Sessions.CreateTaskSession(call.ID, b.parentSessionID, "New Agent Session")
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(fmt.Sprintf("error creating session: %s", err)), nil
|
||||
}
|
||||
|
||||
err = agent.Generate(session.ID, params.Prompt)
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(fmt.Sprintf("error generating agent: %s", err)), nil
|
||||
}
|
||||
|
||||
messages, err := b.app.Messages.List(session.ID)
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(fmt.Sprintf("error listing messages: %s", err)), nil
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
return tools.NewTextErrorResponse("no messages found"), nil
|
||||
}
|
||||
|
||||
response := messages[len(messages)-1]
|
||||
if response.Role != message.Assistant {
|
||||
return tools.NewTextErrorResponse("no assistant message found"), nil
|
||||
}
|
||||
|
||||
updatedSession, err := b.app.Sessions.Get(session.ID)
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
|
||||
}
|
||||
parentSession, err := b.app.Sessions.Get(b.parentSessionID)
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
|
||||
}
|
||||
|
||||
parentSession.Cost += updatedSession.Cost
|
||||
parentSession.PromptTokens += updatedSession.PromptTokens
|
||||
parentSession.CompletionTokens += updatedSession.CompletionTokens
|
||||
|
||||
_, err = b.app.Sessions.Save(parentSession)
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
|
||||
}
|
||||
return tools.NewTextResponse(response.Content), nil
|
||||
}
|
||||
|
||||
func NewAgentTool(parentSessionID string, app *app.App) tools.BaseTool {
|
||||
return &agentTool{
|
||||
parentSessionID: parentSessionID,
|
||||
app: app,
|
||||
}
|
||||
}
|
|
@ -2,16 +2,353 @@ package agent
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/flow/agent/react"
|
||||
"github.com/kujtimiihoxha/termai/internal/app"
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/prompt"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/provider"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
)
|
||||
|
||||
func GetAgent(ctx context.Context, name string) (*react.Agent, string, error) {
|
||||
switch name {
|
||||
case "coder":
|
||||
agent, err := NewCoderAgent(ctx)
|
||||
return agent, CoderSystemPrompt(), err
|
||||
}
|
||||
return nil, "", fmt.Errorf("agent %s not found", name)
|
||||
type Agent interface {
|
||||
Generate(sessionID string, content string) error
|
||||
}
|
||||
|
||||
type agent struct {
|
||||
*app.App
|
||||
model models.Model
|
||||
tools []tools.BaseTool
|
||||
agent provider.Provider
|
||||
titleGenerator provider.Provider
|
||||
}
|
||||
|
||||
func (c *agent) handleTitleGeneration(sessionID, content string) {
|
||||
response, err := c.titleGenerator.SendMessages(
|
||||
c.Context,
|
||||
[]message.Message{
|
||||
{
|
||||
Role: message.User,
|
||||
Content: content,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
session, err := c.Sessions.Get(sessionID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if response.Content != "" {
|
||||
session.Title = response.Content
|
||||
c.Sessions.Save(session)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
|
||||
session, err := c.Sessions.Get(sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
|
||||
model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
|
||||
model.CostPer1MIn/1e6*float64(usage.InputTokens) +
|
||||
model.CostPer1MOut/1e6*float64(usage.OutputTokens)
|
||||
|
||||
session.Cost += cost
|
||||
session.CompletionTokens += usage.OutputTokens
|
||||
session.PromptTokens += usage.InputTokens
|
||||
|
||||
_, err = c.Sessions.Save(session)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *agent) processEvent(
|
||||
sessionID string,
|
||||
assistantMsg *message.Message,
|
||||
event provider.ProviderEvent,
|
||||
) error {
|
||||
switch event.Type {
|
||||
case provider.EventThinkingDelta:
|
||||
assistantMsg.Thinking += event.Thinking
|
||||
return c.Messages.Update(*assistantMsg)
|
||||
case provider.EventContentDelta:
|
||||
assistantMsg.Content += event.Content
|
||||
return c.Messages.Update(*assistantMsg)
|
||||
case provider.EventError:
|
||||
log.Println("error", event.Error)
|
||||
return event.Error
|
||||
|
||||
case provider.EventComplete:
|
||||
assistantMsg.ToolCalls = event.Response.ToolCalls
|
||||
err := c.Messages.Update(*assistantMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.TrackUsage(sessionID, c.model, event.Response.Usage)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
|
||||
var wg sync.WaitGroup
|
||||
toolResults := make([]message.ToolResult, len(toolCalls))
|
||||
mutex := &sync.Mutex{}
|
||||
|
||||
for i, tc := range toolCalls {
|
||||
wg.Add(1)
|
||||
go func(index int, toolCall message.ToolCall) {
|
||||
defer wg.Done()
|
||||
|
||||
response := ""
|
||||
isError := false
|
||||
found := false
|
||||
|
||||
for _, tool := range tls {
|
||||
if tool.Info().Name == toolCall.Name {
|
||||
found = true
|
||||
toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
|
||||
ID: toolCall.ID,
|
||||
Name: toolCall.Name,
|
||||
Input: toolCall.Input,
|
||||
})
|
||||
if toolErr != nil {
|
||||
response = fmt.Sprintf("error running tool: %s", toolErr)
|
||||
isError = true
|
||||
} else {
|
||||
response = toolResult.Content
|
||||
isError = toolResult.IsError
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
response = fmt.Sprintf("tool not found: %s", toolCall.Name)
|
||||
isError = true
|
||||
}
|
||||
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
toolResults[index] = message.ToolResult{
|
||||
ToolCallID: toolCall.ID,
|
||||
Content: response,
|
||||
IsError: isError,
|
||||
}
|
||||
}(i, tc)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return toolResults, nil
|
||||
}
|
||||
|
||||
func (c *agent) handleToolExecution(
|
||||
ctx context.Context,
|
||||
assistantMsg message.Message,
|
||||
) (*message.Message, error) {
|
||||
if len(assistantMsg.ToolCalls) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls, c.tools)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
|
||||
Role: message.Tool,
|
||||
ToolResults: toolResults,
|
||||
})
|
||||
|
||||
return &msg, err
|
||||
}
|
||||
|
||||
func (c *agent) generate(sessionID string, content string) error {
|
||||
messages, err := c.Messages.List(sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(messages) == 0 {
|
||||
go c.handleTitleGeneration(sessionID, content)
|
||||
}
|
||||
|
||||
userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
|
||||
Role: message.User,
|
||||
Content: content,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
messages = append(messages, userMsg)
|
||||
for {
|
||||
|
||||
eventChan, err := c.agent.StreamResponse(c.Context, messages, c.tools)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
|
||||
Role: message.Assistant,
|
||||
Content: "",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for event := range eventChan {
|
||||
err = c.processEvent(sessionID, &assistantMsg, event)
|
||||
if err != nil {
|
||||
assistantMsg.Finished = true
|
||||
c.Messages.Update(assistantMsg)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
msg, err := c.handleToolExecution(c.Context, assistantMsg)
|
||||
assistantMsg.Finished = true
|
||||
c.Messages.Update(assistantMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(assistantMsg.ToolCalls) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
messages = append(messages, assistantMsg)
|
||||
if msg != nil {
|
||||
messages = append(messages, *msg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
|
||||
maxTokens := config.Get().Model.CoderMaxTokens
|
||||
|
||||
providerConfig, ok := config.Get().Providers[model.Provider]
|
||||
if !ok || !providerConfig.Enabled {
|
||||
return nil, nil, errors.New("provider is not enabled")
|
||||
}
|
||||
var agentProvider provider.Provider
|
||||
var titleGenerator provider.Provider
|
||||
|
||||
switch model.Provider {
|
||||
case models.ProviderOpenAI:
|
||||
var err error
|
||||
agentProvider, err = provider.NewOpenAIProvider(
|
||||
provider.WithOpenAISystemMessage(
|
||||
prompt.CoderOpenAISystemPrompt(),
|
||||
),
|
||||
provider.WithOpenAIMaxTokens(maxTokens),
|
||||
provider.WithOpenAIModel(model),
|
||||
provider.WithOpenAIKey(providerConfig.APIKey),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
titleGenerator, err = provider.NewOpenAIProvider(
|
||||
provider.WithOpenAISystemMessage(
|
||||
prompt.TitlePrompt(),
|
||||
),
|
||||
provider.WithOpenAIMaxTokens(80),
|
||||
provider.WithOpenAIModel(model),
|
||||
provider.WithOpenAIKey(providerConfig.APIKey),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
case models.ProviderAnthropic:
|
||||
var err error
|
||||
agentProvider, err = provider.NewAnthropicProvider(
|
||||
provider.WithAnthropicSystemMessage(
|
||||
prompt.CoderAnthropicSystemPrompt(),
|
||||
),
|
||||
provider.WithAnthropicMaxTokens(maxTokens),
|
||||
provider.WithAnthropicKey(providerConfig.APIKey),
|
||||
provider.WithAnthropicModel(model),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
titleGenerator, err = provider.NewAnthropicProvider(
|
||||
provider.WithAnthropicSystemMessage(
|
||||
prompt.TitlePrompt(),
|
||||
),
|
||||
provider.WithAnthropicMaxTokens(80),
|
||||
provider.WithAnthropicKey(providerConfig.APIKey),
|
||||
provider.WithAnthropicModel(model),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
case models.ProviderGemini:
|
||||
var err error
|
||||
agentProvider, err = provider.NewGeminiProvider(
|
||||
ctx,
|
||||
provider.WithGeminiSystemMessage(
|
||||
prompt.CoderOpenAISystemPrompt(),
|
||||
),
|
||||
provider.WithGeminiMaxTokens(int32(maxTokens)),
|
||||
provider.WithGeminiKey(providerConfig.APIKey),
|
||||
provider.WithGeminiModel(model),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
titleGenerator, err = provider.NewGeminiProvider(
|
||||
ctx,
|
||||
provider.WithGeminiSystemMessage(
|
||||
prompt.TitlePrompt(),
|
||||
),
|
||||
provider.WithGeminiMaxTokens(80),
|
||||
provider.WithGeminiKey(providerConfig.APIKey),
|
||||
provider.WithGeminiModel(model),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
case models.ProviderGROQ:
|
||||
var err error
|
||||
agentProvider, err = provider.NewOpenAIProvider(
|
||||
provider.WithOpenAISystemMessage(
|
||||
prompt.CoderAnthropicSystemPrompt(),
|
||||
),
|
||||
provider.WithOpenAIMaxTokens(maxTokens),
|
||||
provider.WithOpenAIModel(model),
|
||||
provider.WithOpenAIKey(providerConfig.APIKey),
|
||||
provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
titleGenerator, err = provider.NewOpenAIProvider(
|
||||
provider.WithOpenAISystemMessage(
|
||||
prompt.TitlePrompt(),
|
||||
),
|
||||
provider.WithOpenAIMaxTokens(80),
|
||||
provider.WithOpenAIModel(model),
|
||||
provider.WithOpenAIKey(providerConfig.APIKey),
|
||||
provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return agentProvider, titleGenerator, nil
|
||||
}
|
||||
|
|
|
@ -1,182 +1,67 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
"errors"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/flow/agent/react"
|
||||
"github.com/kujtimiihoxha/termai/internal/app"
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func coderTools() []tool.BaseTool {
|
||||
wd := viper.GetString("wd")
|
||||
return []tool.BaseTool{
|
||||
tools.NewAgentTool(wd),
|
||||
tools.NewBashTool(wd),
|
||||
tools.NewLsTool(wd),
|
||||
tools.NewGlobTool(wd),
|
||||
tools.NewViewTool(wd),
|
||||
tools.NewWriteTool(wd),
|
||||
tools.NewEditTool(wd),
|
||||
type coderAgent struct {
|
||||
*agent
|
||||
}
|
||||
|
||||
func (c *coderAgent) setAgentTool(sessionID string) {
|
||||
inx := -1
|
||||
for i, tool := range c.tools {
|
||||
if tool.Info().Name == AgentToolName {
|
||||
inx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if inx == -1 {
|
||||
c.tools = append(c.tools, NewAgentTool(sessionID, c.App))
|
||||
} else {
|
||||
c.tools[inx] = NewAgentTool(sessionID, c.App)
|
||||
}
|
||||
}
|
||||
|
||||
func NewCoderAgent(ctx context.Context) (*react.Agent, error) {
|
||||
model, err := models.GetModel(ctx, models.ModelID(viper.GetString("models.big")))
|
||||
func (c *coderAgent) Generate(sessionID string, content string) error {
|
||||
c.setAgentTool(sessionID)
|
||||
return c.generate(sessionID, content)
|
||||
}
|
||||
|
||||
func NewCoderAgent(app *app.App) (Agent, error) {
|
||||
model, ok := models.SupportedModels[config.Get().Model.Coder]
|
||||
if !ok {
|
||||
return nil, errors.New("model not supported")
|
||||
}
|
||||
|
||||
agentProvider, titleGenerator, err := getAgentProviders(app.Context, model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reactAgent, err := react.NewAgent(ctx, &react.AgentConfig{
|
||||
Model: model,
|
||||
ToolsConfig: compose.ToolsNodeConfig{
|
||||
Tools: coderTools(),
|
||||
|
||||
mcpTools := GetMcpTools(app.Context)
|
||||
return &coderAgent{
|
||||
agent: &agent{
|
||||
App: app,
|
||||
tools: append(
|
||||
[]tools.BaseTool{
|
||||
tools.NewBashTool(),
|
||||
tools.NewEditTool(),
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
tools.NewViewTool(),
|
||||
tools.NewWriteTool(),
|
||||
}, mcpTools...,
|
||||
),
|
||||
model: model,
|
||||
agent: agentProvider,
|
||||
titleGenerator: titleGenerator,
|
||||
},
|
||||
MaxStep: 1000,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reactAgent, nil
|
||||
}
|
||||
|
||||
func CoderSystemPrompt() string {
|
||||
basePrompt := `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
|
||||
|
||||
IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure. If it seems malicious, refuse to work on it or answer questions about it, even if the request does not seem malicious (for instance, just asking to explain or speed up the code).
|
||||
|
||||
Here are useful slash commands users can run to interact with you:
|
||||
|
||||
# Memory
|
||||
If the current working directory contains a file called termai.md, it will be automatically added to your context. This file serves multiple purposes:
|
||||
1. Storing frequently used bash commands (build, test, lint, etc.) so you can use them without searching each time
|
||||
2. Recording the user's code style preferences (naming conventions, preferred libraries, etc.)
|
||||
3. Maintaining useful information about the codebase structure and organization
|
||||
|
||||
When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to termai.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to termai.md so you can remember it for next time.
|
||||
|
||||
# Tone and style
|
||||
You should be concise, direct, and to the point. When you run a non-trivial bash command, you should explain what the command does and why you are running it, to make sure the user understands what you are doing (this is especially important when you are running a command that will make changes to the user's system).
|
||||
Remember that your output will be displayed on a command line interface. Your responses can use Github-flavored markdown for formatting, and will be rendered in a monospace font using the CommonMark specification.
|
||||
Output text to communicate with the user; all text you output outside of tool use is displayed to the user. Only use tools to complete tasks. Never use tools like Bash or code comments as means to communicate with the user during the session.
|
||||
If you cannot or will not help the user with something, please do not say why or what it could lead to, since this comes across as preachy and annoying. Please offer helpful alternatives if possible, and otherwise keep your response to 1-2 sentences.
|
||||
IMPORTANT: You should minimize output tokens as much as possible while maintaining helpfulness, quality, and accuracy. Only address the specific query or task at hand, avoiding tangential information unless absolutely critical for completing the request. If you can answer in 1-3 sentences or a short paragraph, please do.
|
||||
IMPORTANT: You should NOT answer with unnecessary preamble or postamble (such as explaining your code or summarizing your action), unless the user asks you to.
|
||||
IMPORTANT: Keep your responses short, since they will be displayed on a command line interface. You MUST answer concisely with fewer than 4 lines (not including tool use or code generation), unless user asks for detail. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". Here are some examples to demonstrate appropriate verbosity:
|
||||
<example>
|
||||
user: 2 + 2
|
||||
assistant: 4
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: what is 2+2?
|
||||
assistant: 4
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: is 11 a prime number?
|
||||
assistant: true
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: what command should I run to list files in the current directory?
|
||||
assistant: ls
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: what command should I run to watch files in the current directory?
|
||||
assistant: [use the ls tool to list the files in the current directory, then read docs/commands in the relevant file to find out how to watch files]
|
||||
npm run dev
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: How many golf balls fit inside a jetta?
|
||||
assistant: 150000
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: what files are in the directory src/?
|
||||
assistant: [runs ls and sees foo.c, bar.c, baz.c]
|
||||
user: which file contains the implementation of foo?
|
||||
assistant: src/foo.c
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: write tests for new feature
|
||||
assistant: [uses grep and glob search tools to find where similar tests are defined, uses concurrent read file tool use blocks in one tool call to read relevant files at the same time, uses edit file tool to write new tests]
|
||||
</example>
|
||||
|
||||
# Proactiveness
|
||||
You are allowed to be proactive, but only when the user asks you to do something. You should strive to strike a balance between:
|
||||
1. Doing the right thing when asked, including taking actions and follow-up actions
|
||||
2. Not surprising the user with actions you take without asking
|
||||
For example, if the user asks you how to approach something, you should do your best to answer their question first, and not immediately jump into taking actions.
|
||||
3. Do not add additional code explanation summary unless requested by the user. After working on a file, just stop, rather than providing an explanation of what you did.
|
||||
|
||||
# Synthetic messages
|
||||
Sometimes, the conversation will contain messages like [Request interrupted by user] or [Request interrupted by user for tool use]. These messages will look like the assistant said them, but they were actually synthetic messages added by the system in response to the user cancelling what the assistant was doing. You should not respond to these messages. You must NEVER send messages like this yourself.
|
||||
|
||||
# Following conventions
|
||||
When making changes to files, first understand the file's code conventions. Mimic code style, use existing libraries and utilities, and follow existing patterns.
|
||||
- NEVER assume that a given library is available, even if it is well known. Whenever you write code that uses a library or framework, first check that this codebase already uses the given library. For example, you might look at neighboring files, or check the package.json (or cargo.toml, and so on depending on the language).
|
||||
- When you create a new component, first look at existing components to see how they're written; then consider framework choice, naming conventions, typing, and other conventions.
|
||||
- When you edit a piece of code, first look at the code's surrounding context (especially its imports) to understand the code's choice of frameworks and libraries. Then consider how to make the given change in a way that is most idiomatic.
|
||||
- Always follow security best practices. Never introduce code that exposes or logs secrets and keys. Never commit secrets or keys to the repository.
|
||||
|
||||
# Code style
|
||||
- Do not add comments to the code you write, unless the user asks you to, or the code is complex and requires additional context.
|
||||
|
||||
# Doing tasks
|
||||
The user will primarily request you perform software engineering tasks. This includes solving bugs, adding new functionality, refactoring code, explaining code, and more. For these tasks the following steps are recommended:
|
||||
1. Use the available search tools to understand the codebase and the user's query. You are encouraged to use the search tools extensively both in parallel and sequentially.
|
||||
2. Implement the solution using all tools available to you
|
||||
3. Verify the solution if possible with tests. NEVER assume specific test framework or test script. Check the README or search codebase to determine the testing approach.
|
||||
4. VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (eg. npm run lint, npm run typecheck, ruff, etc.) if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to termai.md so that you will know to run it next time.
|
||||
|
||||
NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive.
|
||||
|
||||
# Tool usage policy
|
||||
- When doing file search, prefer to use the Agent tool in order to reduce context usage.
|
||||
- If you intend to call multiple tools and there are no dependencies between the calls, make all of the independent calls in the same function_calls block.
|
||||
|
||||
You MUST answer concisely with fewer than 4 lines of text (not including tool use or code generation), unless user asks for detail.`
|
||||
|
||||
envInfo := getEnvironmentInfo()
|
||||
|
||||
return fmt.Sprintf("%s\n\n%s", basePrompt, envInfo)
|
||||
}
|
||||
|
||||
func getEnvironmentInfo() string {
|
||||
cwd := viper.GetString("wd")
|
||||
isGit := isGitRepo(cwd)
|
||||
platform := runtime.GOOS
|
||||
date := time.Now().Format("1/2/2006")
|
||||
|
||||
return fmt.Sprintf(`Here is useful information about the environment you are running in:
|
||||
<env>
|
||||
Working directory: %s
|
||||
Is directory a git repo: %s
|
||||
Platform: %s
|
||||
Today's date: %s
|
||||
</env>`, cwd, boolToYesNo(isGit), platform, date)
|
||||
}
|
||||
|
||||
func isGitRepo(dir string) bool {
|
||||
_, err := os.Stat(filepath.Join(dir, ".git"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func boolToYesNo(b bool) string {
|
||||
if b {
|
||||
return "Yes"
|
||||
}
|
||||
return "No"
|
||||
}, nil
|
||||
}
|
||||
|
|
190
internal/llm/agent/mcp-tools.go
Normal file
190
internal/llm/agent/mcp-tools.go
Normal file
|
@ -0,0 +1,190 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
"github.com/kujtimiihoxha/termai/internal/version"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
type mcpTool struct {
|
||||
mcpName string
|
||||
tool mcp.Tool
|
||||
mcpConfig config.MCPServer
|
||||
}
|
||||
|
||||
type MCPClient interface {
|
||||
Initialize(
|
||||
ctx context.Context,
|
||||
request mcp.InitializeRequest,
|
||||
) (*mcp.InitializeResult, error)
|
||||
ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
|
||||
CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
func (b *mcpTool) Info() tools.ToolInfo {
|
||||
return tools.ToolInfo{
|
||||
Name: fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
|
||||
Description: b.tool.Description,
|
||||
Parameters: b.tool.InputSchema.Properties,
|
||||
Required: b.tool.InputSchema.Required,
|
||||
}
|
||||
}
|
||||
|
||||
func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
|
||||
defer c.Close()
|
||||
initRequest := mcp.InitializeRequest{}
|
||||
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
||||
initRequest.Params.ClientInfo = mcp.Implementation{
|
||||
Name: "termai",
|
||||
Version: version.Version,
|
||||
}
|
||||
|
||||
_, err := c.Initialize(ctx, initRequest)
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
toolRequest := mcp.CallToolRequest{}
|
||||
toolRequest.Params.Name = toolName
|
||||
var args map[string]any
|
||||
if err = json.Unmarshal([]byte(input), &input); err != nil {
|
||||
return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
toolRequest.Params.Arguments = args
|
||||
result, err := c.CallTool(ctx, toolRequest)
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
output := ""
|
||||
for _, v := range result.Content {
|
||||
if v, ok := v.(mcp.TextContent); ok {
|
||||
output = v.Text
|
||||
} else {
|
||||
output = fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
return tools.NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
|
||||
permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
|
||||
p := permission.Default.Request(
|
||||
permission.CreatePermissionRequest{
|
||||
Path: config.WorkingDirectory(),
|
||||
ToolName: b.Info().Name,
|
||||
Action: "execute",
|
||||
Description: permissionDescription,
|
||||
Params: params.Input,
|
||||
},
|
||||
)
|
||||
if !p {
|
||||
return tools.NewTextErrorResponse("permission denied"), nil
|
||||
}
|
||||
|
||||
switch b.mcpConfig.Type {
|
||||
case config.MCPStdio:
|
||||
c, err := client.NewStdioMCPClient(
|
||||
b.mcpConfig.Command,
|
||||
b.mcpConfig.Env,
|
||||
b.mcpConfig.Args...,
|
||||
)
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return runTool(ctx, c, b.tool.Name, params.Input)
|
||||
case config.MCPSse:
|
||||
c, err := client.NewSSEMCPClient(
|
||||
b.mcpConfig.URL,
|
||||
client.WithHeaders(b.mcpConfig.Headers),
|
||||
)
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return runTool(ctx, c, b.tool.Name, params.Input)
|
||||
}
|
||||
|
||||
return tools.NewTextErrorResponse("invalid mcp type"), nil
|
||||
}
|
||||
|
||||
func NewMcpTool(name string, tool mcp.Tool, mcpConfig config.MCPServer) tools.BaseTool {
|
||||
return &mcpTool{
|
||||
mcpName: name,
|
||||
tool: tool,
|
||||
mcpConfig: mcpConfig,
|
||||
}
|
||||
}
|
||||
|
||||
var mcpTools []tools.BaseTool
|
||||
|
||||
func getTools(ctx context.Context, name string, m config.MCPServer, c MCPClient) []tools.BaseTool {
|
||||
var stdioTools []tools.BaseTool
|
||||
initRequest := mcp.InitializeRequest{}
|
||||
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
||||
initRequest.Params.ClientInfo = mcp.Implementation{
|
||||
Name: "termai",
|
||||
Version: version.Version,
|
||||
}
|
||||
|
||||
_, err := c.Initialize(ctx, initRequest)
|
||||
if err != nil {
|
||||
log.Printf("error initializing mcp client: %s", err)
|
||||
return stdioTools
|
||||
}
|
||||
toolsRequest := mcp.ListToolsRequest{}
|
||||
tools, err := c.ListTools(ctx, toolsRequest)
|
||||
if err != nil {
|
||||
log.Printf("error listing tools: %s", err)
|
||||
return stdioTools
|
||||
}
|
||||
for _, t := range tools.Tools {
|
||||
stdioTools = append(stdioTools, NewMcpTool(name, t, m))
|
||||
}
|
||||
defer c.Close()
|
||||
return stdioTools
|
||||
}
|
||||
|
||||
func GetMcpTools(ctx context.Context) []tools.BaseTool {
|
||||
if len(mcpTools) > 0 {
|
||||
return mcpTools
|
||||
}
|
||||
for name, m := range config.Get().MCPServers {
|
||||
switch m.Type {
|
||||
case config.MCPStdio:
|
||||
c, err := client.NewStdioMCPClient(
|
||||
m.Command,
|
||||
m.Env,
|
||||
m.Args...,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("error creating mcp client: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
mcpTools = append(mcpTools, getTools(ctx, name, m, c)...)
|
||||
case config.MCPSse:
|
||||
c, err := client.NewSSEMCPClient(
|
||||
m.URL,
|
||||
client.WithHeaders(m.Headers),
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("error creating mcp client: %s", err)
|
||||
continue
|
||||
}
|
||||
mcpTools = append(mcpTools, getTools(ctx, name, m, c)...)
|
||||
}
|
||||
}
|
||||
|
||||
return mcpTools
|
||||
}
|
44
internal/llm/agent/task.go
Normal file
44
internal/llm/agent/task.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/app"
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
)
|
||||
|
||||
type taskAgent struct {
|
||||
*agent
|
||||
}
|
||||
|
||||
func (c *taskAgent) Generate(sessionID string, content string) error {
|
||||
return c.generate(sessionID, content)
|
||||
}
|
||||
|
||||
func NewTaskAgent(app *app.App) (Agent, error) {
|
||||
model, ok := models.SupportedModels[config.Get().Model.Coder]
|
||||
if !ok {
|
||||
return nil, errors.New("model not supported")
|
||||
}
|
||||
|
||||
agentProvider, titleGenerator, err := getAgentProviders(app.Context, model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &taskAgent{
|
||||
agent: &agent{
|
||||
App: app,
|
||||
tools: []tools.BaseTool{
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
tools.NewViewTool(),
|
||||
},
|
||||
model: model,
|
||||
agent: agentProvider,
|
||||
titleGenerator: titleGenerator,
|
||||
},
|
||||
}, nil
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func GenerateTitle(ctx context.Context, content string) (string, error) {
|
||||
model, err := models.GetModel(ctx, models.ModelID(viper.GetString("models.small")))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
out, err := model.Generate(
|
||||
ctx,
|
||||
[]*schema.Message{
|
||||
schema.SystemMessage(`- you will generate a short title based on the first message a user begins a conversation with
|
||||
- ensure it is not more than 80 characters long
|
||||
- the title should be a summary of the user's message
|
||||
- do not use quotes or colons
|
||||
- the entire text you return will be used as the title`),
|
||||
schema.UserMessage(content),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return out.Content, nil
|
||||
}
|
|
@ -1,229 +0,0 @@
|
|||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/google/uuid"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/agent"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/logging"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
"github.com/kujtimiihoxha/termai/internal/pubsub"
|
||||
"github.com/kujtimiihoxha/termai/internal/session"
|
||||
|
||||
eModel "github.com/cloudwego/eino/components/model"
|
||||
enioAgent "github.com/cloudwego/eino/flow/agent"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
const (
|
||||
AgentRequestoEvent pubsub.EventType = "agent_request"
|
||||
AgentErrorEvent pubsub.EventType = "agent_error"
|
||||
AgentResponseEvent pubsub.EventType = "agent_response"
|
||||
)
|
||||
|
||||
type AgentMessageType int
|
||||
|
||||
const (
|
||||
AgentMessageTypeNewUserMessage AgentMessageType = iota
|
||||
AgentMessageTypeAgentResponse
|
||||
AgentMessageTypeError
|
||||
)
|
||||
|
||||
type agentID string
|
||||
|
||||
const (
|
||||
RootAgent agentID = "root"
|
||||
TaskAgent agentID = "task"
|
||||
)
|
||||
|
||||
type AgentEvent struct {
|
||||
ID string `json:"id"`
|
||||
Type AgentMessageType `json:"type"`
|
||||
AgentID agentID `json:"agent_id"`
|
||||
MessageID string `json:"message_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Service interface {
|
||||
pubsub.Suscriber[AgentEvent]
|
||||
|
||||
SendRequest(sessionID string, content string)
|
||||
}
|
||||
type service struct {
|
||||
*pubsub.Broker[AgentEvent]
|
||||
Requests sync.Map
|
||||
ctx context.Context
|
||||
activeRequests sync.Map
|
||||
messages message.Service
|
||||
sessions session.Service
|
||||
logger logging.Interface
|
||||
}
|
||||
|
||||
func (s *service) handleRequest(id string, sessionID string, content string) {
|
||||
cancel, ok := s.activeRequests.Load(id)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
defer cancel.(context.CancelFunc)()
|
||||
defer s.activeRequests.Delete(id)
|
||||
|
||||
history, err := s.messages.List(sessionID)
|
||||
if err != nil {
|
||||
s.Publish(AgentErrorEvent, AgentEvent{
|
||||
ID: id,
|
||||
Type: AgentMessageTypeError,
|
||||
AgentID: RootAgent,
|
||||
MessageID: "",
|
||||
SessionID: sessionID,
|
||||
Content: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
currentAgent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
|
||||
if err != nil {
|
||||
s.Publish(AgentErrorEvent, AgentEvent{
|
||||
ID: id,
|
||||
Type: AgentMessageTypeError,
|
||||
AgentID: RootAgent,
|
||||
MessageID: "",
|
||||
SessionID: sessionID,
|
||||
Content: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
messages := []*schema.Message{
|
||||
{
|
||||
Role: schema.System,
|
||||
Content: systemMessage,
|
||||
},
|
||||
}
|
||||
for _, m := range history {
|
||||
messages = append(messages, &m.MessageData)
|
||||
}
|
||||
|
||||
builder := callbacks.NewHandlerBuilder()
|
||||
builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
|
||||
i, ok := input.(*eModel.CallbackInput)
|
||||
if info.Component == "ChatModel" && ok {
|
||||
if len(messages) < len(i.Messages) {
|
||||
// find new messages
|
||||
newMessages := i.Messages[len(messages):]
|
||||
for _, m := range newMessages {
|
||||
_, err = s.messages.Create(sessionID, *m)
|
||||
if err != nil {
|
||||
s.Publish(AgentErrorEvent, AgentEvent{
|
||||
ID: id,
|
||||
Type: AgentMessageTypeError,
|
||||
AgentID: RootAgent,
|
||||
MessageID: "",
|
||||
SessionID: sessionID,
|
||||
Content: err.Error(),
|
||||
})
|
||||
}
|
||||
messages = append(messages, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ctx
|
||||
})
|
||||
builder.OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
|
||||
return ctx
|
||||
})
|
||||
|
||||
out, err := currentAgent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
|
||||
if err != nil {
|
||||
s.Publish(AgentErrorEvent, AgentEvent{
|
||||
ID: id,
|
||||
Type: AgentMessageTypeError,
|
||||
AgentID: RootAgent,
|
||||
MessageID: "",
|
||||
SessionID: sessionID,
|
||||
Content: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
usage := out.ResponseMeta.Usage
|
||||
s.messages.Create(sessionID, *out)
|
||||
if usage != nil {
|
||||
log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
|
||||
session, err := s.sessions.Get(sessionID)
|
||||
if err != nil {
|
||||
s.Publish(AgentErrorEvent, AgentEvent{
|
||||
ID: id,
|
||||
Type: AgentMessageTypeError,
|
||||
AgentID: RootAgent,
|
||||
MessageID: "",
|
||||
SessionID: sessionID,
|
||||
Content: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
session.PromptTokens += int64(usage.PromptTokens)
|
||||
session.CompletionTokens += int64(usage.CompletionTokens)
|
||||
model := models.SupportedModels[models.ModelID(viper.GetString("models.big"))]
|
||||
session.Cost += float64(usage.PromptTokens)*(model.CostPer1MIn/1_000_000) +
|
||||
float64(usage.CompletionTokens)*(model.CostPer1MOut/1_000_000)
|
||||
var newTitle string
|
||||
if len(history) == 1 {
|
||||
// first message generate the title
|
||||
newTitle, err = agent.GenerateTitle(s.ctx, content)
|
||||
if err != nil {
|
||||
s.Publish(AgentErrorEvent, AgentEvent{
|
||||
ID: id,
|
||||
Type: AgentMessageTypeError,
|
||||
AgentID: RootAgent,
|
||||
MessageID: "",
|
||||
SessionID: sessionID,
|
||||
Content: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if newTitle != "" {
|
||||
session.Title = newTitle
|
||||
}
|
||||
|
||||
_, err = s.sessions.Save(session)
|
||||
if err != nil {
|
||||
s.Publish(AgentErrorEvent, AgentEvent{
|
||||
ID: id,
|
||||
Type: AgentMessageTypeError,
|
||||
AgentID: RootAgent,
|
||||
MessageID: "",
|
||||
SessionID: sessionID,
|
||||
Content: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *service) SendRequest(sessionID string, content string) {
|
||||
id := uuid.New().String()
|
||||
|
||||
_, cancel := context.WithTimeout(s.ctx, 5*time.Minute)
|
||||
s.activeRequests.Store(id, cancel)
|
||||
log.Printf("Request: %s", content)
|
||||
go s.handleRequest(id, sessionID, content)
|
||||
}
|
||||
|
||||
func NewService(ctx context.Context, logger logging.Interface, sessions session.Service, messages message.Service) Service {
|
||||
return &service{
|
||||
Broker: pubsub.NewBroker[AgentEvent](),
|
||||
ctx: ctx,
|
||||
sessions: sessions,
|
||||
messages: messages,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
|
@ -1,230 +1,122 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/claude"
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type (
|
||||
ModelID string
|
||||
ModelProvider string
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
ID ModelID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider ModelProvider `json:"provider"`
|
||||
APIModel string `json:"api_model"`
|
||||
CostPer1MIn float64 `json:"cost_per_1m_in"`
|
||||
CostPer1MOut float64 `json:"cost_per_1m_out"`
|
||||
ID ModelID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider ModelProvider `json:"provider"`
|
||||
APIModel string `json:"api_model"`
|
||||
CostPer1MIn float64 `json:"cost_per_1m_in"`
|
||||
CostPer1MOut float64 `json:"cost_per_1m_out"`
|
||||
CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
|
||||
CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultBigModel = Claude37Sonnet
|
||||
DefaultLittleModel = Claude37Sonnet
|
||||
)
|
||||
|
||||
// Model IDs
|
||||
const (
|
||||
// OpenAI
|
||||
GPT4o ModelID = "gpt-4o"
|
||||
GPT4oMini ModelID = "gpt-4o-mini"
|
||||
GPT45 ModelID = "gpt-4.5"
|
||||
O1 ModelID = "o1"
|
||||
O1Mini ModelID = "o1-mini"
|
||||
// Anthropic
|
||||
Claude35Sonnet ModelID = "claude-3.5-sonnet"
|
||||
Claude3Haiku ModelID = "claude-3-haiku"
|
||||
Claude37Sonnet ModelID = "claude-3.7-sonnet"
|
||||
// Google
|
||||
Gemini20Pro ModelID = "gemini-2.0-pro"
|
||||
Gemini15Flash ModelID = "gemini-1.5-flash"
|
||||
Gemini20Flash ModelID = "gemini-2.0-flash"
|
||||
// xAI
|
||||
Grok3 ModelID = "grok-3"
|
||||
Grok2Mini ModelID = "grok-2-mini"
|
||||
// DeepSeek
|
||||
DeepSeekR1 ModelID = "deepseek-r1"
|
||||
DeepSeekCoder ModelID = "deepseek-coder"
|
||||
// Meta
|
||||
Llama3 ModelID = "llama-3"
|
||||
Llama270B ModelID = "llama-2-70b"
|
||||
// OpenAI
|
||||
GPT4o ModelID = "gpt-4o"
|
||||
|
||||
// GEMINI
|
||||
GEMINI25 ModelID = "gemini-2.5"
|
||||
GRMINI20Flash ModelID = "gemini-2.0-flash"
|
||||
|
||||
// GROQ
|
||||
GroqLlama3SpecDec ModelID = "groq-llama-3-spec-dec"
|
||||
GroqQwen32BCoder ModelID = "qwen-2.5-coder-32b"
|
||||
QWENQwq ModelID = "qwen-qwq"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderOpenAI ModelProvider = "openai"
|
||||
ProviderAnthropic ModelProvider = "anthropic"
|
||||
ProviderGoogle ModelProvider = "google"
|
||||
ProviderXAI ModelProvider = "xai"
|
||||
ProviderDeepSeek ModelProvider = "deepseek"
|
||||
ProviderMeta ModelProvider = "meta"
|
||||
ProviderGroq ModelProvider = "groq"
|
||||
ProviderGemini ModelProvider = "gemini"
|
||||
ProviderGROQ ModelProvider = "groq"
|
||||
)
|
||||
|
||||
var SupportedModels = map[ModelID]Model{
|
||||
// OpenAI
|
||||
GPT4o: {
|
||||
ID: GPT4o,
|
||||
Name: "GPT-4o",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4o",
|
||||
},
|
||||
GPT4oMini: {
|
||||
ID: GPT4oMini,
|
||||
Name: "GPT-4o Mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4o-mini",
|
||||
CostPer1MIn: 0.150,
|
||||
CostPer1MOut: 0.600,
|
||||
},
|
||||
GPT45: {
|
||||
ID: GPT45,
|
||||
Name: "GPT-4.5",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.5",
|
||||
},
|
||||
O1: {
|
||||
ID: O1,
|
||||
Name: "o1",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o1",
|
||||
},
|
||||
O1Mini: {
|
||||
ID: O1Mini,
|
||||
Name: "o1 Mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o1-mini",
|
||||
},
|
||||
// Anthropic
|
||||
Claude35Sonnet: {
|
||||
ID: Claude35Sonnet,
|
||||
Name: "Claude 3.5 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3.5-sonnet",
|
||||
ID: Claude35Sonnet,
|
||||
Name: "Claude 3.5 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-5-sonnet-latest",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
},
|
||||
Claude3Haiku: {
|
||||
ID: Claude3Haiku,
|
||||
Name: "Claude 3 Haiku",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-haiku",
|
||||
ID: Claude3Haiku,
|
||||
Name: "Claude 3 Haiku",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-haiku-latest",
|
||||
CostPer1MIn: 0.80,
|
||||
CostPer1MInCached: 1,
|
||||
CostPer1MOutCached: 0.08,
|
||||
CostPer1MOut: 4,
|
||||
},
|
||||
Claude37Sonnet: {
|
||||
ID: Claude37Sonnet,
|
||||
Name: "Claude 3.7 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-7-sonnet-20250219",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MOut: 15.0,
|
||||
ID: Claude37Sonnet,
|
||||
Name: "Claude 3.7 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-7-sonnet-latest",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
},
|
||||
// Google
|
||||
Gemini20Pro: {
|
||||
ID: Gemini20Pro,
|
||||
Name: "Gemini 2.0 Pro",
|
||||
Provider: ProviderGoogle,
|
||||
APIModel: "gemini-2.0-pro",
|
||||
|
||||
// OpenAI
|
||||
GPT4o: {
|
||||
ID: GPT4o,
|
||||
Name: "GPT-4o",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4o",
|
||||
CostPer1MIn: 2.50,
|
||||
CostPer1MInCached: 1.25,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 10.00,
|
||||
},
|
||||
Gemini15Flash: {
|
||||
ID: Gemini15Flash,
|
||||
Name: "Gemini 1.5 Flash",
|
||||
Provider: ProviderGoogle,
|
||||
APIModel: "gemini-1.5-flash",
|
||||
|
||||
// GEMINI
|
||||
GEMINI25: {
|
||||
ID: GEMINI25,
|
||||
Name: "Gemini 2.5 Pro",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.5-pro-exp-03-25",
|
||||
CostPer1MIn: 0,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0,
|
||||
},
|
||||
Gemini20Flash: {
|
||||
ID: Gemini20Flash,
|
||||
Name: "Gemini 2.0 Flash",
|
||||
Provider: ProviderGoogle,
|
||||
APIModel: "gemini-2.0-flash",
|
||||
},
|
||||
// xAI
|
||||
Grok3: {
|
||||
ID: Grok3,
|
||||
Name: "Grok 3",
|
||||
Provider: ProviderXAI,
|
||||
APIModel: "grok-3",
|
||||
},
|
||||
Grok2Mini: {
|
||||
ID: Grok2Mini,
|
||||
Name: "Grok 2 Mini",
|
||||
Provider: ProviderXAI,
|
||||
APIModel: "grok-2-mini",
|
||||
},
|
||||
// DeepSeek
|
||||
DeepSeekR1: {
|
||||
ID: DeepSeekR1,
|
||||
Name: "DeepSeek R1",
|
||||
Provider: ProviderDeepSeek,
|
||||
APIModel: "deepseek-r1",
|
||||
},
|
||||
DeepSeekCoder: {
|
||||
ID: DeepSeekCoder,
|
||||
Name: "DeepSeek Coder",
|
||||
Provider: ProviderDeepSeek,
|
||||
APIModel: "deepseek-coder",
|
||||
},
|
||||
// Meta
|
||||
Llama3: {
|
||||
ID: Llama3,
|
||||
Name: "LLaMA 3",
|
||||
Provider: ProviderMeta,
|
||||
APIModel: "llama-3",
|
||||
},
|
||||
Llama270B: {
|
||||
ID: Llama270B,
|
||||
Name: "LLaMA 2 70B",
|
||||
Provider: ProviderMeta,
|
||||
APIModel: "llama-2-70b",
|
||||
|
||||
GRMINI20Flash: {
|
||||
ID: GRMINI20Flash,
|
||||
Name: "Gemini 2.0 Flash",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.0-flash",
|
||||
CostPer1MIn: 0.1,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0.025,
|
||||
CostPer1MOut: 0.4,
|
||||
},
|
||||
|
||||
// GROQ
|
||||
GroqLlama3SpecDec: {
|
||||
ID: GroqLlama3SpecDec,
|
||||
Name: "GROQ LLaMA 3 SpecDec",
|
||||
Provider: ProviderGroq,
|
||||
APIModel: "llama-3.3-70b-specdec",
|
||||
},
|
||||
GroqQwen32BCoder: {
|
||||
ID: GroqQwen32BCoder,
|
||||
Name: "GROQ Qwen 2.5 Coder 32B",
|
||||
Provider: ProviderGroq,
|
||||
APIModel: "qwen-2.5-coder-32b",
|
||||
QWENQwq: {
|
||||
ID: QWENQwq,
|
||||
Name: "Qwen Qwq",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "qwen-qwq-32b",
|
||||
CostPer1MIn: 0,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0,
|
||||
},
|
||||
}
|
||||
|
||||
func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
|
||||
provider := SupportedModels[model].Provider
|
||||
log.Printf("Provider: %s", provider)
|
||||
maxTokens := viper.GetInt("providers.common.max_tokens")
|
||||
switch provider {
|
||||
case ProviderOpenAI:
|
||||
return openai.NewChatModel(ctx, &openai.ChatModelConfig{
|
||||
APIKey: viper.GetString("providers.openai.key"),
|
||||
Model: string(SupportedModels[model].APIModel),
|
||||
MaxTokens: &maxTokens,
|
||||
})
|
||||
case ProviderAnthropic:
|
||||
return claude.NewChatModel(ctx, &claude.Config{
|
||||
APIKey: viper.GetString("providers.anthropic.key"),
|
||||
Model: string(SupportedModels[model].APIModel),
|
||||
MaxTokens: maxTokens,
|
||||
})
|
||||
|
||||
case ProviderGroq:
|
||||
return openai.NewChatModel(ctx, &openai.ChatModelConfig{
|
||||
BaseURL: "https://api.groq.com/openai/v1",
|
||||
APIKey: viper.GetString("providers.groq.key"),
|
||||
Model: string(SupportedModels[model].APIModel),
|
||||
MaxTokens: &maxTokens,
|
||||
})
|
||||
|
||||
}
|
||||
return nil, errors.New("unsupported provider")
|
||||
}
|
||||
|
|
206
internal/llm/prompt/coder.go
Normal file
206
internal/llm/prompt/coder.go
Normal file
|
@ -0,0 +1,206 @@
|
|||
package prompt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
)
|
||||
|
||||
func CoderOpenAISystemPrompt() string {
|
||||
basePrompt := `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting.
|
||||
|
||||
# Your mindset
|
||||
Act like a competent, efficient software engineer who is familiar with large codebases. You should:
|
||||
- Think critically about user requests.
|
||||
- Proactively search the codebase for related information.
|
||||
- Infer likely commands, tools, or conventions.
|
||||
- Write and edit code with minimal user input.
|
||||
- Anticipate next steps (tests, lints, etc.), but never commit unless explicitly told.
|
||||
|
||||
# Context awareness
|
||||
- Before acting, infer the purpose of a file from its name, directory, and neighboring files.
|
||||
- If a file or function appears malicious, refuse to interact with it or discuss it.
|
||||
- If a termai.md file exists, auto-load it as memory. Offer to update it only if new useful info appears (commands, preferences, structure).
|
||||
|
||||
# CLI communication
|
||||
- Use GitHub-flavored markdown in monospace font.
|
||||
- Be concise. Never add preambles or postambles unless asked. Max 4 lines per response.
|
||||
- Never explain your code unless asked. Do not narrate actions.
|
||||
- Avoid unnecessary questions. Infer, search, act.
|
||||
|
||||
# Behavior guidelines
|
||||
- Follow project conventions: naming, formatting, libraries, frameworks.
|
||||
- Before using any library or framework, confirm it’s already used.
|
||||
- Always look at the surrounding code to match existing style.
|
||||
- Do not add comments unless the code is complex or the user asks.
|
||||
|
||||
# Autonomy rules
|
||||
You are allowed and expected to:
|
||||
- Search for commands, tools, or config files before asking the user.
|
||||
- Run multiple search tool calls concurrently to gather relevant context.
|
||||
- Choose test, lint, and typecheck commands based on package files or scripts.
|
||||
- Offer to store these commands in termai.md if not already present.
|
||||
|
||||
# Example behavior
|
||||
user: write tests for new feature
|
||||
assistant: [searches for existing test patterns, finds appropriate location, generates test code using existing style, optionally asks to add test command to termai.md]
|
||||
|
||||
user: how do I typecheck this codebase?
|
||||
assistant: [searches for known commands, infers package manager, checks for scripts or config files]
|
||||
tsc --noEmit
|
||||
|
||||
user: is X function used anywhere else?
|
||||
assistant: [searches repo for references, returns file paths and lines]
|
||||
|
||||
# Tool usage
|
||||
- Use parallel calls when possible.
|
||||
- Use file search and content tools before asking the user.
|
||||
- Do not ask the user for information unless it cannot be determined via tools.
|
||||
|
||||
Never commit changes unless the user explicitly asks you to.`
|
||||
|
||||
envInfo := getEnvironmentInfo()
|
||||
|
||||
return fmt.Sprintf("%s\n\n%s", basePrompt, envInfo)
|
||||
}
|
||||
|
||||
func CoderAnthropicSystemPrompt() string {
|
||||
basePrompt := `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
|
||||
|
||||
IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure.
|
||||
|
||||
# Memory
|
||||
If the current working directory contains a file called termai.md, it will be automatically added to your context. This file serves multiple purposes:
|
||||
1. Storing frequently used bash commands (build, test, lint, etc.) so you can use them without searching each time
|
||||
2. Recording the user's code style preferences (naming conventions, preferred libraries, etc.)
|
||||
3. Maintaining useful information about the codebase structure and organization
|
||||
|
||||
When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to termai.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to termai.md so you can remember it for next time.
|
||||
|
||||
# Tone and style
|
||||
You should be concise, direct, and to the point. When you run a non-trivial bash command, you should explain what the command does and why you are running it, to make sure the user understands what you are doing (this is especially important when you are running a command that will make changes to the user's system).
|
||||
Remember that your output will be displayed on a command line interface. Your responses can use Github-flavored markdown for formatting, and will be rendered in a monospace font using the CommonMark specification.
|
||||
Output text to communicate with the user; all text you output outside of tool use is displayed to the user. Only use tools to complete tasks. Never use tools like Bash or code comments as means to communicate with the user during the session.
|
||||
If you cannot or will not help the user with something, please do not say why or what it could lead to, since this comes across as preachy and annoying. Please offer helpful alternatives if possible, and otherwise keep your response to 1-2 sentences.
|
||||
IMPORTANT: You should minimize output tokens as much as possible while maintaining helpfulness, quality, and accuracy. Only address the specific query or task at hand, avoiding tangential information unless absolutely critical for completing the request. If you can answer in 1-3 sentences or a short paragraph, please do.
|
||||
IMPORTANT: You should NOT answer with unnecessary preamble or postamble (such as explaining your code or summarizing your action), unless the user asks you to.
|
||||
IMPORTANT: Keep your responses short, since they will be displayed on a command line interface. You MUST answer concisely with fewer than 4 lines (not including tool use or code generation), unless user asks for detail. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". Here are some examples to demonstrate appropriate verbosity:
|
||||
<example>
|
||||
user: 2 + 2
|
||||
assistant: 4
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: what is 2+2?
|
||||
assistant: 4
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: is 11 a prime number?
|
||||
assistant: true
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: what command should I run to list files in the current directory?
|
||||
assistant: ls
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: what command should I run to watch files in the current directory?
|
||||
assistant: [use the ls tool to list the files in the current directory, then read docs/commands in the relevant file to find out how to watch files]
|
||||
npm run dev
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: How many golf balls fit inside a jetta?
|
||||
assistant: 150000
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: what files are in the directory src/?
|
||||
assistant: [runs ls and sees foo.c, bar.c, baz.c]
|
||||
user: which file contains the implementation of foo?
|
||||
assistant: src/foo.c
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: write tests for new feature
|
||||
assistant: [uses grep and glob search tools to find where similar tests are defined, uses concurrent read file tool use blocks in one tool call to read relevant files at the same time, uses edit file tool to write new tests]
|
||||
</example>
|
||||
|
||||
# Proactiveness
|
||||
You are allowed to be proactive, but only when the user asks you to do something. You should strive to strike a balance between:
|
||||
1. Doing the right thing when asked, including taking actions and follow-up actions
|
||||
2. Not surprising the user with actions you take without asking
|
||||
For example, if the user asks you how to approach something, you should do your best to answer their question first, and not immediately jump into taking actions.
|
||||
3. Do not add additional code explanation summary unless requested by the user. After working on a file, just stop, rather than providing an explanation of what you did.
|
||||
|
||||
# Following conventions
|
||||
When making changes to files, first understand the file's code conventions. Mimic code style, use existing libraries and utilities, and follow existing patterns.
|
||||
- NEVER assume that a given library is available, even if it is well known. Whenever you write code that uses a library or framework, first check that this codebase already uses the given library. For example, you might look at neighboring files, or check the package.json (or cargo.toml, and so on depending on the language).
|
||||
- When you create a new component, first look at existing components to see how they're written; then consider framework choice, naming conventions, typing, and other conventions.
|
||||
- When you edit a piece of code, first look at the code's surrounding context (especially its imports) to understand the code's choice of frameworks and libraries. Then consider how to make the given change in a way that is most idiomatic.
|
||||
- Always follow security best practices. Never introduce code that exposes or logs secrets and keys. Never commit secrets or keys to the repository.
|
||||
|
||||
# Code style
|
||||
- Do not add comments to the code you write, unless the user asks you to, or the code is complex and requires additional context.
|
||||
|
||||
# Doing tasks
|
||||
The user will primarily request you perform software engineering tasks. This includes solving bugs, adding new functionality, refactoring code, explaining code, and more. For these tasks the following steps are recommended:
|
||||
1. Use the available search tools to understand the codebase and the user's query. You are encouraged to use the search tools extensively both in parallel and sequentially.
|
||||
2. Implement the solution using all tools available to you
|
||||
3. Verify the solution if possible with tests. NEVER assume specific test framework or test script. Check the README or search codebase to determine the testing approach.
|
||||
4. VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (eg. npm run lint, npm run typecheck, ruff, etc.) if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to termai.md so that you will know to run it next time.
|
||||
|
||||
NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive.
|
||||
|
||||
# Tool usage policy
|
||||
- When doing file search, prefer to use the Agent tool in order to reduce context usage.
|
||||
- If you intend to call multiple tools and there are no dependencies between the calls, make all of the independent calls in the same function_calls block.
|
||||
|
||||
You MUST answer concisely with fewer than 4 lines of text (not including tool use or code generation), unless user asks for detail.`
|
||||
|
||||
envInfo := getEnvironmentInfo()
|
||||
|
||||
return fmt.Sprintf("%s\n\n%s", basePrompt, envInfo)
|
||||
}
|
||||
|
||||
func getEnvironmentInfo() string {
|
||||
cwd := config.WorkingDirectory()
|
||||
isGit := isGitRepo(cwd)
|
||||
platform := runtime.GOOS
|
||||
date := time.Now().Format("1/2/2006")
|
||||
ls := tools.NewLsTool()
|
||||
r, _ := ls.Run(context.Background(), tools.ToolCall{
|
||||
Input: `{"path":"."}`,
|
||||
})
|
||||
return fmt.Sprintf(`Here is useful information about the environment you are running in:
|
||||
<env>
|
||||
Working directory: %s
|
||||
Is directory a git repo: %s
|
||||
Platform: %s
|
||||
Today's date: %s
|
||||
</env>
|
||||
<project>
|
||||
%s
|
||||
</project>
|
||||
`, cwd, boolToYesNo(isGit), platform, date, r.Content)
|
||||
}
|
||||
|
||||
func isGitRepo(dir string) bool {
|
||||
_, err := os.Stat(filepath.Join(dir, ".git"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func boolToYesNo(b bool) string {
|
||||
if b {
|
||||
return "Yes"
|
||||
}
|
||||
return "No"
|
||||
}
|
16
internal/llm/prompt/task.go
Normal file
16
internal/llm/prompt/task.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
package prompt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func TaskAgentSystemPrompt() string {
|
||||
agentPrompt := `You are an agent for termAI. Given the user's prompt, you should use the tools available to you to answer the user's question.
|
||||
|
||||
Notes:
|
||||
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
|
||||
2. When relevant, share file names and code snippets relevant to the query
|
||||
3. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.`
|
||||
|
||||
return fmt.Sprintf("%s\n%s\n", agentPrompt, getEnvironmentInfo())
|
||||
}
|
9
internal/llm/prompt/title.go
Normal file
9
internal/llm/prompt/title.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package prompt
|
||||
|
||||
func TitlePrompt() string {
|
||||
return `you will generate a short title based on the first message a user begins a conversation with
|
||||
- ensure it is not more than 50 characters long
|
||||
- the title should be a summary of the user's message
|
||||
- do not use quotes or colons
|
||||
- the entire text you return will be used as the title`
|
||||
}
|
309
internal/llm/provider/anthropic.go
Normal file
309
internal/llm/provider/anthropic.go
Normal file
|
@ -0,0 +1,309 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
)
|
||||
|
||||
type anthropicProvider struct {
|
||||
client anthropic.Client
|
||||
model models.Model
|
||||
maxTokens int64
|
||||
apiKey string
|
||||
systemMessage string
|
||||
}
|
||||
|
||||
type AnthropicOption func(*anthropicProvider)
|
||||
|
||||
func WithAnthropicSystemMessage(message string) AnthropicOption {
|
||||
return func(a *anthropicProvider) {
|
||||
a.systemMessage = message
|
||||
}
|
||||
}
|
||||
|
||||
func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
|
||||
return func(a *anthropicProvider) {
|
||||
a.maxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
func WithAnthropicModel(model models.Model) AnthropicOption {
|
||||
return func(a *anthropicProvider) {
|
||||
a.model = model
|
||||
}
|
||||
}
|
||||
|
||||
func WithAnthropicKey(apiKey string) AnthropicOption {
|
||||
return func(a *anthropicProvider) {
|
||||
a.apiKey = apiKey
|
||||
}
|
||||
}
|
||||
|
||||
func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
|
||||
provider := &anthropicProvider{
|
||||
maxTokens: 1024,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(provider)
|
||||
}
|
||||
|
||||
if provider.systemMessage == "" {
|
||||
return nil, errors.New("system message is required")
|
||||
}
|
||||
|
||||
provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
anthropicMessages := a.convertToAnthropicMessages(messages)
|
||||
anthropicTools := a.convertToAnthropicTools(tools)
|
||||
|
||||
response, err := a.client.Messages.New(ctx, anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(a.model.APIModel),
|
||||
MaxTokens: a.maxTokens,
|
||||
Temperature: anthropic.Float(0),
|
||||
Messages: anthropicMessages,
|
||||
Tools: anthropicTools,
|
||||
System: []anthropic.TextBlockParam{
|
||||
{
|
||||
Text: a.systemMessage,
|
||||
CacheControl: anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
content := ""
|
||||
for _, block := range response.Content {
|
||||
if text, ok := block.AsAny().(anthropic.TextBlock); ok {
|
||||
content += text.Text
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls := a.extractToolCalls(response.Content)
|
||||
tokenUsage := a.extractTokenUsage(response.Usage)
|
||||
|
||||
return &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
|
||||
anthropicMessages := a.convertToAnthropicMessages(messages)
|
||||
anthropicTools := a.convertToAnthropicTools(tools)
|
||||
|
||||
var thinkingParam anthropic.ThinkingConfigParamUnion
|
||||
lastMessage := messages[len(messages)-1]
|
||||
temperature := anthropic.Float(0)
|
||||
if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content), "think") {
|
||||
thinkingParam = anthropic.ThinkingConfigParamUnion{
|
||||
OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
|
||||
BudgetTokens: int64(float64(a.maxTokens) * 0.8),
|
||||
Type: "enabled",
|
||||
},
|
||||
}
|
||||
temperature = anthropic.Float(1)
|
||||
}
|
||||
|
||||
stream := a.client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(a.model.APIModel),
|
||||
MaxTokens: a.maxTokens,
|
||||
Temperature: temperature,
|
||||
Messages: anthropicMessages,
|
||||
Tools: anthropicTools,
|
||||
Thinking: thinkingParam,
|
||||
System: []anthropic.TextBlockParam{
|
||||
{
|
||||
Text: a.systemMessage,
|
||||
CacheControl: anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
go func() {
|
||||
defer close(eventChan)
|
||||
|
||||
accumulatedMessage := anthropic.Message{}
|
||||
|
||||
for stream.Next() {
|
||||
event := stream.Current()
|
||||
err := accumulatedMessage.Accumulate(event)
|
||||
if err != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: err}
|
||||
return
|
||||
}
|
||||
|
||||
switch event := event.AsAny().(type) {
|
||||
case anthropic.ContentBlockStartEvent:
|
||||
eventChan <- ProviderEvent{Type: EventContentStart}
|
||||
|
||||
case anthropic.ContentBlockDeltaEvent:
|
||||
if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventThinkingDelta,
|
||||
Thinking: event.Delta.Thinking,
|
||||
}
|
||||
} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventContentDelta,
|
||||
Content: event.Delta.Text,
|
||||
}
|
||||
}
|
||||
|
||||
case anthropic.ContentBlockStopEvent:
|
||||
eventChan <- ProviderEvent{Type: EventContentStop}
|
||||
|
||||
case anthropic.MessageStopEvent:
|
||||
content := ""
|
||||
for _, block := range accumulatedMessage.Content {
|
||||
if text, ok := block.AsAny().(anthropic.TextBlock); ok {
|
||||
content += text.Text
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls := a.extractToolCalls(accumulatedMessage.Content)
|
||||
tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
|
||||
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventComplete,
|
||||
Response: &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if stream.Err() != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: stream.Err()}
|
||||
}
|
||||
}()
|
||||
|
||||
return eventChan, nil
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
|
||||
var toolCalls []message.ToolCall
|
||||
|
||||
for _, block := range content {
|
||||
switch variant := block.AsAny().(type) {
|
||||
case anthropic.ToolUseBlock:
|
||||
toolCall := message.ToolCall{
|
||||
ID: variant.ID,
|
||||
Name: variant.Name,
|
||||
Input: string(variant.Input),
|
||||
Type: string(variant.Type),
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
|
||||
return TokenUsage{
|
||||
InputTokens: usage.InputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
CacheCreationTokens: usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: usage.CacheReadInputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
|
||||
anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
|
||||
|
||||
for i, tool := range tools {
|
||||
info := tool.Info()
|
||||
toolParam := anthropic.ToolParam{
|
||||
Name: info.Name,
|
||||
Description: anthropic.String(info.Description),
|
||||
InputSchema: anthropic.ToolInputSchemaParam{
|
||||
Properties: info.Parameters,
|
||||
},
|
||||
}
|
||||
|
||||
if i == len(tools)-1 {
|
||||
toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
}
|
||||
|
||||
anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
|
||||
}
|
||||
|
||||
return anthropicTools
|
||||
}
|
||||
|
||||
func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
|
||||
anthropicMessages := make([]anthropic.MessageParam, len(messages))
|
||||
cachedBlocks := 0
|
||||
|
||||
for i, msg := range messages {
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
content := anthropic.NewTextBlock(msg.Content)
|
||||
if cachedBlocks < 2 {
|
||||
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
cachedBlocks++
|
||||
}
|
||||
anthropicMessages[i] = anthropic.NewUserMessage(content)
|
||||
|
||||
case message.Assistant:
|
||||
blocks := []anthropic.ContentBlockParamUnion{}
|
||||
if msg.Content != "" {
|
||||
content := anthropic.NewTextBlock(msg.Content)
|
||||
if cachedBlocks < 2 {
|
||||
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
cachedBlocks++
|
||||
}
|
||||
blocks = append(blocks, content)
|
||||
}
|
||||
|
||||
for _, toolCall := range msg.ToolCalls {
|
||||
var inputMap map[string]any
|
||||
err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
|
||||
}
|
||||
|
||||
anthropicMessages[i] = anthropic.NewAssistantMessage(blocks...)
|
||||
|
||||
case message.Tool:
|
||||
results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults))
|
||||
for i, toolResult := range msg.ToolResults {
|
||||
results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
|
||||
}
|
||||
anthropicMessages[i] = anthropic.NewUserMessage(results...)
|
||||
}
|
||||
}
|
||||
|
||||
return anthropicMessages
|
||||
}
|
443
internal/llm/provider/gemini.go
Normal file
443
internal/llm/provider/gemini.go
Normal file
|
@ -0,0 +1,443 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
"github.com/google/generative-ai-go/genai"
|
||||
"github.com/google/uuid"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
"google.golang.org/api/googleapi"
|
||||
"google.golang.org/api/iterator"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
type geminiProvider struct {
|
||||
client *genai.Client
|
||||
model models.Model
|
||||
maxTokens int32
|
||||
apiKey string
|
||||
systemMessage string
|
||||
}
|
||||
|
||||
type GeminiOption func(*geminiProvider)
|
||||
|
||||
func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) {
|
||||
provider := &geminiProvider{
|
||||
maxTokens: 5000,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(provider)
|
||||
}
|
||||
|
||||
if provider.systemMessage == "" {
|
||||
return nil, errors.New("system message is required")
|
||||
}
|
||||
|
||||
client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
provider.client = client
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func WithGeminiSystemMessage(message string) GeminiOption {
|
||||
return func(p *geminiProvider) {
|
||||
p.systemMessage = message
|
||||
}
|
||||
}
|
||||
|
||||
func WithGeminiMaxTokens(maxTokens int32) GeminiOption {
|
||||
return func(p *geminiProvider) {
|
||||
p.maxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
func WithGeminiModel(model models.Model) GeminiOption {
|
||||
return func(p *geminiProvider) {
|
||||
p.model = model
|
||||
}
|
||||
}
|
||||
|
||||
func WithGeminiKey(apiKey string) GeminiOption {
|
||||
return func(p *geminiProvider) {
|
||||
p.apiKey = apiKey
|
||||
}
|
||||
}
|
||||
|
||||
func (p *geminiProvider) Close() {
|
||||
if p.client != nil {
|
||||
p.client.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// convertToGeminiHistory converts the message history to Gemini's format
|
||||
func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
|
||||
var history []*genai.Content
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
history = append(history, &genai.Content{
|
||||
Parts: []genai.Part{genai.Text(msg.Content)},
|
||||
Role: "user",
|
||||
})
|
||||
case message.Assistant:
|
||||
content := &genai.Content{
|
||||
Role: "model",
|
||||
Parts: []genai.Part{},
|
||||
}
|
||||
|
||||
// Handle regular content
|
||||
if msg.Content != "" {
|
||||
content.Parts = append(content.Parts, genai.Text(msg.Content))
|
||||
}
|
||||
|
||||
// Handle tool calls if any
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
for _, call := range msg.ToolCalls {
|
||||
args, _ := parseJsonToMap(call.Input)
|
||||
content.Parts = append(content.Parts, genai.FunctionCall{
|
||||
Name: call.Name,
|
||||
Args: args,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
history = append(history, content)
|
||||
case message.Tool:
|
||||
for _, result := range msg.ToolResults {
|
||||
// Parse response content to map if possible
|
||||
response := map[string]interface{}{"result": result.Content}
|
||||
parsed, err := parseJsonToMap(result.Content)
|
||||
if err == nil {
|
||||
response = parsed
|
||||
}
|
||||
var toolCall message.ToolCall
|
||||
for _, msg := range messages {
|
||||
if msg.Role == message.Assistant {
|
||||
for _, call := range msg.ToolCalls {
|
||||
if call.ID == result.ToolCallID {
|
||||
toolCall = call
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
history = append(history, &genai.Content{
|
||||
Parts: []genai.Part{genai.FunctionResponse{
|
||||
Name: toolCall.Name,
|
||||
Response: response,
|
||||
}},
|
||||
Role: "function",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
// convertToolsToGeminiFunctionDeclarations converts tool definitions to Gemini's function declarations
|
||||
func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
|
||||
declarations := make([]*genai.FunctionDeclaration, len(tools))
|
||||
|
||||
for i, tool := range tools {
|
||||
info := tool.Info()
|
||||
|
||||
// Convert parameters to genai.Schema format
|
||||
properties := make(map[string]*genai.Schema)
|
||||
for name, param := range info.Parameters {
|
||||
// Try to extract type and description from the parameter
|
||||
paramMap, ok := param.(map[string]interface{})
|
||||
if !ok {
|
||||
// Default to string if unable to determine type
|
||||
properties[name] = &genai.Schema{Type: genai.TypeString}
|
||||
continue
|
||||
}
|
||||
|
||||
schemaType := genai.TypeString // Default
|
||||
var description string
|
||||
var itemsTypeSchema *genai.Schema
|
||||
if typeVal, found := paramMap["type"]; found {
|
||||
if typeStr, ok := typeVal.(string); ok {
|
||||
switch typeStr {
|
||||
case "string":
|
||||
schemaType = genai.TypeString
|
||||
case "number":
|
||||
schemaType = genai.TypeNumber
|
||||
case "integer":
|
||||
schemaType = genai.TypeInteger
|
||||
case "boolean":
|
||||
schemaType = genai.TypeBoolean
|
||||
case "array":
|
||||
schemaType = genai.TypeArray
|
||||
items, found := paramMap["items"]
|
||||
if found {
|
||||
itemsMap, ok := items.(map[string]interface{})
|
||||
if ok {
|
||||
itemsType, found := itemsMap["type"]
|
||||
if found {
|
||||
itemsTypeStr, ok := itemsType.(string)
|
||||
if ok {
|
||||
switch itemsTypeStr {
|
||||
case "string":
|
||||
itemsTypeSchema = &genai.Schema{
|
||||
Type: genai.TypeString,
|
||||
}
|
||||
case "number":
|
||||
itemsTypeSchema = &genai.Schema{
|
||||
Type: genai.TypeNumber,
|
||||
}
|
||||
case "integer":
|
||||
itemsTypeSchema = &genai.Schema{
|
||||
Type: genai.TypeInteger,
|
||||
}
|
||||
case "boolean":
|
||||
itemsTypeSchema = &genai.Schema{
|
||||
Type: genai.TypeBoolean,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case "object":
|
||||
schemaType = genai.TypeObject
|
||||
if _, found := paramMap["properties"]; !found {
|
||||
continue
|
||||
}
|
||||
// TODO: Add support for other types
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if desc, found := paramMap["description"]; found {
|
||||
if descStr, ok := desc.(string); ok {
|
||||
description = descStr
|
||||
}
|
||||
}
|
||||
|
||||
properties[name] = &genai.Schema{
|
||||
Type: schemaType,
|
||||
Description: description,
|
||||
Items: itemsTypeSchema,
|
||||
}
|
||||
}
|
||||
|
||||
declarations[i] = &genai.FunctionDeclaration{
|
||||
Name: info.Name,
|
||||
Description: info.Description,
|
||||
Parameters: &genai.Schema{
|
||||
Type: genai.TypeObject,
|
||||
Properties: properties,
|
||||
Required: info.Required,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return declarations
|
||||
}
|
||||
|
||||
// extractTokenUsage extracts token usage information from Gemini's response
|
||||
func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
|
||||
if resp == nil || resp.UsageMetadata == nil {
|
||||
return TokenUsage{}
|
||||
}
|
||||
|
||||
return TokenUsage{
|
||||
InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
|
||||
OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
|
||||
CacheCreationTokens: 0, // Not directly provided by Gemini
|
||||
CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
|
||||
}
|
||||
}
|
||||
|
||||
// SendMessages sends a batch of messages to Gemini and returns the response
|
||||
func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
// Create a generative model
|
||||
model := p.client.GenerativeModel(p.model.APIModel)
|
||||
model.SetMaxOutputTokens(p.maxTokens)
|
||||
|
||||
// Set system instruction
|
||||
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
|
||||
|
||||
// Set up tools if provided
|
||||
if len(tools) > 0 {
|
||||
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
|
||||
model.Tools = []*genai.Tool{{FunctionDeclarations: declarations}}
|
||||
}
|
||||
|
||||
// Create chat session and set history
|
||||
chat := model.StartChat()
|
||||
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
|
||||
|
||||
// Get the most recent user message
|
||||
var lastUserMsg message.Message
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == message.User {
|
||||
lastUserMsg = messages[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Send the message
|
||||
resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process the response
|
||||
var content string
|
||||
var toolCalls []message.ToolCall
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
switch p := part.(type) {
|
||||
case genai.Text:
|
||||
content = string(p)
|
||||
case genai.FunctionCall:
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(p.Args)
|
||||
toolCalls = append(toolCalls, message.ToolCall{
|
||||
ID: id,
|
||||
Name: p.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract token usage
|
||||
tokenUsage := p.extractTokenUsage(resp)
|
||||
|
||||
return &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StreamResponse streams the response from Gemini
|
||||
func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
|
||||
// Create a generative model
|
||||
model := p.client.GenerativeModel(p.model.APIModel)
|
||||
model.SetMaxOutputTokens(p.maxTokens)
|
||||
|
||||
// Set system instruction
|
||||
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
|
||||
|
||||
// Set up tools if provided
|
||||
if len(tools) > 0 {
|
||||
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
|
||||
for _, declaration := range declarations {
|
||||
model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
|
||||
}
|
||||
}
|
||||
|
||||
// Create chat session and set history
|
||||
chat := model.StartChat()
|
||||
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
|
||||
|
||||
lastUserMsg := messages[len(messages)-1]
|
||||
|
||||
// Start streaming
|
||||
iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content))
|
||||
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
go func() {
|
||||
defer close(eventChan)
|
||||
|
||||
var finalResp *genai.GenerateContentResponse
|
||||
currentContent := ""
|
||||
toolCalls := []message.ToolCall{}
|
||||
|
||||
for {
|
||||
resp, err := iter.Next()
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
var apiErr *googleapi.Error
|
||||
if errors.As(err, &apiErr) {
|
||||
log.Printf("%s", apiErr.Body)
|
||||
}
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventError,
|
||||
Error: err,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
finalResp = resp
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
switch p := part.(type) {
|
||||
case genai.Text:
|
||||
newText := string(p)
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventContentDelta,
|
||||
Content: newText,
|
||||
}
|
||||
currentContent += newText
|
||||
case genai.FunctionCall:
|
||||
// For function calls, we assume they come complete, not streamed in parts
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(p.Args)
|
||||
newCall := message.ToolCall{
|
||||
ID: id,
|
||||
Name: p.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
}
|
||||
|
||||
// Check if this is a new tool call
|
||||
isNew := true
|
||||
for _, existing := range toolCalls {
|
||||
if existing.Name == newCall.Name && existing.Input == newCall.Input {
|
||||
isNew = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isNew {
|
||||
toolCalls = append(toolCalls, newCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract token usage from the final response
|
||||
tokenUsage := p.extractTokenUsage(finalResp)
|
||||
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventComplete,
|
||||
Response: &ProviderResponse{
|
||||
Content: currentContent,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
},
|
||||
}
|
||||
}()
|
||||
|
||||
return eventChan, nil
|
||||
}
|
||||
|
||||
// Helper function to parse JSON string into map
|
||||
func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
|
||||
var result map[string]interface{}
|
||||
err := json.Unmarshal([]byte(jsonStr), &result)
|
||||
return result, err
|
||||
}
|
278
internal/llm/provider/openai.go
Normal file
278
internal/llm/provider/openai.go
Normal file
|
@ -0,0 +1,278 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
"github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/option"
|
||||
)
|
||||
|
||||
type openaiProvider struct {
|
||||
client openai.Client
|
||||
model models.Model
|
||||
maxTokens int64
|
||||
baseURL string
|
||||
apiKey string
|
||||
systemMessage string
|
||||
}
|
||||
|
||||
type OpenAIOption func(*openaiProvider)
|
||||
|
||||
func NewOpenAIProvider(opts ...OpenAIOption) (Provider, error) {
|
||||
provider := &openaiProvider{
|
||||
maxTokens: 5000,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(provider)
|
||||
}
|
||||
|
||||
clientOpts := []option.RequestOption{
|
||||
option.WithAPIKey(provider.apiKey),
|
||||
}
|
||||
if provider.baseURL != "" {
|
||||
clientOpts = append(clientOpts, option.WithBaseURL(provider.baseURL))
|
||||
}
|
||||
|
||||
provider.client = openai.NewClient(clientOpts...)
|
||||
if provider.systemMessage == "" {
|
||||
return nil, errors.New("system message is required")
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func WithOpenAISystemMessage(message string) OpenAIOption {
|
||||
return func(p *openaiProvider) {
|
||||
p.systemMessage = message
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIMaxTokens(maxTokens int64) OpenAIOption {
|
||||
return func(p *openaiProvider) {
|
||||
p.maxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIModel(model models.Model) OpenAIOption {
|
||||
return func(p *openaiProvider) {
|
||||
p.model = model
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIBaseURL(baseURL string) OpenAIOption {
|
||||
return func(p *openaiProvider) {
|
||||
p.baseURL = baseURL
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIKey(apiKey string) OpenAIOption {
|
||||
return func(p *openaiProvider) {
|
||||
p.apiKey = apiKey
|
||||
}
|
||||
}
|
||||
|
||||
func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []openai.ChatCompletionMessageParamUnion {
|
||||
var chatMessages []openai.ChatCompletionMessageParamUnion
|
||||
|
||||
chatMessages = append(chatMessages, openai.SystemMessage(p.systemMessage))
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
chatMessages = append(chatMessages, openai.UserMessage(msg.Content))
|
||||
|
||||
case message.Assistant:
|
||||
assistantMsg := openai.ChatCompletionAssistantMessageParam{
|
||||
Role: "assistant",
|
||||
}
|
||||
|
||||
if msg.Content != "" {
|
||||
assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
|
||||
OfString: openai.String(msg.Content),
|
||||
}
|
||||
}
|
||||
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls))
|
||||
for i, call := range msg.ToolCalls {
|
||||
assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
|
||||
ID: call.ID,
|
||||
Type: "function",
|
||||
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
||||
Name: call.Name,
|
||||
Arguments: call.Input,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
chatMessages = append(chatMessages, openai.ChatCompletionMessageParamUnion{
|
||||
OfAssistant: &assistantMsg,
|
||||
})
|
||||
|
||||
case message.Tool:
|
||||
for _, result := range msg.ToolResults {
|
||||
chatMessages = append(chatMessages,
|
||||
openai.ToolMessage(result.Content, result.ToolCallID),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return chatMessages
|
||||
}
|
||||
|
||||
func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
|
||||
openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
|
||||
|
||||
for i, tool := range tools {
|
||||
info := tool.Info()
|
||||
openaiTools[i] = openai.ChatCompletionToolParam{
|
||||
Function: openai.FunctionDefinitionParam{
|
||||
Name: info.Name,
|
||||
Description: openai.String(info.Description),
|
||||
Parameters: openai.FunctionParameters{
|
||||
"type": "object",
|
||||
"properties": info.Parameters,
|
||||
"required": info.Required,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return openaiTools
|
||||
}
|
||||
|
||||
func (p *openaiProvider) extractTokenUsage(usage openai.CompletionUsage) TokenUsage {
|
||||
cachedTokens := int64(0)
|
||||
|
||||
cachedTokens = usage.PromptTokensDetails.CachedTokens
|
||||
inputTokens := usage.PromptTokens - cachedTokens
|
||||
|
||||
return TokenUsage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: usage.CompletionTokens,
|
||||
CacheCreationTokens: 0, // OpenAI doesn't provide this directly
|
||||
CacheReadTokens: cachedTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
chatMessages := p.convertToOpenAIMessages(messages)
|
||||
openaiTools := p.convertToOpenAITools(tools)
|
||||
|
||||
params := openai.ChatCompletionNewParams{
|
||||
Model: openai.ChatModel(p.model.APIModel),
|
||||
Messages: chatMessages,
|
||||
MaxTokens: openai.Int(p.maxTokens),
|
||||
Tools: openaiTools,
|
||||
}
|
||||
|
||||
response, err := p.client.Chat.Completions.New(ctx, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
content := ""
|
||||
if response.Choices[0].Message.Content != "" {
|
||||
content = response.Choices[0].Message.Content
|
||||
}
|
||||
|
||||
var toolCalls []message.ToolCall
|
||||
if len(response.Choices[0].Message.ToolCalls) > 0 {
|
||||
toolCalls = make([]message.ToolCall, len(response.Choices[0].Message.ToolCalls))
|
||||
for i, call := range response.Choices[0].Message.ToolCalls {
|
||||
toolCalls[i] = message.ToolCall{
|
||||
ID: call.ID,
|
||||
Name: call.Function.Name,
|
||||
Input: call.Function.Arguments,
|
||||
Type: "function",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokenUsage := p.extractTokenUsage(response.Usage)
|
||||
|
||||
return &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
|
||||
chatMessages := p.convertToOpenAIMessages(messages)
|
||||
openaiTools := p.convertToOpenAITools(tools)
|
||||
|
||||
params := openai.ChatCompletionNewParams{
|
||||
Model: openai.ChatModel(p.model.APIModel),
|
||||
Messages: chatMessages,
|
||||
MaxTokens: openai.Int(p.maxTokens),
|
||||
Tools: openaiTools,
|
||||
StreamOptions: openai.ChatCompletionStreamOptionsParam{
|
||||
IncludeUsage: openai.Bool(true),
|
||||
},
|
||||
}
|
||||
|
||||
stream := p.client.Chat.Completions.NewStreaming(ctx, params)
|
||||
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
toolCalls := make([]message.ToolCall, 0)
|
||||
go func() {
|
||||
defer close(eventChan)
|
||||
|
||||
acc := openai.ChatCompletionAccumulator{}
|
||||
currentContent := ""
|
||||
|
||||
for stream.Next() {
|
||||
chunk := stream.Current()
|
||||
acc.AddChunk(chunk)
|
||||
|
||||
if tool, ok := acc.JustFinishedToolCall(); ok {
|
||||
toolCalls = append(toolCalls, message.ToolCall{
|
||||
ID: tool.Id,
|
||||
Name: tool.Name,
|
||||
Input: tool.Arguments,
|
||||
Type: "function",
|
||||
})
|
||||
}
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
if choice.Delta.Content != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventContentDelta,
|
||||
Content: choice.Delta.Content,
|
||||
}
|
||||
currentContent += choice.Delta.Content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventError,
|
||||
Error: err,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
tokenUsage := p.extractTokenUsage(acc.Usage)
|
||||
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventComplete,
|
||||
Response: &ProviderResponse{
|
||||
Content: currentContent,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: tokenUsage,
|
||||
},
|
||||
}
|
||||
}()
|
||||
|
||||
return eventChan, nil
|
||||
}
|
48
internal/llm/provider/provider.go
Normal file
48
internal/llm/provider/provider.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
)
|
||||
|
||||
// EventType represents the type of streaming event
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
EventContentStart EventType = "content_start"
|
||||
EventContentDelta EventType = "content_delta"
|
||||
EventThinkingDelta EventType = "thinking_delta"
|
||||
EventContentStop EventType = "content_stop"
|
||||
EventComplete EventType = "complete"
|
||||
EventError EventType = "error"
|
||||
)
|
||||
|
||||
type TokenUsage struct {
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
CacheCreationTokens int64
|
||||
CacheReadTokens int64
|
||||
}
|
||||
|
||||
type ProviderResponse struct {
|
||||
Content string
|
||||
ToolCalls []message.ToolCall
|
||||
Usage TokenUsage
|
||||
}
|
||||
|
||||
type ProviderEvent struct {
|
||||
Type EventType
|
||||
Content string
|
||||
Thinking string
|
||||
ToolCall *message.ToolCall
|
||||
Error error
|
||||
Response *ProviderResponse
|
||||
}
|
||||
|
||||
type Provider interface {
|
||||
SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
|
||||
|
||||
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error)
|
||||
}
|
|
@ -1,141 +0,0 @@
|
|||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/flow/agent/react"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type agentTool struct {
|
||||
workingDir string
|
||||
}
|
||||
|
||||
const (
|
||||
AgentToolName = "agent"
|
||||
)
|
||||
|
||||
type AgentParams struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
func taskAgentTools() []tool.BaseTool {
|
||||
wd := viper.GetString("wd")
|
||||
return []tool.BaseTool{
|
||||
NewBashTool(wd),
|
||||
NewLsTool(wd),
|
||||
NewGlobTool(wd),
|
||||
NewViewTool(wd),
|
||||
NewWriteTool(wd),
|
||||
NewEditTool(wd),
|
||||
}
|
||||
}
|
||||
|
||||
func NewTaskAgent(ctx context.Context) (*react.Agent, error) {
|
||||
model, err := models.GetModel(ctx, models.ModelID(viper.GetString("models.big")))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reactAgent, err := react.NewAgent(ctx, &react.AgentConfig{
|
||||
Model: model,
|
||||
ToolsConfig: compose.ToolsNodeConfig{
|
||||
Tools: taskAgentTools(),
|
||||
},
|
||||
MaxStep: 1000,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reactAgent, nil
|
||||
}
|
||||
|
||||
func TaskAgentSystemPrompt() string {
|
||||
agentPrompt := `You are an agent for Orbitowl. Given the user's prompt, you should use the tools available to you to answer the user's question.
|
||||
|
||||
Notes:
|
||||
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
|
||||
2. When relevant, share file names and code snippets relevant to the query
|
||||
3. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.
|
||||
|
||||
Here is useful information about the environment you are running in:
|
||||
<env>
|
||||
Working directory: %s
|
||||
Platform: %s
|
||||
Today's date: %s
|
||||
</env>`
|
||||
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
cwd = "unknown"
|
||||
}
|
||||
|
||||
platform := runtime.GOOS
|
||||
|
||||
switch platform {
|
||||
case "darwin":
|
||||
platform = "macos"
|
||||
case "windows":
|
||||
platform = "windows"
|
||||
case "linux":
|
||||
platform = "linux"
|
||||
}
|
||||
return fmt.Sprintf(agentPrompt, cwd, platform, time.Now().Format("1/2/2006"))
|
||||
}
|
||||
|
||||
func (b *agentTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
return &schema.ToolInfo{
|
||||
Name: AgentToolName,
|
||||
Desc: "Launch a new agent that has access to the following tools: GlobTool, GrepTool, LS, View, ReadNotebook. When you are searching for a keyword or file and are not confident that you will find the right match on the first try, use the Agent tool to perform the search for you. For example:\n\n- If you are searching for a keyword like \"config\" or \"logger\", or for questions like \"which file does X?\", the Agent tool is strongly recommended\n- If you want to read a specific file path, use the View or GlobTool tool instead of the Agent tool, to find the match more quickly\n- If you are searching for a specific class definition like \"class Foo\", use the GlobTool tool instead, to find the match more quickly\n\nUsage notes:\n1. Launch multiple agents concurrently whenever possible, to maximize performance; to do that, use a single message with multiple tool uses\n2. When the agent is done, it will return a single message back to you. The result returned by the agent is not visible to the user. To show the user the result, you should send a text message back to the user with a concise summary of the result.\n3. Each agent invocation is stateless. You will not be able to send additional messages to the agent, nor will the agent be able to communicate with you outside of its final report. Therefore, your prompt should contain a highly detailed task description for the agent to perform autonomously and you should specify exactly what information the agent should return back to you in its final and only message to you.\n4. The agent's outputs should generally be trusted\n5. IMPORTANT: The agent can not use Bash, Replace, Edit, NotebookEditCell, so can not modify files. If you want to use these tools, use them directly instead of going through the agent.",
|
||||
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
|
||||
"prompt": {
|
||||
Type: "string",
|
||||
Desc: "The task for the agent to perform",
|
||||
Required: true,
|
||||
},
|
||||
}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *agentTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
|
||||
var params AgentParams
|
||||
if err := json.Unmarshal([]byte(args), ¶ms); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if params.Prompt == "" {
|
||||
return "prompt is required", nil
|
||||
}
|
||||
|
||||
a, err := NewTaskAgent(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
out, err := a.Generate(
|
||||
ctx,
|
||||
[]*schema.Message{
|
||||
schema.SystemMessage(TaskAgentSystemPrompt()),
|
||||
schema.UserMessage(params.Prompt),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return out.Content, nil
|
||||
}
|
||||
|
||||
func NewAgentTool(wd string) tool.InvokableTool {
|
||||
return &agentTool{
|
||||
workingDir: wd,
|
||||
}
|
||||
}
|
|
@ -6,20 +6,17 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools/shell"
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
)
|
||||
|
||||
type bashTool struct {
|
||||
workingDir string
|
||||
}
|
||||
type bashTool struct{}
|
||||
|
||||
const (
|
||||
BashToolName = "bash"
|
||||
|
||||
DefaultTimeout = 30 * 60 * 1000 // 30 minutes in milliseconds
|
||||
DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds
|
||||
MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds
|
||||
MaxOutputLength = 30000
|
||||
)
|
||||
|
@ -29,6 +26,11 @@ type BashParams struct {
|
|||
Timeout int `json:"timeout"`
|
||||
}
|
||||
|
||||
type BashPermissionsParams struct {
|
||||
Command string `json:"command"`
|
||||
Timeout int `json:"timeout"`
|
||||
}
|
||||
|
||||
var BannedCommands = []string{
|
||||
"alias", "curl", "curlie", "wget", "axel", "aria2c",
|
||||
"nc", "telnet", "lynx", "w3m", "links", "httpie", "xh",
|
||||
|
@ -40,29 +42,29 @@ var SafeReadOnlyCommands = []string{
|
|||
"whatis", //...
|
||||
}
|
||||
|
||||
func (b *bashTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
return &schema.ToolInfo{
|
||||
Name: BashToolName,
|
||||
Desc: bashDescription(),
|
||||
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
|
||||
"command": {
|
||||
Type: "string",
|
||||
Desc: "The command to execute",
|
||||
Required: true,
|
||||
func (b *bashTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: BashToolName,
|
||||
Description: bashDescription(),
|
||||
Parameters: map[string]any{
|
||||
"command": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The command to execute",
|
||||
},
|
||||
"timeout": {
|
||||
Type: "number",
|
||||
Desc: "Optional timeout in milliseconds (max 600000)",
|
||||
"timeout": map[string]any{
|
||||
"type": "number",
|
||||
"desription": "Optional timeout in milliseconds (max 600000)",
|
||||
},
|
||||
}),
|
||||
}, nil
|
||||
},
|
||||
Required: []string{"command"},
|
||||
}
|
||||
}
|
||||
|
||||
// Handle implements Tool.
|
||||
func (b *bashTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
|
||||
func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params BashParams
|
||||
if err := json.Unmarshal([]byte(args), ¶ms); err != nil {
|
||||
return "", err
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse("invalid parameters"), nil
|
||||
}
|
||||
|
||||
if params.Timeout > MaxTimeout {
|
||||
|
@ -72,13 +74,13 @@ func (b *bashTool) InvokableRun(ctx context.Context, args string, opts ...tool.O
|
|||
}
|
||||
|
||||
if params.Command == "" {
|
||||
return "missing command", nil
|
||||
return NewTextErrorResponse("missing command"), nil
|
||||
}
|
||||
|
||||
baseCmd := strings.Fields(params.Command)[0]
|
||||
for _, banned := range BannedCommands {
|
||||
if strings.EqualFold(baseCmd, banned) {
|
||||
return fmt.Sprintf("command '%s' is not allowed", baseCmd), nil
|
||||
return NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", baseCmd)), nil
|
||||
}
|
||||
}
|
||||
isSafeReadOnly := false
|
||||
|
@ -91,39 +93,21 @@ func (b *bashTool) InvokableRun(ctx context.Context, args string, opts ...tool.O
|
|||
if !isSafeReadOnly {
|
||||
p := permission.Default.Request(
|
||||
permission.CreatePermissionRequest{
|
||||
Path: b.workingDir,
|
||||
Path: config.WorkingDirectory(),
|
||||
ToolName: BashToolName,
|
||||
Action: "execute",
|
||||
Description: fmt.Sprintf("Execute command: %s", params.Command),
|
||||
Params: map[string]interface{}{
|
||||
"command": params.Command,
|
||||
"timeout": params.Timeout,
|
||||
},
|
||||
Params: BashPermissionsParams(params),
|
||||
},
|
||||
)
|
||||
if !p {
|
||||
return "", fmt.Errorf("permission denied for command: %s", params.Command)
|
||||
return NewTextErrorResponse("permission denied"), nil
|
||||
}
|
||||
}
|
||||
|
||||
// p := b.permission.Request(permission.CreatePermissionRequest{
|
||||
// Path: b.workingDir,
|
||||
// ToolName: BashToolName,
|
||||
// Action: "execute",
|
||||
// Description: fmt.Sprintf("Execute command: %s", params.Command),
|
||||
// Params: map[string]any{
|
||||
// "command": params.Command,
|
||||
// "timeout": params.Timeout,
|
||||
// },
|
||||
// })
|
||||
// if !p {
|
||||
// return "", errors.New("permission denied")
|
||||
// }
|
||||
|
||||
shell := shell.GetPersistentShell(b.workingDir)
|
||||
shell := shell.GetPersistentShell(config.WorkingDirectory())
|
||||
stdout, stderr, exitCode, interrupted, err := shell.Exec(ctx, params.Command, params.Timeout)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return NewTextErrorResponse(fmt.Sprintf("error executing command: %s", err)), nil
|
||||
}
|
||||
|
||||
stdout = truncateOutput(stdout)
|
||||
|
@ -153,9 +137,9 @@ func (b *bashTool) InvokableRun(ctx context.Context, args string, opts ...tool.O
|
|||
}
|
||||
|
||||
if stdout == "" {
|
||||
return "no output", nil
|
||||
return NewTextResponse("no output"), nil
|
||||
}
|
||||
return stdout, nil
|
||||
return NewTextResponse(stdout), nil
|
||||
}
|
||||
|
||||
func truncateOutput(content string) string {
|
||||
|
@ -327,8 +311,6 @@ Important:
|
|||
- Never update git config`, bannedCommandsStr, MaxOutputLength)
|
||||
}
|
||||
|
||||
func NewBashTool(workingDir string) tool.InvokableTool {
|
||||
return &bashTool{
|
||||
workingDir: workingDir,
|
||||
}
|
||||
func NewBashTool() BaseTool {
|
||||
return &bashTool{}
|
||||
}
|
||||
|
|
389
internal/llm/tools/bash_test.go
Normal file
389
internal/llm/tools/bash_test.go
Normal file
|
@ -0,0 +1,389 @@
|
|||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
"github.com/kujtimiihoxha/termai/internal/pubsub"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBashTool_Info(t *testing.T) {
|
||||
tool := NewBashTool()
|
||||
info := tool.Info()
|
||||
|
||||
assert.Equal(t, BashToolName, info.Name)
|
||||
assert.NotEmpty(t, info.Description)
|
||||
assert.Contains(t, info.Parameters, "command")
|
||||
assert.Contains(t, info.Parameters, "timeout")
|
||||
assert.Contains(t, info.Required, "command")
|
||||
}
|
||||
|
||||
func TestBashTool_Run(t *testing.T) {
|
||||
// Setup a mock permission handler that always allows
|
||||
origPermission := permission.Default
|
||||
defer func() {
|
||||
permission.Default = origPermission
|
||||
}()
|
||||
permission.Default = newMockPermissionService(true)
|
||||
|
||||
// Save original working directory
|
||||
origWd, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
os.Chdir(origWd)
|
||||
}()
|
||||
|
||||
t.Run("executes command successfully", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewBashTool()
|
||||
params := BashParams{
|
||||
Command: "echo 'Hello World'",
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: BashToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Hello World\n", response.Content)
|
||||
})
|
||||
|
||||
t.Run("handles invalid parameters", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
|
||||
tool := NewBashTool()
|
||||
call := ToolCall{
|
||||
Name: BashToolName,
|
||||
Input: "invalid json",
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "invalid parameters")
|
||||
})
|
||||
|
||||
t.Run("handles missing command", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
|
||||
tool := NewBashTool()
|
||||
params := BashParams{
|
||||
Command: "",
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: BashToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "missing command")
|
||||
})
|
||||
|
||||
t.Run("handles banned commands", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
|
||||
tool := NewBashTool()
|
||||
|
||||
for _, bannedCmd := range BannedCommands {
|
||||
params := BashParams{
|
||||
Command: bannedCmd + " arg1 arg2",
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: BashToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "not allowed", "Command %s should be blocked", bannedCmd)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles safe read-only commands without permission check", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(false)
|
||||
|
||||
tool := NewBashTool()
|
||||
|
||||
// Test with a safe read-only command
|
||||
params := BashParams{
|
||||
Command: "echo 'test'",
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: BashToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test\n", response.Content)
|
||||
})
|
||||
|
||||
t.Run("handles permission denied", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(false)
|
||||
|
||||
tool := NewBashTool()
|
||||
|
||||
// Test with a command that requires permission
|
||||
params := BashParams{
|
||||
Command: "mkdir test_dir",
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: BashToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "permission denied")
|
||||
})
|
||||
|
||||
t.Run("handles command timeout", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewBashTool()
|
||||
params := BashParams{
|
||||
Command: "sleep 2",
|
||||
Timeout: 100, // 100ms timeout
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: BashToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "aborted")
|
||||
})
|
||||
|
||||
t.Run("handles command with stderr output", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewBashTool()
|
||||
params := BashParams{
|
||||
Command: "echo 'error message' >&2",
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: BashToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "error message")
|
||||
})
|
||||
|
||||
t.Run("handles command with both stdout and stderr", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewBashTool()
|
||||
params := BashParams{
|
||||
Command: "echo 'stdout message' && echo 'stderr message' >&2",
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: BashToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "stdout message")
|
||||
assert.Contains(t, response.Content, "stderr message")
|
||||
})
|
||||
|
||||
t.Run("handles context cancellation", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewBashTool()
|
||||
params := BashParams{
|
||||
Command: "sleep 5",
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: BashToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Cancel the context after a short delay
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
response, err := tool.Run(ctx, call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "aborted")
|
||||
})
|
||||
|
||||
t.Run("respects max timeout", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewBashTool()
|
||||
params := BashParams{
|
||||
Command: "echo 'test'",
|
||||
Timeout: MaxTimeout + 1000, // Exceeds max timeout
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: BashToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test\n", response.Content)
|
||||
})
|
||||
|
||||
t.Run("uses default timeout for zero or negative timeout", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewBashTool()
|
||||
params := BashParams{
|
||||
Command: "echo 'test'",
|
||||
Timeout: -100, // Negative timeout
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: BashToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test\n", response.Content)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTruncateOutput(t *testing.T) {
|
||||
t.Run("does not truncate short output", func(t *testing.T) {
|
||||
output := "short output"
|
||||
result := truncateOutput(output)
|
||||
assert.Equal(t, output, result)
|
||||
})
|
||||
|
||||
t.Run("truncates long output", func(t *testing.T) {
|
||||
// Create a string longer than MaxOutputLength
|
||||
longOutput := strings.Repeat("a\n", MaxOutputLength)
|
||||
result := truncateOutput(longOutput)
|
||||
|
||||
// Check that the result is shorter than the original
|
||||
assert.Less(t, len(result), len(longOutput))
|
||||
|
||||
// Check that the truncation message is included
|
||||
assert.Contains(t, result, "lines truncated")
|
||||
|
||||
// Check that we have the beginning and end of the original string
|
||||
assert.True(t, strings.HasPrefix(result, "a\n"))
|
||||
assert.True(t, strings.HasSuffix(result, "a\n"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCountLines(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "single line",
|
||||
input: "line1",
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple lines",
|
||||
input: "line1\nline2\nline3",
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
name: "trailing newline",
|
||||
input: "line1\nline2\n",
|
||||
expected: 3, // Empty string after last newline counts as a line
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := countLines(tc.input)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Mock permission service for testing
|
||||
type mockPermissionService struct {
|
||||
*pubsub.Broker[permission.PermissionRequest]
|
||||
allow bool
|
||||
}
|
||||
|
||||
func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {
|
||||
// Not needed for tests
|
||||
}
|
||||
|
||||
func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {
|
||||
// Not needed for tests
|
||||
}
|
||||
|
||||
func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {
|
||||
// Not needed for tests
|
||||
}
|
||||
|
||||
func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool {
|
||||
return m.allow
|
||||
}
|
||||
|
||||
func newMockPermissionService(allow bool) permission.Service {
|
||||
return &mockPermissionService{
|
||||
Broker: pubsub.NewBroker[permission.PermissionRequest](),
|
||||
allow: allow,
|
||||
}
|
||||
}
|
||||
|
|
@ -3,22 +3,18 @@ package tools
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
"github.com/sergi/go-diff/diffmatchpatch"
|
||||
)
|
||||
|
||||
type editTool struct {
|
||||
workingDir string
|
||||
}
|
||||
type editTool struct{}
|
||||
|
||||
const (
|
||||
EditToolName = "edit"
|
||||
|
@ -30,100 +26,72 @@ type EditParams struct {
|
|||
NewString string `json:"new_string"`
|
||||
}
|
||||
|
||||
func (b *editTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
return &schema.ToolInfo{
|
||||
Name: EditToolName,
|
||||
Desc: `This is a tool for editing files. For moving or renaming files, you should generally use the Bash tool with the 'mv' command instead. For larger edits, use the Write tool to overwrite files. F.
|
||||
|
||||
Before using this tool:
|
||||
|
||||
1. Use the View tool to understand the file's contents and context
|
||||
|
||||
2. Verify the directory path is correct (only applicable when creating new files):
|
||||
- Use the LS tool to verify the parent directory exists and is the correct location
|
||||
|
||||
To make a file edit, provide the following:
|
||||
1. file_path: The absolute path to the file to modify (must be absolute, not relative)
|
||||
2. old_string: The text to replace (must be unique within the file, and must match the file contents exactly, including all whitespace and indentation)
|
||||
3. new_string: The edited text to replace the old_string
|
||||
|
||||
The tool will replace ONE occurrence of old_string with new_string in the specified file.
|
||||
|
||||
CRITICAL REQUIREMENTS FOR USING THIS TOOL:
|
||||
|
||||
1. UNIQUENESS: The old_string MUST uniquely identify the specific instance you want to change. This means:
|
||||
- Include AT LEAST 3-5 lines of context BEFORE the change point
|
||||
- Include AT LEAST 3-5 lines of context AFTER the change point
|
||||
- Include all whitespace, indentation, and surrounding code exactly as it appears in the file
|
||||
|
||||
2. SINGLE INSTANCE: This tool can only change ONE instance at a time. If you need to change multiple instances:
|
||||
- Make separate calls to this tool for each instance
|
||||
- Each call must uniquely identify its specific instance using extensive context
|
||||
|
||||
3. VERIFICATION: Before using this tool:
|
||||
- Check how many instances of the target text exist in the file
|
||||
- If multiple instances exist, gather enough context to uniquely identify each one
|
||||
- Plan separate tool calls for each instance
|
||||
|
||||
WARNING: If you do not follow these requirements:
|
||||
- The tool will fail if old_string matches multiple locations
|
||||
- The tool will fail if old_string doesn't match exactly (including whitespace)
|
||||
- You may change the wrong instance if you don't include enough context
|
||||
|
||||
When making edits:
|
||||
- Ensure the edit results in idiomatic, correct code
|
||||
- Do not leave the code in a broken state
|
||||
- Always use absolute file paths (starting with /)
|
||||
|
||||
If you want to create a new file, use:
|
||||
- A new file path, including dir name if needed
|
||||
- An empty old_string
|
||||
- The new file's contents as new_string
|
||||
|
||||
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.`,
|
||||
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
|
||||
"file_path": {
|
||||
Type: "string",
|
||||
Desc: "The absolute path to the file to modify",
|
||||
Required: true,
|
||||
},
|
||||
"old_string": {
|
||||
Type: "string",
|
||||
Desc: "The text to replace",
|
||||
Required: true,
|
||||
},
|
||||
"new_string": {
|
||||
Type: "string",
|
||||
Desc: "The text to replace it with",
|
||||
Required: true,
|
||||
},
|
||||
}),
|
||||
}, nil
|
||||
type EditPermissionsParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
OldString string `json:"old_string"`
|
||||
NewString string `json:"new_string"`
|
||||
Diff string `json:"diff"`
|
||||
}
|
||||
|
||||
func (b *editTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
|
||||
func (e *editTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: EditToolName,
|
||||
Description: editDescription(),
|
||||
Parameters: map[string]any{
|
||||
"file_path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The absolute path to the file to modify",
|
||||
},
|
||||
"old_string": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The text to replace",
|
||||
},
|
||||
"new_string": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The text to replace it with",
|
||||
},
|
||||
},
|
||||
Required: []string{"file_path", "old_string", "new_string"},
|
||||
}
|
||||
}
|
||||
|
||||
// Run implements Tool.
|
||||
func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params EditParams
|
||||
if err := json.Unmarshal([]byte(args), ¶ms); err != nil {
|
||||
return "", err
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse("invalid parameters"), nil
|
||||
}
|
||||
|
||||
if params.FilePath == "" {
|
||||
return "", errors.New("file_path is required")
|
||||
return NewTextErrorResponse("file_path is required"), nil
|
||||
}
|
||||
|
||||
if !filepath.IsAbs(params.FilePath) {
|
||||
return "", fmt.Errorf("file path must be absolute, got: %s", params.FilePath)
|
||||
wd := config.WorkingDirectory()
|
||||
params.FilePath = filepath.Join(wd, params.FilePath)
|
||||
}
|
||||
|
||||
if params.OldString == "" {
|
||||
return createNewFile(params.FilePath, params.NewString)
|
||||
result, err := createNewFile(params.FilePath, params.NewString)
|
||||
if err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error creating file: %s", err)), nil
|
||||
}
|
||||
return NewTextErrorResponse(result), nil
|
||||
}
|
||||
|
||||
if params.NewString == "" {
|
||||
return deleteContent(params.FilePath, params.OldString)
|
||||
result, err := deleteContent(params.FilePath, params.OldString)
|
||||
if err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error deleting content: %s", err)), nil
|
||||
}
|
||||
return NewTextErrorResponse(result), nil
|
||||
}
|
||||
|
||||
return replaceContent(params.FilePath, params.OldString, params.NewString)
|
||||
result, err := replaceContent(params.FilePath, params.OldString, params.NewString)
|
||||
if err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error replacing content: %s", err)), nil
|
||||
}
|
||||
return NewTextResponse(result), nil
|
||||
}
|
||||
|
||||
func createNewFile(filePath, content string) (string, error) {
|
||||
|
@ -148,9 +116,11 @@ func createNewFile(filePath, content string) (string, error) {
|
|||
ToolName: EditToolName,
|
||||
Action: "create",
|
||||
Description: fmt.Sprintf("Create file %s", filePath),
|
||||
Params: map[string]interface{}{
|
||||
"file_path": filePath,
|
||||
"content": content,
|
||||
Params: EditPermissionsParams{
|
||||
FilePath: filePath,
|
||||
OldString: "",
|
||||
NewString: content,
|
||||
Diff: GenerateDiff("", content),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
@ -166,19 +136,6 @@ func createNewFile(filePath, content string) (string, error) {
|
|||
recordFileWrite(filePath)
|
||||
recordFileRead(filePath)
|
||||
|
||||
// result := FileEditResult{
|
||||
// FilePath: filePath,
|
||||
// Created: true,
|
||||
// Updated: false,
|
||||
// Deleted: false,
|
||||
// Diff: generateDiff("", content),
|
||||
// }
|
||||
//
|
||||
// resultJSON, err := json.Marshal(result)
|
||||
// if err != nil {
|
||||
// return "", fmt.Errorf("failed to serialize result: %w", err)
|
||||
// }
|
||||
//
|
||||
return "File created: " + filePath, nil
|
||||
}
|
||||
|
||||
|
@ -231,9 +188,11 @@ func deleteContent(filePath, oldString string) (string, error) {
|
|||
ToolName: EditToolName,
|
||||
Action: "delete",
|
||||
Description: fmt.Sprintf("Delete content from file %s", filePath),
|
||||
Params: map[string]interface{}{
|
||||
"file_path": filePath,
|
||||
"content": content,
|
||||
Params: EditPermissionsParams{
|
||||
FilePath: filePath,
|
||||
OldString: oldString,
|
||||
NewString: "",
|
||||
Diff: GenerateDiff(oldContent, newContent),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
@ -247,21 +206,7 @@ func deleteContent(filePath, oldString string) (string, error) {
|
|||
}
|
||||
|
||||
recordFileWrite(filePath)
|
||||
|
||||
// result := FileEditResult{
|
||||
// FilePath: filePath,
|
||||
// Created: false,
|
||||
// Updated: true,
|
||||
// Deleted: true,
|
||||
// Diff: generateDiff(oldContent, newContent),
|
||||
// SnippetBefore: getContextSnippet(oldContent, index, len(oldString)),
|
||||
// SnippetAfter: getContextSnippet(newContent, index, 0),
|
||||
// }
|
||||
//
|
||||
// resultJSON, err := json.Marshal(result)
|
||||
// if err != nil {
|
||||
// return "", fmt.Errorf("failed to serialize result: %w", err)
|
||||
// }
|
||||
recordFileRead(filePath)
|
||||
|
||||
return "Content deleted from file: " + filePath, nil
|
||||
}
|
||||
|
@ -270,44 +215,45 @@ func replaceContent(filePath, oldString, newString string) (string, error) {
|
|||
fileInfo, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Sprintf("file not found: %s", filePath), nil
|
||||
return "", fmt.Errorf("file not found: %s", filePath)
|
||||
}
|
||||
return fmt.Sprintf("failed to access file: %s", err), nil
|
||||
return "", fmt.Errorf("failed to access file: %w", err)
|
||||
}
|
||||
|
||||
if fileInfo.IsDir() {
|
||||
return fmt.Sprintf("path is a directory, not a file: %s", filePath), nil
|
||||
return "", fmt.Errorf("path is a directory, not a file: %s", filePath)
|
||||
}
|
||||
|
||||
if getLastReadTime(filePath).IsZero() {
|
||||
return "you must read the file before editing it. Use the View tool first", nil
|
||||
return "", fmt.Errorf("you must read the file before editing it. Use the View tool first")
|
||||
}
|
||||
|
||||
modTime := fileInfo.ModTime()
|
||||
lastRead := getLastReadTime(filePath)
|
||||
if modTime.After(lastRead) {
|
||||
return fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
|
||||
filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339)), nil
|
||||
return "", fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
|
||||
filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("failed to read file: %s", err), nil
|
||||
return "", fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
oldContent := string(content)
|
||||
|
||||
index := strings.Index(oldContent, oldString)
|
||||
if index == -1 {
|
||||
return "old_string not found in file. Make sure it matches exactly, including whitespace and line breaks", nil
|
||||
return "", fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks")
|
||||
}
|
||||
|
||||
lastIndex := strings.LastIndex(oldContent, oldString)
|
||||
if index != lastIndex {
|
||||
return "old_string appears multiple times in the file. Please provide more context to ensure a unique match", nil
|
||||
return "", fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match")
|
||||
}
|
||||
|
||||
newContent := oldContent[:index] + newString + oldContent[index+len(oldString):]
|
||||
diff := GenerateDiff(oldString, newContent)
|
||||
|
||||
p := permission.Default.Request(
|
||||
permission.CreatePermissionRequest{
|
||||
|
@ -315,10 +261,11 @@ func replaceContent(filePath, oldString, newString string) (string, error) {
|
|||
ToolName: EditToolName,
|
||||
Action: "replace",
|
||||
Description: fmt.Sprintf("Replace content in file %s", filePath),
|
||||
Params: map[string]interface{}{
|
||||
"file_path": filePath,
|
||||
"old_string": oldString,
|
||||
"new_string": newString,
|
||||
Params: EditPermissionsParams{
|
||||
FilePath: filePath,
|
||||
OldString: oldString,
|
||||
NewString: newString,
|
||||
Diff: diff,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
@ -328,93 +275,97 @@ func replaceContent(filePath, oldString, newString string) (string, error) {
|
|||
|
||||
err = os.WriteFile(filePath, []byte(newContent), 0o644)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("failed to write file: %s", err), nil
|
||||
return "", fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
recordFileWrite(filePath)
|
||||
|
||||
// result := FileEditResult{
|
||||
// FilePath: filePath,
|
||||
// Created: false,
|
||||
// Updated: true,
|
||||
// Deleted: false,
|
||||
// Diff: generateDiff(oldContent, newContent),
|
||||
// SnippetBefore: getContextSnippet(oldContent, index, len(oldString)),
|
||||
// SnippetAfter: getContextSnippet(newContent, index, len(newString)),
|
||||
// }
|
||||
//
|
||||
// resultJSON, err := json.Marshal(result)
|
||||
// if err != nil {
|
||||
// return "", fmt.Errorf("failed to serialize result: %w", err)
|
||||
// }
|
||||
recordFileRead(filePath)
|
||||
|
||||
return "Content replaced in file: " + filePath, nil
|
||||
}
|
||||
|
||||
func getContextSnippet(content string, position, length int) string {
|
||||
contextLines := 3
|
||||
|
||||
lines := strings.Split(content, "\n")
|
||||
lineIndex := 0
|
||||
currentPos := 0
|
||||
|
||||
for i, line := range lines {
|
||||
if currentPos <= position && position < currentPos+len(line)+1 {
|
||||
lineIndex = i
|
||||
break
|
||||
}
|
||||
currentPos += len(line) + 1 // +1 for the newline
|
||||
}
|
||||
|
||||
startLine := max(0, lineIndex-contextLines)
|
||||
endLine := min(len(lines), lineIndex+contextLines+1)
|
||||
|
||||
var snippetBuilder strings.Builder
|
||||
for i := startLine; i < endLine; i++ {
|
||||
if i == lineIndex {
|
||||
snippetBuilder.WriteString(fmt.Sprintf("> %s\n", lines[i]))
|
||||
} else {
|
||||
snippetBuilder.WriteString(fmt.Sprintf(" %s\n", lines[i]))
|
||||
}
|
||||
}
|
||||
|
||||
return snippetBuilder.String()
|
||||
}
|
||||
|
||||
func generateDiff(oldContent, newContent string) string {
|
||||
func GenerateDiff(oldContent, newContent string) string {
|
||||
dmp := diffmatchpatch.New()
|
||||
fileAdmp, fileBdmp, dmpStrings := dmp.DiffLinesToChars(oldContent, newContent)
|
||||
diffs := dmp.DiffMain(fileAdmp, fileBdmp, false)
|
||||
diffs = dmp.DiffCharsToLines(diffs, dmpStrings)
|
||||
diffs = dmp.DiffCleanupSemantic(diffs)
|
||||
buff := strings.Builder{}
|
||||
for _, diff := range diffs {
|
||||
text := diff.Text
|
||||
|
||||
diffs := dmp.DiffMain(oldContent, newContent, false)
|
||||
|
||||
patches := dmp.PatchMake(oldContent, diffs)
|
||||
patchText := dmp.PatchToText(patches)
|
||||
|
||||
if patchText == "" && (oldContent != newContent) {
|
||||
var result strings.Builder
|
||||
|
||||
result.WriteString("@@ Diff @@\n")
|
||||
for _, diff := range diffs {
|
||||
switch diff.Type {
|
||||
case diffmatchpatch.DiffInsert:
|
||||
result.WriteString("+ " + diff.Text + "\n")
|
||||
case diffmatchpatch.DiffDelete:
|
||||
result.WriteString("- " + diff.Text + "\n")
|
||||
case diffmatchpatch.DiffEqual:
|
||||
if len(diff.Text) > 40 {
|
||||
result.WriteString(" " + diff.Text[:20] + "..." + diff.Text[len(diff.Text)-20:] + "\n")
|
||||
} else {
|
||||
result.WriteString(" " + diff.Text + "\n")
|
||||
switch diff.Type {
|
||||
case diffmatchpatch.DiffInsert:
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
_, _ = buff.WriteString("+ " + line + "\n")
|
||||
}
|
||||
case diffmatchpatch.DiffDelete:
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
_, _ = buff.WriteString("- " + line + "\n")
|
||||
}
|
||||
case diffmatchpatch.DiffEqual:
|
||||
if len(text) > 40 {
|
||||
_, _ = buff.WriteString(" " + text[:20] + "..." + text[len(text)-20:] + "\n")
|
||||
} else {
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
_, _ = buff.WriteString(" " + line + "\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
return patchText
|
||||
return buff.String()
|
||||
}
|
||||
|
||||
func NewEditTool(workingDir string) tool.InvokableTool {
|
||||
return &editTool{
|
||||
workingDir: workingDir,
|
||||
}
|
||||
func editDescription() string {
|
||||
return `Edits files by replacing text, creating new files, or deleting content. For moving or renaming files, use the Bash tool with the 'mv' command instead. For larger file edits, use the FileWrite tool to overwrite files.
|
||||
|
||||
Before using this tool:
|
||||
|
||||
1. Use the FileRead tool to understand the file's contents and context
|
||||
|
||||
2. Verify the directory path is correct (only applicable when creating new files):
|
||||
- Use the LS tool to verify the parent directory exists and is the correct location
|
||||
|
||||
To make a file edit, provide the following:
|
||||
1. file_path: The absolute path to the file to modify (must be absolute, not relative)
|
||||
2. old_string: The text to replace (must be unique within the file, and must match the file contents exactly, including all whitespace and indentation)
|
||||
3. new_string: The edited text to replace the old_string
|
||||
|
||||
Special cases:
|
||||
- To create a new file: provide file_path and new_string, leave old_string empty
|
||||
- To delete content: provide file_path and old_string, leave new_string empty
|
||||
|
||||
The tool will replace ONE occurrence of old_string with new_string in the specified file.
|
||||
|
||||
CRITICAL REQUIREMENTS FOR USING THIS TOOL:
|
||||
|
||||
1. UNIQUENESS: The old_string MUST uniquely identify the specific instance you want to change. This means:
|
||||
- Include AT LEAST 3-5 lines of context BEFORE the change point
|
||||
- Include AT LEAST 3-5 lines of context AFTER the change point
|
||||
- Include all whitespace, indentation, and surrounding code exactly as it appears in the file
|
||||
|
||||
2. SINGLE INSTANCE: This tool can only change ONE instance at a time. If you need to change multiple instances:
|
||||
- Make separate calls to this tool for each instance
|
||||
- Each call must uniquely identify its specific instance using extensive context
|
||||
|
||||
3. VERIFICATION: Before using this tool:
|
||||
- Check how many instances of the target text exist in the file
|
||||
- If multiple instances exist, gather enough context to uniquely identify each one
|
||||
- Plan separate tool calls for each instance
|
||||
|
||||
WARNING: If you do not follow these requirements:
|
||||
- The tool will fail if old_string matches multiple locations
|
||||
- The tool will fail if old_string doesn't match exactly (including whitespace)
|
||||
- You may change the wrong instance if you don't include enough context
|
||||
|
||||
When making edits:
|
||||
- Ensure the edit results in idiomatic, correct code
|
||||
- Do not leave the code in a broken state
|
||||
- Always use absolute file paths (starting with /)
|
||||
|
||||
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.`
|
||||
}
|
||||
|
||||
func NewEditTool() BaseTool {
|
||||
return &editTool{}
|
||||
}
|
||||
|
|
|
@ -11,15 +11,11 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/bmatcuk/doublestar/v4"
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
)
|
||||
|
||||
type globTool struct {
|
||||
workingDir string
|
||||
}
|
||||
type globTool struct{}
|
||||
|
||||
const (
|
||||
GlobToolName = "glob"
|
||||
|
@ -35,43 +31,44 @@ type GlobParams struct {
|
|||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
func (b *globTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
return &schema.ToolInfo{
|
||||
Name: GlobToolName,
|
||||
Desc: `- Fast file pattern matching tool that works with any codebase size
|
||||
- Supports glob patterns like "**/*.js" or "src/**/*.ts"
|
||||
- Returns matching file paths sorted by modification time
|
||||
- Use this tool when you need to find files by name patterns
|
||||
- When you are doing an open ended search that may require multiple rounds of globbing and grepping, use the Agent tool instead`,
|
||||
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
|
||||
"pattern": {
|
||||
Type: "string",
|
||||
Desc: "The glob pattern to match files against",
|
||||
Required: true,
|
||||
func (g *globTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: GlobToolName,
|
||||
Description: globDescription(),
|
||||
Parameters: map[string]any{
|
||||
"pattern": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The glob pattern to match files against",
|
||||
},
|
||||
"path": {
|
||||
Type: "string",
|
||||
Desc: "The directory to search in. Defaults to the current working directory.",
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The directory to search in. Defaults to the current working directory.",
|
||||
},
|
||||
}),
|
||||
}, nil
|
||||
},
|
||||
Required: []string{"pattern"},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *globTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
|
||||
// Run implements Tool.
|
||||
func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params GlobParams
|
||||
if err := json.Unmarshal([]byte(args), ¶ms); err != nil {
|
||||
return fmt.Sprintf("error parsing parameters: %s", err), nil
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
if params.Pattern == "" {
|
||||
return NewTextErrorResponse("pattern is required"), nil
|
||||
}
|
||||
|
||||
// If path is empty, use current working directory
|
||||
searchPath := params.Path
|
||||
if searchPath == "" {
|
||||
searchPath = b.workingDir
|
||||
searchPath = config.WorkingDirectory()
|
||||
}
|
||||
|
||||
files, truncated, err := globFiles(params.Pattern, searchPath, 100)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error performing glob search: %s", err), nil
|
||||
return NewTextErrorResponse(fmt.Sprintf("error performing glob search: %s", err)), nil
|
||||
}
|
||||
|
||||
// Format the output for the assistant
|
||||
|
@ -81,11 +78,11 @@ func (b *globTool) InvokableRun(ctx context.Context, args string, opts ...tool.O
|
|||
} else {
|
||||
output = strings.Join(files, "\n")
|
||||
if truncated {
|
||||
output += "\n(Results are truncated. Consider using a more specific path or pattern.)"
|
||||
output += "\n\n(Results are truncated. Consider using a more specific path or pattern.)"
|
||||
}
|
||||
}
|
||||
|
||||
return output, nil
|
||||
return NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func globFiles(pattern, searchPath string, limit int) ([]string, bool, error) {
|
||||
|
@ -167,8 +164,43 @@ func skipHidden(path string) bool {
|
|||
return base != "." && strings.HasPrefix(base, ".")
|
||||
}
|
||||
|
||||
func NewGlobTool(workingDir string) tool.InvokableTool {
|
||||
return &globTool{
|
||||
workingDir,
|
||||
}
|
||||
func globDescription() string {
|
||||
return `Fast file pattern matching tool that finds files by name and pattern, returning matching paths sorted by modification time (newest first).
|
||||
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to find files by name patterns or extensions
|
||||
- Great for finding specific file types across a directory structure
|
||||
- Useful for discovering files that match certain naming conventions
|
||||
|
||||
HOW TO USE:
|
||||
- Provide a glob pattern to match against file paths
|
||||
- Optionally specify a starting directory (defaults to current working directory)
|
||||
- Results are sorted with most recently modified files first
|
||||
|
||||
GLOB PATTERN SYNTAX:
|
||||
- '*' matches any sequence of non-separator characters
|
||||
- '**' matches any sequence of characters, including separators
|
||||
- '?' matches any single non-separator character
|
||||
- '[...]' matches any character in the brackets
|
||||
- '[!...]' matches any character not in the brackets
|
||||
|
||||
COMMON PATTERN EXAMPLES:
|
||||
- '*.js' - Find all JavaScript files in the current directory
|
||||
- '**/*.js' - Find all JavaScript files in any subdirectory
|
||||
- 'src/**/*.{ts,tsx}' - Find all TypeScript files in the src directory
|
||||
- '*.{html,css,js}' - Find all HTML, CSS, and JS files
|
||||
|
||||
LIMITATIONS:
|
||||
- Results are limited to 100 files (newest first)
|
||||
- Does not search file contents (use Grep tool for that)
|
||||
- Hidden files (starting with '.') are skipped
|
||||
|
||||
TIPS:
|
||||
- For the most useful results, combine with the Grep tool: first find files with Glob, then search their contents with Grep
|
||||
- When doing iterative exploration that may require multiple rounds of searching, consider using the Agent tool instead
|
||||
- Always check if results are truncated and refine your search pattern if needed`
|
||||
}
|
||||
|
||||
func NewGlobTool() BaseTool {
|
||||
return &globTool{}
|
||||
}
|
||||
|
|
|
@ -13,18 +13,13 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
)
|
||||
|
||||
type grepTool struct {
|
||||
workingDir string
|
||||
}
|
||||
type grepTool struct{}
|
||||
|
||||
const (
|
||||
GrepToolName = "grep"
|
||||
|
||||
MaxGrepResults = 100
|
||||
)
|
||||
|
||||
type GrepParams struct {
|
||||
|
@ -38,83 +33,66 @@ type grepMatch struct {
|
|||
modTime time.Time
|
||||
}
|
||||
|
||||
func (b *grepTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
return &schema.ToolInfo{
|
||||
Name: GrepToolName,
|
||||
Desc: `- Fast content search tool that works with any codebase size
|
||||
- Searches file contents using regular expressions
|
||||
- Supports full regex syntax (eg. "log.*Error", "function\\s+\\w+", etc.)
|
||||
- Filter files by pattern with the include parameter (eg. "*.js", "*.{ts,tsx}")
|
||||
- Returns matching file paths sorted by modification time
|
||||
- Use this tool when you need to find files containing specific patterns
|
||||
- When you are doing an open ended search that may require multiple rounds of globbing and grepping, use the Agent tool instead`,
|
||||
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
|
||||
"command": {
|
||||
Type: "string",
|
||||
Desc: "The command to execute",
|
||||
Required: true,
|
||||
func (g *grepTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: GrepToolName,
|
||||
Description: grepDescription(),
|
||||
Parameters: map[string]any{
|
||||
"pattern": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The regex pattern to search for in file contents",
|
||||
},
|
||||
"timeout": {
|
||||
Type: "number",
|
||||
Desc: "Optional timeout in milliseconds (max 600000)",
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The directory to search in. Defaults to the current working directory.",
|
||||
},
|
||||
}),
|
||||
}, nil
|
||||
"include": map[string]any{
|
||||
"type": "string",
|
||||
"description": "File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")",
|
||||
},
|
||||
},
|
||||
Required: []string{"pattern"},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *grepTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
|
||||
// Run implements Tool.
|
||||
func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params GrepParams
|
||||
if err := json.Unmarshal([]byte(args), ¶ms); err != nil {
|
||||
return "", err
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
if params.Pattern == "" {
|
||||
return NewTextErrorResponse("pattern is required"), nil
|
||||
}
|
||||
|
||||
// If path is empty, use current working directory
|
||||
searchPath := params.Path
|
||||
if searchPath == "" {
|
||||
var err error
|
||||
searchPath, err = os.Getwd()
|
||||
if err != nil {
|
||||
return fmt.Sprintf("unable to get current working directory: %s", err), nil
|
||||
}
|
||||
searchPath = config.WorkingDirectory()
|
||||
}
|
||||
|
||||
matches, err := searchWithRipgrep(params.Pattern, searchPath, params.Include)
|
||||
matches, truncated, err := searchFiles(params.Pattern, searchPath, params.Include, 100)
|
||||
if err != nil {
|
||||
matches, err = searchFilesWithRegex(params.Pattern, searchPath, params.Include)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error searching files: %s", err), nil
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
return matches[i].modTime.After(matches[j].modTime)
|
||||
})
|
||||
|
||||
truncated := false
|
||||
if len(matches) > MaxGrepResults {
|
||||
truncated = true
|
||||
matches = matches[:MaxGrepResults]
|
||||
}
|
||||
|
||||
filenames := make([]string, len(matches))
|
||||
for i, m := range matches {
|
||||
filenames[i] = m.path
|
||||
return NewTextErrorResponse(fmt.Sprintf("error searching files: %s", err)), nil
|
||||
}
|
||||
|
||||
// Format the output for the assistant
|
||||
var output string
|
||||
if len(filenames) == 0 {
|
||||
if len(matches) == 0 {
|
||||
output = "No files found"
|
||||
} else {
|
||||
output = fmt.Sprintf("Found %d file%s\n%s",
|
||||
len(filenames),
|
||||
pluralize(len(filenames)),
|
||||
strings.Join(filenames, "\n"))
|
||||
len(matches),
|
||||
pluralize(len(matches)),
|
||||
strings.Join(matches, "\n"))
|
||||
|
||||
if truncated {
|
||||
output += "\n(Results are truncated. Consider using a more specific path or pattern.)"
|
||||
output += "\n\n(Results are truncated. Consider using a more specific path or pattern.)"
|
||||
}
|
||||
}
|
||||
|
||||
return output, nil
|
||||
return NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func pluralize(count int) string {
|
||||
|
@ -124,6 +102,37 @@ func pluralize(count int) string {
|
|||
return "s"
|
||||
}
|
||||
|
||||
func searchFiles(pattern, rootPath, include string, limit int) ([]string, bool, error) {
|
||||
// First try using ripgrep if available for better performance
|
||||
matches, err := searchWithRipgrep(pattern, rootPath, include)
|
||||
if err != nil {
|
||||
// Fall back to manual regex search if ripgrep is not available
|
||||
matches, err = searchFilesWithRegex(pattern, rootPath, include)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
}
|
||||
|
||||
// Sort files by modification time (newest first)
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
return matches[i].modTime.After(matches[j].modTime)
|
||||
})
|
||||
|
||||
// Check if we need to truncate the results
|
||||
truncated := len(matches) > limit
|
||||
if truncated {
|
||||
matches = matches[:limit]
|
||||
}
|
||||
|
||||
// Extract just the paths
|
||||
results := make([]string, len(matches))
|
||||
for i, m := range matches {
|
||||
results[i] = m.path
|
||||
}
|
||||
|
||||
return results, truncated, nil
|
||||
}
|
||||
|
||||
func searchWithRipgrep(pattern, path, include string) ([]grepMatch, error) {
|
||||
_, err := exec.LookPath("rg")
|
||||
if err != nil {
|
||||
|
@ -140,6 +149,7 @@ func searchWithRipgrep(pattern, path, include string) ([]grepMatch, error) {
|
|||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
|
||||
// Exit code 1 means no matches, which isn't an error for our purposes
|
||||
return []grepMatch{}, nil
|
||||
}
|
||||
return nil, err
|
||||
|
@ -155,7 +165,7 @@ func searchWithRipgrep(pattern, path, include string) ([]grepMatch, error) {
|
|||
|
||||
fileInfo, err := os.Stat(line)
|
||||
if err != nil {
|
||||
continue
|
||||
continue // Skip files we can't access
|
||||
}
|
||||
|
||||
matches = append(matches, grepMatch{
|
||||
|
@ -186,20 +196,27 @@ func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error
|
|||
|
||||
err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil // Skip errors
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return nil // Skip directories
|
||||
}
|
||||
|
||||
// Skip hidden files
|
||||
if skipHidden(path) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check include pattern if provided
|
||||
if includePattern != nil && !includePattern.MatchString(path) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check file contents for the pattern
|
||||
match, err := fileContainsPattern(path, regex)
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil // Skip files we can't read
|
||||
}
|
||||
|
||||
if match {
|
||||
|
@ -207,6 +224,11 @@ func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error
|
|||
path: path,
|
||||
modTime: info.ModTime(),
|
||||
})
|
||||
|
||||
// Check if we've hit the limit (collect double for sorting)
|
||||
if len(matches) >= 200 {
|
||||
return filepath.SkipAll
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -232,11 +254,7 @@ func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, error)
|
|||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return false, nil
|
||||
return false, scanner.Err()
|
||||
}
|
||||
|
||||
func globToRegex(glob string) string {
|
||||
|
@ -250,11 +268,46 @@ func globToRegex(glob string) string {
|
|||
return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
|
||||
})
|
||||
|
||||
return "^" + regexPattern + "$"
|
||||
return regexPattern
|
||||
}
|
||||
|
||||
func NewGrepTool(workingDir string) tool.InvokableTool {
|
||||
return &grepTool{
|
||||
workingDir,
|
||||
}
|
||||
func grepDescription() string {
|
||||
return `Fast content search tool that finds files containing specific text or patterns, returning matching file paths sorted by modification time (newest first).
|
||||
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to find files containing specific text or patterns
|
||||
- Great for searching code bases for function names, variable declarations, or error messages
|
||||
- Useful for finding all files that use a particular API or pattern
|
||||
|
||||
HOW TO USE:
|
||||
- Provide a regex pattern to search for within file contents
|
||||
- Optionally specify a starting directory (defaults to current working directory)
|
||||
- Optionally provide an include pattern to filter which files to search
|
||||
- Results are sorted with most recently modified files first
|
||||
|
||||
REGEX PATTERN SYNTAX:
|
||||
- Supports standard regular expression syntax
|
||||
- 'function' searches for the literal text "function"
|
||||
- 'log\..*Error' finds text starting with "log." and ending with "Error"
|
||||
- 'import\s+.*\s+from' finds import statements in JavaScript/TypeScript
|
||||
|
||||
COMMON INCLUDE PATTERN EXAMPLES:
|
||||
- '*.js' - Only search JavaScript files
|
||||
- '*.{ts,tsx}' - Only search TypeScript files
|
||||
- '*.go' - Only search Go files
|
||||
|
||||
LIMITATIONS:
|
||||
- Results are limited to 100 files (newest first)
|
||||
- Performance depends on the number of files being searched
|
||||
- Very large binary files may be skipped
|
||||
- Hidden files (starting with '.') are skipped
|
||||
|
||||
TIPS:
|
||||
- For faster, more targeted searches, first use Glob to find relevant files, then use Grep
|
||||
- When doing iterative exploration that may require multiple rounds of searching, consider using the Agent tool instead
|
||||
- Always check if results are truncated and refine your search pattern if needed`
|
||||
}
|
||||
|
||||
func NewGrepTool() BaseTool {
|
||||
return &grepTool{}
|
||||
}
|
||||
|
|
|
@ -8,19 +8,14 @@ import (
|
|||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
)
|
||||
|
||||
type lsTool struct {
|
||||
workingDir string
|
||||
}
|
||||
type lsTool struct{}
|
||||
|
||||
const (
|
||||
LSToolName = "ls"
|
||||
|
||||
MaxFiles = 1000
|
||||
TruncatedMessage = "There are more than 1000 files in the repository. Use the LS tool (passing a specific path), Bash tool, and other tools to explore nested directories. The first 1000 files and directories are included below:\n\n"
|
||||
MaxLSFiles = 1000
|
||||
)
|
||||
|
||||
type LSParams struct {
|
||||
|
@ -28,61 +23,82 @@ type LSParams struct {
|
|||
Ignore []string `json:"ignore"`
|
||||
}
|
||||
|
||||
func (b *lsTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
return &schema.ToolInfo{
|
||||
Name: LSToolName,
|
||||
Desc: "Lists files and directories in a given path. The path parameter must be an absolute path, not a relative path. You can optionally provide an array of glob patterns to ignore with the ignore parameter. You should generally prefer the Glob and Grep tools, if you know which directories to search.",
|
||||
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Desc: "The absolute path to the directory to list (must be absolute, not relative)",
|
||||
Required: true,
|
||||
},
|
||||
"ignore": {
|
||||
Type: "array",
|
||||
ElemInfo: &schema.ParameterInfo{
|
||||
Type: schema.String,
|
||||
Desc: "List of glob patterns to ignore",
|
||||
},
|
||||
},
|
||||
}),
|
||||
}, nil
|
||||
type TreeNode struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
Type string `json:"type"` // "file" or "directory"
|
||||
Children []*TreeNode `json:"children,omitempty"`
|
||||
}
|
||||
|
||||
func (b *lsTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
|
||||
func (l *lsTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: LSToolName,
|
||||
Description: lsDescription(),
|
||||
Parameters: map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The path to the directory to list (defaults to current working directory)",
|
||||
},
|
||||
"ignore": map[string]any{
|
||||
"type": "array",
|
||||
"description": "List of glob patterns to ignore",
|
||||
"items": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
Required: []string{"path"},
|
||||
}
|
||||
}
|
||||
|
||||
// Run implements Tool.
|
||||
func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params LSParams
|
||||
if err := json.Unmarshal([]byte(args), ¶ms); err != nil {
|
||||
return "", err
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
if !filepath.IsAbs(params.Path) {
|
||||
return fmt.Sprintf("path must be absolute, got: %s", params.Path), nil
|
||||
// If path is empty, use current working directory
|
||||
searchPath := params.Path
|
||||
if searchPath == "" {
|
||||
searchPath = config.WorkingDirectory()
|
||||
}
|
||||
|
||||
files, err := b.listDirectory(params.Path)
|
||||
// Ensure the path is absolute
|
||||
if !filepath.IsAbs(searchPath) {
|
||||
searchPath = filepath.Join(config.WorkingDirectory(), searchPath)
|
||||
}
|
||||
|
||||
// Check if the path exists
|
||||
if _, err := os.Stat(searchPath); os.IsNotExist(err) {
|
||||
return NewTextErrorResponse(fmt.Sprintf("path does not exist: %s", searchPath)), nil
|
||||
}
|
||||
|
||||
files, truncated, err := listDirectory(searchPath, params.Ignore, MaxLSFiles)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error listing directory: %s", err), nil
|
||||
return NewTextErrorResponse(fmt.Sprintf("error listing directory: %s", err)), nil
|
||||
}
|
||||
|
||||
tree := createFileTree(files)
|
||||
output := printTree(tree, params.Path)
|
||||
output := printTree(tree, searchPath)
|
||||
|
||||
if len(files) >= MaxFiles {
|
||||
output = TruncatedMessage + output
|
||||
if truncated {
|
||||
output = fmt.Sprintf("There are more than %d files in the directory. Use a more specific path or use the Glob tool to find specific files. The first %d files and directories are included below:\n\n%s", MaxLSFiles, MaxLSFiles, output)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
return NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func (b *lsTool) listDirectory(initialPath string) ([]string, error) {
|
||||
func listDirectory(initialPath string, ignorePatterns []string, limit int) ([]string, bool, error) {
|
||||
var results []string
|
||||
truncated := false
|
||||
|
||||
err := filepath.Walk(initialPath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil // Skip files we don't have permission to access
|
||||
}
|
||||
|
||||
if shouldSkip(path) {
|
||||
if shouldSkip(path, ignorePatterns) {
|
||||
if info.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
@ -93,137 +109,212 @@ func (b *lsTool) listDirectory(initialPath string) ([]string, error) {
|
|||
if info.IsDir() {
|
||||
path = path + string(filepath.Separator)
|
||||
}
|
||||
|
||||
relPath, err := filepath.Rel(b.workingDir, path)
|
||||
if err == nil {
|
||||
results = append(results, relPath)
|
||||
} else {
|
||||
results = append(results, path)
|
||||
}
|
||||
results = append(results, path)
|
||||
}
|
||||
|
||||
if len(results) >= MaxFiles {
|
||||
return fmt.Errorf("max files reached")
|
||||
if len(results) >= limit {
|
||||
truncated = true
|
||||
return filepath.SkipAll
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil && err.Error() != "max files reached" {
|
||||
return nil, err
|
||||
if err != nil {
|
||||
return nil, truncated, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return results, truncated, nil
|
||||
}
|
||||
|
||||
func shouldSkip(path string) bool {
|
||||
func shouldSkip(path string, ignorePatterns []string) bool {
|
||||
base := filepath.Base(path)
|
||||
|
||||
// Skip hidden files and directories
|
||||
if base != "." && strings.HasPrefix(base, ".") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip common directories and files
|
||||
commonIgnored := []string{
|
||||
"__pycache__",
|
||||
"node_modules",
|
||||
"dist",
|
||||
"build",
|
||||
"target",
|
||||
"vendor",
|
||||
"bin",
|
||||
"obj",
|
||||
".git",
|
||||
".idea",
|
||||
".vscode",
|
||||
".DS_Store",
|
||||
"*.pyc",
|
||||
"*.pyo",
|
||||
"*.pyd",
|
||||
"*.so",
|
||||
"*.dll",
|
||||
"*.exe",
|
||||
}
|
||||
|
||||
// Skip __pycache__ directories
|
||||
if strings.Contains(path, filepath.Join("__pycache__", "")) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check against common ignored patterns
|
||||
for _, ignored := range commonIgnored {
|
||||
if strings.HasSuffix(ignored, "/") {
|
||||
// Directory pattern
|
||||
if strings.Contains(path, filepath.Join(ignored[:len(ignored)-1], "")) {
|
||||
return true
|
||||
}
|
||||
} else if strings.HasPrefix(ignored, "*.") {
|
||||
// File extension pattern
|
||||
if strings.HasSuffix(base, ignored[1:]) {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
// Exact match
|
||||
if base == ignored {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check against ignore patterns
|
||||
for _, pattern := range ignorePatterns {
|
||||
matched, err := filepath.Match(pattern, base)
|
||||
if err == nil && matched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
type TreeNode struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
Type string `json:"type"` // "file" or "directory"
|
||||
Children []TreeNode `json:"children,omitempty"`
|
||||
}
|
||||
|
||||
func createFileTree(sortedPaths []string) []TreeNode {
|
||||
root := []TreeNode{}
|
||||
func createFileTree(sortedPaths []string) []*TreeNode {
|
||||
root := []*TreeNode{}
|
||||
pathMap := make(map[string]*TreeNode)
|
||||
|
||||
for _, path := range sortedPaths {
|
||||
parts := strings.Split(path, string(filepath.Separator))
|
||||
currentLevel := &root
|
||||
currentPath := ""
|
||||
var parentPath string
|
||||
|
||||
var cleanParts []string
|
||||
for _, part := range parts {
|
||||
if part != "" {
|
||||
cleanParts = append(cleanParts, part)
|
||||
}
|
||||
}
|
||||
parts = cleanParts
|
||||
|
||||
if len(parts) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if currentPath == "" {
|
||||
currentPath = part
|
||||
} else {
|
||||
currentPath = filepath.Join(currentPath, part)
|
||||
}
|
||||
|
||||
if _, exists := pathMap[currentPath]; exists {
|
||||
parentPath = currentPath
|
||||
continue
|
||||
}
|
||||
|
||||
isLastPart := i == len(parts)-1
|
||||
isDir := !isLastPart || strings.HasSuffix(path, string(filepath.Separator))
|
||||
|
||||
found := false
|
||||
for i := range *currentLevel {
|
||||
if (*currentLevel)[i].Name == part {
|
||||
found = true
|
||||
if (*currentLevel)[i].Children != nil {
|
||||
currentLevel = &(*currentLevel)[i].Children
|
||||
}
|
||||
break
|
||||
}
|
||||
nodeType := "file"
|
||||
if isDir {
|
||||
nodeType = "directory"
|
||||
}
|
||||
newNode := &TreeNode{
|
||||
Name: part,
|
||||
Path: currentPath,
|
||||
Type: nodeType,
|
||||
Children: []*TreeNode{},
|
||||
}
|
||||
|
||||
if !found {
|
||||
nodeType := "file"
|
||||
if isDir {
|
||||
nodeType = "directory"
|
||||
}
|
||||
pathMap[currentPath] = newNode
|
||||
|
||||
newNode := TreeNode{
|
||||
Name: part,
|
||||
Path: currentPath,
|
||||
Type: nodeType,
|
||||
}
|
||||
|
||||
if isDir {
|
||||
newNode.Children = []TreeNode{}
|
||||
*currentLevel = append(*currentLevel, newNode)
|
||||
currentLevel = &(*currentLevel)[len(*currentLevel)-1].Children
|
||||
} else {
|
||||
*currentLevel = append(*currentLevel, newNode)
|
||||
if i > 0 && parentPath != "" {
|
||||
if parent, ok := pathMap[parentPath]; ok {
|
||||
parent.Children = append(parent.Children, newNode)
|
||||
}
|
||||
} else {
|
||||
root = append(root, newNode)
|
||||
}
|
||||
|
||||
parentPath = currentPath
|
||||
}
|
||||
}
|
||||
|
||||
return root
|
||||
}
|
||||
|
||||
func printTree(tree []TreeNode, rootPath string) string {
|
||||
func printTree(tree []*TreeNode, rootPath string) string {
|
||||
var result strings.Builder
|
||||
|
||||
result.WriteString(fmt.Sprintf("- %s%s\n", rootPath, string(filepath.Separator)))
|
||||
|
||||
printTreeRecursive(&result, tree, 0, " ")
|
||||
for _, node := range tree {
|
||||
printNode(&result, node, 1)
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func printTreeRecursive(builder *strings.Builder, tree []TreeNode, level int, prefix string) {
|
||||
for _, node := range tree {
|
||||
linePrefix := prefix + "- "
|
||||
func printNode(builder *strings.Builder, node *TreeNode, level int) {
|
||||
indent := strings.Repeat(" ", level)
|
||||
|
||||
nodeName := node.Name
|
||||
if node.Type == "directory" {
|
||||
nodeName += string(filepath.Separator)
|
||||
}
|
||||
fmt.Fprintf(builder, "%s%s\n", linePrefix, nodeName)
|
||||
nodeName := node.Name
|
||||
if node.Type == "directory" {
|
||||
nodeName += string(filepath.Separator)
|
||||
}
|
||||
|
||||
if node.Type == "directory" && len(node.Children) > 0 {
|
||||
printTreeRecursive(builder, node.Children, level+1, prefix+" ")
|
||||
fmt.Fprintf(builder, "%s- %s\n", indent, nodeName)
|
||||
|
||||
if node.Type == "directory" && len(node.Children) > 0 {
|
||||
for _, child := range node.Children {
|
||||
printNode(builder, child, level+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewLsTool(workingDir string) tool.InvokableTool {
|
||||
return &lsTool{
|
||||
workingDir,
|
||||
}
|
||||
func lsDescription() string {
|
||||
return `Directory listing tool that shows files and subdirectories in a tree structure, helping you explore and understand the project organization.
|
||||
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to explore the structure of a directory
|
||||
- Helpful for understanding the organization of a project
|
||||
- Good first step when getting familiar with a new codebase
|
||||
|
||||
HOW TO USE:
|
||||
- Provide a path to list (defaults to current working directory)
|
||||
- Optionally specify glob patterns to ignore
|
||||
- Results are displayed in a tree structure
|
||||
|
||||
FEATURES:
|
||||
- Displays a hierarchical view of files and directories
|
||||
- Automatically skips hidden files/directories (starting with '.')
|
||||
- Skips common system directories like __pycache__
|
||||
- Can filter out files matching specific patterns
|
||||
|
||||
LIMITATIONS:
|
||||
- Results are limited to 1000 files
|
||||
- Very large directories will be truncated
|
||||
- Does not show file sizes or permissions
|
||||
- Cannot recursively list all directories in a large project
|
||||
|
||||
TIPS:
|
||||
- Use Glob tool for finding files by name patterns instead of browsing
|
||||
- Use Grep tool for searching file contents
|
||||
- Combine with other tools for more effective exploration`
|
||||
}
|
||||
|
||||
func NewLsTool() BaseTool {
|
||||
return &lsTool{}
|
||||
}
|
||||
|
|
457
internal/llm/tools/ls_test.go
Normal file
457
internal/llm/tools/ls_test.go
Normal file
|
@ -0,0 +1,457 @@
|
|||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLsTool_Info(t *testing.T) {
|
||||
tool := NewLsTool()
|
||||
info := tool.Info()
|
||||
|
||||
assert.Equal(t, LSToolName, info.Name)
|
||||
assert.NotEmpty(t, info.Description)
|
||||
assert.Contains(t, info.Parameters, "path")
|
||||
assert.Contains(t, info.Parameters, "ignore")
|
||||
assert.Contains(t, info.Required, "path")
|
||||
}
|
||||
|
||||
func TestLsTool_Run(t *testing.T) {
|
||||
// Create a temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "ls_tool_test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create a test directory structure
|
||||
testDirs := []string{
|
||||
"dir1",
|
||||
"dir2",
|
||||
"dir2/subdir1",
|
||||
"dir2/subdir2",
|
||||
"dir3",
|
||||
"dir3/.hidden_dir",
|
||||
"__pycache__",
|
||||
}
|
||||
|
||||
testFiles := []string{
|
||||
"file1.txt",
|
||||
"file2.txt",
|
||||
"dir1/file3.txt",
|
||||
"dir2/file4.txt",
|
||||
"dir2/subdir1/file5.txt",
|
||||
"dir2/subdir2/file6.txt",
|
||||
"dir3/file7.txt",
|
||||
"dir3/.hidden_file.txt",
|
||||
"__pycache__/cache.pyc",
|
||||
".hidden_root_file.txt",
|
||||
}
|
||||
|
||||
// Create directories
|
||||
for _, dir := range testDirs {
|
||||
dirPath := filepath.Join(tempDir, dir)
|
||||
err := os.MkdirAll(dirPath, 0755)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create files
|
||||
for _, file := range testFiles {
|
||||
filePath := filepath.Join(tempDir, file)
|
||||
err := os.WriteFile(filePath, []byte("test content"), 0644)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("lists directory successfully", func(t *testing.T) {
|
||||
tool := NewLsTool()
|
||||
params := LSParams{
|
||||
Path: tempDir,
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: LSToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that visible directories and files are included
|
||||
assert.Contains(t, response.Content, "dir1")
|
||||
assert.Contains(t, response.Content, "dir2")
|
||||
assert.Contains(t, response.Content, "dir3")
|
||||
assert.Contains(t, response.Content, "file1.txt")
|
||||
assert.Contains(t, response.Content, "file2.txt")
|
||||
|
||||
// Check that hidden files and directories are not included
|
||||
assert.NotContains(t, response.Content, ".hidden_dir")
|
||||
assert.NotContains(t, response.Content, ".hidden_file.txt")
|
||||
assert.NotContains(t, response.Content, ".hidden_root_file.txt")
|
||||
|
||||
// Check that __pycache__ is not included
|
||||
assert.NotContains(t, response.Content, "__pycache__")
|
||||
})
|
||||
|
||||
t.Run("handles non-existent path", func(t *testing.T) {
|
||||
tool := NewLsTool()
|
||||
params := LSParams{
|
||||
Path: filepath.Join(tempDir, "non_existent_dir"),
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: LSToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "path does not exist")
|
||||
})
|
||||
|
||||
t.Run("handles empty path parameter", func(t *testing.T) {
|
||||
// For this test, we need to mock the config.WorkingDirectory function
|
||||
// Since we can't easily do that, we'll just check that the response doesn't contain an error message
|
||||
|
||||
tool := NewLsTool()
|
||||
params := LSParams{
|
||||
Path: "",
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: LSToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The response should either contain a valid directory listing or an error
|
||||
// We'll just check that it's not empty
|
||||
assert.NotEmpty(t, response.Content)
|
||||
})
|
||||
|
||||
t.Run("handles invalid parameters", func(t *testing.T) {
|
||||
tool := NewLsTool()
|
||||
call := ToolCall{
|
||||
Name: LSToolName,
|
||||
Input: "invalid json",
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "error parsing parameters")
|
||||
})
|
||||
|
||||
t.Run("respects ignore patterns", func(t *testing.T) {
|
||||
tool := NewLsTool()
|
||||
params := LSParams{
|
||||
Path: tempDir,
|
||||
Ignore: []string{"file1.txt", "dir1"},
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: LSToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The output format is a tree, so we need to check for specific patterns
|
||||
// Check that file1.txt is not directly mentioned
|
||||
assert.NotContains(t, response.Content, "- file1.txt")
|
||||
|
||||
// Check that dir1/ is not directly mentioned
|
||||
assert.NotContains(t, response.Content, "- dir1/")
|
||||
})
|
||||
|
||||
t.Run("handles relative path", func(t *testing.T) {
|
||||
// Save original working directory
|
||||
origWd, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
os.Chdir(origWd)
|
||||
}()
|
||||
|
||||
// Change to a directory above the temp directory
|
||||
parentDir := filepath.Dir(tempDir)
|
||||
err = os.Chdir(parentDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
tool := NewLsTool()
|
||||
params := LSParams{
|
||||
Path: filepath.Base(tempDir),
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: LSToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should list the temp directory contents
|
||||
assert.Contains(t, response.Content, "dir1")
|
||||
assert.Contains(t, response.Content, "file1.txt")
|
||||
})
|
||||
}
|
||||
|
||||
func TestShouldSkip(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
path string
|
||||
ignorePatterns []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "hidden file",
|
||||
path: "/path/to/.hidden_file",
|
||||
ignorePatterns: []string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "hidden directory",
|
||||
path: "/path/to/.hidden_dir",
|
||||
ignorePatterns: []string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "pycache directory",
|
||||
path: "/path/to/__pycache__/file.pyc",
|
||||
ignorePatterns: []string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "node_modules directory",
|
||||
path: "/path/to/node_modules/package",
|
||||
ignorePatterns: []string{},
|
||||
expected: false, // The shouldSkip function doesn't directly check for node_modules in the path
|
||||
},
|
||||
{
|
||||
name: "normal file",
|
||||
path: "/path/to/normal_file.txt",
|
||||
ignorePatterns: []string{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "normal directory",
|
||||
path: "/path/to/normal_dir",
|
||||
ignorePatterns: []string{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "ignored by pattern",
|
||||
path: "/path/to/ignore_me.txt",
|
||||
ignorePatterns: []string{"ignore_*.txt"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "not ignored by pattern",
|
||||
path: "/path/to/keep_me.txt",
|
||||
ignorePatterns: []string{"ignore_*.txt"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := shouldSkip(tc.path, tc.ignorePatterns)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateFileTree(t *testing.T) {
|
||||
paths := []string{
|
||||
"/path/to/file1.txt",
|
||||
"/path/to/dir1/file2.txt",
|
||||
"/path/to/dir1/subdir/file3.txt",
|
||||
"/path/to/dir2/file4.txt",
|
||||
}
|
||||
|
||||
tree := createFileTree(paths)
|
||||
|
||||
// Check the structure of the tree
|
||||
assert.Len(t, tree, 1) // Should have one root node
|
||||
|
||||
// Check the root node
|
||||
rootNode := tree[0]
|
||||
assert.Equal(t, "path", rootNode.Name)
|
||||
assert.Equal(t, "directory", rootNode.Type)
|
||||
assert.Len(t, rootNode.Children, 1)
|
||||
|
||||
// Check the "to" node
|
||||
toNode := rootNode.Children[0]
|
||||
assert.Equal(t, "to", toNode.Name)
|
||||
assert.Equal(t, "directory", toNode.Type)
|
||||
assert.Len(t, toNode.Children, 3) // file1.txt, dir1, dir2
|
||||
|
||||
// Find the dir1 node
|
||||
var dir1Node *TreeNode
|
||||
for _, child := range toNode.Children {
|
||||
if child.Name == "dir1" {
|
||||
dir1Node = child
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, dir1Node)
|
||||
assert.Equal(t, "directory", dir1Node.Type)
|
||||
assert.Len(t, dir1Node.Children, 2) // file2.txt and subdir
|
||||
}
|
||||
|
||||
func TestPrintTree(t *testing.T) {
|
||||
// Create a simple tree
|
||||
tree := []*TreeNode{
|
||||
{
|
||||
Name: "dir1",
|
||||
Path: "dir1",
|
||||
Type: "directory",
|
||||
Children: []*TreeNode{
|
||||
{
|
||||
Name: "file1.txt",
|
||||
Path: "dir1/file1.txt",
|
||||
Type: "file",
|
||||
},
|
||||
{
|
||||
Name: "subdir",
|
||||
Path: "dir1/subdir",
|
||||
Type: "directory",
|
||||
Children: []*TreeNode{
|
||||
{
|
||||
Name: "file2.txt",
|
||||
Path: "dir1/subdir/file2.txt",
|
||||
Type: "file",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "file3.txt",
|
||||
Path: "file3.txt",
|
||||
Type: "file",
|
||||
},
|
||||
}
|
||||
|
||||
result := printTree(tree, "/root")
|
||||
|
||||
// Check the output format
|
||||
assert.Contains(t, result, "- /root/")
|
||||
assert.Contains(t, result, " - dir1/")
|
||||
assert.Contains(t, result, " - file1.txt")
|
||||
assert.Contains(t, result, " - subdir/")
|
||||
assert.Contains(t, result, " - file2.txt")
|
||||
assert.Contains(t, result, " - file3.txt")
|
||||
}
|
||||
|
||||
func TestListDirectory(t *testing.T) {
|
||||
// Create a temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "list_directory_test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create a test directory structure
|
||||
testDirs := []string{
|
||||
"dir1",
|
||||
"dir1/subdir1",
|
||||
".hidden_dir",
|
||||
}
|
||||
|
||||
testFiles := []string{
|
||||
"file1.txt",
|
||||
"file2.txt",
|
||||
"dir1/file3.txt",
|
||||
"dir1/subdir1/file4.txt",
|
||||
".hidden_file.txt",
|
||||
}
|
||||
|
||||
// Create directories
|
||||
for _, dir := range testDirs {
|
||||
dirPath := filepath.Join(tempDir, dir)
|
||||
err := os.MkdirAll(dirPath, 0755)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create files
|
||||
for _, file := range testFiles {
|
||||
filePath := filepath.Join(tempDir, file)
|
||||
err := os.WriteFile(filePath, []byte("test content"), 0644)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("lists files with no limit", func(t *testing.T) {
|
||||
files, truncated, err := listDirectory(tempDir, []string{}, 1000)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, truncated)
|
||||
|
||||
// Check that visible files and directories are included
|
||||
containsPath := func(paths []string, target string) bool {
|
||||
targetPath := filepath.Join(tempDir, target)
|
||||
for _, path := range paths {
|
||||
if strings.HasPrefix(path, targetPath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
assert.True(t, containsPath(files, "dir1"))
|
||||
assert.True(t, containsPath(files, "file1.txt"))
|
||||
assert.True(t, containsPath(files, "file2.txt"))
|
||||
assert.True(t, containsPath(files, "dir1/file3.txt"))
|
||||
|
||||
// Check that hidden files and directories are not included
|
||||
assert.False(t, containsPath(files, ".hidden_dir"))
|
||||
assert.False(t, containsPath(files, ".hidden_file.txt"))
|
||||
})
|
||||
|
||||
t.Run("respects limit and returns truncated flag", func(t *testing.T) {
|
||||
files, truncated, err := listDirectory(tempDir, []string{}, 2)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, truncated)
|
||||
assert.Len(t, files, 2)
|
||||
})
|
||||
|
||||
t.Run("respects ignore patterns", func(t *testing.T) {
|
||||
files, truncated, err := listDirectory(tempDir, []string{"*.txt"}, 1000)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, truncated)
|
||||
|
||||
// Check that no .txt files are included
|
||||
for _, file := range files {
|
||||
assert.False(t, strings.HasSuffix(file, ".txt"), "Found .txt file: %s", file)
|
||||
}
|
||||
|
||||
// But directories should still be included
|
||||
containsDir := false
|
||||
for _, file := range files {
|
||||
if strings.Contains(file, "dir1") {
|
||||
containsDir = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, containsDir)
|
||||
})
|
||||
}
|
|
@ -116,10 +116,10 @@ func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx
|
|||
}
|
||||
|
||||
tempDir := os.TempDir()
|
||||
stdoutFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-stdout-%d", time.Now().UnixNano()))
|
||||
stderrFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-stderr-%d", time.Now().UnixNano()))
|
||||
statusFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-status-%d", time.Now().UnixNano()))
|
||||
cwdFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-cwd-%d", time.Now().UnixNano()))
|
||||
stdoutFile := filepath.Join(tempDir, fmt.Sprintf("termai-stdout-%d", time.Now().UnixNano()))
|
||||
stderrFile := filepath.Join(tempDir, fmt.Sprintf("termai-stderr-%d", time.Now().UnixNano()))
|
||||
statusFile := filepath.Join(tempDir, fmt.Sprintf("termai-status-%d", time.Now().UnixNano()))
|
||||
cwdFile := filepath.Join(tempDir, fmt.Sprintf("termai-cwd-%d", time.Now().UnixNano()))
|
||||
|
||||
defer func() {
|
||||
os.Remove(stdoutFile)
|
||||
|
|
49
internal/llm/tools/tools.go
Normal file
49
internal/llm/tools/tools.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package tools
|
||||
|
||||
import "context"
|
||||
|
||||
type ToolInfo struct {
|
||||
Name string
|
||||
Description string
|
||||
Parameters map[string]any
|
||||
Required []string
|
||||
}
|
||||
|
||||
type toolResponseType string
|
||||
|
||||
const (
|
||||
ToolResponseTypeText toolResponseType = "text"
|
||||
ToolResponseTypeImage toolResponseType = "image"
|
||||
)
|
||||
|
||||
type ToolResponse struct {
|
||||
Type toolResponseType `json:"type"`
|
||||
Content string `json:"content"`
|
||||
IsError bool `json:"is_error"`
|
||||
}
|
||||
|
||||
func NewTextResponse(content string) ToolResponse {
|
||||
return ToolResponse{
|
||||
Type: ToolResponseTypeText,
|
||||
Content: content,
|
||||
}
|
||||
}
|
||||
|
||||
func NewTextErrorResponse(content string) ToolResponse {
|
||||
return ToolResponse{
|
||||
Type: ToolResponseTypeText,
|
||||
Content: content,
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
type BaseTool interface {
|
||||
Info() ToolInfo
|
||||
Run(ctx context.Context, params ToolCall) (ToolResponse, error)
|
||||
}
|
|
@ -10,77 +10,77 @@ import (
|
|||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
)
|
||||
|
||||
type viewTool struct {
|
||||
workingDir string
|
||||
}
|
||||
type viewTool struct{}
|
||||
|
||||
const (
|
||||
ViewToolName = "view"
|
||||
|
||||
MaxReadSize = 250 * 1024
|
||||
|
||||
ViewToolName = "view"
|
||||
MaxReadSize = 250 * 1024
|
||||
DefaultReadLimit = 2000
|
||||
|
||||
MaxLineLength = 2000
|
||||
MaxLineLength = 2000
|
||||
)
|
||||
|
||||
type ViewPatams struct {
|
||||
type ViewParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
Offset int `json:"offset"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
func (b *viewTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
return &schema.ToolInfo{
|
||||
Name: ViewToolName,
|
||||
Desc: `Reads a file from the local filesystem. The file_path parameter must be an absolute path, not a relative path. By default, it reads up to 2000 lines starting from the beginning of the file. You can optionally specify a line offset and limit (especially handy for long files), but it's recommended to read the whole file by not providing these parameters. Any lines longer than 2000 characters will be truncated. For image files, the tool will display the image for you.`,
|
||||
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
|
||||
"file_path": {
|
||||
Type: "string",
|
||||
Desc: "The absolute path to the file to read",
|
||||
Required: true,
|
||||
func (v *viewTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: ViewToolName,
|
||||
Description: viewDescription(),
|
||||
Parameters: map[string]any{
|
||||
"file_path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The path to the file to read",
|
||||
},
|
||||
"offset": {
|
||||
Type: "int",
|
||||
Desc: "The line number to start reading from. Only provide if the file is too large to read at once",
|
||||
"offset": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The line number to start reading from (0-based)",
|
||||
},
|
||||
"limit": {
|
||||
Type: "int",
|
||||
Desc: "The number of lines to read. Only provide if the file is too large to read at once.",
|
||||
"limit": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The number of lines to read (defaults to 2000)",
|
||||
},
|
||||
}),
|
||||
}, nil
|
||||
},
|
||||
Required: []string{"file_path"},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *viewTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
|
||||
var params ViewPatams
|
||||
if err := json.Unmarshal([]byte(args), ¶ms); err != nil {
|
||||
return fmt.Sprintf("failed to parse parameters: %s", err), nil
|
||||
// Run implements Tool.
|
||||
func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params ViewParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
if params.FilePath == "" {
|
||||
return "file_path is required", nil
|
||||
return NewTextErrorResponse("file_path is required"), nil
|
||||
}
|
||||
|
||||
if !filepath.IsAbs(params.FilePath) {
|
||||
return fmt.Sprintf("file path must be absolute, got: %s", params.FilePath), nil
|
||||
// Handle relative paths
|
||||
filePath := params.FilePath
|
||||
if !filepath.IsAbs(filePath) {
|
||||
filePath = filepath.Join(config.WorkingDirectory(), filePath)
|
||||
}
|
||||
|
||||
fileInfo, err := os.Stat(params.FilePath)
|
||||
// Check if file exists
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
dir := filepath.Dir(params.FilePath)
|
||||
base := filepath.Base(params.FilePath)
|
||||
// Try to offer suggestions for similarly named files
|
||||
dir := filepath.Dir(filePath)
|
||||
base := filepath.Base(filePath)
|
||||
|
||||
dirEntries, dirErr := os.ReadDir(dir)
|
||||
if dirErr == nil {
|
||||
var suggestions []string
|
||||
for _, entry := range dirEntries {
|
||||
if strings.Contains(entry.Name(), base) || strings.Contains(base, entry.Name()) {
|
||||
if strings.Contains(strings.ToLower(entry.Name()), strings.ToLower(base)) ||
|
||||
strings.Contains(strings.ToLower(base), strings.ToLower(entry.Name())) {
|
||||
suggestions = append(suggestions, filepath.Join(dir, entry.Name()))
|
||||
if len(suggestions) >= 3 {
|
||||
break
|
||||
|
@ -89,43 +89,55 @@ func (b *viewTool) InvokableRun(ctx context.Context, args string, opts ...tool.O
|
|||
}
|
||||
|
||||
if len(suggestions) > 0 {
|
||||
return fmt.Sprintf("file not found: %s. Did you mean one of these?\n%s",
|
||||
params.FilePath, strings.Join(suggestions, "\n")), nil
|
||||
return NewTextErrorResponse(fmt.Sprintf("File not found: %s\n\nDid you mean one of these?\n%s",
|
||||
filePath, strings.Join(suggestions, "\n"))), nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("file not found: %s", params.FilePath), nil
|
||||
return NewTextErrorResponse(fmt.Sprintf("File not found: %s", filePath)), nil
|
||||
}
|
||||
return fmt.Sprintf("failed to access file: %s", err), nil
|
||||
return NewTextErrorResponse(fmt.Sprintf("Failed to access file: %s", err)), nil
|
||||
}
|
||||
|
||||
// Check if it's a directory
|
||||
if fileInfo.IsDir() {
|
||||
return fmt.Sprintf("path is a directory, not a file: %s", params.FilePath), nil
|
||||
return NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
|
||||
}
|
||||
|
||||
// Check file size
|
||||
if fileInfo.Size() > MaxReadSize {
|
||||
return fmt.Sprintf("file is too large (%d bytes). Maximum size is %d bytes",
|
||||
fileInfo.Size(), MaxReadSize), nil
|
||||
return NewTextErrorResponse(fmt.Sprintf("File is too large (%d bytes). Maximum size is %d bytes",
|
||||
fileInfo.Size(), MaxReadSize)), nil
|
||||
}
|
||||
|
||||
// Set default limit if not provided
|
||||
if params.Limit <= 0 {
|
||||
params.Limit = DefaultReadLimit
|
||||
}
|
||||
|
||||
isImage, _ := isImageFile(params.FilePath)
|
||||
// Check if it's an image file
|
||||
isImage, imageType := isImageFile(filePath)
|
||||
if isImage {
|
||||
// TODO: Implement image reading
|
||||
return "reading images is not supported", nil
|
||||
return NewTextErrorResponse(fmt.Sprintf("This is an image file of type: %s\nUse a different tool to process images", imageType)), nil
|
||||
}
|
||||
|
||||
content, _, err := readTextFile(params.FilePath, params.Offset, params.Limit)
|
||||
// Read the file content
|
||||
content, lineCount, err := readTextFile(filePath, params.Offset, params.Limit)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("failed to read file: %s", err), nil
|
||||
return NewTextErrorResponse(fmt.Sprintf("Failed to read file: %s", err)), nil
|
||||
}
|
||||
|
||||
recordFileRead(params.FilePath)
|
||||
// Format the output with line numbers
|
||||
output := addLineNumbers(content, params.Offset+1)
|
||||
|
||||
return addLineNumbers(content, params.Offset+1), nil
|
||||
// Add a note if the content was truncated
|
||||
if lineCount > params.Offset+len(strings.Split(content, "\n")) {
|
||||
output += fmt.Sprintf("\n\n(File has more lines. Use 'offset' parameter to read beyond line %d)",
|
||||
params.Offset+len(strings.Split(content, "\n")))
|
||||
}
|
||||
|
||||
recordFileRead(filePath)
|
||||
return NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func addLineNumbers(content string, startLine int) string {
|
||||
|
@ -191,6 +203,11 @@ func readTextFile(filePath string, offset, limit int) (string, int, error) {
|
|||
lines = append(lines, lineText)
|
||||
}
|
||||
|
||||
// Continue scanning to get total line count
|
||||
for scanner.Scan() {
|
||||
lineCount++
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
@ -202,17 +219,17 @@ func isImageFile(filePath string) (bool, string) {
|
|||
ext := strings.ToLower(filepath.Ext(filePath))
|
||||
switch ext {
|
||||
case ".jpg", ".jpeg":
|
||||
return true, "jpeg"
|
||||
return true, "JPEG"
|
||||
case ".png":
|
||||
return true, "png"
|
||||
return true, "PNG"
|
||||
case ".gif":
|
||||
return true, "gif"
|
||||
return true, "GIF"
|
||||
case ".bmp":
|
||||
return true, "bmp"
|
||||
return true, "BMP"
|
||||
case ".svg":
|
||||
return true, "svg"
|
||||
return true, "SVG"
|
||||
case ".webp":
|
||||
return true, "webp"
|
||||
return true, "WebP"
|
||||
default:
|
||||
return false, ""
|
||||
}
|
||||
|
@ -240,8 +257,39 @@ func (s *LineScanner) Err() error {
|
|||
return s.scanner.Err()
|
||||
}
|
||||
|
||||
func NewViewTool(workingDir string) tool.InvokableTool {
|
||||
return &viewTool{
|
||||
workingDir,
|
||||
}
|
||||
func viewDescription() string {
|
||||
return `File viewing tool that reads and displays the contents of files with line numbers, allowing you to examine code, logs, or text data.
|
||||
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to read the contents of a specific file
|
||||
- Helpful for examining source code, configuration files, or log files
|
||||
- Perfect for looking at text-based file formats
|
||||
|
||||
HOW TO USE:
|
||||
- Provide the path to the file you want to view
|
||||
- Optionally specify an offset to start reading from a specific line
|
||||
- Optionally specify a limit to control how many lines are read
|
||||
|
||||
FEATURES:
|
||||
- Displays file contents with line numbers for easy reference
|
||||
- Can read from any position in a file using the offset parameter
|
||||
- Handles large files by limiting the number of lines read
|
||||
- Automatically truncates very long lines for better display
|
||||
- Suggests similar file names when the requested file isn't found
|
||||
|
||||
LIMITATIONS:
|
||||
- Maximum file size is 250KB
|
||||
- Default reading limit is 2000 lines
|
||||
- Lines longer than 2000 characters are truncated
|
||||
- Cannot display binary files or images
|
||||
- Images can be identified but not displayed
|
||||
|
||||
TIPS:
|
||||
- Use with Glob tool to first find files you want to view
|
||||
- For code exploration, first use Grep to find relevant files, then View to examine them
|
||||
- When viewing large files, use the offset parameter to read specific sections`
|
||||
}
|
||||
|
||||
func NewViewTool() BaseTool {
|
||||
return &viewTool{}
|
||||
}
|
||||
|
|
|
@ -6,17 +6,13 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
)
|
||||
|
||||
type writeTool struct {
|
||||
workingDir string
|
||||
}
|
||||
type writeTool struct{}
|
||||
|
||||
const (
|
||||
WriteToolName = "write"
|
||||
|
@ -27,139 +23,139 @@ type WriteParams struct {
|
|||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func (b *writeTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
return &schema.ToolInfo{
|
||||
Name: WriteToolName,
|
||||
Desc: "Write a file to the local filesystem. Overwrites the existing file if there is one.\n\nBefore using this tool:\n\n1. Use the ReadFile tool to understand the file's contents and context\n\n2. Directory Verification (only applicable when creating new files):\n - Use the LS tool to verify the parent directory exists and is the correct location",
|
||||
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
|
||||
"file_path": {
|
||||
Type: "string",
|
||||
Desc: "The absolute path to the file to write (must be absolute, not relative)",
|
||||
Required: true,
|
||||
},
|
||||
"content": {
|
||||
Type: "string",
|
||||
Desc: "The content to write to the file",
|
||||
Required: true,
|
||||
},
|
||||
}),
|
||||
}, nil
|
||||
type WritePermissionsParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func (b *writeTool) InvokableRun(ctx context.Context, args string, opts ...tool.Option) (string, error) {
|
||||
func (w *writeTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: WriteToolName,
|
||||
Description: writeDescription(),
|
||||
Parameters: map[string]any{
|
||||
"file_path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The path to the file to write",
|
||||
},
|
||||
"content": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The content to write to the file",
|
||||
},
|
||||
},
|
||||
Required: []string{"file_path", "content"},
|
||||
}
|
||||
}
|
||||
|
||||
// Run implements Tool.
|
||||
func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params WriteParams
|
||||
if err := json.Unmarshal([]byte(args), ¶ms); err != nil {
|
||||
return "", fmt.Errorf("failed to parse parameters: %w", err)
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
if params.FilePath == "" {
|
||||
return "file_path is required", nil
|
||||
return NewTextErrorResponse("file_path is required"), nil
|
||||
}
|
||||
|
||||
if !filepath.IsAbs(params.FilePath) {
|
||||
return fmt.Sprintf("file path must be absolute, got: %s", params.FilePath), nil
|
||||
if params.Content == "" {
|
||||
return NewTextErrorResponse("content is required"), nil
|
||||
}
|
||||
|
||||
// fileExists := false
|
||||
// oldContent := ""
|
||||
fileInfo, err := os.Stat(params.FilePath)
|
||||
// Handle relative paths
|
||||
filePath := params.FilePath
|
||||
if !filepath.IsAbs(filePath) {
|
||||
filePath = filepath.Join(config.WorkingDirectory(), filePath)
|
||||
}
|
||||
|
||||
// Check if file exists and is a directory
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
if err == nil {
|
||||
if fileInfo.IsDir() {
|
||||
return fmt.Sprintf("path is a directory, not a file: %s", params.FilePath), nil
|
||||
return NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
|
||||
}
|
||||
|
||||
// Check if file was modified since last read
|
||||
modTime := fileInfo.ModTime()
|
||||
lastRead := getLastReadTime(params.FilePath)
|
||||
lastRead := getLastReadTime(filePath)
|
||||
if modTime.After(lastRead) {
|
||||
return fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
|
||||
params.FilePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339)), nil
|
||||
return NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
|
||||
filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
|
||||
}
|
||||
|
||||
// oldContentBytes, readErr := os.ReadFile(params.FilePath)
|
||||
// if readErr != nil {
|
||||
// oldContent = string(oldContentBytes)
|
||||
// }
|
||||
// Optional: Get old content for diff
|
||||
oldContent, readErr := os.ReadFile(filePath)
|
||||
if readErr == nil && string(oldContent) == params.Content {
|
||||
return NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return fmt.Sprintf("failed to access file: %s", err), nil
|
||||
return NewTextErrorResponse(fmt.Sprintf("Failed to access file: %s", err)), nil
|
||||
}
|
||||
|
||||
// Create parent directories if needed
|
||||
dir := filepath.Dir(filePath)
|
||||
if err = os.MkdirAll(dir, 0o755); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("Failed to create parent directories: %s", err)), nil
|
||||
}
|
||||
p := permission.Default.Request(
|
||||
permission.CreatePermissionRequest{
|
||||
Path: b.workingDir,
|
||||
Path: filePath,
|
||||
ToolName: WriteToolName,
|
||||
Action: "write",
|
||||
Description: fmt.Sprintf("Write to file %s", params.FilePath),
|
||||
Params: map[string]interface{}{
|
||||
"file_path": params.FilePath,
|
||||
"contnet": params.Content,
|
||||
Action: "create",
|
||||
Description: fmt.Sprintf("Create file %s", filePath),
|
||||
Params: WritePermissionsParams{
|
||||
FilePath: filePath,
|
||||
Content: GenerateDiff("", params.Content),
|
||||
},
|
||||
},
|
||||
)
|
||||
if !p {
|
||||
return "", fmt.Errorf("permission denied")
|
||||
}
|
||||
dir := filepath.Dir(params.FilePath)
|
||||
if err = os.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Sprintf("failed to create parent directories: %s", err), nil
|
||||
return NewTextErrorResponse(fmt.Sprintf("Permission denied to create file: %s", filePath)), nil
|
||||
}
|
||||
|
||||
err = os.WriteFile(params.FilePath, []byte(params.Content), 0o644)
|
||||
// Write the file
|
||||
err = os.WriteFile(filePath, []byte(params.Content), 0o644)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("failed to write file: %s", err), nil
|
||||
return NewTextErrorResponse(fmt.Sprintf("Failed to write file: %s", err)), nil
|
||||
}
|
||||
|
||||
recordFileWrite(params.FilePath)
|
||||
// Record the file write
|
||||
recordFileWrite(filePath)
|
||||
recordFileRead(filePath)
|
||||
|
||||
output := "File written: " + params.FilePath
|
||||
|
||||
// if fileExists && oldContent != params.Content {
|
||||
// output = generateSimpleDiff(oldContent, params.Content)
|
||||
// }
|
||||
|
||||
return output, nil
|
||||
return NewTextResponse(fmt.Sprintf("File successfully written: %s", filePath)), nil
|
||||
}
|
||||
|
||||
func generateSimpleDiff(oldContent, newContent string) string {
|
||||
if oldContent == newContent {
|
||||
return "[No changes]"
|
||||
}
|
||||
func writeDescription() string {
|
||||
return `File writing tool that creates or updates files in the filesystem, allowing you to save or modify text content.
|
||||
|
||||
oldLines := strings.Split(oldContent, "\n")
|
||||
newLines := strings.Split(newContent, "\n")
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to create a new file
|
||||
- Helpful for updating existing files with modified content
|
||||
- Perfect for saving generated code, configurations, or text data
|
||||
|
||||
var diffBuilder strings.Builder
|
||||
diffBuilder.WriteString(fmt.Sprintf("@@ -%d,+%d @@\n", len(oldLines), len(newLines)))
|
||||
HOW TO USE:
|
||||
- Provide the path to the file you want to write
|
||||
- Include the content to be written to the file
|
||||
- The tool will create any necessary parent directories
|
||||
|
||||
maxLines := max(len(oldLines), len(newLines))
|
||||
for i := range maxLines {
|
||||
oldLine := ""
|
||||
newLine := ""
|
||||
FEATURES:
|
||||
- Can create new files or overwrite existing ones
|
||||
- Creates parent directories automatically if they don't exist
|
||||
- Checks if the file has been modified since last read for safety
|
||||
- Avoids unnecessary writes when content hasn't changed
|
||||
|
||||
if i < len(oldLines) {
|
||||
oldLine = oldLines[i]
|
||||
}
|
||||
LIMITATIONS:
|
||||
- You should read a file before writing to it to avoid conflicts
|
||||
- Cannot append to files (rewrites the entire file)
|
||||
|
||||
if i < len(newLines) {
|
||||
newLine = newLines[i]
|
||||
}
|
||||
|
||||
if oldLine != newLine {
|
||||
if i < len(oldLines) {
|
||||
diffBuilder.WriteString(fmt.Sprintf("- %s\n", oldLine))
|
||||
}
|
||||
if i < len(newLines) {
|
||||
diffBuilder.WriteString(fmt.Sprintf("+ %s\n", newLine))
|
||||
}
|
||||
} else {
|
||||
diffBuilder.WriteString(fmt.Sprintf(" %s\n", oldLine))
|
||||
}
|
||||
}
|
||||
|
||||
return diffBuilder.String()
|
||||
TIPS:
|
||||
- Use the View tool first to examine existing files before modifying them
|
||||
- Use the LS tool to verify the correct location when creating new files
|
||||
- Combine with Glob and Grep tools to find and modify multiple files
|
||||
- Always include descriptive comments when making changes to existing code`
|
||||
}
|
||||
|
||||
func NewWriteTool(workingDir string) tool.InvokableTool {
|
||||
return &writeTool{
|
||||
workingDir: workingDir,
|
||||
}
|
||||
func NewWriteTool() BaseTool {
|
||||
return &writeTool{}
|
||||
}
|
||||
|
|
324
internal/llm/tools/write_test.go
Normal file
324
internal/llm/tools/write_test.go
Normal file
|
@ -0,0 +1,324 @@
|
|||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWriteTool_Info(t *testing.T) {
|
||||
tool := NewWriteTool()
|
||||
info := tool.Info()
|
||||
|
||||
assert.Equal(t, WriteToolName, info.Name)
|
||||
assert.NotEmpty(t, info.Description)
|
||||
assert.Contains(t, info.Parameters, "file_path")
|
||||
assert.Contains(t, info.Parameters, "content")
|
||||
assert.Contains(t, info.Required, "file_path")
|
||||
assert.Contains(t, info.Required, "content")
|
||||
}
|
||||
|
||||
func TestWriteTool_Run(t *testing.T) {
|
||||
// Setup a mock permission handler that always allows
|
||||
origPermission := permission.Default
|
||||
defer func() {
|
||||
permission.Default = origPermission
|
||||
}()
|
||||
permission.Default = newMockPermissionService(true)
|
||||
|
||||
// Create a temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "write_tool_test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
t.Run("creates a new file successfully", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
filePath := filepath.Join(tempDir, "new_file.txt")
|
||||
content := "This is a test content"
|
||||
|
||||
params := WriteParams{
|
||||
FilePath: filePath,
|
||||
Content: content,
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: WriteToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "successfully written")
|
||||
|
||||
// Verify file was created with correct content
|
||||
fileContent, err := os.ReadFile(filePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, string(fileContent))
|
||||
})
|
||||
|
||||
t.Run("creates file with nested directories", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt")
|
||||
content := "Content in nested directory"
|
||||
|
||||
params := WriteParams{
|
||||
FilePath: filePath,
|
||||
Content: content,
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: WriteToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "successfully written")
|
||||
|
||||
// Verify file was created with correct content
|
||||
fileContent, err := os.ReadFile(filePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, string(fileContent))
|
||||
})
|
||||
|
||||
t.Run("updates existing file", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
// Create a file first
|
||||
filePath := filepath.Join(tempDir, "existing_file.txt")
|
||||
initialContent := "Initial content"
|
||||
err := os.WriteFile(filePath, []byte(initialContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Record the file read to avoid modification time check failure
|
||||
recordFileRead(filePath)
|
||||
|
||||
// Update the file
|
||||
updatedContent := "Updated content"
|
||||
params := WriteParams{
|
||||
FilePath: filePath,
|
||||
Content: updatedContent,
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: WriteToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "successfully written")
|
||||
|
||||
// Verify file was updated with correct content
|
||||
fileContent, err := os.ReadFile(filePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, updatedContent, string(fileContent))
|
||||
})
|
||||
|
||||
t.Run("handles invalid parameters", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
call := ToolCall{
|
||||
Name: WriteToolName,
|
||||
Input: "invalid json",
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "error parsing parameters")
|
||||
})
|
||||
|
||||
t.Run("handles missing file_path", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
params := WriteParams{
|
||||
FilePath: "",
|
||||
Content: "Some content",
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: WriteToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "file_path is required")
|
||||
})
|
||||
|
||||
t.Run("handles missing content", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
params := WriteParams{
|
||||
FilePath: filepath.Join(tempDir, "file.txt"),
|
||||
Content: "",
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: WriteToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "content is required")
|
||||
})
|
||||
|
||||
t.Run("handles writing to a directory path", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
// Create a directory
|
||||
dirPath := filepath.Join(tempDir, "test_dir")
|
||||
err := os.Mkdir(dirPath, 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
params := WriteParams{
|
||||
FilePath: dirPath,
|
||||
Content: "Some content",
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: WriteToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "Path is a directory")
|
||||
})
|
||||
|
||||
t.Run("handles permission denied", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(false)
|
||||
tool := NewWriteTool()
|
||||
|
||||
filePath := filepath.Join(tempDir, "permission_denied.txt")
|
||||
params := WriteParams{
|
||||
FilePath: filePath,
|
||||
Content: "Content that should not be written",
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: WriteToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "Permission denied")
|
||||
|
||||
// Verify file was not created
|
||||
_, err = os.Stat(filePath)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
})
|
||||
|
||||
t.Run("detects file modified since last read", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
// Create a file
|
||||
filePath := filepath.Join(tempDir, "modified_file.txt")
|
||||
initialContent := "Initial content"
|
||||
err := os.WriteFile(filePath, []byte(initialContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Record an old read time
|
||||
fileRecordMutex.Lock()
|
||||
fileRecords[filePath] = fileRecord{
|
||||
path: filePath,
|
||||
readTime: time.Now().Add(-1 * time.Hour),
|
||||
}
|
||||
fileRecordMutex.Unlock()
|
||||
|
||||
// Try to update the file
|
||||
params := WriteParams{
|
||||
FilePath: filePath,
|
||||
Content: "Updated content",
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: WriteToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "has been modified since it was last read")
|
||||
|
||||
// Verify file was not modified
|
||||
fileContent, err := os.ReadFile(filePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, initialContent, string(fileContent))
|
||||
})
|
||||
|
||||
t.Run("skips writing when content is identical", func(t *testing.T) {
|
||||
permission.Default = newMockPermissionService(true)
|
||||
tool := NewWriteTool()
|
||||
|
||||
// Create a file
|
||||
filePath := filepath.Join(tempDir, "identical_content.txt")
|
||||
content := "Content that won't change"
|
||||
err := os.WriteFile(filePath, []byte(content), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Record a read time
|
||||
recordFileRead(filePath)
|
||||
|
||||
// Try to write the same content
|
||||
params := WriteParams{
|
||||
FilePath: filePath,
|
||||
Content: content,
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: WriteToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "already contains the exact content")
|
||||
})
|
||||
}
|
|
@ -2,26 +2,65 @@ package message
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/google/uuid"
|
||||
"github.com/kujtimiihoxha/termai/internal/db"
|
||||
"github.com/kujtimiihoxha/termai/internal/pubsub"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
ID string
|
||||
SessionID string
|
||||
MessageData schema.Message
|
||||
type MessageRole string
|
||||
|
||||
CreatedAt int64
|
||||
UpdatedAt int64
|
||||
const (
|
||||
Assistant MessageRole = "assistant"
|
||||
User MessageRole = "user"
|
||||
System MessageRole = "system"
|
||||
Tool MessageRole = "tool"
|
||||
)
|
||||
|
||||
type ToolResult struct {
|
||||
ToolCallID string
|
||||
Content string
|
||||
IsError bool
|
||||
// TODO: support for images
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string
|
||||
Name string
|
||||
Input string
|
||||
Type string
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
ID string
|
||||
SessionID string
|
||||
|
||||
// NEW
|
||||
Role MessageRole
|
||||
Content string
|
||||
Thinking string
|
||||
|
||||
Finished bool
|
||||
|
||||
ToolResults []ToolResult
|
||||
ToolCalls []ToolCall
|
||||
CreatedAt int64
|
||||
UpdatedAt int64
|
||||
}
|
||||
|
||||
type CreateMessageParams struct {
|
||||
Role MessageRole
|
||||
Content string
|
||||
ToolCalls []ToolCall
|
||||
ToolResults []ToolResult
|
||||
}
|
||||
|
||||
type Service interface {
|
||||
pubsub.Suscriber[Message]
|
||||
Create(sessionID string, messageData schema.Message) (Message, error)
|
||||
Create(sessionID string, params CreateMessageParams) (Message, error)
|
||||
Update(message Message) error
|
||||
Get(id string) (Message, error)
|
||||
List(sessionID string) ([]Message, error)
|
||||
Delete(id string) error
|
||||
|
@ -34,24 +73,6 @@ type service struct {
|
|||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s *service) Create(sessionID string, messageData schema.Message) (Message, error) {
|
||||
messageDataJSON, err := json.Marshal(messageData)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{
|
||||
ID: uuid.New().String(),
|
||||
SessionID: sessionID,
|
||||
MessageData: string(messageDataJSON),
|
||||
})
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
message := s.fromDBItem(dbMessage)
|
||||
s.Publish(pubsub.CreatedEvent, message)
|
||||
return message, nil
|
||||
}
|
||||
|
||||
func (s *service) Delete(id string) error {
|
||||
message, err := s.Get(id)
|
||||
if err != nil {
|
||||
|
@ -65,6 +86,35 @@ func (s *service) Delete(id string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Create(sessionID string, params CreateMessageParams) (Message, error) {
|
||||
toolCallsStr, err := json.Marshal(params.ToolCalls)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
toolResultsStr, err := json.Marshal(params.ToolResults)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{
|
||||
ID: uuid.New().String(),
|
||||
SessionID: sessionID,
|
||||
Role: string(params.Role),
|
||||
Finished: params.Role != Assistant,
|
||||
Content: params.Content,
|
||||
ToolCalls: sql.NullString{String: string(toolCallsStr), Valid: true},
|
||||
ToolResults: sql.NullString{String: string(toolResultsStr), Valid: true},
|
||||
})
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
message, err := s.fromDBItem(dbMessage)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
s.Publish(pubsub.CreatedEvent, message)
|
||||
return message, nil
|
||||
}
|
||||
|
||||
func (s *service) DeleteSessionMessages(sessionID string) error {
|
||||
messages, err := s.List(sessionID)
|
||||
if err != nil {
|
||||
|
@ -81,12 +131,36 @@ func (s *service) DeleteSessionMessages(sessionID string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Update(message Message) error {
|
||||
toolCallsStr, err := json.Marshal(message.ToolCalls)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
toolResultsStr, err := json.Marshal(message.ToolResults)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{
|
||||
ID: message.ID,
|
||||
Content: message.Content,
|
||||
Thinking: message.Thinking,
|
||||
Finished: message.Finished,
|
||||
ToolCalls: sql.NullString{String: string(toolCallsStr), Valid: true},
|
||||
ToolResults: sql.NullString{String: string(toolResultsStr), Valid: true},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Publish(pubsub.UpdatedEvent, message)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Get(id string) (Message, error) {
|
||||
dbMessage, err := s.q.GetMessage(s.ctx, id)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
return s.fromDBItem(dbMessage), nil
|
||||
return s.fromDBItem(dbMessage)
|
||||
}
|
||||
|
||||
func (s *service) List(sessionID string) ([]Message, error) {
|
||||
|
@ -96,21 +170,43 @@ func (s *service) List(sessionID string) ([]Message, error) {
|
|||
}
|
||||
messages := make([]Message, len(dbMessages))
|
||||
for i, dbMessage := range dbMessages {
|
||||
messages[i] = s.fromDBItem(dbMessage)
|
||||
messages[i], err = s.fromDBItem(dbMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (s *service) fromDBItem(item db.Message) Message {
|
||||
var messageData schema.Message
|
||||
json.Unmarshal([]byte(item.MessageData), &messageData)
|
||||
func (s *service) fromDBItem(item db.Message) (Message, error) {
|
||||
toolCalls := make([]ToolCall, 0)
|
||||
if item.ToolCalls.Valid {
|
||||
err := json.Unmarshal([]byte(item.ToolCalls.String), &toolCalls)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
}
|
||||
|
||||
toolResults := make([]ToolResult, 0)
|
||||
if item.ToolResults.Valid {
|
||||
err := json.Unmarshal([]byte(item.ToolResults.String), &toolResults)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return Message{
|
||||
ID: item.ID,
|
||||
SessionID: item.SessionID,
|
||||
MessageData: messageData,
|
||||
Role: MessageRole(item.Role),
|
||||
Content: item.Content,
|
||||
Thinking: item.Thinking,
|
||||
Finished: item.Finished,
|
||||
ToolCalls: toolCalls,
|
||||
ToolResults: toolResults,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewService(ctx context.Context, q db.Querier) Service {
|
||||
|
|
|
@ -65,6 +65,7 @@ func (s *permissionService) Deny(permission PermissionRequest) {
|
|||
func (s *permissionService) Request(opts CreatePermissionRequest) bool {
|
||||
permission := PermissionRequest{
|
||||
ID: uuid.New().String(),
|
||||
Path: opts.Path,
|
||||
ToolName: opts.ToolName,
|
||||
Description: opts.Description,
|
||||
Action: opts.Action,
|
||||
|
|
|
@ -2,6 +2,7 @@ package session
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/kujtimiihoxha/termai/internal/db"
|
||||
|
@ -10,6 +11,7 @@ import (
|
|||
|
||||
type Session struct {
|
||||
ID string
|
||||
ParentSessionID string
|
||||
Title string
|
||||
MessageCount int64
|
||||
PromptTokens int64
|
||||
|
@ -22,6 +24,7 @@ type Session struct {
|
|||
type Service interface {
|
||||
pubsub.Suscriber[Session]
|
||||
Create(title string) (Session, error)
|
||||
CreateTaskSession(toolCallID, parentSessionID, title string) (Session, error)
|
||||
Get(id string) (Session, error)
|
||||
List() ([]Session, error)
|
||||
Save(session Session) (Session, error)
|
||||
|
@ -47,6 +50,20 @@ func (s *service) Create(title string) (Session, error) {
|
|||
return session, nil
|
||||
}
|
||||
|
||||
func (s *service) CreateTaskSession(toolCallID, parentSessionID, title string) (Session, error) {
|
||||
dbSession, err := s.q.CreateSession(s.ctx, db.CreateSessionParams{
|
||||
ID: toolCallID,
|
||||
ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
|
||||
Title: title,
|
||||
})
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
}
|
||||
session := s.fromDBItem(dbSession)
|
||||
s.Publish(pubsub.CreatedEvent, session)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *service) Delete(id string) error {
|
||||
session, err := s.Get(id)
|
||||
if err != nil {
|
||||
|
@ -99,6 +116,7 @@ func (s *service) List() ([]Session, error) {
|
|||
func (s service) fromDBItem(item db.Session) Session {
|
||||
return Session{
|
||||
ID: item.ID,
|
||||
ParentSessionID: item.ParentSessionID.String,
|
||||
Title: item.Title,
|
||||
MessageCount: item.MessageCount,
|
||||
PromptTokens: item.PromptTokens,
|
||||
|
|
|
@ -14,7 +14,12 @@ type SizeableModel interface {
|
|||
}
|
||||
|
||||
type DialogMsg struct {
|
||||
Content SizeableModel
|
||||
Content SizeableModel
|
||||
WidthRatio float64
|
||||
HeightRatio float64
|
||||
|
||||
MinWidth int
|
||||
MinHeight int
|
||||
}
|
||||
|
||||
type DialogCloseMsg struct{}
|
||||
|
@ -36,7 +41,18 @@ type DialogCmp interface {
|
|||
}
|
||||
|
||||
type dialogCmp struct {
|
||||
content SizeableModel
|
||||
content SizeableModel
|
||||
screenWidth int
|
||||
screenHeight int
|
||||
|
||||
widthRatio float64
|
||||
heightRatio float64
|
||||
|
||||
minWidth int
|
||||
minHeight int
|
||||
|
||||
width int
|
||||
height int
|
||||
}
|
||||
|
||||
func (d *dialogCmp) Init() tea.Cmd {
|
||||
|
@ -45,8 +61,26 @@ func (d *dialogCmp) Init() tea.Cmd {
|
|||
|
||||
func (d *dialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
d.screenWidth = msg.Width
|
||||
d.screenHeight = msg.Height
|
||||
d.width = max(int(float64(d.screenWidth)*d.widthRatio), d.minWidth)
|
||||
d.height = max(int(float64(d.screenHeight)*d.heightRatio), d.minHeight)
|
||||
if d.content != nil {
|
||||
d.content.SetSize(d.width, d.height)
|
||||
}
|
||||
return d, nil
|
||||
case DialogMsg:
|
||||
d.content = msg.Content
|
||||
d.widthRatio = msg.WidthRatio
|
||||
d.heightRatio = msg.HeightRatio
|
||||
d.minWidth = msg.MinWidth
|
||||
d.minHeight = msg.MinHeight
|
||||
d.width = max(int(float64(d.screenWidth)*d.widthRatio), d.minWidth)
|
||||
d.height = max(int(float64(d.screenHeight)*d.heightRatio), d.minHeight)
|
||||
if d.content != nil {
|
||||
d.content.SetSize(d.width, d.height)
|
||||
}
|
||||
case DialogCloseMsg:
|
||||
d.content = nil
|
||||
return d, nil
|
||||
|
@ -75,8 +109,7 @@ func (d *dialogCmp) BindingKeys() []key.Binding {
|
|||
}
|
||||
|
||||
func (d *dialogCmp) View() string {
|
||||
w, h := d.content.GetSize()
|
||||
return lipgloss.NewStyle().Width(w).Height(h).Render(d.content.View())
|
||||
return lipgloss.NewStyle().Width(d.width).Height(d.height).Render(d.content.View())
|
||||
}
|
||||
|
||||
func NewDialogCmp() DialogCmp {
|
||||
|
|
|
@ -3,6 +3,8 @@ package core
|
|||
import (
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/kujtimiihoxha/termai/internal/config"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/models"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/styles"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/util"
|
||||
"github.com/kujtimiihoxha/termai/internal/version"
|
||||
|
@ -57,14 +59,19 @@ func (m statusCmp) View() string {
|
|||
Width(m.availableFooterMsgWidth()).
|
||||
Render(m.info)
|
||||
}
|
||||
|
||||
status += m.model()
|
||||
status += versionWidget
|
||||
return status
|
||||
}
|
||||
|
||||
func (m statusCmp) availableFooterMsgWidth() int {
|
||||
// -2 to accommodate padding
|
||||
return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(versionWidget))
|
||||
return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(versionWidget)-lipgloss.Width(m.model()))
|
||||
}
|
||||
|
||||
func (m statusCmp) model() string {
|
||||
model := models.SupportedModels[config.Get().Model.Coder]
|
||||
return styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render(model.Name)
|
||||
}
|
||||
|
||||
func NewStatusCmp() tea.Model {
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
package dialog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/glamour"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/tools"
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/components/core"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/layout"
|
||||
|
@ -28,12 +33,6 @@ type PermissionResponseMsg struct {
|
|||
Action PermissionAction
|
||||
}
|
||||
|
||||
// Width and height constants for the dialog
|
||||
var (
|
||||
permissionWidth = 60
|
||||
permissionHeight = 10
|
||||
)
|
||||
|
||||
// PermissionDialog interface for permission dialog component
|
||||
type PermissionDialog interface {
|
||||
tea.Model
|
||||
|
@ -41,13 +40,28 @@ type PermissionDialog interface {
|
|||
layout.Bindings
|
||||
}
|
||||
|
||||
type keyMap struct {
|
||||
ChangeFocus key.Binding
|
||||
}
|
||||
|
||||
var keyMapValue = keyMap{
|
||||
ChangeFocus: key.NewBinding(
|
||||
key.WithKeys("tab"),
|
||||
key.WithHelp("tab", "change focus"),
|
||||
),
|
||||
}
|
||||
|
||||
// permissionDialogCmp is the implementation of PermissionDialog
|
||||
type permissionDialogCmp struct {
|
||||
form *huh.Form
|
||||
content string
|
||||
width int
|
||||
height int
|
||||
permission permission.PermissionRequest
|
||||
form *huh.Form
|
||||
width int
|
||||
height int
|
||||
permission permission.PermissionRequest
|
||||
windowSize tea.WindowSizeMsg
|
||||
r *glamour.TermRenderer
|
||||
contentViewPort viewport.Model
|
||||
isViewportFocus bool
|
||||
selectOption *huh.Select[string]
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) Init() tea.Cmd {
|
||||
|
@ -57,41 +71,101 @@ func (p *permissionDialogCmp) Init() tea.Cmd {
|
|||
func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmds []tea.Cmd
|
||||
|
||||
// Process the form
|
||||
form, cmd := p.form.Update(msg)
|
||||
if f, ok := form.(*huh.Form); ok {
|
||||
p.form = f
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
p.windowSize = msg
|
||||
case tea.KeyMsg:
|
||||
if key.Matches(msg, keyMapValue.ChangeFocus) {
|
||||
p.isViewportFocus = !p.isViewportFocus
|
||||
if p.isViewportFocus {
|
||||
p.selectOption.Blur()
|
||||
} else {
|
||||
p.selectOption.Focus()
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
|
||||
if p.isViewportFocus {
|
||||
viewPort, cmd := p.contentViewPort.Update(msg)
|
||||
p.contentViewPort = viewPort
|
||||
cmds = append(cmds, cmd)
|
||||
} else {
|
||||
form, cmd := p.form.Update(msg)
|
||||
if f, ok := form.(*huh.Form); ok {
|
||||
p.form = f
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
if p.form.State == huh.StateCompleted {
|
||||
// Get the selected action
|
||||
action := p.form.GetString("action")
|
||||
|
||||
// Close the dialog and return the response
|
||||
return p, tea.Batch(
|
||||
util.CmdHandler(core.DialogCloseMsg{}),
|
||||
util.CmdHandler(PermissionResponseMsg{Action: PermissionAction(action), Permission: p.permission}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if p.form.State == huh.StateCompleted {
|
||||
// Get the selected action
|
||||
action := p.form.GetString("action")
|
||||
|
||||
// Close the dialog and return the response
|
||||
return p, tea.Batch(
|
||||
util.CmdHandler(core.DialogCloseMsg{}),
|
||||
util.CmdHandler(PermissionResponseMsg{Action: PermissionAction(action), Permission: p.permission}),
|
||||
)
|
||||
}
|
||||
|
||||
return p, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) View() string {
|
||||
contentStyle := lipgloss.NewStyle().
|
||||
Width(p.width).
|
||||
Padding(1, 0).
|
||||
Foreground(styles.Text).
|
||||
Align(lipgloss.Center)
|
||||
func (p *permissionDialogCmp) render() string {
|
||||
form := p.form.View()
|
||||
keyStyle := lipgloss.NewStyle().Bold(true).Foreground(styles.Rosewater)
|
||||
valueStyle := lipgloss.NewStyle().Foreground(styles.Peach)
|
||||
|
||||
headerParts := []string{
|
||||
lipgloss.JoinHorizontal(lipgloss.Left, keyStyle.Render("Tool:"), " ", valueStyle.Render(p.permission.ToolName)),
|
||||
" ",
|
||||
lipgloss.JoinHorizontal(lipgloss.Left, keyStyle.Render("Path:"), " ", valueStyle.Render(p.permission.Path)),
|
||||
" ",
|
||||
}
|
||||
r, _ := glamour.NewTermRenderer(
|
||||
glamour.WithStyles(styles.CatppuccinMarkdownStyle()),
|
||||
glamour.WithWordWrap(p.width-10),
|
||||
glamour.WithEmoji(),
|
||||
)
|
||||
content := ""
|
||||
switch p.permission.ToolName {
|
||||
case tools.BashToolName:
|
||||
pr := p.permission.Params.(tools.BashPermissionsParams)
|
||||
headerParts = append(headerParts, keyStyle.Render("Command:"))
|
||||
content, _ = r.Render(fmt.Sprintf("```bash\n%s\n```", pr.Command))
|
||||
case tools.EditToolName:
|
||||
pr := p.permission.Params.(tools.EditPermissionsParams)
|
||||
headerParts = append(headerParts, keyStyle.Render("Update:"))
|
||||
content, _ = r.Render(fmt.Sprintf("```diff\n%s\n```", pr.Diff))
|
||||
case tools.WriteToolName:
|
||||
pr := p.permission.Params.(tools.WritePermissionsParams)
|
||||
headerParts = append(headerParts, keyStyle.Render("Content:"))
|
||||
content, _ = r.Render(fmt.Sprintf("```diff\n%s\n```", pr.Content))
|
||||
default:
|
||||
content, _ = r.Render(p.permission.Description)
|
||||
}
|
||||
headerContent := lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...))
|
||||
p.contentViewPort.Width = p.width - 2 - 2
|
||||
p.contentViewPort.Height = p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1
|
||||
p.contentViewPort.SetContent(content)
|
||||
contentBorder := lipgloss.RoundedBorder()
|
||||
if p.isViewportFocus {
|
||||
contentBorder = lipgloss.DoubleBorder()
|
||||
}
|
||||
cotentStyle := lipgloss.NewStyle().MarginTop(1).Padding(0, 1).Border(contentBorder).BorderForeground(styles.Flamingo)
|
||||
|
||||
return lipgloss.JoinVertical(
|
||||
lipgloss.Center,
|
||||
contentStyle.Render(p.content),
|
||||
p.form.View(),
|
||||
lipgloss.Top,
|
||||
headerContent,
|
||||
cotentStyle.Render(p.contentViewPort.View()),
|
||||
form,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) View() string {
|
||||
return p.render()
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) GetSize() (int, int) {
|
||||
return p.width, p.height
|
||||
}
|
||||
|
@ -99,13 +173,14 @@ func (p *permissionDialogCmp) GetSize() (int, int) {
|
|||
func (p *permissionDialogCmp) SetSize(width int, height int) {
|
||||
p.width = width
|
||||
p.height = height
|
||||
p.form = p.form.WithWidth(width)
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) BindingKeys() []key.Binding {
|
||||
return p.form.KeyBinds()
|
||||
}
|
||||
|
||||
func newPermissionDialogCmp(permission permission.PermissionRequest, content string) PermissionDialog {
|
||||
func newPermissionDialogCmp(permission permission.PermissionRequest) PermissionDialog {
|
||||
// Create a note field for displaying the content
|
||||
|
||||
// Create select field for the permission options
|
||||
|
@ -116,14 +191,13 @@ func newPermissionDialogCmp(permission permission.PermissionRequest, content str
|
|||
huh.NewOption("Allow for this session", string(PermissionAllowForSession)),
|
||||
huh.NewOption("Deny", string(PermissionDeny)),
|
||||
).
|
||||
Title("Permission Request")
|
||||
Title("Select an action")
|
||||
|
||||
// Apply theme
|
||||
theme := styles.HuhTheme()
|
||||
|
||||
// Setup form width and height
|
||||
form := huh.NewForm(huh.NewGroup(selectOption)).
|
||||
WithWidth(permissionWidth - 2).
|
||||
WithShowHelp(false).
|
||||
WithTheme(theme).
|
||||
WithShowErrors(false)
|
||||
|
@ -132,25 +206,22 @@ func newPermissionDialogCmp(permission permission.PermissionRequest, content str
|
|||
selectOption.Focus()
|
||||
|
||||
return &permissionDialogCmp{
|
||||
permission: permission,
|
||||
form: form,
|
||||
content: content,
|
||||
width: permissionWidth,
|
||||
height: permissionHeight,
|
||||
permission: permission,
|
||||
form: form,
|
||||
selectOption: selectOption,
|
||||
}
|
||||
}
|
||||
|
||||
// NewPermissionDialogCmd creates a new permission dialog command
|
||||
func NewPermissionDialogCmd(permission permission.PermissionRequest, content string) tea.Cmd {
|
||||
permDialog := newPermissionDialogCmp(permission, content)
|
||||
func NewPermissionDialogCmd(permission permission.PermissionRequest) tea.Cmd {
|
||||
permDialog := newPermissionDialogCmp(permission)
|
||||
|
||||
// Create the dialog layout
|
||||
dialogPane := layout.NewSinglePane(
|
||||
permDialog.(*permissionDialogCmp),
|
||||
layout.WithSignlePaneSize(permissionWidth+2, permissionHeight+2),
|
||||
layout.WithSinglePaneBordered(true),
|
||||
layout.WithSinglePaneFocusable(true),
|
||||
layout.WithSinglePaneActiveColor(styles.Blue),
|
||||
layout.WithSinglePaneActiveColor(styles.Warning),
|
||||
layout.WithSignlePaneBorderText(map[layout.BorderPosition]string{
|
||||
layout.TopMiddleBorder: " Permission Required ",
|
||||
}),
|
||||
|
@ -158,10 +229,24 @@ func NewPermissionDialogCmd(permission permission.PermissionRequest, content str
|
|||
|
||||
// Focus the dialog
|
||||
dialogPane.Focus()
|
||||
widthRatio := 0.7
|
||||
heightRatio := 0.6
|
||||
minWidth := 100
|
||||
minHeight := 30
|
||||
|
||||
switch permission.ToolName {
|
||||
case tools.BashToolName:
|
||||
widthRatio = 0.5
|
||||
heightRatio = 0.3
|
||||
minWidth = 80
|
||||
minHeight = 20
|
||||
}
|
||||
// Return the dialog command
|
||||
return util.CmdHandler(core.DialogMsg{
|
||||
Content: dialogPane,
|
||||
Content: dialogPane,
|
||||
WidthRatio: widthRatio,
|
||||
HeightRatio: heightRatio,
|
||||
MinWidth: minWidth,
|
||||
MinHeight: minHeight,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ package dialog
|
|||
import (
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/components/core"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/layout"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/styles"
|
||||
|
@ -14,11 +13,6 @@ import (
|
|||
|
||||
const question = "Are you sure you want to quit?"
|
||||
|
||||
var (
|
||||
width = lipgloss.Width(question) + 6
|
||||
height = 3
|
||||
)
|
||||
|
||||
type QuitDialog interface {
|
||||
tea.Model
|
||||
layout.Sizeable
|
||||
|
@ -37,8 +31,6 @@ func (q *quitDialogCmp) Init() tea.Cmd {
|
|||
|
||||
func (q *quitDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmds []tea.Cmd
|
||||
|
||||
// Process the form
|
||||
form, cmd := q.form.Update(msg)
|
||||
if f, ok := form.(*huh.Form); ok {
|
||||
q.form = f
|
||||
|
@ -67,6 +59,7 @@ func (q *quitDialogCmp) GetSize() (int, int) {
|
|||
func (q *quitDialogCmp) SetSize(width int, height int) {
|
||||
q.width = width
|
||||
q.height = height
|
||||
q.form = q.form.WithWidth(width).WithHeight(height)
|
||||
}
|
||||
|
||||
func (q *quitDialogCmp) BindingKeys() []key.Binding {
|
||||
|
@ -84,28 +77,30 @@ func newQuitDialogCmp() QuitDialog {
|
|||
theme.Focused.FocusedButton = theme.Focused.FocusedButton.Background(styles.Warning)
|
||||
theme.Blurred.FocusedButton = theme.Blurred.FocusedButton.Background(styles.Warning)
|
||||
form := huh.NewForm(huh.NewGroup(confirm)).
|
||||
WithWidth(width).
|
||||
WithHeight(height).
|
||||
WithShowHelp(false).
|
||||
WithWidth(0).
|
||||
WithHeight(0).
|
||||
WithTheme(theme).
|
||||
WithShowErrors(false)
|
||||
confirm.Focus()
|
||||
return &quitDialogCmp{
|
||||
form: form,
|
||||
width: width,
|
||||
form: form,
|
||||
}
|
||||
}
|
||||
|
||||
func NewQuitDialogCmd() tea.Cmd {
|
||||
content := layout.NewSinglePane(
|
||||
newQuitDialogCmp().(*quitDialogCmp),
|
||||
layout.WithSignlePaneSize(width+2, height+2),
|
||||
layout.WithSinglePaneBordered(true),
|
||||
layout.WithSinglePaneFocusable(true),
|
||||
layout.WithSinglePaneActiveColor(styles.Warning),
|
||||
)
|
||||
content.Focus()
|
||||
return util.CmdHandler(core.DialogMsg{
|
||||
Content: content,
|
||||
Content: content,
|
||||
WidthRatio: 0.2,
|
||||
HeightRatio: 0.1,
|
||||
MinWidth: 40,
|
||||
MinHeight: 5,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,108 +0,0 @@
|
|||
package messages
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/layout"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/styles"
|
||||
)
|
||||
|
||||
const (
|
||||
maxHeight = 10
|
||||
)
|
||||
|
||||
type MessagesCmp interface {
|
||||
tea.Model
|
||||
layout.Focusable
|
||||
layout.Bordered
|
||||
layout.Sizeable
|
||||
}
|
||||
|
||||
type messageCmp struct {
|
||||
message message.Message
|
||||
width int
|
||||
height int
|
||||
focused bool
|
||||
expanded bool
|
||||
}
|
||||
|
||||
func (m *messageCmp) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *messageCmp) Update(tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *messageCmp) View() string {
|
||||
wrapper := layout.NewSinglePane(
|
||||
m,
|
||||
layout.WithSinglePaneBordered(true),
|
||||
layout.WithSinglePaneFocusable(true),
|
||||
layout.WithSinglePanePadding(1),
|
||||
layout.WithSinglePaneActiveColor(m.borderColor()),
|
||||
)
|
||||
if m.focused {
|
||||
wrapper.Focus()
|
||||
}
|
||||
wrapper.SetSize(m.width, m.height)
|
||||
return wrapper.View()
|
||||
}
|
||||
|
||||
func (m *messageCmp) Blur() tea.Cmd {
|
||||
m.focused = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *messageCmp) borderColor() lipgloss.TerminalColor {
|
||||
switch m.message.MessageData.Role {
|
||||
case schema.Assistant:
|
||||
return styles.Mauve
|
||||
case schema.User:
|
||||
return styles.Flamingo
|
||||
}
|
||||
return styles.Blue
|
||||
}
|
||||
|
||||
func (m *messageCmp) BorderText() map[layout.BorderPosition]string {
|
||||
role := ""
|
||||
icon := ""
|
||||
switch m.message.MessageData.Role {
|
||||
case schema.Assistant:
|
||||
role = "Assistant"
|
||||
icon = styles.BotIcon
|
||||
case schema.User:
|
||||
role = "User"
|
||||
icon = styles.UserIcon
|
||||
}
|
||||
return map[layout.BorderPosition]string{
|
||||
layout.TopLeftBorder: fmt.Sprintf("%s %s ", role, icon),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *messageCmp) Focus() tea.Cmd {
|
||||
m.focused = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *messageCmp) IsFocused() bool {
|
||||
return m.focused
|
||||
}
|
||||
|
||||
func (m *messageCmp) GetSize() (int, int) {
|
||||
return m.width, 0
|
||||
}
|
||||
|
||||
func (m *messageCmp) SetSize(width int, height int) {
|
||||
m.width = width
|
||||
}
|
||||
|
||||
func NewMessageCmp(msg message.Message) MessagesCmp {
|
||||
return &messageCmp{
|
||||
message: msg,
|
||||
}
|
||||
}
|
|
@ -6,10 +6,11 @@ import (
|
|||
"github.com/charmbracelet/bubbles/key"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/kujtimiihoxha/termai/internal/app"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/agent"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/layout"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/styles"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/util"
|
||||
"github.com/kujtimiihoxha/vimtea"
|
||||
)
|
||||
|
||||
|
@ -112,7 +113,7 @@ func (m *editorCmp) BorderText() map[layout.BorderPosition]string {
|
|||
title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title)
|
||||
}
|
||||
return map[layout.BorderPosition]string{
|
||||
layout.TopLeftBorder: title,
|
||||
layout.BottomLeftBorder: title,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -137,9 +138,15 @@ func (m *editorCmp) SetSize(width int, height int) {
|
|||
|
||||
func (m *editorCmp) Send() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
messages, _ := m.app.Messages.List(m.sessionID)
|
||||
if hasUnfinishedMessages(messages) {
|
||||
return util.InfoMsg("Assistant is still working on the previous message")
|
||||
}
|
||||
a, _ := agent.NewCoderAgent(m.app)
|
||||
|
||||
content := strings.Join(m.editor.GetBuffer().Lines(), "\n")
|
||||
m.app.Messages.Create(m.sessionID, *schema.UserMessage(content))
|
||||
m.app.LLM.SendRequest(m.sessionID, content)
|
||||
go a.Generate(m.sessionID, content)
|
||||
|
||||
return m.editor.Reset()
|
||||
}
|
||||
}
|
||||
|
@ -153,10 +160,11 @@ func (m *editorCmp) BindingKeys() []key.Binding {
|
|||
}
|
||||
|
||||
func NewEditorCmp(app *app.App) EditorCmp {
|
||||
editor := vimtea.NewEditor(
|
||||
vimtea.WithFileName("message.md"),
|
||||
)
|
||||
return &editorCmp{
|
||||
app: app,
|
||||
editor: vimtea.NewEditor(
|
||||
vimtea.WithFileName("message.md"),
|
||||
),
|
||||
app: app,
|
||||
editor: editor,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
package repl
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
|
@ -10,8 +11,8 @@ import (
|
|||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/glamour"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/kujtimiihoxha/termai/internal/app"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm/agent"
|
||||
"github.com/kujtimiihoxha/termai/internal/message"
|
||||
"github.com/kujtimiihoxha/termai/internal/pubsub"
|
||||
"github.com/kujtimiihoxha/termai/internal/session"
|
||||
|
@ -28,30 +29,50 @@ type MessagesCmp interface {
|
|||
}
|
||||
|
||||
type messagesCmp struct {
|
||||
app *app.App
|
||||
messages []message.Message
|
||||
session session.Session
|
||||
viewport viewport.Model
|
||||
mdRenderer *glamour.TermRenderer
|
||||
width int
|
||||
height int
|
||||
focused bool
|
||||
cachedView string
|
||||
app *app.App
|
||||
messages []message.Message
|
||||
selectedMsgIdx int // Index of the selected message
|
||||
session session.Session
|
||||
viewport viewport.Model
|
||||
mdRenderer *glamour.TermRenderer
|
||||
width int
|
||||
height int
|
||||
focused bool
|
||||
cachedView string
|
||||
}
|
||||
|
||||
func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case pubsub.Event[message.Message]:
|
||||
if msg.Type == pubsub.CreatedEvent {
|
||||
m.messages = append(m.messages, msg.Payload)
|
||||
m.renderView()
|
||||
m.viewport.GotoBottom()
|
||||
if msg.Payload.SessionID == m.session.ID {
|
||||
m.messages = append(m.messages, msg.Payload)
|
||||
m.renderView()
|
||||
m.viewport.GotoBottom()
|
||||
}
|
||||
for _, v := range m.messages {
|
||||
for _, c := range v.ToolCalls {
|
||||
if c.ID == msg.Payload.SessionID {
|
||||
m.renderView()
|
||||
m.viewport.GotoBottom()
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID {
|
||||
for i, v := range m.messages {
|
||||
if v.ID == msg.Payload.ID {
|
||||
m.messages[i] = msg.Payload
|
||||
m.renderView()
|
||||
if i == len(m.messages)-1 {
|
||||
m.viewport.GotoBottom()
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
case pubsub.Event[session.Session]:
|
||||
if msg.Type == pubsub.UpdatedEvent {
|
||||
if m.session.ID == msg.Payload.ID {
|
||||
m.session = msg.Payload
|
||||
}
|
||||
if msg.Type == pubsub.UpdatedEvent && m.session.ID == msg.Payload.ID {
|
||||
m.session = msg.Payload
|
||||
}
|
||||
case SelectedSessionMsg:
|
||||
m.session, _ = m.app.Sessions.Get(msg.SessionID)
|
||||
|
@ -67,26 +88,24 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
return m, nil
|
||||
}
|
||||
|
||||
func borderColor(role schema.RoleType) lipgloss.TerminalColor {
|
||||
func borderColor(role message.MessageRole) lipgloss.TerminalColor {
|
||||
switch role {
|
||||
case schema.Assistant:
|
||||
case message.Assistant:
|
||||
return styles.Mauve
|
||||
case schema.User:
|
||||
case message.User:
|
||||
return styles.Rosewater
|
||||
case schema.Tool:
|
||||
return styles.Peach
|
||||
}
|
||||
return styles.Blue
|
||||
}
|
||||
|
||||
func borderText(msgRole schema.RoleType, currentMessage int) map[layout.BorderPosition]string {
|
||||
func borderText(msgRole message.MessageRole, currentMessage int) map[layout.BorderPosition]string {
|
||||
role := ""
|
||||
icon := ""
|
||||
switch msgRole {
|
||||
case schema.Assistant:
|
||||
case message.Assistant:
|
||||
role = "Assistant"
|
||||
icon = styles.BotIcon
|
||||
case schema.User:
|
||||
case message.User:
|
||||
role = "User"
|
||||
icon = styles.UserIcon
|
||||
}
|
||||
|
@ -106,81 +125,259 @@ func borderText(msgRole schema.RoleType, currentMessage int) map[layout.BorderPo
|
|||
}
|
||||
}
|
||||
|
||||
func hasUnfinishedMessages(messages []message.Message) bool {
|
||||
if len(messages) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, msg := range messages {
|
||||
if !msg.Finished {
|
||||
return true
|
||||
}
|
||||
}
|
||||
lastMessage := messages[len(messages)-1]
|
||||
return lastMessage.Role != message.Assistant
|
||||
}
|
||||
|
||||
func (m *messagesCmp) renderMessageWithToolCall(content string, tools []message.ToolCall, futureMessages []message.Message) string {
|
||||
allParts := []string{content}
|
||||
|
||||
leftPaddingValue := 4
|
||||
connectorStyle := lipgloss.NewStyle().
|
||||
Foreground(styles.Peach).
|
||||
Bold(true)
|
||||
|
||||
toolCallStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(styles.Peach).
|
||||
Width(m.width-leftPaddingValue-5).
|
||||
Padding(0, 1)
|
||||
|
||||
toolResultStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(styles.Green).
|
||||
Width(m.width-leftPaddingValue-5).
|
||||
Padding(0, 1)
|
||||
|
||||
leftPadding := lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue)
|
||||
|
||||
runningStyle := lipgloss.NewStyle().
|
||||
Foreground(styles.Peach).
|
||||
Bold(true)
|
||||
|
||||
renderTool := func(toolCall message.ToolCall) string {
|
||||
toolHeader := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(styles.Blue).
|
||||
Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name))
|
||||
|
||||
var paramLines []string
|
||||
var args map[string]interface{}
|
||||
var paramOrder []string
|
||||
|
||||
json.Unmarshal([]byte(toolCall.Input), &args)
|
||||
|
||||
for key := range args {
|
||||
paramOrder = append(paramOrder, key)
|
||||
}
|
||||
sort.Strings(paramOrder)
|
||||
|
||||
for _, name := range paramOrder {
|
||||
value := args[name]
|
||||
paramName := lipgloss.NewStyle().
|
||||
Foreground(styles.Peach).
|
||||
Bold(true).
|
||||
Render(name)
|
||||
|
||||
truncate := m.width - leftPaddingValue*2 - 10
|
||||
if len(fmt.Sprintf("%v", value)) > truncate {
|
||||
value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
|
||||
}
|
||||
paramValue := fmt.Sprintf("%v", value)
|
||||
paramLines = append(paramLines, fmt.Sprintf(" %s: %s", paramName, paramValue))
|
||||
}
|
||||
|
||||
paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...)
|
||||
|
||||
toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock)
|
||||
return toolCallStyle.Render(toolContent)
|
||||
}
|
||||
|
||||
findToolResult := func(toolCallID string, messages []message.Message) *message.ToolResult {
|
||||
for _, msg := range messages {
|
||||
if msg.Role == message.Tool {
|
||||
for _, result := range msg.ToolResults {
|
||||
if result.ToolCallID == toolCallID {
|
||||
return &result
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
renderToolResult := func(result message.ToolResult) string {
|
||||
resultHeader := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(styles.Green).
|
||||
Render(fmt.Sprintf("%s Result", styles.CheckIcon))
|
||||
if result.IsError {
|
||||
resultHeader = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(styles.Red).
|
||||
Render(fmt.Sprintf("%s Error", styles.ErrorIcon))
|
||||
}
|
||||
|
||||
truncate := 200
|
||||
content := result.Content
|
||||
if len(content) > truncate {
|
||||
content = content[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
|
||||
}
|
||||
|
||||
resultContent := lipgloss.JoinVertical(lipgloss.Left, resultHeader, content)
|
||||
return toolResultStyle.Render(resultContent)
|
||||
}
|
||||
|
||||
connector := connectorStyle.Render("└─> Tool Calls:")
|
||||
allParts = append(allParts, connector)
|
||||
|
||||
for _, toolCall := range tools {
|
||||
toolOutput := renderTool(toolCall)
|
||||
allParts = append(allParts, leftPadding.Render(toolOutput))
|
||||
|
||||
result := findToolResult(toolCall.ID, futureMessages)
|
||||
if result != nil {
|
||||
|
||||
resultOutput := renderToolResult(*result)
|
||||
allParts = append(allParts, leftPadding.Render(resultOutput))
|
||||
|
||||
} else if toolCall.Name == agent.AgentToolName {
|
||||
|
||||
runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
|
||||
allParts = append(allParts, leftPadding.Render(runningIndicator))
|
||||
taskSessionMessages, _ := m.app.Messages.List(toolCall.ID)
|
||||
for _, msg := range taskSessionMessages {
|
||||
if msg.Role == message.Assistant {
|
||||
for _, toolCall := range msg.ToolCalls {
|
||||
toolHeader := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(styles.Blue).
|
||||
Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name))
|
||||
|
||||
var paramLines []string
|
||||
var args map[string]interface{}
|
||||
var paramOrder []string
|
||||
|
||||
json.Unmarshal([]byte(toolCall.Input), &args)
|
||||
|
||||
for key := range args {
|
||||
paramOrder = append(paramOrder, key)
|
||||
}
|
||||
sort.Strings(paramOrder)
|
||||
|
||||
for _, name := range paramOrder {
|
||||
value := args[name]
|
||||
paramName := lipgloss.NewStyle().
|
||||
Foreground(styles.Peach).
|
||||
Bold(true).
|
||||
Render(name)
|
||||
|
||||
truncate := 50
|
||||
if len(fmt.Sprintf("%v", value)) > truncate {
|
||||
value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
|
||||
}
|
||||
paramValue := fmt.Sprintf("%v", value)
|
||||
paramLines = append(paramLines, fmt.Sprintf(" %s: %s", paramName, paramValue))
|
||||
}
|
||||
|
||||
paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...)
|
||||
toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock)
|
||||
toolOutput := toolCallStyle.BorderForeground(styles.Teal).MaxWidth(m.width - leftPaddingValue*2 - 2).Render(toolContent)
|
||||
allParts = append(allParts, lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue*2).Render(toolOutput))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
|
||||
allParts = append(allParts, " "+runningIndicator)
|
||||
}
|
||||
}
|
||||
|
||||
for _, msg := range futureMessages {
|
||||
if msg.Content != "" {
|
||||
break
|
||||
}
|
||||
|
||||
for _, toolCall := range msg.ToolCalls {
|
||||
toolOutput := renderTool(toolCall)
|
||||
allParts = append(allParts, " "+strings.ReplaceAll(toolOutput, "\n", "\n "))
|
||||
|
||||
result := findToolResult(toolCall.ID, futureMessages)
|
||||
if result != nil {
|
||||
resultOutput := renderToolResult(*result)
|
||||
allParts = append(allParts, " "+strings.ReplaceAll(resultOutput, "\n", "\n "))
|
||||
} else {
|
||||
runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
|
||||
allParts = append(allParts, " "+runningIndicator)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return lipgloss.JoinVertical(lipgloss.Left, allParts...)
|
||||
}
|
||||
|
||||
func (m *messagesCmp) renderView() {
|
||||
stringMessages := make([]string, 0)
|
||||
r, _ := glamour.NewTermRenderer(
|
||||
glamour.WithStyles(styles.CatppuccinMarkdownStyle()),
|
||||
glamour.WithWordWrap(m.width-10),
|
||||
glamour.WithWordWrap(m.width-20),
|
||||
glamour.WithEmoji(),
|
||||
)
|
||||
textStyle := lipgloss.NewStyle().Width(m.width - 4)
|
||||
currentMessage := 1
|
||||
for _, msg := range m.messages {
|
||||
if msg.MessageData.Role == schema.Tool {
|
||||
continue
|
||||
}
|
||||
content := msg.MessageData.Content
|
||||
if content != "" {
|
||||
content, _ = r.Render(msg.MessageData.Content)
|
||||
stringMessages = append(stringMessages, layout.Borderize(
|
||||
textStyle.Render(content),
|
||||
layout.BorderOptions{
|
||||
InactiveBorder: lipgloss.DoubleBorder(),
|
||||
ActiveBorder: lipgloss.DoubleBorder(),
|
||||
ActiveColor: borderColor(msg.MessageData.Role),
|
||||
InactiveColor: borderColor(msg.MessageData.Role),
|
||||
EmbeddedText: borderText(msg.MessageData.Role, currentMessage),
|
||||
},
|
||||
))
|
||||
currentMessage++
|
||||
}
|
||||
for _, toolCall := range msg.MessageData.ToolCalls {
|
||||
resultInx := slices.IndexFunc(m.messages, func(m message.Message) bool {
|
||||
return m.MessageData.ToolCallID == toolCall.ID
|
||||
})
|
||||
content := fmt.Sprintf("**Arguments**\n```json\n%s\n```\n", toolCall.Function.Arguments)
|
||||
if resultInx == -1 {
|
||||
content += "Running..."
|
||||
} else {
|
||||
result := m.messages[resultInx].MessageData.Content
|
||||
if result != "" {
|
||||
lines := strings.Split(result, "\n")
|
||||
if len(lines) > 15 {
|
||||
result = strings.Join(lines[:15], "\n")
|
||||
}
|
||||
content += fmt.Sprintf("**Result**\n```\n%s\n```\n", result)
|
||||
if len(lines) > 15 {
|
||||
content += fmt.Sprintf("\n\n *...%d lines are truncated* ", len(lines)-15)
|
||||
}
|
||||
}
|
||||
displayedMsgCount := 0 // Track the actual displayed messages count
|
||||
|
||||
prevMessageWasUser := false
|
||||
for inx, msg := range m.messages {
|
||||
content := msg.Content
|
||||
if content != "" || prevMessageWasUser {
|
||||
if msg.Thinking != "" && content == "" {
|
||||
content = msg.Thinking
|
||||
} else if content == "" {
|
||||
content = "..."
|
||||
}
|
||||
content, _ = r.Render(content)
|
||||
stringMessages = append(stringMessages, layout.Borderize(
|
||||
|
||||
isSelected := inx == m.selectedMsgIdx
|
||||
|
||||
border := lipgloss.DoubleBorder()
|
||||
activeColor := borderColor(msg.Role)
|
||||
|
||||
if isSelected {
|
||||
activeColor = styles.Primary // Use primary color for selected message
|
||||
}
|
||||
|
||||
content = layout.Borderize(
|
||||
textStyle.Render(content),
|
||||
layout.BorderOptions{
|
||||
InactiveBorder: lipgloss.DoubleBorder(),
|
||||
ActiveBorder: lipgloss.DoubleBorder(),
|
||||
ActiveColor: borderColor(schema.Tool),
|
||||
InactiveColor: borderColor(schema.Tool),
|
||||
EmbeddedText: map[layout.BorderPosition]string{
|
||||
layout.TopLeftBorder: lipgloss.NewStyle().
|
||||
Padding(0, 1).
|
||||
Bold(true).
|
||||
Foreground(styles.Crust).
|
||||
Background(borderColor(schema.Tool)).
|
||||
Render(
|
||||
fmt.Sprintf("Tool [%s] %s ", toolCall.Function.Name, styles.ToolIcon),
|
||||
),
|
||||
layout.TopRightBorder: lipgloss.NewStyle().
|
||||
Padding(0, 1).
|
||||
Bold(true).
|
||||
Foreground(styles.Crust).
|
||||
Background(borderColor(schema.Tool)).
|
||||
Render(fmt.Sprintf("#%d ", currentMessage)),
|
||||
},
|
||||
InactiveBorder: border,
|
||||
ActiveBorder: border,
|
||||
ActiveColor: activeColor,
|
||||
InactiveColor: borderColor(msg.Role),
|
||||
EmbeddedText: borderText(msg.Role, currentMessage),
|
||||
},
|
||||
))
|
||||
)
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
content = m.renderMessageWithToolCall(content, msg.ToolCalls, m.messages[inx+1:])
|
||||
}
|
||||
stringMessages = append(stringMessages, content)
|
||||
currentMessage++
|
||||
displayedMsgCount++
|
||||
}
|
||||
if msg.Role == message.User && msg.Content != "" {
|
||||
prevMessageWasUser = true
|
||||
} else {
|
||||
prevMessageWasUser = false
|
||||
}
|
||||
}
|
||||
m.viewport.SetContent(lipgloss.JoinVertical(lipgloss.Top, stringMessages...))
|
||||
|
@ -191,7 +388,9 @@ func (m *messagesCmp) View() string {
|
|||
}
|
||||
|
||||
func (m *messagesCmp) BindingKeys() []key.Binding {
|
||||
return layout.KeyMapToSlice(m.viewport.KeyMap)
|
||||
keys := layout.KeyMapToSlice(m.viewport.KeyMap)
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
func (m *messagesCmp) Blur() tea.Cmd {
|
||||
|
@ -208,10 +407,17 @@ func (m *messagesCmp) BorderText() map[layout.BorderPosition]string {
|
|||
if m.focused {
|
||||
title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title)
|
||||
}
|
||||
return map[layout.BorderPosition]string{
|
||||
borderTest := map[layout.BorderPosition]string{
|
||||
layout.TopLeftBorder: title,
|
||||
layout.BottomRightBorder: formatTokensAndCost(m.session.CompletionTokens+m.session.PromptTokens, m.session.Cost),
|
||||
}
|
||||
if hasUnfinishedMessages(m.messages) {
|
||||
borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Peach).Render("Thinking...")
|
||||
} else {
|
||||
borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Text).Render("Sleeping " + styles.SleepIcon + " ")
|
||||
}
|
||||
|
||||
return borderTest
|
||||
}
|
||||
|
||||
func (m *messagesCmp) Focus() tea.Cmd {
|
||||
|
@ -232,6 +438,7 @@ func (m *messagesCmp) SetSize(width int, height int) {
|
|||
m.height = height
|
||||
m.viewport.Width = width - 2 // padding
|
||||
m.viewport.Height = height - 2 // padding
|
||||
m.renderView()
|
||||
}
|
||||
|
||||
func (m *messagesCmp) Init() tea.Cmd {
|
||||
|
|
|
@ -89,7 +89,23 @@ func (i *sessionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
}
|
||||
return i, i.list.SetItems(items)
|
||||
case pubsub.Event[session.Session]:
|
||||
if msg.Type == pubsub.UpdatedEvent {
|
||||
if msg.Type == pubsub.CreatedEvent && msg.Payload.ParentSessionID == "" {
|
||||
// Check if the session is already in the list
|
||||
items := i.list.Items()
|
||||
for _, item := range items {
|
||||
s := item.(listItem)
|
||||
if s.id == msg.Payload.ID {
|
||||
return i, nil
|
||||
}
|
||||
}
|
||||
// insert the new session at the top of the list
|
||||
items = append([]list.Item{listItem{
|
||||
id: msg.Payload.ID,
|
||||
title: msg.Payload.Title,
|
||||
desc: formatTokensAndCost(msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost),
|
||||
}}, items...)
|
||||
return i, i.list.SetItems(items)
|
||||
} else if msg.Type == pubsub.UpdatedEvent {
|
||||
// update the session in the list
|
||||
items := i.list.Items()
|
||||
for idx, item := range items {
|
||||
|
@ -229,3 +245,4 @@ func NewSessionsCmp(app *app.App) SessionsCmp {
|
|||
focused: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -78,12 +78,12 @@ func (i *initPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
// Save configuration to file
|
||||
configPath := filepath.Join(os.Getenv("HOME"), ".termai.yaml")
|
||||
maxTokens, _ := strconv.Atoi(i.maxTokens)
|
||||
config := map[string]interface{}{
|
||||
config := map[string]any{
|
||||
"models": map[string]string{
|
||||
"big": i.bigModel,
|
||||
"small": i.smallModel,
|
||||
},
|
||||
"providers": map[string]interface{}{
|
||||
"providers": map[string]any{
|
||||
"openai": map[string]string{
|
||||
"key": i.openAIKey,
|
||||
},
|
||||
|
@ -192,8 +192,8 @@ func NewInitPage() tea.Model {
|
|||
// Init page with form
|
||||
initModel := &initPage{
|
||||
modelOpts: modelOpts,
|
||||
bigModel: string(models.DefaultBigModel),
|
||||
smallModel: string(models.DefaultLittleModel),
|
||||
bigModel: string(models.Claude37Sonnet),
|
||||
smallModel: string(models.Claude37Sonnet),
|
||||
maxTokens: "4000",
|
||||
dataDir: ".termai",
|
||||
agent: "coder",
|
||||
|
|
|
@ -8,5 +8,9 @@ const (
|
|||
ToolIcon string = ""
|
||||
UserIcon string = ""
|
||||
|
||||
CheckIcon string = "✓"
|
||||
ErrorIcon string = "✗"
|
||||
SpinnerIcon string = "..."
|
||||
|
||||
SleepIcon string = ""
|
||||
)
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
const defaultMargin = 2
|
||||
const defaultMargin = 1
|
||||
|
||||
// Helper functions for style pointers
|
||||
func boolPtr(b bool) *bool { return &b }
|
||||
|
@ -25,7 +25,7 @@ var catppuccinDark = ansi.StyleConfig{
|
|||
Document: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
BlockPrefix: "\n",
|
||||
BlockSuffix: "\n",
|
||||
BlockSuffix: "",
|
||||
Color: stringPtr(dark.Text().Hex),
|
||||
},
|
||||
Margin: uintPtr(defaultMargin),
|
||||
|
@ -153,7 +153,7 @@ var catppuccinDark = ansi.StyleConfig{
|
|||
CodeBlock: ansi.StyleCodeBlock{
|
||||
StyleBlock: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Prefix: " ",
|
||||
Prefix: " ",
|
||||
Color: stringPtr(dark.Text().Hex),
|
||||
},
|
||||
|
||||
|
|
|
@ -20,8 +20,7 @@ var (
|
|||
DoubleBorder = Regular.Border(lipgloss.DoubleBorder())
|
||||
|
||||
// Colors
|
||||
White = lipgloss.Color("#ffffff")
|
||||
|
||||
White = lipgloss.Color("#ffffff")
|
||||
Surface0 = lipgloss.AdaptiveColor{
|
||||
Dark: dark.Surface0().Hex,
|
||||
Light: light.Surface0().Hex,
|
||||
|
|
|
@ -1,20 +1,15 @@
|
|||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/kujtimiihoxha/termai/internal/app"
|
||||
"github.com/kujtimiihoxha/termai/internal/llm"
|
||||
"github.com/kujtimiihoxha/termai/internal/permission"
|
||||
"github.com/kujtimiihoxha/termai/internal/pubsub"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/components/core"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/components/dialog"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/components/repl"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/layout"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/page"
|
||||
"github.com/kujtimiihoxha/termai/internal/tui/util"
|
||||
|
@ -52,9 +47,9 @@ var keys = keyMap{
|
|||
),
|
||||
}
|
||||
|
||||
var editorKeyMap = key.NewBinding(
|
||||
key.WithKeys("i"),
|
||||
key.WithHelp("i", "insert mode"),
|
||||
var replKeyMap = key.NewBinding(
|
||||
key.WithKeys("N"),
|
||||
key.WithHelp("N", "new session"),
|
||||
)
|
||||
|
||||
type appModel struct {
|
||||
|
@ -66,6 +61,7 @@ type appModel struct {
|
|||
status tea.Model
|
||||
help core.HelpCmp
|
||||
dialog core.DialogCmp
|
||||
app *app.App
|
||||
dialogVisible bool
|
||||
editorMode vimtea.EditorMode
|
||||
showHelp bool
|
||||
|
@ -79,19 +75,8 @@ func (a appModel) Init() tea.Cmd {
|
|||
|
||||
func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case pubsub.Event[llm.AgentEvent]:
|
||||
log.Println("AgentEvent")
|
||||
log.Println(msg)
|
||||
case pubsub.Event[permission.PermissionRequest]:
|
||||
return a, dialog.NewPermissionDialogCmd(
|
||||
msg.Payload,
|
||||
fmt.Sprintf(
|
||||
"Tool: %s\nAction: %s\nParams: %v",
|
||||
msg.Payload.ToolName,
|
||||
msg.Payload.Action,
|
||||
msg.Payload.Params,
|
||||
),
|
||||
)
|
||||
return a, dialog.NewPermissionDialogCmd(msg.Payload)
|
||||
case dialog.PermissionResponseMsg:
|
||||
switch msg.Action {
|
||||
case dialog.PermissionAllow:
|
||||
|
@ -104,6 +89,7 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
case vimtea.EditorModeMsg:
|
||||
a.editorMode = msg.Mode
|
||||
case tea.WindowSizeMsg:
|
||||
var cmds []tea.Cmd
|
||||
msg.Height -= 1 // Make space for the status bar
|
||||
a.width, a.height = msg.Width, msg.Height
|
||||
|
||||
|
@ -113,8 +99,14 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
a.help = uh.(core.HelpCmp)
|
||||
|
||||
p, cmd := a.pages[a.currentPage].Update(msg)
|
||||
cmds = append(cmds, cmd)
|
||||
a.pages[a.currentPage] = p
|
||||
return a, cmd
|
||||
|
||||
d, cmd := a.dialog.Update(msg)
|
||||
cmds = append(cmds, cmd)
|
||||
a.dialog = d.(core.DialogCmp)
|
||||
|
||||
return a, tea.Batch(cmds...)
|
||||
case core.DialogMsg:
|
||||
d, cmd := a.dialog.Update(msg)
|
||||
a.dialog = d.(core.DialogCmp)
|
||||
|
@ -145,6 +137,22 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
a.ToggleHelp()
|
||||
return a, nil
|
||||
}
|
||||
case key.Matches(msg, replKeyMap):
|
||||
if a.currentPage == page.ReplPage {
|
||||
sessions, err := a.app.Sessions.List()
|
||||
if err != nil {
|
||||
return a, util.CmdHandler(util.ErrorMsg(err))
|
||||
}
|
||||
lastSession := sessions[0]
|
||||
if lastSession.MessageCount == 0 {
|
||||
return a, util.CmdHandler(repl.SelectedSessionMsg{SessionID: lastSession.ID})
|
||||
}
|
||||
s, err := a.app.Sessions.Create("New Session")
|
||||
if err != nil {
|
||||
return a, util.CmdHandler(util.ErrorMsg(err))
|
||||
}
|
||||
return a, util.CmdHandler(repl.SelectedSessionMsg{SessionID: s.ID})
|
||||
}
|
||||
case key.Matches(msg, keys.Logs):
|
||||
return a, a.moveToPage(page.LogsPage)
|
||||
case key.Matches(msg, keys.Help):
|
||||
|
@ -205,6 +213,9 @@ func (a appModel) View() string {
|
|||
if a.dialogVisible {
|
||||
bindings = append(bindings, a.dialog.BindingKeys()...)
|
||||
}
|
||||
if a.currentPage == page.ReplPage {
|
||||
bindings = append(bindings, replKeyMap)
|
||||
}
|
||||
a.help.SetBindings(bindings)
|
||||
components = append(components, a.help.View())
|
||||
}
|
||||
|
@ -231,14 +242,13 @@ func (a appModel) View() string {
|
|||
}
|
||||
|
||||
func New(app *app.App) tea.Model {
|
||||
// Check if config file exists, if not, start with init page
|
||||
homedir, _ := os.UserHomeDir()
|
||||
configPath := filepath.Join(homedir, ".termai.yaml")
|
||||
|
||||
// homedir, _ := os.UserHomeDir()
|
||||
// configPath := filepath.Join(homedir, ".termai.yaml")
|
||||
//
|
||||
startPage := page.ReplPage
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
startPage = page.InitPage
|
||||
}
|
||||
// if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
// startPage = page.InitPage
|
||||
// }
|
||||
|
||||
return &appModel{
|
||||
currentPage: startPage,
|
||||
|
@ -246,6 +256,7 @@ func New(app *app.App) tea.Model {
|
|||
status: core.NewStatusCmp(),
|
||||
help: core.NewHelpCmp(),
|
||||
dialog: core.NewDialogCmp(),
|
||||
app: app,
|
||||
pages: map[page.PageID]tea.Model{
|
||||
page.LogsPage: page.NewLogsPage(),
|
||||
page.InitPage: page.NewInitPage(),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue