diff --git a/node-graph/node-macro/src/codegen.rs b/node-graph/node-macro/src/codegen.rs index da7f0b6a2..b3a1b28be 100644 --- a/node-graph/node-macro/src/codegen.rs +++ b/node-graph/node-macro/src/codegen.rs @@ -1,7 +1,7 @@ use crate::parsing::*; use convert_case::{Case, Casing}; use proc_macro_crate::FoundCrate; -use proc_macro2::{TokenStream as TokenStream2, TokenStream}; +use proc_macro2::TokenStream as TokenStream2; use quote::{ToTokens, format_ident, quote, quote_spanned}; use std::sync::atomic::AtomicU64; use syn::punctuated::Punctuated; @@ -295,11 +295,7 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result, _>(|n| Ok((n.codegen_shader_entry_point(parsed)?, n.codegen_gpu_node(parsed)?))) - .unwrap_or(Ok((TokenStream::new(), TokenStream::new())))?; + let ShaderTokens { shader_entry_point, gpu_node } = attributes.shader_node.as_ref().map(|n| n.codegen(parsed, &cfg)).unwrap_or(Ok(ShaderTokens::default()))?; Ok(quote! { /// Underlying implementation for [#struct_name] @@ -393,7 +389,7 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result syn::Result; - fn codegen_gpu_node(&self, parsed: &ParsedNodeFn) -> syn::Result; +pub trait ShaderCodegen { + fn codegen(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream) -> syn::Result; } -impl CodegenShaderEntryPoint for ShaderNodeType { - fn codegen_shader_entry_point(&self, parsed: &ParsedNodeFn) -> syn::Result { +impl ShaderCodegen for ShaderNodeType { + fn codegen(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream) -> syn::Result { match self { ShaderNodeType::GpuNode => (), _ => { @@ -55,15 +54,14 @@ impl CodegenShaderEntryPoint for ShaderNodeType { } match self { - ShaderNodeType::GpuNode => Ok(TokenStream::new()), - ShaderNodeType::PerPixelAdjust(x) => x.codegen_shader_entry_point(parsed), - } - } - - fn codegen_gpu_node(&self, parsed: &ParsedNodeFn) -> syn::Result { - match self { - ShaderNodeType::GpuNode => Ok(TokenStream::new()), - ShaderNodeType::PerPixelAdjust(x) => x.codegen_gpu_node(parsed), + ShaderNodeType::GpuNode => Ok(ShaderTokens::default()), + ShaderNodeType::PerPixelAdjust(x) => x.codegen(parsed, node_cfg), } } } + +#[derive(Clone, Default)] +pub struct ShaderTokens { + pub shader_entry_point: TokenStream, + pub gpu_node: TokenStream, +} diff --git a/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs b/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs index b263a2c15..e21e3cdfb 100644 --- a/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs +++ b/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs @@ -1,5 +1,5 @@ use crate::parsing::{Input, NodeFnAttributes, ParsedField, ParsedFieldType, ParsedNodeFn, RegularParsedField}; -use crate::shader_nodes::{CodegenShaderEntryPoint, ShaderNodeType}; +use crate::shader_nodes::{ShaderCodegen, ShaderNodeType, ShaderTokens}; use convert_case::{Case, Casing}; use proc_macro_crate::FoundCrate; use proc_macro2::{Ident, Span, TokenStream}; @@ -7,7 +7,7 @@ use quote::{ToTokens, format_ident, quote}; use std::borrow::Cow; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; -use syn::{Path, Token, TraitBound, TraitBoundModifier, Type, TypeImplTrait, TypeParamBound}; +use syn::{Token, TraitBound, TraitBoundModifier, Type, TypeImplTrait, TypeParamBound}; #[derive(Debug, Clone)] pub struct PerPixelAdjust {} @@ -18,10 +18,19 @@ impl Parse for PerPixelAdjust { } } -impl CodegenShaderEntryPoint for PerPixelAdjust { +impl ShaderCodegen for PerPixelAdjust { + fn codegen(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream) -> syn::Result { + Ok(ShaderTokens { + shader_entry_point: self.codegen_shader_entry_point(parsed)?, + gpu_node: self.codegen_gpu_node(parsed, node_cfg)?, + }) + } +} + +impl PerPixelAdjust { fn codegen_shader_entry_point(&self, parsed: &ParsedNodeFn) -> syn::Result { let fn_name = &parsed.fn_name; - let gpu_mod = format_ident!("{}_gpu_entry_point", parsed.fn_name); + let gpu_mod = format_ident!("{}_gpu_entry_point", fn_name); let spirv_image_ty = quote!(Image2d); // bindings for images start at 1 @@ -101,7 +110,7 @@ impl CodegenShaderEntryPoint for PerPixelAdjust { }) } - fn codegen_gpu_node(&self, parsed: &ParsedNodeFn) -> syn::Result { + fn codegen_gpu_node(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream) -> syn::Result { let fn_name = format_ident!("{}_gpu", parsed.fn_name); let struct_name = format_ident!("{}", fn_name.to_string().to_case(Case::Pascal)); let mod_name = fn_name.clone(); @@ -127,13 +136,14 @@ impl CodegenShaderEntryPoint for PerPixelAdjust { ParsedFieldType::Node { .. } => Err(syn::Error::new_spanned(&f.pat_ident, "PerPixelAdjust shader nodes cannot accept other nodes as generics")), }) .collect::>()?; + let body = quote! { { } }; - crate::codegen::generate_node_code(&ParsedNodeFn { + let gpu_node = crate::codegen::generate_node_code(&ParsedNodeFn { vis: parsed.vis.clone(), attributes: NodeFnAttributes { shader_node: Some(ShaderNodeType::GpuNode), @@ -141,7 +151,7 @@ impl CodegenShaderEntryPoint for PerPixelAdjust { }, fn_name, struct_name, - mod_name, + mod_name: mod_name.clone(), fn_generics: vec![], where_clause: None, input: Input { @@ -152,7 +162,7 @@ impl CodegenShaderEntryPoint for PerPixelAdjust { paren_token: None, modifier: TraitBoundModifier::None, lifetimes: None, - path: Path::from(format_ident!("Ctx")), + path: syn::parse2(quote!(#gcore::context::Ctx))?, })]), }), implementations: Default::default(), @@ -163,6 +173,15 @@ impl CodegenShaderEntryPoint for PerPixelAdjust { body, crate_name: parsed.crate_name.clone(), description: "".to_string(), + })?; + + Ok(quote! { + #node_cfg + mod #mod_name { + use super::*; + + #gpu_node + } }) } }