Fix cycle head durability (#1024)
Some checks failed
Release-plz / Release-plz release (push) Has been cancelled
Release-plz / Release-plz PR (push) Has been cancelled
Test / Test (push) Has been cancelled
Book / Book (push) Has been cancelled
Test / Miri (push) Has been cancelled
Test / Shuttle (push) Has been cancelled
Test / Benchmarks (push) Has been cancelled
Book / Deploy (push) Has been cancelled

This commit is contained in:
Micha Reiser 2025-11-13 10:17:44 +01:00 committed by GitHub
parent 05a9af7f55
commit a885bb4c4c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 101 additions and 11 deletions

View file

@ -170,7 +170,7 @@ where
// Our provisional value from the previous iteration, when doing fixpoint iteration.
// This is different from `opt_old_memo` which might be from a different revision.
let mut last_provisional_memo: Option<&Memo<'db, C>> = None;
let mut last_provisional_memo_opt: Option<&Memo<'db, C>> = None;
// TODO: Can we seed those somehow?
let mut last_stale_tracked_ids: Vec<(Identity, Id)> = Vec::new();
@ -194,7 +194,7 @@ where
// Only use the last provisional memo if it was a cycle head in the last iteration. This is to
// force at least two executions.
if old_memo.cycle_heads().contains(&database_key_index) {
last_provisional_memo = Some(old_memo);
last_provisional_memo_opt = Some(old_memo);
}
iteration_count = old_memo.revisions.iteration();
@ -219,7 +219,7 @@ where
db,
zalsa,
active_query,
last_provisional_memo.or(opt_old_memo),
last_provisional_memo_opt.or(opt_old_memo),
);
// Take the cycle heads to not-fight-rust's-borrow-checker.
@ -329,10 +329,7 @@ where
// Get the last provisional value for this query so that we can compare it with the new value
// to test if the cycle converged.
let last_provisional_value = if let Some(last_provisional) = last_provisional_memo {
// We have a last provisional value from our previous time around the loop.
last_provisional.value.as_ref()
} else {
let last_provisional_memo = last_provisional_memo_opt.unwrap_or_else(|| {
// This is our first time around the loop; a provisional value must have been
// inserted into the memo table when the cycle was hit, so let's pull our
// initial provisional value from there.
@ -346,8 +343,10 @@ where
});
debug_assert!(memo.may_be_provisional());
memo.value.as_ref()
};
memo
});
let last_provisional_value = last_provisional_memo.value.as_ref();
let last_provisional_value = last_provisional_value.expect(
"`fetch_cold_cycle` should have inserted a provisional memo with Cycle::initial",
@ -389,9 +388,24 @@ where
}
}
let this_converged = C::values_equal(&new_value, last_provisional_value);
let mut completed_query = active_query.pop();
let value_converged = C::values_equal(&new_value, last_provisional_value);
// It's important to force a re-execution of the cycle if `changed_at` or `durability` has changed
// to ensure the reduced durability and changed propagates to all queries depending on this head.
let metadata_converged = last_provisional_memo.revisions.durability
== completed_query.revisions.durability
&& last_provisional_memo.revisions.changed_at
== completed_query.revisions.changed_at
&& last_provisional_memo
.revisions
.origin
.is_derived_untracked()
== completed_query.revisions.origin.is_derived_untracked();
let this_converged = value_converged && metadata_converged;
if let Some(outer_cycle) = outer_cycle {
tracing::info!(
"Detected nested cycle {database_key_index:?}, iterate it as part of the outer cycle {outer_cycle:?}"
@ -494,7 +508,7 @@ where
memo_ingredient_index,
);
last_provisional_memo = Some(new_memo);
last_provisional_memo_opt = Some(new_memo);
last_stale_tracked_ids = completed_query.stale_tracked_structs;

View file

@ -934,6 +934,10 @@ impl QueryOrigin {
}
}
pub fn is_derived_untracked(&self) -> bool {
matches!(self.kind, QueryOriginKind::DerivedUntracked)
}
/// Create a query origin of type `QueryOriginKind::Derived`, with the given edges.
pub fn derived(input_outputs: Box<[QueryEdge]>) -> QueryOrigin {
// Exceeding `u32::MAX` query edges should never happen in real-world usage.

View file

@ -0,0 +1,72 @@
#![cfg(feature = "inventory")]
//! Tests that the durability correctly propagates
//! to all cycle heads.
use salsa::Setter as _;
#[test_log::test]
fn low_durability_cycle_enter_from_different_head() {
let mut db = MyDbImpl::default();
// Start with 0, the same as returned by cycle initial
let input = Input::builder(0).new(&db);
db.input = Some(input);
assert_eq!(query_a(&db), 0); // Prime the Db
input.set_value(&mut db).to(10);
assert_eq!(query_b(&db), 10);
}
#[salsa::input]
struct Input {
value: u32,
}
#[salsa::db]
trait MyDb: salsa::Database {
fn input(&self) -> Input;
}
#[salsa::db]
#[derive(Clone, Default)]
struct MyDbImpl {
storage: salsa::Storage<Self>,
input: Option<Input>,
}
#[salsa::db]
impl salsa::Database for MyDbImpl {}
#[salsa::db]
impl MyDb for MyDbImpl {
fn input(&self) -> Input {
self.input.unwrap()
}
}
#[salsa::tracked(cycle_initial=cycle_initial)]
fn query_a(db: &dyn MyDb) -> u32 {
query_b(db);
db.input().value(db)
}
fn cycle_initial(_db: &dyn MyDb, _id: salsa::Id) -> u32 {
0
}
#[salsa::interned]
struct Interned {
value: u32,
}
#[salsa::tracked(cycle_initial=cycle_initial)]
fn query_b<'db>(db: &'db dyn MyDb) -> u32 {
query_c(db)
}
#[salsa::tracked]
fn query_c(db: &dyn MyDb) -> u32 {
query_a(db)
}