Refactor the node macro and simply most of the node implementations (#1942)

* Add support structure for new node macro to gcore

* Fix compile issues and code generation

* Implement new node_fn macro

* Implement property translation

* Fix NodeIO type generation

* Start translating math nodes

* Move node implementation to outer scope to allow usage of local imports

* Add expose attribute to allow controlling the parameter exposure

* Add rust analyzer support for #[implementations] attribute

* Migrate logic nodes

* Handle where clause properly

* Implement argument ident pattern preservation

* Implement adjustment layer mapping

* Fix node registry types

* Fix module paths

* Improve demo artwork comptibility

* Improve macro error reporting

* Fix handling of impl node implementations

* Fix nodeio type computation

* Fix opacity node and graph type resolution

* Fix loading of demo artworks

* Fix eslint

* Fix typo in macro test

* Remove node definitions for Adjustment Nodes

* Fix type alias property generation and make adjustments footprint aware

* Convert vector nodes

* Implement path overrides

* Fix stroke node

* Fix painted dreams

* Implement experimental type level specialization

* Fix poisson disk sampling -> all demo artworks should work again

* Port text node + make node macro more robust by implementing lifetime substitution

* Fix vector node tests

* Fix red dress demo + ci

* Fix clippy warnings

* Code review

* Fix primary input issues

* Improve math nodes and audit others

* Set no_properties when no automatic properties are derived

* Port vector generator nodes (could not derive all definitions yet)

* Various QA changes and add min/max/mode_range to number parameters

* Add min and max for f64 and u32

* Convert gpu nodes and clean up unused nodes

* Partially port transform node

* Allow implementations on call arg

* Port path modify node

* Start porting graphic element nodes

* Transform nodes in graphic_element.rs

* Port brush node

* Port nodes in wasm_executior

* Rename node macro

* Fix formatting

* Fix Mandelbrot node

* Formatting

* Fix Load Image and Load Resource nodes, add scope input to node macro

* Remove unnecessary underscores

* Begin attemping to make nodes resolution-aware

* Infer a generic manual compositon type on generic call arg

* Various fixes and work towards merging

* Final changes for merge!

* Fix tests, probably

* More free line removals!

---------

Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
Dennis Kobert 2024-09-20 12:50:30 +02:00 committed by GitHub
parent ca0d102296
commit e352c7fa71
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
92 changed files with 4255 additions and 7275 deletions

View file

@ -0,0 +1,872 @@
use convert_case::{Case, Casing};
use indoc::indoc;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, ToTokens};
use syn::parse::{Parse, ParseStream, Parser};
use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{Attribute, Error, ExprTuple, FnArg, GenericParam, Ident, ItemFn, LitFloat, LitStr, Meta, Pat, PatIdent, PatType, Path, ReturnType, Type, TypeTuple, WhereClause};
use crate::codegen::generate_node_code;
#[derive(Debug)]
pub(crate) struct ParsedNodeFn {
pub(crate) attributes: NodeFnAttributes,
pub(crate) fn_name: Ident,
pub(crate) struct_name: Ident,
pub(crate) mod_name: Ident,
pub(crate) fn_generics: Vec<GenericParam>,
pub(crate) where_clause: Option<WhereClause>,
pub(crate) input: Input,
pub(crate) output_type: Type,
pub(crate) is_async: bool,
pub(crate) fields: Vec<ParsedField>,
pub(crate) body: TokenStream2,
pub(crate) crate_name: proc_macro_crate::FoundCrate,
}
#[derive(Debug, Default)]
pub(crate) struct NodeFnAttributes {
pub(crate) category: Option<LitStr>,
pub(crate) display_name: Option<LitStr>,
pub(crate) path: Option<Path>,
pub(crate) skip_impl: bool,
// Add more attributes as needed
}
#[derive(Debug, Default)]
pub enum ValueSource {
#[default]
None,
Default(TokenStream2),
Scope(LitStr),
}
#[derive(Debug)]
pub(crate) enum ParsedField {
Regular {
pat_ident: PatIdent,
name: Option<LitStr>,
ty: Type,
exposed: bool,
value_source: ValueSource,
number_min: Option<LitFloat>,
number_max: Option<LitFloat>,
number_mode_range: Option<ExprTuple>,
implementations: Punctuated<Type, Comma>,
},
Node {
pat_ident: PatIdent,
name: Option<LitStr>,
input_type: Type,
output_type: Type,
implementations: Punctuated<TypeTuple, Comma>,
},
}
#[derive(Debug)]
pub(crate) struct Input {
pub(crate) pat_ident: PatIdent,
pub(crate) ty: Type,
pub(crate) implementations: Punctuated<Type, Comma>,
}
impl Parse for NodeFnAttributes {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut category = None;
let mut display_name = None;
let mut path = None;
let mut skip_impl = false;
let content = input;
// let content;
// syn::parenthesized!(content in input);
let nested = content.call(Punctuated::<Meta, Comma>::parse_terminated)?;
for meta in nested {
match meta {
Meta::List(meta) if meta.path.is_ident("category") => {
if category.is_some() {
return Err(Error::new_spanned(meta, "Multiple 'category' attributes are not allowed"));
}
let lit: LitStr = meta
.parse_args()
.map_err(|_| Error::new_spanned(meta, "Expected a string literal for 'category', e.g., category(\"Value\")"))?;
category = Some(lit);
}
Meta::List(meta) if meta.path.is_ident("name") => {
if display_name.is_some() {
return Err(Error::new_spanned(meta, "Multiple 'name' attributes are not allowed"));
}
let parsed_name: LitStr = meta.parse_args().map_err(|_| Error::new_spanned(meta, "Expected a string for 'name', e.g., name(\"Memoize\")"))?;
display_name = Some(parsed_name);
}
Meta::List(meta) if meta.path.is_ident("path") => {
if path.is_some() {
return Err(Error::new_spanned(meta, "Multiple 'path' attributes are not allowed"));
}
let parsed_path: Path = meta
.parse_args()
.map_err(|_| Error::new_spanned(meta, "Expected a valid path for 'path', e.g., path(crate::MemoizeNode)"))?;
path = Some(parsed_path);
}
Meta::Path(path) if path.is_ident("skip_impl") => {
if skip_impl {
return Err(Error::new_spanned(path, "Multiple 'skip_impl' attributes are not allowed"));
}
skip_impl = true;
}
_ => {
return Err(Error::new_spanned(
meta,
indoc!(
r#"
Unsupported attribute in `node`.
Supported attributes are 'category', 'path' and 'name'.
Example usage:
#[node_macro::node(category("Value"), name("Test Node"))]
"#
),
));
}
}
}
Ok(NodeFnAttributes {
category,
display_name,
path,
skip_impl,
})
}
}
fn parse_node_fn(attr: TokenStream2, item: TokenStream2) -> syn::Result<ParsedNodeFn> {
let attributes = syn::parse2::<NodeFnAttributes>(attr.clone()).map_err(|e| Error::new(e.span(), format!("Failed to parse node_fn attributes: {}", e)))?;
let input_fn = syn::parse2::<ItemFn>(item.clone()).map_err(|e| Error::new(e.span(), format!("Failed to parse function: {}. Make sure it's a valid Rust function.", e)))?;
let fn_name = input_fn.sig.ident.clone();
let struct_name = format_ident!("{}", fn_name.to_string().to_case(Case::Pascal));
let mod_name = fn_name.clone();
let fn_generics = input_fn.sig.generics.params.into_iter().collect();
let is_async = input_fn.sig.asyncness.is_some();
let (input, fields) = parse_inputs(&input_fn.sig.inputs)?;
let output_type = parse_output(&input_fn.sig.output)?;
let where_clause = input_fn.sig.generics.where_clause;
let body = input_fn.block.to_token_stream();
let crate_name = proc_macro_crate::crate_name("graphene-core").map_err(|e| {
Error::new(
proc_macro2::Span::call_site(),
format!("Failed to find location of graphene_core. Make sure it is imported as a dependency: {}", e),
)
})?;
Ok(ParsedNodeFn {
attributes,
fn_name,
struct_name,
mod_name,
fn_generics,
input,
output_type,
is_async,
fields,
where_clause,
body,
crate_name,
})
}
fn parse_inputs(inputs: &Punctuated<FnArg, Comma>) -> syn::Result<(Input, Vec<ParsedField>)> {
let mut fields = Vec::new();
let mut input = None;
for (index, arg) in inputs.iter().enumerate() {
if let FnArg::Typed(PatType { pat, ty, attrs, .. }) = arg {
// Call argument
if index == 0 {
if extract_attribute(attrs, "default").is_some() {
return Err(Error::new_spanned(&attrs[0], "Call argument cannot be given a default value".to_string()));
}
if extract_attribute(attrs, "expose").is_some() {
return Err(Error::new_spanned(&attrs[0], "Call argument cannot be exposed".to_string()));
}
let pat_ident = match (**pat).clone() {
Pat::Ident(pat_ident) => pat_ident,
Pat::Wild(wild) => PatIdent {
attrs: wild.attrs,
by_ref: None,
mutability: None,
ident: wild.underscore_token.into(),
subpat: None,
},
_ => continue,
};
let implementations = extract_attribute(attrs, "implementations")
.map(|attr| parse_implementations(attr, &pat_ident.ident))
.transpose()?
.unwrap_or_default();
input = Some(Input {
pat_ident,
ty: (**ty).clone(),
implementations,
});
} else if let Pat::Ident(pat_ident) = &**pat {
let field = parse_field(pat_ident.clone(), (**ty).clone(), attrs).map_err(|e| Error::new_spanned(pat_ident, format!("Failed to parse argument '{}': {}", pat_ident.ident, e)))?;
fields.push(field);
} else {
return Err(Error::new_spanned(pat, "Expected a simple identifier for the field name"));
}
} else {
return Err(Error::new_spanned(arg, "Expected a typed argument (e.g., `x: i32`)"));
}
}
let input = input.ok_or_else(|| Error::new_spanned(inputs, "Expected at least one input argument. The first argument should be the node input type."))?;
Ok((input, fields))
}
fn parse_implementations<T: Parse>(attr: &Attribute, name: &Ident) -> syn::Result<Punctuated<T, Comma>> {
let content: TokenStream2 = attr
.parse_args()
.map_err(|e| Error::new_spanned(attr, format!("Invalid implementations for argument '{}': {}", name, e)))?;
let parser = Punctuated::<T, Comma>::parse_terminated;
parser
.parse2(content)
.map_err(|e| Error::new_spanned(attr, format!("Failed to parse implementations for argument '{}': {}", name, e)))
}
fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Result<ParsedField> {
let ident = &pat_ident.ident;
let default_value = extract_attribute(attrs, "default")
.map(|attr| {
attr.parse_args()
.map_err(|e| Error::new_spanned(attr, format!("Invalid `default` value for argument '{}': {}", ident, e)))
})
.transpose()?;
let scope = extract_attribute(attrs, "scope")
.map(|attr| {
attr.parse_args()
.map_err(|e| Error::new_spanned(attr, format!("Invalid `scope` value for argument '{}': {}", ident, e)))
})
.transpose()?;
let name = extract_attribute(attrs, "name")
.map(|attr| attr.parse_args().map_err(|e| Error::new_spanned(attr, format!("Invalid `name` value for argument '{}': {}", ident, e))))
.transpose()?;
let exposed = extract_attribute(attrs, "expose").is_some();
let value_source = match (default_value, scope) {
(Some(_), Some(_)) => return Err(Error::new_spanned(&pat_ident, "Cannot have both `default` and `scope` attributes")),
(Some(default_value), _) => ValueSource::Default(default_value),
(_, Some(scope)) => ValueSource::Scope(scope),
_ => ValueSource::None,
};
let number_min = extract_attribute(attrs, "min")
.map(|attr| {
attr.parse_args()
.map_err(|e| Error::new_spanned(attr, format!("Invalid numerical `min` value for argument '{}': {}", ident, e)))
})
.transpose()?;
let number_max = extract_attribute(attrs, "max")
.map(|attr| {
attr.parse_args()
.map_err(|e| Error::new_spanned(attr, format!("Invalid numerical `max` value for argument '{}': {}", ident, e)))
})
.transpose()?;
let number_mode_range = extract_attribute(attrs, "range")
.map(|attr| {
attr.parse_args::<ExprTuple>().map_err(|e| {
Error::new_spanned(
attr,
format!(
"Invalid `range` tuple of min and max range slider values for argument '{}': {}\nUSAGE EXAMPLE: #[range((0., 100.))]",
ident, e
),
)
})
})
.transpose()?;
if let Some(range) = &number_mode_range {
if range.elems.len() != 2 {
return Err(Error::new_spanned(range, "Expected a tuple of two values for `range` for the min and max, respectively"));
}
}
let implementations = extract_attribute(attrs, "implementations")
.map(|attr| parse_implementations(attr, ident))
.transpose()?
.unwrap_or_default();
let (is_node, node_input_type, node_output_type) = parse_node_type(&ty);
if is_node {
let (input_type, output_type) = node_input_type
.zip(node_output_type)
.ok_or_else(|| Error::new_spanned(&ty, "Invalid Node type. Expected `impl Node<Input, Output = OutputType>`"))?;
if !matches!(&value_source, ValueSource::None) {
return Err(Error::new_spanned(&ty, "No default values for `impl Node` allowed"));
}
let implementations = extract_attribute(attrs, "implementations")
.map(|attr| parse_implementations(attr, ident))
.transpose()?
.unwrap_or_default();
Ok(ParsedField::Node {
pat_ident,
name,
input_type,
output_type,
implementations,
})
} else {
Ok(ParsedField::Regular {
pat_ident,
name,
exposed,
number_min,
number_max,
number_mode_range,
ty,
value_source,
implementations,
})
}
}
fn parse_node_type(ty: &Type) -> (bool, Option<Type>, Option<Type>) {
if let Type::ImplTrait(impl_trait) = ty {
for bound in &impl_trait.bounds {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
if trait_bound.path.segments.last().map_or(false, |seg| seg.ident == "Node") {
if let syn::PathArguments::AngleBracketed(args) = &trait_bound.path.segments.last().unwrap().arguments {
let input_type = args.args.iter().find_map(|arg| if let syn::GenericArgument::Type(ty) = arg { Some(ty.clone()) } else { None });
let output_type = args.args.iter().find_map(|arg| {
if let syn::GenericArgument::AssocType(assoc_type) = arg {
if assoc_type.ident == "Output" {
Some(assoc_type.ty.clone())
} else {
None
}
} else {
None
}
});
return (true, input_type, output_type);
}
}
}
}
}
(false, None, None)
}
fn parse_output(output: &ReturnType) -> syn::Result<Type> {
match output {
ReturnType::Default => Ok(syn::parse_quote!(())),
ReturnType::Type(_, ty) => Ok((**ty).clone()),
}
}
fn extract_attribute<'a>(attrs: &'a [Attribute], name: &str) -> Option<&'a Attribute> {
attrs.iter().find(|attr| attr.path().is_ident(name))
}
// Modify the new_node_fn function to use the code generation
pub fn new_node_fn(attr: TokenStream2, item: TokenStream2) -> TokenStream2 {
match parse_node_fn(attr, item.clone()).and_then(|x| generate_node_code(&x)) {
Ok(parsed) => {
/*let generated_code = generate_node_code(&parsed);
// panic!("{}", generated_code.to_string());
quote! {
// #item
#generated_code
}*/
parsed
}
Err(e) => {
// Return the error as a compile error
Error::new(e.span(), format!("Failed to parse node function: {}", e)).to_compile_error()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use proc_macro2::Span;
use proc_macro_crate::FoundCrate;
use quote::quote;
use syn::parse_quote;
fn pat_ident(name: &str) -> PatIdent {
PatIdent {
attrs: Vec::new(),
by_ref: None,
mutability: None,
ident: Ident::new(name, Span::call_site()),
subpat: None,
}
}
fn assert_parsed_node_fn(parsed: &ParsedNodeFn, expected: &ParsedNodeFn) {
assert_eq!(parsed.fn_name, expected.fn_name);
assert_eq!(parsed.struct_name, expected.struct_name);
assert_eq!(parsed.mod_name, expected.mod_name);
assert_eq!(parsed.is_async, expected.is_async);
assert_eq!(format!("{:?}", parsed.input), format!("{:?}", expected.input));
assert_eq!(format!("{:?}", parsed.output_type), format!("{:?}", expected.output_type));
assert_eq!(parsed.attributes.category, expected.attributes.category);
assert_eq!(parsed.attributes.display_name, expected.attributes.display_name);
assert_eq!(parsed.attributes.path, expected.attributes.path);
assert_eq!(parsed.attributes.skip_impl, expected.attributes.skip_impl);
assert_eq!(parsed.fields.len(), expected.fields.len());
for (parsed_field, expected_field) in parsed.fields.iter().zip(expected.fields.iter()) {
match (parsed_field, expected_field) {
(
ParsedField::Regular {
pat_ident: p_name,
ty: p_ty,
exposed: p_exp,
value_source: p_default,
..
},
ParsedField::Regular {
pat_ident: e_name,
ty: e_ty,
exposed: e_exp,
value_source: e_default,
..
},
) => {
assert_eq!(p_name, e_name);
assert_eq!(p_exp, e_exp);
match (p_default, e_default) {
(ValueSource::None, ValueSource::None) => {}
(ValueSource::Default(p), ValueSource::Default(e)) => {
assert_eq!(p.to_token_stream().to_string(), e.to_token_stream().to_string());
}
(ValueSource::Scope(p), ValueSource::Scope(e)) => {
assert_eq!(p.value(), e.value());
}
_ => panic!("Mismatched default values"),
}
assert_eq!(format!("{:?}", p_ty), format!("{:?}", e_ty));
}
(
ParsedField::Node {
pat_ident: p_name,
input_type: p_input,
output_type: p_output,
..
},
ParsedField::Node {
pat_ident: e_name,
input_type: e_input,
output_type: e_output,
..
},
) => {
assert_eq!(p_name, e_name);
assert_eq!(format!("{:?}", p_input), format!("{:?}", e_input));
assert_eq!(format!("{:?}", p_output), format!("{:?}", e_output));
}
_ => panic!("Mismatched field types"),
}
}
}
#[test]
fn test_basic_node() {
let attr = quote!(category("Math: Arithmetic"), path(graphene_core::TestNode), skip_impl);
let input = quote!(
fn add(a: f64, b: f64) -> f64 {
a + b
}
);
let parsed = parse_node_fn(attr, input).unwrap();
let expected = ParsedNodeFn {
attributes: NodeFnAttributes {
category: Some(parse_quote!("Math: Arithmetic")),
display_name: None,
path: Some(parse_quote!(graphene_core::TestNode)),
skip_impl: true,
},
fn_name: Ident::new("add", Span::call_site()),
struct_name: Ident::new("Add", Span::call_site()),
mod_name: Ident::new("add", Span::call_site()),
fn_generics: vec![],
where_clause: None,
input: Input {
pat_ident: pat_ident("a"),
ty: parse_quote!(f64),
implementations: Punctuated::new(),
},
output_type: parse_quote!(f64),
is_async: false,
fields: vec![ParsedField::Regular {
pat_ident: pat_ident("b"),
name: None,
ty: parse_quote!(f64),
exposed: false,
value_source: ValueSource::None,
number_min: None,
number_max: None,
number_mode_range: None,
implementations: Punctuated::new(),
}],
body: TokenStream2::new(),
crate_name: FoundCrate::Itself,
};
assert_parsed_node_fn(&parsed, &expected);
}
#[test]
fn test_node_with_impl_node() {
let attr = quote!(category("General"));
let input = quote!(
fn transform<T: 'static>(footprint: Footprint, transform_target: impl Node<Footprint, Output = T>, translate: DVec2) -> T {
// Implementation details...
}
);
let parsed = parse_node_fn(attr, input).unwrap();
let expected = ParsedNodeFn {
attributes: NodeFnAttributes {
category: Some(parse_quote!("General")),
display_name: None,
path: None,
skip_impl: false,
},
fn_name: Ident::new("transform", Span::call_site()),
struct_name: Ident::new("Transform", Span::call_site()),
mod_name: Ident::new("transform", Span::call_site()),
fn_generics: vec![parse_quote!(T: 'static)],
where_clause: None,
input: Input {
pat_ident: pat_ident("footprint"),
ty: parse_quote!(Footprint),
implementations: Punctuated::new(),
},
output_type: parse_quote!(T),
is_async: false,
fields: vec![
ParsedField::Node {
pat_ident: pat_ident("transform_target"),
name: None,
input_type: parse_quote!(Footprint),
output_type: parse_quote!(T),
implementations: Punctuated::new(),
},
ParsedField::Regular {
pat_ident: pat_ident("translate"),
name: None,
ty: parse_quote!(DVec2),
exposed: false,
value_source: ValueSource::None,
number_min: None,
number_max: None,
number_mode_range: None,
implementations: Punctuated::new(),
},
],
body: TokenStream2::new(),
crate_name: FoundCrate::Itself,
};
assert_parsed_node_fn(&parsed, &expected);
}
#[test]
fn test_node_with_default_values() {
let attr = quote!(category("Vector: Shape"));
let input = quote!(
fn circle(_: (), #[default(50.)] radius: f64) -> VectorData {
// Implementation details...
}
);
let parsed = parse_node_fn(attr, input).unwrap();
let expected = ParsedNodeFn {
attributes: NodeFnAttributes {
category: Some(parse_quote!("Vector: Shape")),
display_name: None,
path: None,
skip_impl: false,
},
fn_name: Ident::new("circle", Span::call_site()),
struct_name: Ident::new("Circle", Span::call_site()),
mod_name: Ident::new("circle", Span::call_site()),
fn_generics: vec![],
where_clause: None,
input: Input {
pat_ident: pat_ident("_"),
ty: parse_quote!(()),
implementations: Punctuated::new(),
},
output_type: parse_quote!(VectorData),
is_async: false,
fields: vec![ParsedField::Regular {
pat_ident: pat_ident("radius"),
name: None,
ty: parse_quote!(f64),
exposed: false,
value_source: ValueSource::Default(quote!(50.)),
number_min: None,
number_max: None,
number_mode_range: None,
implementations: Punctuated::new(),
}],
body: TokenStream2::new(),
crate_name: FoundCrate::Itself,
};
assert_parsed_node_fn(&parsed, &expected);
}
#[test]
fn test_node_with_implementations() {
let attr = quote!(category("Raster: Adjustment"));
let input = quote!(
fn levels<P: Pixel>(image: ImageFrame<P>, #[implementations(f32, f64)] shadows: f64) -> ImageFrame<P> {
// Implementation details...
}
);
let parsed = parse_node_fn(attr, input).unwrap();
let expected = ParsedNodeFn {
attributes: NodeFnAttributes {
category: Some(parse_quote!("Raster: Adjustment")),
display_name: None,
path: None,
skip_impl: false,
},
fn_name: Ident::new("levels", Span::call_site()),
struct_name: Ident::new("Levels", Span::call_site()),
mod_name: Ident::new("levels", Span::call_site()),
fn_generics: vec![parse_quote!(P: Pixel)],
where_clause: None,
input: Input {
pat_ident: pat_ident("image"),
ty: parse_quote!(ImageFrame<P>),
implementations: Punctuated::new(),
},
output_type: parse_quote!(ImageFrame<P>),
is_async: false,
fields: vec![ParsedField::Regular {
pat_ident: pat_ident("shadows"),
name: None,
ty: parse_quote!(f64),
exposed: false,
value_source: ValueSource::None,
number_min: None,
number_max: None,
number_mode_range: None,
implementations: {
let mut p = Punctuated::new();
p.push(parse_quote!(f32));
p.push(parse_quote!(f64));
p
},
}],
body: TokenStream2::new(),
crate_name: FoundCrate::Itself,
};
assert_parsed_node_fn(&parsed, &expected);
}
#[test]
fn test_number_min_max_range_mode() {
let attr = quote!(category("Math: Arithmetic"), path(graphene_core::TestNode));
let input = quote!(
fn add(
a: f64,
#[range((0., 100.))]
#[min(-500.)]
#[max(500.)]
b: f64,
) -> f64 {
a + b
}
);
let parsed = parse_node_fn(attr, input).unwrap();
let expected = ParsedNodeFn {
attributes: NodeFnAttributes {
category: Some(parse_quote!("Math: Arithmetic")),
display_name: None,
path: Some(parse_quote!(graphene_core::TestNode)),
skip_impl: false,
},
fn_name: Ident::new("add", Span::call_site()),
struct_name: Ident::new("Add", Span::call_site()),
mod_name: Ident::new("add", Span::call_site()),
fn_generics: vec![],
where_clause: None,
input: Input {
pat_ident: pat_ident("a"),
ty: parse_quote!(f64),
implementations: Punctuated::new(),
},
output_type: parse_quote!(f64),
is_async: false,
fields: vec![ParsedField::Regular {
pat_ident: pat_ident("b"),
name: None,
ty: parse_quote!(f64),
exposed: false,
value_source: ValueSource::None,
number_min: Some(parse_quote!(-500.)),
number_max: Some(parse_quote!(500.)),
number_mode_range: Some(parse_quote!((0., 100.))),
implementations: Punctuated::new(),
}],
body: TokenStream2::new(),
crate_name: FoundCrate::Itself,
};
assert_parsed_node_fn(&parsed, &expected);
}
#[test]
fn test_async_node() {
let attr = quote!(category("IO"));
let input = quote!(
async fn load_image(api: &WasmEditorApi, #[expose] path: String) -> ImageFrame<Color> {
// Implementation details...
}
);
let parsed = parse_node_fn(attr, input).unwrap();
let expected = ParsedNodeFn {
attributes: NodeFnAttributes {
category: Some(parse_quote!("IO")),
display_name: None,
path: None,
skip_impl: false,
},
fn_name: Ident::new("load_image", Span::call_site()),
struct_name: Ident::new("LoadImage", Span::call_site()),
mod_name: Ident::new("load_image", Span::call_site()),
fn_generics: vec![],
where_clause: None,
input: Input {
pat_ident: pat_ident("api"),
ty: parse_quote!(&WasmEditorApi),
implementations: Punctuated::new(),
},
output_type: parse_quote!(ImageFrame<Color>),
is_async: true,
fields: vec![ParsedField::Regular {
pat_ident: pat_ident("path"),
name: None,
ty: parse_quote!(String),
exposed: true,
value_source: ValueSource::None,
number_min: None,
number_max: None,
number_mode_range: None,
implementations: Punctuated::new(),
}],
body: TokenStream2::new(),
crate_name: FoundCrate::Itself,
};
assert_parsed_node_fn(&parsed, &expected);
}
#[test]
fn test_node_with_custom_name() {
let attr = quote!(category("Custom"), name("CustomNode2"));
let input = quote!(
fn custom_node(input: i32) -> i32 {
input * 2
}
);
let parsed = parse_node_fn(attr, input).unwrap();
let expected = ParsedNodeFn {
attributes: NodeFnAttributes {
category: Some(parse_quote!("Custom")),
display_name: Some(parse_quote!("CustomNode2")),
path: None,
skip_impl: false,
},
fn_name: Ident::new("custom_node", Span::call_site()),
struct_name: Ident::new("CustomNode", Span::call_site()),
mod_name: Ident::new("custom_node", Span::call_site()),
fn_generics: vec![],
where_clause: None,
input: Input {
pat_ident: pat_ident("input"),
ty: parse_quote!(i32),
implementations: Punctuated::new(),
},
output_type: parse_quote!(i32),
is_async: false,
fields: vec![],
body: TokenStream2::new(),
crate_name: FoundCrate::Itself,
};
assert_parsed_node_fn(&parsed, &expected);
}
#[test]
#[should_panic(expected = "Multiple 'category' attributes are not allowed")]
fn test_multiple_categories() {
let attr = quote!(category("Math: Arithmetic"), category("General"));
let input = quote!(
fn add(a: i32, b: i32) -> i32 {
a + b
}
);
parse_node_fn(attr, input).unwrap();
}
#[test]
#[should_panic(expected = "Call argument cannot be given a default value")]
fn test_default_value_for_first_arg() {
let attr = quote!(category("Invalid"));
let input = quote!(
fn invalid_node(#[default(())] node: impl Node<(), Output = i32>) -> i32 {
node.eval(())
}
);
parse_node_fn(attr, input).unwrap();
}
#[test]
#[should_panic(expected = "No default values for `impl Node` allowed")]
fn test_default_value_for_impl_node() {
let attr = quote!(category("Invalid"));
let input = quote!(
fn invalid_node(_: (), #[default(())] node: impl Node<(), Output = i32>) -> i32 {
node.eval(())
}
);
parse_node_fn(attr, input).unwrap();
}
#[test]
#[should_panic(expected = "Unsupported attribute in `node`")]
fn test_unsupported_attribute() {
let attr = quote!(unsupported("Value"));
let input = quote!(
fn test_node(input: i32) -> i32 {
input
}
);
parse_node_fn(attr, input).unwrap();
}
}