diff --git a/editor/src/messages/portfolio/document/graph_operation/utility_types.rs b/editor/src/messages/portfolio/document/graph_operation/utility_types.rs index aec230d88..f54f40837 100644 --- a/editor/src/messages/portfolio/document/graph_operation/utility_types.rs +++ b/editor/src/messages/portfolio/document/graph_operation/utility_types.rs @@ -292,7 +292,7 @@ impl<'a> ModifyInputsContext<'a> { // If inserting a path node, insert a flatten vector elements if the type is a graphic group. // TODO: Allow the path node to operate on Graphic Group data by utilizing the reference for each vector data in a group. if node_definition.identifier == "Path" { - let layer_input_type = self.network_interface.input_type(&InputConnector::node(output_layer.to_node(), 1), &[]).0.nested_type(); + let layer_input_type = self.network_interface.input_type(&InputConnector::node(output_layer.to_node(), 1), &[]).0.nested_type().clone(); if layer_input_type == concrete!(GraphicGroupTable) { let Some(flatten_vector_elements_definition) = resolve_document_node_type("Flatten Vector Elements") else { log::error!("Flatten Vector Elements does not exist in ModifyInputsContext::existing_node_id"); diff --git a/editor/src/messages/portfolio/document/node_graph/document_node_definitions.rs b/editor/src/messages/portfolio/document/node_graph/document_node_definitions.rs index d56ac6f12..9d4cecf4b 100644 --- a/editor/src/messages/portfolio/document/node_graph/document_node_definitions.rs +++ b/editor/src/messages/portfolio/document/node_graph/document_node_definitions.rs @@ -2,8 +2,8 @@ use super::node_properties; use super::utility_types::FrontendNodeType; use crate::messages::layout::utility_types::widget_prelude::*; use crate::messages::portfolio::document::utility_types::network_interface::{ - DocumentNodeMetadata, DocumentNodePersistentMetadata, NodeNetworkInterface, NodeNetworkMetadata, NodeNetworkPersistentMetadata, NodeTemplate, NodeTypePersistentMetadata, NumberInputSettings, - PropertiesRow, Vec2InputSettings, WidgetOverride, + DocumentNodeMetadata, DocumentNodePersistentMetadata, NodeNetworkInterface, NodeNetworkMetadata, NodeNetworkPersistentMetadata, NodePersistentMetadata, NodePosition, NodeTemplate, + NodeTypePersistentMetadata, NumberInputSettings, PropertiesRow, Vec2InputSettings, WidgetOverride, }; use crate::messages::portfolio::utility_types::PersistentData; use crate::messages::prelude::Message; @@ -2663,6 +2663,7 @@ fn static_nodes() -> Vec { let node_registry = graphene_core::registry::NODE_REGISTRY.lock().unwrap(); 'outer: for (id, metadata) in graphene_core::registry::NODE_METADATA.lock().unwrap().iter() { use graphene_core::registry::*; + let id = id.clone(); for node in custom.iter() { let DocumentNodeDefinition { @@ -2673,7 +2674,7 @@ fn static_nodes() -> Vec { .. } = node; match implementation { - DocumentNodeImplementation::ProtoNode(ProtoNodeIdentifier { name }) if name == id => continue 'outer, + DocumentNodeImplementation::ProtoNode(ProtoNodeIdentifier { name }) if name == &id => continue 'outer, _ => (), } } @@ -2685,12 +2686,12 @@ fn static_nodes() -> Vec { description, properties, } = metadata; - let Some(implementations) = &node_registry.get(id) else { continue }; + let Some(implementations) = &node_registry.get(&id) else { continue }; let valid_inputs: HashSet<_> = implementations.iter().map(|(_, node_io)| node_io.call_argument.clone()).collect(); let first_node_io = implementations.first().map(|(_, node_io)| node_io).unwrap_or(const { &NodeIOTypes::empty() }); let mut input_type = &first_node_io.call_argument; if valid_inputs.len() > 1 { - input_type = &const { generic!(T) }; + input_type = &const { generic!(D) }; } let output_type = &first_node_io.return_value; @@ -2740,6 +2741,7 @@ fn static_nodes() -> Vec { output_names: vec![output_type.to_string()], has_primary_output: true, locked: false, + ..Default::default() }, }, diff --git a/editor/src/messages/portfolio/document/utility_types/network_interface.rs b/editor/src/messages/portfolio/document/utility_types/network_interface.rs index ee9d469eb..aebe55f5e 100644 --- a/editor/src/messages/portfolio/document/utility_types/network_interface.rs +++ b/editor/src/messages/portfolio/document/utility_types/network_interface.rs @@ -654,7 +654,7 @@ impl NodeNetworkInterface { let input_type = self.input_type(&InputConnector::node(*node_id, iterator_index), network_path).0; // Value inputs are stored as concrete, so they are compared to the nested type. Node inputs are stored as fn, so they are compared to the entire type. // For example a node input of (Footprint) -> VectorData would not be compatible with () -> VectorData - node_io.inputs[iterator_index].clone().nested_type() == input_type || node_io.inputs[iterator_index] == input_type + node_io.inputs[iterator_index].clone().nested_type() == &input_type || node_io.inputs[iterator_index] == input_type }); if valid_implementation { node_io.inputs.get(*input_index).cloned() } else { None } }) diff --git a/editor/src/messages/tool/common_functionality/graph_modification_utils.rs b/editor/src/messages/tool/common_functionality/graph_modification_utils.rs index 3220eed0d..f3404bcd8 100644 --- a/editor/src/messages/tool/common_functionality/graph_modification_utils.rs +++ b/editor/src/messages/tool/common_functionality/graph_modification_utils.rs @@ -420,7 +420,7 @@ impl<'a> NodeGraphLayer<'a> { /// Check if a layer is a raster layer pub fn is_raster_layer(layer: LayerNodeIdentifier, network_interface: &mut NodeNetworkInterface) -> bool { - let layer_input_type = network_interface.input_type(&InputConnector::node(layer.to_node(), 1), &[]).0.nested_type(); + let layer_input_type = network_interface.input_type(&InputConnector::node(layer.to_node(), 1), &[]).0.nested_type().clone(); if layer_input_type == concrete!(graphene_core::raster::image::ImageFrameTable) || layer_input_type == concrete!(graphene_core::application_io::TextureFrameTable) || layer_input_type == concrete!(graphene_std::RasterFrame) diff --git a/node-graph/gcore/src/types.rs b/node-graph/gcore/src/types.rs index 2a3004f95..0fc03e4d9 100644 --- a/node-graph/gcore/src/types.rs +++ b/node-graph/gcore/src/types.rs @@ -293,7 +293,7 @@ impl Type { } } - pub fn nested_type(self) -> Type { + pub fn nested_type(&self) -> &Type { match self { Self::Generic(_) => self, Self::Concrete(_) => self, @@ -301,6 +301,18 @@ impl Type { Self::Future(output) => output.nested_type(), } } + + pub fn replace_nested(&mut self, f: impl Fn(&Type) -> Option) -> Option { + if let Some(replacement) = f(self) { + return Some(std::mem::replace(self, replacement)); + } + match self { + Self::Generic(_) => None, + Self::Concrete(_) => None, + Self::Fn(_, output) => output.replace_nested(f), + Self::Future(output) => output.replace_nested(f), + } + } } fn format_type(ty: &str) -> String { diff --git a/node-graph/graph-craft/src/document.rs b/node-graph/graph-craft/src/document.rs index d81458617..49479f9bf 100644 --- a/node-graph/graph-craft/src/document.rs +++ b/node-graph/graph-craft/src/document.rs @@ -281,21 +281,16 @@ impl DocumentNode { self.inputs[index] = NodeInput::Node { node_id, output_index, lambda }; let input_source = &mut self.original_location.inputs_source; for source in source { - input_source.insert(source, index + self.original_location.skip_inputs - skip); + input_source.insert(source, (index + self.original_location.skip_inputs).saturating_sub(skip)); } } fn resolve_proto_node(mut self) -> ProtoNode { assert!(!self.inputs.is_empty() || self.manual_composition.is_some(), "Resolving document node {self:#?} with no inputs"); - let DocumentNodeImplementation::ProtoNode(fqn) = self.implementation else { + let DocumentNodeImplementation::ProtoNode(identifier) = self.implementation else { unreachable!("tried to resolve not flattened node on resolved node {self:?}"); }; - // TODO replace with proper generics removal - let identifier = match fqn.name.clone().split_once('<') { - Some((path, _generics)) => ProtoNodeIdentifier { name: Cow::Owned(path.to_string()) }, - _ => ProtoNodeIdentifier { name: fqn.name }, - }; let (input, mut args) = if let Some(ty) = self.manual_composition { (ProtoNodeInput::ManualComposition(ty), ConstructionArgs::Nodes(vec![])) } else { diff --git a/node-graph/graph-craft/src/proto.rs b/node-graph/graph-craft/src/proto.rs index 0d9dead17..a3c592f7b 100644 --- a/node-graph/graph-craft/src/proto.rs +++ b/node-graph/graph-craft/src/proto.rs @@ -696,7 +696,7 @@ impl TypingContext { // Direct comparison of two concrete types. (Type::Concrete(type1), Type::Concrete(type2)) => type1 == type2, // Check inner type for futures - (Type::Future(type1), Type::Future(type2)) => type1 == type2, + (Type::Future(type1), Type::Future(type2)) => valid_type(type1, type2), // Direct comparison of two function types. // Note: in the presence of subtyping, functions are considered on a "greater than or equal to" basis of its function type's generality. // That means we compare their types with a contravariant relationship, which means that a more general type signature may be substituted for a more specific type signature. @@ -728,16 +728,17 @@ impl TypingContext { let substitution_results = valid_output_types .iter() .map(|node_io| { - collect_generics(node_io) + let generics_lookup: Result, _> = collect_generics(node_io) .iter() - .try_for_each(|generic| check_generic(node_io, &primary_input_or_call_argument, &inputs, generic).map(|_| ())) - .map(|_| { - if let Type::Generic(out) = &node_io.return_value { - ((*node_io).clone(), check_generic(node_io, &primary_input_or_call_argument, &inputs, out).unwrap()) - } else { - ((*node_io).clone(), node_io.return_value.clone()) - } - }) + .map(|generic| check_generic(node_io, &primary_input_or_call_argument, &inputs, generic).map(|x| (generic.to_string(), x))) + .collect(); + + generics_lookup.map(|generics_lookup| { + let orig_node_io = (*node_io).clone(); + let mut new_node_io = orig_node_io.clone(); + replace_generics(&mut new_node_io, &generics_lookup); + (new_node_io, orig_node_io) + }) }) .collect::>(); @@ -783,8 +784,8 @@ impl TypingContext { .join("\n"); Err(vec![GraphError::new(node, GraphErrorType::InvalidImplementations { inputs, error_inputs })]) } - [(org_nio, _)] => { - let node_io = org_nio.clone(); + [(node_io, org_nio)] => { + let node_io = node_io.clone(); // Save the inferred type self.inferred.insert(node_id, node_io.clone()); @@ -794,15 +795,15 @@ impl TypingContext { // If two types are available and one of them accepts () an input, always choose that one [first, second] => { if first.0.call_argument != second.0.call_argument { - for (org_nio, _) in [first, second] { - if org_nio.call_argument != concrete!(()) { + for (node_io, orig_nio) in [first, second] { + if node_io.call_argument != concrete!(()) { continue; } // Save the inferred type - self.inferred.insert(node_id, org_nio.clone()); - self.constructor.insert(node_id, impls[org_nio]); - return Ok(org_nio.clone()); + self.inferred.insert(node_id, node_io.clone()); + self.constructor.insert(node_id, impls[orig_nio]); + return Ok(node_io.clone()); } } let inputs = [&primary_input_or_call_argument].into_iter().chain(&inputs).map(|t| t.to_string()).collect::>().join(", "); @@ -821,7 +822,7 @@ impl TypingContext { /// Returns a list of all generic types used in the node fn collect_generics(types: &NodeIOTypes) -> Vec> { - let inputs = [&types.call_argument].into_iter().chain(types.inputs.iter().flat_map(|x| x.fn_output())); + let inputs = [&types.call_argument].into_iter().chain(types.inputs.iter().map(|x| x.nested_type())); let mut generics = inputs .filter_map(|t| match t { Type::Generic(out) => Some(out.clone()), @@ -839,6 +840,7 @@ fn collect_generics(types: &NodeIOTypes) -> Vec> { fn check_generic(types: &NodeIOTypes, input: &Type, parameters: &[Type], generic: &str) -> Result { let inputs = [(Some(&types.call_argument), Some(input))] .into_iter() + .chain(types.inputs.iter().map(|x| x.fn_input()).zip(parameters.iter().map(|x| x.fn_input()))) .chain(types.inputs.iter().map(|x| x.fn_output()).zip(parameters.iter().map(|x| x.fn_output()))); let concrete_inputs = inputs.filter(|(ni, _)| matches!(ni, Some(Type::Generic(input)) if generic == input)); let mut outputs = concrete_inputs.flat_map(|(_, out)| out); @@ -851,6 +853,21 @@ fn check_generic(types: &NodeIOTypes, input: &Type, parameters: &[Type], generic Ok(out_ty.clone()) } +/// Returns a list of all generic types used in the node +fn replace_generics(types: &mut NodeIOTypes, lookup: &HashMap) { + let replace = |ty: &Type| { + let Type::Generic(ident) = ty else { + return None; + }; + lookup.get(ident.as_ref()).cloned() + }; + types.call_argument.replace_nested(replace); + types.return_value.replace_nested(replace); + for input in &mut types.inputs { + input.replace_nested(replace); + } +} + #[cfg(test)] mod test { use super::*; diff --git a/node-graph/interpreted-executor/src/node_registry.rs b/node-graph/interpreted-executor/src/node_registry.rs index d356cd5f0..36d72c397 100644 --- a/node-graph/interpreted-executor/src/node_registry.rs +++ b/node-graph/interpreted-executor/src/node_registry.rs @@ -47,8 +47,6 @@ macro_rules! async_node { let node = <$path>::new($( graphene_std::any::PanicNode::<$arg, core::pin::Pin + Send>>>::new() ),*); - // TODO: Propagate the future type through the node graph - // let params = vec![$(Type::Fn(Box::new(concrete!(())), Box::new(Type::Future(Box::new(concrete!($type)))))),*]; let params = vec![$(fn_type_fut!($arg, $type)),*]; let mut node_io = NodeIO::<'_, $input>::to_async_node_io(&node, params); node_io.call_argument = concrete!(<$input as StaticType>::Static); @@ -58,6 +56,28 @@ macro_rules! async_node { }; } +macro_rules! into_node { + (from: $from:ty, to: $to:ty) => { + ( + ProtoNodeIdentifier::new(concat!["graphene_core::ops::IntoNode<", stringify!($to), ">"]), + |mut args| { + Box::pin(async move { + args.reverse(); + let node = graphene_core::ops::IntoNode::<$to>::new(); + let any: DynAnyNode<$from, _, _> = graphene_std::any::DynAnyNode::new(node); + Box::new(any) as TypeErasedBox + }) + }, + { + let node = graphene_core::ops::IntoNode::<$to>::new(); + let mut node_io = NodeIO::<'_, $from>::to_async_node_io(&node, vec![]); + node_io.call_argument = future!(<$from as StaticType>::Static); + node_io + }, + ) + }; +} + // TODO: turn into hashmap fn node_registry() -> HashMap> { let node_types: Vec<(ProtoNodeIdentifier, NodeConstructor, NodeIOTypes)> = vec![ @@ -68,15 +88,20 @@ fn node_registry() -> HashMap>, input: ImageFrameTable, params: []), // async_node!(graphene_core::ops::IntoNode>, input: ImageFrameTable, params: []), - async_node!(graphene_core::ops::IntoNode, input: ImageFrameTable, params: []), - async_node!(graphene_core::ops::IntoNode, input: VectorDataTable, params: []), + into_node!(from: f64, to: f64), + into_node!(from: ImageFrameTable, to: GraphicGroupTable), + into_node!(from: f64,to: f64), + into_node!(from: u32,to: f64), + into_node!(from: u8,to: u32), + into_node!(from: ImageFrameTable,to: GraphicGroupTable), + into_node!(from: VectorDataTable,to: GraphicGroupTable), #[cfg(feature = "gpu")] - async_node!(graphene_core::ops::IntoNode<&WgpuExecutor>, input: &WasmEditorApi, params: []), - async_node!(graphene_core::ops::IntoNode, input: VectorDataTable, params: []), - async_node!(graphene_core::ops::IntoNode, input: ImageFrameTable, params: []), - async_node!(graphene_core::ops::IntoNode, input: GraphicGroupTable, params: []), - async_node!(graphene_core::ops::IntoNode, input: VectorDataTable, params: []), - async_node!(graphene_core::ops::IntoNode, input: ImageFrameTable, params: []), + into_node!(from: &WasmEditorApi,to: &WgpuExecutor), + into_node!(from: VectorDataTable,to: GraphicElement), + into_node!(from: ImageFrameTable,to: GraphicElement), + into_node!(from: GraphicGroupTable,to: GraphicElement), + into_node!(from: VectorDataTable,to: GraphicGroupTable), + into_node!(from: ImageFrameTable,to: GraphicGroupTable), async_node!(graphene_core::memo::MonitorNode<_, _, _>, input: Context, fn_params: [Context => ImageFrameTable]), async_node!(graphene_core::memo::MonitorNode<_, _, _>, input: Context, fn_params: [Context => ImageTexture]), async_node!(graphene_core::memo::MonitorNode<_, _, _>, input: Context, fn_params: [Context => VectorDataTable]), @@ -304,9 +329,11 @@ fn node_registry() -> HashMap