initial tool call stream

This commit is contained in:
Kujtim Hoxha 2025-04-19 16:35:45 +02:00
parent 2b5a33e476
commit 2de5127417
11 changed files with 261 additions and 136 deletions

View file

@ -380,6 +380,21 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
case provider.EventContentDelta:
assistantMsg.AppendContent(event.Content)
return a.messages.Update(ctx, *assistantMsg)
case provider.EventToolUseStart:
assistantMsg.AddToolCall(*event.ToolCall)
return a.messages.Update(ctx, *assistantMsg)
// TODO: see how to handle this
// case provider.EventToolUseDelta:
// tm := time.Unix(assistantMsg.UpdatedAt, 0)
// assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
// if time.Since(tm) > 1000*time.Millisecond {
// err := a.messages.Update(ctx, *assistantMsg)
// assistantMsg.UpdatedAt = time.Now().Unix()
// return err
// }
case provider.EventToolUseStop:
assistantMsg.FinishToolCall(event.ToolCall.ID)
return a.messages.Update(ctx, *assistantMsg)
case provider.EventError:
if errors.Is(event.Error, context.Canceled) {
logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
@ -456,6 +471,13 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error)
provider.WithReasoningEffort(agentConfig.ReasoningEffort),
),
)
} else if model.Provider == models.ProviderAnthropic && model.CanReason {
opts = append(
opts,
provider.WithAnthropicOptions(
provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
),
)
}
agentProvider, err := provider.NewProvider(
model.Provider,

View file

@ -93,8 +93,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
}
if len(blocks) == 0 {
logging.Warn("There is a message without content, investigate")
// This should never happend but we log this because we might have a bug in our cleanup method
logging.Warn("There is a message without content, investigate, this should not happen")
continue
}
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
@ -196,8 +195,8 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message,
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
cfg := config.Get()
if cfg.Debug {
jsonData, _ := json.Marshal(preparedMessages)
logging.Debug("Prepared messages", "messages", string(jsonData))
// jsonData, _ := json.Marshal(preparedMessages)
// logging.Debug("Prepared messages", "messages", string(jsonData))
}
attempts := 0
for {
@ -243,8 +242,8 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
cfg := config.Get()
if cfg.Debug {
jsonData, _ := json.Marshal(preparedMessages)
logging.Debug("Prepared messages", "messages", string(jsonData))
// jsonData, _ := json.Marshal(preparedMessages)
// logging.Debug("Prepared messages", "messages", string(jsonData))
}
attempts := 0
eventChan := make(chan ProviderEvent)
@ -257,6 +256,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
)
accumulatedMessage := anthropic.Message{}
currentToolCallID := ""
for anthropicStream.Next() {
event := anthropicStream.Current()
err := accumulatedMessage.Accumulate(event)
@ -267,7 +267,19 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
switch event := event.AsAny().(type) {
case anthropic.ContentBlockStartEvent:
eventChan <- ProviderEvent{Type: EventContentStart}
if event.ContentBlock.Type == "text" {
eventChan <- ProviderEvent{Type: EventContentStart}
} else if event.ContentBlock.Type == "tool_use" {
currentToolCallID = event.ContentBlock.ID
eventChan <- ProviderEvent{
Type: EventToolUseStart,
ToolCall: &message.ToolCall{
ID: event.ContentBlock.ID,
Name: event.ContentBlock.Name,
Finished: false,
},
}
}
case anthropic.ContentBlockDeltaEvent:
if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
@ -280,11 +292,30 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
Type: EventContentDelta,
Content: event.Delta.Text,
}
} else if event.Delta.Type == "input_json_delta" {
if currentToolCallID != "" {
eventChan <- ProviderEvent{
Type: EventToolUseDelta,
ToolCall: &message.ToolCall{
ID: currentToolCallID,
Finished: false,
Input: event.Delta.JSON.PartialJSON.Raw(),
},
}
}
}
// TODO: check if we can somehow stream tool calls
case anthropic.ContentBlockStopEvent:
eventChan <- ProviderEvent{Type: EventContentStop}
if currentToolCallID != "" {
eventChan <- ProviderEvent{
Type: EventToolUseStop,
ToolCall: &message.ToolCall{
ID: currentToolCallID,
},
}
currentToolCallID = ""
} else {
eventChan <- ProviderEvent{Type: EventContentStop}
}
case anthropic.MessageStopEvent:
content := ""
@ -378,10 +409,11 @@ func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
switch variant := block.AsAny().(type) {
case anthropic.ToolUseBlock:
toolCall := message.ToolCall{
ID: variant.ID,
Name: variant.Name,
Input: string(variant.Input),
Type: string(variant.Type),
ID: variant.ID,
Name: variant.Name,
Input: string(variant.Input),
Type: string(variant.Type),
Finished: true,
}
toolCalls = append(toolCalls, toolCall)
}

