mirror of
https://github.com/rust-lang/rust-analyzer.git
synced 2025-09-30 05:45:12 +00:00
Merge commit 'ddf105b646
' into sync-from-ra
This commit is contained in:
parent
0816d49d83
commit
e41ab350d6
378 changed files with 14720 additions and 3111 deletions
132
crates/salsa/tests/parallel/cancellation.rs
Normal file
132
crates/salsa/tests/parallel/cancellation.rs
Normal file
|
@ -0,0 +1,132 @@
|
|||
use crate::setup::{CancellationFlag, Knobs, ParDatabase, ParDatabaseImpl, WithValue};
|
||||
use salsa::{Cancelled, ParallelDatabase};
|
||||
|
||||
macro_rules! assert_cancelled {
|
||||
($thread:expr) => {
|
||||
match $thread.join() {
|
||||
Ok(value) => panic!("expected cancellation, got {:?}", value),
|
||||
Err(payload) => match payload.downcast::<Cancelled>() {
|
||||
Ok(_) => {}
|
||||
Err(payload) => ::std::panic::resume_unwind(payload),
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Add test where a call to `sum` is cancelled by a simultaneous
|
||||
/// write. Check that we recompute the result in next revision, even
|
||||
/// though none of the inputs have changed.
|
||||
#[test]
|
||||
fn in_par_get_set_cancellation_immediate() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 100);
|
||||
db.set_input('b', 10);
|
||||
db.set_input('c', 1);
|
||||
db.set_input('d', 0);
|
||||
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
// This will not return until it sees cancellation is
|
||||
// signaled.
|
||||
db.knobs().sum_signal_on_entry.with_value(1, || {
|
||||
db.knobs()
|
||||
.sum_wait_for_cancellation
|
||||
.with_value(CancellationFlag::Panic, || db.sum("abc"))
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
// Wait until we have entered `sum` in the other thread.
|
||||
db.wait_for(1);
|
||||
|
||||
// Try to set the input. This will signal cancellation.
|
||||
db.set_input('d', 1000);
|
||||
|
||||
// This should re-compute the value (even though no input has changed).
|
||||
let thread2 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.sum("abc")
|
||||
});
|
||||
|
||||
assert_eq!(db.sum("d"), 1000);
|
||||
assert_cancelled!(thread1);
|
||||
assert_eq!(thread2.join().unwrap(), 111);
|
||||
}
|
||||
|
||||
/// Here, we check that `sum`'s cancellation is propagated
|
||||
/// to `sum2` properly.
|
||||
#[test]
|
||||
fn in_par_get_set_cancellation_transitive() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 100);
|
||||
db.set_input('b', 10);
|
||||
db.set_input('c', 1);
|
||||
db.set_input('d', 0);
|
||||
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
// This will not return until it sees cancellation is
|
||||
// signaled.
|
||||
db.knobs().sum_signal_on_entry.with_value(1, || {
|
||||
db.knobs()
|
||||
.sum_wait_for_cancellation
|
||||
.with_value(CancellationFlag::Panic, || db.sum2("abc"))
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
// Wait until we have entered `sum` in the other thread.
|
||||
db.wait_for(1);
|
||||
|
||||
// Try to set the input. This will signal cancellation.
|
||||
db.set_input('d', 1000);
|
||||
|
||||
// This should re-compute the value (even though no input has changed).
|
||||
let thread2 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.sum2("abc")
|
||||
});
|
||||
|
||||
assert_eq!(db.sum2("d"), 1000);
|
||||
assert_cancelled!(thread1);
|
||||
assert_eq!(thread2.join().unwrap(), 111);
|
||||
}
|
||||
|
||||
/// https://github.com/salsa-rs/salsa/issues/66
|
||||
#[test]
|
||||
fn no_back_dating_in_cancellation() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 1);
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
// Here we compute a long-chain of queries,
|
||||
// but the last one gets cancelled.
|
||||
db.knobs().sum_signal_on_entry.with_value(1, || {
|
||||
db.knobs()
|
||||
.sum_wait_for_cancellation
|
||||
.with_value(CancellationFlag::Panic, || db.sum3("a"))
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
db.wait_for(1);
|
||||
|
||||
// Set unrelated input to bump revision
|
||||
db.set_input('b', 2);
|
||||
|
||||
// Here we should recompuet the whole chain again, clearing the cancellation
|
||||
// state. If we get `usize::max()` here, it is a bug!
|
||||
assert_eq!(db.sum3("a"), 1);
|
||||
|
||||
assert_cancelled!(thread1);
|
||||
|
||||
db.set_input('a', 3);
|
||||
db.set_input('a', 4);
|
||||
assert_eq!(db.sum3("ab"), 6);
|
||||
}
|
57
crates/salsa/tests/parallel/frozen.rs
Normal file
57
crates/salsa/tests/parallel/frozen.rs
Normal file
|
@ -0,0 +1,57 @@
|
|||
use crate::setup::{ParDatabase, ParDatabaseImpl};
|
||||
use crate::signal::Signal;
|
||||
use salsa::{Database, ParallelDatabase};
|
||||
use std::{
|
||||
panic::{catch_unwind, AssertUnwindSafe},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
/// Add test where a call to `sum` is cancelled by a simultaneous
|
||||
/// write. Check that we recompute the result in next revision, even
|
||||
/// though none of the inputs have changed.
|
||||
#[test]
|
||||
fn in_par_get_set_cancellation() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 1);
|
||||
|
||||
let signal = Arc::new(Signal::default());
|
||||
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
let signal = signal.clone();
|
||||
move || {
|
||||
// Check that cancellation flag is not yet set, because
|
||||
// `set` cannot have been called yet.
|
||||
catch_unwind(AssertUnwindSafe(|| db.unwind_if_cancelled())).unwrap();
|
||||
|
||||
// Signal other thread to proceed.
|
||||
signal.signal(1);
|
||||
|
||||
// Wait for other thread to signal cancellation
|
||||
catch_unwind(AssertUnwindSafe(|| loop {
|
||||
db.unwind_if_cancelled();
|
||||
std::thread::yield_now();
|
||||
}))
|
||||
.unwrap_err();
|
||||
}
|
||||
});
|
||||
|
||||
let thread2 = std::thread::spawn({
|
||||
move || {
|
||||
// Wait until thread 1 has asserted that they are not cancelled
|
||||
// before we invoke `set.`
|
||||
signal.wait_for(1);
|
||||
|
||||
// This will block until thread1 drops the revision lock.
|
||||
db.set_input('a', 2);
|
||||
|
||||
db.input('a')
|
||||
}
|
||||
});
|
||||
|
||||
thread1.join().unwrap();
|
||||
|
||||
let c = thread2.join().unwrap();
|
||||
assert_eq!(c, 2);
|
||||
}
|
29
crates/salsa/tests/parallel/independent.rs
Normal file
29
crates/salsa/tests/parallel/independent.rs
Normal file
|
@ -0,0 +1,29 @@
|
|||
use crate::setup::{ParDatabase, ParDatabaseImpl};
|
||||
use salsa::ParallelDatabase;
|
||||
|
||||
/// Test two `sum` queries (on distinct keys) executing in different
|
||||
/// threads. Really just a test that `snapshot` etc compiles.
|
||||
#[test]
|
||||
fn in_par_two_independent_queries() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 100);
|
||||
db.set_input('b', 10);
|
||||
db.set_input('c', 1);
|
||||
db.set_input('d', 200);
|
||||
db.set_input('e', 20);
|
||||
db.set_input('f', 2);
|
||||
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.sum("abc")
|
||||
});
|
||||
|
||||
let thread2 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.sum("def")
|
||||
});
|
||||
|
||||
assert_eq!(thread1.join().unwrap(), 111);
|
||||
assert_eq!(thread2.join().unwrap(), 222);
|
||||
}
|
13
crates/salsa/tests/parallel/main.rs
Normal file
13
crates/salsa/tests/parallel/main.rs
Normal file
|
@ -0,0 +1,13 @@
|
|||
mod setup;
|
||||
|
||||
mod cancellation;
|
||||
mod frozen;
|
||||
mod independent;
|
||||
mod parallel_cycle_all_recover;
|
||||
mod parallel_cycle_mid_recover;
|
||||
mod parallel_cycle_none_recover;
|
||||
mod parallel_cycle_one_recovers;
|
||||
mod race;
|
||||
mod signal;
|
||||
mod stress;
|
||||
mod true_parallel;
|
110
crates/salsa/tests/parallel/parallel_cycle_all_recover.rs
Normal file
110
crates/salsa/tests/parallel/parallel_cycle_all_recover.rs
Normal file
|
@ -0,0 +1,110 @@
|
|||
//! Test for cycle recover spread across two threads.
|
||||
//! See `../cycles.rs` for a complete listing of cycle tests,
|
||||
//! both intra and cross thread.
|
||||
|
||||
use crate::setup::{Knobs, ParDatabaseImpl};
|
||||
use salsa::ParallelDatabase;
|
||||
use test_log::test;
|
||||
|
||||
// Recover cycle test:
|
||||
//
|
||||
// The pattern is as follows.
|
||||
//
|
||||
// Thread A Thread B
|
||||
// -------- --------
|
||||
// a1 b1
|
||||
// | wait for stage 1 (blocks)
|
||||
// signal stage 1 |
|
||||
// wait for stage 2 (blocks) (unblocked)
|
||||
// | signal stage 2
|
||||
// (unblocked) wait for stage 3 (blocks)
|
||||
// a2 |
|
||||
// b1 (blocks -> stage 3) |
|
||||
// | (unblocked)
|
||||
// | b2
|
||||
// | a1 (cycle detected, recovers)
|
||||
// | b2 completes, recovers
|
||||
// | b1 completes, recovers
|
||||
// a2 sees cycle, recovers
|
||||
// a1 completes, recovers
|
||||
|
||||
#[test]
|
||||
fn parallel_cycle_all_recover() {
|
||||
let db = ParDatabaseImpl::default();
|
||||
db.knobs().signal_on_will_block.set(3);
|
||||
|
||||
let thread_a = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.a1(1)
|
||||
});
|
||||
|
||||
let thread_b = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.b1(1)
|
||||
});
|
||||
|
||||
assert_eq!(thread_a.join().unwrap(), 11);
|
||||
assert_eq!(thread_b.join().unwrap(), 21);
|
||||
}
|
||||
|
||||
#[salsa::query_group(ParallelCycleAllRecover)]
|
||||
pub(crate) trait TestDatabase: Knobs {
|
||||
#[salsa::cycle(recover_a1)]
|
||||
fn a1(&self, key: i32) -> i32;
|
||||
|
||||
#[salsa::cycle(recover_a2)]
|
||||
fn a2(&self, key: i32) -> i32;
|
||||
|
||||
#[salsa::cycle(recover_b1)]
|
||||
fn b1(&self, key: i32) -> i32;
|
||||
|
||||
#[salsa::cycle(recover_b2)]
|
||||
fn b2(&self, key: i32) -> i32;
|
||||
}
|
||||
|
||||
fn recover_a1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
|
||||
tracing::debug!("recover_a1");
|
||||
key * 10 + 1
|
||||
}
|
||||
|
||||
fn recover_a2(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
|
||||
tracing::debug!("recover_a2");
|
||||
key * 10 + 2
|
||||
}
|
||||
|
||||
fn recover_b1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
|
||||
tracing::debug!("recover_b1");
|
||||
key * 20 + 1
|
||||
}
|
||||
|
||||
fn recover_b2(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
|
||||
tracing::debug!("recover_b2");
|
||||
key * 20 + 2
|
||||
}
|
||||
|
||||
fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// Wait to create the cycle until both threads have entered
|
||||
db.signal(1);
|
||||
db.wait_for(2);
|
||||
|
||||
db.a2(key)
|
||||
}
|
||||
|
||||
fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
db.b1(key)
|
||||
}
|
||||
|
||||
fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// Wait to create the cycle until both threads have entered
|
||||
db.wait_for(1);
|
||||
db.signal(2);
|
||||
|
||||
// Wait for thread A to block on this thread
|
||||
db.wait_for(3);
|
||||
|
||||
db.b2(key)
|
||||
}
|
||||
|
||||
fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
db.a1(key)
|
||||
}
|
110
crates/salsa/tests/parallel/parallel_cycle_mid_recover.rs
Normal file
110
crates/salsa/tests/parallel/parallel_cycle_mid_recover.rs
Normal file
|
@ -0,0 +1,110 @@
|
|||
//! Test for cycle recover spread across two threads.
|
||||
//! See `../cycles.rs` for a complete listing of cycle tests,
|
||||
//! both intra and cross thread.
|
||||
|
||||
use crate::setup::{Knobs, ParDatabaseImpl};
|
||||
use salsa::ParallelDatabase;
|
||||
use test_log::test;
|
||||
|
||||
// Recover cycle test:
|
||||
//
|
||||
// The pattern is as follows.
|
||||
//
|
||||
// Thread A Thread B
|
||||
// -------- --------
|
||||
// a1 b1
|
||||
// | wait for stage 1 (blocks)
|
||||
// signal stage 1 |
|
||||
// wait for stage 2 (blocks) (unblocked)
|
||||
// | |
|
||||
// | b2
|
||||
// | b3
|
||||
// | a1 (blocks -> stage 2)
|
||||
// (unblocked) |
|
||||
// a2 (cycle detected) |
|
||||
// b3 recovers
|
||||
// b2 resumes
|
||||
// b1 panics because bug
|
||||
|
||||
#[test]
|
||||
fn parallel_cycle_mid_recovers() {
|
||||
let db = ParDatabaseImpl::default();
|
||||
db.knobs().signal_on_will_block.set(2);
|
||||
|
||||
let thread_a = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.a1(1)
|
||||
});
|
||||
|
||||
let thread_b = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.b1(1)
|
||||
});
|
||||
|
||||
// We expect that the recovery function yields
|
||||
// `1 * 20 + 2`, which is returned (and forwarded)
|
||||
// to b1, and from there to a2 and a1.
|
||||
assert_eq!(thread_a.join().unwrap(), 22);
|
||||
assert_eq!(thread_b.join().unwrap(), 22);
|
||||
}
|
||||
|
||||
#[salsa::query_group(ParallelCycleMidRecovers)]
|
||||
pub(crate) trait TestDatabase: Knobs {
|
||||
fn a1(&self, key: i32) -> i32;
|
||||
|
||||
fn a2(&self, key: i32) -> i32;
|
||||
|
||||
#[salsa::cycle(recover_b1)]
|
||||
fn b1(&self, key: i32) -> i32;
|
||||
|
||||
fn b2(&self, key: i32) -> i32;
|
||||
|
||||
#[salsa::cycle(recover_b3)]
|
||||
fn b3(&self, key: i32) -> i32;
|
||||
}
|
||||
|
||||
fn recover_b1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
|
||||
tracing::debug!("recover_b1");
|
||||
key * 20 + 2
|
||||
}
|
||||
|
||||
fn recover_b3(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
|
||||
tracing::debug!("recover_b1");
|
||||
key * 200 + 2
|
||||
}
|
||||
|
||||
fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// tell thread b we have started
|
||||
db.signal(1);
|
||||
|
||||
// wait for thread b to block on a1
|
||||
db.wait_for(2);
|
||||
|
||||
db.a2(key)
|
||||
}
|
||||
|
||||
fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// create the cycle
|
||||
db.b1(key)
|
||||
}
|
||||
|
||||
fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// wait for thread a to have started
|
||||
db.wait_for(1);
|
||||
|
||||
db.b2(key);
|
||||
|
||||
0
|
||||
}
|
||||
|
||||
fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// will encounter a cycle but recover
|
||||
db.b3(key);
|
||||
db.b1(key); // hasn't recovered yet
|
||||
0
|
||||
}
|
||||
|
||||
fn b3(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// will block on thread a, signaling stage 2
|
||||
db.a1(key)
|
||||
}
|
69
crates/salsa/tests/parallel/parallel_cycle_none_recover.rs
Normal file
69
crates/salsa/tests/parallel/parallel_cycle_none_recover.rs
Normal file
|
@ -0,0 +1,69 @@
|
|||
//! Test a cycle where no queries recover that occurs across threads.
|
||||
//! See the `../cycles.rs` for a complete listing of cycle tests,
|
||||
//! both intra and cross thread.
|
||||
|
||||
use crate::setup::{Knobs, ParDatabaseImpl};
|
||||
use expect_test::expect;
|
||||
use salsa::ParallelDatabase;
|
||||
use test_log::test;
|
||||
|
||||
#[test]
|
||||
fn parallel_cycle_none_recover() {
|
||||
let db = ParDatabaseImpl::default();
|
||||
db.knobs().signal_on_will_block.set(3);
|
||||
|
||||
let thread_a = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.a(-1)
|
||||
});
|
||||
|
||||
let thread_b = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.b(-1)
|
||||
});
|
||||
|
||||
// We expect B to panic because it detects a cycle (it is the one that calls A, ultimately).
|
||||
// Right now, it panics with a string.
|
||||
let err_b = thread_b.join().unwrap_err();
|
||||
if let Some(c) = err_b.downcast_ref::<salsa::Cycle>() {
|
||||
expect![[r#"
|
||||
[
|
||||
"a(-1)",
|
||||
"b(-1)",
|
||||
]
|
||||
"#]]
|
||||
.assert_debug_eq(&c.unexpected_participants(&db));
|
||||
} else {
|
||||
panic!("b failed in an unexpected way: {:?}", err_b);
|
||||
}
|
||||
|
||||
// We expect A to propagate a panic, which causes us to use the sentinel
|
||||
// type `Canceled`.
|
||||
assert!(thread_a.join().unwrap_err().downcast_ref::<salsa::Cycle>().is_some());
|
||||
}
|
||||
|
||||
#[salsa::query_group(ParallelCycleNoneRecover)]
|
||||
pub(crate) trait TestDatabase: Knobs {
|
||||
fn a(&self, key: i32) -> i32;
|
||||
fn b(&self, key: i32) -> i32;
|
||||
}
|
||||
|
||||
fn a(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// Wait to create the cycle until both threads have entered
|
||||
db.signal(1);
|
||||
db.wait_for(2);
|
||||
|
||||
db.b(key)
|
||||
}
|
||||
|
||||
fn b(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// Wait to create the cycle until both threads have entered
|
||||
db.wait_for(1);
|
||||
db.signal(2);
|
||||
|
||||
// Wait for thread A to block on this thread
|
||||
db.wait_for(3);
|
||||
|
||||
// Now try to execute A
|
||||
db.a(key)
|
||||
}
|
95
crates/salsa/tests/parallel/parallel_cycle_one_recovers.rs
Normal file
95
crates/salsa/tests/parallel/parallel_cycle_one_recovers.rs
Normal file
|
@ -0,0 +1,95 @@
|
|||
//! Test for cycle recover spread across two threads.
|
||||
//! See `../cycles.rs` for a complete listing of cycle tests,
|
||||
//! both intra and cross thread.
|
||||
|
||||
use crate::setup::{Knobs, ParDatabaseImpl};
|
||||
use salsa::ParallelDatabase;
|
||||
use test_log::test;
|
||||
|
||||
// Recover cycle test:
|
||||
//
|
||||
// The pattern is as follows.
|
||||
//
|
||||
// Thread A Thread B
|
||||
// -------- --------
|
||||
// a1 b1
|
||||
// | wait for stage 1 (blocks)
|
||||
// signal stage 1 |
|
||||
// wait for stage 2 (blocks) (unblocked)
|
||||
// | signal stage 2
|
||||
// (unblocked) wait for stage 3 (blocks)
|
||||
// a2 |
|
||||
// b1 (blocks -> stage 3) |
|
||||
// | (unblocked)
|
||||
// | b2
|
||||
// | a1 (cycle detected)
|
||||
// a2 recovery fn executes |
|
||||
// a1 completes normally |
|
||||
// b2 completes, recovers
|
||||
// b1 completes, recovers
|
||||
|
||||
#[test]
|
||||
fn parallel_cycle_one_recovers() {
|
||||
let db = ParDatabaseImpl::default();
|
||||
db.knobs().signal_on_will_block.set(3);
|
||||
|
||||
let thread_a = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.a1(1)
|
||||
});
|
||||
|
||||
let thread_b = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.b1(1)
|
||||
});
|
||||
|
||||
// We expect that the recovery function yields
|
||||
// `1 * 20 + 2`, which is returned (and forwarded)
|
||||
// to b1, and from there to a2 and a1.
|
||||
assert_eq!(thread_a.join().unwrap(), 22);
|
||||
assert_eq!(thread_b.join().unwrap(), 22);
|
||||
}
|
||||
|
||||
#[salsa::query_group(ParallelCycleOneRecovers)]
|
||||
pub(crate) trait TestDatabase: Knobs {
|
||||
fn a1(&self, key: i32) -> i32;
|
||||
|
||||
#[salsa::cycle(recover)]
|
||||
fn a2(&self, key: i32) -> i32;
|
||||
|
||||
fn b1(&self, key: i32) -> i32;
|
||||
|
||||
fn b2(&self, key: i32) -> i32;
|
||||
}
|
||||
|
||||
fn recover(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
|
||||
tracing::debug!("recover");
|
||||
key * 20 + 2
|
||||
}
|
||||
|
||||
fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// Wait to create the cycle until both threads have entered
|
||||
db.signal(1);
|
||||
db.wait_for(2);
|
||||
|
||||
db.a2(key)
|
||||
}
|
||||
|
||||
fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
db.b1(key)
|
||||
}
|
||||
|
||||
fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
// Wait to create the cycle until both threads have entered
|
||||
db.wait_for(1);
|
||||
db.signal(2);
|
||||
|
||||
// Wait for thread A to block on this thread
|
||||
db.wait_for(3);
|
||||
|
||||
db.b2(key)
|
||||
}
|
||||
|
||||
fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
|
||||
db.a1(key)
|
||||
}
|
37
crates/salsa/tests/parallel/race.rs
Normal file
37
crates/salsa/tests/parallel/race.rs
Normal file
|
@ -0,0 +1,37 @@
|
|||
use std::panic::AssertUnwindSafe;
|
||||
|
||||
use crate::setup::{ParDatabase, ParDatabaseImpl};
|
||||
use salsa::{Cancelled, ParallelDatabase};
|
||||
|
||||
/// Test where a read and a set are racing with one another.
|
||||
/// Should be atomic.
|
||||
#[test]
|
||||
fn in_par_get_set_race() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 100);
|
||||
db.set_input('b', 10);
|
||||
db.set_input('c', 1);
|
||||
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || Cancelled::catch(AssertUnwindSafe(|| db.sum("abc")))
|
||||
});
|
||||
|
||||
let thread2 = std::thread::spawn(move || {
|
||||
db.set_input('a', 1000);
|
||||
db.sum("a")
|
||||
});
|
||||
|
||||
// If the 1st thread runs first, you get 111, otherwise you get
|
||||
// 1011; if they run concurrently and the 1st thread observes the
|
||||
// cancellation, it'll unwind.
|
||||
let result1 = thread1.join().unwrap();
|
||||
if let Ok(value1) = result1 {
|
||||
assert!(value1 == 111 || value1 == 1011, "illegal result {}", value1);
|
||||
}
|
||||
|
||||
// thread2 can not observe a cancellation because it performs a
|
||||
// database write before running any other queries.
|
||||
assert_eq!(thread2.join().unwrap(), 1000);
|
||||
}
|
197
crates/salsa/tests/parallel/setup.rs
Normal file
197
crates/salsa/tests/parallel/setup.rs
Normal file
|
@ -0,0 +1,197 @@
|
|||
use crate::signal::Signal;
|
||||
use salsa::Database;
|
||||
use salsa::ParallelDatabase;
|
||||
use salsa::Snapshot;
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
cell::Cell,
|
||||
panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
|
||||
};
|
||||
|
||||
#[salsa::query_group(Par)]
|
||||
pub(crate) trait ParDatabase: Knobs {
|
||||
#[salsa::input]
|
||||
fn input(&self, key: char) -> usize;
|
||||
|
||||
fn sum(&self, key: &'static str) -> usize;
|
||||
|
||||
/// Invokes `sum`
|
||||
fn sum2(&self, key: &'static str) -> usize;
|
||||
|
||||
/// Invokes `sum` but doesn't really care about the result.
|
||||
fn sum2_drop_sum(&self, key: &'static str) -> usize;
|
||||
|
||||
/// Invokes `sum2`
|
||||
fn sum3(&self, key: &'static str) -> usize;
|
||||
|
||||
/// Invokes `sum2_drop_sum`
|
||||
fn sum3_drop_sum(&self, key: &'static str) -> usize;
|
||||
}
|
||||
|
||||
/// Various "knobs" and utilities used by tests to force
|
||||
/// a certain behavior.
|
||||
pub(crate) trait Knobs {
|
||||
fn knobs(&self) -> &KnobsStruct;
|
||||
|
||||
fn signal(&self, stage: usize);
|
||||
|
||||
fn wait_for(&self, stage: usize);
|
||||
}
|
||||
|
||||
pub(crate) trait WithValue<T> {
|
||||
fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R;
|
||||
}
|
||||
|
||||
impl<T> WithValue<T> for Cell<T> {
|
||||
fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R {
|
||||
let old_value = self.replace(value);
|
||||
|
||||
let result = catch_unwind(AssertUnwindSafe(closure));
|
||||
|
||||
self.set(old_value);
|
||||
|
||||
match result {
|
||||
Ok(r) => r,
|
||||
Err(payload) => resume_unwind(payload),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum CancellationFlag {
|
||||
#[default]
|
||||
Down,
|
||||
Panic,
|
||||
}
|
||||
|
||||
/// Various "knobs" that can be used to customize how the queries
|
||||
/// behave on one specific thread. Note that this state is
|
||||
/// intentionally thread-local (apart from `signal`).
|
||||
#[derive(Clone, Default)]
|
||||
pub(crate) struct KnobsStruct {
|
||||
/// A kind of flexible barrier used to coordinate execution across
|
||||
/// threads to ensure we reach various weird states.
|
||||
pub(crate) signal: Arc<Signal>,
|
||||
|
||||
/// When this database is about to block, send a signal.
|
||||
pub(crate) signal_on_will_block: Cell<usize>,
|
||||
|
||||
/// Invocations of `sum` will signal this stage on entry.
|
||||
pub(crate) sum_signal_on_entry: Cell<usize>,
|
||||
|
||||
/// Invocations of `sum` will wait for this stage on entry.
|
||||
pub(crate) sum_wait_for_on_entry: Cell<usize>,
|
||||
|
||||
/// If true, invocations of `sum` will panic before they exit.
|
||||
pub(crate) sum_should_panic: Cell<bool>,
|
||||
|
||||
/// If true, invocations of `sum` will wait for cancellation before
|
||||
/// they exit.
|
||||
pub(crate) sum_wait_for_cancellation: Cell<CancellationFlag>,
|
||||
|
||||
/// Invocations of `sum` will wait for this stage prior to exiting.
|
||||
pub(crate) sum_wait_for_on_exit: Cell<usize>,
|
||||
|
||||
/// Invocations of `sum` will signal this stage prior to exiting.
|
||||
pub(crate) sum_signal_on_exit: Cell<usize>,
|
||||
|
||||
/// Invocations of `sum3_drop_sum` will panic unconditionally
|
||||
pub(crate) sum3_drop_sum_should_panic: Cell<bool>,
|
||||
}
|
||||
|
||||
fn sum(db: &dyn ParDatabase, key: &'static str) -> usize {
|
||||
let mut sum = 0;
|
||||
|
||||
db.signal(db.knobs().sum_signal_on_entry.get());
|
||||
|
||||
db.wait_for(db.knobs().sum_wait_for_on_entry.get());
|
||||
|
||||
if db.knobs().sum_should_panic.get() {
|
||||
panic!("query set to panic before exit")
|
||||
}
|
||||
|
||||
for ch in key.chars() {
|
||||
sum += db.input(ch);
|
||||
}
|
||||
|
||||
match db.knobs().sum_wait_for_cancellation.get() {
|
||||
CancellationFlag::Down => (),
|
||||
CancellationFlag::Panic => {
|
||||
tracing::debug!("waiting for cancellation");
|
||||
loop {
|
||||
db.unwind_if_cancelled();
|
||||
std::thread::yield_now();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.wait_for(db.knobs().sum_wait_for_on_exit.get());
|
||||
|
||||
db.signal(db.knobs().sum_signal_on_exit.get());
|
||||
|
||||
sum
|
||||
}
|
||||
|
||||
fn sum2(db: &dyn ParDatabase, key: &'static str) -> usize {
|
||||
db.sum(key)
|
||||
}
|
||||
|
||||
fn sum2_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize {
|
||||
let _ = db.sum(key);
|
||||
22
|
||||
}
|
||||
|
||||
fn sum3(db: &dyn ParDatabase, key: &'static str) -> usize {
|
||||
db.sum2(key)
|
||||
}
|
||||
|
||||
fn sum3_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize {
|
||||
if db.knobs().sum3_drop_sum_should_panic.get() {
|
||||
panic!("sum3_drop_sum executed")
|
||||
}
|
||||
db.sum2_drop_sum(key)
|
||||
}
|
||||
|
||||
#[salsa::database(
|
||||
Par,
|
||||
crate::parallel_cycle_all_recover::ParallelCycleAllRecover,
|
||||
crate::parallel_cycle_none_recover::ParallelCycleNoneRecover,
|
||||
crate::parallel_cycle_mid_recover::ParallelCycleMidRecovers,
|
||||
crate::parallel_cycle_one_recovers::ParallelCycleOneRecovers
|
||||
)]
|
||||
#[derive(Default)]
|
||||
pub(crate) struct ParDatabaseImpl {
|
||||
storage: salsa::Storage<Self>,
|
||||
knobs: KnobsStruct,
|
||||
}
|
||||
|
||||
impl Database for ParDatabaseImpl {
|
||||
fn salsa_event(&self, event: salsa::Event) {
|
||||
if let salsa::EventKind::WillBlockOn { .. } = event.kind {
|
||||
self.signal(self.knobs().signal_on_will_block.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ParallelDatabase for ParDatabaseImpl {
|
||||
fn snapshot(&self) -> Snapshot<Self> {
|
||||
Snapshot::new(ParDatabaseImpl {
|
||||
storage: self.storage.snapshot(),
|
||||
knobs: self.knobs.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Knobs for ParDatabaseImpl {
|
||||
fn knobs(&self) -> &KnobsStruct {
|
||||
&self.knobs
|
||||
}
|
||||
|
||||
fn signal(&self, stage: usize) {
|
||||
self.knobs.signal.signal(stage);
|
||||
}
|
||||
|
||||
fn wait_for(&self, stage: usize) {
|
||||
self.knobs.signal.wait_for(stage);
|
||||
}
|
||||
}
|
40
crates/salsa/tests/parallel/signal.rs
Normal file
40
crates/salsa/tests/parallel/signal.rs
Normal file
|
@ -0,0 +1,40 @@
|
|||
use parking_lot::{Condvar, Mutex};
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct Signal {
|
||||
value: Mutex<usize>,
|
||||
cond_var: Condvar,
|
||||
}
|
||||
|
||||
impl Signal {
|
||||
pub(crate) fn signal(&self, stage: usize) {
|
||||
tracing::debug!("signal({})", stage);
|
||||
|
||||
// This check avoids acquiring the lock for things that will
|
||||
// clearly be a no-op. Not *necessary* but helps to ensure we
|
||||
// are more likely to encounter weird race conditions;
|
||||
// otherwise calls to `sum` will tend to be unnecessarily
|
||||
// synchronous.
|
||||
if stage > 0 {
|
||||
let mut v = self.value.lock();
|
||||
if stage > *v {
|
||||
*v = stage;
|
||||
self.cond_var.notify_all();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Waits until the given condition is true; the fn is invoked
|
||||
/// with the current stage.
|
||||
pub(crate) fn wait_for(&self, stage: usize) {
|
||||
tracing::debug!("wait_for({})", stage);
|
||||
|
||||
// As above, avoid lock if clearly a no-op.
|
||||
if stage > 0 {
|
||||
let mut v = self.value.lock();
|
||||
while *v < stage {
|
||||
self.cond_var.wait(&mut v);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
168
crates/salsa/tests/parallel/stress.rs
Normal file
168
crates/salsa/tests/parallel/stress.rs
Normal file
|
@ -0,0 +1,168 @@
|
|||
use rand::seq::SliceRandom;
|
||||
use rand::Rng;
|
||||
|
||||
use salsa::ParallelDatabase;
|
||||
use salsa::Snapshot;
|
||||
use salsa::{Cancelled, Database};
|
||||
|
||||
// Number of operations a reader performs
|
||||
const N_MUTATOR_OPS: usize = 100;
|
||||
const N_READER_OPS: usize = 100;
|
||||
|
||||
#[salsa::query_group(Stress)]
|
||||
trait StressDatabase: salsa::Database {
|
||||
#[salsa::input]
|
||||
fn a(&self, key: usize) -> usize;
|
||||
|
||||
fn b(&self, key: usize) -> usize;
|
||||
|
||||
fn c(&self, key: usize) -> usize;
|
||||
}
|
||||
|
||||
fn b(db: &dyn StressDatabase, key: usize) -> usize {
|
||||
db.unwind_if_cancelled();
|
||||
db.a(key)
|
||||
}
|
||||
|
||||
fn c(db: &dyn StressDatabase, key: usize) -> usize {
|
||||
db.b(key)
|
||||
}
|
||||
|
||||
#[salsa::database(Stress)]
|
||||
#[derive(Default)]
|
||||
struct StressDatabaseImpl {
|
||||
storage: salsa::Storage<Self>,
|
||||
}
|
||||
|
||||
impl salsa::Database for StressDatabaseImpl {}
|
||||
|
||||
impl salsa::ParallelDatabase for StressDatabaseImpl {
|
||||
fn snapshot(&self) -> Snapshot<StressDatabaseImpl> {
|
||||
Snapshot::new(StressDatabaseImpl { storage: self.storage.snapshot() })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
enum Query {
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
}
|
||||
|
||||
enum MutatorOp {
|
||||
WriteOp(WriteOp),
|
||||
LaunchReader { ops: Vec<ReadOp>, check_cancellation: bool },
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum WriteOp {
|
||||
SetA(usize, usize),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ReadOp {
|
||||
Get(Query, usize),
|
||||
}
|
||||
|
||||
impl rand::distributions::Distribution<Query> for rand::distributions::Standard {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Query {
|
||||
*[Query::A, Query::B, Query::C].choose(rng).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl rand::distributions::Distribution<MutatorOp> for rand::distributions::Standard {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> MutatorOp {
|
||||
if rng.gen_bool(0.5) {
|
||||
MutatorOp::WriteOp(rng.gen())
|
||||
} else {
|
||||
MutatorOp::LaunchReader {
|
||||
ops: (0..N_READER_OPS).map(|_| rng.gen()).collect(),
|
||||
check_cancellation: rng.gen(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl rand::distributions::Distribution<WriteOp> for rand::distributions::Standard {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> WriteOp {
|
||||
let key = rng.gen::<usize>() % 10;
|
||||
let value = rng.gen::<usize>() % 10;
|
||||
WriteOp::SetA(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
impl rand::distributions::Distribution<ReadOp> for rand::distributions::Standard {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> ReadOp {
|
||||
let query = rng.gen::<Query>();
|
||||
let key = rng.gen::<usize>() % 10;
|
||||
ReadOp::Get(query, key)
|
||||
}
|
||||
}
|
||||
|
||||
fn db_reader_thread(db: &StressDatabaseImpl, ops: Vec<ReadOp>, check_cancellation: bool) {
|
||||
for op in ops {
|
||||
if check_cancellation {
|
||||
db.unwind_if_cancelled();
|
||||
}
|
||||
op.execute(db);
|
||||
}
|
||||
}
|
||||
|
||||
impl WriteOp {
|
||||
fn execute(self, db: &mut StressDatabaseImpl) {
|
||||
match self {
|
||||
WriteOp::SetA(key, value) => {
|
||||
db.set_a(key, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReadOp {
|
||||
fn execute(self, db: &StressDatabaseImpl) {
|
||||
match self {
|
||||
ReadOp::Get(query, key) => match query {
|
||||
Query::A => {
|
||||
db.a(key);
|
||||
}
|
||||
Query::B => {
|
||||
let _ = db.b(key);
|
||||
}
|
||||
Query::C => {
|
||||
let _ = db.c(key);
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stress_test() {
|
||||
let mut db = StressDatabaseImpl::default();
|
||||
for i in 0..10 {
|
||||
db.set_a(i, i);
|
||||
}
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// generate the ops that the mutator thread will perform
|
||||
let write_ops: Vec<MutatorOp> = (0..N_MUTATOR_OPS).map(|_| rng.gen()).collect();
|
||||
|
||||
// execute the "main thread", which sometimes snapshots off other threads
|
||||
let mut all_threads = vec![];
|
||||
for op in write_ops {
|
||||
match op {
|
||||
MutatorOp::WriteOp(w) => w.execute(&mut db),
|
||||
MutatorOp::LaunchReader { ops, check_cancellation } => {
|
||||
all_threads.push(std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || Cancelled::catch(|| db_reader_thread(&db, ops, check_cancellation))
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for thread in all_threads {
|
||||
thread.join().unwrap().ok();
|
||||
}
|
||||
}
|
125
crates/salsa/tests/parallel/true_parallel.rs
Normal file
125
crates/salsa/tests/parallel/true_parallel.rs
Normal file
|
@ -0,0 +1,125 @@
|
|||
use crate::setup::{Knobs, ParDatabase, ParDatabaseImpl, WithValue};
|
||||
use salsa::ParallelDatabase;
|
||||
use std::panic::{self, AssertUnwindSafe};
|
||||
|
||||
/// Test where two threads are executing sum. We show that they can
|
||||
/// both be executing sum in parallel by having thread1 wait for
|
||||
/// thread2 to send a signal before it leaves (similarly, thread2
|
||||
/// waits for thread1 to send a signal before it enters).
|
||||
#[test]
|
||||
fn true_parallel_different_keys() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 100);
|
||||
db.set_input('b', 10);
|
||||
db.set_input('c', 1);
|
||||
|
||||
// Thread 1 will signal stage 1 when it enters and wait for stage 2.
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
let v = db
|
||||
.knobs()
|
||||
.sum_signal_on_entry
|
||||
.with_value(1, || db.knobs().sum_wait_for_on_exit.with_value(2, || db.sum("a")));
|
||||
v
|
||||
}
|
||||
});
|
||||
|
||||
// Thread 2 will wait_for stage 1 when it enters and signal stage 2
|
||||
// when it leaves.
|
||||
let thread2 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
let v = db
|
||||
.knobs()
|
||||
.sum_wait_for_on_entry
|
||||
.with_value(1, || db.knobs().sum_signal_on_exit.with_value(2, || db.sum("b")));
|
||||
v
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(thread1.join().unwrap(), 100);
|
||||
assert_eq!(thread2.join().unwrap(), 10);
|
||||
}
|
||||
|
||||
/// Add a test that tries to trigger a conflict, where we fetch
|
||||
/// `sum("abc")` from two threads simultaneously, and of them
|
||||
/// therefore has to block.
|
||||
#[test]
|
||||
fn true_parallel_same_keys() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 100);
|
||||
db.set_input('b', 10);
|
||||
db.set_input('c', 1);
|
||||
|
||||
// Thread 1 will wait_for a barrier in the start of `sum`
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
let v = db
|
||||
.knobs()
|
||||
.sum_signal_on_entry
|
||||
.with_value(1, || db.knobs().sum_wait_for_on_entry.with_value(2, || db.sum("abc")));
|
||||
v
|
||||
}
|
||||
});
|
||||
|
||||
// Thread 2 will wait until Thread 1 has entered sum and then --
|
||||
// once it has set itself to block -- signal Thread 1 to
|
||||
// continue. This way, we test out the mechanism of one thread
|
||||
// blocking on another.
|
||||
let thread2 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
db.knobs().signal.wait_for(1);
|
||||
db.knobs().signal_on_will_block.set(2);
|
||||
db.sum("abc")
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(thread1.join().unwrap(), 111);
|
||||
assert_eq!(thread2.join().unwrap(), 111);
|
||||
}
|
||||
|
||||
/// Add a test that tries to trigger a conflict, where we fetch `sum("a")`
|
||||
/// from two threads simultaneously. After `thread2` begins blocking,
|
||||
/// we force `thread1` to panic and should see that propagate to `thread2`.
|
||||
#[test]
|
||||
fn true_parallel_propagate_panic() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 1);
|
||||
|
||||
// `thread1` will wait_for a barrier in the start of `sum`. Once it can
|
||||
// continue, it will panic.
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
let v = db.knobs().sum_signal_on_entry.with_value(1, || {
|
||||
db.knobs()
|
||||
.sum_wait_for_on_entry
|
||||
.with_value(2, || db.knobs().sum_should_panic.with_value(true, || db.sum("a")))
|
||||
});
|
||||
v
|
||||
}
|
||||
});
|
||||
|
||||
// `thread2` will wait until `thread1` has entered sum and then -- once it
|
||||
// has set itself to block -- signal `thread1` to continue.
|
||||
let thread2 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
db.knobs().signal.wait_for(1);
|
||||
db.knobs().signal_on_will_block.set(2);
|
||||
db.sum("a")
|
||||
}
|
||||
});
|
||||
|
||||
let result1 = panic::catch_unwind(AssertUnwindSafe(|| thread1.join().unwrap()));
|
||||
let result2 = panic::catch_unwind(AssertUnwindSafe(|| thread2.join().unwrap()));
|
||||
|
||||
assert!(result1.is_err());
|
||||
assert!(result2.is_err());
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue