opencode/internal/message/message.go
2025-05-13 13:08:43 -05:00

503 lines
14 KiB
Go

package message
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"log/slog"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/sst/opencode/internal/db"
"github.com/sst/opencode/internal/llm/models"
"github.com/sst/opencode/internal/pubsub"
)
type Message struct {
ID string
Role MessageRole
SessionID string
Parts []ContentPart
Model models.ModelID
CreatedAt time.Time
UpdatedAt time.Time
}
const (
EventMessageCreated pubsub.EventType = "message_created"
EventMessageUpdated pubsub.EventType = "message_updated"
EventMessageDeleted pubsub.EventType = "message_deleted"
)
type CreateMessageParams struct {
Role MessageRole
Parts []ContentPart
Model models.ModelID
}
type Service interface {
pubsub.Subscriber[Message]
Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
Update(ctx context.Context, message Message) (Message, error)
Get(ctx context.Context, id string) (Message, error)
List(ctx context.Context, sessionID string) ([]Message, error)
ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error)
Delete(ctx context.Context, id string) error
DeleteSessionMessages(ctx context.Context, sessionID string) error
}
type service struct {
db *db.Queries
broker *pubsub.Broker[Message]
mu sync.RWMutex
}
var globalMessageService *service
func InitService(dbConn *sql.DB) error {
if globalMessageService != nil {
return fmt.Errorf("message service already initialized")
}
queries := db.New(dbConn)
broker := pubsub.NewBroker[Message]()
globalMessageService = &service{
db: queries,
broker: broker,
}
return nil
}
func GetService() Service {
if globalMessageService == nil {
panic("message service not initialized. Call message.InitService() first.")
}
return globalMessageService
}
func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
s.mu.Lock()
defer s.mu.Unlock()
isFinished := false
for _, p := range params.Parts {
if _, ok := p.(Finish); ok {
isFinished = true
break
}
}
if params.Role == User && !isFinished {
params.Parts = append(params.Parts, Finish{Reason: FinishReasonEndTurn, Time: time.Now()})
}
partsJSON, err := marshallParts(params.Parts)
if err != nil {
return Message{}, fmt.Errorf("failed to marshal message parts: %w", err)
}
dbMsgParams := db.CreateMessageParams{
ID: uuid.New().String(),
SessionID: sessionID,
Role: string(params.Role),
Parts: string(partsJSON),
Model: sql.NullString{String: string(params.Model), Valid: params.Model != ""},
}
dbMessage, err := s.db.CreateMessage(ctx, dbMsgParams)
if err != nil {
return Message{}, fmt.Errorf("db.CreateMessage: %w", err)
}
message, err := s.fromDBItem(dbMessage)
if err != nil {
return Message{}, fmt.Errorf("failed to convert DB message: %w", err)
}
s.broker.Publish(EventMessageCreated, message)
return message, nil
}
func (s *service) Update(ctx context.Context, message Message) (Message, error) {
s.mu.Lock()
defer s.mu.Unlock()
if message.ID == "" {
return Message{}, fmt.Errorf("cannot update message with empty ID")
}
partsJSON, err := marshallParts(message.Parts)
if err != nil {
return Message{}, fmt.Errorf("failed to marshal message parts for update: %w", err)
}
var dbFinishedAt sql.NullString
finishPart := message.FinishPart()
if finishPart != nil && !finishPart.Time.IsZero() {
dbFinishedAt = sql.NullString{
String: finishPart.Time.UTC().Format(time.RFC3339Nano),
Valid: true,
}
}
// UpdatedAt is handled by the DB trigger (strftime('%s', 'now'))
err = s.db.UpdateMessage(ctx, db.UpdateMessageParams{
ID: message.ID,
Parts: string(partsJSON),
FinishedAt: dbFinishedAt,
})
if err != nil {
return Message{}, fmt.Errorf("db.UpdateMessage: %w", err)
}
dbUpdatedMessage, err := s.db.GetMessage(ctx, message.ID)
if err != nil {
return Message{}, fmt.Errorf("failed to fetch message after update: %w", err)
}
updatedMessage, err := s.fromDBItem(dbUpdatedMessage)
if err != nil {
return Message{}, fmt.Errorf("failed to convert updated DB message: %w", err)
}
s.broker.Publish(EventMessageUpdated, updatedMessage)
return updatedMessage, nil
}
func (s *service) Get(ctx context.Context, id string) (Message, error) {
s.mu.RLock()
defer s.mu.RUnlock()
dbMessage, err := s.db.GetMessage(ctx, id)
if err != nil {
if err == sql.ErrNoRows {
return Message{}, fmt.Errorf("message with ID '%s' not found", id)
}
return Message{}, fmt.Errorf("db.GetMessage: %w", err)
}
return s.fromDBItem(dbMessage)
}
func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
s.mu.RLock()
defer s.mu.RUnlock()
dbMessages, err := s.db.ListMessagesBySession(ctx, sessionID)
if err != nil {
return nil, fmt.Errorf("db.ListMessagesBySession: %w", err)
}
messages := make([]Message, len(dbMessages))
for i, dbMsg := range dbMessages {
msg, convErr := s.fromDBItem(dbMsg)
if convErr != nil {
return nil, fmt.Errorf("failed to convert DB message at index %d: %w", i, convErr)
}
messages[i] = msg
}
return messages, nil
}
func (s *service) ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error) {
s.mu.RLock()
defer s.mu.RUnlock()
dbMessages, err := s.db.ListMessagesBySessionAfter(ctx, db.ListMessagesBySessionAfterParams{
SessionID: sessionID,
CreatedAt: timestamp.Format(time.RFC3339Nano),
})
if err != nil {
return nil, fmt.Errorf("db.ListMessagesBySessionAfter: %w", err)
}
messages := make([]Message, len(dbMessages))
for i, dbMsg := range dbMessages {
msg, convErr := s.fromDBItem(dbMsg)
if convErr != nil {
return nil, fmt.Errorf("failed to convert DB message at index %d (ListAfter): %w", i, convErr)
}
messages[i] = msg
}
return messages, nil
}
func (s *service) Delete(ctx context.Context, id string) error {
s.mu.Lock()
messageToPublish, err := s.getServiceForPublish(ctx, id)
s.mu.Unlock()
if err != nil {
// If error was due to not found, it's not a critical failure for deletion intent
if strings.Contains(err.Error(), "not found") {
return nil // Or return the error if strictness is required
}
return err
}
s.mu.Lock()
defer s.mu.Unlock()
err = s.db.DeleteMessage(ctx, id)
if err != nil {
return fmt.Errorf("db.DeleteMessage: %w", err)
}
if messageToPublish != nil {
s.broker.Publish(EventMessageDeleted, *messageToPublish)
}
return nil
}
func (s *service) getServiceForPublish(ctx context.Context, id string) (*Message, error) {
dbMsg, err := s.db.GetMessage(ctx, id)
if err != nil {
return nil, err
}
msg, convErr := s.fromDBItem(dbMsg)
if convErr != nil {
return nil, fmt.Errorf("failed to convert DB message for publishing: %w", convErr)
}
return &msg, nil
}
func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
s.mu.Lock()
defer s.mu.Unlock()
messagesToDelete, err := s.db.ListMessagesBySession(ctx, sessionID)
if err != nil {
return fmt.Errorf("failed to list messages for deletion: %w", err)
}
err = s.db.DeleteSessionMessages(ctx, sessionID)
if err != nil {
return fmt.Errorf("db.DeleteSessionMessages: %w", err)
}
for _, dbMsg := range messagesToDelete {
msg, convErr := s.fromDBItem(dbMsg)
if convErr == nil {
s.broker.Publish(EventMessageDeleted, msg)
} else {
slog.Error("Failed to convert DB message for delete event publishing", "id", dbMsg.ID, "error", convErr)
}
}
return nil
}
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Message] {
return s.broker.Subscribe(ctx)
}
func (s *service) fromDBItem(item db.Message) (Message, error) {
parts, err := unmarshallParts([]byte(item.Parts))
if err != nil {
return Message{}, fmt.Errorf("unmarshallParts for message ID %s: %w. Raw parts: %s", item.ID, err, item.Parts)
}
// Parse timestamps from ISO strings
createdAt, err := time.Parse(time.RFC3339Nano, item.CreatedAt)
if err != nil {
slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err)
createdAt = time.Now() // Fallback
}
updatedAt, err := time.Parse(time.RFC3339Nano, item.UpdatedAt)
if err != nil {
slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err)
updatedAt = time.Now() // Fallback
}
msg := Message{
ID: item.ID,
SessionID: item.SessionID,
Role: MessageRole(item.Role),
Parts: parts,
Model: models.ModelID(item.Model.String),
CreatedAt: createdAt,
UpdatedAt: updatedAt,
}
return msg, nil
}
func Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
return GetService().Create(ctx, sessionID, params)
}
func Update(ctx context.Context, message Message) (Message, error) {
return GetService().Update(ctx, message)
}
func Get(ctx context.Context, id string) (Message, error) {
return GetService().Get(ctx, id)
}
func List(ctx context.Context, sessionID string) ([]Message, error) {
return GetService().List(ctx, sessionID)
}
func ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error) {
return GetService().ListAfter(ctx, sessionID, timestamp)
}
func Delete(ctx context.Context, id string) error {
return GetService().Delete(ctx, id)
}
func DeleteSessionMessages(ctx context.Context, sessionID string) error {
return GetService().DeleteSessionMessages(ctx, sessionID)
}
func Subscribe(ctx context.Context) <-chan pubsub.Event[Message] {
return GetService().Subscribe(ctx)
}
type partType string
const (
reasoningType partType = "reasoning"
textType partType = "text"
imageURLType partType = "image_url"
binaryType partType = "binary"
toolCallType partType = "tool_call"
toolResultType partType = "tool_result"
finishType partType = "finish"
)
type partWrapper struct {
Type partType `json:"type"`
Data json.RawMessage `json:"data"`
}
func marshallParts(parts []ContentPart) ([]byte, error) {
wrappedParts := make([]json.RawMessage, len(parts))
for i, part := range parts {
var typ partType
var dataBytes []byte
var err error
switch p := part.(type) {
case ReasoningContent:
typ = reasoningType
dataBytes, err = json.Marshal(p)
case TextContent:
typ = textType
dataBytes, err = json.Marshal(p)
case *TextContent:
typ = textType
dataBytes, err = json.Marshal(p)
case ImageURLContent:
typ = imageURLType
dataBytes, err = json.Marshal(p)
case BinaryContent:
typ = binaryType
dataBytes, err = json.Marshal(p)
case ToolCall:
typ = toolCallType
dataBytes, err = json.Marshal(p)
case ToolResult:
typ = toolResultType
dataBytes, err = json.Marshal(p)
case Finish:
typ = finishType
var dbFinish DBFinish
dbFinish.Reason = p.Reason
dbFinish.Time = p.Time.UnixMilli()
dataBytes, err = json.Marshal(dbFinish)
default:
return nil, fmt.Errorf("unknown part type for marshalling: %T", part)
}
if err != nil {
return nil, fmt.Errorf("failed to marshal part data for type %s: %w", typ, err)
}
wrapper := struct {
Type partType `json:"type"`
Data json.RawMessage `json:"data"`
}{Type: typ, Data: dataBytes}
wrappedBytes, err := json.Marshal(wrapper)
if err != nil {
return nil, fmt.Errorf("failed to marshal part wrapper for type %s: %w", typ, err)
}
wrappedParts[i] = wrappedBytes
}
return json.Marshal(wrappedParts)
}
func unmarshallParts(data []byte) ([]ContentPart, error) {
var rawMessages []json.RawMessage
if err := json.Unmarshal(data, &rawMessages); err != nil {
return nil, fmt.Errorf("failed to unmarshal parts data as array: %w. Data: %s", err, string(data))
}
parts := make([]ContentPart, 0, len(rawMessages))
for _, rawPart := range rawMessages {
var wrapper partWrapper
if err := json.Unmarshal(rawPart, &wrapper); err != nil {
// Fallback for old format where parts might be just TextContent string
var text string
if errText := json.Unmarshal(rawPart, &text); errText == nil {
parts = append(parts, TextContent{Text: text})
continue
}
return nil, fmt.Errorf("failed to unmarshal part wrapper: %w. Raw part: %s", err, string(rawPart))
}
switch wrapper.Type {
case reasoningType:
var p ReasoningContent
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
return nil, fmt.Errorf("unmarshal ReasoningContent: %w. Data: %s", err, string(wrapper.Data))
}
parts = append(parts, p)
case textType:
var p TextContent
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
return nil, fmt.Errorf("unmarshal TextContent: %w. Data: %s", err, string(wrapper.Data))
}
parts = append(parts, p)
case imageURLType:
var p ImageURLContent
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
return nil, fmt.Errorf("unmarshal ImageURLContent: %w. Data: %s", err, string(wrapper.Data))
}
parts = append(parts, p)
case binaryType:
var p BinaryContent
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
return nil, fmt.Errorf("unmarshal BinaryContent: %w. Data: %s", err, string(wrapper.Data))
}
parts = append(parts, p)
case toolCallType:
var p ToolCall
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
return nil, fmt.Errorf("unmarshal ToolCall: %w. Data: %s", err, string(wrapper.Data))
}
parts = append(parts, p)
case toolResultType:
var p ToolResult
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
return nil, fmt.Errorf("unmarshal ToolResult: %w. Data: %s", err, string(wrapper.Data))
}
parts = append(parts, p)
case finishType:
var p DBFinish
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
return nil, fmt.Errorf("unmarshal Finish: %w. Data: %s", err, string(wrapper.Data))
}
parts = append(parts, Finish{Reason: FinishReason(p.Reason), Time: time.UnixMilli(p.Time)})
default:
slog.Warn("Unknown part type during unmarshalling, attempting to parse as TextContent", "type", wrapper.Type, "data", string(wrapper.Data))
// Fallback: if type is unknown or empty, try to parse data as TextContent directly
var p TextContent
if err := json.Unmarshal(wrapper.Data, &p); err == nil {
parts = append(parts, p)
} else {
// If that also fails, log it but continue if possible, or return error
slog.Error("Failed to unmarshal unknown part type and fallback to TextContent failed", "type", wrapper.Type, "data", string(wrapper.Data), "error", err)
// Depending on strictness, you might return an error here:
// return nil, fmt.Errorf("unknown part type '%s' and failed fallback: %w", wrapper.Type, err)
}
}
}
return parts, nil
}