Shaders: rust-gpu integration to compile shader nodes to WGSL (#3097)
Some checks failed
Editor: Dev & CI / build (push) Has been cancelled
Editor: Dev & CI / cargo-deny (push) Has been cancelled

* shaders: shader compilation setup

* nix: use rustc_codegen_spirv.so from nix

* shaders: codegen for per_pixel_adjust shader nodes

* shaders: disable nodes needing bool

* shaders: `#[repr(u32)]` some enums

* shaders: add lint ignores from rust-gpu

* shaders: fix node-macro tests

* gcore-shaders: toml cleanup

* shader-nodes feature: put rust-gpu to wgsl compile behind feature gate

* shaders: fix use TokenStream2

* shaders: allow providing shader externally

* Update iai runner in workflow

---------

Co-authored-by: Timon Schelling <me@timon.zip>
Co-authored-by: Dennis Kobert <dennis@kobert.dev>
This commit is contained in:
Firestar99 2025-09-02 16:10:32 +02:00 committed by GitHub
parent 083dfa5f49
commit a10103311e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 717 additions and 20 deletions

View file

@ -7,7 +7,23 @@ authors = ["Graphite Authors <contact@graphite.rs>"]
license = "MIT OR Apache-2.0"
[features]
std = ["dep:dyn-any", "dep:serde", "dep:specta", "dep:log", "glam/debug-glam-assert", "glam/std", "glam/serde", "half/std", "half/serde", "num-traits/std"]
# any feature that
# * must be usable in shaders
# * but requires std
# * and should be on by default
# should be in this list instead of `[workspace.dependency]`
std = [
"dep:dyn-any",
"dep:serde",
"dep:specta",
"dep:log",
"glam/debug-glam-assert",
"glam/std",
"glam/serde",
"half/std",
"half/serde",
"num-traits/std"
]
[dependencies]
# Local std dependencies

View file

@ -3,6 +3,7 @@ use super::discrete_srgb::{float_to_srgb_u8, srgb_u8_to_float};
use bytemuck::{Pod, Zeroable};
use core::fmt::Debug;
use core::hash::Hash;
use glam::Vec4;
use half::f16;
#[cfg(not(feature = "std"))]
use num_traits::Euclid;
@ -1075,6 +1076,21 @@ impl Color {
..*self
}
}
#[inline(always)]
pub const fn from_vec4(vec: Vec4) -> Self {
Self {
red: vec.x,
green: vec.y,
blue: vec.z,
alpha: vec.w,
}
}
#[inline(always)]
pub fn to_vec4(&self) -> Vec4 {
Vec4::new(self.red, self.green, self.blue, self.alpha)
}
}
#[cfg(test)]

View file

@ -6,8 +6,18 @@ description = "graphene raster data format"
authors = ["Graphite Authors <contact@graphite.rs>"]
license = "MIT OR Apache-2.0"
[lib]
crate-type = ["rlib", "dylib"]
[lints]
workspace = true
[features]
default = ["std"]
shader-nodes = [
"std",
"dep:graphene-raster-nodes-shaders",
]
std = [
"dep:graphene-core",
"dep:dyn-any",
@ -19,8 +29,6 @@ std = [
"dep:serde",
"dep:specta",
"dep:kurbo",
"glam/debug-glam-assert",
"glam/serde",
]
[dependencies]
@ -31,10 +39,12 @@ node-macro = { workspace = true }
# Local std dependencies
dyn-any = { workspace = true, optional = true }
graphene-core = { workspace = true, optional = true }
graphene-raster-nodes-shaders = { path = "./shaders", optional = true }
# Workspace dependencies
bytemuck = { workspace = true }
glam = { workspace = true }
spirv-std = { workspace = true }
num-traits = { workspace = true }
# Workspace std dependencies

View file

@ -0,0 +1,13 @@
[package]
name = "graphene-raster-nodes-shaders"
version = "0.1.0"
edition = "2024"
description = "graphene raster data format"
authors = ["Graphite Authors <contact@graphite.rs>"]
license = "MIT OR Apache-2.0"
[dependencies]
[build-dependencies]
cargo-gpu = { workspace = true }
env_logger = { workspace = true }

View file

@ -0,0 +1,55 @@
use cargo_gpu::InstalledBackend;
use cargo_gpu::spirv_builder::{MetadataPrintout, SpirvMetadata};
use std::path::PathBuf;
pub fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::builder().init();
// Skip building the shader if they are provided externally
println!("cargo:rerun-if-env-changed=GRAPHENE_RASTER_NODES_SHADER_PATH");
if !std::env::var("GRAPHENE_RASTER_NODES_SHADER_PATH").unwrap_or_default().is_empty() {
return Ok(());
}
// Allows overriding the PATH to inject the rust-gpu rust toolchain when building the rest of the project with stable rustc.
// Used in nix shell. Do not remove without checking with developers using nix.
println!("cargo:rerun-if-env-changed=RUST_GPU_PATH_OVERRIDE");
if let Ok(path_override) = std::env::var("RUST_GPU_PATH_OVERRIDE") {
let current_path = std::env::var("PATH").unwrap_or_default();
let new_path = format!("{path_override}:{current_path}");
// SAFETY: Build script is single-threaded therefore this cannot lead to undefined behavior.
unsafe {
std::env::set_var("PATH", &new_path);
}
}
let shader_crate = PathBuf::from(concat!(env!("CARGO_MANIFEST_DIR"), "/.."));
println!("cargo:rerun-if-env-changed=RUSTC_CODEGEN_SPIRV_PATH");
let rustc_codegen_spirv_path = std::env::var("RUSTC_CODEGEN_SPIRV_PATH").unwrap_or_default();
let backend = if rustc_codegen_spirv_path.is_empty() {
// install the toolchain and build the `rustc_codegen_spirv` codegen backend with it
cargo_gpu::Install::from_shader_crate(shader_crate.clone()).run()?
} else {
// use the `RUSTC_CODEGEN_SPIRV` environment variable to find the codegen backend
let mut backend = InstalledBackend::default();
backend.rustc_codegen_spirv_location = PathBuf::from(rustc_codegen_spirv_path);
backend.toolchain_channel = "nightly".to_string();
backend.target_spec_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
backend
};
// build the shader crate
let mut builder = backend.to_spirv_builder(shader_crate, "spirv-unknown-naga-wgsl");
builder.print_metadata = MetadataPrintout::DependencyOnly;
builder.spirv_metadata = SpirvMetadata::Full;
builder.shader_crate_features.default_features = false;
let wgsl_result = builder.build()?;
let path_to_spv = wgsl_result.module.unwrap_single();
// needs to be fixed upstream
let path_to_wgsl = path_to_spv.with_extension("wgsl");
println!("cargo::rustc-env=GRAPHENE_RASTER_NODES_SHADER_PATH={}", path_to_wgsl.display());
Ok(())
}

View file

@ -0,0 +1,26 @@
{
"allows-weak-linkage": false,
"arch": "spirv",
"crt-objects-fallback": "false",
"crt-static-allows-dylibs": true,
"crt-static-respected": true,
"data-layout": "e-m:e-p:32:32:32-i64:64-n8:16:32:64",
"dll-prefix": "",
"dll-suffix": ".spv.json",
"dynamic-linking": true,
"emit-debug-gdb-scripts": false,
"env": "naga-wgsl",
"linker-flavor": "unix",
"linker-is-gnu": false,
"llvm-target": "spirv-unknown-naga-wgsl",
"main-needs-argc-argv": false,
"metadata": {
"description": null,
"host_tools": null,
"std": null,
"tier": null
},
"panic-strategy": "abort",
"simd-types-indirect": false,
"target-pointer-width": "32"
}

View file

@ -0,0 +1 @@
pub const WGSL_SHADER: &str = include_str!(env!("GRAPHENE_RASTER_NODES_SHADER_PATH"));

View file

@ -33,6 +33,7 @@ use num_traits::float::Float;
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Hash, node_macro::ChoiceType)]
#[cfg_attr(feature = "std", derive(dyn_any::DynAny, specta::Type, serde::Serialize, serde::Deserialize))]
#[widget(Dropdown)]
#[repr(u32)]
pub enum LuminanceCalculation {
#[default]
#[label("sRGB")]
@ -52,6 +53,7 @@ fn luminance<T: Adjust<Color>>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
mut input: T,
luminance_calc: LuminanceCalculation,
) -> T {
@ -68,7 +70,7 @@ fn luminance<T: Adjust<Color>>(
input
}
#[node_macro::node(category("Raster"), shader_node(PerPixelAdjust))]
#[node_macro::node(category("Raster"), cfg(feature = "std"))]
fn gamma_correction<T: Adjust<Color>>(
_: impl Ctx,
#[implementations(
@ -77,6 +79,7 @@ fn gamma_correction<T: Adjust<Color>>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
mut input: T,
#[default(2.2)]
#[range((0.01, 10.))]
@ -98,6 +101,7 @@ fn extract_channel<T: Adjust<Color>>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
mut input: T,
channel: RedGreenBlueAlpha,
) -> T {
@ -122,6 +126,7 @@ fn make_opaque<T: Adjust<Color>>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
mut input: T,
) -> T {
input.adjust(|color| {
@ -139,7 +144,7 @@ fn make_opaque<T: Adjust<Color>>(
//
// Some further analysis available at:
// https://geraldbakker.nl/psnumbers/brightness-contrast.html
#[node_macro::node(name("Brightness/Contrast"), category("Raster: Adjustment"), properties("brightness_contrast_properties"), shader_node(PerPixelAdjust))]
#[node_macro::node(name("Brightness/Contrast"), category("Raster: Adjustment"), properties("brightness_contrast_properties"), cfg(feature = "std"))]
fn brightness_contrast<T: Adjust<Color>>(
_: impl Ctx,
#[implementations(
@ -148,6 +153,7 @@ fn brightness_contrast<T: Adjust<Color>>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
mut input: T,
brightness: SignedPercentageF32,
contrast: SignedPercentageF32,
@ -238,6 +244,7 @@ fn levels<T: Adjust<Color>>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
mut image: T,
#[default(0.)] shadows: PercentageF32,
#[default(50.)] midtones: PercentageF32,
@ -306,6 +313,7 @@ fn black_and_white<T: Adjust<Color>>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
mut image: T,
#[default(Color::BLACK)] tint: Color,
#[default(40.)]
@ -379,6 +387,7 @@ fn hue_saturation<T: Adjust<Color>>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
mut input: T,
hue_shift: AngleF32,
saturation_shift: SignedPercentageF32,
@ -414,6 +423,7 @@ fn invert<T: Adjust<Color>>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
mut input: T,
) -> T {
input.adjust(|color| {
@ -437,6 +447,7 @@ fn threshold<T: Adjust<Color>>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
mut image: T,
#[default(50.)] min_luminance: PercentageF32,
#[default(100.)] max_luminance: PercentageF32,
@ -483,6 +494,7 @@ fn vibrance<T: Adjust<Color>>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
mut image: T,
vibrance: SignedPercentageF32,
) -> T {
@ -551,6 +563,7 @@ pub enum RedGreenBlue {
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, node_macro::ChoiceType)]
#[cfg_attr(feature = "std", derive(dyn_any::DynAny, specta::Type, serde::Serialize, serde::Deserialize))]
#[widget(Radio)]
#[repr(u32)]
pub enum RedGreenBlueAlpha {
#[default]
Red,
@ -640,7 +653,7 @@ pub enum DomainWarpType {
// Aims for interoperable compatibility with:
// https://www.adobe.com/devnet-apps/photoshop/fileformatashtml/#:~:text=%27mixr%27%20%3D%20Channel%20Mixer
// https://www.adobe.com/devnet-apps/photoshop/fileformatashtml/#:~:text=Lab%20color%20only-,Channel%20Mixer,-Key%20is%20%27mixr
#[node_macro::node(category("Raster: Adjustment"), properties("channel_mixer_properties"), shader_node(PerPixelAdjust))]
#[node_macro::node(category("Raster: Adjustment"), properties("channel_mixer_properties"), cfg(feature = "std"))]
fn channel_mixer<T: Adjust<Color>>(
_: impl Ctx,
#[implementations(
@ -649,6 +662,7 @@ fn channel_mixer<T: Adjust<Color>>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
mut image: T,
monochrome: bool,
@ -769,7 +783,7 @@ pub enum SelectiveColorChoice {
//
// Algorithm based on:
// https://blog.pkh.me/p/22-understanding-selective-coloring-in-adobe-photoshop.html
#[node_macro::node(category("Raster: Adjustment"), properties("selective_color_properties"), shader_node(PerPixelAdjust))]
#[node_macro::node(category("Raster: Adjustment"), properties("selective_color_properties"), cfg(feature = "std"))]
fn selective_color<T: Adjust<Color>>(
_: impl Ctx,
#[implementations(
@ -778,6 +792,7 @@ fn selective_color<T: Adjust<Color>>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
mut image: T,
mode: RelativeAbsolute,
@ -921,6 +936,7 @@ fn posterize<T: Adjust<Color>>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
mut input: T,
#[default(4)]
#[hard_min(2.)]
@ -955,6 +971,7 @@ fn exposure<T: Adjust<Color>>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
mut input: T,
exposure: f32,
offset: f32,

View file

@ -141,6 +141,7 @@ fn blend<T: Blend<Color> + Send>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
over: T,
#[expose]
#[implementations(
@ -149,6 +150,7 @@ fn blend<T: Blend<Color> + Send>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
under: T,
blend_mode: BlendMode,
#[default(100.)] opacity: PercentageF32,
@ -165,6 +167,7 @@ fn color_overlay<T: Adjust<Color>>(
Table<GradientStops>,
GradientStops,
)]
#[gpu_image]
mut image: T,
#[default(Color::BLACK)] color: Color,
blend_mode: BlendMode,

View file

@ -21,6 +21,7 @@ image-compare = []
vello = ["dep:vello", "gpu"]
resvg = []
wayland = ["graph-craft/wayland"]
shader-nodes = ["graphene-raster-nodes/shader-nodes"]
[dependencies]
# Local dependencies

View file

@ -2,8 +2,9 @@ mod benchmark_util;
use benchmark_util::setup_network;
use graphene_std::application_io::RenderConfig;
use iai_callgrind::{black_box, library_benchmark, library_benchmark_group, main};
use iai_callgrind::{library_benchmark, library_benchmark_group, main};
use interpreted_executor::dynamic_executor::DynamicExecutor;
use std::hint::black_box;
fn setup_run_cached(name: &str) -> DynamicExecutor {
let (executor, _) = setup_network(name);

View file

@ -2,8 +2,9 @@ mod benchmark_util;
use benchmark_util::setup_network;
use graphene_std::application_io;
use iai_callgrind::{black_box, library_benchmark, library_benchmark_group, main};
use iai_callgrind::{library_benchmark, library_benchmark_group, main};
use interpreted_executor::dynamic_executor::DynamicExecutor;
use std::hint::black_box;
fn setup_run_once(name: &str) -> DynamicExecutor {
let (executor, _) = setup_network(name);

View file

@ -2,8 +2,9 @@ mod benchmark_util;
use benchmark_util::setup_network;
use graph_craft::proto::ProtoNetwork;
use iai_callgrind::{black_box, library_benchmark, library_benchmark_group, main};
use iai_callgrind::{library_benchmark, library_benchmark_group, main};
use interpreted_executor::dynamic_executor::DynamicExecutor;
use std::hint::black_box;
fn setup_update_executor(name: &str) -> (DynamicExecutor, ProtoNetwork) {
let (_, proto_network) = setup_network(name);

View file

@ -295,6 +295,7 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
let cfg = crate::shader_nodes::modify_cfg(attributes);
let node_input_accessor = generate_node_input_references(parsed, fn_generics, &field_idents, &graphene_core, &identifier, &cfg);
let shader_entry_point = attributes.shader_node.as_ref().map(|n| n.codegen_shader_entry_point(parsed)).unwrap_or(Ok(TokenStream2::new()))?;
Ok(quote! {
/// Underlying implementation for [#struct_name]
#[inline]
@ -384,6 +385,8 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
NODE_METADATA.lock().unwrap().insert(#identifier(), metadata);
}
}
#shader_entry_point
})
}
@ -586,6 +589,7 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st
})
}
use crate::shader_nodes::CodegenShaderEntryPoint;
use syn::visit_mut::VisitMut;
use syn::{GenericArgument, Lifetime, Type};

View file

@ -120,6 +120,8 @@ pub enum ParsedFieldType {
Node(NodeParsedField),
}
/// a param of any kind, either a concrete type or a generic type with a set of possible types specified via
/// `#[implementation(type)]`
#[derive(Clone, Debug)]
pub struct RegularParsedField {
pub ty: Type,
@ -131,8 +133,10 @@ pub struct RegularParsedField {
pub number_hard_max: Option<LitFloat>,
pub number_mode_range: Option<ExprTuple>,
pub implementations: Punctuated<Type, Comma>,
pub gpu_image: bool,
}
/// a param of `impl Node` with `#[implementation(in -> out)]`
#[derive(Clone, Debug)]
pub struct NodeParsedField {
pub input_type: Type,
@ -529,6 +533,7 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
.map_err(|e| Error::new_spanned(attr, format!("Invalid `step` for argument '{ident}': {e}\nUSAGE EXAMPLE: #[step(2.)]")))
})
.transpose()?;
let gpu_image = extract_attribute(attrs, "gpu_image").is_some();
let (is_node, node_input_type, node_output_type) = parse_node_type(&ty);
let description = attrs
@ -590,6 +595,7 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
ty,
value_source,
implementations,
gpu_image,
}),
name,
description,
@ -829,6 +835,7 @@ mod tests {
number_hard_max: None,
number_mode_range: None,
implementations: Punctuated::new(),
gpu_image: false,
}),
number_display_decimal_places: None,
number_step: None,
@ -909,6 +916,7 @@ mod tests {
number_hard_max: None,
number_mode_range: None,
implementations: Punctuated::new(),
gpu_image: false,
}),
number_display_decimal_places: None,
number_step: None,
@ -972,6 +980,7 @@ mod tests {
number_hard_max: None,
number_mode_range: None,
implementations: Punctuated::new(),
gpu_image: false,
}),
number_display_decimal_places: None,
number_step: None,
@ -1038,6 +1047,7 @@ mod tests {
p.push(parse_quote!(f64));
p
},
gpu_image: false,
}),
number_display_decimal_places: None,
number_step: None,
@ -1106,6 +1116,7 @@ mod tests {
number_hard_max: None,
number_mode_range: Some(parse_quote!((0., 100.))),
implementations: Punctuated::new(),
gpu_image: false,
}),
number_display_decimal_places: None,
number_step: None,
@ -1167,6 +1178,7 @@ mod tests {
number_hard_max: None,
number_mode_range: None,
implementations: Punctuated::new(),
gpu_image: false,
}),
number_display_decimal_places: None,
number_step: None,

View file

@ -1,10 +1,13 @@
use crate::parsing::NodeFnAttributes;
use crate::parsing::{NodeFnAttributes, ParsedNodeFn};
use crate::shader_nodes::per_pixel_adjust::PerPixelAdjust;
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use strum::{EnumString, VariantNames};
use strum::VariantNames;
use syn::Error;
use syn::parse::{Parse, ParseStream};
pub mod per_pixel_adjust;
pub const STD_FEATURE_GATE: &str = "std";
pub fn modify_cfg(attributes: &NodeFnAttributes) -> TokenStream {
@ -16,17 +19,33 @@ pub fn modify_cfg(attributes: &NodeFnAttributes) -> TokenStream {
}
}
#[derive(Debug, EnumString, VariantNames)]
#[derive(Debug, VariantNames)]
pub(crate) enum ShaderNodeType {
PerPixelAdjust,
PerPixelAdjust(PerPixelAdjust),
}
impl Parse for ShaderNodeType {
fn parse(input: ParseStream) -> syn::Result<Self> {
let ident: Ident = input.parse()?;
Ok(match ident.to_string().as_str() {
"PerPixelAdjust" => ShaderNodeType::PerPixelAdjust,
"PerPixelAdjust" => ShaderNodeType::PerPixelAdjust(PerPixelAdjust::parse(input)?),
_ => return Err(Error::new_spanned(&ident, format!("attr 'shader_node' must be one of {:?}", Self::VARIANTS))),
})
}
}
pub trait CodegenShaderEntryPoint {
fn codegen_shader_entry_point(&self, parsed: &ParsedNodeFn) -> syn::Result<TokenStream>;
}
impl CodegenShaderEntryPoint for ShaderNodeType {
fn codegen_shader_entry_point(&self, parsed: &ParsedNodeFn) -> syn::Result<TokenStream> {
if parsed.is_async {
return Err(Error::new_spanned(&parsed.fn_name, "Shader nodes must not be async"));
}
match self {
ShaderNodeType::PerPixelAdjust(x) => x.codegen_shader_entry_point(parsed),
}
}
}

