node-macro: cleanup ParsedField struct (#3064)
Some checks are pending
Editor: Dev & CI / build (push) Waiting to run
Editor: Dev & CI / cargo-deny (push) Waiting to run

* node-macro: cleanup `ParsedField` struct

* node-macro: fixup tests
This commit is contained in:
Firestar99 2025-08-19 11:25:58 +02:00 committed by GitHub
parent 36a1453d03
commit b44a4fba1e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 244 additions and 253 deletions

View file

@ -41,19 +41,12 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
let struct_generics: Vec<Ident> = fields.iter().enumerate().map(|(i, _)| format_ident!("Node{}", i)).collect(); let struct_generics: Vec<Ident> = fields.iter().enumerate().map(|(i, _)| format_ident!("Node{}", i)).collect();
let input_ident = &input.pat_ident; let input_ident = &input.pat_ident;
let field_idents: Vec<_> = fields let field_idents: Vec<_> = fields.iter().map(|f| &f.pat_ident).collect();
.iter()
.map(|field| match field {
ParsedField::Regular { pat_ident, .. } | ParsedField::Node { pat_ident, .. } => pat_ident,
})
.collect();
let field_names: Vec<_> = field_idents.iter().map(|pat_ident| &pat_ident.ident).collect(); let field_names: Vec<_> = field_idents.iter().map(|pat_ident| &pat_ident.ident).collect();
let input_names: Vec<_> = fields let input_names: Vec<_> = fields
.iter() .iter()
.map(|field| match field { .map(|f| &f.name)
ParsedField::Regular { name, .. } | ParsedField::Node { name, .. } => name,
})
.zip(field_names.iter()) .zip(field_names.iter())
.map(|zipped| match zipped { .map(|zipped| match zipped {
(Some(name), _) => name.value(), (Some(name), _) => name.value(),
@ -61,12 +54,7 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
}) })
.collect(); .collect();
let input_descriptions: Vec<_> = fields let input_descriptions: Vec<_> = fields.iter().map(|f| &f.description).collect();
.iter()
.map(|field| match field {
ParsedField::Regular { description, .. } | ParsedField::Node { description, .. } => description,
})
.collect();
let struct_fields = field_names.iter().zip(struct_generics.iter()).map(|(name, r#gen)| { let struct_fields = field_names.iter().zip(struct_generics.iter()).map(|(name, r#gen)| {
quote! { pub(super) #name: #r#gen } quote! { pub(super) #name: #r#gen }
@ -84,9 +72,9 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
let field_types: Vec<_> = fields let field_types: Vec<_> = fields
.iter() .iter()
.map(|field| match field { .map(|field| match &field.ty {
ParsedField::Regular { ty, .. } => ty.clone(), ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty.clone(),
ParsedField::Node { output_type, input_type, .. } => match parsed.is_async { ParsedFieldType::Node(NodeParsedField { output_type, input_type, .. }) => match parsed.is_async {
true => parse_quote!(&'n impl #graphene_core::Node<'n, #input_type, Output = impl core::future::Future<Output=#output_type>>), true => parse_quote!(&'n impl #graphene_core::Node<'n, #input_type, Output = impl core::future::Future<Output=#output_type>>),
false => parse_quote!(&'n impl #graphene_core::Node<'n, #input_type, Output = #output_type>), false => parse_quote!(&'n impl #graphene_core::Node<'n, #input_type, Output = #output_type>),
}, },
@ -95,24 +83,18 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
let widget_override: Vec<_> = fields let widget_override: Vec<_> = fields
.iter() .iter()
.map(|field| { .map(|field| match &field.widget_override {
let parsed_widget_override = match field {
ParsedField::Regular { widget_override, .. } => widget_override,
ParsedField::Node { widget_override, .. } => widget_override,
};
match parsed_widget_override {
ParsedWidgetOverride::None => quote!(RegistryWidgetOverride::None), ParsedWidgetOverride::None => quote!(RegistryWidgetOverride::None),
ParsedWidgetOverride::Hidden => quote!(RegistryWidgetOverride::Hidden), ParsedWidgetOverride::Hidden => quote!(RegistryWidgetOverride::Hidden),
ParsedWidgetOverride::String(lit_str) => quote!(RegistryWidgetOverride::String(#lit_str)), ParsedWidgetOverride::String(lit_str) => quote!(RegistryWidgetOverride::String(#lit_str)),
ParsedWidgetOverride::Custom(lit_str) => quote!(RegistryWidgetOverride::Custom(#lit_str)), ParsedWidgetOverride::Custom(lit_str) => quote!(RegistryWidgetOverride::Custom(#lit_str)),
}
}) })
.collect(); .collect();
let value_sources: Vec<_> = fields let value_sources: Vec<_> = fields
.iter() .iter()
.map(|field| match field { .map(|field| match &field.ty {
ParsedField::Regular { value_source, .. } => match value_source { ParsedFieldType::Regular(RegularParsedField { value_source, .. }) => match value_source {
ParsedValueSource::Default(data) => quote!(RegistryValueSource::Default(stringify!(#data))), ParsedValueSource::Default(data) => quote!(RegistryValueSource::Default(stringify!(#data))),
ParsedValueSource::Scope(data) => quote!(RegistryValueSource::Scope(#data)), ParsedValueSource::Scope(data) => quote!(RegistryValueSource::Scope(#data)),
_ => quote!(RegistryValueSource::None), _ => quote!(RegistryValueSource::None),
@ -123,8 +105,8 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
let default_types: Vec<_> = fields let default_types: Vec<_> = fields
.iter() .iter()
.map(|field| match field { .map(|field| match &field.ty {
ParsedField::Regular { implementations, .. } => match implementations.first() { ParsedFieldType::Regular(RegularParsedField { implementations, .. }) => match implementations.first() {
Some(ty) => quote!(Some(concrete!(#ty))), Some(ty) => quote!(Some(concrete!(#ty))),
_ => quote!(None), _ => quote!(None),
}, },
@ -134,8 +116,8 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
let number_min_values: Vec<_> = fields let number_min_values: Vec<_> = fields
.iter() .iter()
.map(|field| match field { .map(|field| match &field.ty {
ParsedField::Regular { number_soft_min, number_hard_min, .. } => match (number_soft_min, number_hard_min) { ParsedFieldType::Regular(RegularParsedField { number_soft_min, number_hard_min, .. }) => match (number_soft_min, number_hard_min) {
(Some(soft_min), _) => quote!(Some(#soft_min)), (Some(soft_min), _) => quote!(Some(#soft_min)),
(None, Some(hard_min)) => quote!(Some(#hard_min)), (None, Some(hard_min)) => quote!(Some(#hard_min)),
(None, None) => quote!(None), (None, None) => quote!(None),
@ -145,8 +127,8 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
.collect(); .collect();
let number_max_values: Vec<_> = fields let number_max_values: Vec<_> = fields
.iter() .iter()
.map(|field| match field { .map(|field| match &field.ty {
ParsedField::Regular { number_soft_max, number_hard_max, .. } => match (number_soft_max, number_hard_max) { ParsedFieldType::Regular(RegularParsedField { number_soft_max, number_hard_max, .. }) => match (number_soft_max, number_hard_max) {
(Some(soft_max), _) => quote!(Some(#soft_max)), (Some(soft_max), _) => quote!(Some(#soft_max)),
(None, Some(hard_max)) => quote!(Some(#hard_max)), (None, Some(hard_max)) => quote!(Some(#hard_max)),
(None, None) => quote!(None), (None, None) => quote!(None),
@ -156,77 +138,45 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
.collect(); .collect();
let number_mode_range_values: Vec<_> = fields let number_mode_range_values: Vec<_> = fields
.iter() .iter()
.map(|field| match field { .map(|field| match &field.ty {
ParsedField::Regular { ParsedFieldType::Regular(RegularParsedField {
number_mode_range: Some(number_mode_range), number_mode_range: Some(number_mode_range),
.. ..
} => quote!(Some(#number_mode_range)), }) => quote!(Some(#number_mode_range)),
_ => quote!(None), _ => quote!(None),
}) })
.collect(); .collect();
let number_display_decimal_places: Vec<_> = fields let number_display_decimal_places: Vec<_> = fields
.iter() .iter()
.map(|field| match field { .map(|field| field.number_display_decimal_places.as_ref().map_or(quote!(None), |i| quote!(Some(#i))))
ParsedField::Regular {
number_display_decimal_places: Some(decimal_places),
..
}
| ParsedField::Node {
number_display_decimal_places: Some(decimal_places),
..
} => {
quote!(Some(#decimal_places))
}
_ => quote!(None),
})
.collect();
let number_step: Vec<_> = fields
.iter()
.map(|field| match field {
ParsedField::Regular { number_step: Some(step), .. } | ParsedField::Node { number_step: Some(step), .. } => {
quote!(Some(#step))
}
_ => quote!(None),
})
.collect(); .collect();
let number_step: Vec<_> = fields.iter().map(|field| field.number_step.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
let unit_suffix: Vec<_> = fields let unit_suffix: Vec<_> = fields.iter().map(|field| field.unit.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
.iter()
.map(|field| match field {
ParsedField::Regular { unit: Some(unit), .. } | ParsedField::Node { unit: Some(unit), .. } => {
quote!(Some(#unit))
}
_ => quote!(None),
})
.collect();
let exposed: Vec<_> = fields let exposed: Vec<_> = fields
.iter() .iter()
.map(|field| match field { .map(|field| match &field.ty {
ParsedField::Regular { exposed, .. } => quote!(#exposed), ParsedFieldType::Regular(RegularParsedField { exposed, .. }) => quote!(#exposed),
_ => quote!(true), _ => quote!(true),
}) })
.collect(); .collect();
let eval_args = fields.iter().map(|field| match field { let eval_args = fields.iter().map(|field| {
ParsedField::Regular { pat_ident, .. } => { let name = &field.pat_ident.ident;
let name = &pat_ident.ident; match &field.ty {
ParsedFieldType::Regular { .. } => {
quote! { let #name = self.#name.eval(__input.clone()).await; } quote! { let #name = self.#name.eval(__input.clone()).await; }
} }
ParsedField::Node { pat_ident, .. } => { ParsedFieldType::Node { .. } => {
let name = &pat_ident.ident;
quote! { let #name = &self.#name; } quote! { let #name = &self.#name; }
} }
}
}); });
let min_max_args = fields.iter().map(|field| match field { let min_max_args = fields.iter().map(|field| match &field.ty {
ParsedField::Regular { ParsedFieldType::Regular(RegularParsedField { number_hard_min, number_hard_max, .. }) => {
pat_ident, let name = &field.pat_ident.ident;
number_hard_min,
number_hard_max,
..
} => {
let name = &pat_ident.ident;
let mut tokens = quote!(); let mut tokens = quote!();
if let Some(min) = number_hard_min { if let Some(min) = number_hard_min {
tokens.extend(quote_spanned! {min.span()=> tokens.extend(quote_spanned! {min.span()=>
@ -241,15 +191,13 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
} }
tokens tokens
} }
ParsedField::Node { .. } => { ParsedFieldType::Node { .. } => quote!(),
quote!()
}
}); });
let all_implementation_types = fields.iter().flat_map(|field| match field { let all_implementation_types = fields.iter().flat_map(|field| match &field.ty {
ParsedField::Regular { implementations, .. } => implementations.into_iter().cloned().collect::<Vec<_>>(), ParsedFieldType::Regular(RegularParsedField { implementations, .. }) => implementations.iter().cloned().collect::<Vec<_>>(),
ParsedField::Node { implementations, .. } => implementations ParsedFieldType::Node(NodeParsedField { implementations, .. }) => implementations
.into_iter() .iter()
.flat_map(|implementation| [implementation.input.clone(), implementation.output.clone()]) .flat_map(|implementation| [implementation.input.clone(), implementation.output.clone()])
.collect(), .collect(),
}); });
@ -260,11 +208,11 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
let mut clampable_clauses = Vec::new(); let mut clampable_clauses = Vec::new();
for (field, name) in fields.iter().zip(struct_generics.iter()) { for (field, name) in fields.iter().zip(struct_generics.iter()) {
clauses.push(match (field, *is_async) { clauses.push(match (&field.ty, *is_async) {
( (
ParsedField::Regular { ParsedFieldType::Regular(RegularParsedField {
ty, number_hard_min, number_hard_max, .. ty, number_hard_min, number_hard_max, ..
}, }),
_, _,
) => { ) => {
let all_lifetime_ty = substitute_lifetimes(ty.clone(), "all"); let all_lifetime_ty = substitute_lifetimes(ty.clone(), "all");
@ -284,7 +232,7 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
#name: #graphene_core::Node<'n, #input_type, Output = #fut_ident> + #graphene_core::WasmNotSync #name: #graphene_core::Node<'n, #input_type, Output = #fut_ident> + #graphene_core::WasmNotSync
) )
} }
(ParsedField::Node { input_type, output_type, .. }, true) => { (ParsedFieldType::Node(NodeParsedField { input_type, output_type, .. }), true) => {
let id = future_idents.len(); let id = future_idents.len();
let fut_ident = format_ident!("F{}", id); let fut_ident = format_ident!("F{}", id);
future_idents.push(fut_ident.clone()); future_idents.push(fut_ident.clone());
@ -294,7 +242,7 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
#name: #graphene_core::Node<'n, #input_type, Output = #fut_ident > + #graphene_core::WasmNotSync #name: #graphene_core::Node<'n, #input_type, Output = #fut_ident > + #graphene_core::WasmNotSync
) )
} }
(ParsedField::Node { .. }, false) => unreachable!(), (ParsedFieldType::Node { .. }, false) => unreachable!(),
}); });
} }
let where_clause = where_clause.clone().unwrap_or(WhereClause { let where_clause = where_clause.clone().unwrap_or(WhereClause {
@ -454,9 +402,9 @@ fn generate_node_input_references(
let (mut modified, mut generic_collector) = FilterUsedGenerics::new(fn_generics); let (mut modified, mut generic_collector) = FilterUsedGenerics::new(fn_generics);
for (input_index, (parsed_input, input_ident)) in parsed.fields.iter().zip(field_idents).enumerate() { for (input_index, (parsed_input, input_ident)) in parsed.fields.iter().zip(field_idents).enumerate() {
let mut ty = match parsed_input { let mut ty = match &parsed_input.ty {
ParsedField::Regular { ty, .. } => ty, ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty,
ParsedField::Node { output_type, .. } => output_type, ParsedFieldType::Node(NodeParsedField { output_type, .. }) => output_type,
} }
.clone(); .clone();
@ -540,20 +488,20 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st
.fields .fields
.iter() .iter()
.map(|field| { .map(|field| {
match field { match &field.ty {
ParsedField::Regular { implementations, ty, .. } => { ParsedFieldType::Regular(RegularParsedField { implementations, ty, .. }) => {
if !implementations.is_empty() { if !implementations.is_empty() {
implementations.iter().map(|ty| (&unit, ty)).collect() implementations.iter().map(|ty| (&unit, ty)).collect()
} else { } else {
vec![(&unit, ty)] vec![(&unit, ty)]
} }
} }
ParsedField::Node { ParsedFieldType::Node(NodeParsedField {
implementations, implementations,
input_type, input_type,
output_type, output_type,
.. ..
} => { }) => {
if !implementations.is_empty() { if !implementations.is_empty() {
implementations.iter().map(|impl_| (&impl_.input, &impl_.output)).collect() implementations.iter().map(|impl_| (&impl_.input, &impl_.output)).collect()
} else { } else {
@ -578,7 +526,7 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st
let field_name = field_names[j]; let field_name = field_names[j];
let (input_type, output_type) = &types[i.min(types.len() - 1)]; let (input_type, output_type) = &types[i.min(types.len() - 1)];
let node = matches!(parsed.fields[j], ParsedField::Node { .. }); let node = matches!(parsed.fields[j].ty, ParsedFieldType::Node { .. });
let downcast_node = quote!( let downcast_node = quote!(
let #field_name: DowncastBothNode<#input_type, #output_type> = DowncastBothNode::new(args[#j].clone()); let #field_name: DowncastBothNode<#input_type, #output_type> = DowncastBothNode::new(args[#j].clone());

View file

@ -14,7 +14,7 @@ use syn::{
use crate::codegen::generate_node_code; use crate::codegen::generate_node_code;
use crate::shader_nodes::ShaderNodeType; use crate::shader_nodes::ShaderNodeType;
#[derive(Debug)] #[derive(Clone, Debug)]
pub(crate) struct Implementation { pub(crate) struct Implementation {
pub(crate) input: Type, pub(crate) input: Type,
pub(crate) _arrow: RArrow, pub(crate) _arrow: RArrow,
@ -53,7 +53,7 @@ pub(crate) struct NodeFnAttributes {
// Add more attributes as needed // Add more attributes as needed
} }
#[derive(Debug, Default)] #[derive(Clone, Debug, Default)]
pub enum ParsedValueSource { pub enum ParsedValueSource {
#[default] #[default]
None, None,
@ -64,7 +64,7 @@ pub enum ParsedValueSource {
// #[widget(ParsedWidgetOverride::Hidden)] // #[widget(ParsedWidgetOverride::Hidden)]
// #[widget(ParsedWidgetOverride::String = "Some string")] // #[widget(ParsedWidgetOverride::String = "Some string")]
// #[widget(ParsedWidgetOverride::Custom = "Custom string")] // #[widget(ParsedWidgetOverride::Custom = "Custom string")]
#[derive(Debug, Default)] #[derive(Clone, Debug, Default)]
pub enum ParsedWidgetOverride { pub enum ParsedWidgetOverride {
#[default] #[default]
None, None,
@ -102,39 +102,44 @@ impl Parse for ParsedWidgetOverride {
} }
} }
#[derive(Debug)] #[derive(Clone, Debug)]
pub(crate) enum ParsedField { pub struct ParsedField {
Regular { pub pat_ident: PatIdent,
pat_ident: PatIdent, pub name: Option<LitStr>,
name: Option<LitStr>, pub description: String,
description: String, pub widget_override: ParsedWidgetOverride,
widget_override: ParsedWidgetOverride, pub ty: ParsedFieldType,
ty: Type, pub number_display_decimal_places: Option<LitInt>,
exposed: bool, pub number_step: Option<LitFloat>,
value_source: ParsedValueSource, pub unit: Option<LitStr>,
number_soft_min: Option<LitFloat>,
number_soft_max: Option<LitFloat>,
number_hard_min: Option<LitFloat>,
number_hard_max: Option<LitFloat>,
number_mode_range: Option<ExprTuple>,
number_display_decimal_places: Option<LitInt>,
number_step: Option<LitFloat>,
implementations: Punctuated<Type, Comma>,
unit: Option<LitStr>,
},
Node {
pat_ident: PatIdent,
name: Option<LitStr>,
description: String,
widget_override: ParsedWidgetOverride,
input_type: Type,
output_type: Type,
number_display_decimal_places: Option<LitInt>,
number_step: Option<LitFloat>,
implementations: Punctuated<Implementation, Comma>,
unit: Option<LitStr>,
},
} }
#[derive(Clone, Debug)]
pub enum ParsedFieldType {
Regular(RegularParsedField),
Node(NodeParsedField),
}
#[derive(Clone, Debug)]
pub struct RegularParsedField {
pub ty: Type,
pub exposed: bool,
pub value_source: ParsedValueSource,
pub number_soft_min: Option<LitFloat>,
pub number_soft_max: Option<LitFloat>,
pub number_hard_min: Option<LitFloat>,
pub number_hard_max: Option<LitFloat>,
pub number_mode_range: Option<ExprTuple>,
pub implementations: Punctuated<Type, Comma>,
}
#[derive(Clone, Debug)]
pub struct NodeParsedField {
pub input_type: Type,
pub output_type: Type,
pub implementations: Punctuated<Implementation, Comma>,
}
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct Input { pub(crate) struct Input {
pub(crate) pat_ident: PatIdent, pub(crate) pat_ident: PatIdent,
@ -563,16 +568,18 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
.transpose()? .transpose()?
.unwrap_or_default(); .unwrap_or_default();
Ok(ParsedField::Node { Ok(ParsedField {
pat_ident, pat_ident,
ty: ParsedFieldType::Node(NodeParsedField {
input_type,
output_type,
implementations,
}),
name, name,
description, description,
widget_override, widget_override,
input_type,
output_type,
number_display_decimal_places, number_display_decimal_places,
number_step, number_step,
implementations,
unit, unit,
}) })
} else { } else {
@ -580,22 +587,24 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
.map(|attr| parse_implementations(attr, ident)) .map(|attr| parse_implementations(attr, ident))
.transpose()? .transpose()?
.unwrap_or_default(); .unwrap_or_default();
Ok(ParsedField::Regular { Ok(ParsedField {
pat_ident, pat_ident,
name, ty: ParsedFieldType::Regular(RegularParsedField {
description,
widget_override,
exposed, exposed,
number_soft_min, number_soft_min,
number_soft_max, number_soft_max,
number_hard_min, number_hard_min,
number_hard_max, number_hard_max,
number_mode_range, number_mode_range,
number_display_decimal_places,
number_step,
ty, ty,
value_source, value_source,
implementations, implementations,
}),
name,
description,
widget_override,
number_display_decimal_places,
number_step,
unit, unit,
}) })
} }
@ -715,19 +724,25 @@ mod tests {
for (parsed_field, expected_field) in parsed.fields.iter().zip(expected.fields.iter()) { for (parsed_field, expected_field) in parsed.fields.iter().zip(expected.fields.iter()) {
match (parsed_field, expected_field) { match (parsed_field, expected_field) {
( (
ParsedField::Regular { ParsedField {
pat_ident: p_name, pat_ident: p_name,
ty: ParsedFieldType::Regular(RegularParsedField {
ty: p_ty, ty: p_ty,
exposed: p_exp, exposed: p_exp,
value_source: p_default, value_source: p_default,
.. ..
}),
..
}, },
ParsedField::Regular { ParsedField {
pat_ident: e_name, pat_ident: e_name,
ty: ParsedFieldType::Regular(RegularParsedField {
ty: e_ty, ty: e_ty,
exposed: e_exp, exposed: e_exp,
value_source: e_default, value_source: e_default,
.. ..
}),
..
}, },
) => { ) => {
assert_eq!(p_name, e_name); assert_eq!(p_name, e_name);
@ -745,17 +760,23 @@ mod tests {
assert_eq!(format!("{:?}", p_ty), format!("{:?}", e_ty)); assert_eq!(format!("{:?}", p_ty), format!("{:?}", e_ty));
} }
( (
ParsedField::Node { ParsedField {
pat_ident: p_name, pat_ident: p_name,
ty: ParsedFieldType::Node(NodeParsedField {
input_type: p_input, input_type: p_input,
output_type: p_output, output_type: p_output,
.. ..
}),
..
}, },
ParsedField::Node { ParsedField {
pat_ident: e_name, pat_ident: e_name,
ty: ParsedFieldType::Node(NodeParsedField {
input_type: e_input, input_type: e_input,
output_type: e_output, output_type: e_output,
.. ..
}),
..
}, },
) => { ) => {
assert_eq!(p_name, e_name); assert_eq!(p_name, e_name);
@ -802,11 +823,12 @@ mod tests {
}, },
output_type: parse_quote!(f64), output_type: parse_quote!(f64),
is_async: false, is_async: false,
fields: vec![ParsedField::Regular { fields: vec![ParsedField {
pat_ident: pat_ident("b"), pat_ident: pat_ident("b"),
name: None, name: None,
description: String::new(), description: String::new(),
widget_override: ParsedWidgetOverride::None, widget_override: ParsedWidgetOverride::None,
ty: ParsedFieldType::Regular(RegularParsedField {
ty: parse_quote!(f64), ty: parse_quote!(f64),
exposed: false, exposed: false,
value_source: ParsedValueSource::None, value_source: ParsedValueSource::None,
@ -815,9 +837,10 @@ mod tests {
number_hard_min: None, number_hard_min: None,
number_hard_max: None, number_hard_max: None,
number_mode_range: None, number_mode_range: None,
implementations: Punctuated::new(),
}),
number_display_decimal_places: None, number_display_decimal_places: None,
number_step: None, number_step: None,
implementations: Punctuated::new(),
unit: None, unit: None,
}], }],
body: TokenStream2::new(), body: TokenStream2::new(),
@ -866,23 +889,26 @@ mod tests {
output_type: parse_quote!(T), output_type: parse_quote!(T),
is_async: false, is_async: false,
fields: vec![ fields: vec![
ParsedField::Node { ParsedField {
pat_ident: pat_ident("transform_target"), pat_ident: pat_ident("transform_target"),
name: None, name: None,
description: String::new(), description: String::new(),
widget_override: ParsedWidgetOverride::None, widget_override: ParsedWidgetOverride::None,
ty: ParsedFieldType::Node(NodeParsedField {
input_type: parse_quote!(Footprint), input_type: parse_quote!(Footprint),
output_type: parse_quote!(T), output_type: parse_quote!(T),
implementations: Punctuated::new(),
}),
number_display_decimal_places: None, number_display_decimal_places: None,
number_step: None, number_step: None,
implementations: Punctuated::new(),
unit: None, unit: None,
}, },
ParsedField::Regular { ParsedField {
pat_ident: pat_ident("translate"), pat_ident: pat_ident("translate"),
name: None, name: None,
description: String::new(), description: String::new(),
widget_override: ParsedWidgetOverride::None, widget_override: ParsedWidgetOverride::None,
ty: ParsedFieldType::Regular(RegularParsedField {
ty: parse_quote!(DVec2), ty: parse_quote!(DVec2),
exposed: false, exposed: false,
value_source: ParsedValueSource::None, value_source: ParsedValueSource::None,
@ -891,9 +917,10 @@ mod tests {
number_hard_min: None, number_hard_min: None,
number_hard_max: None, number_hard_max: None,
number_mode_range: None, number_mode_range: None,
implementations: Punctuated::new(),
}),
number_display_decimal_places: None, number_display_decimal_places: None,
number_step: None, number_step: None,
implementations: Punctuated::new(),
unit: None, unit: None,
}, },
], ],
@ -939,11 +966,12 @@ mod tests {
}, },
output_type: parse_quote!(Vector), output_type: parse_quote!(Vector),
is_async: false, is_async: false,
fields: vec![ParsedField::Regular { fields: vec![ParsedField {
pat_ident: pat_ident("radius"), pat_ident: pat_ident("radius"),
name: None, name: None,
description: String::new(), description: String::new(),
widget_override: ParsedWidgetOverride::None, widget_override: ParsedWidgetOverride::None,
ty: ParsedFieldType::Regular(RegularParsedField {
ty: parse_quote!(f64), ty: parse_quote!(f64),
exposed: false, exposed: false,
value_source: ParsedValueSource::Default(quote!(50.)), value_source: ParsedValueSource::Default(quote!(50.)),
@ -952,9 +980,10 @@ mod tests {
number_hard_min: None, number_hard_min: None,
number_hard_max: None, number_hard_max: None,
number_mode_range: None, number_mode_range: None,
implementations: Punctuated::new(),
}),
number_display_decimal_places: None, number_display_decimal_places: None,
number_step: None, number_step: None,
implementations: Punctuated::new(),
unit: None, unit: None,
}], }],
body: TokenStream2::new(), body: TokenStream2::new(),
@ -998,11 +1027,12 @@ mod tests {
}, },
output_type: parse_quote!(Table<Raster<P>>), output_type: parse_quote!(Table<Raster<P>>),
is_async: false, is_async: false,
fields: vec![ParsedField::Regular { fields: vec![ParsedField {
pat_ident: pat_ident("shadows"), pat_ident: pat_ident("shadows"),
name: None, name: None,
description: String::new(), description: String::new(),
widget_override: ParsedWidgetOverride::None, widget_override: ParsedWidgetOverride::None,
ty: ParsedFieldType::Regular(RegularParsedField {
ty: parse_quote!(f64), ty: parse_quote!(f64),
exposed: false, exposed: false,
value_source: ParsedValueSource::None, value_source: ParsedValueSource::None,
@ -1011,14 +1041,15 @@ mod tests {
number_hard_min: None, number_hard_min: None,
number_hard_max: None, number_hard_max: None,
number_mode_range: None, number_mode_range: None,
number_display_decimal_places: None,
number_step: None,
implementations: { implementations: {
let mut p = Punctuated::new(); let mut p = Punctuated::new();
p.push(parse_quote!(f32)); p.push(parse_quote!(f32));
p.push(parse_quote!(f64)); p.push(parse_quote!(f64));
p p
}, },
}),
number_display_decimal_places: None,
number_step: None,
unit: None, unit: None,
}], }],
body: TokenStream2::new(), body: TokenStream2::new(),
@ -1069,11 +1100,12 @@ mod tests {
}, },
output_type: parse_quote!(f64), output_type: parse_quote!(f64),
is_async: false, is_async: false,
fields: vec![ParsedField::Regular { fields: vec![ParsedField {
pat_ident: pat_ident("b"), pat_ident: pat_ident("b"),
name: None, name: None,
description: String::from("b"), description: String::from("b"),
widget_override: ParsedWidgetOverride::None, widget_override: ParsedWidgetOverride::None,
ty: ParsedFieldType::Regular(RegularParsedField {
ty: parse_quote!(f64), ty: parse_quote!(f64),
exposed: false, exposed: false,
value_source: ParsedValueSource::None, value_source: ParsedValueSource::None,
@ -1082,9 +1114,10 @@ mod tests {
number_hard_min: None, number_hard_min: None,
number_hard_max: None, number_hard_max: None,
number_mode_range: Some(parse_quote!((0., 100.))), number_mode_range: Some(parse_quote!((0., 100.))),
implementations: Punctuated::new(),
}),
number_display_decimal_places: None, number_display_decimal_places: None,
number_step: None, number_step: None,
implementations: Punctuated::new(),
unit: None, unit: None,
}], }],
body: TokenStream2::new(), body: TokenStream2::new(),
@ -1128,12 +1161,13 @@ mod tests {
}, },
output_type: parse_quote!(Table<Raster<CPU>>), output_type: parse_quote!(Table<Raster<CPU>>),
is_async: true, is_async: true,
fields: vec![ParsedField::Regular { fields: vec![ParsedField {
pat_ident: pat_ident("path"), pat_ident: pat_ident("path"),
name: None, name: None,
ty: parse_quote!(String),
description: String::new(), description: String::new(),
widget_override: ParsedWidgetOverride::None, widget_override: ParsedWidgetOverride::None,
ty: ParsedFieldType::Regular(RegularParsedField {
ty: parse_quote!(String),
exposed: true, exposed: true,
value_source: ParsedValueSource::None, value_source: ParsedValueSource::None,
number_soft_min: None, number_soft_min: None,
@ -1141,9 +1175,10 @@ mod tests {
number_hard_min: None, number_hard_min: None,
number_hard_max: None, number_hard_max: None,
number_mode_range: None, number_mode_range: None,
implementations: Punctuated::new(),
}),
number_display_decimal_places: None, number_display_decimal_places: None,
number_step: None, number_step: None,
implementations: Punctuated::new(),
unit: None, unit: None,
}], }],
body: TokenStream2::new(), body: TokenStream2::new(),

View file

@ -1,4 +1,4 @@
use crate::parsing::{Implementation, ParsedField, ParsedNodeFn}; use crate::parsing::{Implementation, NodeParsedField, ParsedField, ParsedFieldType, ParsedNodeFn, RegularParsedField};
use proc_macro_error2::emit_error; use proc_macro_error2::emit_error;
use quote::quote; use quote::quote;
use syn::spanned::Spanned; use syn::spanned::Spanned;
@ -21,11 +21,14 @@ pub fn validate_node_fn(parsed: &ParsedNodeFn) -> syn::Result<()> {
fn validate_min_max(parsed: &ParsedNodeFn) { fn validate_min_max(parsed: &ParsedNodeFn) {
for field in &parsed.fields { for field in &parsed.fields {
if let ParsedField::Regular { if let ParsedField {
ty: ParsedFieldType::Regular(RegularParsedField {
number_hard_max, number_hard_max,
number_hard_min, number_hard_min,
number_soft_max, number_soft_max,
number_soft_min, number_soft_min,
..
}),
pat_ident, pat_ident,
.. ..
} = field } = field
@ -78,7 +81,12 @@ fn validate_min_max(parsed: &ParsedNodeFn) {
} }
fn validate_primary_input_expose(parsed: &ParsedNodeFn) { fn validate_primary_input_expose(parsed: &ParsedNodeFn) {
if let Some(ParsedField::Regular { exposed: true, pat_ident, .. }) = parsed.fields.first() { if let Some(ParsedField {
ty: ParsedFieldType::Regular(RegularParsedField { exposed: true, .. }),
pat_ident,
..
}) = parsed.fields.first()
{
emit_error!( emit_error!(
pat_ident.span(), pat_ident.span(),
"Unnecessary #[expose] attribute on primary input `{}`. Primary inputs are always exposed.", "Unnecessary #[expose] attribute on primary input `{}`. Primary inputs are always exposed.",
@ -94,8 +102,9 @@ fn validate_implementations_for_generics(parsed: &ParsedNodeFn) {
if !has_skip_impl && !parsed.fn_generics.is_empty() { if !has_skip_impl && !parsed.fn_generics.is_empty() {
for field in &parsed.fields { for field in &parsed.fields {
match field { let pat_ident = &field.pat_ident;
ParsedField::Regular { ty, implementations, pat_ident, .. } => { match &field.ty {
ParsedFieldType::Regular(RegularParsedField { ty, implementations, .. }) => {
if contains_generic_param(ty, &parsed.fn_generics) && implementations.is_empty() { if contains_generic_param(ty, &parsed.fn_generics) && implementations.is_empty() {
emit_error!( emit_error!(
ty.span(), ty.span(),
@ -107,13 +116,12 @@ fn validate_implementations_for_generics(parsed: &ParsedNodeFn) {
); );
} }
} }
ParsedField::Node { ParsedFieldType::Node(NodeParsedField {
input_type, input_type,
output_type, output_type,
implementations, implementations,
pat_ident,
.. ..
} => { }) => {
if (contains_generic_param(input_type, &parsed.fn_generics) || contains_generic_param(output_type, &parsed.fn_generics)) && implementations.is_empty() { if (contains_generic_param(input_type, &parsed.fn_generics) || contains_generic_param(output_type, &parsed.fn_generics)) && implementations.is_empty() {
emit_error!( emit_error!(
pat_ident.span(), pat_ident.span(),