mirror of
https://github.com/sst/opencode.git
synced 2025-12-23 10:11:41 +00:00
core: extract session processor to handle streaming responses and tool execution
This commit is contained in:
parent
1f968d8692
commit
d0277fae0d
3 changed files with 645 additions and 550 deletions
|
|
@ -17,6 +17,7 @@ import { ProviderTransform } from "@/provider/transform"
|
|||
import { SessionRetry } from "./retry"
|
||||
import { Config } from "@/config/config"
|
||||
import { Lock } from "../util/lock"
|
||||
import { SessionProcessor } from "./processor"
|
||||
|
||||
export namespace SessionCompaction {
|
||||
const log = Log.create({ service: "session.compaction" })
|
||||
|
|
@ -36,7 +37,8 @@ export namespace SessionCompaction {
|
|||
if (context === 0) return false
|
||||
const count = input.tokens.input + input.tokens.cache.read + input.tokens.output
|
||||
const output = Math.min(input.model.limit.output, SessionPrompt.OUTPUT_TOKEN_MAX) || SessionPrompt.OUTPUT_TOKEN_MAX
|
||||
const usable = context - output
|
||||
// const usable = context - output
|
||||
const usable = 20_000
|
||||
return count > usable
|
||||
}
|
||||
|
||||
|
|
@ -87,6 +89,109 @@ export namespace SessionCompaction {
|
|||
}
|
||||
}
|
||||
|
||||
export async function process(input: {
|
||||
parentID: string
|
||||
messages: MessageV2.WithParts[]
|
||||
sessionID: string
|
||||
model: {
|
||||
providerID: string
|
||||
modelID: string
|
||||
}
|
||||
abort: AbortSignal
|
||||
}) {
|
||||
const model = await Provider.getModel(input.model.providerID, input.model.modelID)
|
||||
const system = [
|
||||
...SystemPrompt.summarize(model.providerID),
|
||||
...(await SystemPrompt.environment()),
|
||||
...(await SystemPrompt.custom()),
|
||||
]
|
||||
const msg = (await Session.updateMessage({
|
||||
id: Identifier.ascending("message"),
|
||||
role: "assistant",
|
||||
parentID: input.parentID,
|
||||
sessionID: input.sessionID,
|
||||
mode: "build",
|
||||
path: {
|
||||
cwd: Instance.directory,
|
||||
root: Instance.worktree,
|
||||
},
|
||||
summary: true,
|
||||
cost: 0,
|
||||
tokens: {
|
||||
output: 0,
|
||||
input: 0,
|
||||
reasoning: 0,
|
||||
cache: { read: 0, write: 0 },
|
||||
},
|
||||
modelID: input.model.modelID,
|
||||
providerID: model.providerID,
|
||||
time: {
|
||||
created: Date.now(),
|
||||
},
|
||||
})) as MessageV2.Assistant
|
||||
const stream = streamText({
|
||||
// set to 0, we handle loop
|
||||
maxRetries: 0,
|
||||
model: model.language,
|
||||
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
|
||||
headers: model.info.headers,
|
||||
abortSignal: input.abort,
|
||||
tools: model.info.tool_call ? {} : undefined,
|
||||
messages: [
|
||||
...system.map(
|
||||
(x): ModelMessage => ({
|
||||
role: "system",
|
||||
content: x,
|
||||
}),
|
||||
),
|
||||
...MessageV2.toModelMessage(input.messages),
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
})
|
||||
const processor = SessionProcessor.create({
|
||||
assistantMessage: msg,
|
||||
sessionID: input.sessionID,
|
||||
providerID: input.model.providerID,
|
||||
model: model.info,
|
||||
abort: input.abort,
|
||||
})
|
||||
const result = await processor.process(stream)
|
||||
const userMessage = await Session.updateMessage({
|
||||
id: Identifier.ascending("message"),
|
||||
role: "user",
|
||||
sessionID: input.sessionID,
|
||||
time: {
|
||||
created: Date.now(),
|
||||
},
|
||||
model: {
|
||||
providerID: input.model.providerID,
|
||||
modelID: input.model.modelID,
|
||||
},
|
||||
agent: "build",
|
||||
})
|
||||
await Session.updatePart({
|
||||
type: "text",
|
||||
sessionID: input.sessionID,
|
||||
messageID: userMessage.id,
|
||||
id: Identifier.ascending("part"),
|
||||
text: "Use the above summary generated from your last session to resume from where you left off.",
|
||||
time: {
|
||||
start: Date.now(),
|
||||
end: Date.now(),
|
||||
},
|
||||
synthetic: true,
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
export async function run(input: { sessionID: string; providerID: string; modelID: string; signal?: AbortSignal }) {
|
||||
const signal = input.signal ?? new AbortController().signal
|
||||
await using lock = input.signal === undefined ? await Lock.write(input.sessionID) : undefined
|
||||
|
|
|
|||
350
packages/opencode/src/session/processor.ts
Normal file
350
packages/opencode/src/session/processor.ts
Normal file
|
|
@ -0,0 +1,350 @@
|
|||
import type { ModelsDev } from "@/provider/models"
|
||||
import { MessageV2 } from "./message-v2"
|
||||
import type { StreamTextResult, Tool as AITool } from "ai"
|
||||
import { Log } from "@/util/log"
|
||||
import { Identifier } from "@/id/id"
|
||||
import { Session } from "."
|
||||
import { Agent } from "@/agent/agent"
|
||||
import { Permission } from "@/permission"
|
||||
import { Snapshot } from "@/snapshot"
|
||||
import { SessionSummary } from "./summary"
|
||||
import { Bus } from "@/bus"
|
||||
|
||||
export namespace SessionProcessor {
|
||||
const DOOM_LOOP_THRESHOLD = 3
|
||||
const log = Log.create({ service: "session.processor" })
|
||||
|
||||
export type Info = Awaited<ReturnType<typeof create>>
|
||||
export type Result = Awaited<ReturnType<Info["process"]>>
|
||||
|
||||
export function create(input: {
|
||||
assistantMessage: MessageV2.Assistant
|
||||
sessionID: string
|
||||
providerID: string
|
||||
model: ModelsDev.Model
|
||||
abort: AbortSignal
|
||||
}) {
|
||||
const toolcalls: Record<string, MessageV2.ToolPart> = {}
|
||||
let snapshot: string | undefined
|
||||
let blocked = false
|
||||
|
||||
const result = {
|
||||
get message() {
|
||||
return input.assistantMessage
|
||||
},
|
||||
partFromToolCall(toolCallID: string) {
|
||||
return toolcalls[toolCallID]
|
||||
},
|
||||
async process(stream: StreamTextResult<Record<string, AITool>, never>) {
|
||||
log.info("process")
|
||||
try {
|
||||
let currentText: MessageV2.TextPart | undefined
|
||||
let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
|
||||
|
||||
for await (const value of stream.fullStream) {
|
||||
input.abort.throwIfAborted()
|
||||
switch (value.type) {
|
||||
case "start":
|
||||
break
|
||||
|
||||
case "reasoning-start":
|
||||
if (value.id in reasoningMap) {
|
||||
continue
|
||||
}
|
||||
reasoningMap[value.id] = {
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: input.assistantMessage.id,
|
||||
sessionID: input.assistantMessage.sessionID,
|
||||
type: "reasoning",
|
||||
text: "",
|
||||
time: {
|
||||
start: Date.now(),
|
||||
},
|
||||
metadata: value.providerMetadata,
|
||||
}
|
||||
break
|
||||
|
||||
case "reasoning-delta":
|
||||
if (value.id in reasoningMap) {
|
||||
const part = reasoningMap[value.id]
|
||||
part.text += value.text
|
||||
if (value.providerMetadata) part.metadata = value.providerMetadata
|
||||
if (part.text) await Session.updatePart({ part, delta: value.text })
|
||||
}
|
||||
break
|
||||
|
||||
case "reasoning-end":
|
||||
if (value.id in reasoningMap) {
|
||||
const part = reasoningMap[value.id]
|
||||
part.text = part.text.trimEnd()
|
||||
|
||||
part.time = {
|
||||
...part.time,
|
||||
end: Date.now(),
|
||||
}
|
||||
if (value.providerMetadata) part.metadata = value.providerMetadata
|
||||
await Session.updatePart(part)
|
||||
delete reasoningMap[value.id]
|
||||
}
|
||||
break
|
||||
|
||||
case "tool-input-start":
|
||||
const part = await Session.updatePart({
|
||||
id: toolcalls[value.id]?.id ?? Identifier.ascending("part"),
|
||||
messageID: input.assistantMessage.id,
|
||||
sessionID: input.assistantMessage.sessionID,
|
||||
type: "tool",
|
||||
tool: value.toolName,
|
||||
callID: value.id,
|
||||
state: {
|
||||
status: "pending",
|
||||
input: {},
|
||||
raw: "",
|
||||
},
|
||||
})
|
||||
toolcalls[value.id] = part as MessageV2.ToolPart
|
||||
break
|
||||
|
||||
case "tool-input-delta":
|
||||
break
|
||||
|
||||
case "tool-input-end":
|
||||
break
|
||||
|
||||
case "tool-call": {
|
||||
const match = toolcalls[value.toolCallId]
|
||||
if (match) {
|
||||
const part = await Session.updatePart({
|
||||
...match,
|
||||
tool: value.toolName,
|
||||
state: {
|
||||
status: "running",
|
||||
input: value.input,
|
||||
time: {
|
||||
start: Date.now(),
|
||||
},
|
||||
},
|
||||
metadata: value.providerMetadata,
|
||||
})
|
||||
toolcalls[value.toolCallId] = part as MessageV2.ToolPart
|
||||
|
||||
const parts = await MessageV2.parts(input.assistantMessage.id)
|
||||
const lastThree = parts.slice(-DOOM_LOOP_THRESHOLD)
|
||||
if (
|
||||
lastThree.length === DOOM_LOOP_THRESHOLD &&
|
||||
lastThree.every(
|
||||
(p) =>
|
||||
p.type === "tool" &&
|
||||
p.tool === value.toolName &&
|
||||
p.state.status !== "pending" &&
|
||||
JSON.stringify(p.state.input) === JSON.stringify(value.input),
|
||||
)
|
||||
) {
|
||||
const permission = await Agent.get(input.assistantMessage.mode).then((x) => x.permission)
|
||||
if (permission.doom_loop === "ask") {
|
||||
await Permission.ask({
|
||||
type: "doom_loop",
|
||||
pattern: value.toolName,
|
||||
sessionID: input.assistantMessage.sessionID,
|
||||
messageID: input.assistantMessage.id,
|
||||
callID: value.toolCallId,
|
||||
title: `Possible doom loop: "${value.toolName}" called ${DOOM_LOOP_THRESHOLD} times with identical arguments`,
|
||||
metadata: {
|
||||
tool: value.toolName,
|
||||
input: value.input,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case "tool-result": {
|
||||
const match = toolcalls[value.toolCallId]
|
||||
if (match && match.state.status === "running") {
|
||||
await Session.updatePart({
|
||||
...match,
|
||||
state: {
|
||||
status: "completed",
|
||||
input: value.input,
|
||||
output: value.output.output,
|
||||
metadata: value.output.metadata,
|
||||
title: value.output.title,
|
||||
time: {
|
||||
start: match.state.time.start,
|
||||
end: Date.now(),
|
||||
},
|
||||
attachments: value.output.attachments,
|
||||
},
|
||||
})
|
||||
|
||||
delete toolcalls[value.toolCallId]
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case "tool-error": {
|
||||
const match = toolcalls[value.toolCallId]
|
||||
if (match && match.state.status === "running") {
|
||||
await Session.updatePart({
|
||||
...match,
|
||||
state: {
|
||||
status: "error",
|
||||
input: value.input,
|
||||
error: (value.error as any).toString(),
|
||||
metadata: value.error instanceof Permission.RejectedError ? value.error.metadata : undefined,
|
||||
time: {
|
||||
start: match.state.time.start,
|
||||
end: Date.now(),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if (value.error instanceof Permission.RejectedError) {
|
||||
blocked = true
|
||||
}
|
||||
delete toolcalls[value.toolCallId]
|
||||
}
|
||||
break
|
||||
}
|
||||
case "error":
|
||||
throw value.error
|
||||
|
||||
case "start-step":
|
||||
snapshot = await Snapshot.track()
|
||||
await Session.updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: input.assistantMessage.id,
|
||||
sessionID: input.sessionID,
|
||||
snapshot,
|
||||
type: "step-start",
|
||||
})
|
||||
break
|
||||
|
||||
case "finish-step":
|
||||
const usage = Session.getUsage({
|
||||
model: input.model,
|
||||
usage: value.usage,
|
||||
metadata: value.providerMetadata,
|
||||
})
|
||||
input.assistantMessage.finish = value.finishReason
|
||||
input.assistantMessage.cost += usage.cost
|
||||
input.assistantMessage.tokens = usage.tokens
|
||||
await Session.updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
reason: value.finishReason,
|
||||
snapshot: await Snapshot.track(),
|
||||
messageID: input.assistantMessage.id,
|
||||
sessionID: input.assistantMessage.sessionID,
|
||||
type: "step-finish",
|
||||
tokens: usage.tokens,
|
||||
cost: usage.cost,
|
||||
})
|
||||
await Session.updateMessage(input.assistantMessage)
|
||||
if (snapshot) {
|
||||
const patch = await Snapshot.patch(snapshot)
|
||||
if (patch.files.length) {
|
||||
await Session.updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: input.assistantMessage.id,
|
||||
sessionID: input.sessionID,
|
||||
type: "patch",
|
||||
hash: patch.hash,
|
||||
files: patch.files,
|
||||
})
|
||||
}
|
||||
snapshot = undefined
|
||||
}
|
||||
SessionSummary.summarize({
|
||||
sessionID: input.sessionID,
|
||||
messageID: input.assistantMessage.parentID,
|
||||
})
|
||||
break
|
||||
|
||||
case "text-start":
|
||||
currentText = {
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: input.assistantMessage.id,
|
||||
sessionID: input.assistantMessage.sessionID,
|
||||
type: "text",
|
||||
text: "",
|
||||
time: {
|
||||
start: Date.now(),
|
||||
},
|
||||
metadata: value.providerMetadata,
|
||||
}
|
||||
break
|
||||
|
||||
case "text-delta":
|
||||
if (currentText) {
|
||||
currentText.text += value.text
|
||||
if (value.providerMetadata) currentText.metadata = value.providerMetadata
|
||||
if (currentText.text)
|
||||
await Session.updatePart({
|
||||
part: currentText,
|
||||
delta: value.text,
|
||||
})
|
||||
}
|
||||
break
|
||||
|
||||
case "text-end":
|
||||
if (currentText) {
|
||||
currentText.text = currentText.text.trimEnd()
|
||||
currentText.time = {
|
||||
start: Date.now(),
|
||||
end: Date.now(),
|
||||
}
|
||||
if (value.providerMetadata) currentText.metadata = value.providerMetadata
|
||||
await Session.updatePart(currentText)
|
||||
}
|
||||
currentText = undefined
|
||||
break
|
||||
|
||||
case "finish":
|
||||
input.assistantMessage.time.completed = Date.now()
|
||||
await Session.updateMessage(input.assistantMessage)
|
||||
break
|
||||
|
||||
default:
|
||||
log.info("unhandled", {
|
||||
...value,
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
log.error("process", {
|
||||
error: e,
|
||||
})
|
||||
const error = MessageV2.fromError(e, { providerID: input.providerID })
|
||||
input.assistantMessage.error = error
|
||||
Bus.publish(Session.Event.Error, {
|
||||
sessionID: input.assistantMessage.sessionID,
|
||||
error: input.assistantMessage.error,
|
||||
})
|
||||
}
|
||||
const p = await MessageV2.parts(input.assistantMessage.id)
|
||||
for (const part of p) {
|
||||
if (part.type === "tool" && part.state.status !== "completed" && part.state.status !== "error") {
|
||||
await Session.updatePart({
|
||||
...part,
|
||||
state: {
|
||||
...part.state,
|
||||
status: "error",
|
||||
error: "Tool execution aborted",
|
||||
time: {
|
||||
start: Date.now(),
|
||||
end: Date.now(),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
input.assistantMessage.time.completed = Date.now()
|
||||
await Session.updateMessage(input.assistantMessage)
|
||||
return { info: input.assistantMessage, parts: p, blocked }
|
||||
},
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
|
@ -50,11 +50,12 @@ import { SessionSummary } from "./summary"
|
|||
import { NamedError } from "@/util/error"
|
||||
import { fn } from "@/util/fn"
|
||||
import { SessionRetry } from "./retry"
|
||||
import { SessionProcessor } from "./processor"
|
||||
import { iife } from "@/util/iife"
|
||||
|
||||
export namespace SessionPrompt {
|
||||
const log = Log.create({ service: "session.prompt" })
|
||||
export const OUTPUT_TOKEN_MAX = 32_000
|
||||
const DOOM_LOOP_THRESHOLD = 3
|
||||
|
||||
export const Status = z
|
||||
.union([
|
||||
|
|
@ -298,158 +299,197 @@ export namespace SessionPrompt {
|
|||
step++
|
||||
|
||||
const model = await Provider.getModel(lastUser.model.providerID, lastUser.model.modelID)
|
||||
msgs = await checkOverflow({
|
||||
sessionID,
|
||||
model: model.info,
|
||||
abort,
|
||||
msgs,
|
||||
})
|
||||
const agent = await Agent.get(lastUser.agent)
|
||||
msgs = insertReminders({
|
||||
messages: msgs,
|
||||
agent,
|
||||
})
|
||||
const processor = await createProcessor({
|
||||
userMessage: lastUser,
|
||||
sessionID: sessionID,
|
||||
model: model.info,
|
||||
providerID: model.providerID,
|
||||
agent: agent.name,
|
||||
abort,
|
||||
})
|
||||
const system = await resolveSystemPrompt({
|
||||
providerID: model.providerID,
|
||||
modelID: model.info.id,
|
||||
agent,
|
||||
system: lastUser.system,
|
||||
})
|
||||
const tools = await resolveTools({
|
||||
agent,
|
||||
sessionID,
|
||||
model: lastUser.model,
|
||||
tools: lastUser.tools,
|
||||
processor,
|
||||
})
|
||||
|
||||
const params = await Plugin.trigger(
|
||||
"chat.params",
|
||||
{
|
||||
sessionID: sessionID,
|
||||
agent: lastUser.agent,
|
||||
model: model.info,
|
||||
provider: await Provider.getProvider(model.providerID),
|
||||
message: lastUser,
|
||||
},
|
||||
{
|
||||
temperature: model.info.temperature
|
||||
? (agent.temperature ?? ProviderTransform.temperature(model.providerID, model.modelID))
|
||||
: undefined,
|
||||
topP: agent.topP ?? ProviderTransform.topP(model.providerID, model.modelID),
|
||||
options: {
|
||||
...ProviderTransform.options(model.providerID, model.modelID, sessionID),
|
||||
...model.info.options,
|
||||
...agent.options,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if (step === 1) {
|
||||
SessionSummary.summarize({
|
||||
sessionID: sessionID,
|
||||
messageID: lastUser.id,
|
||||
})
|
||||
}
|
||||
|
||||
const stream = streamText({
|
||||
onError(error) {
|
||||
log.error("stream error", {
|
||||
error,
|
||||
const result = await iife(async () => {
|
||||
if (
|
||||
await checkOverflow({
|
||||
sessionID,
|
||||
model: model.info,
|
||||
abort,
|
||||
msgs,
|
||||
})
|
||||
},
|
||||
async experimental_repairToolCall(input) {
|
||||
const lower = input.toolCall.toolName.toLowerCase()
|
||||
if (lower !== input.toolCall.toolName && tools[lower]) {
|
||||
log.info("repairing tool call", {
|
||||
tool: input.toolCall.toolName,
|
||||
repaired: lower,
|
||||
) {
|
||||
return await SessionCompaction.process({
|
||||
messages: msgs,
|
||||
parentID: lastUser.id,
|
||||
abort,
|
||||
model: {
|
||||
providerID: model.providerID,
|
||||
modelID: model.modelID,
|
||||
},
|
||||
sessionID,
|
||||
})
|
||||
}
|
||||
|
||||
const agent = await Agent.get(lastUser.agent)
|
||||
msgs = insertReminders({
|
||||
messages: msgs,
|
||||
agent,
|
||||
})
|
||||
const processor = SessionProcessor.create({
|
||||
assistantMessage: (await Session.updateMessage({
|
||||
id: Identifier.ascending("message"),
|
||||
parentID: lastUser.id,
|
||||
role: "assistant",
|
||||
mode: agent.mode,
|
||||
path: {
|
||||
cwd: Instance.directory,
|
||||
root: Instance.worktree,
|
||||
},
|
||||
cost: 0,
|
||||
tokens: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
reasoning: 0,
|
||||
cache: { read: 0, write: 0 },
|
||||
},
|
||||
modelID: model.modelID,
|
||||
providerID: model.providerID,
|
||||
time: {
|
||||
created: Date.now(),
|
||||
},
|
||||
sessionID,
|
||||
})) as MessageV2.Assistant,
|
||||
sessionID: sessionID,
|
||||
model: model.info,
|
||||
providerID: model.providerID,
|
||||
abort,
|
||||
})
|
||||
const system = await resolveSystemPrompt({
|
||||
providerID: model.providerID,
|
||||
modelID: model.info.id,
|
||||
agent,
|
||||
system: lastUser.system,
|
||||
})
|
||||
const tools = await resolveTools({
|
||||
agent,
|
||||
sessionID,
|
||||
model: lastUser.model,
|
||||
tools: lastUser.tools,
|
||||
processor,
|
||||
})
|
||||
const params = await Plugin.trigger(
|
||||
"chat.params",
|
||||
{
|
||||
sessionID: sessionID,
|
||||
agent: lastUser.agent,
|
||||
model: model.info,
|
||||
provider: await Provider.getProvider(model.providerID),
|
||||
message: lastUser,
|
||||
},
|
||||
{
|
||||
temperature: model.info.temperature
|
||||
? (agent.temperature ?? ProviderTransform.temperature(model.providerID, model.modelID))
|
||||
: undefined,
|
||||
topP: agent.topP ?? ProviderTransform.topP(model.providerID, model.modelID),
|
||||
options: {
|
||||
...ProviderTransform.options(model.providerID, model.modelID, sessionID),
|
||||
...model.info.options,
|
||||
...agent.options,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if (step === 1) {
|
||||
SessionSummary.summarize({
|
||||
sessionID: sessionID,
|
||||
messageID: lastUser.id,
|
||||
})
|
||||
}
|
||||
|
||||
const stream = streamText({
|
||||
onError(error) {
|
||||
log.error("stream error", {
|
||||
error,
|
||||
})
|
||||
},
|
||||
async experimental_repairToolCall(input) {
|
||||
const lower = input.toolCall.toolName.toLowerCase()
|
||||
if (lower !== input.toolCall.toolName && tools[lower]) {
|
||||
log.info("repairing tool call", {
|
||||
tool: input.toolCall.toolName,
|
||||
repaired: lower,
|
||||
})
|
||||
return {
|
||||
...input.toolCall,
|
||||
toolName: lower,
|
||||
}
|
||||
}
|
||||
return {
|
||||
...input.toolCall,
|
||||
toolName: lower,
|
||||
input: JSON.stringify({
|
||||
tool: input.toolCall.toolName,
|
||||
error: input.error.message,
|
||||
}),
|
||||
toolName: "invalid",
|
||||
}
|
||||
}
|
||||
return {
|
||||
...input.toolCall,
|
||||
input: JSON.stringify({
|
||||
tool: input.toolCall.toolName,
|
||||
error: input.error.message,
|
||||
}),
|
||||
toolName: "invalid",
|
||||
}
|
||||
},
|
||||
headers: {
|
||||
...(model.providerID === "opencode"
|
||||
? {
|
||||
"x-opencode-session": sessionID,
|
||||
"x-opencode-request": lastUser.id,
|
||||
}
|
||||
: undefined),
|
||||
...model.info.headers,
|
||||
},
|
||||
// set to 0, we handle loop
|
||||
maxRetries: 0,
|
||||
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
|
||||
maxOutputTokens: ProviderTransform.maxOutputTokens(
|
||||
model.providerID,
|
||||
params.options,
|
||||
model.info.limit.output,
|
||||
OUTPUT_TOKEN_MAX,
|
||||
),
|
||||
abortSignal: abort,
|
||||
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options),
|
||||
stopWhen: stepCountIs(1),
|
||||
temperature: params.temperature,
|
||||
topP: params.topP,
|
||||
messages: [
|
||||
...system.map(
|
||||
(x): ModelMessage => ({
|
||||
role: "system",
|
||||
content: x,
|
||||
}),
|
||||
),
|
||||
...MessageV2.toModelMessage(
|
||||
msgs.filter((m) => {
|
||||
if (m.info.role !== "assistant" || m.info.error === undefined) {
|
||||
return true
|
||||
}
|
||||
if (
|
||||
MessageV2.AbortedError.isInstance(m.info.error) &&
|
||||
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
|
||||
) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}),
|
||||
),
|
||||
],
|
||||
tools: model.info.tool_call === false ? undefined : tools,
|
||||
model: wrapLanguageModel({
|
||||
model: model.language,
|
||||
middleware: [
|
||||
{
|
||||
async transformParams(args) {
|
||||
if (args.type === "stream") {
|
||||
// @ts-expect-error
|
||||
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
|
||||
},
|
||||
headers: {
|
||||
...(model.providerID === "opencode"
|
||||
? {
|
||||
"x-opencode-session": sessionID,
|
||||
"x-opencode-request": lastUser.id,
|
||||
}
|
||||
return args.params
|
||||
},
|
||||
},
|
||||
: undefined),
|
||||
...model.info.headers,
|
||||
},
|
||||
// set to 0, we handle loop
|
||||
maxRetries: 0,
|
||||
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
|
||||
maxOutputTokens: ProviderTransform.maxOutputTokens(
|
||||
model.providerID,
|
||||
params.options,
|
||||
model.info.limit.output,
|
||||
OUTPUT_TOKEN_MAX,
|
||||
),
|
||||
abortSignal: abort,
|
||||
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options),
|
||||
stopWhen: stepCountIs(1),
|
||||
temperature: params.temperature,
|
||||
topP: params.topP,
|
||||
messages: [
|
||||
...system.map(
|
||||
(x): ModelMessage => ({
|
||||
role: "system",
|
||||
content: x,
|
||||
}),
|
||||
),
|
||||
...MessageV2.toModelMessage(
|
||||
msgs.filter((m) => {
|
||||
if (m.info.role !== "assistant" || m.info.error === undefined) {
|
||||
return true
|
||||
}
|
||||
if (
|
||||
MessageV2.AbortedError.isInstance(m.info.error) &&
|
||||
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
|
||||
) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}),
|
||||
),
|
||||
],
|
||||
}),
|
||||
tools: model.info.tool_call === false ? undefined : tools,
|
||||
model: wrapLanguageModel({
|
||||
model: model.language,
|
||||
middleware: [
|
||||
{
|
||||
async transformParams(args) {
|
||||
if (args.type === "stream") {
|
||||
// @ts-expect-error
|
||||
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
|
||||
}
|
||||
return args.params
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
})
|
||||
|
||||
return await processor.process(stream)
|
||||
})
|
||||
const result = await processor.process(stream)
|
||||
|
||||
if (result.blocked) break
|
||||
if (result.info.error?.name === "APIError" && result.info.error.data.isRetryable) {
|
||||
retries++
|
||||
|
|
@ -497,52 +537,11 @@ export namespace SessionPrompt {
|
|||
}) {
|
||||
const lastAssistant = input.msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.time.completed)
|
||||
?.info as MessageV2.Assistant
|
||||
if (!lastAssistant) return input.msgs
|
||||
if (
|
||||
!SessionCompaction.isOverflow({
|
||||
tokens: lastAssistant.tokens,
|
||||
model: input.model,
|
||||
})
|
||||
)
|
||||
return input.msgs
|
||||
// TODO: make this more efficient
|
||||
const summaryMsg = await SessionCompaction.run({
|
||||
sessionID: input.sessionID,
|
||||
signal: input.abort,
|
||||
modelID: lastAssistant.modelID,
|
||||
providerID: lastAssistant.providerID,
|
||||
if (!lastAssistant) return false
|
||||
return SessionCompaction.isOverflow({
|
||||
tokens: lastAssistant.tokens,
|
||||
model: input.model,
|
||||
})
|
||||
const resumeMsgID = Identifier.ascending("message")
|
||||
const resumeMsg = {
|
||||
info: await Session.updateMessage({
|
||||
id: resumeMsgID,
|
||||
role: "user",
|
||||
sessionID: input.sessionID,
|
||||
time: {
|
||||
created: Date.now(),
|
||||
},
|
||||
model: {
|
||||
providerID: lastAssistant.providerID,
|
||||
modelID: lastAssistant.modelID,
|
||||
},
|
||||
agent: lastAssistant.mode,
|
||||
}),
|
||||
parts: [
|
||||
await Session.updatePart({
|
||||
type: "text",
|
||||
sessionID: input.sessionID,
|
||||
messageID: resumeMsgID,
|
||||
id: Identifier.ascending("part"),
|
||||
text: "Use the above summary generated from your last session to resume from where you left off.",
|
||||
time: {
|
||||
start: Date.now(),
|
||||
end: Date.now(),
|
||||
},
|
||||
synthetic: true,
|
||||
}),
|
||||
],
|
||||
}
|
||||
return [summaryMsg, resumeMsg]
|
||||
}
|
||||
|
||||
async function resolveModel(input: { model: PromptInput["model"]; agent: Agent.Info }) {
|
||||
|
|
@ -585,7 +584,7 @@ export namespace SessionPrompt {
|
|||
}
|
||||
sessionID: string
|
||||
tools?: Record<string, boolean>
|
||||
processor: Processor
|
||||
processor: SessionProcessor.Info
|
||||
}) {
|
||||
const tools: Record<string, AITool> = {}
|
||||
const enabledTools = pipe(
|
||||
|
|
@ -1017,365 +1016,6 @@ export namespace SessionPrompt {
|
|||
return input.messages
|
||||
}
|
||||
|
||||
export type Processor = Awaited<ReturnType<typeof createProcessor>>
|
||||
async function createProcessor(input: {
|
||||
userMessage: MessageV2.User
|
||||
sessionID: string
|
||||
providerID: string
|
||||
model: ModelsDev.Model
|
||||
agent: string
|
||||
abort: AbortSignal
|
||||
}) {
|
||||
const toolcalls: Record<string, MessageV2.ToolPart> = {}
|
||||
let snapshot: string | undefined
|
||||
let blocked = false
|
||||
|
||||
const assistantMsg: MessageV2.Info = {
|
||||
id: Identifier.ascending("message"),
|
||||
parentID: input.userMessage.id,
|
||||
role: "assistant",
|
||||
mode: input.agent,
|
||||
path: {
|
||||
cwd: Instance.directory,
|
||||
root: Instance.worktree,
|
||||
},
|
||||
cost: 0,
|
||||
tokens: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
reasoning: 0,
|
||||
cache: { read: 0, write: 0 },
|
||||
},
|
||||
modelID: input.model.id,
|
||||
providerID: input.providerID,
|
||||
time: {
|
||||
created: Date.now(),
|
||||
},
|
||||
sessionID: input.sessionID,
|
||||
}
|
||||
await Session.updateMessage(assistantMsg)
|
||||
|
||||
const result = {
|
||||
get message() {
|
||||
return assistantMsg
|
||||
},
|
||||
partFromToolCall(toolCallID: string) {
|
||||
return toolcalls[toolCallID]
|
||||
},
|
||||
async process(stream: StreamTextResult<Record<string, AITool>, never>) {
|
||||
log.info("process")
|
||||
if (!assistantMsg) throw new Error("call next() first before processing")
|
||||
try {
|
||||
let currentText: MessageV2.TextPart | undefined
|
||||
let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
|
||||
|
||||
for await (const value of stream.fullStream) {
|
||||
input.abort.throwIfAborted()
|
||||
switch (value.type) {
|
||||
case "start":
|
||||
break
|
||||
|
||||
case "reasoning-start":
|
||||
if (value.id in reasoningMap) {
|
||||
continue
|
||||
}
|
||||
reasoningMap[value.id] = {
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: assistantMsg.id,
|
||||
sessionID: assistantMsg.sessionID,
|
||||
type: "reasoning",
|
||||
text: "",
|
||||
time: {
|
||||
start: Date.now(),
|
||||
},
|
||||
metadata: value.providerMetadata,
|
||||
}
|
||||
break
|
||||
|
||||
case "reasoning-delta":
|
||||
if (value.id in reasoningMap) {
|
||||
const part = reasoningMap[value.id]
|
||||
part.text += value.text
|
||||
if (value.providerMetadata) part.metadata = value.providerMetadata
|
||||
if (part.text) await Session.updatePart({ part, delta: value.text })
|
||||
}
|
||||
break
|
||||
|
||||
case "reasoning-end":
|
||||
if (value.id in reasoningMap) {
|
||||
const part = reasoningMap[value.id]
|
||||
part.text = part.text.trimEnd()
|
||||
|
||||
part.time = {
|
||||
...part.time,
|
||||
end: Date.now(),
|
||||
}
|
||||
if (value.providerMetadata) part.metadata = value.providerMetadata
|
||||
await Session.updatePart(part)
|
||||
delete reasoningMap[value.id]
|
||||
}
|
||||
break
|
||||
|
||||
case "tool-input-start":
|
||||
const part = await Session.updatePart({
|
||||
id: toolcalls[value.id]?.id ?? Identifier.ascending("part"),
|
||||
messageID: assistantMsg.id,
|
||||
sessionID: assistantMsg.sessionID,
|
||||
type: "tool",
|
||||
tool: value.toolName,
|
||||
callID: value.id,
|
||||
state: {
|
||||
status: "pending",
|
||||
input: {},
|
||||
raw: "",
|
||||
},
|
||||
})
|
||||
toolcalls[value.id] = part as MessageV2.ToolPart
|
||||
break
|
||||
|
||||
case "tool-input-delta":
|
||||
break
|
||||
|
||||
case "tool-input-end":
|
||||
break
|
||||
|
||||
case "tool-call": {
|
||||
const match = toolcalls[value.toolCallId]
|
||||
if (match) {
|
||||
const part = await Session.updatePart({
|
||||
...match,
|
||||
tool: value.toolName,
|
||||
state: {
|
||||
status: "running",
|
||||
input: value.input,
|
||||
time: {
|
||||
start: Date.now(),
|
||||
},
|
||||
},
|
||||
metadata: value.providerMetadata,
|
||||
})
|
||||
toolcalls[value.toolCallId] = part as MessageV2.ToolPart
|
||||
|
||||
const parts = await MessageV2.parts(assistantMsg.id)
|
||||
const lastThree = parts.slice(-DOOM_LOOP_THRESHOLD)
|
||||
if (
|
||||
lastThree.length === DOOM_LOOP_THRESHOLD &&
|
||||
lastThree.every(
|
||||
(p) =>
|
||||
p.type === "tool" &&
|
||||
p.tool === value.toolName &&
|
||||
p.state.status !== "pending" &&
|
||||
JSON.stringify(p.state.input) === JSON.stringify(value.input),
|
||||
)
|
||||
) {
|
||||
const permission = await Agent.get(input.agent).then((x) => x.permission)
|
||||
if (permission.doom_loop === "ask") {
|
||||
await Permission.ask({
|
||||
type: "doom_loop",
|
||||
pattern: value.toolName,
|
||||
sessionID: assistantMsg.sessionID,
|
||||
messageID: assistantMsg.id,
|
||||
callID: value.toolCallId,
|
||||
title: `Possible doom loop: "${value.toolName}" called ${DOOM_LOOP_THRESHOLD} times with identical arguments`,
|
||||
metadata: {
|
||||
tool: value.toolName,
|
||||
input: value.input,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case "tool-result": {
|
||||
const match = toolcalls[value.toolCallId]
|
||||
if (match && match.state.status === "running") {
|
||||
await Session.updatePart({
|
||||
...match,
|
||||
state: {
|
||||
status: "completed",
|
||||
input: value.input,
|
||||
output: value.output.output,
|
||||
metadata: value.output.metadata,
|
||||
title: value.output.title,
|
||||
time: {
|
||||
start: match.state.time.start,
|
||||
end: Date.now(),
|
||||
},
|
||||
attachments: value.output.attachments,
|
||||
},
|
||||
})
|
||||
|
||||
delete toolcalls[value.toolCallId]
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case "tool-error": {
|
||||
const match = toolcalls[value.toolCallId]
|
||||
if (match && match.state.status === "running") {
|
||||
await Session.updatePart({
|
||||
...match,
|
||||
state: {
|
||||
status: "error",
|
||||
input: value.input,
|
||||
error: (value.error as any).toString(),
|
||||
metadata: value.error instanceof Permission.RejectedError ? value.error.metadata : undefined,
|
||||
time: {
|
||||
start: match.state.time.start,
|
||||
end: Date.now(),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if (value.error instanceof Permission.RejectedError) {
|
||||
blocked = true
|
||||
}
|
||||
delete toolcalls[value.toolCallId]
|
||||
}
|
||||
break
|
||||
}
|
||||
case "error":
|
||||
throw value.error
|
||||
|
||||
case "start-step":
|
||||
snapshot = await Snapshot.track()
|
||||
await Session.updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: assistantMsg.id,
|
||||
sessionID: assistantMsg.sessionID,
|
||||
snapshot,
|
||||
type: "step-start",
|
||||
})
|
||||
break
|
||||
|
||||
case "finish-step":
|
||||
const usage = Session.getUsage({
|
||||
model: input.model,
|
||||
usage: value.usage,
|
||||
metadata: value.providerMetadata,
|
||||
})
|
||||
assistantMsg.finish = value.finishReason
|
||||
assistantMsg.cost += usage.cost
|
||||
assistantMsg.tokens = usage.tokens
|
||||
await Session.updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
reason: value.finishReason,
|
||||
snapshot: await Snapshot.track(),
|
||||
messageID: assistantMsg.id,
|
||||
sessionID: assistantMsg.sessionID,
|
||||
type: "step-finish",
|
||||
tokens: usage.tokens,
|
||||
cost: usage.cost,
|
||||
})
|
||||
await Session.updateMessage(assistantMsg)
|
||||
if (snapshot) {
|
||||
const patch = await Snapshot.patch(snapshot)
|
||||
if (patch.files.length) {
|
||||
await Session.updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: assistantMsg.id,
|
||||
sessionID: assistantMsg.sessionID,
|
||||
type: "patch",
|
||||
hash: patch.hash,
|
||||
files: patch.files,
|
||||
})
|
||||
}
|
||||
snapshot = undefined
|
||||
}
|
||||
SessionSummary.summarize({
|
||||
sessionID: input.sessionID,
|
||||
messageID: assistantMsg.parentID,
|
||||
})
|
||||
break
|
||||
|
||||
case "text-start":
|
||||
currentText = {
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: assistantMsg.id,
|
||||
sessionID: assistantMsg.sessionID,
|
||||
type: "text",
|
||||
text: "",
|
||||
time: {
|
||||
start: Date.now(),
|
||||
},
|
||||
metadata: value.providerMetadata,
|
||||
}
|
||||
break
|
||||
|
||||
case "text-delta":
|
||||
if (currentText) {
|
||||
currentText.text += value.text
|
||||
if (value.providerMetadata) currentText.metadata = value.providerMetadata
|
||||
if (currentText.text)
|
||||
await Session.updatePart({
|
||||
part: currentText,
|
||||
delta: value.text,
|
||||
})
|
||||
}
|
||||
break
|
||||
|
||||
case "text-end":
|
||||
if (currentText) {
|
||||
currentText.text = currentText.text.trimEnd()
|
||||
currentText.time = {
|
||||
start: Date.now(),
|
||||
end: Date.now(),
|
||||
}
|
||||
if (value.providerMetadata) currentText.metadata = value.providerMetadata
|
||||
await Session.updatePart(currentText)
|
||||
}
|
||||
currentText = undefined
|
||||
break
|
||||
|
||||
case "finish":
|
||||
assistantMsg.time.completed = Date.now()
|
||||
await Session.updateMessage(assistantMsg)
|
||||
break
|
||||
|
||||
default:
|
||||
log.info("unhandled", {
|
||||
...value,
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
log.error("process", {
|
||||
error: e,
|
||||
})
|
||||
const error = MessageV2.fromError(e, { providerID: input.providerID })
|
||||
assistantMsg.error = error
|
||||
Bus.publish(Session.Event.Error, {
|
||||
sessionID: assistantMsg.sessionID,
|
||||
error: assistantMsg.error,
|
||||
})
|
||||
}
|
||||
const p = await MessageV2.parts(assistantMsg.id)
|
||||
for (const part of p) {
|
||||
if (part.type === "tool" && part.state.status !== "completed" && part.state.status !== "error") {
|
||||
await Session.updatePart({
|
||||
...part,
|
||||
state: {
|
||||
...part.state,
|
||||
status: "error",
|
||||
error: "Tool execution aborted",
|
||||
time: {
|
||||
start: Date.now(),
|
||||
end: Date.now(),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
assistantMsg.time.completed = Date.now()
|
||||
await Session.updateMessage(assistantMsg)
|
||||
return { info: assistantMsg, parts: p, blocked }
|
||||
},
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
export const ShellInput = z.object({
|
||||
sessionID: Identifier.schema("session"),
|
||||
agent: z.string(),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue