Update derive field overwrite support

This commit is contained in:
Lukas Wirth 2025-03-05 15:22:20 +01:00 committed by Lukas Wirth
parent 04053c1ce3
commit 9cc1cc2e9c
9 changed files with 247 additions and 31 deletions

View file

@ -80,12 +80,12 @@ pub fn tracked(args: TokenStream, input: TokenStream) -> TokenStream {
tracked::tracked(args, input)
}
#[proc_macro_derive(Update)]
#[proc_macro_derive(Update, attributes(update))]
pub fn update(input: TokenStream) -> TokenStream {
let item = parse_macro_input!(input as syn::DeriveInput);
match update::update_derive(item) {
Ok(tokens) => tokens.into(),
Err(error) => token_stream_with_error(input, error),
Err(error) => error.into_compile_error().into(),
}
}

View file

@ -1,5 +1,5 @@
use proc_macro2::{Literal, TokenStream};
use syn::spanned::Spanned;
use proc_macro2::{Literal, Span, TokenStream};
use syn::{parenthesized, parse::ParseStream, spanned::Spanned, Token};
use synstructure::BindStyle;
use crate::hygiene::Hygiene;
@ -7,9 +7,9 @@ use crate::hygiene::Hygiene;
pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream> {
let hygiene = Hygiene::from2(&input);
if let syn::Data::Union(_) = &input.data {
if let syn::Data::Union(u) = &input.data {
return Err(syn::Error::new_spanned(
&input.ident,
u.union_token,
"`derive(Update)` does not support `union`",
));
}
@ -27,6 +27,24 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
.variants()
.iter()
.map(|variant| {
let err = variant
.ast()
.attrs
.iter()
.filter(|attr| attr.path().is_ident("update"))
.map(|attr| {
syn::Error::new(
attr.path().span(),
"unexpected attribute `#[update]` on variant",
)
})
.reduce(|mut acc, err| {
acc.combine(err);
acc
});
if let Some(err) = err {
return Err(err);
}
let variant_pat = variant.pat();
// First check that the `new_value` has same variant.
@ -35,7 +53,7 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
.bindings()
.iter()
.fold(quote!(), |tokens, binding| quote!(#tokens #binding,));
let make_new_value = quote_spanned! {variant.ast().ident.span()=>
let make_new_value = quote! {
let #new_value = if let #variant_pat = #new_value {
(#make_tuple)
} else {
@ -47,40 +65,73 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
// For each field, invoke `maybe_update` recursively to update its value.
// Or the results together (using `|`, not `||`, to avoid shortcircuiting)
// to get the final return value.
let update_fields = variant.bindings().iter().enumerate().fold(
quote!(false),
|tokens, (index, binding)| {
let field_ty = &binding.ast().ty;
let field_index = Literal::usize_unsuffixed(index);
let mut update_fields = quote!(false);
for (index, binding) in variant.bindings().iter().enumerate() {
let mut attrs = binding
.ast()
.attrs
.iter()
.filter(|attr| attr.path().is_ident("update"));
let attr = attrs.next();
if let Some(attr) = attrs.next() {
return Err(syn::Error::new(
attr.path().span(),
"multiple #[update(with)] attributes on field",
));
}
let field_span = binding
.ast()
.ident
.as_ref()
.map(Spanned::span)
.unwrap_or(binding.ast().span());
let field_ty = &binding.ast().ty;
let field_index = Literal::usize_unsuffixed(index);
let update_field = quote_spanned! {field_span=>
salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update(
#binding,
#new_value.#field_index,
)
};
let (maybe_update, unsafe_token) = match attr {
Some(attr) => {
mod kw {
syn::custom_keyword!(with);
}
attr.parse_args_with(|parser: ParseStream| {
let mut content;
quote! {
#tokens | unsafe { #update_field }
let unsafe_token = parser.parse::<Token![unsafe]>()?;
parenthesized!(content in parser);
let with_token = content.parse::<kw::with>()?;
parenthesized!(content in content);
let expr = content.parse::<syn::Expr>()?;
Ok((
quote_spanned! { with_token.span() => ({ let maybe_update: unsafe fn(*mut #field_ty, #field_ty) -> bool = #expr; maybe_update }) },
// quote_spanned! { with_token.span() => ((#expr) as unsafe fn(*mut #field_ty, #field_ty) -> bool) },
unsafe_token,
))
})?
}
},
);
None => {
(
quote!(
salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update
),
Token![unsafe](Span::call_site()),
)
}
};
let update_field = quote! {
#maybe_update(
#binding,
#new_value.#field_index,
)
};
quote!(
update_fields = quote! {
#update_fields | #unsafe_token { #update_field }
};
}
Ok(quote!(
#variant_pat => {
#make_new_value
#update_fields
}
)
))
})
.collect();
.collect::<syn::Result<_>>()?;
let ident = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

View file

@ -0,0 +1,16 @@
#[derive(salsa::Update)]
union U {
field: i32,
}
#[derive(salsa::Update)]
struct S {
#[update(with(missing_unsafe))]
bad: i32,
}
fn missing_unsafe(_: *mut i32, _: i32) -> bool {
true
}
fn main() {}

View file

@ -0,0 +1,11 @@
error: `derive(Update)` does not support `union`
--> tests/compile-fail/derive_update_expansion_failure.rs:2:1
|
2 | union U {
| ^^^^^
error: expected `unsafe`
--> tests/compile-fail/derive_update_expansion_failure.rs:8:14
|
8 | #[update(with(missing_unsafe))]
| ^^^^

View file

@ -0,0 +1,6 @@
#[derive(salsa::Update)]
struct S2<'a> {
bad2: &'a str,
}
fn main() {}

View file

@ -0,0 +1,9 @@
error: lifetime may not live long enough
--> tests/compile-fail/invalid_update_field.rs:1:10
|
1 | #[derive(salsa::Update)]
| ^^^^^^^^^^^^^ requires that `'a` must outlive `'static`
2 | struct S2<'a> {
| -- lifetime `'a` defined here
|
= note: this error originates in the derive macro `salsa::Update` (in Nightly builds, run with -Z macro-backtrace for more info)

View file

@ -0,0 +1,19 @@
#[derive(salsa::Update)]
struct S2 {
#[update(unsafe(with(my_wrong_update)))]
bad: i32,
#[update(unsafe(with(my_wrong_update2)))]
bad2: i32,
#[update(unsafe(with(my_wrong_update3)))]
bad3: i32,
#[update(unsafe(with(true)))]
bad4: &'static str,
}
fn my_wrong_update() {}
fn my_wrong_update2(_: (), _: ()) -> bool {
true
}
fn my_wrong_update3(_: *mut i32, _: i32) -> () {}
fn main() {}

View file

@ -0,0 +1,43 @@
error[E0308]: mismatched types
--> tests/compile-fail/invalid_update_with.rs:3:26
|
3 | #[update(unsafe(with(my_wrong_update)))]
| ---- ^^^^^^^^^^^^^^^ incorrect number of function parameters
| |
| expected due to this
|
= note: expected fn pointer `unsafe fn(*mut i32, i32) -> bool`
found fn item `fn() -> () {my_wrong_update}`
error[E0308]: mismatched types
--> tests/compile-fail/invalid_update_with.rs:5:26
|
5 | #[update(unsafe(with(my_wrong_update2)))]
| ---- ^^^^^^^^^^^^^^^^ expected fn pointer, found fn item
| |
| expected due to this
|
= note: expected fn pointer `unsafe fn(*mut i32, i32) -> bool`
found fn item `fn((), ()) -> bool {my_wrong_update2}`
error[E0308]: mismatched types
--> tests/compile-fail/invalid_update_with.rs:7:26
|
7 | #[update(unsafe(with(my_wrong_update3)))]
| ---- ^^^^^^^^^^^^^^^^ expected fn pointer, found fn item
| |
| expected due to this
|
= note: expected fn pointer `unsafe fn(*mut i32, i32) -> bool`
found fn item `fn(*mut i32, i32) -> () {my_wrong_update3}`
error[E0308]: mismatched types
--> tests/compile-fail/invalid_update_with.rs:9:26
|
9 | #[update(unsafe(with(true)))]
| ---- ^^^^ expected fn pointer, found `bool`
| |
| expected due to this
|
= note: expected fn pointer `unsafe fn(*mut &'static str, &'static str) -> bool`
found type `bool`

61
tests/derive_update.rs Normal file
View file

@ -0,0 +1,61 @@
//! Test that the `Update` derive works as expected
#[derive(salsa::Update)]
struct MyInput {
field: &'static str,
}
#[derive(salsa::Update)]
struct MyInput2 {
#[update(unsafe(with(custom_update)))]
field: &'static str,
#[update(unsafe(with(|dest, data| { *dest = data; true })))]
field2: &'static str,
}
unsafe fn custom_update(dest: *mut &'static str, _data: &'static str) -> bool {
unsafe { *dest = "ill-behaved for testing purposes" };
true
}
#[test]
fn derived() {
let mut m = MyInput { field: "foo" };
assert_eq!(m.field, "foo");
assert!(unsafe { salsa::Update::maybe_update(&mut m, MyInput { field: "bar" }) });
assert_eq!(m.field, "bar");
assert!(!unsafe { salsa::Update::maybe_update(&mut m, MyInput { field: "bar" }) });
assert_eq!(m.field, "bar");
}
#[test]
fn derived_with() {
let mut m = MyInput2 {
field: "foo",
field2: "foo",
};
assert_eq!(m.field, "foo");
assert_eq!(m.field2, "foo");
assert!(unsafe {
salsa::Update::maybe_update(
&mut m,
MyInput2 {
field: "bar",
field2: "bar",
},
)
});
assert_eq!(m.field, "ill-behaved for testing purposes");
assert_eq!(m.field2, "bar");
assert!(unsafe {
salsa::Update::maybe_update(
&mut m,
MyInput2 {
field: "ill-behaved for testing purposes",
field2: "foo",
},
)
});
assert_eq!(m.field, "ill-behaved for testing purposes");
assert_eq!(m.field2, "foo");
}