refactor: Clean up cache priming cancellation handling

This commit is contained in:
Lukas Wirth 2025-04-29 10:40:06 +02:00
parent fe7b4f2ad9
commit 7d9b839f9c
7 changed files with 54 additions and 39 deletions

View file

@ -13,13 +13,13 @@ use hir_def::{
use hir_expand::{HirFileId, name::Name};
use hir_ty::{
db::HirDatabase,
display::{DisplayTarget, HirDisplay, hir_display_with_store},
display::{HirDisplay, hir_display_with_store},
};
use intern::Symbol;
use rustc_hash::FxHashMap;
use syntax::{AstNode, AstPtr, SmolStr, SyntaxNode, SyntaxNodePtr, ToSmolStr, ast::HasName};
use crate::{Module, ModuleDef, Semantics};
use crate::{HasCrate, Module, ModuleDef, Semantics};
pub type FxIndexSet<T> = indexmap::IndexSet<T, std::hash::BuildHasherDefault<rustc_hash::FxHasher>>;
@ -66,7 +66,6 @@ pub struct SymbolCollector<'a> {
symbols: FxIndexSet<FileSymbol>,
work: Vec<SymbolCollectorWork>,
current_container_name: Option<SmolStr>,
display_target: DisplayTarget,
}
/// Given a [`ModuleId`] and a [`HirDatabase`], use the DefMap for the module's crate to collect
@ -78,10 +77,6 @@ impl<'a> SymbolCollector<'a> {
symbols: Default::default(),
work: Default::default(),
current_container_name: None,
display_target: DisplayTarget::from_crate(
db,
*db.all_crates().last().expect("no crate graph present"),
),
}
}
@ -93,8 +88,7 @@ impl<'a> SymbolCollector<'a> {
pub fn collect(&mut self, module: Module) {
let _p = tracing::info_span!("SymbolCollector::collect", ?module).entered();
tracing::info!(?module, "SymbolCollector::collect",);
self.display_target = module.krate().to_display_target(self.db);
tracing::info!(?module, "SymbolCollector::collect");
// The initial work is the root module we're collecting, additional work will
// be populated as we traverse the module's definitions.
@ -321,7 +315,10 @@ impl<'a> SymbolCollector<'a> {
let impl_data = self.db.impl_signature(impl_id);
let impl_name = Some(
hir_display_with_store(impl_data.self_ty, &impl_data.store)
.display(self.db, self.display_target)
.display(
self.db,
crate::Impl::from(impl_id).krate(self.db).to_display_target(self.db),
)
.to_smolstr(),
);
self.with_container_name(impl_name, |s| {

View file

@ -78,6 +78,8 @@ pub type FileRange = FileRangeWrapper<FileId>;
#[salsa::db]
pub struct RootDatabase {
// FIXME: Revisit this commit now that we migrated to the new salsa, given we store arcs in this
// db directly now
// We use `ManuallyDrop` here because every codegen unit that contains a
// `&RootDatabase -> &dyn OtherDatabase` cast will instantiate its drop glue in the vtable,
// which duplicates `Weak::drop` and `Arc::drop` tens of thousands of times, which makes
@ -234,14 +236,6 @@ impl RootDatabase {
// );
// hir::db::BodyWithSourceMapQuery.in_db_mut(self).set_lru_capacity(2048);
}
pub fn snapshot(&self) -> Self {
Self {
storage: self.storage.clone(),
files: self.files.clone(),
crates_map: self.crates_map.clone(),
}
}
}
#[query_group::query_group]

View file

@ -51,6 +51,7 @@ pub fn parallel_prime_caches(
enum ParallelPrimeCacheWorkerProgress {
BeginCrate { crate_id: Crate, crate_name: Symbol },
EndCrate { crate_id: Crate },
Cancelled(Cancelled),
}
// We split off def map computation from other work,
@ -71,26 +72,32 @@ pub fn parallel_prime_caches(
progress_sender
.send(ParallelPrimeCacheWorkerProgress::BeginCrate { crate_id, crate_name })?;
match kind {
let cancelled = Cancelled::catch(|| match kind {
PrimingPhase::DefMap => _ = db.crate_def_map(crate_id),
PrimingPhase::ImportMap => _ = db.import_map(crate_id),
PrimingPhase::CrateSymbols => _ = db.crate_symbols(crate_id.into()),
}
});
progress_sender.send(ParallelPrimeCacheWorkerProgress::EndCrate { crate_id })?;
match cancelled {
Ok(()) => progress_sender
.send(ParallelPrimeCacheWorkerProgress::EndCrate { crate_id })?,
Err(cancelled) => progress_sender
.send(ParallelPrimeCacheWorkerProgress::Cancelled(cancelled))?,
}
}
Ok::<_, crossbeam_channel::SendError<_>>(())
};
for id in 0..num_worker_threads {
let worker = prime_caches_worker.clone();
let db = db.snapshot();
stdx::thread::Builder::new(stdx::thread::ThreadIntent::Worker)
.allow_leak(true)
.name(format!("PrimeCaches#{id}"))
.spawn(move || Cancelled::catch(|| worker(db.snapshot())))
.spawn({
let worker = prime_caches_worker.clone();
let db = db.clone();
move || worker(db)
})
.expect("failed to spawn thread");
}
@ -142,9 +149,14 @@ pub fn parallel_prime_caches(
continue;
}
Err(crossbeam_channel::RecvTimeoutError::Disconnected) => {
// our workers may have died from a cancelled task, so we'll check and re-raise here.
db.unwind_if_revision_cancelled();
break;
// all our workers have exited, mark us as finished and exit
cb(ParallelPrimeCachesProgress {
crates_currently_indexing: vec![],
crates_done,
crates_total: crates_done,
work_type: "Indexing",
});
return;
}
};
match worker_progress {
@ -156,6 +168,10 @@ pub fn parallel_prime_caches(
crates_to_prime.mark_done(crate_id);
crates_done += 1;
}
ParallelPrimeCacheWorkerProgress::Cancelled(cancelled) => {
// Cancelled::throw should probably be public
std::panic::resume_unwind(Box::new(cancelled));
}
};
let progress = ParallelPrimeCachesProgress {
@ -186,9 +202,14 @@ pub fn parallel_prime_caches(
continue;
}
Err(crossbeam_channel::RecvTimeoutError::Disconnected) => {
// our workers may have died from a cancelled task, so we'll check and re-raise here.
db.unwind_if_revision_cancelled();
break;
// all our workers have exited, mark us as finished and exit
cb(ParallelPrimeCachesProgress {
crates_currently_indexing: vec![],
crates_done,
crates_total: crates_done,
work_type: "Populating symbols",
});
return;
}
};
match worker_progress {
@ -199,6 +220,10 @@ pub fn parallel_prime_caches(
crates_currently_indexing.swap_remove(&crate_id);
crates_done += 1;
}
ParallelPrimeCacheWorkerProgress::Cancelled(cancelled) => {
// Cancelled::throw should probably be public
std::panic::resume_unwind(Box::new(cancelled));
}
};
let progress = ParallelPrimeCachesProgress {

View file

@ -182,7 +182,7 @@ impl AnalysisHost {
/// Returns a snapshot of the current state, which you can query for
/// semantic information.
pub fn analysis(&self) -> Analysis {
Analysis { db: self.db.snapshot() }
Analysis { db: self.db.clone() }
}
/// Applies changes to the current state of the world. If there are
@ -864,7 +864,7 @@ impl Analysis {
where
F: FnOnce(&RootDatabase) -> T + std::panic::UnwindSafe,
{
let snap = self.db.snapshot();
let snap = self.db.clone();
Cancelled::catch(|| f(&snap))
}
}

View file

@ -701,10 +701,9 @@ impl flags::AnalysisStats {
if self.parallel {
let mut inference_sw = self.stop_watch();
let snap = db.snapshot();
bodies
.par_iter()
.map_with(snap, |snap, &body| {
.map_with(db.clone(), |snap, &body| {
snap.body(body.into());
snap.infer(body.into());
})

View file

@ -126,10 +126,8 @@ impl CargoParser<DiscoverProjectMessage> for DiscoverProjectParser {
Some(msg)
}
Err(err) => {
let err = DiscoverProjectData::Error {
error: format!("{:#?}\n{}", err, line),
source: None,
};
let err =
DiscoverProjectData::Error { error: format!("{err:#?}\n{line}"), source: None };
Some(DiscoverProjectMessage::new(err))
}
}

View file

@ -56,6 +56,8 @@ impl Builder {
Self { inner: self.inner.stack_size(size), ..self }
}
/// Whether dropping should detach the thread
/// instead of joining it.
#[must_use]
pub fn allow_leak(self, allow_leak: bool) -> Self {
Self { allow_leak, ..self }