Reenable more nodes

This commit is contained in:
Dennis Kobert 2023-05-26 21:10:58 +02:00 committed by Keavon Chambers
parent 9ab8ba18a4
commit 699d9add7f
7 changed files with 225 additions and 223 deletions

View file

@ -8,7 +8,7 @@ use syn::{
#[proc_macro_attribute]
pub fn node_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut imp = node_impl_impl(attr.clone(), item.clone());
let mut imp = node_impl_proxy(attr.clone(), item.clone());
let new = node_new_impl(attr, item);
imp.extend(new);
imp
@ -18,6 +18,11 @@ pub fn node_new(attr: TokenStream, item: TokenStream) -> TokenStream {
node_new_impl(attr, item)
}
#[proc_macro_attribute]
pub fn node_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
node_impl_proxy(attr, item)
}
fn node_new_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let node = parse_macro_input!(attr as syn::PathSegment);
@ -78,12 +83,23 @@ fn args(node: &syn::PathSegment) -> Vec<Type> {
}
}
#[proc_macro_attribute]
pub fn node_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
node_impl_impl(attr, item)
fn node_impl_proxy(attr: TokenStream, item: TokenStream) -> TokenStream {
let fn_item = item.clone();
let function = parse_macro_input!(fn_item as ItemFn);
let mut sync_input = if function.sig.asyncness.is_some() {
node_impl_impl(attr, item, Asyncness::AllAsync)
} else {
node_impl_impl(attr, item, Asyncness::Sync)
};
sync_input
}
enum Asyncness {
Sync,
AsyncOut,
AllAsync,
}
fn node_impl_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
fn node_impl_impl(attr: TokenStream, item: TokenStream, asyncness: Asyncness) -> TokenStream {
//let node_name = parse_macro_input!(attr as Ident);
let node = parse_macro_input!(attr as syn::PathSegment);
@ -93,7 +109,12 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let node_name = &node.ident;
let mut args = args(node);
let asyncness = function.sig.asyncness.is_some();
let async_out = match asyncness {
Asyncness::Sync => false,
Asyncness::AsyncOut | Asyncness::AllAsync => true,
};
let async_in = matches!(asyncness, Asyncness::AllAsync);
let body = &function.block;
let mut type_generics = function.sig.generics.params.clone();
let mut where_clause = function.sig.generics.where_clause.clone().unwrap_or(WhereClause {
@ -101,6 +122,12 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
predicates: Default::default(),
});
type_generics.iter_mut().for_each(|x| {
if let GenericParam::Type(t) = x {
t.bounds.insert(0, TypeParamBound::Lifetime(Lifetime::new("'input", Span::call_site())));
}
});
let (primary_input, parameter_inputs, parameter_pat_ident_patterns) = parse_inputs(&function);
let primary_input_ty = &primary_input.ty;
let Pat::Ident(PatIdent{ident: primary_input_ident, mutability: primary_input_mutability,..} ) =&*primary_input.pat else {
@ -116,70 +143,70 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
quote::quote!(())
};
let struct_generics = (0..parameter_pat_ident_patterns.len())
.map(|x| {
let ident = format_ident!("S{x}");
ident
})
.collect::<Punctuated<_, Comma>>();
let struct_generics = (0..parameter_pat_ident_patterns.len()).map(|x| format_ident!("S{x}")).collect::<Vec<_>>();
let future_generics = (0..parameter_pat_ident_patterns.len()).map(|x| format_ident!("F{x}")).collect::<Vec<_>>();
let future_types = future_generics.iter().map(|x| Type::Verbatim(x.to_token_stream())).collect::<Vec<_>>();
let parameter_types = parameter_inputs.iter().map(|x| *x.ty.clone()).collect::<Vec<Type>>();
for ident in struct_generics.iter() {
args.push(Type::Verbatim(quote::quote!(#ident)));
}
// Generics are simply `S0` through to `Sn-1` where n is the number of secondary inputs
let node_generics = node_generics(&struct_generics);
type_generics.iter_mut().for_each(|x| {
if let GenericParam::Type(t) = x {
t.bounds.insert(0, TypeParamBound::Lifetime(Lifetime::new("'input", Span::call_site())));
}
});
let generics = type_generics.into_iter().chain(node_generics.iter().cloned()).collect::<Punctuated<_, Comma>>();
// Bindings for all of the above generics to a node with an input of `()` and an output of the type in the function
let extra_where_clause = input_node_bounds(parameter_inputs, node_generics);
where_clause.predicates.extend(extra_where_clause);
let node_generics = construct_node_generics(&struct_generics);
let future_generic_params = construct_node_generics(&future_generics);
let node_impl = if asyncness {
quote::quote! {
#[automatically_derived]
impl <'input, #generics> Node<'input, #primary_input_ty> for #node_name<#(#args),*>
#where_clause
{
type Output = core::pin::Pin<Box<dyn core::future::Future< Output = #output> + 'input>>;
#[inline]
fn eval(&'input self, #primary_input_mutability #primary_input_ident: #primary_input_ty) -> Self::Output {
#(
let #parameter_mutability #parameter_idents = self.#parameter_idents.eval(());
)*
Box::pin(async move {#body})
}
}
}
let generics = if async_in {
type_generics
.into_iter()
.chain(node_generics.iter().cloned())
.chain(future_generic_params.iter().cloned())
.collect::<Punctuated<_, Comma>>()
} else {
let token_stream = quote::quote! {
#[automatically_derived]
impl <'input, #generics> Node<'input, #primary_input_ty> for #node_name<#(#args),*>
#where_clause
{
type Output = #output;
#[inline]
fn eval(&'input self, #primary_input_mutability #primary_input_ident: #primary_input_ty) -> Self::Output {
#(
let #parameter_mutability #parameter_idents = self.#parameter_idents.eval(());
)*
#body
}
}
};
token_stream
type_generics.into_iter().chain(node_generics.iter().cloned()).collect::<Punctuated<_, Comma>>()
};
// Bindings for all of the above generics to a node with an input of `()` and an output of the type in the function
let node_bounds = if async_in {
let mut node_bounds = input_node_bounds(future_types, node_generics, |ty| quote! {Node<'input, (), Output = #ty>});
let future_bounds = input_node_bounds(parameter_types, future_generic_params, |ty| quote! { core::future::Future<Output = #ty>});
node_bounds.extend(future_bounds);
node_bounds
} else {
input_node_bounds(parameter_types, node_generics, |ty| quote! {Node<'input, (), Output = #ty>})
};
where_clause.predicates.extend(node_bounds);
let output = if async_out {
quote::quote!(core::pin::Pin<Box<dyn core::future::Future< Output = #output> + 'input>>)
} else {
quote::quote!(#output)
};
let parameters = if matches!(asyncness, Asyncness::AllAsync) {
quote::quote!(#(let #parameter_mutability #parameter_idents = self.#parameter_idents.eval(()).await;)*)
} else {
quote::quote!(#(let #parameter_mutability #parameter_idents = self.#parameter_idents.eval(());)*)
};
let mut body_with_inputs = quote::quote!(
#parameters
{#body}
);
if async_out {
body_with_inputs = quote::quote!(Box::pin(async move { #body_with_inputs }));
}
quote::quote! {
#node_impl
#[automatically_derived]
impl <'input, #generics> Node<'input, #primary_input_ty> for #node_name<#(#args),*>
#where_clause
{
type Output = #output;
#[inline]
fn eval(&'input self, #primary_input_mutability #primary_input_ident: #primary_input_ty) -> Self::Output {
#body_with_inputs
}
}
}
.into()
}
@ -202,8 +229,8 @@ fn parse_inputs(function: &ItemFn) -> (&syn::PatType, Vec<&syn::PatType>, Vec<&P
(primary_input, parameter_inputs, parameter_pat_ident_patterns)
}
fn node_generics(struct_generics: &Punctuated<Ident, Comma>) -> Punctuated<GenericParam, Comma> {
let node_generics = struct_generics
fn construct_node_generics(struct_generics: &[Ident]) -> Vec<GenericParam> {
struct_generics
.iter()
.cloned()
.map(|ident| {
@ -216,18 +243,17 @@ fn node_generics(struct_generics: &Punctuated<Ident, Comma>) -> Punctuated<Gener
default: None,
})
})
.collect::<Punctuated<_, Comma>>();
node_generics
.collect()
}
fn input_node_bounds(parameter_inputs: Vec<&syn::PatType>, node_generics: Punctuated<GenericParam, Comma>) -> Vec<WherePredicate> {
let extra_where_clause = parameter_inputs
fn input_node_bounds(parameter_inputs: Vec<Type>, node_generics: Vec<GenericParam>, trait_bound: impl Fn(Type) -> proc_macro2::TokenStream) -> Vec<WherePredicate> {
parameter_inputs
.iter()
.zip(&node_generics)
.map(|(ty, name)| {
let ty = &ty.ty;
let GenericParam::Type(generic_ty) = name else { panic!("Expected type generic."); };
let ident = &generic_ty.ident;
let bound = trait_bound(ty.clone());
WherePredicate::Type(PredicateType {
lifetimes: None,
bounded_ty: Type::Verbatim(ident.to_token_stream()),
@ -236,10 +262,9 @@ fn input_node_bounds(parameter_inputs: Vec<&syn::PatType>, node_generics: Punctu
paren_token: None,
modifier: syn::TraitBoundModifier::None,
lifetimes: None, //syn::parse_quote!(for<'any_input>),
path: syn::parse_quote!(Node<'input, (), Output = #ty>),
path: syn::parse_quote!(#bound),
})]),
})
})
.collect::<Vec<_>>();
extra_where_clause
.collect()
}