zen: fix byok

This commit is contained in:
Frank 2025-12-04 21:53:31 -05:00
parent 392d46933b
commit 4380727727
2 changed files with 27 additions and 12 deletions

View file

@ -73,8 +73,16 @@ export async function handler(
const stickyProvider = await stickyTracker?.get()
const retriableRequest = async (retry: RetryOptions = { excludeProviders: [], retryCount: 0 }) => {
const providerInfo = selectProvider(zenData, modelInfo, sessionId, isTrial ?? false, retry, stickyProvider)
const authInfo = await authenticate(modelInfo, providerInfo)
const authInfo = await authenticate(modelInfo)
const providerInfo = selectProvider(
zenData,
authInfo,
modelInfo,
sessionId,
isTrial ?? false,
retry,
stickyProvider,
)
validateBilling(authInfo, modelInfo)
validateModelSettings(authInfo)
updateProviderKey(authInfo, providerInfo)
@ -291,6 +299,7 @@ export async function handler(
function selectProvider(
zenData: ZenData,
authInfo: AuthInfo,
modelInfo: ModelInfo,
sessionId: string,
isTrial: boolean,
@ -298,6 +307,10 @@ export async function handler(
stickyProvider: string | undefined,
) {
const provider = (() => {
if (authInfo?.provider?.credentials) {
return modelInfo.providers.find((provider) => provider.id === modelInfo.byokProvider)
}
if (isTrial) {
return modelInfo.providers.find((provider) => provider.id === modelInfo.trial!.provider)
}
@ -342,15 +355,15 @@ export async function handler(
}
}
async function authenticate(modelInfo: ModelInfo, providerInfo: ProviderInfo) {
async function authenticate(modelInfo: ModelInfo) {
const apiKey = opts.parseApiKey(input.request.headers)
if (!apiKey || apiKey === "public") {
if (modelInfo.allowAnonymous) return
throw new AuthError("Missing API key.")
}
const data = await Database.use((tx) =>
tx
const data = await Database.use((tx) => {
const query = tx
.select({
apiKey: KeyTable.id,
workspaceID: KeyTable.workspaceID,
@ -378,13 +391,15 @@ export async function handler(
.innerJoin(BillingTable, eq(BillingTable.workspaceID, KeyTable.workspaceID))
.innerJoin(UserTable, and(eq(UserTable.workspaceID, KeyTable.workspaceID), eq(UserTable.id, KeyTable.userID)))
.leftJoin(ModelTable, and(eq(ModelTable.workspaceID, KeyTable.workspaceID), eq(ModelTable.model, modelInfo.id)))
.leftJoin(
if (modelInfo.byokProvider) {
query.leftJoin(
ProviderTable,
and(eq(ProviderTable.workspaceID, KeyTable.workspaceID), eq(ProviderTable.provider, providerInfo.id)),
and(eq(ProviderTable.workspaceID, KeyTable.workspaceID), eq(ProviderTable.provider, modelInfo.byokProvider)),
)
.where(and(eq(KeyTable.key, apiKey), isNull(KeyTable.timeDeleted)))
.then((rows) => rows[0]),
)
}
return query.where(and(eq(KeyTable.key, apiKey), isNull(KeyTable.timeDeleted))).then((rows) => rows[0])
})
if (!data) throw new AuthError("Invalid API key.")
logger.metric({
@ -457,8 +472,7 @@ export async function handler(
}
function updateProviderKey(authInfo: AuthInfo, providerInfo: ProviderInfo) {
if (!authInfo) return
if (!authInfo.provider?.credentials) return
if (!authInfo?.provider?.credentials) return
providerInfo.apiKey = authInfo.provider.credentials
}

View file

@ -24,6 +24,7 @@ export namespace ZenData {
cost: ModelCostSchema,
cost200K: ModelCostSchema.optional(),
allowAnonymous: z.boolean().optional(),
byokProvider: z.enum(["openai", "anthropic", "google"]).optional(),
stickyProvider: z.boolean().optional(),
trial: z
.object({