From a40a760f27e1c247614badd3f18f5f04198334dc Mon Sep 17 00:00:00 2001 From: Dennis Kobert Date: Fri, 27 Jun 2025 01:10:14 +0200 Subject: [PATCH] Add automatic type conversion and the node graph preprocessor (#2478) * Prototype document network level into node insertion * Implement Convert trait / node for places we can't use Into * Add isize/usize and i128/u128 implementations for Convert trait * Factor out substitutions into preprocessor crate * Simplify layer node further * Code review * Mark preprocessed networks as generated * Revert changes to layer node definition * Skip generated flag for serialization * Don't expand for tests * Code review --------- Co-authored-by: Keavon Chambers --- Cargo.lock | 19 ++ Cargo.toml | 3 +- editor/Cargo.toml | 6 +- .../node_graph/document_node_definitions.rs | 114 +----------- .../document_node_derive.rs | 94 ++++++++++ .../utility_types/network_interface.rs | 6 + editor/src/node_graph_executor/runtime.rs | 11 +- node-graph/gcore/src/ops.rs | 70 ++++++++ node-graph/graph-craft/src/document.rs | 6 +- node-graph/graphene-cli/Cargo.toml | 7 +- node-graph/graphene-cli/src/main.rs | 4 + .../interpreted-executor/src/node_registry.rs | 77 +++++++- node-graph/interpreted-executor/src/util.rs | 2 + node-graph/preprocessor/Cargo.toml | 29 ++++ node-graph/preprocessor/src/lib.rs | 164 ++++++++++++++++++ 15 files changed, 484 insertions(+), 128 deletions(-) create mode 100644 editor/src/messages/portfolio/document/node_graph/document_node_definitions/document_node_derive.rs create mode 100644 node-graph/preprocessor/Cargo.toml create mode 100644 node-graph/preprocessor/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index e98a3270c..cc89fe8b2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2175,6 +2175,7 @@ dependencies = [ "graphene-std", "interpreted-executor", "log", + "preprocessor", "tokio", "wgpu", "wgpu-executor", @@ -2282,6 +2283,7 @@ dependencies = [ "log", "num_enum", "once_cell", + "preprocessor", "ron", "serde", "serde_json", @@ -4475,6 +4477,23 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" +[[package]] +name = "preprocessor" +version = "0.1.0" +dependencies = [ + "base64 0.22.1", + "dyn-any", + "futures", + "glam", + "graph-craft", + "graphene-std", + "interpreted-executor", + "log", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "presser" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index 5d3730d0a..cb81f0d14 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ members = [ "node-graph/graphene-cli", "node-graph/interpreted-executor", "node-graph/node-macro", - "node-graph/wgpu-executor", + "node-graph/preprocessor", "libraries/dyn-any", "libraries/path-bool", "libraries/bezier-rs", @@ -34,6 +34,7 @@ resolver = "2" # Local dependencies bezier-rs = { path = "libraries/bezier-rs", features = ["dyn-any", "serde"] } dyn-any = { path = "libraries/dyn-any", features = ["derive", "glam", "reqwest", "log-bad-types", "rc"] } +preprocessor = { path = "node-graph/preprocessor"} math-parser = { path = "libraries/math-parser" } path-bool = { path = "libraries/path-bool" } graphene-application-io = { path = "node-graph/gapplication-io" } diff --git a/editor/Cargo.toml b/editor/Cargo.toml index 31e54a7a1..ffc95a927 100644 --- a/editor/Cargo.toml +++ b/editor/Cargo.toml @@ -13,10 +13,7 @@ license = "Apache-2.0" [features] default = ["wasm"] wasm = ["wasm-bindgen", "graphene-std/wasm", "wasm-bindgen-futures"] -gpu = [ - "interpreted-executor/gpu", - "wgpu-executor", -] +gpu = ["interpreted-executor/gpu", "wgpu-executor"] tauri = ["ron", "decouple-execution"] decouple-execution = [] resvg = ["graphene-std/resvg"] @@ -29,6 +26,7 @@ graphite-proc-macros = { workspace = true } graph-craft = { workspace = true } interpreted-executor = { workspace = true } graphene-std = { workspace = true } +preprocessor = { workspace = true } # Workspace dependencies js-sys = { workspace = true } diff --git a/editor/src/messages/portfolio/document/node_graph/document_node_definitions.rs b/editor/src/messages/portfolio/document/node_graph/document_node_definitions.rs index d0be22688..ea7fefff2 100644 --- a/editor/src/messages/portfolio/document/node_graph/document_node_definitions.rs +++ b/editor/src/messages/portfolio/document/node_graph/document_node_definitions.rs @@ -1,3 +1,5 @@ +mod document_node_derive; + use super::node_properties::choice::enum_choice; use super::node_properties::{self, ParameterWidgetsInfo}; use super::utility_types::FrontendNodeType; @@ -91,7 +93,7 @@ static DOCUMENT_NODE_TYPES: once_cell::sync::Lazy> = /// Defines the "signature" or "header file"-like metadata for the document nodes, but not the implementation (which is defined in the node registry). /// The [`DocumentNode`] is the instance while these [`DocumentNodeDefinition`]s are the "classes" or "blueprints" from which the instances are built. fn static_nodes() -> Vec { - let mut custom = vec![ + let custom = vec![ // TODO: Auto-generate this from its proto node macro DocumentNodeDefinition { identifier: "Identity", @@ -241,21 +243,21 @@ fn static_nodes() -> Vec { DocumentNode { inputs: vec![NodeInput::network(generic!(T), 1)], implementation: DocumentNodeImplementation::proto("graphene_core::graphic_element::ToElementNode"), - manual_composition: Some(generic!(T)), + manual_composition: Some(concrete!(Context)), ..Default::default() }, // Primary (bottom) input type coercion DocumentNode { inputs: vec![NodeInput::network(generic!(T), 0)], implementation: DocumentNodeImplementation::proto("graphene_core::graphic_element::ToGroupNode"), - manual_composition: Some(generic!(T)), + manual_composition: Some(concrete!(Context)), ..Default::default() }, // The monitor node is used to display a thumbnail in the UI DocumentNode { inputs: vec![NodeInput::node(NodeId(0), 0)], implementation: DocumentNodeImplementation::proto("graphene_core::memo::MonitorNode"), - manual_composition: Some(generic!(T)), + manual_composition: Some(concrete!(Context)), skip_deduplication: true, ..Default::default() }, @@ -2114,109 +2116,7 @@ fn static_nodes() -> Vec { }, ]; - // Remove struct generics - for DocumentNodeDefinition { node_template, .. } in custom.iter_mut() { - let NodeTemplate { - document_node: DocumentNode { implementation, .. }, - .. - } = node_template; - if let DocumentNodeImplementation::ProtoNode(ProtoNodeIdentifier { name }) = implementation { - if let Some((new_name, _suffix)) = name.rsplit_once("<") { - *name = Cow::Owned(new_name.to_string()) - } - }; - } - let node_registry = graphene_std::registry::NODE_REGISTRY.lock().unwrap(); - 'outer: for (id, metadata) in graphene_std::registry::NODE_METADATA.lock().unwrap().iter() { - use graphene_std::registry::*; - let id = id.clone(); - - for node in custom.iter() { - let DocumentNodeDefinition { - node_template: NodeTemplate { - document_node: DocumentNode { implementation, .. }, - .. - }, - .. - } = node; - match implementation { - DocumentNodeImplementation::ProtoNode(ProtoNodeIdentifier { name }) if name == &id => continue 'outer, - _ => (), - } - } - - let NodeMetadata { - display_name, - category, - fields, - description, - properties, - } = metadata; - let Some(implementations) = &node_registry.get(&id) else { continue }; - let valid_inputs: HashSet<_> = implementations.iter().map(|(_, node_io)| node_io.call_argument.clone()).collect(); - let first_node_io = implementations.first().map(|(_, node_io)| node_io).unwrap_or(const { &NodeIOTypes::empty() }); - let mut input_type = &first_node_io.call_argument; - if valid_inputs.len() > 1 { - input_type = &const { generic!(D) }; - } - let output_type = &first_node_io.return_value; - - let inputs = fields - .iter() - .zip(first_node_io.inputs.iter()) - .enumerate() - .map(|(index, (field, node_io_ty))| { - let ty = field.default_type.as_ref().unwrap_or(node_io_ty); - let exposed = if index == 0 { *ty != fn_type_fut!(Context, ()) } else { field.exposed }; - - match field.value_source { - RegistryValueSource::None => {} - RegistryValueSource::Default(data) => return NodeInput::value(TaggedValue::from_primitive_string(data, ty).unwrap_or(TaggedValue::None), exposed), - RegistryValueSource::Scope(data) => return NodeInput::scope(Cow::Borrowed(data)), - }; - - if let Some(type_default) = TaggedValue::from_type(ty) { - return NodeInput::value(type_default, exposed); - } - NodeInput::value(TaggedValue::None, true) - }) - .collect(); - - let node = DocumentNodeDefinition { - identifier: display_name, - node_template: NodeTemplate { - document_node: DocumentNode { - inputs, - manual_composition: Some(input_type.clone()), - implementation: DocumentNodeImplementation::ProtoNode(id.clone().into()), - visible: true, - skip_deduplication: false, - ..Default::default() - }, - persistent_node_metadata: DocumentNodePersistentMetadata { - // TODO: Store information for input overrides in the node macro - input_properties: fields - .iter() - .map(|f| match f.widget_override { - RegistryWidgetOverride::None => (f.name, f.description).into(), - RegistryWidgetOverride::Hidden => PropertiesRow::with_override(f.name, f.description, WidgetOverride::Hidden), - RegistryWidgetOverride::String(str) => PropertiesRow::with_override(f.name, f.description, WidgetOverride::String(str.to_string())), - RegistryWidgetOverride::Custom(str) => PropertiesRow::with_override(f.name, f.description, WidgetOverride::Custom(str.to_string())), - }) - .collect(), - output_names: vec![output_type.to_string()], - has_primary_output: true, - locked: false, - ..Default::default() - }, - }, - category: category.unwrap_or("UNCATEGORIZED"), - description: Cow::Borrowed(description), - properties: *properties, - }; - custom.push(node); - } - custom + document_node_derive::post_process_nodes(custom) } // pub static IMAGINATE_NODE: Lazy = Lazy::new(|| DocumentNodeDefinition { diff --git a/editor/src/messages/portfolio/document/node_graph/document_node_definitions/document_node_derive.rs b/editor/src/messages/portfolio/document/node_graph/document_node_definitions/document_node_derive.rs new file mode 100644 index 000000000..1339621dc --- /dev/null +++ b/editor/src/messages/portfolio/document/node_graph/document_node_definitions/document_node_derive.rs @@ -0,0 +1,94 @@ +use super::DocumentNodeDefinition; +use crate::messages::portfolio::document::utility_types::network_interface::{DocumentNodePersistentMetadata, NodeTemplate, PropertiesRow, WidgetOverride}; +use graph_craft::ProtoNodeIdentifier; +use graph_craft::document::*; +use graphene_std::registry::*; +use graphene_std::*; +use std::collections::HashSet; + +pub(super) fn post_process_nodes(mut custom: Vec) -> Vec { + // Remove struct generics + for DocumentNodeDefinition { node_template, .. } in custom.iter_mut() { + let NodeTemplate { + document_node: DocumentNode { implementation, .. }, + .. + } = node_template; + + if let DocumentNodeImplementation::ProtoNode(ProtoNodeIdentifier { name }) = implementation { + if let Some((new_name, _suffix)) = name.rsplit_once("<") { + *name = Cow::Owned(new_name.to_string()) + } + }; + } + + let node_registry = graphene_core::registry::NODE_REGISTRY.lock().unwrap(); + 'outer: for (id, metadata) in NODE_METADATA.lock().unwrap().iter() { + for node in custom.iter() { + let DocumentNodeDefinition { + node_template: NodeTemplate { + document_node: DocumentNode { implementation, .. }, + .. + }, + .. + } = node; + match implementation { + DocumentNodeImplementation::ProtoNode(ProtoNodeIdentifier { name }) if name == id => continue 'outer, + _ => (), + } + } + + let NodeMetadata { + display_name, + category, + fields, + description, + properties, + } = metadata; + + let Some(implementations) = &node_registry.get(id) else { continue }; + + let valid_inputs: HashSet<_> = implementations.iter().map(|(_, node_io)| node_io.call_argument.clone()).collect(); + let first_node_io = implementations.first().map(|(_, node_io)| node_io).unwrap_or(const { &NodeIOTypes::empty() }); + + let input_type = if valid_inputs.len() > 1 { &const { generic!(D) } } else { &first_node_io.call_argument }; + let output_type = &first_node_io.return_value; + + let inputs = preprocessor::node_inputs(fields, first_node_io); + let node = DocumentNodeDefinition { + identifier: display_name, + node_template: NodeTemplate { + document_node: DocumentNode { + inputs, + manual_composition: Some(input_type.clone()), + implementation: DocumentNodeImplementation::ProtoNode(id.clone().into()), + visible: true, + skip_deduplication: false, + ..Default::default() + }, + persistent_node_metadata: DocumentNodePersistentMetadata { + // TODO: Store information for input overrides in the node macro + input_properties: fields + .iter() + .map(|f| match f.widget_override { + RegistryWidgetOverride::None => (f.name, f.description).into(), + RegistryWidgetOverride::Hidden => PropertiesRow::with_override(f.name, f.description, WidgetOverride::Hidden), + RegistryWidgetOverride::String(str) => PropertiesRow::with_override(f.name, f.description, WidgetOverride::String(str.to_string())), + RegistryWidgetOverride::Custom(str) => PropertiesRow::with_override(f.name, f.description, WidgetOverride::Custom(str.to_string())), + }) + .collect(), + output_names: vec![output_type.to_string()], + has_primary_output: true, + locked: false, + ..Default::default() + }, + }, + category: category.unwrap_or("UNCATEGORIZED"), + description: Cow::Borrowed(description), + properties: *properties, + }; + + custom.push(node); + } + + custom +} diff --git a/editor/src/messages/portfolio/document/utility_types/network_interface.rs b/editor/src/messages/portfolio/document/utility_types/network_interface.rs index 1052300bd..640c7e64a 100644 --- a/editor/src/messages/portfolio/document/utility_types/network_interface.rs +++ b/editor/src/messages/portfolio/document/utility_types/network_interface.rs @@ -6515,6 +6515,12 @@ pub struct NodePersistentMetadata { position: NodePosition, } +impl NodePersistentMetadata { + pub fn new(position: NodePosition) -> Self { + Self { position } + } +} + /// A layer can either be position as Absolute or in a Stack #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub enum LayerPosition { diff --git a/editor/src/node_graph_executor/runtime.rs b/editor/src/node_graph_executor/runtime.rs index 1bd207aa1..c949e32dd 100644 --- a/editor/src/node_graph_executor/runtime.rs +++ b/editor/src/node_graph_executor/runtime.rs @@ -45,6 +45,9 @@ pub struct NodeRuntime { /// Which node is inspected and which monitor node is used (if any) for the current execution inspect_state: Option, + /// Mapping of the fully-qualified node paths to their preprocessor substitutions. + substitutions: HashMap, + // TODO: Remove, it doesn't need to be persisted anymore /// The current renders of the thumbnails for layer nodes. thumbnail_renders: HashMap>, @@ -120,6 +123,8 @@ impl NodeRuntime { node_graph_errors: Vec::new(), monitor_nodes: Vec::new(), + substitutions: preprocessor::generate_node_substitutions(), + thumbnail_renders: Default::default(), vector_modify: Default::default(), inspect_state: None, @@ -221,11 +226,15 @@ impl NodeRuntime { } } - async fn update_network(&mut self, graph: NodeNetwork) -> Result { + async fn update_network(&mut self, mut graph: NodeNetwork) -> Result { + #[cfg(not(test))] + preprocessor::expand_network(&mut graph, &self.substitutions); + let scoped_network = wrap_network_in_scope(graph, self.editor_api.clone()); // We assume only one output assert_eq!(scoped_network.exports.len(), 1, "Graph with multiple outputs not yet handled"); + let c = Compiler {}; let proto_network = match c.compile_single(scoped_network) { Ok(network) => network, diff --git a/node-graph/gcore/src/ops.rs b/node-graph/gcore/src/ops.rs index b5a06abfc..6b968ab24 100644 --- a/node-graph/gcore/src/ops.rs +++ b/node-graph/gcore/src/ops.rs @@ -571,6 +571,76 @@ where } } +/// The [`Convert`] trait allows for conversion between Rust primitive numeric types. +/// Because number casting is lossy, we cannot use the normal [`Into`] trait like we do for other types. +pub trait Convert: Sized { + /// Converts this type into the (usually inferred) output type. + #[must_use] + fn convert(self) -> T; +} + +/// Implements the [`Convert`] trait for conversion between the cartesian product of Rust's primitive numeric types. +macro_rules! impl_convert { + ($from:ty,$to:ty) => { + impl Convert<$to> for $from { + fn convert(self) -> $to { + self as $to + } + } + }; + ($to:ty) => { + impl_convert!(f32, $to); + impl_convert!(f64, $to); + impl_convert!(i8, $to); + impl_convert!(u8, $to); + impl_convert!(u16, $to); + impl_convert!(i16, $to); + impl_convert!(i32, $to); + impl_convert!(u32, $to); + impl_convert!(i64, $to); + impl_convert!(u64, $to); + impl_convert!(i128, $to); + impl_convert!(u128, $to); + impl_convert!(isize, $to); + impl_convert!(usize, $to); + }; +} +impl_convert!(f32); +impl_convert!(f64); +impl_convert!(i8); +impl_convert!(u8); +impl_convert!(u16); +impl_convert!(i16); +impl_convert!(i32); +impl_convert!(u32); +impl_convert!(i64); +impl_convert!(u64); +impl_convert!(i128); +impl_convert!(u128); +impl_convert!(isize); +impl_convert!(usize); + +// Convert +pub struct ConvertNode(PhantomData); +impl<_O> ConvertNode<_O> { + pub const fn new() -> Self { + Self(core::marker::PhantomData) + } +} +impl<_O> Default for ConvertNode<_O> { + fn default() -> Self { + Self::new() + } +} +impl<'input, I: 'input + Convert<_O> + Sync + Send, _O: 'input> Node<'input, I> for ConvertNode<_O> { + type Output = ::dyn_any::DynFuture<'input, _O>; + + #[inline] + fn eval(&'input self, input: I) -> Self::Output { + Box::pin(async move { input.convert() }) + } +} + #[cfg(test)] mod test { use super::*; diff --git a/node-graph/graph-craft/src/document.rs b/node-graph/graph-craft/src/document.rs index 1ff4b4e35..85c15343d 100644 --- a/node-graph/graph-craft/src/document.rs +++ b/node-graph/graph-craft/src/document.rs @@ -683,6 +683,8 @@ pub struct NodeNetwork { #[serde(default)] #[serde(serialize_with = "graphene_core::vector::serialize_hashmap", deserialize_with = "graphene_core::vector::deserialize_hashmap")] pub scope_injections: FxHashMap, + #[serde(skip)] + pub generated: bool, } impl Hash for NodeNetwork { @@ -797,7 +799,9 @@ impl NodeNetwork { pub fn generate_node_paths(&mut self, prefix: &[NodeId]) { for (node_id, node) in &mut self.nodes { let mut new_path = prefix.to_vec(); - new_path.push(*node_id); + if !self.generated { + new_path.push(*node_id); + } if let DocumentNodeImplementation::Network(network) = &mut node.implementation { network.generate_node_paths(new_path.as_slice()); } diff --git a/node-graph/graphene-cli/Cargo.toml b/node-graph/graphene-cli/Cargo.toml index f3c30fe12..84bf96e7c 100644 --- a/node-graph/graphene-cli/Cargo.toml +++ b/node-graph/graphene-cli/Cargo.toml @@ -12,11 +12,7 @@ wgpu = ["wgpu-executor", "gpu", "graphene-std/wgpu"] wayland = ["graphene-std/wayland"] profiling = ["wgpu-executor/profiling"] passthrough = ["wgpu-executor/passthrough"] -gpu = [ - "interpreted-executor/gpu", - "graphene-std/gpu", - "wgpu-executor", -] +gpu = ["interpreted-executor/gpu", "graphene-std/gpu", "wgpu-executor"] [dependencies] # Local dependencies @@ -24,6 +20,7 @@ graphene-core = { workspace = true } graphene-std = { workspace = true } interpreted-executor = { workspace = true } graph-craft = { workspace = true, features = ["loading"] } +preprocessor = { workspace = true } # Workspace dependencies log = { workspace = true } diff --git a/node-graph/graphene-cli/src/main.rs b/node-graph/graphene-cli/src/main.rs index aab5ab85c..af535b736 100644 --- a/node-graph/graphene-cli/src/main.rs +++ b/node-graph/graphene-cli/src/main.rs @@ -184,7 +184,11 @@ fn compile_graph(document_string: String, editor_api: Arc) -> Res let mut network = load_network(&document_string); fix_nodes(&mut network); + let substitutions = preprocessor::generate_node_substitutions(); + preprocessor::expand_network(&mut network, &substitutions); + let wrapped_network = wrap_network_in_scope(network.clone(), editor_api); + let compiler = Compiler {}; compiler.compile_single(wrapped_network).map_err(|x| x.into()) } diff --git a/node-graph/interpreted-executor/src/node_registry.rs b/node-graph/interpreted-executor/src/node_registry.rs index 3f151c95b..4f45992c2 100644 --- a/node-graph/interpreted-executor/src/node_registry.rs +++ b/node-graph/interpreted-executor/src/node_registry.rs @@ -15,7 +15,7 @@ use graphene_std::GraphicElement; use graphene_std::any::{ComposeTypeErased, DowncastBothNode, DynAnyNode, IntoTypeErasedNode}; use graphene_std::application_io::{ImageTexture, SurfaceFrame}; use graphene_std::wasm_application_io::*; -use node_registry_macros::{async_node, into_node}; +use node_registry_macros::{async_node, convert_node, into_node}; use once_cell::sync::Lazy; use std::collections::HashMap; use std::sync::Arc; @@ -23,10 +23,7 @@ use wgpu_executor::{WgpuExecutor, WgpuSurface, WindowHandle}; // TODO: turn into hashmap fn node_registry() -> HashMap> { - let node_types: Vec<(ProtoNodeIdentifier, NodeConstructor, NodeIOTypes)> = vec![ - into_node!(from: f64, to: f64), - into_node!(from: u32, to: f64), - into_node!(from: u8, to: u32), + let mut node_types: Vec<(ProtoNodeIdentifier, NodeConstructor, NodeIOTypes)> = vec![ into_node!(from: VectorDataTable, to: VectorDataTable), into_node!(from: VectorDataTable, to: GraphicElement), into_node!(from: VectorDataTable, to: GraphicGroupTable), @@ -35,6 +32,7 @@ fn node_registry() -> HashMap, to: RasterDataTable), // into_node!(from: RasterDataTable, to: RasterDataTable), into_node!(from: RasterDataTable, to: GraphicElement), + into_node!(from: RasterDataTable, to: GraphicElement), into_node!(from: RasterDataTable, to: GraphicGroupTable), async_node!(graphene_core::memo::MonitorNode<_, _, _>, input: Context, fn_params: [Context => RasterDataTable]), async_node!(graphene_core::memo::MonitorNode<_, _, _>, input: Context, fn_params: [Context => ImageTexture]), @@ -137,6 +135,26 @@ fn node_registry() -> HashMap> = HashMap::new(); @@ -151,12 +169,14 @@ fn node_registry() -> HashMap { ( ProtoNodeIdentifier::new(concat!["graphene_core::ops::IntoNode<", stringify!($to), ">"]), - |mut args| { + |_| { Box::pin(async move { - args.reverse(); let node = graphene_core::ops::IntoNode::<$to>::new(); let any: DynAnyNode<$from, _, _> = graphene_std::any::DynAnyNode::new(node); Box::new(any) as TypeErasedBox @@ -220,7 +239,47 @@ mod node_registry_macros { ) }; } + macro_rules! convert_node { + (from: $from:ty, to: numbers) => {{ + let x: Vec<(ProtoNodeIdentifier, NodeConstructor, NodeIOTypes)> = vec![ + convert_node!(from: $from, to: f32), + convert_node!(from: $from, to: f64), + convert_node!(from: $from, to: i8), + convert_node!(from: $from, to: u8), + convert_node!(from: $from, to: u16), + convert_node!(from: $from, to: i16), + convert_node!(from: $from, to: i32), + convert_node!(from: $from, to: u32), + convert_node!(from: $from, to: i64), + convert_node!(from: $from, to: u64), + convert_node!(from: $from, to: i128), + convert_node!(from: $from, to: u128), + convert_node!(from: $from, to: isize), + convert_node!(from: $from, to: usize), + ]; + x + }}; + (from: $from:ty, to: $to:ty) => { + ( + ProtoNodeIdentifier::new(concat!["graphene_core::ops::ConvertNode<", stringify!($to), ">"]), + |_| { + Box::pin(async move { + let node = graphene_core::ops::ConvertNode::<$to>::new(); + let any: DynAnyNode<$from, _, _> = graphene_std::any::DynAnyNode::new(node); + Box::new(any) as TypeErasedBox + }) + }, + { + let node = graphene_core::ops::ConvertNode::<$to>::new(); + let mut node_io = NodeIO::<'_, $from>::to_async_node_io(&node, vec![]); + node_io.call_argument = future!(<$from as StaticType>::Static); + node_io + }, + ) + }; + } pub(crate) use async_node; + pub(crate) use convert_node; pub(crate) use into_node; } diff --git a/node-graph/interpreted-executor/src/util.rs b/node-graph/interpreted-executor/src/util.rs index 5d4e624a4..ab4c744e3 100644 --- a/node-graph/interpreted-executor/src/util.rs +++ b/node-graph/interpreted-executor/src/util.rs @@ -78,5 +78,7 @@ pub fn wrap_network_in_scope(mut network: NodeNetwork, editor_api: Arc) { + if network.generated { + return; + } + + for node in network.nodes.values_mut() { + match &mut node.implementation { + DocumentNodeImplementation::Network(node_network) => expand_network(node_network, substitutions), + DocumentNodeImplementation::ProtoNode(proto_node_identifier) => { + if let Some(new_node) = substitutions.get(proto_node_identifier.name.as_ref()) { + node.implementation = new_node.implementation.clone(); + } + } + DocumentNodeImplementation::Extract => (), + } + } +} + +pub fn generate_node_substitutions() -> HashMap { + let mut custom = HashMap::new(); + let node_registry = graphene_core::registry::NODE_REGISTRY.lock().unwrap(); + for (id, metadata) in graphene_core::registry::NODE_METADATA.lock().unwrap().iter() { + let id = id.clone(); + + let NodeMetadata { fields, .. } = metadata; + let Some(implementations) = &node_registry.get(&id) else { continue }; + let valid_inputs: HashSet<_> = implementations.iter().map(|(_, node_io)| node_io.call_argument.clone()).collect(); + let first_node_io = implementations.first().map(|(_, node_io)| node_io).unwrap_or(const { &NodeIOTypes::empty() }); + let mut node_io_types = vec![HashSet::new(); fields.len()]; + for (_, node_io) in implementations.iter() { + for (i, ty) in node_io.inputs.iter().enumerate() { + node_io_types[i].insert(ty.clone()); + } + } + let mut input_type = &first_node_io.call_argument; + if valid_inputs.len() > 1 { + input_type = &const { generic!(D) }; + } + + let inputs: Vec<_> = node_inputs(fields, first_node_io); + let input_count = inputs.len(); + let network_inputs = (0..input_count).map(|i| NodeInput::node(NodeId(i as u64), 0)).collect(); + + let identity_node = ProtoNodeIdentifier::new("graphene_core::ops::IdentityNode"); + + let into_node_registry = &interpreted_executor::node_registry::NODE_REGISTRY; + + let mut generated_nodes = 0; + let mut nodes: HashMap<_, _, _> = node_io_types + .iter() + .enumerate() + .map(|(i, inputs)| { + ( + NodeId(i as u64), + match inputs.len() { + 1 => { + let input = inputs.iter().next().unwrap(); + let input_ty = input.nested_type(); + + let into_node_identifier = ProtoNodeIdentifier { + name: format!("graphene_core::ops::IntoNode<{}>", input_ty.clone()).into(), + }; + let convert_node_identifier = ProtoNodeIdentifier { + name: format!("graphene_core::ops::ConvertNode<{}>", input_ty.clone()).into(), + }; + + let proto_node = if into_node_registry.keys().any(|ident: &ProtoNodeIdentifier| ident.name.as_ref() == into_node_identifier.name.as_ref()) { + generated_nodes += 1; + into_node_identifier + } else if into_node_registry.keys().any(|ident| ident.name.as_ref() == convert_node_identifier.name.as_ref()) { + generated_nodes += 1; + convert_node_identifier + } else { + identity_node.clone() + }; + + DocumentNode { + inputs: vec![NodeInput::network(input.clone(), i)], + // manual_composition: Some(fn_input.clone()), + implementation: DocumentNodeImplementation::ProtoNode(proto_node), + visible: true, + ..Default::default() + } + } + _ => DocumentNode { + inputs: vec![NodeInput::network(generic!(X), i)], + implementation: DocumentNodeImplementation::ProtoNode(identity_node.clone()), + visible: false, + ..Default::default() + }, + }, + ) + }) + .collect(); + + if generated_nodes == 0 { + continue; + } + + let document_node = DocumentNode { + inputs: network_inputs, + manual_composition: Some(input_type.clone()), + implementation: DocumentNodeImplementation::ProtoNode(id.clone().into()), + visible: true, + skip_deduplication: false, + ..Default::default() + }; + + nodes.insert(NodeId(input_count as u64), document_node); + + let node = DocumentNode { + inputs, + manual_composition: Some(input_type.clone()), + implementation: DocumentNodeImplementation::Network(NodeNetwork { + exports: vec![NodeInput::Node { + node_id: NodeId(input_count as u64), + output_index: 0, + lambda: false, + }], + nodes, + scope_injections: Default::default(), + generated: true, + }), + visible: true, + skip_deduplication: false, + ..Default::default() + }; + + custom.insert(id.clone(), node); + } + + custom +} + +pub fn node_inputs(fields: &[registry::FieldMetadata], first_node_io: &NodeIOTypes) -> Vec { + fields + .iter() + .zip(first_node_io.inputs.iter()) + .enumerate() + .map(|(index, (field, node_io_ty))| { + let ty = field.default_type.as_ref().unwrap_or(node_io_ty); + let exposed = if index == 0 { *ty != fn_type_fut!(Context, ()) } else { field.exposed }; + + match field.value_source { + RegistryValueSource::None => {} + RegistryValueSource::Default(data) => return NodeInput::value(TaggedValue::from_primitive_string(data, ty).unwrap_or(TaggedValue::None), exposed), + RegistryValueSource::Scope(data) => return NodeInput::scope(Cow::Borrowed(data)), + }; + + if let Some(type_default) = TaggedValue::from_type(ty) { + return NodeInput::value(type_default, exposed); + } + NodeInput::value(TaggedValue::None, true) + }) + .collect() +}