From 5b0dfd7b36da7c516c1c0b584ff90e40cfc8765e Mon Sep 17 00:00:00 2001 From: ByteAtATime Date: Thu, 19 Jun 2025 19:26:07 -0700 Subject: [PATCH] feat: implement basic oauth api --- packages/protocol/src/index.ts | 47 +++++- sidecar/src/api/index.ts | 2 + sidecar/src/api/oauth.ts | 279 +++++++++++++++++++++++++++++++++ sidecar/src/index.ts | 21 +++ src-tauri/src/lib.rs | 6 +- src-tauri/src/oauth.rs | 87 ++++++++++ src/lib/sidecar.svelte.ts | 36 +++++ src/routes/+page.svelte | 28 +++- 8 files changed, 497 insertions(+), 9 deletions(-) create mode 100644 sidecar/src/api/oauth.ts create mode 100644 src-tauri/src/oauth.rs diff --git a/packages/protocol/src/index.ts b/packages/protocol/src/index.ts index 9bb585e..3ee1a78 100644 --- a/packages/protocol/src/index.ts +++ b/packages/protocol/src/index.ts @@ -110,7 +110,7 @@ export type SidecarMessage = z.infer; export const PreferenceSchema = z.object({ name: z.string(), - title: z.string(), + title: z.string().optional(), description: z.string().optional(), type: z.enum(['textfield', 'dropdown', 'checkbox', 'directory']), required: z.boolean().optional(), @@ -249,6 +249,45 @@ const ClipboardClearMessageSchema = z.object({ payload: ClipboardClearPayloadSchema }); +const OauthAuthorizePayloadSchema = z.object({ + url: z.string(), + providerName: z.string(), + providerIcon: z.string().optional(), + description: z.string().optional() +}); +const OauthAuthorizeMessageSchema = z.object({ + type: z.literal('oauth-authorize'), + payload: OauthAuthorizePayloadSchema +}); + +const OauthGetTokensPayloadSchema = z.object({ + requestId: z.string(), + providerId: z.string() +}); +const OauthGetTokensMessageSchema = z.object({ + type: z.literal('oauth-get-tokens'), + payload: OauthGetTokensPayloadSchema +}); + +const OauthSetTokensPayloadSchema = z.object({ + requestId: z.string(), + providerId: z.string(), + tokens: z.record(z.string(), z.unknown()) +}); +const OauthSetTokensMessageSchema = z.object({ + type: z.literal('oauth-set-tokens'), + payload: OauthSetTokensPayloadSchema +}); + +const OauthRemoveTokensPayloadSchema = z.object({ + requestId: z.string(), + providerId: z.string() +}); +const OauthRemoveTokensMessageSchema = z.object({ + type: z.literal('oauth-remove-tokens'), + payload: OauthRemoveTokensPayloadSchema +}); + export const SidecarMessageWithPluginsSchema = z.union([ BatchUpdateSchema, CommandSchema, @@ -264,6 +303,10 @@ export const SidecarMessageWithPluginsSchema = z.union([ ClipboardPasteMessageSchema, ClipboardReadMessageSchema, ClipboardReadTextMessageSchema, - ClipboardClearMessageSchema + ClipboardClearMessageSchema, + OauthAuthorizeMessageSchema, + OauthGetTokensMessageSchema, + OauthSetTokensMessageSchema, + OauthRemoveTokensMessageSchema ]); export type SidecarMessageWithPlugins = z.infer; diff --git a/sidecar/src/api/index.ts b/sidecar/src/api/index.ts index 9acbe93..381ba55 100644 --- a/sidecar/src/api/index.ts +++ b/sidecar/src/api/index.ts @@ -15,6 +15,7 @@ import { preferencesStore } from '../preferences'; import { showToast } from './toast'; import { BrowserExtensionAPI } from './browserExtension'; import { Clipboard } from './clipboard'; +import * as OAuth from './oauth'; let currentPluginName: string | null = null; let currentPluginPreferences: Array<{ @@ -45,6 +46,7 @@ export const getRaycastApi = () => { Icon, LaunchType, Toast, + OAuth, Action, ActionPanel, Detail, diff --git a/sidecar/src/api/oauth.ts b/sidecar/src/api/oauth.ts new file mode 100644 index 0000000..3ccfb19 --- /dev/null +++ b/sidecar/src/api/oauth.ts @@ -0,0 +1,279 @@ +import * as crypto from 'crypto'; +import { writeOutput, writeLog } from '../io'; + +export enum RedirectMethod { + Web = 'web', + App = 'app', + AppURI = 'app-uri' +} + +export interface PKCEClientOptions { + redirectMethod: RedirectMethod; + providerName: string; + providerIcon?: string; + description?: string; + providerId?: string; +} + +export interface AuthorizationRequestOptions { + endpoint: string; + clientId: string; + scope: string; + extraParameters?: { [key: string]: string }; +} + +export interface AuthorizationRequest { + url: string; + codeVerifier: string; + codeChallenge: string; + redirectURI: string; + state: string; + toURL: () => string; +} + +export interface AuthorizationOptions { + url: string; +} + +export interface AuthorizationResponse { + authorizationCode: string; +} + +export interface TokenResponse { + access_token: string; + refresh_token?: string; + expires_in?: number; + scope?: string; + id_token?: string; +} + +export interface TokenSetOptions { + accessToken: string; + refreshToken?: string; + expiresIn?: number; + scope?: string; + idToken?: string; +} + +export interface TokenSet { + accessToken: string; + refreshToken?: string; + expiresIn?: number; + scope?: string; + idToken?: string; + updatedAt: Date; + isExpired: () => boolean; +} + +const pendingAuthorizationRequests = new Map< + string, + { resolve: (value: AuthorizationResponse) => void; reject: (reason?: any) => void } +>(); + +const pendingTokenRequests = new Map< + string, + { resolve: (value: any) => void; reject: (reason?: any) => void } +>(); + +export function handleOAuthResponse( + _requestId: string, + code: string, + state: string, + error?: string +) { + const promise = pendingAuthorizationRequests.get(state); + if (promise) { + if (error) { + promise.reject(new Error(error)); + } else { + promise.resolve({ authorizationCode: code }); + } + pendingAuthorizationRequests.delete(state); + } else { + writeLog(`OAuth state mismatch. Request ID (state): ${state} not found in pending requests.`); + } +} + +export function handleTokenResponse(requestId: string, result: any, error?: string) { + const promise = pendingTokenRequests.get(requestId); + if (promise) { + if (error) { + promise.reject(new Error(error)); + } else { + promise.resolve(result); + } + pendingTokenRequests.delete(requestId); + } +} + +function sendTokenRequest(type: string, payload: object): Promise { + return new Promise((resolve, reject) => { + const requestId = crypto.randomUUID(); + pendingTokenRequests.set(requestId, { resolve, reject }); + writeOutput({ + type, + payload: { requestId, ...payload } + }); + setTimeout(() => { + if (pendingTokenRequests.has(requestId)) { + pendingTokenRequests.delete(requestId); + reject(new Error(`Token request for ${type} timed out`)); + } + }, 5000); + }); +} + +export class PKCEClient { + private options: PKCEClientOptions; + + constructor(options: PKCEClientOptions) { + this.options = options; + } + + private getProviderId(): string { + return this.options.providerId ?? this.options.providerName.toLowerCase().replace(/\s/g, '-'); + } + + async authorizationRequest(options: AuthorizationRequestOptions): Promise { + const codeVerifier = crypto.randomBytes(32).toString('base64url'); + const codeChallenge = crypto.createHash('sha256').update(codeVerifier).digest('base64url'); + const state = JSON.stringify({ + providerName: this.options.providerName, + id: crypto.randomUUID(), + flavor: 'release' + }); + + let redirectURI: string; + const packageName = 'Extension'; // TODO: what does this mean, and is it always the same? + switch (this.options.redirectMethod) { + case RedirectMethod.Web: + redirectURI = `https://raycast.com/redirect?packageName=${packageName}`; + break; + case RedirectMethod.App: + redirectURI = `raycast://oauth?package_name=${packageName}`; + break; + case RedirectMethod.AppURI: + redirectURI = `com.raycast:/oauth?package_name=${packageName}`; + break; + } + + const urlParams = new URLSearchParams({ + response_type: 'code', + client_id: options.clientId, + scope: options.scope, + redirect_uri: redirectURI, + state: state, + code_challenge: codeChallenge, + code_challenge_method: 'S256', + ...options.extraParameters + }); + + const authRequest: AuthorizationRequest = { + url: `${options.endpoint}?${urlParams.toString()}`, + codeVerifier, + codeChallenge, + redirectURI, + state, + toURL: () => authRequest.url + }; + + return authRequest; + } + + async authorize( + authRequest: AuthorizationRequest | AuthorizationOptions + ): Promise { + const state = + 'state' in authRequest + ? authRequest.state + : new URL(authRequest.url).searchParams.get('state'); + + if (!state) { + throw new Error('State parameter is missing from authorization request.'); + } + + return new Promise((resolve, reject) => { + pendingAuthorizationRequests.set(state, { resolve, reject }); + + writeOutput({ + type: 'oauth-authorize', + payload: { + url: authRequest.url, + providerName: this.options.providerName, + providerIcon: this.options.providerIcon, + description: this.options.description + } + }); + + setTimeout( + () => { + if (pendingAuthorizationRequests.has(state)) { + pendingAuthorizationRequests.delete(state); + reject(new Error('OAuth authorization timed out')); + } + }, + 5 * 60 * 1000 + ); + }); + } + + async getTokens(): Promise { + const tokenData = await sendTokenRequest('oauth-get-tokens', { + providerId: this.getProviderId() + }); + + if (!tokenData) { + return undefined; + } + + const updatedAt = new Date(tokenData.updatedAt); + const expiresIn = tokenData.expiresIn; + + const tokenSet: TokenSet = { + ...tokenData, + updatedAt: updatedAt, + isExpired: () => { + if (!expiresIn) { + return false; + } + const now = new Date(); + const expiryDate = new Date(updatedAt.getTime() + expiresIn * 1000); + return now.getTime() > expiryDate.getTime() - 60000; + } + }; + + return tokenSet; + } + + async setTokens(tokens: TokenSetOptions | TokenResponse): Promise { + let tokenSetOptions: TokenSetOptions; + + if ('access_token' in tokens) { + tokenSetOptions = { + accessToken: tokens.access_token, + refreshToken: tokens.refresh_token, + expiresIn: tokens.expires_in, + scope: tokens.scope, + idToken: tokens.id_token + }; + } else { + tokenSetOptions = tokens; + } + + const payload = { + ...tokenSetOptions, + updatedAt: new Date().toISOString() + }; + + await sendTokenRequest('oauth-set-tokens', { + providerId: this.getProviderId(), + tokens: payload + }); + } + + async removeTokens(): Promise { + await sendTokenRequest('oauth-remove-tokens', { + providerId: this.getProviderId() + }); + } +} diff --git a/sidecar/src/index.ts b/sidecar/src/index.ts index dc21eb4..0c613d9 100644 --- a/sidecar/src/index.ts +++ b/sidecar/src/index.ts @@ -12,6 +12,7 @@ import { } from './api/environment'; import { handleBrowserExtensionResponse } from './api/browserExtension'; import { handleClipboardResponse } from './api/clipboard'; +import { handleOAuthResponse, handleTokenResponse } from './api/oauth'; process.on('unhandledRejection', (reason: unknown) => { writeLog(`--- UNHANDLED PROMISE REJECTION ---`); @@ -151,6 +152,26 @@ rl.on('line', (line) => { browserExtensionState.isConnected = isConnected; break; } + case 'oauth-authorize-response': { + const { code, state, error } = command.payload as { + code: string; + state: string; + error?: string; + }; + handleOAuthResponse(state, code, state, error); + break; + } + case 'oauth-get-tokens-response': + case 'oauth-set-tokens-response': + case 'oauth-remove-tokens-response': { + const { requestId, result, error } = command.payload as { + requestId: string; + result?: any; + error?: string; + }; + handleTokenResponse(requestId, result, error); + break; + } case 'clipboard-read-text-response': case 'clipboard-read-response': case 'clipboard-copy-response': diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index caa5a95..591ab4f 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -6,6 +6,7 @@ mod desktop; mod error; mod extensions; mod filesystem; +mod oauth; use crate::{app::App, cache::AppCache}; use browser_extension::WsState; @@ -135,7 +136,10 @@ pub fn run() { clipboard::clipboard_read, clipboard::clipboard_copy, clipboard::clipboard_paste, - clipboard::clipboard_clear + clipboard::clipboard_clear, + oauth::oauth_set_tokens, + oauth::oauth_get_tokens, + oauth::oauth_remove_tokens ]) .setup(|app| { let app_handle = app.handle().clone(); diff --git a/src-tauri/src/oauth.rs b/src-tauri/src/oauth.rs new file mode 100644 index 0000000..abb54df --- /dev/null +++ b/src-tauri/src/oauth.rs @@ -0,0 +1,87 @@ +use serde::{Deserialize, Serialize}; +use serde_json; +use std::collections::HashMap; +use std::fs; +use std::path::{Path, PathBuf}; +use tauri::Manager; + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +struct StoredTokenSet { + access_token: String, + refresh_token: Option, + expires_in: Option, + scope: Option, + id_token: Option, + updated_at: String, +} + +type TokenStore = HashMap; + +fn get_storage_path(app: &tauri::AppHandle) -> Result { + let data_dir = app + .path() + .app_local_data_dir() + .map_err(|_| "Failed to get app local data dir".to_string())?; + + if !data_dir.exists() { + fs::create_dir_all(&data_dir).map_err(|e| e.to_string())?; + } + + Ok(data_dir.join("oauth_tokens.json")) +} + +fn read_store(path: &Path) -> Result { + if !path.exists() { + return Ok(HashMap::new()); + } + let content = fs::read_to_string(path).map_err(|e| e.to_string())?; + if content.trim().is_empty() { + return Ok(HashMap::new()); + } + serde_json::from_str(&content).map_err(|e| e.to_string()) +} + +fn write_store(path: &Path, store: &TokenStore) -> Result<(), String> { + let content = serde_json::to_string_pretty(store).map_err(|e| e.to_string())?; + fs::write(path, content).map_err(|e| e.to_string()) +} + +#[tauri::command] +pub fn oauth_set_tokens( + app: tauri::AppHandle, + provider_id: String, + tokens: serde_json::Value, +) -> Result<(), String> { + let path = get_storage_path(&app)?; + let mut store = read_store(&path)?; + + let token_set: StoredTokenSet = + serde_json::from_value(tokens).map_err(|e| e.to_string())?; + + store.insert(provider_id, token_set); + write_store(&path, &store) +} + +#[tauri::command] +pub fn oauth_get_tokens( + app: tauri::AppHandle, + provider_id: String, +) -> Result, String> { + let path = get_storage_path(&app)?; + let store = read_store(&path)?; + if let Some(token_set) = store.get(&provider_id) { + let value = serde_json::to_value(token_set).map_err(|e| e.to_string())?; + Ok(Some(value)) + } else { + Ok(None) + } +} + +#[tauri::command] +pub fn oauth_remove_tokens(app: tauri::AppHandle, provider_id: String) -> Result<(), String> { + let path = get_storage_path(&app)?; + let mut store = read_store(&path)?; + store.remove(&provider_id); + write_store(&path, &store) +} \ No newline at end of file diff --git a/src/lib/sidecar.svelte.ts b/src/lib/sidecar.svelte.ts index faf9ca7..43997d1 100644 --- a/src/lib/sidecar.svelte.ts +++ b/src/lib/sidecar.svelte.ts @@ -4,6 +4,7 @@ import { uiStore } from '$lib/ui.svelte'; import { SidecarMessageWithPluginsSchema } from '@raycast-linux/protocol'; import { invoke } from '@tauri-apps/api/core'; import { appCacheDir, appLocalDataDir } from '@tauri-apps/api/path'; +import { openUrl } from '@tauri-apps/plugin-opener'; class SidecarService { #sidecarChild: Child | null = $state(null); @@ -171,6 +172,41 @@ class SidecarService { return; } + if (typedMessage.type.startsWith('oauth-')) { + if (typedMessage.type === 'oauth-authorize') { + const { url } = typedMessage.payload; + openUrl(url).catch((err) => { + this.#log(`ERROR: Failed to open OAuth URL '${url}': ${err}`); + }); + return; + } + + const { requestId, ...params } = typedMessage.payload as { + requestId: string; + [key: string]: any; + }; + + const commandMap: Record = { + 'oauth-get-tokens': 'oauth_get_tokens', + 'oauth-set-tokens': 'oauth_set_tokens', + 'oauth-remove-tokens': 'oauth_remove_tokens' + }; + const command = commandMap[typedMessage.type]; + + if (command) { + const responseType = `${typedMessage.type}-response`; + try { + const result = await invoke(command, params); + this.dispatchEvent(responseType, { requestId, result }); + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error); + this.#log(`ERROR from ${command}: ${errorMessage}`); + this.dispatchEvent(responseType, { requestId, error: errorMessage }); + } + return; + } + } + if (typedMessage.type === 'plugin-list') { uiStore.setPluginList(typedMessage.payload); return; diff --git a/src/routes/+page.svelte b/src/routes/+page.svelte index 7f71a31..f645cfd 100644 --- a/src/routes/+page.svelte +++ b/src/routes/+page.svelte @@ -70,12 +70,28 @@ }); if (urlObj.protocol === 'raycast:') { - switch (urlObj.host) { - case 'extensions': - viewState = 'extensions-store'; - break; - default: - viewState = 'plugin-list'; + if (urlObj.host === 'oauth-callback' || urlObj.pathname.startsWith('/redirect')) { + const params = urlObj.searchParams; + const code = params.get('code'); + const state = params.get('state'); + if (code && state) { + sidecarService.dispatchEvent('oauth-authorize-response', { code, state }); + } else { + const error = params.get('error') || 'Unknown OAuth error'; + const errorDescription = params.get('error_description'); + sidecarService.dispatchEvent('oauth-authorize-response', { + state, + error: `${error}: ${errorDescription}` + }); + } + } else { + switch (urlObj.host) { + case 'extensions': + viewState = 'extensions-store'; + break; + default: + viewState = 'plugin-list'; + } } } } catch (error) {