feat: Implement draft for claude and refactor

This commit is contained in:
Sahaj Jain 2024-11-05 15:50:19 +05:30
parent 633c0b58d9
commit 25ceba2c97
2 changed files with 62 additions and 56 deletions

View file

@ -1,15 +1,9 @@
use super::AIProvider;
use super::{AIProvider, ProviderError};
use crate::{ai_prompt::AIPrompt, git_entity::GitEntity};
use async_trait::async_trait;
use serde::Deserialize;
use serde_json::json;
pub struct ClaudeProvider {
client: reqwest::Client,
api_key: String,
model: String,
}
#[derive(Deserialize)]
struct ClaudeResponse {
content: Vec<ClaudeContent>,
@ -20,66 +14,78 @@ struct ClaudeContent {
text: String,
}
impl ClaudeProvider {
pub fn new(client: reqwest::Client, api_key: String, model: Option<String>) -> Self {
ClaudeProvider {
client,
// Configuration type to match OpenAI pattern
#[derive(Clone)]
pub struct ClaudeConfig {
api_key: String,
model: String,
api_base_url: String,
}
impl ClaudeConfig {
pub fn new(api_key: String, model: Option<String>) -> Self {
Self {
api_key,
model: model.unwrap_or_else(|| "claude-3-5-sonnet-20241022".to_string()),
api_base_url: "https://api.anthropic.com/v1/messages".to_string(),
}
}
}
async fn get_completion_result(
client: &reqwest::Client,
api_key: &str,
payload: serde_json::Value,
) -> Result<String, Box<dyn std::error::Error>> {
let response = client
.post("https://api.anthropic.com/v1/messages")
.header("x-api-key", api_key)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json")
.json(&payload)
.send()
.await?;
pub struct ClaudeProvider {
client: reqwest::Client,
config: ClaudeConfig,
}
let claude_response: ClaudeResponse = response.json().await?;
Ok(claude_response
.content
.first()
.map(|content| content.text.clone())
.unwrap_or_default())
impl ClaudeProvider {
pub fn new(client: reqwest::Client, config: ClaudeConfig) -> Self {
Self { client, config }
}
async fn complete(&self, prompt: AIPrompt) -> Result<String, ProviderError> {
let payload = json!({
"model": self.config.model,
"max_tokens": 4096,
"messages": [
{
"role": "system",
"content": prompt.system_prompt
},
{
"role": "user",
"content": prompt.user_prompt,
}
]
});
let response = self
.client
.post(&self.config.api_base_url)
.header("x-api-key", &self.config.api_key)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json")
.json(&payload)
.send()
.await?;
let claude_response: ClaudeResponse = response.json().await?;
claude_response
.content
.first()
.map(|content| content.text.clone())
.ok_or(ProviderError::NoCompletionChoice)
}
}
#[async_trait]
impl AIProvider for ClaudeProvider {
async fn explain(&self, git_entity: GitEntity) -> Result<String, Box<dyn std::error::Error>> {
let AIPrompt {
system_prompt,
user_prompt,
} = AIPrompt::build_explain_prompt(&git_entity);
let payload = json!({
"model": self.model,
"max_tokens": 4096,
"messages": [
{
"role": "system",
"content": system_prompt
},
{
"role": "user",
"content": user_prompt,
}
]
});
let res = get_completion_result(&self.client, &self.api_key, payload).await?;
Ok(res)
let prompt = AIPrompt::build_explain_prompt(&git_entity);
Ok(self.complete(prompt).await?)
}
async fn draft(&self, git_entity: GitEntity) -> Result<String, Box<dyn std::error::Error>> {
todo!()
let prompt = AIPrompt::build_draft_prompt(&git_entity)?;
Ok(self.complete(prompt).await?)
}
}

View file

@ -1,5 +1,5 @@
use async_trait::async_trait;
use claude::ClaudeProvider;
use claude::{ClaudeConfig, ClaudeProvider};
use groq::GroqProvider;
use openai::{OpenAIConfig, OpenAIProvider};
use phind::PhindProvider;
@ -59,8 +59,8 @@ impl LumenProvider {
}
ProviderType::Claude => {
let api_key = api_key.ok_or(LumenError::MissingApiKey("Claude".to_string()))?;
let provider =
LumenProvider::Claude(Box::new(ClaudeProvider::new(client, api_key, model)));
let config = ClaudeConfig::new(api_key, model);
let provider = LumenProvider::Claude(Box::new(ClaudeProvider::new(client, config)));
Ok(provider)
}
}