diff --git a/packages/opencode/src/session/index.ts b/packages/opencode/src/session/index.ts index 6939aa5aa..38863cc9f 100644 --- a/packages/opencode/src/session/index.ts +++ b/packages/opencode/src/session/index.ts @@ -40,6 +40,7 @@ import { MessageV2 } from "./message-v2" import { Mode } from "./mode" import { LSP } from "../lsp" import { ReadTool } from "../tool/read" +import { splitWhen } from "remeda" export namespace Session { const log = Log.create({ service: "session" }) @@ -61,6 +62,13 @@ export namespace Session { created: z.number(), updated: z.number(), }), + revert: z + .object({ + messageID: z.string(), + partID: z.string().optional(), + snapshot: z.string().optional(), + }) + .optional(), }) .openapi({ ref: "Session", @@ -355,6 +363,24 @@ export namespace Session { const previous = msgs.filter((x) => x.info.role === "assistant").at(-1)?.info as MessageV2.Assistant const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX + if (session.revert) { + const messageID = session.revert.messageID + const [preserve, remove] = splitWhen(msgs, (x) => x.info.id === messageID) + msgs = preserve + for (const msg of remove) { + await Storage.remove(`session/message/${input.sessionID}/${msg.info.id}`) + } + const last = preserve.at(-1) + if (session.revert.partID && last) { + const partID = session.revert.partID + const [preserveParts, removeParts] = splitWhen(last.parts, (x) => x.id === partID) + last.parts = preserveParts + for (const part of removeParts) { + await Storage.remove(`session/part/${input.sessionID}/${last.info.id}/${part.id}`) + } + } + } + // auto summarize if too long if (previous && previous.tokens) { const tokens = @@ -946,60 +972,40 @@ export namespace Session { } } - export async function revert(input: { sessionID: string; messageID: string; partID: string }) { + export async function revert(input: { sessionID: string; messageID: string; partID?: string }) { const all = await messages(input.sessionID) - let snapshot: MessageV2.SnapshotPart | undefined - for (let i = 0; i < all.length; i++) { - const msg = all[i] - + let lastUser: MessageV2.User | undefined + let lastSnapshot: MessageV2.SnapshotPart | undefined + for (const msg of all) { + if (msg.info.role === "user") lastUser = msg.info + const remaining = [] for (const part of msg.parts) { - if (part.id > input.partID) { - // delete part - continue + if (part.type === "snapshot") lastSnapshot = part + if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) { + // if no useful parts left in message, same as reverting whole message + const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined + const snapshot = await Snapshot.create(input.sessionID) + if (lastSnapshot) await Snapshot.restore(input.sessionID, lastSnapshot.snapshot) + const session = await update(input.sessionID, (draft) => { + draft.revert = { + // if not part id jump to the last user message + messageID: !partID && lastUser ? lastUser.id : msg.info.id, + partID, + snapshot, + } + }) + return session } - if (part.type === "snapshot") { - snapshot = part - } - if (part.id === input.partID && snapshot) { - await Snapshot.restore(input.sessionID, snapshot.snapshot) - } - } - - if (msg.info.id > input.messageID) { - // delete message + remaining.push(part) } } + } - // TODO - /* - const message = await getMessage(input.sessionID, input.messageID) - if (!message) return - const part = message.parts[input.part] - if (!part) return - const session = await get(input.sessionID) - const snapshot = - session.revert?.snapshot ?? (await Snapshot.create(input.sessionID)) - const old = (() => { - if (message.role === "assistant") { - const lastTool = message.parts.findLast( - (part, index) => - part.type === "tool-invocation" && index < input.part, - ) - if (lastTool && lastTool.type === "tool-invocation") - return message.metadata.tool[lastTool.toolInvocation.toolCallId] - .snapshot - } - return message.metadata.snapshot - })() - if (old) await Snapshot.restore(input.sessionID, old) - await update(input.sessionID, (draft) => { - draft.revert = { - messageID: input.messageID, - part: input.part, - snapshot, - } + export async function unrevert(input: { sessionID: string }) { + const session = await update(input.sessionID, (draft) => { + draft.revert = undefined }) - */ + return session } export async function summarize(input: { sessionID: string; providerID: string; modelID: string }) { diff --git a/packages/opencode/src/snapshot/index.ts b/packages/opencode/src/snapshot/index.ts index d3b055e3b..6534da86d 100644 --- a/packages/opencode/src/snapshot/index.ts +++ b/packages/opencode/src/snapshot/index.ts @@ -54,7 +54,7 @@ export namespace Snapshot { log.info("restore", { commit: snapshot }) const app = App.info() const git = gitdir(sessionID) - await $`git --git-dir=${git} checkout ${snapshot} --force`.quiet().cwd(app.path.root) + await $`git --git-dir=${git} reset --hard ${snapshot}`.quiet().cwd(app.path.root) } export async function diff(sessionID: string, commit: string) {