shader-rt: connect shader runtime

This commit is contained in:
firestar99 2025-08-21 11:08:14 +02:00 committed by Keavon Chambers
parent 0b8eeaf2de
commit fe18f11be4
6 changed files with 80 additions and 16 deletions

1
Cargo.lock generated
View file

@ -2176,6 +2176,7 @@ dependencies = [
"specta", "specta",
"spirv-std", "spirv-std",
"tokio", "tokio",
"wgpu-executor",
] ]
[[package]] [[package]]

View file

@ -17,6 +17,7 @@ default = ["std"]
std = [ std = [
"dep:graphene-core", "dep:graphene-core",
"dep:graphene-raster-nodes-shaders", "dep:graphene-raster-nodes-shaders",
"dep:wgpu-executor",
"dep:dyn-any", "dep:dyn-any",
"dep:image", "dep:image",
"dep:ndarray", "dep:ndarray",
@ -38,6 +39,7 @@ node-macro = { workspace = true }
# Local std dependencies # Local std dependencies
dyn-any = { workspace = true, optional = true } dyn-any = { workspace = true, optional = true }
graphene-core = { workspace = true, optional = true } graphene-core = { workspace = true, optional = true }
wgpu-executor = { workspace = true, optional = true }
graphene-raster-nodes-shaders = { path = "./shaders", optional = true } graphene-raster-nodes-shaders = { path = "./shaders", optional = true }
# Workspace dependencies # Workspace dependencies

View file

@ -6,6 +6,10 @@ pub mod blending_nodes;
pub mod cubic_spline; pub mod cubic_spline;
pub mod fullscreen_vertex; pub mod fullscreen_vertex;
/// required by shader macro
#[cfg(feature = "std")]
pub use graphene_raster_nodes_shaders::WGSL_SHADER;
#[cfg(feature = "std")] #[cfg(feature = "std")]
pub mod curve; pub mod curve;
#[cfg(feature = "std")] #[cfg(feature = "std")]

View file

@ -7,7 +7,7 @@ use quote::{ToTokens, format_ident, quote};
use std::borrow::Cow; use std::borrow::Cow;
use syn::parse::{Parse, ParseStream}; use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated; use syn::punctuated::Punctuated;
use syn::{Type, parse_quote}; use syn::{PatIdent, Type, parse_quote};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct PerPixelAdjust {} pub struct PerPixelAdjust {}
@ -20,15 +20,14 @@ impl Parse for PerPixelAdjust {
impl ShaderCodegen for PerPixelAdjust { impl ShaderCodegen for PerPixelAdjust {
fn codegen(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream) -> syn::Result<ShaderTokens> { fn codegen(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream) -> syn::Result<ShaderTokens> {
Ok(ShaderTokens { let (shader_entry_point, entry_point_name) = self.codegen_shader_entry_point(parsed)?;
shader_entry_point: self.codegen_shader_entry_point(parsed)?, let gpu_node = self.codegen_gpu_node(parsed, node_cfg, &entry_point_name)?;
gpu_node: self.codegen_gpu_node(parsed, node_cfg)?, Ok(ShaderTokens { shader_entry_point, gpu_node })
})
} }
} }
impl PerPixelAdjust { impl PerPixelAdjust {
fn codegen_shader_entry_point(&self, parsed: &ParsedNodeFn) -> syn::Result<TokenStream> { fn codegen_shader_entry_point(&self, parsed: &ParsedNodeFn) -> syn::Result<(TokenStream, TokenStream)> {
let fn_name = &parsed.fn_name; let fn_name = &parsed.fn_name;
let gpu_mod = format_ident!("{}_gpu_entry_point", fn_name); let gpu_mod = format_ident!("{}_gpu_entry_point", fn_name);
let spirv_image_ty = quote!(Image2d); let spirv_image_ty = quote!(Image2d);
@ -82,7 +81,10 @@ impl PerPixelAdjust {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let context = quote!(()); let context = quote!(());
Ok(quote! { let entry_point_name = format_ident!("ENTRY_POINT_NAME");
let entry_point_sym = quote!(#gpu_mod::#entry_point_name);
let shader_entry_point = quote! {
pub mod #gpu_mod { pub mod #gpu_mod {
use super::*; use super::*;
use graphene_core_shaders::color::Color; use graphene_core_shaders::color::Color;
@ -91,6 +93,8 @@ impl PerPixelAdjust {
use spirv_std::image::{Image2d, ImageWithMethods}; use spirv_std::image::{Image2d, ImageWithMethods};
use spirv_std::image::sample_with::lod; use spirv_std::image::sample_with::lod;
pub const #entry_point_name: &str = core::concat!(core::module_path!(), "::entry_point");
pub struct Uniform { pub struct Uniform {
#(#uniform_members),* #(#uniform_members),*
} }
@ -107,10 +111,11 @@ impl PerPixelAdjust {
*color_out = color.to_vec4(); *color_out = color.to_vec4();
} }
} }
}) };
Ok((shader_entry_point, entry_point_sym))
} }
fn codegen_gpu_node(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream) -> syn::Result<TokenStream> { fn codegen_gpu_node(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream, entry_point_name: &TokenStream) -> syn::Result<TokenStream> {
let fn_name = format_ident!("{}_gpu", parsed.fn_name); let fn_name = format_ident!("{}_gpu", parsed.fn_name);
let struct_name = format_ident!("{}", fn_name.to_string().to_case(Case::Pascal)); let struct_name = format_ident!("{}", fn_name.to_string().to_case(Case::Pascal));
let mod_name = fn_name.clone(); let mod_name = fn_name.clone();
@ -121,7 +126,8 @@ impl PerPixelAdjust {
}; };
let raster_gpu: Type = parse_quote!(#gcore::table::Table<#gcore::raster_types::Raster<#gcore::raster_types::GPU>>); let raster_gpu: Type = parse_quote!(#gcore::table::Table<#gcore::raster_types::Raster<#gcore::raster_types::GPU>>);
let fields = parsed // adapt fields for gpu node
let mut fields = parsed
.fields .fields
.iter() .iter()
.map(|f| match &f.ty { .map(|f| match &f.ty {
@ -136,11 +142,55 @@ impl PerPixelAdjust {
ParsedFieldType::Regular(RegularParsedField { gpu_image: false, .. }) => Ok(f.clone()), ParsedFieldType::Regular(RegularParsedField { gpu_image: false, .. }) => Ok(f.clone()),
ParsedFieldType::Node { .. } => Err(syn::Error::new_spanned(&f.pat_ident, "PerPixelAdjust shader nodes cannot accept other nodes as generics")), ParsedFieldType::Node { .. } => Err(syn::Error::new_spanned(&f.pat_ident, "PerPixelAdjust shader nodes cannot accept other nodes as generics")),
}) })
.collect::<syn::Result<_>>()?; .collect::<syn::Result<Vec<_>>>()?;
// wgpu_executor field
let wgpu_executor = format_ident!("__wgpu_executor");
fields.push(ParsedField {
pat_ident: PatIdent {
attrs: vec![],
by_ref: None,
mutability: None,
ident: parse_quote!(#wgpu_executor),
subpat: None,
},
name: None,
description: "".to_string(),
widget_override: Default::default(),
ty: ParsedFieldType::Regular(RegularParsedField {
ty: parse_quote!(WgpuExecutor),
exposed: false,
value_source: Default::default(),
number_soft_min: None,
number_soft_max: None,
number_hard_min: None,
number_hard_max: None,
number_mode_range: None,
implementations: Default::default(),
gpu_image: false,
}),
number_display_decimal_places: None,
number_step: None,
unit: None,
});
// exactly one gpu_image field, may be expanded later
let gpu_image_field = {
let mut iter = fields.iter().filter(|f| matches!(f.ty, ParsedFieldType::Regular(RegularParsedField { gpu_image: true, .. })));
match (iter.next(), iter.next()) {
(Some(v), None) => Ok(v),
(Some(_), Some(more)) => Err(syn::Error::new_spanned(&more.pat_ident, "No more than one parameter must be annotated with `#[gpu_image]`")),
(None, _) => Err(syn::Error::new_spanned(&parsed.fn_name, "At least one parameter must be annotated with `#[gpu_image]`")),
}?
};
let gpu_image = &gpu_image_field.pat_ident.ident;
let body = quote! { let body = quote! {
{ {
#wgpu_executor.shader_runtime.run_per_pixel_adjust(#gpu_image, &::wgpu_executor::shader_runtime::per_pixel_adjust_runtime::PerPixelAdjustInfo {
wgsl_shader: crate::WGSL_SHADER,
fragment_shader_name: super::#entry_point_name,
}).await
} }
}; };
@ -174,6 +224,7 @@ impl PerPixelAdjust {
#node_cfg #node_cfg
mod #mod_name { mod #mod_name {
use super::*; use super::*;
use wgpu_executor::WgpuExecutor;
#gpu_node #gpu_node
} }

View file

@ -2,6 +2,7 @@ mod context;
pub mod shader_runtime; pub mod shader_runtime;
pub mod texture_upload; pub mod texture_upload;
use crate::shader_runtime::ShaderRuntime;
use anyhow::Result; use anyhow::Result;
pub use context::Context; pub use context::Context;
use dyn_any::StaticType; use dyn_any::StaticType;
@ -19,6 +20,7 @@ use wgpu::{Origin3d, SurfaceConfiguration, TextureAspect};
pub struct WgpuExecutor { pub struct WgpuExecutor {
pub context: Context, pub context: Context,
vello_renderer: Mutex<Renderer>, vello_renderer: Mutex<Renderer>,
pub shader_runtime: ShaderRuntime,
} }
impl std::fmt::Debug for WgpuExecutor { impl std::fmt::Debug for WgpuExecutor {
@ -196,6 +198,7 @@ impl WgpuExecutor {
.ok()?; .ok()?;
Some(Self { Some(Self {
shader_runtime: ShaderRuntime::new(&context),
context, context,
vello_renderer: vello_renderer.into(), vello_renderer: vello_renderer.into(),
}) })

View file

@ -35,8 +35,8 @@ impl ShaderRuntime {
} }
pub struct PerPixelAdjustInfo<'a> { pub struct PerPixelAdjustInfo<'a> {
shader_wgsl: &'a str, pub wgsl_shader: &'a str,
fragment_shader_name: &'a str, pub fragment_shader_name: &'a str,
} }
pub struct PerPixelAdjustGraphicsPipeline { pub struct PerPixelAdjustGraphicsPipeline {
@ -48,9 +48,12 @@ impl PerPixelAdjustGraphicsPipeline {
pub fn new(context: &Context, info: &PerPixelAdjustInfo) -> Self { pub fn new(context: &Context, info: &PerPixelAdjustInfo) -> Self {
let device = &context.device; let device = &context.device;
let name = info.fragment_shader_name.to_owned(); let name = info.fragment_shader_name.to_owned();
// TODO workaround to naga removing `:`
let fragment_name = name.replace(":", "");
let shader_module = device.create_shader_module(ShaderModuleDescriptor { let shader_module = device.create_shader_module(ShaderModuleDescriptor {
label: Some(&format!("PerPixelAdjust {} wgsl shader", name)), label: Some(&format!("PerPixelAdjust {} wgsl shader", name)),
source: ShaderSource::Wgsl(Cow::Borrowed(info.shader_wgsl)), source: ShaderSource::Wgsl(Cow::Borrowed(info.wgsl_shader)),
}); });
let pipeline = device.create_render_pipeline(&RenderPipelineDescriptor { let pipeline = device.create_render_pipeline(&RenderPipelineDescriptor {
label: Some(&format!("PerPixelAdjust {} Pipeline", name)), label: Some(&format!("PerPixelAdjust {} Pipeline", name)),
@ -74,7 +77,7 @@ impl PerPixelAdjustGraphicsPipeline {
multisample: Default::default(), multisample: Default::default(),
fragment: Some(FragmentState { fragment: Some(FragmentState {
module: &shader_module, module: &shader_module,
entry_point: Some(&name), entry_point: Some(&fragment_name),
compilation_options: Default::default(), compilation_options: Default::default(),
targets: &[Some(ColorTargetState { targets: &[Some(ColorTargetState {
format: TextureFormat::Rgba32Float, format: TextureFormat::Rgba32Float,