diff --git a/node-graph/gcore/src/raster_types.rs b/node-graph/gcore/src/raster_types.rs index 97dd13814..7efae73fb 100644 --- a/node-graph/gcore/src/raster_types.rs +++ b/node-graph/gcore/src/raster_types.rs @@ -137,7 +137,7 @@ mod gpu { #[derive(Clone, Debug, PartialEq, Hash)] pub struct GPU { - texture: wgpu::Texture, + pub texture: wgpu::Texture, } impl Sealed for Raster {} diff --git a/node-graph/graster-nodes/src/fullscreen_vertex.rs b/node-graph/graster-nodes/src/fullscreen_vertex.rs new file mode 100644 index 000000000..b8ef775b9 --- /dev/null +++ b/node-graph/graster-nodes/src/fullscreen_vertex.rs @@ -0,0 +1,14 @@ +use glam::{Vec2, Vec4}; +use spirv_std::spirv; + +/// webgpu NDC is like OpenGL: (-1.0 .. 1.0, -1.0 .. 1.0, 0.0 .. 1.0) +/// https://www.w3.org/TR/webgpu/#coordinate-systems +const FULLSCREEN_VERTICES: [Vec2; 3] = [Vec2::new(-1., -1.), Vec2::new(-1., 3.), Vec2::new(3., -1.)]; + +#[spirv(vertex)] +pub fn fullscreen_vertex(#[spirv(vertex_index)] vertex_index: u32, #[spirv(position)] gl_position: &mut Vec4) { + // broken on edition 2024 branch + // let vertex = unsafe { *FULLSCREEN_VERTICES.index_unchecked(vertex_index as usize) }; + let vertex = FULLSCREEN_VERTICES[vertex_index as usize]; + *gl_position = Vec4::from((vertex, 0., 1.)); +} diff --git a/node-graph/graster-nodes/src/lib.rs b/node-graph/graster-nodes/src/lib.rs index 8dc169cea..793b04120 100644 --- a/node-graph/graster-nodes/src/lib.rs +++ b/node-graph/graster-nodes/src/lib.rs @@ -4,6 +4,7 @@ pub mod adjust; pub mod adjustments; pub mod blending_nodes; pub mod cubic_spline; +pub mod fullscreen_vertex; #[cfg(feature = "std")] pub mod curve; diff --git a/node-graph/node-macro/src/codegen.rs b/node-graph/node-macro/src/codegen.rs index 1c0c4f651..da7f0b6a2 100644 --- a/node-graph/node-macro/src/codegen.rs +++ b/node-graph/node-macro/src/codegen.rs @@ -295,7 +295,12 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result, _>(|n| Ok((n.codegen_shader_entry_point(parsed)?, n.codegen_gpu_node(parsed)?))) + .unwrap_or(Ok((TokenStream::new(), TokenStream::new())))?; + Ok(quote! { /// Underlying implementation for [#struct_name] #[inline] @@ -387,6 +392,8 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result, pub(crate) display_name: Option, diff --git a/node-graph/node-macro/src/shader_nodes/mod.rs b/node-graph/node-macro/src/shader_nodes/mod.rs index 0720869d0..3eff7fed1 100644 --- a/node-graph/node-macro/src/shader_nodes/mod.rs +++ b/node-graph/node-macro/src/shader_nodes/mod.rs @@ -19,7 +19,7 @@ pub fn modify_cfg(attributes: &NodeFnAttributes) -> TokenStream { } } -#[derive(Debug, VariantNames)] +#[derive(Debug, Clone, VariantNames)] pub(crate) enum ShaderNodeType { PerPixelAdjust(PerPixelAdjust), } @@ -36,6 +36,7 @@ impl Parse for ShaderNodeType { pub trait CodegenShaderEntryPoint { fn codegen_shader_entry_point(&self, parsed: &ParsedNodeFn) -> syn::Result; + fn codegen_gpu_node(&self, parsed: &ParsedNodeFn) -> syn::Result; } impl CodegenShaderEntryPoint for ShaderNodeType { @@ -48,4 +49,10 @@ impl CodegenShaderEntryPoint for ShaderNodeType { ShaderNodeType::PerPixelAdjust(x) => x.codegen_shader_entry_point(parsed), } } + + fn codegen_gpu_node(&self, parsed: &ParsedNodeFn) -> syn::Result { + match self { + ShaderNodeType::PerPixelAdjust(x) => x.codegen_gpu_node(parsed), + } + } } diff --git a/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs b/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs index 0e220c8aa..7e3a64417 100644 --- a/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs +++ b/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs @@ -1,11 +1,13 @@ -use crate::parsing::{ParsedFieldType, ParsedNodeFn, RegularParsedField}; +use crate::parsing::{Input, NodeFnAttributes, ParsedField, ParsedFieldType, ParsedNodeFn, RegularParsedField}; use crate::shader_nodes::CodegenShaderEntryPoint; +use convert_case::{Case, Casing}; use proc_macro2::{Ident, TokenStream}; use quote::{ToTokens, format_ident, quote}; use std::borrow::Cow; use syn::parse::{Parse, ParseStream}; +use syn::{Path, Type, TypePath}; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct PerPixelAdjust {} impl Parse for PerPixelAdjust { @@ -17,7 +19,7 @@ impl Parse for PerPixelAdjust { impl CodegenShaderEntryPoint for PerPixelAdjust { fn codegen_shader_entry_point(&self, parsed: &ParsedNodeFn) -> syn::Result { let fn_name = &parsed.fn_name; - let gpu_mod = format_ident!("{}_gpu", parsed.fn_name); + let gpu_mod = format_ident!("{}_gpu_entry_point", parsed.fn_name); let spirv_image_ty = quote!(Image2d); // bindings for images start at 1 @@ -96,6 +98,52 @@ impl CodegenShaderEntryPoint for PerPixelAdjust { } }) } + + fn codegen_gpu_node(&self, parsed: &ParsedNodeFn) -> syn::Result { + let fn_name = format_ident!("{}_gpu", parsed.fn_name); + let struct_name = format_ident!("{}", fn_name.to_string().to_case(Case::Pascal)); + let mod_name = fn_name.clone(); + + let fields = parsed + .fields + .iter() + .map(|f| match &f.ty { + ParsedFieldType::Regular(reg) => Ok(ParsedField { + ty: ParsedFieldType::Regular(RegularParsedField { gpu_image: false, ..reg.clone() }), + ..f.clone() + }), + ParsedFieldType::Node { .. } => Err(syn::Error::new_spanned(&f.pat_ident, "PerPixelAdjust shader nodes cannot accept other nodes as generics")), + }) + .collect::>()?; + let body = quote! {}; + + crate::codegen::generate_node_code(&ParsedNodeFn { + vis: parsed.vis.clone(), + attributes: NodeFnAttributes { + shader_node: None, + ..parsed.attributes.clone() + }, + fn_name, + struct_name, + mod_name, + fn_generics: vec![], + where_clause: None, + input: Input { + pat_ident: parsed.input.pat_ident.clone(), + ty: Type::Path(TypePath { + path: Path::from(format_ident!("Ctx")), + qself: None, + }), + implementations: Default::default(), + }, + output_type: parsed.output_type.clone(), + is_async: true, + fields, + body, + crate_name: parsed.crate_name.clone(), + description: "".to_string(), + }) + } } struct Param<'a> { diff --git a/node-graph/wgpu-executor/src/lib.rs b/node-graph/wgpu-executor/src/lib.rs index 920a002c4..b45db35a3 100644 --- a/node-graph/wgpu-executor/src/lib.rs +++ b/node-graph/wgpu-executor/src/lib.rs @@ -1,4 +1,5 @@ mod context; +pub mod shader_runtime; pub mod texture_upload; use anyhow::Result; diff --git a/node-graph/wgpu-executor/src/shader_runtime/mod.rs b/node-graph/wgpu-executor/src/shader_runtime/mod.rs new file mode 100644 index 000000000..e7e0df8d9 --- /dev/null +++ b/node-graph/wgpu-executor/src/shader_runtime/mod.rs @@ -0,0 +1,20 @@ +use crate::Context; +use crate::shader_runtime::per_pixel_adjust_runtime::PerPixelAdjustShaderRuntime; + +pub mod per_pixel_adjust_runtime; + +pub const FULLSCREEN_VERTEX_SHADER_NAME: &str = "fullscreen_vertexfullscreen_vertex"; + +pub struct ShaderRuntime { + context: Context, + per_pixel_adjust: PerPixelAdjustShaderRuntime, +} + +impl ShaderRuntime { + pub fn new(context: &Context) -> Self { + Self { + context: context.clone(), + per_pixel_adjust: PerPixelAdjustShaderRuntime::new(), + } + } +} diff --git a/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs b/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs new file mode 100644 index 000000000..604a2c5bf --- /dev/null +++ b/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs @@ -0,0 +1,155 @@ +use crate::Context; +use crate::shader_runtime::{FULLSCREEN_VERTEX_SHADER_NAME, ShaderRuntime}; +use futures::lock::Mutex; +use graphene_core::raster_types::{GPU, Raster}; +use graphene_core::table::{Table, TableRow}; +use std::borrow::Cow; +use std::collections::HashMap; +use wgpu::{ + BindGroupDescriptor, BindGroupEntry, BindingResource, ColorTargetState, Face, FragmentState, FrontFace, LoadOp, Operations, PolygonMode, PrimitiveState, PrimitiveTopology, + RenderPassColorAttachment, RenderPassDescriptor, RenderPipelineDescriptor, ShaderModuleDescriptor, ShaderSource, StoreOp, TextureDescriptor, TextureDimension, TextureFormat, + TextureViewDescriptor, VertexState, +}; + +pub struct PerPixelAdjustShaderRuntime { + // TODO: PerPixelAdjustGraphicsPipeline already contains the key as `name` + pipeline_cache: Mutex>, +} + +impl PerPixelAdjustShaderRuntime { + pub fn new() -> Self { + Self { + pipeline_cache: Mutex::new(HashMap::new()), + } + } +} + +impl ShaderRuntime { + pub async fn run_per_pixel_adjust(&self, input: Table>, info: &PerPixelAdjustInfo<'_>) -> Table> { + let mut cache = self.per_pixel_adjust.pipeline_cache.lock().await; + let pipeline = cache + .entry(info.fragment_shader_name.to_owned()) + .or_insert_with(|| PerPixelAdjustGraphicsPipeline::new(&self.context, &info)); + pipeline.run(&self.context, input) + } +} + +pub struct PerPixelAdjustInfo<'a> { + shader_wgsl: &'a str, + fragment_shader_name: &'a str, +} + +pub struct PerPixelAdjustGraphicsPipeline { + name: String, + pipeline: wgpu::RenderPipeline, +} + +impl PerPixelAdjustGraphicsPipeline { + pub fn new(context: &Context, info: &PerPixelAdjustInfo) -> Self { + let device = &context.device; + let name = info.fragment_shader_name.to_owned(); + let shader_module = device.create_shader_module(ShaderModuleDescriptor { + label: Some(&format!("PerPixelAdjust {} wgsl shader", name)), + source: ShaderSource::Wgsl(Cow::Borrowed(info.shader_wgsl)), + }); + let pipeline = device.create_render_pipeline(&RenderPipelineDescriptor { + label: Some(&format!("PerPixelAdjust {} Pipeline", name)), + layout: None, + vertex: VertexState { + module: &shader_module, + entry_point: Some(FULLSCREEN_VERTEX_SHADER_NAME), + compilation_options: Default::default(), + buffers: &[], + }, + primitive: PrimitiveState { + topology: PrimitiveTopology::TriangleList, + strip_index_format: None, + front_face: FrontFace::Ccw, + cull_mode: Some(Face::Back), + unclipped_depth: false, + polygon_mode: PolygonMode::Fill, + conservative: false, + }, + depth_stencil: None, + multisample: Default::default(), + fragment: Some(FragmentState { + module: &shader_module, + entry_point: Some(&name), + compilation_options: Default::default(), + targets: &[Some(ColorTargetState { + format: TextureFormat::Rgba32Float, + blend: None, + write_mask: Default::default(), + })], + }), + multiview: None, + cache: None, + }); + Self { pipeline, name } + } + + pub fn run(&self, context: &Context, input: Table>) -> Table> { + let device = &context.device; + let name = self.name.as_str(); + + let mut cmd = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("gpu_invert") }); + let out = input + .iter() + .map(|instance| { + let tex_in = &instance.element.texture; + let view_in = tex_in.create_view(&TextureViewDescriptor::default()); + let format = tex_in.format(); + + let bind_group = device.create_bind_group(&BindGroupDescriptor { + label: Some(&format!("{name} bind group")), + // `get_bind_group_layout` allocates unnecessary memory, we could create it manually to not do that + layout: &self.pipeline.get_bind_group_layout(0), + entries: &[BindGroupEntry { + binding: 0, + resource: BindingResource::TextureView(&view_in), + }], + }); + + let tex_out = device.create_texture(&TextureDescriptor { + label: Some(&format!("{name} texture out")), + size: tex_in.size(), + mip_level_count: 1, + sample_count: 1, + dimension: TextureDimension::D2, + format, + usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST | wgpu::TextureUsages::COPY_SRC | wgpu::TextureUsages::RENDER_ATTACHMENT, + view_formats: &[format], + }); + + let view_out = tex_out.create_view(&TextureViewDescriptor::default()); + let mut rp = cmd.begin_render_pass(&RenderPassDescriptor { + label: Some(&format!("{name} render pipeline")), + color_attachments: &[Some(RenderPassColorAttachment { + view: &view_out, + resolve_target: None, + ops: Operations { + // should be dont_care but wgpu doesn't expose that + load: LoadOp::Clear(wgpu::Color::BLACK), + store: StoreOp::Store, + }, + })], + depth_stencil_attachment: None, + timestamp_writes: None, + occlusion_query_set: None, + }); + rp.set_pipeline(&self.pipeline); + rp.set_bind_group(0, Some(&bind_group), &[]); + rp.draw(0..3, 0..1); + + TableRow { + element: Raster::new(GPU { texture: tex_out }), + transform: *instance.transform, + alpha_blending: *instance.alpha_blending, + source_node_id: *instance.source_node_id, + } + }) + .collect::>(); + context.queue.submit([cmd.finish()]); + out + } +}