This commit is contained in:
Dax Raad 2025-07-15 18:05:52 -04:00
parent 0036eb3a09
commit 99c8bf704b
2 changed files with 54 additions and 48 deletions

View file

@ -40,6 +40,7 @@ import { MessageV2 } from "./message-v2"
import { Mode } from "./mode"
import { LSP } from "../lsp"
import { ReadTool } from "../tool/read"
import { splitWhen } from "remeda"
export namespace Session {
const log = Log.create({ service: "session" })
@ -61,6 +62,13 @@ export namespace Session {
created: z.number(),
updated: z.number(),
}),
revert: z
.object({
messageID: z.string(),
partID: z.string().optional(),
snapshot: z.string().optional(),
})
.optional(),
})
.openapi({
ref: "Session",
@ -355,6 +363,24 @@ export namespace Session {
const previous = msgs.filter((x) => x.info.role === "assistant").at(-1)?.info as MessageV2.Assistant
const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
if (session.revert) {
const messageID = session.revert.messageID
const [preserve, remove] = splitWhen(msgs, (x) => x.info.id === messageID)
msgs = preserve
for (const msg of remove) {
await Storage.remove(`session/message/${input.sessionID}/${msg.info.id}`)
}
const last = preserve.at(-1)
if (session.revert.partID && last) {
const partID = session.revert.partID
const [preserveParts, removeParts] = splitWhen(last.parts, (x) => x.id === partID)
last.parts = preserveParts
for (const part of removeParts) {
await Storage.remove(`session/part/${input.sessionID}/${last.info.id}/${part.id}`)
}
}
}
// auto summarize if too long
if (previous && previous.tokens) {
const tokens =
@ -946,60 +972,40 @@ export namespace Session {
}
}
export async function revert(input: { sessionID: string; messageID: string; partID: string }) {
export async function revert(input: { sessionID: string; messageID: string; partID?: string }) {
const all = await messages(input.sessionID)
let snapshot: MessageV2.SnapshotPart | undefined
for (let i = 0; i < all.length; i++) {
const msg = all[i]
let lastUser: MessageV2.User | undefined
let lastSnapshot: MessageV2.SnapshotPart | undefined
for (const msg of all) {
if (msg.info.role === "user") lastUser = msg.info
const remaining = []
for (const part of msg.parts) {
if (part.id > input.partID) {
// delete part
continue
if (part.type === "snapshot") lastSnapshot = part
if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
// if no useful parts left in message, same as reverting whole message
const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
const snapshot = await Snapshot.create(input.sessionID)
if (lastSnapshot) await Snapshot.restore(input.sessionID, lastSnapshot.snapshot)
const session = await update(input.sessionID, (draft) => {
draft.revert = {
// if not part id jump to the last user message
messageID: !partID && lastUser ? lastUser.id : msg.info.id,
partID,
snapshot,
}
})
return session
}
if (part.type === "snapshot") {
snapshot = part
}
if (part.id === input.partID && snapshot) {
await Snapshot.restore(input.sessionID, snapshot.snapshot)
}
}
if (msg.info.id > input.messageID) {
// delete message
remaining.push(part)
}
}
}
// TODO
/*
const message = await getMessage(input.sessionID, input.messageID)
if (!message) return
const part = message.parts[input.part]
if (!part) return
const session = await get(input.sessionID)
const snapshot =
session.revert?.snapshot ?? (await Snapshot.create(input.sessionID))
const old = (() => {
if (message.role === "assistant") {
const lastTool = message.parts.findLast(
(part, index) =>
part.type === "tool-invocation" && index < input.part,
)
if (lastTool && lastTool.type === "tool-invocation")
return message.metadata.tool[lastTool.toolInvocation.toolCallId]
.snapshot
}
return message.metadata.snapshot
})()
if (old) await Snapshot.restore(input.sessionID, old)
await update(input.sessionID, (draft) => {
draft.revert = {
messageID: input.messageID,
part: input.part,
snapshot,
}
export async function unrevert(input: { sessionID: string }) {
const session = await update(input.sessionID, (draft) => {
draft.revert = undefined
})
*/
return session
}
export async function summarize(input: { sessionID: string; providerID: string; modelID: string }) {

View file

@ -54,7 +54,7 @@ export namespace Snapshot {
log.info("restore", { commit: snapshot })
const app = App.info()
const git = gitdir(sessionID)
await $`git --git-dir=${git} checkout ${snapshot} --force`.quiet().cwd(app.path.root)
await $`git --git-dir=${git} reset --hard ${snapshot}`.quiet().cwd(app.path.root)
}
export async function diff(sessionID: string, commit: string) {