mirror of
https://github.com/sst/opencode.git
synced 2025-08-03 21:28:14 +00:00
271 lines
5.8 KiB
Go
271 lines
5.8 KiB
Go
package message
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/kujtimiihoxha/termai/internal/db"
|
|
"github.com/kujtimiihoxha/termai/internal/pubsub"
|
|
)
|
|
|
|
type CreateMessageParams struct {
|
|
Role MessageRole
|
|
Parts []ContentPart
|
|
}
|
|
|
|
type Service interface {
|
|
pubsub.Suscriber[Message]
|
|
Create(sessionID string, params CreateMessageParams) (Message, error)
|
|
Update(message Message) error
|
|
Get(id string) (Message, error)
|
|
List(sessionID string) ([]Message, error)
|
|
Delete(id string) error
|
|
DeleteSessionMessages(sessionID string) error
|
|
}
|
|
|
|
type service struct {
|
|
*pubsub.Broker[Message]
|
|
q db.Querier
|
|
ctx context.Context
|
|
}
|
|
|
|
func NewService(ctx context.Context, q db.Querier) Service {
|
|
return &service{
|
|
Broker: pubsub.NewBroker[Message](),
|
|
q: q,
|
|
ctx: ctx,
|
|
}
|
|
}
|
|
|
|
func (s *service) Delete(id string) error {
|
|
message, err := s.Get(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = s.q.DeleteMessage(s.ctx, message.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.Publish(pubsub.DeletedEvent, message)
|
|
return nil
|
|
}
|
|
|
|
func (s *service) Create(sessionID string, params CreateMessageParams) (Message, error) {
|
|
if params.Role != Assistant {
|
|
params.Parts = append(params.Parts, Finish{
|
|
Reason: "stop",
|
|
})
|
|
}
|
|
partsJSON, err := marshallParts(params.Parts)
|
|
if err != nil {
|
|
return Message{}, err
|
|
}
|
|
|
|
dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{
|
|
ID: uuid.New().String(),
|
|
SessionID: sessionID,
|
|
Role: string(params.Role),
|
|
Parts: string(partsJSON),
|
|
})
|
|
if err != nil {
|
|
return Message{}, err
|
|
}
|
|
message, err := s.fromDBItem(dbMessage)
|
|
if err != nil {
|
|
return Message{}, err
|
|
}
|
|
s.Publish(pubsub.CreatedEvent, message)
|
|
return message, nil
|
|
}
|
|
|
|
func (s *service) DeleteSessionMessages(sessionID string) error {
|
|
messages, err := s.List(sessionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, message := range messages {
|
|
if message.SessionID == sessionID {
|
|
err = s.Delete(message.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *service) Update(message Message) error {
|
|
parts, err := marshallParts(message.Parts)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{
|
|
ID: message.ID,
|
|
Parts: string(parts),
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.Publish(pubsub.UpdatedEvent, message)
|
|
return nil
|
|
}
|
|
|
|
func (s *service) Get(id string) (Message, error) {
|
|
dbMessage, err := s.q.GetMessage(s.ctx, id)
|
|
if err != nil {
|
|
return Message{}, err
|
|
}
|
|
return s.fromDBItem(dbMessage)
|
|
}
|
|
|
|
func (s *service) List(sessionID string) ([]Message, error) {
|
|
dbMessages, err := s.q.ListMessagesBySession(s.ctx, sessionID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
messages := make([]Message, len(dbMessages))
|
|
for i, dbMessage := range dbMessages {
|
|
messages[i], err = s.fromDBItem(dbMessage)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return messages, nil
|
|
}
|
|
|
|
func (s *service) fromDBItem(item db.Message) (Message, error) {
|
|
parts, err := unmarshallParts([]byte(item.Parts))
|
|
if err != nil {
|
|
return Message{}, err
|
|
}
|
|
return Message{
|
|
ID: item.ID,
|
|
SessionID: item.SessionID,
|
|
Role: MessageRole(item.Role),
|
|
Parts: parts,
|
|
CreatedAt: item.CreatedAt,
|
|
UpdatedAt: item.UpdatedAt,
|
|
}, nil
|
|
}
|
|
|
|
type partType string
|
|
|
|
const (
|
|
reasoningType partType = "reasoning"
|
|
textType partType = "text"
|
|
imageURLType partType = "image_url"
|
|
binaryType partType = "binary"
|
|
toolCallType partType = "tool_call"
|
|
toolResultType partType = "tool_result"
|
|
finishType partType = "finish"
|
|
)
|
|
|
|
type partWrapper struct {
|
|
Type partType `json:"type"`
|
|
Data ContentPart `json:"data"`
|
|
}
|
|
|
|
func marshallParts(parts []ContentPart) ([]byte, error) {
|
|
wrappedParts := make([]partWrapper, len(parts))
|
|
|
|
for i, part := range parts {
|
|
var typ partType
|
|
|
|
switch part.(type) {
|
|
case ReasoningContent:
|
|
typ = reasoningType
|
|
case TextContent:
|
|
typ = textType
|
|
case ImageURLContent:
|
|
typ = imageURLType
|
|
case BinaryContent:
|
|
typ = binaryType
|
|
case ToolCall:
|
|
typ = toolCallType
|
|
case ToolResult:
|
|
typ = toolResultType
|
|
case Finish:
|
|
typ = finishType
|
|
default:
|
|
return nil, fmt.Errorf("unknown part type: %T", part)
|
|
}
|
|
|
|
wrappedParts[i] = partWrapper{
|
|
Type: typ,
|
|
Data: part,
|
|
}
|
|
}
|
|
return json.Marshal(wrappedParts)
|
|
}
|
|
|
|
func unmarshallParts(data []byte) ([]ContentPart, error) {
|
|
temp := []json.RawMessage{}
|
|
|
|
if err := json.Unmarshal(data, &temp); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
parts := make([]ContentPart, 0)
|
|
|
|
for _, rawPart := range temp {
|
|
var wrapper struct {
|
|
Type partType `json:"type"`
|
|
Data json.RawMessage `json:"data"`
|
|
}
|
|
|
|
if err := json.Unmarshal(rawPart, &wrapper); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch wrapper.Type {
|
|
case reasoningType:
|
|
part := ReasoningContent{}
|
|
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
|
return nil, err
|
|
}
|
|
parts = append(parts, part)
|
|
case textType:
|
|
part := TextContent{}
|
|
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
|
return nil, err
|
|
}
|
|
parts = append(parts, part)
|
|
case imageURLType:
|
|
part := ImageURLContent{}
|
|
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
|
return nil, err
|
|
}
|
|
case binaryType:
|
|
part := BinaryContent{}
|
|
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
|
return nil, err
|
|
}
|
|
parts = append(parts, part)
|
|
case toolCallType:
|
|
part := ToolCall{}
|
|
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
|
return nil, err
|
|
}
|
|
parts = append(parts, part)
|
|
case toolResultType:
|
|
part := ToolResult{}
|
|
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
|
return nil, err
|
|
}
|
|
parts = append(parts, part)
|
|
case finishType:
|
|
part := Finish{}
|
|
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
|
return nil, err
|
|
}
|
|
parts = append(parts, part)
|
|
default:
|
|
return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
|
|
}
|
|
|
|
}
|
|
|
|
return parts, nil
|
|
}
|