fix(ai): refactor AI model settings and fallback logic

This commit moves default model mappings into the backend, making it the single source of truth. The `get_ai_settings` command now merges user-defined settings with these defaults, so the UI always receives a complete and correct configuration. The model selection logic in `ai_ask_stream` was also simplified to use the merged settings, which fixes a bug where the system would use a global fallback instead of the correct model-specific default. Additionally, the `set_ai_settings` command now only persists user overrides, keeping the configuration file clean.
This commit is contained in:
ByteAtATime 2025-06-27 18:27:53 -07:00
parent 9b7b29d720
commit 5dafd76912
No known key found for this signature in database
5 changed files with 248 additions and 81 deletions

View file

@ -6,8 +6,7 @@ export const AiAskStreamPayloadSchema = z.object({
options: z
.object({
model: z.string().optional(),
creativity: z.string().optional(),
modelMappings: z.record(z.string(), z.string()).optional()
creativity: z.string().optional()
})
.optional()
});

View file

@ -1,54 +1,53 @@
import { EventEmitter } from 'events';
import { writeLog, writeOutput } from '../io';
import { inspect } from 'util';
export const Model = {
'OpenAI_GPT4.1': 'openai/gpt-4.1',
'OpenAI_GPT4.1-mini': 'openai/gpt-4.1-mini',
'OpenAI_GPT4.1-nano': 'openai/gpt-4.1-nano',
OpenAI_GPT4: 'openai/gpt-4',
'OpenAI_GPT4-turbo': 'openai/gpt-4-turbo',
OpenAI_GPT4o: 'openai/gpt-4o',
'OpenAI_GPT4o-mini': 'openai/gpt-4o-mini',
OpenAI_o3: 'openai/o3',
'OpenAI_o4-mini': 'openai/o4-mini',
OpenAI_o1: 'openai/o1',
'OpenAI_o3-mini': 'openai/o3-mini',
Anthropic_Claude_Haiku: 'anthropic/claude-3-haiku',
Anthropic_Claude_Sonnet: 'anthropic/claude-3-sonnet',
'Anthropic_Claude_Sonnet_3.7': 'anthropic/claude-3.7-sonnet',
Anthropic_Claude_Opus: 'anthropic/claude-3-opus',
Anthropic_Claude_4_Sonnet: 'anthropic/claude-sonnet-4',
Anthropic_Claude_4_Opus: 'anthropic/claude-opus-4',
Perplexity_Sonar: 'perplexity/sonar',
Perplexity_Sonar_Pro: 'perplexity/sonar-pro',
Perplexity_Sonar_Reasoning: 'perplexity/sonar-reasoning',
Perplexity_Sonar_Reasoning_Pro: 'perplexity/sonar-reasoning-pro',
Llama4_Scout: 'meta-llama/llama-4-scout',
'Llama3.3_70B': 'meta-llama/llama-3.3-70b-instruct',
'Llama3.1_8B': 'meta-llama/llama-3.1-8b-instruct',
'Llama3.1_405B': 'meta-llama/llama-3.1-405b-instruct',
Mistral_Nemo: 'mistralai/mistral-nemo',
Mistral_Large: 'mistralai/mistral-large',
Mistral_Medium: 'mistralai/mistral-medium-3',
Mistral_Small: 'mistralai/mistral-small',
Mistral_Codestral: 'mistralai/codestral-2501',
'DeepSeek_R1_Distill_Llama_3.3_70B': 'deepseek/deepseek-r1-distill-llama-70b',
DeepSeek_R1: 'deepseek/deepseek-r1',
DeepSeek_V3: 'deepseek/deepseek-chat',
'Google_Gemini_2.5_Pro': 'google/gemini-2.5-pro',
'Google_Gemini_2.5_Flash': 'google/gemini-2.5-flash',
'Google_Gemini_2.0_Flash': 'google/gemini-2.0-flash-001',
xAI_Grok_3: 'x-ai/grok-3',
xAI_Grok_3_Mini: 'x-ai/grok-3-mini',
xAI_Grok_2: 'x-ai/grok-2-1212'
} as const;
import { writeOutput } from '../io';
export type Creativity = 'none' | 'low' | 'medium' | 'high' | 'maximum' | number;
export enum Model {
'OpenAI_GPT4.1' = 'OpenAI_GPT4.1',
'OpenAI_GPT4.1-mini' = 'OpenAI_GPT4.1-mini',
'OpenAI_GPT4.1-nano' = 'OpenAI_GPT4.1-nano',
OpenAI_GPT4 = 'OpenAI_GPT4',
'OpenAI_GPT4-turbo' = 'OpenAI_GPT4-turbo',
OpenAI_GPT4o = 'OpenAI_GPT4o',
'OpenAI_GPT4o-mini' = 'OpenAI_GPT4o-mini',
OpenAI_o3 = 'OpenAI_o3',
'OpenAI_o4-mini' = 'OpenAI_o4-mini',
OpenAI_o1 = 'OpenAI_o1',
'OpenAI_o3-mini' = 'OpenAI_o3-mini',
Anthropic_Claude_Haiku = 'Anthropic_Claude_Haiku',
Anthropic_Claude_Sonnet = 'Anthropic_Claude_Sonnet',
'Anthropic_Claude_Sonnet_3.7' = 'Anthropic_Claude_Sonnet_3.7',
Anthropic_Claude_Opus = 'Anthropic_Claude_Opus',
Anthropic_Claude_4_Sonnet = 'Anthropic_Claude_4_Sonnet',
Anthropic_Claude_4_Opus = 'Anthropic_Claude_4_Opus',
Perplexity_Sonar = 'Perplexity_Sonar',
Perplexity_Sonar_Pro = 'Perplexity_Sonar_Pro',
Perplexity_Sonar_Reasoning = 'Perplexity_Sonar_Reasoning',
Perplexity_Sonar_Reasoning_Pro = 'Perplexity_Sonar_Reasoning_Pro',
Llama4_Scout = 'Llama4_Scout',
'Llama3.3_70B' = 'Llama3.3_70B',
'Llama3.1_8B' = 'Llama3.1_8B',
'Llama3.1_405B' = 'Llama3.1_405B',
Mistral_Nemo = 'Mistral_Nemo',
Mistral_Large = 'Mistral_Large',
Mistral_Medium = 'Mistral_Medium',
Mistral_Small = 'Mistral_Small',
Mistral_Codestral = 'Mistral_Codestral',
'DeepSeek_R1_Distill_Llama_3.3_70B' = 'DeepSeek_R1_Distill_Llama_3.3_70B',
DeepSeek_R1 = 'DeepSeek_R1',
DeepSeek_V3 = 'DeepSeek_V3',
'Google_Gemini_2.5_Pro' = 'Google_Gemini_2.5_Pro',
'Google_Gemini_2.5_Flash' = 'Google_Gemini_2.5_Flash',
'Google_Gemini_2.0_Flash' = 'Google_Gemini_2.0_Flash',
xAI_Grok_3 = 'xAI_Grok_3',
xAI_Grok_3_Mini = 'xAI_Grok_3_Mini',
xAI_Grok_2 = 'xAI_Grok_2'
}
export interface AskOptions {
creativity?: Creativity;
model?: keyof typeof Model;
model?: string;
signal?: AbortSignal;
}
@ -90,11 +89,6 @@ export function ask(prompt: string, options: AskOptions = {}): AskResult {
const emitter = new EventEmitter();
const requestId = crypto.randomUUID();
const modelMappings: Record<string, string> = {};
if (options.model && Model[options.model]) {
modelMappings[options.model] = Model[options.model];
}
let fullText = '';
let isResolved = false;
@ -142,8 +136,7 @@ export function ask(prompt: string, options: AskOptions = {}): AskResult {
prompt,
options: {
model: options.model,
creativity: options.creativity,
modelMappings
creativity: options.creativity
}
}
});
@ -158,7 +151,6 @@ export function ask(prompt: string, options: AskOptions = {}): AskResult {
export const AI = {
ask,
Model,
Creativity: {
none: 'none' as const,
low: 'low' as const,

View file

@ -1,24 +1,23 @@
use crate::error::AppError;
use futures_util::StreamExt;
use once_cell::sync::Lazy;
use rusqlite::{params, Connection, Result as RusqliteResult};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use tauri::{AppHandle, Emitter, Manager, State};
const AI_KEYRING_SERVICE: &str = "dev.byteatatime.raycast.ai";
const AI_KEYRING_USERNAME: &str = "openrouter_api_key";
// --- Structs for API and Events ---
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AskOptions {
pub model: Option<String>,
pub creativity: Option<String>,
#[serde(default)]
model_mappings: HashMap<String, String>,
}
#[derive(Serialize, Clone)]
@ -52,7 +51,140 @@ pub struct GenerationData {
pub total_cost: f64,
}
// --- Key Management Commands ---
static DEFAULT_AI_MODELS: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
let mut m = HashMap::new();
// OpenAI
m.insert("OpenAI_GPT4.1", "openai/gpt-4.1");
m.insert("OpenAI_GPT4.1-mini", "openai/gpt-4.1-mini");
m.insert("OpenAI_GPT4.1-nano", "openai/gpt-4.1-nano");
m.insert("OpenAI_GPT4", "openai/gpt-4");
m.insert("OpenAI_GPT4-turbo", "openai/gpt-4-turbo");
m.insert("OpenAI_GPT4o", "openai/gpt-4o");
m.insert("OpenAI_GPT4o-mini", "openai/gpt-4o-mini");
m.insert("OpenAI_o3", "openai/o3");
m.insert("OpenAI_o4-mini", "openai/o4-mini");
m.insert("OpenAI_o1", "openai/o1");
m.insert("OpenAI_o3-mini", "openai/o3-mini");
// Anthropic
m.insert("Anthropic_Claude_Haiku", "anthropic/claude-3-haiku");
m.insert("Anthropic_Claude_Sonnet", "anthropic/claude-3-sonnet");
m.insert("Anthropic_Claude_Sonnet_3.7", "anthropic/claude-3.7-sonnet");
m.insert("Anthropic_Claude_Opus", "anthropic/claude-3-opus");
m.insert("Anthropic_Claude_4_Sonnet", "anthropic/claude-sonnet-4");
m.insert("Anthropic_Claude_4_Opus", "anthropic/claude-opus-4");
// Perplexity
m.insert("Perplexity_Sonar", "perplexity/sonar");
m.insert("Perplexity_Sonar_Pro", "perplexity/sonar-pro");
m.insert("Perplexity_Sonar_Reasoning", "perplexity/sonar-reasoning");
m.insert(
"Perplexity_Sonar_Reasoning_Pro",
"perplexity/sonar-reasoning-pro",
);
// Meta
m.insert("Llama4_Scout", "meta-llama/llama-4-scout");
m.insert("Llama3.3_70B", "meta-llama/llama-3.3-70b-instruct");
m.insert("Llama3.1_8B", "meta-llama/llama-3.1-8b-instruct");
m.insert("Llama3.1_405B", "meta-llama/llama-3.1-405b-instruct");
// Mistral
m.insert("Mistral_Nemo", "mistralai/mistral-nemo");
m.insert("Mistral_Large", "mistralai/mistral-large");
m.insert("Mistral_Medium", "mistralai/mistral-medium-3");
m.insert("Mistral_Small", "mistralai/mistral-small");
m.insert("Mistral_Codestral", "mistralai/codestral-2501");
// DeepSeek
m.insert(
"DeepSeek_R1_Distill_Llama_3.3_70B",
"deepseek/deepseek-r1-distill-llama-70b",
);
m.insert("DeepSeek_R1", "deepseek/deepseek-r1");
m.insert("DeepSeek_V3", "deepseek/deepseek-chat");
// Google
m.insert("Google_Gemini_2.5_Pro", "google/gemini-2.5-pro");
m.insert("Google_Gemini_2.5_Flash", "google/gemini-2.5-flash");
m.insert("Google_Gemini_2.0_Flash", "google/gemini-2.0-flash-001");
// xAI
m.insert("xAI_Grok_3", "x-ai/grok-3");
m.insert("xAI_Grok_3_Mini", "x-ai/grok-3-mini");
m.insert("xAI_Grok_2", "x-ai/grok-2-1212");
m
});
#[derive(Serialize, Deserialize, Default, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct AiSettings {
enabled: bool,
model_associations: HashMap<String, String>,
}
fn get_settings_path(app: &tauri::AppHandle) -> Result<PathBuf, String> {
let data_dir = app
.path()
.app_local_data_dir()
.map_err(|_| "Failed to get app local data dir".to_string())?;
if !data_dir.exists() {
fs::create_dir_all(&data_dir).map_err(|e| e.to_string())?;
}
Ok(data_dir.join("ai_settings.json"))
}
fn read_settings(path: &Path) -> Result<AiSettings, String> {
if !path.exists() {
return Ok(AiSettings::default());
}
let content = fs::read_to_string(path).map_err(|e| e.to_string())?;
if content.trim().is_empty() {
return Ok(AiSettings::default());
}
serde_json::from_str(&content).map_err(|e| e.to_string())
}
fn write_settings(path: &Path, settings: &AiSettings) -> Result<(), String> {
let content = serde_json::to_string_pretty(settings).map_err(|e| e.to_string())?;
fs::write(path, content).map_err(|e| e.to_string())
}
#[tauri::command]
pub fn get_ai_settings(app: tauri::AppHandle) -> Result<AiSettings, String> {
let path = get_settings_path(&app)?;
let mut user_settings = read_settings(&path)?;
for (key, &default_value) in DEFAULT_AI_MODELS.iter() {
let entry = user_settings
.model_associations
.entry(key.to_string())
.or_insert_with(|| default_value.to_string());
if entry.is_empty() {
*entry = default_value.to_string();
}
}
Ok(user_settings)
}
#[tauri::command]
pub fn set_ai_settings(app: tauri::AppHandle, settings: AiSettings) -> Result<(), String> {
let path = get_settings_path(&app)?;
let mut settings_to_save = AiSettings {
enabled: settings.enabled,
model_associations: HashMap::new(),
};
for (key, value) in settings.model_associations {
let is_different_from_default = DEFAULT_AI_MODELS
.get(key.as_str())
.map_or(true, |&default_val| default_val != value);
if is_different_from_default {
settings_to_save.model_associations.insert(key, value);
}
}
write_settings(&path, &settings_to_save)
}
fn get_keyring_entry() -> Result<keyring::Entry, AppError> {
keyring::Entry::new(AI_KEYRING_SERVICE, AI_KEYRING_USERNAME).map_err(AppError::from)
@ -81,8 +213,6 @@ pub fn clear_ai_api_key() -> Result<(), String> {
.map_err(|e| e.to_string())
}
// --- Usage Tracking ---
pub struct AiUsageManager {
db: Mutex<Connection>,
}
@ -204,8 +334,6 @@ async fn fetch_and_log_usage(
Ok(())
}
// --- Core Stream Command ---
#[tauri::command]
pub async fn ai_ask_stream(
app_handle: AppHandle,
@ -214,6 +342,11 @@ pub async fn ai_ask_stream(
prompt: String,
options: AskOptions,
) -> Result<(), String> {
let settings = get_ai_settings(app_handle.clone())?;
if !settings.enabled {
return Err("AI features are not enabled.".to_string());
}
let api_key =
match get_keyring_entry().and_then(|entry| entry.get_password().map_err(AppError::from)) {
Ok(key) => key,
@ -221,11 +354,12 @@ pub async fn ai_ask_stream(
};
let model_key = options.model.unwrap_or_else(|| "default".to_string());
// For testing, use a free model if "default" is chosen.
let model_id = options.model_mappings.get(&model_key).map_or_else(
|| "mistralai/mistral-7b-instruct:free".to_string(),
|id| id.clone(),
);
let model_id = settings
.model_associations
.get(&model_key)
.cloned()
.unwrap_or_else(|| "mistralai/mistral-7b-instruct:free".to_string());
let temperature = match options.creativity.as_deref() {
Some("none") => 0.0,
@ -246,7 +380,7 @@ pub async fn ai_ask_stream(
let res = client
.post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {}", api_key))
.header("HTTP-Referer", "http://localhost") // Required by OpenRouter
.header("HTTP-Referer", "http://localhost")
.json(&body)
.send()
.await

View file

@ -255,7 +255,9 @@ pub fn run() {
ai::is_ai_api_key_set,
ai::clear_ai_api_key,
ai::ai_ask_stream,
ai::get_ai_usage_history
ai::get_ai_usage_history,
ai::get_ai_settings,
ai::set_ai_settings
])
.setup(|app| {
let app_handle = app.handle().clone();

View file

@ -5,7 +5,12 @@
import { invoke } from '@tauri-apps/api/core';
import { onMount } from 'svelte';
import PasswordInput from './PasswordInput.svelte';
import { Model } from '../../../sidecar/src/api/ai';
import { uiStore } from '$lib/ui.svelte';
type AiSettings = {
enabled: boolean;
modelAssociations: Record<string, string>;
};
let aiEnabled = $state(false);
let apiKey = $state('');
@ -13,16 +18,52 @@
let isApiKeySet = $state(false);
async function loadSettings() {
isApiKeySet = await invoke('is_ai_api_key_set');
// TODO: Load aiEnabled and modelAssociations from storage
try {
isApiKeySet = await invoke('is_ai_api_key_set');
const settings = await invoke<AiSettings>('get_ai_settings');
aiEnabled = settings.enabled;
modelAssociations = settings.modelAssociations ?? {};
} catch (error) {
console.error('Failed to load AI settings:', error);
uiStore.toasts.set(Date.now(), {
id: Date.now(),
title: 'Failed to load AI settings',
message: String(error),
style: 'FAILURE'
});
}
}
async function saveSettings() {
if (apiKey) {
await invoke('set_ai_api_key', { key: apiKey });
try {
if (apiKey) {
await invoke('set_ai_api_key', { key: apiKey });
apiKey = '';
}
const settingsToSave: AiSettings = {
enabled: aiEnabled,
modelAssociations: modelAssociations
};
await invoke('set_ai_settings', { settings: settingsToSave });
uiStore.toasts.set(Date.now(), {
id: Date.now(),
title: 'AI Settings Saved',
style: 'SUCCESS'
});
await loadSettings();
} catch (error) {
console.error('Failed to save AI settings:', error);
uiStore.toasts.set(Date.now(), {
id: Date.now(),
title: 'Failed to save AI settings',
message: String(error),
style: 'FAILURE'
});
}
// TODO: Save aiEnabled and modelAssociations to storage
await loadSettings();
}
async function clearApiKey() {
@ -66,14 +107,13 @@
Associate Raycast AI models with specific models available through OpenRouter.
</p>
<div class="grid grid-cols-[auto_1fr] items-center gap-4">
{#each Object.entries(Model) as [raycastModel, openRouterModel] (raycastModel)}
{#each Object.entries(modelAssociations) as [raycastModel, openRouterModel] (raycastModel)}
<span class="text-sm font-medium">{raycastModel}</span>
<Input
value={modelAssociations[raycastModel] ?? openRouterModel}
value={openRouterModel}
onchange={(e) => {
modelAssociations[raycastModel] = (e.target as HTMLInputElement)?.value;
}}
placeholder={openRouterModel}
class="w-full"
/>
{/each}