Restructure GPU execution to model GPU pipelines in the node graph (#1088)

* Start implementing GpuExecutor for wgpu

* Implement read_output_buffer function

* Implement extraction node in the compiler

* Generate type annotations during shader compilation

* Start adding node wrapprs for graph execution api

* Wrap more of the api in nodes

* Restructure Pipeline to accept arbitrary shader inputs

* Adapt nodes to new trait definitions

* Start implementing gpu-compiler trait

* Adapt shader generation

* Hardstuck on pointer casts

* Pass nodes as references in gpu code to avoid zsts

* Update gcore to compile on the gpu

* Fix color doc tests

* Impl Node for node refs
This commit is contained in:
Dennis Kobert 2023-04-23 10:18:31 +02:00 committed by Keavon Chambers
parent 161bbc62b4
commit bdc1ef926a
43 changed files with 1874 additions and 515 deletions

View file

@ -1,3 +1,4 @@
use crate::document::value::TaggedValue;
use crate::proto::{ConstructionArgs, ProtoNetwork, ProtoNode, ProtoNodeInput};
use graphene_core::{NodeIdentifier, Type};
@ -20,7 +21,7 @@ fn merge_ids(a: u64, b: u64) -> u64 {
hasher.finish()
}
#[derive(Clone, Debug, PartialEq, Default, specta::Type)]
#[derive(Clone, Debug, PartialEq, Default, specta::Type, Hash, DynAny)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DocumentNodeMetadata {
pub position: IVec2,
@ -32,7 +33,7 @@ impl DocumentNodeMetadata {
}
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq, Hash, DynAny)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DocumentNode {
pub name: String,
@ -156,7 +157,7 @@ impl DocumentNode {
///
/// In this case the Cache node actually consumes its input and then manually forwards it to its parameter Node.
/// This is necessary because the Cache Node needs to short-circut the actual node evaluation.
#[derive(Debug, Clone, PartialEq, Hash)]
#[derive(Debug, Clone, PartialEq, Hash, DynAny)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum NodeInput {
Node {
@ -165,7 +166,7 @@ pub enum NodeInput {
lambda: bool,
},
Value {
tagged_value: crate::document::value::TaggedValue,
tagged_value: TaggedValue,
exposed: bool,
},
Network(Type),
@ -182,7 +183,7 @@ impl NodeInput {
pub const fn lambda(node_id: NodeId, output_index: usize) -> Self {
Self::Node { node_id, output_index, lambda: true }
}
pub const fn value(tagged_value: crate::document::value::TaggedValue, exposed: bool) -> Self {
pub const fn value(tagged_value: TaggedValue, exposed: bool) -> Self {
Self::Value { tagged_value, exposed }
}
fn map_ids(&mut self, f: impl Fn(NodeId) -> NodeId) {
@ -212,11 +213,12 @@ impl NodeInput {
}
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq, Hash, DynAny)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum DocumentNodeImplementation {
Network(NodeNetwork),
Unresolved(NodeIdentifier),
Extract,
}
impl Default for DocumentNodeImplementation {
@ -227,23 +229,21 @@ impl Default for DocumentNodeImplementation {
impl DocumentNodeImplementation {
pub fn get_network(&self) -> Option<&NodeNetwork> {
if let DocumentNodeImplementation::Network(n) = self {
Some(n)
} else {
None
match self {
DocumentNodeImplementation::Network(n) => Some(n),
_ => None,
}
}
pub fn get_network_mut(&mut self) -> Option<&mut NodeNetwork> {
if let DocumentNodeImplementation::Network(n) = self {
Some(n)
} else {
None
match self {
DocumentNodeImplementation::Network(n) => Some(n),
_ => None,
}
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, DynAny, specta::Type)]
#[derive(Clone, Copy, Debug, Default, PartialEq, DynAny, specta::Type, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct NodeOutput {
pub node_id: NodeId,
@ -267,6 +267,21 @@ pub struct NodeNetwork {
pub previous_outputs: Option<Vec<NodeOutput>>,
}
impl std::hash::Hash for NodeNetwork {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.inputs.hash(state);
self.outputs.hash(state);
let mut nodes: Vec<_> = self.nodes.iter().collect();
nodes.sort_by_key(|(id, _)| *id);
for (id, node) in nodes {
id.hash(state);
node.hash(state);
}
self.disabled.hash(state);
self.previous_outputs.hash(state);
}
}
/// Graph modification functions
impl NodeNetwork {
/// Get the original output nodes of this network, ignoring any preview node
@ -701,12 +716,43 @@ impl NodeNetwork {
self.flatten_with_fns(node_id, map_ids, gen_id);
}
}
DocumentNodeImplementation::Unresolved(_) => {}
DocumentNodeImplementation::Unresolved(_) => (),
DocumentNodeImplementation::Extract => {
panic!("Extract nodes should have been removed before flattening");
}
}
assert!(!self.nodes.contains_key(&id), "Trying to insert a node into the network caused an id conflict");
self.nodes.insert(id, node);
}
pub fn resolve_extract_nodes(&mut self) {
let mut extraction_nodes = self
.nodes
.iter()
.filter(|(_, node)| matches!(node.implementation, DocumentNodeImplementation::Extract))
.map(|(id, node)| (*id, node.clone()))
.collect::<Vec<_>>();
self.nodes.retain(|_, node| !matches!(node.implementation, DocumentNodeImplementation::Extract));
for (_, node) in &mut extraction_nodes {
match node.implementation {
DocumentNodeImplementation::Extract => {
assert_eq!(node.inputs.len(), 1);
let NodeInput::Node { node_id, output_index, lambda } = node.inputs.pop().unwrap() else {
panic!("Extract node has no input");
};
assert_eq!(output_index, 0);
assert!(lambda);
let input_node = self.nodes.get_mut(&node_id).unwrap();
node.implementation = DocumentNodeImplementation::Unresolved("graphene_core::value::ValueNode".into());
node.inputs = vec![NodeInput::value(TaggedValue::DocumentNode(input_node.clone()), false)];
}
_ => (),
}
}
self.nodes.extend(extraction_nodes);
}
pub fn into_proto_networks(self) -> impl Iterator<Item = ProtoNetwork> {
let mut nodes: Vec<_> = self.nodes.into_iter().map(|(id, node)| (id, node.resolve_proto_node())).collect();
nodes.sort_unstable_by_key(|(i, _)| *i);
@ -798,6 +844,39 @@ mod test {
assert_eq!(network, maped_add);
}
#[test]
fn extract_node() {
let id_node = DocumentNode {
name: "Id".into(),
inputs: vec![],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into()),
};
let mut extraction_network = NodeNetwork {
inputs: vec![],
outputs: vec![NodeOutput::new(1, 0)],
nodes: [
id_node.clone(),
DocumentNode {
name: "Extract".into(),
inputs: vec![NodeInput::lambda(0, 0)],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Extract,
},
]
.into_iter()
.enumerate()
.map(|(id, node)| (id as NodeId, node))
.collect(),
..Default::default()
};
extraction_network.resolve_extract_nodes();
assert_eq!(extraction_network.nodes.len(), 2);
let inputs = extraction_network.nodes.get(&1).unwrap().inputs.clone();
assert_eq!(inputs.len(), 1);
assert!(matches!(&inputs[0], &NodeInput::Value{ tagged_value: TaggedValue::DocumentNode(ref network), ..} if network == &id_node));
}
#[test]
fn flatten_add() {
let mut network = NodeNetwork {
@ -810,7 +889,7 @@ mod test {
inputs: vec![
NodeInput::Network(concrete!(u32)),
NodeInput::Value {
tagged_value: crate::document::value::TaggedValue::U32(2),
tagged_value: TaggedValue::U32(2),
exposed: false,
},
],
@ -876,7 +955,7 @@ mod test {
construction_args: ConstructionArgs::Nodes(vec![]),
},
),
(14, ProtoNode::value(ConstructionArgs::Value(crate::document::value::TaggedValue::U32(2)))),
(14, ProtoNode::value(ConstructionArgs::Value(TaggedValue::U32(2)))),
]
.into_iter()
.collect(),
@ -917,7 +996,7 @@ mod test {
DocumentNode {
name: "Value".into(),
inputs: vec![NodeInput::Value {
tagged_value: crate::document::value::TaggedValue::U32(2),
tagged_value: TaggedValue::U32(2),
exposed: false,
}],
metadata: DocumentNodeMetadata::default(),
@ -979,10 +1058,7 @@ mod test {
10,
DocumentNode {
name: "Nested network".into(),
inputs: vec![
NodeInput::value(crate::document::value::TaggedValue::F32(1.), false),
NodeInput::value(crate::document::value::TaggedValue::F32(2.), false),
],
inputs: vec![NodeInput::value(TaggedValue::F32(1.), false), NodeInput::value(TaggedValue::F32(2.), false)],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Network(two_node_identity()),
},
@ -1015,11 +1091,7 @@ mod test {
assert_eq!(result.nodes.keys().copied().collect::<Vec<_>>(), vec![101], "Should just call nested network");
let nested_network_node = result.nodes.get(&101).unwrap();
assert_eq!(nested_network_node.name, "Nested network".to_string(), "Name should not change");
assert_eq!(
nested_network_node.inputs,
vec![NodeInput::value(crate::document::value::TaggedValue::F32(2.), false)],
"Input should be 2"
);
assert_eq!(nested_network_node.inputs, vec![NodeInput::value(TaggedValue::F32(2.), false)], "Input should be 2");
let inner_network = nested_network_node.implementation.get_network().expect("Implementation should be network");
assert_eq!(inner_network.inputs, vec![2], "The input should be sent to the second node");
assert_eq!(inner_network.outputs, vec![NodeOutput::new(2, 0)], "The output should be node id 2");
@ -1038,11 +1110,7 @@ mod test {
for (node_id, input_value, inner_id) in [(10, 1., 1), (101, 2., 2)] {
let nested_network_node = result.nodes.get(&node_id).unwrap();
assert_eq!(nested_network_node.name, "Nested network".to_string(), "Name should not change");
assert_eq!(
nested_network_node.inputs,
vec![NodeInput::value(crate::document::value::TaggedValue::F32(input_value), false)],
"Input should be stable"
);
assert_eq!(nested_network_node.inputs, vec![NodeInput::value(TaggedValue::F32(input_value), false)], "Input should be stable");
let inner_network = nested_network_node.implementation.get_network().expect("Implementation should be network");
assert_eq!(inner_network.inputs, vec![inner_id], "The input should be sent to the second node");
assert_eq!(inner_network.outputs, vec![NodeOutput::new(inner_id, 0)], "The output should be node id");
@ -1061,11 +1129,7 @@ mod test {
assert_eq!(result_node.inputs, vec![NodeInput::node(101, 0)], "Result node should refer to duplicate node as input");
let nested_network_node = result.nodes.get(&101).unwrap();
assert_eq!(nested_network_node.name, "Nested network".to_string(), "Name should not change");
assert_eq!(
nested_network_node.inputs,
vec![NodeInput::value(crate::document::value::TaggedValue::F32(2.), false)],
"Input should be 2"
);
assert_eq!(nested_network_node.inputs, vec![NodeInput::value(TaggedValue::F32(2.), false)], "Input should be 2");
let inner_network = nested_network_node.implementation.get_network().expect("Implementation should be network");
assert_eq!(inner_network.inputs, vec![2], "The input should be sent to the second node");
assert_eq!(inner_network.outputs, vec![NodeOutput::new(2, 0)], "The output should be node id 2");