shader-rt: correct arg buffer handling

This commit is contained in:
firestar99 2025-08-21 12:38:23 +02:00
parent 9b1b575354
commit f838ba5a91
5 changed files with 152 additions and 62 deletions

View file

@ -66,7 +66,7 @@ impl AlphaBlending {
} }
#[repr(i32)] #[repr(i32)]
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Hash)] #[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Hash, bytemuck::NoUninit)]
#[cfg_attr(feature = "std", derive(dyn_any::DynAny, specta::Type, serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "std", derive(dyn_any::DynAny, specta::Type, serde::Serialize, serde::Deserialize))]
pub enum BlendMode { pub enum BlendMode {
// Basic group // Basic group

View file

@ -30,7 +30,7 @@ use num_traits::float::Float;
// https://www.adobe.com/devnet-apps/photoshop/fileformatashtml/#:~:text=%27clrL%27%20%3D%20Color%20Lookup // https://www.adobe.com/devnet-apps/photoshop/fileformatashtml/#:~:text=%27clrL%27%20%3D%20Color%20Lookup
// https://www.adobe.com/devnet-apps/photoshop/fileformatashtml/#:~:text=Color%20Lookup%20(Photoshop%20CS6 // https://www.adobe.com/devnet-apps/photoshop/fileformatashtml/#:~:text=Color%20Lookup%20(Photoshop%20CS6
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Hash, node_macro::ChoiceType)] #[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Hash, node_macro::ChoiceType, bytemuck::NoUninit)]
#[cfg_attr(feature = "std", derive(dyn_any::DynAny, specta::Type, serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "std", derive(dyn_any::DynAny, specta::Type, serde::Serialize, serde::Deserialize))]
#[widget(Dropdown)] #[widget(Dropdown)]
#[repr(u32)] #[repr(u32)]
@ -560,7 +560,7 @@ pub enum RedGreenBlue {
} }
/// Color Channel /// Color Channel
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, node_macro::ChoiceType)] #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, node_macro::ChoiceType, bytemuck::NoUninit)]
#[cfg_attr(feature = "std", derive(dyn_any::DynAny, specta::Type, serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "std", derive(dyn_any::DynAny, specta::Type, serde::Serialize, serde::Deserialize))]
#[widget(Radio)] #[widget(Radio)]
#[repr(u32)] #[repr(u32)]

View file

