mirror of
https://github.com/sst/opencode.git
synced 2025-08-04 05:28:16 +00:00
wip: refactoring
This commit is contained in:
parent
f100777199
commit
ed9fba99c9
13 changed files with 1342 additions and 931 deletions
|
@ -4,9 +4,11 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/opencode-ai/opencode/internal/db"
|
||||
|
@ -27,218 +29,338 @@ type File struct {
|
|||
UpdatedAt int64
|
||||
}
|
||||
|
||||
const (
|
||||
EventFileCreated pubsub.EventType = "history_file_created"
|
||||
EventFileVersionCreated pubsub.EventType = "history_file_version_created"
|
||||
EventFileUpdated pubsub.EventType = "history_file_updated"
|
||||
EventFileDeleted pubsub.EventType = "history_file_deleted"
|
||||
EventSessionFilesDeleted pubsub.EventType = "history_session_files_deleted"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
pubsub.Suscriber[File]
|
||||
pubsub.Subscriber[File]
|
||||
|
||||
Create(ctx context.Context, sessionID, path, content string) (File, error)
|
||||
CreateVersion(ctx context.Context, sessionID, path, content string) (File, error)
|
||||
Get(ctx context.Context, id string) (File, error)
|
||||
GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error)
|
||||
GetByPathAndVersion(ctx context.Context, sessionID, path, version string) (File, error)
|
||||
GetLatestByPathAndSession(ctx context.Context, path, sessionID string) (File, error)
|
||||
ListBySession(ctx context.Context, sessionID string) ([]File, error)
|
||||
ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error)
|
||||
ListVersionsByPath(ctx context.Context, path string) ([]File, error)
|
||||
Update(ctx context.Context, file File) (File, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
DeleteSessionFiles(ctx context.Context, sessionID string) error
|
||||
}
|
||||
|
||||
type service struct {
|
||||
*pubsub.Broker[File]
|
||||
db *sql.DB
|
||||
q *db.Queries
|
||||
db *db.Queries
|
||||
sqlDB *sql.DB
|
||||
broker *pubsub.Broker[File]
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewService(q *db.Queries, db *sql.DB) Service {
|
||||
return &service{
|
||||
Broker: pubsub.NewBroker[File](),
|
||||
q: q,
|
||||
db: db,
|
||||
var globalHistoryService *service
|
||||
|
||||
func InitService(sqlDatabase *sql.DB) error {
|
||||
if globalHistoryService != nil {
|
||||
return fmt.Errorf("history service already initialized")
|
||||
}
|
||||
queries := db.New(sqlDatabase)
|
||||
broker := pubsub.NewBroker[File]()
|
||||
|
||||
globalHistoryService = &service{
|
||||
db: queries,
|
||||
sqlDB: sqlDatabase,
|
||||
broker: broker,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetService() Service {
|
||||
if globalHistoryService == nil {
|
||||
panic("history service not initialized. Call history.InitService() first.")
|
||||
}
|
||||
return globalHistoryService
|
||||
}
|
||||
|
||||
func (s *service) Create(ctx context.Context, sessionID, path, content string) (File, error) {
|
||||
return s.createWithVersion(ctx, sessionID, path, content, InitialVersion)
|
||||
return s.createWithVersion(ctx, sessionID, path, content, InitialVersion, EventFileCreated)
|
||||
}
|
||||
|
||||
func (s *service) CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) {
|
||||
// Get the latest version for this path
|
||||
files, err := s.q.ListFilesByPath(ctx, path)
|
||||
if err != nil {
|
||||
return File{}, err
|
||||
s.mu.RLock()
|
||||
files, err := s.db.ListFilesByPath(ctx, path)
|
||||
s.mu.RUnlock()
|
||||
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return File{}, fmt.Errorf("db.ListFilesByPath for next version: %w", err)
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
// No previous versions, create initial
|
||||
return s.Create(ctx, sessionID, path, content)
|
||||
}
|
||||
latestVersionNumber := 0
|
||||
if len(files) > 0 {
|
||||
// Sort to be absolutely sure about the latest version globally for this path
|
||||
slices.SortFunc(files, func(a, b db.File) int {
|
||||
if strings.HasPrefix(a.Version, "v") && strings.HasPrefix(b.Version, "v") {
|
||||
vA, _ := strconv.Atoi(a.Version[1:])
|
||||
vB, _ := strconv.Atoi(b.Version[1:])
|
||||
return vB - vA // Descending to get latest first
|
||||
}
|
||||
if a.Version == InitialVersion && b.Version != InitialVersion {
|
||||
return 1 // initial comes after vX
|
||||
}
|
||||
if b.Version == InitialVersion && a.Version != InitialVersion {
|
||||
return -1
|
||||
}
|
||||
return int(b.CreatedAt - a.CreatedAt) // Fallback to timestamp
|
||||
})
|
||||
|
||||
// Get the latest version
|
||||
latestFile := files[0] // Files are ordered by created_at DESC
|
||||
latestVersion := latestFile.Version
|
||||
|
||||
// Generate the next version
|
||||
var nextVersion string
|
||||
if latestVersion == InitialVersion {
|
||||
nextVersion = "v1"
|
||||
} else if strings.HasPrefix(latestVersion, "v") {
|
||||
versionNum, err := strconv.Atoi(latestVersion[1:])
|
||||
if err != nil {
|
||||
// If we can't parse the version, just use a timestamp-based version
|
||||
nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt)
|
||||
} else {
|
||||
nextVersion = fmt.Sprintf("v%d", versionNum+1)
|
||||
latestFile := files[0]
|
||||
if strings.HasPrefix(latestFile.Version, "v") {
|
||||
vNum, parseErr := strconv.Atoi(latestFile.Version[1:])
|
||||
if parseErr == nil {
|
||||
latestVersionNumber = vNum
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If the version format is unexpected, use a timestamp-based version
|
||||
nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt)
|
||||
}
|
||||
|
||||
return s.createWithVersion(ctx, sessionID, path, content, nextVersion)
|
||||
nextVersionStr := fmt.Sprintf("v%d", latestVersionNumber+1)
|
||||
return s.createWithVersion(ctx, sessionID, path, content, nextVersionStr, EventFileVersionCreated)
|
||||
}
|
||||
|
||||
func (s *service) createWithVersion(ctx context.Context, sessionID, path, content, version string) (File, error) {
|
||||
// Maximum number of retries for transaction conflicts
|
||||
func (s *service) createWithVersion(ctx context.Context, sessionID, path, content, version string, eventType pubsub.EventType) (File, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
const maxRetries = 3
|
||||
var file File
|
||||
var err error
|
||||
|
||||
// Retry loop for transaction conflicts
|
||||
for attempt := range maxRetries {
|
||||
// Start a transaction
|
||||
tx, txErr := s.db.Begin()
|
||||
tx, txErr := s.sqlDB.BeginTx(ctx, nil)
|
||||
if txErr != nil {
|
||||
return File{}, fmt.Errorf("failed to begin transaction: %w", txErr)
|
||||
}
|
||||
qtx := s.db.WithTx(tx)
|
||||
|
||||
// Create a new queries instance with the transaction
|
||||
qtx := s.q.WithTx(tx)
|
||||
|
||||
// Try to create the file within the transaction
|
||||
dbFile, txErr := qtx.CreateFile(ctx, db.CreateFileParams{
|
||||
dbFile, createErr := qtx.CreateFile(ctx, db.CreateFileParams{
|
||||
ID: uuid.New().String(),
|
||||
SessionID: sessionID,
|
||||
Path: path,
|
||||
Content: content,
|
||||
Version: version,
|
||||
})
|
||||
if txErr != nil {
|
||||
// Rollback the transaction
|
||||
tx.Rollback()
|
||||
|
||||
// Check if this is a uniqueness constraint violation
|
||||
if strings.Contains(txErr.Error(), "UNIQUE constraint failed") {
|
||||
if createErr != nil {
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
slog.Error("Failed to rollback transaction on create error", "error", rbErr)
|
||||
}
|
||||
if strings.Contains(createErr.Error(), "UNIQUE constraint failed: files.path, files.session_id, files.version") {
|
||||
if attempt < maxRetries-1 {
|
||||
// If we have retries left, generate a new version and try again
|
||||
slog.Warn("Unique constraint violation for file version, retrying with incremented version", "path", path, "session", sessionID, "attempted_version", version, "attempt", attempt+1)
|
||||
// Increment version string like v1, v2, v3...
|
||||
if strings.HasPrefix(version, "v") {
|
||||
versionNum, parseErr := strconv.Atoi(version[1:])
|
||||
numPart := version[1:]
|
||||
num, parseErr := strconv.Atoi(numPart)
|
||||
if parseErr == nil {
|
||||
version = fmt.Sprintf("v%d", versionNum+1)
|
||||
continue
|
||||
version = fmt.Sprintf("v%d", num+1)
|
||||
continue // Retry with new version
|
||||
}
|
||||
}
|
||||
// If we can't parse the version, use a timestamp-based version
|
||||
version = fmt.Sprintf("v%d", time.Now().Unix())
|
||||
// Fallback if version is not "vX" or parsing failed
|
||||
version = fmt.Sprintf("%s-retry%d", version, attempt+1)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return File{}, txErr
|
||||
return File{}, fmt.Errorf("db.CreateFile within transaction: %w", createErr)
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if txErr = tx.Commit(); txErr != nil {
|
||||
return File{}, fmt.Errorf("failed to commit transaction: %w", txErr)
|
||||
if commitErr := tx.Commit(); commitErr != nil {
|
||||
return File{}, fmt.Errorf("failed to commit transaction: %w", commitErr)
|
||||
}
|
||||
|
||||
file = s.fromDBItem(dbFile)
|
||||
s.Publish(pubsub.CreatedEvent, file)
|
||||
return file, nil
|
||||
s.broker.Publish(eventType, file)
|
||||
return file, nil // Success
|
||||
}
|
||||
|
||||
return file, err
|
||||
return File{}, fmt.Errorf("failed to create file after %d retries due to version conflicts: %w", maxRetries, err)
|
||||
}
|
||||
|
||||
func (s *service) Get(ctx context.Context, id string) (File, error) {
|
||||
dbFile, err := s.q.GetFile(ctx, id)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbFile, err := s.db.GetFile(ctx, id)
|
||||
if err != nil {
|
||||
return File{}, err
|
||||
if err == sql.ErrNoRows {
|
||||
return File{}, fmt.Errorf("file with ID '%s' not found", id)
|
||||
}
|
||||
return File{}, fmt.Errorf("db.GetFile: %w", err)
|
||||
}
|
||||
return s.fromDBItem(dbFile), nil
|
||||
}
|
||||
|
||||
func (s *service) GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error) {
|
||||
dbFile, err := s.q.GetFileByPathAndSession(ctx, db.GetFileByPathAndSessionParams{
|
||||
func (s *service) GetByPathAndVersion(ctx context.Context, sessionID, path, version string) (File, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// sqlc doesn't directly support GetyByPathAndVersionAndSession
|
||||
// We list and filter. This could be optimized with a custom query if performance is an issue.
|
||||
allFilesForPath, err := s.db.ListFilesByPath(ctx, path)
|
||||
if err != nil {
|
||||
return File{}, fmt.Errorf("db.ListFilesByPath for GetByPathAndVersion: %w", err)
|
||||
}
|
||||
|
||||
for _, dbFile := range allFilesForPath {
|
||||
if dbFile.SessionID == sessionID && dbFile.Version == version {
|
||||
return s.fromDBItem(dbFile), nil
|
||||
}
|
||||
}
|
||||
return File{}, fmt.Errorf("file not found for session '%s', path '%s', version '%s'", sessionID, path, version)
|
||||
}
|
||||
|
||||
func (s *service) GetLatestByPathAndSession(ctx context.Context, path, sessionID string) (File, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
// GetFileByPathAndSession in sqlc already orders by created_at DESC and takes LIMIT 1
|
||||
dbFile, err := s.db.GetFileByPathAndSession(ctx, db.GetFileByPathAndSessionParams{
|
||||
Path: path,
|
||||
SessionID: sessionID,
|
||||
})
|
||||
if err != nil {
|
||||
return File{}, err
|
||||
if err == sql.ErrNoRows {
|
||||
return File{}, fmt.Errorf("no file found for path '%s' in session '%s'", path, sessionID)
|
||||
}
|
||||
return File{}, fmt.Errorf("db.GetFileByPathAndSession: %w", err)
|
||||
}
|
||||
return s.fromDBItem(dbFile), nil
|
||||
}
|
||||
|
||||
func (s *service) ListBySession(ctx context.Context, sessionID string) ([]File, error) {
|
||||
dbFiles, err := s.q.ListFilesBySession(ctx, sessionID)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbFiles, err := s.db.ListFilesBySession(ctx, sessionID) // Assumes this orders by created_at ASC
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("db.ListFilesBySession: %w", err)
|
||||
}
|
||||
files := make([]File, len(dbFiles))
|
||||
for i, dbFile := range dbFiles {
|
||||
files[i] = s.fromDBItem(dbFile)
|
||||
for i, dbF := range dbFiles {
|
||||
files[i] = s.fromDBItem(dbF)
|
||||
}
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (s *service) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) {
|
||||
dbFiles, err := s.q.ListLatestSessionFiles(ctx, sessionID)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbFiles, err := s.db.ListLatestSessionFiles(ctx, sessionID) // Uses the specific sqlc query
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("db.ListLatestSessionFiles: %w", err)
|
||||
}
|
||||
files := make([]File, len(dbFiles))
|
||||
for i, dbFile := range dbFiles {
|
||||
files[i] = s.fromDBItem(dbFile)
|
||||
for i, dbF := range dbFiles {
|
||||
files[i] = s.fromDBItem(dbF)
|
||||
}
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (s *service) ListVersionsByPath(ctx context.Context, path string) ([]File, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbFiles, err := s.db.ListFilesByPath(ctx, path) // sqlc query orders by created_at DESC
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListFilesByPath: %w", err)
|
||||
}
|
||||
files := make([]File, len(dbFiles))
|
||||
for i, dbF := range dbFiles {
|
||||
files[i] = s.fromDBItem(dbF)
|
||||
}
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (s *service) Update(ctx context.Context, file File) (File, error) {
|
||||
dbFile, err := s.q.UpdateFile(ctx, db.UpdateFileParams{
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if file.ID == "" {
|
||||
return File{}, fmt.Errorf("cannot update file with empty ID")
|
||||
}
|
||||
// UpdatedAt is handled by DB trigger
|
||||
dbFile, err := s.db.UpdateFile(ctx, db.UpdateFileParams{
|
||||
ID: file.ID,
|
||||
Content: file.Content,
|
||||
Version: file.Version,
|
||||
})
|
||||
if err != nil {
|
||||
return File{}, err
|
||||
return File{}, fmt.Errorf("db.UpdateFile: %w", err)
|
||||
}
|
||||
updatedFile := s.fromDBItem(dbFile)
|
||||
s.Publish(pubsub.UpdatedEvent, updatedFile)
|
||||
s.broker.Publish(EventFileUpdated, updatedFile)
|
||||
return updatedFile, nil
|
||||
}
|
||||
|
||||
func (s *service) Delete(ctx context.Context, id string) error {
|
||||
file, err := s.Get(ctx, id)
|
||||
s.mu.Lock()
|
||||
fileToPublish, err := s.getServiceForPublish(ctx, id) // Use internal method with appropriate locking
|
||||
s.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
slog.Warn("Attempted to delete non-existent file history", "id", id)
|
||||
return nil // Or return specific error if needed
|
||||
}
|
||||
return err
|
||||
}
|
||||
err = s.q.DeleteFile(ctx, id)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
err = s.db.DeleteFile(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("db.DeleteFile: %w", err)
|
||||
}
|
||||
if fileToPublish != nil {
|
||||
s.broker.Publish(EventFileDeleted, *fileToPublish)
|
||||
}
|
||||
s.Publish(pubsub.DeletedEvent, file)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) DeleteSessionFiles(ctx context.Context, sessionID string) error {
|
||||
files, err := s.ListBySession(ctx, sessionID)
|
||||
// getServiceForPublish is an internal helper for Delete
|
||||
func (s *service) getServiceForPublish(ctx context.Context, id string) (*File, error) {
|
||||
// Assumes outer lock is NOT held or caller manages it.
|
||||
// For GetFile, it has its own RLock.
|
||||
dbFile, err := s.db.GetFile(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
for _, file := range files {
|
||||
err = s.Delete(ctx, file.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
file := s.fromDBItem(dbFile)
|
||||
return &file, nil
|
||||
}
|
||||
|
||||
func (s *service) DeleteSessionFiles(ctx context.Context, sessionID string) error {
|
||||
s.mu.Lock() // Lock for the entire operation
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Get files first for publishing events
|
||||
filesToDelete, err := s.db.ListFilesBySession(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.ListFilesBySession for deletion: %w", err)
|
||||
}
|
||||
|
||||
err = s.db.DeleteSessionFiles(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.DeleteSessionFiles: %w", err)
|
||||
}
|
||||
|
||||
for _, dbFile := range filesToDelete {
|
||||
file := s.fromDBItem(dbFile)
|
||||
s.broker.Publish(EventFileDeleted, file) // Individual delete events
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[File] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func (s *service) fromDBItem(item db.File) File {
|
||||
return File{
|
||||
ID: item.ID,
|
||||
|
@ -246,7 +368,45 @@ func (s *service) fromDBItem(item db.File) File {
|
|||
Path: item.Path,
|
||||
Content: item.Content,
|
||||
Version: item.Version,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
CreatedAt: item.CreatedAt * 1000, // DB stores seconds, Go struct uses milliseconds
|
||||
UpdatedAt: item.UpdatedAt * 1000, // DB stores seconds, Go struct uses milliseconds
|
||||
}
|
||||
}
|
||||
|
||||
// --- Package-Level Wrapper Functions ---
|
||||
func Create(ctx context.Context, sessionID, path, content string) (File, error) {
|
||||
return GetService().Create(ctx, sessionID, path, content)
|
||||
}
|
||||
func CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) {
|
||||
return GetService().CreateVersion(ctx, sessionID, path, content)
|
||||
}
|
||||
func Get(ctx context.Context, id string) (File, error) {
|
||||
return GetService().Get(ctx, id)
|
||||
}
|
||||
func GetByPathAndVersion(ctx context.Context, sessionID, path, version string) (File, error) {
|
||||
return GetService().GetByPathAndVersion(ctx, sessionID, path, version)
|
||||
}
|
||||
func GetLatestByPathAndSession(ctx context.Context, path, sessionID string) (File, error) {
|
||||
return GetService().GetLatestByPathAndSession(ctx, path, sessionID)
|
||||
}
|
||||
func ListBySession(ctx context.Context, sessionID string) ([]File, error) {
|
||||
return GetService().ListBySession(ctx, sessionID)
|
||||
}
|
||||
func ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) {
|
||||
return GetService().ListLatestSessionFiles(ctx, sessionID)
|
||||
}
|
||||
func ListVersionsByPath(ctx context.Context, path string) ([]File, error) {
|
||||
return GetService().ListVersionsByPath(ctx, path)
|
||||
}
|
||||
func Update(ctx context.Context, file File) (File, error) {
|
||||
return GetService().Update(ctx, file)
|
||||
}
|
||||
func Delete(ctx context.Context, id string) error {
|
||||
return GetService().Delete(ctx, id)
|
||||
}
|
||||
func DeleteSessionFiles(ctx context.Context, sessionID string) error {
|
||||
return GetService().DeleteSessionFiles(ctx, sessionID)
|
||||
}
|
||||
func SubscribeToEvents(ctx context.Context) <-chan pubsub.Event[File] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
|
|
|
@ -1,51 +1,282 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/status"
|
||||
"github.com/go-logfmt/logfmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/opencode-ai/opencode/internal/db"
|
||||
"github.com/opencode-ai/opencode/internal/pubsub"
|
||||
// "github.com/opencode-ai/opencode/internal/status"
|
||||
)
|
||||
|
||||
type Log struct {
|
||||
ID string
|
||||
SessionID string
|
||||
Timestamp int64
|
||||
Level string
|
||||
Message string
|
||||
Attributes map[string]string
|
||||
CreatedAt int64
|
||||
}
|
||||
|
||||
const (
|
||||
EventLogCreated pubsub.EventType = "log_created"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
pubsub.Subscriber[Log]
|
||||
|
||||
Create(ctx context.Context, log Log) error
|
||||
ListBySession(ctx context.Context, sessionID string) ([]Log, error)
|
||||
ListAll(ctx context.Context, limit int) ([]Log, error)
|
||||
}
|
||||
|
||||
type service struct {
|
||||
db *db.Queries
|
||||
broker *pubsub.Broker[Log]
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var globalLoggingService *service
|
||||
|
||||
func InitService(dbConn *sql.DB) error {
|
||||
if globalLoggingService != nil {
|
||||
return fmt.Errorf("logging service already initialized")
|
||||
}
|
||||
queries := db.New(dbConn)
|
||||
broker := pubsub.NewBroker[Log]()
|
||||
|
||||
globalLoggingService = &service{
|
||||
db: queries,
|
||||
broker: broker,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetService() Service {
|
||||
if globalLoggingService == nil {
|
||||
panic("logging service not initialized. Call logging.InitService() first.")
|
||||
}
|
||||
return globalLoggingService
|
||||
}
|
||||
|
||||
func (s *service) Create(ctx context.Context, log Log) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if log.ID == "" {
|
||||
log.ID = uuid.New().String()
|
||||
}
|
||||
if log.Timestamp == 0 {
|
||||
log.Timestamp = time.Now().UnixMilli()
|
||||
}
|
||||
if log.CreatedAt == 0 {
|
||||
log.CreatedAt = time.Now().UnixMilli()
|
||||
}
|
||||
if log.Level == "" {
|
||||
log.Level = "info"
|
||||
}
|
||||
|
||||
var attributesJSON sql.NullString
|
||||
if len(log.Attributes) > 0 {
|
||||
attributesBytes, err := json.Marshal(log.Attributes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal log attributes: %w", err)
|
||||
}
|
||||
attributesJSON = sql.NullString{String: string(attributesBytes), Valid: true}
|
||||
}
|
||||
|
||||
err := s.db.CreateLog(ctx, db.CreateLogParams{
|
||||
ID: log.ID,
|
||||
SessionID: sql.NullString{String: log.SessionID, Valid: log.SessionID != ""},
|
||||
Timestamp: log.Timestamp / 1000,
|
||||
Level: log.Level,
|
||||
Message: log.Message,
|
||||
Attributes: attributesJSON,
|
||||
CreatedAt: log.CreatedAt / 1000,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.CreateLog: %w", err)
|
||||
}
|
||||
|
||||
s.broker.Publish(EventLogCreated, log)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) ListBySession(ctx context.Context, sessionID string) ([]Log, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
dbLogs, err := s.db.ListLogsBySession(ctx, sql.NullString{String: sessionID, Valid: true})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListLogsBySession: %w", err)
|
||||
}
|
||||
return s.fromDBItems(dbLogs)
|
||||
}
|
||||
|
||||
func (s *service) ListAll(ctx context.Context, limit int) ([]Log, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
dbLogs, err := s.db.ListAllLogs(ctx, int64(limit))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListAllLogs: %w", err)
|
||||
}
|
||||
return s.fromDBItems(dbLogs)
|
||||
}
|
||||
|
||||
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Log] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func (s *service) fromDBItems(items []db.Log) ([]Log, error) {
|
||||
logs := make([]Log, len(items))
|
||||
for i, item := range items {
|
||||
log := Log{
|
||||
ID: item.ID,
|
||||
SessionID: item.SessionID.String,
|
||||
Timestamp: item.Timestamp * 1000,
|
||||
Level: item.Level,
|
||||
Message: item.Message,
|
||||
CreatedAt: item.CreatedAt * 1000,
|
||||
}
|
||||
if item.Attributes.Valid && item.Attributes.String != "" {
|
||||
if err := json.Unmarshal([]byte(item.Attributes.String), &log.Attributes); err != nil {
|
||||
slog.Error("Failed to unmarshal log attributes", "log_id", item.ID, "error", err)
|
||||
log.Attributes = make(map[string]string)
|
||||
}
|
||||
} else {
|
||||
log.Attributes = make(map[string]string)
|
||||
}
|
||||
logs[i] = log
|
||||
}
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
func Create(ctx context.Context, log Log) error {
|
||||
return GetService().Create(ctx, log)
|
||||
}
|
||||
|
||||
func ListBySession(ctx context.Context, sessionID string) ([]Log, error) {
|
||||
return GetService().ListBySession(ctx, sessionID)
|
||||
}
|
||||
|
||||
func ListAll(ctx context.Context, limit int) ([]Log, error) {
|
||||
return GetService().ListAll(ctx, limit)
|
||||
}
|
||||
|
||||
func SubscribeToEvents(ctx context.Context) <-chan pubsub.Event[Log] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
|
||||
type slogWriter struct{}
|
||||
|
||||
func (sw *slogWriter) Write(p []byte) (n int, err error) {
|
||||
// Example: time=2024-05-09T12:34:56.789-05:00 level=INFO msg="User request" session=xyz foo=bar
|
||||
d := logfmt.NewDecoder(bytes.NewReader(p))
|
||||
for d.ScanRecord() {
|
||||
logEntry := Log{
|
||||
Attributes: make(map[string]string),
|
||||
}
|
||||
hasTimestamp := false
|
||||
|
||||
for d.ScanKeyval() {
|
||||
key := string(d.Key())
|
||||
value := string(d.Value())
|
||||
|
||||
switch key {
|
||||
case "time":
|
||||
parsedTime, timeErr := time.Parse(time.RFC3339Nano, value)
|
||||
if timeErr != nil {
|
||||
parsedTime, timeErr = time.Parse(time.RFC3339, value)
|
||||
if timeErr != nil {
|
||||
slog.Error("Failed to parse time in slog writer", "value", value, "error", timeErr)
|
||||
logEntry.Timestamp = time.Now().UnixMilli()
|
||||
hasTimestamp = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
logEntry.Timestamp = parsedTime.UnixMilli()
|
||||
hasTimestamp = true
|
||||
case "level":
|
||||
logEntry.Level = strings.ToLower(value)
|
||||
case "msg", "message":
|
||||
logEntry.Message = value
|
||||
case "session_id", "session", "sid":
|
||||
logEntry.SessionID = value
|
||||
default:
|
||||
logEntry.Attributes[key] = value
|
||||
}
|
||||
}
|
||||
if d.Err() != nil {
|
||||
return len(p), fmt.Errorf("logfmt.ScanRecord: %w", d.Err())
|
||||
}
|
||||
|
||||
if !hasTimestamp {
|
||||
logEntry.Timestamp = time.Now().UnixMilli()
|
||||
}
|
||||
|
||||
// Create log entry via the service (non-blocking or handle error appropriately)
|
||||
// Using context.Background() as this is a low-level logging write.
|
||||
go func(le Log) { // Run in a goroutine to avoid blocking slog
|
||||
if err := Create(context.Background(), le); err != nil {
|
||||
// Log internal error using a more primitive logger to avoid loops
|
||||
fmt.Fprintf(os.Stderr, "ERROR [logging.slogWriter]: failed to persist log: %v\n", err)
|
||||
}
|
||||
}(logEntry)
|
||||
}
|
||||
if d.Err() != nil {
|
||||
return len(p), fmt.Errorf("logfmt.ScanRecord final: %w", d.Err())
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func NewSlogWriter() io.Writer {
|
||||
return &slogWriter{}
|
||||
}
|
||||
|
||||
// RecoverPanic is a common function to handle panics gracefully.
|
||||
// It logs the error, creates a panic log file with stack trace,
|
||||
// and executes an optional cleanup function before returning.
|
||||
// and executes an optional cleanup function.
|
||||
func RecoverPanic(name string, cleanup func()) {
|
||||
if r := recover(); r != nil {
|
||||
// Log the panic
|
||||
errorMsg := fmt.Sprintf("Panic in %s: %v", name, r)
|
||||
// Use slog directly here, as our service might be the one panicking.
|
||||
slog.Error(errorMsg)
|
||||
status.Error(errorMsg)
|
||||
// status.Error(errorMsg)
|
||||
|
||||
// Create a timestamped panic log file
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
filename := fmt.Sprintf("opencode-panic-%s-%s.log", name, timestamp)
|
||||
|
||||
file, err := os.Create(filename)
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("Failed to create panic log: %v", err)
|
||||
errMsg := fmt.Sprintf("Failed to create panic log file '%s': %v", filename, err)
|
||||
slog.Error(errMsg)
|
||||
status.Error(errMsg)
|
||||
// status.Error(errMsg)
|
||||
} else {
|
||||
defer file.Close()
|
||||
|
||||
// Write panic information and stack trace
|
||||
fmt.Fprintf(file, "Panic in %s: %v\n\n", name, r)
|
||||
fmt.Fprintf(file, "Time: %s\n\n", time.Now().Format(time.RFC3339))
|
||||
fmt.Fprintf(file, "Stack Trace:\n%s\n", debug.Stack())
|
||||
|
||||
fmt.Fprintf(file, "Stack Trace:\n%s\n", string(debug.Stack())) // Capture stack trace
|
||||
infoMsg := fmt.Sprintf("Panic details written to %s", filename)
|
||||
slog.Info(infoMsg)
|
||||
status.Info(infoMsg)
|
||||
// status.Info(infoMsg)
|
||||
}
|
||||
|
||||
// Execute cleanup function if provided
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Manager handles logging management
|
||||
type Manager struct {
|
||||
service Service
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Global instance of the logging manager
|
||||
var globalManager *Manager
|
||||
|
||||
// InitManager initializes the global logging manager with the provided service
|
||||
func InitManager(service Service) {
|
||||
globalManager = &Manager{
|
||||
service: service,
|
||||
}
|
||||
|
||||
// Subscribe to log events if needed
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
_ = service.Subscribe(ctx) // Just subscribing to keep the channel open
|
||||
}()
|
||||
}
|
||||
|
||||
// GetService returns the logging service
|
||||
func GetService() Service {
|
||||
if globalManager == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
globalManager.mu.RLock()
|
||||
defer globalManager.mu.RUnlock()
|
||||
|
||||
return globalManager.service
|
||||
}
|
||||
|
||||
func Create(ctx context.Context, log Log) error {
|
||||
if globalManager == nil {
|
||||
return nil
|
||||
}
|
||||
return globalManager.service.Create(ctx, log)
|
||||
}
|
||||
|
|
@ -1,167 +0,0 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/opencode-ai/opencode/internal/db"
|
||||
"github.com/opencode-ai/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
// Log represents a log entry in the system
|
||||
type Log struct {
|
||||
ID string
|
||||
SessionID string
|
||||
Timestamp int64
|
||||
Level string
|
||||
Message string
|
||||
Attributes map[string]string
|
||||
CreatedAt int64
|
||||
}
|
||||
|
||||
// Service defines the interface for log operations
|
||||
type Service interface {
|
||||
pubsub.Suscriber[Log]
|
||||
Create(ctx context.Context, log Log) error
|
||||
ListBySession(ctx context.Context, sessionID string) ([]Log, error)
|
||||
ListAll(ctx context.Context, limit int) ([]Log, error)
|
||||
}
|
||||
|
||||
// service implements the Service interface
|
||||
type service struct {
|
||||
*pubsub.Broker[Log]
|
||||
q db.Querier
|
||||
}
|
||||
|
||||
// NewService creates a new logging service
|
||||
func NewService(q db.Querier) Service {
|
||||
broker := pubsub.NewBroker[Log]()
|
||||
return &service{
|
||||
Broker: broker,
|
||||
q: q,
|
||||
}
|
||||
}
|
||||
|
||||
// Create adds a new log entry to the database
|
||||
func (s *service) Create(ctx context.Context, log Log) error {
|
||||
// Generate ID if not provided
|
||||
if log.ID == "" {
|
||||
log.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
// Set timestamp if not provided
|
||||
if log.Timestamp == 0 {
|
||||
log.Timestamp = time.Now().Unix()
|
||||
}
|
||||
|
||||
// Set created_at if not provided
|
||||
if log.CreatedAt == 0 {
|
||||
log.CreatedAt = time.Now().Unix()
|
||||
}
|
||||
|
||||
// Convert attributes to JSON string
|
||||
var attributesJSON sql.NullString
|
||||
if len(log.Attributes) > 0 {
|
||||
attributesBytes, err := json.Marshal(log.Attributes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
attributesJSON = sql.NullString{
|
||||
String: string(attributesBytes),
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Convert session ID to SQL nullable string
|
||||
var sessionID sql.NullString
|
||||
if log.SessionID != "" {
|
||||
sessionID = sql.NullString{
|
||||
String: log.SessionID,
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Insert log into database
|
||||
err := s.q.CreateLog(ctx, db.CreateLogParams{
|
||||
ID: log.ID,
|
||||
SessionID: sessionID,
|
||||
Timestamp: log.Timestamp,
|
||||
Level: log.Level,
|
||||
Message: log.Message,
|
||||
Attributes: attributesJSON,
|
||||
CreatedAt: log.CreatedAt,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Publish event
|
||||
s.Publish(pubsub.CreatedEvent, log)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListBySession retrieves logs for a specific session
|
||||
func (s *service) ListBySession(ctx context.Context, sessionID string) ([]Log, error) {
|
||||
dbLogs, err := s.q.ListLogsBySession(ctx, sql.NullString{
|
||||
String: sessionID,
|
||||
Valid: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logs := make([]Log, len(dbLogs))
|
||||
for i, dbLog := range dbLogs {
|
||||
logs[i] = s.fromDBItem(dbLog)
|
||||
}
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
// ListAll retrieves all logs with a limit
|
||||
func (s *service) ListAll(ctx context.Context, limit int) ([]Log, error) {
|
||||
dbLogs, err := s.q.ListAllLogs(ctx, int64(limit))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logs := make([]Log, len(dbLogs))
|
||||
for i, dbLog := range dbLogs {
|
||||
logs[i] = s.fromDBItem(dbLog)
|
||||
}
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
// fromDBItem converts a database log item to a Log struct
|
||||
func (s *service) fromDBItem(item db.Log) Log {
|
||||
log := Log{
|
||||
ID: item.ID,
|
||||
Timestamp: item.Timestamp,
|
||||
Level: item.Level,
|
||||
Message: item.Message,
|
||||
CreatedAt: item.CreatedAt,
|
||||
}
|
||||
|
||||
// Convert session ID if valid
|
||||
if item.SessionID.Valid {
|
||||
log.SessionID = item.SessionID.String
|
||||
}
|
||||
|
||||
// Parse attributes JSON if present
|
||||
if item.Attributes.Valid {
|
||||
attributes := make(map[string]string)
|
||||
if err := json.Unmarshal([]byte(item.Attributes.String), &attributes); err == nil {
|
||||
log.Attributes = attributes
|
||||
} else {
|
||||
// Initialize empty map if parsing fails
|
||||
log.Attributes = make(map[string]string)
|
||||
}
|
||||
} else {
|
||||
log.Attributes = make(map[string]string)
|
||||
}
|
||||
|
||||
return log
|
||||
}
|
|
@ -1,53 +0,0 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-logfmt/logfmt"
|
||||
"github.com/opencode-ai/opencode/internal/session"
|
||||
)
|
||||
|
||||
type writer struct{}
|
||||
|
||||
func (w *writer) Write(p []byte) (int, error) {
|
||||
d := logfmt.NewDecoder(bytes.NewReader(p))
|
||||
for d.ScanRecord() {
|
||||
msg := Log{}
|
||||
|
||||
for d.ScanKeyval() {
|
||||
switch string(d.Key()) {
|
||||
case "time":
|
||||
parsed, err := time.Parse(time.RFC3339, string(d.Value()))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing time: %w", err)
|
||||
}
|
||||
msg.Timestamp = parsed.UnixMilli()
|
||||
case "level":
|
||||
msg.Level = strings.ToLower(string(d.Value()))
|
||||
case "msg":
|
||||
msg.Message = string(d.Value())
|
||||
default:
|
||||
if msg.Attributes == nil {
|
||||
msg.Attributes = make(map[string]string)
|
||||
}
|
||||
msg.Attributes[string(d.Key())] = string(d.Value())
|
||||
}
|
||||
}
|
||||
|
||||
msg.SessionID = session.CurrentSessionID()
|
||||
Create(context.Background(), msg)
|
||||
}
|
||||
if d.Err() != nil {
|
||||
return 0, d.Err()
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func NewWriter() *writer {
|
||||
w := &writer{}
|
||||
return w
|
||||
}
|
|
@ -5,6 +5,9 @@ import (
|
|||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
@ -13,6 +16,12 @@ import (
|
|||
"github.com/opencode-ai/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
const (
|
||||
EventMessageCreated pubsub.EventType = "message_created"
|
||||
EventMessageUpdated pubsub.EventType = "message_updated"
|
||||
EventMessageDeleted pubsub.EventType = "message_deleted"
|
||||
)
|
||||
|
||||
type CreateMessageParams struct {
|
||||
Role MessageRole
|
||||
Parts []ContentPart
|
||||
|
@ -20,163 +29,345 @@ type CreateMessageParams struct {
|
|||
}
|
||||
|
||||
type Service interface {
|
||||
pubsub.Suscriber[Message]
|
||||
pubsub.Subscriber[Message]
|
||||
|
||||
Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
|
||||
Update(ctx context.Context, message Message) error
|
||||
Update(ctx context.Context, message Message) (Message, error)
|
||||
Get(ctx context.Context, id string) (Message, error)
|
||||
List(ctx context.Context, sessionID string) ([]Message, error)
|
||||
ListAfter(ctx context.Context, sessionID string, timestamp int64) ([]Message, error)
|
||||
ListAfter(ctx context.Context, sessionID string, timestampMillis int64) ([]Message, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
DeleteSessionMessages(ctx context.Context, sessionID string) error
|
||||
}
|
||||
|
||||
type service struct {
|
||||
*pubsub.Broker[Message]
|
||||
q db.Querier
|
||||
db *db.Queries
|
||||
broker *pubsub.Broker[Message]
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewService(q db.Querier) Service {
|
||||
return &service{
|
||||
Broker: pubsub.NewBroker[Message](),
|
||||
q: q,
|
||||
}
|
||||
}
|
||||
var globalMessageService *service
|
||||
|
||||
func (s *service) Delete(ctx context.Context, id string) error {
|
||||
message, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
func InitService(dbConn *sql.DB) error {
|
||||
if globalMessageService != nil {
|
||||
return fmt.Errorf("message service already initialized")
|
||||
}
|
||||
err = s.q.DeleteMessage(ctx, message.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
queries := db.New(dbConn)
|
||||
broker := pubsub.NewBroker[Message]()
|
||||
|
||||
globalMessageService = &service{
|
||||
db: queries,
|
||||
broker: broker,
|
||||
}
|
||||
s.Publish(pubsub.DeletedEvent, message)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
|
||||
if params.Role != Assistant {
|
||||
params.Parts = append(params.Parts, Finish{
|
||||
Reason: "stop",
|
||||
})
|
||||
func GetService() Service {
|
||||
if globalMessageService == nil {
|
||||
panic("message service not initialized. Call message.InitService() first.")
|
||||
}
|
||||
return globalMessageService
|
||||
}
|
||||
|
||||
func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
isFinished := false
|
||||
for _, p := range params.Parts {
|
||||
if _, ok := p.(Finish); ok {
|
||||
isFinished = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if params.Role == User && !isFinished {
|
||||
params.Parts = append(params.Parts, Finish{Reason: FinishReasonEndTurn, Time: time.Now().UnixMilli()})
|
||||
}
|
||||
|
||||
partsJSON, err := marshallParts(params.Parts)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
return Message{}, fmt.Errorf("failed to marshal message parts: %w", err)
|
||||
}
|
||||
dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
|
||||
|
||||
dbMsgParams := db.CreateMessageParams{
|
||||
ID: uuid.New().String(),
|
||||
SessionID: sessionID,
|
||||
Role: string(params.Role),
|
||||
Parts: string(partsJSON),
|
||||
Model: sql.NullString{String: string(params.Model), Valid: true},
|
||||
})
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
Model: sql.NullString{String: string(params.Model), Valid: params.Model != ""},
|
||||
}
|
||||
|
||||
dbMessage, err := s.db.CreateMessage(ctx, dbMsgParams)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("db.CreateMessage: %w", err)
|
||||
}
|
||||
|
||||
message, err := s.fromDBItem(dbMessage)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
return Message{}, fmt.Errorf("failed to convert DB message: %w", err)
|
||||
}
|
||||
s.Publish(pubsub.CreatedEvent, message)
|
||||
|
||||
s.broker.Publish(EventMessageCreated, message)
|
||||
return message, nil
|
||||
}
|
||||
|
||||
func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
|
||||
messages, err := s.List(ctx, sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
func (s *service) Update(ctx context.Context, message Message) (Message, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if message.ID == "" {
|
||||
return Message{}, fmt.Errorf("cannot update message with empty ID")
|
||||
}
|
||||
for _, message := range messages {
|
||||
if message.SessionID == sessionID {
|
||||
err = s.Delete(ctx, message.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
partsJSON, err := marshallParts(message.Parts)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("failed to marshal message parts for update: %w", err)
|
||||
}
|
||||
|
||||
var dbFinishedAt sql.NullInt64
|
||||
finishPart := message.FinishPart()
|
||||
if finishPart != nil && finishPart.Time > 0 {
|
||||
dbFinishedAt = sql.NullInt64{
|
||||
Int64: finishPart.Time / 1000, // Convert Milliseconds from Go struct to Seconds for DB
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Update(ctx context.Context, message Message) error {
|
||||
parts, err := marshallParts(message.Parts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
finishedAt := sql.NullInt64{}
|
||||
if f := message.FinishPart(); f != nil {
|
||||
finishedAt.Int64 = f.Time
|
||||
finishedAt.Valid = true
|
||||
}
|
||||
err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
|
||||
// UpdatedAt is handled by the DB trigger (strftime('%s', 'now'))
|
||||
err = s.db.UpdateMessage(ctx, db.UpdateMessageParams{
|
||||
ID: message.ID,
|
||||
Parts: string(parts),
|
||||
FinishedAt: finishedAt,
|
||||
Parts: string(partsJSON),
|
||||
FinishedAt: dbFinishedAt,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return Message{}, fmt.Errorf("db.UpdateMessage: %w", err)
|
||||
}
|
||||
message.UpdatedAt = time.Now().Unix()
|
||||
s.Publish(pubsub.UpdatedEvent, message)
|
||||
return nil
|
||||
|
||||
dbUpdatedMessage, err := s.db.GetMessage(ctx, message.ID)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("failed to fetch message after update: %w", err)
|
||||
}
|
||||
updatedMessage, err := s.fromDBItem(dbUpdatedMessage)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("failed to convert updated DB message: %w", err)
|
||||
}
|
||||
|
||||
s.broker.Publish(EventMessageUpdated, updatedMessage)
|
||||
return updatedMessage, nil
|
||||
}
|
||||
|
||||
func (s *service) Get(ctx context.Context, id string) (Message, error) {
|
||||
dbMessage, err := s.q.GetMessage(ctx, id)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
dbMessage, err := s.db.GetMessage(ctx, id)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
if err == sql.ErrNoRows {
|
||||
return Message{}, fmt.Errorf("message with ID '%s' not found", id)
|
||||
}
|
||||
return Message{}, fmt.Errorf("db.GetMessage: %w", err)
|
||||
}
|
||||
return s.fromDBItem(dbMessage)
|
||||
}
|
||||
|
||||
func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
|
||||
dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
dbMessages, err := s.db.ListMessagesBySession(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("db.ListMessagesBySession: %w", err)
|
||||
}
|
||||
messages := make([]Message, len(dbMessages))
|
||||
for i, dbMessage := range dbMessages {
|
||||
messages[i], err = s.fromDBItem(dbMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
for i, dbMsg := range dbMessages {
|
||||
msg, convErr := s.fromDBItem(dbMsg)
|
||||
if convErr != nil {
|
||||
return nil, fmt.Errorf("failed to convert DB message at index %d: %w", i, convErr)
|
||||
}
|
||||
messages[i] = msg
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (s *service) ListAfter(ctx context.Context, sessionID string, timestamp int64) ([]Message, error) {
|
||||
dbMessages, err := s.q.ListMessagesBySessionAfter(ctx, db.ListMessagesBySessionAfterParams{
|
||||
func (s *service) ListAfter(ctx context.Context, sessionID string, timestampMillis int64) ([]Message, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
timestampSeconds := timestampMillis / 1000 // Convert to seconds for DB query
|
||||
|
||||
dbMessages, err := s.db.ListMessagesBySessionAfter(ctx, db.ListMessagesBySessionAfterParams{
|
||||
SessionID: sessionID,
|
||||
CreatedAt: timestamp,
|
||||
CreatedAt: timestampSeconds,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListMessagesBySessionAfter: %w", err)
|
||||
}
|
||||
messages := make([]Message, len(dbMessages))
|
||||
for i, dbMsg := range dbMessages {
|
||||
msg, convErr := s.fromDBItem(dbMsg)
|
||||
if convErr != nil {
|
||||
return nil, fmt.Errorf("failed to convert DB message at index %d (ListAfter): %w", i, convErr)
|
||||
}
|
||||
messages[i] = msg
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (s *service) Delete(ctx context.Context, id string) error {
|
||||
s.mu.Lock()
|
||||
messageToPublish, err := s.getServiceForPublish(ctx, id)
|
||||
s.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
// If error was due to not found, it's not a critical failure for deletion intent
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil // Or return the error if strictness is required
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
err = s.db.DeleteMessage(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.DeleteMessage: %w", err)
|
||||
}
|
||||
|
||||
if messageToPublish != nil {
|
||||
s.broker.Publish(EventMessageDeleted, *messageToPublish)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) getServiceForPublish(ctx context.Context, id string) (*Message, error) {
|
||||
dbMsg, err := s.db.GetMessage(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages := make([]Message, len(dbMessages))
|
||||
for i, dbMessage := range dbMessages {
|
||||
messages[i], err = s.fromDBItem(dbMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
msg, convErr := s.fromDBItem(dbMsg)
|
||||
if convErr != nil {
|
||||
return nil, fmt.Errorf("failed to convert DB message for publishing: %w", convErr)
|
||||
}
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
messagesToDelete, err := s.db.ListMessagesBySession(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list messages for deletion: %w", err)
|
||||
}
|
||||
|
||||
err = s.db.DeleteSessionMessages(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.DeleteSessionMessages: %w", err)
|
||||
}
|
||||
|
||||
for _, dbMsg := range messagesToDelete {
|
||||
msg, convErr := s.fromDBItem(dbMsg)
|
||||
if convErr == nil {
|
||||
s.broker.Publish(EventMessageDeleted, msg)
|
||||
} else {
|
||||
slog.Error("Failed to convert DB message for delete event publishing", "id", dbMsg.ID, "error", convErr)
|
||||
}
|
||||
}
|
||||
return messages, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Message] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func (s *service) fromDBItem(item db.Message) (Message, error) {
|
||||
parts, err := unmarshallParts([]byte(item.Parts))
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
return Message{}, fmt.Errorf("unmarshallParts for message ID %s: %w. Raw parts: %s", item.ID, err, item.Parts)
|
||||
}
|
||||
return Message{
|
||||
|
||||
// DB stores created_at, updated_at, finished_at as Unix seconds.
|
||||
// Go struct Message stores them as Unix milliseconds.
|
||||
createdAtMillis := item.CreatedAt * 1000
|
||||
updatedAtMillis := item.UpdatedAt * 1000
|
||||
|
||||
msg := Message{
|
||||
ID: item.ID,
|
||||
SessionID: item.SessionID,
|
||||
Role: MessageRole(item.Role),
|
||||
Parts: parts,
|
||||
Model: models.ModelID(item.Model.String),
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
}, nil
|
||||
CreatedAt: createdAtMillis,
|
||||
UpdatedAt: updatedAtMillis,
|
||||
}
|
||||
|
||||
// Ensure Finish part in msg.Parts reflects the item.FinishedAt state
|
||||
// if item.FinishedAt is the source of truth for the "overall message finished time".
|
||||
// The `unmarshallParts` should already create a Finish part if it's in the JSON.
|
||||
// This logic reconciles the DB column with the JSON parts.
|
||||
var existingFinishPart *Finish
|
||||
var finishPartIndex = -1
|
||||
|
||||
for i, p := range msg.Parts {
|
||||
if fp, ok := p.(Finish); ok {
|
||||
existingFinishPart = &fp
|
||||
finishPartIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if item.FinishedAt.Valid && item.FinishedAt.Int64 > 0 {
|
||||
dbFinishTimeMillis := item.FinishedAt.Int64 * 1000
|
||||
if existingFinishPart != nil {
|
||||
// If a Finish part exists from JSON, update its time if DB's time is different.
|
||||
// This assumes DB `finished_at` is the ultimate source of truth for when the message truly finished.
|
||||
if existingFinishPart.Time != dbFinishTimeMillis {
|
||||
slog.Debug("Aligning Finish part time with DB finished_at", "message_id", msg.ID, "json_finish_time", existingFinishPart.Time, "db_finish_time", dbFinishTimeMillis)
|
||||
existingFinishPart.Time = dbFinishTimeMillis
|
||||
msg.Parts[finishPartIndex] = *existingFinishPart
|
||||
}
|
||||
} else {
|
||||
// If no Finish part in JSON but DB says it's finished, add one.
|
||||
// We might not know the original FinishReason here, so use a sensible default or leave it to be set by Update.
|
||||
// This scenario should be less common if `Update` always ensures a Finish part for finished messages.
|
||||
slog.Debug("Synthesizing Finish part from DB finished_at", "message_id", msg.ID)
|
||||
msg.Parts = append(msg.Parts, Finish{Reason: FinishReasonEndTurn, Time: dbFinishTimeMillis})
|
||||
}
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
|
||||
return GetService().Create(ctx, sessionID, params)
|
||||
}
|
||||
|
||||
func Update(ctx context.Context, message Message) (Message, error) {
|
||||
return GetService().Update(ctx, message)
|
||||
}
|
||||
|
||||
func Get(ctx context.Context, id string) (Message, error) {
|
||||
return GetService().Get(ctx, id)
|
||||
}
|
||||
|
||||
func List(ctx context.Context, sessionID string) ([]Message, error) {
|
||||
return GetService().List(ctx, sessionID)
|
||||
}
|
||||
|
||||
func ListAfter(ctx context.Context, sessionID string, timestampMillis int64) ([]Message, error) {
|
||||
return GetService().ListAfter(ctx, sessionID, timestampMillis)
|
||||
}
|
||||
|
||||
func Delete(ctx context.Context, id string) error {
|
||||
return GetService().Delete(ctx, id)
|
||||
}
|
||||
|
||||
func DeleteSessionMessages(ctx context.Context, sessionID string) error {
|
||||
return GetService().DeleteSessionMessages(ctx, sessionID)
|
||||
}
|
||||
|
||||
func SubscribeToEvents(ctx context.Context) <-chan pubsub.Event[Message] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
|
||||
type partType string
|
||||
|
@ -192,109 +383,143 @@ const (
|
|||
)
|
||||
|
||||
type partWrapper struct {
|
||||
Type partType `json:"type"`
|
||||
Data ContentPart `json:"data"`
|
||||
Type partType `json:"type"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
func marshallParts(parts []ContentPart) ([]byte, error) {
|
||||
wrappedParts := make([]partWrapper, len(parts))
|
||||
|
||||
wrappedParts := make([]json.RawMessage, len(parts))
|
||||
for i, part := range parts {
|
||||
var typ partType
|
||||
var dataBytes []byte
|
||||
var err error
|
||||
|
||||
switch part.(type) {
|
||||
switch p := part.(type) {
|
||||
case ReasoningContent:
|
||||
typ = reasoningType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case TextContent:
|
||||
typ = textType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case *TextContent:
|
||||
typ = textType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case ImageURLContent:
|
||||
typ = imageURLType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case BinaryContent:
|
||||
typ = binaryType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case ToolCall:
|
||||
typ = toolCallType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case ToolResult:
|
||||
typ = toolResultType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case Finish:
|
||||
typ = finishType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown part type: %T", part)
|
||||
return nil, fmt.Errorf("unknown part type for marshalling: %T", part)
|
||||
}
|
||||
|
||||
wrappedParts[i] = partWrapper{
|
||||
Type: typ,
|
||||
Data: part,
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal part data for type %s: %w", typ, err)
|
||||
}
|
||||
wrapper := struct {
|
||||
Type partType `json:"type"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}{Type: typ, Data: dataBytes}
|
||||
wrappedBytes, err := json.Marshal(wrapper)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal part wrapper for type %s: %w", typ, err)
|
||||
}
|
||||
wrappedParts[i] = wrappedBytes
|
||||
}
|
||||
return json.Marshal(wrappedParts)
|
||||
}
|
||||
|
||||
func unmarshallParts(data []byte) ([]ContentPart, error) {
|
||||
temp := []json.RawMessage{}
|
||||
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return nil, err
|
||||
var rawMessages []json.RawMessage
|
||||
if err := json.Unmarshal(data, &rawMessages); err != nil {
|
||||
// Handle case where 'parts' might be a single object if not an array initially
|
||||
// This was a fallback, if your DB always stores an array, this might not be needed.
|
||||
var singleRawMessage json.RawMessage
|
||||
if errSingle := json.Unmarshal(data, &singleRawMessage); errSingle == nil {
|
||||
rawMessages = []json.RawMessage{singleRawMessage}
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to unmarshal parts data as array: %w. Data: %s", err, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
parts := make([]ContentPart, 0)
|
||||
|
||||
for _, rawPart := range temp {
|
||||
var wrapper struct {
|
||||
Type partType `json:"type"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
parts := make([]ContentPart, 0, len(rawMessages))
|
||||
for _, rawPart := range rawMessages {
|
||||
var wrapper partWrapper
|
||||
if err := json.Unmarshal(rawPart, &wrapper); err != nil {
|
||||
return nil, err
|
||||
// Fallback for old format where parts might be just TextContent string
|
||||
var text string
|
||||
if errText := json.Unmarshal(rawPart, &text); errText == nil {
|
||||
parts = append(parts, TextContent{Text: text})
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("failed to unmarshal part wrapper: %w. Raw part: %s", err, string(rawPart))
|
||||
}
|
||||
|
||||
switch wrapper.Type {
|
||||
case reasoningType:
|
||||
part := ReasoningContent{}
|
||||
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
||||
return nil, err
|
||||
var p ReasoningContent
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal ReasoningContent: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, part)
|
||||
parts = append(parts, p)
|
||||
case textType:
|
||||
part := TextContent{}
|
||||
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
||||
return nil, err
|
||||
var p TextContent
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal TextContent: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, part)
|
||||
parts = append(parts, p)
|
||||
case imageURLType:
|
||||
part := ImageURLContent{}
|
||||
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
||||
return nil, err
|
||||
var p ImageURLContent
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal ImageURLContent: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, p)
|
||||
case binaryType:
|
||||
part := BinaryContent{}
|
||||
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
||||
return nil, err
|
||||
var p BinaryContent
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal BinaryContent: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, part)
|
||||
parts = append(parts, p)
|
||||
case toolCallType:
|
||||
part := ToolCall{}
|
||||
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
||||
return nil, err
|
||||
var p ToolCall
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal ToolCall: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, part)
|
||||
parts = append(parts, p)
|
||||
case toolResultType:
|
||||
part := ToolResult{}
|
||||
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
||||
return nil, err
|
||||
var p ToolResult
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal ToolResult: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, part)
|
||||
parts = append(parts, p)
|
||||
case finishType:
|
||||
part := Finish{}
|
||||
if err := json.Unmarshal(wrapper.Data, &part); err != nil {
|
||||
return nil, err
|
||||
var p Finish
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal Finish: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, part)
|
||||
parts = append(parts, p)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
|
||||
slog.Warn("Unknown part type during unmarshalling, attempting to parse as TextContent", "type", wrapper.Type, "data", string(wrapper.Data))
|
||||
// Fallback: if type is unknown or empty, try to parse data as TextContent directly
|
||||
var p TextContent
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err == nil {
|
||||
parts = append(parts, p)
|
||||
} else {
|
||||
// If that also fails, log it but continue if possible, or return error
|
||||
slog.Error("Failed to unmarshal unknown part type and fallback to TextContent failed", "type", wrapper.Type, "data", string(wrapper.Data), "error", err)
|
||||
// Depending on strictness, you might return an error here:
|
||||
// return nil, fmt.Errorf("unknown part type '%s' and failed fallback: %w", wrapper.Type, err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return parts, nil
|
||||
}
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
package permission
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/pubsub"
|
||||
|
@ -32,56 +36,141 @@ type PermissionRequest struct {
|
|||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
type PermissionResponse struct {
|
||||
Request PermissionRequest
|
||||
Granted bool
|
||||
}
|
||||
|
||||
const (
|
||||
EventPermissionRequested pubsub.EventType = "permission_requested"
|
||||
EventPermissionGranted pubsub.EventType = "permission_granted"
|
||||
EventPermissionDenied pubsub.EventType = "permission_denied"
|
||||
EventPermissionPersisted pubsub.EventType = "permission_persisted"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
pubsub.Suscriber[PermissionRequest]
|
||||
GrantPersistant(permission PermissionRequest)
|
||||
Grant(permission PermissionRequest)
|
||||
Deny(permission PermissionRequest)
|
||||
Request(opts CreatePermissionRequest) bool
|
||||
AutoApproveSession(sessionID string)
|
||||
pubsub.Subscriber[PermissionRequest]
|
||||
SubscribeToResponseEvents(ctx context.Context) <-chan pubsub.Event[PermissionResponse]
|
||||
|
||||
GrantPersistant(ctx context.Context, permission PermissionRequest)
|
||||
Grant(ctx context.Context, permission PermissionRequest)
|
||||
Deny(ctx context.Context, permission PermissionRequest)
|
||||
Request(ctx context.Context, opts CreatePermissionRequest) bool
|
||||
AutoApproveSession(ctx context.Context, sessionID string)
|
||||
IsAutoApproved(ctx context.Context, sessionID string) bool
|
||||
}
|
||||
|
||||
type permissionService struct {
|
||||
*pubsub.Broker[PermissionRequest]
|
||||
broker *pubsub.Broker[PermissionRequest]
|
||||
responseBroker *pubsub.Broker[PermissionResponse]
|
||||
|
||||
sessionPermissions []PermissionRequest
|
||||
sessionPermissions map[string][]PermissionRequest
|
||||
pendingRequests sync.Map
|
||||
autoApproveSessions []string
|
||||
autoApproveSessions map[string]bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *permissionService) GrantPersistant(permission PermissionRequest) {
|
||||
var globalPermissionService *permissionService
|
||||
|
||||
func InitService() error {
|
||||
if globalPermissionService != nil {
|
||||
return fmt.Errorf("permission service already initialized")
|
||||
}
|
||||
globalPermissionService = &permissionService{
|
||||
broker: pubsub.NewBroker[PermissionRequest](),
|
||||
responseBroker: pubsub.NewBroker[PermissionResponse](),
|
||||
sessionPermissions: make(map[string][]PermissionRequest),
|
||||
autoApproveSessions: make(map[string]bool),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetService() *permissionService {
|
||||
if globalPermissionService == nil {
|
||||
panic("permission service not initialized. Call permission.InitService() first.")
|
||||
}
|
||||
return globalPermissionService
|
||||
}
|
||||
|
||||
func (s *permissionService) GrantPersistant(ctx context.Context, permission PermissionRequest) {
|
||||
s.mu.Lock()
|
||||
s.sessionPermissions[permission.SessionID] = append(s.sessionPermissions[permission.SessionID], permission)
|
||||
s.mu.Unlock()
|
||||
|
||||
respCh, ok := s.pendingRequests.Load(permission.ID)
|
||||
if ok {
|
||||
respCh.(chan bool) <- true
|
||||
select {
|
||||
case respCh.(chan bool) <- true:
|
||||
case <-ctx.Done():
|
||||
slog.Warn("Context cancelled while sending grant persistent response", "request_id", permission.ID)
|
||||
}
|
||||
}
|
||||
s.sessionPermissions = append(s.sessionPermissions, permission)
|
||||
s.responseBroker.Publish(EventPermissionPersisted, PermissionResponse{Request: permission, Granted: true})
|
||||
}
|
||||
|
||||
func (s *permissionService) Grant(permission PermissionRequest) {
|
||||
func (s *permissionService) Grant(ctx context.Context, permission PermissionRequest) {
|
||||
respCh, ok := s.pendingRequests.Load(permission.ID)
|
||||
if ok {
|
||||
respCh.(chan bool) <- true
|
||||
select {
|
||||
case respCh.(chan bool) <- true:
|
||||
case <-ctx.Done():
|
||||
slog.Warn("Context cancelled while sending grant response", "request_id", permission.ID)
|
||||
}
|
||||
}
|
||||
s.responseBroker.Publish(EventPermissionGranted, PermissionResponse{Request: permission, Granted: true})
|
||||
}
|
||||
|
||||
func (s *permissionService) Deny(permission PermissionRequest) {
|
||||
func (s *permissionService) Deny(ctx context.Context, permission PermissionRequest) {
|
||||
respCh, ok := s.pendingRequests.Load(permission.ID)
|
||||
if ok {
|
||||
respCh.(chan bool) <- false
|
||||
select {
|
||||
case respCh.(chan bool) <- false:
|
||||
case <-ctx.Done():
|
||||
slog.Warn("Context cancelled while sending deny response", "request_id", permission.ID)
|
||||
}
|
||||
}
|
||||
s.responseBroker.Publish(EventPermissionDenied, PermissionResponse{Request: permission, Granted: false})
|
||||
}
|
||||
|
||||
func (s *permissionService) Request(opts CreatePermissionRequest) bool {
|
||||
if slices.Contains(s.autoApproveSessions, opts.SessionID) {
|
||||
func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRequest) bool {
|
||||
s.mu.RLock()
|
||||
if s.autoApproveSessions[opts.SessionID] {
|
||||
s.mu.RUnlock()
|
||||
return true
|
||||
}
|
||||
dir := filepath.Dir(opts.Path)
|
||||
if dir == "." {
|
||||
dir = config.WorkingDirectory()
|
||||
|
||||
requestPath := opts.Path
|
||||
if !filepath.IsAbs(requestPath) {
|
||||
requestPath = filepath.Join(config.WorkingDirectory(), requestPath)
|
||||
}
|
||||
permission := PermissionRequest{
|
||||
requestPath = filepath.Clean(requestPath)
|
||||
|
||||
if permissions, ok := s.sessionPermissions[opts.SessionID]; ok {
|
||||
for _, p := range permissions {
|
||||
storedPath := p.Path
|
||||
if !filepath.IsAbs(storedPath) {
|
||||
storedPath = filepath.Join(config.WorkingDirectory(), storedPath)
|
||||
}
|
||||
storedPath = filepath.Clean(storedPath)
|
||||
|
||||
if p.ToolName == opts.ToolName && p.Action == opts.Action &&
|
||||
(requestPath == storedPath || strings.HasPrefix(requestPath, storedPath+string(filepath.Separator))) {
|
||||
s.mu.RUnlock()
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
normalizedPath := opts.Path
|
||||
if !filepath.IsAbs(normalizedPath) {
|
||||
normalizedPath = filepath.Join(config.WorkingDirectory(), normalizedPath)
|
||||
}
|
||||
normalizedPath = filepath.Clean(normalizedPath)
|
||||
|
||||
permissionReq := PermissionRequest{
|
||||
ID: uuid.New().String(),
|
||||
Path: dir,
|
||||
Path: normalizedPath,
|
||||
SessionID: opts.SessionID,
|
||||
ToolName: opts.ToolName,
|
||||
Description: opts.Description,
|
||||
|
@ -89,31 +178,69 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
|
|||
Params: opts.Params,
|
||||
}
|
||||
|
||||
for _, p := range s.sessionPermissions {
|
||||
if p.ToolName == permission.ToolName && p.Action == permission.Action && p.SessionID == permission.SessionID && p.Path == permission.Path {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
respCh := make(chan bool, 1)
|
||||
s.pendingRequests.Store(permissionReq.ID, respCh)
|
||||
defer s.pendingRequests.Delete(permissionReq.ID)
|
||||
|
||||
s.pendingRequests.Store(permission.ID, respCh)
|
||||
defer s.pendingRequests.Delete(permission.ID)
|
||||
s.broker.Publish(EventPermissionRequested, permissionReq)
|
||||
|
||||
s.Publish(pubsub.CreatedEvent, permission)
|
||||
|
||||
// Wait for the response with a timeout
|
||||
resp := <-respCh
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *permissionService) AutoApproveSession(sessionID string) {
|
||||
s.autoApproveSessions = append(s.autoApproveSessions, sessionID)
|
||||
}
|
||||
|
||||
func NewPermissionService() Service {
|
||||
return &permissionService{
|
||||
Broker: pubsub.NewBroker[PermissionRequest](),
|
||||
sessionPermissions: make([]PermissionRequest, 0),
|
||||
select {
|
||||
case resp := <-respCh:
|
||||
return resp
|
||||
case <-ctx.Done():
|
||||
slog.Warn("Permission request timed out or context cancelled", "request_id", permissionReq.ID, "tool", opts.ToolName)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *permissionService) AutoApproveSession(ctx context.Context, sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.autoApproveSessions[sessionID] = true
|
||||
}
|
||||
|
||||
func (s *permissionService) IsAutoApproved(ctx context.Context, sessionID string) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.autoApproveSessions[sessionID]
|
||||
}
|
||||
|
||||
func (s *permissionService) Subscribe(ctx context.Context) <-chan pubsub.Event[PermissionRequest] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func (s *permissionService) SubscribeToResponseEvents(ctx context.Context) <-chan pubsub.Event[PermissionResponse] {
|
||||
return s.responseBroker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func GrantPersistant(ctx context.Context, permission PermissionRequest) {
|
||||
GetService().GrantPersistant(ctx, permission)
|
||||
}
|
||||
|
||||
func Grant(ctx context.Context, permission PermissionRequest) {
|
||||
GetService().Grant(ctx, permission)
|
||||
}
|
||||
|
||||
func Deny(ctx context.Context, permission PermissionRequest) {
|
||||
GetService().Deny(ctx, permission)
|
||||
}
|
||||
|
||||
func Request(ctx context.Context, opts CreatePermissionRequest) bool {
|
||||
return GetService().Request(ctx, opts)
|
||||
}
|
||||
|
||||
func AutoApproveSession(ctx context.Context, sessionID string) {
|
||||
GetService().AutoApproveSession(ctx, sessionID)
|
||||
}
|
||||
|
||||
func IsAutoApproved(ctx context.Context, sessionID string) bool {
|
||||
return GetService().IsAutoApproved(ctx, sessionID)
|
||||
}
|
||||
|
||||
func SubscribeToRequests(ctx context.Context) <-chan pubsub.Event[PermissionRequest] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
|
||||
func SubscribeToResponses(ctx context.Context) <-chan pubsub.Event[PermissionResponse] {
|
||||
return GetService().SubscribeToResponseEvents(ctx)
|
||||
}
|
||||
|
|
|
@ -2,136 +2,112 @@ package pubsub
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const bufferSize = 1000
|
||||
const defaultChannelBufferSize = 100
|
||||
|
||||
type Broker[T any] struct {
|
||||
subs map[chan Event[T]]struct{}
|
||||
mu sync.RWMutex
|
||||
done chan struct{}
|
||||
subCount int
|
||||
maxEvents int
|
||||
subs map[chan Event[T]]context.CancelFunc
|
||||
mu sync.RWMutex
|
||||
isClosed bool
|
||||
}
|
||||
|
||||
func NewBroker[T any]() *Broker[T] {
|
||||
return NewBrokerWithOptions[T](bufferSize, 1000)
|
||||
}
|
||||
|
||||
func NewBrokerWithOptions[T any](channelBufferSize, maxEvents int) *Broker[T] {
|
||||
b := &Broker[T]{
|
||||
subs: make(map[chan Event[T]]struct{}),
|
||||
done: make(chan struct{}),
|
||||
subCount: 0,
|
||||
maxEvents: maxEvents,
|
||||
return &Broker[T]{
|
||||
subs: make(map[chan Event[T]]context.CancelFunc),
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Broker[T]) Shutdown() {
|
||||
select {
|
||||
case <-b.done: // Already closed
|
||||
return
|
||||
default:
|
||||
close(b.done)
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
for ch := range b.subs {
|
||||
delete(b.subs, ch)
|
||||
close(ch)
|
||||
if b.isClosed {
|
||||
b.mu.Unlock()
|
||||
return
|
||||
}
|
||||
b.isClosed = true
|
||||
|
||||
b.subCount = 0
|
||||
for ch, cancel := range b.subs {
|
||||
cancel()
|
||||
close(ch)
|
||||
delete(b.subs, ch)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
slog.Debug("PubSub broker shut down", "type", fmt.Sprintf("%T", *new(T)))
|
||||
}
|
||||
|
||||
func (b *Broker[T]) Subscribe(ctx context.Context) <-chan Event[T] {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-b.done:
|
||||
ch := make(chan Event[T])
|
||||
close(ch)
|
||||
return ch
|
||||
default:
|
||||
if b.isClosed {
|
||||
closedCh := make(chan Event[T])
|
||||
close(closedCh)
|
||||
return closedCh
|
||||
}
|
||||
|
||||
sub := make(chan Event[T], bufferSize)
|
||||
b.subs[sub] = struct{}{}
|
||||
b.subCount++
|
||||
subCtx, subCancel := context.WithCancel(ctx)
|
||||
subscriberChannel := make(chan Event[T], defaultChannelBufferSize)
|
||||
b.subs[subscriberChannel] = subCancel
|
||||
|
||||
// Only start a goroutine if the context can actually be canceled
|
||||
if ctx.Done() != nil {
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
go func() {
|
||||
<-subCtx.Done()
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if _, ok := b.subs[subscriberChannel]; ok {
|
||||
close(subscriberChannel)
|
||||
delete(b.subs, subscriberChannel)
|
||||
}
|
||||
}()
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return subscriberChannel
|
||||
}
|
||||
|
||||
select {
|
||||
case <-b.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
func (b *Broker[T]) Publish(eventType EventType, payload T) {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
if _, exists := b.subs[sub]; exists {
|
||||
delete(b.subs, sub)
|
||||
close(sub)
|
||||
b.subCount--
|
||||
}
|
||||
}()
|
||||
if b.isClosed {
|
||||
slog.Warn("Attempted to publish on a closed pubsub broker", "type", eventType, "payload_type", fmt.Sprintf("%T", payload))
|
||||
return
|
||||
}
|
||||
|
||||
return sub
|
||||
event := Event[T]{Type: eventType, Payload: payload}
|
||||
|
||||
for ch := range b.subs {
|
||||
// Non-blocking send with a fallback to a goroutine to prevent slow subscribers
|
||||
// from blocking the publisher.
|
||||
select {
|
||||
case ch <- event:
|
||||
// Successfully sent
|
||||
default:
|
||||
// Subscriber channel is full or receiver is slow.
|
||||
// Send in a new goroutine to avoid blocking the publisher.
|
||||
// This might lead to out-of-order delivery for this specific slow subscriber.
|
||||
go func(sChan chan Event[T], ev Event[T]) {
|
||||
// Re-check if broker is closed before attempting send in goroutine
|
||||
b.mu.RLock()
|
||||
isBrokerClosed := b.isClosed
|
||||
b.mu.RUnlock()
|
||||
if isBrokerClosed {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case sChan <- ev:
|
||||
case <-time.After(2 * time.Second): // Timeout for slow subscriber
|
||||
slog.Warn("PubSub: Dropped event for slow subscriber after timeout", "type", ev.Type)
|
||||
}
|
||||
}(ch, event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker[T]) GetSubscriberCount() int {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
return b.subCount
|
||||
}
|
||||
|
||||
func (b *Broker[T]) Publish(t EventType, payload T) {
|
||||
b.mu.RLock()
|
||||
select {
|
||||
case <-b.done:
|
||||
b.mu.RUnlock()
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
subscribers := make([]chan Event[T], 0, len(b.subs))
|
||||
for sub := range b.subs {
|
||||
subscribers = append(subscribers, sub)
|
||||
}
|
||||
b.mu.RUnlock()
|
||||
|
||||
event := Event[T]{Type: t, Payload: payload}
|
||||
|
||||
for _, sub := range subscribers {
|
||||
select {
|
||||
case sub <- event:
|
||||
// Successfully sent
|
||||
case <-b.done:
|
||||
// Broker is shutting down
|
||||
return
|
||||
default:
|
||||
// Channel is full, but we don't want to block
|
||||
// Log this situation or consider other strategies
|
||||
// For now, we'll create a new goroutine to ensure delivery
|
||||
go func(ch chan Event[T], evt Event[T]) {
|
||||
select {
|
||||
case ch <- evt:
|
||||
// Successfully sent
|
||||
case <-b.done:
|
||||
// Broker is shutting down
|
||||
return
|
||||
}
|
||||
}(sub, event)
|
||||
}
|
||||
}
|
||||
return len(b.subs)
|
||||
}
|
||||
|
|
|
@ -2,27 +2,23 @@ package pubsub
|
|||
|
||||
import "context"
|
||||
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
CreatedEvent EventType = "created"
|
||||
UpdatedEvent EventType = "updated"
|
||||
DeletedEvent EventType = "deleted"
|
||||
EventTypeCreated EventType = "created"
|
||||
EventTypeUpdated EventType = "updated"
|
||||
EventTypeDeleted EventType = "deleted"
|
||||
)
|
||||
|
||||
type Suscriber[T any] interface {
|
||||
Subscribe(context.Context) <-chan Event[T]
|
||||
type Event[T any] struct {
|
||||
Type EventType
|
||||
Payload T
|
||||
}
|
||||
|
||||
type (
|
||||
// EventType identifies the type of event
|
||||
EventType string
|
||||
type Subscriber[T any] interface {
|
||||
Subscribe(ctx context.Context) <-chan Event[T]
|
||||
}
|
||||
|
||||
// Event represents an event in the lifecycle of a resource
|
||||
Event[T any] struct {
|
||||
Type EventType
|
||||
Payload T
|
||||
}
|
||||
|
||||
Publisher[T any] interface {
|
||||
Publish(EventType, T)
|
||||
}
|
||||
)
|
||||
type Publisher[T any] interface {
|
||||
Publish(eventType EventType, payload T)
|
||||
}
|
||||
|
|
|
@ -1,85 +0,0 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/pubsub"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// Manager handles session management, tracking the currently active session.
|
||||
type Manager struct {
|
||||
currentSessionID string
|
||||
service Service
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Global instance of the session manager
|
||||
var globalManager *Manager
|
||||
|
||||
// InitManager initializes the global session manager with the provided service.
|
||||
func InitManager(service Service) {
|
||||
globalManager = &Manager{
|
||||
currentSessionID: "",
|
||||
service: service,
|
||||
}
|
||||
|
||||
// Subscribe to session events to handle session deletions
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
eventCh := service.Subscribe(ctx)
|
||||
for event := range eventCh {
|
||||
if event.Type == pubsub.DeletedEvent && event.Payload.ID == CurrentSessionID() {
|
||||
// If the current session is deleted, clear the current session
|
||||
SetCurrentSession("")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// SetCurrentSession changes the active session to the one with the specified ID.
|
||||
func SetCurrentSession(sessionID string) {
|
||||
if globalManager == nil {
|
||||
slog.Warn("Session manager not initialized")
|
||||
return
|
||||
}
|
||||
|
||||
globalManager.mu.Lock()
|
||||
defer globalManager.mu.Unlock()
|
||||
|
||||
globalManager.currentSessionID = sessionID
|
||||
slog.Debug("Current session changed", "sessionID", sessionID)
|
||||
}
|
||||
|
||||
// CurrentSessionID returns the ID of the currently active session.
|
||||
func CurrentSessionID() string {
|
||||
if globalManager == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// globalManager.mu.RLock()
|
||||
// defer globalManager.mu.RUnlock()
|
||||
|
||||
return globalManager.currentSessionID
|
||||
}
|
||||
|
||||
// CurrentSession returns the currently active session.
|
||||
// If no session is set or the session cannot be found, it returns nil.
|
||||
func CurrentSession() *Session {
|
||||
if globalManager == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sessionID := CurrentSessionID()
|
||||
if sessionID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
session, err := globalManager.service.Get(context.Background(), sessionID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &session
|
||||
}
|
|
@ -3,12 +3,16 @@ package session
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/opencode-ai/opencode/internal/db"
|
||||
"github.com/opencode-ai/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
// Session represents a conversation session.
|
||||
type Session struct {
|
||||
ID string
|
||||
ParentSessionID string
|
||||
|
@ -23,128 +27,185 @@ type Session struct {
|
|||
UpdatedAt int64
|
||||
}
|
||||
|
||||
// --- Events ---
|
||||
|
||||
const (
|
||||
EventSessionCreated pubsub.EventType = "session_created"
|
||||
EventSessionUpdated pubsub.EventType = "session_updated"
|
||||
EventSessionDeleted pubsub.EventType = "session_deleted"
|
||||
)
|
||||
|
||||
// --- Service Definition ---
|
||||
|
||||
type Service interface {
|
||||
pubsub.Suscriber[Session]
|
||||
pubsub.Subscriber[Session]
|
||||
|
||||
Create(ctx context.Context, title string) (Session, error)
|
||||
CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
|
||||
CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
|
||||
Get(ctx context.Context, id string) (Session, error)
|
||||
List(ctx context.Context) ([]Session, error)
|
||||
Save(ctx context.Context, session Session) (Session, error)
|
||||
Update(ctx context.Context, session Session) (Session, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
type service struct {
|
||||
*pubsub.Broker[Session]
|
||||
q db.Querier
|
||||
db *db.Queries
|
||||
broker *pubsub.Broker[Session]
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var globalSessionService *service
|
||||
|
||||
func InitService(dbConn *sql.DB) error {
|
||||
if globalSessionService != nil {
|
||||
return fmt.Errorf("session service already initialized")
|
||||
}
|
||||
queries := db.New(dbConn)
|
||||
broker := pubsub.NewBroker[Session]()
|
||||
|
||||
globalSessionService = &service{
|
||||
db: queries,
|
||||
broker: broker,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetService() Service {
|
||||
if globalSessionService == nil {
|
||||
panic("session service not initialized. Call session.InitService() first.")
|
||||
}
|
||||
return globalSessionService
|
||||
}
|
||||
|
||||
// --- Service Methods ---
|
||||
|
||||
func (s *service) Create(ctx context.Context, title string) (Session, error) {
|
||||
dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if title == "" {
|
||||
title = "New Session - " + time.Now().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
dbSessParams := db.CreateSessionParams{
|
||||
ID: uuid.New().String(),
|
||||
Title: title,
|
||||
})
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
}
|
||||
dbSession, err := s.db.CreateSession(ctx, dbSessParams)
|
||||
if err != nil {
|
||||
return Session{}, fmt.Errorf("db.CreateSession: %w", err)
|
||||
}
|
||||
|
||||
session := s.fromDBItem(dbSession)
|
||||
s.Publish(pubsub.CreatedEvent, session)
|
||||
s.broker.Publish(EventSessionCreated, session)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
|
||||
dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if title == "" {
|
||||
title = "Task Session - " + time.Now().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
if toolCallID == "" {
|
||||
toolCallID = uuid.New().String()
|
||||
}
|
||||
|
||||
dbSessParams := db.CreateSessionParams{
|
||||
ID: toolCallID,
|
||||
ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
|
||||
ParentSessionID: sql.NullString{String: parentSessionID, Valid: parentSessionID != ""},
|
||||
Title: title,
|
||||
})
|
||||
}
|
||||
dbSession, err := s.db.CreateSession(ctx, dbSessParams)
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
return Session{}, fmt.Errorf("db.CreateTaskSession: %w", err)
|
||||
}
|
||||
session := s.fromDBItem(dbSession)
|
||||
s.Publish(pubsub.CreatedEvent, session)
|
||||
s.broker.Publish(EventSessionCreated, session)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
|
||||
dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
|
||||
ID: "title-" + parentSessionID,
|
||||
ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
|
||||
Title: "Generate a title",
|
||||
})
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
}
|
||||
session := s.fromDBItem(dbSession)
|
||||
s.Publish(pubsub.CreatedEvent, session)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *service) Delete(ctx context.Context, id string) error {
|
||||
session, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.q.DeleteSession(ctx, session.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Publish(pubsub.DeletedEvent, session)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Get(ctx context.Context, id string) (Session, error) {
|
||||
dbSession, err := s.q.GetSessionByID(ctx, id)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbSession, err := s.db.GetSessionByID(ctx, id)
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
if err == sql.ErrNoRows {
|
||||
return Session{}, fmt.Errorf("session ID '%s' not found", id)
|
||||
}
|
||||
return Session{}, fmt.Errorf("db.GetSessionByID: %w", err)
|
||||
}
|
||||
return s.fromDBItem(dbSession), nil
|
||||
}
|
||||
|
||||
func (s *service) Save(ctx context.Context, session Session) (Session, error) {
|
||||
summary := sql.NullString{}
|
||||
if session.Summary != "" {
|
||||
summary.String = session.Summary
|
||||
summary.Valid = true
|
||||
func (s *service) List(ctx context.Context) ([]Session, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbSessions, err := s.db.ListSessions(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListSessions: %w", err)
|
||||
}
|
||||
|
||||
summarizedAt := sql.NullInt64{}
|
||||
if session.SummarizedAt != 0 {
|
||||
summarizedAt.Int64 = session.SummarizedAt
|
||||
summarizedAt.Valid = true
|
||||
sessions := make([]Session, len(dbSessions))
|
||||
for i, dbSess := range dbSessions {
|
||||
sessions[i] = s.fromDBItem(dbSess)
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
|
||||
func (s *service) Update(ctx context.Context, session Session) (Session, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if session.ID == "" {
|
||||
return Session{}, fmt.Errorf("cannot update session with empty ID")
|
||||
}
|
||||
params := db.UpdateSessionParams{
|
||||
ID: session.ID,
|
||||
Title: session.Title,
|
||||
PromptTokens: session.PromptTokens,
|
||||
CompletionTokens: session.CompletionTokens,
|
||||
Cost: session.Cost,
|
||||
Summary: summary,
|
||||
SummarizedAt: summarizedAt,
|
||||
})
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
Summary: sql.NullString{String: session.Summary, Valid: session.Summary != ""},
|
||||
SummarizedAt: sql.NullInt64{Int64: session.SummarizedAt, Valid: session.SummarizedAt > 0},
|
||||
}
|
||||
|
||||
session = s.fromDBItem(dbSession)
|
||||
s.Publish(pubsub.UpdatedEvent, session)
|
||||
return session, nil
|
||||
dbSession, err := s.db.UpdateSession(ctx, params)
|
||||
if err != nil {
|
||||
return Session{}, fmt.Errorf("db.UpdateSession: %w", err)
|
||||
}
|
||||
updatedSession := s.fromDBItem(dbSession)
|
||||
s.broker.Publish(EventSessionUpdated, updatedSession)
|
||||
return updatedSession, nil
|
||||
}
|
||||
|
||||
func (s *service) List(ctx context.Context) ([]Session, error) {
|
||||
dbSessions, err := s.q.ListSessions(ctx)
|
||||
func (s *service) Delete(ctx context.Context, id string) error {
|
||||
s.mu.Lock()
|
||||
dbSess, err := s.db.GetSessionByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
s.mu.Unlock()
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("session ID '%s' not found for deletion", id)
|
||||
}
|
||||
return fmt.Errorf("db.GetSessionByID before delete: %w", err)
|
||||
}
|
||||
sessions := make([]Session, len(dbSessions))
|
||||
for i, dbSession := range dbSessions {
|
||||
sessions[i] = s.fromDBItem(dbSession)
|
||||
sessionToPublish := s.fromDBItem(dbSess)
|
||||
s.mu.Unlock()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
err = s.db.DeleteSession(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.DeleteSession: %w", err)
|
||||
}
|
||||
return sessions, nil
|
||||
s.broker.Publish(EventSessionDeleted, sessionToPublish)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s service) fromDBItem(item db.Session) Session {
|
||||
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Session] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func (s *service) fromDBItem(item db.Session) Session {
|
||||
return Session{
|
||||
ID: item.ID,
|
||||
ParentSessionID: item.ParentSessionID.String,
|
||||
|
@ -160,10 +221,30 @@ func (s service) fromDBItem(item db.Session) Session {
|
|||
}
|
||||
}
|
||||
|
||||
func NewService(q db.Querier) Service {
|
||||
broker := pubsub.NewBroker[Session]()
|
||||
return &service{
|
||||
broker,
|
||||
q,
|
||||
}
|
||||
func Create(ctx context.Context, title string) (Session, error) {
|
||||
return GetService().Create(ctx, title)
|
||||
}
|
||||
|
||||
func CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
|
||||
return GetService().CreateTaskSession(ctx, toolCallID, parentSessionID, title)
|
||||
}
|
||||
|
||||
func Get(ctx context.Context, id string) (Session, error) {
|
||||
return GetService().Get(ctx, id)
|
||||
}
|
||||
|
||||
func List(ctx context.Context) ([]Session, error) {
|
||||
return GetService().List(ctx)
|
||||
}
|
||||
|
||||
func Update(ctx context.Context, session Session) (Session, error) {
|
||||
return GetService().Update(ctx, session)
|
||||
}
|
||||
|
||||
func Delete(ctx context.Context, id string) error {
|
||||
return GetService().Delete(ctx, id)
|
||||
}
|
||||
|
||||
func Subscribe(ctx context.Context) <-chan pubsub.Event[Session] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
|
|
|
@ -1,64 +0,0 @@
|
|||
package status
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Manager handles status message management
|
||||
type Manager struct {
|
||||
service Service
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Global instance of the status manager
|
||||
var globalManager *Manager
|
||||
|
||||
// InitManager initializes the global status manager with the provided service
|
||||
func InitManager(service Service) {
|
||||
globalManager = &Manager{
|
||||
service: service,
|
||||
}
|
||||
|
||||
// Subscribe to status events for any global handling if needed
|
||||
// go func() {
|
||||
// ctx := context.Background()
|
||||
// _ = service.Subscribe(ctx)
|
||||
// }()
|
||||
|
||||
slog.Debug("Status manager initialized")
|
||||
}
|
||||
|
||||
// GetService returns the status service from the global manager
|
||||
func GetService() Service {
|
||||
if globalManager == nil {
|
||||
slog.Warn("Status manager not initialized, initializing with default service")
|
||||
InitManager(NewService())
|
||||
}
|
||||
|
||||
globalManager.mu.RLock()
|
||||
defer globalManager.mu.RUnlock()
|
||||
|
||||
return globalManager.service
|
||||
}
|
||||
|
||||
// Info publishes an info level status message using the global manager
|
||||
func Info(message string) {
|
||||
GetService().Info(message)
|
||||
}
|
||||
|
||||
// Warn publishes a warning level status message using the global manager
|
||||
func Warn(message string) {
|
||||
GetService().Warn(message)
|
||||
}
|
||||
|
||||
// Error publishes an error level status message using the global manager
|
||||
func Error(message string) {
|
||||
GetService().Error(message)
|
||||
}
|
||||
|
||||
// Debug publishes a debug level status message using the global manager
|
||||
func Debug(message string) {
|
||||
GetService().Debug(message)
|
||||
}
|
||||
|
|
@ -1,35 +1,36 @@
|
|||
package status
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
// Level represents the severity level of a status message
|
||||
type Level string
|
||||
|
||||
const (
|
||||
// LevelInfo represents an informational status message
|
||||
LevelInfo Level = "info"
|
||||
// LevelWarn represents a warning status message
|
||||
LevelWarn Level = "warn"
|
||||
// LevelError represents an error status message
|
||||
LevelInfo Level = "info"
|
||||
LevelWarn Level = "warn"
|
||||
LevelError Level = "error"
|
||||
// LevelDebug represents a debug status message
|
||||
LevelDebug Level = "debug"
|
||||
)
|
||||
|
||||
// StatusMessage represents a status update to be displayed in the UI
|
||||
type StatusMessage struct {
|
||||
Level Level `json:"level"`
|
||||
Message string `json:"message"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// Service defines the interface for the status service
|
||||
const (
|
||||
EventStatusPublished pubsub.EventType = "status_published"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
pubsub.Suscriber[StatusMessage]
|
||||
pubsub.Subscriber[StatusMessage]
|
||||
|
||||
Info(message string)
|
||||
Warn(message string)
|
||||
Error(message string)
|
||||
|
@ -37,44 +38,75 @@ type Service interface {
|
|||
}
|
||||
|
||||
type service struct {
|
||||
*pubsub.Broker[StatusMessage]
|
||||
broker *pubsub.Broker[StatusMessage]
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var globalStatusService *service
|
||||
|
||||
func InitService() error {
|
||||
if globalStatusService != nil {
|
||||
return fmt.Errorf("status service already initialized")
|
||||
}
|
||||
broker := pubsub.NewBroker[StatusMessage]()
|
||||
globalStatusService = &service{
|
||||
broker: broker,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetService() Service {
|
||||
if globalStatusService == nil {
|
||||
panic("status service not initialized. Call status.InitService() at application startup.")
|
||||
}
|
||||
return globalStatusService
|
||||
}
|
||||
|
||||
// Info publishes an info level status message
|
||||
func (s *service) Info(message string) {
|
||||
s.publish(LevelInfo, message)
|
||||
}
|
||||
|
||||
// Warn publishes a warning level status message
|
||||
func (s *service) Warn(message string) {
|
||||
s.publish(LevelWarn, message)
|
||||
}
|
||||
|
||||
// Error publishes an error level status message
|
||||
func (s *service) Error(message string) {
|
||||
s.publish(LevelError, message)
|
||||
}
|
||||
|
||||
// Debug publishes a debug level status message
|
||||
func (s *service) Debug(message string) {
|
||||
s.publish(LevelDebug, message)
|
||||
}
|
||||
|
||||
// publish creates and publishes a status message with the given level and message
|
||||
func (s *service) publish(level Level, message string) {
|
||||
func (s *service) publish(level Level, messageText string) {
|
||||
statusMsg := StatusMessage{
|
||||
Level: level,
|
||||
Message: message,
|
||||
Message: messageText,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
s.Publish(pubsub.CreatedEvent, statusMsg)
|
||||
s.broker.Publish(EventStatusPublished, statusMsg)
|
||||
}
|
||||
|
||||
// NewService creates a new status service
|
||||
func NewService() Service {
|
||||
broker := pubsub.NewBroker[StatusMessage]()
|
||||
return &service{
|
||||
Broker: broker,
|
||||
}
|
||||
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[StatusMessage] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func Info(message string) {
|
||||
GetService().Info(message)
|
||||
}
|
||||
|
||||
func Warn(message string) {
|
||||
GetService().Warn(message)
|
||||
}
|
||||
|
||||
func Error(message string) {
|
||||
GetService().Error(message)
|
||||
}
|
||||
|
||||
func Debug(message string) {
|
||||
GetService().Debug(message)
|
||||
}
|
||||
|
||||
func Subscribe(ctx context.Context) <-chan pubsub.Event[StatusMessage] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue