Web gpu execution MVP

Ready infrastructure for wgpu experimentation

Start implementing simple gpu test case

Fix Extract Node not working with nested networks

Convert inputs for extracted node to network inputs

Fix missing cors headers

Feature gate gcore to make it once again no-std compatible

Add skeleton structure gpu shader

Work on gpu node graph output saving

Fix Get and Set nodes

Fix storage nodes

Fix shader construction errors -> spirv errors

Add unsafe version

Add once cell node

Web gpu execution MVP
This commit is contained in:
Dennis Kobert 2023-05-18 19:47:57 +02:00 committed by Keavon Chambers
parent 259078c847
commit 7a254122c3
32 changed files with 1399 additions and 534 deletions

View file

@ -26,7 +26,7 @@ impl Metadata {
}
}
pub fn create_files(metadata: &Metadata, network: &ProtoNetwork, compile_dir: &Path, io: &ShaderIO) -> anyhow::Result<()> {
pub fn create_files(metadata: &Metadata, networks: &[ProtoNetwork], compile_dir: &Path, io: &ShaderIO) -> anyhow::Result<()> {
let src = compile_dir.join("src");
let cargo_file = compile_dir.join("Cargo.toml");
let cargo_toml = create_cargo_toml(metadata)?;
@ -46,7 +46,7 @@ pub fn create_files(metadata: &Metadata, network: &ProtoNetwork, compile_dir: &P
}
}
let lib = src.join("lib.rs");
let shader = serialize_gpu(network, io)?;
let shader = serialize_gpu(networks, io)?;
eprintln!("{}", shader);
std::fs::write(lib, shader)?;
Ok(())
@ -67,20 +67,21 @@ fn constant_attribute(constant: &GPUConstant) -> &'static str {
}
}
pub fn construct_argument(input: &ShaderInput<()>, position: u32) -> String {
match input {
ShaderInput::Constant(constant) => format!("#[spirv({})] i{}: {},", constant_attribute(constant), position, constant.ty()),
pub fn construct_argument(input: &ShaderInput<()>, position: u32, binding_offset: u32) -> String {
let line = match input {
ShaderInput::Constant(constant) => format!("#[spirv({})] i{}: {}", constant_attribute(constant), position, constant.ty()),
ShaderInput::UniformBuffer(_, ty) => {
format!("#[spirv(uniform, descriptor_set = 0, binding = {})] i{}: &[{}]", position, position, ty,)
format!("#[spirv(uniform, descriptor_set = 0, binding = {})] i{}: &[{}]", position + binding_offset, position, ty,)
}
ShaderInput::StorageBuffer(_, ty) | ShaderInput::ReadBackBuffer(_, ty) => {
format!("#[spirv(storage_buffer, descriptor_set = 0, binding = {})] i{}: &[{}]", position, position, ty,)
format!("#[spirv(storage_buffer, descriptor_set = 0, binding = {})] i{}: &[{}]", position + binding_offset, position, ty,)
}
ShaderInput::OutputBuffer(_, ty) => {
format!("#[spirv(storage_buffer, descriptor_set = 0, binding = {})] i{}: &mut[{}]", position, position, ty,)
format!("#[spirv(storage_buffer, descriptor_set = 0, binding = {})] o{}: &mut[{}]", position + binding_offset, position, ty,)
}
ShaderInput::WorkGroupMemory(_, ty) => format!("#[spirv(workgroup_memory] i{}: {}", position, ty,),
}
};
line.replace("glam::u32::uvec3::UVec3", "spirv_std::glam::UVec3")
}
struct GpuCompiler {
@ -88,10 +89,10 @@ struct GpuCompiler {
}
impl SpirVCompiler for GpuCompiler {
fn compile(&self, network: ProtoNetwork, io: &ShaderIO) -> anyhow::Result<gpu_executor::Shader> {
fn compile(&self, networks: &[ProtoNetwork], io: &ShaderIO) -> anyhow::Result<gpu_executor::Shader> {
let metadata = Metadata::new("project".to_owned(), vec!["test@example.com".to_owned()]);
create_files(&metadata, &network, &self.compile_dir, io)?;
create_files(&metadata, networks, &self.compile_dir, io)?;
let result = compile(&self.compile_dir)?;
let bytes = std::fs::read(result.module.unwrap_single())?;
@ -105,50 +106,80 @@ impl SpirVCompiler for GpuCompiler {
}
}
pub fn serialize_gpu(network: &ProtoNetwork, io: &ShaderIO) -> anyhow::Result<String> {
pub fn serialize_gpu(networks: &[ProtoNetwork], io: &ShaderIO) -> anyhow::Result<String> {
fn nid(id: &u64) -> String {
format!("n{id}")
}
dbg!(&network);
dbg!(&io);
let inputs = io.inputs.iter().enumerate().map(|(i, input)| construct_argument(input, i as u32)).collect::<Vec<_>>();
let mut inputs = io
.inputs
.iter()
.filter(|x| !x.is_output())
.enumerate()
.map(|(i, input)| construct_argument(input, i as u32, 0))
.collect::<Vec<_>>();
let offset = inputs.len() as u32;
inputs.extend(io.inputs.iter().filter(|x| x.is_output()).enumerate().map(|(i, input)| construct_argument(input, i as u32, offset)));
let mut nodes = Vec::new();
let mut input_nodes = Vec::new();
#[derive(serde::Serialize)]
struct Node {
id: String,
fqn: String,
args: Vec<String>,
}
for id in network.inputs.iter() {
let Some((_, node)) = network.nodes.iter().find(|(i, _)| i == id) else {
let mut output_nodes = Vec::new();
for network in networks {
dbg!(&network);
//assert_eq!(network.inputs.len(), io.inputs.iter().filter(|x| !x.is_output()).count());
#[derive(serde::Serialize, Debug)]
struct Node {
id: String,
index: usize,
fqn: String,
args: Vec<String>,
}
for (i, id) in network.inputs.iter().enumerate() {
let Some((_, node)) = network.nodes.iter().find(|(i, _)| i == id) else {
anyhow::bail!("Input node not found");
};
let fqn = &node.identifier.name;
let id = nid(id);
input_nodes.push(Node {
id,
fqn: fqn.to_string().split("<").next().unwrap().to_owned(),
args: node.construction_args.new_function_args(),
});
}
for (ref id, node) in network.nodes.iter() {
if network.inputs.contains(id) {
continue;
let fqn = &node.identifier.name;
let id = nid(id);
let node = Node {
id: id.clone(),
index: i,
fqn: fqn.to_string().split('<').next().unwrap().to_owned(),
args: node.construction_args.new_function_args(),
};
dbg!(&node);
if !io.inputs[i].is_output() {
if input_nodes.iter().any(|x: &Node| x.id == id) {
continue;
}
input_nodes.push(node);
}
}
let fqn = &node.identifier.name;
let id = nid(id);
for (ref id, node) in network.nodes.iter() {
if network.inputs.contains(id) {
continue;
}
nodes.push(Node {
id,
fqn: fqn.to_string().split("<").next().unwrap().to_owned(),
args: node.construction_args.new_function_args(),
});
let fqn = &node.identifier.name;
let id = nid(id);
if nodes.iter().any(|x: &Node| x.id == id) {
continue;
}
nodes.push(Node {
id,
index: 0,
fqn: fqn.to_string().split("<").next().unwrap().to_owned(),
args: node.construction_args.new_function_args(),
});
}
let output = nid(&network.output);
output_nodes.push(output);
}
dbg!(&input_nodes);
let template = include_str!("templates/spirv-template.rs");
let mut tera = tera::Tera::default();
@ -156,8 +187,8 @@ pub fn serialize_gpu(network: &ProtoNetwork, io: &ShaderIO) -> anyhow::Result<St
let mut context = Context::new();
context.insert("inputs", &inputs);
context.insert("input_nodes", &input_nodes);
context.insert("output_nodes", &output_nodes);
context.insert("nodes", &nodes);
context.insert("last_node", &nid(&network.output));
context.insert("compute_threads", &64);
Ok(tera.render("spirv", &context)?)
}
@ -171,9 +202,9 @@ pub fn compile(dir: &Path) -> Result<spirv_builder::CompileResult, spirv_builder
.preserve_bindings(true)
.release(true)
.spirv_metadata(SpirvMetadata::Full)
.extra_arg("no-early-report-zombies")
.extra_arg("no-infer-storage-classes")
.extra_arg("spirt-passes=qptr")
//.extra_arg("no-early-report-zombies")
//.extra_arg("no-infer-storage-classes")
//.extra_arg("spirt-passes=qptr")
.build()?;
Ok(result)