diff --git a/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs b/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs index 8dc97eba2..06a450801 100644 --- a/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs +++ b/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs @@ -187,10 +187,10 @@ impl PerPixelAdjust { let body = quote! { { - #wgpu_executor.shader_runtime.run_per_pixel_adjust(#gpu_image, &::wgpu_executor::shader_runtime::per_pixel_adjust_runtime::PerPixelAdjustInfo { + #wgpu_executor.shader_runtime.run_per_pixel_adjust(&::wgpu_executor::shader_runtime::Shaders { wgsl_shader: crate::WGSL_SHADER, fragment_shader_name: super::#entry_point_name, - }).await + }, #gpu_image, &()).await } }; diff --git a/node-graph/wgpu-executor/src/shader_runtime/mod.rs b/node-graph/wgpu-executor/src/shader_runtime/mod.rs index e7e0df8d9..2745d5bda 100644 --- a/node-graph/wgpu-executor/src/shader_runtime/mod.rs +++ b/node-graph/wgpu-executor/src/shader_runtime/mod.rs @@ -18,3 +18,8 @@ impl ShaderRuntime { } } } + +pub struct Shaders<'a> { + pub wgsl_shader: &'a str, + pub fragment_shader_name: &'a str, +} diff --git a/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs b/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs index 2b01bfc70..352adb7e9 100644 --- a/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs +++ b/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs @@ -1,13 +1,15 @@ use crate::Context; -use crate::shader_runtime::{FULLSCREEN_VERTEX_SHADER_NAME, ShaderRuntime}; +use crate::shader_runtime::{FULLSCREEN_VERTEX_SHADER_NAME, ShaderRuntime, Shaders}; +use bytemuck::NoUninit; use futures::lock::Mutex; use graphene_core::raster_types::{GPU, Raster}; use graphene_core::table::{Table, TableRow}; use std::borrow::Cow; use std::collections::HashMap; +use wgpu::util::{BufferInitDescriptor, DeviceExt}; use wgpu::{ - BindGroupDescriptor, BindGroupEntry, BindingResource, ColorTargetState, Face, FragmentState, FrontFace, LoadOp, Operations, PolygonMode, PrimitiveState, PrimitiveTopology, - RenderPassColorAttachment, RenderPassDescriptor, RenderPipelineDescriptor, ShaderModuleDescriptor, ShaderSource, StoreOp, TextureDescriptor, TextureDimension, TextureFormat, + BindGroupDescriptor, BindGroupEntry, BindingResource, Buffer, BufferBinding, BufferUsages, ColorTargetState, Face, FragmentState, FrontFace, LoadOp, Operations, PolygonMode, PrimitiveState, + PrimitiveTopology, RenderPassColorAttachment, RenderPassDescriptor, RenderPipelineDescriptor, ShaderModuleDescriptor, ShaderSource, StoreOp, TextureDescriptor, TextureDimension, TextureFormat, TextureViewDescriptor, VertexState, }; @@ -25,18 +27,20 @@ impl PerPixelAdjustShaderRuntime { } impl ShaderRuntime { - pub async fn run_per_pixel_adjust(&self, input: Table>, info: &PerPixelAdjustInfo<'_>) -> Table> { + pub async fn run_per_pixel_adjust(&self, shaders: &Shaders<'_>, textures: Table>, args: &T) -> Table> { let mut cache = self.per_pixel_adjust.pipeline_cache.lock().await; let pipeline = cache - .entry(info.fragment_shader_name.to_owned()) - .or_insert_with(|| PerPixelAdjustGraphicsPipeline::new(&self.context, &info)); - pipeline.run(&self.context, input) - } -} + .entry(shaders.fragment_shader_name.to_owned()) + .or_insert_with(|| PerPixelAdjustGraphicsPipeline::new(&self.context, &shaders)); -pub struct PerPixelAdjustInfo<'a> { - pub wgsl_shader: &'a str, - pub fragment_shader_name: &'a str, + let device = &self.context.device; + let arg_buffer = device.create_buffer_init(&BufferInitDescriptor { + label: Some(&format!("{} arg buffer", pipeline.name.as_str())), + usage: BufferUsages::STORAGE, + contents: bytemuck::bytes_of(args), + }); + pipeline.dispatch(&self.context, textures, &arg_buffer) + } } pub struct PerPixelAdjustGraphicsPipeline { @@ -45,11 +49,14 @@ pub struct PerPixelAdjustGraphicsPipeline { } impl PerPixelAdjustGraphicsPipeline { - pub fn new(context: &Context, info: &PerPixelAdjustInfo) -> Self { + pub fn new(context: &Context, info: &Shaders) -> Self { let device = &context.device; let name = info.fragment_shader_name.to_owned(); + // TODO workaround to naga removing `:` - let fragment_name = name.replace(":", ""); + let fragment_name = &name; + let fragment_name = &fragment_name[(fragment_name.find("::").unwrap() + 2)..]; + let fragment_name = fragment_name.replace(":", ""); let shader_module = device.create_shader_module(ShaderModuleDescriptor { label: Some(&format!("PerPixelAdjust {} wgsl shader", name)), @@ -91,12 +98,12 @@ impl PerPixelAdjustGraphicsPipeline { Self { pipeline, name } } - pub fn run(&self, context: &Context, input: Table>) -> Table> { + pub fn dispatch(&self, context: &Context, textures: Table>, arg_buffer: &Buffer) -> Table> { let device = &context.device; let name = self.name.as_str(); let mut cmd = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("gpu_invert") }); - let out = input + let out = textures .iter() .map(|instance| { let tex_in = &instance.element.texture; @@ -107,10 +114,20 @@ impl PerPixelAdjustGraphicsPipeline { label: Some(&format!("{name} bind group")), // `get_bind_group_layout` allocates unnecessary memory, we could create it manually to not do that layout: &self.pipeline.get_bind_group_layout(0), - entries: &[BindGroupEntry { - binding: 0, - resource: BindingResource::TextureView(&view_in), - }], + entries: &[ + BindGroupEntry { + binding: 0, + resource: BindingResource::Buffer(BufferBinding { + buffer: arg_buffer, + offset: 0, + size: None, + }), + }, + BindGroupEntry { + binding: 1, + resource: BindingResource::TextureView(&view_in), + }, + ], }); let tex_out = device.create_texture(&TextureDescriptor {