Use Fallback trick for tracked function Update constraint, implement Update for smallvec and compact_str

This commit is contained in:
Micha Reiser 2025-02-08 20:17:00 +01:00
parent 538eaadbe0
commit acbee94fbe
No known key found for this signature in database
10 changed files with 166 additions and 47 deletions

View file

@ -10,6 +10,7 @@ rust-version = "1.76"
[dependencies] [dependencies]
arc-swap = "1" arc-swap = "1"
compact_str = { version = "0.8", optional = true }
crossbeam = "0.8" crossbeam = "0.8"
dashmap = { version = "6", features = ["raw-api"] } dashmap = { version = "6", features = ["raw-api"] }
hashlink = "0.9" hashlink = "0.9"

View file

@ -145,6 +145,20 @@ macro_rules! setup_tracked_fn {
} }
} }
/// This method isn't used anywhere. It only exitst to enforce the `Self::Output: Update` constraint
/// for types that aren't `'static`.
///
/// # Safety
/// The same safety rules as for `Update` apply.
unsafe fn _implements_update<'db>(old_pointer: *mut $output_ty, new_value: $output_ty) -> bool {
unsafe {
use $zalsa::UpdateFallback;
$zalsa::UpdateDispatch::<$output_ty>::maybe_update(
old_pointer, new_value
)
}
}
impl $zalsa::function::Configuration for $Configuration { impl $zalsa::function::Configuration for $Configuration {
const DEBUG_NAME: &'static str = stringify!($fn_name); const DEBUG_NAME: &'static str = stringify!($fn_name);

View file

@ -1,4 +1,5 @@
use proc_macro2::{Literal, TokenStream}; use proc_macro2::{Literal, TokenStream};
use syn::spanned::Spanned;
use synstructure::BindStyle; use synstructure::BindStyle;
use crate::hygiene::Hygiene; use crate::hygiene::Hygiene;
@ -34,7 +35,7 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
.bindings() .bindings()
.iter() .iter()
.fold(quote!(), |tokens, binding| quote!(#tokens #binding,)); .fold(quote!(), |tokens, binding| quote!(#tokens #binding,));
let make_new_value = quote! { let make_new_value = quote_spanned! {variant.ast().ident.span()=>
let #new_value = if let #variant_pat = #new_value { let #new_value = if let #variant_pat = #new_value {
(#make_tuple) (#make_tuple)
} else { } else {
@ -46,20 +47,28 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
// For each field, invoke `maybe_update` recursively to update its value. // For each field, invoke `maybe_update` recursively to update its value.
// Or the results together (using `|`, not `||`, to avoid shortcircuiting) // Or the results together (using `|`, not `||`, to avoid shortcircuiting)
// to get the final return value. // to get the final return value.
let update_fields = variant.bindings().iter().zip(0..).fold( let update_fields = variant.bindings().iter().enumerate().fold(
quote!(false), quote!(false),
|tokens, (binding, index)| { |tokens, (index, binding)| {
let field_ty = &binding.ast().ty; let field_ty = &binding.ast().ty;
let field_index = Literal::usize_unsuffixed(index); let field_index = Literal::usize_unsuffixed(index);
let field_span = binding
.ast()
.ident
.as_ref()
.map(Spanned::span)
.unwrap_or(binding.ast().span());
let update_field = quote_spanned! {field_span=>
salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update(
#binding,
#new_value.#field_index,
)
};
quote! { quote! {
#tokens | #tokens | unsafe { #update_field }
unsafe {
salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update(
#binding,
#new_value.#field_index,
)
}
} }
}, },
); );
@ -77,6 +86,7 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let tokens = quote! { let tokens = quote! {
#[allow(clippy::all)] #[allow(clippy::all)]
#[automatically_derived]
unsafe impl #impl_generics salsa::Update for #ident #ty_generics #where_clause { unsafe impl #impl_generics salsa::Update for #ident #ty_generics #where_clause {
unsafe fn maybe_update(#old_pointer: *mut Self, #new_value: Self) -> bool { unsafe fn maybe_update(#old_pointer: *mut Self, #new_value: Self) -> bool {
use ::salsa::plumbing::UpdateFallback as _; use ::salsa::plumbing::UpdateFallback as _;

View file

@ -65,7 +65,7 @@ pub enum ExpressionData<'db> {
Call(FunctionId<'db>, Vec<Expression<'db>>), Call(FunctionId<'db>, Vec<Expression<'db>>),
} }
#[derive(Eq, PartialEq, Copy, Clone, Hash, Debug, salsa::Update)] #[derive(Eq, PartialEq, Copy, Clone, Hash, Debug)]
pub enum Op { pub enum Op {
Add, Add,
Subtract, Subtract,

View file

@ -9,7 +9,7 @@ use crate::{
salsa_struct::SalsaStructInDb, salsa_struct::SalsaStructInDb,
zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa}, zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa},
zalsa_local::QueryOrigin, zalsa_local::QueryOrigin,
Cycle, Database, Id, Revision, Update, Cycle, Database, Id, Revision,
}; };
use self::delete::DeletedEntries; use self::delete::DeletedEntries;
@ -43,7 +43,7 @@ pub trait Configuration: Any {
type Input<'db>: Send + Sync; type Input<'db>: Send + Sync;
/// The value computed by the function. /// The value computed by the function.
type Output<'db>: fmt::Debug + Send + Sync + Update; type Output<'db>: fmt::Debug + Send + Sync;
/// Determines whether this function can recover from being a participant in a cycle /// Determines whether this function can recover from being a participant in a cycle
/// (and, if so, how). /// (and, if so, how).

View file

@ -1,6 +1,7 @@
use std::{ use std::{
collections::{BTreeMap, BTreeSet, HashMap, HashSet}, collections::{BTreeMap, BTreeSet, HashMap, HashSet},
hash::{BuildHasher, Hash}, hash::{BuildHasher, Hash},
marker::PhantomData,
path::PathBuf, path::PathBuf,
sync::Arc, sync::Arc,
}; };
@ -188,6 +189,29 @@ where
} }
} }
unsafe impl<A> Update for smallvec::SmallVec<A>
where
A: smallvec::Array,
A::Item: Update,
{
unsafe fn maybe_update(old_pointer: *mut Self, new_vec: Self) -> bool {
let old_vec: &mut smallvec::SmallVec<A> = unsafe { &mut *old_pointer };
if old_vec.len() != new_vec.len() {
old_vec.clear();
old_vec.extend(new_vec);
return true;
}
let mut changed = false;
for (old_element, new_element) in old_vec.iter_mut().zip(new_vec) {
changed |= A::Item::maybe_update(old_element, new_element);
}
changed
}
}
macro_rules! maybe_update_set { macro_rules! maybe_update_set {
($old_pointer: expr, $new_set: expr) => {{ ($old_pointer: expr, $new_set: expr) => {{
let old_pointer = $old_pointer; let old_pointer = $old_pointer;
@ -291,6 +315,26 @@ where
} }
} }
unsafe impl<T> Update for Box<[T]>
where
T: Update,
{
unsafe fn maybe_update(old_pointer: *mut Self, new_box: Self) -> bool {
let old_box: &mut Box<[T]> = unsafe { &mut *old_pointer };
if old_box.len() == new_box.len() {
let mut changed = false;
for (old_element, new_element) in old_box.iter_mut().zip(new_box) {
changed |= T::maybe_update(old_element, new_element);
}
changed
} else {
*old_box = new_box;
true
}
}
}
unsafe impl<T> Update for Arc<T> unsafe impl<T> Update for Arc<T>
where where
T: Update, T: Update,
@ -398,6 +442,9 @@ fallback_impl! {
PathBuf, PathBuf,
} }
#[cfg(feature = "compact_str")]
fallback_impl! { compact_str::CompactString, }
macro_rules! tuple_impl { macro_rules! tuple_impl {
($($t:ident),*; $($u:ident),*) => { ($($t:ident),*; $($u:ident),*) => {
unsafe impl<$($t),*> Update for ($($t,)*) unsafe impl<$($t),*> Update for ($($t,)*)
@ -451,3 +498,9 @@ where
} }
} }
} }
unsafe impl<T> Update for PhantomData<T> {
unsafe fn maybe_update(_old_pointer: *mut Self, _new_value: Self) -> bool {
false
}
}

View file

@ -6,37 +6,24 @@ warning: unused import: `salsa::Update`
| |
= note: `#[warn(unused_imports)]` on by default = note: `#[warn(unused_imports)]` on by default
error[E0277]: the trait bound `&'db str: Update` is not satisfied error: lifetime may not live long enough
--> tests/compile-fail/tracked_fn_return_ref.rs:16:67 --> tests/compile-fail/tracked_fn_return_ref.rs:15:1
| |
16 | fn tracked_fn_return_ref<'db>(db: &'db dyn Db, input: MyInput) -> &'db str { 15 | #[salsa::tracked]
| ^^^^^^^^ the trait `Update` is not implemented for `&'db str` | ^^^^^^^^^^^^^^^^^
| |
| lifetime `'db` defined here
| requires that `'db` must outlive `'static`
| |
= help: the trait `Update` is implemented for `String` = note: this error originates in the macro `salsa::plumbing::setup_tracked_fn` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info)
note: required by a bound in `salsa::plumbing::function::Configuration::Output`
--> src/function.rs
|
| type Output<'db>: fmt::Debug + Send + Sync + Update;
| ^^^^^^ required by this bound in `Configuration::Output`
error[E0277]: the trait bound `ContainsRef<'db>: Update` is not satisfied error: lifetime may not live long enough
--> tests/compile-fail/tracked_fn_return_ref.rs:24:6 --> tests/compile-fail/tracked_fn_return_ref.rs:20:1
| |
24 | ) -> ContainsRef<'db> { 20 | #[salsa::tracked]
| ^^^^^^^^^^^^^^^^ the trait `Update` is not implemented for `ContainsRef<'db>` | ^^^^^^^^^^^^^^^^^
| |
| lifetime `'db` defined here
| requires that `'db` must outlive `'static`
| |
= help: the following other types implement trait `Update`: = note: this error originates in the macro `salsa::plumbing::setup_tracked_fn` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info)
()
(A, B)
(A, B, C)
(A, B, C, D)
(A, B, C, D, E)
(A, B, C, D, E, F)
(A, B, C, D, E, F, G)
(A, B, C, D, E, F, G, H)
and $N others
note: required by a bound in `salsa::plumbing::function::Configuration::Output`
--> src/function.rs
|
| type Output<'db>: fmt::Debug + Send + Sync + Update;
| ^^^^^^ required by this bound in `Configuration::Output`

View file

@ -8,10 +8,10 @@ use std::sync::{
mod common; mod common;
use common::LogDatabase; use common::LogDatabase;
use salsa::{Database as _, Update}; use salsa::Database as _;
use test_log::test; use test_log::test;
#[derive(Debug, PartialEq, Eq, Update)] #[derive(Debug, PartialEq, Eq)]
struct HotPotato(u32); struct HotPotato(u32);
thread_local! { thread_local! {

56
tests/tracked_struct.rs Normal file
View file

@ -0,0 +1,56 @@
mod common;
use salsa::{Database, Setter};
#[salsa::tracked]
struct Tracked<'db> {
untracked_1: usize,
untracked_2: usize,
}
#[salsa::input]
struct MyInput {
field1: usize,
field2: usize,
}
#[salsa::tracked]
fn intermediate(db: &dyn salsa::Database, input: MyInput) -> Tracked<'_> {
Tracked::new(db, input.field1(db), input.field2(db))
}
#[salsa::tracked]
fn accumulate(db: &dyn salsa::Database, input: MyInput) -> (usize, usize) {
let tracked = intermediate(db, input);
let one = read_tracked_1(db, tracked);
let two = read_tracked_2(db, tracked);
(one, two)
}
#[salsa::tracked]
fn read_tracked_1<'db>(db: &'db dyn Database, tracked: Tracked<'db>) -> usize {
tracked.untracked_1(db)
}
#[salsa::tracked]
fn read_tracked_2<'db>(db: &'db dyn Database, tracked: Tracked<'db>) -> usize {
tracked.untracked_2(db)
}
#[test]
fn execute() {
let mut db = salsa::DatabaseImpl::default();
let input = MyInput::new(&db, 1, 1);
assert_eq!(accumulate(&db, input), (1, 1));
// Should only re-execute `read_tracked_1`.
input.set_field1(&mut db).to(2);
assert_eq!(accumulate(&db, input), (2, 1));
// Should only re-execute `read_tracked_2`.
input.set_field2(&mut db).to(2);
assert_eq!(accumulate(&db, input), (2, 2));
}

View file

@ -1,9 +1,7 @@
use salsa::Update;
#[salsa::db] #[salsa::db]
pub trait Db: salsa::Database {} pub trait Db: salsa::Database {}
#[derive(Debug, PartialEq, Eq, Hash, Update)] #[derive(Debug, PartialEq, Eq, Hash)]
pub struct Item {} pub struct Item {}
#[salsa::tracked] #[salsa::tracked]