View file

@ -344,10 +344,11 @@ func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.Too
if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
for _, call := range completion.Choices[0].Message.ToolCalls {
toolCall := message.ToolCall{
ID: call.ID,
Name: call.Function.Name,
Input: call.Function.Arguments,
Type: "function",
ID: call.ID,
Name: call.Function.Name,
Input: call.Function.Arguments,
Type: "function",
Finished: true,
}
toolCalls = append(toolCalls, toolCall)
}

View file

@ -15,6 +15,9 @@ const maxRetries = 8
const (
EventContentStart EventType = "content_start"
EventToolUseStart EventType = "tool_use_start"
EventToolUseDelta EventType = "tool_use_delta"
EventToolUseStop EventType = "tool_use_stop"
EventContentDelta EventType = "content_delta"
EventThinkingDelta EventType = "thinking_delta"
EventContentStop EventType = "content_stop"
@ -43,8 +46,8 @@ type ProviderEvent struct {
Content string
Thinking string
Response *ProviderResponse
Error error
ToolCall *message.ToolCall
Error error
}
type Provider interface {
SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)

View file

@ -233,6 +233,40 @@ func (m *Message) AppendReasoningContent(delta string) {
}
}
func (m *Message) FinishToolCall(toolCallID string) {
for i, part := range m.Parts {
if c, ok := part.(ToolCall); ok {
if c.ID == toolCallID {
m.Parts[i] = ToolCall{
ID: c.ID,
Name: c.Name,
Input: c.Input,
Type: c.Type,
Finished: true,
}
return
}
}
}
}
func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
for i, part := range m.Parts {
if c, ok := part.(ToolCall); ok {
if c.ID == toolCallID {
m.Parts[i] = ToolCall{
ID: c.ID,
Name: c.Name,
Input: c.Input + inputDelta,
Type: c.Type,
Finished: c.Finished,
}
return
}
}
}
}
func (m *Message) AddToolCall(tc ToolCall) {
for i, part := range m.Parts {
if c, ok := part.(ToolCall); ok {
@ -246,6 +280,15 @@ func (m *Message) AddToolCall(tc ToolCall) {
}
func (m *Message) SetToolCalls(tc []ToolCall) {
// remove any existing tool call part it could have multiple
parts := make([]ContentPart, 0)
for _, part := range m.Parts {
if _, ok := part.(ToolCall); ok {
continue
}
parts = append(parts, part)
}
m.Parts = parts
for _, toolCall := range tc {
m.Parts = append(m.Parts, toolCall)
}

View file

@ -5,6 +5,7 @@ import (
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"github.com/kujtimiihoxha/opencode/internal/db"
@ -116,6 +117,7 @@ func (s *service) Update(ctx context.Context, message Message) error {
if err != nil {
return err
}
message.UpdatedAt = time.Now().Unix()
s.Publish(pubsub.UpdatedEvent, message)
return nil
}

View file

@ -7,13 +7,6 @@ import (
const bufferSize = 1024
type Logger interface {
Debug(msg string, args ...any)
Info(msg string, args ...any)
Warn(msg string, args ...any)
Error(msg string, args ...any)
}
// Broker allows clients to publish events and subscribe to events
type Broker[T any] struct {
subs map[chan Event[T]]struct{} // subscriptions

View file

@ -4,8 +4,6 @@ import (
"context"
"fmt"
"math"
"sync"
"time"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/spinner"
@ -13,7 +11,6 @@ import (
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/kujtimiihoxha/opencode/internal/app"
"github.com/kujtimiihoxha/opencode/internal/logging"
"github.com/kujtimiihoxha/opencode/internal/message"
"github.com/kujtimiihoxha/opencode/internal/pubsub"
"github.com/kujtimiihoxha/opencode/internal/session"
@ -35,89 +32,14 @@ type messagesCmp struct {
messages []message.Message
uiMessages []uiMessage
currentMsgID string
mutex sync.Mutex
cachedContent map[string]cacheItem
spinner spinner.Model
lastUpdate time.Time
rendering bool
}
type renderFinishedMsg struct{}
func (m *messagesCmp) Init() tea.Cmd {
return tea.Batch(m.viewport.Init())
}
func (m *messagesCmp) preloadSessions() tea.Cmd {
return func() tea.Msg {
m.mutex.Lock()
defer m.mutex.Unlock()
sessions, err := m.app.Sessions.List(context.Background())
if err != nil {
return util.ReportError(err)()
}
if len(sessions) == 0 {
return nil
}
if len(sessions) > 20 {
sessions = sessions[:20]
}
for _, s := range sessions {
messages, err := m.app.Messages.List(context.Background(), s.ID)
if err != nil {
return util.ReportError(err)()
}
if len(messages) == 0 {
continue
}
m.cacheSessionMessages(messages, m.width)
}
logging.Debug("preloaded sessions")
return func() tea.Msg {
return renderFinishedMsg{}
}
}
}
func (m *messagesCmp) cacheSessionMessages(messages []message.Message, width int) {
pos := 0
if m.width == 0 {
return
}
for inx, msg := range messages {
switch msg.Role {
case message.User:
userMsg := renderUserMessage(
msg,
false,
width,
pos,
)
m.cachedContent[msg.ID] = cacheItem{
width: width,
content: []uiMessage{userMsg},
}
pos += userMsg.height + 1 // + 1 for spacing
case message.Assistant:
assistantMessages := renderAssistantMessage(
msg,
inx,
messages,
m.app.Messages,
"",
width,
pos,
)
for _, msg := range assistantMessages {
pos += msg.height + 1 // + 1 for spacing
}
m.cachedContent[msg.ID] = cacheItem{
width: width,
content: assistantMessages,
}
}
}
return tea.Batch(m.viewport.Init(), m.spinner.Tick)
}
func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
@ -360,21 +282,35 @@ func hasToolsWithoutResponse(messages []message.Message) bool {
break
}
}
if !found {
if !found && v.Finished {
return true
}
}
return false
}
func hasUnfinishedToolCalls(messages []message.Message) bool {
toolCalls := make([]message.ToolCall, 0)
for _, m := range messages {
toolCalls = append(toolCalls, m.ToolCalls()...)
}
for _, v := range toolCalls {
if !v.Finished {
return true
}
}
return false
}
func (m *messagesCmp) working() string {
text := ""
if m.IsAgentWorking() {
if m.IsAgentWorking() && len(m.messages) > 0 {
task := "Thinking..."
lastMessage := m.messages[len(m.messages)-1]
if hasToolsWithoutResponse(m.messages) {
task = "Waiting for tool response..."
} else if hasUnfinishedToolCalls(m.messages) {
task = "Building tool call..."
} else if !lastMessage.IsFinished() {
task = "Generating..."
}
@ -434,8 +370,7 @@ func (m *messagesCmp) SetSize(width, height int) tea.Cmd {
delete(m.cachedContent, msg.ID)
}
m.uiMessages = make([]uiMessage, 0)
m.renderView()
return m.preloadSessions()
return nil
}
func (m *messagesCmp) GetSize() (int, int) {
@ -446,16 +381,16 @@ func (m *messagesCmp) SetSession(session session.Session) tea.Cmd {
if m.session.ID == session.ID {
return nil
}
m.session = session
messages, err := m.app.Messages.List(context.Background(), session.ID)
if err != nil {
return util.ReportError(err)
}
m.messages = messages
m.currentMsgID = m.messages[len(m.messages)-1].ID
delete(m.cachedContent, m.currentMsgID)
m.rendering = true
return func() tea.Msg {
m.session = session
messages, err := m.app.Messages.List(context.Background(), session.ID)
if err != nil {
return util.ReportError(err)
}
m.messages = messages
m.currentMsgID = m.messages[len(m.messages)-1].ID
delete(m.cachedContent, m.currentMsgID)
m.renderView()
return renderFinishedMsg{}
}

View file

@ -113,18 +113,10 @@ func renderAssistantMessage(
width int,
position int,
) []uiMessage {
// find the user message that is before this assistant message
var userMsg message.Message
for i := msgIndex - 1; i >= 0; i-- {
msg := allMessages[i]
if msg.Role == message.User {
userMsg = allMessages[i]
break
}
}
messages := []uiMessage{}
content := msg.Content().String()
thinking := msg.IsThinking()
thinkingContent := msg.ReasoningContent().Thinking
finished := msg.IsFinished()
finishData := msg.FinishPart()
info := []string{}
@ -133,7 +125,7 @@ func renderAssistantMessage(
if finished {
switch finishData.Reason {
case message.FinishReasonEndTurn:
took := formatTimeDifference(userMsg.CreatedAt, finishData.Time)
took := formatTimeDifference(msg.CreatedAt, finishData.Time)
info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render(
fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, took),
))
@ -166,6 +158,9 @@ func renderAssistantMessage(
})
position += messages[0].height
position++ // for the space
} else if thinking && thinkingContent != "" {
// Render the thinking content
content = renderMessage(thinkingContent, false, msg.ID == focusedUIMessageId, width)
}
for i, toolCall := range msg.ToolCalls() {
@ -218,10 +213,40 @@ func toolName(name string) string {
return "View"
case tools.WriteToolName:
return "Write"
case tools.PatchToolName:
return "Patch"
}
return name
}
func getToolAction(name string) string {
switch name {
case agent.AgentToolName:
return "Preparing prompt..."
case tools.BashToolName:
return "Building command..."
case tools.EditToolName:
return "Preparing edit..."
case tools.FetchToolName:
return "Writing fetch..."
case tools.GlobToolName:
return "Finding files..."
case tools.GrepToolName:
return "Searching content..."
case tools.LSToolName:
return "Listing directory..."
case tools.SourcegraphToolName:
return "Searching code..."
case tools.ViewToolName:
return "Reading file..."
case tools.WriteToolName:
return "Preparing write..."
case tools.PatchToolName:
return "Preparing patch..."
}
return "Working..."
}
// renders params, params[0] (params[1]=params[2] ....)
func renderParams(paramsWidth int, params ...string) string {
if len(params) == 0 {
@ -490,8 +515,47 @@ func renderToolMessage(
if nested {
width = width - 3
}
style := styles.BaseStyle.
Width(width - 1).
BorderLeft(true).
BorderStyle(lipgloss.ThickBorder()).
PaddingLeft(1).
BorderForeground(styles.ForgroundDim)
response := findToolResponse(toolCall.ID, allMessages)
toolName := styles.BaseStyle.Foreground(styles.ForgroundDim).Render(fmt.Sprintf("%s: ", toolName(toolCall.Name)))
if !toolCall.Finished {
// Get a brief description of what the tool is doing
toolAction := getToolAction(toolCall.Name)
// toolInput := strings.ReplaceAll(toolCall.Input, "\n", " ")
// truncatedInput := toolInput
// if len(truncatedInput) > 10 {
// truncatedInput = truncatedInput[len(truncatedInput)-10:]
// }
//
// truncatedInput = styles.BaseStyle.
// Italic(true).
// Width(width - 2 - lipgloss.Width(toolName)).
// Background(styles.BackgroundDim).
// Foreground(styles.ForgroundMid).
// Render(truncatedInput)
progressText := styles.BaseStyle.
Width(width - 2 - lipgloss.Width(toolName)).
Foreground(styles.ForgroundDim).
Render(fmt.Sprintf("%s", toolAction))
content := style.Render(lipgloss.JoinHorizontal(lipgloss.Left, toolName, progressText))
toolMsg := uiMessage{
messageType: toolMessageType,
position: position,
height: lipgloss.Height(content),
content: content,
}
return toolMsg
}
params := renderToolParams(width-2-lipgloss.Width(toolName), toolCall)
responseContent := ""
if response != nil {
@ -504,12 +568,6 @@ func renderToolMessage(
Foreground(styles.ForgroundDim).
Render("Waiting for response...")
}
style := styles.BaseStyle.
Width(width - 1).
BorderLeft(true).
BorderStyle(lipgloss.ThickBorder()).
PaddingLeft(1).
BorderForeground(styles.ForgroundDim)
parts := []string{}
if !nested {

View file

@ -14,6 +14,10 @@ type SplitPaneLayout interface {
SetLeftPanel(panel Container) tea.Cmd
SetRightPanel(panel Container) tea.Cmd
SetBottomPanel(panel Container) tea.Cmd
ClearLeftPanel() tea.Cmd
ClearRightPanel() tea.Cmd
ClearBottomPanel() tea.Cmd
}
type splitPaneLayout struct {
@ -192,6 +196,30 @@ func (s *splitPaneLayout) SetBottomPanel(panel Container) tea.Cmd {
return nil
}
func (s *splitPaneLayout) ClearLeftPanel() tea.Cmd {
s.leftPanel = nil
if s.width > 0 && s.height > 0 {
return s.SetSize(s.width, s.height)
}
return nil
}
func (s *splitPaneLayout) ClearRightPanel() tea.Cmd {
s.rightPanel = nil
if s.width > 0 && s.height > 0 {
return s.SetSize(s.width, s.height)
}
return nil
}
func (s *splitPaneLayout) ClearBottomPanel() tea.Cmd {
s.bottomPanel = nil
if s.width > 0 && s.height > 0 {
return s.SetSize(s.width, s.height)
}
return nil
}
func (s *splitPaneLayout) BindingKeys() []key.Binding {
keys := []key.Binding{}
if s.leftPanel != nil {

View file

@ -57,6 +57,14 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if cmd != nil {
return p, cmd
}
case chat.SessionSelectedMsg:
if p.session.ID == "" {
cmd := p.setSidebar()
if cmd != nil {
cmds = append(cmds, cmd)
}
}
p.session = msg
case chat.EditorFocusMsg:
p.editingMode = bool(msg)
case tea.KeyMsg:
@ -91,7 +99,7 @@ func (p *chatPage) setSidebar() tea.Cmd {
}
func (p *chatPage) clearSidebar() tea.Cmd {
return p.layout.SetRightPanel(nil)
return p.layout.ClearRightPanel()
}
func (p *chatPage) sendMessage(text string) tea.Cmd {