From 43807277272467fcc9a8af0618d88845c3edad96 Mon Sep 17 00:00:00 2001 From: Frank Date: Thu, 4 Dec 2025 21:53:31 -0500 Subject: [PATCH] zen: fix byok --- .../app/src/routes/zen/util/handler.ts | 38 +++++++++++++------ packages/console/core/src/model.ts | 1 + 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/packages/console/app/src/routes/zen/util/handler.ts b/packages/console/app/src/routes/zen/util/handler.ts index 7844a3ab0..a7f025bab 100644 --- a/packages/console/app/src/routes/zen/util/handler.ts +++ b/packages/console/app/src/routes/zen/util/handler.ts @@ -73,8 +73,16 @@ export async function handler( const stickyProvider = await stickyTracker?.get() const retriableRequest = async (retry: RetryOptions = { excludeProviders: [], retryCount: 0 }) => { - const providerInfo = selectProvider(zenData, modelInfo, sessionId, isTrial ?? false, retry, stickyProvider) - const authInfo = await authenticate(modelInfo, providerInfo) + const authInfo = await authenticate(modelInfo) + const providerInfo = selectProvider( + zenData, + authInfo, + modelInfo, + sessionId, + isTrial ?? false, + retry, + stickyProvider, + ) validateBilling(authInfo, modelInfo) validateModelSettings(authInfo) updateProviderKey(authInfo, providerInfo) @@ -291,6 +299,7 @@ export async function handler( function selectProvider( zenData: ZenData, + authInfo: AuthInfo, modelInfo: ModelInfo, sessionId: string, isTrial: boolean, @@ -298,6 +307,10 @@ export async function handler( stickyProvider: string | undefined, ) { const provider = (() => { + if (authInfo?.provider?.credentials) { + return modelInfo.providers.find((provider) => provider.id === modelInfo.byokProvider) + } + if (isTrial) { return modelInfo.providers.find((provider) => provider.id === modelInfo.trial!.provider) } @@ -342,15 +355,15 @@ export async function handler( } } - async function authenticate(modelInfo: ModelInfo, providerInfo: ProviderInfo) { + async function authenticate(modelInfo: ModelInfo) { const apiKey = opts.parseApiKey(input.request.headers) if (!apiKey || apiKey === "public") { if (modelInfo.allowAnonymous) return throw new AuthError("Missing API key.") } - const data = await Database.use((tx) => - tx + const data = await Database.use((tx) => { + const query = tx .select({ apiKey: KeyTable.id, workspaceID: KeyTable.workspaceID, @@ -378,13 +391,15 @@ export async function handler( .innerJoin(BillingTable, eq(BillingTable.workspaceID, KeyTable.workspaceID)) .innerJoin(UserTable, and(eq(UserTable.workspaceID, KeyTable.workspaceID), eq(UserTable.id, KeyTable.userID))) .leftJoin(ModelTable, and(eq(ModelTable.workspaceID, KeyTable.workspaceID), eq(ModelTable.model, modelInfo.id))) - .leftJoin( + + if (modelInfo.byokProvider) { + query.leftJoin( ProviderTable, - and(eq(ProviderTable.workspaceID, KeyTable.workspaceID), eq(ProviderTable.provider, providerInfo.id)), + and(eq(ProviderTable.workspaceID, KeyTable.workspaceID), eq(ProviderTable.provider, modelInfo.byokProvider)), ) - .where(and(eq(KeyTable.key, apiKey), isNull(KeyTable.timeDeleted))) - .then((rows) => rows[0]), - ) + } + return query.where(and(eq(KeyTable.key, apiKey), isNull(KeyTable.timeDeleted))).then((rows) => rows[0]) + }) if (!data) throw new AuthError("Invalid API key.") logger.metric({ @@ -457,8 +472,7 @@ export async function handler( } function updateProviderKey(authInfo: AuthInfo, providerInfo: ProviderInfo) { - if (!authInfo) return - if (!authInfo.provider?.credentials) return + if (!authInfo?.provider?.credentials) return providerInfo.apiKey = authInfo.provider.credentials } diff --git a/packages/console/core/src/model.ts b/packages/console/core/src/model.ts index 5a4a98fe9..47ba3e9d8 100644 --- a/packages/console/core/src/model.ts +++ b/packages/console/core/src/model.ts @@ -24,6 +24,7 @@ export namespace ZenData { cost: ModelCostSchema, cost200K: ModelCostSchema.optional(), allowAnonymous: z.boolean().optional(), + byokProvider: z.enum(["openai", "anthropic", "google"]).optional(), stickyProvider: z.boolean().optional(), trial: z .object({