remove vestiges of windows traits (#19)

* remove vestiges of windows traits

* remove unused
This commit is contained in:
Josh Thomas 2024-12-10 22:16:09 -06:00 committed by GitHub
parent 7573415597
commit 6f27c5ba9d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,14 +1,12 @@
use anyhow::{Context, Result};
use async_trait::async_trait;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::{
path::{Path, PathBuf},
time::Duration,
};
use std::path::Path;
use std::time::Duration;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::UnixStream;
#[derive(Clone, Copy, Debug)]
pub(crate) struct ConnectionConfig {
struct ConnectionConfig {
max_retries: u16,
initial_delay_ms: u16,
max_delay_ms: u32,
@ -26,43 +24,23 @@ impl Default for ConnectionConfig {
}
}
#[async_trait]
pub trait ConnectionTrait: Send {
async fn write_all(&mut self, buf: &[u8]) -> Result<()>;
async fn read_line(&mut self, buf: &mut String) -> Result<usize>;
}
pub struct Connection {
inner: UnixConnection,
}
#[cfg(unix)]
pub struct UnixConnection {
stream: tokio::net::UnixStream,
#[derive(Debug)]
struct Connection {
stream: UnixStream,
}
impl Connection {
pub async fn connect(path: &Path) -> Result<Box<dyn ConnectionTrait>> {
async fn connect(path: &Path) -> Result<Self> {
Self::connect_with_config(path, ConnectionConfig::default()).await
}
pub(crate) async fn connect_with_config(
path: &Path,
config: ConnectionConfig,
) -> Result<Box<dyn ConnectionTrait>> {
async fn connect_with_config(path: &Path, config: ConnectionConfig) -> Result<Self> {
let mut current_delay = u64::from(config.initial_delay_ms);
let mut last_error = None;
for attempt in 0..config.max_retries {
let result = {
let stream = tokio::net::UnixStream::connect(path).await;
stream
.map(|s| Box::new(UnixConnection { stream: s }) as Box<dyn ConnectionTrait>)
.context("Failed to connect to Unix socket")
};
match result {
Ok(connection) => return Ok(connection),
match UnixStream::connect(path).await {
Ok(stream) => return Ok(Self { stream }),
Err(e) => {
last_error = Some(e);
@ -76,15 +54,11 @@ impl Connection {
}
}
Err(last_error.unwrap_or_else(|| {
anyhow::anyhow!("Failed to connect after {} attempts", config.max_retries)
}))
Err(last_error
.unwrap_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "Unknown error"))
.into())
}
}
#[cfg(unix)]
#[async_trait]
impl ConnectionTrait for UnixConnection {
async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
self.stream.write_all(buf).await?;
Ok(())
@ -103,9 +77,11 @@ pub struct Message<T> {
pub content: T,
}
#[derive(Debug)]
pub struct Client {
connection: Box<dyn ConnectionTrait>,
connection: Connection,
message_id: u64,
#[cfg(test)]
socket_path: PathBuf,
}
@ -115,6 +91,7 @@ impl Client {
Ok(Self {
connection,
message_id: 0,
#[cfg(test)]
socket_path: path.to_owned(),
})
}
@ -339,6 +316,7 @@ mod client_tests {
use super::*;
use std::sync::{Arc, Mutex};
#[derive(Debug)]
struct MockConnection {
written: Arc<Mutex<Vec<u8>>>,
responses: Vec<Result<String>>,
@ -355,8 +333,7 @@ mod client_tests {
}
}
#[async_trait::async_trait]
impl crate::client::ConnectionTrait for MockConnection {
impl MockConnection {
async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
if self.response_index >= self.responses.len() {
return Err(anyhow::anyhow!("Connection closed"));
@ -378,26 +355,65 @@ mod client_tests {
}
}
#[derive(Debug)]
struct TestClient {
connection: MockConnection,
message_id: u64,
socket_path: PathBuf,
}
impl TestClient {
fn new(mock_conn: MockConnection) -> Self {
Self {
connection: mock_conn,
message_id: 0,
socket_path: PathBuf::from("/test/socket"),
}
}
async fn send<T, R>(&mut self, content: T) -> Result<R>
where
T: Serialize,
R: for<'de> Deserialize<'de>,
{
self.message_id += 1;
let message = Message {
id: self.message_id,
content,
};
let msg = serde_json::to_string(&message)? + "\n";
self.connection.write_all(msg.as_bytes()).await?;
let mut buffer = String::new();
self.connection.read_line(&mut buffer).await?;
let response: Message<R> = serde_json::from_str(&buffer)?;
if response.id != self.message_id {
return Err(anyhow::anyhow!(
"Message ID mismatch. Expected {}, got {}",
self.message_id,
response.id
));
}
Ok(response.content)
}
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
struct TestMessage {
value: String,
}
fn create_test_client(mock_conn: MockConnection) -> Client {
Client {
connection: Box::new(mock_conn),
message_id: 0,
socket_path: PathBuf::from("/test/socket"),
}
}
#[tokio::test]
async fn test_successful_message_exchange() -> Result<()> {
let mock_conn = MockConnection::new(vec![Ok(
r#"{"id":1,"content":{"value":"response"}}"#.to_string()
)]);
let mut client = create_test_client(mock_conn);
let mut client = TestClient::new(mock_conn);
let request = TestMessage {
value: "test".to_string(),
@ -413,7 +429,7 @@ mod client_tests {
async fn test_connection_error() {
let mock_conn = MockConnection::new(vec![Err(anyhow::anyhow!("Connection error"))]);
let mut client = create_test_client(mock_conn);
let mut client = TestClient::new(mock_conn);
let request = TestMessage {
value: "test".to_string(),
@ -429,7 +445,7 @@ mod client_tests {
r#"{"id":2,"content":{"value":"response"}}"#.to_string()
)]);
let mut client = create_test_client(mock_conn);
let mut client = TestClient::new(mock_conn);
let request = TestMessage {
value: "test".to_string(),
@ -447,7 +463,7 @@ mod client_tests {
async fn test_invalid_json_response() {
let mock_conn = MockConnection::new(vec![Ok("invalid json".to_string())]);
let mut client = create_test_client(mock_conn);
let mut client = TestClient::new(mock_conn);
let request = TestMessage {
value: "test".to_string(),
@ -463,7 +479,7 @@ mod client_tests {
Ok(r#"{"id":2,"content":{"value":"response2"}}"#.to_string()),
]);
let mut client = create_test_client(mock_conn);
let mut client = TestClient::new(mock_conn);
let request1 = TestMessage {
value: "test1".to_string(),