Reenable more nodes

This commit is contained in:
Dennis Kobert 2023-05-26 21:10:58 +02:00 committed by Keavon Chambers
parent 9ab8ba18a4
commit 699d9add7f
7 changed files with 225 additions and 223 deletions

View file

@ -121,6 +121,13 @@ impl<'i, I: 'i, O: 'i> Node<'i, I> for Pin<Box<dyn for<'a> Node<'a, I, Output =
(**self).eval(input)
}
}
impl<'i, I: 'i, O: 'i> Node<'i, I> for Pin<&'i (dyn NodeIO<'i, I, Output = O> + 'i)> {
type Output = O;
fn eval(&'i self, input: I) -> Self::Output {
(**self).eval(input)
}
}
#[cfg(feature = "alloc")]
pub use crate::raster::image::{EditorApi, ExtractImageFrame};

View file

@ -22,7 +22,7 @@ pub struct AddParameterNode<Second> {
second: Second,
}
#[node_macro::node_fn(AddParameterNode)]
#[node_macro::node_new(AddParameterNode)]
fn add_parameter<U, T>(first: U, second: T) -> <U as Add<T>>::Output
where
U: Add<T>,
@ -30,6 +30,24 @@ where
first + second
}
#[automatically_derived]
impl<'input, U: 'input, T: 'input, S0: 'input> Node<'input, U> for AddParameterNode<S0>
where
U: Add<T>,
S0: Node<'input, (), Output = T>,
{
type Output = <U as Add<T>>::Output;
#[inline]
fn eval(&'input self, first: U) -> Self::Output {
let second = self.second.eval(());
{
{
first + second
}
}
}
}
pub struct MulParameterNode<Second> {
second: Second,
}

View file

@ -196,37 +196,6 @@ impl<'n, I, O> DowncastBothNode<'n, I, O> {
/// Boxes the input and downcasts the output.
/// Wraps around a node taking Box<dyn DynAny> and returning Box<dyn DynAny>
#[derive(Clone, Copy)]
pub struct DowncastBothSyncNode<'a, I, O> {
node: TypeErasedPinnedRef<'a>,
_i: PhantomData<I>,
_o: PhantomData<O>,
}
impl<'n: 'input, 'input, O: 'input + StaticType, I: 'input + StaticType> Node<'input, I> for DowncastBothSyncNode<'n, I, O> {
type Output = O;
#[inline]
fn eval(&'input self, input: I) -> Self::Output {
{
let input = Box::new(input);
let future = self.node.eval(input);
let value = EvalSyncNode::new().eval(future);
let out = dyn_any::downcast(value).unwrap_or_else(|e| panic!("DowncastBothNode Input {e}"));
*out
}
}
}
impl<'n, I, O> DowncastBothSyncNode<'n, I, O> {
pub const fn new(node: TypeErasedPinnedRef<'n>) -> Self {
Self {
node,
_i: core::marker::PhantomData,
_o: core::marker::PhantomData,
}
}
}
/// Boxes the input and downcasts the output.
/// Wraps around a node taking Box<dyn DynAny> and returning Box<dyn DynAny>
#[derive(Clone, Copy)]
pub struct DowncastBothRefNode<'a, I, O> {
node: TypeErasedPinnedRef<'a>,
_i: PhantomData<(I, O)>,

View file

@ -1,3 +1,5 @@
use futures::Future;
use graph_craft::proto::DynFuture;
use graphene_core::Node;
use std::hash::{Hash, Hasher};
@ -15,25 +17,30 @@ pub struct CacheNode<T, CachedNode> {
cache: boxcar::Vec<(u64, T, AtomicBool)>,
node: CachedNode,
}
impl<'i, T: 'i, I: 'i + Hash, CachedNode: 'i> Node<'i, I> for CacheNode<T, CachedNode>
impl<'i, T: 'i + Clone, I: 'i + Hash, CachedNode: 'i> Node<'i, I> for CacheNode<T, CachedNode>
where
CachedNode: for<'any_input> Node<'any_input, I, Output = T>,
CachedNode: for<'any_input> Node<'any_input, I>,
for<'a> <CachedNode as Node<'a, I>>::Output: core::future::Future<Output = T> + 'a,
{
type Output = &'i T;
// TODO: This should return a reference to the cached cached_value
// but that requires a lot of lifetime magic <- This was suggested by copilot but is pretty acurate xD
type Output = Pin<Box<dyn Future<Output = T> + 'i>>;
fn eval(&'i self, input: I) -> Self::Output {
let mut hasher = Xxh3::new();
input.hash(&mut hasher);
let hash = hasher.finish();
Box::pin(async move {
let mut hasher = Xxh3::new();
input.hash(&mut hasher);
let hash = hasher.finish();
if let Some((_, cached_value, keep)) = self.cache.iter().find(|(h, _, _)| *h == hash) {
keep.store(true, std::sync::atomic::Ordering::Relaxed);
cached_value
} else {
trace!("Cache miss");
let output = self.node.eval(input);
let index = self.cache.push((hash, output, AtomicBool::new(true)));
&self.cache[index].1
}
if let Some((_, cached_value, keep)) = self.cache.iter().find(|(h, _, _)| *h == hash) {
keep.store(true, std::sync::atomic::Ordering::Relaxed);
cached_value.clone()
} else {
trace!("Cache miss");
let output = self.node.eval(input).await;
let index = self.cache.push((hash, output, AtomicBool::new(true)));
self.cache[index].1.clone()
}
})
}
fn reset(mut self: Pin<&mut Self>) {
@ -151,11 +158,12 @@ pub struct RefNode<T, Let> {
let_node: Let,
_t: PhantomData<T>,
}
impl<'i, T: 'i, Let> Node<'i, ()> for RefNode<T, Let>
where
Let: for<'a> Node<'a, Option<T>, Output = &'a T>,
Let: for<'a> Node<'a, Option<T>>,
{
type Output = &'i T;
type Output = <Let as Node<'i, Option<T>>>::Output;
fn eval(&'i self, _: ()) -> Self::Output {
self.let_node.eval(None)
}

View file

@ -299,7 +299,7 @@ pub struct BlendImageNode<P, Background, MapFn> {
}
#[node_macro::node_fn(BlendImageNode<_P>)]
fn blend_image_node<_P: Alpha + Pixel + Debug, MapFn, Forground: Sample<Pixel = _P> + Transform>(foreground: Forground, background: ImageFrame<_P>, map_fn: &'input MapFn) -> ImageFrame<_P>
async fn blend_image_node<_P: Alpha + Pixel + Debug, MapFn, Forground: Sample<Pixel = _P> + Transform>(foreground: Forground, background: ImageFrame<_P>, map_fn: &'input MapFn) -> ImageFrame<_P>
where
MapFn: for<'any_input> Node<'any_input, (_P, _P), Output = _P> + 'input,
{

View file

@ -7,22 +7,21 @@ use once_cell::sync::Lazy;
use std::collections::HashMap;
use graphene_core::raster::color::Color;
use graphene_core::raster::*;
use graphene_core::structural::Then;
use graphene_core::value::{ClonedNode, ValueNode};
use graphene_core::value::{ClonedNode, CopiedNode, ValueNode};
use graphene_core::{fn_type, raster::*};
use graphene_core::{Node, NodeIO, NodeIOTypes};
use graphene_std::brush::*;
use graphene_std::raster::*;
use graphene_std::any::{ComposeTypeErased, DowncastBothNode, DowncastBothRefNode, DynAnyInRefNode, DynAnyNode, FutureWrapperNode, IntoTypeErasedNode, TypeErasedPinnedRef};
use graphene_std::any::{ComposeTypeErased, DowncastBothNode, DowncastBothRefNode, DynAnyInRefNode, DynAnyNode, DynAnyRefNode, FutureWrapperNode, IntoTypeErasedNode, TypeErasedPinnedRef};
use graphene_core::{Cow, NodeIdentifier, Type, TypeDescriptor};
use graph_craft::proto::{NodeConstructor, TypeErasedPinned};
use graphene_core::{concrete, generic, value_fn};
use graphene_std::http::EvalSyncNode;
use graphene_std::memo::LetNode;
use graphene_std::memo::{CacheNode, LetNode};
use graphene_std::raster::BlendImageTupleNode;
use dyn_any::StaticType;
@ -335,47 +334,25 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
raster_node!(graphene_core::raster::LevelsNode<_, _, _, _, _>, params: [f64, f64, f64, f64, f64]),
register_node!(graphene_std::image_segmentation::ImageSegmentationNode<_>, input: ImageFrame<Color>, params: [ImageFrame<Color>]),
register_node!(graphene_core::raster::IndexNode<_>, input: Vec<ImageFrame<Color>>, params: [u32]),
/*
vec![
(
NodeIdentifier::new("graphene_core::raster::BlendNode<_, _, _, _>"),
|args| {
Box::pin(async move {
let image: DowncastBothNode<(), ImageFrame<Color>> = DowncastBothNode::new(args[0]);
let blend_mode: DowncastBothNode<(), BlendMode> = DowncastBothNode::new(args[1]);
let opacity: DowncastBothNode<(), f64> = DowncastBothNode::new(args[2]);
let blend_node = graphene_core::raster::BlendNode::new(CopiedNode::new(blend_mode.eval(()).await), CopiedNode::new(opacity.eval(()).await));
let node = graphene_std::raster::BlendImageNode::new(image, ValueNode::new(blend_node));
let _ = &node as &dyn for<'i> Node<'i, ImageFrame<Color>, Output = ImageFrame<Color>>;
let node = FutureWrapperNode::new(node);
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node));
any.into_type_erased()
})
},
NodeIOTypes::new(
concrete!(ImageFrame<Color>),
concrete!(ImageFrame<Color>),
vec![value_fn!(ImageFrame<Color>), value_fn!(BlendMode), value_fn!(f64)],
),
vec![(
NodeIdentifier::new("graphene_core::raster::BlendNode<_, _, _, _>"),
|args| {
Box::pin(async move {
let image: DowncastBothNode<(), ImageFrame<Color>> = DowncastBothNode::new(args[0]);
let blend_mode: DowncastBothNode<(), BlendMode> = DowncastBothNode::new(args[1]);
let opacity: DowncastBothNode<(), f64> = DowncastBothNode::new(args[2]);
let blend_node = graphene_core::raster::BlendNode::new(CopiedNode::new(blend_mode.eval(()).await), CopiedNode::new(opacity.eval(()).await));
let node = graphene_std::raster::BlendImageNode::new(image, FutureWrapperNode::new(ValueNode::new(blend_node)));
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node));
Box::pin(any) as TypeErasedPinned
})
},
NodeIOTypes::new(
concrete!(ImageFrame<Color>),
concrete!(ImageFrame<Color>),
vec![value_fn!(ImageFrame<Color>), value_fn!(BlendMode), value_fn!(f64)],
),
(
NodeIdentifier::new("graphene_core::raster::EraseNode<_, _>"),
|args| {
Box::pin(async move {
let image: DowncastBothNode<(), ImageFrame<Color>> = DowncastBothNode::new(args[0]);
let opacity: DowncastBothNode<(), f64> = DowncastBothNode::new(args[1]);
let blend_node = graphene_std::brush::EraseNode::new(ClonedNode::new(opacity.eval(()).await));
let node = graphene_std::raster::BlendImageNode::new(image, ValueNode::new(blend_node));
let _ = &node as &dyn for<'i> Node<'i, ImageFrame<Color>, Output = ImageFrame<Color>>;
let node = FutureWrapperNode::new(node);
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node));
any.into_type_erased()
})
},
NodeIOTypes::new(concrete!(ImageFrame<Color>), concrete!(ImageFrame<Color>), vec![value_fn!(ImageFrame<Color>), value_fn!(f64)]),
),
],
*/
)],
raster_node!(graphene_core::raster::GrayscaleNode<_, _, _, _, _, _, _>, params: [Color, f64, f64, f64, f64, f64, f64]),
raster_node!(graphene_core::raster::HueSaturationNode<_, _, _>, params: [f64, f64, f64]),
raster_node!(graphene_core::raster::InvertRGBNode, params: []),
@ -389,35 +366,35 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
graphene_core::raster::SelectiveColorNode<_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _>,
params: [RelativeAbsolute, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64]
),
/*
vec![(
NodeIdentifier::new("graphene_core::raster::BrightnessContrastNode<_, _, _>"),
|args| {
use graphene_core::raster::brightness_contrast::*;
Box::pin(async move {
use graphene_core::raster::brightness_contrast::*;
let brightness: DowncastBothNode<(), f64> = DowncastBothNode::new(args[0]);
let brightness = ClonedNode::new(brightness.eval(()) as f32);
let contrast: DowncastBothNode<(), f64> = DowncastBothNode::new(args[1]);
let contrast = ClonedNode::new(contrast.eval(()) as f32);
let use_legacy: DowncastBothNode<(), bool> = DowncastBothNode::new(args[2]);
let brightness: DowncastBothNode<(), f64> = DowncastBothNode::new(args[0]);
let brightness = ClonedNode::new(brightness.eval(()).await as f32);
let contrast: DowncastBothNode<(), f64> = DowncastBothNode::new(args[1]);
let contrast = ClonedNode::new(contrast.eval(()).await as f32);
let use_legacy: DowncastBothNode<(), bool> = DowncastBothNode::new(args[2]);
if use_legacy.eval(()) {
let generate_brightness_contrast_legacy_mapper_node = GenerateBrightnessContrastLegacyMapperNode::new(brightness, contrast);
let map_image_frame_node = graphene_std::raster::MapImageNode::new(ValueNode::new(generate_brightness_contrast_legacy_mapper_node.eval(())));
let map_image_frame_node = FutureWrapperNode::new(map_image_frame_node);
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(ValueNode::new(map_image_frame_node));
Box::pin(any)
} else {
let generate_brightness_contrast_mapper_node = GenerateBrightnessContrastMapperNode::new(brightness, contrast);
let map_image_frame_node = graphene_std::raster::MapImageNode::new(ValueNode::new(generate_brightness_contrast_mapper_node.eval(())));
let map_image_frame_node = FutureWrapperNode::new(map_image_frame_node);
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(ValueNode::new(map_image_frame_node));
Box::pin(any)
}
if use_legacy.eval(()).await {
let generate_brightness_contrast_legacy_mapper_node = GenerateBrightnessContrastLegacyMapperNode::new(brightness, contrast);
let map_image_frame_node = graphene_std::raster::MapImageNode::new(ValueNode::new(generate_brightness_contrast_legacy_mapper_node.eval(())));
let map_image_frame_node = FutureWrapperNode::new(map_image_frame_node);
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(ValueNode::new(map_image_frame_node));
Box::pin(any) as TypeErasedPinned
} else {
let generate_brightness_contrast_mapper_node = GenerateBrightnessContrastMapperNode::new(brightness, contrast);
let map_image_frame_node = graphene_std::raster::MapImageNode::new(ValueNode::new(generate_brightness_contrast_mapper_node.eval(())));
let map_image_frame_node = FutureWrapperNode::new(map_image_frame_node);
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(ValueNode::new(map_image_frame_node));
Box::pin(any) as TypeErasedPinned
}
})
},
NodeIOTypes::new(concrete!(ImageFrame<Color>), concrete!(ImageFrame<Color>), vec![value_fn!(f64), value_fn!(f64), value_fn!(bool)]),
)],
*/
raster_node!(graphene_core::raster::OpacityNode<_>, params: [f64]),
raster_node!(graphene_core::raster::PosterizeNode<_>, params: [f64]),
raster_node!(graphene_core::raster::ExposureNode<_, _, _>, params: [f64, f64, f64]),
@ -496,30 +473,16 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
NodeIdentifier::new("graphene_std::memo::RefNode<_, _>"),
|args| {
Box::pin(async move {
let map_fn: DowncastBothRefNode<Option<graphene_core::EditorApi>, graphene_core::EditorApi> = DowncastBothRefNode::new(args[0]);
let map_fn = map_fn.then(EvalSyncNode::new());
let map_fn: DowncastBothNode<Option<graphene_core::EditorApi>, graphene_core::EditorApi> = DowncastBothNode::new(args[0]);
//let map_fn = map_fn.then(EvalSyncNode::new());
let node = graphene_std::memo::RefNode::new(map_fn);
let any = graphene_std::any::DynAnyRefNode::new(node);
let any = graphene_std::any::DynAnyNode::new(ValueNode::new(node));
Box::pin(any) as TypeErasedPinned
})
},
NodeIOTypes::new(concrete!(()), concrete!(&graphene_core::EditorApi), vec![]),
),
/*
(
NodeIdentifier::new("graphene_core::structural::MapImageNode"),
|args| {
Box::pin(async move {
let map_fn: DowncastBothNode<Color, Color> = DowncastBothNode::new(args[0]);
let node = graphene_std::raster::MapImageNode::new(ValueNode::new(map_fn));
let node = FutureWrapperNode::new(node);
let any: DynAnyNode<Image<Color>, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node));
Box::pin(any) as TypeErasedPinned
})
},
NodeIOTypes::new(concrete!(Image<Color>), concrete!(Image<Color>), vec![]),
),
*/
(
NodeIdentifier::new("graphene_std::raster::ImaginateNode<_>"),
|args| {
@ -557,6 +520,7 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
],
),
),
*/
/*
(
NodeIdentifier::new("graphene_core::raster::BlurNode"),
@ -598,56 +562,67 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
NodeIOTypes::new(concrete!(Image<Color>), concrete!(Image<Color>), vec![value_fn!(u32), value_fn!(f64)]),
),
//register_node!(graphene_std::memo::CacheNode<_>, input: Image<Color>, params: []),
*/
(
NodeIdentifier::new("graphene_std::memo::CacheNode"),
|args| {
let input: DowncastBothNode<(), Image<Color>> = DowncastBothNode::new(args[0]);
let node: CacheNode<Image<Color>, _> = graphene_std::memo::CacheNode::new(input);
let any = DynAnyRefNode::new(node);
Box::pin(any)
Box::pin(async move {
let input: DowncastBothNode<(), Image<Color>> = DowncastBothNode::new(args[0]);
let node: CacheNode<Image<Color>, _> = graphene_std::memo::CacheNode::new(input);
let any = DynAnyNode::new(ValueNode::new(node));
Box::pin(any) as TypeErasedPinned
})
},
NodeIOTypes::new(concrete!(()), concrete!(&Image<Color>), vec![value_fn!(Image<Color>)]),
),
(
NodeIdentifier::new("graphene_std::memo::CacheNode"),
|args| {
let input: DowncastBothNode<(), ImageFrame<Color>> = DowncastBothNode::new(args[0]);
let node: CacheNode<ImageFrame<Color>, _> = graphene_std::memo::CacheNode::new(input);
let any = DynAnyRefNode::new(node);
Box::pin(any)
Box::pin(async move {
let input: DowncastBothNode<(), ImageFrame<Color>> = DowncastBothNode::new(args[0]);
let node: CacheNode<ImageFrame<Color>, _> = graphene_std::memo::CacheNode::new(input);
let any = DynAnyNode::new(ValueNode::new(node));
Box::pin(any) as TypeErasedPinned
})
},
NodeIOTypes::new(concrete!(()), concrete!(&ImageFrame<Color>), vec![value_fn!(ImageFrame<Color>)]),
),
(
NodeIdentifier::new("graphene_std::memo::CacheNode"),
|args| {
let input: DowncastBothNode<ImageFrame<Color>, ImageFrame<Color>> = DowncastBothNode::new(args[0]);
let node: CacheNode<ImageFrame<Color>, _> = graphene_std::memo::CacheNode::new(input);
let any = DynAnyRefNode::new(node);
Box::pin(any)
Box::pin(async move {
let input: DowncastBothNode<ImageFrame<Color>, ImageFrame<Color>> = DowncastBothNode::new(args[0]);
let node: CacheNode<ImageFrame<Color>, _> = graphene_std::memo::CacheNode::new(input);
let any = DynAnyNode::new(ValueNode::new(node));
Box::pin(any) as TypeErasedPinned
})
},
NodeIOTypes::new(concrete!(ImageFrame<Color>), concrete!(&ImageFrame<Color>), vec![fn_type!(ImageFrame<Color>, ImageFrame<Color>)]),
),
(
NodeIdentifier::new("graphene_std::memo::CacheNode"),
|args| {
let input: DowncastBothNode<(), QuantizationChannels> = DowncastBothNode::new(args[0]);
let node: CacheNode<QuantizationChannels, _> = graphene_std::memo::CacheNode::new(input);
let any = DynAnyRefNode::new(node);
Box::pin(any)
Box::pin(async move {
let input: DowncastBothNode<(), QuantizationChannels> = DowncastBothNode::new(args[0]);
let node: CacheNode<QuantizationChannels, _> = graphene_std::memo::CacheNode::new(input);
let any = DynAnyNode::new(ValueNode::new(node));
Box::pin(any) as TypeErasedPinned
})
},
NodeIOTypes::new(concrete!(()), concrete!(&QuantizationChannels), vec![value_fn!(QuantizationChannels)]),
),
(
NodeIdentifier::new("graphene_std::memo::CacheNode"),
|args| {
let input: DowncastBothNode<(), Vec<DVec2>> = DowncastBothNode::new(args[0]);
let node: CacheNode<Vec<DVec2>, _> = graphene_std::memo::CacheNode::new(input);
let any = DynAnyRefNode::new(node);
Box::pin(any)
Box::pin(async move {
let input: DowncastBothNode<(), Vec<DVec2>> = DowncastBothNode::new(args[0]);
let node: CacheNode<Vec<DVec2>, _> = graphene_std::memo::CacheNode::new(input);
let any = DynAnyNode::new(ValueNode::new(node));
Box::pin(any) as TypeErasedPinned
})
},
NodeIOTypes::new(concrete!(()), concrete!(&Vec<DVec2>), vec![value_fn!(Vec<DVec2>)]),
),*/
),
],
register_node!(graphene_core::structural::ConsNode<_, _>, input: Image<Color>, params: [&str]),
register_node!(graphene_std::raster::ImageFrameNode<_, _>, input: Image<Color>, params: [DAffine2]),

View file

@ -8,7 +8,7 @@ use syn::{
#[proc_macro_attribute]
pub fn node_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut imp = node_impl_impl(attr.clone(), item.clone());
let mut imp = node_impl_proxy(attr.clone(), item.clone());
let new = node_new_impl(attr, item);
imp.extend(new);
imp
@ -18,6 +18,11 @@ pub fn node_new(attr: TokenStream, item: TokenStream) -> TokenStream {
node_new_impl(attr, item)
}
#[proc_macro_attribute]
pub fn node_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
node_impl_proxy(attr, item)
}
fn node_new_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let node = parse_macro_input!(attr as syn::PathSegment);
@ -78,12 +83,23 @@ fn args(node: &syn::PathSegment) -> Vec<Type> {
}
}
#[proc_macro_attribute]
pub fn node_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
node_impl_impl(attr, item)
fn node_impl_proxy(attr: TokenStream, item: TokenStream) -> TokenStream {
let fn_item = item.clone();
let function = parse_macro_input!(fn_item as ItemFn);
let mut sync_input = if function.sig.asyncness.is_some() {
node_impl_impl(attr, item, Asyncness::AllAsync)
} else {
node_impl_impl(attr, item, Asyncness::Sync)
};
sync_input
}
enum Asyncness {
Sync,
AsyncOut,
AllAsync,
}
fn node_impl_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
fn node_impl_impl(attr: TokenStream, item: TokenStream, asyncness: Asyncness) -> TokenStream {
//let node_name = parse_macro_input!(attr as Ident);
let node = parse_macro_input!(attr as syn::PathSegment);
@ -93,7 +109,12 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let node_name = &node.ident;
let mut args = args(node);
let asyncness = function.sig.asyncness.is_some();
let async_out = match asyncness {
Asyncness::Sync => false,
Asyncness::AsyncOut | Asyncness::AllAsync => true,
};
let async_in = matches!(asyncness, Asyncness::AllAsync);
let body = &function.block;
let mut type_generics = function.sig.generics.params.clone();
let mut where_clause = function.sig.generics.where_clause.clone().unwrap_or(WhereClause {
@ -101,6 +122,12 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
predicates: Default::default(),
});
type_generics.iter_mut().for_each(|x| {
if let GenericParam::Type(t) = x {
t.bounds.insert(0, TypeParamBound::Lifetime(Lifetime::new("'input", Span::call_site())));
}
});
let (primary_input, parameter_inputs, parameter_pat_ident_patterns) = parse_inputs(&function);
let primary_input_ty = &primary_input.ty;
let Pat::Ident(PatIdent{ident: primary_input_ident, mutability: primary_input_mutability,..} ) =&*primary_input.pat else {
@ -116,70 +143,70 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
quote::quote!(())
};
let struct_generics = (0..parameter_pat_ident_patterns.len())
.map(|x| {
let ident = format_ident!("S{x}");
ident
})
.collect::<Punctuated<_, Comma>>();
let struct_generics = (0..parameter_pat_ident_patterns.len()).map(|x| format_ident!("S{x}")).collect::<Vec<_>>();
let future_generics = (0..parameter_pat_ident_patterns.len()).map(|x| format_ident!("F{x}")).collect::<Vec<_>>();
let future_types = future_generics.iter().map(|x| Type::Verbatim(x.to_token_stream())).collect::<Vec<_>>();
let parameter_types = parameter_inputs.iter().map(|x| *x.ty.clone()).collect::<Vec<Type>>();
for ident in struct_generics.iter() {
args.push(Type::Verbatim(quote::quote!(#ident)));
}
// Generics are simply `S0` through to `Sn-1` where n is the number of secondary inputs
let node_generics = node_generics(&struct_generics);
type_generics.iter_mut().for_each(|x| {
if let GenericParam::Type(t) = x {
t.bounds.insert(0, TypeParamBound::Lifetime(Lifetime::new("'input", Span::call_site())));
}
});
let generics = type_generics.into_iter().chain(node_generics.iter().cloned()).collect::<Punctuated<_, Comma>>();
// Bindings for all of the above generics to a node with an input of `()` and an output of the type in the function
let extra_where_clause = input_node_bounds(parameter_inputs, node_generics);
where_clause.predicates.extend(extra_where_clause);
let node_generics = construct_node_generics(&struct_generics);
let future_generic_params = construct_node_generics(&future_generics);
let node_impl = if asyncness {
quote::quote! {
#[automatically_derived]
impl <'input, #generics> Node<'input, #primary_input_ty> for #node_name<#(#args),*>
#where_clause
{
type Output = core::pin::Pin<Box<dyn core::future::Future< Output = #output> + 'input>>;
#[inline]
fn eval(&'input self, #primary_input_mutability #primary_input_ident: #primary_input_ty) -> Self::Output {
#(
let #parameter_mutability #parameter_idents = self.#parameter_idents.eval(());
)*
Box::pin(async move {#body})
}
}
}
let generics = if async_in {
type_generics
.into_iter()
.chain(node_generics.iter().cloned())
.chain(future_generic_params.iter().cloned())
.collect::<Punctuated<_, Comma>>()
} else {
let token_stream = quote::quote! {
#[automatically_derived]
impl <'input, #generics> Node<'input, #primary_input_ty> for #node_name<#(#args),*>
#where_clause
{
type Output = #output;
#[inline]
fn eval(&'input self, #primary_input_mutability #primary_input_ident: #primary_input_ty) -> Self::Output {
#(
let #parameter_mutability #parameter_idents = self.#parameter_idents.eval(());
)*
#body
}
}
};
token_stream
type_generics.into_iter().chain(node_generics.iter().cloned()).collect::<Punctuated<_, Comma>>()
};
// Bindings for all of the above generics to a node with an input of `()` and an output of the type in the function
let node_bounds = if async_in {
let mut node_bounds = input_node_bounds(future_types, node_generics, |ty| quote! {Node<'input, (), Output = #ty>});
let future_bounds = input_node_bounds(parameter_types, future_generic_params, |ty| quote! { core::future::Future<Output = #ty>});
node_bounds.extend(future_bounds);
node_bounds
} else {
input_node_bounds(parameter_types, node_generics, |ty| quote! {Node<'input, (), Output = #ty>})
};
where_clause.predicates.extend(node_bounds);
let output = if async_out {
quote::quote!(core::pin::Pin<Box<dyn core::future::Future< Output = #output> + 'input>>)
} else {
quote::quote!(#output)
};
let parameters = if matches!(asyncness, Asyncness::AllAsync) {
quote::quote!(#(let #parameter_mutability #parameter_idents = self.#parameter_idents.eval(()).await;)*)
} else {
quote::quote!(#(let #parameter_mutability #parameter_idents = self.#parameter_idents.eval(());)*)
};
let mut body_with_inputs = quote::quote!(
#parameters
{#body}
);
if async_out {
body_with_inputs = quote::quote!(Box::pin(async move { #body_with_inputs }));
}
quote::quote! {
#node_impl
#[automatically_derived]
impl <'input, #generics> Node<'input, #primary_input_ty> for #node_name<#(#args),*>
#where_clause
{
type Output = #output;
#[inline]
fn eval(&'input self, #primary_input_mutability #primary_input_ident: #primary_input_ty) -> Self::Output {
#body_with_inputs
}
}
}
.into()
}
@ -202,8 +229,8 @@ fn parse_inputs(function: &ItemFn) -> (&syn::PatType, Vec<&syn::PatType>, Vec<&P
(primary_input, parameter_inputs, parameter_pat_ident_patterns)
}
fn node_generics(struct_generics: &Punctuated<Ident, Comma>) -> Punctuated<GenericParam, Comma> {
let node_generics = struct_generics
fn construct_node_generics(struct_generics: &[Ident]) -> Vec<GenericParam> {
struct_generics
.iter()
.cloned()
.map(|ident| {
@ -216,18 +243,17 @@ fn node_generics(struct_generics: &Punctuated<Ident, Comma>) -> Punctuated<Gener
default: None,
})
})
.collect::<Punctuated<_, Comma>>();
node_generics
.collect()
}
fn input_node_bounds(parameter_inputs: Vec<&syn::PatType>, node_generics: Punctuated<GenericParam, Comma>) -> Vec<WherePredicate> {
let extra_where_clause = parameter_inputs
fn input_node_bounds(parameter_inputs: Vec<Type>, node_generics: Vec<GenericParam>, trait_bound: impl Fn(Type) -> proc_macro2::TokenStream) -> Vec<WherePredicate> {
parameter_inputs
.iter()
.zip(&node_generics)
.map(|(ty, name)| {
let ty = &ty.ty;
let GenericParam::Type(generic_ty) = name else { panic!("Expected type generic."); };
let ident = &generic_ty.ident;
let bound = trait_bound(ty.clone());
WherePredicate::Type(PredicateType {
lifetimes: None,
bounded_ty: Type::Verbatim(ident.to_token_stream()),
@ -236,10 +262,9 @@ fn input_node_bounds(parameter_inputs: Vec<&syn::PatType>, node_generics: Punctu
paren_token: None,
modifier: syn::TraitBoundModifier::None,
lifetimes: None, //syn::parse_quote!(for<'any_input>),
path: syn::parse_quote!(Node<'input, (), Output = #ty>),
path: syn::parse_quote!(#bound),
})]),
})
})
.collect::<Vec<_>>();
extra_where_clause
.collect()
}