Node macro lambda parameters (#1309)

* Implement parsing of impl Node<I, …> syntax for the macro

* Extend node macro to allow specifying lambda nodes
This commit is contained in:
Dennis Kobert 2023-06-09 16:43:46 +02:00 committed by Keavon Chambers
parent a5f890289b
commit 40ec52b395
15 changed files with 109 additions and 106 deletions

View file

@ -95,8 +95,6 @@ pub enum FrontendMessage {
layer_path: Vec<LayerId>,
svg: String,
size: glam::DVec2,
#[serde(rename = "imaginateNodePath")]
imaginate_node_path: Option<Vec<NodeId>>,
},
TriggerRefreshBoundsOfViewports,
TriggerRevokeBlobUrl {

View file

@ -96,7 +96,6 @@ pub enum DocumentMessage {
},
ImaginateGenerate {
layer_path: Vec<LayerId>,
imaginate_node: Vec<NodeId>,
},
ImaginateRandom {
layer_path: Vec<LayerId>,

View file

@ -31,7 +31,7 @@ use document_legacy::layers::layer_layer::CachedOutputData;
use document_legacy::layers::style::{RenderData, ViewMode};
use document_legacy::{DocumentError, DocumentResponse, LayerId, Operation as DocumentOperation};
use graph_craft::document::value::TaggedValue;
use graph_craft::document::{NodeId, NodeInput, NodeNetwork};
use graph_craft::document::{NodeInput, NodeNetwork};
use graphene_core::raster::ImageFrame;
use graphene_core::text::Font;
@ -466,8 +466,8 @@ impl MessageHandler<DocumentMessage, (u64, &InputPreprocessorMessageHandler, &Pe
});
}
ImaginateClear { layer_path } => responses.add(InputFrameRasterizeRegionBelowLayer { layer_path }),
ImaginateGenerate { layer_path, imaginate_node } => {
if let Some(message) = self.rasterize_region_below_layer(document_id, layer_path, preferences, persistent_data, Some(imaginate_node)) {
ImaginateGenerate { layer_path } => {
if let Some(message) = self.rasterize_region_below_layer(document_id, layer_path, preferences, persistent_data) {
responses.add(message);
}
}
@ -490,13 +490,13 @@ impl MessageHandler<DocumentMessage, (u64, &InputPreprocessorMessageHandler, &Pe
// Generate the image
if then_generate {
responses.add(DocumentMessage::ImaginateGenerate { layer_path, imaginate_node });
responses.add(DocumentMessage::ImaginateGenerate { layer_path });
}
}
InputFrameRasterizeRegionBelowLayer { layer_path } => {
if layer_path.is_empty() {
responses.add(NodeGraphMessage::RunDocumentGraph);
} else if let Some(message) = self.rasterize_region_below_layer(document_id, layer_path, preferences, persistent_data, None) {
} else if let Some(message) = self.rasterize_region_below_layer(document_id, layer_path, preferences, persistent_data) {
responses.add(message);
}
}
@ -967,14 +967,7 @@ impl MessageHandler<DocumentMessage, (u64, &InputPreprocessorMessageHandler, &Pe
}
impl DocumentMessageHandler {
pub fn rasterize_region_below_layer(
&mut self,
document_id: u64,
layer_path: Vec<LayerId>,
_preferences: &PreferencesMessageHandler,
persistent_data: &PersistentData,
imaginate_node_path: Option<Vec<NodeId>>,
) -> Option<Message> {
pub fn rasterize_region_below_layer(&mut self, document_id: u64, layer_path: Vec<LayerId>, _preferences: &PreferencesMessageHandler, persistent_data: &PersistentData) -> Option<Message> {
// Prepare the node graph input image
let Some(node_network) = self.document_legacy.layer(&layer_path).ok().and_then(|layer| layer.as_layer_network().ok()) else {
@ -998,14 +991,7 @@ impl DocumentMessageHandler {
self.restore_document_transform(old_transforms);
// Once JS asynchronously rasterizes the SVG, it will call the `PortfolioMessage::RenderGraphUsingRasterizedRegionBelowLayer` message with the rasterized image data
FrontendMessage::TriggerRasterizeRegionBelowLayer {
document_id,
layer_path,
svg,
size,
imaginate_node_path,
}
.into()
FrontendMessage::TriggerRasterizeRegionBelowLayer { document_id, layer_path, svg, size }.into()
}
// Skip taking a round trip through JS since there's nothing to rasterize, and instead directly call the message which would otherwise be called asynchronously from JS
else {
@ -1014,7 +1000,6 @@ impl DocumentMessageHandler {
layer_path,
input_image_data: vec![],
size: (0, 0),
imaginate_node_path,
}
.into()
};

View file

@ -714,7 +714,6 @@ impl MessageHandler<NodeGraphMessage, (&mut Document, &NodeGraphExecutor, u64)>
layer_path: Vec::new(),
input_image_data: vec![],
size: (0, 0),
imaginate_node_path: None,
}),
NodeGraphMessage::SelectNodes { nodes } => {
self.selected_nodes = nodes;

View file

@ -984,7 +984,6 @@ pub fn node_section_font(document_node: &DocumentNode, node_id: NodeId, _context
pub fn imaginate_properties(document_node: &DocumentNode, node_id: NodeId, context: &mut NodePropertiesContext) -> Vec<LayoutGroup> {
let imaginate_node = [context.nested_path, &[node_id]].concat();
let layer_path = context.layer_path.to_vec();
let resolve_input = |name: &str| {
super::IMAGINATE_NODE
@ -1140,16 +1139,11 @@ pub fn imaginate_properties(document_node: &DocumentNode, node_id: NodeId, conte
TextButton::new("Generate")
.tooltip("Fill layer frame by generating a new image")
.on_update({
let imaginate_node = imaginate_node.clone();
let layer_path = context.layer_path.to_vec();
let controller = controller.clone();
move |_| {
controller.trigger_regenerate();
DocumentMessage::ImaginateGenerate {
layer_path: layer_path.clone(),
imaginate_node: imaginate_node.clone(),
}
.into()
DocumentMessage::ImaginateGenerate { layer_path: layer_path.clone() }.into()
}
})
.widget_holder(),

View file

@ -100,7 +100,6 @@ pub enum PortfolioMessage {
layer_path: Vec<LayerId>,
input_image_data: Vec<u8>,
size: (u32, u32),
imaginate_node_path: Option<Vec<NodeId>>,
},
SelectDocument {
document_id: u64,

View file

@ -417,7 +417,6 @@ impl MessageHandler<PortfolioMessage, (&InputPreprocessorMessageHandler, &Prefer
layer_path,
input_image_data,
size,
imaginate_node_path,
} => {
let result = self.executor.submit_node_graph_evaluation(
(document_id, &mut self.documents),

View file

@ -1,6 +1,6 @@
/* eslint-disable max-classes-per-file */
import {writable} from "svelte/store";
import { writable } from "svelte/store";
import { downloadFileText, downloadFileBlob, upload, downloadFileURL } from "@graphite/utility-functions/files";
import { extractPixelData, imageToPNG, rasterizeSVG, rasterizeSVGCanvas } from "@graphite/utility-functions/rasterization";
@ -98,7 +98,7 @@ export function createPortfolioState(editor: Editor) {
});
});
editor.subscriptions.subscribeJsMessage(TriggerRasterizeRegionBelowLayer, async (triggerRasterizeRegionBelowLayer) => {
const { documentId, layerPath, svg, size, imaginateNodePath } = triggerRasterizeRegionBelowLayer;
const { documentId, layerPath, svg, size } = triggerRasterizeRegionBelowLayer;
// Rasterize the SVG to an image file
try {
@ -106,7 +106,7 @@ export function createPortfolioState(editor: Editor) {
const imageData = (await rasterizeSVGCanvas(svg, size[0], size[1])).getContext("2d")?.getImageData(0, 0, size[0], size[1]);
if (!imageData) return;
editor.instance.renderGraphUsingRasterizedRegionBelowLayer(documentId, layerPath, new Uint8Array(imageData.data), imageData.width, imageData.height, imaginateNodePath);
editor.instance.renderGraphUsingRasterizedRegionBelowLayer(documentId, layerPath, new Uint8Array(imageData.data), imageData.width, imageData.height);
}
}
// getImageData may throw an exception if the resolution is too high

View file

@ -545,8 +545,6 @@ export class TriggerRasterizeRegionBelowLayer extends JsMessage {
readonly svg!: string;
readonly size!: [number, number];
readonly imaginateNodePath!: BigUint64Array | undefined;
}
export class TriggerRefreshBoundsOfViewports extends JsMessage { }
@ -700,7 +698,7 @@ export class ImaginateImageData {
readonly imageData!: Uint8Array;
readonly transform!: Float64Array ;
readonly transform!: Float64Array;
}
export class DisplayDialogDismiss extends JsMessage { }

View file

@ -588,21 +588,12 @@ impl JsEditorHandle {
/// Sends the blob URL generated by JS to the Imaginate layer in the respective document
#[wasm_bindgen(js_name = renderGraphUsingRasterizedRegionBelowLayer)]
pub fn render_graph_using_rasterized_region_below_layer(
&self,
document_id: u64,
layer_path: Vec<LayerId>,
input_image_data: Vec<u8>,
width: u32,
height: u32,
imaginate_node_path: Option<Vec<NodeId>>,
) {
pub fn render_graph_using_rasterized_region_below_layer(&self, document_id: u64, layer_path: Vec<LayerId>, input_image_data: Vec<u8>, width: u32, height: u32) {
let message = PortfolioMessage::RenderGraphUsingRasterizedRegionBelowLayer {
document_id,
layer_path,
input_image_data,
size: (width, height),
imaginate_node_path,
};
self.dispatch(message);
}

View file

@ -675,7 +675,7 @@ impl NodeNetwork {
let mut dummy_input = NodeInput::ShortCircut(concrete!(()));
std::mem::swap(&mut dummy_input, input);
if let NodeInput::Value { mut tagged_value, exposed } = dummy_input {
if let NodeInput::Value { tagged_value, exposed } = dummy_input {
let value_node_id = gen_id();
let merged_node_id = map_ids(id, value_node_id);
let path = if let Some(mut new_path) = node.path.clone() {

View file

@ -11,19 +11,17 @@ pub struct DynAnyNode<I, O, Node> {
_o: PhantomData<O>,
}
impl<'input, _I: 'input + StaticType, _O: 'input + StaticType, N: 'input, S0: 'input> Node<'input, Any<'input>> for DynAnyNode<_I, _O, S0>
impl<'input, _I: 'input + StaticType, _O: 'input + StaticType, N: 'input> Node<'input, Any<'input>> for DynAnyNode<_I, _O, N>
where
N: for<'any_input> Node<'any_input, _I, Output = DynFuture<'any_input, _O>>,
S0: for<'any_input> Node<'any_input, (), Output = &'any_input N>,
N: Node<'input, _I, Output = DynFuture<'input, _O>>,
{
type Output = FutureAny<'input>;
#[inline]
fn eval(&'input self, input: Any<'input>) -> Self::Output {
let node = self.node.eval(());
let node_name = core::any::type_name::<N>();
let input: Box<_I> = dyn_any::downcast(input).unwrap_or_else(|e| panic!("DynAnyNode Input, {0} in:\n{1}", e, node_name));
let output = async move {
let result = node.eval(*input).await;
let result = self.node.eval(*input).await;
Box::new(result) as Any<'input>
};
Box::pin(output)
@ -34,14 +32,14 @@ where
}
fn serialize(&self) -> Option<std::sync::Arc<dyn core::any::Any>> {
self.node.eval(()).serialize()
self.node.serialize()
}
}
impl<'input, _I: StaticType, _O: StaticType, N, S0: 'input> DynAnyNode<_I, _O, S0>
impl<'input, _I: 'input + StaticType, _O: 'input + StaticType, N: 'input> DynAnyNode<_I, _O, N>
where
S0: for<'any_input> Node<'any_input, (), Output = &'any_input N>,
N: Node<'input, _I, Output = DynFuture<'input, _O>>,
{
pub const fn new(node: S0) -> Self {
pub const fn new(node: N) -> Self {
Self {
node,
_i: core::marker::PhantomData,
@ -271,7 +269,7 @@ mod test {
pub fn dyn_input_invalid_eval_panic() {
//let add = DynAnyNode::new(AddNode::new()).into_type_erased();
//add.eval(Box::new(&("32", 32u32)));
let dyn_any = DynAnyNode::<(u32, u32), u32, _>::new(ValueNode::new(FutureWrapperNode { node: AddNode::new() }));
let dyn_any = DynAnyNode::<(u32, u32), u32, _>::new(FutureWrapperNode { node: AddNode::new() });
let type_erased = Box::new(dyn_any) as TypeErasedBox;
let _ref_type_erased = type_erased.as_ref();
//let type_erased = Box::pin(dyn_any) as TypeErasedBox<'_>;
@ -282,11 +280,11 @@ mod test {
pub fn dyn_input_compose() {
//let add = DynAnyNode::new(AddNode::new()).into_type_erased();
//add.eval(Box::new(&("32", 32u32)));
let dyn_any = DynAnyNode::<(u32, u32), u32, _>::new(ValueNode::new(FutureWrapperNode { node: AddNode::new() }));
let dyn_any = DynAnyNode::<(u32, u32), u32, _>::new(FutureWrapperNode { node: AddNode::new() });
let type_erased = Box::new(dyn_any) as TypeErasedBox<'_>;
type_erased.eval(Box::new((4u32, 2u32)));
let id_node = FutureWrapperNode::new(IdNode::new());
let any_id = DynAnyNode::<u32, u32, _>::new(ValueNode::new(id_node));
let any_id = DynAnyNode::<u32, u32, _>::new(id_node);
let type_erased_id = Box::new(any_id) as TypeErasedBox;
let type_erased = ComposeTypeErased::new(NodeContainer::new(type_erased), NodeContainer::new(type_erased_id));
type_erased.eval(Box::new((4u32, 2u32)));

View file

@ -224,11 +224,12 @@ pub struct BlendImageNode<P, Background, MapFn> {
}
#[node_macro::node_fn(BlendImageNode<_P>)]
async fn blend_image_node<_P: Alpha + Pixel + Debug, MapFn, Forground: Sample<Pixel = _P> + Transform>(foreground: Forground, background: ImageFrame<_P>, map_fn: &'input MapFn) -> ImageFrame<_P>
where
MapFn: for<'any_input> Node<'any_input, (_P, _P), Output = _P> + 'input,
{
blend_new_image(foreground, background, map_fn)
async fn blend_image_node<_P: Alpha + Pixel + Debug, Forground: Sample<Pixel = _P> + Transform>(
foreground: Forground,
background: ImageFrame<_P>,
map_fn: impl Node<(_P, _P), Output = _P>,
) -> ImageFrame<_P> {
blend_new_image(foreground, background, &self.map_fn)
}
#[derive(Debug, Clone, Copy)]
@ -246,9 +247,9 @@ where
blend_new_image(background, foreground, map_fn)
}
fn blend_new_image<_P: Alpha + Pixel + Debug, MapFn, Frame: Sample<Pixel = _P> + Transform>(foreground: Frame, background: ImageFrame<_P>, map_fn: &MapFn) -> ImageFrame<_P>
fn blend_new_image<'input, _P: Alpha + Pixel + Debug, MapFn, Frame: Sample<Pixel = _P> + Transform>(foreground: Frame, background: ImageFrame<_P>, map_fn: &'input MapFn) -> ImageFrame<_P>
where
MapFn: for<'any_input> Node<'any_input, (_P, _P), Output = _P>,
MapFn: Node<'input, (_P, _P), Output = _P>,
{
let foreground_aabb = Bbox::unit().affine_transform(foreground.transform()).to_axis_aligned_bbox();
let background_aabb = Bbox::unit().affine_transform(background.transform()).to_axis_aligned_bbox();
@ -275,13 +276,13 @@ where
blend_image(foreground, new_background, map_fn)
}
fn blend_image<_P: Alpha + Pixel + Debug, MapFn, Frame: Sample<Pixel = _P> + Transform, Background: RasterMut<Pixel = _P> + Transform + Sample<Pixel = _P>>(
fn blend_image<'input, _P: Alpha + Pixel + Debug, MapFn, Frame: Sample<Pixel = _P> + Transform, Background: RasterMut<Pixel = _P> + Transform + Sample<Pixel = _P>>(
foreground: Frame,
background: Background,
map_fn: &MapFn,
map_fn: &'input MapFn,
) -> Background
where
MapFn: for<'any_input> Node<'any_input, (_P, _P), Output = _P>,
MapFn: Node<'input, (_P, _P), Output = _P>,
{
blend_image_closure(foreground, background, |a, b| map_fn.eval((a, b)))
}

View file

@ -55,7 +55,7 @@ macro_rules! register_node {
Box::pin(async move {
let node = construct_node!(args, $path, [$($type),*]).await;
let node = graphene_std::any::FutureWrapperNode::new(node);
let any: DynAnyNode<$input, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node));
let any: DynAnyNode<$input, _, _> = graphene_std::any::DynAnyNode::new(node);
Box::new(any) as TypeErasedBox
})
},
@ -83,7 +83,7 @@ macro_rules! async_node {
Box::pin(async move {
args.reverse();
let node = <$path>::new($(graphene_std::any::input_node::<$type>(args.pop().expect("Not enough arguments provided to construct node"))),*);
let any: DynAnyNode<$input, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node));
let any: DynAnyNode<$input, _, _> = graphene_std::any::DynAnyNode::new(node);
Box::new(any) as TypeErasedBox
})
},
@ -122,7 +122,7 @@ macro_rules! raster_node {
Box::pin(async move {
let node = construct_node!(args, $path, [$($type),*]).await;
let node = graphene_std::any::FutureWrapperNode::new(node);
let any: DynAnyNode<Color, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node));
let any: DynAnyNode<Color, _, _> = graphene_std::any::DynAnyNode::new(node);
any.into_type_erased()
})
},
@ -138,7 +138,7 @@ macro_rules! raster_node {
let node = construct_node!(args, $path, [$($type),*]).await;
let map_node = graphene_std::raster::MapImageNode::new(graphene_core::value::ValueNode::new(node));
let map_node = graphene_std::any::FutureWrapperNode::new(map_node);
let any: DynAnyNode<Image<Color>, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(map_node));
let any: DynAnyNode<Image<Color>, _, _> = graphene_std::any::DynAnyNode::new(map_node);
any.into_type_erased()
})
},
@ -154,7 +154,7 @@ macro_rules! raster_node {
let node = construct_node!(args, $path, [$($type),*]).await;
let map_node = graphene_std::raster::MapImageNode::new(graphene_core::value::ValueNode::new(node));
let map_node = graphene_std::any::FutureWrapperNode::new(map_node);
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(map_node));
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(map_node);
any.into_type_erased()
})
},
@ -238,7 +238,7 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
let final_image = ClonedNode::new(empty_image).then(complete_node);
let final_image = FutureWrapperNode::new(final_image);
let any: DynAnyNode<(), _, _> = graphene_std::any::DynAnyNode::new(ValueNode::new(final_image));
let any: DynAnyNode<(), _, _> = graphene_std::any::DynAnyNode::new(final_image);
any.into_type_erased()
})
},
@ -304,7 +304,7 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
let editor_api: DowncastBothNode<(), WasmEditorApi> = DowncastBothNode::new(args[1].clone());
//let document_node = ClonedNode::new(document_node.eval(()));
let node = graphene_std::gpu_nodes::MapGpuNode::new(document_node, editor_api);
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node));
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(node);
any.into_type_erased()
})
},
@ -323,7 +323,7 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
let blend_mode: DowncastBothNode<(), BlendMode> = DowncastBothNode::new(args[1].clone());
let opacity: DowncastBothNode<(), f32> = DowncastBothNode::new(args[2].clone());
let node = graphene_std::gpu_nodes::BlendGpuImageNode::new(background, blend_mode, opacity);
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node));
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(node);
any.into_type_erased()
})
@ -366,8 +366,8 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
let blend_mode: DowncastBothNode<(), BlendMode> = DowncastBothNode::new(args[1].clone());
let opacity: DowncastBothNode<(), f64> = DowncastBothNode::new(args[2].clone());
let blend_node = graphene_core::raster::BlendNode::new(CopiedNode::new(blend_mode.eval(()).await), CopiedNode::new(opacity.eval(()).await));
let node = graphene_std::raster::BlendImageNode::new(image, FutureWrapperNode::new(ValueNode::new(blend_node)));
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node));
let node = graphene_std::raster::BlendImageNode::new(image, blend_node);
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(node);
any.into_type_erased()
})
},
@ -406,13 +406,13 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
let generate_brightness_contrast_legacy_mapper_node = GenerateBrightnessContrastLegacyMapperNode::new(brightness, contrast);
let map_image_frame_node = graphene_std::raster::MapImageNode::new(ValueNode::new(generate_brightness_contrast_legacy_mapper_node.eval(())));
let map_image_frame_node = FutureWrapperNode::new(map_image_frame_node);
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(ValueNode::new(map_image_frame_node));
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(map_image_frame_node);
any.into_type_erased()
} else {
let generate_brightness_contrast_mapper_node = GenerateBrightnessContrastMapperNode::new(brightness, contrast);
let map_image_frame_node = graphene_std::raster::MapImageNode::new(ValueNode::new(generate_brightness_contrast_mapper_node.eval(())));
let map_image_frame_node = FutureWrapperNode::new(map_image_frame_node);
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(ValueNode::new(map_image_frame_node));
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(map_image_frame_node);
any.into_type_erased()
}
})
@ -452,7 +452,8 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
Box::pin(async move {
let node: DowncastBothNode<Option<WasmEditorApi>, WasmEditorApi> = graphene_std::any::DowncastBothNode::new(args[0].clone());
let node = <graphene_core::memo::RefNode<_, _>>::new(node);
let any: DynAnyNode<(), _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node));
let any: DynAnyNode<(), _, _> = graphene_std::any::DynAnyNode::new(node);
any.into_type_erased()
})
},
@ -468,7 +469,7 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
($($i:expr,)*) => { ImaginateNode::new($(graphene_std::any::input_node(args[$i].clone()),)* cache.into_inner()) };
}
let node: ImaginateNode<Color, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _> = instanciate_imaginate_node!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,);
let any = graphene_std::any::DynAnyNode::new(ValueNode::new(node));
let any = graphene_std::any::DynAnyNode::new(node);
any.into_type_erased()
})
},

