add protobuf transport and refactor to support (#24)

This commit is contained in:
Josh Thomas 2024-12-11 20:28:57 -06:00 committed by GitHub
parent b3e0ee7b6e
commit 643a47953e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 406 additions and 74 deletions

View file

@ -7,3 +7,5 @@ pub use transport::parse_raw_response;
pub use transport::JsonResponse;
pub use transport::Transport;
pub use transport::TransportError;
pub use transport::TransportMessage;
pub use transport::TransportResponse;

View file

@ -1,4 +1,6 @@
use crate::transport::{Transport, TransportError, TransportProtocol};
use crate::transport::{
Transport, TransportError, TransportMessage, TransportProtocol, TransportResponse,
};
use std::process::{Child, Command, Stdio};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
@ -72,9 +74,9 @@ impl PythonProcess {
pub fn send(
&mut self,
message: &str,
message: TransportMessage,
args: Option<Vec<String>>,
) -> Result<String, TransportError> {
) -> Result<TransportResponse, TransportError> {
let mut transport = self.transport.lock().unwrap();
transport.send(message, args)
}

View file

@ -1,6 +1,9 @@
use djls_types::proto::*;
use prost::Message;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::fmt::Debug;
use std::io::Read;
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::process::{ChildStdin, ChildStdout};
use std::sync::{Arc, Mutex};
@ -19,6 +22,7 @@ pub enum TransportError {
pub enum Transport {
Raw,
Json,
Protobuf,
}
impl Transport {
@ -30,6 +34,7 @@ impl Transport {
let transport_type = match self {
Transport::Raw => "raw",
Transport::Json => "json",
Transport::Protobuf => "protobuf",
};
writeln!(stdin, "{}", transport_type).map_err(TransportError::Io)?;
@ -48,10 +53,25 @@ impl Transport {
match self {
Transport::Raw => Ok(Box::new(RawTransport::new(stdin, stdout)?)),
Transport::Json => Ok(Box::new(JsonTransport::new(stdin, stdout)?)),
Transport::Protobuf => Ok(Box::new(ProtobufTransport::new(stdin, stdout)?)),
}
}
}
#[derive(Debug)]
pub enum TransportMessage {
Raw(String),
Json(String),
Protobuf(ToAgent),
}
#[derive(Debug)]
pub enum TransportResponse {
Raw(String),
Json(String),
Protobuf(FromAgent),
}
pub trait TransportProtocol: Debug + Send {
fn new(stdin: ChildStdin, stdout: ChildStdout) -> Result<Self, TransportError>
where
@ -60,11 +80,15 @@ pub trait TransportProtocol: Debug + Send {
fn clone_box(&self) -> Box<dyn TransportProtocol>;
fn send_impl(
&mut self,
message: &str,
message: TransportMessage,
args: Option<Vec<String>>,
) -> Result<String, TransportError>;
) -> Result<TransportResponse, TransportError>;
fn send(&mut self, message: &str, args: Option<Vec<String>>) -> Result<String, TransportError> {
fn send(
&mut self,
message: TransportMessage,
args: Option<Vec<String>>,
) -> Result<TransportResponse, TransportError> {
self.health_check()?;
self.send_impl(message, args)
}
@ -91,13 +115,16 @@ impl TransportProtocol for RawTransport {
}
fn health_check(&mut self) -> Result<(), TransportError> {
self.send_impl("health", None)
.and_then(|response| match response.as_str() {
"ok" => Ok(()),
other => Err(TransportError::Process(format!(
self.send_impl(TransportMessage::Raw("health".to_string()), None)
.and_then(|response| match response {
TransportResponse::Raw(s) if s == "ok" => Ok(()),
TransportResponse::Raw(other) => Err(TransportError::Process(format!(
"Health check failed: {}",
other
))),
_ => Err(TransportError::Process(
"Unexpected response type".to_string(),
)),
})
}
@ -110,16 +137,24 @@ impl TransportProtocol for RawTransport {
fn send_impl(
&mut self,
message: &str,
message: TransportMessage,
args: Option<Vec<String>>,
) -> Result<String, TransportError> {
) -> Result<TransportResponse, TransportError> {
let mut writer = self.writer.lock().unwrap();
if let Some(args) = args {
// Join command and args with spaces
writeln!(writer, "{} {}", message, args.join(" ")).map_err(TransportError::Io)?;
} else {
writeln!(writer, "{}", message).map_err(TransportError::Io)?;
match message {
TransportMessage::Raw(msg) => {
if let Some(args) = args {
writeln!(writer, "{} {}", msg, args.join(" ")).map_err(TransportError::Io)?;
} else {
writeln!(writer, "{}", msg).map_err(TransportError::Io)?;
}
}
_ => {
return Err(TransportError::Process(
"Raw transport only accepts raw messages".to_string(),
))
}
}
writer.flush().map_err(TransportError::Io)?;
@ -127,7 +162,7 @@ impl TransportProtocol for RawTransport {
let mut reader = self.reader.lock().unwrap();
let mut line = String::new();
reader.read_line(&mut line).map_err(TransportError::Io)?;
Ok(line.trim().to_string())
Ok(TransportResponse::Raw(line.trim().to_string()))
}
}
@ -165,15 +200,21 @@ impl TransportProtocol for JsonTransport {
}
fn health_check(&mut self) -> Result<(), TransportError> {
self.send_impl("health", None).and_then(|response| {
let json: JsonResponse = serde_json::from_str(&response)?;
match json.status.as_str() {
"ok" => Ok(()),
self.send_impl(TransportMessage::Json("health".to_string()), None)
.and_then(|response| match response {
TransportResponse::Json(json) => {
let resp: JsonResponse = serde_json::from_str(&json)?;
match resp.status.as_str() {
"ok" => Ok(()),
_ => Err(TransportError::Process(
resp.error.unwrap_or_else(|| "Unknown error".to_string()),
)),
}
}
_ => Err(TransportError::Process(
json.error.unwrap_or_else(|| "Unknown error".to_string()),
"Unexpected response type".to_string(),
)),
}
})
})
}
fn clone_box(&self) -> Box<dyn TransportProtocol> {
@ -185,23 +226,110 @@ impl TransportProtocol for JsonTransport {
fn send_impl(
&mut self,
message: &str,
message: TransportMessage,
args: Option<Vec<String>>,
) -> Result<String, TransportError> {
let command = JsonCommand {
command: message.to_string(),
args,
};
) -> Result<TransportResponse, TransportError> {
let mut writer = self.writer.lock().unwrap();
serde_json::to_writer(&mut *writer, &command)?;
writeln!(writer).map_err(TransportError::Io)?;
match message {
TransportMessage::Json(msg) => {
let command = JsonCommand { command: msg, args };
serde_json::to_writer(&mut *writer, &command)?;
writeln!(writer).map_err(TransportError::Io)?;
}
_ => {
return Err(TransportError::Process(
"JSON transport only accepts JSON messages".to_string(),
))
}
}
writer.flush().map_err(TransportError::Io)?;
let mut reader = self.reader.lock().unwrap();
let mut line = String::new();
reader.read_line(&mut line).map_err(TransportError::Io)?;
Ok(line.trim().to_string())
Ok(TransportResponse::Json(line.trim().to_string()))
}
}
#[derive(Debug)]
pub struct ProtobufTransport {
reader: Arc<Mutex<BufReader<ChildStdout>>>,
writer: Arc<Mutex<BufWriter<ChildStdin>>>,
}
impl TransportProtocol for ProtobufTransport {
fn new(stdin: ChildStdin, stdout: ChildStdout) -> Result<Self, TransportError> {
Ok(Self {
reader: Arc::new(Mutex::new(BufReader::new(stdout))),
writer: Arc::new(Mutex::new(BufWriter::new(stdin))),
})
}
fn health_check(&mut self) -> Result<(), TransportError> {
let request = ToAgent {
command: Some(to_agent::Command::HealthCheck(HealthCheck {})),
};
match self.send_impl(TransportMessage::Protobuf(request), None)? {
TransportResponse::Protobuf(FromAgent {
message: Some(from_agent::Message::Error(e)),
}) => Err(TransportError::Process(e.message)),
TransportResponse::Protobuf(FromAgent {
message: Some(from_agent::Message::HealthCheck(_)),
}) => Ok(()),
_ => Err(TransportError::Process("Unexpected response".to_string())),
}
}
fn clone_box(&self) -> Box<dyn TransportProtocol> {
Box::new(ProtobufTransport {
reader: self.reader.clone(),
writer: self.writer.clone(),
})
}
fn send_impl(
&mut self,
message: TransportMessage,
_args: Option<Vec<String>>,
) -> Result<TransportResponse, TransportError> {
let mut writer = self.writer.lock().unwrap();
match message {
TransportMessage::Protobuf(msg) => {
let buf = msg.encode_to_vec();
writer
.write_all(&(buf.len() as u32).to_be_bytes())
.map_err(TransportError::Io)?;
writer.write_all(&buf).map_err(TransportError::Io)?;
}
_ => {
return Err(TransportError::Process(
"Protobuf transport only accepts protobuf messages".to_string(),
))
}
}
writer.flush().map_err(TransportError::Io)?;
let mut reader = self.reader.lock().unwrap();
let mut length_bytes = [0u8; 4];
reader
.read_exact(&mut length_bytes)
.map_err(TransportError::Io)?;
let length = u32::from_be_bytes(length_bytes);
let mut message_bytes = vec![0u8; length as usize];
reader
.read_exact(&mut message_bytes)
.map_err(TransportError::Io)?;
let response = FromAgent::decode(message_bytes.as_slice())
.map_err(|e| TransportError::Process(e.to_string()))?;
Ok(TransportResponse::Protobuf(response))
}
}