Fix cloning Symbols not increasing their ref count

This commit is contained in:
Lukas Wirth 2024-07-12 17:11:12 +02:00
parent 3fe815b0f3
commit dd626e78c7
2 changed files with 63 additions and 28 deletions

View file

@ -5,7 +5,7 @@ use std::{
borrow::Borrow, borrow::Borrow,
fmt, fmt,
hash::{BuildHasherDefault, Hash, Hasher}, hash::{BuildHasherDefault, Hash, Hasher},
mem, mem::{self, ManuallyDrop},
ptr::NonNull, ptr::NonNull,
sync::OnceLock, sync::OnceLock,
}; };
@ -25,6 +25,15 @@ const _: () = assert!(std::mem::align_of::<Box<str>>() == std::mem::align_of::<&
const _: () = assert!(std::mem::size_of::<Arc<Box<str>>>() == std::mem::size_of::<&&str>()); const _: () = assert!(std::mem::size_of::<Arc<Box<str>>>() == std::mem::size_of::<&&str>());
const _: () = assert!(std::mem::align_of::<Arc<Box<str>>>() == std::mem::align_of::<&&str>()); const _: () = assert!(std::mem::align_of::<Arc<Box<str>>>() == std::mem::align_of::<&&str>());
const _: () =
assert!(std::mem::size_of::<*const *const str>() == std::mem::size_of::<TaggedArcPtr>());
const _: () =
assert!(std::mem::align_of::<*const *const str>() == std::mem::align_of::<TaggedArcPtr>());
const _: () = assert!(std::mem::size_of::<Arc<Box<str>>>() == std::mem::size_of::<TaggedArcPtr>());
const _: () =
assert!(std::mem::align_of::<Arc<Box<str>>>() == std::mem::align_of::<TaggedArcPtr>());
/// A pointer that points to a pointer to a `str`, it may be backed as a `&'static &'static str` or /// A pointer that points to a pointer to a `str`, it may be backed as a `&'static &'static str` or
/// `Arc<Box<str>>` but its size is that of a thin pointer. The active variant is encoded as a tag /// `Arc<Box<str>>` but its size is that of a thin pointer. The active variant is encoded as a tag
/// in the LSB of the alignment niche. /// in the LSB of the alignment niche.
@ -40,19 +49,24 @@ impl TaggedArcPtr {
const BOOL_BITS: usize = true as usize; const BOOL_BITS: usize = true as usize;
const fn non_arc(r: &'static &'static str) -> Self { const fn non_arc(r: &'static &'static str) -> Self {
Self { assert!(
// SAFETY: The pointer is non-null as it is derived from a reference mem::align_of::<&'static &'static str>().trailing_zeros() as usize > Self::BOOL_BITS
// Ideally we would call out to `pack_arc` but for a `false` tag, unfortunately the );
// packing stuff requires reading out the pointer to an integer which is not supported // SAFETY: The pointer is non-null as it is derived from a reference
// in const contexts, so here we make use of the fact that for the non-arc version the // Ideally we would call out to `pack_arc` but for a `false` tag, unfortunately the
// tag is false (0) and thus does not need touching the actual pointer value.ext) // packing stuff requires reading out the pointer to an integer which is not supported
packed: unsafe { // in const contexts, so here we make use of the fact that for the non-arc version the
NonNull::new_unchecked((r as *const &str).cast::<*const str>().cast_mut()) // tag is false (0) and thus does not need touching the actual pointer value.ext)
},
} let packed =
unsafe { NonNull::new_unchecked((r as *const &str).cast::<*const str>().cast_mut()) };
Self { packed }
} }
fn arc(arc: Arc<Box<str>>) -> Self { fn arc(arc: Arc<Box<str>>) -> Self {
assert!(
mem::align_of::<&'static &'static str>().trailing_zeros() as usize > Self::BOOL_BITS
);
Self { Self {
packed: Self::pack_arc( packed: Self::pack_arc(
// Safety: `Arc::into_raw` always returns a non null pointer // Safety: `Arc::into_raw` always returns a non null pointer
@ -63,12 +77,14 @@ impl TaggedArcPtr {
/// Retrieves the tag. /// Retrieves the tag.
#[inline] #[inline]
pub(crate) fn try_as_arc_owned(self) -> Option<Arc<Box<str>>> { pub(crate) fn try_as_arc_owned(self) -> Option<ManuallyDrop<Arc<Box<str>>>> {
// Unpack the tag from the alignment niche // Unpack the tag from the alignment niche
let tag = Strict::addr(self.packed.as_ptr()) & Self::BOOL_BITS; let tag = Strict::addr(self.packed.as_ptr()) & Self::BOOL_BITS;
if tag != 0 { if tag != 0 {
// Safety: We checked that the tag is non-zero -> true, so we are pointing to the data offset of an `Arc` // Safety: We checked that the tag is non-zero -> true, so we are pointing to the data offset of an `Arc`
Some(unsafe { Arc::from_raw(self.pointer().as_ptr().cast::<Box<str>>()) }) Some(ManuallyDrop::new(unsafe {
Arc::from_raw(self.pointer().as_ptr().cast::<Box<str>>())
}))
} else { } else {
None None
} }
@ -122,10 +138,11 @@ impl TaggedArcPtr {
} }
} }
#[derive(PartialEq, Eq, Hash, Clone, Debug)] #[derive(PartialEq, Eq, Hash, Debug)]
pub struct Symbol { pub struct Symbol {
repr: TaggedArcPtr, repr: TaggedArcPtr,
} }
const _: () = assert!(std::mem::size_of::<Symbol>() == std::mem::size_of::<NonNull<()>>()); const _: () = assert!(std::mem::size_of::<Symbol>() == std::mem::size_of::<NonNull<()>>());
const _: () = assert!(std::mem::align_of::<Symbol>() == std::mem::align_of::<NonNull<()>>()); const _: () = assert!(std::mem::align_of::<Symbol>() == std::mem::align_of::<NonNull<()>>());
@ -185,19 +202,27 @@ impl Symbol {
fn drop_slow(arc: &Arc<Box<str>>) { fn drop_slow(arc: &Arc<Box<str>>) {
let (mut shard, hash) = Self::select_shard(arc); let (mut shard, hash) = Self::select_shard(arc);
if Arc::count(arc) != 2 { match Arc::count(arc) {
// Another thread has interned another copy 0 => unreachable!(),
return; 1 => unreachable!(),
2 => (),
_ => {
// Another thread has interned another copy
return;
}
} }
match shard.raw_entry_mut().from_key_hashed_nocheck::<str>(hash, arc.as_ref()) { ManuallyDrop::into_inner(
RawEntryMut::Occupied(occ) => occ.remove_entry(), match shard.raw_entry_mut().from_key_hashed_nocheck::<str>(hash, arc.as_ref()) {
RawEntryMut::Vacant(_) => unreachable!(), RawEntryMut::Occupied(occ) => occ.remove_entry(),
} RawEntryMut::Vacant(_) => unreachable!(),
.0 }
.0 .0
.try_as_arc_owned() .0
.unwrap(); .try_as_arc_owned()
.unwrap(),
);
debug_assert_eq!(Arc::count(&arc), 1);
// Shrink the backing storage if the shard is less than 50% occupied. // Shrink the backing storage if the shard is less than 50% occupied.
if shard.len() * 2 < shard.capacity() { if shard.len() * 2 < shard.capacity() {
@ -219,7 +244,13 @@ impl Drop for Symbol {
Self::drop_slow(&arc); Self::drop_slow(&arc);
} }
// decrement the ref count // decrement the ref count
drop(arc); ManuallyDrop::into_inner(arc);
}
}
impl Clone for Symbol {
fn clone(&self) -> Self {
Self { repr: increase_arc_refcount(self.repr) }
} }
} }
@ -228,8 +259,7 @@ fn increase_arc_refcount(repr: TaggedArcPtr) -> TaggedArcPtr {
return repr; return repr;
}; };
// increase the ref count // increase the ref count
mem::forget(arc.clone()); mem::forget(Arc::clone(&arc));
mem::forget(arc);
repr repr
} }
@ -265,6 +295,7 @@ mod tests {
let base_len = MAP.get().unwrap().len(); let base_len = MAP.get().unwrap().len();
let hello = Symbol::intern("hello"); let hello = Symbol::intern("hello");
let world = Symbol::intern("world"); let world = Symbol::intern("world");
let more_worlds = world.clone();
let bang = Symbol::intern("!"); let bang = Symbol::intern("!");
let q = Symbol::intern("?"); let q = Symbol::intern("?");
assert_eq!(MAP.get().unwrap().len(), base_len + 4); assert_eq!(MAP.get().unwrap().len(), base_len + 4);
@ -275,6 +306,7 @@ mod tests {
drop(q); drop(q);
assert_eq!(MAP.get().unwrap().len(), base_len + 3); assert_eq!(MAP.get().unwrap().len(), base_len + 3);
let default = Symbol::intern("default"); let default = Symbol::intern("default");
let many_worlds = world.clone();
assert_eq!(MAP.get().unwrap().len(), base_len + 3); assert_eq!(MAP.get().unwrap().len(), base_len + 3);
assert_eq!( assert_eq!(
"hello default world!", "hello default world!",
@ -285,6 +317,8 @@ mod tests {
"hello world!", "hello world!",
format!("{} {}{}", hello.as_str(), world.as_str(), bang.as_str()) format!("{} {}{}", hello.as_str(), world.as_str(), bang.as_str())
); );
drop(many_worlds);
drop(more_worlds);
drop(hello); drop(hello);
drop(world); drop(world);
drop(bang); drop(bang);

View file

@ -10,6 +10,7 @@ use crate::{
symbol::{SymbolProxy, TaggedArcPtr}, symbol::{SymbolProxy, TaggedArcPtr},
Symbol, Symbol,
}; };
macro_rules! define_symbols { macro_rules! define_symbols {
(@WITH_NAME: $($alias:ident = $value:literal),* $(,)? @PLAIN: $($name:ident),* $(,)?) => { (@WITH_NAME: $($alias:ident = $value:literal),* $(,)? @PLAIN: $($name:ident),* $(,)?) => {
$( $(