View file

@ -2,8 +2,8 @@ use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{format_ident, quote, ToTokens};
use syn::{
parse_macro_input, punctuated::Punctuated, token::Comma, FnArg, GenericParam, Ident, ItemFn, Lifetime, Pat, PatIdent, PathArguments, PredicateType, ReturnType, Token, TraitBound, Type, TypeParam,
TypeParamBound, WhereClause, WherePredicate,
parse_macro_input, punctuated::Punctuated, token::Comma, AngleBracketedGenericArguments, Binding, FnArg, GenericArgument, GenericParam, Ident, ItemFn, Lifetime, Pat, PatIdent, PathArguments,
PredicateType, ReturnType, Token, TraitBound, Type, TypeImplTrait, TypeParam, TypeParamBound, TypeTuple, WhereClause, WherePredicate,
};
#[proc_macro_attribute]
@ -30,7 +30,7 @@ fn node_new_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let node = &node;
let node_name = &node.ident;
let mut args = args(node);
let mut args = node_args(node);
let arg_idents = args
.iter()
@ -38,7 +38,7 @@ fn node_new_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
.map(|arg| Ident::new(arg.to_token_stream().to_string().to_lowercase().as_str(), Span::call_site()))
.collect::<Vec<_>>();
let (_, _, parameter_pat_ident_patterns) = parse_inputs(&function);
let (_, _, parameter_pat_ident_patterns) = parse_inputs(&function, false);
let parameter_idents = parameter_pat_ident_patterns.iter().map(|pat_ident| &pat_ident.ident).collect::<Vec<_>>();
// Extract the output type of the entire node - `()` by default
@ -69,7 +69,7 @@ fn node_new_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
.into()
}
fn args(node: &syn::PathSegment) -> Vec<Type> {
fn node_args(node: &syn::PathSegment) -> Vec<Type> {
match node.arguments.clone() {
PathArguments::AngleBracketed(args) => args
.args
@ -105,7 +105,7 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream, asyncness: Asyncness) ->
let node = &node;
let node_name = &node.ident;
let mut args = args(node);
let mut args = node_args(node);
let async_out = match asyncness {
Asyncness::Sync => false,
@ -126,13 +126,11 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream, asyncness: Asyncness) ->
}
});
let (primary_input, parameter_inputs, parameter_pat_ident_patterns) = parse_inputs(&function);
let (primary_input, parameter_inputs, parameter_pat_ident_patterns) = parse_inputs(&function, true);
let primary_input_ty = &primary_input.ty;
let Pat::Ident(PatIdent{ident: primary_input_ident, mutability: primary_input_mutability,..} ) =&*primary_input.pat else {
panic!("Expected ident as primary input.");
};
let parameter_idents = parameter_pat_ident_patterns.iter().map(|pat_ident| &pat_ident.ident).collect::<Vec<_>>();
let parameter_mutability = parameter_pat_ident_patterns.iter().map(|pat_ident| &pat_ident.mutability);
// Extract the output type of the entire node - `()` by default
let output = if let ReturnType::Type(_, ty) = &function.sig.output {
@ -141,10 +139,18 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream, asyncness: Asyncness) ->
quote::quote!(())
};
let struct_generics = (0..parameter_pat_ident_patterns.len()).map(|x| format_ident!("S{x}")).collect::<Vec<_>>();
let future_generics = (0..parameter_pat_ident_patterns.len()).map(|x| format_ident!("F{x}")).collect::<Vec<_>>();
let future_types = future_generics.iter().map(|x| Type::Verbatim(x.to_token_stream())).collect::<Vec<_>>();
let num_inputs = parameter_inputs.len();
let struct_generics = (0..num_inputs).map(|x| format_ident!("S{x}")).collect::<Vec<_>>();
let future_generics = (0..num_inputs).map(|x| format_ident!("F{x}")).collect::<Vec<_>>();
let parameter_types = parameter_inputs.iter().map(|x| *x.ty.clone()).collect::<Vec<Type>>();
let future_types = future_generics
.iter()
.enumerate()
.map(|(i, x)| match parameter_types[i].clone() {
Type::ImplTrait(x) => Type::ImplTrait(x),
_ => Type::Verbatim(x.to_token_stream()),
})
.collect::<Vec<_>>();
for ident in struct_generics.iter() {
args.push(Type::Verbatim(quote::quote!(#ident)));
@ -153,6 +159,7 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream, asyncness: Asyncness) ->
// Generics are simply `S0` through to `Sn-1` where n is the number of secondary inputs
let node_generics = construct_node_generics(&struct_generics);
let future_generic_params = construct_node_generics(&future_generics);
let (future_parameter_types, future_generic_params): (Vec<_>, Vec<_>) = parameter_types.iter().cloned().zip(future_generic_params).filter(|(ty, _)| !matches!(ty, Type::ImplTrait(_))).unzip();
let generics = if async_in {
type_generics
@ -166,12 +173,12 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream, asyncness: Asyncness) ->
// Bindings for all of the above generics to a node with an input of `()` and an output of the type in the function
let node_bounds = if async_in {
let mut node_bounds = input_node_bounds(future_types, node_generics, |ty| quote! {Node<'input, (), Output = #ty>});
let future_bounds = input_node_bounds(parameter_types, future_generic_params, |ty| quote! { core::future::Future<Output = #ty>});
let mut node_bounds = input_node_bounds(future_types, node_generics, |lifetime, in_ty, out_ty| quote! {Node<#lifetime, #in_ty, Output = #out_ty>});
let future_bounds = input_node_bounds(future_parameter_types, future_generic_params, |_, _, out_ty| quote! { core::future::Future<Output = #out_ty>});
node_bounds.extend(future_bounds);
node_bounds
} else {
input_node_bounds(parameter_types, node_generics, |ty| quote! {Node<'input, (), Output = #ty>})
input_node_bounds(parameter_types, node_generics, |lifetime, in_ty, out_ty| quote! {Node<#lifetime, #in_ty, Output = #out_ty>})
};
where_clause.predicates.extend(node_bounds);
@ -181,6 +188,9 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream, asyncness: Asyncness) ->
quote::quote!(#output)
};
let parameter_idents = parameter_pat_ident_patterns.iter().map(|pat_ident| &pat_ident.ident).collect::<Vec<_>>();
let parameter_mutability = parameter_pat_ident_patterns.iter().map(|pat_ident| &pat_ident.mutability);
let parameters = if matches!(asyncness, Asyncness::AllAsync) {
quote::quote!(#(let #parameter_mutability #parameter_idents = self.#parameter_idents.eval(()).await;)*)
} else {
@ -209,7 +219,7 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream, asyncness: Asyncness) ->
.into()
}
fn parse_inputs(function: &ItemFn) -> (&syn::PatType, Vec<&syn::PatType>, Vec<&PatIdent>) {
fn parse_inputs(function: &ItemFn, remove_impl_node: bool) -> (&syn::PatType, Vec<&syn::PatType>, Vec<&PatIdent>) {
let mut function_inputs = function.sig.inputs.iter().filter_map(|arg| if let FnArg::Typed(typed_arg) = arg { Some(typed_arg) } else { None });
// Extract primary input as first argument
@ -217,8 +227,10 @@ fn parse_inputs(function: &ItemFn) -> (&syn::PatType, Vec<&syn::PatType>, Vec<&P
// Extract secondary inputs as all other arguments
let parameter_inputs = function_inputs.collect::<Vec<_>>();
let parameter_pat_ident_patterns = parameter_inputs
.iter()
.filter(|input| !matches!(&*input.ty, Type::ImplTrait(_)) || !remove_impl_node)
.map(|input| {
let Pat::Ident(pat_ident) = &*input.pat else { panic!("Expected ident for secondary input."); };
pat_ident
@ -244,14 +256,43 @@ fn construct_node_generics(struct_generics: &[Ident]) -> Vec<GenericParam> {
.collect()
}
fn input_node_bounds(parameter_inputs: Vec<Type>, node_generics: Vec<GenericParam>, trait_bound: impl Fn(Type) -> proc_macro2::TokenStream) -> Vec<WherePredicate> {
fn input_node_bounds(parameter_inputs: Vec<Type>, node_generics: Vec<GenericParam>, trait_bound: impl Fn(Lifetime, Type, Type) -> proc_macro2::TokenStream) -> Vec<WherePredicate> {
parameter_inputs
.iter()
.zip(&node_generics)
.map(|(ty, name)| {
let GenericParam::Type(generic_ty) = name else { panic!("Expected type generic."); };
let ident = &generic_ty.ident;
let bound = trait_bound(ty.clone());
let (lifetime, in_ty, out_ty) = match ty.clone() {
Type::ImplTrait(TypeImplTrait { bounds, .. }) if bounds.len() == 1 => {
let TypeParamBound::Trait(TraitBound { ref path, .. }) = bounds[0] else {panic!("impl Traits other then Node are not supported")};
let node_segment = path.segments.last().expect("Found an empty path in the impl Trait arg");
assert_eq!(node_segment.ident.to_string(), "Node", "Only impl Node is supported as an argument");
let PathArguments::AngleBracketed(AngleBracketedGenericArguments {ref args, .. }) = node_segment.arguments else { panic!("Node must have generic arguments")};
let mut args_iter = args.iter();
let lifetime = if args.len() == 2 {
Lifetime::new("'input", Span::call_site())
} else if let Some(GenericArgument::Lifetime(node_lifetime)) = args_iter.next() {
node_lifetime.clone()
} else {
panic!("Invalid arguments for Node trait")
};
let Some(GenericArgument::Type(in_ty)) = args_iter.next() else { panic!("Expected type argument in Node<> declaration")};
let Some(GenericArgument::Binding(Binding {ty: out_ty, ..} )) = args_iter.next() else { panic!("Expected Output = in Node declaration")};
(lifetime, in_ty.clone(), out_ty.clone())
}
ty => (
Lifetime::new("'input", Span::call_site()),
Type::Tuple(TypeTuple {
paren_token: syn::token::Paren { span: Span::call_site() },
elems: Punctuated::new(),
}),
ty,
),
};
let bound = trait_bound(lifetime, in_ty, out_ty);
WherePredicate::Type(PredicateType {
lifetimes: None,
bounded_ty: Type::Verbatim(ident.to_token_stream()),