replace arc-swap with manual AtomicPtr

This commit is contained in:
Ibraheem Ahmed 2025-02-19 00:32:20 -05:00 committed by David Barsky
parent 26aeeec9b1
commit 2a751d559b
9 changed files with 166 additions and 149 deletions

View file

@ -12,7 +12,6 @@ rust-version = "1.80"
salsa-macro-rules = { version = "0.1.0", path = "components/salsa-macro-rules" } salsa-macro-rules = { version = "0.1.0", path = "components/salsa-macro-rules" }
salsa-macros = { version = "0.18.0", path = "components/salsa-macros" } salsa-macros = { version = "0.18.0", path = "components/salsa-macros" }
arc-swap = "1"
boxcar = "0.2.9" boxcar = "0.2.9"
crossbeam-queue = "0.3.11" crossbeam-queue = "0.3.11"
dashmap = { version = "6", features = ["raw-api"] } dashmap = { version = "6", features = ["raw-api"] }

View file

@ -1,4 +1,4 @@
use std::{any::Any, fmt, mem::ManuallyDrop, sync::Arc}; use std::{any::Any, fmt, ptr::NonNull};
use crate::{ use crate::{
accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}, accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues},
@ -176,10 +176,14 @@ where
memo: memo::Memo<C::Output<'db>>, memo: memo::Memo<C::Output<'db>>,
memo_ingredient_index: MemoIngredientIndex, memo_ingredient_index: MemoIngredientIndex,
) -> &'db memo::Memo<C::Output<'db>> { ) -> &'db memo::Memo<C::Output<'db>> {
let memo = Arc::new(memo); // We convert to a `NonNull` here as soon as possible because we are going to alias
// into the `Box`, which is a `noalias` type.
let memo = unsafe { NonNull::new_unchecked(Box::into_raw(Box::new(memo))) };
// Unsafety conditions: memo must be in the map (it's not yet, but it will be by the time this // Unsafety conditions: memo must be in the map (it's not yet, but it will be by the time this
// value is returned) and anything removed from map is added to deleted entries (ensured elsewhere). // value is returned) and anything removed from map is added to deleted entries (ensured elsewhere).
let db_memo = unsafe { self.extend_memo_lifetime(&memo) }; let db_memo = unsafe { self.extend_memo_lifetime(memo.as_ref()) };
// Safety: We delay the drop of `old_value` until a new revision starts which ensures no // Safety: We delay the drop of `old_value` until a new revision starts which ensures no
// references will exist for the memo contents. // references will exist for the memo contents.
if let Some(old_value) = if let Some(old_value) =
@ -187,8 +191,10 @@ where
{ {
// In case there is a reference to the old memo out there, we have to store it // In case there is a reference to the old memo out there, we have to store it
// in the deleted entries. This will get cleared when a new revision starts. // in the deleted entries. This will get cleared when a new revision starts.
self.deleted_entries //
.push(ManuallyDrop::into_inner(old_value)); // SAFETY: Once the revision starts, there will be no oustanding borrows to the
// memo contents, and so it will be safe to free.
unsafe { self.deleted_entries.push(old_value) };
} }
db_memo db_memo
} }
@ -254,7 +260,6 @@ where
let ingredient_index = table.ingredient_index(evict); let ingredient_index = table.ingredient_index(evict);
Self::evict_value_from_memo_for( Self::evict_value_from_memo_for(
table.memos_mut(evict), table.memos_mut(evict),
&self.deleted_entries,
self.memo_ingredient_indices.get(ingredient_index), self.memo_ingredient_indices.get(ingredient_index),
) )
}); });

View file

@ -1,14 +1,20 @@
use std::ptr::NonNull;
use crossbeam_queue::SegQueue; use crossbeam_queue::SegQueue;
use super::{memo::ArcMemo, Configuration}; use super::memo::Memo;
use super::Configuration;
/// Stores the list of memos that have been deleted so they can be freed /// Stores the list of memos that have been deleted so they can be freed
/// once the next revision starts. See the comment on the field /// once the next revision starts. See the comment on the field
/// `deleted_entries` of [`FunctionIngredient`][] for more details. /// `deleted_entries` of [`FunctionIngredient`][] for more details.
pub(super) struct DeletedEntries<C: Configuration> { pub(super) struct DeletedEntries<C: Configuration> {
seg_queue: SegQueue<ArcMemo<'static, C>>, seg_queue: SegQueue<SharedBox<Memo<C::Output<'static>>>>,
} }
unsafe impl<C: Configuration> Send for DeletedEntries<C> {}
unsafe impl<C: Configuration> Sync for DeletedEntries<C> {}
impl<C: Configuration> Default for DeletedEntries<C> { impl<C: Configuration> Default for DeletedEntries<C> {
fn default() -> Self { fn default() -> Self {
Self { Self {
@ -18,8 +24,29 @@ impl<C: Configuration> Default for DeletedEntries<C> {
} }
impl<C: Configuration> DeletedEntries<C> { impl<C: Configuration> DeletedEntries<C> {
pub(super) fn push(&self, memo: ArcMemo<'_, C>) { /// # Safety
let memo = unsafe { std::mem::transmute::<ArcMemo<'_, C>, ArcMemo<'static, C>>(memo) }; ///
self.seg_queue.push(memo); /// The memo must be valid and safe to free when the `DeletedEntries` list is dropped.
pub(super) unsafe fn push(&self, memo: NonNull<Memo<C::Output<'_>>>) {
let memo = unsafe {
std::mem::transmute::<NonNull<Memo<C::Output<'_>>>, NonNull<Memo<C::Output<'static>>>>(
memo,
)
};
self.seg_queue.push(SharedBox(memo));
}
}
/// A wrapper around `NonNull` that frees the allocation when it is dropped.
///
/// `crossbeam::SegQueue` does not expose mutable accessors so we have to create
/// a wrapper to run code during `Drop`.
struct SharedBox<T>(NonNull<T>);
impl<T> Drop for SharedBox<T> {
fn drop(&mut self) {
// SAFETY: Guaranteed by the caller of `DeletedEntries::push`.
unsafe { drop(Box::from_raw(self.0.as_ptr())) };
} }
} }

View file

@ -1,5 +1,3 @@
use std::sync::Arc;
use crate::{ use crate::{
zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Cycle, Database, Event, EventKind, zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Cycle, Database, Event, EventKind,
}; };
@ -23,7 +21,7 @@ where
&'db self, &'db self,
db: &'db C::DbView, db: &'db C::DbView,
active_query: ActiveQueryGuard<'_>, active_query: ActiveQueryGuard<'_>,
opt_old_memo: Option<Arc<Memo<C::Output<'_>>>>, opt_old_memo: Option<&Memo<C::Output<'_>>>,
) -> &'db Memo<C::Output<'db>> { ) -> &'db Memo<C::Output<'db>> {
let zalsa = db.zalsa(); let zalsa = db.zalsa();
let revision_now = zalsa.current_revision(); let revision_now = zalsa.current_revision();

View file

@ -38,7 +38,6 @@ where
MaybeChangedAfter::No(memo.revisions.accumulated_inputs.load()) MaybeChangedAfter::No(memo.revisions.accumulated_inputs.load())
}; };
} }
drop(memo_guard); // release the arc-swap guard before cold path
if let Some(mcs) = if let Some(mcs) =
self.maybe_changed_after_cold(zalsa, db, id, revision, memo_ingredient_index) self.maybe_changed_after_cold(zalsa, db, id, revision, memo_ingredient_index)
{ {
@ -86,7 +85,7 @@ where
); );
// Check if the inputs are still valid. We can just compare `changed_at`. // Check if the inputs are still valid. We can just compare `changed_at`.
if self.deep_verify_memo(db, zalsa, &old_memo, &active_query) { if self.deep_verify_memo(db, zalsa, old_memo, &active_query) {
return Some(if old_memo.revisions.changed_at > revision { return Some(if old_memo.revisions.changed_at > revision {
MaybeChangedAfter::Yes MaybeChangedAfter::Yes
} else { } else {

View file

@ -1,11 +1,9 @@
use std::any::Any; use std::any::Any;
use std::fmt::Debug; use std::fmt::Debug;
use std::fmt::Formatter; use std::fmt::Formatter;
use std::mem::ManuallyDrop; use std::ptr::NonNull;
use std::sync::Arc;
use crate::accumulator::accumulated_map::InputAccumulatedValues; use crate::accumulator::accumulated_map::InputAccumulatedValues;
use crate::function::DeletedEntries;
use crate::revision::AtomicRevision; use crate::revision::AtomicRevision;
use crate::table::memo::MemoTable; use crate::table::memo::MemoTable;
use crate::zalsa::MemoIngredientIndex; use crate::zalsa::MemoIngredientIndex;
@ -17,21 +15,33 @@ use crate::{
use super::{Configuration, IngredientImpl}; use super::{Configuration, IngredientImpl};
#[allow(type_alias_bounds)]
pub(super) type ArcMemo<'lt, C: Configuration> = Arc<Memo<<C as Configuration>::Output<'lt>>>;
impl<C: Configuration> IngredientImpl<C> { impl<C: Configuration> IngredientImpl<C> {
/// Memos have to be stored internally using `'static` as the database lifetime. /// Memos have to be stored internally using `'static` as the database lifetime.
/// This (unsafe) function call converts from something tied to self to static. /// This (unsafe) function call converts from something tied to self to static.
/// Values transmuted this way have to be transmuted back to being tied to self /// Values transmuted this way have to be transmuted back to being tied to self
/// when they are returned to the user. /// when they are returned to the user.
unsafe fn to_static<'db>(&'db self, memo: ArcMemo<'db, C>) -> ArcMemo<'static, C> { unsafe fn to_static<'db>(
unsafe { std::mem::transmute(memo) } &'db self,
memo: NonNull<Memo<C::Output<'db>>>,
) -> NonNull<Memo<C::Output<'static>>> {
memo.cast()
}
/// Convert from an internal memo (which uses `'static``) to one tied to self
/// so it can be publicly released.
unsafe fn to_self<'db>(
&'db self,
memo: NonNull<Memo<C::Output<'static>>>,
) -> NonNull<Memo<C::Output<'db>>> {
memo.cast()
} }
/// Convert from an internal memo (which uses `'static`) to one tied to self /// Convert from an internal memo (which uses `'static`) to one tied to self
/// so it can be publicly released. /// so it can be publicly released.
unsafe fn to_self<'db>(&'db self, memo: ArcMemo<'static, C>) -> ArcMemo<'db, C> { unsafe fn to_self_ref<'db>(
&'db self,
memo: &'db Memo<C::Output<'static>>,
) -> &'db Memo<C::Output<'db>> {
unsafe { std::mem::transmute(memo) } unsafe { std::mem::transmute(memo) }
} }
@ -45,17 +55,16 @@ impl<C: Configuration> IngredientImpl<C> {
&'db self, &'db self,
zalsa: &'db Zalsa, zalsa: &'db Zalsa,
id: Id, id: Id,
memo: ArcMemo<'db, C>, memo: NonNull<Memo<C::Output<'db>>>,
memo_ingredient_index: MemoIngredientIndex, memo_ingredient_index: MemoIngredientIndex,
) -> Option<ManuallyDrop<ArcMemo<'db, C>>> { ) -> Option<NonNull<Memo<C::Output<'db>>>> {
let static_memo = unsafe { self.to_static(memo) }; let static_memo = unsafe { self.to_static(memo) };
let old_static_memo = unsafe { let old_static_memo = unsafe {
zalsa zalsa
.memo_table_for(id) .memo_table_for(id)
.insert(memo_ingredient_index, static_memo) .insert(memo_ingredient_index, static_memo)
}?; }?;
let old_static_memo = ManuallyDrop::into_inner(old_static_memo); Some(unsafe { self.to_self(old_static_memo) })
Some(ManuallyDrop::new(unsafe { self.to_self(old_static_memo) }))
} }
/// Loads the current memo for `key_index`. This does not hold any sort of /// Loads the current memo for `key_index`. This does not hold any sort of
@ -66,9 +75,10 @@ impl<C: Configuration> IngredientImpl<C> {
zalsa: &'db Zalsa, zalsa: &'db Zalsa,
id: Id, id: Id,
memo_ingredient_index: MemoIngredientIndex, memo_ingredient_index: MemoIngredientIndex,
) -> Option<ArcMemo<'db, C>> { ) -> Option<&'db Memo<C::Output<'db>>> {
let static_memo = zalsa.memo_table_for(id).get(memo_ingredient_index)?; let static_memo = zalsa.memo_table_for(id).get(memo_ingredient_index)?;
unsafe { Some(self.to_self(static_memo)) }
unsafe { Some(self.to_self_ref(static_memo)) }
} }
/// Evicts the existing memo for the given key, replacing it /// Evicts the existing memo for the given key, replacing it
@ -76,10 +86,9 @@ impl<C: Configuration> IngredientImpl<C> {
/// or has values assigned as output of another query, this has no effect. /// or has values assigned as output of another query, this has no effect.
pub(super) fn evict_value_from_memo_for( pub(super) fn evict_value_from_memo_for(
table: &mut MemoTable, table: &mut MemoTable,
deleted_entries: &DeletedEntries<C>,
memo_ingredient_index: MemoIngredientIndex, memo_ingredient_index: MemoIngredientIndex,
) { ) {
let map = |memo: ArcMemo<'static, C>| -> ArcMemo<'static, C> { let map = |memo: &mut Memo<C::Output<'static>>| {
match &memo.revisions.origin { match &memo.revisions.origin {
QueryOrigin::Assigned(_) QueryOrigin::Assigned(_)
| QueryOrigin::DerivedUntracked(_) | QueryOrigin::DerivedUntracked(_)
@ -88,43 +97,15 @@ impl<C: Configuration> IngredientImpl<C> {
// assigned as output of another query // assigned as output of another query
// or those with untracked inputs // or those with untracked inputs
// as their values cannot be reconstructed. // as their values cannot be reconstructed.
memo
} }
QueryOrigin::Derived(_) => { QueryOrigin::Derived(_) => {
// Note that we cannot use `Arc::get_mut` here as the use of `ArcSwap` makes it // Set the memo value to `None`.
// impossible to get unique access to the interior Arc memo.value = None;
// QueryRevisions: !Clone to discourage cloning, we need it here though
let &QueryRevisions {
changed_at,
durability,
ref origin,
ref tracked_struct_ids,
ref accumulated,
ref accumulated_inputs,
} = &memo.revisions;
// Re-assemble the memo but with the value set to `None`
Arc::new(Memo::new(
None,
memo.verified_at.load(),
QueryRevisions {
changed_at,
durability,
origin: origin.clone(),
tracked_struct_ids: tracked_struct_ids.clone(),
accumulated: accumulated.clone(),
accumulated_inputs: accumulated_inputs.clone(),
},
))
} }
} }
}; };
// SAFETY: We queue the old value for deletion, delaying its drop until the next revision bump.
let old = unsafe { table.map_memo(memo_ingredient_index, map) }; table.map_memo(memo_ingredient_index, map)
if let Some(old) = old {
// In case there is a reference to the old memo out there, we have to store it
// in the deleted entries. This will get cleared when a new revision starts.
deleted_entries.push(ManuallyDrop::into_inner(old));
}
} }
} }

View file

@ -75,8 +75,8 @@ where
let memo_ingredient_index = self.memo_ingredient_index(zalsa, key); let memo_ingredient_index = self.memo_ingredient_index(zalsa, key);
if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key, memo_ingredient_index) { if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key, memo_ingredient_index) {
self.backdate_if_appropriate(&old_memo, &mut revisions, &value); self.backdate_if_appropriate(old_memo, &mut revisions, &value);
self.diff_outputs(zalsa, db, database_key_index, &old_memo, &mut revisions); self.diff_outputs(zalsa, db, database_key_index, old_memo, &mut revisions);
} }
let memo = Memo { let memo = Memo {

View file

@ -1,11 +1,10 @@
use std::{ use std::{
any::{Any, TypeId}, any::{Any, TypeId},
fmt::Debug, fmt::Debug,
mem::ManuallyDrop, ptr::NonNull,
sync::Arc, sync::atomic::{AtomicPtr, Ordering},
}; };
use arc_swap::ArcSwap;
use parking_lot::RwLock; use parking_lot::RwLock;
use crate::{zalsa::MemoIngredientIndex, zalsa_local::QueryOrigin}; use crate::{zalsa::MemoIngredientIndex, zalsa_local::QueryOrigin};
@ -32,7 +31,7 @@ struct MemoEntry {
} }
/// Data for a memoized entry. /// Data for a memoized entry.
/// This is a type-erased `Arc<M>`, where `M` is the type of memo associated /// This is a type-erased `Box<M>`, where `M` is the type of memo associated
/// with that particular ingredient index. /// with that particular ingredient index.
/// ///
/// # Implementation note /// # Implementation note
@ -42,9 +41,9 @@ struct MemoEntry {
/// Therefore, once a given entry goes from `Empty` to `Full`, /// Therefore, once a given entry goes from `Empty` to `Full`,
/// the type-id associated with that entry should never change. /// the type-id associated with that entry should never change.
/// ///
/// We take advantage of this and use an `ArcSwap` to store the actual memo. /// We take advantage of this and use an `AtomicPtr` to store the actual memo.
/// This allows us to store into the memo-entry without acquiring a write-lock. /// This allows us to store into the memo-entry without acquiring a write-lock.
/// However, using `ArcSwap` means we cannot use a `Arc<dyn Any>` or any other wide pointer. /// However, using `AtomicPtr` means we cannot use a `Box<dyn Any>` or any other wide pointer.
/// Therefore, we hide the type by transmuting to `DummyMemo`; but we must then be very careful /// Therefore, we hide the type by transmuting to `DummyMemo`; but we must then be very careful
/// when freeing `MemoEntryData` values to transmute things back. See the `Drop` impl for /// when freeing `MemoEntryData` values to transmute things back. See the `Drop` impl for
/// [`MemoEntry`][] for details. /// [`MemoEntry`][] for details.
@ -52,43 +51,45 @@ struct MemoEntryData {
/// The `type_id` of the erased memo type `M` /// The `type_id` of the erased memo type `M`
type_id: TypeId, type_id: TypeId,
/// A pointer to `std::mem::drop::<Arc<M>>` for the erased memo type `M` /// A type-coercion function for the erased memo type `M`
to_dyn_fn: fn(Arc<DummyMemo>) -> Arc<dyn Memo>, to_dyn_fn: fn(NonNull<DummyMemo>) -> NonNull<dyn Memo>,
/// An [`ArcSwap`][] to a `Arc<M>` for the erased memo type `M` /// An [`AtomicPtr`][] to a `Box<M>` for the erased memo type `M`
arc_swap: ArcSwap<DummyMemo>, atomic_memo: AtomicPtr<DummyMemo>,
} }
/// Dummy placeholder type that we use when erasing the memo type `M` in [`MemoEntryData`][]. /// Dummy placeholder type that we use when erasing the memo type `M` in [`MemoEntryData`][].
struct DummyMemo {} struct DummyMemo {}
impl MemoTable { impl MemoTable {
fn to_dummy<M: Memo>(memo: Arc<M>) -> Arc<DummyMemo> { fn to_dummy<M: Memo>(memo: NonNull<M>) -> NonNull<DummyMemo> {
unsafe { std::mem::transmute::<Arc<M>, Arc<DummyMemo>>(memo) } memo.cast()
} }
unsafe fn from_dummy<M: Memo>(memo: Arc<DummyMemo>) -> Arc<M> { unsafe fn from_dummy<M: Memo>(memo: NonNull<DummyMemo>) -> NonNull<M> {
unsafe { std::mem::transmute::<Arc<DummyMemo>, Arc<M>>(memo) } memo.cast()
} }
fn to_dyn_fn<M: Memo>() -> fn(Arc<DummyMemo>) -> Arc<dyn Memo> { fn to_dyn_fn<M: Memo>() -> fn(NonNull<DummyMemo>) -> NonNull<dyn Memo> {
let f: fn(Arc<M>) -> Arc<dyn Memo> = |x| x; let f: fn(NonNull<M>) -> NonNull<dyn Memo> = |x| x;
unsafe { unsafe {
std::mem::transmute::<fn(Arc<M>) -> Arc<dyn Memo>, fn(Arc<DummyMemo>) -> Arc<dyn Memo>>( std::mem::transmute::<
f, fn(NonNull<M>) -> NonNull<dyn Memo>,
) fn(NonNull<DummyMemo>) -> NonNull<dyn Memo>,
>(f)
} }
} }
/// # Safety /// # Safety
/// ///
/// The caller needs to make sure to not drop the returned value until no more references into /// The caller needs to make sure to not free the returned value until no more references into
/// the database exist as there may be outstanding borrows into the `Arc` contents. /// the database exist as there may be outstanding borrows into the pointer contents.
pub(crate) unsafe fn insert<M: Memo>( pub(crate) unsafe fn insert<M: Memo>(
&self, &self,
memo_ingredient_index: MemoIngredientIndex, memo_ingredient_index: MemoIngredientIndex,
memo: Arc<M>, memo: NonNull<M>,
) -> Option<ManuallyDrop<Arc<M>>> { ) -> Option<NonNull<M>> {
// If the memo slot is already occupied, it must already have the // If the memo slot is already occupied, it must already have the
// right type info etc, and we only need the read-lock. // right type info etc, and we only need the read-lock.
if let Some(MemoEntry { if let Some(MemoEntry {
@ -96,7 +97,7 @@ impl MemoTable {
Some(MemoEntryData { Some(MemoEntryData {
type_id, type_id,
to_dyn_fn: _, to_dyn_fn: _,
arc_swap, atomic_memo,
}), }),
}) = self.memos.read().get(memo_ingredient_index.as_usize()) }) = self.memos.read().get(memo_ingredient_index.as_usize())
{ {
@ -105,8 +106,14 @@ impl MemoTable {
TypeId::of::<M>(), TypeId::of::<M>(),
"inconsistent type-id for `{memo_ingredient_index:?}`" "inconsistent type-id for `{memo_ingredient_index:?}`"
); );
let old_memo = arc_swap.swap(Self::to_dummy(memo));
return Some(ManuallyDrop::new(unsafe { Self::from_dummy(old_memo) })); let old_memo = atomic_memo.swap(Self::to_dummy(memo).as_ptr(), Ordering::AcqRel);
// SAFETY: The `atomic_memo` field is never null.
let old_memo = unsafe { NonNull::new_unchecked(old_memo) };
// SAFETY: `type_id` check asserted above
return Some(unsafe { Self::from_dummy(old_memo) });
} }
// Otherwise we need the write lock. // Otherwise we need the write lock.
@ -116,13 +123,13 @@ impl MemoTable {
/// # Safety /// # Safety
/// ///
/// The caller needs to make sure to not drop the returned value until no more references into /// The caller needs to make sure to not free the returned value until no more references into
/// the database exist as there may be outstanding borrows into the `Arc` contents. /// the database exist as there may be outstanding borrows into the pointer contents.
unsafe fn insert_cold<M: Memo>( unsafe fn insert_cold<M: Memo>(
&self, &self,
memo_ingredient_index: MemoIngredientIndex, memo_ingredient_index: MemoIngredientIndex,
memo: Arc<M>, memo: NonNull<M>,
) -> Option<ManuallyDrop<Arc<M>>> { ) -> Option<NonNull<M>> {
let mut memos = self.memos.write(); let mut memos = self.memos.write();
let memo_ingredient_index = memo_ingredient_index.as_usize(); let memo_ingredient_index = memo_ingredient_index.as_usize();
if memos.len() < memo_ingredient_index + 1 { if memos.len() < memo_ingredient_index + 1 {
@ -133,24 +140,22 @@ impl MemoTable {
Some(MemoEntryData { Some(MemoEntryData {
type_id: TypeId::of::<M>(), type_id: TypeId::of::<M>(),
to_dyn_fn: Self::to_dyn_fn::<M>(), to_dyn_fn: Self::to_dyn_fn::<M>(),
arc_swap: ArcSwap::new(Self::to_dummy(memo)), atomic_memo: AtomicPtr::new(Self::to_dummy(memo).as_ptr()),
}), }),
); );
old_entry old_entry.map(
.map(
|MemoEntryData { |MemoEntryData {
type_id: _, type_id: _,
to_dyn_fn: _, to_dyn_fn: _,
arc_swap, atomic_memo,
}| unsafe { Self::from_dummy(arc_swap.into_inner()) }, }| unsafe {
// SAFETY: The `atomic_memo` field is never null.
Self::from_dummy(NonNull::new_unchecked(atomic_memo.into_inner()))
},
) )
.map(ManuallyDrop::new)
} }
pub(crate) fn get<M: Memo>( pub(crate) fn get<M: Memo>(&self, memo_ingredient_index: MemoIngredientIndex) -> Option<&M> {
&self,
memo_ingredient_index: MemoIngredientIndex,
) -> Option<Arc<M>> {
let memos = self.memos.read(); let memos = self.memos.read();
let Some(MemoEntry { let Some(MemoEntry {
@ -158,7 +163,7 @@ impl MemoTable {
Some(MemoEntryData { Some(MemoEntryData {
type_id, type_id,
to_dyn_fn: _, to_dyn_fn: _,
arc_swap, atomic_memo,
}), }),
}) = memos.get(memo_ingredient_index.as_usize()) }) = memos.get(memo_ingredient_index.as_usize())
else { else {
@ -171,57 +176,54 @@ impl MemoTable {
"inconsistent type-id for `{memo_ingredient_index:?}`" "inconsistent type-id for `{memo_ingredient_index:?}`"
); );
// SAFETY: type_id check asserted above // SAFETY: The `atomic_memo` field is never null.
unsafe { Some(Self::from_dummy(arc_swap.load_full())) } let memo = unsafe { NonNull::new_unchecked(atomic_memo.load(Ordering::Acquire)) };
// SAFETY: `type_id` check asserted above
unsafe { Some(Self::from_dummy(memo).as_ref()) }
} }
/// Calls `f` on the memo at `memo_ingredient_index` and replaces the memo with the result of `f`. /// Calls `f` on the memo at `memo_ingredient_index`.
///
/// If the memo is not present, `f` is not called. /// If the memo is not present, `f` is not called.
/// pub(crate) fn map_memo<M: Memo>(
/// # Safety
///
/// The caller needs to make sure to not drop the returned value until no more references into
/// the database exist as there may be outstanding borrows into the `Arc` contents.
pub(crate) unsafe fn map_memo<M: Memo>(
&mut self, &mut self,
memo_ingredient_index: MemoIngredientIndex, memo_ingredient_index: MemoIngredientIndex,
f: impl FnOnce(Arc<M>) -> Arc<M>, f: impl FnOnce(&mut M),
) -> Option<ManuallyDrop<Arc<M>>> { ) {
let memos = self.memos.get_mut(); let memos = self.memos.get_mut();
let Some(MemoEntry { let Some(MemoEntry {
data: data:
Some(MemoEntryData { Some(MemoEntryData {
type_id, type_id,
to_dyn_fn: _, to_dyn_fn: _,
arc_swap, atomic_memo,
}), }),
}) = memos.get(memo_ingredient_index.as_usize()) }) = memos.get_mut(memo_ingredient_index.as_usize())
else { else {
return None; return;
}; };
assert_eq!( assert_eq!(
*type_id, *type_id,
TypeId::of::<M>(), TypeId::of::<M>(),
"inconsistent type-id for `{memo_ingredient_index:?}`" "inconsistent type-id for `{memo_ingredient_index:?}`"
); );
// arc-swap does not expose accessing the interior mutably at all unfortunately
// https://github.com/vorner/arc-swap/issues/131 // SAFETY: The `atomic_memo` field is never null.
// so we are required to allocate a new arc within `f` instead of being able let memo = unsafe { NonNull::new_unchecked(*atomic_memo.get_mut()) };
// to swap out the interior
// SAFETY: type_id check asserted above // SAFETY: `type_id` check asserted above
let memo = f(unsafe { Self::from_dummy(arc_swap.load_full()) }); f(unsafe { Self::from_dummy(memo).as_mut() });
Some(ManuallyDrop::new(unsafe {
Self::from_dummy::<M>(arc_swap.swap(Self::to_dummy(memo)))
}))
} }
/// # Safety /// # Safety
/// ///
/// The caller needs to make sure to not drop the returned value until no more references into /// The caller needs to make sure to not call this function until no more references into
/// the database exist as there may be outstanding borrows into the `Arc` contents. /// the database exist as there may be outstanding borrows into the pointer contents.
pub(crate) unsafe fn into_memos( pub(crate) unsafe fn into_memos(
self, self,
) -> impl Iterator<Item = (MemoIngredientIndex, ManuallyDrop<Arc<dyn Memo>>)> { ) -> impl Iterator<Item = (MemoIngredientIndex, Box<dyn Memo>)> {
self.memos self.memos
.into_inner() .into_inner()
.into_iter() .into_iter()
@ -232,14 +234,17 @@ impl MemoTable {
MemoEntryData { MemoEntryData {
type_id: _, type_id: _,
to_dyn_fn, to_dyn_fn,
arc_swap, atomic_memo,
}, },
index, index,
)| { )| {
( // SAFETY: The `atomic_memo` field is never null.
MemoIngredientIndex::from_usize(index), let memo =
ManuallyDrop::new(to_dyn_fn(arc_swap.into_inner())), unsafe { to_dyn_fn(NonNull::new_unchecked(atomic_memo.into_inner())) };
) // SAFETY: The caller guarantees that there are no outstanding borrows into the `Box` contents.
let memo = unsafe { Box::from_raw(memo.as_ptr()) };
(MemoIngredientIndex::from_usize(index), memo)
}, },
) )
} }
@ -250,11 +255,14 @@ impl Drop for MemoEntry {
if let Some(MemoEntryData { if let Some(MemoEntryData {
type_id: _, type_id: _,
to_dyn_fn, to_dyn_fn,
arc_swap, atomic_memo,
}) = self.data.take() }) = self.data.take()
{ {
let arc = arc_swap.into_inner(); // SAFETY: The `atomic_memo` field is never null.
std::mem::drop(to_dyn_fn(arc)); let memo = unsafe { to_dyn_fn(NonNull::new_unchecked(atomic_memo.into_inner())) };
// SAFETY: We have `&mut self`, so there are no outstanding borrows into the `Box` contents.
let memo = unsafe { Box::from_raw(memo.as_ptr()) };
std::mem::drop(memo);
} }
} }
} }

View file

@ -1,4 +1,4 @@
use std::{any::TypeId, fmt, hash::Hash, marker::PhantomData, mem::ManuallyDrop, ops::DerefMut}; use std::{any::TypeId, fmt, hash::Hash, marker::PhantomData, ops::DerefMut};
use crossbeam_queue::SegQueue; use crossbeam_queue::SegQueue;
use tracked_field::FieldIngredientImpl; use tracked_field::FieldIngredientImpl;
@ -617,10 +617,10 @@ where
// Take the memo table. This is safe because we have modified `data_ref.updated_at` to `None` // Take the memo table. This is safe because we have modified `data_ref.updated_at` to `None`
// and the code that references the memo-table has a read-lock. // and the code that references the memo-table has a read-lock.
let memo_table = unsafe { (*data).take_memo_table() }; let memo_table = unsafe { (*data).take_memo_table() };
// SAFETY: We have verified that no more references to these memos exist and so we are good // SAFETY: We have verified that no more references to these memos exist and so we are good
// to drop them. // to drop them.
for (memo_ingredient_index, memo) in unsafe { memo_table.into_memos() } { for (memo_ingredient_index, memo) in unsafe { memo_table.into_memos() } {
let memo = ManuallyDrop::into_inner(memo);
let ingredient_index = let ingredient_index =
zalsa.ingredient_index_for_memo(self.ingredient_index, memo_ingredient_index); zalsa.ingredient_index_for_memo(self.ingredient_index, memo_ingredient_index);