diff --git a/editor/src/messages/frontend/frontend_message.rs b/editor/src/messages/frontend/frontend_message.rs index 9e5428486..542fcd843 100644 --- a/editor/src/messages/frontend/frontend_message.rs +++ b/editor/src/messages/frontend/frontend_message.rs @@ -95,8 +95,6 @@ pub enum FrontendMessage { layer_path: Vec, svg: String, size: glam::DVec2, - #[serde(rename = "imaginateNodePath")] - imaginate_node_path: Option>, }, TriggerRefreshBoundsOfViewports, TriggerRevokeBlobUrl { diff --git a/editor/src/messages/portfolio/document/document_message.rs b/editor/src/messages/portfolio/document/document_message.rs index 35ac00512..a609f7d8c 100644 --- a/editor/src/messages/portfolio/document/document_message.rs +++ b/editor/src/messages/portfolio/document/document_message.rs @@ -96,7 +96,6 @@ pub enum DocumentMessage { }, ImaginateGenerate { layer_path: Vec, - imaginate_node: Vec, }, ImaginateRandom { layer_path: Vec, diff --git a/editor/src/messages/portfolio/document/document_message_handler.rs b/editor/src/messages/portfolio/document/document_message_handler.rs index 82dff9037..a5762dbf9 100644 --- a/editor/src/messages/portfolio/document/document_message_handler.rs +++ b/editor/src/messages/portfolio/document/document_message_handler.rs @@ -31,7 +31,7 @@ use document_legacy::layers::layer_layer::CachedOutputData; use document_legacy::layers::style::{RenderData, ViewMode}; use document_legacy::{DocumentError, DocumentResponse, LayerId, Operation as DocumentOperation}; use graph_craft::document::value::TaggedValue; -use graph_craft::document::{NodeId, NodeInput, NodeNetwork}; +use graph_craft::document::{NodeInput, NodeNetwork}; use graphene_core::raster::ImageFrame; use graphene_core::text::Font; @@ -466,8 +466,8 @@ impl MessageHandler responses.add(InputFrameRasterizeRegionBelowLayer { layer_path }), - ImaginateGenerate { layer_path, imaginate_node } => { - if let Some(message) = self.rasterize_region_below_layer(document_id, layer_path, preferences, persistent_data, Some(imaginate_node)) { + ImaginateGenerate { layer_path } => { + if let Some(message) = self.rasterize_region_below_layer(document_id, layer_path, preferences, persistent_data) { responses.add(message); } } @@ -490,13 +490,13 @@ impl MessageHandler { if layer_path.is_empty() { responses.add(NodeGraphMessage::RunDocumentGraph); - } else if let Some(message) = self.rasterize_region_below_layer(document_id, layer_path, preferences, persistent_data, None) { + } else if let Some(message) = self.rasterize_region_below_layer(document_id, layer_path, preferences, persistent_data) { responses.add(message); } } @@ -967,14 +967,7 @@ impl MessageHandler, - _preferences: &PreferencesMessageHandler, - persistent_data: &PersistentData, - imaginate_node_path: Option>, - ) -> Option { + pub fn rasterize_region_below_layer(&mut self, document_id: u64, layer_path: Vec, _preferences: &PreferencesMessageHandler, persistent_data: &PersistentData) -> Option { // Prepare the node graph input image let Some(node_network) = self.document_legacy.layer(&layer_path).ok().and_then(|layer| layer.as_layer_network().ok()) else { @@ -998,14 +991,7 @@ impl DocumentMessageHandler { self.restore_document_transform(old_transforms); // Once JS asynchronously rasterizes the SVG, it will call the `PortfolioMessage::RenderGraphUsingRasterizedRegionBelowLayer` message with the rasterized image data - FrontendMessage::TriggerRasterizeRegionBelowLayer { - document_id, - layer_path, - svg, - size, - imaginate_node_path, - } - .into() + FrontendMessage::TriggerRasterizeRegionBelowLayer { document_id, layer_path, svg, size }.into() } // Skip taking a round trip through JS since there's nothing to rasterize, and instead directly call the message which would otherwise be called asynchronously from JS else { @@ -1014,7 +1000,6 @@ impl DocumentMessageHandler { layer_path, input_image_data: vec![], size: (0, 0), - imaginate_node_path, } .into() }; diff --git a/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler.rs b/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler.rs index 6000d5ac1..ae4e5c6e5 100644 --- a/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler.rs +++ b/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler.rs @@ -714,7 +714,6 @@ impl MessageHandler layer_path: Vec::new(), input_image_data: vec![], size: (0, 0), - imaginate_node_path: None, }), NodeGraphMessage::SelectNodes { nodes } => { self.selected_nodes = nodes; diff --git a/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler/node_properties.rs b/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler/node_properties.rs index 20ca218f4..f9dd10d39 100644 --- a/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler/node_properties.rs +++ b/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler/node_properties.rs @@ -984,7 +984,6 @@ pub fn node_section_font(document_node: &DocumentNode, node_id: NodeId, _context pub fn imaginate_properties(document_node: &DocumentNode, node_id: NodeId, context: &mut NodePropertiesContext) -> Vec { let imaginate_node = [context.nested_path, &[node_id]].concat(); - let layer_path = context.layer_path.to_vec(); let resolve_input = |name: &str| { super::IMAGINATE_NODE @@ -1140,16 +1139,11 @@ pub fn imaginate_properties(document_node: &DocumentNode, node_id: NodeId, conte TextButton::new("Generate") .tooltip("Fill layer frame by generating a new image") .on_update({ - let imaginate_node = imaginate_node.clone(); let layer_path = context.layer_path.to_vec(); let controller = controller.clone(); move |_| { controller.trigger_regenerate(); - DocumentMessage::ImaginateGenerate { - layer_path: layer_path.clone(), - imaginate_node: imaginate_node.clone(), - } - .into() + DocumentMessage::ImaginateGenerate { layer_path: layer_path.clone() }.into() } }) .widget_holder(), diff --git a/editor/src/messages/portfolio/portfolio_message.rs b/editor/src/messages/portfolio/portfolio_message.rs index ee6bd7532..ed7840c96 100644 --- a/editor/src/messages/portfolio/portfolio_message.rs +++ b/editor/src/messages/portfolio/portfolio_message.rs @@ -100,7 +100,6 @@ pub enum PortfolioMessage { layer_path: Vec, input_image_data: Vec, size: (u32, u32), - imaginate_node_path: Option>, }, SelectDocument { document_id: u64, diff --git a/editor/src/messages/portfolio/portfolio_message_handler.rs b/editor/src/messages/portfolio/portfolio_message_handler.rs index b90f25847..b62c67747 100644 --- a/editor/src/messages/portfolio/portfolio_message_handler.rs +++ b/editor/src/messages/portfolio/portfolio_message_handler.rs @@ -417,7 +417,6 @@ impl MessageHandler { let result = self.executor.submit_node_graph_evaluation( (document_id, &mut self.documents), diff --git a/frontend/src/state-providers/portfolio.ts b/frontend/src/state-providers/portfolio.ts index bf9470444..4986e3a54 100644 --- a/frontend/src/state-providers/portfolio.ts +++ b/frontend/src/state-providers/portfolio.ts @@ -1,6 +1,6 @@ /* eslint-disable max-classes-per-file */ -import {writable} from "svelte/store"; +import { writable } from "svelte/store"; import { downloadFileText, downloadFileBlob, upload, downloadFileURL } from "@graphite/utility-functions/files"; import { extractPixelData, imageToPNG, rasterizeSVG, rasterizeSVGCanvas } from "@graphite/utility-functions/rasterization"; @@ -98,7 +98,7 @@ export function createPortfolioState(editor: Editor) { }); }); editor.subscriptions.subscribeJsMessage(TriggerRasterizeRegionBelowLayer, async (triggerRasterizeRegionBelowLayer) => { - const { documentId, layerPath, svg, size, imaginateNodePath } = triggerRasterizeRegionBelowLayer; + const { documentId, layerPath, svg, size } = triggerRasterizeRegionBelowLayer; // Rasterize the SVG to an image file try { @@ -106,7 +106,7 @@ export function createPortfolioState(editor: Editor) { const imageData = (await rasterizeSVGCanvas(svg, size[0], size[1])).getContext("2d")?.getImageData(0, 0, size[0], size[1]); if (!imageData) return; - editor.instance.renderGraphUsingRasterizedRegionBelowLayer(documentId, layerPath, new Uint8Array(imageData.data), imageData.width, imageData.height, imaginateNodePath); + editor.instance.renderGraphUsingRasterizedRegionBelowLayer(documentId, layerPath, new Uint8Array(imageData.data), imageData.width, imageData.height); } } // getImageData may throw an exception if the resolution is too high diff --git a/frontend/src/wasm-communication/messages.ts b/frontend/src/wasm-communication/messages.ts index 419284334..90a0dda52 100644 --- a/frontend/src/wasm-communication/messages.ts +++ b/frontend/src/wasm-communication/messages.ts @@ -545,8 +545,6 @@ export class TriggerRasterizeRegionBelowLayer extends JsMessage { readonly svg!: string; readonly size!: [number, number]; - - readonly imaginateNodePath!: BigUint64Array | undefined; } export class TriggerRefreshBoundsOfViewports extends JsMessage { } @@ -700,7 +698,7 @@ export class ImaginateImageData { readonly imageData!: Uint8Array; - readonly transform!: Float64Array ; + readonly transform!: Float64Array; } export class DisplayDialogDismiss extends JsMessage { } diff --git a/frontend/wasm/src/editor_api.rs b/frontend/wasm/src/editor_api.rs index be2d87a34..1adc3bc1e 100644 --- a/frontend/wasm/src/editor_api.rs +++ b/frontend/wasm/src/editor_api.rs @@ -588,21 +588,12 @@ impl JsEditorHandle { /// Sends the blob URL generated by JS to the Imaginate layer in the respective document #[wasm_bindgen(js_name = renderGraphUsingRasterizedRegionBelowLayer)] - pub fn render_graph_using_rasterized_region_below_layer( - &self, - document_id: u64, - layer_path: Vec, - input_image_data: Vec, - width: u32, - height: u32, - imaginate_node_path: Option>, - ) { + pub fn render_graph_using_rasterized_region_below_layer(&self, document_id: u64, layer_path: Vec, input_image_data: Vec, width: u32, height: u32) { let message = PortfolioMessage::RenderGraphUsingRasterizedRegionBelowLayer { document_id, layer_path, input_image_data, size: (width, height), - imaginate_node_path, }; self.dispatch(message); } diff --git a/node-graph/graph-craft/src/document.rs b/node-graph/graph-craft/src/document.rs index 4a320a2e2..bb0d32982 100644 --- a/node-graph/graph-craft/src/document.rs +++ b/node-graph/graph-craft/src/document.rs @@ -675,7 +675,7 @@ impl NodeNetwork { let mut dummy_input = NodeInput::ShortCircut(concrete!(())); std::mem::swap(&mut dummy_input, input); - if let NodeInput::Value { mut tagged_value, exposed } = dummy_input { + if let NodeInput::Value { tagged_value, exposed } = dummy_input { let value_node_id = gen_id(); let merged_node_id = map_ids(id, value_node_id); let path = if let Some(mut new_path) = node.path.clone() { diff --git a/node-graph/gstd/src/any.rs b/node-graph/gstd/src/any.rs index e0f0c9e6f..8e2b776da 100644 --- a/node-graph/gstd/src/any.rs +++ b/node-graph/gstd/src/any.rs @@ -11,19 +11,17 @@ pub struct DynAnyNode { _o: PhantomData, } -impl<'input, _I: 'input + StaticType, _O: 'input + StaticType, N: 'input, S0: 'input> Node<'input, Any<'input>> for DynAnyNode<_I, _O, S0> +impl<'input, _I: 'input + StaticType, _O: 'input + StaticType, N: 'input> Node<'input, Any<'input>> for DynAnyNode<_I, _O, N> where - N: for<'any_input> Node<'any_input, _I, Output = DynFuture<'any_input, _O>>, - S0: for<'any_input> Node<'any_input, (), Output = &'any_input N>, + N: Node<'input, _I, Output = DynFuture<'input, _O>>, { type Output = FutureAny<'input>; #[inline] fn eval(&'input self, input: Any<'input>) -> Self::Output { - let node = self.node.eval(()); let node_name = core::any::type_name::(); let input: Box<_I> = dyn_any::downcast(input).unwrap_or_else(|e| panic!("DynAnyNode Input, {0} in:\n{1}", e, node_name)); let output = async move { - let result = node.eval(*input).await; + let result = self.node.eval(*input).await; Box::new(result) as Any<'input> }; Box::pin(output) @@ -34,14 +32,14 @@ where } fn serialize(&self) -> Option> { - self.node.eval(()).serialize() + self.node.serialize() } } -impl<'input, _I: StaticType, _O: StaticType, N, S0: 'input> DynAnyNode<_I, _O, S0> +impl<'input, _I: 'input + StaticType, _O: 'input + StaticType, N: 'input> DynAnyNode<_I, _O, N> where - S0: for<'any_input> Node<'any_input, (), Output = &'any_input N>, + N: Node<'input, _I, Output = DynFuture<'input, _O>>, { - pub const fn new(node: S0) -> Self { + pub const fn new(node: N) -> Self { Self { node, _i: core::marker::PhantomData, @@ -271,7 +269,7 @@ mod test { pub fn dyn_input_invalid_eval_panic() { //let add = DynAnyNode::new(AddNode::new()).into_type_erased(); //add.eval(Box::new(&("32", 32u32))); - let dyn_any = DynAnyNode::<(u32, u32), u32, _>::new(ValueNode::new(FutureWrapperNode { node: AddNode::new() })); + let dyn_any = DynAnyNode::<(u32, u32), u32, _>::new(FutureWrapperNode { node: AddNode::new() }); let type_erased = Box::new(dyn_any) as TypeErasedBox; let _ref_type_erased = type_erased.as_ref(); //let type_erased = Box::pin(dyn_any) as TypeErasedBox<'_>; @@ -282,11 +280,11 @@ mod test { pub fn dyn_input_compose() { //let add = DynAnyNode::new(AddNode::new()).into_type_erased(); //add.eval(Box::new(&("32", 32u32))); - let dyn_any = DynAnyNode::<(u32, u32), u32, _>::new(ValueNode::new(FutureWrapperNode { node: AddNode::new() })); + let dyn_any = DynAnyNode::<(u32, u32), u32, _>::new(FutureWrapperNode { node: AddNode::new() }); let type_erased = Box::new(dyn_any) as TypeErasedBox<'_>; type_erased.eval(Box::new((4u32, 2u32))); let id_node = FutureWrapperNode::new(IdNode::new()); - let any_id = DynAnyNode::::new(ValueNode::new(id_node)); + let any_id = DynAnyNode::::new(id_node); let type_erased_id = Box::new(any_id) as TypeErasedBox; let type_erased = ComposeTypeErased::new(NodeContainer::new(type_erased), NodeContainer::new(type_erased_id)); type_erased.eval(Box::new((4u32, 2u32))); diff --git a/node-graph/gstd/src/raster.rs b/node-graph/gstd/src/raster.rs index 7a9ac2b57..cdc482701 100644 --- a/node-graph/gstd/src/raster.rs +++ b/node-graph/gstd/src/raster.rs @@ -224,11 +224,12 @@ pub struct BlendImageNode { } #[node_macro::node_fn(BlendImageNode<_P>)] -async fn blend_image_node<_P: Alpha + Pixel + Debug, MapFn, Forground: Sample + Transform>(foreground: Forground, background: ImageFrame<_P>, map_fn: &'input MapFn) -> ImageFrame<_P> -where - MapFn: for<'any_input> Node<'any_input, (_P, _P), Output = _P> + 'input, -{ - blend_new_image(foreground, background, map_fn) +async fn blend_image_node<_P: Alpha + Pixel + Debug, Forground: Sample + Transform>( + foreground: Forground, + background: ImageFrame<_P>, + map_fn: impl Node<(_P, _P), Output = _P>, +) -> ImageFrame<_P> { + blend_new_image(foreground, background, &self.map_fn) } #[derive(Debug, Clone, Copy)] @@ -246,9 +247,9 @@ where blend_new_image(background, foreground, map_fn) } -fn blend_new_image<_P: Alpha + Pixel + Debug, MapFn, Frame: Sample + Transform>(foreground: Frame, background: ImageFrame<_P>, map_fn: &MapFn) -> ImageFrame<_P> +fn blend_new_image<'input, _P: Alpha + Pixel + Debug, MapFn, Frame: Sample + Transform>(foreground: Frame, background: ImageFrame<_P>, map_fn: &'input MapFn) -> ImageFrame<_P> where - MapFn: for<'any_input> Node<'any_input, (_P, _P), Output = _P>, + MapFn: Node<'input, (_P, _P), Output = _P>, { let foreground_aabb = Bbox::unit().affine_transform(foreground.transform()).to_axis_aligned_bbox(); let background_aabb = Bbox::unit().affine_transform(background.transform()).to_axis_aligned_bbox(); @@ -275,13 +276,13 @@ where blend_image(foreground, new_background, map_fn) } -fn blend_image<_P: Alpha + Pixel + Debug, MapFn, Frame: Sample + Transform, Background: RasterMut + Transform + Sample>( +fn blend_image<'input, _P: Alpha + Pixel + Debug, MapFn, Frame: Sample + Transform, Background: RasterMut + Transform + Sample>( foreground: Frame, background: Background, - map_fn: &MapFn, + map_fn: &'input MapFn, ) -> Background where - MapFn: for<'any_input> Node<'any_input, (_P, _P), Output = _P>, + MapFn: Node<'input, (_P, _P), Output = _P>, { blend_image_closure(foreground, background, |a, b| map_fn.eval((a, b))) } diff --git a/node-graph/interpreted-executor/src/node_registry.rs b/node-graph/interpreted-executor/src/node_registry.rs index b39a62362..d59e8d128 100644 --- a/node-graph/interpreted-executor/src/node_registry.rs +++ b/node-graph/interpreted-executor/src/node_registry.rs @@ -55,7 +55,7 @@ macro_rules! register_node { Box::pin(async move { let node = construct_node!(args, $path, [$($type),*]).await; let node = graphene_std::any::FutureWrapperNode::new(node); - let any: DynAnyNode<$input, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node)); + let any: DynAnyNode<$input, _, _> = graphene_std::any::DynAnyNode::new(node); Box::new(any) as TypeErasedBox }) }, @@ -83,7 +83,7 @@ macro_rules! async_node { Box::pin(async move { args.reverse(); let node = <$path>::new($(graphene_std::any::input_node::<$type>(args.pop().expect("Not enough arguments provided to construct node"))),*); - let any: DynAnyNode<$input, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node)); + let any: DynAnyNode<$input, _, _> = graphene_std::any::DynAnyNode::new(node); Box::new(any) as TypeErasedBox }) }, @@ -122,7 +122,7 @@ macro_rules! raster_node { Box::pin(async move { let node = construct_node!(args, $path, [$($type),*]).await; let node = graphene_std::any::FutureWrapperNode::new(node); - let any: DynAnyNode = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node)); + let any: DynAnyNode = graphene_std::any::DynAnyNode::new(node); any.into_type_erased() }) }, @@ -138,7 +138,7 @@ macro_rules! raster_node { let node = construct_node!(args, $path, [$($type),*]).await; let map_node = graphene_std::raster::MapImageNode::new(graphene_core::value::ValueNode::new(node)); let map_node = graphene_std::any::FutureWrapperNode::new(map_node); - let any: DynAnyNode, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(map_node)); + let any: DynAnyNode, _, _> = graphene_std::any::DynAnyNode::new(map_node); any.into_type_erased() }) }, @@ -154,7 +154,7 @@ macro_rules! raster_node { let node = construct_node!(args, $path, [$($type),*]).await; let map_node = graphene_std::raster::MapImageNode::new(graphene_core::value::ValueNode::new(node)); let map_node = graphene_std::any::FutureWrapperNode::new(map_node); - let any: DynAnyNode, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(map_node)); + let any: DynAnyNode, _, _> = graphene_std::any::DynAnyNode::new(map_node); any.into_type_erased() }) }, @@ -238,7 +238,7 @@ fn node_registry() -> HashMap = graphene_std::any::DynAnyNode::new(ValueNode::new(final_image)); + let any: DynAnyNode<(), _, _> = graphene_std::any::DynAnyNode::new(final_image); any.into_type_erased() }) }, @@ -304,7 +304,7 @@ fn node_registry() -> HashMap = DowncastBothNode::new(args[1].clone()); //let document_node = ClonedNode::new(document_node.eval(())); let node = graphene_std::gpu_nodes::MapGpuNode::new(document_node, editor_api); - let any: DynAnyNode, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node)); + let any: DynAnyNode, _, _> = graphene_std::any::DynAnyNode::new(node); any.into_type_erased() }) }, @@ -323,7 +323,7 @@ fn node_registry() -> HashMap = DowncastBothNode::new(args[1].clone()); let opacity: DowncastBothNode<(), f32> = DowncastBothNode::new(args[2].clone()); let node = graphene_std::gpu_nodes::BlendGpuImageNode::new(background, blend_mode, opacity); - let any: DynAnyNode, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node)); + let any: DynAnyNode, _, _> = graphene_std::any::DynAnyNode::new(node); any.into_type_erased() }) @@ -366,8 +366,8 @@ fn node_registry() -> HashMap = DowncastBothNode::new(args[1].clone()); let opacity: DowncastBothNode<(), f64> = DowncastBothNode::new(args[2].clone()); let blend_node = graphene_core::raster::BlendNode::new(CopiedNode::new(blend_mode.eval(()).await), CopiedNode::new(opacity.eval(()).await)); - let node = graphene_std::raster::BlendImageNode::new(image, FutureWrapperNode::new(ValueNode::new(blend_node))); - let any: DynAnyNode, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node)); + let node = graphene_std::raster::BlendImageNode::new(image, blend_node); + let any: DynAnyNode, _, _> = graphene_std::any::DynAnyNode::new(node); any.into_type_erased() }) }, @@ -406,13 +406,13 @@ fn node_registry() -> HashMap, _, _> = graphene_std::any::DynAnyNode::new(ValueNode::new(map_image_frame_node)); + let any: DynAnyNode, _, _> = graphene_std::any::DynAnyNode::new(map_image_frame_node); any.into_type_erased() } else { let generate_brightness_contrast_mapper_node = GenerateBrightnessContrastMapperNode::new(brightness, contrast); let map_image_frame_node = graphene_std::raster::MapImageNode::new(ValueNode::new(generate_brightness_contrast_mapper_node.eval(()))); let map_image_frame_node = FutureWrapperNode::new(map_image_frame_node); - let any: DynAnyNode, _, _> = graphene_std::any::DynAnyNode::new(ValueNode::new(map_image_frame_node)); + let any: DynAnyNode, _, _> = graphene_std::any::DynAnyNode::new(map_image_frame_node); any.into_type_erased() } }) @@ -452,7 +452,8 @@ fn node_registry() -> HashMap, WasmEditorApi> = graphene_std::any::DowncastBothNode::new(args[0].clone()); let node = >::new(node); - let any: DynAnyNode<(), _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node)); + let any: DynAnyNode<(), _, _> = graphene_std::any::DynAnyNode::new(node); + any.into_type_erased() }) }, @@ -468,7 +469,7 @@ fn node_registry() -> HashMap { ImaginateNode::new($(graphene_std::any::input_node(args[$i].clone()),)* cache.into_inner()) }; } let node: ImaginateNode = instanciate_imaginate_node!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,); - let any = graphene_std::any::DynAnyNode::new(ValueNode::new(node)); + let any = graphene_std::any::DynAnyNode::new(node); any.into_type_erased() }) }, diff --git a/node-graph/node-macro/src/lib.rs b/node-graph/node-macro/src/lib.rs index 6403d6928..855e4e87f 100644 --- a/node-graph/node-macro/src/lib.rs +++ b/node-graph/node-macro/src/lib.rs @@ -2,8 +2,8 @@ use proc_macro::TokenStream; use proc_macro2::Span; use quote::{format_ident, quote, ToTokens}; use syn::{ - parse_macro_input, punctuated::Punctuated, token::Comma, FnArg, GenericParam, Ident, ItemFn, Lifetime, Pat, PatIdent, PathArguments, PredicateType, ReturnType, Token, TraitBound, Type, TypeParam, - TypeParamBound, WhereClause, WherePredicate, + parse_macro_input, punctuated::Punctuated, token::Comma, AngleBracketedGenericArguments, Binding, FnArg, GenericArgument, GenericParam, Ident, ItemFn, Lifetime, Pat, PatIdent, PathArguments, + PredicateType, ReturnType, Token, TraitBound, Type, TypeImplTrait, TypeParam, TypeParamBound, TypeTuple, WhereClause, WherePredicate, }; #[proc_macro_attribute] @@ -30,7 +30,7 @@ fn node_new_impl(attr: TokenStream, item: TokenStream) -> TokenStream { let node = &node; let node_name = &node.ident; - let mut args = args(node); + let mut args = node_args(node); let arg_idents = args .iter() @@ -38,7 +38,7 @@ fn node_new_impl(attr: TokenStream, item: TokenStream) -> TokenStream { .map(|arg| Ident::new(arg.to_token_stream().to_string().to_lowercase().as_str(), Span::call_site())) .collect::>(); - let (_, _, parameter_pat_ident_patterns) = parse_inputs(&function); + let (_, _, parameter_pat_ident_patterns) = parse_inputs(&function, false); let parameter_idents = parameter_pat_ident_patterns.iter().map(|pat_ident| &pat_ident.ident).collect::>(); // Extract the output type of the entire node - `()` by default @@ -69,7 +69,7 @@ fn node_new_impl(attr: TokenStream, item: TokenStream) -> TokenStream { .into() } -fn args(node: &syn::PathSegment) -> Vec { +fn node_args(node: &syn::PathSegment) -> Vec { match node.arguments.clone() { PathArguments::AngleBracketed(args) => args .args @@ -105,7 +105,7 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream, asyncness: Asyncness) -> let node = &node; let node_name = &node.ident; - let mut args = args(node); + let mut args = node_args(node); let async_out = match asyncness { Asyncness::Sync => false, @@ -126,13 +126,11 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream, asyncness: Asyncness) -> } }); - let (primary_input, parameter_inputs, parameter_pat_ident_patterns) = parse_inputs(&function); + let (primary_input, parameter_inputs, parameter_pat_ident_patterns) = parse_inputs(&function, true); let primary_input_ty = &primary_input.ty; let Pat::Ident(PatIdent{ident: primary_input_ident, mutability: primary_input_mutability,..} ) =&*primary_input.pat else { panic!("Expected ident as primary input."); }; - let parameter_idents = parameter_pat_ident_patterns.iter().map(|pat_ident| &pat_ident.ident).collect::>(); - let parameter_mutability = parameter_pat_ident_patterns.iter().map(|pat_ident| &pat_ident.mutability); // Extract the output type of the entire node - `()` by default let output = if let ReturnType::Type(_, ty) = &function.sig.output { @@ -141,10 +139,18 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream, asyncness: Asyncness) -> quote::quote!(()) }; - let struct_generics = (0..parameter_pat_ident_patterns.len()).map(|x| format_ident!("S{x}")).collect::>(); - let future_generics = (0..parameter_pat_ident_patterns.len()).map(|x| format_ident!("F{x}")).collect::>(); - let future_types = future_generics.iter().map(|x| Type::Verbatim(x.to_token_stream())).collect::>(); + let num_inputs = parameter_inputs.len(); + let struct_generics = (0..num_inputs).map(|x| format_ident!("S{x}")).collect::>(); + let future_generics = (0..num_inputs).map(|x| format_ident!("F{x}")).collect::>(); let parameter_types = parameter_inputs.iter().map(|x| *x.ty.clone()).collect::>(); + let future_types = future_generics + .iter() + .enumerate() + .map(|(i, x)| match parameter_types[i].clone() { + Type::ImplTrait(x) => Type::ImplTrait(x), + _ => Type::Verbatim(x.to_token_stream()), + }) + .collect::>(); for ident in struct_generics.iter() { args.push(Type::Verbatim(quote::quote!(#ident))); @@ -153,6 +159,7 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream, asyncness: Asyncness) -> // Generics are simply `S0` through to `Sn-1` where n is the number of secondary inputs let node_generics = construct_node_generics(&struct_generics); let future_generic_params = construct_node_generics(&future_generics); + let (future_parameter_types, future_generic_params): (Vec<_>, Vec<_>) = parameter_types.iter().cloned().zip(future_generic_params).filter(|(ty, _)| !matches!(ty, Type::ImplTrait(_))).unzip(); let generics = if async_in { type_generics @@ -166,12 +173,12 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream, asyncness: Asyncness) -> // Bindings for all of the above generics to a node with an input of `()` and an output of the type in the function let node_bounds = if async_in { - let mut node_bounds = input_node_bounds(future_types, node_generics, |ty| quote! {Node<'input, (), Output = #ty>}); - let future_bounds = input_node_bounds(parameter_types, future_generic_params, |ty| quote! { core::future::Future}); + let mut node_bounds = input_node_bounds(future_types, node_generics, |lifetime, in_ty, out_ty| quote! {Node<#lifetime, #in_ty, Output = #out_ty>}); + let future_bounds = input_node_bounds(future_parameter_types, future_generic_params, |_, _, out_ty| quote! { core::future::Future}); node_bounds.extend(future_bounds); node_bounds } else { - input_node_bounds(parameter_types, node_generics, |ty| quote! {Node<'input, (), Output = #ty>}) + input_node_bounds(parameter_types, node_generics, |lifetime, in_ty, out_ty| quote! {Node<#lifetime, #in_ty, Output = #out_ty>}) }; where_clause.predicates.extend(node_bounds); @@ -181,6 +188,9 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream, asyncness: Asyncness) -> quote::quote!(#output) }; + let parameter_idents = parameter_pat_ident_patterns.iter().map(|pat_ident| &pat_ident.ident).collect::>(); + let parameter_mutability = parameter_pat_ident_patterns.iter().map(|pat_ident| &pat_ident.mutability); + let parameters = if matches!(asyncness, Asyncness::AllAsync) { quote::quote!(#(let #parameter_mutability #parameter_idents = self.#parameter_idents.eval(()).await;)*) } else { @@ -209,7 +219,7 @@ fn node_impl_impl(attr: TokenStream, item: TokenStream, asyncness: Asyncness) -> .into() } -fn parse_inputs(function: &ItemFn) -> (&syn::PatType, Vec<&syn::PatType>, Vec<&PatIdent>) { +fn parse_inputs(function: &ItemFn, remove_impl_node: bool) -> (&syn::PatType, Vec<&syn::PatType>, Vec<&PatIdent>) { let mut function_inputs = function.sig.inputs.iter().filter_map(|arg| if let FnArg::Typed(typed_arg) = arg { Some(typed_arg) } else { None }); // Extract primary input as first argument @@ -217,8 +227,10 @@ fn parse_inputs(function: &ItemFn) -> (&syn::PatType, Vec<&syn::PatType>, Vec<&P // Extract secondary inputs as all other arguments let parameter_inputs = function_inputs.collect::>(); + let parameter_pat_ident_patterns = parameter_inputs .iter() + .filter(|input| !matches!(&*input.ty, Type::ImplTrait(_)) || !remove_impl_node) .map(|input| { let Pat::Ident(pat_ident) = &*input.pat else { panic!("Expected ident for secondary input."); }; pat_ident @@ -244,14 +256,43 @@ fn construct_node_generics(struct_generics: &[Ident]) -> Vec { .collect() } -fn input_node_bounds(parameter_inputs: Vec, node_generics: Vec, trait_bound: impl Fn(Type) -> proc_macro2::TokenStream) -> Vec { +fn input_node_bounds(parameter_inputs: Vec, node_generics: Vec, trait_bound: impl Fn(Lifetime, Type, Type) -> proc_macro2::TokenStream) -> Vec { parameter_inputs .iter() .zip(&node_generics) .map(|(ty, name)| { let GenericParam::Type(generic_ty) = name else { panic!("Expected type generic."); }; let ident = &generic_ty.ident; - let bound = trait_bound(ty.clone()); + let (lifetime, in_ty, out_ty) = match ty.clone() { + Type::ImplTrait(TypeImplTrait { bounds, .. }) if bounds.len() == 1 => { + let TypeParamBound::Trait(TraitBound { ref path, .. }) = bounds[0] else {panic!("impl Traits other then Node are not supported")}; + let node_segment = path.segments.last().expect("Found an empty path in the impl Trait arg"); + assert_eq!(node_segment.ident.to_string(), "Node", "Only impl Node is supported as an argument"); + let PathArguments::AngleBracketed(AngleBracketedGenericArguments {ref args, .. }) = node_segment.arguments else { panic!("Node must have generic arguments")}; + let mut args_iter = args.iter(); + let lifetime = if args.len() == 2 { + Lifetime::new("'input", Span::call_site()) + } else if let Some(GenericArgument::Lifetime(node_lifetime)) = args_iter.next() { + node_lifetime.clone() + } else { + panic!("Invalid arguments for Node trait") + }; + + let Some(GenericArgument::Type(in_ty)) = args_iter.next() else { panic!("Expected type argument in Node<> declaration")}; + let Some(GenericArgument::Binding(Binding {ty: out_ty, ..} )) = args_iter.next() else { panic!("Expected Output = in Node declaration")}; + (lifetime, in_ty.clone(), out_ty.clone()) + } + ty => ( + Lifetime::new("'input", Span::call_site()), + Type::Tuple(TypeTuple { + paren_token: syn::token::Paren { span: Span::call_site() }, + elems: Punctuated::new(), + }), + ty, + ), + }; + + let bound = trait_bound(lifetime, in_ty, out_ty); WherePredicate::Type(PredicateType { lifetimes: None, bounded_ty: Type::Verbatim(ident.to_token_stream()),