refactor: Clean up some unsafety (#830)

This commit is contained in:
Lukas Wirth 2025-05-02 14:16:44 +02:00 committed by GitHub
parent fa8409212d
commit 2c041763b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 123 additions and 174 deletions

View file

@ -282,15 +282,12 @@ macro_rules! setup_tracked_fn {
)
};
// SAFETY: We pass the MemoEntryType for this Configuration, and we lookup the memo types table correctly.
let fn_ingredient = unsafe {
<$zalsa::function::IngredientImpl<$Configuration>>::new(
first_index,
memo_ingredient_indices,
$lru,
zalsa.views().downcaster_for::<dyn $Db>(),
)
};
let fn_ingredient = <$zalsa::function::IngredientImpl<$Configuration>>::new(
first_index,
memo_ingredient_indices,
$lru,
zalsa.views().downcaster_for::<dyn $Db>(),
);
$zalsa::macro_if! {
if $needs_interner {
vec![

View file

@ -143,10 +143,7 @@ impl<C> IngredientImpl<C>
where
C: Configuration,
{
/// # Safety
///
/// `memo_type` and `memo_table_types` must be correct.
pub unsafe fn new(
pub fn new(
index: IngredientIndex,
memo_ingredient_indices: <C::SalsaStruct<'static> as SalsaStructInDb>::MemoIngredientMap,
lru: usize,
@ -195,19 +192,11 @@ where
) -> &'db memo::Memo<C::Output<'db>> {
// We convert to a `NonNull` here as soon as possible because we are going to alias
// into the `Box`, which is a `noalias` type.
// SAFETY: memo is not null
let memo = unsafe { NonNull::new_unchecked(Box::into_raw(Box::new(memo))) };
// SAFETY: memo must be in the map (it's not yet, but it will be by the time this
// value is returned) and anything removed from map is added to deleted entries (ensured elsewhere).
let db_memo = unsafe { self.extend_memo_lifetime(memo.as_ref()) };
// FIXME: Use `Box::into_non_null` once stable
let memo = NonNull::from(Box::leak(Box::new(memo)));
if let Some(old_value) =
// SAFETY: We delay the drop of `old_value` until a new revision starts which ensures no
// references will exist for the memo contents.
unsafe {
self.insert_memo_into_table_for(zalsa, id, memo, memo_ingredient_index)
}
self.insert_memo_into_table_for(zalsa, id, memo, memo_ingredient_index)
{
// In case there is a reference to the old memo out there, we have to store it
// in the deleted entries. This will get cleared when a new revision starts.
@ -216,7 +205,8 @@ where
// memo contents, and so it will be safe to free.
unsafe { self.deleted_entries.push(old_value) };
}
db_memo
// SAFETY: memo has been inserted into the table
unsafe { self.extend_memo_lifetime(memo.as_ref()) }
}
#[inline]

View file

@ -1,7 +1,6 @@
#![allow(clippy::undocumented_unsafe_blocks)] // TODO(#697) document safety
use std::any::Any;
use std::fmt::{Debug, Formatter};
use std::mem::transmute;
use std::ptr::NonNull;
use std::sync::atomic::Ordering;
@ -15,69 +14,46 @@ use crate::zalsa_local::{QueryOrigin, QueryRevisions};
use crate::{Event, EventKind, Id, Revision};
impl<C: Configuration> IngredientImpl<C> {
/// Memos have to be stored internally using `'static` as the database lifetime.
/// This (unsafe) function call converts from something tied to self to static.
/// Values transmuted this way have to be transmuted back to being tied to self
/// when they are returned to the user.
unsafe fn to_static<'db>(
&'db self,
memo: NonNull<Memo<C::Output<'db>>>,
) -> NonNull<Memo<C::Output<'static>>> {
memo.cast()
}
/// Convert from an internal memo (which uses `'static`) to one tied to self
/// so it can be publicly released.
unsafe fn to_self<'db>(
&'db self,
memo: NonNull<Memo<C::Output<'static>>>,
) -> NonNull<Memo<C::Output<'db>>> {
memo.cast()
}
/// Convert from an internal memo (which uses `'static`) to one tied to self
/// so it can be publicly released.
unsafe fn to_self_ref<'db>(
&'db self,
memo: &'db Memo<C::Output<'static>>,
) -> &'db Memo<C::Output<'db>> {
unsafe { std::mem::transmute(memo) }
}
/// Inserts the memo for the given key; (atomically) overwrites and returns any previously existing memo
///
/// # Safety
///
/// The caller needs to make sure to not drop the returned value until no more references into
/// the database exist as there may be outstanding borrows into the `Arc` contents.
pub(super) unsafe fn insert_memo_into_table_for<'db>(
&'db self,
pub(super) fn insert_memo_into_table_for<'db>(
&self,
zalsa: &'db Zalsa,
id: Id,
memo: NonNull<Memo<C::Output<'db>>>,
memo_ingredient_index: MemoIngredientIndex,
) -> Option<NonNull<Memo<C::Output<'db>>>> {
let static_memo = unsafe { self.to_static(memo) };
let old_static_memo = unsafe {
zalsa
.memo_table_for(id)
.insert(memo_ingredient_index, static_memo)
}?;
Some(unsafe { self.to_self(old_static_memo) })
// SAFETY: The table stores 'static memos (to support `Any`), the memos are in fact valid
// for `'db` though as we delay their dropping to the end of a revision.
let static_memo = unsafe {
transmute::<NonNull<Memo<C::Output<'db>>>, NonNull<Memo<C::Output<'static>>>>(memo)
};
let old_static_memo = zalsa
.memo_table_for(id)
.insert(memo_ingredient_index, static_memo)?;
// SAFETY: The table stores 'static memos (to support `Any`), the memos are in fact valid
// for `'db` though as we delay their dropping to the end of a revision.
Some(unsafe {
transmute::<NonNull<Memo<C::Output<'static>>>, NonNull<Memo<C::Output<'db>>>>(
old_static_memo,
)
})
}
/// Loads the current memo for `key_index`. This does not hold any sort of
/// lock on the `memo_map` once it returns, so this memo could immediately
/// become outdated if other threads store into the `memo_map`.
pub(super) fn get_memo_from_table_for<'db>(
&'db self,
&self,
zalsa: &'db Zalsa,
id: Id,
memo_ingredient_index: MemoIngredientIndex,
) -> Option<&'db Memo<C::Output<'db>>> {
let static_memo = zalsa.memo_table_for(id).get(memo_ingredient_index)?;
unsafe { Some(self.to_self_ref(static_memo)) }
// SAFETY: The table stores 'static memos (to support `Any`), the memos are in fact valid
// for `'db` though as we delay their dropping to the end of a revision.
Some(unsafe {
transmute::<&Memo<C::Output<'static>>, &'db Memo<C::Output<'db>>>(static_memo.as_ref())
})
}
/// Evicts the existing memo for the given key, replacing it

View file

@ -35,11 +35,6 @@ impl Id {
/// # Safety
///
/// The supplied value must be less than [`Self::MAX_U32`].
///
/// Additionally, creating an arbitrary `Id` can lead to unsoundness if such an ID ends up being used to index
/// the internal allocation tables which end up being out of bounds. Care must be taken that the
/// ID is either constructed with a valid value or that it never ends up being used as keys to
/// salsa computations.
#[doc(hidden)]
#[track_caller]
#[inline]

View file

@ -20,7 +20,7 @@ pub(crate) mod memo;
const PAGE_LEN_BITS: usize = 10;
const PAGE_LEN_MASK: usize = PAGE_LEN - 1;
const PAGE_LEN: usize = 1 << PAGE_LEN_BITS;
const MAX_PAGES: usize = 1 << (32 - PAGE_LEN_BITS);
const MAX_PAGES: usize = 1 << (u32::BITS as usize - PAGE_LEN_BITS);
/// A typed [`Page`] view.
pub(crate) struct PageView<'p, T: Slot>(&'p Page, PhantomData<&'p T>);
@ -50,7 +50,7 @@ type SlotMemosFn<T> = unsafe fn(&T, current_revision: Revision) -> &MemoTable;
/// [Slot::memos_mut]
type SlotMemosMutFnRaw = unsafe fn(*mut ()) -> *mut MemoTable;
/// [Slot::memos_mut]
type SlotMemosMutFn<T> = unsafe fn(&mut T) -> &mut MemoTable;
type SlotMemosMutFn<T> = fn(&mut T) -> &mut MemoTable;
struct SlotVTable {
layout: Layout,
@ -127,10 +127,10 @@ struct Page {
// SAFETY: `Page` is `Send` as we make sure to only ever store `Slot` types in it which
// requires `Send`.`
unsafe impl Send for Page {}
unsafe impl Send for Page /* where for<M: Memo> M: Send */ {}
// SAFETY: `Page` is `Sync` as we make sure to only ever store `Slot` types in it which
// requires `Sync`.`
unsafe impl Sync for Page {}
unsafe impl Sync for Page /* where for<M: Memo> M: Sync */ {}
#[derive(Copy, Clone, Debug)]
pub struct PageIndex(usize);
@ -283,6 +283,7 @@ impl Table {
}
impl<'p, T: Slot> PageView<'p, T> {
#[inline]
fn page_data(&self) -> &[PageDataEntry<T>] {
let len = self.0.allocated.load(Ordering::Acquire);
// SAFETY: `len` is the initialized length of the page
@ -390,7 +391,7 @@ impl Drop for Page {
fn make_id(page: PageIndex, slot: SlotIndex) -> Id {
let page = page.0 as u32;
let slot = slot.0 as u32;
// SAFETY: `page` and `slot` are derived from proper indices.
// SAFETY: `slot` is guaranteed to be small enough that the resulting Id won't be bigger than `Id::MAX_U32`
unsafe { Id::from_u32((page << PAGE_LEN_BITS) | slot) }
}

View file

@ -75,7 +75,8 @@ impl MemoEntryType {
const fn to_dyn_fn<M: Memo>() -> fn(NonNull<DummyMemo>) -> NonNull<dyn Memo> {
let f: fn(NonNull<M>) -> NonNull<dyn Memo> = |x| x;
#[allow(clippy::undocumented_unsafe_blocks)] // TODO(#697) document safety
// SAFETY: `M: Sized` and `DummyMemo: Sized`, as such they are ABI compatible behind a
// `NonNull` making it safe to do type erasure.
unsafe {
mem::transmute::<
fn(NonNull<M>) -> NonNull<dyn Memo>,
@ -144,7 +145,7 @@ impl MemoTableTypes {
///
/// The types table must be the correct one of `memos`.
#[inline]
pub(crate) unsafe fn attach_memos<'a>(
pub(super) unsafe fn attach_memos<'a>(
&'a self,
memos: &'a MemoTable,
) -> MemoTableWithTypes<'a> {
@ -168,12 +169,8 @@ pub(crate) struct MemoTableWithTypes<'a> {
memos: &'a MemoTable,
}
impl<'a> MemoTableWithTypes<'a> {
/// # Safety
///
/// The caller needs to make sure to not drop the returned value until no more references into
/// the database exist as there may be outstanding borrows into the `Arc` contents.
pub(crate) unsafe fn insert<M: Memo>(
impl MemoTableWithTypes<'_> {
pub(crate) fn insert<M: Memo>(
self,
memo_ingredient_index: MemoIngredientIndex,
memo: NonNull<M>,
@ -207,15 +204,11 @@ impl<'a> MemoTableWithTypes<'a> {
}
// Otherwise we need the write lock.
// SAFETY: The caller is responsible for dropping
unsafe { self.insert_cold(memo_ingredient_index, memo) }
self.insert_cold(memo_ingredient_index, memo)
}
/// # Safety
///
/// The caller needs to make sure to not drop the returned value until no more references into
/// the database exist as there may be outstanding borrows into the `Arc` contents.
unsafe fn insert_cold<M: Memo>(
#[cold]
fn insert_cold<M: Memo>(
self,
memo_ingredient_index: MemoIngredientIndex,
memo: NonNull<M>,
@ -237,7 +230,10 @@ impl<'a> MemoTableWithTypes<'a> {
}
#[inline]
pub(crate) fn get<M: Memo>(self, memo_ingredient_index: MemoIngredientIndex) -> Option<&'a M> {
pub(crate) fn get<M: Memo>(
self,
memo_ingredient_index: MemoIngredientIndex,
) -> Option<NonNull<M>> {
let read = self.memos.memos.read();
let memo = read.get(memo_ingredient_index.as_usize())?;
let type_ = self
@ -250,9 +246,9 @@ impl<'a> MemoTableWithTypes<'a> {
TypeId::of::<M>(),
"inconsistent type-id for `{memo_ingredient_index:?}`"
);
let memo = NonNull::new(memo.atomic_memo.load(Ordering::Acquire));
let memo = NonNull::new(memo.atomic_memo.load(Ordering::Acquire))?;
// SAFETY: `type_id` check asserted above
memo.map(|memo| unsafe { MemoEntryType::from_dummy(memo).as_ref() })
Some(unsafe { MemoEntryType::from_dummy(memo) })
}
}
@ -300,12 +296,19 @@ impl MemoTableWithTypesMut<'_> {
}
/// To drop an entry, we need its type, so we don't implement `Drop`, and instead have this method.
///
/// Note that calling this multiple times is safe, dropping an uninitialized entry is a no-op.
///
/// # Safety
///
/// The caller needs to make sure to not call this function until no more references into
/// the database exist as there may be outstanding borrows into the pointer contents.
#[inline]
pub fn drop(self) {
pub unsafe fn drop(&mut self) {
let types = self.types.types.iter();
for ((_, type_), memo) in std::iter::zip(types, self.memos.memos.get_mut()) {
// SAFETY: The types match because this is an invariant of `MemoTableWithTypesMut`.
unsafe { memo.drop(type_) };
// SAFETY: The types match as per our constructor invariant.
unsafe { memo.take(type_) };
}
}
@ -313,22 +316,19 @@ impl MemoTableWithTypesMut<'_> {
///
/// The caller needs to make sure to not call this function until no more references into
/// the database exist as there may be outstanding borrows into the pointer contents.
pub(crate) unsafe fn with_memos(self, mut f: impl FnMut(MemoIngredientIndex, Box<dyn Memo>)) {
pub(crate) unsafe fn take_memos(
&mut self,
mut f: impl FnMut(MemoIngredientIndex, Box<dyn Memo>),
) {
let memos = self.memos.memos.get_mut();
memos
.iter_mut()
.zip(self.types.types.iter())
.zip(0..)
.filter_map(|((memo, (_, type_)), index)| {
let memo = mem::replace(memo.atomic_memo.get_mut(), ptr::null_mut());
let memo = NonNull::new(memo)?;
Some((memo, type_.load()?, index))
})
.map(|(memo, type_, index)| {
// SAFETY: We took ownership of the memo, and converted it to the correct type.
// The caller guarantees that there are no outstanding borrows into the `Box` contents.
let memo = unsafe { Box::from_raw((type_.to_dyn_fn)(memo).as_ptr()) };
(MemoIngredientIndex::from_usize(index), memo)
.enumerate()
.filter_map(|(index, (memo, (_, type_)))| {
// SAFETY: The types match as per our constructor invariant.
let memo = unsafe { memo.take(type_)? };
Some((MemoIngredientIndex::from_usize(index), memo))
})
.for_each(|(index, memo)| f(index, memo));
}
@ -339,14 +339,11 @@ impl MemoEntry {
///
/// The type must match.
#[inline]
unsafe fn drop(&mut self, type_: &MemoEntryType) {
if let Some(memo) = NonNull::new(mem::replace(self.atomic_memo.get_mut(), ptr::null_mut()))
{
if let Some(type_) = type_.load() {
// SAFETY: Our preconditions.
mem::drop(unsafe { Box::from_raw((type_.to_dyn_fn)(memo).as_ptr()) });
}
}
unsafe fn take(&mut self, type_: &MemoEntryType) -> Option<Box<dyn Memo>> {
let memo = NonNull::new(mem::replace(self.atomic_memo.get_mut(), ptr::null_mut()))?;
let type_ = type_.load()?;
// SAFETY: Our preconditions.
Some(unsafe { Box::from_raw((type_.to_dyn_fn)(memo).as_ptr()) })
}
}

View file

@ -1,11 +1,11 @@
#![allow(clippy::undocumented_unsafe_blocks)] // TODO(#697) document safety
use std::any::TypeId;
use std::fmt;
use std::hash::Hash;
use std::marker::PhantomData;
use std::ops::Index;
use std::sync::Arc;
use std::{fmt, mem};
use crossbeam_queue::SegQueue;
use tracked_field::FieldIngredientImpl;
@ -18,7 +18,7 @@ use crate::plumbing::ZalsaLocal;
use crate::revision::OptionalAtomicRevision;
use crate::runtime::StampedValue;
use crate::salsa_struct::SalsaStructInDb;
use crate::table::memo::{MemoTable, MemoTableTypes};
use crate::table::memo::{MemoTable, MemoTableTypes, MemoTableWithTypesMut};
use crate::table::{Slot, Table};
use crate::zalsa::{IngredientIndex, Zalsa};
use crate::{Database, Durability, Event, EventKind, Id, Revision};
@ -308,6 +308,7 @@ where
revisions: C::Revisions,
/// Memo table storing the results of query functions etc.
/*unsafe */
memos: MemoTable,
}
// ANCHOR_END: ValueStruct
@ -353,22 +354,6 @@ impl<C> IngredientImpl<C>
where
C: Configuration,
{
/// Convert the fields from a `'db` lifetime to `'static`: used when storing
/// the data into this ingredient, should never be released outside this type.
unsafe fn to_static<'db>(&'db self, fields: C::Fields<'db>) -> C::Fields<'static> {
unsafe { std::mem::transmute(fields) }
}
unsafe fn to_self_ref<'db>(&'db self, fields: &'db C::Fields<'static>) -> &'db C::Fields<'db> {
unsafe { std::mem::transmute(fields) }
}
/// Convert from static back to the db lifetime; used when returning data
/// out from this ingredient.
unsafe fn to_self_ptr<'db>(&'db self, fields: *mut C::Fields<'static>) -> *mut C::Fields<'db> {
unsafe { std::mem::transmute(fields) }
}
/// Create a tracked struct ingredient. Generated by the `#[tracked]` macro,
/// not meant to be called directly by end-users.
fn new(index: IngredientIndex) -> Self {
@ -440,7 +425,8 @@ where
created_at: current_revision,
updated_at: OptionalAtomicRevision::new(Some(current_revision)),
durability: current_deps.durability,
fields: unsafe { self.to_static(fields) },
// lifetime erase for storage
fields: unsafe { mem::transmute::<C::Fields<'db>, C::Fields<'static>>(fields) },
revisions: C::new_revisions(current_deps.changed_at),
memos: Default::default(),
};
@ -552,7 +538,9 @@ where
if C::update_fields(
current_deps.changed_at,
&mut data.revisions,
self.to_self_ptr(std::ptr::addr_of_mut!(data.fields)),
mem::transmute::<*mut C::Fields<'static>, *mut C::Fields<'db>>(
std::ptr::addr_of_mut!(data.fields),
),
fields,
) {
// Consider this a new tracked-struct (even though it still uses the same id)
@ -606,8 +594,8 @@ where
// We want to set `updated_at` to `None`, signalling that other field values
// cannot be read. The current value should be `Some(R0)` for some older revision.
let data_ref = unsafe { &*data };
match data_ref.updated_at.load() {
let updated_at = unsafe { &(*data).updated_at };
match updated_at.load() {
None => {
panic!("cannot delete write-locked id `{id:?}`; value leaked across threads");
}
@ -615,41 +603,45 @@ where
"cannot delete read-locked id `{id:?}`; value leaked across threads or user functions not deterministic"
),
Some(r) => {
if data_ref.updated_at.compare_exchange(Some(r), None).is_err() {
if updated_at.compare_exchange(Some(r), None).is_err() {
panic!("race occurred when deleting value `{id:?}`")
}
}
}
// Take the memo table. This is safe because we have modified `data_ref.updated_at` to `None`
// and the code that references the memo-table has a read-lock.
struct MemoTableWithTypes<'a>(MemoTable, &'a MemoTableTypes);
impl Drop for MemoTableWithTypes<'_> {
// SAFETY: We have acquired the write lock
let data = unsafe { &mut *data };
let mut memo_table = data.take_memo_table();
// SAFETY: We use the correct types table.
let table = unsafe { self.memo_table_types.attach_memos_mut(&mut memo_table) };
// `Database::salsa_event` is a user supplied callback which may panic
// in that case we need a drop guard to free the memo table
struct TableDropGuard<'a>(MemoTableWithTypesMut<'a>);
impl Drop for TableDropGuard<'_> {
fn drop(&mut self) {
// SAFETY: We use the correct types table.
unsafe { self.1.attach_memos_mut(&mut self.0) }.drop();
// SAFETY: We have verified that no more references to these memos exist and so we are good
// to drop them.
unsafe { self.0.drop() };
}
}
let mut memo_table =
MemoTableWithTypes(unsafe { (*data).take_memo_table() }, &self.memo_table_types);
let mut table_guard = TableDropGuard(table);
// SAFETY: We have verified that no more references to these memos exist and so we are good
// to drop them.
unsafe {
memo_table.1.attach_memos_mut(&mut memo_table.0).with_memos(
|memo_ingredient_index, memo| {
let ingredient_index = zalsa
.ingredient_index_for_memo(self.ingredient_index, memo_ingredient_index);
table_guard.0.take_memos(|memo_ingredient_index, memo| {
let ingredient_index =
zalsa.ingredient_index_for_memo(self.ingredient_index, memo_ingredient_index);
let executor = DatabaseKeyIndex::new(ingredient_index, id);
let executor = DatabaseKeyIndex::new(ingredient_index, id);
db.salsa_event(&|| Event::new(EventKind::DidDiscard { key: executor }));
db.salsa_event(&|| Event::new(EventKind::DidDiscard { key: executor }));
for stale_output in memo.origin().outputs() {
stale_output.remove_stale_output(zalsa, db, executor, provisional);
}
},
)
for stale_output in memo.origin().outputs() {
stale_output.remove_stale_output(zalsa, db, executor, provisional);
}
})
};
mem::forget(table_guard);
// now that all cleanup has occurred, make available for re-use
self.free_list.push(id);
@ -663,8 +655,8 @@ where
s: C::Struct<'db>,
) -> &'db C::Fields<'db> {
let id = AsId::as_id(&s);
let value = Self::data(db.zalsa().table(), id);
unsafe { self.to_self_ref(&value.fields) }
let data = Self::data(db.zalsa().table(), id);
data.fields()
}
/// Access to this tracked field.
@ -692,7 +684,7 @@ where
field_changed_at,
);
unsafe { self.to_self_ref(&data.fields) }
data.fields()
}
/// Access to this untracked field.
@ -717,7 +709,7 @@ where
data.created_at,
);
unsafe { self.to_self_ref(&data.fields) }
data.fields()
}
#[cfg(feature = "salsa_unstable")]
@ -811,9 +803,10 @@ where
///
/// They can change across revisions, but they do not change within
/// a particular revision.
#[cfg(feature = "salsa_unstable")]
pub fn fields(&self) -> &C::Fields<'static> {
&self.fields
#[cfg_attr(not(feature = "salsa_unstable"), doc(hidden))]
pub fn fields(&self) -> &C::Fields<'_> {
// SAFETY: We are shrinking the lifetime from storage back to the db lifetime.
unsafe { mem::transmute::<&C::Fields<'static>, &C::Fields<'_>>(&self.fields) }
}
fn take_memo_table(&mut self) -> MemoTable {
@ -822,7 +815,7 @@ where
// (and that the `&mut self` is accurate...).
assert!(self.updated_at.load().is_none());
std::mem::take(&mut self.memos)
mem::take(&mut self.memos)
}
fn read_lock(&self, current_revision: Revision) {