Get rid of all transport types and settle on Protobuf (#25)

* Get rid of all transport types and settle on Protobuf

hope i don't regret this

* Update Cargo.toml

* Update agent.py
This commit is contained in:
Josh Thomas 2024-12-12 16:53:49 -06:00 committed by GitHub
parent 643a47953e
commit 0a6e975ca5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 1484 additions and 685 deletions

View file

@ -1,11 +1,9 @@
mod process;
mod proto;
mod transport;
pub use process::ProcessError;
pub use process::PythonProcess;
pub use transport::parse_json_response;
pub use transport::parse_raw_response;
pub use transport::JsonResponse;
pub use proto::v1;
pub use transport::Transport;
pub use transport::TransportError;
pub use transport::TransportMessage;
pub use transport::TransportResponse;

View file

@ -1,6 +1,6 @@
use crate::transport::{
Transport, TransportError, TransportMessage, TransportProtocol, TransportResponse,
};
use crate::proto::v1::*;
use crate::transport::{Transport, TransportError};
use std::ffi::OsStr;
use std::process::{Child, Command, Stdio};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
@ -9,75 +9,130 @@ use tokio::time;
#[derive(Debug)]
pub struct PythonProcess {
transport: Arc<Mutex<Box<dyn TransportProtocol>>>,
transport: Arc<Mutex<Transport>>,
_child: Child,
healthy: Arc<AtomicBool>,
}
impl PythonProcess {
pub fn new(
pub fn new<I, S>(
module: &str,
transport: Transport,
args: Option<I>,
health_check_interval: Option<Duration>,
) -> Result<Self, TransportError> {
let mut child = Command::new("python")
.arg("-m")
.arg(module)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()?;
) -> Result<Self, ProcessError>
where
I: IntoIterator<Item = S>,
S: AsRef<OsStr>,
{
let mut command = Command::new("python");
command.arg("-m").arg(module);
if let Some(args) = args {
command.args(args);
}
command.stdin(Stdio::piped()).stdout(Stdio::piped());
let mut child = command.spawn().map_err(TransportError::Io)?;
let stdin = child.stdin.take().unwrap();
let stdout = child.stdout.take().unwrap();
let transport = Transport::new(stdin, stdout)?;
let process = Self {
transport: Arc::new(Mutex::new(transport.create(stdin, stdout)?)),
transport: Arc::new(Mutex::new(transport)),
_child: child,
healthy: Arc::new(AtomicBool::new(true)),
};
if let Some(interval) = health_check_interval {
process.start_health_check_task(interval)?;
let transport = process.transport.clone();
let healthy = process.healthy.clone();
tokio::spawn(async move {
let mut interval = time::interval(interval);
loop {
interval.tick().await;
let _ = PythonProcess::check_health(transport.clone(), healthy.clone()).await;
}
});
}
Ok(process)
}
fn start_health_check_task(&self, interval: Duration) -> Result<(), TransportError> {
let healthy = self.healthy.clone();
let transport = self.transport.clone();
tokio::spawn(async move {
let mut interval = time::interval(interval);
loop {
interval.tick().await;
if let Ok(mut transport) = transport.lock() {
match transport.health_check() {
Ok(()) => {
healthy.store(true, Ordering::SeqCst);
}
Err(_) => {
healthy.store(false, Ordering::SeqCst);
}
}
}
}
});
Ok(())
}
pub fn is_healthy(&self) -> bool {
self.healthy.load(Ordering::SeqCst)
}
pub fn send(
&mut self,
message: TransportMessage,
args: Option<Vec<String>>,
) -> Result<TransportResponse, TransportError> {
request: messages::Request,
) -> Result<messages::Response, TransportError> {
let mut transport = self.transport.lock().unwrap();
transport.send(message, args)
transport.send(request)
}
async fn check_health(
transport: Arc<Mutex<Transport>>,
healthy: Arc<AtomicBool>,
) -> Result<(), ProcessError> {
let request = messages::Request {
command: Some(messages::request::Command::CheckHealth(
check::HealthRequest {},
)),
};
let response = tokio::time::timeout(
Duration::from_secs(5),
tokio::task::spawn_blocking(move || {
let mut transport = transport.lock().unwrap();
transport.send(request)
}),
)
.await
.map_err(|_| ProcessError::Timeout(5))?
.map_err(TransportError::Task)?
.map_err(ProcessError::Transport)?;
let result = match response.result {
Some(messages::response::Result::CheckHealth(health)) => {
if !health.passed {
let error_msg = health.error.unwrap_or_else(|| "Unknown error".to_string());
Err(ProcessError::Health(error_msg))
} else {
Ok(())
}
}
Some(messages::response::Result::Error(e)) => Err(ProcessError::Health(e.message)),
_ => Err(ProcessError::Response),
};
healthy.store(result.is_ok(), Ordering::SeqCst);
result
}
}
impl Drop for PythonProcess {
fn drop(&mut self) {
if let Ok(()) = self._child.kill() {
let _ = self._child.wait();
}
}
}
#[derive(thiserror::Error, Debug)]
pub enum ProcessError {
#[error("Health check failed: {0}")]
Health(String),
#[error("Operation timed out after {0} seconds")]
Timeout(u64),
#[error("Unexpected response type")]
Response,
#[error("Failed to acquire lock: {0}")]
Lock(String),
#[error("Process not ready: {0}")]
Ready(String),
#[error("Transport error: {0}")]
Transport(#[from] TransportError),
}

View file

@ -0,0 +1,17 @@
pub mod v1 {
pub mod messages {
include!(concat!(env!("OUT_DIR"), "/djls.v1.messages.rs"));
}
pub mod check {
include!(concat!(env!("OUT_DIR"), "/djls.v1.check.rs"));
}
pub mod django {
include!(concat!(env!("OUT_DIR"), "/djls.v1.django.rs"));
}
pub mod python {
include!(concat!(env!("OUT_DIR"), "/djls.v1.python.rs"));
}
}

View file

@ -1,320 +1,59 @@
use djls_types::proto::*;
use crate::process::ProcessError;
use crate::proto::v1::*;
use prost::Message;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::fmt::Debug;
use std::io::Read;
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::io::{BufRead, BufReader, BufWriter, Read, Write};
use std::process::{ChildStdin, ChildStdout};
use std::sync::{Arc, Mutex};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum TransportError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("Process error: {0}")]
Process(String),
}
pub enum Transport {
Raw,
Json,
Protobuf,
#[derive(Debug, Clone)]
pub struct Transport {
reader: Arc<Mutex<BufReader<ChildStdout>>>,
writer: Arc<Mutex<BufWriter<ChildStdin>>>,
}
impl Transport {
pub fn create(
&self,
mut stdin: ChildStdin,
mut stdout: ChildStdout,
) -> Result<Box<dyn TransportProtocol>, TransportError> {
let transport_type = match self {
Transport::Raw => "raw",
Transport::Json => "json",
Transport::Protobuf => "protobuf",
};
writeln!(stdin, "{}", transport_type).map_err(TransportError::Io)?;
pub fn new(mut stdin: ChildStdin, mut stdout: ChildStdout) -> Result<Self, ProcessError> {
stdin.flush().map_err(TransportError::Io)?;
let mut ready_line = String::new();
BufReader::new(&mut stdout)
.read_line(&mut ready_line)
.map_err(TransportError::Io)?;
if ready_line.trim() != "ready" {
return Err(TransportError::Process(
"Python process not ready".to_string(),
));
return Err(ProcessError::Ready("Python process not ready".to_string()));
}
match self {
Transport::Raw => Ok(Box::new(RawTransport::new(stdin, stdout)?)),
Transport::Json => Ok(Box::new(JsonTransport::new(stdin, stdout)?)),
Transport::Protobuf => Ok(Box::new(ProtobufTransport::new(stdin, stdout)?)),
}
}
}
#[derive(Debug)]
pub enum TransportMessage {
Raw(String),
Json(String),
Protobuf(ToAgent),
}
#[derive(Debug)]
pub enum TransportResponse {
Raw(String),
Json(String),
Protobuf(FromAgent),
}
pub trait TransportProtocol: Debug + Send {
fn new(stdin: ChildStdin, stdout: ChildStdout) -> Result<Self, TransportError>
where
Self: Sized;
fn health_check(&mut self) -> Result<(), TransportError>;
fn clone_box(&self) -> Box<dyn TransportProtocol>;
fn send_impl(
&mut self,
message: TransportMessage,
args: Option<Vec<String>>,
) -> Result<TransportResponse, TransportError>;
fn send(
&mut self,
message: TransportMessage,
args: Option<Vec<String>>,
) -> Result<TransportResponse, TransportError> {
self.health_check()?;
self.send_impl(message, args)
}
}
impl Clone for Box<dyn TransportProtocol> {
fn clone(&self) -> Self {
self.clone_box()
}
}
#[derive(Debug)]
pub struct RawTransport {
reader: Arc<Mutex<BufReader<ChildStdout>>>,
writer: Arc<Mutex<BufWriter<ChildStdin>>>,
}
impl TransportProtocol for RawTransport {
fn new(stdin: ChildStdin, stdout: ChildStdout) -> Result<Self, TransportError> {
Ok(Self {
reader: Arc::new(Mutex::new(BufReader::new(stdout))),
writer: Arc::new(Mutex::new(BufWriter::new(stdin))),
})
}
fn health_check(&mut self) -> Result<(), TransportError> {
self.send_impl(TransportMessage::Raw("health".to_string()), None)
.and_then(|response| match response {
TransportResponse::Raw(s) if s == "ok" => Ok(()),
TransportResponse::Raw(other) => Err(TransportError::Process(format!(
"Health check failed: {}",
other
))),
_ => Err(TransportError::Process(
"Unexpected response type".to_string(),
)),
})
}
fn clone_box(&self) -> Box<dyn TransportProtocol> {
Box::new(RawTransport {
reader: self.reader.clone(),
writer: self.writer.clone(),
})
}
fn send_impl(
pub fn send(
&mut self,
message: TransportMessage,
args: Option<Vec<String>>,
) -> Result<TransportResponse, TransportError> {
let mut writer = self.writer.lock().unwrap();
match message {
TransportMessage::Raw(msg) => {
if let Some(args) = args {
writeln!(writer, "{} {}", msg, args.join(" ")).map_err(TransportError::Io)?;
} else {
writeln!(writer, "{}", msg).map_err(TransportError::Io)?;
}
}
_ => {
return Err(TransportError::Process(
"Raw transport only accepts raw messages".to_string(),
))
}
}
message: messages::Request,
) -> Result<messages::Response, TransportError> {
let buf = message.encode_to_vec();
let mut writer = self.writer.lock().map_err(|_| {
TransportError::Io(std::io::Error::new(
std::io::ErrorKind::Other,
"Failed to acquire writer lock",
))
})?;
writer
.write_all(&(buf.len() as u32).to_be_bytes())
.map_err(TransportError::Io)?;
writer.write_all(&buf).map_err(TransportError::Io)?;
writer.flush().map_err(TransportError::Io)?;
let mut reader = self.reader.lock().unwrap();
let mut line = String::new();
reader.read_line(&mut line).map_err(TransportError::Io)?;
Ok(TransportResponse::Raw(line.trim().to_string()))
}
}
#[derive(Debug, Serialize, Deserialize)]
struct JsonCommand {
command: String,
args: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct JsonResponse {
status: String,
data: Option<Value>,
error: Option<String>,
}
impl JsonResponse {
pub fn data(&self) -> &Option<Value> {
&self.data
}
}
#[derive(Debug)]
pub struct JsonTransport {
reader: Arc<Mutex<BufReader<ChildStdout>>>,
writer: Arc<Mutex<BufWriter<ChildStdin>>>,
}
impl TransportProtocol for JsonTransport {
fn new(stdin: ChildStdin, stdout: ChildStdout) -> Result<Self, TransportError> {
Ok(Self {
reader: Arc::new(Mutex::new(BufReader::new(stdout))),
writer: Arc::new(Mutex::new(BufWriter::new(stdin))),
})
}
fn health_check(&mut self) -> Result<(), TransportError> {
self.send_impl(TransportMessage::Json("health".to_string()), None)
.and_then(|response| match response {
TransportResponse::Json(json) => {
let resp: JsonResponse = serde_json::from_str(&json)?;
match resp.status.as_str() {
"ok" => Ok(()),
_ => Err(TransportError::Process(
resp.error.unwrap_or_else(|| "Unknown error".to_string()),
)),
}
}
_ => Err(TransportError::Process(
"Unexpected response type".to_string(),
)),
})
}
fn clone_box(&self) -> Box<dyn TransportProtocol> {
Box::new(JsonTransport {
reader: self.reader.clone(),
writer: self.writer.clone(),
})
}
fn send_impl(
&mut self,
message: TransportMessage,
args: Option<Vec<String>>,
) -> Result<TransportResponse, TransportError> {
let mut writer = self.writer.lock().unwrap();
match message {
TransportMessage::Json(msg) => {
let command = JsonCommand { command: msg, args };
serde_json::to_writer(&mut *writer, &command)?;
writeln!(writer).map_err(TransportError::Io)?;
}
_ => {
return Err(TransportError::Process(
"JSON transport only accepts JSON messages".to_string(),
))
}
}
writer.flush().map_err(TransportError::Io)?;
let mut reader = self.reader.lock().unwrap();
let mut line = String::new();
reader.read_line(&mut line).map_err(TransportError::Io)?;
Ok(TransportResponse::Json(line.trim().to_string()))
}
}
#[derive(Debug)]
pub struct ProtobufTransport {
reader: Arc<Mutex<BufReader<ChildStdout>>>,
writer: Arc<Mutex<BufWriter<ChildStdin>>>,
}
impl TransportProtocol for ProtobufTransport {
fn new(stdin: ChildStdin, stdout: ChildStdout) -> Result<Self, TransportError> {
Ok(Self {
reader: Arc::new(Mutex::new(BufReader::new(stdout))),
writer: Arc::new(Mutex::new(BufWriter::new(stdin))),
})
}
fn health_check(&mut self) -> Result<(), TransportError> {
let request = ToAgent {
command: Some(to_agent::Command::HealthCheck(HealthCheck {})),
};
match self.send_impl(TransportMessage::Protobuf(request), None)? {
TransportResponse::Protobuf(FromAgent {
message: Some(from_agent::Message::Error(e)),
}) => Err(TransportError::Process(e.message)),
TransportResponse::Protobuf(FromAgent {
message: Some(from_agent::Message::HealthCheck(_)),
}) => Ok(()),
_ => Err(TransportError::Process("Unexpected response".to_string())),
}
}
fn clone_box(&self) -> Box<dyn TransportProtocol> {
Box::new(ProtobufTransport {
reader: self.reader.clone(),
writer: self.writer.clone(),
})
}
fn send_impl(
&mut self,
message: TransportMessage,
_args: Option<Vec<String>>,
) -> Result<TransportResponse, TransportError> {
let mut writer = self.writer.lock().unwrap();
match message {
TransportMessage::Protobuf(msg) => {
let buf = msg.encode_to_vec();
writer
.write_all(&(buf.len() as u32).to_be_bytes())
.map_err(TransportError::Io)?;
writer.write_all(&buf).map_err(TransportError::Io)?;
}
_ => {
return Err(TransportError::Process(
"Protobuf transport only accepts protobuf messages".to_string(),
))
}
}
writer.flush().map_err(TransportError::Io)?;
let mut reader = self.reader.lock().unwrap();
let mut reader = self.reader.lock().map_err(|_| {
TransportError::Io(std::io::Error::new(
std::io::ErrorKind::Other,
"Failed to acquire reader lock",
))
})?;
let mut length_bytes = [0u8; 4];
reader
.read_exact(&mut length_bytes)
@ -326,17 +65,17 @@ impl TransportProtocol for ProtobufTransport {
.read_exact(&mut message_bytes)
.map_err(TransportError::Io)?;
let response = FromAgent::decode(message_bytes.as_slice())
.map_err(|e| TransportError::Process(e.to_string()))?;
Ok(TransportResponse::Protobuf(response))
messages::Response::decode(message_bytes.as_slice())
.map_err(|e| TransportError::Decode(e.to_string()))
}
}
pub fn parse_raw_response(response: String) -> Result<String, TransportError> {
Ok(response)
}
pub fn parse_json_response(response: String) -> Result<JsonResponse, TransportError> {
serde_json::from_str(&response).map_err(TransportError::Json)
#[derive(thiserror::Error, Debug)]
pub enum TransportError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Task error: {0}")]
Task(#[from] tokio::task::JoinError),
#[error("Failed to decode message: {0}")]
Decode(String),
}