@ -22,11 +22,11 @@ impl ShaderCodegen for PerPixelAdjust {
fn codegen(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream) -> syn::Result<ShaderTokens> { fn codegen(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream) -> syn::Result<ShaderTokens> {
let fn_name = &parsed.fn_name; let fn_name = &parsed.fn_name;
// categorize params and assign image bindings let mut params;
// bindings for images start at 1 let has_uniform;
let params = { {
let mut binding_cnt = 0; // categorize params
parsed params = parsed
.fields .fields
.iter() .iter()
.map(|f| { .map(|f| {
@ -39,30 +39,50 @@ impl ShaderCodegen for PerPixelAdjust {
param_type: ParamType::Uniform, param_type: ParamType::Uniform,
}), }),
ParsedFieldType::Regular(RegularParsedField { gpu_image: true, .. }) => { ParsedFieldType::Regular(RegularParsedField { gpu_image: true, .. }) => {
binding_cnt += 1; let param = Param {
Ok(Param {
ident: Cow::Owned(format_ident!("image_{}", &ident.ident)), ident: Cow::Owned(format_ident!("image_{}", &ident.ident)),
ty: quote!(Image2d), ty: quote!(Image2d),
param_type: ParamType::Image { binding: binding_cnt }, param_type: ParamType::Image { binding: 0 },
}) };
Ok(param)
} }
} }
}) })
.collect::<syn::Result<Vec<_>>>()? .collect::<syn::Result<Vec<_>>>()?;
};
has_uniform = params.iter().any(|p| matches!(p.param_type, ParamType::Uniform));
// assign image bindings
// if an arg_buffer exists, bindings for images start at 1 to leave 0 for arg buffer
let mut binding_cnt = if has_uniform { 1 } else { 0 };
for p in params.iter_mut() {
match &mut p.param_type {
ParamType::Image { binding } => {
*binding = binding_cnt;
binding_cnt += 1;
}
ParamType::Uniform => {}
}
}
}
let entry_point_mod = format_ident!("{}_gpu_entry_point", fn_name); let entry_point_mod = format_ident!("{}_gpu_entry_point", fn_name);
let entry_point_name_ident = format_ident!("ENTRY_POINT_NAME"); let entry_point_name_ident = format_ident!("ENTRY_POINT_NAME");
let entry_point_name = quote!(#entry_point_mod::#entry_point_name_ident); let entry_point_name = quote!(#entry_point_mod::#entry_point_name_ident);
let uniform_struct_ident = format_ident!("Uniform");
let uniform_struct = quote!(#entry_point_mod::#uniform_struct_ident);
let gpu_node_mod = format_ident!("{}_gpu", fn_name); let gpu_node_mod = format_ident!("{}_gpu", fn_name);
let codegen = PerPixelAdjustCodegen { let codegen = PerPixelAdjustCodegen {
parsed, parsed,
node_cfg, node_cfg,
params, params,
has_uniform,
entry_point_mod, entry_point_mod,
entry_point_name_ident, entry_point_name_ident,
entry_point_name, entry_point_name,
uniform_struct_ident,
uniform_struct,
gpu_node_mod, gpu_node_mod,
}; };
@ -77,9 +97,12 @@ pub struct PerPixelAdjustCodegen<'a> {
parsed: &'a ParsedNodeFn, parsed: &'a ParsedNodeFn,
node_cfg: &'a TokenStream, node_cfg: &'a TokenStream,
params: Vec<Param<'a>>, params: Vec<Param<'a>>,
has_uniform: bool,
entry_point_mod: Ident, entry_point_mod: Ident,
entry_point_name_ident: Ident, entry_point_name_ident: Ident,
entry_point_name: TokenStream, entry_point_name: TokenStream,
uniform_struct_ident: Ident,
uniform_struct: TokenStream,
gpu_node_mod: Ident, gpu_node_mod: Ident,
} }
@ -114,6 +137,7 @@ impl PerPixelAdjustCodegen<'_> {
let entry_point_mod = &self.entry_point_mod; let entry_point_mod = &self.entry_point_mod;
let entry_point_name = &self.entry_point_name_ident; let entry_point_name = &self.entry_point_name_ident;
let uniform_struct_ident = &self.uniform_struct_ident;
Ok(quote! { Ok(quote! {
pub mod #entry_point_mod { pub mod #entry_point_mod {
use super::*; use super::*;
@ -125,8 +149,10 @@ impl PerPixelAdjustCodegen<'_> {
pub const #entry_point_name: &str = core::concat!(core::module_path!(), "::entry_point"); pub const #entry_point_name: &str = core::concat!(core::module_path!(), "::entry_point");
pub struct Uniform { #[repr(C)]
#(#uniform_members),* #[derive(Copy, Clone, bytemuck::NoUninit)]
pub struct #uniform_struct_ident {
#(pub #uniform_members),*
} }
#[spirv(fragment)] #[spirv(fragment)]
@ -158,6 +184,11 @@ impl PerPixelAdjustCodegen<'_> {
.iter() .iter()
.map(|f| match &f.ty { .map(|f| match &f.ty {
ParsedFieldType::Regular(reg @ RegularParsedField { gpu_image: true, .. }) => Ok(ParsedField { ParsedFieldType::Regular(reg @ RegularParsedField { gpu_image: true, .. }) => Ok(ParsedField {
pat_ident: PatIdent {
mutability: None,
by_ref: None,
..f.pat_ident.clone()
},
ty: ParsedFieldType::Regular(RegularParsedField { ty: ParsedFieldType::Regular(RegularParsedField {
ty: raster_gpu.clone(), ty: raster_gpu.clone(),
implementations: Punctuated::default(), implementations: Punctuated::default(),
@ -165,7 +196,14 @@ impl PerPixelAdjustCodegen<'_> {
}), }),
..f.clone() ..f.clone()
}), }),
ParsedFieldType::Regular(RegularParsedField { gpu_image: false, .. }) => Ok(f.clone()), ParsedFieldType::Regular(RegularParsedField { gpu_image: false, .. }) => Ok(ParsedField {
pat_ident: PatIdent {
mutability: None,
by_ref: None,
..f.pat_ident.clone()
},
..f.clone()
}),
ParsedFieldType::Node { .. } => Err(syn::Error::new_spanned(&f.pat_ident, "PerPixelAdjust shader nodes cannot accept other nodes as generics")), ParsedFieldType::Node { .. } => Err(syn::Error::new_spanned(&f.pat_ident, "PerPixelAdjust shader nodes cannot accept other nodes as generics")),
}) })
.collect::<syn::Result<Vec<_>>>()?; .collect::<syn::Result<Vec<_>>>()?;
@ -211,14 +249,35 @@ impl PerPixelAdjustCodegen<'_> {
}; };
let gpu_image = &gpu_image_field.pat_ident.ident; let gpu_image = &gpu_image_field.pat_ident.ident;
// uniform buffer struct construction
let has_uniform = self.has_uniform;
let uniform_buffer = if has_uniform {
let uniform_struct = &self.uniform_struct;
let uniform_members = self
.params
.iter()
.filter_map(|p| match p.param_type {
ParamType::Image { .. } => None,
ParamType::Uniform => Some(p.ident.as_ref()),
})
.collect::<Vec<_>>();
quote!(Some(&super::#uniform_struct {
#(#uniform_members),*
}))
} else {
// explicit generics placed here cause it's easier than explicitly writing `run_per_pixel_adjust::<()>`
quote!(Option::<&()>::None)
};
// node function body // node function body
let entry_point_name = &self.entry_point_name; let entry_point_name = &self.entry_point_name;
let body = quote! { let body = quote! {
{ {
#wgpu_executor.shader_runtime.run_per_pixel_adjust(&::wgpu_executor::shader_runtime::Shaders { #wgpu_executor.shader_runtime.run_per_pixel_adjust(&::wgpu_executor::shader_runtime::per_pixel_adjust_runtime::Shaders {
wgsl_shader: crate::WGSL_SHADER, wgsl_shader: crate::WGSL_SHADER,
fragment_shader_name: super::#entry_point_name, fragment_shader_name: super::#entry_point_name,
}, #gpu_image, &1u32).await has_uniform: #has_uniform,
}, #gpu_image, #uniform_buffer).await
} }
}; };

