create djls-ipc crate for communicating with Django process (#17)

This commit is contained in:
Josh Thomas 2024-12-10 18:49:41 -06:00 committed by GitHub
parent f4e473677c
commit a2ebd0dc8f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 868 additions and 9 deletions

View file

@ -12,28 +12,32 @@ concurrency:
env:
CARGO_TERM_COLOR: always
FORCE_COLOR: "1"
PYTHONUNBUFFERED: "1"
UV_VERSION: "0.4.x"
jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os:
- macos-latest
- ubuntu-latest
- windows-latest
toolchain:
- stable
- beta
- nightly
steps:
- uses: actions/checkout@v4
- name: Install Rust toolchain
run: rustup update ${{ matrix.toolchain }} && rustup default ${{ matrix.toolchain }}
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
enable-cache: true
version: ${{ env.UV_VERSION }}
- name: Build
run: cargo build --verbose
- name: Install dependencies and build
run: |
uv sync --frozen
cargo build --verbose
- name: Run tests
run: cargo test --verbose

View file

@ -6,6 +6,7 @@ resolver = "2"
djls = { path = "crates/djls" }
djls-ast = { path = "crates/djls-ast" }
djls-django = { path = "crates/djls-django" }
djls-ipc = { path = "crates/djls-ipc" }
djls-python = { path = "crates/djls-python" }
djls-worker = { path = "crates/djls-worker" }

View file

@ -0,0 +1,13 @@
[package]
name = "djls-ipc"
version = "0.0.0"
edition = "2021"
[dependencies]
anyhow = { workspace = true }
async-trait = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
tempfile = "3.14.0"

View file

@ -0,0 +1,486 @@
use anyhow::{Context, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::{path::Path, time::Duration};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
#[derive(Clone, Debug)]
pub(crate) struct ConnectionConfig {
max_retries: u32,
initial_delay_ms: u64,
max_delay_ms: u64,
backoff_factor: f64,
}
impl Default for ConnectionConfig {
fn default() -> Self {
Self {
max_retries: 5,
initial_delay_ms: 100,
max_delay_ms: 5000,
backoff_factor: 2.0,
}
}
}
#[async_trait]
pub trait ConnectionTrait: Send {
async fn write_all(&mut self, buf: &[u8]) -> Result<()>;
async fn read_line(&mut self, buf: &mut String) -> Result<usize>;
}
pub struct Connection {
inner: UnixConnection,
}
#[cfg(unix)]
pub struct UnixConnection {
stream: tokio::net::UnixStream,
}
impl Connection {
pub async fn connect(path: &Path) -> Result<Box<dyn ConnectionTrait>> {
Self::connect_with_config(path, ConnectionConfig::default()).await
}
pub(crate) async fn connect_with_config(
path: &Path,
config: ConnectionConfig,
) -> Result<Box<dyn ConnectionTrait>> {
let mut current_delay = config.initial_delay_ms;
let mut last_error = None;
for attempt in 0..config.max_retries {
let result = {
let stream = tokio::net::UnixStream::connect(path).await;
stream
.map(|s| Box::new(UnixConnection { stream: s }) as Box<dyn ConnectionTrait>)
.context("Failed to connect to Unix socket")
};
match result {
Ok(connection) => return Ok(connection),
Err(e) => {
last_error = Some(e);
if attempt < config.max_retries - 1 {
tokio::time::sleep(Duration::from_millis(current_delay)).await;
current_delay = ((current_delay as f64 * config.backoff_factor) as u64)
.min(config.max_delay_ms);
}
}
}
}
Err(last_error.unwrap_or_else(|| {
anyhow::anyhow!("Failed to connect after {} attempts", config.max_retries)
}))
}
}
#[cfg(unix)]
#[async_trait]
impl ConnectionTrait for UnixConnection {
async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
self.stream.write_all(buf).await?;
Ok(())
}
async fn read_line(&mut self, buf: &mut String) -> Result<usize> {
let mut reader = BufReader::new(&mut self.stream);
let bytes_read = reader.read_line(buf).await?;
Ok(bytes_read)
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Message<T> {
pub id: u64,
pub content: T,
}
pub struct Client {
connection: Box<dyn ConnectionTrait>,
message_id: u64,
}
impl Client {
pub async fn connect(path: &Path) -> Result<Self> {
let connection = Connection::connect(path).await?;
Ok(Self {
connection,
message_id: 0,
})
}
pub async fn send<T, R>(&mut self, content: T) -> Result<R>
where
T: Serialize,
R: for<'de> Deserialize<'de>,
{
self.message_id += 1;
let message = Message {
id: self.message_id,
content,
};
let msg = serde_json::to_string(&message)? + "\n";
self.connection.write_all(msg.as_bytes()).await?;
let mut buffer = String::new();
self.connection.read_line(&mut buffer).await?;
let response: Message<R> = serde_json::from_str(&buffer)?;
if response.id != self.message_id {
return Err(anyhow::anyhow!(
"Message ID mismatch. Expected {}, got {}",
self.message_id,
response.id
));
}
Ok(response.content)
}
}
#[cfg(test)]
mod conn_tests {
use super::*;
use tempfile::NamedTempFile;
use tokio::net::UnixListener;
use tokio::sync::oneshot;
fn test_config() -> ConnectionConfig {
ConnectionConfig {
max_retries: 5,
initial_delay_ms: 10,
max_delay_ms: 100,
backoff_factor: 2.0,
}
}
#[tokio::test]
async fn test_unix_connection() -> Result<()> {
let temp_file = NamedTempFile::new()?;
let socket_path = temp_file.path().to_owned();
temp_file.close()?;
// Channel to signal when server is ready
let (tx, rx) = oneshot::channel();
let listener = UnixListener::bind(&socket_path)?;
tokio::spawn(async move {
tx.send(()).unwrap();
let (stream, _) = listener.accept().await.unwrap();
loop {
let mut buf = [0; 1024];
match stream.try_read(&mut buf) {
Ok(0) => break, // EOF
Ok(n) => {
stream.try_write(&buf[..n]).unwrap();
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
tokio::task::yield_now().await;
continue;
}
Err(e) => panic!("Error reading from socket: {}", e),
}
}
});
rx.await?;
let mut connection = Connection::connect(&socket_path).await?;
// single message
connection.write_all(b"hello\n").await?;
let mut response = String::new();
let n = connection.read_line(&mut response).await?;
assert_eq!(n, 6);
assert_eq!(response, "hello\n");
// multiple messages
for i in 0..3 {
let msg = format!("message{}\n", i);
connection.write_all(msg.as_bytes()).await?;
let mut response = String::new();
let n = connection.read_line(&mut response).await?;
assert_eq!(n, msg.len());
assert_eq!(response, msg);
}
// large message
let large_msg = "a".repeat(1000) + "\n";
connection.write_all(large_msg.as_bytes()).await?;
let mut response = String::new();
let n = connection.read_line(&mut response).await?;
assert_eq!(n, large_msg.len());
assert_eq!(response, large_msg);
Ok(())
}
#[tokio::test]
async fn test_unix_connection_nonexistent_path() -> Result<()> {
let temp_file = NamedTempFile::new()?;
let socket_path = temp_file.path().to_owned();
temp_file.close()?;
let result = Connection::connect(&socket_path).await;
assert!(result.is_err());
Ok(())
}
#[tokio::test]
async fn test_unix_connection_server_disconnect() -> Result<()> {
let temp_file = NamedTempFile::new()?;
let socket_path = temp_file.path().to_owned();
temp_file.close()?;
let (tx, rx) = oneshot::channel();
let listener = UnixListener::bind(&socket_path)?;
let server_handle = tokio::spawn(async move {
tx.send(()).unwrap();
let (stream, _) = listener.accept().await.unwrap();
// Server immediately drops the connection
drop(stream);
});
rx.await?;
let mut connection = Connection::connect(&socket_path).await?;
// Write should fail after server disconnects
connection.write_all(b"hello\n").await?;
let mut response = String::new();
let result = connection.read_line(&mut response).await;
assert!(result.is_err() || result.unwrap() == 0);
server_handle.await?;
Ok(())
}
#[tokio::test]
async fn test_connection_retry() -> Result<()> {
let temp_file = NamedTempFile::new()?;
let socket_path = temp_file.path().to_owned();
temp_file.close()?;
let socket_path_clone = socket_path.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(25)).await;
let listener = tokio::net::UnixListener::bind(&socket_path_clone).unwrap();
let (stream, _) = listener.accept().await.unwrap();
drop(stream);
});
let start = std::time::Instant::now();
let _connection = Connection::connect_with_config(&socket_path, test_config()).await?;
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_millis(30),
"Connection succeeded too quickly ({:?}), should have retried",
elapsed
);
assert!(
elapsed < Duration::from_millis(100),
"Connection took too long ({:?}), too many retries",
elapsed
);
Ok(())
}
#[tokio::test]
async fn test_connection_max_retries() -> Result<()> {
let temp_file = NamedTempFile::new()?;
let socket_path = temp_file.path().to_owned();
temp_file.close()?;
let start = std::time::Instant::now();
let result = Connection::connect_with_config(&socket_path, test_config()).await;
let elapsed = start.elapsed();
assert!(result.is_err());
// Should have waited approximately
// 0 + 10 + 20 + 40 + 80 ~= 150ms
assert!(
elapsed >= Duration::from_millis(150),
"Didn't retry enough times ({:?})",
elapsed
);
assert!(
elapsed < Duration::from_millis(200),
"Retried for too long ({:?})",
elapsed
);
Ok(())
}
}
#[cfg(test)]
mod client_tests {
use super::*;
use std::sync::{Arc, Mutex};
struct MockConnection {
written: Arc<Mutex<Vec<u8>>>,
responses: Vec<Result<String>>,
response_index: usize,
}
impl MockConnection {
fn new(responses: Vec<Result<String>>) -> Self {
Self {
written: Arc::new(Mutex::new(Vec::new())),
responses,
response_index: 0,
}
}
}
#[async_trait::async_trait]
impl crate::client::ConnectionTrait for MockConnection {
async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
if self.response_index >= self.responses.len() {
return Err(anyhow::anyhow!("Connection closed"));
}
self.written.lock().unwrap().extend_from_slice(buf);
Ok(())
}
async fn read_line(&mut self, buf: &mut String) -> Result<usize> {
match self.responses.get(self.response_index) {
Some(Ok(response)) => {
buf.push_str(response);
self.response_index += 1;
Ok(response.len())
}
Some(Err(e)) => Err(anyhow::anyhow!(e.to_string())),
None => Ok(0),
}
}
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
struct TestMessage {
value: String,
}
#[tokio::test]
async fn test_successful_message_exchange() -> Result<()> {
let mock_conn = MockConnection::new(vec![Ok(
r#"{"id":1,"content":{"value":"response"}}"#.to_string()
)]);
let mut client = Client {
connection: Box::new(mock_conn),
message_id: 0,
};
let request = TestMessage {
value: "test".to_string(),
};
let response: TestMessage = client.send(request).await?;
assert_eq!(response.value, "response");
assert_eq!(client.message_id, 1);
Ok(())
}
#[tokio::test]
async fn test_connection_error() {
let mock_conn = MockConnection::new(vec![Err(anyhow::anyhow!("Connection error"))]);
let mut client = Client {
connection: Box::new(mock_conn),
message_id: 0,
};
let request = TestMessage {
value: "test".to_string(),
};
let result: Result<TestMessage> = client.send(request).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Connection error"));
}
#[tokio::test]
async fn test_id_mismatch() {
let mock_conn = MockConnection::new(vec![Ok(
r#"{"id":2,"content":{"value":"response"}}"#.to_string()
)]);
let mut client = Client {
connection: Box::new(mock_conn),
message_id: 0,
};
let request = TestMessage {
value: "test".to_string(),
};
let result: Result<TestMessage> = client.send(request).await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Message ID mismatch"));
}
#[tokio::test]
async fn test_invalid_json_response() {
let mock_conn = MockConnection::new(vec![Ok("invalid json".to_string())]);
let mut client = Client {
connection: Box::new(mock_conn),
message_id: 0,
};
let request = TestMessage {
value: "test".to_string(),
};
let result: Result<TestMessage> = client.send(request).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_multiple_messages() -> Result<()> {
let mock_conn = MockConnection::new(vec![
Ok(r#"{"id":1,"content":{"value":"response1"}}"#.to_string()),
Ok(r#"{"id":2,"content":{"value":"response2"}}"#.to_string()),
]);
let mut client = Client {
connection: Box::new(mock_conn),
message_id: 0,
};
let request1 = TestMessage {
value: "test1".to_string(),
};
let response1: TestMessage = client.send(request1).await?;
assert_eq!(response1.value, "response1");
assert_eq!(client.message_id, 1);
let request2 = TestMessage {
value: "test2".to_string(),
};
let response2: TestMessage = client.send(request2).await?;
assert_eq!(response2.value, "response2");
assert_eq!(client.message_id, 2);
Ok(())
}
}

View file

@ -0,0 +1,5 @@
mod client;
mod server;
pub use client::Client;
pub use server::Server;

View file

@ -0,0 +1,162 @@
use anyhow::{Context, Result};
use std::path::{Path, PathBuf};
use std::process::{Child, Command};
use std::thread::sleep;
use std::time::Duration;
use tempfile::{tempdir, TempDir};
pub struct Server {
#[cfg(unix)]
socket_path: PathBuf,
process: Child,
_temp_dir: TempDir,
}
impl Server {
pub fn start(python_module: &str, args: &[&str]) -> Result<Self> {
Self::start_with_options(python_module, args, true)
}
pub fn start_script(python_script: &str, args: &[&str]) -> Result<Self> {
Self::start_with_options(python_script, args, false)
}
fn start_with_options(python_path: &str, args: &[&str], use_module: bool) -> Result<Self> {
let temp_dir = tempdir()?;
let path = {
let socket_path = temp_dir.path().join("ipc.sock");
socket_path
};
let mut command = Command::new("python");
if use_module {
command.arg("-m");
}
command.arg(python_path);
command.args(args);
command.arg("--ipc-path").arg(&path);
let process = command.spawn().context("Failed to start Python process")?;
sleep(Duration::from_millis(100));
Ok(Self {
socket_path: path,
process,
_temp_dir: temp_dir,
})
}
pub fn get_path(&self) -> &Path {
&self.socket_path
}
}
impl Drop for Server {
fn drop(&mut self) {
let _ = self.process.kill();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::Client;
use serde::{Deserialize, Serialize};
const FIXTURES_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures");
async fn setup_server_and_client() -> Result<(Server, crate::client::Client)> {
let path = format!("{}/echo_server.py", FIXTURES_PATH);
let server = Server::start_script(&path, &[])?;
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
let client = Client::connect(server.get_path()).await?;
Ok((server, client))
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
struct ComplexMessage {
field1: String,
field2: i32,
field3: bool,
}
#[tokio::test]
async fn test_basic_string_message() -> Result<()> {
let (_server, mut client) = setup_server_and_client().await?;
let response: String = client.send("test".to_string()).await?;
assert_eq!(response, "test");
Ok(())
}
#[tokio::test]
async fn test_multiple_messages() -> Result<()> {
let (_server, mut client) = setup_server_and_client().await?;
for i in 1..=3 {
let msg = format!("test{}", i);
let response: String = client.send(msg.clone()).await?;
assert_eq!(response, msg);
}
Ok(())
}
#[tokio::test]
async fn test_complex_message() -> Result<()> {
let (_server, mut client) = setup_server_and_client().await?;
let complex = ComplexMessage {
field1: "hello".to_string(),
field2: 42,
field3: true,
};
let response: ComplexMessage = client.send(complex.clone()).await?;
assert_eq!(response, complex);
Ok(())
}
#[tokio::test]
async fn test_multiple_clients() -> Result<()> {
let (server, mut client1) = setup_server_and_client().await?;
let mut client2 = crate::client::Client::connect(server.get_path()).await?;
let response1: String = client1.send("test1".to_string()).await?;
assert_eq!(response1, "test1");
let response2: String = client2.send("test2".to_string()).await?;
assert_eq!(response2, "test2");
Ok(())
}
#[tokio::test]
async fn test_concurrent_messages() -> Result<()> {
let (_server, mut client) = setup_server_and_client().await?;
let mut handles = Vec::new();
for i in 1..=5 {
let msg = format!("test{}", i);
handles.push(tokio::spawn(async move { msg }));
}
let mut results = Vec::new();
for handle in handles {
let msg = handle.await?;
let response: String = client.send(msg.clone()).await?;
results.push((msg, response));
}
for (request, response) in results {
assert_eq!(request, response);
}
Ok(())
}
}

View file

@ -0,0 +1,47 @@
from __future__ import annotations
import argparse
import asyncio
import json
from pathlib import Path
async def handle_client(reader, writer):
while True:
try:
data = await reader.readline()
if not data:
break
# Parse the incoming message
message = json.loads(data)
# Echo back with same ID but just echo the content
response = {"id": message["id"], "content": message["content"]}
writer.write(json.dumps(response).encode() + b"\n")
await writer.drain()
except Exception:
break
writer.close()
await writer.wait_closed()
async def main(ipc_path):
try:
Path(ipc_path).unlink()
except FileNotFoundError:
pass
server = await asyncio.start_unix_server(
handle_client,
path=ipc_path,
)
async with server:
await server.serve_forever()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ipc-path", required=True)
args = parser.parse_args()
asyncio.run(main(args.ipc_path))

View file

@ -0,0 +1,141 @@
use anyhow::Result;
use djls_ipc::{Client, Server};
use serde::{Deserialize, Serialize};
const FIXTURES_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/fixtures");
async fn setup_echo_server() -> Result<(Server, Client)> {
let path = format!("{}/echo_server.py", FIXTURES_PATH);
let server = Server::start_script(&path, &[])?;
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
let client = Client::connect(server.get_path()).await?;
Ok((server, client))
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
struct TestMessage {
field1: String,
field2: i32,
vec_field: Vec<String>,
}
#[tokio::test]
async fn test_full_communication_cycle() -> Result<()> {
let (_server, mut client) = setup_echo_server().await?;
let test_msg = TestMessage {
field1: "hello".to_string(),
field2: 42,
vec_field: vec!["a".to_string(), "b".to_string()],
};
let response: TestMessage = client.send(test_msg.clone()).await?;
assert_eq!(response, test_msg);
Ok(())
}
#[tokio::test]
async fn test_long_running_session() -> Result<()> {
let (_server, mut client) = setup_echo_server().await?;
for i in 0..10 {
let string_msg = format!("test message {}", i);
let response: String = client.send(string_msg.clone()).await?;
assert_eq!(response, string_msg);
let complex_msg = TestMessage {
field1: format!("message {}", i),
field2: i,
vec_field: vec![format!("item {}", i)],
};
let response: TestMessage = client.send(complex_msg.clone()).await?;
assert_eq!(response, complex_msg);
}
Ok(())
}
#[tokio::test]
async fn test_multiple_clients_single_server() -> Result<()> {
let (server, mut client1) = setup_echo_server().await?;
let mut client2 = Client::connect(server.get_path()).await?;
let mut client3 = Client::connect(server.get_path()).await?;
let msg1 = TestMessage {
field1: "client1".to_string(),
field2: 1,
vec_field: vec!["a".to_string()],
};
let msg2 = TestMessage {
field1: "client2".to_string(),
field2: 2,
vec_field: vec!["b".to_string()],
};
let msg3 = TestMessage {
field1: "client3".to_string(),
field2: 3,
vec_field: vec!["c".to_string()],
};
let response1: TestMessage = client1.send(msg1.clone()).await?;
let response2: TestMessage = client2.send(msg2.clone()).await?;
let response3: TestMessage = client3.send(msg3.clone()).await?;
assert_eq!(response1, msg1);
assert_eq!(response2, msg2);
assert_eq!(response3, msg3);
Ok(())
}
#[tokio::test]
async fn test_server_restart() -> Result<()> {
let (server1, mut client1) = setup_echo_server().await?;
let msg = "test".to_string();
let response: String = client1.send(msg.clone()).await?;
assert_eq!(response, msg);
drop(client1);
drop(server1);
let (_server2, mut client2) = setup_echo_server().await?;
let msg = "test after restart".to_string();
let response: String = client2.send(msg.clone()).await?;
assert_eq!(response, msg);
Ok(())
}
#[tokio::test]
async fn test_large_messages() -> Result<()> {
let (_server, mut client) = setup_echo_server().await?;
let large_vec: Vec<String> = (0..1000).map(|i| format!("item {}", i)).collect();
let large_msg = TestMessage {
field1: "x".repeat(10000),
field2: 42,
vec_field: large_vec.clone(),
};
let response: TestMessage = client.send(large_msg.clone()).await?;
assert_eq!(response, large_msg);
Ok(())
}
#[tokio::test]
async fn test_rapid_messages() -> Result<()> {
let (_server, mut client) = setup_echo_server().await?;
for i in 0..100 {
let msg = format!("rapid message {}", i);
let response: String = client.send(msg.clone()).await?;
assert_eq!(response, msg);
}
Ok(())
}