Improve node macro and add more diagnostics (#1999)

* Improve node macro ergonomics

* Fix type error in stub import

* Fix wasm nodes

* Code review

---------

Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
Dennis Kobert 2024-09-21 21:57:45 +02:00 committed by GitHub
parent 3eb98c6d6d
commit cd4124a596
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 358 additions and 122 deletions

View file

@ -1,14 +1,21 @@
use convert_case::{Case, Casing};
use indoc::indoc;
use indoc::{formatdoc, indoc};
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, ToTokens};
use syn::parse::{Parse, ParseStream, Parser};
use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{Attribute, Error, ExprTuple, FnArg, GenericParam, Ident, ItemFn, LitFloat, LitStr, Meta, Pat, PatIdent, PatType, Path, ReturnType, Type, TypeTuple, WhereClause};
use syn::token::{Comma, RArrow};
use syn::{Attribute, Error, ExprTuple, FnArg, GenericParam, Ident, ItemFn, LitFloat, LitStr, Meta, Pat, PatIdent, PatType, Path, ReturnType, Type, WhereClause};
use crate::codegen::generate_node_code;
#[derive(Debug)]
pub(crate) struct Implementation {
pub(crate) input: Type,
pub(crate) _arrow: RArrow,
pub(crate) output: Type,
}
#[derive(Debug)]
pub(crate) struct ParsedNodeFn {
pub(crate) attributes: NodeFnAttributes,
@ -60,7 +67,7 @@ pub(crate) enum ParsedField {
name: Option<LitStr>,
input_type: Type,
output_type: Type,
implementations: Punctuated<TypeTuple, Comma>,
implementations: Punctuated<Implementation, Comma>,
},
}
#[derive(Debug)]
@ -70,6 +77,46 @@ pub(crate) struct Input {
pub(crate) implementations: Punctuated<Type, Comma>,
}
impl Parse for Implementation {
fn parse(input: ParseStream) -> syn::Result<Self> {
let input_type: Type = input.parse().map_err(|e| {
Error::new(
input.span(),
formatdoc!(
"Failed to parse input type for #[implementation(...)]. Expected a valid Rust type.
Error: {}",
e,
),
)
})?;
let arrow: RArrow = input.parse().map_err(|_| {
Error::new(
input.span(),
indoc!(
"Expected `->` arrow after input type in #[implementations(...)] on a field of type `impl Node`.
The correct syntax is `InputType -> OutputType`."
),
)
})?;
let output_type: Type = input.parse().map_err(|e| {
Error::new(
input.span(),
formatdoc!(
"Failed to parse output type for #[implementation(...)]. Expected a valid Rust type after `->`.
Error: {}",
e
),
)
})?;
Ok(Implementation {
input: input_type,
_arrow: arrow,
output: output_type,
})
}
}
impl Parse for NodeFnAttributes {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut category = None;
@ -228,14 +275,31 @@ fn parse_inputs(inputs: &Punctuated<FnArg, Comma>) -> syn::Result<(Input, Vec<Pa
Ok((input, fields))
}
fn parse_implementations<T: Parse>(attr: &Attribute, name: &Ident) -> syn::Result<Punctuated<T, Comma>> {
let content: TokenStream2 = attr
.parse_args()
.map_err(|e| Error::new_spanned(attr, format!("Invalid implementations for argument '{}': {}", name, e)))?;
fn parse_implementations(attr: &Attribute, name: &Ident) -> syn::Result<Punctuated<Type, Comma>> {
let content: TokenStream2 = attr.parse_args()?;
let parser = Punctuated::<Type, Comma>::parse_terminated;
parser.parse2(content.clone()).map_err(|e| {
let span = e.span(); // Get the span of the error
Error::new(span, format!("Failed to parse implementations for argument '{}': {}", name, e))
})
}
fn parse_node_implementations<T: Parse>(attr: &Attribute, name: &Ident) -> syn::Result<Punctuated<T, Comma>> {
let content: TokenStream2 = attr.parse_args()?;
let parser = Punctuated::<T, Comma>::parse_terminated;
parser
.parse2(content)
.map_err(|e| Error::new_spanned(attr, format!("Failed to parse implementations for argument '{}': {}", name, e)))
parser.parse2(content.clone()).map_err(|e| {
Error::new(
e.span(),
formatdoc!(
"Invalid #[implementations(...)] for argument `{}`.
Expected a comma-separated list of `InputType -> OutputType` pairs.
Example: #[implementations(i32 -> f64, String -> Vec<u8>)]
Error: {}",
name,
e
),
)
})
}
fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Result<ParsedField> {
@ -300,11 +364,6 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
}
}
let implementations = extract_attribute(attrs, "implementations")
.map(|attr| parse_implementations(attr, ident))
.transpose()?
.unwrap_or_default();
let (is_node, node_input_type, node_output_type) = parse_node_type(&ty);
if is_node {
@ -315,7 +374,7 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
return Err(Error::new_spanned(&ty, "No default values for `impl Node` allowed"));
}
let implementations = extract_attribute(attrs, "implementations")
.map(|attr| parse_implementations(attr, ident))
.map(|attr| parse_node_implementations(attr, ident))
.transpose()?
.unwrap_or_default();
@ -327,6 +386,10 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
implementations,
})
} else {
let implementations = extract_attribute(attrs, "implementations")
.map(|attr| parse_implementations(attr, ident))
.transpose()?
.unwrap_or_default();
Ok(ParsedField::Regular {
pat_ident,
name,
@ -381,16 +444,16 @@ fn extract_attribute<'a>(attrs: &'a [Attribute], name: &str) -> Option<&'a Attri
// Modify the new_node_fn function to use the code generation
pub fn new_node_fn(attr: TokenStream2, item: TokenStream2) -> TokenStream2 {
match parse_node_fn(attr, item.clone()).and_then(|x| generate_node_code(&x)) {
Ok(parsed) => {
/*let generated_code = generate_node_code(&parsed);
// panic!("{}", generated_code.to_string());
quote! {
// #item
#generated_code
}*/
parsed
}
let parse_result = parse_node_fn(attr, item.clone());
let Ok(parsed_node) = parse_result else {
let e = parse_result.unwrap_err();
return Error::new(e.span(), format!("Failed to parse node function: {e}")).to_compile_error();
};
if let Err(e) = crate::validation::validate_node_fn(&parsed_node) {
return Error::new(e.span(), format!("Validation Error:\n{e}")).to_compile_error();
}
match generate_node_code(&parsed_node) {
Ok(parsed) => parsed,
Err(e) => {
// Return the error as a compile error
Error::new(e.span(), format!("Failed to parse node function: {}", e)).to_compile_error()
@ -403,7 +466,7 @@ mod tests {
use super::*;
use proc_macro2::Span;
use proc_macro_crate::FoundCrate;
use quote::quote;
use quote::{quote, quote_spanned};
use syn::parse_quote;
fn pat_ident(name: &str) -> PatIdent {
PatIdent {
@ -869,4 +932,59 @@ mod tests {
);
parse_node_fn(attr, input).unwrap();
}
#[test]
fn test_invalid_implementation_syntax() {
let attr = quote!(category("Test"));
let input = quote!(
fn test_node(_: (), #[implementations((Footprint, Color), (Footprint, ImageFrame<Color>))] input: impl Node<Footprint, Output = T>) -> T {
// Implementation details...
}
);
let result = parse_node_fn(attr, input);
assert!(result.is_err());
let error = result.unwrap_err();
let error_message = error.to_string();
assert!(error_message.contains("Invalid #[implementations(...)] for argument `input`"));
assert!(error_message.contains("Expected a comma-separated list of `InputType -> OutputType` pairs"));
assert!(error_message.contains("Expected `->` arrow after input type in #[implementations(...)] on a field of type `impl Node`"));
}
#[test]
fn test_implementation_on_first_arg() {
let attr = quote!(category("Test"));
// Use quote_spanned! to attach a specific span to the problematic part
let problem_span = proc_macro2::Span::call_site(); // You could create a custom span here if needed
let tuples = quote_spanned!(problem_span=> () ());
let input = quote! {
fn test_node(
#[implementations((), #tuples, Footprint)] footprint: F,
#[implementations(
() -> Color,
() -> ImageFrame<Color>,
() -> GradientStops,
Footprint -> Color,
Footprint -> ImageFrame<Color>,
Footprint -> GradientStops,
)]
image: impl Node<F, Output = T>,
) -> T {
// Implementation details...
}
};
let result = parse_node_fn(attr, input);
assert!(result.is_err(), "Expected an error, but parsing succeeded");
let error = result.unwrap_err();
let error_string = error.to_string();
assert!(error_string.contains("Failed to parse implementations for argument 'footprint'"));
assert!(error_string.contains("expected `,`"));
// Instead of checking for exact line and column,
// verify that the error span is the one we specified
assert_eq!(error.span().start(), problem_span.start());
}
}