Deprecate automatic composition (#3088)
Some checks failed
Website / build (push) Has been cancelled
Editor: Dev & CI / build (push) Waiting to run
Editor: Dev & CI / cargo-deny (push) Waiting to run

* Make manual_compositon non optional and rename to call_argument

* Fix clippy warnings

* Remove automatic composition compiler infrastructure

* Implement document migration

* Fix tests

* Fix compilation on web

* Fix doble number test

* Remove extra parens

* Cleanup

* Update demo artwork

* Remove last compose node mention

* Remove last mention of manual composition
This commit is contained in:
Dennis Kobert 2025-08-24 10:34:59 +02:00 committed by GitHub
parent bb364c92ad
commit d9cbf975ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 184 additions and 676 deletions

View file

@ -1,13 +1,13 @@
pub mod value;
use crate::document::value::TaggedValue;
use crate::proto::{ConstructionArgs, ProtoNetwork, ProtoNode, ProtoNodeInput};
use crate::proto::{ConstructionArgs, ProtoNetwork, ProtoNode};
use dyn_any::DynAny;
use glam::IVec2;
use graphene_core::memo::MemoHashGuard;
pub use graphene_core::uuid::NodeId;
pub use graphene_core::uuid::generate_uuid;
use graphene_core::{Cow, MemoHash, ProtoNodeIdentifier, Type};
use graphene_core::{Context, Cow, MemoHash, ProtoNodeIdentifier, Type};
use log::Metadata;
use rustc_hash::FxHashMap;
use std::collections::HashMap;
@ -44,101 +44,9 @@ pub struct DocumentNode {
/// by using network.update_click_target(node_id).
#[cfg_attr(target_family = "wasm", serde(alias = "outputs"))]
pub inputs: Vec<NodeInput>,
/// Manual composition is the methodology by which most nodes are implemented, involving a call argument and upstream inputs.
/// By contrast, automatic composition is an alternative way to handle the composition of nodes as they execute in the graph.
/// Normally, the program (the compiled graph) builds up its call stack, with each node calling its upstream predecessor to acquire its input data.
/// When the document graph becomes the proto graph, that conceptual model changes into a model that's unique to the proto graph.
/// Automatic composition allows a document node to be translated into its place in the proto graph differently, such that
/// the node doesn't participate in that process of being called with a call argument and calling its upstream predecessor.
/// Instead, it is called directly with its input data from the upstream node, skipping the call stack building process.
/// The abstraction is provided by the compiler for nodes which opt for automatic composition. It works by inserting a `ComposeNode`
/// into the proto graph, which does the job of calling the upstream node and feeding its output into the downstream node's first input.
/// That first input is typically used by manual composition nodes as the call argument, but for automatic composition nodes,
/// that first input becomes the input data from the upstream node passed in by the `ComposeNode`.
///
/// Through automatic composition, the upstream node providing the first input for a proto node is evaluated before the proto node itself is run.
/// (That first input is usually the call argument when manual composition is used.)
/// - Abstract example: upstream node `G` is evaluated and its data feeds into the first input of downstream node `F`,
/// just like function composition where function `G` is evaluated and its result is fed into function `F`.
/// - Concrete example: a node that takes an image as its first input will get that image data from an upstream node that produces image output data and is evaluated first before being fed downstream.
///
/// This is achieved by automatically inserting `ComposeNode`s, which run the first node with the overall input and then feed the resulting output into the second node.
/// The `ComposeNode` is basically a function composition operator: the parentheses in `F(G(x))` or circle math operator in `(F ∘ G)(x)`.
/// For flexibility, instead of being a language construct, Graphene splits out composition itself as its own low-level node so that behavior can be overridden.
/// The `ComposeNode`s are then inserted during the graph rewriting step for nodes that don't opt out with `manual_composition`.
/// Instead of node `G` feeding into node `F` feeding as the result back to the caller,
/// the graph is rewritten so nodes `G` and `F` both feed as lambdas into the inputs of a `ComposeNode` which calls `F(G(input))` and returns the result to the caller.
///
/// A node's manual composition input represents an input that is not resolved through graph rewriting with a `ComposeNode`,
/// and is instead just passed in when evaluating this node within the borrow tree.
/// This is similar to having the first input be a `NodeInput::Network` after the graph flattening.
///
/// ## Example Use Case: CacheNode
///
/// The `CacheNode` is a pass-through node on cache miss, but on cache hit it needs to avoid evaluating the upstream node and instead just return the cached value.
///
/// First, let's consider what that would look like using the default composition flow if the `CacheNode` instead just always acted as a pass-through (akin to a cache that always misses):
///
/// ```text
/// ┌───────────────┐ ┌───────────────┐ ┌───────────────┐
/// │ │◄───┤ │◄───┤ │◄─── EVAL (START)
/// │ G │ │PassThroughNode│ │ F │
/// │ ├───►│ ├───►│ │───► RESULT (END)
/// └───────────────┘ └───────────────┘ └───────────────┘
/// ```
///
/// This acts like the function call `F(PassThroughNode(G(input)))` when evaluating `F` with some `input`: `F.eval(input)`.
/// - The diagram's upper track of arrows represents the flow of building up the call stack:
/// since `F` is the output it is encountered first but deferred to its upstream caller `PassThroughNode` and that is once again deferred to its upstream caller `G`.
/// - The diagram's lower track of arrows represents the flow of evaluating the call stack:
/// `G` is evaluated first, then `PassThroughNode` is evaluated with the result of `G`, and finally `F` is evaluated with the result of `PassThroughNode`.
///
/// With the default composition flow (no manual composition), `ComposeNode`s would be automatically inserted during the graph rewriting step like this:
///
/// ```text
/// ┌───────────────┐
/// │ │◄─── EVAL (START)
/// │ ComposeNode │
/// ┌───────────────┐ │ ├───► RESULT (END)
/// │ │◄─┐ ├───────────────┤
/// │ G │ └─┤ │
/// │ ├─┐ │ First │
/// └───────────────┘ └─►│ │
/// ┌───────────────┐ ├───────────────┤
/// │ │◄───┤ │
/// │ ComposeNode │ │ Second │
/// ┌───────────────┐ │ ├───►│ │
/// │ │◄─┐ ├───────────────┤ └───────────────┘
/// │PassThroughNode│ └─┤ │
/// │ ├─┐ │ First │
/// └───────────────┘ └─►│ │
/// ┌───────────────┐ ├───────────────┤
/// | │◄───┤ │
/// │ F │ │ Second │
/// │ ├───►│ │
/// └───────────────┘ └───────────────┘
/// ```
///
/// Now let's swap back from the `PassThroughNode` to the `CacheNode` to make caching actually work.
/// It needs to override the default composition flow so that `G` is not automatically evaluated when the cache is hit.
/// We need to give the `CacheNode` more manual control over the order of execution.
/// So the `CacheNode` opts into manual composition and, instead of deferring to its upstream caller, it consumes the input directly:
///
/// ```text
/// ┌───────────────┐ ┌───────────────┐
/// │ │◄───┤ │◄─── EVAL (START)
/// │ CacheNode │ │ F │
/// │ ├───►│ │───► RESULT (END)
/// ┌───────────────┐ ├───────────────┤ └───────────────┘
/// │ │◄───┤ │
/// │ G │ │ Cached Data │
/// │ ├───►│ │
/// └───────────────┘ └───────────────┘
/// ```
///
/// Now, the call from `F` directly reaches the `CacheNode` and the `CacheNode` can decide whether to call `G.eval(input_from_f)`
/// in the event of a cache miss or just return the cached data in the event of a cache hit.
pub manual_composition: Option<Type>,
/// Type of the argument which this node can be evaluated with.
#[serde(alias = "manual_composition", default)]
pub call_argument: Type,
// A nested document network or a proto-node identifier.
pub implementation: DocumentNodeImplementation,
/// Represents the eye icon for hiding/showing the node in the graph UI. When hidden, a node gets replaced with an identity node during the graph flattening step.
@ -173,15 +81,13 @@ pub struct OriginalLocation {
pub dependants: Vec<Vec<NodeId>>,
/// A list of flags indicating whether the input is exposed in the UI
pub inputs_exposed: Vec<bool>,
/// Skipping inputs is useful for the manual composition thing - whereby a hidden `Footprint` input is added as the first input.
pub skip_inputs: usize,
}
impl Default for DocumentNode {
fn default() -> Self {
Self {
inputs: Default::default(),
manual_composition: Default::default(),
call_argument: concrete!(Context),
implementation: Default::default(),
visible: true,
skip_deduplication: Default::default(),
@ -195,14 +101,13 @@ impl Hash for OriginalLocation {
self.path.hash(state);
self.inputs_source.iter().for_each(|val| val.hash(state));
self.inputs_exposed.hash(state);
self.skip_inputs.hash(state);
}
}
impl OriginalLocation {
pub fn inputs(&self, index: usize) -> impl Iterator<Item = Source> + '_ {
[(index >= self.skip_inputs).then(|| Source {
[(index >= 1).then(|| Source {
node: self.path.clone().unwrap_or_default(),
index: self.inputs_exposed.iter().take(index - self.skip_inputs).filter(|&&exposed| exposed).count(),
index: self.inputs_exposed.iter().take(index - 1).filter(|&&exposed| exposed).count(),
})]
.into_iter()
.flatten()
@ -222,48 +127,26 @@ impl DocumentNode {
self.inputs[index] = NodeInput::Node { node_id, output_index };
let input_source = &mut self.original_location.inputs_source;
for source in source {
input_source.insert(source, (index + self.original_location.skip_inputs).saturating_sub(skip));
input_source.insert(source, (index + 1).saturating_sub(skip));
}
}
fn resolve_proto_node(mut self) -> ProtoNode {
assert!(!self.inputs.is_empty() || self.manual_composition.is_some(), "Resolving document node {self:#?} with no inputs");
fn resolve_proto_node(self) -> ProtoNode {
let DocumentNodeImplementation::ProtoNode(identifier) = self.implementation else {
unreachable!("tried to resolve not flattened node on resolved node {self:?}");
};
let (input, mut args) = if let Some(ty) = self.manual_composition {
(ProtoNodeInput::ManualComposition(ty), ConstructionArgs::Nodes(vec![]))
} else {
let first = self.inputs.remove(0);
match first {
NodeInput::Value { tagged_value, .. } => {
assert_eq!(self.inputs.len(), 0, "A value node cannot have any inputs. Current inputs: {:?}", self.inputs);
(ProtoNodeInput::ManualComposition(concrete!(graphene_core::Context<'static>)), ConstructionArgs::Value(tagged_value))
}
NodeInput::Node { node_id, output_index } => {
assert_eq!(output_index, 0, "Outputs should be flattened before converting to proto node");
let node = ProtoNodeInput::Node(node_id);
(node, ConstructionArgs::Nodes(vec![]))
}
NodeInput::Network { import_type, .. } => (ProtoNodeInput::ManualComposition(import_type), ConstructionArgs::Nodes(vec![])),
NodeInput::Inline(inline) => (ProtoNodeInput::None, ConstructionArgs::Inline(inline)),
NodeInput::Scope(_) => unreachable!("Scope input was not resolved"),
NodeInput::Reflection(_) => unreachable!("Reflection input was not resolved"),
}
};
let (input, mut args) = (self.call_argument, ConstructionArgs::Nodes(vec![]));
assert!(!self.inputs.iter().any(|input| matches!(input, NodeInput::Network { .. })), "received non-resolved input");
assert!(
!self.inputs.iter().any(|input| matches!(input, NodeInput::Value { .. })),
"received value as input. inputs: {:#?}, construction_args: {:#?}",
self.inputs,
args
);
// If we have one input of the type inline, set it as the construction args
if let &[NodeInput::Inline(ref inline)] = self.inputs.as_slice() {
args = ConstructionArgs::Inline(inline.clone());
}
// If we have one input of the type inline, set it as the construction args
if let &[NodeInput::Value { ref tagged_value, .. }] = self.inputs.as_slice() {
args = ConstructionArgs::Value(tagged_value.clone());
}
if let ConstructionArgs::Nodes(nodes) = &mut args {
nodes.extend(self.inputs.iter().map(|input| match input {
NodeInput::Node { node_id, .. } => *node_id,
@ -272,7 +155,7 @@ impl DocumentNode {
}
ProtoNode {
identifier,
input,
call_argument: input,
construction_args: args,
original_location: self.original_location,
skip_deduplication: self.skip_deduplication,
@ -764,7 +647,6 @@ impl NodeNetwork {
node.original_location = OriginalLocation {
path: Some(new_path),
inputs_exposed: node.inputs.iter().map(|input| input.is_exposed()).collect(),
skip_inputs: if node.manual_composition.is_some() { 1 } else { 0 },
dependants: (0..node.implementation.output_count()).map(|_| Vec::new()).collect(),
..Default::default()
};
@ -891,7 +773,7 @@ impl NodeNetwork {
// Connect layer node to the group below
node.inputs.drain(1..);
node.manual_composition = None;
node.call_argument = concrete!(());
self.nodes.insert(id, node);
return;
}
@ -945,8 +827,7 @@ impl NodeNetwork {
match *parent_input {
// If the input to self is a node, connect the corresponding output of the inner network to it
NodeInput::Node { node_id, output_index } => {
let skip = node.original_location.skip_inputs;
nested_node.populate_first_network_input(node_id, output_index, nested_input_index, node.original_location.inputs(*import_index), skip);
nested_node.populate_first_network_input(node_id, output_index, nested_input_index, node.original_location.inputs(*import_index), 1);
let input_node = self.nodes.get_mut(&node_id).unwrap_or_else(|| panic!("unable find input node {node_id:?}"));
input_node.original_location.dependants[output_index].push(nested_node_id);
}
@ -1048,15 +929,6 @@ impl NodeNetwork {
}
}
// /// Locate the export that is a [`NodeInput::Network`] at index `offset` and replace it with a [`NodeInput::Node`].
// fn populate_first_network_export(&mut self, node: &mut DocumentNode, node_id: NodeId, output_index: usize, lambda: bool, export_index: usize, source: impl Iterator<Item = Source>, skip: usize) {
// self.exports[export_index] = NodeInput::Node { node_id, output_index, lambda };
// let input_source = &mut node.original_location.inputs_source;
// for source in source {
// input_source.insert(source, output_index + node.original_location.skip_inputs - skip);
// }
// }
fn remove_id_node(&mut self, id: NodeId) -> Result<(), String> {
let node = self.nodes.get(&id).ok_or_else(|| format!("Node with id {id} does not exist"))?.clone();
if let DocumentNodeImplementation::ProtoNode(ident) = &node.implementation {
@ -1086,7 +958,7 @@ impl NodeNetwork {
let input_source = &mut output.original_location.inputs_source;
for source in node.original_location.inputs(index) {
input_source.insert(source, index + output.original_location.skip_inputs - node.original_location.skip_inputs);
input_source.insert(source, index);
}
}
}
@ -1231,7 +1103,7 @@ impl<'a> Iterator for RecursiveNodeIter<'a> {
#[cfg(test)]
mod test {
use super::*;
use crate::proto::{ConstructionArgs, ProtoNetwork, ProtoNode, ProtoNodeInput};
use crate::proto::{ConstructionArgs, ProtoNetwork, ProtoNode};
use std::sync::atomic::AtomicU64;
fn gen_node_id() -> NodeId {
@ -1356,7 +1228,8 @@ mod test {
#[test]
fn resolve_proto_node_add() {
let document_node = DocumentNode {
inputs: vec![NodeInput::network(concrete!(u32), 0), NodeInput::node(NodeId(0), 0)],
inputs: vec![NodeInput::node(NodeId(0), 0)],
call_argument: concrete!(u32),
implementation: DocumentNodeImplementation::ProtoNode("graphene_core::structural::ConsNode".into()),
..Default::default()
};
@ -1364,7 +1237,7 @@ mod test {
let proto_node = document_node.resolve_proto_node();
let reference = ProtoNode {
identifier: "graphene_core::structural::ConsNode".into(),
input: ProtoNodeInput::ManualComposition(concrete!(u32)),
call_argument: concrete!(u32),
construction_args: ConstructionArgs::Nodes(vec![NodeId(0)]),
..Default::default()
};
@ -1381,13 +1254,12 @@ mod test {
NodeId(10),
ProtoNode {
identifier: "graphene_core::structural::ConsNode".into(),
input: ProtoNodeInput::ManualComposition(concrete!(u32)),
call_argument: concrete!(u32),
construction_args: ConstructionArgs::Nodes(vec![NodeId(14)]),
original_location: OriginalLocation {
path: Some(vec![NodeId(1), NodeId(0)]),
inputs_source: [(Source { node: vec![NodeId(1)], index: 1 }, 1)].into(),
inputs_exposed: vec![true, true],
skip_inputs: 0,
..Default::default()
},
@ -1398,13 +1270,12 @@ mod test {
NodeId(11),
ProtoNode {
identifier: "graphene_core::ops::AddPairNode".into(),
input: ProtoNodeInput::Node(NodeId(10)),
construction_args: ConstructionArgs::Nodes(vec![]),
call_argument: concrete!(Context),
construction_args: ConstructionArgs::Nodes(vec![NodeId(10)]),
original_location: OriginalLocation {
path: Some(vec![NodeId(1), NodeId(1)]),
inputs_source: HashMap::new(),
inputs_exposed: vec![true],
skip_inputs: 0,
..Default::default()
},
..Default::default()
@ -1414,13 +1285,12 @@ mod test {
NodeId(14),
ProtoNode {
identifier: "graphene_core::value::ClonedNode".into(),
input: ProtoNodeInput::ManualComposition(concrete!(graphene_core::Context)),
call_argument: concrete!(graphene_core::Context),
construction_args: ConstructionArgs::Value(TaggedValue::U32(2).into()),
original_location: OriginalLocation {
path: Some(vec![NodeId(1), NodeId(4)]),
inputs_source: HashMap::new(),
inputs_exposed: vec![true, false],
skip_inputs: 0,
..Default::default()
},
..Default::default()
@ -1446,13 +1316,13 @@ mod test {
(
NodeId(10),
DocumentNode {
inputs: vec![NodeInput::network(concrete!(u32), 0), NodeInput::node(NodeId(14), 0)],
inputs: vec![NodeInput::node(NodeId(14), 0)],
call_argument: concrete!(u32),
implementation: DocumentNodeImplementation::ProtoNode("graphene_core::structural::ConsNode".into()),
original_location: OriginalLocation {
path: Some(vec![NodeId(1), NodeId(0)]),
inputs_source: [(Source { node: vec![NodeId(1)], index: 1 }, 1)].into(),
inputs_exposed: vec![true, true],
skip_inputs: 0,
..Default::default()
},
..Default::default()
@ -1467,7 +1337,6 @@ mod test {
path: Some(vec![NodeId(1), NodeId(4)]),
inputs_source: HashMap::new(),
inputs_exposed: vec![true, false],
skip_inputs: 0,
..Default::default()
},
..Default::default()
@ -1482,7 +1351,6 @@ mod test {
path: Some(vec![NodeId(1), NodeId(1)]),
inputs_source: HashMap::new(),
inputs_exposed: vec![true],
skip_inputs: 0,
..Default::default()
},
..Default::default()
@ -1567,49 +1435,4 @@ mod test {
}
// TODO: Write more tests
// #[test]
// fn out_of_order_duplicate() {
// let result = output_duplicate(vec![NodeInput::node(NodeId(10), 1), NodeInput::node(NodeId(10), 0)], NodeInput::node(NodeId(10), 0);
// assert_eq!(
// result.outputs[0],
// NodeInput::node(NodeId(101), 0),
// "The first network output should be from a duplicated nested network"
// );
// assert_eq!(
// result.outputs[1],
// NodeInput::node(NodeId(10), 0),
// "The second network output should be from the original nested network"
// );
// assert!(
// result.nodes.contains_key(&NodeId(10)) && result.nodes.contains_key(&NodeId(101)) && result.nodes.len() == 2,
// "Network should contain two duplicated nodes"
// );
// for (node_id, input_value, inner_id) in [(10, 1., 1), (101, 2., 2)] {
// let nested_network_node = result.nodes.get(&NodeId(node_id)).unwrap();
// assert_eq!(nested_network_node.name, "Nested network".to_string(), "Name should not change");
// assert_eq!(nested_network_node.inputs, vec![NodeInput::value(TaggedValue::F32(input_value), false)], "Input should be stable");
// let inner_network = nested_network_node.implementation.get_network().expect("Implementation should be network");
// assert_eq!(inner_network.inputs, vec![inner_id], "The input should be sent to the second node");
// assert_eq!(inner_network.outputs, vec![NodeInput::node(NodeId(inner_id), 0)], "The output should be node id");
// assert_eq!(inner_network.nodes.get(&NodeId(inner_id)).unwrap().name, format!("Identity {inner_id}"), "The node should be identity");
// }
// }
// #[test]
// fn using_other_node_duplicate() {
// let result = output_duplicate(vec![NodeInput::node(NodeId(11), 0)], NodeInput::node(NodeId(10), 1);
// assert_eq!(result.outputs, vec![NodeInput::node(NodeId(11), 0)], "The network output should be the result node");
// assert!(
// result.nodes.contains_key(&NodeId(11)) && result.nodes.contains_key(&NodeId(101)) && result.nodes.len() == 2,
// "Network should contain a duplicated node and a result node"
// );
// let result_node = result.nodes.get(&NodeId(11)).unwrap();
// assert_eq!(result_node.inputs, vec![NodeInput::node(NodeId(101), 0)], "Result node should refer to duplicate node as input");
// let nested_network_node = result.nodes.get(&NodeId(101)).unwrap();
// assert_eq!(nested_network_node.name, "Nested network".to_string(), "Name should not change");
// assert_eq!(nested_network_node.inputs, vec![NodeInput::value(TaggedValue::F32(2.), false)], "Input should be 2");
// let inner_network = nested_network_node.implementation.get_network().expect("Implementation should be network");
// assert_eq!(inner_network.inputs, vec![2], "The input should be sent to the second node");
// assert_eq!(inner_network.outputs, vec![NodeInput::node(NodeId(2), 0)], "The output should be node id 2");
// assert_eq!(inner_network.nodes.get(&NodeId(2)).unwrap().name, "Identity 2", "The node should be identity 2");
// }
}

View file

@ -37,11 +37,7 @@ impl core::fmt::Display for ProtoNetwork {
f.write_str(&"\t".repeat(indent + 1))?;
f.write_str("Input: ")?;
match &node.input {
ProtoNodeInput::None => f.write_str("None")?,
ProtoNodeInput::ManualComposition(ty) => f.write_fmt(format_args!("Manual Composition (type = {ty:?})"))?,
ProtoNodeInput::Node(_) => f.write_str("Node")?,
}
f.write_fmt(format_args!("Call Argument (type = {:?})", node.call_argument))?;
f.write_str("\n")?;
match &node.construction_args {
@ -132,7 +128,7 @@ impl ConstructionArgs {
/// At different stages in the compilation process, this struct will be transformed into a reduced (more restricted) form acting as a subset of its original form, but that restricted form is still valid in the earlier stage in the compilation process before it was transformed.
pub struct ProtoNode {
pub construction_args: ConstructionArgs,
pub input: ProtoNodeInput,
pub call_argument: Type,
pub identifier: ProtoNodeIdentifier,
pub original_location: OriginalLocation,
pub skip_deduplication: bool,
@ -143,37 +139,13 @@ impl Default for ProtoNode {
Self {
identifier: ProtoNodeIdentifier::new("graphene_core::ops::IdentityNode"),
construction_args: ConstructionArgs::Value(value::TaggedValue::U32(0).into()),
input: ProtoNodeInput::None,
call_argument: concrete!(()),
original_location: OriginalLocation::default(),
skip_deduplication: false,
}
}
}
/// Similar to the document node's [`crate::document::NodeInput`].
#[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)]
pub enum ProtoNodeInput {
/// This input will be converted to `()` as the call argument.
None,
/// A ManualComposition input represents an input that opts out of being resolved through the `ComposeNode`, which first runs the previous (upstream) node, then passes that evaluated
/// result to this node. Instead, ManualComposition lets this node actually consume the provided input instead of passing it to its predecessor.
///
/// Say we have the network `a -> b -> c` where `c` is the output node and `a` is the input node.
/// We would expect `a` to get input from the network, `b` to get input from `a`, and `c` to get input from `b`.
/// This could be represented as `f(x) = c(b(a(x)))`. `a` is run with input `x` from the network. `b` is run with input from `a`. `c` is run with input from `b`.
///
/// However if `b`'s input is using manual composition, this means it would instead be `f(x) = c(b(x))`. This means that `b` actually gets input from the network, and `a` is not automatically
/// executed as it would be using the default ComposeNode flow. Now `b` can use its own logic to decide when or if it wants to run `a` and how to use its output. For example, the CacheNode can
/// look up `x` in its cache and return the result, or otherwise call `a`, cache the result, and return it.
ManualComposition(Type),
/// The previous node where automatic (not manual) composition occurs when compiled. The entire network, of which the node is the output, is fed as input.
///
/// Grayscale example:
///
/// We're interested in receiving an input of the desaturated image data which has been fed through a grayscale filter.
Node(NodeId),
}
impl ProtoNode {
/// A stable node ID is a hash of a node that should stay constant. This is used in order to remove duplicates from the graph.
/// In the case of `skip_deduplication`, the `document_node_path` is also hashed in order to avoid duplicate monitor nodes from being removed (which would make it impossible to load thumbnails).
@ -187,14 +159,8 @@ impl ProtoNode {
self.original_location.path.hash(&mut hasher);
}
std::mem::discriminant(&self.input).hash(&mut hasher);
match self.input {
ProtoNodeInput::None => (),
ProtoNodeInput::ManualComposition(ref ty) => {
ty.hash(&mut hasher);
}
ProtoNodeInput::Node(id) => id.hash(&mut hasher),
};
std::mem::discriminant(&self.call_argument).hash(&mut hasher);
self.call_argument.hash(&mut hasher);
Some(NodeId(hasher.finish()))
}
@ -208,7 +174,7 @@ impl ProtoNode {
Self {
identifier: ProtoNodeIdentifier::new("graphene_core::value::ClonedNode"),
construction_args: value,
input: ProtoNodeInput::ManualComposition(concrete!(Context)),
call_argument: concrete!(Context),
original_location: OriginalLocation {
path: Some(path),
inputs_exposed: vec![false; inputs_exposed],
@ -221,10 +187,6 @@ impl ProtoNode {
/// Converts all references to other node IDs into new IDs by running the specified function on them.
/// This can be used when changing the IDs of the nodes, for example in the case of generating stable IDs.
pub fn map_ids(&mut self, f: impl Fn(NodeId) -> NodeId) {
if let ProtoNodeInput::Node(id) = self.input {
self.input = ProtoNodeInput::Node(f(id))
}
if let ConstructionArgs::Nodes(ids) = &mut self.construction_args {
ids.iter_mut().for_each(|id| *id = f(*id));
}
@ -269,11 +231,6 @@ impl ProtoNetwork {
pub fn collect_outwards_edges(&self) -> HashMap<NodeId, Vec<NodeId>> {
let mut edges: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
for (id, node) in &self.nodes {
if let ProtoNodeInput::Node(ref_id) = &node.input {
self.check_ref(ref_id, id);
edges.entry(*ref_id).or_default().push(*id)
}
if let ConstructionArgs::Nodes(ref_nodes) = &node.construction_args {
for ref_id in ref_nodes {
self.check_ref(ref_id, id);
@ -304,11 +261,6 @@ impl ProtoNetwork {
pub fn collect_inwards_edges(&self) -> HashMap<NodeId, Vec<NodeId>> {
let mut edges: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
for (id, node) in &self.nodes {
if let ProtoNodeInput::Node(ref_id) = &node.input {
self.check_ref(ref_id, id);
edges.entry(*id).or_default().push(*ref_id)
}
if let ConstructionArgs::Nodes(ref_nodes) = &node.construction_args {
for ref_id in ref_nodes {
self.check_ref(ref_id, id);
@ -326,10 +278,6 @@ impl ProtoNetwork {
let mut inwards_edges = vec![Vec::new(); self.nodes.len()];
for (node_id, node) in &self.nodes {
let node_index = id_map[node_id];
if let ProtoNodeInput::Node(ref_id) = &node.input {
self.check_ref(ref_id, &NodeId(node_index as u64));
inwards_edges[node_index].push(id_map[ref_id]);
}
if let ConstructionArgs::Nodes(ref_nodes) = &node.construction_args {
for ref_id in ref_nodes {
@ -342,70 +290,31 @@ impl ProtoNetwork {
(inwards_edges, id_map)
}
/// Inserts a [`structural::ComposeNode`] for each node that has a [`ProtoNodeInput::Node`]. The compose node evaluates the first node, and then sends the result into the second node.
/// Performs topological sort and reorders ids.
pub fn resolve_inputs(&mut self) -> Result<(), String> {
// Perform topological sort once
self.reorder_ids()?;
let max_id = self.nodes.len() as u64 - 1;
// Collect outward edges once
let outwards_edges = self.collect_outwards_edges();
// Iterate over nodes in topological order
for node_id in 0..=max_id {
let node_id = NodeId(node_id);
let (_, node) = &mut self.nodes[node_id.0 as usize];
if let ProtoNodeInput::Node(input_node_id) = node.input {
// Create a new node that composes the current node and its input node
let compose_node_id = NodeId(self.nodes.len() as u64);
let (_, input_node_id_proto) = &self.nodes[input_node_id.0 as usize];
let input = input_node_id_proto.input.clone();
let mut path = input_node_id_proto.original_location.path.clone();
if let Some(path) = &mut path {
path.push(node_id);
}
self.nodes.push((
compose_node_id,
ProtoNode {
identifier: ProtoNodeIdentifier::new("graphene_core::structural::ComposeNode"),
construction_args: ConstructionArgs::Nodes(vec![input_node_id, node_id]),
input,
original_location: OriginalLocation { path, ..Default::default() },
skip_deduplication: false,
},
));
self.replace_node_id(&outwards_edges, node_id, compose_node_id);
}
}
self.reorder_ids()?;
Ok(())
}
/// Update all of the references to a node ID in the graph with a new ID named `compose_node_id`.
fn replace_node_id(&mut self, outwards_edges: &HashMap<NodeId, Vec<NodeId>>, node_id: NodeId, compose_node_id: NodeId) {
// Update references in other nodes to use the new compose node
/// Update all of the references to a node ID in the graph with a new ID named `replacement_node_id`.
fn replace_node_id(&mut self, outwards_edges: &HashMap<NodeId, Vec<NodeId>>, node_id: NodeId, replacement_node_id: NodeId) {
// Update references in other nodes to use the new node
if let Some(referring_nodes) = outwards_edges.get(&node_id) {
for &referring_node_id in referring_nodes {
let (_, referring_node) = &mut self.nodes[referring_node_id.0 as usize];
referring_node.map_ids(|id| if id == node_id { compose_node_id } else { id })
referring_node.map_ids(|id| if id == node_id { replacement_node_id } else { id })
}
}
if self.output == node_id {
self.output = compose_node_id;
self.output = replacement_node_id;
}
self.inputs.iter_mut().for_each(|id| {
if *id == node_id {
*id = compose_node_id;
*id = replacement_node_id;
}
});
}
@ -636,7 +545,6 @@ impl TypingContext {
let inputs = match node.construction_args {
// If the node has a value input we can infer the return type from it
ConstructionArgs::Value(ref v) => {
assert!(matches!(node.input, ProtoNodeInput::None) || matches!(node.input, ProtoNodeInput::ManualComposition(ref x) if x == &concrete!(Context)));
// TODO: This should return a reference to the value
let types = NodeIOTypes::new(concrete!(Context), Type::Future(Box::new(v.ty())), vec![]);
self.inferred.insert(node_id, types.clone());
@ -656,16 +564,7 @@ impl TypingContext {
};
// Get the node input type from the proto node declaration
// TODO: When removing automatic composition, rename this to just `call_argument`
let primary_input_or_call_argument = match node.input {
ProtoNodeInput::None => concrete!(()),
ProtoNodeInput::ManualComposition(ref ty) => ty.clone(),
ProtoNodeInput::Node(id) => {
let input = self.inferred.get(&id).ok_or_else(|| vec![GraphError::new(node, GraphErrorType::InputNodeNotFound(id))])?;
input.return_value.clone()
}
};
let using_manual_composition = matches!(node.input, ProtoNodeInput::ManualComposition(_) | ProtoNodeInput::None);
let call_argument = &node.call_argument;
let impls = self.lookup.get(&node.identifier).ok_or_else(|| vec![GraphError::new(node, GraphErrorType::NoImplementations)])?;
if let Some(index) = inputs.iter().position(|p| {
@ -707,7 +606,7 @@ impl TypingContext {
// List of all implementations that match the input types
let valid_output_types = impls
.keys()
.filter(|node_io| valid_type(&node_io.call_argument, &primary_input_or_call_argument) && inputs.iter().zip(node_io.inputs.iter()).all(|(p1, p2)| valid_type(p1, p2)))
.filter(|node_io| valid_type(&node_io.call_argument, call_argument) && inputs.iter().zip(node_io.inputs.iter()).all(|(p1, p2)| valid_type(p1, p2)))
.collect::<Vec<_>>();
// Attempt to substitute generic types with concrete types and save the list of results
@ -716,7 +615,7 @@ impl TypingContext {
.map(|node_io| {
let generics_lookup: Result<HashMap<_, _>, _> = collect_generics(node_io)
.iter()
.map(|generic| check_generic(node_io, &primary_input_or_call_argument, &inputs, generic).map(|x| (generic.to_string(), x)))
.map(|generic| check_generic(node_io, call_argument, &inputs, generic).map(|x| (generic.to_string(), x)))
.collect();
generics_lookup.map(|generics_lookup| {
@ -736,7 +635,7 @@ impl TypingContext {
let mut best_errors = usize::MAX;
let mut error_inputs = Vec::new();
for node_io in impls.keys() {
let current_errors = [&primary_input_or_call_argument]
let current_errors = [call_argument]
.into_iter()
.chain(&inputs)
.cloned()
@ -745,7 +644,6 @@ impl TypingContext {
.filter(|(_, (p1, p2))| !valid_type(p1, p2))
.map(|(index, ty)| {
let i = node.original_location.inputs(index).min_by_key(|s| s.node.len()).map(|s| s.index).unwrap_or(index);
let i = if using_manual_composition { i } else { i + 1 };
(i, ty)
})
.collect::<Vec<_>>();
@ -757,15 +655,11 @@ impl TypingContext {
error_inputs.push(current_errors);
}
}
let inputs = [&primary_input_or_call_argument]
let inputs = [call_argument]
.into_iter()
.chain(&inputs)
.enumerate()
// TODO: Make the following line's if statement conditional on being a call argument or primary input
.filter_map(|(i, t)| {
let i = if using_manual_composition { i } else { i + 1 };
if i == 0 { None } else { Some(format!("• Input {i}: {t}")) }
})
.filter_map(|(i, t)| if i == 0 { None } else { Some(format!("• Input {i}: {t}")) })
.collect::<Vec<_>>()
.join("\n");
Err(vec![GraphError::new(node, GraphErrorType::InvalidImplementations { inputs, error_inputs })])
@ -792,13 +686,13 @@ impl TypingContext {
return Ok(node_io.clone());
}
}
let inputs = [&primary_input_or_call_argument].into_iter().chain(&inputs).map(|t| t.to_string()).collect::<Vec<_>>().join(", ");
let inputs = [call_argument].into_iter().chain(&inputs).map(|t| t.to_string()).collect::<Vec<_>>().join(", ");
let valid = valid_output_types.into_iter().cloned().collect();
Err(vec![GraphError::new(node, GraphErrorType::MultipleImplementations { inputs, valid })])
}
_ => {
let inputs = [&primary_input_or_call_argument].into_iter().chain(&inputs).map(|t| t.to_string()).collect::<Vec<_>>().join(", ");
let inputs = [call_argument].into_iter().chain(&inputs).map(|t| t.to_string()).collect::<Vec<_>>().join(", ");
let valid = valid_output_types.into_iter().cloned().collect();
Err(vec![GraphError::new(node, GraphErrorType::MultipleImplementations { inputs, valid })])
}
@ -857,7 +751,7 @@ fn replace_generics(types: &mut NodeIOTypes, lookup: &HashMap<String, Type>) {
#[cfg(test)]
mod test {
use super::*;
use crate::proto::{ConstructionArgs, ProtoNetwork, ProtoNode, ProtoNodeInput};
use crate::proto::{ConstructionArgs, ProtoNetwork, ProtoNode};
#[test]
fn topological_sort() {
@ -904,16 +798,6 @@ mod test {
assert_eq!(ids, vec![NodeId(0), NodeId(1), NodeId(2), NodeId(3)]);
}
#[test]
fn input_resolution() {
let mut construction_network = test_network();
construction_network.resolve_inputs().expect("Error when calling 'resolve_inputs' on 'construction_network.");
println!("{construction_network:#?}");
assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value");
assert_eq!(construction_network.nodes.len(), 6);
assert_eq!(construction_network.nodes[5].1.construction_args, ConstructionArgs::Nodes(vec![(NodeId(3)), (NodeId(4))]));
}
#[test]
fn stable_node_id_generation() {
let mut construction_network = test_network();
@ -923,14 +807,7 @@ mod test {
let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect();
assert_eq!(
ids,
vec![
NodeId(16997244687192517417),
NodeId(7064939117677356327),
NodeId(10605314923684175783),
NodeId(6550828352538976747),
NodeId(277515424782779520),
NodeId(8855802688584342558)
]
vec![NodeId(13743208144182721472), NodeId(4607569396187877965), NodeId(16950305885390329527), NodeId(15151181027373658932)]
);
}
@ -943,8 +820,8 @@ mod test {
NodeId(7),
ProtoNode {
identifier: "id".into(),
input: ProtoNodeInput::Node(NodeId(11)),
construction_args: ConstructionArgs::Nodes(vec![]),
call_argument: concrete!(()),
construction_args: ConstructionArgs::Nodes(vec![NodeId(11)]),
..Default::default()
},
),
@ -952,8 +829,8 @@ mod test {
NodeId(1),
ProtoNode {
identifier: "id".into(),
input: ProtoNodeInput::Node(NodeId(11)),
construction_args: ConstructionArgs::Nodes(vec![]),
call_argument: concrete!(()),
construction_args: ConstructionArgs::Nodes(vec![NodeId(11)]),
..Default::default()
},
),
@ -961,7 +838,7 @@ mod test {
NodeId(10),
ProtoNode {
identifier: "cons".into(),
input: ProtoNodeInput::ManualComposition(concrete!(u32)),
call_argument: concrete!(u32),
construction_args: ConstructionArgs::Nodes(vec![NodeId(14)]),
..Default::default()
},
@ -970,8 +847,8 @@ mod test {
NodeId(11),
ProtoNode {
identifier: "add".into(),
input: ProtoNodeInput::Node(NodeId(10)),
construction_args: ConstructionArgs::Nodes(vec![]),
call_argument: concrete!(()),
construction_args: ConstructionArgs::Nodes(vec![NodeId(10)]),
..Default::default()
},
),
@ -979,7 +856,7 @@ mod test {
NodeId(14),
ProtoNode {
identifier: "value".into(),
input: ProtoNodeInput::None,
call_argument: concrete!(()),
construction_args: ConstructionArgs::Value(value::TaggedValue::U32(2).into()),
..Default::default()
},
@ -999,8 +876,8 @@ mod test {
NodeId(1),
ProtoNode {
identifier: "id".into(),
input: ProtoNodeInput::Node(NodeId(2)),
construction_args: ConstructionArgs::Nodes(vec![]),
call_argument: concrete!(()),
construction_args: ConstructionArgs::Nodes(vec![NodeId(2)]),
..Default::default()
},
),
@ -1008,8 +885,8 @@ mod test {
NodeId(2),
ProtoNode {
identifier: "id".into(),
input: ProtoNodeInput::Node(NodeId(1)),
construction_args: ConstructionArgs::Nodes(vec![]),
call_argument: concrete!(()),
construction_args: ConstructionArgs::Nodes(vec![NodeId(1)]),
..Default::default()
},
),