Improve node macro and add more diagnostics (#1999)

* Improve node macro ergonomics

* Fix type error in stub import

* Fix wasm nodes

* Code review

---------

Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
Dennis Kobert 2024-09-21 21:57:45 +02:00 committed by GitHub
parent 3eb98c6d6d
commit cd4124a596
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 358 additions and 122 deletions

1
Cargo.lock generated
View file

@ -3850,6 +3850,7 @@ dependencies = [
"graphene-core",
"indoc",
"proc-macro-crate 3.2.0",
"proc-macro-error",
"proc-macro2",
"quote",
"syn 2.0.77",

View file

@ -234,8 +234,8 @@ impl ArtboardGroup {
#[node_macro::node(category(""))]
async fn layer<F: 'n + Copy + Send>(
#[implementations((), Footprint)] footprint: F,
#[implementations(((), GraphicGroup), (Footprint, GraphicGroup))] stack: impl Node<F, Output = GraphicGroup>,
#[implementations(((), GraphicElement), (Footprint, GraphicElement))] graphic_element: impl Node<F, Output = GraphicElement>,
#[implementations(() -> GraphicGroup, Footprint -> GraphicGroup)] stack: impl Node<F, Output = GraphicGroup>,
#[implementations(() -> GraphicElement, Footprint -> GraphicElement)] graphic_element: impl Node<F, Output = GraphicElement>,
node_path: Vec<NodeId>,
) -> GraphicGroup {
let mut element = graphic_element.eval(footprint).await;
@ -257,15 +257,15 @@ async fn layer<F: 'n + Copy + Send>(
async fn to_element<F: 'n + Send, Data: Into<GraphicElement> + 'n>(
#[implementations((), (), (), (), Footprint)] footprint: F,
#[implementations(
((), VectorData),
((), ImageFrame<Color>),
((), GraphicGroup),
((), TextureFrame),
(Footprint, VectorData),
(Footprint, ImageFrame<Color>),
(Footprint, GraphicGroup),
(Footprint, TextureFrame),
)]
() -> VectorData,
() -> ImageFrame<Color>,
() -> GraphicGroup,
() -> TextureFrame,
Footprint -> VectorData,
Footprint -> ImageFrame<Color>,
Footprint -> GraphicGroup,
Footprint -> TextureFrame,
)]
data: impl Node<F, Output = Data>,
) -> GraphicElement {
data.eval(footprint).await.into()
@ -275,14 +275,14 @@ async fn to_element<F: 'n + Send, Data: Into<GraphicElement> + 'n>(
async fn to_group<F: 'n + Send, Data: Into<GraphicGroup> + 'n>(
#[implementations((), (), (), (), Footprint)] footprint: F,
#[implementations(
((), VectorData),
((), ImageFrame<Color>),
((), GraphicGroup),
((), TextureFrame),
(Footprint, VectorData),
(Footprint, ImageFrame<Color>),
(Footprint, GraphicGroup),
(Footprint, TextureFrame),
() -> VectorData,
() -> ImageFrame<Color>,
() -> GraphicGroup,
() -> TextureFrame,
Footprint -> VectorData,
Footprint -> ImageFrame<Color>,
Footprint -> GraphicGroup,
Footprint -> TextureFrame,
)]
element: impl Node<F, Output = Data>,
) -> GraphicGroup {
@ -292,7 +292,7 @@ async fn to_group<F: 'n + Send, Data: Into<GraphicGroup> + 'n>(
#[node_macro::node(category(""))]
async fn to_artboard<F: 'n + Copy + Send + ApplyTransform>(
#[implementations((), Footprint)] mut footprint: F,
#[implementations(((), GraphicGroup), (Footprint, GraphicGroup))] contents: impl Node<F, Output = GraphicGroup>,
#[implementations(() -> GraphicGroup, Footprint -> GraphicGroup)] contents: impl Node<F, Output = GraphicGroup>,
label: String,
location: IVec2,
dimensions: IVec2,
@ -314,8 +314,8 @@ async fn to_artboard<F: 'n + Copy + Send + ApplyTransform>(
#[node_macro::node(category(""))]
async fn append_artboard<F: 'n + Copy + Send>(
#[implementations((), Footprint)] footprint: F,
#[implementations(((), ArtboardGroup), (Footprint, ArtboardGroup))] artboards: impl Node<F, Output = ArtboardGroup>,
#[implementations(((), Artboard), (Footprint, Artboard))] artboard: impl Node<F, Output = Artboard>,
#[implementations(() -> ArtboardGroup, Footprint -> ArtboardGroup)] artboards: impl Node<F, Output = ArtboardGroup>,
#[implementations(() -> Artboard, Footprint -> Artboard)] artboard: impl Node<F, Output = Artboard>,
node_path: Vec<NodeId>,
) -> ArtboardGroup {
let artboard = artboard.eval(footprint).await;

View file

@ -310,7 +310,7 @@ impl SetBlendMode for ImageFrame<Color> {
#[node_macro::node(category("Style"))]
async fn blend_mode<T: SetBlendMode>(
footprint: Footprint,
#[implementations((Footprint, crate::vector::VectorData), (Footprint, crate::GraphicGroup), (Footprint, ImageFrame<Color>))] value: impl Node<Footprint, Output = T>,
#[implementations(Footprint -> crate::vector::VectorData, Footprint -> crate::GraphicGroup, Footprint -> ImageFrame<Color>)] value: impl Node<Footprint, Output = T>,
blend_mode: BlendMode,
) -> T {
let mut value = value.eval(footprint).await;
@ -321,7 +321,7 @@ async fn blend_mode<T: SetBlendMode>(
#[node_macro::node(category("Style"))]
async fn opacity<T: MultiplyAlpha>(
footprint: Footprint,
#[implementations((Footprint, crate::vector::VectorData), (Footprint, crate::GraphicGroup), (Footprint, ImageFrame<Color>))] value: impl Node<Footprint, Output = T>,
#[implementations(Footprint -> crate::vector::VectorData, Footprint -> crate::GraphicGroup, Footprint -> ImageFrame<Color>)] value: impl Node<Footprint, Output = T>,
#[default(100.)] factor: Percentage,
) -> T {
let mut value = value.eval(footprint).await;

View file

@ -271,7 +271,7 @@ impl From<BlendMode> for vello::peniko::Mix {
#[node_macro::node(category("Raster: Adjustment"))] // Unique to Graphite
async fn luminance<T: Adjust<Color>>(
footprint: Footprint,
#[implementations((Footprint, Color), (Footprint, ImageFrame<Color>))] input: impl Node<Footprint, Output = T>,
#[implementations(Footprint -> Color, Footprint -> ImageFrame<Color>)] input: impl Node<Footprint, Output = T>,
luminance_calc: LuminanceCalculation,
) -> T {
let mut input = input.eval(footprint).await;
@ -291,7 +291,7 @@ async fn luminance<T: Adjust<Color>>(
#[node_macro::node(category("Raster"))]
async fn extract_channel<T: Adjust<Color>>(
footprint: Footprint,
#[implementations((Footprint, Color), (Footprint, ImageFrame<Color>))] input: impl Node<Footprint, Output = T>,
#[implementations(Footprint -> Color, Footprint -> ImageFrame<Color>)] input: impl Node<Footprint, Output = T>,
channel: RedGreenBlueAlpha,
) -> T {
let mut input = input.eval(footprint).await;
@ -308,7 +308,7 @@ async fn extract_channel<T: Adjust<Color>>(
}
#[node_macro::node(category("Raster"))]
async fn make_opaque<T: Adjust<Color>>(footprint: Footprint, #[implementations((Footprint, Color), (Footprint, ImageFrame<Color>))] input: impl Node<Footprint, Output = T>) -> T {
async fn make_opaque<T: Adjust<Color>>(footprint: Footprint, #[implementations(Footprint -> Color, Footprint -> ImageFrame<Color>)] input: impl Node<Footprint, Output = T>) -> T {
let mut input = input.eval(footprint).await;
input.adjust(|color| {
if color.a() == 0. {
@ -323,7 +323,7 @@ async fn make_opaque<T: Adjust<Color>>(footprint: Footprint, #[implementations((
#[node_macro::node(category("Raster: Adjustment"))]
async fn levels<T: Adjust<Color>>(
footprint: Footprint,
#[implementations((Footprint, Color), (Footprint, ImageFrame<Color>))] image: impl Node<Footprint, Output = T>,
#[implementations(Footprint -> Color, Footprint -> ImageFrame<Color>)] image: impl Node<Footprint, Output = T>,
#[default(0.)] shadows: Percentage,
#[default(50.)] midtones: Percentage,
#[default(100.)] highlights: Percentage,
@ -381,7 +381,7 @@ async fn levels<T: Adjust<Color>>(
#[node_macro::node(name("Black & White"), category("Raster: Adjustment"))]
async fn black_and_white<T: Adjust<Color>>(
footprint: Footprint,
#[implementations((Footprint, Color), (Footprint, ImageFrame<Color>))] image: impl Node<Footprint, Output = T>,
#[implementations(Footprint -> Color, Footprint -> ImageFrame<Color>)] image: impl Node<Footprint, Output = T>,
#[default(000000ff)] tint: Color,
#[default(40.)]
#[range((-200., 300.))]
@ -446,7 +446,7 @@ async fn black_and_white<T: Adjust<Color>>(
#[node_macro::node(name("Hue/Saturation"), category("Raster: Adjustment"))]
async fn hue_saturation<T: Adjust<Color>>(
footprint: Footprint,
#[implementations((Footprint, Color), (Footprint, ImageFrame<Color>))] input: impl Node<Footprint, Output = T>,
#[implementations(Footprint -> Color, Footprint -> ImageFrame<Color>)] input: impl Node<Footprint, Output = T>,
hue_shift: Angle,
saturation_shift: SignedPercentage,
lightness_shift: SignedPercentage,
@ -472,7 +472,7 @@ async fn hue_saturation<T: Adjust<Color>>(
}
#[node_macro::node(category("Raster: Adjustment"))]
async fn invert<T: Adjust<Color>>(footprint: Footprint, #[implementations((Footprint, Color), (Footprint, ImageFrame<Color>))] input: impl Node<Footprint, Output = T>) -> T {
async fn invert<T: Adjust<Color>>(footprint: Footprint, #[implementations(Footprint -> Color, Footprint -> ImageFrame<Color>)] input: impl Node<Footprint, Output = T>) -> T {
let mut input = input.eval(footprint).await;
input.adjust(|color| {
let color = color.to_gamma_srgb();
@ -487,7 +487,7 @@ async fn invert<T: Adjust<Color>>(footprint: Footprint, #[implementations((Footp
#[node_macro::node(category("Raster: Adjustment"))]
async fn threshold<T: Adjust<Color>>(
footprint: Footprint,
#[implementations((Footprint, Color), (Footprint, ImageFrame<Color>))] image: impl Node<Footprint, Output = T>,
#[implementations(Footprint -> Color, Footprint -> ImageFrame<Color>)] image: impl Node<Footprint, Output = T>,
#[default(50.)] min_luminance: Percentage,
#[default(100.)] max_luminance: Percentage,
luminance_calc: LuminanceCalculation,
@ -574,22 +574,22 @@ impl Blend<Color> for GradientStops {
async fn blend<F: 'n + Copy + Send, T: Blend<Color> + Send>(
#[implementations((), (), (), Footprint)] footprint: F,
#[implementations(
((), Color),
((), ImageFrame<Color>),
((), GradientStops),
(Footprint, Color),
(Footprint, ImageFrame<Color>),
(Footprint, GradientStops),
() -> Color,
() -> ImageFrame<Color>,
() -> GradientStops,
Footprint -> Color,
Footprint -> ImageFrame<Color>,
Footprint -> GradientStops,
)]
over: impl Node<F, Output = T>,
#[expose]
#[implementations(
((), Color),
((), ImageFrame<Color>),
((), GradientStops),
(Footprint, Color),
(Footprint, ImageFrame<Color>),
(Footprint, GradientStops),
() -> Color,
() -> ImageFrame<Color>,
() -> GradientStops,
Footprint -> Color,
Footprint -> ImageFrame<Color>,
Footprint -> GradientStops,
)]
under: impl Node<F, Output = T>,
blend_mode: BlendMode,
@ -693,12 +693,12 @@ pub fn blend_colors(foreground: Color, background: Color, blend_mode: BlendMode,
async fn gradient_map<F: 'n + Copy + Send, T: Adjust<Color>>(
#[implementations((), (), (), Footprint)] footprint: F,
#[implementations(
((), Color),
((), ImageFrame<Color>),
((), GradientStops),
(Footprint, Color),
(Footprint, ImageFrame<Color>),
(Footprint, GradientStops),
() -> Color,
() -> ImageFrame<Color>,
() -> GradientStops,
Footprint -> Color,
Footprint -> ImageFrame<Color>,
Footprint -> GradientStops,
)]
image: impl Node<F, Output = T>,
gradient: GradientStops,
@ -720,7 +720,7 @@ async fn gradient_map<F: 'n + Copy + Send, T: Adjust<Color>>(
#[node_macro::node(category("Raster: Adjustment"))]
async fn vibrance<T: Adjust<Color>>(
footprint: Footprint,
#[implementations((Footprint, Color), (Footprint, ImageFrame<Color>))] image: impl Node<Footprint, Output = T>,
#[implementations(Footprint -> Color, Footprint -> ImageFrame<Color>)] image: impl Node<Footprint, Output = T>,
vibrance: SignedPercentage,
) -> T {
let mut input = image.eval(footprint).await;
@ -1003,7 +1003,7 @@ impl DomainWarpType {
#[node_macro::node(category("Raster: Adjustment"))]
async fn channel_mixer<T: Adjust<Color>>(
footprint: Footprint,
#[implementations((Footprint, Color), (Footprint, ImageFrame<Color>))] image: impl Node<Footprint, Output = T>,
#[implementations(Footprint -> Color, Footprint -> ImageFrame<Color>)] image: impl Node<Footprint, Output = T>,
monochrome: bool,
#[default(40.)]
@ -1145,7 +1145,7 @@ impl core::fmt::Display for SelectiveColorChoice {
#[node_macro::node(category("Raster: Adjustment"))]
async fn selective_color<T: Adjust<Color>>(
footprint: Footprint,
#[implementations((Footprint, Color), (Footprint, ImageFrame<Color>))] image: impl Node<Footprint, Output = T>,
#[implementations(Footprint -> Color, Footprint -> ImageFrame<Color>)] image: impl Node<Footprint, Output = T>,
mode: RelativeAbsolute,
#[name("(Reds) Cyan")] r_c: f64,
#[name("(Reds) Magenta")] r_m: f64,
@ -1293,7 +1293,7 @@ impl<P: Pixel> MultiplyAlpha for ImageFrame<P> {
#[node_macro::node(category("Raster: Adjustment"))]
async fn posterize<T: Adjust<Color>>(
footprint: Footprint,
#[implementations((Footprint, Color), (Footprint, ImageFrame<Color>))] input: impl Node<Footprint, Output = T>,
#[implementations(Footprint -> Color, Footprint -> ImageFrame<Color>)] input: impl Node<Footprint, Output = T>,
#[default(4)]
#[min(2.)]
levels: u32,
@ -1317,7 +1317,7 @@ async fn posterize<T: Adjust<Color>>(
#[node_macro::node(category("Raster: Adjustment"))]
async fn exposure<T: Adjust<Color>>(
footprint: Footprint,
#[implementations((Footprint, Color), (Footprint, ImageFrame<Color>))] input: impl Node<Footprint, Output = T>,
#[implementations(Footprint -> Color, Footprint -> ImageFrame<Color>)] input: impl Node<Footprint, Output = T>,
exposure: f64,
offset: f64,
#[default(1.)]
@ -1387,12 +1387,12 @@ fn generate_curves<C: Channel + super::Linear>(_: (), curve: Curve, #[implementa
async fn color_overlay<F: 'n + Copy + Send, T: Adjust<Color>>(
#[implementations((), (), (), Footprint)] footprint: F,
#[implementations(
((), Color),
((), ImageFrame<Color>),
((), GradientStops),
(Footprint, Color),
(Footprint, ImageFrame<Color>),
(Footprint, GradientStops),
() -> Color,
() -> ImageFrame<Color>,
() -> GradientStops,
Footprint -> Color,
Footprint -> ImageFrame<Color>,
Footprint -> GradientStops,
)]
image: impl Node<F, Output = T>,
#[default(000000ff)] color: Color,

View file

@ -220,12 +220,12 @@ impl ApplyTransform for () {
async fn transform<I: Into<Footprint> + ApplyTransform + 'n + Clone + Send + Sync, T: TransformMut + 'n>(
#[implementations(Footprint, Footprint, Footprint, (), (), ())] mut input: I,
#[implementations(
(Footprint, VectorData),
(Footprint, GraphicGroup),
(Footprint, ImageFrame<crate::Color>),
((), VectorData),
((), GraphicGroup),
((), ImageFrame<crate::Color>),
Footprint -> VectorData,
Footprint -> GraphicGroup,
Footprint -> ImageFrame<crate::Color>,
() -> VectorData,
() -> GraphicGroup,
() -> ImageFrame<crate::Color>,
)]
transform_target: impl Node<I, Output = T>,
translate: DVec2,

View file

@ -426,7 +426,7 @@ use crate::transform::Footprint;
#[node_macro::node(category(""))]
async fn path_modify<F: 'n + Send + Sync + Clone>(
#[implementations((), Footprint)] input: F,
#[implementations(((), VectorData), (Footprint, VectorData))] vector_data: impl Node<F, Output = VectorData>,
#[implementations(() -> VectorData, Footprint -> VectorData)] vector_data: impl Node<F, Output = VectorData>,
modification: Box<VectorModification>,
) -> VectorData {
let mut vector_data = vector_data.eval(input).await;

View file

@ -29,7 +29,7 @@ impl VectorIterMut for VectorData {
#[node_macro::node(category("Vector: Style"), path(graphene_core::vector))]
async fn assign_colors<T: VectorIterMut>(
footprint: Footprint,
#[implementations((Footprint, GraphicGroup), (Footprint, VectorData))] vector_group: impl Node<Footprint, Output = T>,
#[implementations(Footprint -> GraphicGroup, Footprint -> VectorData)] vector_group: impl Node<Footprint, Output = T>,
#[default(true)] fill: bool,
stroke: bool,
gradient: GradientStops,
@ -177,7 +177,7 @@ async fn circular_repeat(
#[node_macro::node(category("Vector"), path(graphene_core::vector))]
async fn bounding_box<F: 'n + Copy + Send>(
#[implementations((), Footprint)] footprint: F,
#[implementations(((), VectorData), (Footprint, VectorData))] vector_data: impl Node<F, Output = VectorData>,
#[implementations(() -> VectorData, Footprint -> VectorData)] vector_data: impl Node<F, Output = VectorData>,
) -> VectorData {
let vector_data = vector_data.eval(footprint).await;
@ -253,7 +253,7 @@ async fn copy_to_points<I: GraphicElementRendered + Default + ConcatElement + Tr
footprint: Footprint,
points: impl Node<Footprint, Output = VectorData>,
#[expose]
#[implementations((Footprint, VectorData), (Footprint, GraphicGroup))]
#[implementations(Footprint -> VectorData, Footprint -> GraphicGroup)]
instance: impl Node<Footprint, Output = I>,
#[default(1)] random_scale_min: f64,
#[default(1)] random_scale_max: f64,
@ -387,7 +387,7 @@ async fn sample_points(
#[node_macro::node(category(""), path(graphene_core::vector))]
async fn poisson_disk_points<F: 'n + Copy + Send>(
#[implementations((), Footprint)] footprint: F,
#[implementations(((), VectorData), (Footprint, VectorData))] vector_data: impl Node<F, Output = VectorData>,
#[implementations(() -> VectorData, Footprint -> VectorData)] vector_data: impl Node<F, Output = VectorData>,
#[default(10.)]
#[min(0.01)]
separation_disk_diameter: f64,

View file

@ -5,7 +5,7 @@ use graphene_core::Color;
#[node_macro::node(category("Raster"))]
async fn image_color_palette<F: 'n + Send>(
#[implementations((), Footprint)] footprint: F,
#[implementations(((), ImageFrame<Color>), (Footprint, ImageFrame<Color>))] image: impl Node<F, Output = ImageFrame<Color>>,
#[implementations(() -> ImageFrame<Color>, Footprint -> ImageFrame<Color>)] image: impl Node<F, Output = ImageFrame<Color>>,
#[min(1.)]
#[max(28.)]
max_size: u32,

View file

@ -13,7 +13,7 @@ use std::ops::{Div, Mul};
#[node_macro::node(category(""))]
async fn boolean_operation<F: 'n + Copy + Send>(
#[implementations((), Footprint)] footprint: F,
#[implementations(((), GraphicGroup), (Footprint, GraphicGroup))] group_of_paths: impl Node<F, Output = GraphicGroup>,
#[implementations(() -> GraphicGroup, Footprint -> GraphicGroup)] group_of_paths: impl Node<F, Output = GraphicGroup>,
operation: BooleanOperation,
) -> VectorData {
let group_of_paths = group_of_paths.eval(footprint).await;

View file

@ -133,7 +133,7 @@ async fn render_canvas(render_config: RenderConfig, data: impl GraphicElementRen
#[cfg(target_arch = "wasm32")]
async fn rasterize<T: GraphicElementRendered + graphene_core::transform::TransformMut + WasmNotSend + 'n>(
_: (),
#[implementations((Footprint, VectorData), (Footprint, ImageFrame<Color>), (Footprint, GraphicGroup))] data: impl Node<Footprint, Output = T>,
#[implementations(Footprint -> VectorData, Footprint -> ImageFrame<Color>, Footprint -> GraphicGroup)] data: impl Node<Footprint, Output = T>,
footprint: Footprint,
surface_handle: Arc<SurfaceHandle<HtmlCanvasElement>>,
) -> ImageFrame<Color> {
@ -190,17 +190,17 @@ async fn render<'a: 'n, T: 'n + GraphicElementRendered + WasmNotSend>(
render_config: RenderConfig,
editor_api: &'a WasmEditorApi,
#[implementations(
(Footprint, VectorData),
(Footprint, ImageFrame<Color>),
(Footprint, GraphicGroup),
(Footprint, graphene_core::Artboard),
(Footprint, graphene_core::ArtboardGroup),
(Footprint, Option<Color>),
(Footprint, Vec<Color>),
(Footprint, bool),
(Footprint, f32),
(Footprint, f64),
(Footprint, String),
Footprint -> VectorData,
Footprint -> ImageFrame<Color>,
Footprint -> GraphicGroup,
Footprint -> graphene_core::Artboard,
Footprint -> graphene_core::ArtboardGroup,
Footprint -> Option<Color>,
Footprint -> Vec<Color>,
Footprint -> bool,
Footprint -> f32,
Footprint -> f64,
Footprint -> String,
)]
data: impl Node<Footprint, Output = T>,
_surface_handle: impl Node<(), Output = Option<wgpu_executor::WgpuSurface>>,

View file

@ -15,13 +15,14 @@ proc-macro = true
[dependencies]
# Workspace dependencies
syn = { workspace = true, features = ["extra-traits", "full", "printing", "parsing", "clone-impls", "proc-macro", "visit-mut"] }
proc-macro2 = { workspace = true }
syn = { workspace = true, features = [ "extra-traits", "full", "printing", "parsing", "clone-impls", "proc-macro", "visit-mut", "visit"] }
proc-macro2 = { workspace = true, features = [ "span-locations" ] }
quote = { workspace = true }
convert_case = { workspace = true }
indoc = "2.0.5"
proc-macro-crate = "3.1.0"
proc-macro-error = "1.0"
[dev-dependencies]
graphene-core = { workspace = true }

View file

@ -140,7 +140,10 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
let all_implementation_types = fields.iter().flat_map(|field| match field {
ParsedField::Regular { implementations, .. } => implementations.into_iter().cloned().collect::<Vec<_>>(),
ParsedField::Node { implementations, .. } => implementations.into_iter().map(|tuple| syn::Type::Tuple(tuple.clone())).collect(),
ParsedField::Node { implementations, .. } => implementations
.into_iter()
.flat_map(|implementation| [implementation.input.clone(), implementation.output.clone()])
.collect(),
});
let all_implementation_types = all_implementation_types.chain(input.implementations.iter().cloned());
@ -200,6 +203,7 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
let identifier = quote!(format!("{}::{}", #path, stringify!(#struct_name)));
let register_node_impl = generate_register_node_impl(parsed, &field_names, &struct_name, &identifier)?;
let import_name = format_ident!("_IMPORT_STUB_{}", mod_name.to_string().to_case(Case::UpperSnake));
Ok(quote! {
/// Underlying implementation for [#struct_name]
@ -227,8 +231,8 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
use gcore::ctor::ctor;
// Use the types specified in the implementation
#[cfg(__never_compiled)]
static _IMPORTS: core::marker::PhantomData<#(#all_implementation_types,)*> = core::marker::PhantomData;
static #import_name: core::marker::PhantomData<(#(#all_implementation_types,)*)> = core::marker::PhantomData;
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct #struct_name<#(#struct_generics,)*> {
@ -286,19 +290,19 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st
match field {
ParsedField::Regular { implementations, ty, .. } => {
if !implementations.is_empty() {
implementations.into_iter().map(|ty| (&unit, ty, false)).collect()
implementations.iter().map(|ty| (&unit, ty, false)).collect()
} else {
vec![(&unit, ty, false)]
}
}
ParsedField::Node {
implementations,
output_type,
input_type,
output_type,
..
} => {
if !implementations.is_empty() {
implementations.into_iter().map(|tup| (&tup.elems[0], &tup.elems[1], true)).collect()
implementations.iter().map(|impl_| (&impl_.input, &impl_.output, true)).collect()
} else {
vec![(input_type, output_type, true)]
}

View file

@ -1,5 +1,6 @@
use proc_macro::TokenStream;
use proc_macro2::Span;
use proc_macro_error::proc_macro_error;
use quote::{format_ident, quote, ToTokens};
use syn::{
parse_macro_input, punctuated::Punctuated, token::Comma, AngleBracketedGenericArguments, AssocType, FnArg, GenericArgument, GenericParam, Ident, ItemFn, Lifetime, Pat, PatIdent, PathArguments,
@ -8,6 +9,7 @@ use syn::{
mod codegen;
mod parsing;
mod validation;
/// A macro used to construct a proto node implementation from the given struct and the decorated function.
///
@ -102,6 +104,7 @@ pub fn old_node_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
new_constructor
}
#[proc_macro_error]
#[proc_macro_attribute]
pub fn node(attr: TokenStream, item: TokenStream) -> TokenStream {
// Performs the `node_impl` macro's functionality of attaching an `impl Node for TheGivenStruct` block to the node struct

View file

@ -1,14 +1,21 @@
use convert_case::{Case, Casing};
use indoc::indoc;
use indoc::{formatdoc, indoc};
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, ToTokens};
use syn::parse::{Parse, ParseStream, Parser};
use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{Attribute, Error, ExprTuple, FnArg, GenericParam, Ident, ItemFn, LitFloat, LitStr, Meta, Pat, PatIdent, PatType, Path, ReturnType, Type, TypeTuple, WhereClause};
use syn::token::{Comma, RArrow};
use syn::{Attribute, Error, ExprTuple, FnArg, GenericParam, Ident, ItemFn, LitFloat, LitStr, Meta, Pat, PatIdent, PatType, Path, ReturnType, Type, WhereClause};
use crate::codegen::generate_node_code;
#[derive(Debug)]
pub(crate) struct Implementation {
pub(crate) input: Type,
pub(crate) _arrow: RArrow,
pub(crate) output: Type,
}
#[derive(Debug)]
pub(crate) struct ParsedNodeFn {
pub(crate) attributes: NodeFnAttributes,
@ -60,7 +67,7 @@ pub(crate) enum ParsedField {
name: Option<LitStr>,
input_type: Type,
output_type: Type,
implementations: Punctuated<TypeTuple, Comma>,
implementations: Punctuated<Implementation, Comma>,
},
}
#[derive(Debug)]
@ -70,6 +77,46 @@ pub(crate) struct Input {
pub(crate) implementations: Punctuated<Type, Comma>,
}
impl Parse for Implementation {
fn parse(input: ParseStream) -> syn::Result<Self> {
let input_type: Type = input.parse().map_err(|e| {
Error::new(
input.span(),
formatdoc!(
"Failed to parse input type for #[implementation(...)]. Expected a valid Rust type.
Error: {}",
e,
),
)
})?;
let arrow: RArrow = input.parse().map_err(|_| {
Error::new(
input.span(),
indoc!(
"Expected `->` arrow after input type in #[implementations(...)] on a field of type `impl Node`.
The correct syntax is `InputType -> OutputType`."
),
)
})?;
let output_type: Type = input.parse().map_err(|e| {
Error::new(
input.span(),
formatdoc!(
"Failed to parse output type for #[implementation(...)]. Expected a valid Rust type after `->`.
Error: {}",
e
),
)
})?;
Ok(Implementation {
input: input_type,
_arrow: arrow,
output: output_type,
})
}
}
impl Parse for NodeFnAttributes {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut category = None;
@ -228,14 +275,31 @@ fn parse_inputs(inputs: &Punctuated<FnArg, Comma>) -> syn::Result<(Input, Vec<Pa
Ok((input, fields))
}
fn parse_implementations<T: Parse>(attr: &Attribute, name: &Ident) -> syn::Result<Punctuated<T, Comma>> {
let content: TokenStream2 = attr
.parse_args()
.map_err(|e| Error::new_spanned(attr, format!("Invalid implementations for argument '{}': {}", name, e)))?;
fn parse_implementations(attr: &Attribute, name: &Ident) -> syn::Result<Punctuated<Type, Comma>> {
let content: TokenStream2 = attr.parse_args()?;
let parser = Punctuated::<Type, Comma>::parse_terminated;
parser.parse2(content.clone()).map_err(|e| {
let span = e.span(); // Get the span of the error
Error::new(span, format!("Failed to parse implementations for argument '{}': {}", name, e))
})
}
fn parse_node_implementations<T: Parse>(attr: &Attribute, name: &Ident) -> syn::Result<Punctuated<T, Comma>> {
let content: TokenStream2 = attr.parse_args()?;
let parser = Punctuated::<T, Comma>::parse_terminated;
parser
.parse2(content)
.map_err(|e| Error::new_spanned(attr, format!("Failed to parse implementations for argument '{}': {}", name, e)))
parser.parse2(content.clone()).map_err(|e| {
Error::new(
e.span(),
formatdoc!(
"Invalid #[implementations(...)] for argument `{}`.
Expected a comma-separated list of `InputType -> OutputType` pairs.
Example: #[implementations(i32 -> f64, String -> Vec<u8>)]
Error: {}",
name,
e
),
)
})
}
fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Result<ParsedField> {
@ -300,11 +364,6 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
}
}
let implementations = extract_attribute(attrs, "implementations")
.map(|attr| parse_implementations(attr, ident))
.transpose()?
.unwrap_or_default();
let (is_node, node_input_type, node_output_type) = parse_node_type(&ty);
if is_node {
@ -315,7 +374,7 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
return Err(Error::new_spanned(&ty, "No default values for `impl Node` allowed"));
}
let implementations = extract_attribute(attrs, "implementations")
.map(|attr| parse_implementations(attr, ident))
.map(|attr| parse_node_implementations(attr, ident))
.transpose()?
.unwrap_or_default();
@ -327,6 +386,10 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
implementations,
})
} else {
let implementations = extract_attribute(attrs, "implementations")
.map(|attr| parse_implementations(attr, ident))
.transpose()?
.unwrap_or_default();
Ok(ParsedField::Regular {
pat_ident,
name,
@ -381,16 +444,16 @@ fn extract_attribute<'a>(attrs: &'a [Attribute], name: &str) -> Option<&'a Attri
// Modify the new_node_fn function to use the code generation
pub fn new_node_fn(attr: TokenStream2, item: TokenStream2) -> TokenStream2 {
match parse_node_fn(attr, item.clone()).and_then(|x| generate_node_code(&x)) {
Ok(parsed) => {
/*let generated_code = generate_node_code(&parsed);
// panic!("{}", generated_code.to_string());
quote! {
// #item
#generated_code
}*/
parsed
}
let parse_result = parse_node_fn(attr, item.clone());
let Ok(parsed_node) = parse_result else {
let e = parse_result.unwrap_err();
return Error::new(e.span(), format!("Failed to parse node function: {e}")).to_compile_error();
};
if let Err(e) = crate::validation::validate_node_fn(&parsed_node) {
return Error::new(e.span(), format!("Validation Error:\n{e}")).to_compile_error();
}
match generate_node_code(&parsed_node) {
Ok(parsed) => parsed,
Err(e) => {
// Return the error as a compile error
Error::new(e.span(), format!("Failed to parse node function: {}", e)).to_compile_error()
@ -403,7 +466,7 @@ mod tests {
use super::*;
use proc_macro2::Span;
use proc_macro_crate::FoundCrate;
use quote::quote;
use quote::{quote, quote_spanned};
use syn::parse_quote;
fn pat_ident(name: &str) -> PatIdent {
PatIdent {
@ -869,4 +932,59 @@ mod tests {
);
parse_node_fn(attr, input).unwrap();
}
#[test]
fn test_invalid_implementation_syntax() {
let attr = quote!(category("Test"));
let input = quote!(
fn test_node(_: (), #[implementations((Footprint, Color), (Footprint, ImageFrame<Color>))] input: impl Node<Footprint, Output = T>) -> T {
// Implementation details...
}
);
let result = parse_node_fn(attr, input);
assert!(result.is_err());
let error = result.unwrap_err();
let error_message = error.to_string();
assert!(error_message.contains("Invalid #[implementations(...)] for argument `input`"));
assert!(error_message.contains("Expected a comma-separated list of `InputType -> OutputType` pairs"));
assert!(error_message.contains("Expected `->` arrow after input type in #[implementations(...)] on a field of type `impl Node`"));
}
#[test]
fn test_implementation_on_first_arg() {
let attr = quote!(category("Test"));
// Use quote_spanned! to attach a specific span to the problematic part
let problem_span = proc_macro2::Span::call_site(); // You could create a custom span here if needed
let tuples = quote_spanned!(problem_span=> () ());
let input = quote! {
fn test_node(
#[implementations((), #tuples, Footprint)] footprint: F,
#[implementations(
() -> Color,
() -> ImageFrame<Color>,
() -> GradientStops,
Footprint -> Color,
Footprint -> ImageFrame<Color>,
Footprint -> GradientStops,
)]
image: impl Node<F, Output = T>,
) -> T {
// Implementation details...
}
};
let result = parse_node_fn(attr, input);
assert!(result.is_err(), "Expected an error, but parsing succeeded");
let error = result.unwrap_err();
let error_string = error.to_string();
assert!(error_string.contains("Failed to parse implementations for argument 'footprint'"));
assert!(error_string.contains("expected `,`"));
// Instead of checking for exact line and column,
// verify that the error span is the one we specified
assert_eq!(error.span().start(), problem_span.start());
}
}

View file

@ -0,0 +1,109 @@
use crate::parsing::{Implementation, ParsedField, ParsedNodeFn};
use proc_macro_error::emit_error;
use quote::quote;
use syn::{spanned::Spanned, GenericParam, Type};
pub fn validate_node_fn(parsed: &ParsedNodeFn) -> syn::Result<()> {
let validators: &[fn(&ParsedNodeFn)] = &[
// Add more validators here as needed
validate_implementations_for_generics,
validate_primary_input_expose,
];
for validator in validators {
validator(parsed);
}
Ok(())
}
fn validate_primary_input_expose(parsed: &ParsedNodeFn) {
if let Some(ParsedField::Regular { exposed: true, pat_ident, .. }) = parsed.fields.first() {
emit_error!(
pat_ident.span(),
"Unnecessary #[expose] attribute on primary input `{}`. Primary inputs are always exposed.",
pat_ident.ident;
help = "You can safely remove the #[expose] attribute from this field.";
note = "The function's second argument, `{}`, is the node's primary input and it's always exposed by default", pat_ident.ident
);
}
}
fn validate_implementations_for_generics(parsed: &ParsedNodeFn) {
let has_skip_impl = parsed.attributes.skip_impl;
if !has_skip_impl && !parsed.fn_generics.is_empty() {
for field in &parsed.fields {
match field {
ParsedField::Regular { ty, implementations, pat_ident, .. } => {
if contains_generic_param(ty, &parsed.fn_generics) && implementations.is_empty() {
emit_error!(
ty.span(),
"Generic type `{}` in field `{}` requires an #[implementations(...)] attribute",
quote!(#ty),
pat_ident.ident;
help = "Add #[implementations(ConcreteType1, ConcreteType2)] to field '{}'", pat_ident.ident;
help = "Or use #[skip_impl] if you want to manually implement the node"
);
}
}
ParsedField::Node {
input_type,
output_type,
implementations,
pat_ident,
..
} => {
if (contains_generic_param(input_type, &parsed.fn_generics) || contains_generic_param(output_type, &parsed.fn_generics)) && implementations.is_empty() {
emit_error!(
pat_ident.span(),
"Generic types in Node field `{}` require an #[implementations(...)] attribute",
pat_ident.ident;
help = "Add #[implementations(InputType1 -> OutputType1, InputType2 -> OutputType2)] to field '{}'", pat_ident.ident;
help = "Or use #[skip_impl] if you want to manually implement the node"
);
}
// Additional check for Node implementations
for impl_ in implementations {
validate_node_implementation(impl_, input_type, output_type, &parsed.fn_generics);
}
}
}
}
}
}
fn validate_node_implementation(impl_: &Implementation, input_type: &Type, output_type: &Type, fn_generics: &[GenericParam]) {
if contains_generic_param(&impl_.input, fn_generics) || contains_generic_param(&impl_.output, fn_generics) {
emit_error!(
impl_.input.span(),
"Implementation types `{}` and `{}` must be concrete, not generic",
quote!(#input_type), quote!(#output_type);
help = "Replace generic types with concrete types in the implementation"
);
}
}
fn contains_generic_param(ty: &Type, fn_generics: &[GenericParam]) -> bool {
struct GenericParamChecker<'a> {
fn_generics: &'a [GenericParam],
found: bool,
}
impl<'a> syn::visit::Visit<'a> for GenericParamChecker<'a> {
fn visit_ident(&mut self, ident: &'a syn::Ident) {
if self
.fn_generics
.iter()
.any(|param| if let GenericParam::Type(type_param) = param { type_param.ident == *ident } else { false })
{
self.found = true;
}
}
}
let mut checker = GenericParamChecker { fn_generics, found: false };
syn::visit::visit_type(&mut checker, ty);
checker.found
}