mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-08-04 11:00:05 +00:00
initial late field impl
This commit is contained in:
parent
7ab42086d1
commit
d6eeae3e5a
11 changed files with 426 additions and 43 deletions
|
@ -29,7 +29,43 @@ macro_rules! maybe_backdate {
|
|||
$zalsa:ident,
|
||||
) => {
|
||||
if $maybe_update(std::ptr::addr_of_mut!($old_field_place), $new_field_place) {
|
||||
$revision_place = $current_revision;
|
||||
$revision_place.store($current_revision);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Conditionally update field value and backdate revisions
|
||||
#[macro_export]
|
||||
macro_rules! maybe_backdate_late {
|
||||
(
|
||||
($return_mode:ident, no_backdate, $maybe_default:ident),
|
||||
$maybe_update:tt,
|
||||
$old_field_place:expr,
|
||||
$new_field_place:expr,
|
||||
$revision_place:expr,
|
||||
$current_revision:expr,
|
||||
$zalsa:ident,
|
||||
|
||||
) => {
|
||||
$zalsa::always_update(
|
||||
&mut $revision_place,
|
||||
$current_revision,
|
||||
&mut $old_field_place,
|
||||
$new_field_place,
|
||||
);
|
||||
};
|
||||
|
||||
(
|
||||
($return_mode:ident, backdate, $maybe_default:ident),
|
||||
$maybe_update:tt,
|
||||
$old_field_place:expr,
|
||||
$new_field_place:expr,
|
||||
$revision_place:expr,
|
||||
$current_revision:expr,
|
||||
$zalsa:ident,
|
||||
) => {
|
||||
if $zalsa::LateField::maybe_update(&mut $old_field_place, $new_field_place, $maybe_update, $revision_place.load()) {
|
||||
$revision_place.store($current_revision);
|
||||
}
|
||||
};
|
||||
}
|
|
@ -23,12 +23,18 @@ macro_rules! setup_tracked_struct {
|
|||
// Tracked field names.
|
||||
tracked_ids: [$($tracked_id:ident),*],
|
||||
|
||||
// Non late field names.
|
||||
non_late_ids: [$($non_late_id:ident),*],
|
||||
|
||||
// Visibility and names of tracked fields.
|
||||
tracked_getters: [$($tracked_getter_vis:vis $tracked_getter_id:ident),*],
|
||||
|
||||
// Visibility and names of untracked fields.
|
||||
untracked_getters: [$($untracked_getter_vis:vis $untracked_getter_id:ident),*],
|
||||
|
||||
// Names of setters for late fields.
|
||||
tracked_setters: [$($tracked_setter_id:ident),*],
|
||||
|
||||
// Field types, may reference `db_lt`.
|
||||
field_tys: [$($field_ty:ty),*],
|
||||
|
||||
|
@ -38,6 +44,9 @@ macro_rules! setup_tracked_struct {
|
|||
// Untracked field types.
|
||||
untracked_tys: [$($untracked_ty:ty),*],
|
||||
|
||||
// Non late field types.
|
||||
non_late_tys: [$($non_late_ty:ty),*],
|
||||
|
||||
// Indices for each field from 0..N -- must be unsuffixed (e.g., `0`, `1`).
|
||||
field_indices: [$($field_index:tt),*],
|
||||
|
||||
|
@ -56,12 +65,17 @@ macro_rules! setup_tracked_struct {
|
|||
// Untracked field types.
|
||||
untracked_maybe_updates: [$($untracked_maybe_update:tt),*],
|
||||
|
||||
// If tracked field can be set after new.
|
||||
field_is_late: [$($field_is_late:tt),*],
|
||||
tracked_is_late: [$($tracked_is_late:tt),*],
|
||||
|
||||
// A set of "field options" for each tracked field.
|
||||
//
|
||||
// Each field option is a tuple `(return_mode, maybe_backdate)` where:
|
||||
//
|
||||
// * `return_mode` is an identifier as specified in `salsa_macros::options::Option::returns`
|
||||
// * `maybe_backdate` is either the identifier `backdate` or `no_backdate`
|
||||
// * `maybe_backdate` is either the identifier `backdate` or `no_backdate`
|
||||
//
|
||||
// These are used to drive conditional logic for each field via recursive macro invocation
|
||||
// (see e.g. @return_mode below).
|
||||
|
@ -131,9 +145,9 @@ macro_rules! setup_tracked_struct {
|
|||
$($relative_tracked_index,)*
|
||||
];
|
||||
|
||||
type Fields<$db_lt> = ($($field_ty,)*);
|
||||
type Fields<$db_lt> = ($($zalsa::macro_if!(if $field_is_late { $zalsa::LateField<$field_ty> } else { $field_ty }),)*);
|
||||
|
||||
type Revisions = [$Revision; $N];
|
||||
type Revisions = [$zalsa::AtomicRevision; $N];
|
||||
|
||||
type Struct<$db_lt> = $Struct<$db_lt>;
|
||||
|
||||
|
@ -142,7 +156,7 @@ macro_rules! setup_tracked_struct {
|
|||
}
|
||||
|
||||
fn new_revisions(current_revision: $Revision) -> Self::Revisions {
|
||||
[current_revision; $N]
|
||||
std::array::from_fn(|_| $zalsa::AtomicRevision::from(current_revision))
|
||||
}
|
||||
|
||||
unsafe fn update_fields<$db_lt>(
|
||||
|
@ -154,15 +168,29 @@ macro_rules! setup_tracked_struct {
|
|||
use $zalsa::UpdateFallback as _;
|
||||
unsafe {
|
||||
$(
|
||||
$crate::maybe_backdate!(
|
||||
$tracked_option,
|
||||
$tracked_maybe_update,
|
||||
(*old_fields).$absolute_tracked_index,
|
||||
new_fields.$absolute_tracked_index,
|
||||
revisions[$relative_tracked_index],
|
||||
current_revision,
|
||||
$zalsa,
|
||||
);
|
||||
$zalsa::macro_if! {
|
||||
if $tracked_is_late {
|
||||
$crate::maybe_backdate_late!(
|
||||
$tracked_option,
|
||||
$tracked_maybe_update,
|
||||
(*old_fields).$absolute_tracked_index,
|
||||
new_fields.$absolute_tracked_index,
|
||||
revisions[$relative_tracked_index],
|
||||
current_revision,
|
||||
$zalsa,
|
||||
);
|
||||
} else {
|
||||
$crate::maybe_backdate!(
|
||||
$tracked_option,
|
||||
$tracked_maybe_update,
|
||||
(*old_fields).$absolute_tracked_index,
|
||||
new_fields.$absolute_tracked_index,
|
||||
revisions[$relative_tracked_index],
|
||||
current_revision,
|
||||
$zalsa,
|
||||
);
|
||||
}
|
||||
}
|
||||
)*;
|
||||
|
||||
// If any untracked field has changed, return `true`, indicating that the tracked struct
|
||||
|
@ -254,31 +282,52 @@ macro_rules! setup_tracked_struct {
|
|||
}
|
||||
|
||||
impl<$db_lt> $Struct<$db_lt> {
|
||||
pub fn $new_fn<$Db>(db: &$db_lt $Db, $($field_id: $field_ty),*) -> Self
|
||||
where
|
||||
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
|
||||
$Db: ?Sized + $zalsa::Database,
|
||||
pub fn $new_fn(db: &$db_lt dyn $zalsa::Database, $($non_late_id: $non_late_ty),*) -> Self
|
||||
{
|
||||
$Configuration::ingredient(db.as_dyn_database()).new_struct(
|
||||
db.as_dyn_database(),
|
||||
($($field_id,)*)
|
||||
$Configuration::ingredient(db).new_struct(
|
||||
db,
|
||||
($($zalsa::macro_if!(if $field_is_late {$zalsa::LateField::new()} else {$field_id}),)*)
|
||||
)
|
||||
}
|
||||
|
||||
$(
|
||||
$(#[$tracked_field_attr])*
|
||||
$tracked_getter_vis fn $tracked_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::return_mode_ty!($tracked_option, $db_lt, $tracked_ty)
|
||||
where
|
||||
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
|
||||
$Db: ?Sized + $zalsa::Database,
|
||||
$tracked_getter_vis fn $tracked_getter_id(self, db: &$db_lt dyn $zalsa::Database) -> $crate::return_mode_ty!($tracked_option, $db_lt, $tracked_ty)
|
||||
{
|
||||
let db = db.as_dyn_database();
|
||||
let fields = $Configuration::ingredient(db).tracked_field(db, self, $relative_tracked_index);
|
||||
$crate::return_mode_expression!(
|
||||
$tracked_option,
|
||||
$tracked_ty,
|
||||
&fields.$absolute_tracked_index,
|
||||
)
|
||||
$zalsa::macro_if! { if $tracked_is_late {
|
||||
$crate::return_mode_expression!(
|
||||
$tracked_option,
|
||||
$tracked_ty,
|
||||
&fields.$absolute_tracked_index.get().expect("can't get late field without initialization"),
|
||||
)
|
||||
} else {
|
||||
$crate::return_mode_expression!(
|
||||
$tracked_option,
|
||||
$tracked_ty,
|
||||
&fields.$absolute_tracked_index,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
)*
|
||||
|
||||
$(
|
||||
$zalsa::macro_if! { if $tracked_is_late {
|
||||
pub fn $tracked_setter_id(
|
||||
self,
|
||||
db: &$db_lt dyn $zalsa::Database,
|
||||
value: $tracked_ty,
|
||||
) {
|
||||
let ingredient = $Configuration::ingredient(db);
|
||||
if let Some(old_rev) = ingredient.tracked_field(db, self, $relative_tracked_index)
|
||||
.$absolute_tracked_index
|
||||
.set_maybe_backdate(value) {
|
||||
ingredient.revisions(db, self)[$relative_tracked_index].store(old_rev);
|
||||
}
|
||||
}
|
||||
} else {}
|
||||
}
|
||||
)*
|
||||
|
||||
|
|
|
@ -66,6 +66,7 @@ pub(crate) struct SalsaField<'s> {
|
|||
pub(crate) returns: syn::Ident,
|
||||
pub(crate) has_no_eq_attr: bool,
|
||||
pub(crate) maybe_update_attr: Option<(syn::Path, syn::Expr)>,
|
||||
pub(crate) has_late_attr: bool,
|
||||
get_name: syn::Ident,
|
||||
set_name: syn::Ident,
|
||||
unknown_attrs: Vec<&'s syn::Attribute>,
|
||||
|
@ -83,6 +84,11 @@ pub(crate) const FIELD_OPTION_ATTRIBUTES: &[(
|
|||
ef.has_tracked_attr = true;
|
||||
Ok(())
|
||||
}),
|
||||
("late", |_, ef| {
|
||||
ef.has_tracked_attr = true;
|
||||
ef.has_late_attr = true;
|
||||
Ok(())
|
||||
}),
|
||||
("default", |_, ef| {
|
||||
ef.has_default_attr = true;
|
||||
Ok(())
|
||||
|
@ -139,6 +145,7 @@ where
|
|||
this.maybe_disallow_maybe_update_fields()?;
|
||||
this.maybe_disallow_tracked_fields()?;
|
||||
this.maybe_disallow_default_fields()?;
|
||||
// this.disallow_late_id_fields()?;
|
||||
|
||||
this.check_generics()?;
|
||||
|
||||
|
@ -244,6 +251,27 @@ where
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Disallow `#[default]` attributes on the fields of this struct.
|
||||
///
|
||||
/// If an `#[default]` field is found, return an error.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `kind`, the attribute name (e.g., `input` or `interned`)
|
||||
fn disallow_late_id_fields(&self) -> syn::Result<()> {
|
||||
// Check if any field has the `#[default]` attribute.
|
||||
for ef in &self.fields {
|
||||
if ef.has_late_attr && ef.has_tracked_attr {
|
||||
return Err(syn::Error::new_spanned(
|
||||
ef.field,
|
||||
format!("`#[late]` cannot be used with `#[tracked]`"),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check that the generic parameters look as expected for this kind of struct.
|
||||
fn check_generics(&self) -> syn::Result<()> {
|
||||
if A::HAS_LIFETIME {
|
||||
|
@ -337,6 +365,12 @@ where
|
|||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn tracked_setter_ids(&self) -> Vec<&syn::Ident> {
|
||||
self.tracked_fields_iter()
|
||||
.map(|(_, f)| &f.set_name)
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn untracked_getter_ids(&self) -> Vec<&syn::Ident> {
|
||||
self.untracked_fields_iter()
|
||||
.map(|(_, f)| &f.get_name)
|
||||
|
@ -410,6 +444,7 @@ where
|
|||
.collect()
|
||||
}
|
||||
|
||||
|
||||
pub fn generate_debug_impl(&self) -> bool {
|
||||
self.args.debug.is_some()
|
||||
}
|
||||
|
@ -431,6 +466,37 @@ where
|
|||
.enumerate()
|
||||
.filter(|(_, f)| !f.has_tracked_attr)
|
||||
}
|
||||
|
||||
pub fn non_late_iter(&self) -> impl Iterator<Item = (usize, &SalsaField<'s>)> {
|
||||
self.fields
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, f)| !f.has_late_attr)
|
||||
}
|
||||
|
||||
pub(crate) fn field_is_late(&self) -> Vec<TokenStream> {
|
||||
self.fields.iter()
|
||||
.map(|f| f.is_late())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn tracked_is_late(&self) -> Vec<TokenStream> {
|
||||
self.tracked_fields_iter()
|
||||
.map(|(_, f)| f.is_late())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn non_late_ids(&self) -> Vec<&syn::Ident> {
|
||||
self.non_late_iter()
|
||||
.map(|(_, f)| f.field.ident.as_ref().unwrap())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn non_late_tys(&self) -> Vec<&syn::Type> {
|
||||
self.non_late_iter()
|
||||
.map(|(_, f)| &f.field.ty)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'s> SalsaField<'s> {
|
||||
|
@ -453,6 +519,7 @@ impl<'s> SalsaField<'s> {
|
|||
returns,
|
||||
has_default_attr: false,
|
||||
has_no_eq_attr: false,
|
||||
has_late_attr: false,
|
||||
maybe_update_attr: None,
|
||||
get_name,
|
||||
set_name,
|
||||
|
@ -488,6 +555,14 @@ impl<'s> SalsaField<'s> {
|
|||
Ok(result)
|
||||
}
|
||||
|
||||
fn is_late(&self) -> TokenStream {
|
||||
if self.has_late_attr {
|
||||
quote! { true }
|
||||
} else {
|
||||
quote! { false }
|
||||
}
|
||||
}
|
||||
|
||||
fn options(&self) -> TokenStream {
|
||||
let returns = &self.returns;
|
||||
|
||||
|
|
|
@ -107,6 +107,8 @@ impl Macro {
|
|||
let tracked_getter_ids = salsa_struct.tracked_getter_ids();
|
||||
let untracked_getter_ids = salsa_struct.untracked_getter_ids();
|
||||
|
||||
let tracked_setter_ids = salsa_struct.tracked_setter_ids();
|
||||
|
||||
let field_indices = salsa_struct.field_indices();
|
||||
|
||||
let absolute_tracked_indices = salsa_struct.tracked_field_indices();
|
||||
|
@ -126,12 +128,14 @@ impl Macro {
|
|||
|
||||
let tracked_maybe_update = salsa_struct.tracked_fields_iter().map(|(_, field)| {
|
||||
let field_ty = &field.field.ty;
|
||||
if let Some((with_token, maybe_update)) = &field.maybe_update_attr {
|
||||
let update = if let Some((with_token, maybe_update)) = &field.maybe_update_attr {
|
||||
quote_spanned! { with_token.span() => ({ let maybe_update: unsafe fn(*mut #field_ty, #field_ty) -> bool = #maybe_update; maybe_update }) }
|
||||
} else {
|
||||
quote! {(#zalsa::UpdateDispatch::<#field_ty>::maybe_update)}
|
||||
}
|
||||
};
|
||||
update
|
||||
});
|
||||
|
||||
let untracked_maybe_update = salsa_struct.untracked_fields_iter().map(|(_, field)| {
|
||||
let field_ty = &field.field.ty;
|
||||
if let Some((with_token, maybe_update)) = &field.maybe_update_attr {
|
||||
|
@ -144,6 +148,11 @@ impl Macro {
|
|||
let num_tracked_fields = salsa_struct.num_tracked_fields();
|
||||
let generate_debug_impl = salsa_struct.generate_debug_impl();
|
||||
|
||||
let field_is_late = salsa_struct.field_is_late();
|
||||
let tracked_is_late = salsa_struct.tracked_is_late();
|
||||
let non_late_ids = salsa_struct.non_late_ids();
|
||||
let non_late_tys = salsa_struct.non_late_tys();
|
||||
|
||||
let zalsa_struct = self.hygiene.ident("zalsa_struct");
|
||||
let Configuration = self.hygiene.ident("Configuration");
|
||||
let CACHE = self.hygiene.ident("CACHE");
|
||||
|
@ -162,13 +171,17 @@ impl Macro {
|
|||
|
||||
field_ids: [#(#field_ids),*],
|
||||
tracked_ids: [#(#tracked_ids),*],
|
||||
non_late_ids: [#(#non_late_ids),*],
|
||||
|
||||
tracked_getters: [#(#tracked_vis #tracked_getter_ids),*],
|
||||
untracked_getters: [#(#untracked_vis #untracked_getter_ids),*],
|
||||
|
||||
tracked_setters: [#(#tracked_setter_ids),*],
|
||||
|
||||
field_tys: [#(#field_tys),*],
|
||||
tracked_tys: [#(#tracked_tys),*],
|
||||
untracked_tys: [#(#untracked_tys),*],
|
||||
non_late_tys: [#(#non_late_tys),*],
|
||||
|
||||
field_indices: [#(#field_indices),*],
|
||||
|
||||
|
@ -180,6 +193,9 @@ impl Macro {
|
|||
tracked_maybe_updates: [#(#tracked_maybe_update),*],
|
||||
untracked_maybe_updates: [#(#untracked_maybe_update),*],
|
||||
|
||||
field_is_late: [#(#field_is_late),*],
|
||||
tracked_is_late: [#(#tracked_is_late),*],
|
||||
|
||||
tracked_options: [#(#tracked_options),*],
|
||||
untracked_options: [#(#untracked_options),*],
|
||||
|
||||
|
|
|
@ -95,7 +95,7 @@ pub mod plumbing {
|
|||
IngredientIndices, MemoIngredientIndices, MemoIngredientMap, MemoIngredientSingletonIndex,
|
||||
NewMemoIngredientIndices,
|
||||
};
|
||||
pub use crate::revision::Revision;
|
||||
pub use crate::revision::{Revision, AtomicRevision};
|
||||
pub use crate::runtime::{stamp, Runtime, Stamp};
|
||||
pub use crate::salsa_struct::SalsaStructInDb;
|
||||
pub use crate::storage::{HasStorage, Storage};
|
||||
|
@ -106,6 +106,7 @@ pub mod plumbing {
|
|||
transmute_data_ptr, views, IngredientCache, IngredientIndex, Zalsa, ZalsaDatabase,
|
||||
};
|
||||
pub use crate::zalsa_local::ZalsaLocal;
|
||||
pub use crate::tracked_struct::late_field::LateField;
|
||||
|
||||
pub mod accumulator {
|
||||
pub use crate::accumulator::{IngredientImpl, JarImpl};
|
||||
|
|
|
@ -61,7 +61,7 @@ impl std::fmt::Debug for Revision {
|
|||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct AtomicRevision {
|
||||
pub struct AtomicRevision {
|
||||
data: AtomicUsize,
|
||||
}
|
||||
|
||||
|
@ -74,13 +74,13 @@ impl From<Revision> for AtomicRevision {
|
|||
}
|
||||
|
||||
impl AtomicRevision {
|
||||
pub(crate) const fn start() -> Self {
|
||||
pub const fn start() -> Self {
|
||||
Self {
|
||||
data: AtomicUsize::new(START),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn load(&self) -> Revision {
|
||||
pub fn load(&self) -> Revision {
|
||||
Revision {
|
||||
// SAFETY: We know that the value is non-zero because we only ever store `START` which 1, or a
|
||||
// Revision which is guaranteed to be non-zero.
|
||||
|
@ -88,7 +88,7 @@ impl AtomicRevision {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn store(&self, r: Revision) {
|
||||
pub fn store(&self, r: Revision) {
|
||||
self.data.store(r.as_usize(), Ordering::Release);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ use crate::id::{AsId, FromId};
|
|||
use crate::ingredient::{Ingredient, Jar};
|
||||
use crate::key::DatabaseKeyIndex;
|
||||
use crate::plumbing::ZalsaLocal;
|
||||
use crate::revision::OptionalAtomicRevision;
|
||||
use crate::revision::{AtomicRevision, OptionalAtomicRevision};
|
||||
use crate::runtime::Stamp;
|
||||
use crate::salsa_struct::SalsaStructInDb;
|
||||
use crate::sync::Arc;
|
||||
|
@ -26,7 +26,7 @@ use crate::zalsa::{IngredientIndex, Zalsa};
|
|||
use crate::{Database, Durability, Event, EventKind, Id, Revision};
|
||||
|
||||
pub mod tracked_field;
|
||||
|
||||
pub mod late_field;
|
||||
// ANCHOR: Configuration
|
||||
/// Trait that defines the key properties of a tracked struct.
|
||||
///
|
||||
|
@ -51,7 +51,7 @@ pub trait Configuration: Sized + 'static {
|
|||
/// When a struct is re-recreated in a new revision, the corresponding
|
||||
/// entries for each field are updated to the new revision if their
|
||||
/// values have changed (or if the field is marked as `#[no_eq]`).
|
||||
type Revisions: Send + Sync + Index<usize, Output = Revision>;
|
||||
type Revisions: Send + Sync + Index<usize, Output = AtomicRevision>;
|
||||
|
||||
type Struct<'db>: Copy + FromId + AsId;
|
||||
|
||||
|
@ -755,7 +755,7 @@ where
|
|||
|
||||
data.read_lock(zalsa.current_revision());
|
||||
|
||||
let field_changed_at = data.revisions[relative_tracked_index];
|
||||
let field_changed_at = data.revisions[relative_tracked_index].load();
|
||||
|
||||
zalsa_local.report_tracked_read_simple(
|
||||
DatabaseKeyIndex::new(field_ingredient_index, id),
|
||||
|
@ -766,6 +766,20 @@ where
|
|||
data.fields()
|
||||
}
|
||||
|
||||
pub fn revisions<'db>(
|
||||
&'db self,
|
||||
db: &'db dyn crate::Database,
|
||||
s: C::Struct<'db>,
|
||||
) -> &'db C::Revisions {
|
||||
let (zalsa, _) = db.zalsas();
|
||||
let id = AsId::as_id(&s);
|
||||
let data = Self::data(zalsa.table(), id);
|
||||
|
||||
data.read_lock(zalsa.current_revision());
|
||||
|
||||
&data.revisions
|
||||
}
|
||||
|
||||
/// Access to this untracked field.
|
||||
///
|
||||
/// Note that this function returns the entire tuple of value fields.
|
||||
|
|
117
src/tracked_struct/late_field.rs
Normal file
117
src/tracked_struct/late_field.rs
Normal file
|
@ -0,0 +1,117 @@
|
|||
use std::{cell::UnsafeCell, mem::MaybeUninit, sync::atomic::{AtomicUsize, Ordering}};
|
||||
|
||||
use crate::{Revision, Update};
|
||||
|
||||
const EMPTY: usize = 0;
|
||||
const ACQUIRED: usize = 1;
|
||||
const SET: usize = 2;
|
||||
const DIRTY: usize = 3;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LateField<T: Update> {
|
||||
state: AtomicUsize,
|
||||
// Last valid revision of DIRTY state
|
||||
old_revision: Option<Revision>,
|
||||
|
||||
data: UnsafeCell<MaybeUninit<T>>
|
||||
}
|
||||
|
||||
unsafe impl<T: Update + Send> Send for LateField<T> {}
|
||||
unsafe impl<T: Update + Sync> Sync for LateField<T> {}
|
||||
|
||||
impl<T: Update> LateField<T> {
|
||||
pub fn new() -> LateField<T> {
|
||||
LateField {
|
||||
state: AtomicUsize::new(EMPTY),
|
||||
old_revision: None,
|
||||
|
||||
data: UnsafeCell::new(MaybeUninit::uninit())
|
||||
}
|
||||
}
|
||||
|
||||
// Update self, store old revision to probably backdate later
|
||||
pub fn maybe_update(&mut self, mut value: Self, maybe_update_inner: unsafe fn(*mut T, T) -> bool, old_revision: Revision) -> bool {
|
||||
let old_state = self.state.load(Ordering::Relaxed);
|
||||
let new_state = value.state.load(Ordering::Relaxed);
|
||||
let t = match (old_state, new_state) {
|
||||
(EMPTY, EMPTY) => {
|
||||
self.old_revision = None;
|
||||
self.state.store(EMPTY, Ordering::Release);
|
||||
|
||||
false
|
||||
},
|
||||
(EMPTY, SET) => {
|
||||
self.old_revision = None;
|
||||
self.data = value.data;
|
||||
self.state.store(SET, Ordering::Release);
|
||||
|
||||
true
|
||||
},
|
||||
(DIRTY, SET) => {
|
||||
// SAFETY: DIRTY and SET state assumes that data is initialized
|
||||
let changed = unsafe {
|
||||
maybe_update_inner(self.data.get_mut().assume_init_mut(), value.data.get_mut().assume_init_read())
|
||||
};
|
||||
self.state.store(SET, Ordering::Release);
|
||||
|
||||
changed
|
||||
},
|
||||
(SET, EMPTY) => {
|
||||
self.old_revision = Some(old_revision);
|
||||
// Save old value to probably backdate later
|
||||
self.state.store(DIRTY, Ordering::Release);
|
||||
|
||||
true
|
||||
}
|
||||
_ => panic!("unexpected state"),
|
||||
};
|
||||
|
||||
t
|
||||
}
|
||||
|
||||
/// Set new value and returns saved revision if its not updated
|
||||
pub fn set_maybe_backdate(&self, value: T) -> Option<Revision> {
|
||||
let old_state = self.state.load(Ordering::Relaxed);
|
||||
match old_state {
|
||||
EMPTY => {},
|
||||
DIRTY => {},
|
||||
SET => {
|
||||
panic!("set on late field called twice")
|
||||
},
|
||||
ACQUIRED => {
|
||||
panic!("concurrent set on late field is not allowed")
|
||||
}
|
||||
_ => panic!("unexpected state"),
|
||||
}
|
||||
self.state.compare_exchange(old_state, ACQUIRED, Ordering::Acquire, Ordering::Relaxed).expect("concurrent set on late field is not allowed");
|
||||
let updated = if old_state == EMPTY {
|
||||
unsafe {
|
||||
(&mut *self.data.get()).write(value);
|
||||
}
|
||||
true
|
||||
} else {
|
||||
unsafe {
|
||||
Update::maybe_update((&mut *self.data.get()).assume_init_mut(), value)
|
||||
}
|
||||
};
|
||||
|
||||
self.state.store(SET, Ordering::Release);
|
||||
|
||||
if updated {
|
||||
None
|
||||
} else {
|
||||
self.old_revision
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get(&self) -> Option<&T> {
|
||||
if self.state.load(Ordering::Acquire) != SET {
|
||||
return None;
|
||||
};
|
||||
|
||||
// SAFETY: we can't move from SET to any other state while we have ref to self
|
||||
Some(unsafe {
|
||||
(&*self.data.get()).assume_init_ref()
|
||||
})
|
||||
}
|
||||
}
|
|
@ -64,7 +64,7 @@ where
|
|||
) -> VerifyResult {
|
||||
let zalsa = db.zalsa();
|
||||
let data = <super::IngredientImpl<C>>::data(zalsa.table(), input);
|
||||
let field_changed_at = data.revisions[self.field_index];
|
||||
let field_changed_at = data.revisions[self.field_index].load();
|
||||
VerifyResult::changed_if(field_changed_at > revision)
|
||||
}
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ use std::path::PathBuf;
|
|||
#[cfg(feature = "rayon")]
|
||||
use rayon::iter::Either;
|
||||
|
||||
use crate::revision::AtomicRevision;
|
||||
use crate::sync::Arc;
|
||||
use crate::Revision;
|
||||
|
||||
|
@ -104,12 +105,12 @@ where
|
|||
/// and updates `*old_revision` with `new_revision.` Used for fields
|
||||
/// tagged with `#[no_eq]`
|
||||
pub fn always_update<T>(
|
||||
old_revision: &mut Revision,
|
||||
old_revision: &mut AtomicRevision,
|
||||
new_revision: Revision,
|
||||
old_pointer: &mut T,
|
||||
new_value: T,
|
||||
) {
|
||||
*old_revision = new_revision;
|
||||
old_revision.store(new_revision);
|
||||
*old_pointer = new_value;
|
||||
}
|
||||
|
||||
|
|
74
tests/tracked_struct_late_fields.rs
Normal file
74
tests/tracked_struct_late_fields.rs
Normal file
|
@ -0,0 +1,74 @@
|
|||
mod common;
|
||||
|
||||
use salsa::{Database, Setter};
|
||||
|
||||
// A tracked struct with mixed tracked and untracked fields to ensure
|
||||
// the correct field indices are used when tracking dependencies.
|
||||
#[salsa::tracked(debug)]
|
||||
struct TrackedWithLateField<'db> {
|
||||
untracked_1: usize,
|
||||
|
||||
#[late]
|
||||
tracked_1: usize,
|
||||
|
||||
#[late]
|
||||
tracked_2: usize,
|
||||
|
||||
untracked_2: usize,
|
||||
|
||||
untracked_3: usize,
|
||||
|
||||
untracked_4: usize,
|
||||
}
|
||||
|
||||
#[salsa::input]
|
||||
struct MyInput {
|
||||
field1: usize,
|
||||
field2: usize,
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn intermediate(db: &dyn salsa::Database, input: MyInput) -> TrackedWithLateField<'_> {
|
||||
input.field1(db);
|
||||
input.field2(db);
|
||||
let t = TrackedWithLateField::new(db, 0, 1, 2, 3);
|
||||
t.set_tracked_1(db, input.field1(db));
|
||||
t.set_tracked_2(db, input.field2(db));
|
||||
t
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn accumulate(db: &dyn salsa::Database, input: MyInput) -> (usize, usize) {
|
||||
let tracked = intermediate(db, input);
|
||||
let one = read_tracked_1(db, tracked);
|
||||
let two = read_tracked_2(db, tracked);
|
||||
|
||||
(one, two)
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn read_tracked_1<'db>(db: &'db dyn Database, tracked: TrackedWithLateField<'db>) -> usize {
|
||||
tracked.tracked_1(db)
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn read_tracked_2<'db>(db: &'db dyn Database, tracked: TrackedWithLateField<'db>) -> usize {
|
||||
tracked.tracked_2(db)
|
||||
}
|
||||
|
||||
#[test_log::test]
|
||||
fn execute() {
|
||||
let mut db = salsa::DatabaseImpl::default();
|
||||
let input = MyInput::new(&db, 1, 1);
|
||||
|
||||
assert_eq!(accumulate(&db, input), (1, 1));
|
||||
|
||||
// Should only re-execute `read_tracked_1`.
|
||||
input.set_field1(&mut db).to(2);
|
||||
input.set_field2(&mut db).to(1);
|
||||
assert_eq!(accumulate(&db, input), (2, 1));
|
||||
|
||||
// Should only re-execute `read_tracked_2`.
|
||||
input.set_field2(&mut db).to(2);
|
||||
assert_eq!(accumulate(&db, input), (2, 2));
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue