Merge commit 'ddf105b646' into sync-from-ra

This commit is contained in:
Laurențiu Nicola 2024-02-11 08:40:19 +02:00
parent 0816d49d83
commit e41ab350d6
378 changed files with 14720 additions and 3111 deletions

View file

@ -0,0 +1,493 @@
use std::panic::UnwindSafe;
use expect_test::expect;
use salsa::{Durability, ParallelDatabase, Snapshot};
use test_log::test;
// Axes:
//
// Threading
// * Intra-thread
// * Cross-thread -- part of cycle is on one thread, part on another
//
// Recovery strategies:
// * Panic
// * Fallback
// * Mixed -- multiple strategies within cycle participants
//
// Across revisions:
// * N/A -- only one revision
// * Present in new revision, not old
// * Present in old revision, not new
// * Present in both revisions
//
// Dependencies
// * Tracked
// * Untracked -- cycle participant(s) contain untracked reads
//
// Layers
// * Direct -- cycle participant is directly invoked from test
// * Indirect -- invoked a query that invokes the cycle
//
//
// | Thread | Recovery | Old, New | Dep style | Layers | Test Name |
// | ------ | -------- | -------- | --------- | ------ | --------- |
// | Intra | Panic | N/A | Tracked | direct | cycle_memoized |
// | Intra | Panic | N/A | Untracked | direct | cycle_volatile |
// | Intra | Fallback | N/A | Tracked | direct | cycle_cycle |
// | Intra | Fallback | N/A | Tracked | indirect | inner_cycle |
// | Intra | Fallback | Both | Tracked | direct | cycle_revalidate |
// | Intra | Fallback | New | Tracked | direct | cycle_appears |
// | Intra | Fallback | Old | Tracked | direct | cycle_disappears |
// | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability |
// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 |
// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 |
// | Cross | Fallback | N/A | Tracked | both | parallel/cycles.rs: recover_parallel_cycle |
// | Cross | Panic | N/A | Tracked | both | parallel/cycles.rs: panic_parallel_cycle |
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
struct Error {
cycle: Vec<String>,
}
#[salsa::database(GroupStruct)]
#[derive(Default)]
struct DatabaseImpl {
storage: salsa::Storage<Self>,
}
impl salsa::Database for DatabaseImpl {}
impl ParallelDatabase for DatabaseImpl {
fn snapshot(&self) -> Snapshot<Self> {
Snapshot::new(DatabaseImpl { storage: self.storage.snapshot() })
}
}
/// The queries A, B, and C in `Database` can be configured
/// to invoke one another in arbitrary ways using this
/// enum.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum CycleQuery {
None,
A,
B,
C,
AthenC,
}
#[salsa::query_group(GroupStruct)]
trait Database: salsa::Database {
// `a` and `b` depend on each other and form a cycle
fn memoized_a(&self) -> ();
fn memoized_b(&self) -> ();
fn volatile_a(&self) -> ();
fn volatile_b(&self) -> ();
#[salsa::input]
fn a_invokes(&self) -> CycleQuery;
#[salsa::input]
fn b_invokes(&self) -> CycleQuery;
#[salsa::input]
fn c_invokes(&self) -> CycleQuery;
#[salsa::cycle(recover_a)]
fn cycle_a(&self) -> Result<(), Error>;
#[salsa::cycle(recover_b)]
fn cycle_b(&self) -> Result<(), Error>;
fn cycle_c(&self) -> Result<(), Error>;
}
fn recover_a(db: &dyn Database, cycle: &salsa::Cycle) -> Result<(), Error> {
Err(Error { cycle: cycle.all_participants(db) })
}
fn recover_b(db: &dyn Database, cycle: &salsa::Cycle) -> Result<(), Error> {
Err(Error { cycle: cycle.all_participants(db) })
}
fn memoized_a(db: &dyn Database) {
db.memoized_b()
}
fn memoized_b(db: &dyn Database) {
db.memoized_a()
}
fn volatile_a(db: &dyn Database) {
db.salsa_runtime().report_untracked_read();
db.volatile_b()
}
fn volatile_b(db: &dyn Database) {
db.salsa_runtime().report_untracked_read();
db.volatile_a()
}
impl CycleQuery {
fn invoke(self, db: &dyn Database) -> Result<(), Error> {
match self {
CycleQuery::A => db.cycle_a(),
CycleQuery::B => db.cycle_b(),
CycleQuery::C => db.cycle_c(),
CycleQuery::AthenC => {
let _ = db.cycle_a();
db.cycle_c()
}
CycleQuery::None => Ok(()),
}
}
}
fn cycle_a(db: &dyn Database) -> Result<(), Error> {
db.a_invokes().invoke(db)
}
fn cycle_b(db: &dyn Database) -> Result<(), Error> {
db.b_invokes().invoke(db)
}
fn cycle_c(db: &dyn Database) -> Result<(), Error> {
db.c_invokes().invoke(db)
}
#[track_caller]
fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle {
let v = std::panic::catch_unwind(f);
if let Err(d) = &v {
if let Some(cycle) = d.downcast_ref::<salsa::Cycle>() {
return cycle.clone();
}
}
panic!("unexpected value: {:?}", v)
}
#[test]
fn cycle_memoized() {
let db = DatabaseImpl::default();
let cycle = extract_cycle(|| db.memoized_a());
expect![[r#"
[
"memoized_a(())",
"memoized_b(())",
]
"#]]
.assert_debug_eq(&cycle.unexpected_participants(&db));
}
#[test]
fn cycle_volatile() {
let db = DatabaseImpl::default();
let cycle = extract_cycle(|| db.volatile_a());
expect![[r#"
[
"volatile_a(())",
"volatile_b(())",
]
"#]]
.assert_debug_eq(&cycle.unexpected_participants(&db));
}
#[test]
fn cycle_cycle() {
let mut query = DatabaseImpl::default();
// A --> B
// ^ |
// +-----+
query.set_a_invokes(CycleQuery::B);
query.set_b_invokes(CycleQuery::A);
assert!(query.cycle_a().is_err());
}
#[test]
fn inner_cycle() {
let mut query = DatabaseImpl::default();
// A --> B <-- C
// ^ |
// +-----+
query.set_a_invokes(CycleQuery::B);
query.set_b_invokes(CycleQuery::A);
query.set_c_invokes(CycleQuery::B);
let err = query.cycle_c();
assert!(err.is_err());
let cycle = err.unwrap_err().cycle;
expect![[r#"
[
"cycle_a(())",
"cycle_b(())",
]
"#]]
.assert_debug_eq(&cycle);
}
#[test]
fn cycle_revalidate() {
let mut db = DatabaseImpl::default();
// A --> B
// ^ |
// +-----+
db.set_a_invokes(CycleQuery::B);
db.set_b_invokes(CycleQuery::A);
assert!(db.cycle_a().is_err());
db.set_b_invokes(CycleQuery::A); // same value as default
assert!(db.cycle_a().is_err());
}
#[test]
fn cycle_revalidate_unchanged_twice() {
let mut db = DatabaseImpl::default();
// A --> B
// ^ |
// +-----+
db.set_a_invokes(CycleQuery::B);
db.set_b_invokes(CycleQuery::A);
assert!(db.cycle_a().is_err());
db.set_c_invokes(CycleQuery::A); // force new revisi5on
// on this run
expect![[r#"
Err(
Error {
cycle: [
"cycle_a(())",
"cycle_b(())",
],
},
)
"#]]
.assert_debug_eq(&db.cycle_a());
}
#[test]
fn cycle_appears() {
let mut db = DatabaseImpl::default();
// A --> B
db.set_a_invokes(CycleQuery::B);
db.set_b_invokes(CycleQuery::None);
assert!(db.cycle_a().is_ok());
// A --> B
// ^ |
// +-----+
db.set_b_invokes(CycleQuery::A);
tracing::debug!("Set Cycle Leaf");
assert!(db.cycle_a().is_err());
}
#[test]
fn cycle_disappears() {
let mut db = DatabaseImpl::default();
// A --> B
// ^ |
// +-----+
db.set_a_invokes(CycleQuery::B);
db.set_b_invokes(CycleQuery::A);
assert!(db.cycle_a().is_err());
// A --> B
db.set_b_invokes(CycleQuery::None);
assert!(db.cycle_a().is_ok());
}
/// A variant on `cycle_disappears` in which the values of
/// `a_invokes` and `b_invokes` are set with durability values.
/// If we are not careful, this could cause us to overlook
/// the fact that the cycle will no longer occur.
#[test]
fn cycle_disappears_durability() {
let mut db = DatabaseImpl::default();
db.set_a_invokes_with_durability(CycleQuery::B, Durability::LOW);
db.set_b_invokes_with_durability(CycleQuery::A, Durability::HIGH);
let res = db.cycle_a();
assert!(res.is_err());
// At this point, `a` read `LOW` input, and `b` read `HIGH` input. However,
// because `b` participates in the same cycle as `a`, its final durability
// should be `LOW`.
//
// Check that setting a `LOW` input causes us to re-execute `b` query, and
// observe that the cycle goes away.
db.set_a_invokes_with_durability(CycleQuery::None, Durability::LOW);
let res = db.cycle_b();
assert!(res.is_ok());
}
#[test]
fn cycle_mixed_1() {
let mut db = DatabaseImpl::default();
// A --> B <-- C
// | ^
// +-----+
db.set_a_invokes(CycleQuery::B);
db.set_b_invokes(CycleQuery::C);
db.set_c_invokes(CycleQuery::B);
let u = db.cycle_c();
expect![[r#"
Err(
Error {
cycle: [
"cycle_b(())",
"cycle_c(())",
],
},
)
"#]]
.assert_debug_eq(&u);
}
#[test]
fn cycle_mixed_2() {
let mut db = DatabaseImpl::default();
// Configuration:
//
// A --> B --> C
// ^ |
// +-----------+
db.set_a_invokes(CycleQuery::B);
db.set_b_invokes(CycleQuery::C);
db.set_c_invokes(CycleQuery::A);
let u = db.cycle_a();
expect![[r#"
Err(
Error {
cycle: [
"cycle_a(())",
"cycle_b(())",
"cycle_c(())",
],
},
)
"#]]
.assert_debug_eq(&u);
}
#[test]
fn cycle_deterministic_order() {
// No matter whether we start from A or B, we get the same set of participants:
let db = || {
let mut db = DatabaseImpl::default();
// A --> B
// ^ |
// +-----+
db.set_a_invokes(CycleQuery::B);
db.set_b_invokes(CycleQuery::A);
db
};
let a = db().cycle_a();
let b = db().cycle_b();
expect![[r#"
(
Err(
Error {
cycle: [
"cycle_a(())",
"cycle_b(())",
],
},
),
Err(
Error {
cycle: [
"cycle_a(())",
"cycle_b(())",
],
},
),
)
"#]]
.assert_debug_eq(&(a, b));
}
#[test]
fn cycle_multiple() {
// No matter whether we start from A or B, we get the same set of participants:
let mut db = DatabaseImpl::default();
// Configuration:
//
// A --> B <-- C
// ^ | ^
// +-----+ |
// | |
// +-----+
//
// Here, conceptually, B encounters a cycle with A and then
// recovers.
db.set_a_invokes(CycleQuery::B);
db.set_b_invokes(CycleQuery::AthenC);
db.set_c_invokes(CycleQuery::B);
let c = db.cycle_c();
let b = db.cycle_b();
let a = db.cycle_a();
expect![[r#"
(
Err(
Error {
cycle: [
"cycle_a(())",
"cycle_b(())",
],
},
),
Err(
Error {
cycle: [
"cycle_a(())",
"cycle_b(())",
],
},
),
Err(
Error {
cycle: [
"cycle_a(())",
"cycle_b(())",
],
},
),
)
"#]]
.assert_debug_eq(&(a, b, c));
}
#[test]
fn cycle_recovery_set_but_not_participating() {
let mut db = DatabaseImpl::default();
// A --> C -+
// ^ |
// +--+
db.set_a_invokes(CycleQuery::C);
db.set_c_invokes(CycleQuery::C);
// Here we expect C to panic and A not to recover:
let r = extract_cycle(|| drop(db.cycle_a()));
expect![[r#"
[
"cycle_c(())",
]
"#]]
.assert_debug_eq(&r.all_participants(&db));
}

View file

@ -0,0 +1,28 @@
//! Test that you can implement a query using a `dyn Trait` setup.
#[salsa::database(DynTraitStorage)]
#[derive(Default)]
struct DynTraitDatabase {
storage: salsa::Storage<Self>,
}
impl salsa::Database for DynTraitDatabase {}
#[salsa::query_group(DynTraitStorage)]
trait DynTrait {
#[salsa::input]
fn input(&self, x: u32) -> u32;
fn output(&self, x: u32) -> u32;
}
fn output(db: &dyn DynTrait, x: u32) -> u32 {
db.input(x) * 2
}
#[test]
fn dyn_trait() {
let mut query = DynTraitDatabase::default();
query.set_input(22, 23);
assert_eq!(query.output(22), 46);
}

View file

@ -0,0 +1,145 @@
use crate::implementation::{TestContext, TestContextImpl};
use salsa::debug::DebugQueryTable;
use salsa::Durability;
#[salsa::query_group(Constants)]
pub(crate) trait ConstantsDatabase: TestContext {
#[salsa::input]
fn input(&self, key: char) -> usize;
fn add(&self, key1: char, key2: char) -> usize;
fn add3(&self, key1: char, key2: char, key3: char) -> usize;
}
fn add(db: &dyn ConstantsDatabase, key1: char, key2: char) -> usize {
db.log().add(format!("add({}, {})", key1, key2));
db.input(key1) + db.input(key2)
}
fn add3(db: &dyn ConstantsDatabase, key1: char, key2: char, key3: char) -> usize {
db.log().add(format!("add3({}, {}, {})", key1, key2, key3));
db.add(key1, key2) + db.input(key3)
}
// Test we can assign a constant and things will be correctly
// recomputed afterwards.
#[test]
fn invalidate_constant() {
let db = &mut TestContextImpl::default();
db.set_input_with_durability('a', 44, Durability::HIGH);
db.set_input_with_durability('b', 22, Durability::HIGH);
assert_eq!(db.add('a', 'b'), 66);
db.set_input_with_durability('a', 66, Durability::HIGH);
assert_eq!(db.add('a', 'b'), 88);
}
#[test]
fn invalidate_constant_1() {
let db = &mut TestContextImpl::default();
// Not constant:
db.set_input('a', 44);
assert_eq!(db.add('a', 'a'), 88);
// Becomes constant:
db.set_input_with_durability('a', 44, Durability::HIGH);
assert_eq!(db.add('a', 'a'), 88);
// Invalidates:
db.set_input_with_durability('a', 33, Durability::HIGH);
assert_eq!(db.add('a', 'a'), 66);
}
// Test cases where we assign same value to 'a' after declaring it a
// constant.
#[test]
fn set_after_constant_same_value() {
let db = &mut TestContextImpl::default();
db.set_input_with_durability('a', 44, Durability::HIGH);
db.set_input_with_durability('a', 44, Durability::HIGH);
db.set_input('a', 44);
}
#[test]
fn not_constant() {
let mut db = TestContextImpl::default();
db.set_input('a', 22);
db.set_input('b', 44);
assert_eq!(db.add('a', 'b'), 66);
assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b')));
}
#[test]
fn durability() {
let mut db = TestContextImpl::default();
db.set_input_with_durability('a', 22, Durability::HIGH);
db.set_input_with_durability('b', 44, Durability::HIGH);
assert_eq!(db.add('a', 'b'), 66);
assert_eq!(Durability::HIGH, AddQuery.in_db(&db).durability(('a', 'b')));
}
#[test]
fn mixed_constant() {
let mut db = TestContextImpl::default();
db.set_input_with_durability('a', 22, Durability::HIGH);
db.set_input('b', 44);
assert_eq!(db.add('a', 'b'), 66);
assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b')));
}
#[test]
fn becomes_constant_with_change() {
let mut db = TestContextImpl::default();
db.set_input('a', 22);
db.set_input('b', 44);
assert_eq!(db.add('a', 'b'), 66);
assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b')));
db.set_input_with_durability('a', 23, Durability::HIGH);
assert_eq!(db.add('a', 'b'), 67);
assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b')));
db.set_input_with_durability('b', 45, Durability::HIGH);
assert_eq!(db.add('a', 'b'), 68);
assert_eq!(Durability::HIGH, AddQuery.in_db(&db).durability(('a', 'b')));
db.set_input_with_durability('b', 45, Durability::MEDIUM);
assert_eq!(db.add('a', 'b'), 68);
assert_eq!(Durability::MEDIUM, AddQuery.in_db(&db).durability(('a', 'b')));
}
// Test a subtle case in which an input changes from constant to
// non-constant, but its value doesn't change. If we're not careful,
// this can cause us to incorrectly consider derived values as still
// being constant.
#[test]
fn constant_to_non_constant() {
let mut db = TestContextImpl::default();
db.set_input_with_durability('a', 11, Durability::HIGH);
db.set_input_with_durability('b', 22, Durability::HIGH);
db.set_input_with_durability('c', 33, Durability::HIGH);
// Here, `add3` invokes `add`, which yields 33. Both calls are
// constant.
assert_eq!(db.add3('a', 'b', 'c'), 66);
db.set_input('a', 11);
// Here, `add3` invokes `add`, which *still* yields 33, but which
// is no longer constant. Since value didn't change, we might
// preserve `add3` unchanged, not noticing that it is no longer
// constant.
assert_eq!(db.add3('a', 'b', 'c'), 66);
// In that case, we would not get the correct result here, when
// 'a' changes *again*.
db.set_input('a', 22);
assert_eq!(db.add3('a', 'b', 'c'), 77);
}

View file

@ -0,0 +1,14 @@
use std::cell::Cell;
#[derive(Default)]
pub(crate) struct Counter {
value: Cell<usize>,
}
impl Counter {
pub(crate) fn increment(&self) -> usize {
let v = self.value.get();
self.value.set(v + 1);
v
}
}

View file

@ -0,0 +1,59 @@
use crate::constants;
use crate::counter::Counter;
use crate::log::Log;
use crate::memoized_dep_inputs;
use crate::memoized_inputs;
use crate::memoized_volatile;
pub(crate) trait TestContext: salsa::Database {
fn clock(&self) -> &Counter;
fn log(&self) -> &Log;
}
#[salsa::database(
constants::Constants,
memoized_dep_inputs::MemoizedDepInputs,
memoized_inputs::MemoizedInputs,
memoized_volatile::MemoizedVolatile
)]
#[derive(Default)]
pub(crate) struct TestContextImpl {
storage: salsa::Storage<TestContextImpl>,
clock: Counter,
log: Log,
}
impl TestContextImpl {
#[track_caller]
pub(crate) fn assert_log(&self, expected_log: &[&str]) {
let expected_text = &format!("{:#?}", expected_log);
let actual_text = &format!("{:#?}", self.log().take());
if expected_text == actual_text {
return;
}
#[allow(clippy::print_stdout)]
for diff in dissimilar::diff(expected_text, actual_text) {
match diff {
dissimilar::Chunk::Delete(l) => println!("-{}", l),
dissimilar::Chunk::Equal(l) => println!(" {}", l),
dissimilar::Chunk::Insert(r) => println!("+{}", r),
}
}
panic!("incorrect log results");
}
}
impl TestContext for TestContextImpl {
fn clock(&self) -> &Counter {
&self.clock
}
fn log(&self) -> &Log {
&self.log
}
}
impl salsa::Database for TestContextImpl {}

View file

@ -0,0 +1,16 @@
use std::cell::RefCell;
#[derive(Default)]
pub(crate) struct Log {
data: RefCell<Vec<String>>,
}
impl Log {
pub(crate) fn add(&self, text: impl Into<String>) {
self.data.borrow_mut().push(text.into());
}
pub(crate) fn take(&self) -> Vec<String> {
self.data.take()
}
}

View file

@ -0,0 +1,9 @@
mod constants;
mod counter;
mod implementation;
mod log;
mod memoized_dep_inputs;
mod memoized_inputs;
mod memoized_volatile;
fn main() {}

View file

@ -0,0 +1,60 @@
use crate::implementation::{TestContext, TestContextImpl};
#[salsa::query_group(MemoizedDepInputs)]
pub(crate) trait MemoizedDepInputsContext: TestContext {
fn dep_memoized2(&self) -> usize;
fn dep_memoized1(&self) -> usize;
#[salsa::dependencies]
fn dep_derived1(&self) -> usize;
#[salsa::input]
fn dep_input1(&self) -> usize;
#[salsa::input]
fn dep_input2(&self) -> usize;
}
fn dep_memoized2(db: &dyn MemoizedDepInputsContext) -> usize {
db.log().add("Memoized2 invoked");
db.dep_memoized1()
}
fn dep_memoized1(db: &dyn MemoizedDepInputsContext) -> usize {
db.log().add("Memoized1 invoked");
db.dep_derived1() * 2
}
fn dep_derived1(db: &dyn MemoizedDepInputsContext) -> usize {
db.log().add("Derived1 invoked");
db.dep_input1() / 2
}
#[test]
fn revalidate() {
let db = &mut TestContextImpl::default();
db.set_dep_input1(0);
// Initial run starts from Memoized2:
let v = db.dep_memoized2();
assert_eq!(v, 0);
db.assert_log(&["Memoized2 invoked", "Memoized1 invoked", "Derived1 invoked"]);
// After that, we first try to validate Memoized1 but wind up
// running Memoized2. Note that we don't try to validate
// Derived1, so it is invoked by Memoized1.
db.set_dep_input1(44);
let v = db.dep_memoized2();
assert_eq!(v, 44);
db.assert_log(&["Memoized1 invoked", "Derived1 invoked", "Memoized2 invoked"]);
// Here validation of Memoized1 succeeds so Memoized2 never runs.
db.set_dep_input1(45);
let v = db.dep_memoized2();
assert_eq!(v, 44);
db.assert_log(&["Memoized1 invoked", "Derived1 invoked"]);
// Here, a change to input2 doesn't affect us, so nothing runs.
db.set_dep_input2(45);
let v = db.dep_memoized2();
assert_eq!(v, 44);
db.assert_log(&[]);
}

View file

@ -0,0 +1,76 @@
use crate::implementation::{TestContext, TestContextImpl};
#[salsa::query_group(MemoizedInputs)]
pub(crate) trait MemoizedInputsContext: TestContext {
fn max(&self) -> usize;
#[salsa::input]
fn input1(&self) -> usize;
#[salsa::input]
fn input2(&self) -> usize;
}
fn max(db: &dyn MemoizedInputsContext) -> usize {
db.log().add("Max invoked");
std::cmp::max(db.input1(), db.input2())
}
#[test]
fn revalidate() {
let db = &mut TestContextImpl::default();
db.set_input1(0);
db.set_input2(0);
let v = db.max();
assert_eq!(v, 0);
db.assert_log(&["Max invoked"]);
let v = db.max();
assert_eq!(v, 0);
db.assert_log(&[]);
db.set_input1(44);
db.assert_log(&[]);
let v = db.max();
assert_eq!(v, 44);
db.assert_log(&["Max invoked"]);
let v = db.max();
assert_eq!(v, 44);
db.assert_log(&[]);
db.set_input1(44);
db.assert_log(&[]);
db.set_input2(66);
db.assert_log(&[]);
db.set_input1(64);
db.assert_log(&[]);
let v = db.max();
assert_eq!(v, 66);
db.assert_log(&["Max invoked"]);
let v = db.max();
assert_eq!(v, 66);
db.assert_log(&[]);
}
/// Test that invoking `set` on an input with the same value still
/// triggers a new revision.
#[test]
fn set_after_no_change() {
let db = &mut TestContextImpl::default();
db.set_input2(0);
db.set_input1(44);
let v = db.max();
assert_eq!(v, 44);
db.assert_log(&["Max invoked"]);
db.set_input1(44);
let v = db.max();
assert_eq!(v, 44);
db.assert_log(&["Max invoked"]);
}

View file

@ -0,0 +1,77 @@
use crate::implementation::{TestContext, TestContextImpl};
use salsa::{Database, Durability};
#[salsa::query_group(MemoizedVolatile)]
pub(crate) trait MemoizedVolatileContext: TestContext {
// Queries for testing a "volatile" value wrapped by
// memoization.
fn memoized2(&self) -> usize;
fn memoized1(&self) -> usize;
fn volatile(&self) -> usize;
}
fn memoized2(db: &dyn MemoizedVolatileContext) -> usize {
db.log().add("Memoized2 invoked");
db.memoized1()
}
fn memoized1(db: &dyn MemoizedVolatileContext) -> usize {
db.log().add("Memoized1 invoked");
let v = db.volatile();
v / 2
}
fn volatile(db: &dyn MemoizedVolatileContext) -> usize {
db.log().add("Volatile invoked");
db.salsa_runtime().report_untracked_read();
db.clock().increment()
}
#[test]
fn volatile_x2() {
let query = TestContextImpl::default();
// Invoking volatile twice doesn't execute twice, because volatile
// queries are memoized by default.
query.volatile();
query.volatile();
query.assert_log(&["Volatile invoked"]);
}
/// Test that:
///
/// - On the first run of R0, we recompute everything.
/// - On the second run of R1, we recompute nothing.
/// - On the first run of R1, we recompute Memoized1 but not Memoized2 (since Memoized1 result
/// did not change).
/// - On the second run of R1, we recompute nothing.
/// - On the first run of R2, we recompute everything (since Memoized1 result *did* change).
#[test]
fn revalidate() {
let mut query = TestContextImpl::default();
query.memoized2();
query.assert_log(&["Memoized2 invoked", "Memoized1 invoked", "Volatile invoked"]);
query.memoized2();
query.assert_log(&[]);
// Second generation: volatile will change (to 1) but memoized1
// will not (still 0, as 1/2 = 0)
query.salsa_runtime_mut().synthetic_write(Durability::LOW);
query.memoized2();
query.assert_log(&["Volatile invoked", "Memoized1 invoked"]);
query.memoized2();
query.assert_log(&[]);
// Third generation: volatile will change (to 2) and memoized1
// will too (to 1). Therefore, after validating that Memoized1
// changed, we now invoke Memoized2.
query.salsa_runtime_mut().synthetic_write(Durability::LOW);
query.memoized2();
query.assert_log(&["Volatile invoked", "Memoized1 invoked", "Memoized2 invoked"]);
query.memoized2();
query.assert_log(&[]);
}

View file

@ -0,0 +1,90 @@
//! Test that you can implement a query using a `dyn Trait` setup.
use salsa::InternId;
#[salsa::database(InternStorage)]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
}
impl salsa::Database for Database {}
impl salsa::ParallelDatabase for Database {
fn snapshot(&self) -> salsa::Snapshot<Self> {
salsa::Snapshot::new(Database { storage: self.storage.snapshot() })
}
}
#[salsa::query_group(InternStorage)]
trait Intern {
#[salsa::interned]
fn intern1(&self, x: String) -> InternId;
#[salsa::interned]
fn intern2(&self, x: String, y: String) -> InternId;
#[salsa::interned]
fn intern_key(&self, x: String) -> InternKey;
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct InternKey(InternId);
impl salsa::InternKey for InternKey {
fn from_intern_id(v: InternId) -> Self {
InternKey(v)
}
fn as_intern_id(&self) -> InternId {
self.0
}
}
#[test]
fn test_intern1() {
let db = Database::default();
let foo0 = db.intern1("foo".to_owned());
let bar0 = db.intern1("bar".to_owned());
let foo1 = db.intern1("foo".to_owned());
let bar1 = db.intern1("bar".to_owned());
assert_eq!(foo0, foo1);
assert_eq!(bar0, bar1);
assert_ne!(foo0, bar0);
assert_eq!("foo".to_owned(), db.lookup_intern1(foo0));
assert_eq!("bar".to_owned(), db.lookup_intern1(bar0));
}
#[test]
fn test_intern2() {
let db = Database::default();
let foo0 = db.intern2("x".to_owned(), "foo".to_owned());
let bar0 = db.intern2("x".to_owned(), "bar".to_owned());
let foo1 = db.intern2("x".to_owned(), "foo".to_owned());
let bar1 = db.intern2("x".to_owned(), "bar".to_owned());
assert_eq!(foo0, foo1);
assert_eq!(bar0, bar1);
assert_ne!(foo0, bar0);
assert_eq!(("x".to_owned(), "foo".to_owned()), db.lookup_intern2(foo0));
assert_eq!(("x".to_owned(), "bar".to_owned()), db.lookup_intern2(bar0));
}
#[test]
fn test_intern_key() {
let db = Database::default();
let foo0 = db.intern_key("foo".to_owned());
let bar0 = db.intern_key("bar".to_owned());
let foo1 = db.intern_key("foo".to_owned());
let bar1 = db.intern_key("bar".to_owned());
assert_eq!(foo0, foo1);
assert_eq!(bar0, bar1);
assert_ne!(foo0, bar0);
assert_eq!("foo".to_owned(), db.lookup_intern_key(foo0));
assert_eq!("bar".to_owned(), db.lookup_intern_key(bar0));
}

102
crates/salsa/tests/lru.rs Normal file
View file

@ -0,0 +1,102 @@
//! Test setting LRU actually limits the number of things in the database;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
#[derive(Debug, PartialEq, Eq)]
struct HotPotato(u32);
static N_POTATOES: AtomicUsize = AtomicUsize::new(0);
impl HotPotato {
fn new(id: u32) -> HotPotato {
N_POTATOES.fetch_add(1, Ordering::SeqCst);
HotPotato(id)
}
}
impl Drop for HotPotato {
fn drop(&mut self) {
N_POTATOES.fetch_sub(1, Ordering::SeqCst);
}
}
#[salsa::query_group(QueryGroupStorage)]
trait QueryGroup: salsa::Database {
fn get(&self, x: u32) -> Arc<HotPotato>;
fn get_volatile(&self, x: u32) -> usize;
}
fn get(_db: &dyn QueryGroup, x: u32) -> Arc<HotPotato> {
Arc::new(HotPotato::new(x))
}
fn get_volatile(db: &dyn QueryGroup, _x: u32) -> usize {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
db.salsa_runtime().report_untracked_read();
COUNTER.fetch_add(1, Ordering::SeqCst)
}
#[salsa::database(QueryGroupStorage)]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
}
impl salsa::Database for Database {}
#[test]
fn lru_works() {
let mut db = Database::default();
GetQuery.in_db_mut(&mut db).set_lru_capacity(32);
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0);
for i in 0..128u32 {
let p = db.get(i);
assert_eq!(p.0, i)
}
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
for i in 0..128u32 {
let p = db.get(i);
assert_eq!(p.0, i)
}
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
GetQuery.in_db_mut(&mut db).set_lru_capacity(32);
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
GetQuery.in_db_mut(&mut db).set_lru_capacity(64);
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
for i in 0..128u32 {
let p = db.get(i);
assert_eq!(p.0, i)
}
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 64);
// Special case: setting capacity to zero disables LRU
GetQuery.in_db_mut(&mut db).set_lru_capacity(0);
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 64);
for i in 0..128u32 {
let p = db.get(i);
assert_eq!(p.0, i)
}
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 128);
drop(db);
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0);
}
#[test]
fn lru_doesnt_break_volatile_queries() {
let mut db = Database::default();
GetVolatileQuery.in_db_mut(&mut db).set_lru_capacity(32);
// Here, we check that we execute each volatile query at most once, despite
// LRU. That does mean that we have more values in DB than the LRU capacity,
// but it's much better than inconsistent results from volatile queries!
for i in (0..3).flat_map(|_| 0..128usize) {
let x = db.get_volatile(i as u32);
assert_eq!(x, i)
}
}

View file

@ -0,0 +1,11 @@
#[salsa::query_group(MyStruct)]
trait MyDatabase: salsa::Database {
#[salsa::invoke(another_module::another_name)]
fn my_query(&self, key: ()) -> ();
}
mod another_module {
pub(crate) fn another_name(_: &dyn crate::MyDatabase, (): ()) {}
}
fn main() {}

View file

@ -0,0 +1,31 @@
use std::rc::Rc;
#[salsa::query_group(NoSendSyncStorage)]
trait NoSendSyncDatabase: salsa::Database {
fn no_send_sync_value(&self, key: bool) -> Rc<bool>;
fn no_send_sync_key(&self, key: Rc<bool>) -> bool;
}
fn no_send_sync_value(_db: &dyn NoSendSyncDatabase, key: bool) -> Rc<bool> {
Rc::new(key)
}
fn no_send_sync_key(_db: &dyn NoSendSyncDatabase, key: Rc<bool>) -> bool {
*key
}
#[salsa::database(NoSendSyncStorage)]
#[derive(Default)]
struct DatabaseImpl {
storage: salsa::Storage<Self>,
}
impl salsa::Database for DatabaseImpl {}
#[test]
fn no_send_sync() {
let db = DatabaseImpl::default();
assert_eq!(db.no_send_sync_value(true), Rc::new(true));
assert!(!db.no_send_sync_key(Rc::new(false)));
}

View file

@ -0,0 +1,147 @@
//! Test that "on-demand" input pattern works.
//!
//! On-demand inputs are inputs computed lazily on the fly. They are simulated
//! via a b query with zero inputs, which uses `add_synthetic_read` to
//! tweak durability and `invalidate` to clear the input.
#![allow(clippy::disallowed_types, clippy::type_complexity)]
use std::{cell::RefCell, collections::HashMap, rc::Rc};
use salsa::{Database as _, Durability, EventKind};
#[salsa::query_group(QueryGroupStorage)]
trait QueryGroup: salsa::Database + AsRef<HashMap<u32, u32>> {
fn a(&self, x: u32) -> u32;
fn b(&self, x: u32) -> u32;
fn c(&self, x: u32) -> u32;
}
fn a(db: &dyn QueryGroup, x: u32) -> u32 {
let durability = if x % 2 == 0 { Durability::LOW } else { Durability::HIGH };
db.salsa_runtime().report_synthetic_read(durability);
let external_state: &HashMap<u32, u32> = db.as_ref();
external_state[&x]
}
fn b(db: &dyn QueryGroup, x: u32) -> u32 {
db.a(x)
}
fn c(db: &dyn QueryGroup, x: u32) -> u32 {
db.b(x)
}
#[salsa::database(QueryGroupStorage)]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
external_state: HashMap<u32, u32>,
on_event: Option<Box<dyn Fn(&Database, salsa::Event)>>,
}
impl salsa::Database for Database {
fn salsa_event(&self, event: salsa::Event) {
if let Some(cb) = &self.on_event {
cb(self, event)
}
}
}
impl AsRef<HashMap<u32, u32>> for Database {
fn as_ref(&self) -> &HashMap<u32, u32> {
&self.external_state
}
}
#[test]
fn on_demand_input_works() {
let mut db = Database::default();
db.external_state.insert(1, 10);
assert_eq!(db.b(1), 10);
assert_eq!(db.a(1), 10);
// We changed external state, but haven't signaled about this yet,
// so we expect to see the old answer
db.external_state.insert(1, 92);
assert_eq!(db.b(1), 10);
assert_eq!(db.a(1), 10);
AQuery.in_db_mut(&mut db).invalidate(&1);
assert_eq!(db.b(1), 92);
assert_eq!(db.a(1), 92);
// Downstream queries should also be rerun if we call `a` first.
db.external_state.insert(1, 50);
AQuery.in_db_mut(&mut db).invalidate(&1);
assert_eq!(db.a(1), 50);
assert_eq!(db.b(1), 50);
}
#[test]
fn on_demand_input_durability() {
let mut db = Database::default();
let events = Rc::new(RefCell::new(vec![]));
db.on_event = Some(Box::new({
let events = events.clone();
move |db, event| {
if let EventKind::WillCheckCancellation = event.kind {
// these events are not interesting
} else {
events.borrow_mut().push(format!("{:?}", event.debug(db)))
}
}
}));
events.replace(vec![]);
db.external_state.insert(1, 10);
db.external_state.insert(2, 20);
assert_eq!(db.b(1), 10);
assert_eq!(db.b(2), 20);
expect_test::expect![[r#"
RefCell {
value: [
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: b(1) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(1) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: b(2) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }",
],
}
"#]].assert_debug_eq(&events);
db.salsa_runtime_mut().synthetic_write(Durability::LOW);
events.replace(vec![]);
assert_eq!(db.c(1), 10);
assert_eq!(db.c(2), 20);
// Re-execute `a(2)` because that has low durability, but not `a(1)`
expect_test::expect![[r#"
RefCell {
value: [
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: c(1) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(1) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: c(2) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(2) } }",
],
}
"#]].assert_debug_eq(&events);
db.salsa_runtime_mut().synthetic_write(Durability::HIGH);
events.replace(vec![]);
assert_eq!(db.c(1), 10);
assert_eq!(db.c(2), 20);
// Re-execute both `a(1)` and `a(2)`, but we don't re-execute any `b` queries as the
// result didn't actually change.
expect_test::expect![[r#"
RefCell {
value: [
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(1) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(1) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(2) } }",
],
}
"#]].assert_debug_eq(&events);
}

View file

@ -0,0 +1,93 @@
use salsa::{Database, ParallelDatabase, Snapshot};
use std::panic::{self, AssertUnwindSafe};
use std::sync::atomic::{AtomicU32, Ordering::SeqCst};
#[salsa::query_group(PanicSafelyStruct)]
trait PanicSafelyDatabase: salsa::Database {
#[salsa::input]
fn one(&self) -> usize;
fn panic_safely(&self) -> ();
fn outer(&self) -> ();
}
fn panic_safely(db: &dyn PanicSafelyDatabase) {
assert_eq!(db.one(), 1);
}
static OUTER_CALLS: AtomicU32 = AtomicU32::new(0);
fn outer(db: &dyn PanicSafelyDatabase) {
OUTER_CALLS.fetch_add(1, SeqCst);
db.panic_safely();
}
#[salsa::database(PanicSafelyStruct)]
#[derive(Default)]
struct DatabaseStruct {
storage: salsa::Storage<Self>,
}
impl salsa::Database for DatabaseStruct {}
impl salsa::ParallelDatabase for DatabaseStruct {
fn snapshot(&self) -> Snapshot<Self> {
Snapshot::new(DatabaseStruct { storage: self.storage.snapshot() })
}
}
#[test]
fn should_panic_safely() {
let mut db = DatabaseStruct::default();
db.set_one(0);
// Invoke `db.panic_safely() without having set `db.one`. `db.one` will
// return 0 and we should catch the panic.
let result = panic::catch_unwind(AssertUnwindSafe({
let db = db.snapshot();
move || db.panic_safely()
}));
assert!(result.is_err());
// Set `db.one` to 1 and assert ok
db.set_one(1);
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
assert!(result.is_ok());
// Check, that memoized outer is not invalidated by a panic
{
assert_eq!(OUTER_CALLS.load(SeqCst), 0);
db.outer();
assert_eq!(OUTER_CALLS.load(SeqCst), 1);
db.set_one(0);
let result = panic::catch_unwind(AssertUnwindSafe(|| db.outer()));
assert!(result.is_err());
assert_eq!(OUTER_CALLS.load(SeqCst), 1);
db.set_one(1);
db.outer();
assert_eq!(OUTER_CALLS.load(SeqCst), 2);
}
}
#[test]
fn storages_are_unwind_safe() {
fn check_unwind_safe<T: std::panic::UnwindSafe>() {}
check_unwind_safe::<&DatabaseStruct>();
}
#[test]
fn panics_clear_query_stack() {
let db = DatabaseStruct::default();
// Invoke `db.panic_if_not_one() without having set `db.input`. `db.input`
// will default to 0 and we should catch the panic.
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
assert!(result.is_err());
// The database has been poisoned and any attempt to increment the
// revision should panic.
assert_eq!(db.salsa_runtime().active_query(), None);
}

View file

@ -0,0 +1,132 @@
use crate::setup::{CancellationFlag, Knobs, ParDatabase, ParDatabaseImpl, WithValue};
use salsa::{Cancelled, ParallelDatabase};
macro_rules! assert_cancelled {
($thread:expr) => {
match $thread.join() {
Ok(value) => panic!("expected cancellation, got {:?}", value),
Err(payload) => match payload.downcast::<Cancelled>() {
Ok(_) => {}
Err(payload) => ::std::panic::resume_unwind(payload),
},
}
};
}
/// Add test where a call to `sum` is cancelled by a simultaneous
/// write. Check that we recompute the result in next revision, even
/// though none of the inputs have changed.
#[test]
fn in_par_get_set_cancellation_immediate() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
db.set_input('d', 0);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
// This will not return until it sees cancellation is
// signaled.
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(CancellationFlag::Panic, || db.sum("abc"))
})
}
});
// Wait until we have entered `sum` in the other thread.
db.wait_for(1);
// Try to set the input. This will signal cancellation.
db.set_input('d', 1000);
// This should re-compute the value (even though no input has changed).
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || db.sum("abc")
});
assert_eq!(db.sum("d"), 1000);
assert_cancelled!(thread1);
assert_eq!(thread2.join().unwrap(), 111);
}
/// Here, we check that `sum`'s cancellation is propagated
/// to `sum2` properly.
#[test]
fn in_par_get_set_cancellation_transitive() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
db.set_input('d', 0);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
// This will not return until it sees cancellation is
// signaled.
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(CancellationFlag::Panic, || db.sum2("abc"))
})
}
});
// Wait until we have entered `sum` in the other thread.
db.wait_for(1);
// Try to set the input. This will signal cancellation.
db.set_input('d', 1000);
// This should re-compute the value (even though no input has changed).
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || db.sum2("abc")
});
assert_eq!(db.sum2("d"), 1000);
assert_cancelled!(thread1);
assert_eq!(thread2.join().unwrap(), 111);
}
/// https://github.com/salsa-rs/salsa/issues/66
#[test]
fn no_back_dating_in_cancellation() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 1);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
// Here we compute a long-chain of queries,
// but the last one gets cancelled.
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(CancellationFlag::Panic, || db.sum3("a"))
})
}
});
db.wait_for(1);
// Set unrelated input to bump revision
db.set_input('b', 2);
// Here we should recompuet the whole chain again, clearing the cancellation
// state. If we get `usize::max()` here, it is a bug!
assert_eq!(db.sum3("a"), 1);
assert_cancelled!(thread1);
db.set_input('a', 3);
db.set_input('a', 4);
assert_eq!(db.sum3("ab"), 6);
}

View file

@ -0,0 +1,57 @@
use crate::setup::{ParDatabase, ParDatabaseImpl};
use crate::signal::Signal;
use salsa::{Database, ParallelDatabase};
use std::{
panic::{catch_unwind, AssertUnwindSafe},
sync::Arc,
};
/// Add test where a call to `sum` is cancelled by a simultaneous
/// write. Check that we recompute the result in next revision, even
/// though none of the inputs have changed.
#[test]
fn in_par_get_set_cancellation() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 1);
let signal = Arc::new(Signal::default());
let thread1 = std::thread::spawn({
let db = db.snapshot();
let signal = signal.clone();
move || {
// Check that cancellation flag is not yet set, because
// `set` cannot have been called yet.
catch_unwind(AssertUnwindSafe(|| db.unwind_if_cancelled())).unwrap();
// Signal other thread to proceed.
signal.signal(1);
// Wait for other thread to signal cancellation
catch_unwind(AssertUnwindSafe(|| loop {
db.unwind_if_cancelled();
std::thread::yield_now();
}))
.unwrap_err();
}
});
let thread2 = std::thread::spawn({
move || {
// Wait until thread 1 has asserted that they are not cancelled
// before we invoke `set.`
signal.wait_for(1);
// This will block until thread1 drops the revision lock.
db.set_input('a', 2);
db.input('a')
}
});
thread1.join().unwrap();
let c = thread2.join().unwrap();
assert_eq!(c, 2);
}

View file

@ -0,0 +1,29 @@
use crate::setup::{ParDatabase, ParDatabaseImpl};
use salsa::ParallelDatabase;
/// Test two `sum` queries (on distinct keys) executing in different
/// threads. Really just a test that `snapshot` etc compiles.
#[test]
fn in_par_two_independent_queries() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
db.set_input('d', 200);
db.set_input('e', 20);
db.set_input('f', 2);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || db.sum("abc")
});
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || db.sum("def")
});
assert_eq!(thread1.join().unwrap(), 111);
assert_eq!(thread2.join().unwrap(), 222);
}

View file

@ -0,0 +1,13 @@
mod setup;
mod cancellation;
mod frozen;
mod independent;
mod parallel_cycle_all_recover;
mod parallel_cycle_mid_recover;
mod parallel_cycle_none_recover;
mod parallel_cycle_one_recovers;
mod race;
mod signal;
mod stress;
mod true_parallel;

View file

@ -0,0 +1,110 @@
//! Test for cycle recover spread across two threads.
//! See `../cycles.rs` for a complete listing of cycle tests,
//! both intra and cross thread.
use crate::setup::{Knobs, ParDatabaseImpl};
use salsa::ParallelDatabase;
use test_log::test;
// Recover cycle test:
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1 b1
// | wait for stage 1 (blocks)
// signal stage 1 |
// wait for stage 2 (blocks) (unblocked)
// | signal stage 2
// (unblocked) wait for stage 3 (blocks)
// a2 |
// b1 (blocks -> stage 3) |
// | (unblocked)
// | b2
// | a1 (cycle detected, recovers)
// | b2 completes, recovers
// | b1 completes, recovers
// a2 sees cycle, recovers
// a1 completes, recovers
#[test]
fn parallel_cycle_all_recover() {
let db = ParDatabaseImpl::default();
db.knobs().signal_on_will_block.set(3);
let thread_a = std::thread::spawn({
let db = db.snapshot();
move || db.a1(1)
});
let thread_b = std::thread::spawn({
let db = db.snapshot();
move || db.b1(1)
});
assert_eq!(thread_a.join().unwrap(), 11);
assert_eq!(thread_b.join().unwrap(), 21);
}
#[salsa::query_group(ParallelCycleAllRecover)]
pub(crate) trait TestDatabase: Knobs {
#[salsa::cycle(recover_a1)]
fn a1(&self, key: i32) -> i32;
#[salsa::cycle(recover_a2)]
fn a2(&self, key: i32) -> i32;
#[salsa::cycle(recover_b1)]
fn b1(&self, key: i32) -> i32;
#[salsa::cycle(recover_b2)]
fn b2(&self, key: i32) -> i32;
}
fn recover_a1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_a1");
key * 10 + 1
}
fn recover_a2(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_a2");
key * 10 + 2
}
fn recover_b1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_b1");
key * 20 + 1
}
fn recover_b2(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_b2");
key * 20 + 2
}
fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.signal(1);
db.wait_for(2);
db.a2(key)
}
fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
db.b1(key)
}
fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.wait_for(1);
db.signal(2);
// Wait for thread A to block on this thread
db.wait_for(3);
db.b2(key)
}
fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
db.a1(key)
}

View file

@ -0,0 +1,110 @@
//! Test for cycle recover spread across two threads.
//! See `../cycles.rs` for a complete listing of cycle tests,
//! both intra and cross thread.
use crate::setup::{Knobs, ParDatabaseImpl};
use salsa::ParallelDatabase;
use test_log::test;
// Recover cycle test:
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1 b1
// | wait for stage 1 (blocks)
// signal stage 1 |
// wait for stage 2 (blocks) (unblocked)
// | |
// | b2
// | b3
// | a1 (blocks -> stage 2)
// (unblocked) |
// a2 (cycle detected) |
// b3 recovers
// b2 resumes
// b1 panics because bug
#[test]
fn parallel_cycle_mid_recovers() {
let db = ParDatabaseImpl::default();
db.knobs().signal_on_will_block.set(2);
let thread_a = std::thread::spawn({
let db = db.snapshot();
move || db.a1(1)
});
let thread_b = std::thread::spawn({
let db = db.snapshot();
move || db.b1(1)
});
// We expect that the recovery function yields
// `1 * 20 + 2`, which is returned (and forwarded)
// to b1, and from there to a2 and a1.
assert_eq!(thread_a.join().unwrap(), 22);
assert_eq!(thread_b.join().unwrap(), 22);
}
#[salsa::query_group(ParallelCycleMidRecovers)]
pub(crate) trait TestDatabase: Knobs {
fn a1(&self, key: i32) -> i32;
fn a2(&self, key: i32) -> i32;
#[salsa::cycle(recover_b1)]
fn b1(&self, key: i32) -> i32;
fn b2(&self, key: i32) -> i32;
#[salsa::cycle(recover_b3)]
fn b3(&self, key: i32) -> i32;
}
fn recover_b1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_b1");
key * 20 + 2
}
fn recover_b3(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_b1");
key * 200 + 2
}
fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
// tell thread b we have started
db.signal(1);
// wait for thread b to block on a1
db.wait_for(2);
db.a2(key)
}
fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
// create the cycle
db.b1(key)
}
fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
// wait for thread a to have started
db.wait_for(1);
db.b2(key);
0
}
fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
// will encounter a cycle but recover
db.b3(key);
db.b1(key); // hasn't recovered yet
0
}
fn b3(db: &dyn TestDatabase, key: i32) -> i32 {
// will block on thread a, signaling stage 2
db.a1(key)
}

View file

@ -0,0 +1,69 @@
//! Test a cycle where no queries recover that occurs across threads.
//! See the `../cycles.rs` for a complete listing of cycle tests,
//! both intra and cross thread.
use crate::setup::{Knobs, ParDatabaseImpl};
use expect_test::expect;
use salsa::ParallelDatabase;
use test_log::test;
#[test]
fn parallel_cycle_none_recover() {
let db = ParDatabaseImpl::default();
db.knobs().signal_on_will_block.set(3);
let thread_a = std::thread::spawn({
let db = db.snapshot();
move || db.a(-1)
});
let thread_b = std::thread::spawn({
let db = db.snapshot();
move || db.b(-1)
});
// We expect B to panic because it detects a cycle (it is the one that calls A, ultimately).
// Right now, it panics with a string.
let err_b = thread_b.join().unwrap_err();
if let Some(c) = err_b.downcast_ref::<salsa::Cycle>() {
expect![[r#"
[
"a(-1)",
"b(-1)",
]
"#]]
.assert_debug_eq(&c.unexpected_participants(&db));
} else {
panic!("b failed in an unexpected way: {:?}", err_b);
}
// We expect A to propagate a panic, which causes us to use the sentinel
// type `Canceled`.
assert!(thread_a.join().unwrap_err().downcast_ref::<salsa::Cycle>().is_some());
}
#[salsa::query_group(ParallelCycleNoneRecover)]
pub(crate) trait TestDatabase: Knobs {
fn a(&self, key: i32) -> i32;
fn b(&self, key: i32) -> i32;
}
fn a(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.signal(1);
db.wait_for(2);
db.b(key)
}
fn b(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.wait_for(1);
db.signal(2);
// Wait for thread A to block on this thread
db.wait_for(3);
// Now try to execute A
db.a(key)
}

View file

@ -0,0 +1,95 @@
//! Test for cycle recover spread across two threads.
//! See `../cycles.rs` for a complete listing of cycle tests,
//! both intra and cross thread.
use crate::setup::{Knobs, ParDatabaseImpl};
use salsa::ParallelDatabase;
use test_log::test;
// Recover cycle test:
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1 b1
// | wait for stage 1 (blocks)
// signal stage 1 |
// wait for stage 2 (blocks) (unblocked)
// | signal stage 2
// (unblocked) wait for stage 3 (blocks)
// a2 |
// b1 (blocks -> stage 3) |
// | (unblocked)
// | b2
// | a1 (cycle detected)
// a2 recovery fn executes |
// a1 completes normally |
// b2 completes, recovers
// b1 completes, recovers
#[test]
fn parallel_cycle_one_recovers() {
let db = ParDatabaseImpl::default();
db.knobs().signal_on_will_block.set(3);
let thread_a = std::thread::spawn({
let db = db.snapshot();
move || db.a1(1)
});
let thread_b = std::thread::spawn({
let db = db.snapshot();
move || db.b1(1)
});
// We expect that the recovery function yields
// `1 * 20 + 2`, which is returned (and forwarded)
// to b1, and from there to a2 and a1.
assert_eq!(thread_a.join().unwrap(), 22);
assert_eq!(thread_b.join().unwrap(), 22);
}
#[salsa::query_group(ParallelCycleOneRecovers)]
pub(crate) trait TestDatabase: Knobs {
fn a1(&self, key: i32) -> i32;
#[salsa::cycle(recover)]
fn a2(&self, key: i32) -> i32;
fn b1(&self, key: i32) -> i32;
fn b2(&self, key: i32) -> i32;
}
fn recover(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover");
key * 20 + 2
}
fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.signal(1);
db.wait_for(2);
db.a2(key)
}
fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
db.b1(key)
}
fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.wait_for(1);
db.signal(2);
// Wait for thread A to block on this thread
db.wait_for(3);
db.b2(key)
}
fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
db.a1(key)
}

View file

@ -0,0 +1,37 @@
use std::panic::AssertUnwindSafe;
use crate::setup::{ParDatabase, ParDatabaseImpl};
use salsa::{Cancelled, ParallelDatabase};
/// Test where a read and a set are racing with one another.
/// Should be atomic.
#[test]
fn in_par_get_set_race() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || Cancelled::catch(AssertUnwindSafe(|| db.sum("abc")))
});
let thread2 = std::thread::spawn(move || {
db.set_input('a', 1000);
db.sum("a")
});
// If the 1st thread runs first, you get 111, otherwise you get
// 1011; if they run concurrently and the 1st thread observes the
// cancellation, it'll unwind.
let result1 = thread1.join().unwrap();
if let Ok(value1) = result1 {
assert!(value1 == 111 || value1 == 1011, "illegal result {}", value1);
}
// thread2 can not observe a cancellation because it performs a
// database write before running any other queries.
assert_eq!(thread2.join().unwrap(), 1000);
}

View file

@ -0,0 +1,197 @@
use crate::signal::Signal;
use salsa::Database;
use salsa::ParallelDatabase;
use salsa::Snapshot;
use std::sync::Arc;
use std::{
cell::Cell,
panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
};
#[salsa::query_group(Par)]
pub(crate) trait ParDatabase: Knobs {
#[salsa::input]
fn input(&self, key: char) -> usize;
fn sum(&self, key: &'static str) -> usize;
/// Invokes `sum`
fn sum2(&self, key: &'static str) -> usize;
/// Invokes `sum` but doesn't really care about the result.
fn sum2_drop_sum(&self, key: &'static str) -> usize;
/// Invokes `sum2`
fn sum3(&self, key: &'static str) -> usize;
/// Invokes `sum2_drop_sum`
fn sum3_drop_sum(&self, key: &'static str) -> usize;
}
/// Various "knobs" and utilities used by tests to force
/// a certain behavior.
pub(crate) trait Knobs {
fn knobs(&self) -> &KnobsStruct;
fn signal(&self, stage: usize);
fn wait_for(&self, stage: usize);
}
pub(crate) trait WithValue<T> {
fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R;
}
impl<T> WithValue<T> for Cell<T> {
fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R {
let old_value = self.replace(value);
let result = catch_unwind(AssertUnwindSafe(closure));
self.set(old_value);
match result {
Ok(r) => r,
Err(payload) => resume_unwind(payload),
}
}
}
#[derive(Default, Clone, Copy, PartialEq, Eq)]
pub(crate) enum CancellationFlag {
#[default]
Down,
Panic,
}
/// Various "knobs" that can be used to customize how the queries
/// behave on one specific thread. Note that this state is
/// intentionally thread-local (apart from `signal`).
#[derive(Clone, Default)]
pub(crate) struct KnobsStruct {
/// A kind of flexible barrier used to coordinate execution across
/// threads to ensure we reach various weird states.
pub(crate) signal: Arc<Signal>,
/// When this database is about to block, send a signal.
pub(crate) signal_on_will_block: Cell<usize>,
/// Invocations of `sum` will signal this stage on entry.
pub(crate) sum_signal_on_entry: Cell<usize>,
/// Invocations of `sum` will wait for this stage on entry.
pub(crate) sum_wait_for_on_entry: Cell<usize>,
/// If true, invocations of `sum` will panic before they exit.
pub(crate) sum_should_panic: Cell<bool>,
/// If true, invocations of `sum` will wait for cancellation before
/// they exit.
pub(crate) sum_wait_for_cancellation: Cell<CancellationFlag>,
/// Invocations of `sum` will wait for this stage prior to exiting.
pub(crate) sum_wait_for_on_exit: Cell<usize>,
/// Invocations of `sum` will signal this stage prior to exiting.
pub(crate) sum_signal_on_exit: Cell<usize>,
/// Invocations of `sum3_drop_sum` will panic unconditionally
pub(crate) sum3_drop_sum_should_panic: Cell<bool>,
}
fn sum(db: &dyn ParDatabase, key: &'static str) -> usize {
let mut sum = 0;
db.signal(db.knobs().sum_signal_on_entry.get());
db.wait_for(db.knobs().sum_wait_for_on_entry.get());
if db.knobs().sum_should_panic.get() {
panic!("query set to panic before exit")
}
for ch in key.chars() {
sum += db.input(ch);
}
match db.knobs().sum_wait_for_cancellation.get() {
CancellationFlag::Down => (),
CancellationFlag::Panic => {
tracing::debug!("waiting for cancellation");
loop {
db.unwind_if_cancelled();
std::thread::yield_now();
}
}
}
db.wait_for(db.knobs().sum_wait_for_on_exit.get());
db.signal(db.knobs().sum_signal_on_exit.get());
sum
}
fn sum2(db: &dyn ParDatabase, key: &'static str) -> usize {
db.sum(key)
}
fn sum2_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize {
let _ = db.sum(key);
22
}
fn sum3(db: &dyn ParDatabase, key: &'static str) -> usize {
db.sum2(key)
}
fn sum3_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize {
if db.knobs().sum3_drop_sum_should_panic.get() {
panic!("sum3_drop_sum executed")
}
db.sum2_drop_sum(key)
}
#[salsa::database(
Par,
crate::parallel_cycle_all_recover::ParallelCycleAllRecover,
crate::parallel_cycle_none_recover::ParallelCycleNoneRecover,
crate::parallel_cycle_mid_recover::ParallelCycleMidRecovers,
crate::parallel_cycle_one_recovers::ParallelCycleOneRecovers
)]
#[derive(Default)]
pub(crate) struct ParDatabaseImpl {
storage: salsa::Storage<Self>,
knobs: KnobsStruct,
}
impl Database for ParDatabaseImpl {
fn salsa_event(&self, event: salsa::Event) {
if let salsa::EventKind::WillBlockOn { .. } = event.kind {
self.signal(self.knobs().signal_on_will_block.get());
}
}
}
impl ParallelDatabase for ParDatabaseImpl {
fn snapshot(&self) -> Snapshot<Self> {
Snapshot::new(ParDatabaseImpl {
storage: self.storage.snapshot(),
knobs: self.knobs.clone(),
})
}
}
impl Knobs for ParDatabaseImpl {
fn knobs(&self) -> &KnobsStruct {
&self.knobs
}
fn signal(&self, stage: usize) {
self.knobs.signal.signal(stage);
}
fn wait_for(&self, stage: usize) {
self.knobs.signal.wait_for(stage);
}
}

View file

@ -0,0 +1,40 @@
use parking_lot::{Condvar, Mutex};
#[derive(Default)]
pub(crate) struct Signal {
value: Mutex<usize>,
cond_var: Condvar,
}
impl Signal {
pub(crate) fn signal(&self, stage: usize) {
tracing::debug!("signal({})", stage);
// This check avoids acquiring the lock for things that will
// clearly be a no-op. Not *necessary* but helps to ensure we
// are more likely to encounter weird race conditions;
// otherwise calls to `sum` will tend to be unnecessarily
// synchronous.
if stage > 0 {
let mut v = self.value.lock();
if stage > *v {
*v = stage;
self.cond_var.notify_all();
}
}
}
/// Waits until the given condition is true; the fn is invoked
/// with the current stage.
pub(crate) fn wait_for(&self, stage: usize) {
tracing::debug!("wait_for({})", stage);
// As above, avoid lock if clearly a no-op.
if stage > 0 {
let mut v = self.value.lock();
while *v < stage {
self.cond_var.wait(&mut v);
}
}
}
}

View file

@ -0,0 +1,168 @@
use rand::seq::SliceRandom;
use rand::Rng;
use salsa::ParallelDatabase;
use salsa::Snapshot;
use salsa::{Cancelled, Database};
// Number of operations a reader performs
const N_MUTATOR_OPS: usize = 100;
const N_READER_OPS: usize = 100;
#[salsa::query_group(Stress)]
trait StressDatabase: salsa::Database {
#[salsa::input]
fn a(&self, key: usize) -> usize;
fn b(&self, key: usize) -> usize;
fn c(&self, key: usize) -> usize;
}
fn b(db: &dyn StressDatabase, key: usize) -> usize {
db.unwind_if_cancelled();
db.a(key)
}
fn c(db: &dyn StressDatabase, key: usize) -> usize {
db.b(key)
}
#[salsa::database(Stress)]
#[derive(Default)]
struct StressDatabaseImpl {
storage: salsa::Storage<Self>,
}
impl salsa::Database for StressDatabaseImpl {}
impl salsa::ParallelDatabase for StressDatabaseImpl {
fn snapshot(&self) -> Snapshot<StressDatabaseImpl> {
Snapshot::new(StressDatabaseImpl { storage: self.storage.snapshot() })
}
}
#[derive(Clone, Copy, Debug)]
enum Query {
A,
B,
C,
}
enum MutatorOp {
WriteOp(WriteOp),
LaunchReader { ops: Vec<ReadOp>, check_cancellation: bool },
}
#[derive(Debug)]
enum WriteOp {
SetA(usize, usize),
}
#[derive(Debug)]
enum ReadOp {
Get(Query, usize),
}
impl rand::distributions::Distribution<Query> for rand::distributions::Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Query {
*[Query::A, Query::B, Query::C].choose(rng).unwrap()
}
}
impl rand::distributions::Distribution<MutatorOp> for rand::distributions::Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> MutatorOp {
if rng.gen_bool(0.5) {
MutatorOp::WriteOp(rng.gen())
} else {
MutatorOp::LaunchReader {
ops: (0..N_READER_OPS).map(|_| rng.gen()).collect(),
check_cancellation: rng.gen(),
}
}
}
}
impl rand::distributions::Distribution<WriteOp> for rand::distributions::Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> WriteOp {
let key = rng.gen::<usize>() % 10;
let value = rng.gen::<usize>() % 10;
WriteOp::SetA(key, value)
}
}
impl rand::distributions::Distribution<ReadOp> for rand::distributions::Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> ReadOp {
let query = rng.gen::<Query>();
let key = rng.gen::<usize>() % 10;
ReadOp::Get(query, key)
}
}
fn db_reader_thread(db: &StressDatabaseImpl, ops: Vec<ReadOp>, check_cancellation: bool) {
for op in ops {
if check_cancellation {
db.unwind_if_cancelled();
}
op.execute(db);
}
}
impl WriteOp {
fn execute(self, db: &mut StressDatabaseImpl) {
match self {
WriteOp::SetA(key, value) => {
db.set_a(key, value);
}
}
}
}
impl ReadOp {
fn execute(self, db: &StressDatabaseImpl) {
match self {
ReadOp::Get(query, key) => match query {
Query::A => {
db.a(key);
}
Query::B => {
let _ = db.b(key);
}
Query::C => {
let _ = db.c(key);
}
},
}
}
}
#[test]
fn stress_test() {
let mut db = StressDatabaseImpl::default();
for i in 0..10 {
db.set_a(i, i);
}
let mut rng = rand::thread_rng();
// generate the ops that the mutator thread will perform
let write_ops: Vec<MutatorOp> = (0..N_MUTATOR_OPS).map(|_| rng.gen()).collect();
// execute the "main thread", which sometimes snapshots off other threads
let mut all_threads = vec![];
for op in write_ops {
match op {
MutatorOp::WriteOp(w) => w.execute(&mut db),
MutatorOp::LaunchReader { ops, check_cancellation } => {
all_threads.push(std::thread::spawn({
let db = db.snapshot();
move || Cancelled::catch(|| db_reader_thread(&db, ops, check_cancellation))
}))
}
}
}
for thread in all_threads {
thread.join().unwrap().ok();
}
}

View file

@ -0,0 +1,125 @@
use crate::setup::{Knobs, ParDatabase, ParDatabaseImpl, WithValue};
use salsa::ParallelDatabase;
use std::panic::{self, AssertUnwindSafe};
/// Test where two threads are executing sum. We show that they can
/// both be executing sum in parallel by having thread1 wait for
/// thread2 to send a signal before it leaves (similarly, thread2
/// waits for thread1 to send a signal before it enters).
#[test]
fn true_parallel_different_keys() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
// Thread 1 will signal stage 1 when it enters and wait for stage 2.
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
let v = db
.knobs()
.sum_signal_on_entry
.with_value(1, || db.knobs().sum_wait_for_on_exit.with_value(2, || db.sum("a")));
v
}
});
// Thread 2 will wait_for stage 1 when it enters and signal stage 2
// when it leaves.
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || {
let v = db
.knobs()
.sum_wait_for_on_entry
.with_value(1, || db.knobs().sum_signal_on_exit.with_value(2, || db.sum("b")));
v
}
});
assert_eq!(thread1.join().unwrap(), 100);
assert_eq!(thread2.join().unwrap(), 10);
}
/// Add a test that tries to trigger a conflict, where we fetch
/// `sum("abc")` from two threads simultaneously, and of them
/// therefore has to block.
#[test]
fn true_parallel_same_keys() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
// Thread 1 will wait_for a barrier in the start of `sum`
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
let v = db
.knobs()
.sum_signal_on_entry
.with_value(1, || db.knobs().sum_wait_for_on_entry.with_value(2, || db.sum("abc")));
v
}
});
// Thread 2 will wait until Thread 1 has entered sum and then --
// once it has set itself to block -- signal Thread 1 to
// continue. This way, we test out the mechanism of one thread
// blocking on another.
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || {
db.knobs().signal.wait_for(1);
db.knobs().signal_on_will_block.set(2);
db.sum("abc")
}
});
assert_eq!(thread1.join().unwrap(), 111);
assert_eq!(thread2.join().unwrap(), 111);
}
/// Add a test that tries to trigger a conflict, where we fetch `sum("a")`
/// from two threads simultaneously. After `thread2` begins blocking,
/// we force `thread1` to panic and should see that propagate to `thread2`.
#[test]
fn true_parallel_propagate_panic() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 1);
// `thread1` will wait_for a barrier in the start of `sum`. Once it can
// continue, it will panic.
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
let v = db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_on_entry
.with_value(2, || db.knobs().sum_should_panic.with_value(true, || db.sum("a")))
});
v
}
});
// `thread2` will wait until `thread1` has entered sum and then -- once it
// has set itself to block -- signal `thread1` to continue.
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || {
db.knobs().signal.wait_for(1);
db.knobs().signal_on_will_block.set(2);
db.sum("a")
}
});
let result1 = panic::catch_unwind(AssertUnwindSafe(|| thread1.join().unwrap()));
let result2 = panic::catch_unwind(AssertUnwindSafe(|| thread2.join().unwrap()));
assert!(result1.is_err());
assert!(result2.is_err());
}

View file

@ -0,0 +1,19 @@
use crate::queries;
use std::cell::Cell;
#[salsa::database(queries::GroupStruct)]
#[derive(Default)]
pub(crate) struct DatabaseImpl {
storage: salsa::Storage<Self>,
counter: Cell<usize>,
}
impl queries::Counter for DatabaseImpl {
fn increment(&self) -> usize {
let v = self.counter.get();
self.counter.set(v + 1);
v
}
}
impl salsa::Database for DatabaseImpl {}

View file

@ -0,0 +1,5 @@
mod implementation;
mod queries;
mod tests;
fn main() {}

View file

@ -0,0 +1,22 @@
pub(crate) trait Counter: salsa::Database {
fn increment(&self) -> usize;
}
#[salsa::query_group(GroupStruct)]
pub(crate) trait Database: Counter {
fn memoized(&self) -> usize;
fn volatile(&self) -> usize;
}
/// Because this query is memoized, we only increment the counter
/// the first time it is invoked.
fn memoized(db: &dyn Database) -> usize {
db.volatile()
}
/// Because this query is volatile, each time it is invoked,
/// we will increment the counter.
fn volatile(db: &dyn Database) -> usize {
db.salsa_runtime().report_untracked_read();
db.increment()
}

View file

@ -0,0 +1,49 @@
#![cfg(test)]
use crate::implementation::DatabaseImpl;
use crate::queries::Database;
use salsa::Database as _Database;
use salsa::Durability;
#[test]
fn memoized_twice() {
let db = DatabaseImpl::default();
let v1 = db.memoized();
let v2 = db.memoized();
assert_eq!(v1, v2);
}
#[test]
fn volatile_twice() {
let mut db = DatabaseImpl::default();
let v1 = db.volatile();
let v2 = db.volatile(); // volatiles are cached, so 2nd read returns the same
assert_eq!(v1, v2);
db.salsa_runtime_mut().synthetic_write(Durability::LOW); // clears volatile caches
let v3 = db.volatile(); // will re-increment the counter
let v4 = db.volatile(); // second call will be cached
assert_eq!(v1 + 1, v3);
assert_eq!(v3, v4);
}
#[test]
fn intermingled() {
let mut db = DatabaseImpl::default();
let v1 = db.volatile();
let v2 = db.memoized();
let v3 = db.volatile(); // cached
let v4 = db.memoized(); // cached
assert_eq!(v1, v2);
assert_eq!(v1, v3);
assert_eq!(v2, v4);
db.salsa_runtime_mut().synthetic_write(Durability::LOW); // clears volatile caches
let v5 = db.memoized(); // re-executes volatile, caches new result
let v6 = db.memoized(); // re-use cached result
assert_eq!(v4 + 1, v5);
assert_eq!(v5, v6);
}

View file

@ -0,0 +1,39 @@
//! Test that transparent (uncached) queries work
#[salsa::query_group(QueryGroupStorage)]
trait QueryGroup {
#[salsa::input]
fn input(&self, x: u32) -> u32;
#[salsa::transparent]
fn wrap(&self, x: u32) -> u32;
fn get(&self, x: u32) -> u32;
}
fn wrap(db: &dyn QueryGroup, x: u32) -> u32 {
db.input(x)
}
fn get(db: &dyn QueryGroup, x: u32) -> u32 {
db.wrap(x)
}
#[salsa::database(QueryGroupStorage)]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
}
impl salsa::Database for Database {}
#[test]
fn transparent_queries_work() {
let mut db = Database::default();
db.set_input(1, 10);
assert_eq!(db.get(1), 10);
assert_eq!(db.get(1), 10);
db.set_input(1, 92);
assert_eq!(db.get(1), 92);
assert_eq!(db.get(1), 92);
}

View file

@ -0,0 +1,51 @@
#[salsa::query_group(HelloWorld)]
trait HelloWorldDatabase: salsa::Database {
#[salsa::input]
fn input(&self, a: u32, b: u32) -> u32;
fn none(&self) -> u32;
fn one(&self, k: u32) -> u32;
fn two(&self, a: u32, b: u32) -> u32;
fn trailing(&self, a: u32, b: u32) -> u32;
}
fn none(_db: &dyn HelloWorldDatabase) -> u32 {
22
}
fn one(_db: &dyn HelloWorldDatabase, k: u32) -> u32 {
k * 2
}
fn two(_db: &dyn HelloWorldDatabase, a: u32, b: u32) -> u32 {
a * b
}
fn trailing(_db: &dyn HelloWorldDatabase, a: u32, b: u32) -> u32 {
a - b
}
#[salsa::database(HelloWorld)]
#[derive(Default)]
struct DatabaseStruct {
storage: salsa::Storage<Self>,
}
impl salsa::Database for DatabaseStruct {}
#[test]
fn execute() {
let mut db = DatabaseStruct::default();
// test what happens with inputs:
db.set_input(1, 2, 3);
assert_eq!(db.input(1, 2), 3);
assert_eq!(db.none(), 22);
assert_eq!(db.one(11), 22);
assert_eq!(db.two(11, 2), 22);
assert_eq!(db.trailing(24, 2), 22);
}