Lazily allocate extra memo state (#888)

* lazily allocate extra memo state

* lazily allocate accumulators

* simplify `QueryRevisionsExtra`
This commit is contained in:
Ibraheem Ahmed 2025-05-30 09:28:40 -04:00 committed by GitHub
parent 5750c8448f
commit 0c39c08360
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 173 additions and 60 deletions

View file

@ -1,4 +1,3 @@
use std::ops::Not;
use std::{fmt, mem, ops};
use crate::accumulator::accumulated_map::{
@ -11,7 +10,7 @@ use crate::key::DatabaseKeyIndex;
use crate::runtime::Stamp;
use crate::sync::atomic::AtomicBool;
use crate::tracked_struct::{Disambiguator, DisambiguatorMap, IdentityHash, IdentityMap};
use crate::zalsa_local::{QueryEdge, QueryOrigin, QueryRevisions};
use crate::zalsa_local::{QueryEdge, QueryOrigin, QueryRevisions, QueryRevisionsExtra};
use crate::{Accumulator, IngredientIndex, Revision};
#[derive(Debug)]
@ -199,22 +198,22 @@ impl ActiveQuery {
QueryOrigin::derived(input_outputs.drain(..))
};
disambiguator_map.clear();
let accumulated = accumulated
.is_empty()
.not()
.then(|| Box::new(mem::take(accumulated)));
let tracked_struct_ids = mem::take(tracked_struct_ids);
let verified_final = cycle_heads.is_empty();
let extra = QueryRevisionsExtra::new(
mem::take(accumulated),
mem::take(tracked_struct_ids),
mem::take(cycle_heads),
);
let accumulated_inputs = AtomicInputAccumulatedValues::new(accumulated_inputs);
let cycle_heads = mem::take(cycle_heads);
QueryRevisions {
changed_at,
durability,
origin,
tracked_struct_ids,
accumulated_inputs,
accumulated,
verified_final: AtomicBool::new(cycle_heads.is_empty()),
cycle_heads,
verified_final: AtomicBool::new(verified_final),
extra,
}
}

View file

@ -191,7 +191,10 @@ where
mut memo: memo::Memo<C::Output<'db>>,
memo_ingredient_index: MemoIngredientIndex,
) -> &'db memo::Memo<C::Output<'db>> {
memo.revisions.tracked_struct_ids.shrink_to_fit();
if let Some(tracked_struct_ids) = memo.revisions.tracked_struct_ids_mut() {
tracked_struct_ids.shrink_to_fit();
}
// We convert to a `NonNull` here as soon as possible because we are going to alias
// into the `Box`, which is a `noalias` type.
// FIXME: Use `Box::into_non_null` once stable

View file

@ -99,7 +99,7 @@ where
// NEXT STEP: stash and refactor `fetch` to return an `&Memo` so we can make this work
let memo = self.refresh_memo(db, db.zalsa(), key);
(
memo.revisions.accumulated.as_deref(),
memo.revisions.accumulated(),
memo.revisions.accumulated_inputs.load(),
)
}

View file

@ -22,7 +22,7 @@ where
// right now whether backdating could be made safe for queries participating in queries.
// TODO: Write a test that demonstrates that backdating queries participating in a cycle isn't safe
// OR write many tests showing that it is (and fixing the case where it didn't correctly account for today).
if !revisions.cycle_heads.is_empty() {
if !revisions.cycle_heads().is_empty() {
return;
}

View file

@ -50,11 +50,12 @@ where
return;
}
// Remove the outputs that are no longer present in the current revision
// to prevent that the next revision is seeded with an id mapping that no longer exists.
revisions
.tracked_struct_ids
.retain(|&k, &mut value| !old_outputs.contains(&(k.ingredient_index(), value.index())));
if let Some(tracked_struct_ids) = revisions.tracked_struct_ids_mut() {
// Remove the outputs that are no longer present in the current revision
// to prevent that the next revision is seeded with an id mapping that no longer exists.
tracked_struct_ids
.retain(|k, value| !old_outputs.contains(&(k.ingredient_index(), value.index())));
}
for (ingredient_index, key_index) in old_outputs {
// SAFETY: key_index acquired from valid output

View file

@ -52,9 +52,9 @@ where
id,
);
if !revisions.cycle_heads.is_empty() {
if let Some(cycle_heads) = revisions.cycle_heads_mut() {
// Did the new result we got depend on our own provisional value, in a cycle?
if revisions.cycle_heads.contains(&database_key_index) {
if cycle_heads.contains(&database_key_index) {
// Ignore the computed value, leave the fallback value there.
let memo = self
.get_memo_from_table_for(zalsa, id, memo_ingredient_index)
@ -73,15 +73,16 @@ where
// If we're in the middle of a cycle and we have a fallback, use it instead.
// Cycle participants that don't have a fallback will be discarded in
// `validate_provisional()`.
let cycle_heads = revisions.cycle_heads;
let cycle_heads = std::mem::take(cycle_heads);
let active_query = db.zalsa_local().push_query(database_key_index, 0);
new_value = C::cycle_initial(db, C::id_to_input(db, id));
revisions = active_query.pop();
// We need to set `cycle_heads` and `verified_final` because it needs to propagate to the callers.
// When verifying this, we will see we have fallback and mark ourselves verified.
revisions.cycle_heads = cycle_heads;
revisions.set_cycle_heads(cycle_heads);
revisions.verified_final = AtomicBool::new(false);
}
(new_value, revisions)
}
CycleRecoveryStrategy::Fixpoint => self.execute_maybe_iterate(
@ -142,7 +143,10 @@ where
);
// Did the new result we got depend on our own provisional value, in a cycle?
if revisions.cycle_heads.contains(&database_key_index) {
if let Some(cycle_heads) = revisions
.cycle_heads_mut()
.filter(|cycle_heads| cycle_heads.contains(&database_key_index))
{
let last_provisional_value = if let Some(last_provisional) = opt_last_provisional {
// We have a last provisional value from our previous time around the loop.
last_provisional.value.as_ref()
@ -215,9 +219,7 @@ where
fell_back,
})
});
revisions
.cycle_heads
.update_iteration_count(database_key_index, iteration_count);
cycle_heads.update_iteration_count(database_key_index, iteration_count);
opt_last_provisional = Some(self.insert_memo(
zalsa,
id,
@ -234,7 +236,7 @@ where
tracing::debug!(
"{database_key_index:?}: execute: fixpoint iteration has a final value"
);
revisions.cycle_heads.remove(&database_key_index);
cycle_heads.remove(&database_key_index);
}
tracing::debug!("{database_key_index:?}: execute: result.revisions = {revisions:#?}");
@ -254,7 +256,9 @@ where
if let Some(old_memo) = opt_old_memo {
// If we already executed this query once, then use the tracked-struct ids from the
// previous execution as the starting point for the new one.
active_query.seed_tracked_struct_ids(&old_memo.revisions.tracked_struct_ids);
if let Some(tracked_struct_ids) = old_memo.revisions.tracked_struct_ids() {
active_query.seed_tracked_struct_ids(tracked_struct_ids);
}
// Copy over all inputs and outputs from a previous iteration.
// This is necessary to:

View file

@ -28,7 +28,7 @@ where
database_key_index,
memo.revisions.durability,
memo.revisions.changed_at,
memo.revisions.accumulated.is_some(),
memo.revisions.accumulated().is_some(),
&memo.revisions.accumulated_inputs,
memo.cycle_heads(),
);
@ -124,7 +124,7 @@ where
let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index);
if let Some(memo) = memo_guard {
if memo.value.is_some()
&& memo.revisions.cycle_heads.contains(&database_key_index)
&& memo.revisions.cycle_heads().contains(&database_key_index)
{
let can_shallow_update =
self.shallow_verify_memo(zalsa, database_key_index, memo);
@ -164,7 +164,7 @@ where
let active_query = db.zalsa_local().push_query(database_key_index, 0);
let fallback_value = C::cycle_initial(db, C::id_to_input(db, id));
let mut revisions = active_query.pop();
revisions.cycle_heads = CycleHeads::initial(database_key_index);
revisions.set_cycle_heads(CycleHeads::initial(database_key_index));
// We need this for `cycle_heads()` to work. We will unset this in the outer `execute()`.
*revisions.verified_final.get_mut() = false;
Some(self.insert_memo(

View file

@ -159,7 +159,7 @@ where
return Some(if changed_at > revision {
VerifyResult::Changed
} else {
VerifyResult::Unchanged(match &memo.revisions.accumulated {
VerifyResult::Unchanged(match memo.revisions.accumulated() {
Some(_) => InputAccumulatedValues::Any,
None => memo.revisions.accumulated_inputs.load(),
})
@ -257,7 +257,7 @@ where
"{database_key_index:?}: validate_provisional(memo = {memo:#?})",
memo = memo.tracing_debug()
);
for cycle_head in &memo.revisions.cycle_heads {
for cycle_head in memo.revisions.cycle_heads() {
let kind = zalsa
.lookup_ingredient(cycle_head.database_key_index.ingredient_index())
.cycle_head_kind(zalsa, cycle_head.database_key_index.key_index());
@ -303,7 +303,7 @@ where
memo = memo.tracing_debug()
);
let cycle_heads = &memo.revisions.cycle_heads;
let cycle_heads = memo.revisions.cycle_heads();
if cycle_heads.is_empty() {
return true;
}

View file

@ -101,7 +101,7 @@ pub struct Memo<V> {
#[cfg(not(feature = "shuttle"))]
#[cfg(target_pointer_width = "64")]
const _: [(); std::mem::size_of::<Memo<std::num::NonZeroUsize>>()] =
[(); std::mem::size_of::<[usize; 11]>()];
[(); std::mem::size_of::<[usize; 6]>()];
impl<V> Memo<V> {
pub(super) fn new(value: Option<V>, revision_now: Revision, revisions: QueryRevisions) -> Self {
@ -134,7 +134,7 @@ impl<V> Memo<V> {
zalsa: &Zalsa,
database_key_index: DatabaseKeyIndex,
) -> bool {
if self.revisions.cycle_heads.is_empty() {
if self.revisions.cycle_heads().is_empty() {
return false;
}
@ -142,7 +142,7 @@ impl<V> Memo<V> {
return false;
};
return provisional_retry_cold(zalsa, database_key_index, &self.revisions.cycle_heads);
return provisional_retry_cold(zalsa, database_key_index, self.revisions.cycle_heads());
#[inline(never)]
fn provisional_retry_cold(
@ -204,7 +204,7 @@ impl<V> Memo<V> {
#[inline(always)]
pub(super) fn cycle_heads(&self) -> &CycleHeads {
if self.may_be_provisional() {
&self.revisions.cycle_heads
self.revisions.cycle_heads()
} else {
empty_cycle_heads()
}

View file

@ -5,7 +5,7 @@ use crate::revision::AtomicRevision;
use crate::sync::atomic::AtomicBool;
use crate::tracked_struct::TrackedStructInDb;
use crate::zalsa::{Zalsa, ZalsaDatabase};
use crate::zalsa_local::{QueryOrigin, QueryOriginRef, QueryRevisions};
use crate::zalsa_local::{QueryOrigin, QueryOriginRef, QueryRevisions, QueryRevisionsExtra};
use crate::{DatabaseKeyIndex, Id};
impl<C> IngredientImpl<C>
@ -66,11 +66,9 @@ where
changed_at: current_deps.changed_at,
durability: current_deps.durability,
origin: QueryOrigin::assigned(active_query_key),
tracked_struct_ids: Default::default(),
accumulated: Default::default(),
accumulated_inputs: Default::default(),
verified_final: AtomicBool::new(true),
cycle_heads: Default::default(),
extra: QueryRevisionsExtra::default(),
};
let memo_ingredient_index = self.memo_ingredient_index(zalsa, key);

View file

@ -7,7 +7,7 @@ use tracing::debug;
use crate::accumulator::accumulated_map::{AccumulatedMap, AtomicInputAccumulatedValues};
use crate::active_query::QueryStack;
use crate::cycle::CycleHeads;
use crate::cycle::{empty_cycle_heads, CycleHeads};
use crate::durability::Durability;
use crate::key::DatabaseKeyIndex;
use crate::runtime::Stamp;
@ -328,6 +328,57 @@ pub(crate) struct QueryRevisions {
/// How was this query computed?
pub(crate) origin: QueryOrigin,
/// [`InputAccumulatedValues::Empty`] if any input read during the query's execution
/// has any direct or indirect accumulated values.
///
/// Note that this field could be in `QueryRevisionsExtra` as it is only relevant
/// for accumulators, but we get it for free anyways due to padding.
pub(super) accumulated_inputs: AtomicInputAccumulatedValues,
/// Are the `cycle_heads` verified to not be provisional anymore?
///
/// Note that this field could be in `QueryRevisionsExtra` as it is only
/// relevant for queries that participate in a cycle, but we get it for
/// free anyways due to padding.
pub(super) verified_final: AtomicBool,
/// Lazily allocated state.
pub(super) extra: QueryRevisionsExtra,
}
/// Data on `QueryRevisions` that is lazily allocated to save memory
/// in the common case.
///
/// In particular, not all queries create tracked structs, participate
/// in cycles, or create accumulators.
#[derive(Debug, Default)]
pub(crate) struct QueryRevisionsExtra(Option<Box<QueryRevisionsExtraInner>>);
impl QueryRevisionsExtra {
pub fn new(
accumulated: AccumulatedMap,
tracked_struct_ids: IdentityMap,
cycle_heads: CycleHeads,
) -> Self {
let inner =
if tracked_struct_ids.is_empty() && cycle_heads.is_empty() && accumulated.is_empty() {
None
} else {
Some(Box::new(QueryRevisionsExtraInner {
accumulated,
cycle_heads,
tracked_struct_ids,
}))
};
Self(inner)
}
}
#[derive(Debug)]
struct QueryRevisionsExtraInner {
accumulated: AccumulatedMap,
/// The ids of tracked structs created by this query.
///
/// This table plays an important role when queries are
@ -345,16 +396,7 @@ pub(crate) struct QueryRevisions {
/// previous revision. To handle this, `diff_outputs` compares
/// the structs from the old/new revision and retains
/// only entries that appeared in the new revision.
pub(super) tracked_struct_ids: IdentityMap,
pub(super) accumulated: Option<Box<AccumulatedMap>>,
/// [`InputAccumulatedValues::Empty`] if any input read during the query's execution
/// has any direct or indirect accumulated values.
pub(super) accumulated_inputs: AtomicInputAccumulatedValues,
/// Are the `cycle_heads` verified to not be provisional anymore?
pub(super) verified_final: AtomicBool,
tracked_struct_ids: IdentityMap,
/// This result was computed based on provisional values from
/// these cycle heads. The "cycle head" is the query responsible
@ -364,12 +406,17 @@ pub(crate) struct QueryRevisions {
/// which must provide the initial provisional value and decide,
/// after each iteration, whether the cycle has converged or must
/// iterate again.
pub(super) cycle_heads: CycleHeads,
cycle_heads: CycleHeads,
}
#[cfg(not(feature = "shuttle"))]
#[cfg(target_pointer_width = "64")]
const _: [(); std::mem::size_of::<QueryRevisions>()] = [(); std::mem::size_of::<[usize; 9]>()];
const _: [(); std::mem::size_of::<QueryRevisions>()] = [(); std::mem::size_of::<[usize; 4]>()];
#[cfg(not(feature = "shuttle"))]
#[cfg(target_pointer_width = "64")]
const _: [(); std::mem::size_of::<QueryRevisionsExtraInner>()] =
[(); std::mem::size_of::<[usize; 9]>()];
impl QueryRevisions {
pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex) -> Self {
@ -377,14 +424,75 @@ impl QueryRevisions {
changed_at: Revision::start(),
durability: Durability::MAX,
origin: QueryOrigin::fixpoint_initial(),
tracked_struct_ids: Default::default(),
accumulated: Default::default(),
accumulated_inputs: Default::default(),
verified_final: AtomicBool::new(false),
cycle_heads: CycleHeads::initial(query),
extra: QueryRevisionsExtra::new(
AccumulatedMap::default(),
IdentityMap::default(),
CycleHeads::initial(query),
),
}
}
/// Returns a reference to the `AccumulatedMap` for this query, or `None` if the map is empty.
pub(crate) fn accumulated(&self) -> Option<&AccumulatedMap> {
self.extra
.0
.as_ref()
.map(|extra| &extra.accumulated)
.filter(|map| !map.is_empty())
}
/// Returns a reference to the `CycleHeads` for this query.
pub(crate) fn cycle_heads(&self) -> &CycleHeads {
match &self.extra.0 {
Some(extra) => &extra.cycle_heads,
None => empty_cycle_heads(),
}
}
/// Returns a mutable reference to the `CycleHeads` for this query, or `None` if the list is empty.
pub(crate) fn cycle_heads_mut(&mut self) -> Option<&mut CycleHeads> {
self.extra
.0
.as_mut()
.map(|extra| &mut extra.cycle_heads)
.filter(|cycle_heads| !cycle_heads.is_empty())
}
/// Sets the `CycleHeads` for this query.
pub(crate) fn set_cycle_heads(&mut self, cycle_heads: CycleHeads) {
match &mut self.extra.0 {
Some(extra) => extra.cycle_heads = cycle_heads,
None => {
self.extra = QueryRevisionsExtra::new(
AccumulatedMap::default(),
IdentityMap::default(),
cycle_heads,
);
}
};
}
/// Returns a reference to the `IdentityMap` for this query, or `None` if the map is empty.
pub fn tracked_struct_ids(&self) -> Option<&IdentityMap> {
self.extra
.0
.as_ref()
.map(|extra| &extra.tracked_struct_ids)
.filter(|tracked_struct_ids| !tracked_struct_ids.is_empty())
}
/// Returns a mutable reference to the `IdentityMap` for this query, or `None` if the map is empty.
pub fn tracked_struct_ids_mut(&mut self) -> Option<&mut IdentityMap> {
self.extra
.0
.as_mut()
.map(|extra| &mut extra.tracked_struct_ids)
.filter(|tracked_struct_ids| !tracked_struct_ids.is_empty())
}
}
/// Tracks the way that a memoized value for a query was created.
///
/// This is a read-only reference to a `PackedQueryOrigin`.