diff --git a/packages/console/app/src/routes/zen/util/handler.ts b/packages/console/app/src/routes/zen/util/handler.ts index 623c93127..7844a3ab0 100644 --- a/packages/console/app/src/routes/zen/util/handler.ts +++ b/packages/console/app/src/routes/zen/util/handler.ts @@ -21,6 +21,7 @@ import { oaCompatHelper } from "./provider/openai-compatible" import { createRateLimiter } from "./rateLimiter" import { createDataDumper } from "./dataDumper" import { createTrialLimiter } from "./trialLimiter" +import { createStickyTracker } from "./stickyProviderTracker" type ZenData = Awaited> type RetryOptions = { @@ -68,9 +69,11 @@ export async function handler( const isTrial = await trialLimiter?.isTrial() const rateLimiter = createRateLimiter(modelInfo.id, modelInfo.rateLimit, ip) await rateLimiter?.check() + const stickyTracker = createStickyTracker(modelInfo.stickyProvider ?? false, sessionId) + const stickyProvider = await stickyTracker?.get() const retriableRequest = async (retry: RetryOptions = { excludeProviders: [], retryCount: 0 }) => { - const providerInfo = selectProvider(zenData, modelInfo, sessionId, isTrial ?? false, retry) + const providerInfo = selectProvider(zenData, modelInfo, sessionId, isTrial ?? false, retry, stickyProvider) const authInfo = await authenticate(modelInfo, providerInfo) validateBilling(authInfo, modelInfo) validateModelSettings(authInfo) @@ -121,6 +124,9 @@ export async function handler( dataDumper?.provideModel(providerInfo.storeModel) dataDumper?.provideRequest(reqBody) + // Store sticky provider + await stickyTracker?.set(providerInfo.id) + // Scrub response headers const resHeaders = new Headers() const keepHeaders = ["content-type", "cache-control"] @@ -289,12 +295,18 @@ export async function handler( sessionId: string, isTrial: boolean, retry: RetryOptions, + stickyProvider: string | undefined, ) { const provider = (() => { if (isTrial) { return modelInfo.providers.find((provider) => provider.id === modelInfo.trial!.provider) } + if (stickyProvider) { + const provider = modelInfo.providers.find((provider) => provider.id === stickyProvider) + if (provider) return provider + } + if (retry.retryCount === MAX_RETRIES) { return modelInfo.providers.find((provider) => provider.id === modelInfo.fallbackProvider) } diff --git a/packages/console/app/src/routes/zen/util/stickyProviderTracker.ts b/packages/console/app/src/routes/zen/util/stickyProviderTracker.ts new file mode 100644 index 000000000..63cbb0a68 --- /dev/null +++ b/packages/console/app/src/routes/zen/util/stickyProviderTracker.ts @@ -0,0 +1,16 @@ +import { Resource } from "@opencode-ai/console-resource" + +export function createStickyTracker(stickyProvider: boolean, session: string) { + if (!stickyProvider) return + if (!session) return + const key = `sticky:${session}` + + return { + get: async () => { + return await Resource.GatewayKv.get(key) + }, + set: async (providerId: string) => { + await Resource.GatewayKv.put(key, providerId, { expirationTtl: 86400 }) + }, + } +} diff --git a/packages/console/core/src/model.ts b/packages/console/core/src/model.ts index 8cc181b7c..5a4a98fe9 100644 --- a/packages/console/core/src/model.ts +++ b/packages/console/core/src/model.ts @@ -24,6 +24,7 @@ export namespace ZenData { cost: ModelCostSchema, cost200K: ModelCostSchema.optional(), allowAnonymous: z.boolean().optional(), + stickyProvider: z.boolean().optional(), trial: z .object({ limit: z.number(),