From 174fdbfd29354a85f23c4fc9c3768d19bec7a56a Mon Sep 17 00:00:00 2001 From: firestar99 Date: Thu, 21 Aug 2025 12:38:23 +0200 Subject: [PATCH] shader-rt: correct arg buffer handling --- node-graph/gcore-shaders/src/blending.rs | 2 +- node-graph/graster-nodes/src/adjustments.rs | 4 +- .../src/shader_nodes/per_pixel_adjust.rs | 91 +++++++++++--- .../wgpu-executor/src/shader_runtime/mod.rs | 5 - .../per_pixel_adjust_runtime.rs | 112 ++++++++++++------ 5 files changed, 152 insertions(+), 62 deletions(-) diff --git a/node-graph/gcore-shaders/src/blending.rs b/node-graph/gcore-shaders/src/blending.rs index c3701e2cc..b305dd091 100644 --- a/node-graph/gcore-shaders/src/blending.rs +++ b/node-graph/gcore-shaders/src/blending.rs @@ -66,7 +66,7 @@ impl AlphaBlending { } #[repr(i32)] -#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Hash)] +#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Hash, bytemuck::NoUninit)] #[cfg_attr(feature = "std", derive(dyn_any::DynAny, specta::Type, serde::Serialize, serde::Deserialize))] pub enum BlendMode { // Basic group diff --git a/node-graph/graster-nodes/src/adjustments.rs b/node-graph/graster-nodes/src/adjustments.rs index 3a41a1a98..210ed9db2 100644 --- a/node-graph/graster-nodes/src/adjustments.rs +++ b/node-graph/graster-nodes/src/adjustments.rs @@ -30,7 +30,7 @@ use num_traits::float::Float; // https://www.adobe.com/devnet-apps/photoshop/fileformatashtml/#:~:text=%27clrL%27%20%3D%20Color%20Lookup // https://www.adobe.com/devnet-apps/photoshop/fileformatashtml/#:~:text=Color%20Lookup%20(Photoshop%20CS6 -#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Hash, node_macro::ChoiceType)] +#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Hash, node_macro::ChoiceType, bytemuck::NoUninit)] #[cfg_attr(feature = "std", derive(dyn_any::DynAny, specta::Type, serde::Serialize, serde::Deserialize))] #[widget(Dropdown)] #[repr(u32)] @@ -560,7 +560,7 @@ pub enum RedGreenBlue { } /// Color Channel -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, node_macro::ChoiceType)] +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, node_macro::ChoiceType, bytemuck::NoUninit)] #[cfg_attr(feature = "std", derive(dyn_any::DynAny, specta::Type, serde::Serialize, serde::Deserialize))] #[widget(Radio)] #[repr(u32)] diff --git a/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs b/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs index 1096d88da..57e4fcefa 100644 --- a/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs +++ b/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs @@ -22,11 +22,11 @@ impl ShaderCodegen for PerPixelAdjust { fn codegen(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream) -> syn::Result { let fn_name = &parsed.fn_name; - // categorize params and assign image bindings - // bindings for images start at 1 - let params = { - let mut binding_cnt = 0; - parsed + let mut params; + let has_uniform; + { + // categorize params + params = parsed .fields .iter() .map(|f| { @@ -39,30 +39,50 @@ impl ShaderCodegen for PerPixelAdjust { param_type: ParamType::Uniform, }), ParsedFieldType::Regular(RegularParsedField { gpu_image: true, .. }) => { - binding_cnt += 1; - Ok(Param { + let param = Param { ident: Cow::Owned(format_ident!("image_{}", &ident.ident)), ty: quote!(Image2d), - param_type: ParamType::Image { binding: binding_cnt }, - }) + param_type: ParamType::Image { binding: 0 }, + }; + Ok(param) } } }) - .collect::>>()? - }; + .collect::>>()?; + + has_uniform = params.iter().any(|p| matches!(p.param_type, ParamType::Uniform)); + + // assign image bindings + // if an arg_buffer exists, bindings for images start at 1 to leave 0 for arg buffer + let mut binding_cnt = if has_uniform { 1 } else { 0 }; + for p in params.iter_mut() { + match &mut p.param_type { + ParamType::Image { binding } => { + *binding = binding_cnt; + binding_cnt += 1; + } + ParamType::Uniform => {} + } + } + } let entry_point_mod = format_ident!("{}_gpu_entry_point", fn_name); let entry_point_name_ident = format_ident!("ENTRY_POINT_NAME"); let entry_point_name = quote!(#entry_point_mod::#entry_point_name_ident); + let uniform_struct_ident = format_ident!("Uniform"); + let uniform_struct = quote!(#entry_point_mod::#uniform_struct_ident); let gpu_node_mod = format_ident!("{}_gpu", fn_name); let codegen = PerPixelAdjustCodegen { parsed, node_cfg, params, + has_uniform, entry_point_mod, entry_point_name_ident, entry_point_name, + uniform_struct_ident, + uniform_struct, gpu_node_mod, }; @@ -77,9 +97,12 @@ pub struct PerPixelAdjustCodegen<'a> { parsed: &'a ParsedNodeFn, node_cfg: &'a TokenStream, params: Vec>, + has_uniform: bool, entry_point_mod: Ident, entry_point_name_ident: Ident, entry_point_name: TokenStream, + uniform_struct_ident: Ident, + uniform_struct: TokenStream, gpu_node_mod: Ident, } @@ -114,6 +137,7 @@ impl PerPixelAdjustCodegen<'_> { let entry_point_mod = &self.entry_point_mod; let entry_point_name = &self.entry_point_name_ident; + let uniform_struct_ident = &self.uniform_struct_ident; Ok(quote! { pub mod #entry_point_mod { use super::*; @@ -125,8 +149,10 @@ impl PerPixelAdjustCodegen<'_> { pub const #entry_point_name: &str = core::concat!(core::module_path!(), "::entry_point"); - pub struct Uniform { - #(#uniform_members),* + #[repr(C)] + #[derive(Copy, Clone, bytemuck::NoUninit)] + pub struct #uniform_struct_ident { + #(pub #uniform_members),* } #[spirv(fragment)] @@ -158,6 +184,11 @@ impl PerPixelAdjustCodegen<'_> { .iter() .map(|f| match &f.ty { ParsedFieldType::Regular(reg @ RegularParsedField { gpu_image: true, .. }) => Ok(ParsedField { + pat_ident: PatIdent { + mutability: None, + by_ref: None, + ..f.pat_ident.clone() + }, ty: ParsedFieldType::Regular(RegularParsedField { ty: raster_gpu.clone(), implementations: Punctuated::default(), @@ -165,7 +196,14 @@ impl PerPixelAdjustCodegen<'_> { }), ..f.clone() }), - ParsedFieldType::Regular(RegularParsedField { gpu_image: false, .. }) => Ok(f.clone()), + ParsedFieldType::Regular(RegularParsedField { gpu_image: false, .. }) => Ok(ParsedField { + pat_ident: PatIdent { + mutability: None, + by_ref: None, + ..f.pat_ident.clone() + }, + ..f.clone() + }), ParsedFieldType::Node { .. } => Err(syn::Error::new_spanned(&f.pat_ident, "PerPixelAdjust shader nodes cannot accept other nodes as generics")), }) .collect::>>()?; @@ -211,14 +249,35 @@ impl PerPixelAdjustCodegen<'_> { }; let gpu_image = &gpu_image_field.pat_ident.ident; + // uniform buffer struct construction + let has_uniform = self.has_uniform; + let uniform_buffer = if has_uniform { + let uniform_struct = &self.uniform_struct; + let uniform_members = self + .params + .iter() + .filter_map(|p| match p.param_type { + ParamType::Image { .. } => None, + ParamType::Uniform => Some(p.ident.as_ref()), + }) + .collect::>(); + quote!(Some(&super::#uniform_struct { + #(#uniform_members),* + })) + } else { + // explicit generics placed here cause it's easier than explicitly writing `run_per_pixel_adjust::<()>` + quote!(Option::<&()>::None) + }; + // node function body let entry_point_name = &self.entry_point_name; let body = quote! { { - #wgpu_executor.shader_runtime.run_per_pixel_adjust(&::wgpu_executor::shader_runtime::Shaders { + #wgpu_executor.shader_runtime.run_per_pixel_adjust(&::wgpu_executor::shader_runtime::per_pixel_adjust_runtime::Shaders { wgsl_shader: crate::WGSL_SHADER, fragment_shader_name: super::#entry_point_name, - }, #gpu_image, &1u32).await + has_uniform: #has_uniform, + }, #gpu_image, #uniform_buffer).await } }; diff --git a/node-graph/wgpu-executor/src/shader_runtime/mod.rs b/node-graph/wgpu-executor/src/shader_runtime/mod.rs index 2745d5bda..e7e0df8d9 100644 --- a/node-graph/wgpu-executor/src/shader_runtime/mod.rs +++ b/node-graph/wgpu-executor/src/shader_runtime/mod.rs @@ -18,8 +18,3 @@ impl ShaderRuntime { } } } - -pub struct Shaders<'a> { - pub wgsl_shader: &'a str, - pub fragment_shader_name: &'a str, -} diff --git a/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs b/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs index 119d34695..d958e0650 100644 --- a/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs +++ b/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs @@ -1,5 +1,5 @@ use crate::Context; -use crate::shader_runtime::{FULLSCREEN_VERTEX_SHADER_NAME, ShaderRuntime, Shaders}; +use crate::shader_runtime::{FULLSCREEN_VERTEX_SHADER_NAME, ShaderRuntime}; use bytemuck::NoUninit; use futures::lock::Mutex; use graphene_core::raster_types::{GPU, Raster}; @@ -27,24 +27,33 @@ impl PerPixelAdjustShaderRuntime { } impl ShaderRuntime { - pub async fn run_per_pixel_adjust(&self, shaders: &Shaders<'_>, textures: Table>, args: &T) -> Table> { + pub async fn run_per_pixel_adjust(&self, shaders: &Shaders<'_>, textures: Table>, args: Option<&T>) -> Table> { let mut cache = self.per_pixel_adjust.pipeline_cache.lock().await; let pipeline = cache .entry(shaders.fragment_shader_name.to_owned()) .or_insert_with(|| PerPixelAdjustGraphicsPipeline::new(&self.context, &shaders)); - let device = &self.context.device; - let arg_buffer = device.create_buffer_init(&BufferInitDescriptor { - label: Some(&format!("{} arg buffer", pipeline.name.as_str())), - usage: BufferUsages::STORAGE, - contents: bytemuck::bytes_of(args), + let arg_buffer = args.map(|args| { + let device = &self.context.device; + device.create_buffer_init(&BufferInitDescriptor { + label: Some(&format!("{} arg buffer", pipeline.name.as_str())), + usage: BufferUsages::STORAGE, + contents: bytemuck::bytes_of(args), + }) }); - pipeline.dispatch(&self.context, textures, &arg_buffer) + pipeline.dispatch(&self.context, textures, arg_buffer) } } +pub struct Shaders<'a> { + pub wgsl_shader: &'a str, + pub fragment_shader_name: &'a str, + pub has_uniform: bool, +} + pub struct PerPixelAdjustGraphicsPipeline { name: String, + has_uniform: bool, pipeline: wgpu::RenderPipeline, } @@ -62,32 +71,46 @@ impl PerPixelAdjustGraphicsPipeline { source: ShaderSource::Wgsl(Cow::Borrowed(info.wgsl_shader)), }); + let entries: &[_] = if info.has_uniform { + &[ + BindGroupLayoutEntry { + binding: 0, + visibility: ShaderStages::FRAGMENT, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + BindGroupLayoutEntry { + binding: 1, + visibility: ShaderStages::FRAGMENT, + ty: BindingType::Texture { + sample_type: TextureSampleType::Float { filterable: false }, + view_dimension: TextureViewDimension::D2, + multisampled: false, + }, + count: None, + }, + ] + } else { + &[BindGroupLayoutEntry { + binding: 0, + visibility: ShaderStages::FRAGMENT, + ty: BindingType::Texture { + sample_type: TextureSampleType::Float { filterable: false }, + view_dimension: TextureViewDimension::D2, + multisampled: false, + }, + count: None, + }] + }; let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor { label: Some(&format!("PerPixelAdjust {} PipelineLayout", name)), bind_group_layouts: &[&device.create_bind_group_layout(&BindGroupLayoutDescriptor { label: Some(&format!("PerPixelAdjust {} BindGroupLayout 0", name)), - entries: &[ - BindGroupLayoutEntry { - binding: 0, - visibility: ShaderStages::FRAGMENT, - ty: BindingType::Buffer { - ty: BufferBindingType::Storage { read_only: true }, - has_dynamic_offset: false, - min_binding_size: None, - }, - count: None, - }, - BindGroupLayoutEntry { - binding: 1, - visibility: ShaderStages::FRAGMENT, - ty: BindingType::Texture { - sample_type: TextureSampleType::Float { filterable: false }, - view_dimension: TextureViewDimension::D2, - multisampled: false, - }, - count: None, - }, - ], + entries, })], push_constant_ranges: &[], }); @@ -125,10 +148,15 @@ impl PerPixelAdjustGraphicsPipeline { multiview: None, cache: None, }); - Self { pipeline, name } + Self { + pipeline, + name, + has_uniform: info.has_uniform, + } } - pub fn dispatch(&self, context: &Context, textures: Table>, arg_buffer: &Buffer) -> Table> { + pub fn dispatch(&self, context: &Context, textures: Table>, arg_buffer: Option) -> Table> { + assert_eq!(self.has_uniform, arg_buffer.is_some()); let device = &context.device; let name = self.name.as_str(); @@ -140,11 +168,8 @@ impl PerPixelAdjustGraphicsPipeline { let view_in = tex_in.create_view(&TextureViewDescriptor::default()); let format = tex_in.format(); - let bind_group = device.create_bind_group(&BindGroupDescriptor { - label: Some(&format!("{name} bind group")), - // `get_bind_group_layout` allocates unnecessary memory, we could create it manually to not do that - layout: &self.pipeline.get_bind_group_layout(0), - entries: &[ + let entries: &[_] = if let Some(arg_buffer) = arg_buffer.as_ref() { + &[ BindGroupEntry { binding: 0, resource: BindingResource::Buffer(BufferBinding { @@ -157,7 +182,18 @@ impl PerPixelAdjustGraphicsPipeline { binding: 1, resource: BindingResource::TextureView(&view_in), }, - ], + ] + } else { + &[BindGroupEntry { + binding: 0, + resource: BindingResource::TextureView(&view_in), + }] + }; + let bind_group = device.create_bind_group(&BindGroupDescriptor { + label: Some(&format!("{name} bind group")), + // `get_bind_group_layout` allocates unnecessary memory, we could create it manually to not do that + layout: &self.pipeline.get_bind_group_layout(0), + entries, }); let tex_out = device.create_texture(&TextureDescriptor {