package message import ( "context" "database/sql" "encoding/json" "fmt" "log/slog" "strings" "sync" "time" "github.com/google/uuid" "github.com/sst/opencode/internal/db" "github.com/sst/opencode/internal/llm/models" "github.com/sst/opencode/internal/pubsub" ) type Message struct { ID string Role MessageRole SessionID string Parts []ContentPart Model models.ModelID CreatedAt time.Time UpdatedAt time.Time } const ( EventMessageCreated pubsub.EventType = "message_created" EventMessageUpdated pubsub.EventType = "message_updated" EventMessageDeleted pubsub.EventType = "message_deleted" ) type CreateMessageParams struct { Role MessageRole Parts []ContentPart Model models.ModelID } type Service interface { pubsub.Subscriber[Message] Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) Update(ctx context.Context, message Message) (Message, error) Get(ctx context.Context, id string) (Message, error) List(ctx context.Context, sessionID string) ([]Message, error) ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error) Delete(ctx context.Context, id string) error DeleteSessionMessages(ctx context.Context, sessionID string) error } type service struct { db *db.Queries broker *pubsub.Broker[Message] mu sync.RWMutex } var globalMessageService *service func InitService(dbConn *sql.DB) error { if globalMessageService != nil { return fmt.Errorf("message service already initialized") } queries := db.New(dbConn) broker := pubsub.NewBroker[Message]() globalMessageService = &service{ db: queries, broker: broker, } return nil } func GetService() Service { if globalMessageService == nil { panic("message service not initialized. Call message.InitService() first.") } return globalMessageService } func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) { s.mu.Lock() defer s.mu.Unlock() isFinished := false for _, p := range params.Parts { if _, ok := p.(Finish); ok { isFinished = true break } } if params.Role == User && !isFinished { params.Parts = append(params.Parts, Finish{Reason: FinishReasonEndTurn, Time: time.Now()}) } partsJSON, err := marshallParts(params.Parts) if err != nil { return Message{}, fmt.Errorf("failed to marshal message parts: %w", err) } dbMsgParams := db.CreateMessageParams{ ID: uuid.New().String(), SessionID: sessionID, Role: string(params.Role), Parts: string(partsJSON), Model: sql.NullString{String: string(params.Model), Valid: params.Model != ""}, } dbMessage, err := s.db.CreateMessage(ctx, dbMsgParams) if err != nil { return Message{}, fmt.Errorf("db.CreateMessage: %w", err) } message, err := s.fromDBItem(dbMessage) if err != nil { return Message{}, fmt.Errorf("failed to convert DB message: %w", err) } s.broker.Publish(EventMessageCreated, message) return message, nil } func (s *service) Update(ctx context.Context, message Message) (Message, error) { s.mu.Lock() defer s.mu.Unlock() if message.ID == "" { return Message{}, fmt.Errorf("cannot update message with empty ID") } partsJSON, err := marshallParts(message.Parts) if err != nil { return Message{}, fmt.Errorf("failed to marshal message parts for update: %w", err) } var dbFinishedAt sql.NullString finishPart := message.FinishPart() if finishPart != nil && !finishPart.Time.IsZero() { dbFinishedAt = sql.NullString{ String: finishPart.Time.UTC().Format(time.RFC3339Nano), Valid: true, } } // UpdatedAt is handled by the DB trigger (strftime('%s', 'now')) err = s.db.UpdateMessage(ctx, db.UpdateMessageParams{ ID: message.ID, Parts: string(partsJSON), FinishedAt: dbFinishedAt, }) if err != nil { return Message{}, fmt.Errorf("db.UpdateMessage: %w", err) } dbUpdatedMessage, err := s.db.GetMessage(ctx, message.ID) if err != nil { return Message{}, fmt.Errorf("failed to fetch message after update: %w", err) } updatedMessage, err := s.fromDBItem(dbUpdatedMessage) if err != nil { return Message{}, fmt.Errorf("failed to convert updated DB message: %w", err) } s.broker.Publish(EventMessageUpdated, updatedMessage) return updatedMessage, nil } func (s *service) Get(ctx context.Context, id string) (Message, error) { s.mu.RLock() defer s.mu.RUnlock() dbMessage, err := s.db.GetMessage(ctx, id) if err != nil { if err == sql.ErrNoRows { return Message{}, fmt.Errorf("message with ID '%s' not found", id) } return Message{}, fmt.Errorf("db.GetMessage: %w", err) } return s.fromDBItem(dbMessage) } func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) { s.mu.RLock() defer s.mu.RUnlock() dbMessages, err := s.db.ListMessagesBySession(ctx, sessionID) if err != nil { return nil, fmt.Errorf("db.ListMessagesBySession: %w", err) } messages := make([]Message, len(dbMessages)) for i, dbMsg := range dbMessages { msg, convErr := s.fromDBItem(dbMsg) if convErr != nil { return nil, fmt.Errorf("failed to convert DB message at index %d: %w", i, convErr) } messages[i] = msg } return messages, nil } func (s *service) ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error) { s.mu.RLock() defer s.mu.RUnlock() dbMessages, err := s.db.ListMessagesBySessionAfter(ctx, db.ListMessagesBySessionAfterParams{ SessionID: sessionID, CreatedAt: timestamp.Format(time.RFC3339Nano), }) if err != nil { return nil, fmt.Errorf("db.ListMessagesBySessionAfter: %w", err) } messages := make([]Message, len(dbMessages)) for i, dbMsg := range dbMessages { msg, convErr := s.fromDBItem(dbMsg) if convErr != nil { return nil, fmt.Errorf("failed to convert DB message at index %d (ListAfter): %w", i, convErr) } messages[i] = msg } return messages, nil } func (s *service) Delete(ctx context.Context, id string) error { s.mu.Lock() messageToPublish, err := s.getServiceForPublish(ctx, id) s.mu.Unlock() if err != nil { // If error was due to not found, it's not a critical failure for deletion intent if strings.Contains(err.Error(), "not found") { return nil // Or return the error if strictness is required } return err } s.mu.Lock() defer s.mu.Unlock() err = s.db.DeleteMessage(ctx, id) if err != nil { return fmt.Errorf("db.DeleteMessage: %w", err) } if messageToPublish != nil { s.broker.Publish(EventMessageDeleted, *messageToPublish) } return nil } func (s *service) getServiceForPublish(ctx context.Context, id string) (*Message, error) { dbMsg, err := s.db.GetMessage(ctx, id) if err != nil { return nil, err } msg, convErr := s.fromDBItem(dbMsg) if convErr != nil { return nil, fmt.Errorf("failed to convert DB message for publishing: %w", convErr) } return &msg, nil } func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error { s.mu.Lock() defer s.mu.Unlock() messagesToDelete, err := s.db.ListMessagesBySession(ctx, sessionID) if err != nil { return fmt.Errorf("failed to list messages for deletion: %w", err) } err = s.db.DeleteSessionMessages(ctx, sessionID) if err != nil { return fmt.Errorf("db.DeleteSessionMessages: %w", err) } for _, dbMsg := range messagesToDelete { msg, convErr := s.fromDBItem(dbMsg) if convErr == nil { s.broker.Publish(EventMessageDeleted, msg) } else { slog.Error("Failed to convert DB message for delete event publishing", "id", dbMsg.ID, "error", convErr) } } return nil } func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Message] { return s.broker.Subscribe(ctx) } func (s *service) fromDBItem(item db.Message) (Message, error) { parts, err := unmarshallParts([]byte(item.Parts)) if err != nil { return Message{}, fmt.Errorf("unmarshallParts for message ID %s: %w. Raw parts: %s", item.ID, err, item.Parts) } // Parse timestamps from ISO strings createdAt, err := time.Parse(time.RFC3339Nano, item.CreatedAt) if err != nil { slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err) createdAt = time.Now() // Fallback } updatedAt, err := time.Parse(time.RFC3339Nano, item.UpdatedAt) if err != nil { slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err) updatedAt = time.Now() // Fallback } msg := Message{ ID: item.ID, SessionID: item.SessionID, Role: MessageRole(item.Role), Parts: parts, Model: models.ModelID(item.Model.String), CreatedAt: createdAt, UpdatedAt: updatedAt, } return msg, nil } func Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) { return GetService().Create(ctx, sessionID, params) } func Update(ctx context.Context, message Message) (Message, error) { return GetService().Update(ctx, message) } func Get(ctx context.Context, id string) (Message, error) { return GetService().Get(ctx, id) } func List(ctx context.Context, sessionID string) ([]Message, error) { return GetService().List(ctx, sessionID) } func ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error) { return GetService().ListAfter(ctx, sessionID, timestamp) } func Delete(ctx context.Context, id string) error { return GetService().Delete(ctx, id) } func DeleteSessionMessages(ctx context.Context, sessionID string) error { return GetService().DeleteSessionMessages(ctx, sessionID) } func Subscribe(ctx context.Context) <-chan pubsub.Event[Message] { return GetService().Subscribe(ctx) } type partType string const ( reasoningType partType = "reasoning" textType partType = "text" imageURLType partType = "image_url" binaryType partType = "binary" toolCallType partType = "tool_call" toolResultType partType = "tool_result" finishType partType = "finish" ) type partWrapper struct { Type partType `json:"type"` Data json.RawMessage `json:"data"` } func marshallParts(parts []ContentPart) ([]byte, error) { wrappedParts := make([]json.RawMessage, len(parts)) for i, part := range parts { var typ partType var dataBytes []byte var err error switch p := part.(type) { case ReasoningContent: typ = reasoningType dataBytes, err = json.Marshal(p) case TextContent: typ = textType dataBytes, err = json.Marshal(p) case *TextContent: typ = textType dataBytes, err = json.Marshal(p) case ImageURLContent: typ = imageURLType dataBytes, err = json.Marshal(p) case BinaryContent: typ = binaryType dataBytes, err = json.Marshal(p) case ToolCall: typ = toolCallType dataBytes, err = json.Marshal(p) case ToolResult: typ = toolResultType dataBytes, err = json.Marshal(p) case Finish: typ = finishType var dbFinish DBFinish dbFinish.Reason = p.Reason dbFinish.Time = p.Time.UnixMilli() dataBytes, err = json.Marshal(dbFinish) default: return nil, fmt.Errorf("unknown part type for marshalling: %T", part) } if err != nil { return nil, fmt.Errorf("failed to marshal part data for type %s: %w", typ, err) } wrapper := struct { Type partType `json:"type"` Data json.RawMessage `json:"data"` }{Type: typ, Data: dataBytes} wrappedBytes, err := json.Marshal(wrapper) if err != nil { return nil, fmt.Errorf("failed to marshal part wrapper for type %s: %w", typ, err) } wrappedParts[i] = wrappedBytes } return json.Marshal(wrappedParts) } func unmarshallParts(data []byte) ([]ContentPart, error) { var rawMessages []json.RawMessage if err := json.Unmarshal(data, &rawMessages); err != nil { return nil, fmt.Errorf("failed to unmarshal parts data as array: %w. Data: %s", err, string(data)) } parts := make([]ContentPart, 0, len(rawMessages)) for _, rawPart := range rawMessages { var wrapper partWrapper if err := json.Unmarshal(rawPart, &wrapper); err != nil { // Fallback for old format where parts might be just TextContent string var text string if errText := json.Unmarshal(rawPart, &text); errText == nil { parts = append(parts, TextContent{Text: text}) continue } return nil, fmt.Errorf("failed to unmarshal part wrapper: %w. Raw part: %s", err, string(rawPart)) } switch wrapper.Type { case reasoningType: var p ReasoningContent if err := json.Unmarshal(wrapper.Data, &p); err != nil { return nil, fmt.Errorf("unmarshal ReasoningContent: %w. Data: %s", err, string(wrapper.Data)) } parts = append(parts, p) case textType: var p TextContent if err := json.Unmarshal(wrapper.Data, &p); err != nil { return nil, fmt.Errorf("unmarshal TextContent: %w. Data: %s", err, string(wrapper.Data)) } parts = append(parts, p) case imageURLType: var p ImageURLContent if err := json.Unmarshal(wrapper.Data, &p); err != nil { return nil, fmt.Errorf("unmarshal ImageURLContent: %w. Data: %s", err, string(wrapper.Data)) } parts = append(parts, p) case binaryType: var p BinaryContent if err := json.Unmarshal(wrapper.Data, &p); err != nil { return nil, fmt.Errorf("unmarshal BinaryContent: %w. Data: %s", err, string(wrapper.Data)) } parts = append(parts, p) case toolCallType: var p ToolCall if err := json.Unmarshal(wrapper.Data, &p); err != nil { return nil, fmt.Errorf("unmarshal ToolCall: %w. Data: %s", err, string(wrapper.Data)) } parts = append(parts, p) case toolResultType: var p ToolResult if err := json.Unmarshal(wrapper.Data, &p); err != nil { return nil, fmt.Errorf("unmarshal ToolResult: %w. Data: %s", err, string(wrapper.Data)) } parts = append(parts, p) case finishType: var p DBFinish if err := json.Unmarshal(wrapper.Data, &p); err != nil { return nil, fmt.Errorf("unmarshal Finish: %w. Data: %s", err, string(wrapper.Data)) } parts = append(parts, Finish{Reason: FinishReason(p.Reason), Time: time.UnixMilli(p.Time)}) default: slog.Warn("Unknown part type during unmarshalling, attempting to parse as TextContent", "type", wrapper.Type, "data", string(wrapper.Data)) // Fallback: if type is unknown or empty, try to parse data as TextContent directly var p TextContent if err := json.Unmarshal(wrapper.Data, &p); err == nil { parts = append(parts, p) } else { // If that also fails, log it but continue if possible, or return error slog.Error("Failed to unmarshal unknown part type and fallback to TextContent failed", "type", wrapper.Type, "data", string(wrapper.Data), "error", err) // Depending on strictness, you might return an error here: // return nil, fmt.Errorf("unknown part type '%s' and failed fallback: %w", wrapper.Type, err) } } } return parts, nil }