From 25ceba2c97ff7aa2d0fd4f673235ee75cb6b76e4 Mon Sep 17 00:00:00 2001 From: Sahaj Jain Date: Tue, 5 Nov 2024 15:50:19 +0530 Subject: [PATCH] feat: Implement draft for claude and refactor --- src/provider/claude.rs | 112 ++++++++++++++++++++++------------------- src/provider/mod.rs | 6 +-- 2 files changed, 62 insertions(+), 56 deletions(-) diff --git a/src/provider/claude.rs b/src/provider/claude.rs index d26306c..ef3836c 100644 --- a/src/provider/claude.rs +++ b/src/provider/claude.rs @@ -1,15 +1,9 @@ -use super::AIProvider; +use super::{AIProvider, ProviderError}; use crate::{ai_prompt::AIPrompt, git_entity::GitEntity}; use async_trait::async_trait; use serde::Deserialize; use serde_json::json; -pub struct ClaudeProvider { - client: reqwest::Client, - api_key: String, - model: String, -} - #[derive(Deserialize)] struct ClaudeResponse { content: Vec, @@ -20,66 +14,78 @@ struct ClaudeContent { text: String, } -impl ClaudeProvider { - pub fn new(client: reqwest::Client, api_key: String, model: Option) -> Self { - ClaudeProvider { - client, +// Configuration type to match OpenAI pattern +#[derive(Clone)] +pub struct ClaudeConfig { + api_key: String, + model: String, + api_base_url: String, +} + +impl ClaudeConfig { + pub fn new(api_key: String, model: Option) -> Self { + Self { api_key, model: model.unwrap_or_else(|| "claude-3-5-sonnet-20241022".to_string()), + api_base_url: "https://api.anthropic.com/v1/messages".to_string(), } } } -async fn get_completion_result( - client: &reqwest::Client, - api_key: &str, - payload: serde_json::Value, -) -> Result> { - let response = client - .post("https://api.anthropic.com/v1/messages") - .header("x-api-key", api_key) - .header("anthropic-version", "2023-06-01") - .header("Content-Type", "application/json") - .json(&payload) - .send() - .await?; +pub struct ClaudeProvider { + client: reqwest::Client, + config: ClaudeConfig, +} - let claude_response: ClaudeResponse = response.json().await?; - Ok(claude_response - .content - .first() - .map(|content| content.text.clone()) - .unwrap_or_default()) +impl ClaudeProvider { + pub fn new(client: reqwest::Client, config: ClaudeConfig) -> Self { + Self { client, config } + } + + async fn complete(&self, prompt: AIPrompt) -> Result { + let payload = json!({ + "model": self.config.model, + "max_tokens": 4096, + "messages": [ + { + "role": "system", + "content": prompt.system_prompt + }, + { + "role": "user", + "content": prompt.user_prompt, + } + ] + }); + + let response = self + .client + .post(&self.config.api_base_url) + .header("x-api-key", &self.config.api_key) + .header("anthropic-version", "2023-06-01") + .header("Content-Type", "application/json") + .json(&payload) + .send() + .await?; + + let claude_response: ClaudeResponse = response.json().await?; + claude_response + .content + .first() + .map(|content| content.text.clone()) + .ok_or(ProviderError::NoCompletionChoice) + } } #[async_trait] impl AIProvider for ClaudeProvider { async fn explain(&self, git_entity: GitEntity) -> Result> { - let AIPrompt { - system_prompt, - user_prompt, - } = AIPrompt::build_explain_prompt(&git_entity); - - let payload = json!({ - "model": self.model, - "max_tokens": 4096, - "messages": [ - { - "role": "system", - "content": system_prompt - }, - { - "role": "user", - "content": user_prompt, - } - ] - }); - - let res = get_completion_result(&self.client, &self.api_key, payload).await?; - Ok(res) + let prompt = AIPrompt::build_explain_prompt(&git_entity); + Ok(self.complete(prompt).await?) } async fn draft(&self, git_entity: GitEntity) -> Result> { - todo!() + let prompt = AIPrompt::build_draft_prompt(&git_entity)?; + Ok(self.complete(prompt).await?) } } diff --git a/src/provider/mod.rs b/src/provider/mod.rs index 3bd87c7..f5384b0 100644 --- a/src/provider/mod.rs +++ b/src/provider/mod.rs @@ -1,5 +1,5 @@ use async_trait::async_trait; -use claude::ClaudeProvider; +use claude::{ClaudeConfig, ClaudeProvider}; use groq::GroqProvider; use openai::{OpenAIConfig, OpenAIProvider}; use phind::PhindProvider; @@ -59,8 +59,8 @@ impl LumenProvider { } ProviderType::Claude => { let api_key = api_key.ok_or(LumenError::MissingApiKey("Claude".to_string()))?; - let provider = - LumenProvider::Claude(Box::new(ClaudeProvider::new(client, api_key, model))); + let config = ClaudeConfig::new(api_key, model); + let provider = LumenProvider::Claude(Box::new(ClaudeProvider::new(client, config))); Ok(provider) } }