Fix the Into nodes, which were broken but unused except in GPU nodes (#2480)

* Prototype document network level into node insertion

* Fix generic type resolution

* Cleanup

* Remove network nesting
This commit is contained in:
Dennis Kobert 2025-03-27 10:11:11 +01:00 committed by GitHub
parent 92132919d1
commit 41288d7642
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 100 additions and 47 deletions

View file

@ -292,7 +292,7 @@ impl<'a> ModifyInputsContext<'a> {
// If inserting a path node, insert a flatten vector elements if the type is a graphic group.
// TODO: Allow the path node to operate on Graphic Group data by utilizing the reference for each vector data in a group.
if node_definition.identifier == "Path" {
let layer_input_type = self.network_interface.input_type(&InputConnector::node(output_layer.to_node(), 1), &[]).0.nested_type();
let layer_input_type = self.network_interface.input_type(&InputConnector::node(output_layer.to_node(), 1), &[]).0.nested_type().clone();
if layer_input_type == concrete!(GraphicGroupTable) {
let Some(flatten_vector_elements_definition) = resolve_document_node_type("Flatten Vector Elements") else {
log::error!("Flatten Vector Elements does not exist in ModifyInputsContext::existing_node_id");

View file

@ -2,8 +2,8 @@ use super::node_properties;
use super::utility_types::FrontendNodeType;
use crate::messages::layout::utility_types::widget_prelude::*;
use crate::messages::portfolio::document::utility_types::network_interface::{
DocumentNodeMetadata, DocumentNodePersistentMetadata, NodeNetworkInterface, NodeNetworkMetadata, NodeNetworkPersistentMetadata, NodeTemplate, NodeTypePersistentMetadata, NumberInputSettings,
PropertiesRow, Vec2InputSettings, WidgetOverride,
DocumentNodeMetadata, DocumentNodePersistentMetadata, NodeNetworkInterface, NodeNetworkMetadata, NodeNetworkPersistentMetadata, NodePersistentMetadata, NodePosition, NodeTemplate,
NodeTypePersistentMetadata, NumberInputSettings, PropertiesRow, Vec2InputSettings, WidgetOverride,
};
use crate::messages::portfolio::utility_types::PersistentData;
use crate::messages::prelude::Message;
@ -2663,6 +2663,7 @@ fn static_nodes() -> Vec<DocumentNodeDefinition> {
let node_registry = graphene_core::registry::NODE_REGISTRY.lock().unwrap();
'outer: for (id, metadata) in graphene_core::registry::NODE_METADATA.lock().unwrap().iter() {
use graphene_core::registry::*;
let id = id.clone();
for node in custom.iter() {
let DocumentNodeDefinition {
@ -2673,7 +2674,7 @@ fn static_nodes() -> Vec<DocumentNodeDefinition> {
..
} = node;
match implementation {
DocumentNodeImplementation::ProtoNode(ProtoNodeIdentifier { name }) if name == id => continue 'outer,
DocumentNodeImplementation::ProtoNode(ProtoNodeIdentifier { name }) if name == &id => continue 'outer,
_ => (),
}
}
@ -2685,12 +2686,12 @@ fn static_nodes() -> Vec<DocumentNodeDefinition> {
description,
properties,
} = metadata;
let Some(implementations) = &node_registry.get(id) else { continue };
let Some(implementations) = &node_registry.get(&id) else { continue };
let valid_inputs: HashSet<_> = implementations.iter().map(|(_, node_io)| node_io.call_argument.clone()).collect();
let first_node_io = implementations.first().map(|(_, node_io)| node_io).unwrap_or(const { &NodeIOTypes::empty() });
let mut input_type = &first_node_io.call_argument;
if valid_inputs.len() > 1 {
input_type = &const { generic!(T) };
input_type = &const { generic!(D) };
}
let output_type = &first_node_io.return_value;
@ -2740,6 +2741,7 @@ fn static_nodes() -> Vec<DocumentNodeDefinition> {
output_names: vec![output_type.to_string()],
has_primary_output: true,
locked: false,
..Default::default()
},
},

View file

@ -654,7 +654,7 @@ impl NodeNetworkInterface {
let input_type = self.input_type(&InputConnector::node(*node_id, iterator_index), network_path).0;
// Value inputs are stored as concrete, so they are compared to the nested type. Node inputs are stored as fn, so they are compared to the entire type.
// For example a node input of (Footprint) -> VectorData would not be compatible with () -> VectorData
node_io.inputs[iterator_index].clone().nested_type() == input_type || node_io.inputs[iterator_index] == input_type
node_io.inputs[iterator_index].clone().nested_type() == &input_type || node_io.inputs[iterator_index] == input_type
});
if valid_implementation { node_io.inputs.get(*input_index).cloned() } else { None }
})

View file

@ -420,7 +420,7 @@ impl<'a> NodeGraphLayer<'a> {
/// Check if a layer is a raster layer
pub fn is_raster_layer(layer: LayerNodeIdentifier, network_interface: &mut NodeNetworkInterface) -> bool {
let layer_input_type = network_interface.input_type(&InputConnector::node(layer.to_node(), 1), &[]).0.nested_type();
let layer_input_type = network_interface.input_type(&InputConnector::node(layer.to_node(), 1), &[]).0.nested_type().clone();
if layer_input_type == concrete!(graphene_core::raster::image::ImageFrameTable<graphene_core::Color>)
|| layer_input_type == concrete!(graphene_core::application_io::TextureFrameTable)
|| layer_input_type == concrete!(graphene_std::RasterFrame)

View file

@ -293,7 +293,7 @@ impl Type {
}
}
pub fn nested_type(self) -> Type {
pub fn nested_type(&self) -> &Type {
match self {
Self::Generic(_) => self,
Self::Concrete(_) => self,
@ -301,6 +301,18 @@ impl Type {
Self::Future(output) => output.nested_type(),
}
}
pub fn replace_nested(&mut self, f: impl Fn(&Type) -> Option<Type>) -> Option<Type> {
if let Some(replacement) = f(self) {
return Some(std::mem::replace(self, replacement));
}
match self {
Self::Generic(_) => None,
Self::Concrete(_) => None,
Self::Fn(_, output) => output.replace_nested(f),
Self::Future(output) => output.replace_nested(f),
}
}
}
fn format_type(ty: &str) -> String {

View file

@ -281,21 +281,16 @@ impl DocumentNode {
self.inputs[index] = NodeInput::Node { node_id, output_index, lambda };
let input_source = &mut self.original_location.inputs_source;
for source in source {
input_source.insert(source, index + self.original_location.skip_inputs - skip);
input_source.insert(source, (index + self.original_location.skip_inputs).saturating_sub(skip));
}
}
fn resolve_proto_node(mut self) -> ProtoNode {
assert!(!self.inputs.is_empty() || self.manual_composition.is_some(), "Resolving document node {self:#?} with no inputs");
let DocumentNodeImplementation::ProtoNode(fqn) = self.implementation else {
let DocumentNodeImplementation::ProtoNode(identifier) = self.implementation else {
unreachable!("tried to resolve not flattened node on resolved node {self:?}");
};
// TODO replace with proper generics removal
let identifier = match fqn.name.clone().split_once('<') {
Some((path, _generics)) => ProtoNodeIdentifier { name: Cow::Owned(path.to_string()) },
_ => ProtoNodeIdentifier { name: fqn.name },
};
let (input, mut args) = if let Some(ty) = self.manual_composition {
(ProtoNodeInput::ManualComposition(ty), ConstructionArgs::Nodes(vec![]))
} else {

View file

@ -696,7 +696,7 @@ impl TypingContext {
// Direct comparison of two concrete types.
(Type::Concrete(type1), Type::Concrete(type2)) => type1 == type2,
// Check inner type for futures
(Type::Future(type1), Type::Future(type2)) => type1 == type2,
(Type::Future(type1), Type::Future(type2)) => valid_type(type1, type2),
// Direct comparison of two function types.
// Note: in the presence of subtyping, functions are considered on a "greater than or equal to" basis of its function type's generality.
// That means we compare their types with a contravariant relationship, which means that a more general type signature may be substituted for a more specific type signature.
@ -728,16 +728,17 @@ impl TypingContext {
let substitution_results = valid_output_types
.iter()
.map(|node_io| {
collect_generics(node_io)
let generics_lookup: Result<HashMap<_, _>, _> = collect_generics(node_io)
.iter()
.try_for_each(|generic| check_generic(node_io, &primary_input_or_call_argument, &inputs, generic).map(|_| ()))
.map(|_| {
if let Type::Generic(out) = &node_io.return_value {
((*node_io).clone(), check_generic(node_io, &primary_input_or_call_argument, &inputs, out).unwrap())
} else {
((*node_io).clone(), node_io.return_value.clone())
}
})
.map(|generic| check_generic(node_io, &primary_input_or_call_argument, &inputs, generic).map(|x| (generic.to_string(), x)))
.collect();
generics_lookup.map(|generics_lookup| {
let orig_node_io = (*node_io).clone();
let mut new_node_io = orig_node_io.clone();
replace_generics(&mut new_node_io, &generics_lookup);
(new_node_io, orig_node_io)
})
})
.collect::<Vec<_>>();
@ -783,8 +784,8 @@ impl TypingContext {
.join("\n");
Err(vec![GraphError::new(node, GraphErrorType::InvalidImplementations { inputs, error_inputs })])
}
[(org_nio, _)] => {
let node_io = org_nio.clone();
[(node_io, org_nio)] => {
let node_io = node_io.clone();
// Save the inferred type
self.inferred.insert(node_id, node_io.clone());
@ -794,15 +795,15 @@ impl TypingContext {
// If two types are available and one of them accepts () an input, always choose that one
[first, second] => {
if first.0.call_argument != second.0.call_argument {
for (org_nio, _) in [first, second] {
if org_nio.call_argument != concrete!(()) {
for (node_io, orig_nio) in [first, second] {
if node_io.call_argument != concrete!(()) {
continue;
}
// Save the inferred type
self.inferred.insert(node_id, org_nio.clone());
self.constructor.insert(node_id, impls[org_nio]);
return Ok(org_nio.clone());
self.inferred.insert(node_id, node_io.clone());
self.constructor.insert(node_id, impls[orig_nio]);
return Ok(node_io.clone());
}
}
let inputs = [&primary_input_or_call_argument].into_iter().chain(&inputs).map(|t| t.to_string()).collect::<Vec<_>>().join(", ");
@ -821,7 +822,7 @@ impl TypingContext {
/// Returns a list of all generic types used in the node
fn collect_generics(types: &NodeIOTypes) -> Vec<Cow<'static, str>> {
let inputs = [&types.call_argument].into_iter().chain(types.inputs.iter().flat_map(|x| x.fn_output()));
let inputs = [&types.call_argument].into_iter().chain(types.inputs.iter().map(|x| x.nested_type()));
let mut generics = inputs
.filter_map(|t| match t {
Type::Generic(out) => Some(out.clone()),
@ -839,6 +840,7 @@ fn collect_generics(types: &NodeIOTypes) -> Vec<Cow<'static, str>> {
fn check_generic(types: &NodeIOTypes, input: &Type, parameters: &[Type], generic: &str) -> Result<Type, String> {
let inputs = [(Some(&types.call_argument), Some(input))]
.into_iter()
.chain(types.inputs.iter().map(|x| x.fn_input()).zip(parameters.iter().map(|x| x.fn_input())))
.chain(types.inputs.iter().map(|x| x.fn_output()).zip(parameters.iter().map(|x| x.fn_output())));
let concrete_inputs = inputs.filter(|(ni, _)| matches!(ni, Some(Type::Generic(input)) if generic == input));
let mut outputs = concrete_inputs.flat_map(|(_, out)| out);
@ -851,6 +853,21 @@ fn check_generic(types: &NodeIOTypes, input: &Type, parameters: &[Type], generic
Ok(out_ty.clone())
}
/// Returns a list of all generic types used in the node
fn replace_generics(types: &mut NodeIOTypes, lookup: &HashMap<String, Type>) {
let replace = |ty: &Type| {
let Type::Generic(ident) = ty else {
return None;
};
lookup.get(ident.as_ref()).cloned()
};
types.call_argument.replace_nested(replace);
types.return_value.replace_nested(replace);
for input in &mut types.inputs {
input.replace_nested(replace);
}
}
#[cfg(test)]
mod test {
use super::*;

View file

@ -47,8 +47,6 @@ macro_rules! async_node {
let node = <$path>::new($(
graphene_std::any::PanicNode::<$arg, core::pin::Pin<Box<dyn core::future::Future<Output = $type> + Send>>>::new()
),*);
// TODO: Propagate the future type through the node graph
// let params = vec![$(Type::Fn(Box::new(concrete!(())), Box::new(Type::Future(Box::new(concrete!($type)))))),*];
let params = vec![$(fn_type_fut!($arg, $type)),*];
let mut node_io = NodeIO::<'_, $input>::to_async_node_io(&node, params);
node_io.call_argument = concrete!(<$input as StaticType>::Static);
@ -58,6 +56,28 @@ macro_rules! async_node {
};
}
macro_rules! into_node {
(from: $from:ty, to: $to:ty) => {
(
ProtoNodeIdentifier::new(concat!["graphene_core::ops::IntoNode<", stringify!($to), ">"]),
|mut args| {
Box::pin(async move {
args.reverse();
let node = graphene_core::ops::IntoNode::<$to>::new();
let any: DynAnyNode<$from, _, _> = graphene_std::any::DynAnyNode::new(node);
Box::new(any) as TypeErasedBox
})
},
{
let node = graphene_core::ops::IntoNode::<$to>::new();
let mut node_io = NodeIO::<'_, $from>::to_async_node_io(&node, vec![]);
node_io.call_argument = future!(<$from as StaticType>::Static);
node_io
},
)
};
}
// TODO: turn into hashmap
fn node_registry() -> HashMap<ProtoNodeIdentifier, HashMap<NodeIOTypes, NodeConstructor>> {
let node_types: Vec<(ProtoNodeIdentifier, NodeConstructor, NodeIOTypes)> = vec![
@ -68,15 +88,20 @@ fn node_registry() -> HashMap<ProtoNodeIdentifier, HashMap<NodeIOTypes, NodeCons
// ),
// async_node!(graphene_core::ops::IntoNode<ImageFrameTable<SRGBA8>>, input: ImageFrameTable<Color>, params: []),
// async_node!(graphene_core::ops::IntoNode<ImageFrameTable<Color>>, input: ImageFrameTable<SRGBA8>, params: []),
async_node!(graphene_core::ops::IntoNode<GraphicGroupTable>, input: ImageFrameTable<Color>, params: []),
async_node!(graphene_core::ops::IntoNode<GraphicGroupTable>, input: VectorDataTable, params: []),
into_node!(from: f64, to: f64),
into_node!(from: ImageFrameTable<Color>, to: GraphicGroupTable),
into_node!(from: f64,to: f64),
into_node!(from: u32,to: f64),
into_node!(from: u8,to: u32),
into_node!(from: ImageFrameTable<Color>,to: GraphicGroupTable),
into_node!(from: VectorDataTable,to: GraphicGroupTable),
#[cfg(feature = "gpu")]
async_node!(graphene_core::ops::IntoNode<&WgpuExecutor>, input: &WasmEditorApi, params: []),
async_node!(graphene_core::ops::IntoNode<GraphicElement>, input: VectorDataTable, params: []),
async_node!(graphene_core::ops::IntoNode<GraphicElement>, input: ImageFrameTable<Color>, params: []),
async_node!(graphene_core::ops::IntoNode<GraphicElement>, input: GraphicGroupTable, params: []),
async_node!(graphene_core::ops::IntoNode<GraphicGroupTable>, input: VectorDataTable, params: []),
async_node!(graphene_core::ops::IntoNode<GraphicGroupTable>, input: ImageFrameTable<Color>, params: []),
into_node!(from: &WasmEditorApi,to: &WgpuExecutor),
into_node!(from: VectorDataTable,to: GraphicElement),
into_node!(from: ImageFrameTable<Color>,to: GraphicElement),
into_node!(from: GraphicGroupTable,to: GraphicElement),
into_node!(from: VectorDataTable,to: GraphicGroupTable),
into_node!(from: ImageFrameTable<Color>,to: GraphicGroupTable),
async_node!(graphene_core::memo::MonitorNode<_, _, _>, input: Context, fn_params: [Context => ImageFrameTable<Color>]),
async_node!(graphene_core::memo::MonitorNode<_, _, _>, input: Context, fn_params: [Context => ImageTexture]),
async_node!(graphene_core::memo::MonitorNode<_, _, _>, input: Context, fn_params: [Context => VectorDataTable]),
@ -304,9 +329,11 @@ fn node_registry() -> HashMap<ProtoNodeIdentifier, HashMap<NodeIOTypes, NodeCons
// This occurs for the ChannelMixerNode presumably because of the long name.
// This might be caused by the stringify! macro
let mut new_name = id.name.replace('\n', " ");
// Remove struct generics
if let Some((path, _generics)) = new_name.split_once("<") {
new_name = path.to_string();
// Remove struct generics for all nodes except for the IntoNode
if !new_name.contains("IntoNode") {
if let Some((path, _generics)) = new_name.split_once("<") {
new_name = path.to_string();
}
}
let nid = ProtoNodeIdentifier { name: Cow::Owned(new_name) };
map.entry(nid).or_default().insert(types.clone(), c);