mirror of
https://github.com/GraphiteEditor/Graphite.git
synced 2025-08-08 07:18:01 +00:00
Restructure GPU execution to model GPU pipelines in the node graph (#1088)
* Start implementing GpuExecutor for wgpu * Implement read_output_buffer function * Implement extraction node in the compiler * Generate type annotations during shader compilation * Start adding node wrapprs for graph execution api * Wrap more of the api in nodes * Restructure Pipeline to accept arbitrary shader inputs * Adapt nodes to new trait definitions * Start implementing gpu-compiler trait * Adapt shader generation * Hardstuck on pointer casts * Pass nodes as references in gpu code to avoid zsts * Update gcore to compile on the gpu * Fix color doc tests * Impl Node for node refs
This commit is contained in:
parent
161bbc62b4
commit
bdc1ef926a
43 changed files with 1874 additions and 515 deletions
|
@ -1,3 +1,4 @@
|
|||
use crate::document::value::TaggedValue;
|
||||
use crate::proto::{ConstructionArgs, ProtoNetwork, ProtoNode, ProtoNodeInput};
|
||||
use graphene_core::{NodeIdentifier, Type};
|
||||
|
||||
|
@ -20,7 +21,7 @@ fn merge_ids(a: u64, b: u64) -> u64 {
|
|||
hasher.finish()
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Default, specta::Type)]
|
||||
#[derive(Clone, Debug, PartialEq, Default, specta::Type, Hash, DynAny)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct DocumentNodeMetadata {
|
||||
pub position: IVec2,
|
||||
|
@ -32,7 +33,7 @@ impl DocumentNodeMetadata {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[derive(Clone, Debug, PartialEq, Hash, DynAny)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct DocumentNode {
|
||||
pub name: String,
|
||||
|
@ -156,7 +157,7 @@ impl DocumentNode {
|
|||
///
|
||||
/// In this case the Cache node actually consumes its input and then manually forwards it to its parameter Node.
|
||||
/// This is necessary because the Cache Node needs to short-circut the actual node evaluation.
|
||||
#[derive(Debug, Clone, PartialEq, Hash)]
|
||||
#[derive(Debug, Clone, PartialEq, Hash, DynAny)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub enum NodeInput {
|
||||
Node {
|
||||
|
@ -165,7 +166,7 @@ pub enum NodeInput {
|
|||
lambda: bool,
|
||||
},
|
||||
Value {
|
||||
tagged_value: crate::document::value::TaggedValue,
|
||||
tagged_value: TaggedValue,
|
||||
exposed: bool,
|
||||
},
|
||||
Network(Type),
|
||||
|
@ -182,7 +183,7 @@ impl NodeInput {
|
|||
pub const fn lambda(node_id: NodeId, output_index: usize) -> Self {
|
||||
Self::Node { node_id, output_index, lambda: true }
|
||||
}
|
||||
pub const fn value(tagged_value: crate::document::value::TaggedValue, exposed: bool) -> Self {
|
||||
pub const fn value(tagged_value: TaggedValue, exposed: bool) -> Self {
|
||||
Self::Value { tagged_value, exposed }
|
||||
}
|
||||
fn map_ids(&mut self, f: impl Fn(NodeId) -> NodeId) {
|
||||
|
@ -212,11 +213,12 @@ impl NodeInput {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[derive(Clone, Debug, PartialEq, Hash, DynAny)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub enum DocumentNodeImplementation {
|
||||
Network(NodeNetwork),
|
||||
Unresolved(NodeIdentifier),
|
||||
Extract,
|
||||
}
|
||||
|
||||
impl Default for DocumentNodeImplementation {
|
||||
|
@ -227,23 +229,21 @@ impl Default for DocumentNodeImplementation {
|
|||
|
||||
impl DocumentNodeImplementation {
|
||||
pub fn get_network(&self) -> Option<&NodeNetwork> {
|
||||
if let DocumentNodeImplementation::Network(n) = self {
|
||||
Some(n)
|
||||
} else {
|
||||
None
|
||||
match self {
|
||||
DocumentNodeImplementation::Network(n) => Some(n),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_network_mut(&mut self) -> Option<&mut NodeNetwork> {
|
||||
if let DocumentNodeImplementation::Network(n) = self {
|
||||
Some(n)
|
||||
} else {
|
||||
None
|
||||
match self {
|
||||
DocumentNodeImplementation::Network(n) => Some(n),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, DynAny, specta::Type)]
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, DynAny, specta::Type, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct NodeOutput {
|
||||
pub node_id: NodeId,
|
||||
|
@ -267,6 +267,21 @@ pub struct NodeNetwork {
|
|||
pub previous_outputs: Option<Vec<NodeOutput>>,
|
||||
}
|
||||
|
||||
impl std::hash::Hash for NodeNetwork {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.inputs.hash(state);
|
||||
self.outputs.hash(state);
|
||||
let mut nodes: Vec<_> = self.nodes.iter().collect();
|
||||
nodes.sort_by_key(|(id, _)| *id);
|
||||
for (id, node) in nodes {
|
||||
id.hash(state);
|
||||
node.hash(state);
|
||||
}
|
||||
self.disabled.hash(state);
|
||||
self.previous_outputs.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
/// Graph modification functions
|
||||
impl NodeNetwork {
|
||||
/// Get the original output nodes of this network, ignoring any preview node
|
||||
|
@ -701,12 +716,43 @@ impl NodeNetwork {
|
|||
self.flatten_with_fns(node_id, map_ids, gen_id);
|
||||
}
|
||||
}
|
||||
DocumentNodeImplementation::Unresolved(_) => {}
|
||||
DocumentNodeImplementation::Unresolved(_) => (),
|
||||
DocumentNodeImplementation::Extract => {
|
||||
panic!("Extract nodes should have been removed before flattening");
|
||||
}
|
||||
}
|
||||
assert!(!self.nodes.contains_key(&id), "Trying to insert a node into the network caused an id conflict");
|
||||
self.nodes.insert(id, node);
|
||||
}
|
||||
|
||||
pub fn resolve_extract_nodes(&mut self) {
|
||||
let mut extraction_nodes = self
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|(_, node)| matches!(node.implementation, DocumentNodeImplementation::Extract))
|
||||
.map(|(id, node)| (*id, node.clone()))
|
||||
.collect::<Vec<_>>();
|
||||
self.nodes.retain(|_, node| !matches!(node.implementation, DocumentNodeImplementation::Extract));
|
||||
|
||||
for (_, node) in &mut extraction_nodes {
|
||||
match node.implementation {
|
||||
DocumentNodeImplementation::Extract => {
|
||||
assert_eq!(node.inputs.len(), 1);
|
||||
let NodeInput::Node { node_id, output_index, lambda } = node.inputs.pop().unwrap() else {
|
||||
panic!("Extract node has no input");
|
||||
};
|
||||
assert_eq!(output_index, 0);
|
||||
assert!(lambda);
|
||||
let input_node = self.nodes.get_mut(&node_id).unwrap();
|
||||
node.implementation = DocumentNodeImplementation::Unresolved("graphene_core::value::ValueNode".into());
|
||||
node.inputs = vec![NodeInput::value(TaggedValue::DocumentNode(input_node.clone()), false)];
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
self.nodes.extend(extraction_nodes);
|
||||
}
|
||||
|
||||
pub fn into_proto_networks(self) -> impl Iterator<Item = ProtoNetwork> {
|
||||
let mut nodes: Vec<_> = self.nodes.into_iter().map(|(id, node)| (id, node.resolve_proto_node())).collect();
|
||||
nodes.sort_unstable_by_key(|(i, _)| *i);
|
||||
|
@ -798,6 +844,39 @@ mod test {
|
|||
assert_eq!(network, maped_add);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_node() {
|
||||
let id_node = DocumentNode {
|
||||
name: "Id".into(),
|
||||
inputs: vec![],
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into()),
|
||||
};
|
||||
let mut extraction_network = NodeNetwork {
|
||||
inputs: vec![],
|
||||
outputs: vec![NodeOutput::new(1, 0)],
|
||||
nodes: [
|
||||
id_node.clone(),
|
||||
DocumentNode {
|
||||
name: "Extract".into(),
|
||||
inputs: vec![NodeInput::lambda(0, 0)],
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
implementation: DocumentNodeImplementation::Extract,
|
||||
},
|
||||
]
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(id, node)| (id as NodeId, node))
|
||||
.collect(),
|
||||
..Default::default()
|
||||
};
|
||||
extraction_network.resolve_extract_nodes();
|
||||
assert_eq!(extraction_network.nodes.len(), 2);
|
||||
let inputs = extraction_network.nodes.get(&1).unwrap().inputs.clone();
|
||||
assert_eq!(inputs.len(), 1);
|
||||
assert!(matches!(&inputs[0], &NodeInput::Value{ tagged_value: TaggedValue::DocumentNode(ref network), ..} if network == &id_node));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flatten_add() {
|
||||
let mut network = NodeNetwork {
|
||||
|
@ -810,7 +889,7 @@ mod test {
|
|||
inputs: vec![
|
||||
NodeInput::Network(concrete!(u32)),
|
||||
NodeInput::Value {
|
||||
tagged_value: crate::document::value::TaggedValue::U32(2),
|
||||
tagged_value: TaggedValue::U32(2),
|
||||
exposed: false,
|
||||
},
|
||||
],
|
||||
|
@ -876,7 +955,7 @@ mod test {
|
|||
construction_args: ConstructionArgs::Nodes(vec![]),
|
||||
},
|
||||
),
|
||||
(14, ProtoNode::value(ConstructionArgs::Value(crate::document::value::TaggedValue::U32(2)))),
|
||||
(14, ProtoNode::value(ConstructionArgs::Value(TaggedValue::U32(2)))),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
|
@ -917,7 +996,7 @@ mod test {
|
|||
DocumentNode {
|
||||
name: "Value".into(),
|
||||
inputs: vec![NodeInput::Value {
|
||||
tagged_value: crate::document::value::TaggedValue::U32(2),
|
||||
tagged_value: TaggedValue::U32(2),
|
||||
exposed: false,
|
||||
}],
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
|
@ -979,10 +1058,7 @@ mod test {
|
|||
10,
|
||||
DocumentNode {
|
||||
name: "Nested network".into(),
|
||||
inputs: vec![
|
||||
NodeInput::value(crate::document::value::TaggedValue::F32(1.), false),
|
||||
NodeInput::value(crate::document::value::TaggedValue::F32(2.), false),
|
||||
],
|
||||
inputs: vec![NodeInput::value(TaggedValue::F32(1.), false), NodeInput::value(TaggedValue::F32(2.), false)],
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
implementation: DocumentNodeImplementation::Network(two_node_identity()),
|
||||
},
|
||||
|
@ -1015,11 +1091,7 @@ mod test {
|
|||
assert_eq!(result.nodes.keys().copied().collect::<Vec<_>>(), vec![101], "Should just call nested network");
|
||||
let nested_network_node = result.nodes.get(&101).unwrap();
|
||||
assert_eq!(nested_network_node.name, "Nested network".to_string(), "Name should not change");
|
||||
assert_eq!(
|
||||
nested_network_node.inputs,
|
||||
vec![NodeInput::value(crate::document::value::TaggedValue::F32(2.), false)],
|
||||
"Input should be 2"
|
||||
);
|
||||
assert_eq!(nested_network_node.inputs, vec![NodeInput::value(TaggedValue::F32(2.), false)], "Input should be 2");
|
||||
let inner_network = nested_network_node.implementation.get_network().expect("Implementation should be network");
|
||||
assert_eq!(inner_network.inputs, vec![2], "The input should be sent to the second node");
|
||||
assert_eq!(inner_network.outputs, vec![NodeOutput::new(2, 0)], "The output should be node id 2");
|
||||
|
@ -1038,11 +1110,7 @@ mod test {
|
|||
for (node_id, input_value, inner_id) in [(10, 1., 1), (101, 2., 2)] {
|
||||
let nested_network_node = result.nodes.get(&node_id).unwrap();
|
||||
assert_eq!(nested_network_node.name, "Nested network".to_string(), "Name should not change");
|
||||
assert_eq!(
|
||||
nested_network_node.inputs,
|
||||
vec![NodeInput::value(crate::document::value::TaggedValue::F32(input_value), false)],
|
||||
"Input should be stable"
|
||||
);
|
||||
assert_eq!(nested_network_node.inputs, vec![NodeInput::value(TaggedValue::F32(input_value), false)], "Input should be stable");
|
||||
let inner_network = nested_network_node.implementation.get_network().expect("Implementation should be network");
|
||||
assert_eq!(inner_network.inputs, vec![inner_id], "The input should be sent to the second node");
|
||||
assert_eq!(inner_network.outputs, vec![NodeOutput::new(inner_id, 0)], "The output should be node id");
|
||||
|
@ -1061,11 +1129,7 @@ mod test {
|
|||
assert_eq!(result_node.inputs, vec![NodeInput::node(101, 0)], "Result node should refer to duplicate node as input");
|
||||
let nested_network_node = result.nodes.get(&101).unwrap();
|
||||
assert_eq!(nested_network_node.name, "Nested network".to_string(), "Name should not change");
|
||||
assert_eq!(
|
||||
nested_network_node.inputs,
|
||||
vec![NodeInput::value(crate::document::value::TaggedValue::F32(2.), false)],
|
||||
"Input should be 2"
|
||||
);
|
||||
assert_eq!(nested_network_node.inputs, vec![NodeInput::value(TaggedValue::F32(2.), false)], "Input should be 2");
|
||||
let inner_network = nested_network_node.implementation.get_network().expect("Implementation should be network");
|
||||
assert_eq!(inner_network.inputs, vec![2], "The input should be sent to the second node");
|
||||
assert_eq!(inner_network.outputs, vec![NodeOutput::new(2, 0)], "The output should be node id 2");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue