diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8ae4f55..2314c93 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -12,28 +12,32 @@ concurrency: env: CARGO_TERM_COLOR: always + FORCE_COLOR: "1" + PYTHONUNBUFFERED: "1" + UV_VERSION: "0.4.x" jobs: test: runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: os: - macos-latest - ubuntu-latest - - windows-latest - toolchain: - - stable - - beta - - nightly steps: - uses: actions/checkout@v4 - - name: Install Rust toolchain - run: rustup update ${{ matrix.toolchain }} && rustup default ${{ matrix.toolchain }} + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + enable-cache: true + version: ${{ env.UV_VERSION }} - - name: Build - run: cargo build --verbose + - name: Install dependencies and build + run: | + uv sync --frozen + cargo build --verbose - name: Run tests run: cargo test --verbose diff --git a/Cargo.toml b/Cargo.toml index 4b13ee7..24cb331 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ resolver = "2" djls = { path = "crates/djls" } djls-ast = { path = "crates/djls-ast" } djls-django = { path = "crates/djls-django" } +djls-ipc = { path = "crates/djls-ipc" } djls-python = { path = "crates/djls-python" } djls-worker = { path = "crates/djls-worker" } diff --git a/crates/djls-ipc/Cargo.toml b/crates/djls-ipc/Cargo.toml new file mode 100644 index 0000000..d9bbeda --- /dev/null +++ b/crates/djls-ipc/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "djls-ipc" +version = "0.0.0" +edition = "2021" + +[dependencies] +anyhow = { workspace = true } +async-trait = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tokio = { workspace = true } + +tempfile = "3.14.0" diff --git a/crates/djls-ipc/src/client.rs b/crates/djls-ipc/src/client.rs new file mode 100644 index 0000000..9b1bcf1 --- /dev/null +++ b/crates/djls-ipc/src/client.rs @@ -0,0 +1,486 @@ +use anyhow::{Context, Result}; +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::{path::Path, time::Duration}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; + +#[derive(Clone, Debug)] +pub(crate) struct ConnectionConfig { + max_retries: u32, + initial_delay_ms: u64, + max_delay_ms: u64, + backoff_factor: f64, +} + +impl Default for ConnectionConfig { + fn default() -> Self { + Self { + max_retries: 5, + initial_delay_ms: 100, + max_delay_ms: 5000, + backoff_factor: 2.0, + } + } +} + +#[async_trait] +pub trait ConnectionTrait: Send { + async fn write_all(&mut self, buf: &[u8]) -> Result<()>; + async fn read_line(&mut self, buf: &mut String) -> Result; +} + +pub struct Connection { + inner: UnixConnection, +} + +#[cfg(unix)] +pub struct UnixConnection { + stream: tokio::net::UnixStream, +} + +impl Connection { + pub async fn connect(path: &Path) -> Result> { + Self::connect_with_config(path, ConnectionConfig::default()).await + } + + pub(crate) async fn connect_with_config( + path: &Path, + config: ConnectionConfig, + ) -> Result> { + let mut current_delay = config.initial_delay_ms; + let mut last_error = None; + + for attempt in 0..config.max_retries { + let result = { + let stream = tokio::net::UnixStream::connect(path).await; + stream + .map(|s| Box::new(UnixConnection { stream: s }) as Box) + .context("Failed to connect to Unix socket") + }; + + match result { + Ok(connection) => return Ok(connection), + Err(e) => { + last_error = Some(e); + + if attempt < config.max_retries - 1 { + tokio::time::sleep(Duration::from_millis(current_delay)).await; + + current_delay = ((current_delay as f64 * config.backoff_factor) as u64) + .min(config.max_delay_ms); + } + } + } + } + + Err(last_error.unwrap_or_else(|| { + anyhow::anyhow!("Failed to connect after {} attempts", config.max_retries) + })) + } +} + +#[cfg(unix)] +#[async_trait] +impl ConnectionTrait for UnixConnection { + async fn write_all(&mut self, buf: &[u8]) -> Result<()> { + self.stream.write_all(buf).await?; + Ok(()) + } + + async fn read_line(&mut self, buf: &mut String) -> Result { + let mut reader = BufReader::new(&mut self.stream); + let bytes_read = reader.read_line(buf).await?; + Ok(bytes_read) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Message { + pub id: u64, + pub content: T, +} + +pub struct Client { + connection: Box, + message_id: u64, +} + +impl Client { + pub async fn connect(path: &Path) -> Result { + let connection = Connection::connect(path).await?; + Ok(Self { + connection, + message_id: 0, + }) + } + + pub async fn send(&mut self, content: T) -> Result + where + T: Serialize, + R: for<'de> Deserialize<'de>, + { + self.message_id += 1; + let message = Message { + id: self.message_id, + content, + }; + + let msg = serde_json::to_string(&message)? + "\n"; + self.connection.write_all(msg.as_bytes()).await?; + + let mut buffer = String::new(); + self.connection.read_line(&mut buffer).await?; + + let response: Message = serde_json::from_str(&buffer)?; + + if response.id != self.message_id { + return Err(anyhow::anyhow!( + "Message ID mismatch. Expected {}, got {}", + self.message_id, + response.id + )); + } + + Ok(response.content) + } +} + +#[cfg(test)] +mod conn_tests { + use super::*; + use tempfile::NamedTempFile; + use tokio::net::UnixListener; + use tokio::sync::oneshot; + + fn test_config() -> ConnectionConfig { + ConnectionConfig { + max_retries: 5, + initial_delay_ms: 10, + max_delay_ms: 100, + backoff_factor: 2.0, + } + } + + #[tokio::test] + async fn test_unix_connection() -> Result<()> { + let temp_file = NamedTempFile::new()?; + let socket_path = temp_file.path().to_owned(); + temp_file.close()?; + + // Channel to signal when server is ready + let (tx, rx) = oneshot::channel(); + + let listener = UnixListener::bind(&socket_path)?; + + tokio::spawn(async move { + tx.send(()).unwrap(); + + let (stream, _) = listener.accept().await.unwrap(); + + loop { + let mut buf = [0; 1024]; + match stream.try_read(&mut buf) { + Ok(0) => break, // EOF + Ok(n) => { + stream.try_write(&buf[..n]).unwrap(); + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + tokio::task::yield_now().await; + continue; + } + Err(e) => panic!("Error reading from socket: {}", e), + } + } + }); + + rx.await?; + + let mut connection = Connection::connect(&socket_path).await?; + + // single message + connection.write_all(b"hello\n").await?; + let mut response = String::new(); + let n = connection.read_line(&mut response).await?; + assert_eq!(n, 6); + assert_eq!(response, "hello\n"); + + // multiple messages + for i in 0..3 { + let msg = format!("message{}\n", i); + connection.write_all(msg.as_bytes()).await?; + let mut response = String::new(); + let n = connection.read_line(&mut response).await?; + assert_eq!(n, msg.len()); + assert_eq!(response, msg); + } + + // large message + let large_msg = "a".repeat(1000) + "\n"; + connection.write_all(large_msg.as_bytes()).await?; + let mut response = String::new(); + let n = connection.read_line(&mut response).await?; + assert_eq!(n, large_msg.len()); + assert_eq!(response, large_msg); + + Ok(()) + } + + #[tokio::test] + async fn test_unix_connection_nonexistent_path() -> Result<()> { + let temp_file = NamedTempFile::new()?; + let socket_path = temp_file.path().to_owned(); + temp_file.close()?; + + let result = Connection::connect(&socket_path).await; + assert!(result.is_err()); + + Ok(()) + } + + #[tokio::test] + async fn test_unix_connection_server_disconnect() -> Result<()> { + let temp_file = NamedTempFile::new()?; + let socket_path = temp_file.path().to_owned(); + temp_file.close()?; + + let (tx, rx) = oneshot::channel(); + let listener = UnixListener::bind(&socket_path)?; + + let server_handle = tokio::spawn(async move { + tx.send(()).unwrap(); + let (stream, _) = listener.accept().await.unwrap(); + // Server immediately drops the connection + drop(stream); + }); + + rx.await?; + let mut connection = Connection::connect(&socket_path).await?; + + // Write should fail after server disconnects + connection.write_all(b"hello\n").await?; + let mut response = String::new(); + let result = connection.read_line(&mut response).await; + assert!(result.is_err() || result.unwrap() == 0); + + server_handle.await?; + Ok(()) + } + + #[tokio::test] + async fn test_connection_retry() -> Result<()> { + let temp_file = NamedTempFile::new()?; + let socket_path = temp_file.path().to_owned(); + temp_file.close()?; + + let socket_path_clone = socket_path.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(25)).await; + + let listener = tokio::net::UnixListener::bind(&socket_path_clone).unwrap(); + let (stream, _) = listener.accept().await.unwrap(); + drop(stream); + }); + + let start = std::time::Instant::now(); + let _connection = Connection::connect_with_config(&socket_path, test_config()).await?; + let elapsed = start.elapsed(); + + assert!( + elapsed >= Duration::from_millis(30), + "Connection succeeded too quickly ({:?}), should have retried", + elapsed + ); + + assert!( + elapsed < Duration::from_millis(100), + "Connection took too long ({:?}), too many retries", + elapsed + ); + + Ok(()) + } + + #[tokio::test] + async fn test_connection_max_retries() -> Result<()> { + let temp_file = NamedTempFile::new()?; + let socket_path = temp_file.path().to_owned(); + temp_file.close()?; + + let start = std::time::Instant::now(); + let result = Connection::connect_with_config(&socket_path, test_config()).await; + let elapsed = start.elapsed(); + + assert!(result.is_err()); + + // Should have waited approximately + // 0 + 10 + 20 + 40 + 80 ~= 150ms + assert!( + elapsed >= Duration::from_millis(150), + "Didn't retry enough times ({:?})", + elapsed + ); + assert!( + elapsed < Duration::from_millis(200), + "Retried for too long ({:?})", + elapsed + ); + + Ok(()) + } +} + +#[cfg(test)] +mod client_tests { + use super::*; + use std::sync::{Arc, Mutex}; + + struct MockConnection { + written: Arc>>, + responses: Vec>, + response_index: usize, + } + + impl MockConnection { + fn new(responses: Vec>) -> Self { + Self { + written: Arc::new(Mutex::new(Vec::new())), + responses, + response_index: 0, + } + } + } + + #[async_trait::async_trait] + impl crate::client::ConnectionTrait for MockConnection { + async fn write_all(&mut self, buf: &[u8]) -> Result<()> { + if self.response_index >= self.responses.len() { + return Err(anyhow::anyhow!("Connection closed")); + } + self.written.lock().unwrap().extend_from_slice(buf); + Ok(()) + } + + async fn read_line(&mut self, buf: &mut String) -> Result { + match self.responses.get(self.response_index) { + Some(Ok(response)) => { + buf.push_str(response); + self.response_index += 1; + Ok(response.len()) + } + Some(Err(e)) => Err(anyhow::anyhow!(e.to_string())), + None => Ok(0), + } + } + } + + #[derive(Serialize, Deserialize, Debug, PartialEq)] + struct TestMessage { + value: String, + } + + #[tokio::test] + async fn test_successful_message_exchange() -> Result<()> { + let mock_conn = MockConnection::new(vec![Ok( + r#"{"id":1,"content":{"value":"response"}}"#.to_string() + )]); + + let mut client = Client { + connection: Box::new(mock_conn), + message_id: 0, + }; + + let request = TestMessage { + value: "test".to_string(), + }; + let response: TestMessage = client.send(request).await?; + assert_eq!(response.value, "response"); + assert_eq!(client.message_id, 1); + + Ok(()) + } + + #[tokio::test] + async fn test_connection_error() { + let mock_conn = MockConnection::new(vec![Err(anyhow::anyhow!("Connection error"))]); + + let mut client = Client { + connection: Box::new(mock_conn), + message_id: 0, + }; + + let request = TestMessage { + value: "test".to_string(), + }; + let result: Result = client.send(request).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Connection error")); + } + + #[tokio::test] + async fn test_id_mismatch() { + let mock_conn = MockConnection::new(vec![Ok( + r#"{"id":2,"content":{"value":"response"}}"#.to_string() + )]); + + let mut client = Client { + connection: Box::new(mock_conn), + message_id: 0, + }; + + let request = TestMessage { + value: "test".to_string(), + }; + let result: Result = client.send(request).await; + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Message ID mismatch")); + } + + #[tokio::test] + async fn test_invalid_json_response() { + let mock_conn = MockConnection::new(vec![Ok("invalid json".to_string())]); + + let mut client = Client { + connection: Box::new(mock_conn), + message_id: 0, + }; + + let request = TestMessage { + value: "test".to_string(), + }; + let result: Result = client.send(request).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_multiple_messages() -> Result<()> { + let mock_conn = MockConnection::new(vec![ + Ok(r#"{"id":1,"content":{"value":"response1"}}"#.to_string()), + Ok(r#"{"id":2,"content":{"value":"response2"}}"#.to_string()), + ]); + + let mut client = Client { + connection: Box::new(mock_conn), + message_id: 0, + }; + + let request1 = TestMessage { + value: "test1".to_string(), + }; + let response1: TestMessage = client.send(request1).await?; + assert_eq!(response1.value, "response1"); + assert_eq!(client.message_id, 1); + + let request2 = TestMessage { + value: "test2".to_string(), + }; + let response2: TestMessage = client.send(request2).await?; + assert_eq!(response2.value, "response2"); + assert_eq!(client.message_id, 2); + + Ok(()) + } +} diff --git a/crates/djls-ipc/src/lib.rs b/crates/djls-ipc/src/lib.rs new file mode 100644 index 0000000..9aa5b70 --- /dev/null +++ b/crates/djls-ipc/src/lib.rs @@ -0,0 +1,5 @@ +mod client; +mod server; + +pub use client::Client; +pub use server::Server; diff --git a/crates/djls-ipc/src/server.rs b/crates/djls-ipc/src/server.rs new file mode 100644 index 0000000..4656841 --- /dev/null +++ b/crates/djls-ipc/src/server.rs @@ -0,0 +1,162 @@ +use anyhow::{Context, Result}; +use std::path::{Path, PathBuf}; +use std::process::{Child, Command}; +use std::thread::sleep; +use std::time::Duration; +use tempfile::{tempdir, TempDir}; + +pub struct Server { + #[cfg(unix)] + socket_path: PathBuf, + process: Child, + _temp_dir: TempDir, +} + +impl Server { + pub fn start(python_module: &str, args: &[&str]) -> Result { + Self::start_with_options(python_module, args, true) + } + + pub fn start_script(python_script: &str, args: &[&str]) -> Result { + Self::start_with_options(python_script, args, false) + } + + fn start_with_options(python_path: &str, args: &[&str], use_module: bool) -> Result { + let temp_dir = tempdir()?; + + let path = { + let socket_path = temp_dir.path().join("ipc.sock"); + socket_path + }; + + let mut command = Command::new("python"); + if use_module { + command.arg("-m"); + } + command.arg(python_path); + command.args(args); + command.arg("--ipc-path").arg(&path); + + let process = command.spawn().context("Failed to start Python process")?; + + sleep(Duration::from_millis(100)); + + Ok(Self { + socket_path: path, + process, + _temp_dir: temp_dir, + }) + } + + pub fn get_path(&self) -> &Path { + &self.socket_path + } +} + +impl Drop for Server { + fn drop(&mut self) { + let _ = self.process.kill(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::Client; + use serde::{Deserialize, Serialize}; + + const FIXTURES_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures"); + + async fn setup_server_and_client() -> Result<(Server, crate::client::Client)> { + let path = format!("{}/echo_server.py", FIXTURES_PATH); + let server = Server::start_script(&path, &[])?; + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + let client = Client::connect(server.get_path()).await?; + Ok((server, client)) + } + + #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] + struct ComplexMessage { + field1: String, + field2: i32, + field3: bool, + } + + #[tokio::test] + async fn test_basic_string_message() -> Result<()> { + let (_server, mut client) = setup_server_and_client().await?; + + let response: String = client.send("test".to_string()).await?; + assert_eq!(response, "test"); + + Ok(()) + } + + #[tokio::test] + async fn test_multiple_messages() -> Result<()> { + let (_server, mut client) = setup_server_and_client().await?; + + for i in 1..=3 { + let msg = format!("test{}", i); + let response: String = client.send(msg.clone()).await?; + assert_eq!(response, msg); + } + + Ok(()) + } + + #[tokio::test] + async fn test_complex_message() -> Result<()> { + let (_server, mut client) = setup_server_and_client().await?; + + let complex = ComplexMessage { + field1: "hello".to_string(), + field2: 42, + field3: true, + }; + + let response: ComplexMessage = client.send(complex.clone()).await?; + assert_eq!(response, complex); + + Ok(()) + } + + #[tokio::test] + async fn test_multiple_clients() -> Result<()> { + let (server, mut client1) = setup_server_and_client().await?; + let mut client2 = crate::client::Client::connect(server.get_path()).await?; + + let response1: String = client1.send("test1".to_string()).await?; + assert_eq!(response1, "test1"); + + let response2: String = client2.send("test2".to_string()).await?; + assert_eq!(response2, "test2"); + + Ok(()) + } + + #[tokio::test] + async fn test_concurrent_messages() -> Result<()> { + let (_server, mut client) = setup_server_and_client().await?; + + let mut handles = Vec::new(); + + for i in 1..=5 { + let msg = format!("test{}", i); + handles.push(tokio::spawn(async move { msg })); + } + + let mut results = Vec::new(); + for handle in handles { + let msg = handle.await?; + let response: String = client.send(msg.clone()).await?; + results.push((msg, response)); + } + + for (request, response) in results { + assert_eq!(request, response); + } + + Ok(()) + } +} diff --git a/crates/djls-ipc/tests/fixtures/echo_server.py b/crates/djls-ipc/tests/fixtures/echo_server.py new file mode 100644 index 0000000..8f4eb78 --- /dev/null +++ b/crates/djls-ipc/tests/fixtures/echo_server.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import argparse +import asyncio +import json +from pathlib import Path + + +async def handle_client(reader, writer): + while True: + try: + data = await reader.readline() + if not data: + break + + # Parse the incoming message + message = json.loads(data) + # Echo back with same ID but just echo the content + response = {"id": message["id"], "content": message["content"]} + writer.write(json.dumps(response).encode() + b"\n") + await writer.drain() + except Exception: + break + writer.close() + await writer.wait_closed() + + +async def main(ipc_path): + try: + Path(ipc_path).unlink() + except FileNotFoundError: + pass + + server = await asyncio.start_unix_server( + handle_client, + path=ipc_path, + ) + + async with server: + await server.serve_forever() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ipc-path", required=True) + args = parser.parse_args() + asyncio.run(main(args.ipc_path)) diff --git a/crates/djls-ipc/tests/integration.rs b/crates/djls-ipc/tests/integration.rs new file mode 100644 index 0000000..0ea146a --- /dev/null +++ b/crates/djls-ipc/tests/integration.rs @@ -0,0 +1,141 @@ +use anyhow::Result; +use djls_ipc::{Client, Server}; +use serde::{Deserialize, Serialize}; + +const FIXTURES_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures"); + +async fn setup_echo_server() -> Result<(Server, Client)> { + let path = format!("{}/echo_server.py", FIXTURES_PATH); + let server = Server::start_script(&path, &[])?; + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + let client = Client::connect(server.get_path()).await?; + Ok((server, client)) +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +struct TestMessage { + field1: String, + field2: i32, + vec_field: Vec, +} + +#[tokio::test] +async fn test_full_communication_cycle() -> Result<()> { + let (_server, mut client) = setup_echo_server().await?; + + let test_msg = TestMessage { + field1: "hello".to_string(), + field2: 42, + vec_field: vec!["a".to_string(), "b".to_string()], + }; + + let response: TestMessage = client.send(test_msg.clone()).await?; + assert_eq!(response, test_msg); + + Ok(()) +} + +#[tokio::test] +async fn test_long_running_session() -> Result<()> { + let (_server, mut client) = setup_echo_server().await?; + + for i in 0..10 { + let string_msg = format!("test message {}", i); + let response: String = client.send(string_msg.clone()).await?; + assert_eq!(response, string_msg); + + let complex_msg = TestMessage { + field1: format!("message {}", i), + field2: i, + vec_field: vec![format!("item {}", i)], + }; + let response: TestMessage = client.send(complex_msg.clone()).await?; + assert_eq!(response, complex_msg); + } + + Ok(()) +} + +#[tokio::test] +async fn test_multiple_clients_single_server() -> Result<()> { + let (server, mut client1) = setup_echo_server().await?; + let mut client2 = Client::connect(server.get_path()).await?; + let mut client3 = Client::connect(server.get_path()).await?; + + let msg1 = TestMessage { + field1: "client1".to_string(), + field2: 1, + vec_field: vec!["a".to_string()], + }; + let msg2 = TestMessage { + field1: "client2".to_string(), + field2: 2, + vec_field: vec!["b".to_string()], + }; + let msg3 = TestMessage { + field1: "client3".to_string(), + field2: 3, + vec_field: vec!["c".to_string()], + }; + + let response1: TestMessage = client1.send(msg1.clone()).await?; + let response2: TestMessage = client2.send(msg2.clone()).await?; + let response3: TestMessage = client3.send(msg3.clone()).await?; + + assert_eq!(response1, msg1); + assert_eq!(response2, msg2); + assert_eq!(response3, msg3); + + Ok(()) +} + +#[tokio::test] +async fn test_server_restart() -> Result<()> { + let (server1, mut client1) = setup_echo_server().await?; + + let msg = "test".to_string(); + let response: String = client1.send(msg.clone()).await?; + assert_eq!(response, msg); + + drop(client1); + drop(server1); + + let (_server2, mut client2) = setup_echo_server().await?; + + let msg = "test after restart".to_string(); + let response: String = client2.send(msg.clone()).await?; + assert_eq!(response, msg); + + Ok(()) +} + +#[tokio::test] +async fn test_large_messages() -> Result<()> { + let (_server, mut client) = setup_echo_server().await?; + + let large_vec: Vec = (0..1000).map(|i| format!("item {}", i)).collect(); + + let large_msg = TestMessage { + field1: "x".repeat(10000), + field2: 42, + vec_field: large_vec.clone(), + }; + + let response: TestMessage = client.send(large_msg.clone()).await?; + assert_eq!(response, large_msg); + + Ok(()) +} + +#[tokio::test] +async fn test_rapid_messages() -> Result<()> { + let (_server, mut client) = setup_echo_server().await?; + + for i in 0..100 { + let msg = format!("rapid message {}", i); + let response: String = client.send(msg.clone()).await?; + assert_eq!(response, msg); + } + + Ok(()) +}