Restructure GPU execution to model GPU pipelines in the node graph (#1088)

* Start implementing GpuExecutor for wgpu

* Implement read_output_buffer function

* Implement extraction node in the compiler

* Generate type annotations during shader compilation

* Start adding node wrapprs for graph execution api

* Wrap more of the api in nodes

* Restructure Pipeline to accept arbitrary shader inputs

* Adapt nodes to new trait definitions

* Start implementing gpu-compiler trait

* Adapt shader generation

* Hardstuck on pointer casts

* Pass nodes as references in gpu code to avoid zsts

* Update gcore to compile on the gpu

* Fix color doc tests

* Impl Node for node refs
This commit is contained in:
Dennis Kobert 2023-04-23 10:18:31 +02:00 committed by Keavon Chambers
parent 161bbc62b4
commit bdc1ef926a
43 changed files with 1874 additions and 515 deletions

View file

@ -1,8 +1,15 @@
use gpu_executor::ShaderIO;
use graph_craft::{proto::ProtoNetwork, Type};
use serde::{Deserialize, Serialize};
use std::io::Write;
pub fn compile_spirv(network: &graph_craft::document::NodeNetwork, input_type: &str, output_type: &str, compile_dir: Option<&str>, manifest_path: &str) -> anyhow::Result<Vec<u8>> {
let serialized_graph = serde_json::to_string(&network)?;
pub fn compile_spirv(request: &CompileRequest, compile_dir: Option<&str>, manifest_path: &str) -> anyhow::Result<Vec<u8>> {
let serialized_graph = serde_json::to_string(&gpu_executor::CompileRequest {
network: request.network.clone(),
io: request.shader_io.clone(),
})?;
let features = "";
#[cfg(feature = "profiling")]
let features = "profiling";
@ -19,9 +26,6 @@ pub fn compile_spirv(network: &graph_craft::document::NodeNetwork, input_type: &
.envs(non_cargo_env_vars)
.arg("--features")
.arg(features)
.arg("--")
.arg(input_type)
.arg(output_type)
// TODO: handle None case properly
.arg(compile_dir.unwrap())
.stdin(std::process::Stdio::piped())
@ -38,16 +42,27 @@ pub fn compile_spirv(network: &graph_craft::document::NodeNetwork, input_type: &
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct CompileRequest {
network: graph_craft::document::NodeNetwork,
input_type: String,
output_type: String,
network: graph_craft::proto::ProtoNetwork,
input_types: Vec<Type>,
output_type: Type,
shader_io: ShaderIO,
}
impl CompileRequest {
pub fn new(network: graph_craft::document::NodeNetwork, input_type: String, output_type: String) -> Self {
Self { network, input_type, output_type }
pub fn new(network: ProtoNetwork, input_types: Vec<Type>, output_type: Type, io: ShaderIO) -> Self {
// TODO: add type checking
// for (input, buffer) in input_types.iter().zip(io.inputs.iter()) {
// assert_eq!(input, &buffer.ty());
// }
// assert_eq!(output_type, io.output.ty());
Self {
network,
input_types,
output_type,
shader_io: io,
}
}
pub fn compile(&self, compile_dir: &str, manifest_path: &str) -> anyhow::Result<Vec<u8>> {
compile_spirv(&self.network, &self.input_type, &self.output_type, Some(compile_dir), manifest_path)
compile_spirv(self, Some(compile_dir), manifest_path)
}
}