Simple cross-process message protocol

This commit is contained in:
Edwin Cheng 2020-03-28 18:12:51 +08:00
parent 7155d5df89
commit 0aacacd4a2
3 changed files with 84 additions and 223 deletions

View file

@ -1,218 +1,93 @@
//! A simplified version of lsp base protocol for rpc //! Defines messages for cross-process message based on `ndjson` wire protocol
use std::{ use std::{
fmt, convert::TryFrom,
io::{self, BufRead, Write}, io::{self, BufRead, Write},
}; };
use crate::{
rpc::{ListMacrosResult, ListMacrosTask},
ExpansionResult, ExpansionTask,
};
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(untagged)] pub enum Request {
pub enum Message { ListMacro(ListMacrosTask),
Request(Request), ExpansionMacro(ExpansionTask),
Response(Response),
} }
impl From<Request> for Message { #[derive(Debug, Serialize, Deserialize, Clone)]
fn from(request: Request) -> Message { pub enum Response {
Message::Request(request) Error(ResponseError),
} ListMacro(ListMacrosResult),
ExpansionMacro(ExpansionResult),
} }
impl From<Response> for Message { macro_rules! impl_try_from_response {
fn from(response: Response) -> Message { ($ty:ty, $tag:ident) => {
Message::Response(response) impl TryFrom<Response> for $ty {
} type Error = &'static str;
} fn try_from(value: Response) -> Result<Self, Self::Error> {
match value {
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] Response::$tag(res) => Ok(res),
#[serde(transparent)] _ => Err("Fail to convert from response"),
pub struct RequestId(IdRepr); }
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[serde(untagged)]
enum IdRepr {
U64(u64),
String(String),
}
impl From<u64> for RequestId {
fn from(id: u64) -> RequestId {
RequestId(IdRepr::U64(id))
}
}
impl From<String> for RequestId {
fn from(id: String) -> RequestId {
RequestId(IdRepr::String(id))
}
}
impl fmt::Display for RequestId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.0 {
IdRepr::U64(it) => fmt::Display::fmt(it, f),
IdRepr::String(it) => fmt::Display::fmt(it, f),
} }
} };
} }
#[derive(Debug, Serialize, Deserialize, Clone)] impl_try_from_response!(ListMacrosResult, ListMacro);
pub struct Request { impl_try_from_response!(ExpansionResult, ExpansionMacro);
pub id: RequestId,
pub method: String,
pub params: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Response {
// JSON RPC allows this to be null if it was impossible
// to decode the request's id. Ignore this special case
// and just die horribly.
pub id: RequestId,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<ResponseError>,
}
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ResponseError { pub struct ResponseError {
pub code: i32, pub code: ErrorCode,
pub message: String, pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
#[derive(Clone, Copy, Debug)]
#[allow(unused)]
pub enum ErrorCode {
// Defined by JSON RPC
ParseError = -32700,
InvalidRequest = -32600,
MethodNotFound = -32601,
InvalidParams = -32602,
InternalError = -32603,
ServerErrorStart = -32099,
ServerErrorEnd = -32000,
ServerNotInitialized = -32002,
UnknownErrorCode = -32001,
// Defined by protocol
ExpansionError = -32900,
} }
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Notification { pub enum ErrorCode {
pub method: String, ServerErrorEnd,
pub params: serde_json::Value, ExpansionError,
} }
impl Message { pub trait Message: Sized + Serialize + DeserializeOwned {
pub fn read(r: &mut impl BufRead) -> io::Result<Option<Message>> { fn read(r: &mut impl BufRead) -> io::Result<Option<Self>> {
let text = match read_msg_text(r)? { let text = match read_json(r)? {
None => return Ok(None), None => return Ok(None),
Some(text) => text, Some(text) => text,
}; };
let msg = serde_json::from_str(&text)?; let msg = serde_json::from_str(&text)?;
Ok(Some(msg)) Ok(Some(msg))
} }
pub fn write(self, w: &mut impl Write) -> io::Result<()> { fn write(self, w: &mut impl Write) -> io::Result<()> {
#[derive(Serialize)] let text = serde_json::to_string(&self)?;
struct JsonRpc { write_json(w, &text)
jsonrpc: &'static str,
#[serde(flatten)]
msg: Message,
}
let text = serde_json::to_string(&JsonRpc { jsonrpc: "2.0", msg: self })?;
write_msg_text(w, &text)
} }
} }
impl Response { impl Message for Request {}
pub fn new_ok<R: Serialize>(id: RequestId, result: R) -> Response { impl Message for Response {}
Response { id, result: Some(serde_json::to_value(result).unwrap()), error: None }
}
pub fn new_err(id: RequestId, code: i32, message: String) -> Response {
let error = ResponseError { code, message, data: None };
Response { id, result: None, error: Some(error) }
}
}
impl Request { fn read_json(inp: &mut impl BufRead) -> io::Result<Option<String>> {
pub fn new<P: Serialize>(id: RequestId, method: String, params: P) -> Request {
Request { id, method, params: serde_json::to_value(params).unwrap() }
}
pub fn extract<P: DeserializeOwned>(self, method: &str) -> Result<(RequestId, P), Request> {
if self.method == method {
let params = serde_json::from_value(self.params).unwrap_or_else(|err| {
panic!("Invalid request\nMethod: {}\n error: {}", method, err)
});
Ok((self.id, params))
} else {
Err(self)
}
}
}
impl Notification {
pub fn new(method: String, params: impl Serialize) -> Notification {
Notification { method, params: serde_json::to_value(params).unwrap() }
}
pub fn extract<P: DeserializeOwned>(self, method: &str) -> Result<P, Notification> {
if self.method == method {
let params = serde_json::from_value(self.params).unwrap();
Ok(params)
} else {
Err(self)
}
}
}
fn read_msg_text(inp: &mut impl BufRead) -> io::Result<Option<String>> {
fn invalid_data(error: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
io::Error::new(io::ErrorKind::InvalidData, error)
}
macro_rules! invalid_data {
($($tt:tt)*) => (invalid_data(format!($($tt)*)))
}
let mut size = None;
let mut buf = String::new(); let mut buf = String::new();
loop { if inp.read_line(&mut buf)? == 0 {
buf.clear(); return Ok(None);
if inp.read_line(&mut buf)? == 0 {
return Ok(None);
}
if !buf.ends_with("\r\n") {
return Err(invalid_data!("malformed header: {:?}", buf));
}
let buf = &buf[..buf.len() - 2];
if buf.is_empty() {
break;
}
let mut parts = buf.splitn(2, ": ");
let header_name = parts.next().unwrap();
let header_value =
parts.next().ok_or_else(|| invalid_data!("malformed header: {:?}", buf))?;
if header_name == "Content-Length" {
size = Some(header_value.parse::<usize>().map_err(invalid_data)?);
}
} }
let size: usize = size.ok_or_else(|| invalid_data!("no Content-Length"))?; // Remove ending '\n'
let mut buf = buf.into_bytes(); let buf = &buf[..buf.len() - 1];
buf.resize(size, 0); if buf.is_empty() {
inp.read_exact(&mut buf)?; return Ok(None);
let buf = String::from_utf8(buf).map_err(invalid_data)?; }
log::debug!("< {}", buf); Ok(Some(buf.to_string()))
Ok(Some(buf))
} }
fn write_msg_text(out: &mut impl Write, msg: &str) -> io::Result<()> { fn write_json(out: &mut impl Write, msg: &str) -> io::Result<()> {
log::debug!("> {}", msg); log::debug!("> {}", msg);
write!(out, "Content-Length: {}\r\n\r\n", msg.len())?;
out.write_all(msg.as_bytes())?; out.write_all(msg.as_bytes())?;
out.write_all(b"\n")?;
out.flush()?; out.flush()?;
Ok(()) Ok(())
} }

View file

@ -3,11 +3,12 @@
use crossbeam_channel::{bounded, Receiver, Sender}; use crossbeam_channel::{bounded, Receiver, Sender};
use ra_tt::Subtree; use ra_tt::Subtree;
use crate::msg::{ErrorCode, Message, Request, Response, ResponseError}; use crate::msg::{ErrorCode, Request, Response, ResponseError, Message};
use crate::rpc::{ExpansionResult, ExpansionTask, ListMacrosResult, ListMacrosTask, ProcMacroKind}; use crate::rpc::{ExpansionResult, ExpansionTask, ListMacrosResult, ListMacrosTask, ProcMacroKind};
use io::{BufRead, BufReader}; use io::{BufRead, BufReader};
use std::{ use std::{
convert::{TryFrom, TryInto},
io::{self, Write}, io::{self, Write},
path::{Path, PathBuf}, path::{Path, PathBuf},
process::{Child, Command, Stdio}, process::{Child, Command, Stdio},
@ -26,7 +27,7 @@ pub(crate) struct ProcMacroProcessThread {
} }
enum Task { enum Task {
Request { req: Message, result_tx: Sender<Message> }, Request { req: Request, result_tx: Sender<Response> },
Close, Close,
} }
@ -96,7 +97,7 @@ impl ProcMacroProcessSrv {
) -> Result<Vec<(String, ProcMacroKind)>, ra_tt::ExpansionError> { ) -> Result<Vec<(String, ProcMacroKind)>, ra_tt::ExpansionError> {
let task = ListMacrosTask { lib: dylib_path.to_path_buf() }; let task = ListMacrosTask { lib: dylib_path.to_path_buf() };
let result: ListMacrosResult = self.send_task("list_macros", task)?; let result: ListMacrosResult = self.send_task(Request::ListMacro(task))?;
Ok(result.macros) Ok(result.macros)
} }
@ -113,26 +114,19 @@ impl ProcMacroProcessSrv {
lib: dylib_path.to_path_buf(), lib: dylib_path.to_path_buf(),
}; };
let result: ExpansionResult = self.send_task("custom_derive", task)?; let result: ExpansionResult = self.send_task(Request::ExpansionMacro(task))?;
Ok(result.expansion) Ok(result.expansion)
} }
pub fn send_task<'a, T, R>(&self, method: &str, task: T) -> Result<R, ra_tt::ExpansionError> pub fn send_task<R>(&self, req: Request) -> Result<R, ra_tt::ExpansionError>
where where
T: serde::Serialize, R: TryFrom<Response, Error = &'static str>,
R: serde::de::DeserializeOwned + Default,
{ {
let sender = match &self.inner { let sender = match &self.inner {
None => return Err(ra_tt::ExpansionError::Unknown("No sender is found.".to_string())), None => return Err(ra_tt::ExpansionError::Unknown("No sender is found.".to_string())),
Some(it) => it, Some(it) => it,
}; };
let msg = serde_json::to_value(task).unwrap();
// FIXME: use a proper request id
let id = 0;
let req = Request { id: id.into(), method: method.into(), params: msg };
let (result_tx, result_rx) = bounded(0); let (result_tx, result_rx) = bounded(0);
sender.send(Task::Request { req: req.into(), result_tx }).map_err(|err| { sender.send(Task::Request { req: req.into(), result_tx }).map_err(|err| {
@ -141,27 +135,18 @@ impl ProcMacroProcessSrv {
err err
)) ))
})?; })?;
let response = result_rx.recv().unwrap();
match response { let res = result_rx.recv().unwrap();
Message::Request(_) => { match res {
return Err(ra_tt::ExpansionError::Unknown( Response::Error(err) => {
"Return request from ra_proc_srv".into(), return Err(ra_tt::ExpansionError::ExpansionError(err.message));
}
_ => Ok(res.try_into().map_err(|err| {
ra_tt::ExpansionError::Unknown(format!(
"Fail to get response, reason : {:#?} ",
err
)) ))
} })?),
Message::Response(res) => {
if let Some(err) = res.error {
return Err(ra_tt::ExpansionError::ExpansionError(err.message));
}
match res.result {
None => Ok(R::default()),
Some(res) => {
let result: R = serde_json::from_value(res)
.map_err(|err| ra_tt::ExpansionError::JsonError(err.to_string()))?;
Ok(result)
}
}
}
} }
} }
} }
@ -183,18 +168,13 @@ fn client_loop(task_rx: Receiver<Task>, mut process: Process) {
Task::Close => break, Task::Close => break,
}; };
let res = match send_message(&mut stdin, &mut stdout, req) { let res = match send_request(&mut stdin, &mut stdout, req) {
Ok(res) => res, Ok(res) => res,
Err(_err) => { Err(_err) => {
let res = Response { let res = Response::Error(ResponseError {
id: 0.into(), code: ErrorCode::ServerErrorEnd,
result: None, message: "Server closed".into(),
error: Some(ResponseError { });
code: ErrorCode::ServerErrorEnd as i32,
message: "Server closed".into(),
data: None,
}),
};
if result_tx.send(res.into()).is_err() { if result_tx.send(res.into()).is_err() {
break; break;
} }
@ -222,11 +202,11 @@ fn client_loop(task_rx: Receiver<Task>, mut process: Process) {
let _ = process.child.kill(); let _ = process.child.kill();
} }
fn send_message( fn send_request(
mut writer: &mut impl Write, mut writer: &mut impl Write,
mut reader: &mut impl BufRead, mut reader: &mut impl BufRead,
msg: Message, req: Request,
) -> Result<Option<Message>, io::Error> { ) -> Result<Option<Response>, io::Error> {
msg.write(&mut writer)?; req.write(&mut writer)?;
Ok(Message::read(&mut reader)?) Ok(Response::read(&mut reader)?)
} }

View file

@ -1,4 +1,10 @@
//! Data struture serialization related stuffs for RPC //! Data struture serialization related stuffs for RPC
//!
//! Define all necessary rpc serialization data structure,
//! which include ra_tt related data and some task messages.
//! Although adding Serialize and Deserialize trait to ra_tt directly seem to be much easier,
//! we deliberately duplicate the ra_tt struct with #[serde(with = "XXDef")]
//! for separation of code responsibility.
use ra_tt::{ use ra_tt::{
Delimiter, DelimiterKind, Ident, Leaf, Literal, Punct, SmolStr, Spacing, Subtree, TokenId, Delimiter, DelimiterKind, Ident, Leaf, Literal, Punct, SmolStr, Spacing, Subtree, TokenId,