Remove lambda node inputs since they are no longer used (#3084)
Some checks are pending
Editor: Dev & CI / build (push) Waiting to run
Editor: Dev & CI / cargo-deny (push) Waiting to run

* Remove lambda node inputs as they are now unused

* Fix warnings

* Fix tests

* Fix clippy warning
This commit is contained in:
Dennis Kobert 2025-08-23 12:16:49 +02:00 committed by GitHub
parent 7377871106
commit 469f0a6c30
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 60 additions and 99 deletions

View file

@ -211,7 +211,7 @@ impl OriginalLocation {
}
impl DocumentNode {
/// Locate the input that is a [`NodeInput::Network`] at index `offset` and replace it with a [`NodeInput::Node`].
pub fn populate_first_network_input(&mut self, node_id: NodeId, output_index: usize, offset: usize, lambda: bool, source: impl Iterator<Item = Source>, skip: usize) {
pub fn populate_first_network_input(&mut self, node_id: NodeId, output_index: usize, offset: usize, source: impl Iterator<Item = Source>, skip: usize) {
let (index, _) = self
.inputs
.iter()
@ -219,7 +219,7 @@ impl DocumentNode {
.nth(offset)
.unwrap_or_else(|| panic!("no network input found for {self:#?} and offset: {offset}"));
self.inputs[index] = NodeInput::Node { node_id, output_index, lambda };
self.inputs[index] = NodeInput::Node { node_id, output_index };
let input_source = &mut self.original_location.inputs_source;
for source in source {
input_source.insert(source, (index + self.original_location.skip_inputs).saturating_sub(skip));
@ -241,9 +241,9 @@ impl DocumentNode {
assert_eq!(self.inputs.len(), 0, "A value node cannot have any inputs. Current inputs: {:?}", self.inputs);
(ProtoNodeInput::ManualComposition(concrete!(graphene_core::Context<'static>)), ConstructionArgs::Value(tagged_value))
}
NodeInput::Node { node_id, output_index, lambda } => {
NodeInput::Node { node_id, output_index } => {
assert_eq!(output_index, 0, "Outputs should be flattened before converting to proto node");
let node = if lambda { ProtoNodeInput::NodeLambda(node_id) } else { ProtoNodeInput::Node(node_id) };
let node = ProtoNodeInput::Node(node_id);
(node, ConstructionArgs::Nodes(vec![]))
}
NodeInput::Network { import_type, .. } => (ProtoNodeInput::ManualComposition(import_type), ConstructionArgs::Nodes(vec![])),
@ -266,7 +266,7 @@ impl DocumentNode {
}
if let ConstructionArgs::Nodes(nodes) = &mut args {
nodes.extend(self.inputs.iter().map(|input| match input {
NodeInput::Node { node_id, lambda, .. } => (*node_id, *lambda),
NodeInput::Node { node_id, .. } => *node_id,
_ => unreachable!(),
}));
}
@ -284,7 +284,7 @@ impl DocumentNode {
#[derive(Debug, Clone, PartialEq, Hash, DynAny, serde::Serialize, serde::Deserialize)]
pub enum NodeInput {
/// A reference to another node in the same network from which this node can receive its input.
Node { node_id: NodeId, output_index: usize, lambda: bool },
Node { node_id: NodeId, output_index: usize },
/// A hardcoded value that can't change after the graph is compiled. Gets converted into a value node during graph compilation.
Value { tagged_value: MemoHash<TaggedValue>, exposed: bool },
@ -323,11 +323,7 @@ pub enum DocumentNodeMetadata {
impl NodeInput {
pub const fn node(node_id: NodeId, output_index: usize) -> Self {
Self::Node { node_id, output_index, lambda: false }
}
pub const fn lambda(node_id: NodeId, output_index: usize) -> Self {
Self::Node { node_id, output_index, lambda: true }
Self::Node { node_id, output_index }
}
pub fn value(tagged_value: TaggedValue, exposed: bool) -> Self {
@ -344,12 +340,8 @@ impl NodeInput {
}
fn map_ids(&mut self, f: impl Fn(NodeId) -> NodeId) {
if let &mut NodeInput::Node { node_id, output_index, lambda } = self {
*self = NodeInput::Node {
node_id: f(node_id),
output_index,
lambda,
}
if let &mut NodeInput::Node { node_id, output_index } = self {
*self = NodeInput::Node { node_id: f(node_id), output_index }
}
}
@ -952,9 +944,9 @@ impl NodeNetwork {
let parent_input = node.inputs.get(*import_index).unwrap_or_else(|| panic!("Import index {import_index} should always exist"));
match *parent_input {
// If the input to self is a node, connect the corresponding output of the inner network to it
NodeInput::Node { node_id, output_index, lambda } => {
NodeInput::Node { node_id, output_index } => {
let skip = node.original_location.skip_inputs;
nested_node.populate_first_network_input(node_id, output_index, nested_input_index, lambda, node.original_location.inputs(*import_index), skip);
nested_node.populate_first_network_input(node_id, output_index, nested_input_index, node.original_location.inputs(*import_index), skip);
let input_node = self.nodes.get_mut(&node_id).unwrap_or_else(|| panic!("unable find input node {node_id:?}"));
input_node.original_location.dependants[output_index].push(nested_node_id);
}
@ -1052,7 +1044,6 @@ impl NodeNetwork {
*export = NodeInput::Node {
node_id: merged_node_id,
output_index: 0,
lambda: false,
};
}
}
@ -1319,7 +1310,7 @@ mod test {
nodes: [
id_node.clone(),
DocumentNode {
inputs: vec![NodeInput::lambda(NodeId(0), 0)],
inputs: vec![NodeInput::node(NodeId(0), 0)],
implementation: DocumentNodeImplementation::Extract,
..Default::default()
},
@ -1374,7 +1365,7 @@ mod test {
let reference = ProtoNode {
identifier: "graphene_core::structural::ConsNode".into(),
input: ProtoNodeInput::ManualComposition(concrete!(u32)),
construction_args: ConstructionArgs::Nodes(vec![(NodeId(0), false)]),
construction_args: ConstructionArgs::Nodes(vec![NodeId(0)]),
..Default::default()
};
assert_eq!(proto_node, reference);
@ -1391,7 +1382,7 @@ mod test {
ProtoNode {
identifier: "graphene_core::structural::ConsNode".into(),
input: ProtoNodeInput::ManualComposition(concrete!(u32)),
construction_args: ConstructionArgs::Nodes(vec![(NodeId(14), false)]),
construction_args: ConstructionArgs::Nodes(vec![NodeId(14)]),
original_location: OriginalLocation {
path: Some(vec![NodeId(1), NodeId(0)]),
inputs_source: [(Source { node: vec![NodeId(1)], index: 1 }, 1)].into(),

View file

@ -41,7 +41,6 @@ impl core::fmt::Display for ProtoNetwork {
ProtoNodeInput::None => f.write_str("None")?,
ProtoNodeInput::ManualComposition(ty) => f.write_fmt(format_args!("Manual Composition (type = {ty:?})"))?,
ProtoNodeInput::Node(_) => f.write_str("Node")?,
ProtoNodeInput::NodeLambda(_) => f.write_str("Lambda Node")?,
}
f.write_str("\n")?;
@ -52,7 +51,7 @@ impl core::fmt::Display for ProtoNetwork {
}
ConstructionArgs::Nodes(nodes) => {
for id in nodes {
write_node(f, network, id.0, indent + 1)?;
write_node(f, network, *id, indent + 1)?;
}
}
ConstructionArgs::Inline(inline) => {
@ -78,7 +77,7 @@ pub enum ConstructionArgs {
/// A list of nodes used as inputs to the constructor function in `node_registry.rs`.
/// The bool indicates whether to treat the node as lambda node.
// TODO: use a struct for clearer naming.
Nodes(Vec<(NodeId, bool)>),
Nodes(Vec<NodeId>),
/// Used for GPU computation to work around the limitations of rust-gpu.
Inline(InlineRust),
}
@ -121,7 +120,7 @@ impl Hash for ConstructionArgs {
impl ConstructionArgs {
pub fn new_function_args(&self) -> Vec<String> {
match self {
ConstructionArgs::Nodes(nodes) => nodes.iter().map(|(n, _)| format!("n{:0x}", n.0)).collect(),
ConstructionArgs::Nodes(nodes) => nodes.iter().map(|n| format!("n{:0x}", n.0)).collect(),
ConstructionArgs::Value(value) => vec![value.to_primitive_string()],
ConstructionArgs::Inline(inline) => vec![inline.expr.clone()],
}
@ -172,15 +171,7 @@ pub enum ProtoNodeInput {
/// Grayscale example:
///
/// We're interested in receiving an input of the desaturated image data which has been fed through a grayscale filter.
/// (If we were interested in the grayscale filter itself, we would use the `NodeLambda` variant.)
Node(NodeId),
/// Unlike the `Node` variant, with `NodeLambda` we treat the connected node singularly as a lambda node while ignoring all nodes which feed into it from upstream.
///
/// Grayscale example:
///
/// We're interested in receiving an input of a particular image filter, such as a grayscale filter in the form of a grayscale node lambda.
/// (If we were interested in some image data that had been fed through a grayscale filter, we would use the `Node` variant.)
NodeLambda(NodeId),
}
impl ProtoNode {
@ -202,8 +193,7 @@ impl ProtoNode {
ProtoNodeInput::ManualComposition(ref ty) => {
ty.hash(&mut hasher);
}
ProtoNodeInput::Node(id) => (id, false).hash(&mut hasher),
ProtoNodeInput::NodeLambda(id) => (id, true).hash(&mut hasher),
ProtoNodeInput::Node(id) => id.hash(&mut hasher),
};
Some(NodeId(hasher.finish()))
@ -230,23 +220,17 @@ impl ProtoNode {
/// Converts all references to other node IDs into new IDs by running the specified function on them.
/// This can be used when changing the IDs of the nodes, for example in the case of generating stable IDs.
pub fn map_ids(&mut self, f: impl Fn(NodeId) -> NodeId, skip_lambdas: bool) {
match self.input {
ProtoNodeInput::Node(id) => self.input = ProtoNodeInput::Node(f(id)),
ProtoNodeInput::NodeLambda(id) => {
if !skip_lambdas {
self.input = ProtoNodeInput::NodeLambda(f(id))
}
}
_ => (),
pub fn map_ids(&mut self, f: impl Fn(NodeId) -> NodeId) {
if let ProtoNodeInput::Node(id) = self.input {
self.input = ProtoNodeInput::Node(f(id))
}
if let ConstructionArgs::Nodes(ids) = &mut self.construction_args {
ids.iter_mut().filter(|(_, lambda)| !(skip_lambdas && *lambda)).for_each(|(id, _)| *id = f(*id));
ids.iter_mut().for_each(|id| *id = f(*id));
}
}
pub fn unwrap_construction_nodes(&self) -> Vec<(NodeId, bool)> {
pub fn unwrap_construction_nodes(&self) -> Vec<NodeId> {
match &self.construction_args {
ConstructionArgs::Nodes(nodes) => nodes.clone(),
_ => panic!("tried to unwrap nodes from non node construction args \n node: {self:#?}"),
@ -285,16 +269,13 @@ impl ProtoNetwork {
pub fn collect_outwards_edges(&self) -> HashMap<NodeId, Vec<NodeId>> {
let mut edges: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
for (id, node) in &self.nodes {
match &node.input {
ProtoNodeInput::Node(ref_id) | ProtoNodeInput::NodeLambda(ref_id) => {
self.check_ref(ref_id, id);
edges.entry(*ref_id).or_default().push(*id)
}
_ => (),
if let ProtoNodeInput::Node(ref_id) = &node.input {
self.check_ref(ref_id, id);
edges.entry(*ref_id).or_default().push(*id)
}
if let ConstructionArgs::Nodes(ref_nodes) = &node.construction_args {
for (ref_id, _) in ref_nodes {
for ref_id in ref_nodes {
self.check_ref(ref_id, id);
edges.entry(*ref_id).or_default().push(*id)
}
@ -313,7 +294,7 @@ impl ProtoNetwork {
let Some(sni) = self.nodes[index].1.stable_node_id() else {
panic!("failed to generate stable node id for node {:#?}", self.nodes[index].1);
};
self.replace_node_id(&outwards_edges, NodeId(index as u64), sni, false);
self.replace_node_id(&outwards_edges, NodeId(index as u64), sni);
self.nodes[index].0 = sni;
}
}
@ -323,16 +304,13 @@ impl ProtoNetwork {
pub fn collect_inwards_edges(&self) -> HashMap<NodeId, Vec<NodeId>> {
let mut edges: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
for (id, node) in &self.nodes {
match &node.input {
ProtoNodeInput::Node(ref_id) | ProtoNodeInput::NodeLambda(ref_id) => {
self.check_ref(ref_id, id);
edges.entry(*id).or_default().push(*ref_id)
}
_ => (),
if let ProtoNodeInput::Node(ref_id) = &node.input {
self.check_ref(ref_id, id);
edges.entry(*id).or_default().push(*ref_id)
}
if let ConstructionArgs::Nodes(ref_nodes) = &node.construction_args {
for (ref_id, _) in ref_nodes {
for ref_id in ref_nodes {
self.check_ref(ref_id, id);
edges.entry(*id).or_default().push(*ref_id)
}
@ -348,16 +326,13 @@ impl ProtoNetwork {
let mut inwards_edges = vec![Vec::new(); self.nodes.len()];
for (node_id, node) in &self.nodes {
let node_index = id_map[node_id];
match &node.input {
ProtoNodeInput::Node(ref_id) | ProtoNodeInput::NodeLambda(ref_id) => {
self.check_ref(ref_id, &NodeId(node_index as u64));
inwards_edges[node_index].push(id_map[ref_id]);
}
_ => {}
if let ProtoNodeInput::Node(ref_id) = &node.input {
self.check_ref(ref_id, &NodeId(node_index as u64));
inwards_edges[node_index].push(id_map[ref_id]);
}
if let ConstructionArgs::Nodes(ref_nodes) = &node.construction_args {
for (ref_id, _) in ref_nodes {
for ref_id in ref_nodes {
self.check_ref(ref_id, &NodeId(node_index as u64));
inwards_edges[node_index].push(id_map[ref_id]);
}
@ -400,14 +375,14 @@ impl ProtoNetwork {
compose_node_id,
ProtoNode {
identifier: ProtoNodeIdentifier::new("graphene_core::structural::ComposeNode"),
construction_args: ConstructionArgs::Nodes(vec![(input_node_id, false), (node_id, true)]),
construction_args: ConstructionArgs::Nodes(vec![input_node_id, node_id]),
input,
original_location: OriginalLocation { path, ..Default::default() },
skip_deduplication: false,
},
));
self.replace_node_id(&outwards_edges, node_id, compose_node_id, true);
self.replace_node_id(&outwards_edges, node_id, compose_node_id);
}
}
self.reorder_ids()?;
@ -415,12 +390,12 @@ impl ProtoNetwork {
}
/// Update all of the references to a node ID in the graph with a new ID named `compose_node_id`.
fn replace_node_id(&mut self, outwards_edges: &HashMap<NodeId, Vec<NodeId>>, node_id: NodeId, compose_node_id: NodeId, skip_lambdas: bool) {
fn replace_node_id(&mut self, outwards_edges: &HashMap<NodeId, Vec<NodeId>>, node_id: NodeId, compose_node_id: NodeId) {
// Update references in other nodes to use the new compose node
if let Some(referring_nodes) = outwards_edges.get(&node_id) {
for &referring_node_id in referring_nodes {
let (_, referring_node) = &mut self.nodes[referring_node_id.0 as usize];
referring_node.map_ids(|id| if id == node_id { compose_node_id } else { id }, skip_lambdas)
referring_node.map_ids(|id| if id == node_id { compose_node_id } else { id })
}
}
@ -508,7 +483,7 @@ impl ProtoNetwork {
for (index, &id) in order.iter().enumerate() {
let mut node = std::mem::take(&mut self.nodes[id.0 as usize].1);
// Update node references to reflect the new order
node.map_ids(|id| NodeId(*new_positions.get(&id).expect("node not found in lookup table") as u64), false);
node.map_ids(|id| NodeId(*new_positions.get(&id).expect("node not found in lookup table") as u64));
new_nodes.push((NodeId(index as u64), node));
}
@ -670,7 +645,7 @@ impl TypingContext {
// If the node has nodes as inputs we can infer the types from the node outputs
ConstructionArgs::Nodes(ref nodes) => nodes
.iter()
.map(|(id, _)| {
.map(|id| {
self.inferred
.get(id)
.ok_or_else(|| vec![GraphError::new(node, GraphErrorType::NodeNotFound(*id))])
@ -685,7 +660,7 @@ impl TypingContext {
let primary_input_or_call_argument = match node.input {
ProtoNodeInput::None => concrete!(()),
ProtoNodeInput::ManualComposition(ref ty) => ty.clone(),
ProtoNodeInput::Node(id) | ProtoNodeInput::NodeLambda(id) => {
ProtoNodeInput::Node(id) => {
let input = self.inferred.get(&id).ok_or_else(|| vec![GraphError::new(node, GraphErrorType::InputNodeNotFound(id))])?;
input.return_value.clone()
}
@ -936,7 +911,7 @@ mod test {
println!("{construction_network:#?}");
assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value");
assert_eq!(construction_network.nodes.len(), 6);
assert_eq!(construction_network.nodes[5].1.construction_args, ConstructionArgs::Nodes(vec![(NodeId(3), false), (NodeId(4), true)]));
assert_eq!(construction_network.nodes[5].1.construction_args, ConstructionArgs::Nodes(vec![(NodeId(3)), (NodeId(4))]));
}
#[test]
@ -950,11 +925,11 @@ mod test {
ids,
vec![
NodeId(16997244687192517417),
NodeId(12226224850522777131),
NodeId(9162113827627229771),
NodeId(12793582657066318419),
NodeId(16945623684036608820),
NodeId(2640415155091892458)
NodeId(7064939117677356327),
NodeId(10605314923684175783),
NodeId(6550828352538976747),
NodeId(277515424782779520),
NodeId(8855802688584342558)
]
);
}
@ -987,7 +962,7 @@ mod test {
ProtoNode {
identifier: "cons".into(),
input: ProtoNodeInput::ManualComposition(concrete!(u32)),
construction_args: ConstructionArgs::Nodes(vec![(NodeId(14), false)]),
construction_args: ConstructionArgs::Nodes(vec![NodeId(14)]),
..Default::default()
},
),