diff --git a/Cargo.lock b/Cargo.lock index 252d33d36..aa9faf365 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1940,9 +1940,9 @@ dependencies = [ "graphene-core", "log", "num-traits", + "rustc-hash", "serde", "specta", - "xxhash-rust", ] [[package]] @@ -2024,6 +2024,7 @@ dependencies = [ "log", "node-macro", "reqwest", + "rustc-hash", "serde", "serde_json", "tempfile", @@ -2037,7 +2038,6 @@ dependencies = [ "wgpu-executor", "wgpu-types", "winit", - "xxhash-rust", ] [[package]] @@ -6888,12 +6888,6 @@ version = "0.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a56c84a8ccd4258aed21c92f70c0f6dea75356b6892ae27c24139da456f9336" -[[package]] -name = "xxhash-rust" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "735a71d46c4d68d71d4b24d03fdc2b98e38cea81730595801db779c04fe80d70" - [[package]] name = "zbus" version = "3.14.1" diff --git a/Cargo.toml b/Cargo.toml index 02e1607b5..5a89bc825 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ exclude = ["node-graph/gpu-compiler"] specta = { git = "https://github.com/0HyperCube/specta.git", rev = "c47a22b4c0863d27bc47529f300de3969480c66d", features = [ "glam", ] } -xxhash-rust = { version = "0.8", features = ["xxh3"] } +rustc-hash = "1.1.0" [profile.dev.package.graphite-editor] opt-level = 1 diff --git a/editor/src/node_graph_executor.rs b/editor/src/node_graph_executor.rs index b13f7d96b..39b0a6ba4 100644 --- a/editor/src/node_graph_executor.rs +++ b/editor/src/node_graph_executor.rs @@ -173,7 +173,7 @@ impl NodeRuntime { // We assume only one output assert_eq!(scoped_network.outputs.len(), 1, "Graph with multiple outputs not yet handled"); let c = Compiler {}; - let proto_network = c.compile_single(scoped_network, true)?; + let proto_network = c.compile_single(scoped_network)?; assert_ne!(proto_network.nodes.len(), 0, "No protonodes exist?"); if let Err(e) = self.executor.update(proto_network).await { diff --git a/node-graph/compilation-client/src/main.rs b/node-graph/compilation-client/src/main.rs index 4adb6a81c..f8479c187 100644 --- a/node-graph/compilation-client/src/main.rs +++ b/node-graph/compilation-client/src/main.rs @@ -15,7 +15,7 @@ fn main() { let network = add_network(); let compiler = graph_craft::graphene_compiler::Compiler {}; - let proto_network = compiler.compile_single(network, true).unwrap(); + let proto_network = compiler.compile_single(network).unwrap(); let io = ShaderIO { inputs: vec![ diff --git a/node-graph/graph-craft/Cargo.toml b/node-graph/graph-craft/Cargo.toml index 1f2843a1c..a86fdd5be 100644 --- a/node-graph/graph-craft/Cargo.toml +++ b/node-graph/graph-craft/Cargo.toml @@ -25,4 +25,4 @@ specta.workspace = true bytemuck = {version = "1.8" } anyhow = "1.0.66" -xxhash-rust = {workspace = true} +rustc-hash = {workspace = true} diff --git a/node-graph/graph-craft/src/document.rs b/node-graph/graph-craft/src/document.rs index 7385b1558..bcece0f3b 100644 --- a/node-graph/graph-craft/src/document.rs +++ b/node-graph/graph-craft/src/document.rs @@ -674,7 +674,6 @@ impl NodeNetwork { self.nodes.insert(id, node); return; } - log::debug!("Flattening node {:?}", &node.name); // replace value inputs with value nodes for input in &mut node.inputs { @@ -843,10 +842,8 @@ impl NodeNetwork { self.nodes.retain(|_, node| !matches!(node.implementation, DocumentNodeImplementation::Extract)); for (_, node) in &mut extraction_nodes { - log::debug!("extraction network: {:#?}", &self); if let DocumentNodeImplementation::Extract = node.implementation { assert_eq!(node.inputs.len(), 1); - log::debug!("Resolving extract node {:?}", node); let NodeInput::Node { node_id, output_index, .. } = node.inputs.pop().unwrap() else { panic!("Extract node has no input, inputs: {:?}", node.inputs); }; @@ -866,7 +863,6 @@ impl NodeNetwork { *input = NodeInput::Network(generic!(T)) } } - log::debug!("Extract node {:?} resolved to {:?}", node, input_node); node.inputs = vec![NodeInput::value(TaggedValue::DocumentNode(input_node), false)]; } } diff --git a/node-graph/graph-craft/src/graphene_compiler.rs b/node-graph/graph-craft/src/graphene_compiler.rs index 1b7e11564..daf263ad3 100644 --- a/node-graph/graph-craft/src/graphene_compiler.rs +++ b/node-graph/graph-craft/src/graphene_compiler.rs @@ -8,7 +8,7 @@ use crate::proto::{LocalFuture, ProtoNetwork}; pub struct Compiler {} impl Compiler { - pub fn compile(&self, mut network: NodeNetwork, resolve_inputs: bool) -> impl Iterator { + pub fn compile(&self, mut network: NodeNetwork) -> impl Iterator { println!("flattening"); let node_ids = network.nodes.keys().copied().collect::>(); for id in node_ids { @@ -18,18 +18,14 @@ impl Compiler { network.remove_dead_nodes(); let proto_networks = network.into_proto_networks(); proto_networks.map(move |mut proto_network| { - if resolve_inputs { - println!("resolving inputs"); - proto_network.resolve_inputs(); - } - proto_network.reorder_ids(); + proto_network.resolve_inputs(); proto_network.generate_stable_node_ids(); proto_network }) } - pub fn compile_single(&self, network: NodeNetwork, resolve_inputs: bool) -> Result { + pub fn compile_single(&self, network: NodeNetwork) -> Result { assert_eq!(network.outputs.len(), 1, "Graph with multiple outputs not yet handled"); - let Some(proto_network) = self.compile(network, resolve_inputs).next() else { + let Some(proto_network) = self.compile(network).next() else { return Err("Failed to convert graph into proto graph".to_string()); }; Ok(proto_network) diff --git a/node-graph/graph-craft/src/proto.rs b/node-graph/graph-craft/src/proto.rs index 37cd1baa6..533a59ff9 100644 --- a/node-graph/graph-craft/src/proto.rs +++ b/node-graph/graph-craft/src/proto.rs @@ -4,7 +4,6 @@ use std::collections::{HashMap, HashSet}; use std::ops::Deref; use std::hash::Hash; -use xxhash_rust::xxh3::Xxh3; use crate::document::NodeId; use crate::document::{value, InlineRust}; @@ -155,7 +154,7 @@ impl PartialEq for ConstructionArgs { _ => { use std::hash::Hasher; let hash = |input: &Self| { - let mut hasher = Xxh3::new(); + let mut hasher = rustc_hash::FxHasher::default(); input.hash(&mut hasher); hasher.finish() }; @@ -228,19 +227,18 @@ impl ProtoNodeInput { impl ProtoNode { pub fn stable_node_id(&self) -> Option { use std::hash::Hasher; - let mut hasher = Xxh3::new(); + let mut hasher = rustc_hash::FxHasher::default(); self.identifier.name.hash(&mut hasher); self.construction_args.hash(&mut hasher); self.document_node_path.hash(&mut hasher); + std::mem::discriminant(&self.input).hash(&mut hasher); match self.input { - ProtoNodeInput::None => "none".hash(&mut hasher), + ProtoNodeInput::None => (), ProtoNodeInput::ShortCircut(ref ty) => { - "lambda".hash(&mut hasher); ty.hash(&mut hasher); } ProtoNodeInput::Network(ref ty) => { - "network".hash(&mut hasher); ty.hash(&mut hasher); } ProtoNodeInput::Node(id, lambda) => (id, lambda).hash(&mut hasher), @@ -305,20 +303,15 @@ impl ProtoNetwork { } pub fn generate_stable_node_ids(&mut self) { - for i in 0..self.nodes.len() { - self.generate_stable_node_id(i); - } - } + debug_assert!(self.is_topologically_sorted()); + let outwards_edges = self.collect_outwards_edges(); - pub fn generate_stable_node_id(&mut self, index: usize) -> NodeId { - let mut lookup = self.nodes.iter().map(|(id, _)| (*id, *id)).collect::>(); - if let Some(sni) = self.nodes[index].1.stable_node_id() { - lookup.insert(self.nodes[index].0, sni); - self.replace_node_references(&lookup, false); - self.nodes[index].0 = sni; - sni - } else { - panic!("failed to generate stable node id for node {:#?}", self.nodes[index].1); + for index in 0..self.nodes.len() { + let Some(sni) = self.nodes[index].1.stable_node_id() else { + panic!("failed to generate stable node id for node {:#?}", self.nodes[index].1); + }; + self.replace_node_id(&outwards_edges, index as NodeId, sni, false); + self.nodes[index].0 = sni as NodeId; } } @@ -340,45 +333,59 @@ impl ProtoNetwork { } pub fn resolve_inputs(&mut self) { - let mut resolved = HashSet::new(); - while !self.resolve_inputs_impl(&mut resolved) {} - } - fn resolve_inputs_impl(&mut self, resolved: &mut HashSet) -> bool { + // Perform topological sort once self.reorder_ids(); - let mut lookup = self.nodes.iter().map(|(id, _)| (*id, *id)).collect::>(); - let compose_node_id = self.nodes.len() as NodeId; - let inputs = self.nodes.iter().map(|(_, node)| node.input.clone()).collect::>(); - let paths = self.nodes.iter().map(|(_, node)| node.document_node_path.clone()).collect::>(); + let max_id = self.nodes.len() as NodeId - 1; - let resolved_lookup = resolved.clone(); - if let Some((input_node, id, input, mut path)) = self.nodes.iter_mut().filter(|(id, _)| !resolved_lookup.contains(id)).find_map(|(id, node)| { - if let ProtoNodeInput::Node(input_node, false) = node.input { - resolved.insert(*id); - let pre_node_input = inputs.get(input_node as usize).expect("input node should exist"); - let pre_path = paths.get(input_node as usize).expect("input node should exist"); - Some((input_node, *id, pre_node_input.clone(), pre_path.clone())) - } else { - resolved.insert(*id); - None + // Collect outward edges once + let outwards_edges = self.collect_outwards_edges(); + + // Iterate over nodes in topological order + for node_id in 0..=max_id { + let node = &mut self.nodes[node_id as usize].1; + + if let ProtoNodeInput::Node(input_node_id, false) = node.input { + // Create a new node that composes the current node and its input node + let compose_node_id = self.nodes.len() as NodeId; + let input = self.nodes[input_node_id as usize].1.input.clone(); + let mut path = self.nodes[input_node_id as usize].1.document_node_path.clone(); + path.push(node_id); + + self.nodes.push(( + compose_node_id, + ProtoNode { + identifier: NodeIdentifier::new("graphene_core::structural::ComposeNode<_, _, _>"), + construction_args: ConstructionArgs::Nodes(vec![(input_node_id, false), (node_id, true)]), + input, + document_node_path: path, + }, + )); + + self.replace_node_id(&outwards_edges, node_id, compose_node_id, true); + } + } + self.reorder_ids(); + } + + fn replace_node_id(&mut self, outwards_edges: &HashMap>, node_id: u64, compose_node_id: u64, skip_lambdas: bool) { + // Update references in other nodes to use the new compose node + if let Some(referring_nodes) = outwards_edges.get(&node_id) { + for &referring_node_id in referring_nodes { + let referring_node = &mut self.nodes[referring_node_id as usize].1; + referring_node.map_ids(|id| if id == node_id { compose_node_id } else { id }, skip_lambdas) } - }) { - lookup.insert(id, compose_node_id); - self.replace_node_references(&lookup, true); - path.push(id); - self.nodes.push(( - compose_node_id, - ProtoNode { - identifier: NodeIdentifier::new("graphene_core::structural::ComposeNode<_, _, _>"), - construction_args: ConstructionArgs::Nodes(vec![(input_node, false), (id, true)]), - input, - document_node_path: path, - }, - )); - return false; } - true + if self.output == node_id { + self.output = compose_node_id; + } + + self.inputs.iter_mut().for_each(|id| { + if *id == node_id { + *id = compose_node_id; + } + }); } // Based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search @@ -409,6 +416,24 @@ impl ProtoNetwork { sorted } + fn is_topologically_sorted(&self) -> bool { + let mut visited = HashSet::new(); + + let inwards_edges = self.collect_inwards_edges(); + for (id, node) in &self.nodes { + for &dependency in inwards_edges.get(id).unwrap_or(&Vec::new()) { + if !visited.contains(&dependency) { + dbg!(id, dependency); + dbg!(&visited); + dbg!(&self.nodes); + return false; + } + } + visited.insert(*id); + } + true + } + /*// Based on https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm pub fn topological_sort(&self) -> Vec { let mut sorted = Vec::new(); @@ -435,28 +460,33 @@ impl ProtoNetwork { sorted }*/ - pub fn reorder_ids(&mut self) { + fn reorder_ids(&mut self) { let order = self.topological_sort(); - // Map of node ids to indexes (which become the node ids as they are inserted into the borrow stack) - let lookup: HashMap<_, _> = order.iter().enumerate().map(|(pos, id)| (*id, pos as NodeId)).collect(); - self.nodes = order - .iter() - .enumerate() - .map(|(pos, id)| { - let node = self.nodes.swap_remove(self.nodes.iter().position(|(test_id, _)| test_id == id).unwrap()).1; - (pos as NodeId, node) - }) - .collect(); - self.replace_node_references(&lookup, false); - assert_eq!(order.len(), self.nodes.len()); - } - fn replace_node_references(&mut self, lookup: &HashMap, skip_lambdas: bool) { - self.nodes.iter_mut().for_each(|(_, node)| { - node.map_ids(|id| *lookup.get(&id).expect("node not found in lookup table"), skip_lambdas); + // Map of node ids to their current index in the nodes vector + let current_positions: HashMap<_, _> = self.nodes.iter().enumerate().map(|(pos, (id, _))| (*id, pos)).collect(); + + // Map of node ids to their new index based on topological order + let new_positions: HashMap<_, _> = order.iter().enumerate().map(|(pos, id)| (*id, pos as NodeId)).collect(); + + // Create a new nodes vector based on the topological order + let mut new_nodes = Vec::with_capacity(order.len()); + for (index, &id) in order.iter().enumerate() { + let current_pos = *current_positions.get(&id).unwrap(); + new_nodes.push((index as NodeId, self.nodes[current_pos].1.clone())); + } + + // Update node references to reflect the new order + new_nodes.iter_mut().for_each(|(_, node)| { + node.map_ids(|id| *new_positions.get(&id).expect("node not found in lookup table"), false); }); - self.inputs = self.inputs.iter().filter_map(|id| lookup.get(id).copied()).collect(); - self.output = *lookup.get(&self.output).unwrap(); + + // Update the nodes vector and other references + self.nodes = new_nodes; + self.inputs = self.inputs.iter().filter_map(|id| new_positions.get(id).copied()).collect(); + self.output = *new_positions.get(&self.output).unwrap(); + + assert_eq!(order.len(), self.nodes.len()); } } @@ -698,8 +728,6 @@ mod test { #[test] fn stable_node_id_generation() { let mut construction_network = test_network(); - construction_network.reorder_ids(); - construction_network.generate_stable_node_ids(); construction_network.resolve_inputs(); construction_network.generate_stable_node_ids(); assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value"); @@ -707,12 +735,12 @@ mod test { assert_eq!( ids, vec![ - 4471348669260178714, - 12892313567093808068, - 6883586777044498729, - 13841339389284532934, - 4412916056300566478, - 15358108940336208665 + 16203111412429166836, + 8181436982058796771, + 10130798762907147404, + 1082623390433068677, + 4567264975997576294, + 8215587082195034469 ] ); } diff --git a/node-graph/graphene-cli/src/main.rs b/node-graph/graphene-cli/src/main.rs index 10f68a3ba..d9f8adef6 100644 --- a/node-graph/graphene-cli/src/main.rs +++ b/node-graph/graphene-cli/src/main.rs @@ -94,7 +94,7 @@ fn create_executor(document_string: String) -> Result { async fn compile_gpu(node: &'input DocumentNode, mut typing_context: TypingContext, io: ShaderIO) -> compilation_client::Shader { let compiler = graph_craft::graphene_compiler::Compiler {}; let DocumentNodeImplementation::Network(ref network) = node.implementation else { panic!() }; - let proto_networks: Vec<_> = compiler.compile(network.clone(), true).collect(); + let proto_networks: Vec<_> = compiler.compile(network.clone()).collect(); for network in proto_networks.iter() { typing_context.update(network).expect("Failed to type check network"); @@ -229,7 +229,7 @@ async fn create_compute_pass_descriptor( ..Default::default() }; log::debug!("compiling network"); - let proto_networks = compiler.compile(network.clone(), true).collect(); + let proto_networks = compiler.compile(network.clone()).collect(); log::debug!("compiling shader"); let shader = compilation_client::compile( proto_networks, @@ -442,7 +442,7 @@ async fn blend_gpu_image(foreground: ImageFrame, background: ImageFrame