From fe18f11be41d735d6f6529d8844b9f0a76f01dd6 Mon Sep 17 00:00:00 2001 From: firestar99 Date: Thu, 21 Aug 2025 11:08:14 +0200 Subject: [PATCH] shader-rt: connect shader runtime --- Cargo.lock | 1 + node-graph/graster-nodes/Cargo.toml | 2 + node-graph/graster-nodes/src/lib.rs | 4 + .../src/shader_nodes/per_pixel_adjust.rs | 75 ++++++++++++++++--- node-graph/wgpu-executor/src/lib.rs | 3 + .../per_pixel_adjust_runtime.rs | 11 ++- 6 files changed, 80 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 207cca9a7..9e50dfe77 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2176,6 +2176,7 @@ dependencies = [ "specta", "spirv-std", "tokio", + "wgpu-executor", ] [[package]] diff --git a/node-graph/graster-nodes/Cargo.toml b/node-graph/graster-nodes/Cargo.toml index 9ebe9ec47..f9dd8e776 100644 --- a/node-graph/graster-nodes/Cargo.toml +++ b/node-graph/graster-nodes/Cargo.toml @@ -17,6 +17,7 @@ default = ["std"] std = [ "dep:graphene-core", "dep:graphene-raster-nodes-shaders", + "dep:wgpu-executor", "dep:dyn-any", "dep:image", "dep:ndarray", @@ -38,6 +39,7 @@ node-macro = { workspace = true } # Local std dependencies dyn-any = { workspace = true, optional = true } graphene-core = { workspace = true, optional = true } +wgpu-executor = { workspace = true, optional = true } graphene-raster-nodes-shaders = { path = "./shaders", optional = true } # Workspace dependencies diff --git a/node-graph/graster-nodes/src/lib.rs b/node-graph/graster-nodes/src/lib.rs index 793b04120..d5383df03 100644 --- a/node-graph/graster-nodes/src/lib.rs +++ b/node-graph/graster-nodes/src/lib.rs @@ -6,6 +6,10 @@ pub mod blending_nodes; pub mod cubic_spline; pub mod fullscreen_vertex; +/// required by shader macro +#[cfg(feature = "std")] +pub use graphene_raster_nodes_shaders::WGSL_SHADER; + #[cfg(feature = "std")] pub mod curve; #[cfg(feature = "std")] diff --git a/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs b/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs index d1f4c99a1..84780538d 100644 --- a/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs +++ b/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs @@ -7,7 +7,7 @@ use quote::{ToTokens, format_ident, quote}; use std::borrow::Cow; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; -use syn::{Type, parse_quote}; +use syn::{PatIdent, Type, parse_quote}; #[derive(Debug, Clone)] pub struct PerPixelAdjust {} @@ -20,15 +20,14 @@ impl Parse for PerPixelAdjust { impl ShaderCodegen for PerPixelAdjust { fn codegen(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream) -> syn::Result { - Ok(ShaderTokens { - shader_entry_point: self.codegen_shader_entry_point(parsed)?, - gpu_node: self.codegen_gpu_node(parsed, node_cfg)?, - }) + let (shader_entry_point, entry_point_name) = self.codegen_shader_entry_point(parsed)?; + let gpu_node = self.codegen_gpu_node(parsed, node_cfg, &entry_point_name)?; + Ok(ShaderTokens { shader_entry_point, gpu_node }) } } impl PerPixelAdjust { - fn codegen_shader_entry_point(&self, parsed: &ParsedNodeFn) -> syn::Result { + fn codegen_shader_entry_point(&self, parsed: &ParsedNodeFn) -> syn::Result<(TokenStream, TokenStream)> { let fn_name = &parsed.fn_name; let gpu_mod = format_ident!("{}_gpu_entry_point", fn_name); let spirv_image_ty = quote!(Image2d); @@ -82,7 +81,10 @@ impl PerPixelAdjust { .collect::>(); let context = quote!(()); - Ok(quote! { + let entry_point_name = format_ident!("ENTRY_POINT_NAME"); + let entry_point_sym = quote!(#gpu_mod::#entry_point_name); + + let shader_entry_point = quote! { pub mod #gpu_mod { use super::*; use graphene_core_shaders::color::Color; @@ -91,6 +93,8 @@ impl PerPixelAdjust { use spirv_std::image::{Image2d, ImageWithMethods}; use spirv_std::image::sample_with::lod; + pub const #entry_point_name: &str = core::concat!(core::module_path!(), "::entry_point"); + pub struct Uniform { #(#uniform_members),* } @@ -107,10 +111,11 @@ impl PerPixelAdjust { *color_out = color.to_vec4(); } } - }) + }; + Ok((shader_entry_point, entry_point_sym)) } - fn codegen_gpu_node(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream) -> syn::Result { + fn codegen_gpu_node(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream, entry_point_name: &TokenStream) -> syn::Result { let fn_name = format_ident!("{}_gpu", parsed.fn_name); let struct_name = format_ident!("{}", fn_name.to_string().to_case(Case::Pascal)); let mod_name = fn_name.clone(); @@ -121,7 +126,8 @@ impl PerPixelAdjust { }; let raster_gpu: Type = parse_quote!(#gcore::table::Table<#gcore::raster_types::Raster<#gcore::raster_types::GPU>>); - let fields = parsed + // adapt fields for gpu node + let mut fields = parsed .fields .iter() .map(|f| match &f.ty { @@ -136,11 +142,55 @@ impl PerPixelAdjust { ParsedFieldType::Regular(RegularParsedField { gpu_image: false, .. }) => Ok(f.clone()), ParsedFieldType::Node { .. } => Err(syn::Error::new_spanned(&f.pat_ident, "PerPixelAdjust shader nodes cannot accept other nodes as generics")), }) - .collect::>()?; + .collect::>>()?; + + // wgpu_executor field + let wgpu_executor = format_ident!("__wgpu_executor"); + fields.push(ParsedField { + pat_ident: PatIdent { + attrs: vec![], + by_ref: None, + mutability: None, + ident: parse_quote!(#wgpu_executor), + subpat: None, + }, + name: None, + description: "".to_string(), + widget_override: Default::default(), + ty: ParsedFieldType::Regular(RegularParsedField { + ty: parse_quote!(WgpuExecutor), + exposed: false, + value_source: Default::default(), + number_soft_min: None, + number_soft_max: None, + number_hard_min: None, + number_hard_max: None, + number_mode_range: None, + implementations: Default::default(), + gpu_image: false, + }), + number_display_decimal_places: None, + number_step: None, + unit: None, + }); + + // exactly one gpu_image field, may be expanded later + let gpu_image_field = { + let mut iter = fields.iter().filter(|f| matches!(f.ty, ParsedFieldType::Regular(RegularParsedField { gpu_image: true, .. }))); + match (iter.next(), iter.next()) { + (Some(v), None) => Ok(v), + (Some(_), Some(more)) => Err(syn::Error::new_spanned(&more.pat_ident, "No more than one parameter must be annotated with `#[gpu_image]`")), + (None, _) => Err(syn::Error::new_spanned(&parsed.fn_name, "At least one parameter must be annotated with `#[gpu_image]`")), + }? + }; + let gpu_image = &gpu_image_field.pat_ident.ident; let body = quote! { { - + #wgpu_executor.shader_runtime.run_per_pixel_adjust(#gpu_image, &::wgpu_executor::shader_runtime::per_pixel_adjust_runtime::PerPixelAdjustInfo { + wgsl_shader: crate::WGSL_SHADER, + fragment_shader_name: super::#entry_point_name, + }).await } }; @@ -174,6 +224,7 @@ impl PerPixelAdjust { #node_cfg mod #mod_name { use super::*; + use wgpu_executor::WgpuExecutor; #gpu_node } diff --git a/node-graph/wgpu-executor/src/lib.rs b/node-graph/wgpu-executor/src/lib.rs index b45db35a3..0b42dd631 100644 --- a/node-graph/wgpu-executor/src/lib.rs +++ b/node-graph/wgpu-executor/src/lib.rs @@ -2,6 +2,7 @@ mod context; pub mod shader_runtime; pub mod texture_upload; +use crate::shader_runtime::ShaderRuntime; use anyhow::Result; pub use context::Context; use dyn_any::StaticType; @@ -19,6 +20,7 @@ use wgpu::{Origin3d, SurfaceConfiguration, TextureAspect}; pub struct WgpuExecutor { pub context: Context, vello_renderer: Mutex, + pub shader_runtime: ShaderRuntime, } impl std::fmt::Debug for WgpuExecutor { @@ -196,6 +198,7 @@ impl WgpuExecutor { .ok()?; Some(Self { + shader_runtime: ShaderRuntime::new(&context), context, vello_renderer: vello_renderer.into(), }) diff --git a/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs b/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs index 604a2c5bf..2b01bfc70 100644 --- a/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs +++ b/node-graph/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs @@ -35,8 +35,8 @@ impl ShaderRuntime { } pub struct PerPixelAdjustInfo<'a> { - shader_wgsl: &'a str, - fragment_shader_name: &'a str, + pub wgsl_shader: &'a str, + pub fragment_shader_name: &'a str, } pub struct PerPixelAdjustGraphicsPipeline { @@ -48,9 +48,12 @@ impl PerPixelAdjustGraphicsPipeline { pub fn new(context: &Context, info: &PerPixelAdjustInfo) -> Self { let device = &context.device; let name = info.fragment_shader_name.to_owned(); + // TODO workaround to naga removing `:` + let fragment_name = name.replace(":", ""); + let shader_module = device.create_shader_module(ShaderModuleDescriptor { label: Some(&format!("PerPixelAdjust {} wgsl shader", name)), - source: ShaderSource::Wgsl(Cow::Borrowed(info.shader_wgsl)), + source: ShaderSource::Wgsl(Cow::Borrowed(info.wgsl_shader)), }); let pipeline = device.create_render_pipeline(&RenderPipelineDescriptor { label: Some(&format!("PerPixelAdjust {} Pipeline", name)), @@ -74,7 +77,7 @@ impl PerPixelAdjustGraphicsPipeline { multisample: Default::default(), fragment: Some(FragmentState { module: &shader_module, - entry_point: Some(&name), + entry_point: Some(&fragment_name), compilation_options: Default::default(), targets: &[Some(ColorTargetState { format: TextureFormat::Rgba32Float,