Add more details to Graphene concept documentation (#1437)

* Start improving node system docs

* Add note on debugging

* Explain testing protonodes

* Code review comments

* Review pass

* Further improve explanation of manual_compostion

* Fix explanation of ComposeNode graph rewriting

---------

Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
0HyperCube 2023-10-28 03:21:15 +01:00 committed by GitHub
parent bfb6df3b74
commit ceb2f4c13f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 278 additions and 111 deletions

View file

@ -11,6 +11,8 @@ pub mod value;
pub type NodeId = u64;
/// Hash two IDs together, returning a new ID that is always consistant for two input IDs in a specific order.
/// This is used during [`NodeNetwork::flatten`] in order to ensure consistant yet non-conflicting IDs for inner networks.
fn merge_ids(a: u64, b: u64) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
@ -21,6 +23,7 @@ fn merge_ids(a: u64, b: u64) -> u64 {
#[derive(Clone, Debug, PartialEq, Default, specta::Type, Hash, DynAny)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
/// Metadata about the node including its position in the graph UI
pub struct DocumentNodeMetadata {
pub position: IVec2,
}
@ -40,18 +43,114 @@ fn return_true() -> bool {
#[derive(Clone, Debug, PartialEq, Hash, DynAny)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DocumentNode {
/// An identifier used to display in the UI and to display the appropriate properties.
pub name: String,
/// The inputs to a node, which are either:
/// - From other nodes within this graph [`NodeInput::Node`],
/// - A constant value [`NodeInput::Value`],
/// - A [`NodeInput::Network`] which specifies that this input is from outside the graph, which is resolved in the graph flattening step in the case of nested networks.
/// In the root network, it is resolved when evaluating the borrow tree.
pub inputs: Vec<NodeInput>,
/// Manual composition is a way to override the default composition flow of one node into another.
///
/// Through the usual node composition flow, the upstream node providing the primary input for a node is evaluated before the node itself is run.
/// - Abstract example: upstream node `G` is evaluated and its data feeds into the primary input of downstream node `F`,
/// just like function composition where function `F` is evaluated and its result is fed into function `F`.
/// - Concrete example: a node that takes an image as primary input will get that image data from an upstream node that produces image output data and is evaluated first before being fed downstream.
///
/// This is achieved by automatically inserting `ComposeNode`s, which run the first node with the overall input and then feed the resulting output into the second node.
/// The `ComposeNode` is basically a function composition operator: the parentheses in `F(G(x))` or circle math operator in `(G ∘ F)(x)`.
/// For flexability, instead of being a language construct, Graphene splits out composition itself as its own low-level node so that behavior can be overridden.
/// The `ComposeNode`s are then inserted during the graph rewriting step for nodes that don't opt out with `manual_composition`.
/// Instead of node `G` feeding into node `F` feeding as the result back to the caller,
/// the graph is rewritten so nodes `G` and `F` both feed as lambdas into the parameters of a `ComposeNode` which calls `F(G(input))` and returns the result to the caller.
///
/// A node's manual composition input represents an input that is not resolved through graph rewriting with a `ComposeNode`,
/// and is instead just passed in when evaluating this node within the borrow tree.
/// This is similar to having the first input be a `NodeInput::Network` after the graph flattening.
///
/// ## Example Use Case: CacheNode
///
/// The `CacheNode` is a pass-through node on cache miss, but on cache hit it needs to avoid evaluating the upstream node and instead just return the cached value.
///
/// First, let's consider what that would look like using the default composition flow if the `CacheNode` instead just always acted as a pass-through (akin to a cache that always misses):
///
/// ```text
/// ┌───────────────┐ ┌───────────────┐ ┌───────────────┐
/// │ │◄───┤ │◄───┤ │◄─── EVAL (START)
/// │ G │ │PassThroughNode│ │ F │
/// │ ├───►│ ├───►│ │───► RESULT (END)
/// └───────────────┘ └───────────────┘ └───────────────┘
/// ```
///
/// This acts like the function call `F(PassThroughNode(G(input)))` when evaluating `F` with some `input`: `F.eval(input)`.
/// - The diagram's upper track of arrows represents the flow of building up the call stack:
/// since `F` is the output it is encountered first but deferred to its upstream caller `PassThroughNode` and that is once again deferred to its upstream caller `G`.
/// - The diagram's lower track of arrows represents the flow of evaluating the call stack:
/// `G` is evaluated first, then `PassThroughNode` is evaluated with the result of `G`, and finally `F` is evaluated with the result of `PassThroughNode`.
///
/// With the default composition flow (no manual composition), `ComposeNode`s would be automatically inserted during the graph rewriting step like this:
///
/// ```text
/// ┌───────────────┐ ┌───────────────┐
/// │ │◄───┤ │◄─── EVAL (START)
/// │ ComposeNode │ │ F │
/// ┌───────────────┐ │ ├───►│ │───► RESULT (END)
/// │ │◄─┐ ├───────────────┤ └───────────────┘
/// │ F │ └─┤ │
/// │ ├─┐ │ First │
/// └───────────────┘ └─►│ │
/// ┌───────────────┐ ├───────────────┤
/// │ │◄───┤ │
/// │ ComposeNode │ │ Second │
/// ┌───────────────┐ │ ├───►│ │
/// │ │◄─┐ ├───────────────┤ └───────────────┘
/// │ G │ └─┤ │
/// │ ├─┐ │ First │
/// └───────────────┘ └─►│ │
/// ┌───────────────┐ ├───────────────┤
/// | │◄───┤ │
/// │PassThroughNode│ │ Second │
/// │ ├───►│ │
/// └───────────────┘ └───────────────┘
/// ```
///
/// Now let's swap back from the `PassThroughNode` to the `CacheNode` to make caching actually work.
/// It needs to override the default composition flow so that `G` is not automatically evaluated when the cache is hit.
/// We need to give the `CacheNode` more manual control over the order of execution.
/// So the `CacheNode` opts into manual composition and, instead of deferring to its upstream caller, it consumes the input directly:
///
/// ```text
/// ┌───────────────┐ ┌───────────────┐
/// │ │◄───┤ │◄─── EVAL (START)
/// │ CacheNode │ │ F │
/// │ ├───►│ │───► RESULT (END)
/// ┌───────────────┐ ├───────────────┤ └───────────────┘
/// │ │◄───┤ │
/// │ G │ │ Cached Data │
/// │ ├───►│ │
/// └───────────────┘ └───────────────┘
/// ```
///
/// Now, the call from `F` directly reaches the `CacheNode` and the `CacheNode` can decide whether to call `G.eval(input_from_f)`
/// in the event of a cache miss or just return the cached data in the event of a cache hit.
pub manual_composition: Option<Type>,
#[serde(default = "return_true")]
pub has_primary_output: bool,
// A nested document network or a proto-node identifier.
pub implementation: DocumentNodeImplementation,
/// Metadata about the node including its position in the graph UI.
pub metadata: DocumentNodeMetadata,
/// When two different protonodes hash to the same value (e.g. two value nodes each containing `2_u32` or two multiply nodes that have the same node IDs as input), the duplicates are removed.
/// See [`crate::proto::ProtoNetwork::generate_stable_node_ids`] for details.
/// However sometimes this is not desirable, for example in the case of a [`graphene_core::memo::MonitorNode`] that needs to be accessed outside of the graph.
#[serde(default)]
pub skip_deduplication: bool,
/// Used as a hash of the graph input where applicable. This ensures that protonodes that depend on the graph's input are always regenerated.
#[serde(default)]
pub hash: u64,
/// The path to this node as of when [`NodeNetwork::generate_node_paths`] was called.
/// For example if this node was ID 6 inside a node with ID 4 and with a [`DocumentNodeImplementation::Network`], the path would be [4, 6].
pub path: Option<Vec<NodeId>>,
}
@ -72,6 +171,7 @@ impl Default for DocumentNode {
}
impl DocumentNode {
/// Locate the input that is a [`NodeInput::Network`] at index `offset` and replace it with a [`NodeInput::Node`].
pub fn populate_first_network_input(&mut self, node_id: NodeId, output_index: usize, offset: usize, lambda: bool) {
let (index, _) = self
.inputs
@ -90,7 +190,7 @@ impl DocumentNode {
unreachable!("tried to resolve not flattened node on resolved node {self:?}");
};
let (input, mut args) = if let Some(ty) = self.manual_composition {
(ProtoNodeInput::ShortCircut(ty), ConstructionArgs::Nodes(vec![]))
(ProtoNodeInput::ManualComposition(ty), ConstructionArgs::Nodes(vec![]))
} else {
let first = self.inputs.remove(0);
match first {
@ -102,7 +202,7 @@ impl DocumentNode {
assert_eq!(output_index, 0, "Outputs should be flattened before converting to protonode. {:#?}", self.name);
(ProtoNodeInput::Node(node_id, lambda), ConstructionArgs::Nodes(vec![]))
}
NodeInput::Network(ty) => (ProtoNodeInput::Network(ty), ConstructionArgs::Nodes(vec![])),
NodeInput::Network(ty) => (ProtoNodeInput::ManualComposition(ty), ConstructionArgs::Nodes(vec![])),
NodeInput::Inline(inline) => (ProtoNodeInput::None, ConstructionArgs::Inline(inline)),
}
};
@ -177,41 +277,6 @@ pub enum NodeInput {
/// Input that is provided by the parent network to this document node, instead of from a hardcoded value or another node within the same network.
Network(Type),
/// A short circuting input represents an input that is not resolved through function composition
/// but rather by actually consuming the provided input instead of passing it to its predecessor.
///
/// In Graphite nodes are functions, and by default these are composed into a single function
/// by automatic insertion of inserting Compose nodes.
///
/// ```text
/// ┌───────────────┐ ┌───────────────┐ ┌───────────────┐
/// │ │◄───┤ │◄───┤ │
/// │ A │ │ B │ │ C │
/// │ ├───►│ ├───►│ │
/// └───────────────┘ └───────────────┘ └───────────────┘
/// ```
///
/// This is equivalent to calling c(b(a(input))) when evaluating c with input ( `c.eval(input)`).
/// But sometimes we might want to have a little more control over the order of execution.
/// This is why we allow nodes to opt out of the input forwarding by consuming the input directly.
///
/// ```text
/// ┌───────────────┐ ┌───────────────┐
/// │ │◄───┤ │
/// │ Cache Node │ │ C │
/// │ ├───►│ │
/// ┌───────────────┐ ├───────────────┤ └───────────────┘
/// │ │◄───┤ │
/// │ A │ │ * Cached Node │
/// │ ├───►│ │
/// └───────────────┘ └───────────────┘
/// ```
///
/// In this case the Cache node actually consumes its input and then manually forwards it to its parameter Node.
/// This is necessary because the Cache Node needs to short-circut the actual node evaluation.
// TODO: Update
// ShortCircut(Type),
/// A Rust source code string. Allows us to insert literal Rust code. Only used for GPU compilation.
/// We can use this whenever we spin up Rustc. Sort of like inline assembly, but because our language is Rust, it acts as inline Rust.
Inline(InlineRust),
@ -283,9 +348,14 @@ impl NodeInput {
#[derive(Clone, Debug, PartialEq, Hash, DynAny)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
/// Represents the implementation of a node, which can be a nested [`NodeNetwork`], a proto [`NodeIdentifier`], or extract.
pub enum DocumentNodeImplementation {
/// A nested [`NodeNetwork`] that is flattened by the [`NodeNetwork::flatten`] function.
Network(NodeNetwork),
/// A protonode identifier which can be found in `node_registry.rs`.
Unresolved(NodeIdentifier),
/// `DocumentNode`s with a `DocumentNodeImplementation::Extract` are converted into a `ClonedNode` that returns the `DocumentNode` specified by the single `NodeInput::Node`.
/// The referenced node (specified by the single `NodeInput::Node`) is removed from the network, and any `NodeInput::Node`s used by the referenced node are replaced with a generically typed network input.
Extract,
}
@ -317,6 +387,7 @@ impl DocumentNodeImplementation {
#[derive(Clone, Copy, Debug, Default, PartialEq, DynAny, specta::Type, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
/// Defines a particular output port, specifying the node ID and output index.
pub struct NodeOutput {
pub node_id: NodeId,
pub node_output_index: usize,
@ -329,6 +400,7 @@ impl NodeOutput {
#[derive(Clone, Debug, Default, PartialEq, DynAny)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
/// A network of nodes containing each [`DocumentNode`] and its ID, as well as a list of input nodes and [`NodeOutput`]s
pub struct NodeNetwork {
pub inputs: Vec<NodeId>,
pub outputs: Vec<NodeOutput>,
@ -554,6 +626,7 @@ impl NodeNetwork {
}
}
/// Check there are no cycles in the graph (this should never happen).
pub fn is_acyclic(&self) -> bool {
let mut dependencies: HashMap<u64, Vec<u64>> = HashMap::new();
for (node_id, node) in &self.nodes {
@ -587,6 +660,7 @@ impl NodeNetwork {
}
}
/// Iterate over the primary inputs of nodes, so in the case of `a -> b -> c`, this would yield `c, b, a` if we started from `c`.
struct FlowIter<'a> {
stack: Vec<NodeId>,
network: &'a NodeNetwork,
@ -612,6 +686,7 @@ impl<'a> Iterator for FlowIter<'a> {
/// Functions for compiling the network
impl NodeNetwork {
/// Replace all references in the graph of a node ID with a new node ID defined by the function `f`.
pub fn map_ids(&mut self, f: impl Fn(NodeId) -> NodeId + Copy) {
self.inputs.iter_mut().for_each(|id| *id = f(*id));
self.outputs.iter_mut().for_each(|output| output.node_id = f(output.node_id));
@ -642,6 +717,7 @@ impl NodeNetwork {
outwards_links
}
/// Populate the [`DocumentNode::path`], which stores the location of the document node to allow for matching the resulting protonodes to the document node for the purposes of typing and finding monitor nodes.
pub fn generate_node_paths(&mut self, prefix: &[NodeId]) {
for (node_id, node) in &mut self.nodes {
let mut new_path = prefix.to_vec();
@ -657,6 +733,7 @@ impl NodeNetwork {
}
}
/// Replace all references in any node of `old_input` with `new_input`
fn replace_node_inputs(&mut self, old_input: NodeInput, new_input: NodeInput) {
for node in self.nodes.values_mut() {
node.inputs.iter_mut().for_each(|input| {
@ -667,6 +744,7 @@ impl NodeNetwork {
}
}
/// Replace all references in any node of `old_output` with `new_output`
fn replace_network_outputs(&mut self, old_output: NodeOutput, new_output: NodeOutput) {
for output in self.outputs.iter_mut() {
if *output == old_output {
@ -675,7 +753,7 @@ impl NodeNetwork {
}
}
/// Removes unused nodes from the graph. Returns a list of bools which represent if each of the inputs have been retained
/// Removes unused nodes from the graph. Returns a list of booleans which represent if each of the inputs have been retained.
pub fn remove_dead_nodes(&mut self) -> Vec<bool> {
// Take all the nodes out of the nodes list
let mut old_nodes = std::mem::take(&mut self.nodes);
@ -709,11 +787,12 @@ impl NodeNetwork {
are_inputs_used
}
/// Remove all nodes that contain [`DocumentNodeImplementation::Network`] by moving the nested nodes into the parent network.
pub fn flatten(&mut self, node: NodeId) {
self.flatten_with_fns(node, merge_ids, generate_uuid)
}
/// Recursively dissolve non-primitive document nodes and return a single flattened network of nodes.
/// Remove all nodes that contain [`DocumentNodeImplementation::Network`] by moving the nested nodes into the parent network.
pub fn flatten_with_fns(&mut self, node: NodeId, map_ids: impl Fn(NodeId, NodeId) -> NodeId + Copy, gen_id: impl Fn() -> NodeId + Copy) {
self.resolve_extract_nodes();
let Some((id, mut node)) = self.nodes.remove_entry(&node) else {
@ -870,6 +949,7 @@ impl NodeNetwork {
Ok(())
}
/// Strips out any [`graphene_core::ops::IdNode`]s that are unnecessary.
pub fn remove_redundant_id_nodes(&mut self) {
let id_nodes = self
.nodes
@ -888,6 +968,9 @@ impl NodeNetwork {
}
}
/// Converts the `DocumentNode`s with a `DocumentNodeImplementation::Extract` into a `ClonedNode` that returns
/// the `DocumentNode` specified by the single `NodeInput::Node`.
/// The referenced node is removed from the network, and any `NodeInput::Node`s used by the referenced node are replaced with a generically typed network input.
pub fn resolve_extract_nodes(&mut self) {
let mut extraction_nodes = self
.nodes
@ -898,33 +981,32 @@ impl NodeNetwork {
self.nodes.retain(|_, node| !matches!(node.implementation, DocumentNodeImplementation::Extract));
for (_, node) in &mut extraction_nodes {
if let DocumentNodeImplementation::Extract = node.implementation {
assert_eq!(node.inputs.len(), 1);
let NodeInput::Node { node_id, output_index, .. } = node.inputs.pop().unwrap() else {
panic!("Extract node has no input, inputs: {:?}", node.inputs);
assert_eq!(node.inputs.len(), 1);
let NodeInput::Node { node_id, output_index, .. } = node.inputs.pop().unwrap() else {
panic!("Extract node has no input, inputs: {:?}", node.inputs);
};
assert_eq!(output_index, 0);
// TODO: check if we can read lambda checking?
let mut input_node = self.nodes.remove(&node_id).unwrap();
node.implementation = DocumentNodeImplementation::Unresolved("graphene_core::value::ClonedNode".into());
if let Some(input) = input_node.inputs.get_mut(0) {
*input = match &input {
NodeInput::Node { .. } => NodeInput::Network(generic!(T)),
ni => NodeInput::Network(ni.ty()),
};
assert_eq!(output_index, 0);
// TODO: check if we can readd lambda checking
let mut input_node = self.nodes.remove(&node_id).unwrap();
node.implementation = DocumentNodeImplementation::Unresolved("graphene_core::value::ClonedNode".into());
if let Some(input) = input_node.inputs.get_mut(0) {
*input = match &input {
NodeInput::Node { .. } => NodeInput::Network(generic!(T)),
ni => NodeInput::Network(ni.ty()),
};
}
for input in input_node.inputs.iter_mut() {
if let NodeInput::Node { .. } = input {
*input = NodeInput::Network(generic!(T))
}
}
node.inputs = vec![NodeInput::value(TaggedValue::DocumentNode(input_node), false)];
}
for input in input_node.inputs.iter_mut() {
if let NodeInput::Node { .. } = input {
*input = NodeInput::Network(generic!(T))
}
}
node.inputs = vec![NodeInput::value(TaggedValue::DocumentNode(input_node), false)];
}
self.nodes.extend(extraction_nodes);
}
/// Creates a proto network for evaluating each output of this network.
pub fn into_proto_networks(self) -> impl Iterator<Item = ProtoNetwork> {
let mut nodes: Vec<_> = self.nodes.into_iter().map(|(id, node)| (id, node.resolve_proto_node())).collect();
nodes.sort_unstable_by_key(|(i, _)| *i);
@ -1116,7 +1198,7 @@ mod test {
let proto_node = document_node.resolve_proto_node();
let reference = ProtoNode {
identifier: "graphene_core::structural::ConsNode".into(),
input: ProtoNodeInput::Network(concrete!(u32)),
input: ProtoNodeInput::ManualComposition(concrete!(u32)),
construction_args: ConstructionArgs::Nodes(vec![(0, false)]),
document_node_path: vec![],
skip_deduplication: false,
@ -1134,7 +1216,7 @@ mod test {
10,
ProtoNode {
identifier: "graphene_core::structural::ConsNode".into(),
input: ProtoNodeInput::Network(concrete!(u32)),
input: ProtoNodeInput::ManualComposition(concrete!(u32)),
construction_args: ConstructionArgs::Nodes(vec![(14, false)]),
document_node_path: vec![1, 0],
skip_deduplication: false,