Compile node graph description to GPU code

This commit is contained in:
Dennis 2022-06-08 09:52:58 +02:00 committed by Keavon Chambers
parent 998f37d1b0
commit e84b9bd5bd
10 changed files with 498 additions and 76 deletions

View file

@ -26,3 +26,7 @@ ide_db = { version = "*", package = "ra_ap_ide_db" , optional = true }
storage-map = { version = "*", optional = true }
lock_api = { version= "*", optional = true }
parking_lot = { version = "*", optional = true }
pretty-token-stream = {path = "../../pretty-token-stream"}
syn = {version = "1.0", default-features = false, features = ["parsing", "printing"]}
proc-macro2 = {version = "1.0", default-features = false, features = ["proc-macro"]}
quote = {version = "1.0", default-features = false }

View file

@ -16,3 +16,98 @@ pub trait DynamicInput<'n> {
fn set_kwarg_by_name(&mut self, name: &str, value: DynAnyNode<'n>);
fn set_arg_by_index(&mut self, index: usize, value: DynAnyNode<'n>);
}
use quote::quote;
use syn::{Expr, ExprPath, Type};
/// Given a Node call tree, construct a function
/// that takes an input tuple and evaluates the call graph
/// on the gpu an fn node is constructed that takes a value
/// node as input
pub struct NodeGraph {
/// Collection of nodes with their corresponding inputs.
/// The first node always always has to be an Input Node.
pub nodes: Vec<NodeKind>,
pub output: Type,
pub input: Type,
}
pub enum NodeKind {
Value(Expr),
Input,
Node(ExprPath, Vec<usize>),
}
impl NodeGraph {
pub fn serialize_function(&self) -> proc_macro2::TokenStream {
let output_type = &self.output;
let input_type = &self.input;
fn nid(id: &usize) -> syn::Ident {
let str = format!("n{id}");
syn::Ident::new(str.as_str(), proc_macro2::Span::call_site())
}
let mut nodes = Vec::new();
for (ref id, node) in self.nodes.iter().enumerate() {
let id = nid(id).clone();
let line = match node {
NodeKind::Value(val) => {
quote! {let #id = graphene_core::value::ValueNode::new(#val);}
}
NodeKind::Node(node, ids) => {
let ids = ids.iter().map(nid).collect::<Vec<_>>();
quote! {let #id = #node::new((#(&#ids),*));}
}
NodeKind::Input => {
quote! { let n0 = graphene_core::value::ValueNode::new(input);}
}
};
nodes.push(line)
}
let last_id = self.nodes.len() - 1;
let last_id = nid(&last_id);
let ret = quote! { #last_id.eval() };
let function = quote! {
fn node_graph(input: #input_type) -> #output_type {
#(#nodes)*
#ret
}
};
function
}
pub fn serialize_gpu(&self, name: &str) -> proc_macro2::TokenStream {
let function = self.serialize_function();
let output_type = &self.output;
let input_type = &self.input;
quote! {
#[cfg(target_arch = "spirv")]
pub mod gpu {
//#![deny(warnings)]
#[repr(C)]
pub struct PushConsts {
n: u32,
node: u32,
}
use super::*;
use spirv_std::glam::UVec3;
#[allow(unused)]
#[spirv(compute(threads(64)))]
pub fn #name(
#[spirv(global_invocation_id)] global_id: UVec3,
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] a: &[#input_type],
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [#output_type],
#[spirv(push_constant)] push_consts: &PushConsts,
) {
#function
let gid = global_id.x as usize;
// Only process up to n, which is the length of the buffers.
if global_id.x < push_consts.n {
y[gid] = node_graph(a[gid]);
}
}
}
}
}
}

View file

@ -91,41 +91,61 @@ impl<'n> NodeStore<'n> {
}
fn main() {
use dyn_any::{downcast_ref, DynAny, StaticType};
//let mut mul = mul::MulNode::new();
let mut stack: borrow_stack::FixedSizeStack<Box<dyn Node<'_, Output = &dyn DynAny>>> =
borrow_stack::FixedSizeStack::new(42);
unsafe { stack.push(Box::new(AnyValueNode::new(1f32))) };
//let node = unsafe { stack.get(0) };
//let boxed = Box::new(StorageNode::new(node));
//unsafe { stack.push(boxed) };
let result = unsafe { &stack.get()[0] }.eval();
dbg!(downcast_ref::<f32>(result));
/*unsafe {
stack
.push(Box::new(AnyRefNode::new(stack.get(0).as_ref()))
as Box<dyn Node<(), Output = &dyn DynAny>>)
};*/
let f = (3.2f32, 3.1f32);
let a = ValueNode::new(1.);
let id = std::any::TypeId::of::<&f32>();
let any_a = AnyRefNode::new(&a);
/*let _mul2 = mul::MulNodeInput {
a: None,
b: Some(&any_a),
};
let mut mul2 = mul::new!();
//let cached = memo::CacheNode::new(&mul1);
//let foo = value::AnyRefNode::new(&cached);
mul2.set_arg_by_index(0, &any_a);*/
let int = value::IntNode::<32>;
Node::eval(&int);
println!("{}", Node::eval(&int));
//let _add: u32 = ops::AddNode::<u32>::default().eval((int.exec(), int.exec()));
//let fnode = generic::FnNode::new(|(a, b): &(i32, i32)| a - b);
//let sub = fnode.any(&("a", 2));
//let cache = memo::CacheNode::new(&fnode);
//let cached_result = cache.eval(&(2, 3));
use graphene_std::*;
use quote::quote;
use syn::parse::Parse;
let nodes = vec![
NodeKind::Input,
NodeKind::Value(syn::parse_quote!(1u32)),
NodeKind::Node(syn::parse_quote!(graphene_core::ops::AddNode), vec![0, 0]),
];
//println!("{}", node_graph(1));
let nodegraph = NodeGraph {
nodes,
input: syn::Type::Verbatim(quote! {u32}),
output: syn::Type::Verbatim(quote! {u32}),
};
let pretty = pretty_token_stream::Pretty::new(nodegraph.serialize_gpu("add"));
pretty.print();
/*
use dyn_any::{downcast_ref, DynAny, StaticType};
//let mut mul = mul::MulNode::new();
let mut stack: borrow_stack::FixedSizeStack<Box<dyn Node<'_, Output = &dyn DynAny>>> =
borrow_stack::FixedSizeStack::new(42);
unsafe { stack.push(Box::new(AnyValueNode::new(1f32))) };
//let node = unsafe { stack.get(0) };
//let boxed = Box::new(StorageNode::new(node));
//unsafe { stack.push(boxed) };
let result = unsafe { &stack.get()[0] }.eval();
dbg!(downcast_ref::<f32>(result));
/*unsafe {
stack
.push(Box::new(AnyRefNode::new(stack.get(0).as_ref()))
as Box<dyn Node<(), Output = &dyn DynAny>>)
};*/
let f = (3.2f32, 3.1f32);
let a = ValueNode::new(1.);
let id = std::any::TypeId::of::<&f32>();
let any_a = AnyRefNode::new(&a);
/*let _mul2 = mul::MulNodeInput {
a: None,
b: Some(&any_a),
};
let mut mul2 = mul::new!();
//let cached = memo::CacheNode::new(&mul1);
//let foo = value::AnyRefNode::new(&cached);
mul2.set_arg_by_index(0, &any_a);*/
let int = value::IntNode::<32>;
Node::eval(&int);
println!("{}", Node::eval(&int));
//let _add: u32 = ops::AddNode::<u32>::default().eval((int.exec(), int.exec()));
//let fnode = generic::FnNode::new(|(a, b): &(i32, i32)| a - b);
//let sub = fnode.any(&("a", 2));
//let cache = memo::CacheNode::new(&fnode);
//let cached_result = cache.eval(&(2, 3));
*/
//println!("{}", cached_result)
}

View file

@ -8,7 +8,10 @@ pub struct CacheNode<'n, CachedNode: Node<'n>> {
cache: OnceCell<CachedNode::Output>,
_phantom: PhantomData<&'n ()>,
}
impl<'n, CashedNode: Node<'n>> Node<'n> for CacheNode<'n, CashedNode> {
impl<'n, CashedNode: Node<'n>> Node<'n> for CacheNode<'n, CashedNode>
where
CashedNode::Output: 'n,
{
type Output = &'n CashedNode::Output;
fn eval(&'n self) -> Self::Output {
self.cache.get_or_init(|| self.node.eval())

View file

@ -24,8 +24,7 @@ pub struct StorageNode<'n>(&'n dyn Node<'n, Output = &'n dyn DynAny<'n>>);
impl<'n> Node<'n> for StorageNode<'n> {
type Output = &'n (dyn DynAny<'n>);
fn eval(&'n self) -> Self::Output {
let value = self.0.eval();
value
self.0.eval()
}
}
impl<'n> StorageNode<'n> {