initial late field impl

This commit is contained in:
puuuuh 2025-07-06 06:45:14 +03:00
parent 7ab42086d1
commit d6eeae3e5a
No known key found for this signature in database
GPG key ID: 171E3E1356CEE151
11 changed files with 426 additions and 43 deletions

View file

@ -29,7 +29,43 @@ macro_rules! maybe_backdate {
$zalsa:ident,
) => {
if $maybe_update(std::ptr::addr_of_mut!($old_field_place), $new_field_place) {
$revision_place = $current_revision;
$revision_place.store($current_revision);
}
};
}
/// Conditionally update field value and backdate revisions
#[macro_export]
macro_rules! maybe_backdate_late {
(
($return_mode:ident, no_backdate, $maybe_default:ident),
$maybe_update:tt,
$old_field_place:expr,
$new_field_place:expr,
$revision_place:expr,
$current_revision:expr,
$zalsa:ident,
) => {
$zalsa::always_update(
&mut $revision_place,
$current_revision,
&mut $old_field_place,
$new_field_place,
);
};
(
($return_mode:ident, backdate, $maybe_default:ident),
$maybe_update:tt,
$old_field_place:expr,
$new_field_place:expr,
$revision_place:expr,
$current_revision:expr,
$zalsa:ident,
) => {
if $zalsa::LateField::maybe_update(&mut $old_field_place, $new_field_place, $maybe_update, $revision_place.load()) {
$revision_place.store($current_revision);
}
};
}

View file

@ -23,12 +23,18 @@ macro_rules! setup_tracked_struct {
// Tracked field names.
tracked_ids: [$($tracked_id:ident),*],
// Non late field names.
non_late_ids: [$($non_late_id:ident),*],
// Visibility and names of tracked fields.
tracked_getters: [$($tracked_getter_vis:vis $tracked_getter_id:ident),*],
// Visibility and names of untracked fields.
untracked_getters: [$($untracked_getter_vis:vis $untracked_getter_id:ident),*],
// Names of setters for late fields.
tracked_setters: [$($tracked_setter_id:ident),*],
// Field types, may reference `db_lt`.
field_tys: [$($field_ty:ty),*],
@ -38,6 +44,9 @@ macro_rules! setup_tracked_struct {
// Untracked field types.
untracked_tys: [$($untracked_ty:ty),*],
// Non late field types.
non_late_tys: [$($non_late_ty:ty),*],
// Indices for each field from 0..N -- must be unsuffixed (e.g., `0`, `1`).
field_indices: [$($field_index:tt),*],
@ -56,12 +65,17 @@ macro_rules! setup_tracked_struct {
// Untracked field types.
untracked_maybe_updates: [$($untracked_maybe_update:tt),*],
// If tracked field can be set after new.
field_is_late: [$($field_is_late:tt),*],
tracked_is_late: [$($tracked_is_late:tt),*],
// A set of "field options" for each tracked field.
//
// Each field option is a tuple `(return_mode, maybe_backdate)` where:
//
// * `return_mode` is an identifier as specified in `salsa_macros::options::Option::returns`
// * `maybe_backdate` is either the identifier `backdate` or `no_backdate`
// * `maybe_backdate` is either the identifier `backdate` or `no_backdate`
//
// These are used to drive conditional logic for each field via recursive macro invocation
// (see e.g. @return_mode below).
@ -131,9 +145,9 @@ macro_rules! setup_tracked_struct {
$($relative_tracked_index,)*
];
type Fields<$db_lt> = ($($field_ty,)*);
type Fields<$db_lt> = ($($zalsa::macro_if!(if $field_is_late { $zalsa::LateField<$field_ty> } else { $field_ty }),)*);
type Revisions = [$Revision; $N];
type Revisions = [$zalsa::AtomicRevision; $N];
type Struct<$db_lt> = $Struct<$db_lt>;
@ -142,7 +156,7 @@ macro_rules! setup_tracked_struct {
}
fn new_revisions(current_revision: $Revision) -> Self::Revisions {
[current_revision; $N]
std::array::from_fn(|_| $zalsa::AtomicRevision::from(current_revision))
}
unsafe fn update_fields<$db_lt>(
@ -154,15 +168,29 @@ macro_rules! setup_tracked_struct {
use $zalsa::UpdateFallback as _;
unsafe {
$(
$crate::maybe_backdate!(
$tracked_option,
$tracked_maybe_update,
(*old_fields).$absolute_tracked_index,
new_fields.$absolute_tracked_index,
revisions[$relative_tracked_index],
current_revision,
$zalsa,
);
$zalsa::macro_if! {
if $tracked_is_late {
$crate::maybe_backdate_late!(
$tracked_option,
$tracked_maybe_update,
(*old_fields).$absolute_tracked_index,
new_fields.$absolute_tracked_index,
revisions[$relative_tracked_index],
current_revision,
$zalsa,
);
} else {
$crate::maybe_backdate!(
$tracked_option,
$tracked_maybe_update,
(*old_fields).$absolute_tracked_index,
new_fields.$absolute_tracked_index,
revisions[$relative_tracked_index],
current_revision,
$zalsa,
);
}
}
)*;
// If any untracked field has changed, return `true`, indicating that the tracked struct
@ -254,31 +282,52 @@ macro_rules! setup_tracked_struct {
}
impl<$db_lt> $Struct<$db_lt> {
pub fn $new_fn<$Db>(db: &$db_lt $Db, $($field_id: $field_ty),*) -> Self
where
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
$Db: ?Sized + $zalsa::Database,
pub fn $new_fn(db: &$db_lt dyn $zalsa::Database, $($non_late_id: $non_late_ty),*) -> Self
{
$Configuration::ingredient(db.as_dyn_database()).new_struct(
db.as_dyn_database(),
($($field_id,)*)
$Configuration::ingredient(db).new_struct(
db,
($($zalsa::macro_if!(if $field_is_late {$zalsa::LateField::new()} else {$field_id}),)*)
)
}
$(
$(#[$tracked_field_attr])*
$tracked_getter_vis fn $tracked_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::return_mode_ty!($tracked_option, $db_lt, $tracked_ty)
where
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
$Db: ?Sized + $zalsa::Database,
$tracked_getter_vis fn $tracked_getter_id(self, db: &$db_lt dyn $zalsa::Database) -> $crate::return_mode_ty!($tracked_option, $db_lt, $tracked_ty)
{
let db = db.as_dyn_database();
let fields = $Configuration::ingredient(db).tracked_field(db, self, $relative_tracked_index);
$crate::return_mode_expression!(
$tracked_option,
$tracked_ty,
&fields.$absolute_tracked_index,
)
$zalsa::macro_if! { if $tracked_is_late {
$crate::return_mode_expression!(
$tracked_option,
$tracked_ty,
&fields.$absolute_tracked_index.get().expect("can't get late field without initialization"),
)
} else {
$crate::return_mode_expression!(
$tracked_option,
$tracked_ty,
&fields.$absolute_tracked_index,
)
}
}
}
)*
$(
$zalsa::macro_if! { if $tracked_is_late {
pub fn $tracked_setter_id(
self,
db: &$db_lt dyn $zalsa::Database,
value: $tracked_ty,
) {
let ingredient = $Configuration::ingredient(db);
if let Some(old_rev) = ingredient.tracked_field(db, self, $relative_tracked_index)
.$absolute_tracked_index
.set_maybe_backdate(value) {
ingredient.revisions(db, self)[$relative_tracked_index].store(old_rev);
}
}
} else {}
}
)*

View file

@ -66,6 +66,7 @@ pub(crate) struct SalsaField<'s> {
pub(crate) returns: syn::Ident,
pub(crate) has_no_eq_attr: bool,
pub(crate) maybe_update_attr: Option<(syn::Path, syn::Expr)>,
pub(crate) has_late_attr: bool,
get_name: syn::Ident,
set_name: syn::Ident,
unknown_attrs: Vec<&'s syn::Attribute>,
@ -83,6 +84,11 @@ pub(crate) const FIELD_OPTION_ATTRIBUTES: &[(
ef.has_tracked_attr = true;
Ok(())
}),
("late", |_, ef| {
ef.has_tracked_attr = true;
ef.has_late_attr = true;
Ok(())
}),
("default", |_, ef| {
ef.has_default_attr = true;
Ok(())
@ -139,6 +145,7 @@ where
this.maybe_disallow_maybe_update_fields()?;
this.maybe_disallow_tracked_fields()?;
this.maybe_disallow_default_fields()?;
// this.disallow_late_id_fields()?;
this.check_generics()?;
@ -244,6 +251,27 @@ where
Ok(())
}
/// Disallow `#[default]` attributes on the fields of this struct.
///
/// If an `#[default]` field is found, return an error.
///
/// # Parameters
///
/// * `kind`, the attribute name (e.g., `input` or `interned`)
fn disallow_late_id_fields(&self) -> syn::Result<()> {
// Check if any field has the `#[default]` attribute.
for ef in &self.fields {
if ef.has_late_attr && ef.has_tracked_attr {
return Err(syn::Error::new_spanned(
ef.field,
format!("`#[late]` cannot be used with `#[tracked]`"),
));
}
}
Ok(())
}
/// Check that the generic parameters look as expected for this kind of struct.
fn check_generics(&self) -> syn::Result<()> {
if A::HAS_LIFETIME {
@ -337,6 +365,12 @@ where
.collect()
}
pub(crate) fn tracked_setter_ids(&self) -> Vec<&syn::Ident> {
self.tracked_fields_iter()
.map(|(_, f)| &f.set_name)
.collect()
}
pub(crate) fn untracked_getter_ids(&self) -> Vec<&syn::Ident> {
self.untracked_fields_iter()
.map(|(_, f)| &f.get_name)
@ -410,6 +444,7 @@ where
.collect()
}
pub fn generate_debug_impl(&self) -> bool {
self.args.debug.is_some()
}
@ -431,6 +466,37 @@ where
.enumerate()
.filter(|(_, f)| !f.has_tracked_attr)
}
pub fn non_late_iter(&self) -> impl Iterator<Item = (usize, &SalsaField<'s>)> {
self.fields
.iter()
.enumerate()
.filter(|(_, f)| !f.has_late_attr)
}
pub(crate) fn field_is_late(&self) -> Vec<TokenStream> {
self.fields.iter()
.map(|f| f.is_late())
.collect()
}
pub(crate) fn tracked_is_late(&self) -> Vec<TokenStream> {
self.tracked_fields_iter()
.map(|(_, f)| f.is_late())
.collect()
}
pub(crate) fn non_late_ids(&self) -> Vec<&syn::Ident> {
self.non_late_iter()
.map(|(_, f)| f.field.ident.as_ref().unwrap())
.collect()
}
pub(crate) fn non_late_tys(&self) -> Vec<&syn::Type> {
self.non_late_iter()
.map(|(_, f)| &f.field.ty)
.collect()
}
}
impl<'s> SalsaField<'s> {
@ -453,6 +519,7 @@ impl<'s> SalsaField<'s> {
returns,
has_default_attr: false,
has_no_eq_attr: false,
has_late_attr: false,
maybe_update_attr: None,
get_name,
set_name,
@ -488,6 +555,14 @@ impl<'s> SalsaField<'s> {
Ok(result)
}
fn is_late(&self) -> TokenStream {
if self.has_late_attr {
quote! { true }
} else {
quote! { false }
}
}
fn options(&self) -> TokenStream {
let returns = &self.returns;

View file

@ -107,6 +107,8 @@ impl Macro {
let tracked_getter_ids = salsa_struct.tracked_getter_ids();
let untracked_getter_ids = salsa_struct.untracked_getter_ids();
let tracked_setter_ids = salsa_struct.tracked_setter_ids();
let field_indices = salsa_struct.field_indices();
let absolute_tracked_indices = salsa_struct.tracked_field_indices();
@ -126,12 +128,14 @@ impl Macro {
let tracked_maybe_update = salsa_struct.tracked_fields_iter().map(|(_, field)| {
let field_ty = &field.field.ty;
if let Some((with_token, maybe_update)) = &field.maybe_update_attr {
let update = if let Some((with_token, maybe_update)) = &field.maybe_update_attr {
quote_spanned! { with_token.span() => ({ let maybe_update: unsafe fn(*mut #field_ty, #field_ty) -> bool = #maybe_update; maybe_update }) }
} else {
quote! {(#zalsa::UpdateDispatch::<#field_ty>::maybe_update)}
}
};
update
});
let untracked_maybe_update = salsa_struct.untracked_fields_iter().map(|(_, field)| {
let field_ty = &field.field.ty;
if let Some((with_token, maybe_update)) = &field.maybe_update_attr {
@ -144,6 +148,11 @@ impl Macro {
let num_tracked_fields = salsa_struct.num_tracked_fields();
let generate_debug_impl = salsa_struct.generate_debug_impl();
let field_is_late = salsa_struct.field_is_late();
let tracked_is_late = salsa_struct.tracked_is_late();
let non_late_ids = salsa_struct.non_late_ids();
let non_late_tys = salsa_struct.non_late_tys();
let zalsa_struct = self.hygiene.ident("zalsa_struct");
let Configuration = self.hygiene.ident("Configuration");
let CACHE = self.hygiene.ident("CACHE");
@ -162,13 +171,17 @@ impl Macro {
field_ids: [#(#field_ids),*],
tracked_ids: [#(#tracked_ids),*],
non_late_ids: [#(#non_late_ids),*],
tracked_getters: [#(#tracked_vis #tracked_getter_ids),*],
untracked_getters: [#(#untracked_vis #untracked_getter_ids),*],
tracked_setters: [#(#tracked_setter_ids),*],
field_tys: [#(#field_tys),*],
tracked_tys: [#(#tracked_tys),*],
untracked_tys: [#(#untracked_tys),*],
non_late_tys: [#(#non_late_tys),*],
field_indices: [#(#field_indices),*],
@ -180,6 +193,9 @@ impl Macro {
tracked_maybe_updates: [#(#tracked_maybe_update),*],
untracked_maybe_updates: [#(#untracked_maybe_update),*],
field_is_late: [#(#field_is_late),*],
tracked_is_late: [#(#tracked_is_late),*],
tracked_options: [#(#tracked_options),*],
untracked_options: [#(#untracked_options),*],

View file

@ -95,7 +95,7 @@ pub mod plumbing {
IngredientIndices, MemoIngredientIndices, MemoIngredientMap, MemoIngredientSingletonIndex,
NewMemoIngredientIndices,
};
pub use crate::revision::Revision;
pub use crate::revision::{Revision, AtomicRevision};
pub use crate::runtime::{stamp, Runtime, Stamp};
pub use crate::salsa_struct::SalsaStructInDb;
pub use crate::storage::{HasStorage, Storage};
@ -106,6 +106,7 @@ pub mod plumbing {
transmute_data_ptr, views, IngredientCache, IngredientIndex, Zalsa, ZalsaDatabase,
};
pub use crate::zalsa_local::ZalsaLocal;
pub use crate::tracked_struct::late_field::LateField;
pub mod accumulator {
pub use crate::accumulator::{IngredientImpl, JarImpl};

View file

@ -61,7 +61,7 @@ impl std::fmt::Debug for Revision {
}
#[derive(Debug)]
pub(crate) struct AtomicRevision {
pub struct AtomicRevision {
data: AtomicUsize,
}
@ -74,13 +74,13 @@ impl From<Revision> for AtomicRevision {
}
impl AtomicRevision {
pub(crate) const fn start() -> Self {
pub const fn start() -> Self {
Self {
data: AtomicUsize::new(START),
}
}
pub(crate) fn load(&self) -> Revision {
pub fn load(&self) -> Revision {
Revision {
// SAFETY: We know that the value is non-zero because we only ever store `START` which 1, or a
// Revision which is guaranteed to be non-zero.
@ -88,7 +88,7 @@ impl AtomicRevision {
}
}
pub(crate) fn store(&self, r: Revision) {
pub fn store(&self, r: Revision) {
self.data.store(r.as_usize(), Ordering::Release);
}
}

View file

@ -16,7 +16,7 @@ use crate::id::{AsId, FromId};
use crate::ingredient::{Ingredient, Jar};
use crate::key::DatabaseKeyIndex;
use crate::plumbing::ZalsaLocal;
use crate::revision::OptionalAtomicRevision;
use crate::revision::{AtomicRevision, OptionalAtomicRevision};
use crate::runtime::Stamp;
use crate::salsa_struct::SalsaStructInDb;
use crate::sync::Arc;
@ -26,7 +26,7 @@ use crate::zalsa::{IngredientIndex, Zalsa};
use crate::{Database, Durability, Event, EventKind, Id, Revision};
pub mod tracked_field;
pub mod late_field;
// ANCHOR: Configuration
/// Trait that defines the key properties of a tracked struct.
///
@ -51,7 +51,7 @@ pub trait Configuration: Sized + 'static {
/// When a struct is re-recreated in a new revision, the corresponding
/// entries for each field are updated to the new revision if their
/// values have changed (or if the field is marked as `#[no_eq]`).
type Revisions: Send + Sync + Index<usize, Output = Revision>;
type Revisions: Send + Sync + Index<usize, Output = AtomicRevision>;
type Struct<'db>: Copy + FromId + AsId;
@ -755,7 +755,7 @@ where
data.read_lock(zalsa.current_revision());
let field_changed_at = data.revisions[relative_tracked_index];
let field_changed_at = data.revisions[relative_tracked_index].load();
zalsa_local.report_tracked_read_simple(
DatabaseKeyIndex::new(field_ingredient_index, id),
@ -766,6 +766,20 @@ where
data.fields()
}
pub fn revisions<'db>(
&'db self,
db: &'db dyn crate::Database,
s: C::Struct<'db>,
) -> &'db C::Revisions {
let (zalsa, _) = db.zalsas();
let id = AsId::as_id(&s);
let data = Self::data(zalsa.table(), id);
data.read_lock(zalsa.current_revision());
&data.revisions
}
/// Access to this untracked field.
///
/// Note that this function returns the entire tuple of value fields.

View file

@ -0,0 +1,117 @@
use std::{cell::UnsafeCell, mem::MaybeUninit, sync::atomic::{AtomicUsize, Ordering}};
use crate::{Revision, Update};
const EMPTY: usize = 0;
const ACQUIRED: usize = 1;
const SET: usize = 2;
const DIRTY: usize = 3;
#[derive(Debug)]
pub struct LateField<T: Update> {
state: AtomicUsize,
// Last valid revision of DIRTY state
old_revision: Option<Revision>,
data: UnsafeCell<MaybeUninit<T>>
}
unsafe impl<T: Update + Send> Send for LateField<T> {}
unsafe impl<T: Update + Sync> Sync for LateField<T> {}
impl<T: Update> LateField<T> {
pub fn new() -> LateField<T> {
LateField {
state: AtomicUsize::new(EMPTY),
old_revision: None,
data: UnsafeCell::new(MaybeUninit::uninit())
}
}
// Update self, store old revision to probably backdate later
pub fn maybe_update(&mut self, mut value: Self, maybe_update_inner: unsafe fn(*mut T, T) -> bool, old_revision: Revision) -> bool {
let old_state = self.state.load(Ordering::Relaxed);
let new_state = value.state.load(Ordering::Relaxed);
let t = match (old_state, new_state) {
(EMPTY, EMPTY) => {
self.old_revision = None;
self.state.store(EMPTY, Ordering::Release);
false
},
(EMPTY, SET) => {
self.old_revision = None;
self.data = value.data;
self.state.store(SET, Ordering::Release);
true
},
(DIRTY, SET) => {
// SAFETY: DIRTY and SET state assumes that data is initialized
let changed = unsafe {
maybe_update_inner(self.data.get_mut().assume_init_mut(), value.data.get_mut().assume_init_read())
};
self.state.store(SET, Ordering::Release);
changed
},
(SET, EMPTY) => {
self.old_revision = Some(old_revision);
// Save old value to probably backdate later
self.state.store(DIRTY, Ordering::Release);
true
}
_ => panic!("unexpected state"),
};
t
}
/// Set new value and returns saved revision if its not updated
pub fn set_maybe_backdate(&self, value: T) -> Option<Revision> {
let old_state = self.state.load(Ordering::Relaxed);
match old_state {
EMPTY => {},
DIRTY => {},
SET => {
panic!("set on late field called twice")
},
ACQUIRED => {
panic!("concurrent set on late field is not allowed")
}
_ => panic!("unexpected state"),
}
self.state.compare_exchange(old_state, ACQUIRED, Ordering::Acquire, Ordering::Relaxed).expect("concurrent set on late field is not allowed");
let updated = if old_state == EMPTY {
unsafe {
(&mut *self.data.get()).write(value);
}
true
} else {
unsafe {
Update::maybe_update((&mut *self.data.get()).assume_init_mut(), value)
}
};
self.state.store(SET, Ordering::Release);
if updated {
None
} else {
self.old_revision
}
}
pub fn get(&self) -> Option<&T> {
if self.state.load(Ordering::Acquire) != SET {
return None;
};
// SAFETY: we can't move from SET to any other state while we have ref to self
Some(unsafe {
(&*self.data.get()).assume_init_ref()
})
}
}

View file

@ -64,7 +64,7 @@ where
) -> VerifyResult {
let zalsa = db.zalsa();
let data = <super::IngredientImpl<C>>::data(zalsa.table(), input);
let field_changed_at = data.revisions[self.field_index];
let field_changed_at = data.revisions[self.field_index].load();
VerifyResult::changed_if(field_changed_at > revision)
}

View file

@ -8,6 +8,7 @@ use std::path::PathBuf;
#[cfg(feature = "rayon")]
use rayon::iter::Either;
use crate::revision::AtomicRevision;
use crate::sync::Arc;
use crate::Revision;
@ -104,12 +105,12 @@ where
/// and updates `*old_revision` with `new_revision.` Used for fields
/// tagged with `#[no_eq]`
pub fn always_update<T>(
old_revision: &mut Revision,
old_revision: &mut AtomicRevision,
new_revision: Revision,
old_pointer: &mut T,
new_value: T,
) {
*old_revision = new_revision;
old_revision.store(new_revision);
*old_pointer = new_value;
}

View file

@ -0,0 +1,74 @@
mod common;
use salsa::{Database, Setter};
// A tracked struct with mixed tracked and untracked fields to ensure
// the correct field indices are used when tracking dependencies.
#[salsa::tracked(debug)]
struct TrackedWithLateField<'db> {
untracked_1: usize,
#[late]
tracked_1: usize,
#[late]
tracked_2: usize,
untracked_2: usize,
untracked_3: usize,
untracked_4: usize,
}
#[salsa::input]
struct MyInput {
field1: usize,
field2: usize,
}
#[salsa::tracked]
fn intermediate(db: &dyn salsa::Database, input: MyInput) -> TrackedWithLateField<'_> {
input.field1(db);
input.field2(db);
let t = TrackedWithLateField::new(db, 0, 1, 2, 3);
t.set_tracked_1(db, input.field1(db));
t.set_tracked_2(db, input.field2(db));
t
}
#[salsa::tracked]
fn accumulate(db: &dyn salsa::Database, input: MyInput) -> (usize, usize) {
let tracked = intermediate(db, input);
let one = read_tracked_1(db, tracked);
let two = read_tracked_2(db, tracked);
(one, two)
}
#[salsa::tracked]
fn read_tracked_1<'db>(db: &'db dyn Database, tracked: TrackedWithLateField<'db>) -> usize {
tracked.tracked_1(db)
}
#[salsa::tracked]
fn read_tracked_2<'db>(db: &'db dyn Database, tracked: TrackedWithLateField<'db>) -> usize {
tracked.tracked_2(db)
}
#[test_log::test]
fn execute() {
let mut db = salsa::DatabaseImpl::default();
let input = MyInput::new(&db, 1, 1);
assert_eq!(accumulate(&db, input), (1, 1));
// Should only re-execute `read_tracked_1`.
input.set_field1(&mut db).to(2);
input.set_field2(&mut db).to(1);
assert_eq!(accumulate(&db, input), (2, 1));
// Should only re-execute `read_tracked_2`.
input.set_field2(&mut db).to(2);
assert_eq!(accumulate(&db, input), (2, 2));
}