mirror of
https://github.com/GraphiteEditor/Graphite.git
synced 2025-08-31 18:27:20 +00:00
Implement node graph gpu execution via vulkano and rust gpu (#870)
* Add Executor abstraction * Resolve inputs for proto nodes by adding compose nodes * Add infrastructure for compiling gpu code * Integrate nodegraph gpu execution into graph-crafter * Extract graphene core path from env vars * Make Color struct usable for gpu code
This commit is contained in:
parent
33d5db76c0
commit
57a1f653e1
26 changed files with 2140 additions and 620 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,3 +1,4 @@
|
|||
target/
|
||||
*.spv
|
||||
*.exrc
|
||||
rust-toolchain
|
||||
|
|
1360
Cargo.lock
generated
1360
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -2,7 +2,10 @@ accepted = [
|
|||
"Apache-2.0",
|
||||
"MIT",
|
||||
"BSD-3-Clause",
|
||||
"BSD-2-Clause",
|
||||
"Zlib",
|
||||
"Unicode-DFS-2016",
|
||||
"ISC",
|
||||
]
|
||||
ignore-build-dependencies = true
|
||||
ignore-dev-dependencies = true
|
||||
|
|
|
@ -73,7 +73,11 @@ allow = [
|
|||
"MIT",
|
||||
"Apache-2.0",
|
||||
"BSD-3-Clause",
|
||||
"BSD-2-Clause",
|
||||
"Zlib",
|
||||
"Zlib",
|
||||
"Unicode-DFS-2016",
|
||||
"ISC",
|
||||
#"Apache-2.0 WITH LLVM-exception",
|
||||
]
|
||||
# List of explicitly disallowed licenses
|
||||
|
@ -173,8 +177,8 @@ skip = [
|
|||
#{ name = "ansi_term", version = "=0.11.0" },
|
||||
{ name = "cfg-if", version = "=0.1.10" },
|
||||
]
|
||||
# Similarly to `skip` allows you to skip certain crates during duplicate
|
||||
# detection. Unlike skip, it also includes the entire tree of transitive
|
||||
# Similarly to `skip` allows you to skip certain crates during duplicate
|
||||
# detection. Unlike skip, it also includes the entire tree of transitive
|
||||
# dependencies starting at the specified crate, up to a certain depth, which is
|
||||
# by default infinite
|
||||
skip-tree = [
|
||||
|
|
|
@ -9,6 +9,7 @@ use glam::{DAffine2, DMat2, DVec2};
|
|||
use graph_craft::proto::Type;
|
||||
use kurbo::{Affine, BezPath, Shape as KurboShape};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Write;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
|
||||
|
@ -131,7 +132,7 @@ impl Default for NodeGraphFrameLayer {
|
|||
DocumentNode {
|
||||
name: "Input".into(),
|
||||
inputs: vec![NodeInput::Network],
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode", &[Type::Generic])),
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode", &[Type::Generic(Cow::Borrowed("T"))])),
|
||||
metadata: DocumentNodeMetadata { position: (8, 4) },
|
||||
},
|
||||
),
|
||||
|
@ -140,7 +141,7 @@ impl Default for NodeGraphFrameLayer {
|
|||
DocumentNode {
|
||||
name: "Output".into(),
|
||||
inputs: vec![NodeInput::Node(0)],
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode", &[Type::Generic])),
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode", &[Type::Generic(Cow::Borrowed("T"))])),
|
||||
metadata: DocumentNodeMetadata { position: (20, 4) },
|
||||
},
|
||||
),
|
||||
|
|
|
@ -10,14 +10,16 @@ license = "MIT OR Apache-2.0"
|
|||
|
||||
[features]
|
||||
std = ["dyn-any"]
|
||||
default = ["async"]
|
||||
gpu = ["spirv-std"]
|
||||
default = ["async", "serde"]
|
||||
gpu = ["spirv-std", "bytemuck"]
|
||||
async = ["async-trait"]
|
||||
nightly = []
|
||||
serde = ["dep:serde"]
|
||||
|
||||
[dependencies]
|
||||
dyn-any = {path = "../../libraries/dyn-any", features = ["derive"], optional = true}
|
||||
|
||||
spirv-std = { git = "https://github.com/EmbarkStudios/rust-gpu", features = ["glam"] , optional = true}
|
||||
bytemuck = {version = "1.8", features = ["derive"], optional = true}
|
||||
async-trait = {version = "0.1", optional = true}
|
||||
serde = {version = "1.0", features = ["derive"]}
|
||||
serde = {version = "1.0", features = ["derive"], optional = true}
|
||||
|
|
8
node-graph/gcore/src/gpu.rs
Normal file
8
node-graph/gcore/src/gpu.rs
Normal file
|
@ -0,0 +1,8 @@
|
|||
use bytemuck::{Pod, Zeroable};
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Pod, Zeroable)]
|
||||
pub struct PushConstants {
|
||||
pub n: u32,
|
||||
pub node: u32,
|
||||
}
|
|
@ -1,6 +1,4 @@
|
|||
#![no_std]
|
||||
#![cfg_attr(target_arch = "spirv", feature(register_attr), register_attr(spirv))]
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
extern crate alloc;
|
||||
|
||||
|
@ -11,20 +9,18 @@ use async_trait::async_trait;
|
|||
|
||||
pub mod generic;
|
||||
pub mod ops;
|
||||
pub mod raster;
|
||||
pub mod structural;
|
||||
pub mod value;
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub mod gpu;
|
||||
|
||||
pub mod raster;
|
||||
|
||||
pub trait Node<T> {
|
||||
type Output;
|
||||
|
||||
fn eval(self, input: T) -> Self::Output;
|
||||
fn input(&self) -> &str {
|
||||
core::any::type_name::<T>()
|
||||
}
|
||||
fn output(&self) -> &str {
|
||||
core::any::type_name::<Self::Output>()
|
||||
}
|
||||
}
|
||||
|
||||
trait Input<I> {
|
||||
|
|
|
@ -3,7 +3,7 @@ use core::ops::Add;
|
|||
|
||||
use crate::{Node, RefNode};
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
|
||||
pub struct AddNode;
|
||||
impl<'n, L: Add<R, Output = O> + 'n, R, O: 'n> Node<(L, R)> for AddNode {
|
||||
type Output = <L as Add<R>>::Output;
|
||||
|
@ -30,6 +30,12 @@ impl<'n, L: Add<R, Output = O> + 'n + Copy, R: Copy, O: 'n> Node<&'n (L, R)> for
|
|||
}
|
||||
}
|
||||
|
||||
impl AddNode {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
pub mod dynamic {
|
||||
use super::*;
|
||||
|
@ -156,7 +162,7 @@ impl<'n, T: Clone + 'n> Node<T> for DupNode {
|
|||
}
|
||||
|
||||
/// Return the Input Argument
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
|
||||
pub struct IdNode;
|
||||
impl<T> Node<T> for IdNode {
|
||||
type Output = T;
|
||||
|
@ -177,6 +183,12 @@ impl<T> RefNode<T> for IdNode {
|
|||
}
|
||||
}
|
||||
|
||||
impl IdNode {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MapResultNode<MN, I, E>(pub MN, pub PhantomData<(I, E)>);
|
||||
|
||||
impl<MN: Node<I>, I, E> Node<Result<I, E>> for MapResultNode<MN, I, E> {
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
use crate::Node;
|
||||
|
||||
use self::color::Color;
|
||||
|
||||
pub mod color;
|
||||
use self::color::Color;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct GrayscaleColorNode;
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#[cfg(feature = "std")]
|
||||
use dyn_any::{DynAny, StaticType};
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Structure that represents a color.
|
||||
|
@ -8,7 +9,7 @@ use serde::{Deserialize, Serialize};
|
|||
/// the values encode the brightness of each channel proportional to the light intensity in cd/m² (nits) in HDR, and `0.0` (black) to `1.0` (white) in SDR color.
|
||||
#[repr(C)]
|
||||
#[cfg_attr(feature = "std", derive(Debug, Clone, Copy, PartialEq, Default, Serialize, Deserialize, DynAny))]
|
||||
#[cfg_attr(not(feature = "std"), derive(Debug, Clone, Copy, PartialEq, Default, Serialize, Deserialize))]
|
||||
#[cfg_attr(not(feature = "std"), derive(Debug, Clone, Copy, PartialEq, Default))]
|
||||
pub struct Color {
|
||||
red: f32,
|
||||
green: f32,
|
||||
|
@ -35,11 +36,13 @@ impl Color {
|
|||
/// let color = Color::from_rgbaf32(1.0, 1.0, 1.0, f32::NAN);
|
||||
/// assert!(color == None);
|
||||
/// ```
|
||||
#[cfg(not(target_arch = "spirv"))]
|
||||
pub fn from_rgbaf32(red: f32, green: f32, blue: f32, alpha: f32) -> Option<Color> {
|
||||
if alpha > 1. || [red, green, blue, alpha].iter().any(|c| c.is_sign_negative() || !c.is_finite()) {
|
||||
return None;
|
||||
}
|
||||
Some(Color { red, green, blue, alpha })
|
||||
let color = Color { red, green, blue, alpha };
|
||||
Some(color)
|
||||
}
|
||||
|
||||
/// Return an opaque `Color` from given `f32` RGB channels.
|
||||
|
@ -230,6 +233,7 @@ impl Color {
|
|||
/// use graphene_core::raster::color::Color;
|
||||
/// let color = Color::from_rgba_str("7C67FA61").unwrap();
|
||||
/// ```
|
||||
#[cfg(not(target_arch = "spirv"))]
|
||||
pub fn from_rgba_str(color_str: &str) -> Option<Color> {
|
||||
if color_str.len() != 8 {
|
||||
return None;
|
||||
|
@ -247,6 +251,7 @@ impl Color {
|
|||
/// use graphene_core::raster::color::Color;
|
||||
/// let color = Color::from_rgb_str("7C67FA").unwrap();
|
||||
/// ```
|
||||
#[cfg(not(target_arch = "spirv"))]
|
||||
pub fn from_rgb_str(color_str: &str) -> Option<Color> {
|
||||
if color_str.len() != 6 {
|
||||
return None;
|
||||
|
|
|
@ -4,10 +4,16 @@ version = "0.1.0"
|
|||
edition = "2021"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
profiling = ["nvtx", "gpu"]
|
||||
gpu = ["serde", "vulkano", "spirv-builder", "tera", "graphene-core/gpu"]
|
||||
serde = ["dep:serde", "graphene-std/serde", "glam/serde"]
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
graphene-core = { path = "../gcore", features = ["async", "std"] }
|
||||
graphene-core = { path = "../gcore", features = ["async", "std" ] }
|
||||
graphene-std = { path = "../gstd" }
|
||||
dyn-any = { path = "../../libraries/dyn-any", features = ["log-bad-types", "rc", "glam"] }
|
||||
num-traits = "0.2"
|
||||
|
@ -18,5 +24,10 @@ log = "0.4"
|
|||
serde = { version = "1", features = ["derive", "rc"], optional = true }
|
||||
glam = { version = "0.17" }
|
||||
|
||||
[features]
|
||||
serde = ["dep:serde", "graphene-std/serde", "glam/serde"]
|
||||
vulkano = {git = "https://github.com/GraphiteEditor/vulkano", branch = "fix_rust_gpu", optional = true}
|
||||
bytemuck = {version = "1.8" }
|
||||
nvtx = {version = "1.1.1", optional = true}
|
||||
tempfile = "3"
|
||||
spirv-builder = {git = "https://github.com/EmbarkStudios/rust-gpu" , branch = "main", optional = true, default-features = false, features=["use-installed-tools"]}
|
||||
tera = {version = "1.17.1", optional = true}
|
||||
anyhow = "1.0.66"
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use crate::generic;
|
||||
use crate::proto::{ConstructionArgs, NodeIdentifier, ProtoNetwork, ProtoNode, ProtoNodeInput, Type};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Mutex;
|
||||
|
@ -193,7 +194,7 @@ impl NodeNetwork {
|
|||
let value_node = DocumentNode {
|
||||
name: name.clone(),
|
||||
inputs: vec![NodeInput::Value { tagged_value, exposed }],
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::value::ValueNode", &[Type::Generic])),
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::value::ValueNode", &[generic!("T")])),
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
};
|
||||
assert!(!self.nodes.contains_key(&new_id));
|
||||
|
@ -209,7 +210,7 @@ impl NodeNetwork {
|
|||
}
|
||||
}
|
||||
}
|
||||
node.implementation = DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode", &[Type::Generic]));
|
||||
node.implementation = DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode", &[generic!("T")]));
|
||||
node.inputs = vec![NodeInput::Node(inner_network.output)];
|
||||
for node_id in new_nodes {
|
||||
self.flatten_with_fns(node_id, map_ids, gen_id);
|
||||
|
@ -256,8 +257,8 @@ mod test {
|
|||
DocumentNode {
|
||||
name: "Cons".into(),
|
||||
inputs: vec![NodeInput::Network, NodeInput::Network],
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::structural::ConsNode", &[Type::Generic, Type::Generic])),
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::structural::ConsNode", &[generic!("T"), generic!("U")])),
|
||||
},
|
||||
),
|
||||
(
|
||||
|
@ -265,8 +266,8 @@ mod test {
|
|||
DocumentNode {
|
||||
name: "Add".into(),
|
||||
inputs: vec![NodeInput::Node(0)],
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::AddNode", &[Type::Generic, Type::Generic])),
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::AddNode", &[generic!("T"), generic!("U")])),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
@ -288,8 +289,8 @@ mod test {
|
|||
DocumentNode {
|
||||
name: "Cons".into(),
|
||||
inputs: vec![NodeInput::Network, NodeInput::Network],
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::structural::ConsNode", &[Type::Generic, Type::Generic])),
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::structural::ConsNode", &[generic!("T"), generic!("U")])),
|
||||
},
|
||||
),
|
||||
(
|
||||
|
@ -297,8 +298,8 @@ mod test {
|
|||
DocumentNode {
|
||||
name: "Add".into(),
|
||||
inputs: vec![NodeInput::Node(1)],
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::AddNode", &[Type::Generic, Type::Generic])),
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::AddNode", &[generic!("T"), generic!("U")])),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
@ -344,13 +345,13 @@ mod test {
|
|||
let document_node = DocumentNode {
|
||||
name: "Cons".into(),
|
||||
inputs: vec![NodeInput::Network, NodeInput::Node(0)],
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::structural::ConsNode", &[Type::Generic, Type::Generic])),
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::structural::ConsNode", &[generic!("T"), generic!("U")])),
|
||||
};
|
||||
|
||||
let proto_node = document_node.resolve_proto_node();
|
||||
let reference = ProtoNode {
|
||||
identifier: NodeIdentifier::new("graphene_core::structural::ConsNode", &[Type::Generic, Type::Generic]),
|
||||
identifier: NodeIdentifier::new("graphene_core::structural::ConsNode", &[generic!("T"), generic!("U")]),
|
||||
input: ProtoNodeInput::Network,
|
||||
construction_args: ConstructionArgs::Nodes(vec![0]),
|
||||
};
|
||||
|
@ -366,7 +367,7 @@ mod test {
|
|||
(
|
||||
1,
|
||||
ProtoNode {
|
||||
identifier: NodeIdentifier::new("graphene_core::ops::IdNode", &[Type::Generic]),
|
||||
identifier: NodeIdentifier::new("graphene_core::ops::IdNode", &[generic!("T")]),
|
||||
input: ProtoNodeInput::Node(11),
|
||||
construction_args: ConstructionArgs::Nodes(vec![]),
|
||||
},
|
||||
|
@ -374,7 +375,7 @@ mod test {
|
|||
(
|
||||
10,
|
||||
ProtoNode {
|
||||
identifier: NodeIdentifier::new("graphene_core::structural::ConsNode", &[Type::Generic, Type::Generic]),
|
||||
identifier: NodeIdentifier::new("graphene_core::structural::ConsNode", &[generic!("T"), generic!("U")]),
|
||||
input: ProtoNodeInput::Network,
|
||||
construction_args: ConstructionArgs::Nodes(vec![14]),
|
||||
},
|
||||
|
@ -382,7 +383,7 @@ mod test {
|
|||
(
|
||||
11,
|
||||
ProtoNode {
|
||||
identifier: NodeIdentifier::new("graphene_core::ops::AddNode", &[Type::Generic, Type::Generic]),
|
||||
identifier: NodeIdentifier::new("graphene_core::ops::AddNode", &[generic!("T"), generic!("U")]),
|
||||
input: ProtoNodeInput::Node(10),
|
||||
construction_args: ConstructionArgs::Nodes(vec![]),
|
||||
},
|
||||
|
@ -410,8 +411,8 @@ mod test {
|
|||
DocumentNode {
|
||||
name: "Inc".into(),
|
||||
inputs: vec![NodeInput::Node(11)],
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode", &[Type::Generic])),
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode", &[generic!("T")])),
|
||||
},
|
||||
),
|
||||
(
|
||||
|
@ -419,8 +420,8 @@ mod test {
|
|||
DocumentNode {
|
||||
name: "Cons".into(),
|
||||
inputs: vec![NodeInput::Network, NodeInput::Node(14)],
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::structural::ConsNode", &[Type::Generic, Type::Generic])),
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::structural::ConsNode", &[generic!("T"), generic!("U")])),
|
||||
},
|
||||
),
|
||||
(
|
||||
|
@ -431,8 +432,8 @@ mod test {
|
|||
tagged_value: value::TaggedValue::U32(2),
|
||||
exposed: false,
|
||||
}],
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::value::ValueNode", &[Type::Generic])),
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::value::ValueNode", &[generic!("T")])),
|
||||
},
|
||||
),
|
||||
(
|
||||
|
@ -440,8 +441,8 @@ mod test {
|
|||
DocumentNode {
|
||||
name: "Add".into(),
|
||||
inputs: vec![NodeInput::Node(10)],
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::AddNode", &[Type::Generic, Type::Generic])),
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::AddNode", &[generic!("T"), generic!("U")])),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
54
node-graph/graph-craft/src/executor.rs
Normal file
54
node-graph/graph-craft/src/executor.rs
Normal file
|
@ -0,0 +1,54 @@
|
|||
use std::error::Error;
|
||||
|
||||
use borrow_stack::{BorrowStack, FixedSizeStack};
|
||||
use graphene_core::Node;
|
||||
use graphene_std::any::{Any, TypeErasedNode};
|
||||
|
||||
use crate::{document::NodeNetwork, node_registry::push_node, proto::ProtoNetwork};
|
||||
|
||||
pub struct Compiler {}
|
||||
|
||||
impl Compiler {
|
||||
pub fn compile(&self, mut network: NodeNetwork, resolve_inputs: bool) -> ProtoNetwork {
|
||||
let node_count = network.nodes.len();
|
||||
println!("flattening");
|
||||
for id in 0..node_count {
|
||||
network.flatten(id as u64);
|
||||
}
|
||||
let mut proto_network = network.into_proto_network();
|
||||
if resolve_inputs {
|
||||
println!("resolving inputs");
|
||||
proto_network.resolve_inputs();
|
||||
}
|
||||
println!("reordering ids");
|
||||
proto_network.reorder_ids();
|
||||
proto_network
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Executor {
|
||||
fn execute(&self, input: Any<'static>) -> Result<Any<'static>, Box<dyn Error>>;
|
||||
}
|
||||
|
||||
pub struct DynamicExecutor {
|
||||
stack: FixedSizeStack<TypeErasedNode<'static>>,
|
||||
}
|
||||
|
||||
impl DynamicExecutor {
|
||||
pub fn new(proto_network: ProtoNetwork) -> Self {
|
||||
assert_eq!(proto_network.inputs.len(), 1);
|
||||
let node_count = proto_network.nodes.len();
|
||||
let stack = FixedSizeStack::new(node_count);
|
||||
for (_id, node) in proto_network.nodes {
|
||||
push_node(node, &stack);
|
||||
}
|
||||
Self { stack }
|
||||
}
|
||||
}
|
||||
|
||||
impl Executor for DynamicExecutor {
|
||||
fn execute(&self, input: Any<'static>) -> Result<Any<'static>, Box<dyn Error>> {
|
||||
let result = unsafe { self.stack.get().last().unwrap().eval(input) };
|
||||
Ok(result)
|
||||
}
|
||||
}
|
3
node-graph/graph-craft/src/gpu.rs
Normal file
3
node-graph/graph-craft/src/gpu.rs
Normal file
|
@ -0,0 +1,3 @@
|
|||
pub mod compiler;
|
||||
pub mod context;
|
||||
pub mod executor;
|
137
node-graph/graph-craft/src/gpu/compiler.rs
Normal file
137
node-graph/graph-craft/src/gpu/compiler.rs
Normal file
|
@ -0,0 +1,137 @@
|
|||
use std::path::Path;
|
||||
|
||||
use crate::proto::*;
|
||||
use tera::Context;
|
||||
|
||||
fn create_cargo_toml(metadata: &Metadata) -> Result<String, tera::Error> {
|
||||
let mut tera = tera::Tera::default();
|
||||
tera.add_raw_template("cargo_toml", include_str!("templates/Cargo-template.toml"))?;
|
||||
let mut context = Context::new();
|
||||
context.insert("name", &metadata.name);
|
||||
context.insert("authors", &metadata.authors);
|
||||
context.insert("gcore_path", &format!("{}{}", env!("CARGO_MANIFEST_DIR"), "/../gcore"));
|
||||
tera.render("cargo_toml", &context)
|
||||
}
|
||||
|
||||
pub struct Metadata {
|
||||
name: String,
|
||||
authors: Vec<String>,
|
||||
}
|
||||
|
||||
impl Metadata {
|
||||
pub fn new(name: String, authors: Vec<String>) -> Self {
|
||||
Self { name, authors }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_files(matadata: &Metadata, network: &ProtoNetwork, compile_dir: &Path, input_type: &str, output_type: &str) -> anyhow::Result<()> {
|
||||
let src = compile_dir.join("src");
|
||||
let cargo_file = compile_dir.join("Cargo.toml");
|
||||
let cargo_toml = create_cargo_toml(matadata)?;
|
||||
std::fs::write(cargo_file, cargo_toml)?;
|
||||
|
||||
// create src dir
|
||||
match std::fs::create_dir(&src) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
if e.kind() != std::io::ErrorKind::AlreadyExists {
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
let lib = src.join("lib.rs");
|
||||
let shader = serialize_gpu(network, input_type, output_type)?;
|
||||
println!("{}", shader);
|
||||
std::fs::write(lib, shader)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn serialize_gpu(network: &ProtoNetwork, input_type: &str, output_type: &str) -> anyhow::Result<String> {
|
||||
assert_eq!(network.inputs.len(), 1);
|
||||
/*let input = &network.nodes[network.inputs[0] as usize].1;
|
||||
let output = &network.nodes[network.output as usize].1;
|
||||
let input_type = format!("{}::Input", input.identifier.fully_qualified_name());
|
||||
let output_type = format!("{}::Output", output.identifier.fully_qualified_name());
|
||||
*/
|
||||
|
||||
fn nid(id: &u64) -> String {
|
||||
format!("n{id}")
|
||||
}
|
||||
|
||||
let mut nodes = Vec::new();
|
||||
#[derive(serde::Serialize)]
|
||||
struct Node {
|
||||
id: String,
|
||||
fqn: String,
|
||||
args: Vec<String>,
|
||||
}
|
||||
for (ref id, node) in network.nodes.iter() {
|
||||
let fqn = &node.identifier.name;
|
||||
let id = nid(id);
|
||||
|
||||
nodes.push(Node {
|
||||
id,
|
||||
fqn: fqn.to_string(),
|
||||
args: node.construction_args.new_function_args(),
|
||||
});
|
||||
}
|
||||
|
||||
let template = include_str!("templates/spirv-template.rs");
|
||||
let mut tera = tera::Tera::default();
|
||||
tera.add_raw_template("spirv", template)?;
|
||||
let mut context = Context::new();
|
||||
context.insert("input_type", &input_type);
|
||||
context.insert("output_type", &output_type);
|
||||
context.insert("nodes", &nodes);
|
||||
context.insert("last_node", &nid(&network.output));
|
||||
context.insert("compute_threads", &64);
|
||||
Ok(tera.render("spirv", &context)?)
|
||||
}
|
||||
|
||||
use spirv_builder::{MetadataPrintout, SpirvBuilder, SpirvMetadata};
|
||||
pub fn compile(dir: &Path) -> Result<spirv_builder::CompileResult, spirv_builder::SpirvBuilderError> {
|
||||
let result = SpirvBuilder::new(dir, "spirv-unknown-spv1.5")
|
||||
.print_metadata(MetadataPrintout::DependencyOnly)
|
||||
.multimodule(false)
|
||||
.preserve_bindings(true)
|
||||
.release(true)
|
||||
//.relax_struct_store(true)
|
||||
//.relax_block_layout(true)
|
||||
.spirv_metadata(SpirvMetadata::Full)
|
||||
.build()?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
|
||||
#[test]
|
||||
fn test_create_cargo_toml() {
|
||||
let cargo_toml = super::create_cargo_toml(&super::Metadata {
|
||||
name: "project".to_owned(),
|
||||
authors: vec!["Example <john.smith@example.com>".to_owned(), "smith.john@example.com".to_owned()],
|
||||
});
|
||||
let cargo_toml = cargo_toml.expect("failed to build carog toml template");
|
||||
let lines = cargo_toml.split('\n').collect::<Vec<_>>();
|
||||
let cargo_toml = lines[..lines.len() - 2].join("\n");
|
||||
let reference = r#"[package]
|
||||
name = "project-node"
|
||||
version = "0.1.0"
|
||||
authors = ["Example <john.smith@example.com>", "smith.john@example.com", ]
|
||||
edition = "2021"
|
||||
license = "MIT OR Apache-2.0"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
crate-type = ["dylib", "lib"]
|
||||
|
||||
[patch.crates-io]
|
||||
libm = { git = "https://github.com/rust-lang/libm", tag = "0.2.5" }
|
||||
|
||||
[dependencies]
|
||||
spirv-std = { git = "https://github.com/EmbarkStudios/rust-gpu" , features= ["glam"]}"#;
|
||||
|
||||
assert_eq!(cargo_toml, reference);
|
||||
}
|
||||
}
|
68
node-graph/graph-craft/src/gpu/context.rs
Normal file
68
node-graph/graph-craft/src/gpu/context.rs
Normal file
|
@ -0,0 +1,68 @@
|
|||
use std::sync::Arc;
|
||||
use vulkano::{
|
||||
command_buffer::allocator::StandardCommandBufferAllocator,
|
||||
descriptor_set::allocator::StandardDescriptorSetAllocator,
|
||||
device::{Device, DeviceCreateInfo, Queue, QueueCreateInfo},
|
||||
instance::{Instance, InstanceCreateInfo},
|
||||
memory::allocator::StandardMemoryAllocator,
|
||||
VulkanLibrary,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Context {
|
||||
pub instance: Arc<Instance>,
|
||||
pub device: Arc<Device>,
|
||||
pub queue: Arc<Queue>,
|
||||
pub allocator: StandardMemoryAllocator,
|
||||
pub command_buffer_allocator: StandardCommandBufferAllocator,
|
||||
pub descriptor_set_allocator: StandardDescriptorSetAllocator,
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub fn new() -> Self {
|
||||
let library = VulkanLibrary::new().unwrap();
|
||||
let instance = Instance::new(library, InstanceCreateInfo::default()).expect("failed to create instance");
|
||||
let physical = instance.enumerate_physical_devices().expect("could not enumerate devices").next().expect("no device available");
|
||||
for family in physical.queue_family_properties() {
|
||||
println!("Found a queue family with {:?} queue(s)", family.queue_count);
|
||||
}
|
||||
let queue_family_index = physical
|
||||
.queue_family_properties()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.position(|(_, q)| q.queue_flags.graphics)
|
||||
.expect("couldn't find a graphical queue family") as u32;
|
||||
|
||||
let (device, mut queues) = Device::new(
|
||||
physical,
|
||||
DeviceCreateInfo {
|
||||
// here we pass the desired queue family to use by index
|
||||
queue_create_infos: vec![QueueCreateInfo {
|
||||
queue_family_index,
|
||||
..Default::default()
|
||||
}],
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.expect("failed to create device");
|
||||
let queue = queues.next().unwrap();
|
||||
let alloc = StandardMemoryAllocator::new_default(device.clone());
|
||||
let calloc = StandardCommandBufferAllocator::new(device.clone());
|
||||
let dalloc = StandardDescriptorSetAllocator::new(device.clone());
|
||||
|
||||
Self {
|
||||
instance,
|
||||
device,
|
||||
queue,
|
||||
allocator: alloc,
|
||||
command_buffer_allocator: calloc,
|
||||
descriptor_set_allocator: dalloc,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Context {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
188
node-graph/graph-craft/src/gpu/executor.rs
Normal file
188
node-graph/graph-craft/src/gpu/executor.rs
Normal file
|
@ -0,0 +1,188 @@
|
|||
use std::path::Path;
|
||||
|
||||
use super::{compiler::Metadata, context::Context};
|
||||
use crate::gpu::compiler;
|
||||
use bytemuck::Pod;
|
||||
use dyn_any::StaticTypeSized;
|
||||
use vulkano::{
|
||||
buffer::{self, BufferUsage, CpuAccessibleBuffer},
|
||||
command_buffer::{allocator::StandardCommandBufferAllocator, AutoCommandBufferBuilder, CommandBufferUsage},
|
||||
descriptor_set::{allocator::StandardDescriptorSetAllocator, PersistentDescriptorSet, WriteDescriptorSet},
|
||||
device::Device,
|
||||
memory::allocator::StandardMemoryAllocator,
|
||||
pipeline::{ComputePipeline, Pipeline, PipelineBindPoint},
|
||||
sync::GpuFuture,
|
||||
};
|
||||
|
||||
use crate::proto::*;
|
||||
use graphene_core::gpu::PushConstants;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GpuExecutor<I: StaticTypeSized, O> {
|
||||
context: Context,
|
||||
entry_point: String,
|
||||
shader: std::sync::Arc<vulkano::shader::ShaderModule>,
|
||||
_phantom: std::marker::PhantomData<(I, O)>,
|
||||
}
|
||||
|
||||
impl<I: StaticTypeSized, O> GpuExecutor<I, O> {
|
||||
pub fn new(context: Context, network: ProtoNetwork, metadata: Metadata, compile_dir: &Path) -> anyhow::Result<Self> {
|
||||
compiler::create_files(&metadata, &network, compile_dir, std::any::type_name::<I>(), std::any::type_name::<O>())?;
|
||||
let result = compiler::compile(compile_dir)?;
|
||||
|
||||
let bytes = std::fs::read(result.module.unwrap_single())?;
|
||||
let shader = unsafe { vulkano::shader::ShaderModule::from_bytes(context.device.clone(), &bytes)? };
|
||||
let entry_point = result.entry_points.first().expect("No entry points").clone();
|
||||
|
||||
Ok(Self {
|
||||
context,
|
||||
entry_point,
|
||||
shader,
|
||||
_phantom: std::marker::PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: StaticTypeSized + Sync + Pod + Send, O: StaticTypeSized + Send + Sync + Pod> crate::executor::Executor for GpuExecutor<I, O> {
|
||||
fn execute(&self, input: graphene_std::any::Any<'static>) -> Result<graphene_std::any::Any<'static>, Box<dyn std::error::Error>> {
|
||||
let input = dyn_any::downcast::<Vec<I>>(input).expect("Wrong input type");
|
||||
let context = &self.context;
|
||||
let result: Vec<O> = execute_shader(
|
||||
context.device.clone(),
|
||||
context.queue.clone(),
|
||||
self.shader.entry_point(&self.entry_point).expect("Entry point not found in shader"),
|
||||
&context.allocator,
|
||||
&context.command_buffer_allocator,
|
||||
*input,
|
||||
);
|
||||
Ok(Box::new(result))
|
||||
}
|
||||
}
|
||||
|
||||
fn execute_shader<I: Pod + Send + Sync, O: Pod + Send + Sync>(
|
||||
device: std::sync::Arc<Device>,
|
||||
queue: std::sync::Arc<vulkano::device::Queue>,
|
||||
entry_point: vulkano::shader::EntryPoint,
|
||||
alloc: &StandardMemoryAllocator,
|
||||
calloc: &StandardCommandBufferAllocator,
|
||||
data: Vec<I>,
|
||||
) -> Vec<O> {
|
||||
let constants = PushConstants { n: data.len() as u32, node: 0 };
|
||||
|
||||
let dest_data: Vec<_> = (0..constants.n).map(|_| O::zeroed()).collect();
|
||||
let source_buffer = create_buffer(data, alloc).expect("failed to create buffer");
|
||||
let dest_buffer = create_buffer(dest_data, alloc).expect("failed to create buffer");
|
||||
|
||||
let compute_pipeline = ComputePipeline::new(device.clone(), entry_point, &(), None, |_| {}).expect("failed to create compute pipeline");
|
||||
let layout = compute_pipeline.layout().set_layouts().get(0).unwrap();
|
||||
let dalloc = StandardDescriptorSetAllocator::new(device.clone());
|
||||
let set = PersistentDescriptorSet::new(
|
||||
&dalloc,
|
||||
layout.clone(),
|
||||
[
|
||||
WriteDescriptorSet::buffer(0, source_buffer), // 0 is the binding
|
||||
WriteDescriptorSet::buffer(1, dest_buffer.clone()),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let mut builder = AutoCommandBufferBuilder::primary(calloc, queue.queue_family_index(), CommandBufferUsage::OneTimeSubmit).unwrap();
|
||||
|
||||
builder
|
||||
.bind_pipeline_compute(compute_pipeline.clone())
|
||||
.bind_descriptor_sets(PipelineBindPoint::Compute, compute_pipeline.layout().clone(), 0, set)
|
||||
.push_constants(compute_pipeline.layout().clone(), 0, constants)
|
||||
.dispatch([1024, 1, 1])
|
||||
.unwrap();
|
||||
let command_buffer = builder.build().unwrap();
|
||||
|
||||
let future = vulkano::sync::now(device).then_execute(queue, command_buffer).unwrap().then_signal_fence_and_flush().unwrap();
|
||||
#[cfg(feature = "profiling")]
|
||||
nvtx::range_push!("compute");
|
||||
future.wait(None).unwrap();
|
||||
#[cfg(feature = "profiling")]
|
||||
nvtx::range_pop!();
|
||||
let content = dest_buffer.read().unwrap();
|
||||
content.to_vec()
|
||||
}
|
||||
|
||||
fn create_buffer<T: Pod + Send + Sync>(data: Vec<T>, alloc: &StandardMemoryAllocator) -> Result<std::sync::Arc<CpuAccessibleBuffer<[T]>>, vulkano::memory::allocator::AllocationCreationError> {
|
||||
let buffer_usage = BufferUsage {
|
||||
storage_buffer: true,
|
||||
transfer_src: true,
|
||||
transfer_dst: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
buffer::CpuAccessibleBuffer::from_iter(alloc, buffer_usage, false, data.into_iter())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::concrete;
|
||||
use crate::generic;
|
||||
use crate::gpu::compiler;
|
||||
|
||||
fn inc_network() -> ProtoNetwork {
|
||||
let mut construction_network = ProtoNetwork {
|
||||
inputs: vec![10],
|
||||
output: 1,
|
||||
nodes: [
|
||||
(
|
||||
1,
|
||||
ProtoNode {
|
||||
identifier: NodeIdentifier::new("graphene_core::ops::IdNode", &[generic!("u32")]),
|
||||
input: ProtoNodeInput::Node(11),
|
||||
construction_args: ConstructionArgs::Nodes(vec![]),
|
||||
},
|
||||
),
|
||||
(
|
||||
10,
|
||||
ProtoNode {
|
||||
identifier: NodeIdentifier::new("graphene_core::structural::ConsNode", &[generic!("&ValueNode<u32>"), generic!("()")]),
|
||||
input: ProtoNodeInput::Network,
|
||||
construction_args: ConstructionArgs::Nodes(vec![14]),
|
||||
},
|
||||
),
|
||||
(
|
||||
11,
|
||||
ProtoNode {
|
||||
identifier: NodeIdentifier::new("graphene_core::ops::AddNode", &[generic!("u32"), generic!("u32")]),
|
||||
input: ProtoNodeInput::Node(10),
|
||||
construction_args: ConstructionArgs::Nodes(vec![]),
|
||||
},
|
||||
),
|
||||
(
|
||||
14,
|
||||
ProtoNode {
|
||||
identifier: NodeIdentifier::new("graphene_core::value::ValueNode", &[concrete!("u32")]),
|
||||
input: ProtoNodeInput::None,
|
||||
construction_args: ConstructionArgs::Value(Box::new(3_u32)),
|
||||
},
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
construction_network.resolve_inputs();
|
||||
construction_network.reorder_ids();
|
||||
construction_network
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_on_gpu() {
|
||||
use crate::executor::Executor;
|
||||
let m = compiler::Metadata::new("project".to_owned(), vec!["test@example.com".to_owned()]);
|
||||
let network = inc_network();
|
||||
let temp_dir = tempfile::tempdir().expect("failed to create tempdir");
|
||||
|
||||
let executor: GpuExecutor<u32, u32> = GpuExecutor::new(Context::new(), network, m, temp_dir.path()).unwrap();
|
||||
|
||||
let data: Vec<_> = (0..1024).map(|x| x as u32).collect();
|
||||
let result = executor.execute(Box::new(data)).unwrap();
|
||||
let result = dyn_any::downcast::<Vec<u32>>(result).unwrap();
|
||||
for (i, r) in result.iter().enumerate() {
|
||||
assert_eq!(*r, i as u32 + 3);
|
||||
}
|
||||
}
|
||||
}
|
17
node-graph/graph-craft/src/gpu/templates/Cargo-template.toml
Normal file
17
node-graph/graph-craft/src/gpu/templates/Cargo-template.toml
Normal file
|
@ -0,0 +1,17 @@
|
|||
[package]
|
||||
name = "{{name}}-node"
|
||||
version = "0.1.0"
|
||||
authors = [{%for author in authors%}"{{author}}", {%endfor%}]
|
||||
edition = "2021"
|
||||
license = "MIT OR Apache-2.0"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
crate-type = ["dylib", "lib"]
|
||||
|
||||
[patch.crates-io]
|
||||
libm = { git = "https://github.com/rust-lang/libm", tag = "0.2.5" }
|
||||
|
||||
[dependencies]
|
||||
spirv-std = { git = "https://github.com/EmbarkStudios/rust-gpu" , features= ["glam"]}
|
||||
graphene-core = {path = "{{gcore_path}}", default-features = false, features = ["gpu"]}
|
38
node-graph/graph-craft/src/gpu/templates/spirv-template.rs
Normal file
38
node-graph/graph-craft/src/gpu/templates/spirv-template.rs
Normal file
|
@ -0,0 +1,38 @@
|
|||
#![no_std]
|
||||
#![feature(unchecked_math)]
|
||||
#![deny(warnings)]
|
||||
|
||||
#[cfg(target_arch = "spirv")]
|
||||
extern crate spirv_std;
|
||||
|
||||
#[cfg(target_arch = "spirv")]
|
||||
pub mod gpu {
|
||||
use super::*;
|
||||
use spirv_std::spirv;
|
||||
use spirv_std::glam::UVec3;
|
||||
|
||||
#[allow(unused)]
|
||||
#[spirv(compute(threads({{compute_threads}})))]
|
||||
pub fn eval (
|
||||
#[spirv(global_invocation_id)] global_id: UVec3,
|
||||
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] a: &[{{input_type}}],
|
||||
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [{{output_type}}],
|
||||
#[spirv(push_constant)] push_consts: &graphene_core::gpu::PushConstants,
|
||||
) {
|
||||
let gid = global_id.x as usize;
|
||||
// Only process up to n, which is the length of the buffers.
|
||||
if global_id.x < push_consts.n {
|
||||
y[gid] = node_graph(a[gid]);
|
||||
}
|
||||
}
|
||||
|
||||
fn node_graph(input: {{input_type}}) -> {{output_type}} {
|
||||
use graphene_core::Node;
|
||||
|
||||
{% for node in nodes %}
|
||||
let {{node.id}} = {{node.fqn}}::new({% for arg in node.args %}{{arg}}, {% endfor %});
|
||||
{% endfor %}
|
||||
{{last_node}}.eval(input)
|
||||
}
|
||||
|
||||
}
|
|
@ -6,6 +6,11 @@ pub mod node_registry;
|
|||
pub mod document;
|
||||
pub mod proto;
|
||||
|
||||
pub mod executor;
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub mod gpu;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
|
@ -15,7 +20,6 @@ mod tests {
|
|||
use graphene_core::{structural::*, RefNode};
|
||||
|
||||
use borrow_stack::BorrowStack;
|
||||
use borrow_stack::FixedSizeStack;
|
||||
use dyn_any::{downcast, IntoDynAny};
|
||||
use graphene_std::any::{Any, DowncastNode, DynAnyNode, TypeErasedNode};
|
||||
use graphene_std::ops::AddNode;
|
||||
|
@ -56,9 +60,7 @@ mod tests {
|
|||
#[test]
|
||||
fn execute_add() {
|
||||
use crate::document::*;
|
||||
use crate::node_registry::push_node;
|
||||
use crate::proto::*;
|
||||
use graphene_core::Node;
|
||||
|
||||
fn add_network() -> NodeNetwork {
|
||||
NodeNetwork {
|
||||
|
@ -95,7 +97,7 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
let mut network = NodeNetwork {
|
||||
let network = NodeNetwork {
|
||||
inputs: vec![0],
|
||||
output: 0,
|
||||
nodes: [(
|
||||
|
@ -117,18 +119,14 @@ mod tests {
|
|||
.collect(),
|
||||
};
|
||||
|
||||
let stack = FixedSizeStack::new(256);
|
||||
println!("flattening");
|
||||
network.flatten(0);
|
||||
//println!("flat_network: {:#?}", network);
|
||||
let mut proto_network = network.into_proto_network();
|
||||
proto_network.reorder_ids();
|
||||
//println!("reordered_ides: {:#?}", proto_network);
|
||||
for (_id, node) in proto_network.nodes {
|
||||
push_node(node, &stack);
|
||||
}
|
||||
use crate::executor::{Compiler, DynamicExecutor, Executor};
|
||||
|
||||
let result = unsafe { stack.get().last().unwrap().eval(32_u32.into_dyn()) };
|
||||
let compiler = Compiler {};
|
||||
let protograph = compiler.compile(network, false);
|
||||
|
||||
let exec = DynamicExecutor::new(protograph);
|
||||
|
||||
let result = exec.execute(32_u32.into_dyn()).unwrap();
|
||||
let val = *dyn_any::downcast::<u32>(result).unwrap();
|
||||
assert_eq!(val, 33_u32);
|
||||
}
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
use std::borrow::Cow;
|
||||
|
||||
use borrow_stack::FixedSizeStack;
|
||||
use glam::DVec2;
|
||||
use graphene_core::generic::FnNode;
|
||||
use graphene_core::ops::AddNode;
|
||||
use graphene_core::ops::{AddNode, IdNode};
|
||||
use graphene_core::raster::color::Color;
|
||||
use graphene_core::structural::{ConsNode, Then};
|
||||
use graphene_core::Node;
|
||||
|
@ -13,105 +11,78 @@ use graphene_std::raster::Image;
|
|||
use graphene_std::vector::subpath::Subpath;
|
||||
|
||||
use crate::proto::Type;
|
||||
use crate::proto::{ConstructionArgs, NodeIdentifier, ProtoNode, ProtoNodeInput, Type::Concrete};
|
||||
use crate::proto::{ConstructionArgs, NodeIdentifier, ProtoNode, ProtoNodeInput};
|
||||
|
||||
type NodeConstructor = fn(ProtoNode, &FixedSizeStack<TypeErasedNode<'static>>);
|
||||
|
||||
use crate::{concrete, generic};
|
||||
|
||||
//TODO: turn into hasmap
|
||||
static NODE_REGISTRY: &[(NodeIdentifier, NodeConstructor)] = &[
|
||||
(
|
||||
NodeIdentifier::new("graphene_core::ops::IdNode", &[Concrete(std::borrow::Cow::Borrowed("Any<'_>"))]),
|
||||
|proto_node, stack| {
|
||||
stack.push_fn(|nodes| {
|
||||
if let ProtoNodeInput::Node(pre_id) = proto_node.input {
|
||||
let pre_node = nodes.get(pre_id as usize).unwrap();
|
||||
let node = pre_node.then(graphene_core::ops::IdNode);
|
||||
node.into_type_erased()
|
||||
} else {
|
||||
graphene_core::ops::IdNode.into_type_erased()
|
||||
}
|
||||
})
|
||||
},
|
||||
),
|
||||
(NodeIdentifier::new("graphene_core::ops::IdNode", &[Type::Generic]), |proto_node, stack| {
|
||||
(NodeIdentifier::new("graphene_core::ops::IdNode", &[concrete!("Any<'_>")]), |proto_node, stack| {
|
||||
stack.push_fn(|nodes| {
|
||||
if let ProtoNodeInput::Node(pre_id) = proto_node.input {
|
||||
let pre_node = nodes.get(pre_id as usize).unwrap();
|
||||
let node = pre_node.then(graphene_core::ops::IdNode);
|
||||
node.into_type_erased()
|
||||
pre_node.into_type_erased()
|
||||
} else {
|
||||
graphene_core::ops::IdNode.into_type_erased()
|
||||
}
|
||||
})
|
||||
}),
|
||||
(
|
||||
NodeIdentifier::new("graphene_core::ops::AddNode", &[Type::Concrete(Cow::Borrowed("&TypeErasedNode"))]),
|
||||
|proto_node, stack| {
|
||||
stack.push_fn(move |nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("Add Node constructed with out rhs input node") };
|
||||
let value_node = nodes.get(construction_nodes[0] as usize).unwrap();
|
||||
let input_node: DowncastBothNode<_, (), f64> = DowncastBothNode::new(value_node);
|
||||
let node: DynAnyNode<_, f64, _, _> = DynAnyNode::new(ConsNode::new(input_node).then(graphene_core::ops::AddNode));
|
||||
|
||||
if let ProtoNodeInput::Node(node_id) = proto_node.input {
|
||||
let pre_node = nodes.get(node_id as usize).unwrap();
|
||||
(pre_node).then(node).into_type_erased()
|
||||
} else {
|
||||
node.into_type_erased()
|
||||
}
|
||||
})
|
||||
},
|
||||
),
|
||||
(
|
||||
NodeIdentifier::new(
|
||||
"graphene_core::ops::AddNode",
|
||||
&[Concrete(std::borrow::Cow::Borrowed("u32")), Concrete(std::borrow::Cow::Borrowed("u32"))],
|
||||
),
|
||||
|proto_node, stack| {
|
||||
stack.push_fn(|nodes| {
|
||||
let pre_node = nodes.get(proto_node.input.unwrap_node() as usize).unwrap();
|
||||
let node: DynAnyNode<AddNode, (u32, u32), _, _> = DynAnyNode::new(graphene_core::ops::AddNode);
|
||||
let node = (pre_node).then(node);
|
||||
(NodeIdentifier::new("graphene_core::ops::IdNode", &[generic!("T")]), |proto_node, stack| {
|
||||
stack.push_fn(|nodes| {
|
||||
if let ProtoNodeInput::Node(pre_id) = proto_node.input {
|
||||
let pre_node = nodes.get(pre_id as usize).unwrap();
|
||||
pre_node.into_type_erased()
|
||||
} else {
|
||||
graphene_core::ops::IdNode.into_type_erased()
|
||||
}
|
||||
})
|
||||
}),
|
||||
(NodeIdentifier::new("graphene_core::ops::AddNode", &[concrete!("&TypeErasedNode")]), |proto_node, stack| {
|
||||
stack.push_fn(move |nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("Add Node constructed with out rhs input node") };
|
||||
let value_node = nodes.get(construction_nodes[0] as usize).unwrap();
|
||||
let input_node: DowncastBothNode<_, (), f32> = DowncastBothNode::new(value_node);
|
||||
let node: DynAnyNode<_, f32, _, _> = DynAnyNode::new(ConsNode::new(input_node).then(graphene_core::ops::AddNode));
|
||||
|
||||
if let ProtoNodeInput::Node(node_id) = proto_node.input {
|
||||
let pre_node = nodes.get(node_id as usize).unwrap();
|
||||
(pre_node).then(node).into_type_erased()
|
||||
} else {
|
||||
node.into_type_erased()
|
||||
})
|
||||
},
|
||||
),
|
||||
(
|
||||
NodeIdentifier::new(
|
||||
"graphene_core::ops::AddNode",
|
||||
&[Concrete(std::borrow::Cow::Borrowed("&u32")), Concrete(std::borrow::Cow::Borrowed("&u32"))],
|
||||
),
|
||||
|proto_node, stack| {
|
||||
stack.push_fn(|nodes| {
|
||||
let pre_node = nodes.get(proto_node.input.unwrap_node() as usize).unwrap();
|
||||
let node: DynAnyNode<AddNode, (&u32, &u32), _, _> = DynAnyNode::new(graphene_core::ops::AddNode);
|
||||
let node = (pre_node).then(node);
|
||||
}
|
||||
})
|
||||
}),
|
||||
(NodeIdentifier::new("graphene_core::ops::AddNode", &[concrete!("u32"), concrete!("u32")]), |proto_node, stack| {
|
||||
stack.push_fn(|nodes| {
|
||||
let pre_node = nodes.get(proto_node.input.unwrap_node() as usize).unwrap();
|
||||
let node: DynAnyNode<AddNode, (u32, u32), _, _> = DynAnyNode::new(graphene_core::ops::AddNode);
|
||||
let node = (pre_node).then(node);
|
||||
|
||||
node.into_type_erased()
|
||||
})
|
||||
},
|
||||
),
|
||||
(
|
||||
NodeIdentifier::new(
|
||||
"graphene_core::ops::AddNode",
|
||||
&[Concrete(std::borrow::Cow::Borrowed("&u32")), Concrete(std::borrow::Cow::Borrowed("u32"))],
|
||||
),
|
||||
|proto_node, stack| {
|
||||
stack.push_fn(|nodes| {
|
||||
let pre_node = nodes.get(proto_node.input.unwrap_node() as usize).unwrap();
|
||||
let node: DynAnyNode<AddNode, (&u32, u32), _, _> = DynAnyNode::new(graphene_core::ops::AddNode);
|
||||
let node = (pre_node).then(node);
|
||||
node.into_type_erased()
|
||||
})
|
||||
}),
|
||||
(NodeIdentifier::new("graphene_core::ops::AddNode", &[concrete!("&u32"), concrete!("&u32")]), |proto_node, stack| {
|
||||
stack.push_fn(|nodes| {
|
||||
let pre_node = nodes.get(proto_node.input.unwrap_node() as usize).unwrap();
|
||||
let node: DynAnyNode<AddNode, (&u32, &u32), _, _> = DynAnyNode::new(graphene_core::ops::AddNode);
|
||||
let node = (pre_node).then(node);
|
||||
|
||||
node.into_type_erased()
|
||||
})
|
||||
},
|
||||
),
|
||||
node.into_type_erased()
|
||||
})
|
||||
}),
|
||||
(NodeIdentifier::new("graphene_core::ops::AddNode", &[concrete!("&u32"), concrete!("u32")]), |proto_node, stack| {
|
||||
stack.push_fn(|nodes| {
|
||||
let pre_node = nodes.get(proto_node.input.unwrap_node() as usize).unwrap();
|
||||
let node: DynAnyNode<AddNode, (&u32, u32), _, _> = DynAnyNode::new(graphene_core::ops::AddNode);
|
||||
let node = (pre_node).then(node);
|
||||
|
||||
node.into_type_erased()
|
||||
})
|
||||
}),
|
||||
(
|
||||
NodeIdentifier::new(
|
||||
"graphene_core::structural::ConsNode",
|
||||
&[Concrete(std::borrow::Cow::Borrowed("&u32")), Concrete(std::borrow::Cow::Borrowed("u32"))],
|
||||
),
|
||||
NodeIdentifier::new("graphene_core::structural::ConsNode", &[concrete!("&u32"), concrete!("u32")]),
|
||||
|proto_node, stack| {
|
||||
if let ConstructionArgs::Nodes(cons_node_arg) = proto_node.construction_args {
|
||||
stack.push_fn(move |nodes| {
|
||||
|
@ -135,10 +106,7 @@ static NODE_REGISTRY: &[(NodeIdentifier, NodeConstructor)] = &[
|
|||
},
|
||||
),
|
||||
(
|
||||
NodeIdentifier::new(
|
||||
"graphene_core::structural::ConsNode",
|
||||
&[Concrete(std::borrow::Cow::Borrowed("u32")), Concrete(std::borrow::Cow::Borrowed("u32"))],
|
||||
),
|
||||
NodeIdentifier::new("graphene_core::structural::ConsNode", &[concrete!("u32"), concrete!("u32")]),
|
||||
|proto_node, stack| {
|
||||
if let ConstructionArgs::Nodes(cons_node_arg) = proto_node.construction_args {
|
||||
stack.push_fn(move |nodes| {
|
||||
|
@ -163,10 +131,7 @@ static NODE_REGISTRY: &[(NodeIdentifier, NodeConstructor)] = &[
|
|||
),
|
||||
// TODO: create macro to impl for all types
|
||||
(
|
||||
NodeIdentifier::new(
|
||||
"graphene_core::structural::ConsNode",
|
||||
&[Concrete(std::borrow::Cow::Borrowed("&u32")), Concrete(std::borrow::Cow::Borrowed("&u32"))],
|
||||
),
|
||||
NodeIdentifier::new("graphene_core::structural::ConsNode", &[concrete!("&u32"), concrete!("&u32")]),
|
||||
|proto_node, stack| {
|
||||
let node_id = proto_node.input.unwrap_node() as usize;
|
||||
if let ConstructionArgs::Nodes(cons_node_arg) = proto_node.construction_args {
|
||||
|
@ -184,31 +149,25 @@ static NODE_REGISTRY: &[(NodeIdentifier, NodeConstructor)] = &[
|
|||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
NodeIdentifier::new("graphene_core::any::DowncastNode", &[Concrete(std::borrow::Cow::Borrowed("&u32"))]),
|
||||
|proto_node, stack| {
|
||||
stack.push_fn(|nodes| {
|
||||
let pre_node = nodes.get(proto_node.input.unwrap_node() as usize).unwrap();
|
||||
let node = pre_node.then(graphene_core::ops::IdNode);
|
||||
node.into_type_erased()
|
||||
})
|
||||
},
|
||||
),
|
||||
(
|
||||
NodeIdentifier::new("graphene_core::value::ValueNode", &[Concrete(std::borrow::Cow::Borrowed("Any<'_>"))]),
|
||||
|proto_node, stack| {
|
||||
stack.push_fn(|_nodes| {
|
||||
if let ConstructionArgs::Value(value) = proto_node.construction_args {
|
||||
let node = FnNode::new(move |_| value.clone().up_box() as Any<'static>);
|
||||
(NodeIdentifier::new("graphene_core::any::DowncastNode", &[concrete!("&u32")]), |proto_node, stack| {
|
||||
stack.push_fn(|nodes| {
|
||||
let pre_node = nodes.get(proto_node.input.unwrap_node() as usize).unwrap();
|
||||
let node = pre_node.then(graphene_core::ops::IdNode);
|
||||
node.into_type_erased()
|
||||
})
|
||||
}),
|
||||
(NodeIdentifier::new("graphene_core::value::ValueNode", &[concrete!("Any<'>")]), |proto_node, stack| {
|
||||
stack.push_fn(|_nodes| {
|
||||
if let ConstructionArgs::Value(value) = proto_node.construction_args {
|
||||
let node = FnNode::new(move |_| value.clone().up_box() as Any<'static>);
|
||||
|
||||
node.into_type_erased()
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
})
|
||||
},
|
||||
),
|
||||
(NodeIdentifier::new("graphene_core::value::ValueNode", &[Type::Generic]), |proto_node, stack| {
|
||||
node.into_type_erased()
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
})
|
||||
}),
|
||||
(NodeIdentifier::new("graphene_core::value::ValueNode", &[generic!("T")]), |proto_node, stack| {
|
||||
stack.push_fn(|_nodes| {
|
||||
if let ConstructionArgs::Value(value) = proto_node.construction_args {
|
||||
let node = FnNode::new(move |_| value.clone().up_box() as Any<'static>);
|
||||
|
@ -230,44 +189,38 @@ static NODE_REGISTRY: &[(NodeIdentifier, NodeConstructor)] = &[
|
|||
}
|
||||
})
|
||||
}),
|
||||
(
|
||||
NodeIdentifier::new("graphene_core::raster::BrightenColorNode", &[Type::Concrete(Cow::Borrowed("&TypeErasedNode"))]),
|
||||
|proto_node, stack| {
|
||||
info!("proto node {:?}", proto_node);
|
||||
stack.push_fn(|nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("Brighten Color Node constructed with out brightness input node") };
|
||||
let value_node = nodes.get(construction_nodes[0] as usize).unwrap();
|
||||
let input_node: DowncastBothNode<_, (), f32> = DowncastBothNode::new(value_node);
|
||||
let node = DynAnyNode::new(graphene_core::raster::BrightenColorNode::new(input_node));
|
||||
(NodeIdentifier::new("graphene_core::raster::BrightenColorNode", &[concrete!("&TypeErasedNode")]), |proto_node, stack| {
|
||||
info!("proto node {:?}", proto_node);
|
||||
stack.push_fn(|nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("Brighten Color Node constructed with out brightness input node") };
|
||||
let value_node = nodes.get(construction_nodes[0] as usize).unwrap();
|
||||
let input_node: DowncastBothNode<_, (), f32> = DowncastBothNode::new(value_node);
|
||||
let node = DynAnyNode::new(graphene_core::raster::BrightenColorNode::new(input_node));
|
||||
|
||||
if let ProtoNodeInput::Node(pre_id) = proto_node.input {
|
||||
let pre_node = nodes.get(pre_id as usize).unwrap();
|
||||
(pre_node).then(node).into_type_erased()
|
||||
} else {
|
||||
node.into_type_erased()
|
||||
}
|
||||
})
|
||||
},
|
||||
),
|
||||
(
|
||||
NodeIdentifier::new("graphene_core::raster::HueShiftColorNode", &[Type::Concrete(Cow::Borrowed("&TypeErasedNode"))]),
|
||||
|proto_node, stack| {
|
||||
info!("proto node {:?}", proto_node);
|
||||
stack.push_fn(|nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("Hue Shift Color Node constructed with out shift input node") };
|
||||
let value_node = nodes.get(construction_nodes[0] as usize).unwrap();
|
||||
let input_node: DowncastBothNode<_, (), f32> = DowncastBothNode::new(value_node);
|
||||
let node = DynAnyNode::new(graphene_core::raster::HueShiftColorNode::new(input_node));
|
||||
if let ProtoNodeInput::Node(pre_id) = proto_node.input {
|
||||
let pre_node = nodes.get(pre_id as usize).unwrap();
|
||||
(pre_node).then(node).into_type_erased()
|
||||
} else {
|
||||
node.into_type_erased()
|
||||
}
|
||||
})
|
||||
}),
|
||||
(NodeIdentifier::new("graphene_core::raster::HueShiftColorNode", &[concrete!("&TypeErasedNode")]), |proto_node, stack| {
|
||||
info!("proto node {:?}", proto_node);
|
||||
stack.push_fn(|nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("Hue Shift Color Node constructed with out shift input node") };
|
||||
let value_node = nodes.get(construction_nodes[0] as usize).unwrap();
|
||||
let input_node: DowncastBothNode<_, (), f32> = DowncastBothNode::new(value_node);
|
||||
let node = DynAnyNode::new(graphene_core::raster::HueShiftColorNode::new(input_node));
|
||||
|
||||
if let ProtoNodeInput::Node(pre_id) = proto_node.input {
|
||||
let pre_node = nodes.get(pre_id as usize).unwrap();
|
||||
(pre_node).then(node).into_type_erased()
|
||||
} else {
|
||||
node.into_type_erased()
|
||||
}
|
||||
})
|
||||
},
|
||||
),
|
||||
if let ProtoNodeInput::Node(pre_id) = proto_node.input {
|
||||
let pre_node = nodes.get(pre_id as usize).unwrap();
|
||||
(pre_node).then(node).into_type_erased()
|
||||
} else {
|
||||
node.into_type_erased()
|
||||
}
|
||||
})
|
||||
}),
|
||||
(NodeIdentifier::new("graphene_std::raster::MapImageNode", &[]), |proto_node, stack| {
|
||||
if let ConstructionArgs::Nodes(operation_node_id) = proto_node.construction_args {
|
||||
stack.push_fn(move |nodes| {
|
||||
|
@ -311,28 +264,24 @@ static NODE_REGISTRY: &[(NodeIdentifier, NodeConstructor)] = &[
|
|||
}
|
||||
})
|
||||
}),
|
||||
(
|
||||
NodeIdentifier::new("graphene_std::raster::HueSaturationNode", &[Type::Concrete(Cow::Borrowed("&TypeErasedNode"))]),
|
||||
|proto_node, stack| {
|
||||
stack.push_fn(move |nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("HueSaturationNode Node constructed without inputs") };
|
||||
(NodeIdentifier::new("graphene_std::raster::HueSaturationNode", &[concrete!("&TypeErasedNode")]), |proto_node, stack| {
|
||||
stack.push_fn(move |nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("HueSaturationNode Node constructed without inputs") };
|
||||
|
||||
let hue: DowncastBothNode<_, (), f64> = DowncastBothNode::new(nodes.get(construction_nodes[0] as usize).unwrap());
|
||||
let saturation: DowncastBothNode<_, (), f64> = DowncastBothNode::new(nodes.get(construction_nodes[1] as usize).unwrap());
|
||||
let lightness: DowncastBothNode<_, (), f64> = DowncastBothNode::new(nodes.get(construction_nodes[2] as usize).unwrap());
|
||||
let node = DynAnyNode::new(graphene_std::raster::HueSaturationNode::new(hue, saturation, lightness));
|
||||
|
||||
if let ProtoNodeInput::Node(node_id) = proto_node.input {
|
||||
let pre_node = nodes.get(node_id as usize).unwrap();
|
||||
(pre_node).then(node).into_type_erased()
|
||||
} else {
|
||||
node.into_type_erased()
|
||||
}
|
||||
})
|
||||
},
|
||||
),
|
||||
let hue: DowncastBothNode<_, (), f64> = DowncastBothNode::new(nodes.get(construction_nodes[0] as usize).unwrap());
|
||||
let saturation: DowncastBothNode<_, (), f64> = DowncastBothNode::new(nodes.get(construction_nodes[1] as usize).unwrap());
|
||||
let lightness: DowncastBothNode<_, (), f64> = DowncastBothNode::new(nodes.get(construction_nodes[2] as usize).unwrap());
|
||||
let node = DynAnyNode::new(graphene_std::raster::HueSaturationNode::new(hue, saturation, lightness));
|
||||
if let ProtoNodeInput::Node(node_id) = proto_node.input {
|
||||
let pre_node = nodes.get(node_id as usize).unwrap();
|
||||
(pre_node).then(node).into_type_erased()
|
||||
} else {
|
||||
node.into_type_erased()
|
||||
}
|
||||
})
|
||||
}),
|
||||
(
|
||||
NodeIdentifier::new("graphene_std::raster::BrightnessContrastNode", &[Type::Concrete(Cow::Borrowed("&TypeErasedNode"))]),
|
||||
NodeIdentifier::new("graphene_std::raster::BrightnessContrastNode", &[concrete!("&TypeErasedNode")]),
|
||||
|proto_node, stack| {
|
||||
stack.push_fn(move |nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("BrightnessContrastNode Node constructed without inputs") };
|
||||
|
@ -350,102 +299,81 @@ static NODE_REGISTRY: &[(NodeIdentifier, NodeConstructor)] = &[
|
|||
})
|
||||
},
|
||||
),
|
||||
(
|
||||
NodeIdentifier::new("graphene_std::raster::GammaNode", &[Type::Concrete(Cow::Borrowed("&TypeErasedNode"))]),
|
||||
|proto_node, stack| {
|
||||
stack.push_fn(move |nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("GammaNode Node constructed without inputs") };
|
||||
let gamma: DowncastBothNode<_, (), f64> = DowncastBothNode::new(nodes.get(construction_nodes[0] as usize).unwrap());
|
||||
let node = DynAnyNode::new(graphene_std::raster::GammaNode::new(gamma));
|
||||
(NodeIdentifier::new("graphene_std::raster::GammaNode", &[concrete!("&TypeErasedNode")]), |proto_node, stack| {
|
||||
stack.push_fn(move |nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("GammaNode Node constructed without inputs") };
|
||||
let gamma: DowncastBothNode<_, (), f64> = DowncastBothNode::new(nodes.get(construction_nodes[0] as usize).unwrap());
|
||||
let node = DynAnyNode::new(graphene_std::raster::GammaNode::new(gamma));
|
||||
|
||||
if let ProtoNodeInput::Node(node_id) = proto_node.input {
|
||||
let pre_node = nodes.get(node_id as usize).unwrap();
|
||||
(pre_node).then(node).into_type_erased()
|
||||
} else {
|
||||
node.into_type_erased()
|
||||
}
|
||||
})
|
||||
},
|
||||
),
|
||||
(
|
||||
NodeIdentifier::new("graphene_std::raster::OpacityNode", &[Type::Concrete(Cow::Borrowed("&TypeErasedNode"))]),
|
||||
|proto_node, stack| {
|
||||
stack.push_fn(move |nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("OpacityNode Node constructed without inputs") };
|
||||
let opacity: DowncastBothNode<_, (), f64> = DowncastBothNode::new(nodes.get(construction_nodes[0] as usize).unwrap());
|
||||
let node = DynAnyNode::new(graphene_std::raster::OpacityNode::new(opacity));
|
||||
|
||||
if let ProtoNodeInput::Node(node_id) = proto_node.input {
|
||||
let pre_node = nodes.get(node_id as usize).unwrap();
|
||||
(pre_node).then(node).into_type_erased()
|
||||
} else {
|
||||
node.into_type_erased()
|
||||
}
|
||||
})
|
||||
},
|
||||
),
|
||||
(
|
||||
NodeIdentifier::new("graphene_std::raster::PosterizeNode", &[Type::Concrete(Cow::Borrowed("&TypeErasedNode"))]),
|
||||
|proto_node, stack| {
|
||||
stack.push_fn(move |nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("Posterize node constructed without inputs") };
|
||||
let value: DowncastBothNode<_, (), f64> = DowncastBothNode::new(nodes.get(construction_nodes[0] as usize).unwrap());
|
||||
let node = DynAnyNode::new(graphene_std::raster::PosterizeNode::new(value));
|
||||
|
||||
if let ProtoNodeInput::Node(node_id) = proto_node.input {
|
||||
let pre_node = nodes.get(node_id as usize).unwrap();
|
||||
(pre_node).then(node).into_type_erased()
|
||||
} else {
|
||||
node.into_type_erased()
|
||||
}
|
||||
})
|
||||
},
|
||||
),
|
||||
(
|
||||
NodeIdentifier::new("graphene_std::raster::ExposureNode", &[Type::Concrete(Cow::Borrowed("&TypeErasedNode"))]),
|
||||
|proto_node, stack| {
|
||||
stack.push_fn(move |nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("ExposureNode constructed without inputs") };
|
||||
let value: DowncastBothNode<_, (), f64> = DowncastBothNode::new(nodes.get(construction_nodes[0] as usize).unwrap());
|
||||
let node = DynAnyNode::new(graphene_std::raster::ExposureNode::new(value));
|
||||
|
||||
if let ProtoNodeInput::Node(node_id) = proto_node.input {
|
||||
let pre_node = nodes.get(node_id as usize).unwrap();
|
||||
(pre_node).then(node).into_type_erased()
|
||||
} else {
|
||||
node.into_type_erased()
|
||||
}
|
||||
})
|
||||
},
|
||||
),
|
||||
(
|
||||
NodeIdentifier::new("graphene_std::raster::ImageNode", &[Concrete(std::borrow::Cow::Borrowed("&str"))]),
|
||||
|_proto_node, stack| {
|
||||
stack.push_fn(|_nodes| {
|
||||
let image = FnNode::new(|s: &str| graphene_std::raster::image_node::<&str>().eval(s).unwrap());
|
||||
let node: DynAnyNode<_, &str, _, _> = DynAnyNode::new(image);
|
||||
if let ProtoNodeInput::Node(node_id) = proto_node.input {
|
||||
let pre_node = nodes.get(node_id as usize).unwrap();
|
||||
(pre_node).then(node).into_type_erased()
|
||||
} else {
|
||||
node.into_type_erased()
|
||||
})
|
||||
},
|
||||
),
|
||||
(
|
||||
NodeIdentifier::new("graphene_std::raster::ExportImageNode", &[Concrete(std::borrow::Cow::Borrowed("&str"))]),
|
||||
|proto_node, stack| {
|
||||
stack.push_fn(|nodes| {
|
||||
let pre_node = nodes.get(proto_node.input.unwrap_node() as usize).unwrap();
|
||||
}
|
||||
})
|
||||
}),
|
||||
(NodeIdentifier::new("graphene_std::raster::OpacityNode", &[concrete!("&TypeErasedNode")]), |proto_node, stack| {
|
||||
stack.push_fn(move |nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("OpacityNode Node constructed without inputs") };
|
||||
let opacity: DowncastBothNode<_, (), f64> = DowncastBothNode::new(nodes.get(construction_nodes[0] as usize).unwrap());
|
||||
let node = DynAnyNode::new(graphene_std::raster::OpacityNode::new(opacity));
|
||||
|
||||
let image = FnNode::new(|input: (Image, &str)| graphene_std::raster::export_image_node().eval(input).unwrap());
|
||||
let node: DynAnyNode<_, (Image, &str), _, _> = DynAnyNode::new(image);
|
||||
let node = (pre_node).then(node);
|
||||
if let ProtoNodeInput::Node(node_id) = proto_node.input {
|
||||
let pre_node = nodes.get(node_id as usize).unwrap();
|
||||
(pre_node).then(node).into_type_erased()
|
||||
} else {
|
||||
node.into_type_erased()
|
||||
})
|
||||
},
|
||||
),
|
||||
}
|
||||
})
|
||||
}),
|
||||
(NodeIdentifier::new("graphene_std::raster::PosterizeNode", &[concrete!("&TypeErasedNode")]), |proto_node, stack| {
|
||||
stack.push_fn(move |nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("Posterize node constructed without inputs") };
|
||||
let value: DowncastBothNode<_, (), f64> = DowncastBothNode::new(nodes.get(construction_nodes[0] as usize).unwrap());
|
||||
let node = DynAnyNode::new(graphene_std::raster::PosterizeNode::new(value));
|
||||
|
||||
if let ProtoNodeInput::Node(node_id) = proto_node.input {
|
||||
let pre_node = nodes.get(node_id as usize).unwrap();
|
||||
(pre_node).then(node).into_type_erased()
|
||||
} else {
|
||||
node.into_type_erased()
|
||||
}
|
||||
})
|
||||
}),
|
||||
(NodeIdentifier::new("graphene_std::raster::ExposureNode", &[concrete!("&TypeErasedNode")]), |proto_node, stack| {
|
||||
stack.push_fn(move |nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("ExposureNode constructed without inputs") };
|
||||
let value: DowncastBothNode<_, (), f64> = DowncastBothNode::new(nodes.get(construction_nodes[0] as usize).unwrap());
|
||||
let node = DynAnyNode::new(graphene_std::raster::ExposureNode::new(value));
|
||||
|
||||
if let ProtoNodeInput::Node(node_id) = proto_node.input {
|
||||
let pre_node = nodes.get(node_id as usize).unwrap();
|
||||
(pre_node).then(node).into_type_erased()
|
||||
} else {
|
||||
node.into_type_erased()
|
||||
}
|
||||
})
|
||||
}),
|
||||
(NodeIdentifier::new("graphene_std::raster::ImageNode", &[concrete!("&str")]), |_proto_node, stack| {
|
||||
stack.push_fn(|_nodes| {
|
||||
let image = FnNode::new(|s: &str| graphene_std::raster::image_node::<&str>().eval(s).unwrap());
|
||||
let node: DynAnyNode<_, &str, _, _> = DynAnyNode::new(image);
|
||||
node.into_type_erased()
|
||||
})
|
||||
}),
|
||||
(NodeIdentifier::new("graphene_std::raster::ExportImageNode", &[concrete!("&str")]), |proto_node, stack| {
|
||||
stack.push_fn(|nodes| {
|
||||
let pre_node = nodes.get(proto_node.input.unwrap_node() as usize).unwrap();
|
||||
|
||||
let image = FnNode::new(|input: (Image, &str)| graphene_std::raster::export_image_node().eval(input).unwrap());
|
||||
let node: DynAnyNode<_, (Image, &str), _, _> = DynAnyNode::new(image);
|
||||
let node = (pre_node).then(node);
|
||||
node.into_type_erased()
|
||||
})
|
||||
}),
|
||||
(
|
||||
NodeIdentifier::new(
|
||||
"graphene_core::structural::ConsNode",
|
||||
&[Concrete(std::borrow::Cow::Borrowed("Image")), Concrete(std::borrow::Cow::Borrowed("&str"))],
|
||||
),
|
||||
NodeIdentifier::new("graphene_core::structural::ConsNode", &[concrete!("Image"), concrete!("&str")]),
|
||||
|proto_node, stack| {
|
||||
let node_id = proto_node.input.unwrap_node() as usize;
|
||||
if let ConstructionArgs::Nodes(cons_node_arg) = proto_node.construction_args {
|
||||
|
@ -504,11 +432,19 @@ static NODE_REGISTRY: &[(NodeIdentifier, NodeConstructor)] = &[
|
|||
}),
|
||||
];
|
||||
|
||||
pub fn push_node(proto_node: ProtoNode, stack: &FixedSizeStack<TypeErasedNode<'static>>) {
|
||||
pub fn push_node<'a>(proto_node: ProtoNode, stack: &'a FixedSizeStack<TypeErasedNode<'static>>) {
|
||||
if let Some((_id, f)) = NODE_REGISTRY.iter().find(|(id, _)| *id == proto_node.identifier) {
|
||||
f(proto_node, stack);
|
||||
} else {
|
||||
panic!("NodeImplementation: {:?} not found in Registry", proto_node.identifier);
|
||||
let other_types = NODE_REGISTRY
|
||||
.iter()
|
||||
.map(|(id, _)| id)
|
||||
.filter(|id| id.name.as_ref() == proto_node.identifier.name.as_ref())
|
||||
.collect::<Vec<_>>();
|
||||
panic!(
|
||||
"NodeImplementation: {:?} not found in Registry types for which the node is implemented:\n {:#?}",
|
||||
proto_node.identifier, other_types
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -530,20 +466,14 @@ mod protograph_testing {
|
|||
let cons_protonode = ProtoNode {
|
||||
construction_args: ConstructionArgs::Nodes(vec![1]),
|
||||
input: ProtoNodeInput::Node(0),
|
||||
identifier: NodeIdentifier::new(
|
||||
"graphene_core::structural::ConsNode",
|
||||
&[Concrete(std::borrow::Cow::Borrowed("u32")), Concrete(std::borrow::Cow::Borrowed("u32"))],
|
||||
),
|
||||
identifier: NodeIdentifier::new("graphene_core::structural::ConsNode", &[concrete!("u32"), concrete!("u32")]),
|
||||
};
|
||||
push_node(cons_protonode, &stack);
|
||||
|
||||
let add_protonode = ProtoNode {
|
||||
construction_args: ConstructionArgs::Nodes(vec![]),
|
||||
input: ProtoNodeInput::Node(2),
|
||||
identifier: NodeIdentifier::new(
|
||||
"graphene_core::ops::AddNode",
|
||||
&[Concrete(std::borrow::Cow::Borrowed("u32")), Concrete(std::borrow::Cow::Borrowed("u32"))],
|
||||
),
|
||||
identifier: NodeIdentifier::new("graphene_core::ops::AddNode", &[concrete!("u32"), concrete!("u32")]),
|
||||
};
|
||||
push_node(add_protonode, &stack);
|
||||
|
||||
|
@ -576,7 +506,7 @@ mod protograph_testing {
|
|||
let image_protonode = ProtoNode {
|
||||
construction_args: ConstructionArgs::Nodes(vec![]),
|
||||
input: ProtoNodeInput::None,
|
||||
identifier: NodeIdentifier::new("graphene_std::raster::ImageNode", &[Concrete(std::borrow::Cow::Borrowed("&str"))]),
|
||||
identifier: NodeIdentifier::new("graphene_std::raster::ImageNode", &[concrete!("&str")]),
|
||||
};
|
||||
push_node(image_protonode, &stack);
|
||||
|
||||
|
@ -591,7 +521,7 @@ mod protograph_testing {
|
|||
let image_protonode = ProtoNode {
|
||||
construction_args: ConstructionArgs::Nodes(vec![]),
|
||||
input: ProtoNodeInput::None,
|
||||
identifier: NodeIdentifier::new("graphene_std::raster::ImageNode", &[Concrete(std::borrow::Cow::Borrowed("&str"))]),
|
||||
identifier: NodeIdentifier::new("graphene_std::raster::ImageNode", &[concrete!("&str")]),
|
||||
};
|
||||
push_node(image_protonode, &stack);
|
||||
|
||||
|
|
|
@ -1,8 +1,22 @@
|
|||
use std::borrow::Cow;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use crate::document::value;
|
||||
use crate::document::NodeId;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! concrete {
|
||||
($type:expr) => {
|
||||
Type::Concrete(std::borrow::Cow::Borrowed($type))
|
||||
};
|
||||
}
|
||||
#[macro_export]
|
||||
macro_rules! generic {
|
||||
($type:expr) => {
|
||||
Type::Generic(std::borrow::Cow::Borrowed($type))
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct NodeIdentifier {
|
||||
|
@ -10,10 +24,26 @@ pub struct NodeIdentifier {
|
|||
pub types: std::borrow::Cow<'static, [Type]>,
|
||||
}
|
||||
|
||||
impl NodeIdentifier {
|
||||
pub fn fully_qualified_name(&self) -> String {
|
||||
let mut name = String::new();
|
||||
name.push_str(self.name.as_ref());
|
||||
name.push('<');
|
||||
for t in self.types.as_ref() {
|
||||
name.push_str(t.to_string().as_str());
|
||||
name.push_str(", ");
|
||||
}
|
||||
name.pop();
|
||||
name.pop();
|
||||
name.push('>');
|
||||
name
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub enum Type {
|
||||
Generic,
|
||||
Generic(std::borrow::Cow<'static, str>),
|
||||
Concrete(std::borrow::Cow<'static, str>),
|
||||
}
|
||||
|
||||
|
@ -22,6 +52,14 @@ impl From<&'static str> for Type {
|
|||
Type::Concrete(std::borrow::Cow::Borrowed(s))
|
||||
}
|
||||
}
|
||||
impl std::fmt::Display for Type {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Type::Generic(name) => write!(f, "{}", name),
|
||||
Type::Concrete(name) => write!(f, "{}", name),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Type {
|
||||
pub const fn from_str(concrete: &'static str) -> Self {
|
||||
|
@ -70,6 +108,15 @@ impl PartialEq for ConstructionArgs {
|
|||
}
|
||||
}
|
||||
|
||||
impl ConstructionArgs {
|
||||
pub fn new_function_args(&self) -> Vec<String> {
|
||||
match self {
|
||||
ConstructionArgs::Nodes(nodes) => nodes.iter().map(|n| format!("n{}", n)).collect(),
|
||||
ConstructionArgs::Value(value) => vec![format!("{:?}", value)],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct ProtoNode {
|
||||
pub construction_args: ConstructionArgs,
|
||||
|
@ -77,7 +124,7 @@ pub struct ProtoNode {
|
|||
pub identifier: NodeIdentifier,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, PartialEq, Eq)]
|
||||
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum ProtoNodeInput {
|
||||
None,
|
||||
#[default]
|
||||
|
@ -97,7 +144,7 @@ impl ProtoNodeInput {
|
|||
impl ProtoNode {
|
||||
pub fn value(value: ConstructionArgs) -> Self {
|
||||
Self {
|
||||
identifier: NodeIdentifier::new("graphene_core::value::ValueNode", &[Type::Generic]),
|
||||
identifier: NodeIdentifier::new("graphene_core::value::ValueNode", &[Type::Generic(Cow::Borrowed("T"))]),
|
||||
construction_args: value,
|
||||
input: ProtoNodeInput::None,
|
||||
}
|
||||
|
@ -165,6 +212,41 @@ impl ProtoNetwork {
|
|||
edges
|
||||
}
|
||||
|
||||
pub fn resolve_inputs(&mut self) {
|
||||
while !self.resolve_inputs_impl() {}
|
||||
}
|
||||
fn resolve_inputs_impl(&mut self) -> bool {
|
||||
self.reorder_ids();
|
||||
|
||||
let mut lookup = self.nodes.iter().map(|(id, _)| (*id, *id)).collect::<HashMap<_, _>>();
|
||||
let compose_node_id = self.nodes.len() as NodeId;
|
||||
let inputs = self.nodes.iter().map(|(_, node)| node.input).collect::<Vec<_>>();
|
||||
|
||||
if let Some((input_node, id, input)) = self.nodes.iter_mut().find_map(|(id, node)| {
|
||||
if let ProtoNodeInput::Node(input_node) = node.input {
|
||||
node.input = ProtoNodeInput::None;
|
||||
let pre_node_input = inputs.get(input_node as usize).expect("input node should exist");
|
||||
Some((input_node, *id, *pre_node_input))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}) {
|
||||
lookup.insert(id, compose_node_id);
|
||||
self.replace_node_references(&lookup);
|
||||
self.nodes.push((
|
||||
compose_node_id,
|
||||
ProtoNode {
|
||||
identifier: NodeIdentifier::new("graphene_core::structural::ComposeNode", &[generic!("T"), Type::Generic(Cow::Borrowed("U"))]),
|
||||
construction_args: ConstructionArgs::Nodes(vec![input_node, id]),
|
||||
input,
|
||||
},
|
||||
));
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
// Based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
|
||||
// This approach excludes nodes that are not connected
|
||||
pub fn topological_sort(&self) -> Vec<NodeId> {
|
||||
|
@ -228,14 +310,28 @@ impl ProtoNetwork {
|
|||
info!("Order {order:?}");
|
||||
self.nodes = order
|
||||
.iter()
|
||||
.map(|id| {
|
||||
let mut node = self.nodes.swap_remove(self.nodes.iter().position(|(test_id, _)| test_id == id).unwrap()).1;
|
||||
node.map_ids(|id| *lookup.get(&id).unwrap());
|
||||
(*lookup.get(id).unwrap(), node)
|
||||
.enumerate()
|
||||
.map(|(pos, id)| {
|
||||
let node = self.nodes.swap_remove(self.nodes.iter().position(|(test_id, _)| test_id == id).unwrap()).1;
|
||||
(pos as NodeId, node)
|
||||
})
|
||||
.collect();
|
||||
self.replace_node_references(&lookup);
|
||||
assert_eq!(order.len(), self.nodes.len());
|
||||
}
|
||||
|
||||
fn replace_node_references(&mut self, lookup: &HashMap<u64, u64>) {
|
||||
self.nodes.iter_mut().for_each(|(sid, node)| {
|
||||
node.map_ids(|id| *lookup.get(&id).expect("node not found in lookup table"));
|
||||
});
|
||||
self.inputs = self.inputs.iter().map(|id| *lookup.get(id).unwrap()).collect();
|
||||
self.output = *lookup.get(&self.output).unwrap();
|
||||
}
|
||||
|
||||
fn replace_ids_with_lookup(&mut self, lookup: HashMap<u64, u64>) {
|
||||
self.nodes.iter_mut().for_each(|(id, _)| *id = *lookup.get(id).unwrap());
|
||||
self.replace_node_references(&lookup);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -246,6 +342,51 @@ mod test {
|
|||
|
||||
#[test]
|
||||
fn topological_sort() {
|
||||
let construction_network = test_network();
|
||||
let sorted = construction_network.topological_sort();
|
||||
|
||||
println!("{:#?}", sorted);
|
||||
assert_eq!(sorted, vec![14, 10, 11, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn id_reordering() {
|
||||
let mut construction_network = test_network();
|
||||
construction_network.reorder_ids();
|
||||
let sorted = construction_network.topological_sort();
|
||||
println!("nodes: {:#?}", construction_network.nodes);
|
||||
assert_eq!(sorted, vec![0, 1, 2, 3]);
|
||||
let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect();
|
||||
println!("{:#?}", ids);
|
||||
println!("nodes: {:#?}", construction_network.nodes);
|
||||
assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value");
|
||||
assert_eq!(ids, vec![0, 1, 2, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn id_reordering_idempotent() {
|
||||
let mut construction_network = test_network();
|
||||
construction_network.reorder_ids();
|
||||
construction_network.reorder_ids();
|
||||
let sorted = construction_network.topological_sort();
|
||||
assert_eq!(sorted, vec![0, 1, 2, 3]);
|
||||
let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect();
|
||||
println!("{:#?}", ids);
|
||||
assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value");
|
||||
assert_eq!(ids, vec![0, 1, 2, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn input_resolution() {
|
||||
let mut construction_network = test_network();
|
||||
construction_network.resolve_inputs();
|
||||
println!("{:#?}", construction_network);
|
||||
assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value");
|
||||
assert_eq!(construction_network.nodes.len(), 6);
|
||||
assert_eq!(construction_network.nodes[5].1.construction_args, ConstructionArgs::Nodes(vec![3, 4]));
|
||||
}
|
||||
|
||||
fn test_network() -> ProtoNetwork {
|
||||
let construction_network = ProtoNetwork {
|
||||
inputs: vec![10],
|
||||
output: 1,
|
||||
|
@ -294,9 +435,6 @@ mod test {
|
|||
.into_iter()
|
||||
.collect(),
|
||||
};
|
||||
let sorted = construction_network.topological_sort();
|
||||
|
||||
println!("{:#?}", sorted);
|
||||
assert_eq!(sorted, vec![14, 10, 11, 1]);
|
||||
construction_network
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ default = ["derive", "memoization"]
|
|||
|
||||
|
||||
[dependencies]
|
||||
graphene-core = {path = "../gcore", features = ["async", "std"]}
|
||||
graphene-core = {path = "../gcore", features = ["async", "std"], default-features = false}
|
||||
borrow_stack = {path = "../borrow_stack"}
|
||||
dyn-any = {path = "../../libraries/dyn-any", features = ["derive"]}
|
||||
graph-proc-macros = {path = "../proc-macro", optional = true}
|
||||
|
|
|
@ -14,99 +14,3 @@ pub mod vector;
|
|||
pub mod any;
|
||||
|
||||
pub use graphene_core::*;
|
||||
|
||||
use quote::quote;
|
||||
use syn::{Expr, ExprPath, Type};
|
||||
|
||||
/// Given a Node call tree, construct a function
|
||||
/// that takes an input tuple and evaluates the call graph
|
||||
/// on the gpu an fn node is constructed that takes a value
|
||||
/// node as input
|
||||
pub struct NodeGraph {
|
||||
/// Collection of nodes with their corresponding inputs.
|
||||
/// The first node always always has to be an Input Node.
|
||||
pub nodes: Vec<NodeKind>,
|
||||
pub output: Type,
|
||||
pub input: Type,
|
||||
}
|
||||
|
||||
pub enum NodeKind {
|
||||
Value(Expr),
|
||||
Input,
|
||||
Node(ExprPath, Vec<usize>),
|
||||
}
|
||||
|
||||
impl NodeGraph {
|
||||
pub fn serialize_function(&self) -> proc_macro2::TokenStream {
|
||||
let output_type = &self.output;
|
||||
let input_type = &self.input;
|
||||
|
||||
fn nid(id: &usize) -> syn::Ident {
|
||||
let str = format!("n{id}");
|
||||
syn::Ident::new(str.as_str(), proc_macro2::Span::call_site())
|
||||
}
|
||||
let mut nodes = Vec::new();
|
||||
for (ref id, node) in self.nodes.iter().enumerate() {
|
||||
let id = nid(id).clone();
|
||||
let line = match node {
|
||||
NodeKind::Value(val) => {
|
||||
quote! {let #id = graphene_core::value::ValueNode::new(#val);}
|
||||
}
|
||||
NodeKind::Node(node, ids) => {
|
||||
let ids = ids.iter().map(nid).collect::<Vec<_>>();
|
||||
quote! {let #id = #node::new((#(&#ids),*));}
|
||||
}
|
||||
NodeKind::Input => {
|
||||
quote! { let n0 = graphene_core::value::ValueNode::new(input);}
|
||||
}
|
||||
};
|
||||
nodes.push(line)
|
||||
}
|
||||
let last_id = self.nodes.len() - 1;
|
||||
let last_id = nid(&last_id);
|
||||
let ret = quote! { #last_id.eval() };
|
||||
let function = quote! {
|
||||
fn node_graph(input: #input_type) -> #output_type {
|
||||
#(#nodes)*
|
||||
#ret
|
||||
}
|
||||
};
|
||||
function
|
||||
}
|
||||
pub fn serialize_gpu(&self, name: &str) -> proc_macro2::TokenStream {
|
||||
let function = self.serialize_function();
|
||||
let output_type = &self.output;
|
||||
let input_type = &self.input;
|
||||
|
||||
quote! {
|
||||
#[cfg(target_arch = "spirv")]
|
||||
pub mod gpu {
|
||||
//#![deny(warnings)]
|
||||
#[repr(C)]
|
||||
pub struct PushConsts {
|
||||
n: u32,
|
||||
node: u32,
|
||||
}
|
||||
use super::*;
|
||||
|
||||
use spirv_std::glam::UVec3;
|
||||
|
||||
#[allow(unused)]
|
||||
#[spirv(compute(threads(64)))]
|
||||
pub fn #name(
|
||||
#[spirv(global_invocation_id)] global_id: UVec3,
|
||||
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] a: &[#input_type],
|
||||
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [#output_type],
|
||||
#[spirv(push_constant)] push_consts: &PushConsts,
|
||||
) {
|
||||
#function
|
||||
let gid = global_id.x as usize;
|
||||
// Only process up to n, which is the length of the buffers.
|
||||
if global_id.x < push_consts.n {
|
||||
y[gid] = node_graph(a[gid]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -89,10 +89,8 @@ mod mul {
|
|||
// }
|
||||
|
||||
fn main() {
|
||||
use graphene_std::*;
|
||||
use quote::quote;
|
||||
// use syn::parse::Parse;
|
||||
let nodes = vec![
|
||||
/*let nodes = vec![
|
||||
NodeKind::Input,
|
||||
NodeKind::Value(syn::parse_quote!(1u32)),
|
||||
NodeKind::Node(syn::parse_quote!(graphene_core::ops::AddNode), vec![0, 0]),
|
||||
|
@ -105,7 +103,7 @@ fn main() {
|
|||
nodes,
|
||||
input: syn::Type::Verbatim(quote! {u32}),
|
||||
output: syn::Type::Verbatim(quote! {u32}),
|
||||
};
|
||||
};*/
|
||||
|
||||
//let pretty = pretty_token_stream::Pretty::new(nodegraph.serialize_gpu("add"));
|
||||
//pretty.print();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue