mirror of
https://github.com/jnsahaj/lumen.git
synced 2025-12-23 05:36:48 +00:00
feat: Implement draft for claude and refactor
This commit is contained in:
parent
633c0b58d9
commit
25ceba2c97
2 changed files with 62 additions and 56 deletions
|
|
@ -1,15 +1,9 @@
|
|||
use super::AIProvider;
|
||||
use super::{AIProvider, ProviderError};
|
||||
use crate::{ai_prompt::AIPrompt, git_entity::GitEntity};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
|
||||
pub struct ClaudeProvider {
|
||||
client: reqwest::Client,
|
||||
api_key: String,
|
||||
model: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ClaudeResponse {
|
||||
content: Vec<ClaudeContent>,
|
||||
|
|
@ -20,66 +14,78 @@ struct ClaudeContent {
|
|||
text: String,
|
||||
}
|
||||
|
||||
impl ClaudeProvider {
|
||||
pub fn new(client: reqwest::Client, api_key: String, model: Option<String>) -> Self {
|
||||
ClaudeProvider {
|
||||
client,
|
||||
// Configuration type to match OpenAI pattern
|
||||
#[derive(Clone)]
|
||||
pub struct ClaudeConfig {
|
||||
api_key: String,
|
||||
model: String,
|
||||
api_base_url: String,
|
||||
}
|
||||
|
||||
impl ClaudeConfig {
|
||||
pub fn new(api_key: String, model: Option<String>) -> Self {
|
||||
Self {
|
||||
api_key,
|
||||
model: model.unwrap_or_else(|| "claude-3-5-sonnet-20241022".to_string()),
|
||||
api_base_url: "https://api.anthropic.com/v1/messages".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_completion_result(
|
||||
client: &reqwest::Client,
|
||||
api_key: &str,
|
||||
payload: serde_json::Value,
|
||||
) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let response = client
|
||||
.post("https://api.anthropic.com/v1/messages")
|
||||
.header("x-api-key", api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?;
|
||||
pub struct ClaudeProvider {
|
||||
client: reqwest::Client,
|
||||
config: ClaudeConfig,
|
||||
}
|
||||
|
||||
let claude_response: ClaudeResponse = response.json().await?;
|
||||
Ok(claude_response
|
||||
.content
|
||||
.first()
|
||||
.map(|content| content.text.clone())
|
||||
.unwrap_or_default())
|
||||
impl ClaudeProvider {
|
||||
pub fn new(client: reqwest::Client, config: ClaudeConfig) -> Self {
|
||||
Self { client, config }
|
||||
}
|
||||
|
||||
async fn complete(&self, prompt: AIPrompt) -> Result<String, ProviderError> {
|
||||
let payload = json!({
|
||||
"model": self.config.model,
|
||||
"max_tokens": 4096,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": prompt.system_prompt
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt.user_prompt,
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.config.api_base_url)
|
||||
.header("x-api-key", &self.config.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let claude_response: ClaudeResponse = response.json().await?;
|
||||
claude_response
|
||||
.content
|
||||
.first()
|
||||
.map(|content| content.text.clone())
|
||||
.ok_or(ProviderError::NoCompletionChoice)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AIProvider for ClaudeProvider {
|
||||
async fn explain(&self, git_entity: GitEntity) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let AIPrompt {
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
} = AIPrompt::build_explain_prompt(&git_entity);
|
||||
|
||||
let payload = json!({
|
||||
"model": self.model,
|
||||
"max_tokens": 4096,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let res = get_completion_result(&self.client, &self.api_key, payload).await?;
|
||||
Ok(res)
|
||||
let prompt = AIPrompt::build_explain_prompt(&git_entity);
|
||||
Ok(self.complete(prompt).await?)
|
||||
}
|
||||
|
||||
async fn draft(&self, git_entity: GitEntity) -> Result<String, Box<dyn std::error::Error>> {
|
||||
todo!()
|
||||
let prompt = AIPrompt::build_draft_prompt(&git_entity)?;
|
||||
Ok(self.complete(prompt).await?)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use async_trait::async_trait;
|
||||
use claude::ClaudeProvider;
|
||||
use claude::{ClaudeConfig, ClaudeProvider};
|
||||
use groq::GroqProvider;
|
||||
use openai::{OpenAIConfig, OpenAIProvider};
|
||||
use phind::PhindProvider;
|
||||
|
|
@ -59,8 +59,8 @@ impl LumenProvider {
|
|||
}
|
||||
ProviderType::Claude => {
|
||||
let api_key = api_key.ok_or(LumenError::MissingApiKey("Claude".to_string()))?;
|
||||
let provider =
|
||||
LumenProvider::Claude(Box::new(ClaudeProvider::new(client, api_key, model)));
|
||||
let config = ClaudeConfig::new(api_key, model);
|
||||
let provider = LumenProvider::Claude(Box::new(ClaudeProvider::new(client, config)));
|
||||
Ok(provider)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue