From c145596ef8049691cb6269b938eba400ce85c46e Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Fri, 20 Jun 2025 13:15:43 -0400 Subject: [PATCH] Add API to dump memory usage (#916) * add memory usage information hooks * gate memory usage API under `salsa_unstable` feature * use snapshot tests for memory usage API --- src/accumulator/accumulated_map.rs | 8 +- src/cycle.rs | 5 ++ src/database.rs | 106 ++++++++++++++++++++++++ src/function/memo.rs | 12 +++ src/ingredient.rs | 7 ++ src/input.rs | 30 +++++++ src/interned.rs | 51 +++++++++++- src/lib.rs | 3 + src/table.rs | 2 + src/table/memo.rs | 52 +++++++++--- src/tracked_struct.rs | 30 +++++++ src/zalsa.rs | 7 ++ src/zalsa_local.rs | 45 ++++++++++ tests/memory-usage.rs | 128 +++++++++++++++++++++++++++++ 14 files changed, 472 insertions(+), 14 deletions(-) create mode 100644 tests/memory-usage.rs diff --git a/src/accumulator/accumulated_map.rs b/src/accumulator/accumulated_map.rs index bb38e40f..19ad6366 100644 --- a/src/accumulator/accumulated_map.rs +++ b/src/accumulator/accumulated_map.rs @@ -1,6 +1,6 @@ use std::ops; -use rustc_hash::FxHashMap; +use rustc_hash::FxBuildHasher; use crate::accumulator::accumulated::Accumulated; use crate::accumulator::{Accumulator, AnyAccumulated}; @@ -9,7 +9,7 @@ use crate::IngredientIndex; #[derive(Default)] pub struct AccumulatedMap { - map: FxHashMap>, + map: hashbrown::HashMap, FxBuildHasher>, } impl std::fmt::Debug for AccumulatedMap { @@ -50,6 +50,10 @@ impl AccumulatedMap { pub fn clear(&mut self) { self.map.clear() } + + pub fn allocation_size(&self) -> usize { + self.map.allocation_size() + } } /// Tracks whether any input read during a query's execution has any accumulated values. diff --git a/src/cycle.rs b/src/cycle.rs index fa17d81e..66f20544 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -210,6 +210,11 @@ impl CycleHeads { } } } + + #[cfg(feature = "salsa_unstable")] + pub(crate) fn allocation_size(&self) -> usize { + std::mem::size_of_val(self.0.as_slice()) + } } impl IntoIterator for CycleHeads { diff --git a/src/database.rs b/src/database.rs index 72204a58..a92b5791 100644 --- a/src/database.rs +++ b/src/database.rs @@ -133,3 +133,109 @@ impl dyn Database { views.downcaster_for().downcast(self) } } + +#[cfg(feature = "salsa_unstable")] +pub use memory_usage::{IngredientInfo, SlotInfo}; + +#[cfg(feature = "salsa_unstable")] +mod memory_usage { + use crate::Database; + use hashbrown::HashMap; + + impl dyn Database { + /// Returns information about any Salsa structs. + pub fn structs_info(&self) -> Vec { + self.zalsa() + .ingredients() + .filter_map(|ingredient| { + let mut size_of_fields = 0; + let mut size_of_metadata = 0; + let mut instances = 0; + + for slot in ingredient.memory_usage(self)? { + instances += 1; + size_of_fields += slot.size_of_fields; + size_of_metadata += slot.size_of_metadata; + } + + Some(IngredientInfo { + count: instances, + size_of_fields, + size_of_metadata, + debug_name: ingredient.debug_name(), + }) + }) + .collect() + } + + /// Returns information about any memoized Salsa queries. + /// + /// The returned map holds memory usage information for memoized values of a given query, keyed + /// by its `(input, output)` type names. + pub fn queries_info(&self) -> HashMap<(&'static str, &'static str), IngredientInfo> { + let mut queries = HashMap::new(); + + for input_ingredient in self.zalsa().ingredients() { + let Some(input_info) = input_ingredient.memory_usage(self) else { + continue; + }; + + for input in input_info { + for output in input.memos { + let info = queries + .entry((input.debug_name, output.debug_name)) + .or_insert(IngredientInfo { + debug_name: output.debug_name, + ..Default::default() + }); + + info.count += 1; + info.size_of_fields += output.size_of_fields; + info.size_of_metadata += output.size_of_metadata; + } + } + } + + queries + } + } + + /// Information about instances of a particular Salsa ingredient. + #[derive(Default, Debug, PartialEq, Eq, PartialOrd, Ord)] + pub struct IngredientInfo { + debug_name: &'static str, + count: usize, + size_of_metadata: usize, + size_of_fields: usize, + } + + impl IngredientInfo { + /// Returns the debug name of the ingredient. + pub fn debug_name(&self) -> &'static str { + self.debug_name + } + + /// Returns the total size of the fields of any instances of this ingredient, in bytes. + pub fn size_of_fields(&self) -> usize { + self.size_of_fields + } + + /// Returns the total size of Salsa metadata of any instances of this ingredient, in bytes. + pub fn size_of_metadata(&self) -> usize { + self.size_of_metadata + } + + /// Returns the number of instances of this ingredient. + pub fn count(&self) -> usize { + self.count + } + } + + /// Memory usage information about a particular instance of struct, input or output. + pub struct SlotInfo { + pub(crate) debug_name: &'static str, + pub(crate) size_of_metadata: usize, + pub(crate) size_of_fields: usize, + pub(crate) memos: Vec, + } +} diff --git a/src/function/memo.rs b/src/function/memo.rs index 3955f593..77efe8d6 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -316,6 +316,18 @@ impl crate::table::memo::Memo for Memo { fn origin(&self) -> QueryOriginRef<'_> { self.revisions.origin.as_ref() } + + #[cfg(feature = "salsa_unstable")] + fn memory_usage(&self) -> crate::SlotInfo { + let size_of = std::mem::size_of::>() + self.revisions.allocation_size(); + + crate::SlotInfo { + size_of_metadata: size_of - std::mem::size_of::(), + debug_name: std::any::type_name::(), + size_of_fields: std::mem::size_of::(), + memos: Vec::new(), + } + } } pub(super) enum TryClaimHeadsResult<'me> { diff --git a/src/ingredient.rs b/src/ingredient.rs index bfbcc2d3..a5e233df 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -181,6 +181,13 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { let _ = (db, key_index); (None, InputAccumulatedValues::Empty) } + + /// Returns memory usage information about any instances of the ingredient, + /// if applicable. + #[cfg(feature = "salsa_unstable")] + fn memory_usage(&self, _db: &dyn Database) -> Option> { + None + } } impl dyn Ingredient { diff --git a/src/input.rs b/src/input.rs index f209fe0b..814b6e9a 100644 --- a/src/input.rs +++ b/src/input.rs @@ -241,6 +241,18 @@ impl Ingredient for IngredientImpl { fn memo_table_types(&self) -> Arc { self.memo_table_types.clone() } + + /// Returns memory usage information about any inputs. + #[cfg(feature = "salsa_unstable")] + fn memory_usage(&self, db: &dyn Database) -> Option> { + let memory_usage = self + .entries(db) + // SAFETY: The memo table belongs to a value that we allocated, so it + // has the correct type. + .map(|value| unsafe { value.memory_usage(&self.memo_table_types) }) + .collect(); + Some(memory_usage) + } } impl std::fmt::Debug for IngredientImpl { @@ -284,6 +296,24 @@ where pub fn fields(&self) -> &C::Fields { &self.fields } + + /// Returns memory usage information about the input. + /// + /// # Safety + /// + /// The `MemoTable` must belong to a `Value` of the correct type. + #[cfg(feature = "salsa_unstable")] + unsafe fn memory_usage(&self, memo_table_types: &MemoTableTypes) -> crate::SlotInfo { + // SAFETY: The caller guarantees this is the correct types table. + let memos = unsafe { memo_table_types.attach_memos(&self.memos) }; + + crate::SlotInfo { + debug_name: C::DEBUG_NAME, + size_of_metadata: std::mem::size_of::() - std::mem::size_of::(), + size_of_fields: std::mem::size_of::(), + memos: memos.memory_usage(), + } + } } pub trait HasBuilder { diff --git a/src/interned.rs b/src/interned.rs index 269da3b7..18f0d56c 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -190,6 +190,28 @@ where // ensures that they are not reused while being accessed. unsafe { &*self.fields.get() } } + + /// Returns memory usage information about the interned value. + /// + /// # Safety + /// + /// The `MemoTable` must belong to a `Value` of the correct type. Additionally, the + /// lock must be held for the shard containing the value. + #[cfg(all(not(feature = "shuttle"), feature = "salsa_unstable"))] + unsafe fn memory_usage(&self, memo_table_types: &MemoTableTypes) -> crate::SlotInfo { + // SAFETY: The caller guarantees we hold the lock for the shard containing the value, so we + // have at-least read-only access to the value's memos. + let memos = unsafe { &*self.memos.get() }; + // SAFETY: The caller guarantees this is the correct types table. + let memos = unsafe { memo_table_types.attach_memos(memos) }; + + crate::SlotInfo { + debug_name: C::DEBUG_NAME, + size_of_metadata: std::mem::size_of::() - std::mem::size_of::>(), + size_of_fields: std::mem::size_of::>(), + memos: memos.memory_usage(), + } + } } impl Default for JarImpl { @@ -680,7 +702,7 @@ where // // # Safety // - // The lock must be held. + // The lock must be held for the shard containing the value. unsafe fn value_hash<'db>(&'db self, id: Id, zalsa: &'db Zalsa) -> u64 { // This closure is only called if the table is resized. So while it's expensive // to lookup all values, it will only happen rarely. @@ -694,7 +716,7 @@ where // // # Safety // - // The lock must be held. + // The lock must be held for the shard containing the value. unsafe fn value_eq<'db, Key>( id: Id, key: &Key, @@ -830,6 +852,31 @@ where fn memo_table_types(&self) -> Arc { self.memo_table_types.clone() } + + /// Returns memory usage information about any interned values. + #[cfg(all(not(feature = "shuttle"), feature = "salsa_unstable"))] + fn memory_usage(&self, db: &dyn Database) -> Option> { + use parking_lot::lock_api::RawMutex; + + for shard in self.shards.iter() { + // SAFETY: We do not hold any active mutex guards. + unsafe { shard.raw().lock() }; + } + + let memory_usage = self + .entries(db) + // SAFETY: The memo table belongs to a value that we allocated, so it + // has the correct type. Additionally, we are holding the locks for all shards. + .map(|value| unsafe { value.memory_usage(&self.memo_table_types) }) + .collect(); + + for shard in self.shards.iter() { + // SAFETY: We acquired the locks for all shards. + unsafe { shard.raw().unlock() }; + } + + Some(memory_usage) + } } impl std::fmt::Debug for IngredientImpl diff --git a/src/lib.rs b/src/lib.rs index 2d1465ee..bf12206f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,6 +39,9 @@ pub use parallel::{join, par_map}; #[cfg(feature = "macros")] pub use salsa_macros::{accumulator, db, input, interned, tracked, Supertype, Update}; +#[cfg(feature = "salsa_unstable")] +pub use self::database::{IngredientInfo, SlotInfo}; + pub use self::accumulator::Accumulator; pub use self::active_query::Backtrace; pub use self::cancelled::Cancelled; diff --git a/src/table.rs b/src/table.rs index 6b7344ff..3f08841d 100644 --- a/src/table.rs +++ b/src/table.rs @@ -253,6 +253,7 @@ impl Table { unsafe { page.memo_types.attach_memos_mut(memos) } } + #[cfg(feature = "salsa_unstable")] pub(crate) fn slots_of(&self) -> impl Iterator + '_ { self.pages .iter() @@ -392,6 +393,7 @@ impl Page { PageView(self, PhantomData) } + #[cfg(feature = "salsa_unstable")] fn cast_type(&self) -> Option> { if self.slot_type_id == TypeId::of::() { Some(PageView(self, PhantomData)) diff --git a/src/table/memo.rs b/src/table/memo.rs index 0daaca6c..821f7c4e 100644 --- a/src/table/memo.rs +++ b/src/table/memo.rs @@ -1,9 +1,7 @@ -use std::{ - any::{Any, TypeId}, - fmt::Debug, - mem, - ptr::{self, NonNull}, -}; +use std::any::{Any, TypeId}; +use std::fmt::Debug; +use std::mem; +use std::ptr::{self, NonNull}; use portable_atomic::hint::spin_loop; use thin_vec::ThinVec; @@ -23,6 +21,10 @@ pub(crate) struct MemoTable { pub trait Memo: Any + Send + Sync { /// Returns the `origin` of this memo fn origin(&self) -> QueryOriginRef<'_>; + + /// Returns memory usage information about the memoized value. + #[cfg(feature = "salsa_unstable")] + fn memory_usage(&self) -> crate::SlotInfo; } /// Data for a memoized entry. @@ -53,7 +55,7 @@ pub struct MemoEntryType { data: OnceLock, } -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug)] struct MemoEntryTypeData { /// The `type_id` of the erased memo type `M` type_id: TypeId, @@ -102,12 +104,22 @@ impl MemoEntryType { /// Dummy placeholder type that we use when erasing the memo type `M` in [`MemoEntryData`][]. #[derive(Debug)] -struct DummyMemo {} +struct DummyMemo; impl Memo for DummyMemo { fn origin(&self) -> QueryOriginRef<'_> { unreachable!("should not get here") } + + #[cfg(feature = "salsa_unstable")] + fn memory_usage(&self) -> crate::SlotInfo { + crate::SlotInfo { + debug_name: "dummy", + size_of_metadata: 0, + size_of_fields: 0, + memos: Vec::new(), + } + } } #[derive(Default)] @@ -146,7 +158,6 @@ impl MemoTableTypes { "cannot provide an empty `MemoEntryType` for `MemoEntryType::set()`", ), ) - .ok() .expect("memo type should only be set once"); break; } @@ -156,7 +167,7 @@ impl MemoTableTypes { /// /// The types table must be the correct one of `memos`. #[inline] - pub(super) unsafe fn attach_memos<'a>( + pub(crate) unsafe fn attach_memos<'a>( &'a self, memos: &'a MemoTable, ) -> MemoTableWithTypes<'a> { @@ -266,6 +277,27 @@ impl MemoTableWithTypes<'_> { // SAFETY: `type_id` check asserted above Some(unsafe { MemoEntryType::from_dummy(memo) }) } + + #[cfg(feature = "salsa_unstable")] + pub(crate) fn memory_usage(&self) -> Vec { + let mut memory_usage = Vec::new(); + let memos = self.memos.memos.read(); + for (index, memo) in memos.iter().enumerate() { + let Some(memo) = NonNull::new(memo.atomic_memo.load(Ordering::Acquire)) else { + continue; + }; + + let Some(type_) = self.types.types.get(index).and_then(MemoEntryType::load) else { + continue; + }; + + // SAFETY: The `TypeId` is asserted in `insert()`. + let dyn_memo: &dyn Memo = unsafe { (type_.to_dyn_fn)(memo).as_ref() }; + memory_usage.push(dyn_memo.memory_usage()); + } + + memory_usage + } } pub(crate) struct MemoTableWithTypesMut<'a> { diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 3190c9e7..83817021 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -852,6 +852,18 @@ where fn memo_table_types(&self) -> Arc { self.memo_table_types.clone() } + + /// Returns memory usage information about any tracked structs. + #[cfg(feature = "salsa_unstable")] + fn memory_usage(&self, db: &dyn Database) -> Option> { + let memory_usage = self + .entries(db) + // SAFETY: The memo table belongs to a value that we allocated, so it + // has the correct type. + .map(|value| unsafe { value.memory_usage(&self.memo_table_types) }) + .collect(); + Some(memory_usage) + } } impl std::fmt::Debug for IngredientImpl @@ -910,6 +922,24 @@ where } } } + + /// Returns memory usage information about the tracked struct. + /// + /// # Safety + /// + /// The `MemoTable` must belong to a `Value` of the correct type. + #[cfg(feature = "salsa_unstable")] + unsafe fn memory_usage(&self, memo_table_types: &MemoTableTypes) -> crate::SlotInfo { + // SAFETY: The caller guarantees this is the correct types table. + let memos = unsafe { memo_table_types.attach_memos(&self.memos) }; + + crate::SlotInfo { + debug_name: C::DEBUG_NAME, + size_of_metadata: mem::size_of::() - mem::size_of::>(), + size_of_fields: mem::size_of::>(), + memos: memos.memory_usage(), + } + } } // SAFETY: `Value` is our private type branded over the unique configuration `C`. diff --git a/src/zalsa.rs b/src/zalsa.rs index b5b90a04..d1be2160 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -235,6 +235,13 @@ impl Zalsa { [memo_ingredient_index.as_usize()] } + #[cfg(feature = "salsa_unstable")] + pub(crate) fn ingredients(&self) -> impl Iterator { + self.ingredients_vec + .iter() + .map(|(_, ingredient)| ingredient.as_ref()) + } + /// Starts unwinding the stack if the current revision is cancelled. /// /// This method can be called by query implementations that perform diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 587a4fb6..80e24e7f 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -347,6 +347,35 @@ pub(crate) struct QueryRevisions { pub(super) extra: QueryRevisionsExtra, } +impl QueryRevisions { + #[cfg(feature = "salsa_unstable")] + pub(crate) fn allocation_size(&self) -> usize { + let QueryRevisions { + changed_at: _, + durability: _, + accumulated_inputs: _, + verified_final: _, + origin, + extra, + } = self; + + let mut memory = 0; + + if let QueryOriginRef::Derived(query_edges) + | QueryOriginRef::DerivedUntracked(query_edges) = origin.as_ref() + { + memory += std::mem::size_of_val(query_edges); + } + + if let Some(extra) = extra.0.as_ref() { + memory += std::mem::size_of::(); + memory += extra.allocation_size(); + } + + memory + } +} + /// Data on `QueryRevisions` that is lazily allocated to save memory /// in the common case. /// @@ -417,6 +446,22 @@ struct QueryRevisionsExtraInner { iteration: IterationCount, } +impl QueryRevisionsExtraInner { + #[cfg(feature = "salsa_unstable")] + fn allocation_size(&self) -> usize { + let QueryRevisionsExtraInner { + accumulated, + tracked_struct_ids, + cycle_heads, + iteration: _, + } = self; + + accumulated.allocation_size() + + cycle_heads.allocation_size() + + std::mem::size_of_val(tracked_struct_ids.as_slice()) + } +} + #[cfg(not(feature = "shuttle"))] #[cfg(target_pointer_width = "64")] const _: [(); std::mem::size_of::()] = [(); std::mem::size_of::<[usize; 4]>()]; diff --git a/tests/memory-usage.rs b/tests/memory-usage.rs new file mode 100644 index 00000000..c16a9464 --- /dev/null +++ b/tests/memory-usage.rs @@ -0,0 +1,128 @@ +use expect_test::expect; + +#[salsa::input] +struct MyInput { + field: u32, +} + +#[salsa::tracked] +struct MyTracked<'db> { + field: u32, +} + +#[salsa::interned] +struct MyInterned<'db> { + field: u32, +} + +#[salsa::tracked] +fn input_to_interned<'db>(db: &'db dyn salsa::Database, input: MyInput) -> MyInterned<'db> { + MyInterned::new(db, input.field(db)) +} + +#[salsa::tracked] +fn input_to_tracked<'db>(db: &'db dyn salsa::Database, input: MyInput) -> MyTracked<'db> { + MyTracked::new(db, input.field(db)) +} + +#[salsa::tracked] +fn input_to_tracked_tuple<'db>( + db: &'db dyn salsa::Database, + input: MyInput, +) -> (MyTracked<'db>, MyTracked<'db>) { + ( + MyTracked::new(db, input.field(db)), + MyTracked::new(db, input.field(db)), + ) +} + +#[test] +fn test() { + let db = salsa::DatabaseImpl::new(); + + let input1 = MyInput::new(&db, 1); + let input2 = MyInput::new(&db, 2); + let input3 = MyInput::new(&db, 3); + + let _tracked1 = input_to_tracked(&db, input1); + let _tracked2 = input_to_tracked(&db, input2); + + let _tracked_tuple = input_to_tracked_tuple(&db, input1); + + let _interned1 = input_to_interned(&db, input1); + let _interned2 = input_to_interned(&db, input2); + let _interned3 = input_to_interned(&db, input3); + + let structs_info = ::structs_info(&db); + + let expected = expect![[r#" + [ + IngredientInfo { + debug_name: "MyInput", + count: 3, + size_of_metadata: 84, + size_of_fields: 12, + }, + IngredientInfo { + debug_name: "MyTracked", + count: 4, + size_of_metadata: 112, + size_of_fields: 16, + }, + IngredientInfo { + debug_name: "MyInterned", + count: 3, + size_of_metadata: 156, + size_of_fields: 12, + }, + ]"#]]; + + expected.assert_eq(&format!("{structs_info:#?}")); + + let mut queries_info = ::queries_info(&db) + .into_iter() + .collect::>(); + queries_info.sort(); + + let expected = expect![[r#" + [ + ( + ( + "MyInput", + "(memory_usage::MyTracked, memory_usage::MyTracked)", + ), + IngredientInfo { + debug_name: "(memory_usage::MyTracked, memory_usage::MyTracked)", + count: 1, + size_of_metadata: 132, + size_of_fields: 16, + }, + ), + ( + ( + "MyInput", + "memory_usage::MyInterned", + ), + IngredientInfo { + debug_name: "memory_usage::MyInterned", + count: 3, + size_of_metadata: 192, + size_of_fields: 24, + }, + ), + ( + ( + "MyInput", + "memory_usage::MyTracked", + ), + IngredientInfo { + debug_name: "memory_usage::MyTracked", + count: 2, + size_of_metadata: 192, + size_of_fields: 16, + }, + ), + ]"#]]; + + expected.assert_eq(&format!("{queries_info:#?}")); +}