Merge commit 'ddf105b646' into sync-from-ra

This commit is contained in:
Laurențiu Nicola 2024-02-11 08:40:19 +02:00
parent 0816d49d83
commit e41ab350d6
378 changed files with 14720 additions and 3111 deletions

View file

@ -0,0 +1,132 @@
use crate::setup::{CancellationFlag, Knobs, ParDatabase, ParDatabaseImpl, WithValue};
use salsa::{Cancelled, ParallelDatabase};
macro_rules! assert_cancelled {
($thread:expr) => {
match $thread.join() {
Ok(value) => panic!("expected cancellation, got {:?}", value),
Err(payload) => match payload.downcast::<Cancelled>() {
Ok(_) => {}
Err(payload) => ::std::panic::resume_unwind(payload),
},
}
};
}
/// Add test where a call to `sum` is cancelled by a simultaneous
/// write. Check that we recompute the result in next revision, even
/// though none of the inputs have changed.
#[test]
fn in_par_get_set_cancellation_immediate() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
db.set_input('d', 0);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
// This will not return until it sees cancellation is
// signaled.
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(CancellationFlag::Panic, || db.sum("abc"))
})
}
});
// Wait until we have entered `sum` in the other thread.
db.wait_for(1);
// Try to set the input. This will signal cancellation.
db.set_input('d', 1000);
// This should re-compute the value (even though no input has changed).
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || db.sum("abc")
});
assert_eq!(db.sum("d"), 1000);
assert_cancelled!(thread1);
assert_eq!(thread2.join().unwrap(), 111);
}
/// Here, we check that `sum`'s cancellation is propagated
/// to `sum2` properly.
#[test]
fn in_par_get_set_cancellation_transitive() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
db.set_input('d', 0);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
// This will not return until it sees cancellation is
// signaled.
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(CancellationFlag::Panic, || db.sum2("abc"))
})
}
});
// Wait until we have entered `sum` in the other thread.
db.wait_for(1);
// Try to set the input. This will signal cancellation.
db.set_input('d', 1000);
// This should re-compute the value (even though no input has changed).
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || db.sum2("abc")
});
assert_eq!(db.sum2("d"), 1000);
assert_cancelled!(thread1);
assert_eq!(thread2.join().unwrap(), 111);
}
/// https://github.com/salsa-rs/salsa/issues/66
#[test]
fn no_back_dating_in_cancellation() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 1);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
// Here we compute a long-chain of queries,
// but the last one gets cancelled.
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(CancellationFlag::Panic, || db.sum3("a"))
})
}
});
db.wait_for(1);
// Set unrelated input to bump revision
db.set_input('b', 2);
// Here we should recompuet the whole chain again, clearing the cancellation
// state. If we get `usize::max()` here, it is a bug!
assert_eq!(db.sum3("a"), 1);
assert_cancelled!(thread1);
db.set_input('a', 3);
db.set_input('a', 4);
assert_eq!(db.sum3("ab"), 6);
}

View file

@ -0,0 +1,57 @@
use crate::setup::{ParDatabase, ParDatabaseImpl};
use crate::signal::Signal;
use salsa::{Database, ParallelDatabase};
use std::{
panic::{catch_unwind, AssertUnwindSafe},
sync::Arc,
};
/// Add test where a call to `sum` is cancelled by a simultaneous
/// write. Check that we recompute the result in next revision, even
/// though none of the inputs have changed.
#[test]
fn in_par_get_set_cancellation() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 1);
let signal = Arc::new(Signal::default());
let thread1 = std::thread::spawn({
let db = db.snapshot();
let signal = signal.clone();
move || {
// Check that cancellation flag is not yet set, because
// `set` cannot have been called yet.
catch_unwind(AssertUnwindSafe(|| db.unwind_if_cancelled())).unwrap();
// Signal other thread to proceed.
signal.signal(1);
// Wait for other thread to signal cancellation
catch_unwind(AssertUnwindSafe(|| loop {
db.unwind_if_cancelled();
std::thread::yield_now();
}))
.unwrap_err();
}
});
let thread2 = std::thread::spawn({
move || {
// Wait until thread 1 has asserted that they are not cancelled
// before we invoke `set.`
signal.wait_for(1);
// This will block until thread1 drops the revision lock.
db.set_input('a', 2);
db.input('a')
}
});
thread1.join().unwrap();
let c = thread2.join().unwrap();
assert_eq!(c, 2);
}

View file

@ -0,0 +1,29 @@
use crate::setup::{ParDatabase, ParDatabaseImpl};
use salsa::ParallelDatabase;
/// Test two `sum` queries (on distinct keys) executing in different
/// threads. Really just a test that `snapshot` etc compiles.
#[test]
fn in_par_two_independent_queries() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
db.set_input('d', 200);
db.set_input('e', 20);
db.set_input('f', 2);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || db.sum("abc")
});
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || db.sum("def")
});
assert_eq!(thread1.join().unwrap(), 111);
assert_eq!(thread2.join().unwrap(), 222);
}

View file

@ -0,0 +1,13 @@
mod setup;
mod cancellation;
mod frozen;
mod independent;
mod parallel_cycle_all_recover;
mod parallel_cycle_mid_recover;
mod parallel_cycle_none_recover;
mod parallel_cycle_one_recovers;
mod race;
mod signal;
mod stress;
mod true_parallel;

View file

@ -0,0 +1,110 @@
//! Test for cycle recover spread across two threads.
//! See `../cycles.rs` for a complete listing of cycle tests,
//! both intra and cross thread.
use crate::setup::{Knobs, ParDatabaseImpl};
use salsa::ParallelDatabase;
use test_log::test;
// Recover cycle test:
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1 b1
// | wait for stage 1 (blocks)
// signal stage 1 |
// wait for stage 2 (blocks) (unblocked)
// | signal stage 2
// (unblocked) wait for stage 3 (blocks)
// a2 |
// b1 (blocks -> stage 3) |
// | (unblocked)
// | b2
// | a1 (cycle detected, recovers)
// | b2 completes, recovers
// | b1 completes, recovers
// a2 sees cycle, recovers
// a1 completes, recovers
#[test]
fn parallel_cycle_all_recover() {
let db = ParDatabaseImpl::default();
db.knobs().signal_on_will_block.set(3);
let thread_a = std::thread::spawn({
let db = db.snapshot();
move || db.a1(1)
});
let thread_b = std::thread::spawn({
let db = db.snapshot();
move || db.b1(1)
});
assert_eq!(thread_a.join().unwrap(), 11);
assert_eq!(thread_b.join().unwrap(), 21);
}
#[salsa::query_group(ParallelCycleAllRecover)]
pub(crate) trait TestDatabase: Knobs {
#[salsa::cycle(recover_a1)]
fn a1(&self, key: i32) -> i32;
#[salsa::cycle(recover_a2)]
fn a2(&self, key: i32) -> i32;
#[salsa::cycle(recover_b1)]
fn b1(&self, key: i32) -> i32;
#[salsa::cycle(recover_b2)]
fn b2(&self, key: i32) -> i32;
}
fn recover_a1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_a1");
key * 10 + 1
}
fn recover_a2(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_a2");
key * 10 + 2
}
fn recover_b1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_b1");
key * 20 + 1
}
fn recover_b2(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_b2");
key * 20 + 2
}
fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.signal(1);
db.wait_for(2);
db.a2(key)
}
fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
db.b1(key)
}
fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.wait_for(1);
db.signal(2);
// Wait for thread A to block on this thread
db.wait_for(3);
db.b2(key)
}
fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
db.a1(key)
}

View file

@ -0,0 +1,110 @@
//! Test for cycle recover spread across two threads.
//! See `../cycles.rs` for a complete listing of cycle tests,
//! both intra and cross thread.
use crate::setup::{Knobs, ParDatabaseImpl};
use salsa::ParallelDatabase;
use test_log::test;
// Recover cycle test:
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1 b1
// | wait for stage 1 (blocks)
// signal stage 1 |
// wait for stage 2 (blocks) (unblocked)
// | |
// | b2
// | b3
// | a1 (blocks -> stage 2)
// (unblocked) |
// a2 (cycle detected) |
// b3 recovers
// b2 resumes
// b1 panics because bug
#[test]
fn parallel_cycle_mid_recovers() {
let db = ParDatabaseImpl::default();
db.knobs().signal_on_will_block.set(2);
let thread_a = std::thread::spawn({
let db = db.snapshot();
move || db.a1(1)
});
let thread_b = std::thread::spawn({
let db = db.snapshot();
move || db.b1(1)
});
// We expect that the recovery function yields
// `1 * 20 + 2`, which is returned (and forwarded)
// to b1, and from there to a2 and a1.
assert_eq!(thread_a.join().unwrap(), 22);
assert_eq!(thread_b.join().unwrap(), 22);
}
#[salsa::query_group(ParallelCycleMidRecovers)]
pub(crate) trait TestDatabase: Knobs {
fn a1(&self, key: i32) -> i32;
fn a2(&self, key: i32) -> i32;
#[salsa::cycle(recover_b1)]
fn b1(&self, key: i32) -> i32;
fn b2(&self, key: i32) -> i32;
#[salsa::cycle(recover_b3)]
fn b3(&self, key: i32) -> i32;
}
fn recover_b1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_b1");
key * 20 + 2
}
fn recover_b3(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_b1");
key * 200 + 2
}
fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
// tell thread b we have started
db.signal(1);
// wait for thread b to block on a1
db.wait_for(2);
db.a2(key)
}
fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
// create the cycle
db.b1(key)
}
fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
// wait for thread a to have started
db.wait_for(1);
db.b2(key);
0
}
fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
// will encounter a cycle but recover
db.b3(key);
db.b1(key); // hasn't recovered yet
0
}
fn b3(db: &dyn TestDatabase, key: i32) -> i32 {
// will block on thread a, signaling stage 2
db.a1(key)
}

View file

@ -0,0 +1,69 @@
//! Test a cycle where no queries recover that occurs across threads.
//! See the `../cycles.rs` for a complete listing of cycle tests,
//! both intra and cross thread.
use crate::setup::{Knobs, ParDatabaseImpl};
use expect_test::expect;
use salsa::ParallelDatabase;
use test_log::test;
#[test]
fn parallel_cycle_none_recover() {
let db = ParDatabaseImpl::default();
db.knobs().signal_on_will_block.set(3);
let thread_a = std::thread::spawn({
let db = db.snapshot();
move || db.a(-1)
});
let thread_b = std::thread::spawn({
let db = db.snapshot();
move || db.b(-1)
});
// We expect B to panic because it detects a cycle (it is the one that calls A, ultimately).
// Right now, it panics with a string.
let err_b = thread_b.join().unwrap_err();
if let Some(c) = err_b.downcast_ref::<salsa::Cycle>() {
expect![[r#"
[
"a(-1)",
"b(-1)",
]
"#]]
.assert_debug_eq(&c.unexpected_participants(&db));
} else {
panic!("b failed in an unexpected way: {:?}", err_b);
}
// We expect A to propagate a panic, which causes us to use the sentinel
// type `Canceled`.
assert!(thread_a.join().unwrap_err().downcast_ref::<salsa::Cycle>().is_some());
}
#[salsa::query_group(ParallelCycleNoneRecover)]
pub(crate) trait TestDatabase: Knobs {
fn a(&self, key: i32) -> i32;
fn b(&self, key: i32) -> i32;
}
fn a(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.signal(1);
db.wait_for(2);
db.b(key)
}
fn b(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.wait_for(1);
db.signal(2);
// Wait for thread A to block on this thread
db.wait_for(3);
// Now try to execute A
db.a(key)
}

View file

@ -0,0 +1,95 @@
//! Test for cycle recover spread across two threads.
//! See `../cycles.rs` for a complete listing of cycle tests,
//! both intra and cross thread.
use crate::setup::{Knobs, ParDatabaseImpl};
use salsa::ParallelDatabase;
use test_log::test;
// Recover cycle test:
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1 b1
// | wait for stage 1 (blocks)
// signal stage 1 |
// wait for stage 2 (blocks) (unblocked)
// | signal stage 2
// (unblocked) wait for stage 3 (blocks)
// a2 |
// b1 (blocks -> stage 3) |
// | (unblocked)
// | b2
// | a1 (cycle detected)
// a2 recovery fn executes |
// a1 completes normally |
// b2 completes, recovers
// b1 completes, recovers
#[test]
fn parallel_cycle_one_recovers() {
let db = ParDatabaseImpl::default();
db.knobs().signal_on_will_block.set(3);
let thread_a = std::thread::spawn({
let db = db.snapshot();
move || db.a1(1)
});
let thread_b = std::thread::spawn({
let db = db.snapshot();
move || db.b1(1)
});
// We expect that the recovery function yields
// `1 * 20 + 2`, which is returned (and forwarded)
// to b1, and from there to a2 and a1.
assert_eq!(thread_a.join().unwrap(), 22);
assert_eq!(thread_b.join().unwrap(), 22);
}
#[salsa::query_group(ParallelCycleOneRecovers)]
pub(crate) trait TestDatabase: Knobs {
fn a1(&self, key: i32) -> i32;
#[salsa::cycle(recover)]
fn a2(&self, key: i32) -> i32;
fn b1(&self, key: i32) -> i32;
fn b2(&self, key: i32) -> i32;
}
fn recover(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover");
key * 20 + 2
}
fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.signal(1);
db.wait_for(2);
db.a2(key)
}
fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
db.b1(key)
}
fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.wait_for(1);
db.signal(2);
// Wait for thread A to block on this thread
db.wait_for(3);
db.b2(key)
}
fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
db.a1(key)
}

View file

@ -0,0 +1,37 @@
use std::panic::AssertUnwindSafe;
use crate::setup::{ParDatabase, ParDatabaseImpl};
use salsa::{Cancelled, ParallelDatabase};
/// Test where a read and a set are racing with one another.
/// Should be atomic.
#[test]
fn in_par_get_set_race() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || Cancelled::catch(AssertUnwindSafe(|| db.sum("abc")))
});
let thread2 = std::thread::spawn(move || {
db.set_input('a', 1000);
db.sum("a")
});
// If the 1st thread runs first, you get 111, otherwise you get
// 1011; if they run concurrently and the 1st thread observes the
// cancellation, it'll unwind.
let result1 = thread1.join().unwrap();
if let Ok(value1) = result1 {
assert!(value1 == 111 || value1 == 1011, "illegal result {}", value1);
}
// thread2 can not observe a cancellation because it performs a
// database write before running any other queries.
assert_eq!(thread2.join().unwrap(), 1000);
}

View file

@ -0,0 +1,197 @@
use crate::signal::Signal;
use salsa::Database;
use salsa::ParallelDatabase;
use salsa::Snapshot;
use std::sync::Arc;
use std::{
cell::Cell,
panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
};
#[salsa::query_group(Par)]
pub(crate) trait ParDatabase: Knobs {
#[salsa::input]
fn input(&self, key: char) -> usize;
fn sum(&self, key: &'static str) -> usize;
/// Invokes `sum`
fn sum2(&self, key: &'static str) -> usize;
/// Invokes `sum` but doesn't really care about the result.
fn sum2_drop_sum(&self, key: &'static str) -> usize;
/// Invokes `sum2`
fn sum3(&self, key: &'static str) -> usize;
/// Invokes `sum2_drop_sum`
fn sum3_drop_sum(&self, key: &'static str) -> usize;
}
/// Various "knobs" and utilities used by tests to force
/// a certain behavior.
pub(crate) trait Knobs {
fn knobs(&self) -> &KnobsStruct;
fn signal(&self, stage: usize);
fn wait_for(&self, stage: usize);
}
pub(crate) trait WithValue<T> {
fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R;
}
impl<T> WithValue<T> for Cell<T> {
fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R {
let old_value = self.replace(value);
let result = catch_unwind(AssertUnwindSafe(closure));
self.set(old_value);
match result {
Ok(r) => r,
Err(payload) => resume_unwind(payload),
}
}
}
#[derive(Default, Clone, Copy, PartialEq, Eq)]
pub(crate) enum CancellationFlag {
#[default]
Down,
Panic,
}
/// Various "knobs" that can be used to customize how the queries
/// behave on one specific thread. Note that this state is
/// intentionally thread-local (apart from `signal`).
#[derive(Clone, Default)]
pub(crate) struct KnobsStruct {
/// A kind of flexible barrier used to coordinate execution across
/// threads to ensure we reach various weird states.
pub(crate) signal: Arc<Signal>,
/// When this database is about to block, send a signal.
pub(crate) signal_on_will_block: Cell<usize>,
/// Invocations of `sum` will signal this stage on entry.
pub(crate) sum_signal_on_entry: Cell<usize>,
/// Invocations of `sum` will wait for this stage on entry.
pub(crate) sum_wait_for_on_entry: Cell<usize>,
/// If true, invocations of `sum` will panic before they exit.
pub(crate) sum_should_panic: Cell<bool>,
/// If true, invocations of `sum` will wait for cancellation before
/// they exit.
pub(crate) sum_wait_for_cancellation: Cell<CancellationFlag>,
/// Invocations of `sum` will wait for this stage prior to exiting.
pub(crate) sum_wait_for_on_exit: Cell<usize>,
/// Invocations of `sum` will signal this stage prior to exiting.
pub(crate) sum_signal_on_exit: Cell<usize>,
/// Invocations of `sum3_drop_sum` will panic unconditionally
pub(crate) sum3_drop_sum_should_panic: Cell<bool>,
}
fn sum(db: &dyn ParDatabase, key: &'static str) -> usize {
let mut sum = 0;
db.signal(db.knobs().sum_signal_on_entry.get());
db.wait_for(db.knobs().sum_wait_for_on_entry.get());
if db.knobs().sum_should_panic.get() {
panic!("query set to panic before exit")
}
for ch in key.chars() {
sum += db.input(ch);
}
match db.knobs().sum_wait_for_cancellation.get() {
CancellationFlag::Down => (),
CancellationFlag::Panic => {
tracing::debug!("waiting for cancellation");
loop {
db.unwind_if_cancelled();
std::thread::yield_now();
}
}
}
db.wait_for(db.knobs().sum_wait_for_on_exit.get());
db.signal(db.knobs().sum_signal_on_exit.get());
sum
}
fn sum2(db: &dyn ParDatabase, key: &'static str) -> usize {
db.sum(key)
}
fn sum2_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize {
let _ = db.sum(key);
22
}
fn sum3(db: &dyn ParDatabase, key: &'static str) -> usize {
db.sum2(key)
}
fn sum3_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize {
if db.knobs().sum3_drop_sum_should_panic.get() {
panic!("sum3_drop_sum executed")
}
db.sum2_drop_sum(key)
}
#[salsa::database(
Par,
crate::parallel_cycle_all_recover::ParallelCycleAllRecover,
crate::parallel_cycle_none_recover::ParallelCycleNoneRecover,
crate::parallel_cycle_mid_recover::ParallelCycleMidRecovers,
crate::parallel_cycle_one_recovers::ParallelCycleOneRecovers
)]
#[derive(Default)]
pub(crate) struct ParDatabaseImpl {
storage: salsa::Storage<Self>,
knobs: KnobsStruct,
}
impl Database for ParDatabaseImpl {
fn salsa_event(&self, event: salsa::Event) {
if let salsa::EventKind::WillBlockOn { .. } = event.kind {
self.signal(self.knobs().signal_on_will_block.get());
}
}
}
impl ParallelDatabase for ParDatabaseImpl {
fn snapshot(&self) -> Snapshot<Self> {
Snapshot::new(ParDatabaseImpl {
storage: self.storage.snapshot(),
knobs: self.knobs.clone(),
})
}
}
impl Knobs for ParDatabaseImpl {
fn knobs(&self) -> &KnobsStruct {
&self.knobs
}
fn signal(&self, stage: usize) {
self.knobs.signal.signal(stage);
}
fn wait_for(&self, stage: usize) {
self.knobs.signal.wait_for(stage);
}
}

View file

@ -0,0 +1,40 @@
use parking_lot::{Condvar, Mutex};
#[derive(Default)]
pub(crate) struct Signal {
value: Mutex<usize>,
cond_var: Condvar,
}
impl Signal {
pub(crate) fn signal(&self, stage: usize) {
tracing::debug!("signal({})", stage);
// This check avoids acquiring the lock for things that will
// clearly be a no-op. Not *necessary* but helps to ensure we
// are more likely to encounter weird race conditions;
// otherwise calls to `sum` will tend to be unnecessarily
// synchronous.
if stage > 0 {
let mut v = self.value.lock();
if stage > *v {
*v = stage;
self.cond_var.notify_all();
}
}
}
/// Waits until the given condition is true; the fn is invoked
/// with the current stage.
pub(crate) fn wait_for(&self, stage: usize) {
tracing::debug!("wait_for({})", stage);
// As above, avoid lock if clearly a no-op.
if stage > 0 {
let mut v = self.value.lock();
while *v < stage {
self.cond_var.wait(&mut v);
}
}
}
}

View file

@ -0,0 +1,168 @@
use rand::seq::SliceRandom;
use rand::Rng;
use salsa::ParallelDatabase;
use salsa::Snapshot;
use salsa::{Cancelled, Database};
// Number of operations a reader performs
const N_MUTATOR_OPS: usize = 100;
const N_READER_OPS: usize = 100;
#[salsa::query_group(Stress)]
trait StressDatabase: salsa::Database {
#[salsa::input]
fn a(&self, key: usize) -> usize;
fn b(&self, key: usize) -> usize;
fn c(&self, key: usize) -> usize;
}
fn b(db: &dyn StressDatabase, key: usize) -> usize {
db.unwind_if_cancelled();
db.a(key)
}
fn c(db: &dyn StressDatabase, key: usize) -> usize {
db.b(key)
}
#[salsa::database(Stress)]
#[derive(Default)]
struct StressDatabaseImpl {
storage: salsa::Storage<Self>,
}
impl salsa::Database for StressDatabaseImpl {}
impl salsa::ParallelDatabase for StressDatabaseImpl {
fn snapshot(&self) -> Snapshot<StressDatabaseImpl> {
Snapshot::new(StressDatabaseImpl { storage: self.storage.snapshot() })
}
}
#[derive(Clone, Copy, Debug)]
enum Query {
A,
B,
C,
}
enum MutatorOp {
WriteOp(WriteOp),
LaunchReader { ops: Vec<ReadOp>, check_cancellation: bool },
}
#[derive(Debug)]
enum WriteOp {
SetA(usize, usize),
}
#[derive(Debug)]
enum ReadOp {
Get(Query, usize),
}
impl rand::distributions::Distribution<Query> for rand::distributions::Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Query {
*[Query::A, Query::B, Query::C].choose(rng).unwrap()
}
}
impl rand::distributions::Distribution<MutatorOp> for rand::distributions::Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> MutatorOp {
if rng.gen_bool(0.5) {
MutatorOp::WriteOp(rng.gen())
} else {
MutatorOp::LaunchReader {
ops: (0..N_READER_OPS).map(|_| rng.gen()).collect(),
check_cancellation: rng.gen(),
}
}
}
}
impl rand::distributions::Distribution<WriteOp> for rand::distributions::Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> WriteOp {
let key = rng.gen::<usize>() % 10;
let value = rng.gen::<usize>() % 10;
WriteOp::SetA(key, value)
}
}
impl rand::distributions::Distribution<ReadOp> for rand::distributions::Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> ReadOp {
let query = rng.gen::<Query>();
let key = rng.gen::<usize>() % 10;
ReadOp::Get(query, key)
}
}
fn db_reader_thread(db: &StressDatabaseImpl, ops: Vec<ReadOp>, check_cancellation: bool) {
for op in ops {
if check_cancellation {
db.unwind_if_cancelled();
}
op.execute(db);
}
}
impl WriteOp {
fn execute(self, db: &mut StressDatabaseImpl) {
match self {
WriteOp::SetA(key, value) => {
db.set_a(key, value);
}
}
}
}
impl ReadOp {
fn execute(self, db: &StressDatabaseImpl) {
match self {
ReadOp::Get(query, key) => match query {
Query::A => {
db.a(key);
}
Query::B => {
let _ = db.b(key);
}
Query::C => {
let _ = db.c(key);
}
},
}
}
}
#[test]
fn stress_test() {
let mut db = StressDatabaseImpl::default();
for i in 0..10 {
db.set_a(i, i);
}
let mut rng = rand::thread_rng();
// generate the ops that the mutator thread will perform
let write_ops: Vec<MutatorOp> = (0..N_MUTATOR_OPS).map(|_| rng.gen()).collect();
// execute the "main thread", which sometimes snapshots off other threads
let mut all_threads = vec![];
for op in write_ops {
match op {
MutatorOp::WriteOp(w) => w.execute(&mut db),
MutatorOp::LaunchReader { ops, check_cancellation } => {
all_threads.push(std::thread::spawn({
let db = db.snapshot();
move || Cancelled::catch(|| db_reader_thread(&db, ops, check_cancellation))
}))
}
}
}
for thread in all_threads {
thread.join().unwrap().ok();
}
}

