diff --git a/src/function/execute.rs b/src/function/execute.rs index d07bb45f..558ace73 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -170,7 +170,7 @@ where // Our provisional value from the previous iteration, when doing fixpoint iteration. // This is different from `opt_old_memo` which might be from a different revision. - let mut last_provisional_memo: Option<&Memo<'db, C>> = None; + let mut last_provisional_memo_opt: Option<&Memo<'db, C>> = None; // TODO: Can we seed those somehow? let mut last_stale_tracked_ids: Vec<(Identity, Id)> = Vec::new(); @@ -194,7 +194,7 @@ where // Only use the last provisional memo if it was a cycle head in the last iteration. This is to // force at least two executions. if old_memo.cycle_heads().contains(&database_key_index) { - last_provisional_memo = Some(old_memo); + last_provisional_memo_opt = Some(old_memo); } iteration_count = old_memo.revisions.iteration(); @@ -219,7 +219,7 @@ where db, zalsa, active_query, - last_provisional_memo.or(opt_old_memo), + last_provisional_memo_opt.or(opt_old_memo), ); // Take the cycle heads to not-fight-rust's-borrow-checker. @@ -329,10 +329,7 @@ where // Get the last provisional value for this query so that we can compare it with the new value // to test if the cycle converged. - let last_provisional_value = if let Some(last_provisional) = last_provisional_memo { - // We have a last provisional value from our previous time around the loop. - last_provisional.value.as_ref() - } else { + let last_provisional_memo = last_provisional_memo_opt.unwrap_or_else(|| { // This is our first time around the loop; a provisional value must have been // inserted into the memo table when the cycle was hit, so let's pull our // initial provisional value from there. @@ -346,8 +343,10 @@ where }); debug_assert!(memo.may_be_provisional()); - memo.value.as_ref() - }; + memo + }); + + let last_provisional_value = last_provisional_memo.value.as_ref(); let last_provisional_value = last_provisional_value.expect( "`fetch_cold_cycle` should have inserted a provisional memo with Cycle::initial", @@ -389,9 +388,24 @@ where } } - let this_converged = C::values_equal(&new_value, last_provisional_value); let mut completed_query = active_query.pop(); + let value_converged = C::values_equal(&new_value, last_provisional_value); + + // It's important to force a re-execution of the cycle if `changed_at` or `durability` has changed + // to ensure the reduced durability and changed propagates to all queries depending on this head. + let metadata_converged = last_provisional_memo.revisions.durability + == completed_query.revisions.durability + && last_provisional_memo.revisions.changed_at + == completed_query.revisions.changed_at + && last_provisional_memo + .revisions + .origin + .is_derived_untracked() + == completed_query.revisions.origin.is_derived_untracked(); + + let this_converged = value_converged && metadata_converged; + if let Some(outer_cycle) = outer_cycle { tracing::info!( "Detected nested cycle {database_key_index:?}, iterate it as part of the outer cycle {outer_cycle:?}" @@ -494,7 +508,7 @@ where memo_ingredient_index, ); - last_provisional_memo = Some(new_memo); + last_provisional_memo_opt = Some(new_memo); last_stale_tracked_ids = completed_query.stale_tracked_structs; diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index bde3b6b2..8f0239e5 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -934,6 +934,10 @@ impl QueryOrigin { } } + pub fn is_derived_untracked(&self) -> bool { + matches!(self.kind, QueryOriginKind::DerivedUntracked) + } + /// Create a query origin of type `QueryOriginKind::Derived`, with the given edges. pub fn derived(input_outputs: Box<[QueryEdge]>) -> QueryOrigin { // Exceeding `u32::MAX` query edges should never happen in real-world usage. diff --git a/tests/cycle_input_different_cycle_head.rs b/tests/cycle_input_different_cycle_head.rs new file mode 100644 index 00000000..d7f75143 --- /dev/null +++ b/tests/cycle_input_different_cycle_head.rs @@ -0,0 +1,72 @@ +#![cfg(feature = "inventory")] + +//! Tests that the durability correctly propagates +//! to all cycle heads. + +use salsa::Setter as _; + +#[test_log::test] +fn low_durability_cycle_enter_from_different_head() { + let mut db = MyDbImpl::default(); + // Start with 0, the same as returned by cycle initial + let input = Input::builder(0).new(&db); + db.input = Some(input); + + assert_eq!(query_a(&db), 0); // Prime the Db + + input.set_value(&mut db).to(10); + + assert_eq!(query_b(&db), 10); +} + +#[salsa::input] +struct Input { + value: u32, +} + +#[salsa::db] +trait MyDb: salsa::Database { + fn input(&self) -> Input; +} + +#[salsa::db] +#[derive(Clone, Default)] +struct MyDbImpl { + storage: salsa::Storage, + input: Option, +} + +#[salsa::db] +impl salsa::Database for MyDbImpl {} + +#[salsa::db] +impl MyDb for MyDbImpl { + fn input(&self) -> Input { + self.input.unwrap() + } +} + +#[salsa::tracked(cycle_initial=cycle_initial)] +fn query_a(db: &dyn MyDb) -> u32 { + query_b(db); + db.input().value(db) +} + +fn cycle_initial(_db: &dyn MyDb, _id: salsa::Id) -> u32 { + 0 +} + +#[salsa::interned] +struct Interned { + value: u32, +} + +#[salsa::tracked(cycle_initial=cycle_initial)] +fn query_b<'db>(db: &'db dyn MyDb) -> u32 { + query_c(db) +} + +#[salsa::tracked] +fn query_c(db: &dyn MyDb) -> u32 { + query_a(db) +}