View file

@ -0,0 +1,110 @@
use crate::parsing::{ParsedFieldType, ParsedNodeFn, RegularParsedField};
use crate::shader_nodes::CodegenShaderEntryPoint;
use proc_macro2::{Ident, TokenStream};
use quote::{ToTokens, format_ident, quote};
use std::borrow::Cow;
use syn::parse::{Parse, ParseStream};
#[derive(Debug)]
pub struct PerPixelAdjust {}
impl Parse for PerPixelAdjust {
fn parse(_input: ParseStream) -> syn::Result<Self> {
Ok(Self {})
}
}
impl CodegenShaderEntryPoint for PerPixelAdjust {
fn codegen_shader_entry_point(&self, parsed: &ParsedNodeFn) -> syn::Result<TokenStream> {
let fn_name = &parsed.fn_name;
let gpu_mod = format_ident!("{}_gpu", parsed.fn_name);
let spirv_image_ty = quote!(Image2d);
// bindings for images start at 1
let mut binding_cnt = 0;
let params = parsed
.fields
.iter()
.map(|f| {
let ident = &f.pat_ident;
match &f.ty {
ParsedFieldType::Node { .. } => Err(syn::Error::new_spanned(ident, "PerPixelAdjust shader nodes cannot accept other nodes as generics")),
ParsedFieldType::Regular(RegularParsedField { gpu_image: false, ty, .. }) => Ok(Param {
ident: Cow::Borrowed(&ident.ident),
ty: Cow::Owned(ty.to_token_stream()),
param_type: ParamType::Uniform,
}),
ParsedFieldType::Regular(RegularParsedField { gpu_image: true, .. }) => {
binding_cnt += 1;
Ok(Param {
ident: Cow::Owned(format_ident!("image_{}", &ident.ident)),
ty: Cow::Borrowed(&spirv_image_ty),
param_type: ParamType::Image { binding: binding_cnt },
})
}
}
})
.collect::<syn::Result<Vec<_>>>()?;
let uniform_members = params
.iter()
.filter_map(|Param { ident, ty, param_type }| match param_type {
ParamType::Image { .. } => None,
ParamType::Uniform => Some(quote! {#ident: #ty}),
})
.collect::<Vec<_>>();
let image_params = params
.iter()
.filter_map(|Param { ident, ty, param_type }| match param_type {
ParamType::Image { binding } => Some(quote! {#[spirv(descriptor_set = 0, binding = #binding)] #ident: &#ty}),
ParamType::Uniform => None,
})
.collect::<Vec<_>>();
let call_args = params
.iter()
.map(|Param { ident, param_type, .. }| match param_type {
ParamType::Image { .. } => quote!(Color::from_vec4(#ident.fetch_with(texel_coord, lod(0)))),
ParamType::Uniform => quote!(uniform.#ident),
})
.collect::<Vec<_>>();
let context = quote!(());
Ok(quote! {
pub mod #gpu_mod {
use super::*;
use graphene_core_shaders::color::Color;
use spirv_std::spirv;
use spirv_std::glam::{Vec4, Vec4Swizzles};
use spirv_std::image::{Image2d, ImageWithMethods};
use spirv_std::image::sample_with::lod;
pub struct Uniform {
#(#uniform_members),*
}
#[spirv(fragment)]
pub fn entry_point(
#[spirv(frag_coord)] frag_coord: Vec4,
color_out: &mut Vec4,
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] uniform: &Uniform,
#(#image_params),*
) {
let texel_coord = frag_coord.xy().as_uvec2();
let color: Color = #fn_name(#context, #(#call_args),*);
*color_out = color.to_vec4();
}
}
})
}
}
struct Param<'a> {
ident: Cow<'a, Ident>,
ty: Cow<'a, TokenStream>,
param_type: ParamType,
}
enum ParamType {
Image { binding: u32 },
Uniform,
}