View file

@ -18,8 +18,3 @@ impl ShaderRuntime {
} }
} }
} }
pub struct Shaders<'a> {
pub wgsl_shader: &'a str,
pub fragment_shader_name: &'a str,
}

View file

@ -1,5 +1,5 @@
use crate::Context; use crate::Context;
use crate::shader_runtime::{FULLSCREEN_VERTEX_SHADER_NAME, ShaderRuntime, Shaders}; use crate::shader_runtime::{FULLSCREEN_VERTEX_SHADER_NAME, ShaderRuntime};
use bytemuck::NoUninit; use bytemuck::NoUninit;
use futures::lock::Mutex; use futures::lock::Mutex;
use graphene_core::raster_types::{GPU, Raster}; use graphene_core::raster_types::{GPU, Raster};
@ -27,24 +27,33 @@ impl PerPixelAdjustShaderRuntime {
} }
impl ShaderRuntime { impl ShaderRuntime {
pub async fn run_per_pixel_adjust<T: NoUninit>(&self, shaders: &Shaders<'_>, textures: Table<Raster<GPU>>, args: &T) -> Table<Raster<GPU>> { pub async fn run_per_pixel_adjust<T: NoUninit>(&self, shaders: &Shaders<'_>, textures: Table<Raster<GPU>>, args: Option<&T>) -> Table<Raster<GPU>> {
let mut cache = self.per_pixel_adjust.pipeline_cache.lock().await; let mut cache = self.per_pixel_adjust.pipeline_cache.lock().await;
let pipeline = cache let pipeline = cache
.entry(shaders.fragment_shader_name.to_owned()) .entry(shaders.fragment_shader_name.to_owned())
.or_insert_with(|| PerPixelAdjustGraphicsPipeline::new(&self.context, &shaders)); .or_insert_with(|| PerPixelAdjustGraphicsPipeline::new(&self.context, &shaders));
let device = &self.context.device; let arg_buffer = args.map(|args| {
let arg_buffer = device.create_buffer_init(&BufferInitDescriptor { let device = &self.context.device;
label: Some(&format!("{} arg buffer", pipeline.name.as_str())), device.create_buffer_init(&BufferInitDescriptor {
usage: BufferUsages::STORAGE, label: Some(&format!("{} arg buffer", pipeline.name.as_str())),
contents: bytemuck::bytes_of(args), usage: BufferUsages::STORAGE,
contents: bytemuck::bytes_of(args),
})
}); });
pipeline.dispatch(&self.context, textures, &arg_buffer) pipeline.dispatch(&self.context, textures, arg_buffer)
} }
} }
pub struct Shaders<'a> {
pub wgsl_shader: &'a str,
pub fragment_shader_name: &'a str,
pub has_uniform: bool,
}
pub struct PerPixelAdjustGraphicsPipeline { pub struct PerPixelAdjustGraphicsPipeline {
name: String, name: String,
has_uniform: bool,
pipeline: wgpu::RenderPipeline, pipeline: wgpu::RenderPipeline,
} }
@ -62,32 +71,46 @@ impl PerPixelAdjustGraphicsPipeline {
source: ShaderSource::Wgsl(Cow::Borrowed(info.wgsl_shader)), source: ShaderSource::Wgsl(Cow::Borrowed(info.wgsl_shader)),
}); });
let entries: &[_] = if info.has_uniform {
&[
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::FRAGMENT,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::FRAGMENT,
ty: BindingType::Texture {
sample_type: TextureSampleType::Float { filterable: false },
view_dimension: TextureViewDimension::D2,
multisampled: false,
},
count: None,
},
]
} else {
&[BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::FRAGMENT,
ty: BindingType::Texture {
sample_type: TextureSampleType::Float { filterable: false },
view_dimension: TextureViewDimension::D2,
multisampled: false,
},
count: None,
}]
};
let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor { let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some(&format!("PerPixelAdjust {} PipelineLayout", name)), label: Some(&format!("PerPixelAdjust {} PipelineLayout", name)),
bind_group_layouts: &[&device.create_bind_group_layout(&BindGroupLayoutDescriptor { bind_group_layouts: &[&device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some(&format!("PerPixelAdjust {} BindGroupLayout 0", name)), label: Some(&format!("PerPixelAdjust {} BindGroupLayout 0", name)),
entries: &[ entries,
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::FRAGMENT,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::FRAGMENT,
ty: BindingType::Texture {
sample_type: TextureSampleType::Float { filterable: false },
view_dimension: TextureViewDimension::D2,
multisampled: false,
},
count: None,
},
],
})], })],
push_constant_ranges: &[], push_constant_ranges: &[],
}); });
@ -125,10 +148,15 @@ impl PerPixelAdjustGraphicsPipeline {
multiview: None, multiview: None,
cache: None, cache: None,
}); });
Self { pipeline, name } Self {
pipeline,
name,
has_uniform: info.has_uniform,
}
} }
pub fn dispatch(&self, context: &Context, textures: Table<Raster<GPU>>, arg_buffer: &Buffer) -> Table<Raster<GPU>> { pub fn dispatch(&self, context: &Context, textures: Table<Raster<GPU>>, arg_buffer: Option<Buffer>) -> Table<Raster<GPU>> {
assert_eq!(self.has_uniform, arg_buffer.is_some());
let device = &context.device; let device = &context.device;
let name = self.name.as_str(); let name = self.name.as_str();
@ -140,11 +168,8 @@ impl PerPixelAdjustGraphicsPipeline {
let view_in = tex_in.create_view(&TextureViewDescriptor::default()); let view_in = tex_in.create_view(&TextureViewDescriptor::default());
let format = tex_in.format(); let format = tex_in.format();
let bind_group = device.create_bind_group(&BindGroupDescriptor { let entries: &[_] = if let Some(arg_buffer) = arg_buffer.as_ref() {
label: Some(&format!("{name} bind group")), &[
// `get_bind_group_layout` allocates unnecessary memory, we could create it manually to not do that
layout: &self.pipeline.get_bind_group_layout(0),
entries: &[
BindGroupEntry { BindGroupEntry {
binding: 0, binding: 0,
resource: BindingResource::Buffer(BufferBinding { resource: BindingResource::Buffer(BufferBinding {
@ -157,7 +182,18 @@ impl PerPixelAdjustGraphicsPipeline {
binding: 1, binding: 1,
resource: BindingResource::TextureView(&view_in), resource: BindingResource::TextureView(&view_in),
}, },
], ]
} else {
&[BindGroupEntry {
binding: 0,
resource: BindingResource::TextureView(&view_in),
}]
};
let bind_group = device.create_bind_group(&BindGroupDescriptor {
label: Some(&format!("{name} bind group")),
// `get_bind_group_layout` allocates unnecessary memory, we could create it manually to not do that
layout: &self.pipeline.get_bind_group_layout(0),
entries,
}); });
let tex_out = device.create_texture(&TextureDescriptor { let tex_out = device.create_texture(&TextureDescriptor {