shader-rt: cleanup codegen with common sym struct

This commit is contained in:
firestar99 2025-08-21 11:53:53 +02:00
parent 402623f246
commit a2a4102d58

View file

@ -20,59 +20,90 @@ impl Parse for PerPixelAdjust {
impl ShaderCodegen for PerPixelAdjust {
fn codegen(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream) -> syn::Result<ShaderTokens> {
let (shader_entry_point, entry_point_name) = self.codegen_shader_entry_point(parsed)?;
let gpu_node = self.codegen_gpu_node(parsed, node_cfg, &entry_point_name)?;
Ok(ShaderTokens { shader_entry_point, gpu_node })
let fn_name = &parsed.fn_name;
// categorize params and assign image bindings
// bindings for images start at 1
let params = {
let mut binding_cnt = 0;
parsed
.fields
.iter()
.map(|f| {
let ident = &f.pat_ident;
match &f.ty {
ParsedFieldType::Node { .. } => Err(syn::Error::new_spanned(ident, "PerPixelAdjust shader nodes cannot accept other nodes as generics")),
ParsedFieldType::Regular(RegularParsedField { gpu_image: false, ty, .. }) => Ok(Param {
ident: Cow::Borrowed(&ident.ident),
ty: ty.to_token_stream(),
param_type: ParamType::Uniform,
}),
ParsedFieldType::Regular(RegularParsedField { gpu_image: true, .. }) => {
binding_cnt += 1;
Ok(Param {
ident: Cow::Owned(format_ident!("image_{}", &ident.ident)),
ty: quote!(Image2d),
param_type: ParamType::Image { binding: binding_cnt },
})
}
}
})
.collect::<syn::Result<Vec<_>>>()?
};
let entry_point_mod = format_ident!("{}_gpu_entry_point", fn_name);
let entry_point_name_ident = format_ident!("ENTRY_POINT_NAME");
let entry_point_name = quote!(#entry_point_mod::#entry_point_name_ident);
let gpu_node_mod = format_ident!("{}_gpu", fn_name);
let codegen = PerPixelAdjustCodegen {
parsed,
node_cfg,
params,
entry_point_mod,
entry_point_name_ident,
entry_point_name,
gpu_node_mod,
};
Ok(ShaderTokens {
shader_entry_point: codegen.codegen_shader_entry_point()?,
gpu_node: codegen.codegen_gpu_node()?,
})
}
}
impl PerPixelAdjust {
fn codegen_shader_entry_point(&self, parsed: &ParsedNodeFn) -> syn::Result<(TokenStream, TokenStream)> {
let fn_name = &parsed.fn_name;
let gpu_mod = format_ident!("{}_gpu_entry_point", fn_name);
let spirv_image_ty = quote!(Image2d);
pub struct PerPixelAdjustCodegen<'a> {
parsed: &'a ParsedNodeFn,
node_cfg: &'a TokenStream,
params: Vec<Param<'a>>,
entry_point_mod: Ident,
entry_point_name_ident: Ident,
entry_point_name: TokenStream,
gpu_node_mod: Ident,
}
// bindings for images start at 1
let mut binding_cnt = 0;
let params = parsed
.fields
.iter()
.map(|f| {
let ident = &f.pat_ident;
match &f.ty {
ParsedFieldType::Node { .. } => Err(syn::Error::new_spanned(ident, "PerPixelAdjust shader nodes cannot accept other nodes as generics")),
ParsedFieldType::Regular(RegularParsedField { gpu_image: false, ty, .. }) => Ok(Param {
ident: Cow::Borrowed(&ident.ident),
ty: Cow::Owned(ty.to_token_stream()),
param_type: ParamType::Uniform,
}),
ParsedFieldType::Regular(RegularParsedField { gpu_image: true, .. }) => {
binding_cnt += 1;
Ok(Param {
ident: Cow::Owned(format_ident!("image_{}", &ident.ident)),
ty: Cow::Borrowed(&spirv_image_ty),
param_type: ParamType::Image { binding: binding_cnt },
})
}
}
})
.collect::<syn::Result<Vec<_>>>()?;
let uniform_members = params
impl PerPixelAdjustCodegen<'_> {
fn codegen_shader_entry_point(&self) -> syn::Result<TokenStream> {
let fn_name = &self.parsed.fn_name;
let uniform_members = self
.params
.iter()
.filter_map(|Param { ident, ty, param_type }| match param_type {
ParamType::Image { .. } => None,
ParamType::Uniform => Some(quote! {#ident: #ty}),
})
.collect::<Vec<_>>();
let image_params = params
let image_params = self
.params
.iter()
.filter_map(|Param { ident, ty, param_type }| match param_type {
ParamType::Image { binding } => Some(quote! {#[spirv(descriptor_set = 0, binding = #binding)] #ident: &#ty}),
ParamType::Uniform => None,
})
.collect::<Vec<_>>();
let call_args = params
let call_args = self
.params
.iter()
.map(|Param { ident, param_type, .. }| match param_type {
ParamType::Image { .. } => quote!(Color::from_vec4(#ident.fetch_with(texel_coord, lod(0)))),
@ -81,11 +112,10 @@ impl PerPixelAdjust {
.collect::<Vec<_>>();
let context = quote!(());
let entry_point_name = format_ident!("ENTRY_POINT_NAME");
let entry_point_sym = quote!(#gpu_mod::#entry_point_name);
let shader_entry_point = quote! {
pub mod #gpu_mod {
let entry_point_mod = &self.entry_point_mod;
let entry_point_name = &self.entry_point_name_ident;
Ok(quote! {
pub mod #entry_point_mod {
use super::*;
use graphene_core_shaders::color::Color;
use spirv_std::spirv;
@ -111,23 +141,19 @@ impl PerPixelAdjust {
*color_out = color.to_vec4();
}
}
};
Ok((shader_entry_point, entry_point_sym))
})
}
fn codegen_gpu_node(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream, entry_point_name: &TokenStream) -> syn::Result<TokenStream> {
let fn_name = format_ident!("{}_gpu", parsed.fn_name);
let struct_name = format_ident!("{}", fn_name.to_string().to_case(Case::Pascal));
let mod_name = fn_name.clone();
let gcore = match &parsed.crate_name {
fn codegen_gpu_node(&self) -> syn::Result<TokenStream> {
let gcore = match &self.parsed.crate_name {
FoundCrate::Itself => format_ident!("crate"),
FoundCrate::Name(name) => format_ident!("{name}"),
};
let raster_gpu: Type = parse_quote!(#gcore::table::Table<#gcore::raster_types::Raster<#gcore::raster_types::GPU>>);
// adapt fields for gpu node
let mut fields = parsed
let raster_gpu: Type = parse_quote!(#gcore::table::Table<#gcore::raster_types::Raster<#gcore::raster_types::GPU>>);
let mut fields = self
.parsed
.fields
.iter()
.map(|f| match &f.ty {
@ -144,7 +170,7 @@ impl PerPixelAdjust {
})
.collect::<syn::Result<Vec<_>>>()?;
// wgpu_executor field
// insert wgpu_executor field
let wgpu_executor = format_ident!("__wgpu_executor");
fields.push(ParsedField {
pat_ident: PatIdent {
@ -174,17 +200,19 @@ impl PerPixelAdjust {
unit: None,
});
// exactly one gpu_image field, may be expanded later
// find exactly one gpu_image field, runtime doesn't support more than 1 atm
let gpu_image_field = {
let mut iter = fields.iter().filter(|f| matches!(f.ty, ParsedFieldType::Regular(RegularParsedField { gpu_image: true, .. })));
match (iter.next(), iter.next()) {
(Some(v), None) => Ok(v),
(Some(_), Some(more)) => Err(syn::Error::new_spanned(&more.pat_ident, "No more than one parameter must be annotated with `#[gpu_image]`")),
(None, _) => Err(syn::Error::new_spanned(&parsed.fn_name, "At least one parameter must be annotated with `#[gpu_image]`")),
(None, _) => Err(syn::Error::new_spanned(&self.parsed.fn_name, "At least one parameter must be annotated with `#[gpu_image]`")),
}?
};
let gpu_image = &gpu_image_field.pat_ident.ident;
// node function body
let entry_point_name = &self.entry_point_name;
let body = quote! {
{
#wgpu_executor.shader_runtime.run_per_pixel_adjust(&::wgpu_executor::shader_runtime::Shaders {
@ -194,19 +222,20 @@ impl PerPixelAdjust {
}
};
// call node codegen
let mut parsed_node_fn = ParsedNodeFn {
vis: parsed.vis.clone(),
vis: self.parsed.vis.clone(),
attributes: NodeFnAttributes {
shader_node: Some(ShaderNodeType::GpuNode),
..parsed.attributes.clone()
..self.parsed.attributes.clone()
},
fn_name,
struct_name,
mod_name: mod_name.clone(),
fn_name: self.gpu_node_mod.clone(),
struct_name: format_ident!("{}", self.gpu_node_mod.to_string().to_case(Case::Pascal)),
mod_name: self.gpu_node_mod.clone(),
fn_generics: vec![parse_quote!('a: 'n)],
where_clause: None,
input: Input {
pat_ident: parsed.input.pat_ident.clone(),
pat_ident: self.parsed.input.pat_ident.clone(),
ty: parse_quote!(impl #gcore::context::Ctx),
implementations: Default::default(),
},
@ -214,19 +243,22 @@ impl PerPixelAdjust {
is_async: true,
fields,
body,
crate_name: parsed.crate_name.clone(),
crate_name: self.parsed.crate_name.clone(),
description: "".to_string(),
};
parsed_node_fn.replace_impl_trait_in_input();
let gpu_node = crate::codegen::generate_node_code(&parsed_node_fn)?;
let gpu_node_impl = crate::codegen::generate_node_code(&parsed_node_fn)?;
// wrap node in `mod #gpu_node_mod`
let node_cfg = self.node_cfg;
let gpu_node_mod = &self.gpu_node_mod;
Ok(quote! {
#node_cfg
mod #mod_name {
mod #gpu_node_mod {
use super::*;
use wgpu_executor::WgpuExecutor;
#gpu_node
#gpu_node_impl
}
})
}
@ -234,7 +266,7 @@ impl PerPixelAdjust {
struct Param<'a> {
ident: Cow<'a, Ident>,
ty: Cow<'a, TokenStream>,
ty: TokenStream,
param_type: ParamType,
}