From 2a751d559bafc746e5723d9ee35dcb75b04d3de9 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Wed, 19 Feb 2025 00:32:20 -0500 Subject: [PATCH] replace `arc-swap` with manual `AtomicPtr` --- Cargo.toml | 1 - src/function.rs | 17 ++- src/function/delete.rs | 37 ++++++- src/function/execute.rs | 4 +- src/function/maybe_changed_after.rs | 3 +- src/function/memo.rs | 79 +++++-------- src/function/specify.rs | 4 +- src/table/memo.rs | 166 +++++++++++++++------------- src/tracked_struct.rs | 4 +- 9 files changed, 166 insertions(+), 149 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d6550b17..0e505f61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,6 @@ rust-version = "1.80" salsa-macro-rules = { version = "0.1.0", path = "components/salsa-macro-rules" } salsa-macros = { version = "0.18.0", path = "components/salsa-macros" } -arc-swap = "1" boxcar = "0.2.9" crossbeam-queue = "0.3.11" dashmap = { version = "6", features = ["raw-api"] } diff --git a/src/function.rs b/src/function.rs index 3c1f31ce..ba25d96f 100644 --- a/src/function.rs +++ b/src/function.rs @@ -1,4 +1,4 @@ -use std::{any::Any, fmt, mem::ManuallyDrop, sync::Arc}; +use std::{any::Any, fmt, ptr::NonNull}; use crate::{ accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}, @@ -176,10 +176,14 @@ where memo: memo::Memo>, memo_ingredient_index: MemoIngredientIndex, ) -> &'db memo::Memo> { - let memo = Arc::new(memo); + // We convert to a `NonNull` here as soon as possible because we are going to alias + // into the `Box`, which is a `noalias` type. + let memo = unsafe { NonNull::new_unchecked(Box::into_raw(Box::new(memo))) }; + // Unsafety conditions: memo must be in the map (it's not yet, but it will be by the time this // value is returned) and anything removed from map is added to deleted entries (ensured elsewhere). - let db_memo = unsafe { self.extend_memo_lifetime(&memo) }; + let db_memo = unsafe { self.extend_memo_lifetime(memo.as_ref()) }; + // Safety: We delay the drop of `old_value` until a new revision starts which ensures no // references will exist for the memo contents. if let Some(old_value) = @@ -187,8 +191,10 @@ where { // In case there is a reference to the old memo out there, we have to store it // in the deleted entries. This will get cleared when a new revision starts. - self.deleted_entries - .push(ManuallyDrop::into_inner(old_value)); + // + // SAFETY: Once the revision starts, there will be no oustanding borrows to the + // memo contents, and so it will be safe to free. + unsafe { self.deleted_entries.push(old_value) }; } db_memo } @@ -254,7 +260,6 @@ where let ingredient_index = table.ingredient_index(evict); Self::evict_value_from_memo_for( table.memos_mut(evict), - &self.deleted_entries, self.memo_ingredient_indices.get(ingredient_index), ) }); diff --git a/src/function/delete.rs b/src/function/delete.rs index 01507f2e..ed2f0322 100644 --- a/src/function/delete.rs +++ b/src/function/delete.rs @@ -1,14 +1,20 @@ +use std::ptr::NonNull; + use crossbeam_queue::SegQueue; -use super::{memo::ArcMemo, Configuration}; +use super::memo::Memo; +use super::Configuration; /// Stores the list of memos that have been deleted so they can be freed /// once the next revision starts. See the comment on the field /// `deleted_entries` of [`FunctionIngredient`][] for more details. pub(super) struct DeletedEntries { - seg_queue: SegQueue>, + seg_queue: SegQueue>>>, } +unsafe impl Send for DeletedEntries {} +unsafe impl Sync for DeletedEntries {} + impl Default for DeletedEntries { fn default() -> Self { Self { @@ -18,8 +24,29 @@ impl Default for DeletedEntries { } impl DeletedEntries { - pub(super) fn push(&self, memo: ArcMemo<'_, C>) { - let memo = unsafe { std::mem::transmute::, ArcMemo<'static, C>>(memo) }; - self.seg_queue.push(memo); + /// # Safety + /// + /// The memo must be valid and safe to free when the `DeletedEntries` list is dropped. + pub(super) unsafe fn push(&self, memo: NonNull>>) { + let memo = unsafe { + std::mem::transmute::>>, NonNull>>>( + memo, + ) + }; + + self.seg_queue.push(SharedBox(memo)); + } +} + +/// A wrapper around `NonNull` that frees the allocation when it is dropped. +/// +/// `crossbeam::SegQueue` does not expose mutable accessors so we have to create +/// a wrapper to run code during `Drop`. +struct SharedBox(NonNull); + +impl Drop for SharedBox { + fn drop(&mut self) { + // SAFETY: Guaranteed by the caller of `DeletedEntries::push`. + unsafe { drop(Box::from_raw(self.0.as_ptr())) }; } } diff --git a/src/function/execute.rs b/src/function/execute.rs index 843047fa..7c5abdb1 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use crate::{ zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Cycle, Database, Event, EventKind, }; @@ -23,7 +21,7 @@ where &'db self, db: &'db C::DbView, active_query: ActiveQueryGuard<'_>, - opt_old_memo: Option>>>, + opt_old_memo: Option<&Memo>>, ) -> &'db Memo> { let zalsa = db.zalsa(); let revision_now = zalsa.current_revision(); diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 0134ffbe..901c5a50 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -38,7 +38,6 @@ where MaybeChangedAfter::No(memo.revisions.accumulated_inputs.load()) }; } - drop(memo_guard); // release the arc-swap guard before cold path if let Some(mcs) = self.maybe_changed_after_cold(zalsa, db, id, revision, memo_ingredient_index) { @@ -86,7 +85,7 @@ where ); // Check if the inputs are still valid. We can just compare `changed_at`. - if self.deep_verify_memo(db, zalsa, &old_memo, &active_query) { + if self.deep_verify_memo(db, zalsa, old_memo, &active_query) { return Some(if old_memo.revisions.changed_at > revision { MaybeChangedAfter::Yes } else { diff --git a/src/function/memo.rs b/src/function/memo.rs index c6c7d2db..fea95445 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -1,11 +1,9 @@ use std::any::Any; use std::fmt::Debug; use std::fmt::Formatter; -use std::mem::ManuallyDrop; -use std::sync::Arc; +use std::ptr::NonNull; use crate::accumulator::accumulated_map::InputAccumulatedValues; -use crate::function::DeletedEntries; use crate::revision::AtomicRevision; use crate::table::memo::MemoTable; use crate::zalsa::MemoIngredientIndex; @@ -17,21 +15,33 @@ use crate::{ use super::{Configuration, IngredientImpl}; -#[allow(type_alias_bounds)] -pub(super) type ArcMemo<'lt, C: Configuration> = Arc::Output<'lt>>>; - impl IngredientImpl { /// Memos have to be stored internally using `'static` as the database lifetime. /// This (unsafe) function call converts from something tied to self to static. /// Values transmuted this way have to be transmuted back to being tied to self /// when they are returned to the user. - unsafe fn to_static<'db>(&'db self, memo: ArcMemo<'db, C>) -> ArcMemo<'static, C> { - unsafe { std::mem::transmute(memo) } + unsafe fn to_static<'db>( + &'db self, + memo: NonNull>>, + ) -> NonNull>> { + memo.cast() + } + + /// Convert from an internal memo (which uses `'static``) to one tied to self + /// so it can be publicly released. + unsafe fn to_self<'db>( + &'db self, + memo: NonNull>>, + ) -> NonNull>> { + memo.cast() } /// Convert from an internal memo (which uses `'static`) to one tied to self /// so it can be publicly released. - unsafe fn to_self<'db>(&'db self, memo: ArcMemo<'static, C>) -> ArcMemo<'db, C> { + unsafe fn to_self_ref<'db>( + &'db self, + memo: &'db Memo>, + ) -> &'db Memo> { unsafe { std::mem::transmute(memo) } } @@ -45,17 +55,16 @@ impl IngredientImpl { &'db self, zalsa: &'db Zalsa, id: Id, - memo: ArcMemo<'db, C>, + memo: NonNull>>, memo_ingredient_index: MemoIngredientIndex, - ) -> Option>> { + ) -> Option>>> { let static_memo = unsafe { self.to_static(memo) }; let old_static_memo = unsafe { zalsa .memo_table_for(id) .insert(memo_ingredient_index, static_memo) }?; - let old_static_memo = ManuallyDrop::into_inner(old_static_memo); - Some(ManuallyDrop::new(unsafe { self.to_self(old_static_memo) })) + Some(unsafe { self.to_self(old_static_memo) }) } /// Loads the current memo for `key_index`. This does not hold any sort of @@ -66,9 +75,10 @@ impl IngredientImpl { zalsa: &'db Zalsa, id: Id, memo_ingredient_index: MemoIngredientIndex, - ) -> Option> { + ) -> Option<&'db Memo>> { let static_memo = zalsa.memo_table_for(id).get(memo_ingredient_index)?; - unsafe { Some(self.to_self(static_memo)) } + + unsafe { Some(self.to_self_ref(static_memo)) } } /// Evicts the existing memo for the given key, replacing it @@ -76,10 +86,9 @@ impl IngredientImpl { /// or has values assigned as output of another query, this has no effect. pub(super) fn evict_value_from_memo_for( table: &mut MemoTable, - deleted_entries: &DeletedEntries, memo_ingredient_index: MemoIngredientIndex, ) { - let map = |memo: ArcMemo<'static, C>| -> ArcMemo<'static, C> { + let map = |memo: &mut Memo>| { match &memo.revisions.origin { QueryOrigin::Assigned(_) | QueryOrigin::DerivedUntracked(_) @@ -88,43 +97,15 @@ impl IngredientImpl { // assigned as output of another query // or those with untracked inputs // as their values cannot be reconstructed. - memo } QueryOrigin::Derived(_) => { - // Note that we cannot use `Arc::get_mut` here as the use of `ArcSwap` makes it - // impossible to get unique access to the interior Arc - // QueryRevisions: !Clone to discourage cloning, we need it here though - let &QueryRevisions { - changed_at, - durability, - ref origin, - ref tracked_struct_ids, - ref accumulated, - ref accumulated_inputs, - } = &memo.revisions; - // Re-assemble the memo but with the value set to `None` - Arc::new(Memo::new( - None, - memo.verified_at.load(), - QueryRevisions { - changed_at, - durability, - origin: origin.clone(), - tracked_struct_ids: tracked_struct_ids.clone(), - accumulated: accumulated.clone(), - accumulated_inputs: accumulated_inputs.clone(), - }, - )) + // Set the memo value to `None`. + memo.value = None; } } }; - // SAFETY: We queue the old value for deletion, delaying its drop until the next revision bump. - let old = unsafe { table.map_memo(memo_ingredient_index, map) }; - if let Some(old) = old { - // In case there is a reference to the old memo out there, we have to store it - // in the deleted entries. This will get cleared when a new revision starts. - deleted_entries.push(ManuallyDrop::into_inner(old)); - } + + table.map_memo(memo_ingredient_index, map) } } diff --git a/src/function/specify.rs b/src/function/specify.rs index 98f915d0..4338e915 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -75,8 +75,8 @@ where let memo_ingredient_index = self.memo_ingredient_index(zalsa, key); if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key, memo_ingredient_index) { - self.backdate_if_appropriate(&old_memo, &mut revisions, &value); - self.diff_outputs(zalsa, db, database_key_index, &old_memo, &mut revisions); + self.backdate_if_appropriate(old_memo, &mut revisions, &value); + self.diff_outputs(zalsa, db, database_key_index, old_memo, &mut revisions); } let memo = Memo { diff --git a/src/table/memo.rs b/src/table/memo.rs index c14e1337..70eef215 100644 --- a/src/table/memo.rs +++ b/src/table/memo.rs @@ -1,11 +1,10 @@ use std::{ any::{Any, TypeId}, fmt::Debug, - mem::ManuallyDrop, - sync::Arc, + ptr::NonNull, + sync::atomic::{AtomicPtr, Ordering}, }; -use arc_swap::ArcSwap; use parking_lot::RwLock; use crate::{zalsa::MemoIngredientIndex, zalsa_local::QueryOrigin}; @@ -32,7 +31,7 @@ struct MemoEntry { } /// Data for a memoized entry. -/// This is a type-erased `Arc`, where `M` is the type of memo associated +/// This is a type-erased `Box`, where `M` is the type of memo associated /// with that particular ingredient index. /// /// # Implementation note @@ -42,9 +41,9 @@ struct MemoEntry { /// Therefore, once a given entry goes from `Empty` to `Full`, /// the type-id associated with that entry should never change. /// -/// We take advantage of this and use an `ArcSwap` to store the actual memo. +/// We take advantage of this and use an `AtomicPtr` to store the actual memo. /// This allows us to store into the memo-entry without acquiring a write-lock. -/// However, using `ArcSwap` means we cannot use a `Arc` or any other wide pointer. +/// However, using `AtomicPtr` means we cannot use a `Box` or any other wide pointer. /// Therefore, we hide the type by transmuting to `DummyMemo`; but we must then be very careful /// when freeing `MemoEntryData` values to transmute things back. See the `Drop` impl for /// [`MemoEntry`][] for details. @@ -52,43 +51,45 @@ struct MemoEntryData { /// The `type_id` of the erased memo type `M` type_id: TypeId, - /// A pointer to `std::mem::drop::>` for the erased memo type `M` - to_dyn_fn: fn(Arc) -> Arc, + /// A type-coercion function for the erased memo type `M` + to_dyn_fn: fn(NonNull) -> NonNull, - /// An [`ArcSwap`][] to a `Arc` for the erased memo type `M` - arc_swap: ArcSwap, + /// An [`AtomicPtr`][] to a `Box` for the erased memo type `M` + atomic_memo: AtomicPtr, } /// Dummy placeholder type that we use when erasing the memo type `M` in [`MemoEntryData`][]. struct DummyMemo {} impl MemoTable { - fn to_dummy(memo: Arc) -> Arc { - unsafe { std::mem::transmute::, Arc>(memo) } + fn to_dummy(memo: NonNull) -> NonNull { + memo.cast() } - unsafe fn from_dummy(memo: Arc) -> Arc { - unsafe { std::mem::transmute::, Arc>(memo) } + unsafe fn from_dummy(memo: NonNull) -> NonNull { + memo.cast() } - fn to_dyn_fn() -> fn(Arc) -> Arc { - let f: fn(Arc) -> Arc = |x| x; + fn to_dyn_fn() -> fn(NonNull) -> NonNull { + let f: fn(NonNull) -> NonNull = |x| x; + unsafe { - std::mem::transmute::) -> Arc, fn(Arc) -> Arc>( - f, - ) + std::mem::transmute::< + fn(NonNull) -> NonNull, + fn(NonNull) -> NonNull, + >(f) } } /// # Safety /// - /// The caller needs to make sure to not drop the returned value until no more references into - /// the database exist as there may be outstanding borrows into the `Arc` contents. + /// The caller needs to make sure to not free the returned value until no more references into + /// the database exist as there may be outstanding borrows into the pointer contents. pub(crate) unsafe fn insert( &self, memo_ingredient_index: MemoIngredientIndex, - memo: Arc, - ) -> Option>> { + memo: NonNull, + ) -> Option> { // If the memo slot is already occupied, it must already have the // right type info etc, and we only need the read-lock. if let Some(MemoEntry { @@ -96,7 +97,7 @@ impl MemoTable { Some(MemoEntryData { type_id, to_dyn_fn: _, - arc_swap, + atomic_memo, }), }) = self.memos.read().get(memo_ingredient_index.as_usize()) { @@ -105,8 +106,14 @@ impl MemoTable { TypeId::of::(), "inconsistent type-id for `{memo_ingredient_index:?}`" ); - let old_memo = arc_swap.swap(Self::to_dummy(memo)); - return Some(ManuallyDrop::new(unsafe { Self::from_dummy(old_memo) })); + + let old_memo = atomic_memo.swap(Self::to_dummy(memo).as_ptr(), Ordering::AcqRel); + + // SAFETY: The `atomic_memo` field is never null. + let old_memo = unsafe { NonNull::new_unchecked(old_memo) }; + + // SAFETY: `type_id` check asserted above + return Some(unsafe { Self::from_dummy(old_memo) }); } // Otherwise we need the write lock. @@ -116,13 +123,13 @@ impl MemoTable { /// # Safety /// - /// The caller needs to make sure to not drop the returned value until no more references into - /// the database exist as there may be outstanding borrows into the `Arc` contents. + /// The caller needs to make sure to not free the returned value until no more references into + /// the database exist as there may be outstanding borrows into the pointer contents. unsafe fn insert_cold( &self, memo_ingredient_index: MemoIngredientIndex, - memo: Arc, - ) -> Option>> { + memo: NonNull, + ) -> Option> { let mut memos = self.memos.write(); let memo_ingredient_index = memo_ingredient_index.as_usize(); if memos.len() < memo_ingredient_index + 1 { @@ -133,24 +140,22 @@ impl MemoTable { Some(MemoEntryData { type_id: TypeId::of::(), to_dyn_fn: Self::to_dyn_fn::(), - arc_swap: ArcSwap::new(Self::to_dummy(memo)), + atomic_memo: AtomicPtr::new(Self::to_dummy(memo).as_ptr()), }), ); - old_entry - .map( - |MemoEntryData { - type_id: _, - to_dyn_fn: _, - arc_swap, - }| unsafe { Self::from_dummy(arc_swap.into_inner()) }, - ) - .map(ManuallyDrop::new) + old_entry.map( + |MemoEntryData { + type_id: _, + to_dyn_fn: _, + atomic_memo, + }| unsafe { + // SAFETY: The `atomic_memo` field is never null. + Self::from_dummy(NonNull::new_unchecked(atomic_memo.into_inner())) + }, + ) } - pub(crate) fn get( - &self, - memo_ingredient_index: MemoIngredientIndex, - ) -> Option> { + pub(crate) fn get(&self, memo_ingredient_index: MemoIngredientIndex) -> Option<&M> { let memos = self.memos.read(); let Some(MemoEntry { @@ -158,7 +163,7 @@ impl MemoTable { Some(MemoEntryData { type_id, to_dyn_fn: _, - arc_swap, + atomic_memo, }), }) = memos.get(memo_ingredient_index.as_usize()) else { @@ -171,57 +176,54 @@ impl MemoTable { "inconsistent type-id for `{memo_ingredient_index:?}`" ); - // SAFETY: type_id check asserted above - unsafe { Some(Self::from_dummy(arc_swap.load_full())) } + // SAFETY: The `atomic_memo` field is never null. + let memo = unsafe { NonNull::new_unchecked(atomic_memo.load(Ordering::Acquire)) }; + + // SAFETY: `type_id` check asserted above + unsafe { Some(Self::from_dummy(memo).as_ref()) } } - /// Calls `f` on the memo at `memo_ingredient_index` and replaces the memo with the result of `f`. + /// Calls `f` on the memo at `memo_ingredient_index`. + /// /// If the memo is not present, `f` is not called. - /// - /// # Safety - /// - /// The caller needs to make sure to not drop the returned value until no more references into - /// the database exist as there may be outstanding borrows into the `Arc` contents. - pub(crate) unsafe fn map_memo( + pub(crate) fn map_memo( &mut self, memo_ingredient_index: MemoIngredientIndex, - f: impl FnOnce(Arc) -> Arc, - ) -> Option>> { + f: impl FnOnce(&mut M), + ) { let memos = self.memos.get_mut(); let Some(MemoEntry { data: Some(MemoEntryData { type_id, to_dyn_fn: _, - arc_swap, + atomic_memo, }), - }) = memos.get(memo_ingredient_index.as_usize()) + }) = memos.get_mut(memo_ingredient_index.as_usize()) else { - return None; + return; }; + assert_eq!( *type_id, TypeId::of::(), "inconsistent type-id for `{memo_ingredient_index:?}`" ); - // arc-swap does not expose accessing the interior mutably at all unfortunately - // https://github.com/vorner/arc-swap/issues/131 - // so we are required to allocate a new arc within `f` instead of being able - // to swap out the interior - // SAFETY: type_id check asserted above - let memo = f(unsafe { Self::from_dummy(arc_swap.load_full()) }); - Some(ManuallyDrop::new(unsafe { - Self::from_dummy::(arc_swap.swap(Self::to_dummy(memo))) - })) + + // SAFETY: The `atomic_memo` field is never null. + let memo = unsafe { NonNull::new_unchecked(*atomic_memo.get_mut()) }; + + // SAFETY: `type_id` check asserted above + f(unsafe { Self::from_dummy(memo).as_mut() }); } /// # Safety /// - /// The caller needs to make sure to not drop the returned value until no more references into - /// the database exist as there may be outstanding borrows into the `Arc` contents. + /// The caller needs to make sure to not call this function until no more references into + /// the database exist as there may be outstanding borrows into the pointer contents. pub(crate) unsafe fn into_memos( self, - ) -> impl Iterator>)> { + ) -> impl Iterator)> { self.memos .into_inner() .into_iter() @@ -232,14 +234,17 @@ impl MemoTable { MemoEntryData { type_id: _, to_dyn_fn, - arc_swap, + atomic_memo, }, index, )| { - ( - MemoIngredientIndex::from_usize(index), - ManuallyDrop::new(to_dyn_fn(arc_swap.into_inner())), - ) + // SAFETY: The `atomic_memo` field is never null. + let memo = + unsafe { to_dyn_fn(NonNull::new_unchecked(atomic_memo.into_inner())) }; + // SAFETY: The caller guarantees that there are no outstanding borrows into the `Box` contents. + let memo = unsafe { Box::from_raw(memo.as_ptr()) }; + + (MemoIngredientIndex::from_usize(index), memo) }, ) } @@ -250,11 +255,14 @@ impl Drop for MemoEntry { if let Some(MemoEntryData { type_id: _, to_dyn_fn, - arc_swap, + atomic_memo, }) = self.data.take() { - let arc = arc_swap.into_inner(); - std::mem::drop(to_dyn_fn(arc)); + // SAFETY: The `atomic_memo` field is never null. + let memo = unsafe { to_dyn_fn(NonNull::new_unchecked(atomic_memo.into_inner())) }; + // SAFETY: We have `&mut self`, so there are no outstanding borrows into the `Box` contents. + let memo = unsafe { Box::from_raw(memo.as_ptr()) }; + std::mem::drop(memo); } } } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 082bcc47..ec5ae11a 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -1,4 +1,4 @@ -use std::{any::TypeId, fmt, hash::Hash, marker::PhantomData, mem::ManuallyDrop, ops::DerefMut}; +use std::{any::TypeId, fmt, hash::Hash, marker::PhantomData, ops::DerefMut}; use crossbeam_queue::SegQueue; use tracked_field::FieldIngredientImpl; @@ -617,10 +617,10 @@ where // Take the memo table. This is safe because we have modified `data_ref.updated_at` to `None` // and the code that references the memo-table has a read-lock. let memo_table = unsafe { (*data).take_memo_table() }; + // SAFETY: We have verified that no more references to these memos exist and so we are good // to drop them. for (memo_ingredient_index, memo) in unsafe { memo_table.into_memos() } { - let memo = ManuallyDrop::into_inner(memo); let ingredient_index = zalsa.ingredient_index_for_memo(self.ingredient_index, memo_ingredient_index);