remove channels with callbacks in proc-macro-srv

This commit is contained in:
bit-aloo 2025-12-23 11:59:56 +05:30
parent 336f025424
commit 1f64a69249
No known key found for this signature in database
7 changed files with 98 additions and 259 deletions

View file

@ -1,7 +1,6 @@
//! The main loop of the proc-macro server.
use std::io;
use crossbeam_channel::unbounded;
use proc_macro_api::{
Codec,
bidirectional_protocol::msg as bidirectional,
@ -82,6 +81,7 @@ fn run_new<C: Codec>() -> io::Result<()> {
}
bidirectional::Request::ApiVersionCheck {} => {
// bidirectional::Response::ApiVersionCheck(CURRENT_API_VERSION).write::<_, C>(stdout)
send_response::<_, C>(
&mut stdout,
bidirectional::Response::ApiVersionCheck(CURRENT_API_VERSION),
@ -160,6 +160,7 @@ fn handle_expand_id<W: std::io::Write, C: Codec>(
def_site,
call_site,
mixed_site,
None,
)
.map(|it| {
legacy::FlatTree::from_tokenstream_raw::<SpanTrans>(it, call_site, CURRENT_API_VERSION)
@ -169,7 +170,7 @@ fn handle_expand_id<W: std::io::Write, C: Codec>(
send_response::<_, C>(stdout, bidirectional::Response::ExpandMacro(res))
}
fn handle_expand_ra<W: std::io::Write, R: std::io::BufRead, C: Codec>(
fn handle_expand_ra<W: io::Write, R: io::BufRead, C: Codec>(
srv: &proc_macro_srv::ProcMacroSrv<'_>,
stdin: &mut R,
stdout: &mut W,
@ -185,74 +186,69 @@ fn handle_expand_ra<W: std::io::Write, R: std::io::BufRead, C: Codec>(
macro_body,
macro_name,
attributes,
has_global_spans:
bidirectional::ExpnGlobals { serialize: _, def_site, call_site, mixed_site },
has_global_spans: bidirectional::ExpnGlobals { def_site, call_site, mixed_site, .. },
span_data_table,
},
} = task;
let mut span_data_table = legacy::deserialize_span_data_index_map(&span_data_table);
let def_site_span = span_data_table[def_site];
let call_site_span = span_data_table[call_site];
let mixed_site_span = span_data_table[mixed_site];
let def_site = span_data_table[def_site];
let call_site = span_data_table[call_site];
let mixed_site = span_data_table[mixed_site];
let macro_body_ts =
let macro_body =
macro_body.to_tokenstream_resolved(CURRENT_API_VERSION, &span_data_table, |a, b| {
srv.join_spans(a, b).unwrap_or(b)
});
let attributes_ts = attributes.map(|it| {
let attributes = attributes.map(|it| {
it.to_tokenstream_resolved(CURRENT_API_VERSION, &span_data_table, |a, b| {
srv.join_spans(a, b).unwrap_or(b)
})
});
let (subreq_tx, subreq_rx) = unbounded::<proc_macro_srv::SubRequest>();
let (subresp_tx, subresp_rx) = unbounded::<proc_macro_srv::SubResponse>();
let (subreq_tx, subreq_rx) = crossbeam_channel::unbounded();
let (subresp_tx, subresp_rx) = crossbeam_channel::unbounded();
let (result_tx, result_rx) = crossbeam_channel::bounded(1);
std::thread::scope(|scope| {
let srv_ref = &srv;
scope.spawn(|| {
let callback = Box::new(move |req: proc_macro_srv::SubRequest| {
subreq_tx.send(req).unwrap();
subresp_rx.recv().unwrap()
});
scope.spawn({
let lib = lib.clone();
let env = env.clone();
let current_dir = current_dir.clone();
let macro_name = macro_name.clone();
move || {
let res = srv_ref
.expand_with_channels(
lib,
&env,
current_dir,
&macro_name,
macro_body_ts,
attributes_ts,
def_site_span,
call_site_span,
mixed_site_span,
subresp_rx,
subreq_tx,
let res = srv
.expand(
lib,
&env,
current_dir,
&macro_name,
macro_body,
attributes,
def_site,
call_site,
mixed_site,
Some(callback),
)
.map(|it| {
(
legacy::FlatTree::from_tokenstream(
it,
CURRENT_API_VERSION,
call_site,
&mut span_data_table,
),
legacy::serialize_span_data_index_map(&span_data_table),
)
.map(|it| {
(
legacy::FlatTree::from_tokenstream(
it,
CURRENT_API_VERSION,
call_site_span,
&mut span_data_table,
),
legacy::serialize_span_data_index_map(&span_data_table),
)
})
.map(|(tree, span_data_table)| bidirectional::ExpandMacroExtended {
tree,
span_data_table,
})
.map_err(|e| e.into_string().unwrap_or_default())
.map_err(legacy::PanicMessage);
let _ = result_tx.send(res);
}
})
.map(|(tree, span_data_table)| bidirectional::ExpandMacroExtended {
tree,
span_data_table,
})
.map_err(|e| legacy::PanicMessage(e.into_string().unwrap_or_default()));
let _ = result_tx.send(res);
});
loop {
@ -264,31 +260,26 @@ fn handle_expand_ra<W: std::io::Write, R: std::io::BufRead, C: Codec>(
let subreq = match subreq_rx.recv() {
Ok(r) => r,
Err(_) => {
break;
}
Err(_) => break,
};
send_subrequest::<_, C>(stdout, from_srv_req(subreq)).unwrap();
let api_req = from_srv_req(subreq);
bidirectional::BidirectionalMessage::SubRequest(api_req).write::<_, C>(stdout).unwrap();
let resp_opt = bidirectional::BidirectionalMessage::read::<_, C>(stdin, buf).unwrap();
let resp = match resp_opt {
Some(env) => env,
None => {
break;
}
};
let resp = bidirectional::BidirectionalMessage::read::<_, C>(stdin, buf)
.unwrap()
.expect("client closed connection");
match resp {
bidirectional::BidirectionalMessage::SubResponse(subresp) => {
let _ = subresp_tx.send(from_client_res(subresp));
}
_ => {
break;
bidirectional::BidirectionalMessage::SubResponse(api_resp) => {
let srv_resp = from_client_res(api_resp);
subresp_tx.send(srv_resp).unwrap();
}
other => panic!("expected SubResponse, got {other:?}"),
}
}
});
Ok(())
}
@ -356,6 +347,7 @@ fn run_<C: Codec>() -> io::Result<()> {
def_site,
call_site,
mixed_site,
None,
)
.map(|it| {
legacy::FlatTree::from_tokenstream_raw::<SpanTrans>(
@ -397,6 +389,7 @@ fn run_<C: Codec>() -> io::Result<()> {
def_site,
call_site,
mixed_site,
None,
)
.map(|it| {
(
@ -455,11 +448,3 @@ fn send_response<W: std::io::Write, C: Codec>(
let resp = bidirectional::BidirectionalMessage::Response(resp);
resp.write::<W, C>(stdout)
}
fn send_subrequest<W: std::io::Write, C: Codec>(
stdout: &mut W,
resp: bidirectional::SubRequest,
) -> io::Result<()> {
let resp = bidirectional::BidirectionalMessage::SubRequest(resp);
resp.write::<W, C>(stdout)
}

View file

@ -12,7 +12,7 @@ use object::Object;
use paths::{Utf8Path, Utf8PathBuf};
use crate::{
PanicMessage, ProcMacroKind, ProcMacroSrvSpan, dylib::proc_macros::ProcMacros,
PanicMessage, ProcMacroKind, ProcMacroSrvSpan, SubCallback, dylib::proc_macros::ProcMacros,
token_stream::TokenStream,
};
@ -45,39 +45,14 @@ impl Expander {
def_site: S,
call_site: S,
mixed_site: S,
callback: Option<SubCallback>,
) -> Result<TokenStream<S>, PanicMessage>
where
<S::Server as bridge::server::Types>::TokenStream: Default,
{
self.inner
.proc_macros
.expand(macro_name, macro_body, attribute, def_site, call_site, mixed_site)
}
pub(crate) fn expand_with_channels<S: ProcMacroSrvSpan>(
&self,
macro_name: &str,
macro_body: crate::token_stream::TokenStream<S>,
attribute: Option<crate::token_stream::TokenStream<S>>,
def_site: S,
call_site: S,
mixed_site: S,
cli_to_server: crossbeam_channel::Receiver<crate::SubResponse>,
server_to_cli: crossbeam_channel::Sender<crate::SubRequest>,
) -> Result<crate::token_stream::TokenStream<S>, crate::PanicMessage>
where
<S::Server as proc_macro::bridge::server::Types>::TokenStream: Default,
{
self.inner.proc_macros.expand_with_channels(
macro_name,
macro_body,
attribute,
def_site,
call_site,
mixed_site,
cli_to_server,
server_to_cli,
)
.expand(macro_name, macro_body, attribute, def_site, call_site, mixed_site, callback)
}
pub(crate) fn list_macros(&self) -> impl Iterator<Item = (&str, ProcMacroKind)> {

View file

@ -1,8 +1,7 @@
//! Proc macro ABI
use crate::{ProcMacroKind, ProcMacroSrvSpan, SubCallback, token_stream::TokenStream};
use proc_macro::bridge;
use crate::{ProcMacroKind, ProcMacroSrvSpan, token_stream::TokenStream};
#[repr(transparent)]
pub(crate) struct ProcMacros([bridge::client::ProcMacro]);
@ -21,6 +20,7 @@ impl ProcMacros {
def_site: S,
call_site: S,
mixed_site: S,
callback: Option<SubCallback>,
) -> Result<TokenStream<S>, crate::PanicMessage> {
let parsed_attributes = attribute.unwrap_or_default();
@ -31,7 +31,7 @@ impl ProcMacros {
{
let res = client.run(
&bridge::server::SameThread,
S::make_server(call_site, def_site, mixed_site, None, None),
S::make_server(call_site, def_site, mixed_site, callback),
macro_body,
cfg!(debug_assertions),
);
@ -40,7 +40,7 @@ impl ProcMacros {
bridge::client::ProcMacro::Bang { name, client } if *name == macro_name => {
let res = client.run(
&bridge::server::SameThread,
S::make_server(call_site, def_site, mixed_site, None, None),
S::make_server(call_site, def_site, mixed_site, callback),
macro_body,
cfg!(debug_assertions),
);
@ -49,77 +49,7 @@ impl ProcMacros {
bridge::client::ProcMacro::Attr { name, client } if *name == macro_name => {
let res = client.run(
&bridge::server::SameThread,
S::make_server(call_site, def_site, mixed_site, None, None),
parsed_attributes,
macro_body,
cfg!(debug_assertions),
);
return res.map_err(crate::PanicMessage::from);
}
_ => continue,
}
}
Err(bridge::PanicMessage::String(format!("proc-macro `{macro_name}` is missing")).into())
}
pub(crate) fn expand_with_channels<S: ProcMacroSrvSpan>(
&self,
macro_name: &str,
macro_body: TokenStream<S>,
attribute: Option<TokenStream<S>>,
def_site: S,
call_site: S,
mixed_site: S,
cli_to_server: crossbeam_channel::Receiver<crate::SubResponse>,
server_to_cli: crossbeam_channel::Sender<crate::SubRequest>,
) -> Result<TokenStream<S>, crate::PanicMessage> {
let parsed_attributes = attribute.unwrap_or_default();
for proc_macro in &self.0 {
match proc_macro {
bridge::client::ProcMacro::CustomDerive { trait_name, client, .. }
if *trait_name == macro_name =>
{
let res = client.run(
&bridge::server::SameThread,
S::make_server(
call_site,
def_site,
mixed_site,
Some(cli_to_server),
Some(server_to_cli),
),
macro_body,
cfg!(debug_assertions),
);
return res.map_err(crate::PanicMessage::from);
}
bridge::client::ProcMacro::Bang { name, client } if *name == macro_name => {
let res = client.run(
&bridge::server::SameThread,
S::make_server(
call_site,
def_site,
mixed_site,
Some(cli_to_server),
Some(server_to_cli),
),
macro_body,
cfg!(debug_assertions),
);
return res.map_err(crate::PanicMessage::from);
}
bridge::client::ProcMacro::Attr { name, client } if *name == macro_name => {
let res = client.run(
&bridge::server::SameThread,
S::make_server(
call_site,
def_site,
mixed_site,
Some(cli_to_server),
Some(server_to_cli),
),
S::make_server(call_site, def_site, mixed_site, callback),
parsed_attributes,
macro_body,
cfg!(debug_assertions),

View file

@ -91,6 +91,8 @@ impl<'env> ProcMacroSrv<'env> {
}
}
pub type SubCallback = Box<dyn Fn(SubRequest) -> SubResponse + Send + Sync + 'static>;
pub enum SubRequest {
SourceText { file_id: EditionedFileId, start: u32, end: u32 },
}
@ -113,6 +115,7 @@ impl ProcMacroSrv<'_> {
def_site: S,
call_site: S,
mixed_site: S,
callback: Option<SubCallback>,
) -> Result<token_stream::TokenStream<S>, PanicMessage> {
let snapped_env = self.env;
let expander = self.expander(lib.as_ref()).map_err(|err| PanicMessage {
@ -128,54 +131,9 @@ impl ProcMacroSrv<'_> {
.stack_size(EXPANDER_STACK_SIZE)
.name(macro_name.to_owned())
.spawn_scoped(s, move || {
expander
.expand(macro_name, macro_body, attribute, def_site, call_site, mixed_site)
});
match thread.unwrap().join() {
Ok(res) => res,
Err(e) => std::panic::resume_unwind(e),
}
});
prev_env.rollback();
result
}
pub fn expand_with_channels<S: ProcMacroSrvSpan>(
&self,
lib: impl AsRef<Utf8Path>,
env: &[(String, String)],
current_dir: Option<impl AsRef<Path>>,
macro_name: &str,
macro_body: token_stream::TokenStream<S>,
attribute: Option<token_stream::TokenStream<S>>,
def_site: S,
call_site: S,
mixed_site: S,
cli_to_server: crossbeam_channel::Receiver<SubResponse>,
server_to_cli: crossbeam_channel::Sender<SubRequest>,
) -> Result<token_stream::TokenStream<S>, PanicMessage> {
let snapped_env = self.env;
let expander = self.expander(lib.as_ref()).map_err(|err| PanicMessage {
message: Some(format!("failed to load macro: {err}")),
})?;
let prev_env = EnvChange::apply(snapped_env, env, current_dir.as_ref().map(<_>::as_ref));
let result = thread::scope(|s| {
let thread = thread::Builder::new()
.stack_size(EXPANDER_STACK_SIZE)
.name(macro_name.to_owned())
.spawn_scoped(s, move || {
expander.expand_with_channels(
macro_name,
macro_body,
attribute,
def_site,
call_site,
mixed_site,
cli_to_server,
server_to_cli,
expander.expand(
macro_name, macro_body, attribute, def_site, call_site, mixed_site,
callback,
)
});
match thread.unwrap().join() {
@ -229,8 +187,7 @@ pub trait ProcMacroSrvSpan: Copy + Send + Sync {
call_site: Self,
def_site: Self,
mixed_site: Self,
cli_to_server: Option<crossbeam_channel::Receiver<SubResponse>>,
server_to_cli: Option<crossbeam_channel::Sender<SubRequest>>,
callback: Option<SubCallback>,
) -> Self::Server;
}
@ -241,15 +198,13 @@ impl ProcMacroSrvSpan for SpanId {
call_site: Self,
def_site: Self,
mixed_site: Self,
cli_to_server: Option<crossbeam_channel::Receiver<SubResponse>>,
server_to_cli: Option<crossbeam_channel::Sender<SubRequest>>,
callback: Option<SubCallback>,
) -> Self::Server {
Self::Server {
call_site,
def_site,
mixed_site,
cli_to_server,
server_to_cli,
callback,
tracked_env_vars: Default::default(),
tracked_paths: Default::default(),
}
@ -262,17 +217,15 @@ impl ProcMacroSrvSpan for Span {
call_site: Self,
def_site: Self,
mixed_site: Self,
cli_to_server: Option<crossbeam_channel::Receiver<SubResponse>>,
server_to_cli: Option<crossbeam_channel::Sender<SubRequest>>,
callback: Option<SubCallback>,
) -> Self::Server {
Self::Server {
call_site,
def_site,
mixed_site,
callback,
tracked_env_vars: Default::default(),
tracked_paths: Default::default(),
cli_to_server,
server_to_cli,
}
}
}

View file

@ -14,7 +14,7 @@ use proc_macro::bridge::server;
use span::{FIXUP_ERASED_FILE_AST_ID_MARKER, Span, TextRange, TextSize};
use crate::{
SubRequest, SubResponse,
SubCallback, SubRequest, SubResponse,
bridge::{Diagnostic, ExpnGlobals, Literal, TokenTree},
server_impl::literal_from_str,
};
@ -29,8 +29,7 @@ pub struct RaSpanServer {
pub call_site: Span,
pub def_site: Span,
pub mixed_site: Span,
pub cli_to_server: Option<crossbeam_channel::Receiver<SubResponse>>,
pub server_to_cli: Option<crossbeam_channel::Sender<SubRequest>>,
pub callback: Option<SubCallback>,
}
impl server::Types for RaSpanServer {
@ -153,21 +152,17 @@ impl server::Span for RaSpanServer {
/// See PR:
/// https://github.com/rust-lang/rust/pull/55780
fn source_text(&mut self, span: Self::Span) -> Option<String> {
// FIXME requires db, needs special handling wrt fixup spans
if self.server_to_cli.is_some() && self.cli_to_server.is_some() {
let file_id = span.anchor.file_id;
let start: u32 = span.range.start().into();
let end: u32 = span.range.end().into();
let _ = self.server_to_cli.clone().unwrap().send(SubRequest::SourceText {
file_id,
start,
end,
});
match self.cli_to_server.as_ref()?.recv().ok()? {
SubResponse::SourceTextResult { text } => text,
}
} else {
None
let file_id = span.anchor.file_id;
let start: u32 = span.range.start().into();
let end: u32 = span.range.end().into();
let req = SubRequest::SourceText { file_id, start, end };
let cb = self.callback.as_mut()?;
let response = cb(req);
match response {
SubResponse::SourceTextResult { text } => text,
}
}

View file

@ -9,7 +9,7 @@ use intern::Symbol;
use proc_macro::bridge::server;
use crate::{
SubRequest, SubResponse,
SubCallback,
bridge::{Diagnostic, ExpnGlobals, Literal, TokenTree},
server_impl::literal_from_str,
};
@ -35,8 +35,7 @@ pub struct SpanIdServer {
pub call_site: Span,
pub def_site: Span,
pub mixed_site: Span,
pub cli_to_server: Option<crossbeam_channel::Receiver<SubResponse>>,
pub server_to_cli: Option<crossbeam_channel::Sender<SubRequest>>,
pub callback: Option<SubCallback>,
}
impl server::Types for SpanIdServer {

View file

@ -59,8 +59,9 @@ fn assert_expand_impl(
let input_ts_string = format!("{input_ts:?}");
let attr_ts_string = attr_ts.as_ref().map(|it| format!("{it:?}"));
let res =
expander.expand(macro_name, input_ts, attr_ts, def_site, call_site, mixed_site).unwrap();
let res = expander
.expand(macro_name, input_ts, attr_ts, def_site, call_site, mixed_site, None)
.unwrap();
expect.assert_eq(&format!(
"{input_ts_string}{}{}{}",
if attr_ts_string.is_some() { "\n\n" } else { "" },
@ -91,7 +92,8 @@ fn assert_expand_impl(
let fixture_string = format!("{fixture:?}");
let attr_string = attr.as_ref().map(|it| format!("{it:?}"));
let res = expander.expand(macro_name, fixture, attr, def_site, call_site, mixed_site).unwrap();
let res =
expander.expand(macro_name, fixture, attr, def_site, call_site, mixed_site, None).unwrap();
expect_spanned.assert_eq(&format!(
"{fixture_string}{}{}{}",
if attr_string.is_some() { "\n\n" } else { "" },