Add type checking to the node graph (#1025)

* Implement type inference

Add type hints to node trait

Add type annotation infrastructure

Refactor type ascription infrastructure

Run cargo fix

Insert infer types stub

Remove types from node identifier

* Implement covariance

* Disable rejection of generic inputs + parameters

* Fix lints

* Extend type checking to cover Network inputs

* Implement generic specialization

* Relax covariance rules

* Fix type annotations for TypErasedComposeNode

* Fix type checking errors

* Keep connection information during node resolution
* Fix TypeDescriptor PartialEq implementation

* Apply review suggestions

* Add documentation to type inference

* Add Imaginate node to document node types

* Fix whitespace in macros

* Add types to imaginate node

* Fix type declaration for imaginate node + add console logging

* Use fully qualified type names as fallback during comparison

---------

Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
Dennis Kobert 2023-02-15 23:31:30 +01:00 committed by Keavon Chambers
parent a64c856ec4
commit 5dab7de68d
25 changed files with 1365 additions and 1008 deletions

View file

@ -88,16 +88,16 @@ pub fn serialize_gpu(network: &ProtoNetwork, input_type: &str, output_type: &str
use spirv_builder::{MetadataPrintout, SpirvBuilder, SpirvMetadata};
pub fn compile(dir: &Path) -> Result<spirv_builder::CompileResult, spirv_builder::SpirvBuilderError> {
dbg!(&dir);
dbg!(&dir);
let result = SpirvBuilder::new(dir, "spirv-unknown-spv1.5")
.print_metadata(MetadataPrintout::DependencyOnly)
.multimodule(false)
.preserve_bindings(true)
.release(true)
//.relax_struct_store(true)
//.relax_block_layout(true)
.spirv_metadata(SpirvMetadata::Full)
.build()?;
.print_metadata(MetadataPrintout::DependencyOnly)
.multimodule(false)
.preserve_bindings(true)
.release(true)
//.relax_struct_store(true)
//.relax_block_layout(true)
.spirv_metadata(SpirvMetadata::Full)
.build()?;
Ok(result)
}

View file

@ -1,18 +1,18 @@
use graph_craft::document::NodeNetwork;
use gpu_compiler as compiler;
use graph_craft::document::NodeNetwork;
use std::io::Write;
fn main() -> anyhow::Result<()> {
println!("Starting Gpu Compiler!");
fn main() -> anyhow::Result<()> {
println!("Starting GPU Compiler!");
let mut stdin = std::io::stdin();
let mut stdout = std::io::stdout();
let input_type = std::env::args().nth(1).expect("input type arg missing");
let output_type = std::env::args().nth(2).expect("output type arg missing");
let compile_dir = std::env::args().nth(3).map(|x| std::path::PathBuf::from(&x)).unwrap_or(tempfile::tempdir()?.into_path());
let network: NodeNetwork = serde_json::from_reader(&mut stdin)?;
let compiler = graph_craft::executor::Compiler{};
let compiler = graph_craft::executor::Compiler {};
let proto_network = compiler.compile(network, true);
dbg!(&compile_dir);
dbg!(&compile_dir);
let metadata = compiler::Metadata::new("project".to_owned(), vec!["test@example.com".to_owned()]);

View file

@ -7,32 +7,32 @@ extern crate spirv_std;
#[cfg(target_arch = "spirv")]
pub mod gpu {
use super::*;
use spirv_std::spirv;
use spirv_std::glam::UVec3;
use super::*;
use spirv_std::spirv;
use spirv_std::glam::UVec3;
#[allow(unused)]
#[spirv(compute(threads({{compute_threads}})))]
pub fn eval (
#[spirv(global_invocation_id)] global_id: UVec3,
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] a: &[{{input_type}}],
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [{{output_type}}],
#[spirv(push_constant)] push_consts: &graphene_core::gpu::PushConstants,
) {
let gid = global_id.x as usize;
// Only process up to n, which is the length of the buffers.
if global_id.x < push_consts.n {
y[gid] = node_graph(a[gid]);
}
}
#[allow(unused)]
#[spirv(compute(threads({{compute_threads}})))]
pub fn eval (
#[spirv(global_invocation_id)] global_id: UVec3,
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] a: &[{{input_type}}],
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [{{output_type}}],
#[spirv(push_constant)] push_consts: &graphene_core::gpu::PushConstants,
) {
let gid = global_id.x as usize;
// Only process up to n, which is the length of the buffers.
if global_id.x < push_consts.n {
y[gid] = node_graph(a[gid]);
}
}
fn node_graph(input: {{input_type}}) -> {{output_type}} {
use graphene_core::Node;
fn node_graph(input: {{input_type}}) -> {{output_type}} {
use graphene_core::Node;
{% for node in nodes %}
let {{node.id}} = {{node.fqn}}::new({% for arg in node.args %}{{arg}}, {% endfor %});
{% endfor %}
{% for node in nodes %}
let {{node.id}} = {{node.fqn}}::new({% for arg in node.args %}{{arg}}, {% endfor %});
{% endfor %}
{{last_node}}.eval(input)
}
}
}