View file

@ -0,0 +1,125 @@
use crate::setup::{Knobs, ParDatabase, ParDatabaseImpl, WithValue};
use salsa::ParallelDatabase;
use std::panic::{self, AssertUnwindSafe};
/// Test where two threads are executing sum. We show that they can
/// both be executing sum in parallel by having thread1 wait for
/// thread2 to send a signal before it leaves (similarly, thread2
/// waits for thread1 to send a signal before it enters).
#[test]
fn true_parallel_different_keys() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
// Thread 1 will signal stage 1 when it enters and wait for stage 2.
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
let v = db
.knobs()
.sum_signal_on_entry
.with_value(1, || db.knobs().sum_wait_for_on_exit.with_value(2, || db.sum("a")));
v
}
});
// Thread 2 will wait_for stage 1 when it enters and signal stage 2
// when it leaves.
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || {
let v = db
.knobs()
.sum_wait_for_on_entry
.with_value(1, || db.knobs().sum_signal_on_exit.with_value(2, || db.sum("b")));
v
}
});
assert_eq!(thread1.join().unwrap(), 100);
assert_eq!(thread2.join().unwrap(), 10);
}
/// Add a test that tries to trigger a conflict, where we fetch
/// `sum("abc")` from two threads simultaneously, and of them
/// therefore has to block.
#[test]
fn true_parallel_same_keys() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
// Thread 1 will wait_for a barrier in the start of `sum`
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
let v = db
.knobs()
.sum_signal_on_entry
.with_value(1, || db.knobs().sum_wait_for_on_entry.with_value(2, || db.sum("abc")));
v
}
});
// Thread 2 will wait until Thread 1 has entered sum and then --
// once it has set itself to block -- signal Thread 1 to
// continue. This way, we test out the mechanism of one thread
// blocking on another.
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || {
db.knobs().signal.wait_for(1);
db.knobs().signal_on_will_block.set(2);
db.sum("abc")
}
});
assert_eq!(thread1.join().unwrap(), 111);
assert_eq!(thread2.join().unwrap(), 111);
}
/// Add a test that tries to trigger a conflict, where we fetch `sum("a")`
/// from two threads simultaneously. After `thread2` begins blocking,
/// we force `thread1` to panic and should see that propagate to `thread2`.
#[test]
fn true_parallel_propagate_panic() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 1);
// `thread1` will wait_for a barrier in the start of `sum`. Once it can
// continue, it will panic.
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
let v = db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_on_entry
.with_value(2, || db.knobs().sum_should_panic.with_value(true, || db.sum("a")))
});
v
}
});
// `thread2` will wait until `thread1` has entered sum and then -- once it
// has set itself to block -- signal `thread1` to continue.
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || {
db.knobs().signal.wait_for(1);
db.knobs().signal_on_will_block.set(2);
db.sum("a")
}
});
let result1 = panic::catch_unwind(AssertUnwindSafe(|| thread1.join().unwrap()));
let result2 = panic::catch_unwind(AssertUnwindSafe(|| thread2.join().unwrap()));
assert!(result1.is_err());
assert!(result2.is_err());
}