mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-08-04 19:08:32 +00:00
Update
derive field overwrite support
This commit is contained in:
parent
04053c1ce3
commit
9cc1cc2e9c
9 changed files with 247 additions and 31 deletions
|
@ -80,12 +80,12 @@ pub fn tracked(args: TokenStream, input: TokenStream) -> TokenStream {
|
|||
tracked::tracked(args, input)
|
||||
}
|
||||
|
||||
#[proc_macro_derive(Update)]
|
||||
#[proc_macro_derive(Update, attributes(update))]
|
||||
pub fn update(input: TokenStream) -> TokenStream {
|
||||
let item = parse_macro_input!(input as syn::DeriveInput);
|
||||
match update::update_derive(item) {
|
||||
Ok(tokens) => tokens.into(),
|
||||
Err(error) => token_stream_with_error(input, error),
|
||||
Err(error) => error.into_compile_error().into(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use proc_macro2::{Literal, TokenStream};
|
||||
use syn::spanned::Spanned;
|
||||
use proc_macro2::{Literal, Span, TokenStream};
|
||||
use syn::{parenthesized, parse::ParseStream, spanned::Spanned, Token};
|
||||
use synstructure::BindStyle;
|
||||
|
||||
use crate::hygiene::Hygiene;
|
||||
|
@ -7,9 +7,9 @@ use crate::hygiene::Hygiene;
|
|||
pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream> {
|
||||
let hygiene = Hygiene::from2(&input);
|
||||
|
||||
if let syn::Data::Union(_) = &input.data {
|
||||
if let syn::Data::Union(u) = &input.data {
|
||||
return Err(syn::Error::new_spanned(
|
||||
&input.ident,
|
||||
u.union_token,
|
||||
"`derive(Update)` does not support `union`",
|
||||
));
|
||||
}
|
||||
|
@ -27,6 +27,24 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
|
|||
.variants()
|
||||
.iter()
|
||||
.map(|variant| {
|
||||
let err = variant
|
||||
.ast()
|
||||
.attrs
|
||||
.iter()
|
||||
.filter(|attr| attr.path().is_ident("update"))
|
||||
.map(|attr| {
|
||||
syn::Error::new(
|
||||
attr.path().span(),
|
||||
"unexpected attribute `#[update]` on variant",
|
||||
)
|
||||
})
|
||||
.reduce(|mut acc, err| {
|
||||
acc.combine(err);
|
||||
acc
|
||||
});
|
||||
if let Some(err) = err {
|
||||
return Err(err);
|
||||
}
|
||||
let variant_pat = variant.pat();
|
||||
|
||||
// First check that the `new_value` has same variant.
|
||||
|
@ -35,7 +53,7 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
|
|||
.bindings()
|
||||
.iter()
|
||||
.fold(quote!(), |tokens, binding| quote!(#tokens #binding,));
|
||||
let make_new_value = quote_spanned! {variant.ast().ident.span()=>
|
||||
let make_new_value = quote! {
|
||||
let #new_value = if let #variant_pat = #new_value {
|
||||
(#make_tuple)
|
||||
} else {
|
||||
|
@ -47,40 +65,73 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
|
|||
// For each field, invoke `maybe_update` recursively to update its value.
|
||||
// Or the results together (using `|`, not `||`, to avoid shortcircuiting)
|
||||
// to get the final return value.
|
||||
let update_fields = variant.bindings().iter().enumerate().fold(
|
||||
quote!(false),
|
||||
|tokens, (index, binding)| {
|
||||
let field_ty = &binding.ast().ty;
|
||||
let field_index = Literal::usize_unsuffixed(index);
|
||||
let mut update_fields = quote!(false);
|
||||
for (index, binding) in variant.bindings().iter().enumerate() {
|
||||
let mut attrs = binding
|
||||
.ast()
|
||||
.attrs
|
||||
.iter()
|
||||
.filter(|attr| attr.path().is_ident("update"));
|
||||
let attr = attrs.next();
|
||||
if let Some(attr) = attrs.next() {
|
||||
return Err(syn::Error::new(
|
||||
attr.path().span(),
|
||||
"multiple #[update(with)] attributes on field",
|
||||
));
|
||||
}
|
||||
|
||||
let field_span = binding
|
||||
.ast()
|
||||
.ident
|
||||
.as_ref()
|
||||
.map(Spanned::span)
|
||||
.unwrap_or(binding.ast().span());
|
||||
let field_ty = &binding.ast().ty;
|
||||
let field_index = Literal::usize_unsuffixed(index);
|
||||
|
||||
let update_field = quote_spanned! {field_span=>
|
||||
salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update(
|
||||
#binding,
|
||||
#new_value.#field_index,
|
||||
)
|
||||
};
|
||||
let (maybe_update, unsafe_token) = match attr {
|
||||
Some(attr) => {
|
||||
mod kw {
|
||||
syn::custom_keyword!(with);
|
||||
}
|
||||
attr.parse_args_with(|parser: ParseStream| {
|
||||
let mut content;
|
||||
|
||||
quote! {
|
||||
#tokens | unsafe { #update_field }
|
||||
let unsafe_token = parser.parse::<Token![unsafe]>()?;
|
||||
parenthesized!(content in parser);
|
||||
let with_token = content.parse::<kw::with>()?;
|
||||
parenthesized!(content in content);
|
||||
let expr = content.parse::<syn::Expr>()?;
|
||||
Ok((
|
||||
quote_spanned! { with_token.span() => ({ let maybe_update: unsafe fn(*mut #field_ty, #field_ty) -> bool = #expr; maybe_update }) },
|
||||
// quote_spanned! { with_token.span() => ((#expr) as unsafe fn(*mut #field_ty, #field_ty) -> bool) },
|
||||
unsafe_token,
|
||||
))
|
||||
})?
|
||||
}
|
||||
},
|
||||
);
|
||||
None => {
|
||||
(
|
||||
quote!(
|
||||
salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update
|
||||
),
|
||||
Token),
|
||||
)
|
||||
}
|
||||
};
|
||||
let update_field = quote! {
|
||||
#maybe_update(
|
||||
#binding,
|
||||
#new_value.#field_index,
|
||||
)
|
||||
};
|
||||
|
||||
quote!(
|
||||
update_fields = quote! {
|
||||
#update_fields | #unsafe_token { #update_field }
|
||||
};
|
||||
}
|
||||
|
||||
Ok(quote!(
|
||||
#variant_pat => {
|
||||
#make_new_value
|
||||
#update_fields
|
||||
}
|
||||
)
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
.collect::<syn::Result<_>>()?;
|
||||
|
||||
let ident = &input.ident;
|
||||
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
|
||||
|
|
16
tests/compile-fail/derive_update_expansion_failure.rs
Normal file
16
tests/compile-fail/derive_update_expansion_failure.rs
Normal file
|
@ -0,0 +1,16 @@
|
|||
#[derive(salsa::Update)]
|
||||
union U {
|
||||
field: i32,
|
||||
}
|
||||
|
||||
#[derive(salsa::Update)]
|
||||
struct S {
|
||||
#[update(with(missing_unsafe))]
|
||||
bad: i32,
|
||||
}
|
||||
|
||||
fn missing_unsafe(_: *mut i32, _: i32) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn main() {}
|
11
tests/compile-fail/derive_update_expansion_failure.stderr
Normal file
11
tests/compile-fail/derive_update_expansion_failure.stderr
Normal file
|
@ -0,0 +1,11 @@
|
|||
error: `derive(Update)` does not support `union`
|
||||
--> tests/compile-fail/derive_update_expansion_failure.rs:2:1
|
||||
|
|
||||
2 | union U {
|
||||
| ^^^^^
|
||||
|
||||
error: expected `unsafe`
|
||||
--> tests/compile-fail/derive_update_expansion_failure.rs:8:14
|
||||
|
|
||||
8 | #[update(with(missing_unsafe))]
|
||||
| ^^^^
|
6
tests/compile-fail/invalid_update_field.rs
Normal file
6
tests/compile-fail/invalid_update_field.rs
Normal file
|
@ -0,0 +1,6 @@
|
|||
#[derive(salsa::Update)]
|
||||
struct S2<'a> {
|
||||
bad2: &'a str,
|
||||
}
|
||||
|
||||
fn main() {}
|
9
tests/compile-fail/invalid_update_field.stderr
Normal file
9
tests/compile-fail/invalid_update_field.stderr
Normal file
|
@ -0,0 +1,9 @@
|
|||
error: lifetime may not live long enough
|
||||
--> tests/compile-fail/invalid_update_field.rs:1:10
|
||||
|
|
||||
1 | #[derive(salsa::Update)]
|
||||
| ^^^^^^^^^^^^^ requires that `'a` must outlive `'static`
|
||||
2 | struct S2<'a> {
|
||||
| -- lifetime `'a` defined here
|
||||
|
|
||||
= note: this error originates in the derive macro `salsa::Update` (in Nightly builds, run with -Z macro-backtrace for more info)
|
19
tests/compile-fail/invalid_update_with.rs
Normal file
19
tests/compile-fail/invalid_update_with.rs
Normal file
|
@ -0,0 +1,19 @@
|
|||
#[derive(salsa::Update)]
|
||||
struct S2 {
|
||||
#[update(unsafe(with(my_wrong_update)))]
|
||||
bad: i32,
|
||||
#[update(unsafe(with(my_wrong_update2)))]
|
||||
bad2: i32,
|
||||
#[update(unsafe(with(my_wrong_update3)))]
|
||||
bad3: i32,
|
||||
#[update(unsafe(with(true)))]
|
||||
bad4: &'static str,
|
||||
}
|
||||
|
||||
fn my_wrong_update() {}
|
||||
fn my_wrong_update2(_: (), _: ()) -> bool {
|
||||
true
|
||||
}
|
||||
fn my_wrong_update3(_: *mut i32, _: i32) -> () {}
|
||||
|
||||
fn main() {}
|
43
tests/compile-fail/invalid_update_with.stderr
Normal file
43
tests/compile-fail/invalid_update_with.stderr
Normal file
|
@ -0,0 +1,43 @@
|
|||
error[E0308]: mismatched types
|
||||
--> tests/compile-fail/invalid_update_with.rs:3:26
|
||||
|
|
||||
3 | #[update(unsafe(with(my_wrong_update)))]
|
||||
| ---- ^^^^^^^^^^^^^^^ incorrect number of function parameters
|
||||
| |
|
||||
| expected due to this
|
||||
|
|
||||
= note: expected fn pointer `unsafe fn(*mut i32, i32) -> bool`
|
||||
found fn item `fn() -> () {my_wrong_update}`
|
||||
|
||||
error[E0308]: mismatched types
|
||||
--> tests/compile-fail/invalid_update_with.rs:5:26
|
||||
|
|
||||
5 | #[update(unsafe(with(my_wrong_update2)))]
|
||||
| ---- ^^^^^^^^^^^^^^^^ expected fn pointer, found fn item
|
||||
| |
|
||||
| expected due to this
|
||||
|
|
||||
= note: expected fn pointer `unsafe fn(*mut i32, i32) -> bool`
|
||||
found fn item `fn((), ()) -> bool {my_wrong_update2}`
|
||||
|
||||
error[E0308]: mismatched types
|
||||
--> tests/compile-fail/invalid_update_with.rs:7:26
|
||||
|
|
||||
7 | #[update(unsafe(with(my_wrong_update3)))]
|
||||
| ---- ^^^^^^^^^^^^^^^^ expected fn pointer, found fn item
|
||||
| |
|
||||
| expected due to this
|
||||
|
|
||||
= note: expected fn pointer `unsafe fn(*mut i32, i32) -> bool`
|
||||
found fn item `fn(*mut i32, i32) -> () {my_wrong_update3}`
|
||||
|
||||
error[E0308]: mismatched types
|
||||
--> tests/compile-fail/invalid_update_with.rs:9:26
|
||||
|
|
||||
9 | #[update(unsafe(with(true)))]
|
||||
| ---- ^^^^ expected fn pointer, found `bool`
|
||||
| |
|
||||
| expected due to this
|
||||
|
|
||||
= note: expected fn pointer `unsafe fn(*mut &'static str, &'static str) -> bool`
|
||||
found type `bool`
|
61
tests/derive_update.rs
Normal file
61
tests/derive_update.rs
Normal file
|
@ -0,0 +1,61 @@
|
|||
//! Test that the `Update` derive works as expected
|
||||
|
||||
#[derive(salsa::Update)]
|
||||
struct MyInput {
|
||||
field: &'static str,
|
||||
}
|
||||
|
||||
#[derive(salsa::Update)]
|
||||
struct MyInput2 {
|
||||
#[update(unsafe(with(custom_update)))]
|
||||
field: &'static str,
|
||||
#[update(unsafe(with(|dest, data| { *dest = data; true })))]
|
||||
field2: &'static str,
|
||||
}
|
||||
|
||||
unsafe fn custom_update(dest: *mut &'static str, _data: &'static str) -> bool {
|
||||
unsafe { *dest = "ill-behaved for testing purposes" };
|
||||
true
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn derived() {
|
||||
let mut m = MyInput { field: "foo" };
|
||||
assert_eq!(m.field, "foo");
|
||||
assert!(unsafe { salsa::Update::maybe_update(&mut m, MyInput { field: "bar" }) });
|
||||
assert_eq!(m.field, "bar");
|
||||
assert!(!unsafe { salsa::Update::maybe_update(&mut m, MyInput { field: "bar" }) });
|
||||
assert_eq!(m.field, "bar");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn derived_with() {
|
||||
let mut m = MyInput2 {
|
||||
field: "foo",
|
||||
field2: "foo",
|
||||
};
|
||||
assert_eq!(m.field, "foo");
|
||||
assert_eq!(m.field2, "foo");
|
||||
assert!(unsafe {
|
||||
salsa::Update::maybe_update(
|
||||
&mut m,
|
||||
MyInput2 {
|
||||
field: "bar",
|
||||
field2: "bar",
|
||||
},
|
||||
)
|
||||
});
|
||||
assert_eq!(m.field, "ill-behaved for testing purposes");
|
||||
assert_eq!(m.field2, "bar");
|
||||
assert!(unsafe {
|
||||
salsa::Update::maybe_update(
|
||||
&mut m,
|
||||
MyInput2 {
|
||||
field: "ill-behaved for testing purposes",
|
||||
field2: "foo",
|
||||
},
|
||||
)
|
||||
});
|
||||
assert_eq!(m.field, "ill-behaved for testing purposes");
|
||||
assert_eq!(m.field2, "foo");
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue