mirror of
https://github.com/rust-lang/rust-analyzer.git
synced 2025-10-01 22:31:43 +00:00
Merge commit 'ddf105b646
' into sync-from-ra
This commit is contained in:
parent
0816d49d83
commit
e41ab350d6
378 changed files with 14720 additions and 3111 deletions
493
crates/salsa/tests/cycles.rs
Normal file
493
crates/salsa/tests/cycles.rs
Normal file
|
@ -0,0 +1,493 @@
|
|||
use std::panic::UnwindSafe;
|
||||
|
||||
use expect_test::expect;
|
||||
use salsa::{Durability, ParallelDatabase, Snapshot};
|
||||
use test_log::test;
|
||||
|
||||
// Axes:
|
||||
//
|
||||
// Threading
|
||||
// * Intra-thread
|
||||
// * Cross-thread -- part of cycle is on one thread, part on another
|
||||
//
|
||||
// Recovery strategies:
|
||||
// * Panic
|
||||
// * Fallback
|
||||
// * Mixed -- multiple strategies within cycle participants
|
||||
//
|
||||
// Across revisions:
|
||||
// * N/A -- only one revision
|
||||
// * Present in new revision, not old
|
||||
// * Present in old revision, not new
|
||||
// * Present in both revisions
|
||||
//
|
||||
// Dependencies
|
||||
// * Tracked
|
||||
// * Untracked -- cycle participant(s) contain untracked reads
|
||||
//
|
||||
// Layers
|
||||
// * Direct -- cycle participant is directly invoked from test
|
||||
// * Indirect -- invoked a query that invokes the cycle
|
||||
//
|
||||
//
|
||||
// | Thread | Recovery | Old, New | Dep style | Layers | Test Name |
|
||||
// | ------ | -------- | -------- | --------- | ------ | --------- |
|
||||
// | Intra | Panic | N/A | Tracked | direct | cycle_memoized |
|
||||
// | Intra | Panic | N/A | Untracked | direct | cycle_volatile |
|
||||
// | Intra | Fallback | N/A | Tracked | direct | cycle_cycle |
|
||||
// | Intra | Fallback | N/A | Tracked | indirect | inner_cycle |
|
||||
// | Intra | Fallback | Both | Tracked | direct | cycle_revalidate |
|
||||
// | Intra | Fallback | New | Tracked | direct | cycle_appears |
|
||||
// | Intra | Fallback | Old | Tracked | direct | cycle_disappears |
|
||||
// | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability |
|
||||
// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 |
|
||||
// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 |
|
||||
// | Cross | Fallback | N/A | Tracked | both | parallel/cycles.rs: recover_parallel_cycle |
|
||||
// | Cross | Panic | N/A | Tracked | both | parallel/cycles.rs: panic_parallel_cycle |
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
|
||||
struct Error {
|
||||
cycle: Vec<String>,
|
||||
}
|
||||
|
||||
#[salsa::database(GroupStruct)]
|
||||
#[derive(Default)]
|
||||
struct DatabaseImpl {
|
||||
storage: salsa::Storage<Self>,
|
||||
}
|
||||
|
||||
impl salsa::Database for DatabaseImpl {}
|
||||
|
||||
impl ParallelDatabase for DatabaseImpl {
|
||||
fn snapshot(&self) -> Snapshot<Self> {
|
||||
Snapshot::new(DatabaseImpl { storage: self.storage.snapshot() })
|
||||
}
|
||||
}
|
||||
|
||||
/// The queries A, B, and C in `Database` can be configured
|
||||
/// to invoke one another in arbitrary ways using this
|
||||
/// enum.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
enum CycleQuery {
|
||||
None,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
AthenC,
|
||||
}
|
||||
|
||||
#[salsa::query_group(GroupStruct)]
|
||||
trait Database: salsa::Database {
|
||||
// `a` and `b` depend on each other and form a cycle
|
||||
fn memoized_a(&self) -> ();
|
||||
fn memoized_b(&self) -> ();
|
||||
fn volatile_a(&self) -> ();
|
||||
fn volatile_b(&self) -> ();
|
||||
|
||||
#[salsa::input]
|
||||
fn a_invokes(&self) -> CycleQuery;
|
||||
|
||||
#[salsa::input]
|
||||
fn b_invokes(&self) -> CycleQuery;
|
||||
|
||||
#[salsa::input]
|
||||
fn c_invokes(&self) -> CycleQuery;
|
||||
|
||||
#[salsa::cycle(recover_a)]
|
||||
fn cycle_a(&self) -> Result<(), Error>;
|
||||
|
||||
#[salsa::cycle(recover_b)]
|
||||
fn cycle_b(&self) -> Result<(), Error>;
|
||||
|
||||
fn cycle_c(&self) -> Result<(), Error>;
|
||||
}
|
||||
|
||||
fn recover_a(db: &dyn Database, cycle: &salsa::Cycle) -> Result<(), Error> {
|
||||
Err(Error { cycle: cycle.all_participants(db) })
|
||||
}
|
||||
|
||||
fn recover_b(db: &dyn Database, cycle: &salsa::Cycle) -> Result<(), Error> {
|
||||
Err(Error { cycle: cycle.all_participants(db) })
|
||||
}
|
||||
|
||||
fn memoized_a(db: &dyn Database) {
|
||||
db.memoized_b()
|
||||
}
|
||||
|
||||
fn memoized_b(db: &dyn Database) {
|
||||
db.memoized_a()
|
||||
}
|
||||
|
||||
fn volatile_a(db: &dyn Database) {
|
||||
db.salsa_runtime().report_untracked_read();
|
||||
db.volatile_b()
|
||||
}
|
||||
|
||||
fn volatile_b(db: &dyn Database) {
|
||||
db.salsa_runtime().report_untracked_read();
|
||||
db.volatile_a()
|
||||
}
|
||||
|
||||
impl CycleQuery {
|
||||
fn invoke(self, db: &dyn Database) -> Result<(), Error> {
|
||||
match self {
|
||||
CycleQuery::A => db.cycle_a(),
|
||||
CycleQuery::B => db.cycle_b(),
|
||||
CycleQuery::C => db.cycle_c(),
|
||||
CycleQuery::AthenC => {
|
||||
let _ = db.cycle_a();
|
||||
db.cycle_c()
|
||||
}
|
||||
CycleQuery::None => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cycle_a(db: &dyn Database) -> Result<(), Error> {
|
||||
db.a_invokes().invoke(db)
|
||||
}
|
||||
|
||||
fn cycle_b(db: &dyn Database) -> Result<(), Error> {
|
||||
db.b_invokes().invoke(db)
|
||||
}
|
||||
|
||||
fn cycle_c(db: &dyn Database) -> Result<(), Error> {
|
||||
db.c_invokes().invoke(db)
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle {
|
||||
let v = std::panic::catch_unwind(f);
|
||||
if let Err(d) = &v {
|
||||
if let Some(cycle) = d.downcast_ref::<salsa::Cycle>() {
|
||||
return cycle.clone();
|
||||
}
|
||||
}
|
||||
panic!("unexpected value: {:?}", v)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_memoized() {
|
||||
let db = DatabaseImpl::default();
|
||||
let cycle = extract_cycle(|| db.memoized_a());
|
||||
expect![[r#"
|
||||
[
|
||||
"memoized_a(())",
|
||||
"memoized_b(())",
|
||||
]
|
||||
"#]]
|
||||
.assert_debug_eq(&cycle.unexpected_participants(&db));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_volatile() {
|
||||
let db = DatabaseImpl::default();
|
||||
let cycle = extract_cycle(|| db.volatile_a());
|
||||
expect![[r#"
|
||||
[
|
||||
"volatile_a(())",
|
||||
"volatile_b(())",
|
||||
]
|
||||
"#]]
|
||||
.assert_debug_eq(&cycle.unexpected_participants(&db));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_cycle() {
|
||||
let mut query = DatabaseImpl::default();
|
||||
|
||||
// A --> B
|
||||
// ^ |
|
||||
// +-----+
|
||||
|
||||
query.set_a_invokes(CycleQuery::B);
|
||||
query.set_b_invokes(CycleQuery::A);
|
||||
|
||||
assert!(query.cycle_a().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inner_cycle() {
|
||||
let mut query = DatabaseImpl::default();
|
||||
|
||||
// A --> B <-- C
|
||||
// ^ |
|
||||
// +-----+
|
||||
|
||||
query.set_a_invokes(CycleQuery::B);
|
||||
query.set_b_invokes(CycleQuery::A);
|
||||
query.set_c_invokes(CycleQuery::B);
|
||||
|
||||
let err = query.cycle_c();
|
||||
assert!(err.is_err());
|
||||
let cycle = err.unwrap_err().cycle;
|
||||
expect![[r#"
|
||||
[
|
||||
"cycle_a(())",
|
||||
"cycle_b(())",
|
||||
]
|
||||
"#]]
|
||||
.assert_debug_eq(&cycle);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_revalidate() {
|
||||
let mut db = DatabaseImpl::default();
|
||||
|
||||
// A --> B
|
||||
// ^ |
|
||||
// +-----+
|
||||
db.set_a_invokes(CycleQuery::B);
|
||||
db.set_b_invokes(CycleQuery::A);
|
||||
|
||||
assert!(db.cycle_a().is_err());
|
||||
db.set_b_invokes(CycleQuery::A); // same value as default
|
||||
assert!(db.cycle_a().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_revalidate_unchanged_twice() {
|
||||
let mut db = DatabaseImpl::default();
|
||||
|
||||
// A --> B
|
||||
// ^ |
|
||||
// +-----+
|
||||
db.set_a_invokes(CycleQuery::B);
|
||||
db.set_b_invokes(CycleQuery::A);
|
||||
|
||||
assert!(db.cycle_a().is_err());
|
||||
db.set_c_invokes(CycleQuery::A); // force new revisi5on
|
||||
|
||||
// on this run
|
||||
expect![[r#"
|
||||
Err(
|
||||
Error {
|
||||
cycle: [
|
||||
"cycle_a(())",
|
||||
"cycle_b(())",
|
||||
],
|
||||
},
|
||||
)
|
||||
"#]]
|
||||
.assert_debug_eq(&db.cycle_a());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_appears() {
|
||||
let mut db = DatabaseImpl::default();
|
||||
|
||||
// A --> B
|
||||
db.set_a_invokes(CycleQuery::B);
|
||||
db.set_b_invokes(CycleQuery::None);
|
||||
assert!(db.cycle_a().is_ok());
|
||||
|
||||
// A --> B
|
||||
// ^ |
|
||||
// +-----+
|
||||
db.set_b_invokes(CycleQuery::A);
|
||||
tracing::debug!("Set Cycle Leaf");
|
||||
assert!(db.cycle_a().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_disappears() {
|
||||
let mut db = DatabaseImpl::default();
|
||||
|
||||
// A --> B
|
||||
// ^ |
|
||||
// +-----+
|
||||
db.set_a_invokes(CycleQuery::B);
|
||||
db.set_b_invokes(CycleQuery::A);
|
||||
assert!(db.cycle_a().is_err());
|
||||
|
||||
// A --> B
|
||||
db.set_b_invokes(CycleQuery::None);
|
||||
assert!(db.cycle_a().is_ok());
|
||||
}
|
||||
|
||||
/// A variant on `cycle_disappears` in which the values of
|
||||
/// `a_invokes` and `b_invokes` are set with durability values.
|
||||
/// If we are not careful, this could cause us to overlook
|
||||
/// the fact that the cycle will no longer occur.
|
||||
#[test]
|
||||
fn cycle_disappears_durability() {
|
||||
let mut db = DatabaseImpl::default();
|
||||
db.set_a_invokes_with_durability(CycleQuery::B, Durability::LOW);
|
||||
db.set_b_invokes_with_durability(CycleQuery::A, Durability::HIGH);
|
||||
|
||||
let res = db.cycle_a();
|
||||
assert!(res.is_err());
|
||||
|
||||
// At this point, `a` read `LOW` input, and `b` read `HIGH` input. However,
|
||||
// because `b` participates in the same cycle as `a`, its final durability
|
||||
// should be `LOW`.
|
||||
//
|
||||
// Check that setting a `LOW` input causes us to re-execute `b` query, and
|
||||
// observe that the cycle goes away.
|
||||
db.set_a_invokes_with_durability(CycleQuery::None, Durability::LOW);
|
||||
|
||||
let res = db.cycle_b();
|
||||
assert!(res.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_mixed_1() {
|
||||
let mut db = DatabaseImpl::default();
|
||||
// A --> B <-- C
|
||||
// | ^
|
||||
// +-----+
|
||||
db.set_a_invokes(CycleQuery::B);
|
||||
db.set_b_invokes(CycleQuery::C);
|
||||
db.set_c_invokes(CycleQuery::B);
|
||||
|
||||
let u = db.cycle_c();
|
||||
expect![[r#"
|
||||
Err(
|
||||
Error {
|
||||
cycle: [
|
||||
"cycle_b(())",
|
||||
"cycle_c(())",
|
||||
],
|
||||
},
|
||||
)
|
||||
"#]]
|
||||
.assert_debug_eq(&u);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_mixed_2() {
|
||||
let mut db = DatabaseImpl::default();
|
||||
|
||||
// Configuration:
|
||||
//
|
||||
// A --> B --> C
|
||||
// ^ |
|
||||
// +-----------+
|
||||
db.set_a_invokes(CycleQuery::B);
|
||||
db.set_b_invokes(CycleQuery::C);
|
||||
db.set_c_invokes(CycleQuery::A);
|
||||
|
||||
let u = db.cycle_a();
|
||||
expect![[r#"
|
||||
Err(
|
||||
Error {
|
||||
cycle: [
|
||||
"cycle_a(())",
|
||||
"cycle_b(())",
|
||||
"cycle_c(())",
|
||||
],
|
||||
},
|
||||
)
|
||||
"#]]
|
||||
.assert_debug_eq(&u);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_deterministic_order() {
|
||||
// No matter whether we start from A or B, we get the same set of participants:
|
||||
let db = || {
|
||||
let mut db = DatabaseImpl::default();
|
||||
// A --> B
|
||||
// ^ |
|
||||
// +-----+
|
||||
db.set_a_invokes(CycleQuery::B);
|
||||
db.set_b_invokes(CycleQuery::A);
|
||||
db
|
||||
};
|
||||
let a = db().cycle_a();
|
||||
let b = db().cycle_b();
|
||||
expect![[r#"
|
||||
(
|
||||
Err(
|
||||
Error {
|
||||
cycle: [
|
||||
"cycle_a(())",
|
||||
"cycle_b(())",
|
||||
],
|
||||
},
|
||||
),
|
||||
Err(
|
||||
Error {
|
||||
cycle: [
|
||||
"cycle_a(())",
|
||||
"cycle_b(())",
|
||||
],
|
||||
},
|
||||
),
|
||||
)
|
||||
"#]]
|
||||
.assert_debug_eq(&(a, b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_multiple() {
|
||||
// No matter whether we start from A or B, we get the same set of participants:
|
||||
let mut db = DatabaseImpl::default();
|
||||
|
||||
// Configuration:
|
||||
//
|
||||
// A --> B <-- C
|
||||
// ^ | ^
|
||||
// +-----+ |
|
||||
// | |
|
||||
// +-----+
|
||||
//
|
||||
// Here, conceptually, B encounters a cycle with A and then
|
||||
// recovers.
|
||||
db.set_a_invokes(CycleQuery::B);
|
||||
db.set_b_invokes(CycleQuery::AthenC);
|
||||
db.set_c_invokes(CycleQuery::B);
|
||||
|
||||
let c = db.cycle_c();
|
||||
let b = db.cycle_b();
|
||||
let a = db.cycle_a();
|
||||
expect![[r#"
|
||||
(
|
||||
Err(
|
||||
Error {
|
||||
cycle: [
|
||||
"cycle_a(())",
|
||||
"cycle_b(())",
|
||||
],
|
||||
},
|
||||
),
|
||||
Err(
|
||||
Error {
|
||||
cycle: [
|
||||
"cycle_a(())",
|
||||
"cycle_b(())",
|
||||
],
|
||||
},
|
||||
),
|
||||
Err(
|
||||
Error {
|
||||
cycle: [
|
||||
"cycle_a(())",
|
||||
"cycle_b(())",
|
||||
],
|
||||
},
|
||||
),
|
||||
)
|
||||
"#]]
|
||||
.assert_debug_eq(&(a, b, c));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_recovery_set_but_not_participating() {
|
||||
let mut db = DatabaseImpl::default();
|
||||
|
||||
// A --> C -+
|
||||
// ^ |
|
||||
// +--+
|
||||
db.set_a_invokes(CycleQuery::C);
|
||||
db.set_c_invokes(CycleQuery::C);
|
||||
|
||||
// Here we expect C to panic and A not to recover:
|
||||
let r = extract_cycle(|| drop(db.cycle_a()));
|
||||
expect![[r#"
|
||||
[
|
||||
"cycle_c(())",
|
||||
]
|
||||
"#]]
|
||||
.assert_debug_eq(&r.all_participants(&db));
|
||||
}
|
28
crates/salsa/tests/dyn_trait.rs
Normal file
28
crates/salsa/tests/dyn_trait.rs
Normal file
|
@ -0,0 +1,28 @@
|
|||
//! Test that you can implement a query using a `dyn Trait` setup.
|
||||
|
||||
#[salsa::database(DynTraitStorage)]
|
||||
#[derive(Default)]
|
||||
struct DynTraitDatabase {
|
||||
storage: salsa::Storage<Self>,
|
||||
}
|
||||
|
||||
impl salsa::Database for DynTraitDatabase {}
|
||||
|
||||
#[salsa::query_group(DynTraitStorage)]
|
||||
trait DynTrait {
|
||||
#[salsa::input]
|
||||
fn input(&self, x: u32) -> u32;
|
||||
|
||||
fn output(&self, x: u32) -> u32;
|
||||
}
|
||||
|
||||
fn output(db: &dyn DynTrait, x: u32) -> u32 {
|
||||
db.input(x) * 2
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dyn_trait() {
|
||||
let mut query = DynTraitDatabase::default();
|
||||
query.set_input(22, 23);
|
||||
assert_eq!(query.output(22), 46);
|
||||
}
|
145
crates/salsa/tests/incremental/constants.rs
Normal file
145
crates/salsa/tests/incremental/constants.rs
Normal file
|
@ -0,0 +1,145 @@
|
|||
use crate::implementation::{TestContext, TestContextImpl};
|
||||
use salsa::debug::DebugQueryTable;
|
||||
use salsa::Durability;
|
||||
|
||||
#[salsa::query_group(Constants)]
|
||||
pub(crate) trait ConstantsDatabase: TestContext {
|
||||
#[salsa::input]
|
||||
fn input(&self, key: char) -> usize;
|
||||
|
||||
fn add(&self, key1: char, key2: char) -> usize;
|
||||
|
||||
fn add3(&self, key1: char, key2: char, key3: char) -> usize;
|
||||
}
|
||||
|
||||
fn add(db: &dyn ConstantsDatabase, key1: char, key2: char) -> usize {
|
||||
db.log().add(format!("add({}, {})", key1, key2));
|
||||
db.input(key1) + db.input(key2)
|
||||
}
|
||||
|
||||
fn add3(db: &dyn ConstantsDatabase, key1: char, key2: char, key3: char) -> usize {
|
||||
db.log().add(format!("add3({}, {}, {})", key1, key2, key3));
|
||||
db.add(key1, key2) + db.input(key3)
|
||||
}
|
||||
|
||||
// Test we can assign a constant and things will be correctly
|
||||
// recomputed afterwards.
|
||||
#[test]
|
||||
fn invalidate_constant() {
|
||||
let db = &mut TestContextImpl::default();
|
||||
db.set_input_with_durability('a', 44, Durability::HIGH);
|
||||
db.set_input_with_durability('b', 22, Durability::HIGH);
|
||||
assert_eq!(db.add('a', 'b'), 66);
|
||||
|
||||
db.set_input_with_durability('a', 66, Durability::HIGH);
|
||||
assert_eq!(db.add('a', 'b'), 88);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalidate_constant_1() {
|
||||
let db = &mut TestContextImpl::default();
|
||||
|
||||
// Not constant:
|
||||
db.set_input('a', 44);
|
||||
assert_eq!(db.add('a', 'a'), 88);
|
||||
|
||||
// Becomes constant:
|
||||
db.set_input_with_durability('a', 44, Durability::HIGH);
|
||||
assert_eq!(db.add('a', 'a'), 88);
|
||||
|
||||
// Invalidates:
|
||||
db.set_input_with_durability('a', 33, Durability::HIGH);
|
||||
assert_eq!(db.add('a', 'a'), 66);
|
||||
}
|
||||
|
||||
// Test cases where we assign same value to 'a' after declaring it a
|
||||
// constant.
|
||||
#[test]
|
||||
fn set_after_constant_same_value() {
|
||||
let db = &mut TestContextImpl::default();
|
||||
db.set_input_with_durability('a', 44, Durability::HIGH);
|
||||
db.set_input_with_durability('a', 44, Durability::HIGH);
|
||||
db.set_input('a', 44);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_constant() {
|
||||
let mut db = TestContextImpl::default();
|
||||
|
||||
db.set_input('a', 22);
|
||||
db.set_input('b', 44);
|
||||
assert_eq!(db.add('a', 'b'), 66);
|
||||
assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b')));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn durability() {
|
||||
let mut db = TestContextImpl::default();
|
||||
|
||||
db.set_input_with_durability('a', 22, Durability::HIGH);
|
||||
db.set_input_with_durability('b', 44, Durability::HIGH);
|
||||
assert_eq!(db.add('a', 'b'), 66);
|
||||
assert_eq!(Durability::HIGH, AddQuery.in_db(&db).durability(('a', 'b')));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mixed_constant() {
|
||||
let mut db = TestContextImpl::default();
|
||||
|
||||
db.set_input_with_durability('a', 22, Durability::HIGH);
|
||||
db.set_input('b', 44);
|
||||
assert_eq!(db.add('a', 'b'), 66);
|
||||
assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b')));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn becomes_constant_with_change() {
|
||||
let mut db = TestContextImpl::default();
|
||||
|
||||
db.set_input('a', 22);
|
||||
db.set_input('b', 44);
|
||||
assert_eq!(db.add('a', 'b'), 66);
|
||||
assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b')));
|
||||
|
||||
db.set_input_with_durability('a', 23, Durability::HIGH);
|
||||
assert_eq!(db.add('a', 'b'), 67);
|
||||
assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b')));
|
||||
|
||||
db.set_input_with_durability('b', 45, Durability::HIGH);
|
||||
assert_eq!(db.add('a', 'b'), 68);
|
||||
assert_eq!(Durability::HIGH, AddQuery.in_db(&db).durability(('a', 'b')));
|
||||
|
||||
db.set_input_with_durability('b', 45, Durability::MEDIUM);
|
||||
assert_eq!(db.add('a', 'b'), 68);
|
||||
assert_eq!(Durability::MEDIUM, AddQuery.in_db(&db).durability(('a', 'b')));
|
||||
}
|
||||
|
||||
// Test a subtle case in which an input changes from constant to
|
||||
// non-constant, but its value doesn't change. If we're not careful,
|
||||
// this can cause us to incorrectly consider derived values as still
|
||||
// being constant.
|
||||
#[test]
|
||||
fn constant_to_non_constant() {
|
||||
let mut db = TestContextImpl::default();
|
||||
|
||||
db.set_input_with_durability('a', 11, Durability::HIGH);
|
||||
db.set_input_with_durability('b', 22, Durability::HIGH);
|
||||
db.set_input_with_durability('c', 33, Durability::HIGH);
|
||||
|
||||
// Here, `add3` invokes `add`, which yields 33. Both calls are
|
||||
// constant.
|
||||
assert_eq!(db.add3('a', 'b', 'c'), 66);
|
||||
|
||||
db.set_input('a', 11);
|
||||
|
||||
// Here, `add3` invokes `add`, which *still* yields 33, but which
|
||||
// is no longer constant. Since value didn't change, we might
|
||||
// preserve `add3` unchanged, not noticing that it is no longer
|
||||
// constant.
|
||||
assert_eq!(db.add3('a', 'b', 'c'), 66);
|
||||
|
||||
// In that case, we would not get the correct result here, when
|
||||
// 'a' changes *again*.
|
||||
db.set_input('a', 22);
|
||||
assert_eq!(db.add3('a', 'b', 'c'), 77);
|
||||
}
|
14
crates/salsa/tests/incremental/counter.rs
Normal file
14
crates/salsa/tests/incremental/counter.rs
Normal file
|
@ -0,0 +1,14 @@
|
|||
use std::cell::Cell;
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct Counter {
|
||||
value: Cell<usize>,
|
||||
}
|
||||
|
||||
impl Counter {
|
||||
pub(crate) fn increment(&self) -> usize {
|
||||
let v = self.value.get();
|
||||
self.value.set(v + 1);
|
||||
v
|
||||
}
|
||||
}
|
59
crates/salsa/tests/incremental/implementation.rs
Normal file
59
crates/salsa/tests/incremental/implementation.rs
Normal file
|
@ -0,0 +1,59 @@
|
|||
use crate::constants;
|
||||
use crate::counter::Counter;
|
||||
use crate::log::Log;
|
||||
use crate::memoized_dep_inputs;
|
||||
use crate::memoized_inputs;
|
||||
use crate::memoized_volatile;
|
||||
|
||||
pub(crate) trait TestContext: salsa::Database {
|
||||
fn clock(&self) -> &Counter;
|
||||
fn log(&self) -> &Log;
|
||||
}
|
||||
|
||||
#[salsa::database(
|
||||
constants::Constants,
|
||||
memoized_dep_inputs::MemoizedDepInputs,
|
||||
memoized_inputs::MemoizedInputs,
|
||||
memoized_volatile::MemoizedVolatile
|
||||
)]
|
||||
#[derive(Default)]
|
||||
pub(crate) struct TestContextImpl {
|
||||
storage: salsa::Storage<TestContextImpl>,
|
||||
clock: Counter,
|
||||
log: Log,
|
||||
}
|
||||
|
||||
impl TestContextImpl {
|
||||
#[track_caller]
|
||||
pub(crate) fn assert_log(&self, expected_log: &[&str]) {
|
||||
let expected_text = &format!("{:#?}", expected_log);
|
||||
let actual_text = &format!("{:#?}", self.log().take());
|
||||
|
||||
if expected_text == actual_text {
|
||||
return;
|
||||
}
|
||||
|
||||
#[allow(clippy::print_stdout)]
|
||||
for diff in dissimilar::diff(expected_text, actual_text) {
|
||||
match diff {
|
||||
dissimilar::Chunk::Delete(l) => println!("-{}", l),
|
||||
dissimilar::Chunk::Equal(l) => println!(" {}", l),
|
||||
dissimilar::Chunk::Insert(r) => println!("+{}", r),
|
||||
}
|
||||
}
|
||||
|
||||
panic!("incorrect log results");
|
||||
}
|
||||
}
|
||||
|
||||
impl TestContext for TestContextImpl {
|
||||
fn clock(&self) -> &Counter {
|
||||
&self.clock
|
||||
}
|
||||
|
||||
fn log(&self) -> &Log {
|
||||
&self.log
|
||||
}
|
||||
}
|
||||
|
||||
impl salsa::Database for TestContextImpl {}
|
16
crates/salsa/tests/incremental/log.rs
Normal file
16
crates/salsa/tests/incremental/log.rs
Normal file
|
@ -0,0 +1,16 @@
|
|||
use std::cell::RefCell;
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct Log {
|
||||
data: RefCell<Vec<String>>,
|
||||
}
|
||||
|
||||
impl Log {
|
||||
pub(crate) fn add(&self, text: impl Into<String>) {
|
||||
self.data.borrow_mut().push(text.into());
|
||||
}
|
||||
|
||||
pub(crate) fn take(&self) -> Vec<String> {
|
||||
self.data.take()
|
||||
}
|
||||
}
|
9
crates/salsa/tests/incremental/main.rs
Normal file
9
crates/salsa/tests/incremental/main.rs
Normal file
|
@ -0,0 +1,9 @@
|
|||
mod constants;
|
||||
mod counter;
|
||||
mod implementation;
|
||||
mod log;
|
||||
mod memoized_dep_inputs;
|
||||
mod memoized_inputs;
|
||||
mod memoized_volatile;
|
||||
|
||||
fn main() {}
|
60
crates/salsa/tests/incremental/memoized_dep_inputs.rs
Normal file
60
crates/salsa/tests/incremental/memoized_dep_inputs.rs
Normal file
|
@ -0,0 +1,60 @@
|
|||
use crate::implementation::{TestContext, TestContextImpl};
|
||||
|
||||
#[salsa::query_group(MemoizedDepInputs)]
|
||||
pub(crate) trait MemoizedDepInputsContext: TestContext {
|
||||
fn dep_memoized2(&self) -> usize;
|
||||
fn dep_memoized1(&self) -> usize;
|
||||
#[salsa::dependencies]
|
||||
fn dep_derived1(&self) -> usize;
|
||||
#[salsa::input]
|
||||
fn dep_input1(&self) -> usize;
|
||||
#[salsa::input]
|
||||
fn dep_input2(&self) -> usize;
|
||||
}
|
||||
|
||||
fn dep_memoized2(db: &dyn MemoizedDepInputsContext) -> usize {
|
||||
db.log().add("Memoized2 invoked");
|
||||
db.dep_memoized1()
|
||||
}
|
||||
|
||||
fn dep_memoized1(db: &dyn MemoizedDepInputsContext) -> usize {
|
||||
db.log().add("Memoized1 invoked");
|
||||
db.dep_derived1() * 2
|
||||
}
|
||||
|
||||
fn dep_derived1(db: &dyn MemoizedDepInputsContext) -> usize {
|
||||
db.log().add("Derived1 invoked");
|
||||
db.dep_input1() / 2
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn revalidate() {
|
||||
let db = &mut TestContextImpl::default();
|
||||
|
||||
db.set_dep_input1(0);
|
||||
|
||||
// Initial run starts from Memoized2:
|
||||
let v = db.dep_memoized2();
|
||||
assert_eq!(v, 0);
|
||||
db.assert_log(&["Memoized2 invoked", "Memoized1 invoked", "Derived1 invoked"]);
|
||||
|
||||
// After that, we first try to validate Memoized1 but wind up
|
||||
// running Memoized2. Note that we don't try to validate
|
||||
// Derived1, so it is invoked by Memoized1.
|
||||
db.set_dep_input1(44);
|
||||
let v = db.dep_memoized2();
|
||||
assert_eq!(v, 44);
|
||||
db.assert_log(&["Memoized1 invoked", "Derived1 invoked", "Memoized2 invoked"]);
|
||||
|
||||
// Here validation of Memoized1 succeeds so Memoized2 never runs.
|
||||
db.set_dep_input1(45);
|
||||
let v = db.dep_memoized2();
|
||||
assert_eq!(v, 44);
|
||||
db.assert_log(&["Memoized1 invoked", "Derived1 invoked"]);
|
||||
|
||||
// Here, a change to input2 doesn't affect us, so nothing runs.
|
||||
db.set_dep_input2(45);
|
||||
let v = db.dep_memoized2();
|
||||
assert_eq!(v, 44);
|
||||
db.assert_log(&[]);
|
||||
}
|
76
crates/salsa/tests/incremental/memoized_inputs.rs
Normal file
76
crates/salsa/tests/incremental/memoized_inputs.rs
Normal file
|
@ -0,0 +1,76 @@
|
|||
use crate::implementation::{TestContext, TestContextImpl};
|
||||
|
||||
#[salsa::query_group(MemoizedInputs)]
|
||||
pub(crate) trait MemoizedInputsContext: TestContext {
|
||||
fn max(&self) -> usize;
|
||||
#[salsa::input]
|
||||
fn input1(&self) -> usize;
|
||||
#[salsa::input]
|
||||
fn input2(&self) -> usize;
|
||||
}
|
||||
|
||||
fn max(db: &dyn MemoizedInputsContext) -> usize {
|
||||
db.log().add("Max invoked");
|
||||
std::cmp::max(db.input1(), db.input2())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn revalidate() {
|
||||
let db = &mut TestContextImpl::default();
|
||||
|
||||
db.set_input1(0);
|
||||
db.set_input2(0);
|
||||
|
||||
let v = db.max();
|
||||
assert_eq!(v, 0);
|
||||
db.assert_log(&["Max invoked"]);
|
||||
|
||||
let v = db.max();
|
||||
assert_eq!(v, 0);
|
||||
db.assert_log(&[]);
|
||||
|
||||
db.set_input1(44);
|
||||
db.assert_log(&[]);
|
||||
|
||||
let v = db.max();
|
||||
assert_eq!(v, 44);
|
||||
db.assert_log(&["Max invoked"]);
|
||||
|
||||
let v = db.max();
|
||||
assert_eq!(v, 44);
|
||||
db.assert_log(&[]);
|
||||
|
||||
db.set_input1(44);
|
||||
db.assert_log(&[]);
|
||||
db.set_input2(66);
|
||||
db.assert_log(&[]);
|
||||
db.set_input1(64);
|
||||
db.assert_log(&[]);
|
||||
|
||||
let v = db.max();
|
||||
assert_eq!(v, 66);
|
||||
db.assert_log(&["Max invoked"]);
|
||||
|
||||
let v = db.max();
|
||||
assert_eq!(v, 66);
|
||||
db.assert_log(&[]);
|
||||
}
|
||||
|
||||
/// Test that invoking `set` on an input with the same value still
|
||||
/// triggers a new revision.
|
||||
#[test]
|
||||
fn set_after_no_change() {
|
||||
let db = &mut TestContextImpl::default();
|
||||
|
||||
db.set_input2(0);
|
||||
|
||||
db.set_input1(44);
|
||||
let v = db.max();
|
||||
assert_eq!(v, 44);
|
||||
db.assert_log(&["Max invoked"]);
|
||||
|
||||
db.set_input1(44);
|
||||
let v = db.max();
|
||||
assert_eq!(v, 44);
|
||||
db.assert_log(&["Max invoked"]);
|
||||
}
|
77
crates/salsa/tests/incremental/memoized_volatile.rs
Normal file
77
crates/salsa/tests/incremental/memoized_volatile.rs
Normal file
|
@ -0,0 +1,77 @@
|
|||
use crate::implementation::{TestContext, TestContextImpl};
|
||||
use salsa::{Database, Durability};
|
||||
|
||||
#[salsa::query_group(MemoizedVolatile)]
|
||||
pub(crate) trait MemoizedVolatileContext: TestContext {
|
||||
// Queries for testing a "volatile" value wrapped by
|
||||
// memoization.
|
||||
fn memoized2(&self) -> usize;
|
||||
fn memoized1(&self) -> usize;
|
||||
fn volatile(&self) -> usize;
|
||||
}
|
||||
|
||||
fn memoized2(db: &dyn MemoizedVolatileContext) -> usize {
|
||||
db.log().add("Memoized2 invoked");
|
||||
db.memoized1()
|
||||
}
|
||||
|
||||
fn memoized1(db: &dyn MemoizedVolatileContext) -> usize {
|
||||
db.log().add("Memoized1 invoked");
|
||||
let v = db.volatile();
|
||||
v / 2
|
||||
}
|
||||
|
||||
fn volatile(db: &dyn MemoizedVolatileContext) -> usize {
|
||||
db.log().add("Volatile invoked");
|
||||
db.salsa_runtime().report_untracked_read();
|
||||
db.clock().increment()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn volatile_x2() {
|
||||
let query = TestContextImpl::default();
|
||||
|
||||
// Invoking volatile twice doesn't execute twice, because volatile
|
||||
// queries are memoized by default.
|
||||
query.volatile();
|
||||
query.volatile();
|
||||
query.assert_log(&["Volatile invoked"]);
|
||||
}
|
||||
|
||||
/// Test that:
|
||||
///
|
||||
/// - On the first run of R0, we recompute everything.
|
||||
/// - On the second run of R1, we recompute nothing.
|
||||
/// - On the first run of R1, we recompute Memoized1 but not Memoized2 (since Memoized1 result
|
||||
/// did not change).
|
||||
/// - On the second run of R1, we recompute nothing.
|
||||
/// - On the first run of R2, we recompute everything (since Memoized1 result *did* change).
|
||||
#[test]
|
||||
fn revalidate() {
|
||||
let mut query = TestContextImpl::default();
|
||||
|
||||
query.memoized2();
|
||||
query.assert_log(&["Memoized2 invoked", "Memoized1 invoked", "Volatile invoked"]);
|
||||
|
||||
query.memoized2();
|
||||
query.assert_log(&[]);
|
||||
|
||||
// Second generation: volatile will change (to 1) but memoized1
|
||||
// will not (still 0, as 1/2 = 0)
|
||||
query.salsa_runtime_mut().synthetic_write(Durability::LOW);
|
||||
query.memoized2();
|
||||
query.assert_log(&["Volatile invoked", "Memoized1 invoked"]);
|
||||
query.memoized2();
|
||||
query.assert_log(&[]);
|
||||
|
||||
// Third generation: volatile will change (to 2) and memoized1
|
||||
// will too (to 1). Therefore, after validating that Memoized1
|
||||
// changed, we now invoke Memoized2.
|
||||
query.salsa_runtime_mut().synthetic_write(Durability::LOW);
|
||||
|
||||
query.memoized2();
|
||||
query.assert_log(&["Volatile invoked", "Memoized1 invoked", "Memoized2 invoked"]);
|
||||
|
||||
query.memoized2();
|
||||
query.assert_log(&[]);
|
||||
}
|
90
crates/salsa/tests/interned.rs
Normal file
90
crates/salsa/tests/interned.rs
Normal file
|
@ -0,0 +1,90 @@
|
|||
//! Test that you can implement a query using a `dyn Trait` setup.
|
||||
|
||||
use salsa::InternId;
|
||||
|
||||
#[salsa::database(InternStorage)]
|
||||
#[derive(Default)]
|
||||
struct Database {
|
||||
storage: salsa::Storage<Self>,
|
||||
}
|
||||
|
||||
impl salsa::Database for Database {}
|
||||
|
||||
impl salsa::ParallelDatabase for Database {
|
||||
fn snapshot(&self) -> salsa::Snapshot<Self> {
|
||||
salsa::Snapshot::new(Database { storage: self.storage.snapshot() })
|
||||
}
|
||||
}
|
||||
|
||||
#[salsa::query_group(InternStorage)]
|
||||
trait Intern {
|
||||
#[salsa::interned]
|
||||
fn intern1(&self, x: String) -> InternId;
|
||||
|
||||
#[salsa::interned]
|
||||
fn intern2(&self, x: String, y: String) -> InternId;
|
||||
|
||||
#[salsa::interned]
|
||||
fn intern_key(&self, x: String) -> InternKey;
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct InternKey(InternId);
|
||||
|
||||
impl salsa::InternKey for InternKey {
|
||||
fn from_intern_id(v: InternId) -> Self {
|
||||
InternKey(v)
|
||||
}
|
||||
|
||||
fn as_intern_id(&self) -> InternId {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_intern1() {
|
||||
let db = Database::default();
|
||||
let foo0 = db.intern1("foo".to_owned());
|
||||
let bar0 = db.intern1("bar".to_owned());
|
||||
let foo1 = db.intern1("foo".to_owned());
|
||||
let bar1 = db.intern1("bar".to_owned());
|
||||
|
||||
assert_eq!(foo0, foo1);
|
||||
assert_eq!(bar0, bar1);
|
||||
assert_ne!(foo0, bar0);
|
||||
|
||||
assert_eq!("foo".to_owned(), db.lookup_intern1(foo0));
|
||||
assert_eq!("bar".to_owned(), db.lookup_intern1(bar0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_intern2() {
|
||||
let db = Database::default();
|
||||
let foo0 = db.intern2("x".to_owned(), "foo".to_owned());
|
||||
let bar0 = db.intern2("x".to_owned(), "bar".to_owned());
|
||||
let foo1 = db.intern2("x".to_owned(), "foo".to_owned());
|
||||
let bar1 = db.intern2("x".to_owned(), "bar".to_owned());
|
||||
|
||||
assert_eq!(foo0, foo1);
|
||||
assert_eq!(bar0, bar1);
|
||||
assert_ne!(foo0, bar0);
|
||||
|
||||
assert_eq!(("x".to_owned(), "foo".to_owned()), db.lookup_intern2(foo0));
|
||||
assert_eq!(("x".to_owned(), "bar".to_owned()), db.lookup_intern2(bar0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_intern_key() {
|
||||
let db = Database::default();
|
||||
let foo0 = db.intern_key("foo".to_owned());
|
||||
let bar0 = db.intern_key("bar".to_owned());
|
||||
let foo1 = db.intern_key("foo".to_owned());
|
||||
let bar1 = db.intern_key("bar".to_owned());
|
||||
|
||||
assert_eq!(foo0, foo1);
|
||||
assert_eq!(bar0, bar1);
|
||||
assert_ne!(foo0, bar0);
|
||||
|
||||
assert_eq!("foo".to_owned(), db.lookup_intern_key(foo0));
|
||||
assert_eq!("bar".to_owned(), db.lookup_intern_key(bar0));
|
||||
}
|
102
crates/salsa/tests/lru.rs
Normal file
102
crates/salsa/tests/lru.rs
Normal file
|
@ -0,0 +1,102 @@
|
|||
//! Test setting LRU actually limits the number of things in the database;
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
struct HotPotato(u32);
|
||||
|
||||
static N_POTATOES: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
impl HotPotato {
|
||||
fn new(id: u32) -> HotPotato {
|
||||
N_POTATOES.fetch_add(1, Ordering::SeqCst);
|
||||
HotPotato(id)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for HotPotato {
|
||||
fn drop(&mut self) {
|
||||
N_POTATOES.fetch_sub(1, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
#[salsa::query_group(QueryGroupStorage)]
|
||||
trait QueryGroup: salsa::Database {
|
||||
fn get(&self, x: u32) -> Arc<HotPotato>;
|
||||
fn get_volatile(&self, x: u32) -> usize;
|
||||
}
|
||||
|
||||
fn get(_db: &dyn QueryGroup, x: u32) -> Arc<HotPotato> {
|
||||
Arc::new(HotPotato::new(x))
|
||||
}
|
||||
|
||||
fn get_volatile(db: &dyn QueryGroup, _x: u32) -> usize {
|
||||
static COUNTER: AtomicUsize = AtomicUsize::new(0);
|
||||
db.salsa_runtime().report_untracked_read();
|
||||
COUNTER.fetch_add(1, Ordering::SeqCst)
|
||||
}
|
||||
|
||||
#[salsa::database(QueryGroupStorage)]
|
||||
#[derive(Default)]
|
||||
struct Database {
|
||||
storage: salsa::Storage<Self>,
|
||||
}
|
||||
|
||||
impl salsa::Database for Database {}
|
||||
|
||||
#[test]
|
||||
fn lru_works() {
|
||||
let mut db = Database::default();
|
||||
GetQuery.in_db_mut(&mut db).set_lru_capacity(32);
|
||||
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0);
|
||||
|
||||
for i in 0..128u32 {
|
||||
let p = db.get(i);
|
||||
assert_eq!(p.0, i)
|
||||
}
|
||||
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
|
||||
|
||||
for i in 0..128u32 {
|
||||
let p = db.get(i);
|
||||
assert_eq!(p.0, i)
|
||||
}
|
||||
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
|
||||
|
||||
GetQuery.in_db_mut(&mut db).set_lru_capacity(32);
|
||||
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
|
||||
|
||||
GetQuery.in_db_mut(&mut db).set_lru_capacity(64);
|
||||
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
|
||||
for i in 0..128u32 {
|
||||
let p = db.get(i);
|
||||
assert_eq!(p.0, i)
|
||||
}
|
||||
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 64);
|
||||
|
||||
// Special case: setting capacity to zero disables LRU
|
||||
GetQuery.in_db_mut(&mut db).set_lru_capacity(0);
|
||||
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 64);
|
||||
for i in 0..128u32 {
|
||||
let p = db.get(i);
|
||||
assert_eq!(p.0, i)
|
||||
}
|
||||
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 128);
|
||||
|
||||
drop(db);
|
||||
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lru_doesnt_break_volatile_queries() {
|
||||
let mut db = Database::default();
|
||||
GetVolatileQuery.in_db_mut(&mut db).set_lru_capacity(32);
|
||||
// Here, we check that we execute each volatile query at most once, despite
|
||||
// LRU. That does mean that we have more values in DB than the LRU capacity,
|
||||
// but it's much better than inconsistent results from volatile queries!
|
||||
for i in (0..3).flat_map(|_| 0..128usize) {
|
||||
let x = db.get_volatile(i as u32);
|
||||
assert_eq!(x, i)
|
||||
}
|
||||
}
|
11
crates/salsa/tests/macros.rs
Normal file
11
crates/salsa/tests/macros.rs
Normal file
|
@ -0,0 +1,11 @@
|
|||
#[salsa::query_group(MyStruct)]
|
||||
trait MyDatabase: salsa::Database {
|
||||
#[salsa::invoke(another_module::another_name)]
|
||||
fn my_query(&self, key: ()) -> ();
|
||||
}
|
||||
|
||||
mod another_module {
|
||||
pub(crate) fn another_name(_: &dyn crate::MyDatabase, (): ()) {}
|
||||
}
|
||||
|
||||
fn main() {}
|
31
crates/salsa/tests/no_send_sync.rs
Normal file
31
crates/salsa/tests/no_send_sync.rs
Normal file
|
@ -0,0 +1,31 @@
|
|||
use std::rc::Rc;
|
||||
|
||||
#[salsa::query_group(NoSendSyncStorage)]
|
||||
trait NoSendSyncDatabase: salsa::Database {
|
||||
fn no_send_sync_value(&self, key: bool) -> Rc<bool>;
|
||||
fn no_send_sync_key(&self, key: Rc<bool>) -> bool;
|
||||
}
|
||||
|
||||
fn no_send_sync_value(_db: &dyn NoSendSyncDatabase, key: bool) -> Rc<bool> {
|
||||
Rc::new(key)
|
||||
}
|
||||
|
||||
fn no_send_sync_key(_db: &dyn NoSendSyncDatabase, key: Rc<bool>) -> bool {
|
||||
*key
|
||||
}
|
||||
|
||||
#[salsa::database(NoSendSyncStorage)]
|
||||
#[derive(Default)]
|
||||
struct DatabaseImpl {
|
||||
storage: salsa::Storage<Self>,
|
||||
}
|
||||
|
||||
impl salsa::Database for DatabaseImpl {}
|
||||
|
||||
#[test]
|
||||
fn no_send_sync() {
|
||||
let db = DatabaseImpl::default();
|
||||
|
||||
assert_eq!(db.no_send_sync_value(true), Rc::new(true));
|
||||
assert!(!db.no_send_sync_key(Rc::new(false)));
|
||||
}
|
147
crates/salsa/tests/on_demand_inputs.rs
Normal file
147
crates/salsa/tests/on_demand_inputs.rs
Normal file
|
@ -0,0 +1,147 @@
|
|||
//! Test that "on-demand" input pattern works.
|
||||
//!
|
||||
//! On-demand inputs are inputs computed lazily on the fly. They are simulated
|
||||
//! via a b query with zero inputs, which uses `add_synthetic_read` to
|
||||
//! tweak durability and `invalidate` to clear the input.
|
||||
|
||||
#![allow(clippy::disallowed_types, clippy::type_complexity)]
|
||||
|
||||
use std::{cell::RefCell, collections::HashMap, rc::Rc};
|
||||
|
||||
use salsa::{Database as _, Durability, EventKind};
|
||||
|
||||
#[salsa::query_group(QueryGroupStorage)]
|
||||
trait QueryGroup: salsa::Database + AsRef<HashMap<u32, u32>> {
|
||||
fn a(&self, x: u32) -> u32;
|
||||
fn b(&self, x: u32) -> u32;
|
||||
fn c(&self, x: u32) -> u32;
|
||||
}
|
||||
|
||||
fn a(db: &dyn QueryGroup, x: u32) -> u32 {
|
||||
let durability = if x % 2 == 0 { Durability::LOW } else { Durability::HIGH };
|
||||
db.salsa_runtime().report_synthetic_read(durability);
|
||||
let external_state: &HashMap<u32, u32> = db.as_ref();
|
||||
external_state[&x]
|
||||
}
|
||||
|
||||
fn b(db: &dyn QueryGroup, x: u32) -> u32 {
|
||||
db.a(x)
|
||||
}
|
||||
|
||||
fn c(db: &dyn QueryGroup, x: u32) -> u32 {
|
||||
db.b(x)
|
||||
}
|
||||
|
||||
#[salsa::database(QueryGroupStorage)]
|
||||
#[derive(Default)]
|
||||
struct Database {
|
||||
storage: salsa::Storage<Self>,
|
||||
external_state: HashMap<u32, u32>,
|
||||
on_event: Option<Box<dyn Fn(&Database, salsa::Event)>>,
|
||||
}
|
||||
|
||||
impl salsa::Database for Database {
|
||||
fn salsa_event(&self, event: salsa::Event) {
|
||||
if let Some(cb) = &self.on_event {
|
||||
cb(self, event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<HashMap<u32, u32>> for Database {
|
||||
fn as_ref(&self) -> &HashMap<u32, u32> {
|
||||
&self.external_state
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn on_demand_input_works() {
|
||||
let mut db = Database::default();
|
||||
|
||||
db.external_state.insert(1, 10);
|
||||
assert_eq!(db.b(1), 10);
|
||||
assert_eq!(db.a(1), 10);
|
||||
|
||||
// We changed external state, but haven't signaled about this yet,
|
||||
// so we expect to see the old answer
|
||||
db.external_state.insert(1, 92);
|
||||
assert_eq!(db.b(1), 10);
|
||||
assert_eq!(db.a(1), 10);
|
||||
|
||||
AQuery.in_db_mut(&mut db).invalidate(&1);
|
||||
assert_eq!(db.b(1), 92);
|
||||
assert_eq!(db.a(1), 92);
|
||||
|
||||
// Downstream queries should also be rerun if we call `a` first.
|
||||
db.external_state.insert(1, 50);
|
||||
AQuery.in_db_mut(&mut db).invalidate(&1);
|
||||
assert_eq!(db.a(1), 50);
|
||||
assert_eq!(db.b(1), 50);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn on_demand_input_durability() {
|
||||
let mut db = Database::default();
|
||||
|
||||
let events = Rc::new(RefCell::new(vec![]));
|
||||
db.on_event = Some(Box::new({
|
||||
let events = events.clone();
|
||||
move |db, event| {
|
||||
if let EventKind::WillCheckCancellation = event.kind {
|
||||
// these events are not interesting
|
||||
} else {
|
||||
events.borrow_mut().push(format!("{:?}", event.debug(db)))
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
||||
events.replace(vec![]);
|
||||
db.external_state.insert(1, 10);
|
||||
db.external_state.insert(2, 20);
|
||||
assert_eq!(db.b(1), 10);
|
||||
assert_eq!(db.b(2), 20);
|
||||
expect_test::expect![[r#"
|
||||
RefCell {
|
||||
value: [
|
||||
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: b(1) } }",
|
||||
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(1) } }",
|
||||
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: b(2) } }",
|
||||
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }",
|
||||
],
|
||||
}
|
||||
"#]].assert_debug_eq(&events);
|
||||
|
||||
db.salsa_runtime_mut().synthetic_write(Durability::LOW);
|
||||
events.replace(vec![]);
|
||||
assert_eq!(db.c(1), 10);
|
||||
assert_eq!(db.c(2), 20);
|
||||
// Re-execute `a(2)` because that has low durability, but not `a(1)`
|
||||
expect_test::expect![[r#"
|
||||
RefCell {
|
||||
value: [
|
||||
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: c(1) } }",
|
||||
"Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(1) } }",
|
||||
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: c(2) } }",
|
||||
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }",
|
||||
"Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(2) } }",
|
||||
],
|
||||
}
|
||||
"#]].assert_debug_eq(&events);
|
||||
|
||||
db.salsa_runtime_mut().synthetic_write(Durability::HIGH);
|
||||
events.replace(vec![]);
|
||||
assert_eq!(db.c(1), 10);
|
||||
assert_eq!(db.c(2), 20);
|
||||
// Re-execute both `a(1)` and `a(2)`, but we don't re-execute any `b` queries as the
|
||||
// result didn't actually change.
|
||||
expect_test::expect![[r#"
|
||||
RefCell {
|
||||
value: [
|
||||
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(1) } }",
|
||||
"Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(1) } }",
|
||||
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }",
|
||||
"Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(2) } }",
|
||||
],
|
||||
}
|
||||
"#]].assert_debug_eq(&events);
|
||||
}
|
93
crates/salsa/tests/panic_safely.rs
Normal file
93
crates/salsa/tests/panic_safely.rs
Normal file
|
@ -0,0 +1,93 @@
|
|||
use salsa::{Database, ParallelDatabase, Snapshot};
|
||||
use std::panic::{self, AssertUnwindSafe};
|
||||
use std::sync::atomic::{AtomicU32, Ordering::SeqCst};
|
||||
|
||||
#[salsa::query_group(PanicSafelyStruct)]
|
||||
trait PanicSafelyDatabase: salsa::Database {
|
||||
#[salsa::input]
|
||||
fn one(&self) -> usize;
|
||||
|
||||
fn panic_safely(&self) -> ();
|
||||
|
||||
fn outer(&self) -> ();
|
||||
}
|
||||
|
||||
fn panic_safely(db: &dyn PanicSafelyDatabase) {
|
||||
assert_eq!(db.one(), 1);
|
||||
}
|
||||
|
||||
static OUTER_CALLS: AtomicU32 = AtomicU32::new(0);
|
||||
|
||||
fn outer(db: &dyn PanicSafelyDatabase) {
|
||||
OUTER_CALLS.fetch_add(1, SeqCst);
|
||||
db.panic_safely();
|
||||
}
|
||||
|
||||
#[salsa::database(PanicSafelyStruct)]
|
||||
#[derive(Default)]
|
||||
struct DatabaseStruct {
|
||||
storage: salsa::Storage<Self>,
|
||||
}
|
||||
|
||||
impl salsa::Database for DatabaseStruct {}
|
||||
|
||||
impl salsa::ParallelDatabase for DatabaseStruct {
|
||||
fn snapshot(&self) -> Snapshot<Self> {
|
||||
Snapshot::new(DatabaseStruct { storage: self.storage.snapshot() })
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_panic_safely() {
|
||||
let mut db = DatabaseStruct::default();
|
||||
db.set_one(0);
|
||||
|
||||
// Invoke `db.panic_safely() without having set `db.one`. `db.one` will
|
||||
// return 0 and we should catch the panic.
|
||||
let result = panic::catch_unwind(AssertUnwindSafe({
|
||||
let db = db.snapshot();
|
||||
move || db.panic_safely()
|
||||
}));
|
||||
assert!(result.is_err());
|
||||
|
||||
// Set `db.one` to 1 and assert ok
|
||||
db.set_one(1);
|
||||
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Check, that memoized outer is not invalidated by a panic
|
||||
{
|
||||
assert_eq!(OUTER_CALLS.load(SeqCst), 0);
|
||||
db.outer();
|
||||
assert_eq!(OUTER_CALLS.load(SeqCst), 1);
|
||||
|
||||
db.set_one(0);
|
||||
let result = panic::catch_unwind(AssertUnwindSafe(|| db.outer()));
|
||||
assert!(result.is_err());
|
||||
assert_eq!(OUTER_CALLS.load(SeqCst), 1);
|
||||
|
||||
db.set_one(1);
|
||||
db.outer();
|
||||
assert_eq!(OUTER_CALLS.load(SeqCst), 2);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn storages_are_unwind_safe() {
|
||||
fn check_unwind_safe<T: std::panic::UnwindSafe>() {}
|
||||
check_unwind_safe::<&DatabaseStruct>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn panics_clear_query_stack() {
|
||||
let db = DatabaseStruct::default();
|
||||
|
||||
// Invoke `db.panic_if_not_one() without having set `db.input`. `db.input`
|
||||
// will default to 0 and we should catch the panic.
|
||||
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
|
||||
assert!(result.is_err());
|
||||
|
||||
// The database has been poisoned and any attempt to increment the
|
||||
// revision should panic.
|
||||
assert_eq!(db.salsa_runtime().active_query(), None);
|
||||
}
|
132
crates/salsa/tests/parallel/cancellation.rs
Normal file
132
crates/salsa/tests/parallel/cancellation.rs
Normal file
|
@ -0,0 +1,132 @@
|
|||
use crate::setup::{CancellationFlag, Knobs, ParDatabase, ParDatabaseImpl, WithValue};
|
||||
use salsa::{Cancelled, ParallelDatabase};
|
||||
|
||||
macro_rules! assert_cancelled {
|
||||
($thread:expr) => {
|
||||
match $thread.join() {
|
||||
Ok(value) => panic!("expected cancellation, got {:?}", value),
|
||||
Err(payload) => match payload.downcast::<Cancelled>() {
|
||||
Ok(_) => {}
|
||||
Err(payload) => ::std::panic::resume_unwind(payload),
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Add test where a call to `sum` is cancelled by a simultaneous
|
||||
/// write. Check that we recompute the result in next revision, even
|
||||
/// though none of the inputs have changed.
|
||||
#[test]
|
||||
fn in_par_get_set_cancellation_immediate() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 100);
|
||||
db.set_input('b', 10);
|
||||
db.set_input('c', 1);
|
||||
db.set_input('d', 0);
|
||||
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
// This will not return until it sees cancellation is
|
||||
// signaled.
|
||||
db.knobs().sum_signal_on_entry.with_value(1, || {
|
||||
db.knobs()
|
||||
.sum_wait_for_cancellation
|
||||
.with_value(CancellationFlag::Panic, || db.sum("abc"))
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
// Wait until we have entered `sum` in the other thread.
|
||||
db.wait_for(1);
|
||||
|
||||
// Try to set the input. This will signal cancellation.
|
||||
db.set_input('d', 1000);
|
||||
|
||||
// This should re-compute the value (even though no input has changed).
|
||||
let thread2 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.sum("abc")
|
||||
});
|
||||
|
||||
assert_eq!(db.sum("d"), 1000);
|
||||
assert_cancelled!(thread1);
|
||||
assert_eq!(thread2.join().unwrap(), 111);
|
||||
}
|
||||
|
||||
/// Here, we check that `sum`'s cancellation is propagated
|
||||
/// to `sum2` properly.
|
||||
#[test]
|
||||
fn in_par_get_set_cancellation_transitive() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 100);
|
||||
db.set_input('b', 10);
|
||||
db.set_input('c', 1);
|
||||
db.set_input('d', 0);
|
||||
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
// This will not return until it sees cancellation is
|
||||
// signaled.
|
||||
db.knobs().sum_signal_on_entry.with_value(1, || {
|
||||
db.knobs()
|
||||
.sum_wait_for_cancellation
|
||||
.with_value(CancellationFlag::Panic, || db.sum2("abc"))
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
// Wait until we have entered `sum` in the other thread.
|
||||
db.wait_for(1);
|
||||
|
||||
// Try to set the input. This will signal cancellation.
|
||||
db.set_input('d', 1000);
|
||||
|
||||
// This should re-compute the value (even though no input has changed).
|
||||
let thread2 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.sum2("abc")
|
||||
});
|
||||
|
||||
assert_eq!(db.sum2("d"), 1000);
|
||||
assert_cancelled!(thread1);
|
||||
assert_eq!(thread2.join().unwrap(), 111);
|
||||
}
|
||||
|
||||
/// https://github.com/salsa-rs/salsa/issues/66
|
||||
#[test]
|
||||
fn no_back_dating_in_cancellation() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 1);
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
// Here we compute a long-chain of queries,
|
||||
// but the last one gets cancelled.
|
||||
db.knobs().sum_signal_on_entry.with_value(1, || {
|
||||
db.knobs()
|
||||
.sum_wait_for_cancellation
|
||||
.with_value(CancellationFlag::Panic, || db.sum3("a"))
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
db.wait_for(1);
|
||||
|
||||
// Set unrelated input to bump revision
|
||||
db.set_input('b', 2);
|
||||
|
||||
// Here we should recompuet the whole chain again, clearing the cancellation
|
||||
// state. If we get `usize::max()` here, it is a bug!
|
||||
assert_eq!(db.sum3("a"), 1);
|
||||
|
||||
assert_cancelled!(thread1);
|
||||
|
||||
db.set_input('a', 3);
|
||||
db.set_input('a', 4);
|
||||
assert_eq!(db.sum3("ab"), 6);
|
||||
}
|
57
crates/salsa/tests/parallel/frozen.rs
Normal file
57
crates/salsa/tests/parallel/frozen.rs
Normal file
|
@ -0,0 +1,57 @@
|
|||
use crate::setup::{ParDatabase, ParDatabaseImpl};
|
||||
use crate::signal::Signal;
|
||||
use salsa::{Database, ParallelDatabase};
|
||||
use std::{
|
||||
panic::{catch_unwind, AssertUnwindSafe},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
/// Add test where a call to `sum` is cancelled by a simultaneous
|
||||
/// write. Check that we recompute the result in next revision, even
|
||||
/// though none of the inputs have changed.
|
||||
#[test]
|
||||
fn in_par_get_set_cancellation() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 1);
|
||||
|
||||
let signal = Arc::new(Signal::default());
|
||||
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
let signal = signal.clone();
|
||||
move || {
|
||||
// Check that cancellation flag is not yet set, because
|
||||
// `set` cannot have been called yet.
|
||||
catch_unwind(AssertUnwindSafe(|| db.unwind_if_cancelled())).unwrap();
|
||||
|
||||
// Signal other thread to proceed.
|
||||
signal.signal(1);
|
||||
|
||||
// Wait for other thread to signal cancellation
|
||||
catch_unwind(AssertUnwindSafe(|| loop {
|
||||
db.unwind_if_cancelled();
|
||||
std::thread::yield_now();
|
||||
}))
|
||||
.unwrap_err();
|
||||
}
|
||||
});
|
||||
|
||||
let thread2 = std::thread::spawn({
|
||||
move || {
|
||||
// Wait until thread 1 has asserted that they are not cancelled
|
||||
// before we invoke `set.`
|
||||
signal.wait_for(1);
|
||||
|
||||
// This will block until thread1 drops the revision lock.
|
||||
db.set_input('a', 2);
|
||||
|
||||
db.input('a')
|
||||
}
|
||||
});
|
||||
|
||||
thread1.join().unwrap();
|
||||
|
||||
let c = thread2.join().unwrap();
|
||||
assert_eq!(c, 2);
|
||||
}
|
29
crates/salsa/tests/parallel/independent.rs
Normal file
29
crates/salsa/tests/parallel/independent.rs
Normal file
|
@ -0,0 +1,29 @@
|
|||
use crate::setup::{ParDatabase, ParDatabaseImpl};
|
||||
use salsa::ParallelDatabase;
|
||||
|
||||
/// Test two `sum` queries (on distinct keys) executing in different
|
||||
/// threads. Really just a test that `snapshot` etc compiles.
|
||||
#[test]
|
||||
fn in_par_two_independent_queries() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 100);
|
||||
db.set_input('b', 10);
|
||||
db.set_input('c', 1);
|
||||
db.set_input('d', 200);
|
||||
db.set_input('e', 20);
|
||||
db.set_input('f', 2);
|
||||
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.sum("abc")
|
||||
});
|
||||
|
||||
let thread2 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.sum("def")
|
||||
});
|
||||
|
||||
assert_eq!(thread1.join().unwrap(), 111);
|
||||
assert_eq!(thread2.join().unwrap(), 222);
|
||||
}
|
13
crates/salsa/tests/parallel/main.rs
Normal file
13
crates/salsa/tests/parallel/main.rs
Normal file
|
@ -0,0 +1,13 @@
|
|||
mod setup;
|
||||
|
||||
mod cancellation;
|
||||
mod frozen;
|
||||
mod independent;
|
||||
mod parallel_cycle_all_recover;
|
||||
mod parallel_cycle_mid_recover;
|
||||
mod parallel_cycle_none_recover;
|
||||
mod parallel_cycle_one_recovers;
|
||||
mod race;
|
||||
mod signal;
|
||||
mod stress;
|
||||
mod true_parallel;
|
110
crates/salsa/tests/parallel/parallel_cycle_all_recover.rs
Normal file
110
crates/salsa/tests/parallel/parallel_cycle_all_recover.rs
Normal file
|
@ -0,0 +1,110 @@
|
|||
//! Test for cycle recover spread across two threads.
|
||||
//! See `../cycles.rs` for a complete listing of cycle tests,
|
||||
//! both intra and cross thread.
|
||||
|
||||
use crate::setup::{Knobs, ParDatabaseImpl};
|
||||
use salsa::ParallelDatabase;
|
||||
use test_log::test;
|
||||
|
||||
// Recover cycle test:
|
||||
//
|
||||
// The pattern is as follows.
|
||||
//
|
||||
// Thread A Thread B
|
||||
// -------- --------
|
||||
// a1 b1
|
||||
// | wait for stage 1 (blocks)
|
||||
// signal stage 1 |
|
||||
// wait for stage 2 (blocks) (unblocked)
|
||||
// | signal stage 2
|
||||
// (unblocked) wait for stage 3 (blocks)
|
||||
// a2 |
|
||||
// b1 (blocks -> stage 3) |
|
||||
// | (unblocked)
|
||||
// | b2
|
||||
// | a1 (cycle detected, recovers)
|
||||
// | b2 completes, recovers
|
||||
// | b1 completes, recovers
|
||||
// a2 sees cycle, recovers
|
||||
// a1 completes, recovers
|
||||
|
||||
#[test]
|
||||
fn parallel_cycle_all_recover() {
|
||||
let db = ParDatabaseImpl::default();
|
||||
db.knobs().signal_on_will_block.set(3);
|
||||
|
||||
let thread_a = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.a1(1)
|
||||
});
|
||||
|
||||
let thread_b = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.b1(1)
|
||||
});
|
||||
|
||||
assert_eq!(thread_a.join().unwrap(), 11);
|
||||
assert_eq!(thread_b.join().unwrap(), 21);
|
||||
}
|
||||
|
||||
#[salsa::query_group(ParallelCycleAllRecover)]
|
||||
pub(crate) trait TestDatabase: Knobs {
|
||||
#[salsa::cycle(recover_a1)]
|
||||
fn a1(&self, key: i32) -> i32;
|
||||
|
||||
#[salsa::cycle(recover_a2)]
|
||||
fn a2(&self, key: i32) -> i32;
|
||||
|
||||
#[salsa::cycle(recover_b1)]
|
||||
fn b1(&self, key: i32) -> i32;
|
||||
|
||||
#[salsa::cycle(recover_b2)]
|
||||
fn b2(&self, key: i32) -> i32;
|
||||
}
|
||||
|
||||
fn recover_a1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
|
||||
tracing::debug!("recover_a1");
|
||||
key * 10 + 1
|
||||
}
|
||||
|
||||
fn recover_a2(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
|
||||
tracing::debug!("recover_a2");
|
||||
key * 10 + 2
|
||||
}
|
||||
|
||||
fn recover_b1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
|
||||
tracing::debug!("recover_b1");
|
||||
key * 20 + 1
|
||||
}
|
||||
|
||||
fn recover_b2(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
|
||||
tracing::debug!("recover_b2");
|
||||
key * 20 + 2
|
||||
}
|
||||
|
||||
fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// Wait to create the cycle until both threads have entered
|
||||
db.signal(1);
|
||||
db.wait_for(2);
|
||||
|
||||
db.a2(key)
|
||||
}
|
||||
|
||||
fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
db.b1(key)
|
||||
}
|
||||
|
||||
fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// Wait to create the cycle until both threads have entered
|
||||
db.wait_for(1);
|
||||
db.signal(2);
|
||||
|
||||
// Wait for thread A to block on this thread
|
||||
db.wait_for(3);
|
||||
|
||||
db.b2(key)
|
||||
}
|
||||
|
||||
fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
db.a1(key)
|
||||
}
|
110
crates/salsa/tests/parallel/parallel_cycle_mid_recover.rs
Normal file
110
crates/salsa/tests/parallel/parallel_cycle_mid_recover.rs
Normal file
|
@ -0,0 +1,110 @@
|
|||
//! Test for cycle recover spread across two threads.
|
||||
//! See `../cycles.rs` for a complete listing of cycle tests,
|
||||
//! both intra and cross thread.
|
||||
|
||||
use crate::setup::{Knobs, ParDatabaseImpl};
|
||||
use salsa::ParallelDatabase;
|
||||
use test_log::test;
|
||||
|
||||
// Recover cycle test:
|
||||
//
|
||||
// The pattern is as follows.
|
||||
//
|
||||
// Thread A Thread B
|
||||
// -------- --------
|
||||
// a1 b1
|
||||
// | wait for stage 1 (blocks)
|
||||
// signal stage 1 |
|
||||
// wait for stage 2 (blocks) (unblocked)
|
||||
// | |
|
||||
// | b2
|
||||
// | b3
|
||||
// | a1 (blocks -> stage 2)
|
||||
// (unblocked) |
|
||||
// a2 (cycle detected) |
|
||||
// b3 recovers
|
||||
// b2 resumes
|
||||
// b1 panics because bug
|
||||
|
||||
#[test]
|
||||
fn parallel_cycle_mid_recovers() {
|
||||
let db = ParDatabaseImpl::default();
|
||||
db.knobs().signal_on_will_block.set(2);
|
||||
|
||||
let thread_a = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.a1(1)
|
||||
});
|
||||
|
||||
let thread_b = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.b1(1)
|
||||
});
|
||||
|
||||
// We expect that the recovery function yields
|
||||
// `1 * 20 + 2`, which is returned (and forwarded)
|
||||
// to b1, and from there to a2 and a1.
|
||||
assert_eq!(thread_a.join().unwrap(), 22);
|
||||
assert_eq!(thread_b.join().unwrap(), 22);
|
||||
}
|
||||
|
||||
#[salsa::query_group(ParallelCycleMidRecovers)]
|
||||
pub(crate) trait TestDatabase: Knobs {
|
||||
fn a1(&self, key: i32) -> i32;
|
||||
|
||||
fn a2(&self, key: i32) -> i32;
|
||||
|
||||
#[salsa::cycle(recover_b1)]
|
||||
fn b1(&self, key: i32) -> i32;
|
||||
|
||||
fn b2(&self, key: i32) -> i32;
|
||||
|
||||
#[salsa::cycle(recover_b3)]
|
||||
fn b3(&self, key: i32) -> i32;
|
||||
}
|
||||
|
||||
fn recover_b1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
|
||||
tracing::debug!("recover_b1");
|
||||
key * 20 + 2
|
||||
}
|
||||
|
||||
fn recover_b3(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
|
||||
tracing::debug!("recover_b1");
|
||||
key * 200 + 2
|
||||
}
|
||||
|
||||
fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// tell thread b we have started
|
||||
db.signal(1);
|
||||
|
||||
// wait for thread b to block on a1
|
||||
db.wait_for(2);
|
||||
|
||||
db.a2(key)
|
||||
}
|
||||
|
||||
fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// create the cycle
|
||||
db.b1(key)
|
||||
}
|
||||
|
||||
fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// wait for thread a to have started
|
||||
db.wait_for(1);
|
||||
|
||||
db.b2(key);
|
||||
|
||||
0
|
||||
}
|
||||
|
||||
fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// will encounter a cycle but recover
|
||||
db.b3(key);
|
||||
db.b1(key); // hasn't recovered yet
|
||||
0
|
||||
}
|
||||
|
||||
fn b3(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// will block on thread a, signaling stage 2
|
||||
db.a1(key)
|
||||
}
|
69
crates/salsa/tests/parallel/parallel_cycle_none_recover.rs
Normal file
69
crates/salsa/tests/parallel/parallel_cycle_none_recover.rs
Normal file
|
@ -0,0 +1,69 @@
|
|||
//! Test a cycle where no queries recover that occurs across threads.
|
||||
//! See the `../cycles.rs` for a complete listing of cycle tests,
|
||||
//! both intra and cross thread.
|
||||
|
||||
use crate::setup::{Knobs, ParDatabaseImpl};
|
||||
use expect_test::expect;
|
||||
use salsa::ParallelDatabase;
|
||||
use test_log::test;
|
||||
|
||||
#[test]
|
||||
fn parallel_cycle_none_recover() {
|
||||
let db = ParDatabaseImpl::default();
|
||||
db.knobs().signal_on_will_block.set(3);
|
||||
|
||||
let thread_a = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.a(-1)
|
||||
});
|
||||
|
||||
let thread_b = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.b(-1)
|
||||
});
|
||||
|
||||
// We expect B to panic because it detects a cycle (it is the one that calls A, ultimately).
|
||||
// Right now, it panics with a string.
|
||||
let err_b = thread_b.join().unwrap_err();
|
||||
if let Some(c) = err_b.downcast_ref::<salsa::Cycle>() {
|
||||
expect![[r#"
|
||||
[
|
||||
"a(-1)",
|
||||
"b(-1)",
|
||||
]
|
||||
"#]]
|
||||
.assert_debug_eq(&c.unexpected_participants(&db));
|
||||
} else {
|
||||
panic!("b failed in an unexpected way: {:?}", err_b);
|
||||
}
|
||||
|
||||
// We expect A to propagate a panic, which causes us to use the sentinel
|
||||
// type `Canceled`.
|
||||
assert!(thread_a.join().unwrap_err().downcast_ref::<salsa::Cycle>().is_some());
|
||||
}
|
||||
|
||||
#[salsa::query_group(ParallelCycleNoneRecover)]
|
||||
pub(crate) trait TestDatabase: Knobs {
|
||||
fn a(&self, key: i32) -> i32;
|
||||
fn b(&self, key: i32) -> i32;
|
||||
}
|
||||
|
||||
fn a(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// Wait to create the cycle until both threads have entered
|
||||
db.signal(1);
|
||||
db.wait_for(2);
|
||||
|
||||
db.b(key)
|
||||
}
|
||||
|
||||
fn b(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// Wait to create the cycle until both threads have entered
|
||||
db.wait_for(1);
|
||||
db.signal(2);
|
||||
|
||||
// Wait for thread A to block on this thread
|
||||
db.wait_for(3);
|
||||
|
||||
// Now try to execute A
|
||||
db.a(key)
|
||||
}
|
95
crates/salsa/tests/parallel/parallel_cycle_one_recovers.rs
Normal file
95
crates/salsa/tests/parallel/parallel_cycle_one_recovers.rs
Normal file
|
@ -0,0 +1,95 @@
|
|||
//! Test for cycle recover spread across two threads.
|
||||
//! See `../cycles.rs` for a complete listing of cycle tests,
|
||||
//! both intra and cross thread.
|
||||
|
||||
use crate::setup::{Knobs, ParDatabaseImpl};
|
||||
use salsa::ParallelDatabase;
|
||||
use test_log::test;
|
||||
|
||||
// Recover cycle test:
|
||||
//
|
||||
// The pattern is as follows.
|
||||
//
|
||||
// Thread A Thread B
|
||||
// -------- --------
|
||||
// a1 b1
|
||||
// | wait for stage 1 (blocks)
|
||||
// signal stage 1 |
|
||||
// wait for stage 2 (blocks) (unblocked)
|
||||
// | signal stage 2
|
||||
// (unblocked) wait for stage 3 (blocks)
|
||||
// a2 |
|
||||
// b1 (blocks -> stage 3) |
|
||||
// | (unblocked)
|
||||
// | b2
|
||||
// | a1 (cycle detected)
|
||||
// a2 recovery fn executes |
|
||||
// a1 completes normally |
|
||||
// b2 completes, recovers
|
||||
// b1 completes, recovers
|
||||
|
||||
#[test]
|
||||
fn parallel_cycle_one_recovers() {
|
||||
let db = ParDatabaseImpl::default();
|
||||
db.knobs().signal_on_will_block.set(3);
|
||||
|
||||
let thread_a = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.a1(1)
|
||||
});
|
||||
|
||||
let thread_b = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.b1(1)
|
||||
});
|
||||
|
||||
// We expect that the recovery function yields
|
||||
// `1 * 20 + 2`, which is returned (and forwarded)
|
||||
// to b1, and from there to a2 and a1.
|
||||
assert_eq!(thread_a.join().unwrap(), 22);
|
||||
assert_eq!(thread_b.join().unwrap(), 22);
|
||||
}
|
||||
|
||||
#[salsa::query_group(ParallelCycleOneRecovers)]
|
||||
pub(crate) trait TestDatabase: Knobs {
|
||||
fn a1(&self, key: i32) -> i32;
|
||||
|
||||
#[salsa::cycle(recover)]
|
||||
fn a2(&self, key: i32) -> i32;
|
||||
|
||||
fn b1(&self, key: i32) -> i32;
|
||||
|
||||
fn b2(&self, key: i32) -> i32;
|
||||
}
|
||||
|
||||
fn recover(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
|
||||
tracing::debug!("recover");
|
||||
key * 20 + 2
|
||||
}
|
||||
|
||||
fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// Wait to create the cycle until both threads have entered
|
||||
db.signal(1);
|
||||
db.wait_for(2);
|
||||
|
||||
db.a2(key)
|
||||
}
|
||||
|
||||
fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
db.b1(key)
|
||||
}
|
||||
|
||||
fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// Wait to create the cycle until both threads have entered
|
||||
db.wait_for(1);
|
||||
db.signal(2);
|
||||
|
||||
// Wait for thread A to block on this thread
|
||||
db.wait_for(3);
|
||||
|
||||
db.b2(key)
|
||||
}
|
||||
|
||||
fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
db.a1(key)
|
||||
}
|
37
crates/salsa/tests/parallel/race.rs
Normal file
37
crates/salsa/tests/parallel/race.rs
Normal file
|
@ -0,0 +1,37 @@
|
|||
use std::panic::AssertUnwindSafe;
|
||||
|
||||
use crate::setup::{ParDatabase, ParDatabaseImpl};
|
||||
use salsa::{Cancelled, ParallelDatabase};
|
||||
|
||||
/// Test where a read and a set are racing with one another.
|
||||
/// Should be atomic.
|
||||
#[test]
|
||||
fn in_par_get_set_race() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 100);
|
||||
db.set_input('b', 10);
|
||||
db.set_input('c', 1);
|
||||
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || Cancelled::catch(AssertUnwindSafe(|| db.sum("abc")))
|
||||
});
|
||||
|
||||
let thread2 = std::thread::spawn(move || {
|
||||
db.set_input('a', 1000);
|
||||
db.sum("a")
|
||||
});
|
||||
|
||||
// If the 1st thread runs first, you get 111, otherwise you get
|
||||
// 1011; if they run concurrently and the 1st thread observes the
|
||||
// cancellation, it'll unwind.
|
||||
let result1 = thread1.join().unwrap();
|
||||
if let Ok(value1) = result1 {
|
||||
assert!(value1 == 111 || value1 == 1011, "illegal result {}", value1);
|
||||
}
|
||||
|
||||
// thread2 can not observe a cancellation because it performs a
|
||||
// database write before running any other queries.
|
||||
assert_eq!(thread2.join().unwrap(), 1000);
|
||||
}
|
197
crates/salsa/tests/parallel/setup.rs
Normal file
197
crates/salsa/tests/parallel/setup.rs
Normal file
|
@ -0,0 +1,197 @@
|
|||
use crate::signal::Signal;
|
||||
use salsa::Database;
|
||||
use salsa::ParallelDatabase;
|
||||
use salsa::Snapshot;
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
cell::Cell,
|
||||
panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
|
||||
};
|
||||
|
||||
#[salsa::query_group(Par)]
|
||||
pub(crate) trait ParDatabase: Knobs {
|
||||
#[salsa::input]
|
||||
fn input(&self, key: char) -> usize;
|
||||
|
||||
fn sum(&self, key: &'static str) -> usize;
|
||||
|
||||
/// Invokes `sum`
|
||||
fn sum2(&self, key: &'static str) -> usize;
|
||||
|
||||
/// Invokes `sum` but doesn't really care about the result.
|
||||
fn sum2_drop_sum(&self, key: &'static str) -> usize;
|
||||
|
||||
/// Invokes `sum2`
|
||||
fn sum3(&self, key: &'static str) -> usize;
|
||||
|
||||
/// Invokes `sum2_drop_sum`
|
||||
fn sum3_drop_sum(&self, key: &'static str) -> usize;
|
||||
}
|
||||
|
||||
/// Various "knobs" and utilities used by tests to force
|
||||
/// a certain behavior.
|
||||
pub(crate) trait Knobs {
|
||||
fn knobs(&self) -> &KnobsStruct;
|
||||
|
||||
fn signal(&self, stage: usize);
|
||||
|
||||
fn wait_for(&self, stage: usize);
|
||||
}
|
||||
|
||||
pub(crate) trait WithValue<T> {
|
||||
fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R;
|
||||
}
|
||||
|
||||
impl<T> WithValue<T> for Cell<T> {
|
||||
fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R {
|
||||
let old_value = self.replace(value);
|
||||
|
||||
let result = catch_unwind(AssertUnwindSafe(closure));
|
||||
|
||||
self.set(old_value);
|
||||
|
||||
match result {
|
||||
Ok(r) => r,
|
||||
Err(payload) => resume_unwind(payload),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum CancellationFlag {
|
||||
#[default]
|
||||
Down,
|
||||
Panic,
|
||||
}
|
||||
|
||||
/// Various "knobs" that can be used to customize how the queries
|
||||
/// behave on one specific thread. Note that this state is
|
||||
/// intentionally thread-local (apart from `signal`).
|
||||
#[derive(Clone, Default)]
|
||||
pub(crate) struct KnobsStruct {
|
||||
/// A kind of flexible barrier used to coordinate execution across
|
||||
/// threads to ensure we reach various weird states.
|
||||
pub(crate) signal: Arc<Signal>,
|
||||
|
||||
/// When this database is about to block, send a signal.
|
||||
pub(crate) signal_on_will_block: Cell<usize>,
|
||||
|
||||
/// Invocations of `sum` will signal this stage on entry.
|
||||
pub(crate) sum_signal_on_entry: Cell<usize>,
|
||||
|
||||
/// Invocations of `sum` will wait for this stage on entry.
|
||||
pub(crate) sum_wait_for_on_entry: Cell<usize>,
|
||||
|
||||
/// If true, invocations of `sum` will panic before they exit.
|
||||
pub(crate) sum_should_panic: Cell<bool>,
|
||||
|
||||
/// If true, invocations of `sum` will wait for cancellation before
|
||||
/// they exit.
|
||||
pub(crate) sum_wait_for_cancellation: Cell<CancellationFlag>,
|
||||
|
||||
/// Invocations of `sum` will wait for this stage prior to exiting.
|
||||
pub(crate) sum_wait_for_on_exit: Cell<usize>,
|
||||
|
||||
/// Invocations of `sum` will signal this stage prior to exiting.
|
||||
pub(crate) sum_signal_on_exit: Cell<usize>,
|
||||
|
||||
/// Invocations of `sum3_drop_sum` will panic unconditionally
|
||||
pub(crate) sum3_drop_sum_should_panic: Cell<bool>,
|
||||
}
|
||||
|
||||
fn sum(db: &dyn ParDatabase, key: &'static str) -> usize {
|
||||
let mut sum = 0;
|
||||
|
||||
db.signal(db.knobs().sum_signal_on_entry.get());
|
||||
|
||||
db.wait_for(db.knobs().sum_wait_for_on_entry.get());
|
||||
|
||||
if db.knobs().sum_should_panic.get() {
|
||||
panic!("query set to panic before exit")
|
||||
}
|
||||
|
||||
for ch in key.chars() {
|
||||
sum += db.input(ch);
|
||||
}
|
||||
|
||||
match db.knobs().sum_wait_for_cancellation.get() {
|
||||
CancellationFlag::Down => (),
|
||||
CancellationFlag::Panic => {
|
||||
tracing::debug!("waiting for cancellation");
|
||||
loop {
|
||||
db.unwind_if_cancelled();
|
||||
std::thread::yield_now();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.wait_for(db.knobs().sum_wait_for_on_exit.get());
|
||||
|
||||
db.signal(db.knobs().sum_signal_on_exit.get());
|
||||
|
||||
sum
|
||||
}
|
||||
|
||||
fn sum2(db: &dyn ParDatabase, key: &'static str) -> usize {
|
||||
db.sum(key)
|
||||
}
|
||||
|
||||
fn sum2_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize {
|
||||
let _ = db.sum(key);
|
||||
22
|
||||
}
|
||||
|
||||
fn sum3(db: &dyn ParDatabase, key: &'static str) -> usize {
|
||||
db.sum2(key)
|
||||
}
|
||||
|
||||
fn sum3_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize {
|
||||
if db.knobs().sum3_drop_sum_should_panic.get() {
|
||||
panic!("sum3_drop_sum executed")
|
||||
}
|
||||
db.sum2_drop_sum(key)
|
||||
}
|
||||
|
||||
#[salsa::database(
|
||||
Par,
|
||||
crate::parallel_cycle_all_recover::ParallelCycleAllRecover,
|
||||
crate::parallel_cycle_none_recover::ParallelCycleNoneRecover,
|
||||
crate::parallel_cycle_mid_recover::ParallelCycleMidRecovers,
|
||||
crate::parallel_cycle_one_recovers::ParallelCycleOneRecovers
|
||||
)]
|
||||
#[derive(Default)]
|
||||
pub(crate) struct ParDatabaseImpl {
|
||||
storage: salsa::Storage<Self>,
|
||||
knobs: KnobsStruct,
|
||||
}
|
||||
|
||||
impl Database for ParDatabaseImpl {
|
||||
fn salsa_event(&self, event: salsa::Event) {
|
||||
if let salsa::EventKind::WillBlockOn { .. } = event.kind {
|
||||
self.signal(self.knobs().signal_on_will_block.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ParallelDatabase for ParDatabaseImpl {
|
||||
fn snapshot(&self) -> Snapshot<Self> {
|
||||
Snapshot::new(ParDatabaseImpl {
|
||||
storage: self.storage.snapshot(),
|
||||
knobs: self.knobs.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Knobs for ParDatabaseImpl {
|
||||
fn knobs(&self) -> &KnobsStruct {
|
||||
&self.knobs
|
||||
}
|
||||
|
||||
fn signal(&self, stage: usize) {
|
||||
self.knobs.signal.signal(stage);
|
||||
}
|
||||
|
||||
fn wait_for(&self, stage: usize) {
|
||||
self.knobs.signal.wait_for(stage);
|
||||
}
|
||||
}
|
40
crates/salsa/tests/parallel/signal.rs
Normal file
40
crates/salsa/tests/parallel/signal.rs
Normal file
|
@ -0,0 +1,40 @@
|
|||
use parking_lot::{Condvar, Mutex};
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct Signal {
|
||||
value: Mutex<usize>,
|
||||
cond_var: Condvar,
|
||||
}
|
||||
|
||||
impl Signal {
|
||||
pub(crate) fn signal(&self, stage: usize) {
|
||||
tracing::debug!("signal({})", stage);
|
||||
|
||||
// This check avoids acquiring the lock for things that will
|
||||
// clearly be a no-op. Not *necessary* but helps to ensure we
|
||||
// are more likely to encounter weird race conditions;
|
||||
// otherwise calls to `sum` will tend to be unnecessarily
|
||||
// synchronous.
|
||||
if stage > 0 {
|
||||
let mut v = self.value.lock();
|
||||
if stage > *v {
|
||||
*v = stage;
|
||||
self.cond_var.notify_all();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Waits until the given condition is true; the fn is invoked
|
||||
/// with the current stage.
|
||||
pub(crate) fn wait_for(&self, stage: usize) {
|
||||
tracing::debug!("wait_for({})", stage);
|
||||
|
||||
// As above, avoid lock if clearly a no-op.
|
||||
if stage > 0 {
|
||||
let mut v = self.value.lock();
|
||||
while *v < stage {
|
||||
self.cond_var.wait(&mut v);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
168
crates/salsa/tests/parallel/stress.rs
Normal file
168
crates/salsa/tests/parallel/stress.rs
Normal file
|
@ -0,0 +1,168 @@
|
|||
use rand::seq::SliceRandom;
|
||||
use rand::Rng;
|
||||
|
||||
use salsa::ParallelDatabase;
|
||||
use salsa::Snapshot;
|
||||
use salsa::{Cancelled, Database};
|
||||
|
||||
// Number of operations a reader performs
|
||||
const N_MUTATOR_OPS: usize = 100;
|
||||
const N_READER_OPS: usize = 100;
|
||||
|
||||
#[salsa::query_group(Stress)]
|
||||
trait StressDatabase: salsa::Database {
|
||||
#[salsa::input]
|
||||
fn a(&self, key: usize) -> usize;
|
||||
|
||||
fn b(&self, key: usize) -> usize;
|
||||
|
||||
fn c(&self, key: usize) -> usize;
|
||||
}
|
||||
|
||||
fn b(db: &dyn StressDatabase, key: usize) -> usize {
|
||||
db.unwind_if_cancelled();
|
||||
db.a(key)
|
||||
}
|
||||
|
||||
fn c(db: &dyn StressDatabase, key: usize) -> usize {
|
||||
db.b(key)
|
||||
}
|
||||
|
||||
#[salsa::database(Stress)]
|
||||
#[derive(Default)]
|
||||
struct StressDatabaseImpl {
|
||||
storage: salsa::Storage<Self>,
|
||||
}
|
||||
|
||||
impl salsa::Database for StressDatabaseImpl {}
|
||||
|
||||
impl salsa::ParallelDatabase for StressDatabaseImpl {
|
||||
fn snapshot(&self) -> Snapshot<StressDatabaseImpl> {
|
||||
Snapshot::new(StressDatabaseImpl { storage: self.storage.snapshot() })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
enum Query {
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
}
|
||||
|
||||
enum MutatorOp {
|
||||
WriteOp(WriteOp),
|
||||
LaunchReader { ops: Vec<ReadOp>, check_cancellation: bool },
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum WriteOp {
|
||||
SetA(usize, usize),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ReadOp {
|
||||
Get(Query, usize),
|
||||
}
|
||||
|
||||
impl rand::distributions::Distribution<Query> for rand::distributions::Standard {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Query {
|
||||
*[Query::A, Query::B, Query::C].choose(rng).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl rand::distributions::Distribution<MutatorOp> for rand::distributions::Standard {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> MutatorOp {
|
||||
if rng.gen_bool(0.5) {
|
||||
MutatorOp::WriteOp(rng.gen())
|
||||
} else {
|
||||
MutatorOp::LaunchReader {
|
||||
ops: (0..N_READER_OPS).map(|_| rng.gen()).collect(),
|
||||
check_cancellation: rng.gen(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl rand::distributions::Distribution<WriteOp> for rand::distributions::Standard {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> WriteOp {
|
||||
let key = rng.gen::<usize>() % 10;
|
||||
let value = rng.gen::<usize>() % 10;
|
||||
WriteOp::SetA(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
impl rand::distributions::Distribution<ReadOp> for rand::distributions::Standard {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> ReadOp {
|
||||
let query = rng.gen::<Query>();
|
||||
let key = rng.gen::<usize>() % 10;
|
||||
ReadOp::Get(query, key)
|
||||
}
|
||||
}
|
||||
|
||||
fn db_reader_thread(db: &StressDatabaseImpl, ops: Vec<ReadOp>, check_cancellation: bool) {
|
||||
for op in ops {
|
||||
if check_cancellation {
|
||||
db.unwind_if_cancelled();
|
||||
}
|
||||
op.execute(db);
|
||||
}
|
||||
}
|
||||
|
||||
impl WriteOp {
|
||||
fn execute(self, db: &mut StressDatabaseImpl) {
|
||||
match self {
|
||||
WriteOp::SetA(key, value) => {
|
||||
db.set_a(key, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReadOp {
|
||||
fn execute(self, db: &StressDatabaseImpl) {
|
||||
match self {
|
||||
ReadOp::Get(query, key) => match query {
|
||||
Query::A => {
|
||||
db.a(key);
|
||||
}
|
||||
Query::B => {
|
||||
let _ = db.b(key);
|
||||
}
|
||||
Query::C => {
|
||||
let _ = db.c(key);
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stress_test() {
|
||||
let mut db = StressDatabaseImpl::default();
|
||||
for i in 0..10 {
|
||||
db.set_a(i, i);
|
||||
}
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// generate the ops that the mutator thread will perform
|
||||
let write_ops: Vec<MutatorOp> = (0..N_MUTATOR_OPS).map(|_| rng.gen()).collect();
|
||||
|
||||
// execute the "main thread", which sometimes snapshots off other threads
|
||||
let mut all_threads = vec![];
|
||||
for op in write_ops {
|
||||
match op {
|
||||
MutatorOp::WriteOp(w) => w.execute(&mut db),
|
||||
MutatorOp::LaunchReader { ops, check_cancellation } => {
|
||||
all_threads.push(std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || Cancelled::catch(|| db_reader_thread(&db, ops, check_cancellation))
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for thread in all_threads {
|
||||
thread.join().unwrap().ok();
|
||||
}
|
||||
}
|
125
crates/salsa/tests/parallel/true_parallel.rs
Normal file
125
crates/salsa/tests/parallel/true_parallel.rs
Normal file
|
@ -0,0 +1,125 @@
|
|||
use crate::setup::{Knobs, ParDatabase, ParDatabaseImpl, WithValue};
|
||||
use salsa::ParallelDatabase;
|
||||
use std::panic::{self, AssertUnwindSafe};
|
||||
|
||||
/// Test where two threads are executing sum. We show that they can
|
||||
/// both be executing sum in parallel by having thread1 wait for
|
||||
/// thread2 to send a signal before it leaves (similarly, thread2
|
||||
/// waits for thread1 to send a signal before it enters).
|
||||
#[test]
|
||||
fn true_parallel_different_keys() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 100);
|
||||
db.set_input('b', 10);
|
||||
db.set_input('c', 1);
|
||||
|
||||
// Thread 1 will signal stage 1 when it enters and wait for stage 2.
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
let v = db
|
||||
.knobs()
|
||||
.sum_signal_on_entry
|
||||
.with_value(1, || db.knobs().sum_wait_for_on_exit.with_value(2, || db.sum("a")));
|
||||
v
|
||||
}
|
||||
});
|
||||
|
||||
// Thread 2 will wait_for stage 1 when it enters and signal stage 2
|
||||
// when it leaves.
|
||||
let thread2 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
let v = db
|
||||
.knobs()
|
||||
.sum_wait_for_on_entry
|
||||
.with_value(1, || db.knobs().sum_signal_on_exit.with_value(2, || db.sum("b")));
|
||||
v
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(thread1.join().unwrap(), 100);
|
||||
assert_eq!(thread2.join().unwrap(), 10);
|
||||
}
|
||||
|
||||
/// Add a test that tries to trigger a conflict, where we fetch
|
||||
/// `sum("abc")` from two threads simultaneously, and of them
|
||||
/// therefore has to block.
|
||||
#[test]
|
||||
fn true_parallel_same_keys() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 100);
|
||||
db.set_input('b', 10);
|
||||
db.set_input('c', 1);
|
||||
|
||||
// Thread 1 will wait_for a barrier in the start of `sum`
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
let v = db
|
||||
.knobs()
|
||||
.sum_signal_on_entry
|
||||
.with_value(1, || db.knobs().sum_wait_for_on_entry.with_value(2, || db.sum("abc")));
|
||||
v
|
||||
}
|
||||
});
|
||||
|
||||
// Thread 2 will wait until Thread 1 has entered sum and then --
|
||||
// once it has set itself to block -- signal Thread 1 to
|
||||
// continue. This way, we test out the mechanism of one thread
|
||||
// blocking on another.
|
||||
let thread2 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
db.knobs().signal.wait_for(1);
|
||||
db.knobs().signal_on_will_block.set(2);
|
||||
db.sum("abc")
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(thread1.join().unwrap(), 111);
|
||||
assert_eq!(thread2.join().unwrap(), 111);
|
||||
}
|
||||
|
||||
/// Add a test that tries to trigger a conflict, where we fetch `sum("a")`
|
||||
/// from two threads simultaneously. After `thread2` begins blocking,
|
||||
/// we force `thread1` to panic and should see that propagate to `thread2`.
|
||||
#[test]
|
||||
fn true_parallel_propagate_panic() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 1);
|
||||
|
||||
// `thread1` will wait_for a barrier in the start of `sum`. Once it can
|
||||
// continue, it will panic.
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
let v = db.knobs().sum_signal_on_entry.with_value(1, || {
|
||||
db.knobs()
|
||||
.sum_wait_for_on_entry
|
||||
.with_value(2, || db.knobs().sum_should_panic.with_value(true, || db.sum("a")))
|
||||
});
|
||||
v
|
||||
}
|
||||
});
|
||||
|
||||
// `thread2` will wait until `thread1` has entered sum and then -- once it
|
||||
// has set itself to block -- signal `thread1` to continue.
|
||||
let thread2 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
db.knobs().signal.wait_for(1);
|
||||
db.knobs().signal_on_will_block.set(2);
|
||||
db.sum("a")
|
||||
}
|
||||
});
|
||||
|
||||
let result1 = panic::catch_unwind(AssertUnwindSafe(|| thread1.join().unwrap()));
|
||||
let result2 = panic::catch_unwind(AssertUnwindSafe(|| thread2.join().unwrap()));
|
||||
|
||||
assert!(result1.is_err());
|
||||
assert!(result2.is_err());
|
||||
}
|
19
crates/salsa/tests/storage_varieties/implementation.rs
Normal file
19
crates/salsa/tests/storage_varieties/implementation.rs
Normal file
|
@ -0,0 +1,19 @@
|
|||
use crate::queries;
|
||||
use std::cell::Cell;
|
||||
|
||||
#[salsa::database(queries::GroupStruct)]
|
||||
#[derive(Default)]
|
||||
pub(crate) struct DatabaseImpl {
|
||||
storage: salsa::Storage<Self>,
|
||||
counter: Cell<usize>,
|
||||
}
|
||||
|
||||
impl queries::Counter for DatabaseImpl {
|
||||
fn increment(&self) -> usize {
|
||||
let v = self.counter.get();
|
||||
self.counter.set(v + 1);
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl salsa::Database for DatabaseImpl {}
|
5
crates/salsa/tests/storage_varieties/main.rs
Normal file
5
crates/salsa/tests/storage_varieties/main.rs
Normal file
|
@ -0,0 +1,5 @@
|
|||
mod implementation;
|
||||
mod queries;
|
||||
mod tests;
|
||||
|
||||
fn main() {}
|
22
crates/salsa/tests/storage_varieties/queries.rs
Normal file
22
crates/salsa/tests/storage_varieties/queries.rs
Normal file
|
@ -0,0 +1,22 @@
|
|||
pub(crate) trait Counter: salsa::Database {
|
||||
fn increment(&self) -> usize;
|
||||
}
|
||||
|
||||
#[salsa::query_group(GroupStruct)]
|
||||
pub(crate) trait Database: Counter {
|
||||
fn memoized(&self) -> usize;
|
||||
fn volatile(&self) -> usize;
|
||||
}
|
||||
|
||||
/// Because this query is memoized, we only increment the counter
|
||||
/// the first time it is invoked.
|
||||
fn memoized(db: &dyn Database) -> usize {
|
||||
db.volatile()
|
||||
}
|
||||
|
||||
/// Because this query is volatile, each time it is invoked,
|
||||
/// we will increment the counter.
|
||||
fn volatile(db: &dyn Database) -> usize {
|
||||
db.salsa_runtime().report_untracked_read();
|
||||
db.increment()
|
||||
}
|
49
crates/salsa/tests/storage_varieties/tests.rs
Normal file
49
crates/salsa/tests/storage_varieties/tests.rs
Normal file
|
@ -0,0 +1,49 @@
|
|||
#![cfg(test)]
|
||||
|
||||
use crate::implementation::DatabaseImpl;
|
||||
use crate::queries::Database;
|
||||
use salsa::Database as _Database;
|
||||
use salsa::Durability;
|
||||
|
||||
#[test]
|
||||
fn memoized_twice() {
|
||||
let db = DatabaseImpl::default();
|
||||
let v1 = db.memoized();
|
||||
let v2 = db.memoized();
|
||||
assert_eq!(v1, v2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn volatile_twice() {
|
||||
let mut db = DatabaseImpl::default();
|
||||
let v1 = db.volatile();
|
||||
let v2 = db.volatile(); // volatiles are cached, so 2nd read returns the same
|
||||
assert_eq!(v1, v2);
|
||||
|
||||
db.salsa_runtime_mut().synthetic_write(Durability::LOW); // clears volatile caches
|
||||
|
||||
let v3 = db.volatile(); // will re-increment the counter
|
||||
let v4 = db.volatile(); // second call will be cached
|
||||
assert_eq!(v1 + 1, v3);
|
||||
assert_eq!(v3, v4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn intermingled() {
|
||||
let mut db = DatabaseImpl::default();
|
||||
let v1 = db.volatile();
|
||||
let v2 = db.memoized();
|
||||
let v3 = db.volatile(); // cached
|
||||
let v4 = db.memoized(); // cached
|
||||
|
||||
assert_eq!(v1, v2);
|
||||
assert_eq!(v1, v3);
|
||||
assert_eq!(v2, v4);
|
||||
|
||||
db.salsa_runtime_mut().synthetic_write(Durability::LOW); // clears volatile caches
|
||||
|
||||
let v5 = db.memoized(); // re-executes volatile, caches new result
|
||||
let v6 = db.memoized(); // re-use cached result
|
||||
assert_eq!(v4 + 1, v5);
|
||||
assert_eq!(v5, v6);
|
||||
}
|
39
crates/salsa/tests/transparent.rs
Normal file
39
crates/salsa/tests/transparent.rs
Normal file
|
@ -0,0 +1,39 @@
|
|||
//! Test that transparent (uncached) queries work
|
||||
|
||||
#[salsa::query_group(QueryGroupStorage)]
|
||||
trait QueryGroup {
|
||||
#[salsa::input]
|
||||
fn input(&self, x: u32) -> u32;
|
||||
#[salsa::transparent]
|
||||
fn wrap(&self, x: u32) -> u32;
|
||||
fn get(&self, x: u32) -> u32;
|
||||
}
|
||||
|
||||
fn wrap(db: &dyn QueryGroup, x: u32) -> u32 {
|
||||
db.input(x)
|
||||
}
|
||||
|
||||
fn get(db: &dyn QueryGroup, x: u32) -> u32 {
|
||||
db.wrap(x)
|
||||
}
|
||||
|
||||
#[salsa::database(QueryGroupStorage)]
|
||||
#[derive(Default)]
|
||||
struct Database {
|
||||
storage: salsa::Storage<Self>,
|
||||
}
|
||||
|
||||
impl salsa::Database for Database {}
|
||||
|
||||
#[test]
|
||||
fn transparent_queries_work() {
|
||||
let mut db = Database::default();
|
||||
|
||||
db.set_input(1, 10);
|
||||
assert_eq!(db.get(1), 10);
|
||||
assert_eq!(db.get(1), 10);
|
||||
|
||||
db.set_input(1, 92);
|
||||
assert_eq!(db.get(1), 92);
|
||||
assert_eq!(db.get(1), 92);
|
||||
}
|
51
crates/salsa/tests/variadic.rs
Normal file
51
crates/salsa/tests/variadic.rs
Normal file
|
@ -0,0 +1,51 @@
|
|||
#[salsa::query_group(HelloWorld)]
|
||||
trait HelloWorldDatabase: salsa::Database {
|
||||
#[salsa::input]
|
||||
fn input(&self, a: u32, b: u32) -> u32;
|
||||
|
||||
fn none(&self) -> u32;
|
||||
|
||||
fn one(&self, k: u32) -> u32;
|
||||
|
||||
fn two(&self, a: u32, b: u32) -> u32;
|
||||
|
||||
fn trailing(&self, a: u32, b: u32) -> u32;
|
||||
}
|
||||
|
||||
fn none(_db: &dyn HelloWorldDatabase) -> u32 {
|
||||
22
|
||||
}
|
||||
|
||||
fn one(_db: &dyn HelloWorldDatabase, k: u32) -> u32 {
|
||||
k * 2
|
||||
}
|
||||
|
||||
fn two(_db: &dyn HelloWorldDatabase, a: u32, b: u32) -> u32 {
|
||||
a * b
|
||||
}
|
||||
|
||||
fn trailing(_db: &dyn HelloWorldDatabase, a: u32, b: u32) -> u32 {
|
||||
a - b
|
||||
}
|
||||
|
||||
#[salsa::database(HelloWorld)]
|
||||
#[derive(Default)]
|
||||
struct DatabaseStruct {
|
||||
storage: salsa::Storage<Self>,
|
||||
}
|
||||
|
||||
impl salsa::Database for DatabaseStruct {}
|
||||
|
||||
#[test]
|
||||
fn execute() {
|
||||
let mut db = DatabaseStruct::default();
|
||||
|
||||
// test what happens with inputs:
|
||||
db.set_input(1, 2, 3);
|
||||
assert_eq!(db.input(1, 2), 3);
|
||||
|
||||
assert_eq!(db.none(), 22);
|
||||
assert_eq!(db.one(11), 22);
|
||||
assert_eq!(db.two(11, 2), 22);
|
||||
assert_eq!(db.trailing(24, 2), 22);
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue