From cd4124a59627046f368443e37fa08b968eb50f1a Mon Sep 17 00:00:00 2001 From: Dennis Kobert Date: Sat, 21 Sep 2024 21:57:45 +0200 Subject: [PATCH] Improve node macro and add more diagnostics (#1999) * Improve node macro ergonomics * Fix type error in stub import * Fix wasm nodes * Code review --------- Co-authored-by: Keavon Chambers --- Cargo.lock | 1 + node-graph/gcore/src/graphic_element.rs | 44 ++--- node-graph/gcore/src/raster.rs | 4 +- node-graph/gcore/src/raster/adjustments.rs | 74 ++++---- node-graph/gcore/src/transform.rs | 12 +- .../src/vector/vector_data/modification.rs | 2 +- node-graph/gcore/src/vector/vector_nodes.rs | 8 +- node-graph/gstd/src/image_color_palette.rs | 2 +- node-graph/gstd/src/vector.rs | 2 +- node-graph/gstd/src/wasm_application_io.rs | 24 +-- node-graph/node-macro/Cargo.toml | 5 +- node-graph/node-macro/src/codegen.rs | 16 +- node-graph/node-macro/src/lib.rs | 3 + node-graph/node-macro/src/parsing.rs | 174 +++++++++++++++--- node-graph/node-macro/src/validation.rs | 109 +++++++++++ 15 files changed, 358 insertions(+), 122 deletions(-) create mode 100644 node-graph/node-macro/src/validation.rs diff --git a/Cargo.lock b/Cargo.lock index 6b8d50def..ea30b27d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3850,6 +3850,7 @@ dependencies = [ "graphene-core", "indoc", "proc-macro-crate 3.2.0", + "proc-macro-error", "proc-macro2", "quote", "syn 2.0.77", diff --git a/node-graph/gcore/src/graphic_element.rs b/node-graph/gcore/src/graphic_element.rs index 215a0fd1e..2304d71a9 100644 --- a/node-graph/gcore/src/graphic_element.rs +++ b/node-graph/gcore/src/graphic_element.rs @@ -234,8 +234,8 @@ impl ArtboardGroup { #[node_macro::node(category(""))] async fn layer( #[implementations((), Footprint)] footprint: F, - #[implementations(((), GraphicGroup), (Footprint, GraphicGroup))] stack: impl Node, - #[implementations(((), GraphicElement), (Footprint, GraphicElement))] graphic_element: impl Node, + #[implementations(() -> GraphicGroup, Footprint -> GraphicGroup)] stack: impl Node, + #[implementations(() -> GraphicElement, Footprint -> GraphicElement)] graphic_element: impl Node, node_path: Vec, ) -> GraphicGroup { let mut element = graphic_element.eval(footprint).await; @@ -257,15 +257,15 @@ async fn layer( async fn to_element + 'n>( #[implementations((), (), (), (), Footprint)] footprint: F, #[implementations( - ((), VectorData), - ((), ImageFrame), - ((), GraphicGroup), - ((), TextureFrame), - (Footprint, VectorData), - (Footprint, ImageFrame), - (Footprint, GraphicGroup), - (Footprint, TextureFrame), - )] + () -> VectorData, + () -> ImageFrame, + () -> GraphicGroup, + () -> TextureFrame, + Footprint -> VectorData, + Footprint -> ImageFrame, + Footprint -> GraphicGroup, + Footprint -> TextureFrame, + )] data: impl Node, ) -> GraphicElement { data.eval(footprint).await.into() @@ -275,14 +275,14 @@ async fn to_element + 'n>( async fn to_group + 'n>( #[implementations((), (), (), (), Footprint)] footprint: F, #[implementations( - ((), VectorData), - ((), ImageFrame), - ((), GraphicGroup), - ((), TextureFrame), - (Footprint, VectorData), - (Footprint, ImageFrame), - (Footprint, GraphicGroup), - (Footprint, TextureFrame), + () -> VectorData, + () -> ImageFrame, + () -> GraphicGroup, + () -> TextureFrame, + Footprint -> VectorData, + Footprint -> ImageFrame, + Footprint -> GraphicGroup, + Footprint -> TextureFrame, )] element: impl Node, ) -> GraphicGroup { @@ -292,7 +292,7 @@ async fn to_group + 'n>( #[node_macro::node(category(""))] async fn to_artboard( #[implementations((), Footprint)] mut footprint: F, - #[implementations(((), GraphicGroup), (Footprint, GraphicGroup))] contents: impl Node, + #[implementations(() -> GraphicGroup, Footprint -> GraphicGroup)] contents: impl Node, label: String, location: IVec2, dimensions: IVec2, @@ -314,8 +314,8 @@ async fn to_artboard( #[node_macro::node(category(""))] async fn append_artboard( #[implementations((), Footprint)] footprint: F, - #[implementations(((), ArtboardGroup), (Footprint, ArtboardGroup))] artboards: impl Node, - #[implementations(((), Artboard), (Footprint, Artboard))] artboard: impl Node, + #[implementations(() -> ArtboardGroup, Footprint -> ArtboardGroup)] artboards: impl Node, + #[implementations(() -> Artboard, Footprint -> Artboard)] artboard: impl Node, node_path: Vec, ) -> ArtboardGroup { let artboard = artboard.eval(footprint).await; diff --git a/node-graph/gcore/src/raster.rs b/node-graph/gcore/src/raster.rs index b90e508cb..4ba6adfd9 100644 --- a/node-graph/gcore/src/raster.rs +++ b/node-graph/gcore/src/raster.rs @@ -310,7 +310,7 @@ impl SetBlendMode for ImageFrame { #[node_macro::node(category("Style"))] async fn blend_mode( footprint: Footprint, - #[implementations((Footprint, crate::vector::VectorData), (Footprint, crate::GraphicGroup), (Footprint, ImageFrame))] value: impl Node, + #[implementations(Footprint -> crate::vector::VectorData, Footprint -> crate::GraphicGroup, Footprint -> ImageFrame)] value: impl Node, blend_mode: BlendMode, ) -> T { let mut value = value.eval(footprint).await; @@ -321,7 +321,7 @@ async fn blend_mode( #[node_macro::node(category("Style"))] async fn opacity( footprint: Footprint, - #[implementations((Footprint, crate::vector::VectorData), (Footprint, crate::GraphicGroup), (Footprint, ImageFrame))] value: impl Node, + #[implementations(Footprint -> crate::vector::VectorData, Footprint -> crate::GraphicGroup, Footprint -> ImageFrame)] value: impl Node, #[default(100.)] factor: Percentage, ) -> T { let mut value = value.eval(footprint).await; diff --git a/node-graph/gcore/src/raster/adjustments.rs b/node-graph/gcore/src/raster/adjustments.rs index 7a3c394d9..00eefba43 100644 --- a/node-graph/gcore/src/raster/adjustments.rs +++ b/node-graph/gcore/src/raster/adjustments.rs @@ -271,7 +271,7 @@ impl From for vello::peniko::Mix { #[node_macro::node(category("Raster: Adjustment"))] // Unique to Graphite async fn luminance>( footprint: Footprint, - #[implementations((Footprint, Color), (Footprint, ImageFrame))] input: impl Node, + #[implementations(Footprint -> Color, Footprint -> ImageFrame)] input: impl Node, luminance_calc: LuminanceCalculation, ) -> T { let mut input = input.eval(footprint).await; @@ -291,7 +291,7 @@ async fn luminance>( #[node_macro::node(category("Raster"))] async fn extract_channel>( footprint: Footprint, - #[implementations((Footprint, Color), (Footprint, ImageFrame))] input: impl Node, + #[implementations(Footprint -> Color, Footprint -> ImageFrame)] input: impl Node, channel: RedGreenBlueAlpha, ) -> T { let mut input = input.eval(footprint).await; @@ -308,7 +308,7 @@ async fn extract_channel>( } #[node_macro::node(category("Raster"))] -async fn make_opaque>(footprint: Footprint, #[implementations((Footprint, Color), (Footprint, ImageFrame))] input: impl Node) -> T { +async fn make_opaque>(footprint: Footprint, #[implementations(Footprint -> Color, Footprint -> ImageFrame)] input: impl Node) -> T { let mut input = input.eval(footprint).await; input.adjust(|color| { if color.a() == 0. { @@ -323,7 +323,7 @@ async fn make_opaque>(footprint: Footprint, #[implementations(( #[node_macro::node(category("Raster: Adjustment"))] async fn levels>( footprint: Footprint, - #[implementations((Footprint, Color), (Footprint, ImageFrame))] image: impl Node, + #[implementations(Footprint -> Color, Footprint -> ImageFrame)] image: impl Node, #[default(0.)] shadows: Percentage, #[default(50.)] midtones: Percentage, #[default(100.)] highlights: Percentage, @@ -381,7 +381,7 @@ async fn levels>( #[node_macro::node(name("Black & White"), category("Raster: Adjustment"))] async fn black_and_white>( footprint: Footprint, - #[implementations((Footprint, Color), (Footprint, ImageFrame))] image: impl Node, + #[implementations(Footprint -> Color, Footprint -> ImageFrame)] image: impl Node, #[default(000000ff)] tint: Color, #[default(40.)] #[range((-200., 300.))] @@ -446,7 +446,7 @@ async fn black_and_white>( #[node_macro::node(name("Hue/Saturation"), category("Raster: Adjustment"))] async fn hue_saturation>( footprint: Footprint, - #[implementations((Footprint, Color), (Footprint, ImageFrame))] input: impl Node, + #[implementations(Footprint -> Color, Footprint -> ImageFrame)] input: impl Node, hue_shift: Angle, saturation_shift: SignedPercentage, lightness_shift: SignedPercentage, @@ -472,7 +472,7 @@ async fn hue_saturation>( } #[node_macro::node(category("Raster: Adjustment"))] -async fn invert>(footprint: Footprint, #[implementations((Footprint, Color), (Footprint, ImageFrame))] input: impl Node) -> T { +async fn invert>(footprint: Footprint, #[implementations(Footprint -> Color, Footprint -> ImageFrame)] input: impl Node) -> T { let mut input = input.eval(footprint).await; input.adjust(|color| { let color = color.to_gamma_srgb(); @@ -487,7 +487,7 @@ async fn invert>(footprint: Footprint, #[implementations((Footp #[node_macro::node(category("Raster: Adjustment"))] async fn threshold>( footprint: Footprint, - #[implementations((Footprint, Color), (Footprint, ImageFrame))] image: impl Node, + #[implementations(Footprint -> Color, Footprint -> ImageFrame)] image: impl Node, #[default(50.)] min_luminance: Percentage, #[default(100.)] max_luminance: Percentage, luminance_calc: LuminanceCalculation, @@ -574,22 +574,22 @@ impl Blend for GradientStops { async fn blend + Send>( #[implementations((), (), (), Footprint)] footprint: F, #[implementations( - ((), Color), - ((), ImageFrame), - ((), GradientStops), - (Footprint, Color), - (Footprint, ImageFrame), - (Footprint, GradientStops), + () -> Color, + () -> ImageFrame, + () -> GradientStops, + Footprint -> Color, + Footprint -> ImageFrame, + Footprint -> GradientStops, )] over: impl Node, #[expose] #[implementations( - ((), Color), - ((), ImageFrame), - ((), GradientStops), - (Footprint, Color), - (Footprint, ImageFrame), - (Footprint, GradientStops), + () -> Color, + () -> ImageFrame, + () -> GradientStops, + Footprint -> Color, + Footprint -> ImageFrame, + Footprint -> GradientStops, )] under: impl Node, blend_mode: BlendMode, @@ -693,12 +693,12 @@ pub fn blend_colors(foreground: Color, background: Color, blend_mode: BlendMode, async fn gradient_map>( #[implementations((), (), (), Footprint)] footprint: F, #[implementations( - ((), Color), - ((), ImageFrame), - ((), GradientStops), - (Footprint, Color), - (Footprint, ImageFrame), - (Footprint, GradientStops), + () -> Color, + () -> ImageFrame, + () -> GradientStops, + Footprint -> Color, + Footprint -> ImageFrame, + Footprint -> GradientStops, )] image: impl Node, gradient: GradientStops, @@ -720,7 +720,7 @@ async fn gradient_map>( #[node_macro::node(category("Raster: Adjustment"))] async fn vibrance>( footprint: Footprint, - #[implementations((Footprint, Color), (Footprint, ImageFrame))] image: impl Node, + #[implementations(Footprint -> Color, Footprint -> ImageFrame)] image: impl Node, vibrance: SignedPercentage, ) -> T { let mut input = image.eval(footprint).await; @@ -1003,7 +1003,7 @@ impl DomainWarpType { #[node_macro::node(category("Raster: Adjustment"))] async fn channel_mixer>( footprint: Footprint, - #[implementations((Footprint, Color), (Footprint, ImageFrame))] image: impl Node, + #[implementations(Footprint -> Color, Footprint -> ImageFrame)] image: impl Node, monochrome: bool, #[default(40.)] @@ -1145,7 +1145,7 @@ impl core::fmt::Display for SelectiveColorChoice { #[node_macro::node(category("Raster: Adjustment"))] async fn selective_color>( footprint: Footprint, - #[implementations((Footprint, Color), (Footprint, ImageFrame))] image: impl Node, + #[implementations(Footprint -> Color, Footprint -> ImageFrame)] image: impl Node, mode: RelativeAbsolute, #[name("(Reds) Cyan")] r_c: f64, #[name("(Reds) Magenta")] r_m: f64, @@ -1293,7 +1293,7 @@ impl MultiplyAlpha for ImageFrame

{ #[node_macro::node(category("Raster: Adjustment"))] async fn posterize>( footprint: Footprint, - #[implementations((Footprint, Color), (Footprint, ImageFrame))] input: impl Node, + #[implementations(Footprint -> Color, Footprint -> ImageFrame)] input: impl Node, #[default(4)] #[min(2.)] levels: u32, @@ -1317,7 +1317,7 @@ async fn posterize>( #[node_macro::node(category("Raster: Adjustment"))] async fn exposure>( footprint: Footprint, - #[implementations((Footprint, Color), (Footprint, ImageFrame))] input: impl Node, + #[implementations(Footprint -> Color, Footprint -> ImageFrame)] input: impl Node, exposure: f64, offset: f64, #[default(1.)] @@ -1387,12 +1387,12 @@ fn generate_curves(_: (), curve: Curve, #[implementa async fn color_overlay>( #[implementations((), (), (), Footprint)] footprint: F, #[implementations( - ((), Color), - ((), ImageFrame), - ((), GradientStops), - (Footprint, Color), - (Footprint, ImageFrame), - (Footprint, GradientStops), + () -> Color, + () -> ImageFrame, + () -> GradientStops, + Footprint -> Color, + Footprint -> ImageFrame, + Footprint -> GradientStops, )] image: impl Node, #[default(000000ff)] color: Color, diff --git a/node-graph/gcore/src/transform.rs b/node-graph/gcore/src/transform.rs index 7daadba7b..2c73da0b3 100644 --- a/node-graph/gcore/src/transform.rs +++ b/node-graph/gcore/src/transform.rs @@ -220,12 +220,12 @@ impl ApplyTransform for () { async fn transform + ApplyTransform + 'n + Clone + Send + Sync, T: TransformMut + 'n>( #[implementations(Footprint, Footprint, Footprint, (), (), ())] mut input: I, #[implementations( - (Footprint, VectorData), - (Footprint, GraphicGroup), - (Footprint, ImageFrame), - ((), VectorData), - ((), GraphicGroup), - ((), ImageFrame), + Footprint -> VectorData, + Footprint -> GraphicGroup, + Footprint -> ImageFrame, + () -> VectorData, + () -> GraphicGroup, + () -> ImageFrame, )] transform_target: impl Node, translate: DVec2, diff --git a/node-graph/gcore/src/vector/vector_data/modification.rs b/node-graph/gcore/src/vector/vector_data/modification.rs index d6c86721d..6d1f8eabc 100644 --- a/node-graph/gcore/src/vector/vector_data/modification.rs +++ b/node-graph/gcore/src/vector/vector_data/modification.rs @@ -426,7 +426,7 @@ use crate::transform::Footprint; #[node_macro::node(category(""))] async fn path_modify( #[implementations((), Footprint)] input: F, - #[implementations(((), VectorData), (Footprint, VectorData))] vector_data: impl Node, + #[implementations(() -> VectorData, Footprint -> VectorData)] vector_data: impl Node, modification: Box, ) -> VectorData { let mut vector_data = vector_data.eval(input).await; diff --git a/node-graph/gcore/src/vector/vector_nodes.rs b/node-graph/gcore/src/vector/vector_nodes.rs index b8fe5abd4..e917b8117 100644 --- a/node-graph/gcore/src/vector/vector_nodes.rs +++ b/node-graph/gcore/src/vector/vector_nodes.rs @@ -29,7 +29,7 @@ impl VectorIterMut for VectorData { #[node_macro::node(category("Vector: Style"), path(graphene_core::vector))] async fn assign_colors( footprint: Footprint, - #[implementations((Footprint, GraphicGroup), (Footprint, VectorData))] vector_group: impl Node, + #[implementations(Footprint -> GraphicGroup, Footprint -> VectorData)] vector_group: impl Node, #[default(true)] fill: bool, stroke: bool, gradient: GradientStops, @@ -177,7 +177,7 @@ async fn circular_repeat( #[node_macro::node(category("Vector"), path(graphene_core::vector))] async fn bounding_box( #[implementations((), Footprint)] footprint: F, - #[implementations(((), VectorData), (Footprint, VectorData))] vector_data: impl Node, + #[implementations(() -> VectorData, Footprint -> VectorData)] vector_data: impl Node, ) -> VectorData { let vector_data = vector_data.eval(footprint).await; @@ -253,7 +253,7 @@ async fn copy_to_points, #[expose] - #[implementations((Footprint, VectorData), (Footprint, GraphicGroup))] + #[implementations(Footprint -> VectorData, Footprint -> GraphicGroup)] instance: impl Node, #[default(1)] random_scale_min: f64, #[default(1)] random_scale_max: f64, @@ -387,7 +387,7 @@ async fn sample_points( #[node_macro::node(category(""), path(graphene_core::vector))] async fn poisson_disk_points( #[implementations((), Footprint)] footprint: F, - #[implementations(((), VectorData), (Footprint, VectorData))] vector_data: impl Node, + #[implementations(() -> VectorData, Footprint -> VectorData)] vector_data: impl Node, #[default(10.)] #[min(0.01)] separation_disk_diameter: f64, diff --git a/node-graph/gstd/src/image_color_palette.rs b/node-graph/gstd/src/image_color_palette.rs index 456215d59..5218e7807 100644 --- a/node-graph/gstd/src/image_color_palette.rs +++ b/node-graph/gstd/src/image_color_palette.rs @@ -5,7 +5,7 @@ use graphene_core::Color; #[node_macro::node(category("Raster"))] async fn image_color_palette( #[implementations((), Footprint)] footprint: F, - #[implementations(((), ImageFrame), (Footprint, ImageFrame))] image: impl Node>, + #[implementations(() -> ImageFrame, Footprint -> ImageFrame)] image: impl Node>, #[min(1.)] #[max(28.)] max_size: u32, diff --git a/node-graph/gstd/src/vector.rs b/node-graph/gstd/src/vector.rs index 4623b8c6b..fcb425395 100644 --- a/node-graph/gstd/src/vector.rs +++ b/node-graph/gstd/src/vector.rs @@ -13,7 +13,7 @@ use std::ops::{Div, Mul}; #[node_macro::node(category(""))] async fn boolean_operation( #[implementations((), Footprint)] footprint: F, - #[implementations(((), GraphicGroup), (Footprint, GraphicGroup))] group_of_paths: impl Node, + #[implementations(() -> GraphicGroup, Footprint -> GraphicGroup)] group_of_paths: impl Node, operation: BooleanOperation, ) -> VectorData { let group_of_paths = group_of_paths.eval(footprint).await; diff --git a/node-graph/gstd/src/wasm_application_io.rs b/node-graph/gstd/src/wasm_application_io.rs index 3ae4d8546..084975a91 100644 --- a/node-graph/gstd/src/wasm_application_io.rs +++ b/node-graph/gstd/src/wasm_application_io.rs @@ -133,7 +133,7 @@ async fn render_canvas(render_config: RenderConfig, data: impl GraphicElementRen #[cfg(target_arch = "wasm32")] async fn rasterize( _: (), - #[implementations((Footprint, VectorData), (Footprint, ImageFrame), (Footprint, GraphicGroup))] data: impl Node, + #[implementations(Footprint -> VectorData, Footprint -> ImageFrame, Footprint -> GraphicGroup)] data: impl Node, footprint: Footprint, surface_handle: Arc>, ) -> ImageFrame { @@ -190,17 +190,17 @@ async fn render<'a: 'n, T: 'n + GraphicElementRendered + WasmNotSend>( render_config: RenderConfig, editor_api: &'a WasmEditorApi, #[implementations( - (Footprint, VectorData), - (Footprint, ImageFrame), - (Footprint, GraphicGroup), - (Footprint, graphene_core::Artboard), - (Footprint, graphene_core::ArtboardGroup), - (Footprint, Option), - (Footprint, Vec), - (Footprint, bool), - (Footprint, f32), - (Footprint, f64), - (Footprint, String), + Footprint -> VectorData, + Footprint -> ImageFrame, + Footprint -> GraphicGroup, + Footprint -> graphene_core::Artboard, + Footprint -> graphene_core::ArtboardGroup, + Footprint -> Option, + Footprint -> Vec, + Footprint -> bool, + Footprint -> f32, + Footprint -> f64, + Footprint -> String, )] data: impl Node, _surface_handle: impl Node<(), Output = Option>, diff --git a/node-graph/node-macro/Cargo.toml b/node-graph/node-macro/Cargo.toml index 50acb5f7b..52ee23893 100644 --- a/node-graph/node-macro/Cargo.toml +++ b/node-graph/node-macro/Cargo.toml @@ -15,13 +15,14 @@ proc-macro = true [dependencies] # Workspace dependencies -syn = { workspace = true, features = ["extra-traits", "full", "printing", "parsing", "clone-impls", "proc-macro", "visit-mut"] } -proc-macro2 = { workspace = true } +syn = { workspace = true, features = [ "extra-traits", "full", "printing", "parsing", "clone-impls", "proc-macro", "visit-mut", "visit"] } +proc-macro2 = { workspace = true, features = [ "span-locations" ] } quote = { workspace = true } convert_case = { workspace = true } indoc = "2.0.5" proc-macro-crate = "3.1.0" +proc-macro-error = "1.0" [dev-dependencies] graphene-core = { workspace = true } diff --git a/node-graph/node-macro/src/codegen.rs b/node-graph/node-macro/src/codegen.rs index d418ac9ce..467ce2ff6 100644 --- a/node-graph/node-macro/src/codegen.rs +++ b/node-graph/node-macro/src/codegen.rs @@ -140,7 +140,10 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result implementations.into_iter().cloned().collect::>(), - ParsedField::Node { implementations, .. } => implementations.into_iter().map(|tuple| syn::Type::Tuple(tuple.clone())).collect(), + ParsedField::Node { implementations, .. } => implementations + .into_iter() + .flat_map(|implementation| [implementation.input.clone(), implementation.output.clone()]) + .collect(), }); let all_implementation_types = all_implementation_types.chain(input.implementations.iter().cloned()); @@ -200,6 +203,7 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result syn::Result = core::marker::PhantomData; + + static #import_name: core::marker::PhantomData<(#(#all_implementation_types,)*)> = core::marker::PhantomData; #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct #struct_name<#(#struct_generics,)*> { @@ -286,19 +290,19 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st match field { ParsedField::Regular { implementations, ty, .. } => { if !implementations.is_empty() { - implementations.into_iter().map(|ty| (&unit, ty, false)).collect() + implementations.iter().map(|ty| (&unit, ty, false)).collect() } else { vec![(&unit, ty, false)] } } ParsedField::Node { implementations, - output_type, input_type, + output_type, .. } => { if !implementations.is_empty() { - implementations.into_iter().map(|tup| (&tup.elems[0], &tup.elems[1], true)).collect() + implementations.iter().map(|impl_| (&impl_.input, &impl_.output, true)).collect() } else { vec![(input_type, output_type, true)] } diff --git a/node-graph/node-macro/src/lib.rs b/node-graph/node-macro/src/lib.rs index c7d0d156e..75b87008b 100644 --- a/node-graph/node-macro/src/lib.rs +++ b/node-graph/node-macro/src/lib.rs @@ -1,5 +1,6 @@ use proc_macro::TokenStream; use proc_macro2::Span; +use proc_macro_error::proc_macro_error; use quote::{format_ident, quote, ToTokens}; use syn::{ parse_macro_input, punctuated::Punctuated, token::Comma, AngleBracketedGenericArguments, AssocType, FnArg, GenericArgument, GenericParam, Ident, ItemFn, Lifetime, Pat, PatIdent, PathArguments, @@ -8,6 +9,7 @@ use syn::{ mod codegen; mod parsing; +mod validation; /// A macro used to construct a proto node implementation from the given struct and the decorated function. /// @@ -102,6 +104,7 @@ pub fn old_node_fn(attr: TokenStream, item: TokenStream) -> TokenStream { new_constructor } +#[proc_macro_error] #[proc_macro_attribute] pub fn node(attr: TokenStream, item: TokenStream) -> TokenStream { // Performs the `node_impl` macro's functionality of attaching an `impl Node for TheGivenStruct` block to the node struct diff --git a/node-graph/node-macro/src/parsing.rs b/node-graph/node-macro/src/parsing.rs index 98594bb07..f6749c079 100644 --- a/node-graph/node-macro/src/parsing.rs +++ b/node-graph/node-macro/src/parsing.rs @@ -1,14 +1,21 @@ use convert_case::{Case, Casing}; -use indoc::indoc; +use indoc::{formatdoc, indoc}; use proc_macro2::TokenStream as TokenStream2; use quote::{format_ident, ToTokens}; use syn::parse::{Parse, ParseStream, Parser}; use syn::punctuated::Punctuated; -use syn::token::Comma; -use syn::{Attribute, Error, ExprTuple, FnArg, GenericParam, Ident, ItemFn, LitFloat, LitStr, Meta, Pat, PatIdent, PatType, Path, ReturnType, Type, TypeTuple, WhereClause}; +use syn::token::{Comma, RArrow}; +use syn::{Attribute, Error, ExprTuple, FnArg, GenericParam, Ident, ItemFn, LitFloat, LitStr, Meta, Pat, PatIdent, PatType, Path, ReturnType, Type, WhereClause}; use crate::codegen::generate_node_code; +#[derive(Debug)] +pub(crate) struct Implementation { + pub(crate) input: Type, + pub(crate) _arrow: RArrow, + pub(crate) output: Type, +} + #[derive(Debug)] pub(crate) struct ParsedNodeFn { pub(crate) attributes: NodeFnAttributes, @@ -60,7 +67,7 @@ pub(crate) enum ParsedField { name: Option, input_type: Type, output_type: Type, - implementations: Punctuated, + implementations: Punctuated, }, } #[derive(Debug)] @@ -70,6 +77,46 @@ pub(crate) struct Input { pub(crate) implementations: Punctuated, } +impl Parse for Implementation { + fn parse(input: ParseStream) -> syn::Result { + let input_type: Type = input.parse().map_err(|e| { + Error::new( + input.span(), + formatdoc!( + "Failed to parse input type for #[implementation(...)]. Expected a valid Rust type. + Error: {}", + e, + ), + ) + })?; + let arrow: RArrow = input.parse().map_err(|_| { + Error::new( + input.span(), + indoc!( + "Expected `->` arrow after input type in #[implementations(...)] on a field of type `impl Node`. + The correct syntax is `InputType -> OutputType`." + ), + ) + })?; + let output_type: Type = input.parse().map_err(|e| { + Error::new( + input.span(), + formatdoc!( + "Failed to parse output type for #[implementation(...)]. Expected a valid Rust type after `->`. + Error: {}", + e + ), + ) + })?; + + Ok(Implementation { + input: input_type, + _arrow: arrow, + output: output_type, + }) + } +} + impl Parse for NodeFnAttributes { fn parse(input: ParseStream) -> syn::Result { let mut category = None; @@ -228,14 +275,31 @@ fn parse_inputs(inputs: &Punctuated) -> syn::Result<(Input, Vec(attr: &Attribute, name: &Ident) -> syn::Result> { - let content: TokenStream2 = attr - .parse_args() - .map_err(|e| Error::new_spanned(attr, format!("Invalid implementations for argument '{}': {}", name, e)))?; +fn parse_implementations(attr: &Attribute, name: &Ident) -> syn::Result> { + let content: TokenStream2 = attr.parse_args()?; + let parser = Punctuated::::parse_terminated; + parser.parse2(content.clone()).map_err(|e| { + let span = e.span(); // Get the span of the error + Error::new(span, format!("Failed to parse implementations for argument '{}': {}", name, e)) + }) +} + +fn parse_node_implementations(attr: &Attribute, name: &Ident) -> syn::Result> { + let content: TokenStream2 = attr.parse_args()?; let parser = Punctuated::::parse_terminated; - parser - .parse2(content) - .map_err(|e| Error::new_spanned(attr, format!("Failed to parse implementations for argument '{}': {}", name, e))) + parser.parse2(content.clone()).map_err(|e| { + Error::new( + e.span(), + formatdoc!( + "Invalid #[implementations(...)] for argument `{}`. + Expected a comma-separated list of `InputType -> OutputType` pairs. + Example: #[implementations(i32 -> f64, String -> Vec)] + Error: {}", + name, + e + ), + ) + }) } fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Result { @@ -300,11 +364,6 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul } } - let implementations = extract_attribute(attrs, "implementations") - .map(|attr| parse_implementations(attr, ident)) - .transpose()? - .unwrap_or_default(); - let (is_node, node_input_type, node_output_type) = parse_node_type(&ty); if is_node { @@ -315,7 +374,7 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul return Err(Error::new_spanned(&ty, "No default values for `impl Node` allowed")); } let implementations = extract_attribute(attrs, "implementations") - .map(|attr| parse_implementations(attr, ident)) + .map(|attr| parse_node_implementations(attr, ident)) .transpose()? .unwrap_or_default(); @@ -327,6 +386,10 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul implementations, }) } else { + let implementations = extract_attribute(attrs, "implementations") + .map(|attr| parse_implementations(attr, ident)) + .transpose()? + .unwrap_or_default(); Ok(ParsedField::Regular { pat_ident, name, @@ -381,16 +444,16 @@ fn extract_attribute<'a>(attrs: &'a [Attribute], name: &str) -> Option<&'a Attri // Modify the new_node_fn function to use the code generation pub fn new_node_fn(attr: TokenStream2, item: TokenStream2) -> TokenStream2 { - match parse_node_fn(attr, item.clone()).and_then(|x| generate_node_code(&x)) { - Ok(parsed) => { - /*let generated_code = generate_node_code(&parsed); - // panic!("{}", generated_code.to_string()); - quote! { - // #item - #generated_code - }*/ - parsed - } + let parse_result = parse_node_fn(attr, item.clone()); + let Ok(parsed_node) = parse_result else { + let e = parse_result.unwrap_err(); + return Error::new(e.span(), format!("Failed to parse node function: {e}")).to_compile_error(); + }; + if let Err(e) = crate::validation::validate_node_fn(&parsed_node) { + return Error::new(e.span(), format!("Validation Error:\n{e}")).to_compile_error(); + } + match generate_node_code(&parsed_node) { + Ok(parsed) => parsed, Err(e) => { // Return the error as a compile error Error::new(e.span(), format!("Failed to parse node function: {}", e)).to_compile_error() @@ -403,7 +466,7 @@ mod tests { use super::*; use proc_macro2::Span; use proc_macro_crate::FoundCrate; - use quote::quote; + use quote::{quote, quote_spanned}; use syn::parse_quote; fn pat_ident(name: &str) -> PatIdent { PatIdent { @@ -869,4 +932,59 @@ mod tests { ); parse_node_fn(attr, input).unwrap(); } + + #[test] + fn test_invalid_implementation_syntax() { + let attr = quote!(category("Test")); + let input = quote!( + fn test_node(_: (), #[implementations((Footprint, Color), (Footprint, ImageFrame))] input: impl Node) -> T { + // Implementation details... + } + ); + + let result = parse_node_fn(attr, input); + assert!(result.is_err()); + let error = result.unwrap_err(); + let error_message = error.to_string(); + assert!(error_message.contains("Invalid #[implementations(...)] for argument `input`")); + assert!(error_message.contains("Expected a comma-separated list of `InputType -> OutputType` pairs")); + assert!(error_message.contains("Expected `->` arrow after input type in #[implementations(...)] on a field of type `impl Node`")); + } + + #[test] + fn test_implementation_on_first_arg() { + let attr = quote!(category("Test")); + + // Use quote_spanned! to attach a specific span to the problematic part + let problem_span = proc_macro2::Span::call_site(); // You could create a custom span here if needed + let tuples = quote_spanned!(problem_span=> () ()); + let input = quote! { + fn test_node( + #[implementations((), #tuples, Footprint)] footprint: F, + #[implementations( + () -> Color, + () -> ImageFrame, + () -> GradientStops, + Footprint -> Color, + Footprint -> ImageFrame, + Footprint -> GradientStops, + )] + image: impl Node, + ) -> T { + // Implementation details... + } + }; + + let result = parse_node_fn(attr, input); + assert!(result.is_err(), "Expected an error, but parsing succeeded"); + + let error = result.unwrap_err(); + let error_string = error.to_string(); + assert!(error_string.contains("Failed to parse implementations for argument 'footprint'")); + assert!(error_string.contains("expected `,`")); + + // Instead of checking for exact line and column, + // verify that the error span is the one we specified + assert_eq!(error.span().start(), problem_span.start()); + } } diff --git a/node-graph/node-macro/src/validation.rs b/node-graph/node-macro/src/validation.rs new file mode 100644 index 000000000..b832e83ca --- /dev/null +++ b/node-graph/node-macro/src/validation.rs @@ -0,0 +1,109 @@ +use crate::parsing::{Implementation, ParsedField, ParsedNodeFn}; + +use proc_macro_error::emit_error; +use quote::quote; +use syn::{spanned::Spanned, GenericParam, Type}; + +pub fn validate_node_fn(parsed: &ParsedNodeFn) -> syn::Result<()> { + let validators: &[fn(&ParsedNodeFn)] = &[ + // Add more validators here as needed + validate_implementations_for_generics, + validate_primary_input_expose, + ]; + + for validator in validators { + validator(parsed); + } + + Ok(()) +} + +fn validate_primary_input_expose(parsed: &ParsedNodeFn) { + if let Some(ParsedField::Regular { exposed: true, pat_ident, .. }) = parsed.fields.first() { + emit_error!( + pat_ident.span(), + "Unnecessary #[expose] attribute on primary input `{}`. Primary inputs are always exposed.", + pat_ident.ident; + help = "You can safely remove the #[expose] attribute from this field."; + note = "The function's second argument, `{}`, is the node's primary input and it's always exposed by default", pat_ident.ident + ); + } +} + +fn validate_implementations_for_generics(parsed: &ParsedNodeFn) { + let has_skip_impl = parsed.attributes.skip_impl; + + if !has_skip_impl && !parsed.fn_generics.is_empty() { + for field in &parsed.fields { + match field { + ParsedField::Regular { ty, implementations, pat_ident, .. } => { + if contains_generic_param(ty, &parsed.fn_generics) && implementations.is_empty() { + emit_error!( + ty.span(), + "Generic type `{}` in field `{}` requires an #[implementations(...)] attribute", + quote!(#ty), + pat_ident.ident; + help = "Add #[implementations(ConcreteType1, ConcreteType2)] to field '{}'", pat_ident.ident; + help = "Or use #[skip_impl] if you want to manually implement the node" + ); + } + } + ParsedField::Node { + input_type, + output_type, + implementations, + pat_ident, + .. + } => { + if (contains_generic_param(input_type, &parsed.fn_generics) || contains_generic_param(output_type, &parsed.fn_generics)) && implementations.is_empty() { + emit_error!( + pat_ident.span(), + "Generic types in Node field `{}` require an #[implementations(...)] attribute", + pat_ident.ident; + help = "Add #[implementations(InputType1 -> OutputType1, InputType2 -> OutputType2)] to field '{}'", pat_ident.ident; + help = "Or use #[skip_impl] if you want to manually implement the node" + ); + } + // Additional check for Node implementations + for impl_ in implementations { + validate_node_implementation(impl_, input_type, output_type, &parsed.fn_generics); + } + } + } + } + } +} + +fn validate_node_implementation(impl_: &Implementation, input_type: &Type, output_type: &Type, fn_generics: &[GenericParam]) { + if contains_generic_param(&impl_.input, fn_generics) || contains_generic_param(&impl_.output, fn_generics) { + emit_error!( + impl_.input.span(), + "Implementation types `{}` and `{}` must be concrete, not generic", + quote!(#input_type), quote!(#output_type); + help = "Replace generic types with concrete types in the implementation" + ); + } +} + +fn contains_generic_param(ty: &Type, fn_generics: &[GenericParam]) -> bool { + struct GenericParamChecker<'a> { + fn_generics: &'a [GenericParam], + found: bool, + } + + impl<'a> syn::visit::Visit<'a> for GenericParamChecker<'a> { + fn visit_ident(&mut self, ident: &'a syn::Ident) { + if self + .fn_generics + .iter() + .any(|param| if let GenericParam::Type(type_param) = param { type_param.ident == *ident } else { false }) + { + self.found = true; + } + } + } + + let mut checker = GenericParamChecker { fn_generics, found: false }; + syn::visit::visit_type(&mut checker, ty); + checker.found +}