diff --git a/node-graph/node-macro/src/codegen.rs b/node-graph/node-macro/src/codegen.rs index 8039fdf27..ced865e17 100644 --- a/node-graph/node-macro/src/codegen.rs +++ b/node-graph/node-macro/src/codegen.rs @@ -41,19 +41,12 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result = fields.iter().enumerate().map(|(i, _)| format_ident!("Node{}", i)).collect(); let input_ident = &input.pat_ident; - let field_idents: Vec<_> = fields - .iter() - .map(|field| match field { - ParsedField::Regular { pat_ident, .. } | ParsedField::Node { pat_ident, .. } => pat_ident, - }) - .collect(); + let field_idents: Vec<_> = fields.iter().map(|f| &f.pat_ident).collect(); let field_names: Vec<_> = field_idents.iter().map(|pat_ident| &pat_ident.ident).collect(); let input_names: Vec<_> = fields .iter() - .map(|field| match field { - ParsedField::Regular { name, .. } | ParsedField::Node { name, .. } => name, - }) + .map(|f| &f.name) .zip(field_names.iter()) .map(|zipped| match zipped { (Some(name), _) => name.value(), @@ -61,12 +54,7 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result = fields - .iter() - .map(|field| match field { - ParsedField::Regular { description, .. } | ParsedField::Node { description, .. } => description, - }) - .collect(); + let input_descriptions: Vec<_> = fields.iter().map(|f| &f.description).collect(); let struct_fields = field_names.iter().zip(struct_generics.iter()).map(|(name, r#gen)| { quote! { pub(super) #name: #r#gen } @@ -84,9 +72,9 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result = fields .iter() - .map(|field| match field { - ParsedField::Regular { ty, .. } => ty.clone(), - ParsedField::Node { output_type, input_type, .. } => match parsed.is_async { + .map(|field| match &field.ty { + ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty.clone(), + ParsedFieldType::Node(NodeParsedField { output_type, input_type, .. }) => match parsed.is_async { true => parse_quote!(&'n impl #graphene_core::Node<'n, #input_type, Output = impl core::future::Future>), false => parse_quote!(&'n impl #graphene_core::Node<'n, #input_type, Output = #output_type>), }, @@ -95,24 +83,18 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result = fields .iter() - .map(|field| { - let parsed_widget_override = match field { - ParsedField::Regular { widget_override, .. } => widget_override, - ParsedField::Node { widget_override, .. } => widget_override, - }; - match parsed_widget_override { - ParsedWidgetOverride::None => quote!(RegistryWidgetOverride::None), - ParsedWidgetOverride::Hidden => quote!(RegistryWidgetOverride::Hidden), - ParsedWidgetOverride::String(lit_str) => quote!(RegistryWidgetOverride::String(#lit_str)), - ParsedWidgetOverride::Custom(lit_str) => quote!(RegistryWidgetOverride::Custom(#lit_str)), - } + .map(|field| match &field.widget_override { + ParsedWidgetOverride::None => quote!(RegistryWidgetOverride::None), + ParsedWidgetOverride::Hidden => quote!(RegistryWidgetOverride::Hidden), + ParsedWidgetOverride::String(lit_str) => quote!(RegistryWidgetOverride::String(#lit_str)), + ParsedWidgetOverride::Custom(lit_str) => quote!(RegistryWidgetOverride::Custom(#lit_str)), }) .collect(); let value_sources: Vec<_> = fields .iter() - .map(|field| match field { - ParsedField::Regular { value_source, .. } => match value_source { + .map(|field| match &field.ty { + ParsedFieldType::Regular(RegularParsedField { value_source, .. }) => match value_source { ParsedValueSource::Default(data) => quote!(RegistryValueSource::Default(stringify!(#data))), ParsedValueSource::Scope(data) => quote!(RegistryValueSource::Scope(#data)), _ => quote!(RegistryValueSource::None), @@ -123,8 +105,8 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result = fields .iter() - .map(|field| match field { - ParsedField::Regular { implementations, .. } => match implementations.first() { + .map(|field| match &field.ty { + ParsedFieldType::Regular(RegularParsedField { implementations, .. }) => match implementations.first() { Some(ty) => quote!(Some(concrete!(#ty))), _ => quote!(None), }, @@ -134,8 +116,8 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result = fields .iter() - .map(|field| match field { - ParsedField::Regular { number_soft_min, number_hard_min, .. } => match (number_soft_min, number_hard_min) { + .map(|field| match &field.ty { + ParsedFieldType::Regular(RegularParsedField { number_soft_min, number_hard_min, .. }) => match (number_soft_min, number_hard_min) { (Some(soft_min), _) => quote!(Some(#soft_min)), (None, Some(hard_min)) => quote!(Some(#hard_min)), (None, None) => quote!(None), @@ -145,8 +127,8 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result = fields .iter() - .map(|field| match field { - ParsedField::Regular { number_soft_max, number_hard_max, .. } => match (number_soft_max, number_hard_max) { + .map(|field| match &field.ty { + ParsedFieldType::Regular(RegularParsedField { number_soft_max, number_hard_max, .. }) => match (number_soft_max, number_hard_max) { (Some(soft_max), _) => quote!(Some(#soft_max)), (None, Some(hard_max)) => quote!(Some(#hard_max)), (None, None) => quote!(None), @@ -156,77 +138,45 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result = fields .iter() - .map(|field| match field { - ParsedField::Regular { + .map(|field| match &field.ty { + ParsedFieldType::Regular(RegularParsedField { number_mode_range: Some(number_mode_range), .. - } => quote!(Some(#number_mode_range)), + }) => quote!(Some(#number_mode_range)), _ => quote!(None), }) .collect(); let number_display_decimal_places: Vec<_> = fields .iter() - .map(|field| match field { - ParsedField::Regular { - number_display_decimal_places: Some(decimal_places), - .. - } - | ParsedField::Node { - number_display_decimal_places: Some(decimal_places), - .. - } => { - quote!(Some(#decimal_places)) - } - _ => quote!(None), - }) - .collect(); - let number_step: Vec<_> = fields - .iter() - .map(|field| match field { - ParsedField::Regular { number_step: Some(step), .. } | ParsedField::Node { number_step: Some(step), .. } => { - quote!(Some(#step)) - } - _ => quote!(None), - }) + .map(|field| field.number_display_decimal_places.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))) .collect(); + let number_step: Vec<_> = fields.iter().map(|field| field.number_step.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect(); - let unit_suffix: Vec<_> = fields - .iter() - .map(|field| match field { - ParsedField::Regular { unit: Some(unit), .. } | ParsedField::Node { unit: Some(unit), .. } => { - quote!(Some(#unit)) - } - _ => quote!(None), - }) - .collect(); + let unit_suffix: Vec<_> = fields.iter().map(|field| field.unit.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect(); let exposed: Vec<_> = fields .iter() - .map(|field| match field { - ParsedField::Regular { exposed, .. } => quote!(#exposed), + .map(|field| match &field.ty { + ParsedFieldType::Regular(RegularParsedField { exposed, .. }) => quote!(#exposed), _ => quote!(true), }) .collect(); - let eval_args = fields.iter().map(|field| match field { - ParsedField::Regular { pat_ident, .. } => { - let name = &pat_ident.ident; - quote! { let #name = self.#name.eval(__input.clone()).await; } - } - ParsedField::Node { pat_ident, .. } => { - let name = &pat_ident.ident; - quote! { let #name = &self.#name; } + let eval_args = fields.iter().map(|field| { + let name = &field.pat_ident.ident; + match &field.ty { + ParsedFieldType::Regular { .. } => { + quote! { let #name = self.#name.eval(__input.clone()).await; } + } + ParsedFieldType::Node { .. } => { + quote! { let #name = &self.#name; } + } } }); - let min_max_args = fields.iter().map(|field| match field { - ParsedField::Regular { - pat_ident, - number_hard_min, - number_hard_max, - .. - } => { - let name = &pat_ident.ident; + let min_max_args = fields.iter().map(|field| match &field.ty { + ParsedFieldType::Regular(RegularParsedField { number_hard_min, number_hard_max, .. }) => { + let name = &field.pat_ident.ident; let mut tokens = quote!(); if let Some(min) = number_hard_min { tokens.extend(quote_spanned! {min.span()=> @@ -241,15 +191,13 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result { - quote!() - } + ParsedFieldType::Node { .. } => quote!(), }); - let all_implementation_types = fields.iter().flat_map(|field| match field { - ParsedField::Regular { implementations, .. } => implementations.into_iter().cloned().collect::>(), - ParsedField::Node { implementations, .. } => implementations - .into_iter() + let all_implementation_types = fields.iter().flat_map(|field| match &field.ty { + ParsedFieldType::Regular(RegularParsedField { implementations, .. }) => implementations.iter().cloned().collect::>(), + ParsedFieldType::Node(NodeParsedField { implementations, .. }) => implementations + .iter() .flat_map(|implementation| [implementation.input.clone(), implementation.output.clone()]) .collect(), }); @@ -260,11 +208,11 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result { let all_lifetime_ty = substitute_lifetimes(ty.clone(), "all"); @@ -284,7 +232,7 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result + #graphene_core::WasmNotSync ) } - (ParsedField::Node { input_type, output_type, .. }, true) => { + (ParsedFieldType::Node(NodeParsedField { input_type, output_type, .. }), true) => { let id = future_idents.len(); let fut_ident = format_ident!("F{}", id); future_idents.push(fut_ident.clone()); @@ -294,7 +242,7 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result + #graphene_core::WasmNotSync ) } - (ParsedField::Node { .. }, false) => unreachable!(), + (ParsedFieldType::Node { .. }, false) => unreachable!(), }); } let where_clause = where_clause.clone().unwrap_or(WhereClause { @@ -454,9 +402,9 @@ fn generate_node_input_references( let (mut modified, mut generic_collector) = FilterUsedGenerics::new(fn_generics); for (input_index, (parsed_input, input_ident)) in parsed.fields.iter().zip(field_idents).enumerate() { - let mut ty = match parsed_input { - ParsedField::Regular { ty, .. } => ty, - ParsedField::Node { output_type, .. } => output_type, + let mut ty = match &parsed_input.ty { + ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty, + ParsedFieldType::Node(NodeParsedField { output_type, .. }) => output_type, } .clone(); @@ -540,20 +488,20 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st .fields .iter() .map(|field| { - match field { - ParsedField::Regular { implementations, ty, .. } => { + match &field.ty { + ParsedFieldType::Regular(RegularParsedField { implementations, ty, .. }) => { if !implementations.is_empty() { implementations.iter().map(|ty| (&unit, ty)).collect() } else { vec![(&unit, ty)] } } - ParsedField::Node { + ParsedFieldType::Node(NodeParsedField { implementations, input_type, output_type, .. - } => { + }) => { if !implementations.is_empty() { implementations.iter().map(|impl_| (&impl_.input, &impl_.output)).collect() } else { @@ -578,7 +526,7 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st let field_name = field_names[j]; let (input_type, output_type) = &types[i.min(types.len() - 1)]; - let node = matches!(parsed.fields[j], ParsedField::Node { .. }); + let node = matches!(parsed.fields[j].ty, ParsedFieldType::Node { .. }); let downcast_node = quote!( let #field_name: DowncastBothNode<#input_type, #output_type> = DowncastBothNode::new(args[#j].clone()); diff --git a/node-graph/node-macro/src/parsing.rs b/node-graph/node-macro/src/parsing.rs index 9d5122bb2..f01494263 100644 --- a/node-graph/node-macro/src/parsing.rs +++ b/node-graph/node-macro/src/parsing.rs @@ -14,7 +14,7 @@ use syn::{ use crate::codegen::generate_node_code; use crate::shader_nodes::ShaderNodeType; -#[derive(Debug)] +#[derive(Clone, Debug)] pub(crate) struct Implementation { pub(crate) input: Type, pub(crate) _arrow: RArrow, @@ -53,7 +53,7 @@ pub(crate) struct NodeFnAttributes { // Add more attributes as needed } -#[derive(Debug, Default)] +#[derive(Clone, Debug, Default)] pub enum ParsedValueSource { #[default] None, @@ -64,7 +64,7 @@ pub enum ParsedValueSource { // #[widget(ParsedWidgetOverride::Hidden)] // #[widget(ParsedWidgetOverride::String = "Some string")] // #[widget(ParsedWidgetOverride::Custom = "Custom string")] -#[derive(Debug, Default)] +#[derive(Clone, Debug, Default)] pub enum ParsedWidgetOverride { #[default] None, @@ -102,39 +102,44 @@ impl Parse for ParsedWidgetOverride { } } -#[derive(Debug)] -pub(crate) enum ParsedField { - Regular { - pat_ident: PatIdent, - name: Option, - description: String, - widget_override: ParsedWidgetOverride, - ty: Type, - exposed: bool, - value_source: ParsedValueSource, - number_soft_min: Option, - number_soft_max: Option, - number_hard_min: Option, - number_hard_max: Option, - number_mode_range: Option, - number_display_decimal_places: Option, - number_step: Option, - implementations: Punctuated, - unit: Option, - }, - Node { - pat_ident: PatIdent, - name: Option, - description: String, - widget_override: ParsedWidgetOverride, - input_type: Type, - output_type: Type, - number_display_decimal_places: Option, - number_step: Option, - implementations: Punctuated, - unit: Option, - }, +#[derive(Clone, Debug)] +pub struct ParsedField { + pub pat_ident: PatIdent, + pub name: Option, + pub description: String, + pub widget_override: ParsedWidgetOverride, + pub ty: ParsedFieldType, + pub number_display_decimal_places: Option, + pub number_step: Option, + pub unit: Option, } + +#[derive(Clone, Debug)] +pub enum ParsedFieldType { + Regular(RegularParsedField), + Node(NodeParsedField), +} + +#[derive(Clone, Debug)] +pub struct RegularParsedField { + pub ty: Type, + pub exposed: bool, + pub value_source: ParsedValueSource, + pub number_soft_min: Option, + pub number_soft_max: Option, + pub number_hard_min: Option, + pub number_hard_max: Option, + pub number_mode_range: Option, + pub implementations: Punctuated, +} + +#[derive(Clone, Debug)] +pub struct NodeParsedField { + pub input_type: Type, + pub output_type: Type, + pub implementations: Punctuated, +} + #[derive(Debug)] pub(crate) struct Input { pub(crate) pat_ident: PatIdent, @@ -563,16 +568,18 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul .transpose()? .unwrap_or_default(); - Ok(ParsedField::Node { + Ok(ParsedField { pat_ident, + ty: ParsedFieldType::Node(NodeParsedField { + input_type, + output_type, + implementations, + }), name, description, widget_override, - input_type, - output_type, number_display_decimal_places, number_step, - implementations, unit, }) } else { @@ -580,22 +587,24 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul .map(|attr| parse_implementations(attr, ident)) .transpose()? .unwrap_or_default(); - Ok(ParsedField::Regular { + Ok(ParsedField { pat_ident, + ty: ParsedFieldType::Regular(RegularParsedField { + exposed, + number_soft_min, + number_soft_max, + number_hard_min, + number_hard_max, + number_mode_range, + ty, + value_source, + implementations, + }), name, description, widget_override, - exposed, - number_soft_min, - number_soft_max, - number_hard_min, - number_hard_max, - number_mode_range, number_display_decimal_places, number_step, - ty, - value_source, - implementations, unit, }) } @@ -715,18 +724,24 @@ mod tests { for (parsed_field, expected_field) in parsed.fields.iter().zip(expected.fields.iter()) { match (parsed_field, expected_field) { ( - ParsedField::Regular { + ParsedField { pat_ident: p_name, - ty: p_ty, - exposed: p_exp, - value_source: p_default, + ty: ParsedFieldType::Regular(RegularParsedField { + ty: p_ty, + exposed: p_exp, + value_source: p_default, + .. + }), .. }, - ParsedField::Regular { + ParsedField { pat_ident: e_name, - ty: e_ty, - exposed: e_exp, - value_source: e_default, + ty: ParsedFieldType::Regular(RegularParsedField { + ty: e_ty, + exposed: e_exp, + value_source: e_default, + .. + }), .. }, ) => { @@ -745,16 +760,22 @@ mod tests { assert_eq!(format!("{:?}", p_ty), format!("{:?}", e_ty)); } ( - ParsedField::Node { + ParsedField { pat_ident: p_name, - input_type: p_input, - output_type: p_output, + ty: ParsedFieldType::Node(NodeParsedField { + input_type: p_input, + output_type: p_output, + .. + }), .. }, - ParsedField::Node { + ParsedField { pat_ident: e_name, - input_type: e_input, - output_type: e_output, + ty: ParsedFieldType::Node(NodeParsedField { + input_type: e_input, + output_type: e_output, + .. + }), .. }, ) => { @@ -802,22 +823,24 @@ mod tests { }, output_type: parse_quote!(f64), is_async: false, - fields: vec![ParsedField::Regular { + fields: vec![ParsedField { pat_ident: pat_ident("b"), name: None, description: String::new(), widget_override: ParsedWidgetOverride::None, - ty: parse_quote!(f64), - exposed: false, - value_source: ParsedValueSource::None, - number_soft_min: None, - number_soft_max: None, - number_hard_min: None, - number_hard_max: None, - number_mode_range: None, + ty: ParsedFieldType::Regular(RegularParsedField { + ty: parse_quote!(f64), + exposed: false, + value_source: ParsedValueSource::None, + number_soft_min: None, + number_soft_max: None, + number_hard_min: None, + number_hard_max: None, + number_mode_range: None, + implementations: Punctuated::new(), + }), number_display_decimal_places: None, number_step: None, - implementations: Punctuated::new(), unit: None, }], body: TokenStream2::new(), @@ -866,34 +889,38 @@ mod tests { output_type: parse_quote!(T), is_async: false, fields: vec![ - ParsedField::Node { + ParsedField { pat_ident: pat_ident("transform_target"), name: None, description: String::new(), widget_override: ParsedWidgetOverride::None, - input_type: parse_quote!(Footprint), - output_type: parse_quote!(T), + ty: ParsedFieldType::Node(NodeParsedField { + input_type: parse_quote!(Footprint), + output_type: parse_quote!(T), + implementations: Punctuated::new(), + }), number_display_decimal_places: None, number_step: None, - implementations: Punctuated::new(), unit: None, }, - ParsedField::Regular { + ParsedField { pat_ident: pat_ident("translate"), name: None, description: String::new(), widget_override: ParsedWidgetOverride::None, - ty: parse_quote!(DVec2), - exposed: false, - value_source: ParsedValueSource::None, - number_soft_min: None, - number_soft_max: None, - number_hard_min: None, - number_hard_max: None, - number_mode_range: None, + ty: ParsedFieldType::Regular(RegularParsedField { + ty: parse_quote!(DVec2), + exposed: false, + value_source: ParsedValueSource::None, + number_soft_min: None, + number_soft_max: None, + number_hard_min: None, + number_hard_max: None, + number_mode_range: None, + implementations: Punctuated::new(), + }), number_display_decimal_places: None, number_step: None, - implementations: Punctuated::new(), unit: None, }, ], @@ -939,22 +966,24 @@ mod tests { }, output_type: parse_quote!(Vector), is_async: false, - fields: vec![ParsedField::Regular { + fields: vec![ParsedField { pat_ident: pat_ident("radius"), name: None, description: String::new(), widget_override: ParsedWidgetOverride::None, - ty: parse_quote!(f64), - exposed: false, - value_source: ParsedValueSource::Default(quote!(50.)), - number_soft_min: None, - number_soft_max: None, - number_hard_min: None, - number_hard_max: None, - number_mode_range: None, + ty: ParsedFieldType::Regular(RegularParsedField { + ty: parse_quote!(f64), + exposed: false, + value_source: ParsedValueSource::Default(quote!(50.)), + number_soft_min: None, + number_soft_max: None, + number_hard_min: None, + number_hard_max: None, + number_mode_range: None, + implementations: Punctuated::new(), + }), number_display_decimal_places: None, number_step: None, - implementations: Punctuated::new(), unit: None, }], body: TokenStream2::new(), @@ -998,27 +1027,29 @@ mod tests { }, output_type: parse_quote!(Table>), is_async: false, - fields: vec![ParsedField::Regular { + fields: vec![ParsedField { pat_ident: pat_ident("shadows"), name: None, description: String::new(), widget_override: ParsedWidgetOverride::None, - ty: parse_quote!(f64), - exposed: false, - value_source: ParsedValueSource::None, - number_soft_min: None, - number_soft_max: None, - number_hard_min: None, - number_hard_max: None, - number_mode_range: None, + ty: ParsedFieldType::Regular(RegularParsedField { + ty: parse_quote!(f64), + exposed: false, + value_source: ParsedValueSource::None, + number_soft_min: None, + number_soft_max: None, + number_hard_min: None, + number_hard_max: None, + number_mode_range: None, + implementations: { + let mut p = Punctuated::new(); + p.push(parse_quote!(f32)); + p.push(parse_quote!(f64)); + p + }, + }), number_display_decimal_places: None, number_step: None, - implementations: { - let mut p = Punctuated::new(); - p.push(parse_quote!(f32)); - p.push(parse_quote!(f64)); - p - }, unit: None, }], body: TokenStream2::new(), @@ -1069,22 +1100,24 @@ mod tests { }, output_type: parse_quote!(f64), is_async: false, - fields: vec![ParsedField::Regular { + fields: vec![ParsedField { pat_ident: pat_ident("b"), name: None, description: String::from("b"), widget_override: ParsedWidgetOverride::None, - ty: parse_quote!(f64), - exposed: false, - value_source: ParsedValueSource::None, - number_soft_min: Some(parse_quote!(-500.)), - number_soft_max: Some(parse_quote!(500.)), - number_hard_min: None, - number_hard_max: None, - number_mode_range: Some(parse_quote!((0., 100.))), + ty: ParsedFieldType::Regular(RegularParsedField { + ty: parse_quote!(f64), + exposed: false, + value_source: ParsedValueSource::None, + number_soft_min: Some(parse_quote!(-500.)), + number_soft_max: Some(parse_quote!(500.)), + number_hard_min: None, + number_hard_max: None, + number_mode_range: Some(parse_quote!((0., 100.))), + implementations: Punctuated::new(), + }), number_display_decimal_places: None, number_step: None, - implementations: Punctuated::new(), unit: None, }], body: TokenStream2::new(), @@ -1128,22 +1161,24 @@ mod tests { }, output_type: parse_quote!(Table>), is_async: true, - fields: vec![ParsedField::Regular { + fields: vec![ParsedField { pat_ident: pat_ident("path"), name: None, - ty: parse_quote!(String), description: String::new(), widget_override: ParsedWidgetOverride::None, - exposed: true, - value_source: ParsedValueSource::None, - number_soft_min: None, - number_soft_max: None, - number_hard_min: None, - number_hard_max: None, - number_mode_range: None, + ty: ParsedFieldType::Regular(RegularParsedField { + ty: parse_quote!(String), + exposed: true, + value_source: ParsedValueSource::None, + number_soft_min: None, + number_soft_max: None, + number_hard_min: None, + number_hard_max: None, + number_mode_range: None, + implementations: Punctuated::new(), + }), number_display_decimal_places: None, number_step: None, - implementations: Punctuated::new(), unit: None, }], body: TokenStream2::new(), diff --git a/node-graph/node-macro/src/validation.rs b/node-graph/node-macro/src/validation.rs index 9ce43a650..ae6066309 100644 --- a/node-graph/node-macro/src/validation.rs +++ b/node-graph/node-macro/src/validation.rs @@ -1,4 +1,4 @@ -use crate::parsing::{Implementation, ParsedField, ParsedNodeFn}; +use crate::parsing::{Implementation, NodeParsedField, ParsedField, ParsedFieldType, ParsedNodeFn, RegularParsedField}; use proc_macro_error2::emit_error; use quote::quote; use syn::spanned::Spanned; @@ -21,11 +21,14 @@ pub fn validate_node_fn(parsed: &ParsedNodeFn) -> syn::Result<()> { fn validate_min_max(parsed: &ParsedNodeFn) { for field in &parsed.fields { - if let ParsedField::Regular { - number_hard_max, - number_hard_min, - number_soft_max, - number_soft_min, + if let ParsedField { + ty: ParsedFieldType::Regular(RegularParsedField { + number_hard_max, + number_hard_min, + number_soft_max, + number_soft_min, + .. + }), pat_ident, .. } = field @@ -78,7 +81,12 @@ fn validate_min_max(parsed: &ParsedNodeFn) { } fn validate_primary_input_expose(parsed: &ParsedNodeFn) { - if let Some(ParsedField::Regular { exposed: true, pat_ident, .. }) = parsed.fields.first() { + if let Some(ParsedField { + ty: ParsedFieldType::Regular(RegularParsedField { exposed: true, .. }), + pat_ident, + .. + }) = parsed.fields.first() + { emit_error!( pat_ident.span(), "Unnecessary #[expose] attribute on primary input `{}`. Primary inputs are always exposed.", @@ -94,8 +102,9 @@ fn validate_implementations_for_generics(parsed: &ParsedNodeFn) { if !has_skip_impl && !parsed.fn_generics.is_empty() { for field in &parsed.fields { - match field { - ParsedField::Regular { ty, implementations, pat_ident, .. } => { + let pat_ident = &field.pat_ident; + match &field.ty { + ParsedFieldType::Regular(RegularParsedField { ty, implementations, .. }) => { if contains_generic_param(ty, &parsed.fn_generics) && implementations.is_empty() { emit_error!( ty.span(), @@ -107,13 +116,12 @@ fn validate_implementations_for_generics(parsed: &ParsedNodeFn) { ); } } - ParsedField::Node { + ParsedFieldType::Node(NodeParsedField { input_type, output_type, implementations, - pat_ident, .. - } => { + }) => { if (contains_generic_param(input_type, &parsed.fn_generics) || contains_generic_param(output_type, &parsed.fn_generics)) && implementations.is_empty() { emit_error!( pat_ident.span(),