Avoid using IDs across PubGrub states (#9538)

## Summary

This isn't safe, because the prefetcher is global but the IDs could come
from different PubGrub states (i.e., different forks).
This commit is contained in:
Charlie Marsh 2024-11-30 08:59:21 -05:00 committed by GitHub
parent 2aca623691
commit 53fe301b1d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 34 additions and 31 deletions

View file

@ -1,19 +1,19 @@
use std::cmp::min;
use itertools::Itertools;
use pubgrub::{Id, Range, State, Term};
use pubgrub::{Range, Term};
use rustc_hash::FxHashMap;
use tokio::sync::mpsc::Sender;
use tracing::{debug, trace};
use crate::candidate_selector::CandidateSelector;
use crate::dependency_provider::UvDependencyProvider;
use crate::pubgrub::{PubGrubPackage, PubGrubPackageInner};
use crate::resolver::Request;
use crate::{
InMemoryIndex, PythonRequirement, ResolveError, ResolverEnvironment, VersionsResponse,
};
use uv_distribution_types::{CompatibleDist, DistributionMetadata, IndexCapabilities, IndexUrl};
use uv_normalize::PackageName;
use uv_pep440::Version;
enum BatchPrefetchStrategy {
@ -40,15 +40,14 @@ enum BatchPrefetchStrategy {
/// Note that these all heuristics that could totally prefetch lots of irrelevant versions.
#[derive(Default)]
pub(crate) struct BatchPrefetcher {
tried_versions: FxHashMap<Id<PubGrubPackage>, usize>,
last_prefetch: FxHashMap<Id<PubGrubPackage>, usize>,
tried_versions: FxHashMap<PackageName, usize>,
last_prefetch: FxHashMap<PackageName, usize>,
}
impl BatchPrefetcher {
/// Prefetch a large number of versions if we already unsuccessfully tried many versions.
pub(crate) fn prefetch_batches(
&mut self,
id: Id<PubGrubPackage>,
next: &PubGrubPackage,
index: Option<&IndexUrl>,
version: &Version,
@ -71,7 +70,7 @@ impl BatchPrefetcher {
return Ok(());
};
let (num_tried, do_prefetch) = self.should_prefetch(id);
let (num_tried, do_prefetch) = self.should_prefetch(next);
if !do_prefetch {
return Ok(());
}
@ -222,32 +221,41 @@ impl BatchPrefetcher {
debug!("Prefetching {prefetch_count} {name} versions");
self.last_prefetch.insert(id, num_tried);
self.last_prefetch.insert(name.clone(), num_tried);
Ok(())
}
/// Each time we tried a version for a package, we register that here.
pub(crate) fn version_tried(&mut self, id: Id<PubGrubPackage>, package: &PubGrubPackage) {
pub(crate) fn version_tried(&mut self, package: &PubGrubPackage) {
// Only track base packages, no virtual packages from extras.
if matches!(
&**package,
PubGrubPackageInner::Package {
extra: None,
dev: None,
marker: None,
..
}
) {
*self.tried_versions.entry(id).or_default() += 1;
}
let PubGrubPackageInner::Package {
name,
extra: None,
dev: None,
marker: None,
} = &**package
else {
return;
};
*self.tried_versions.entry(name.clone()).or_default() += 1;
}
/// After 5, 10, 20, 40 tried versions, prefetch that many versions to start early but not
/// too aggressive. Later we schedule the prefetch of 50 versions every 20 versions, this gives
/// us a good buffer until we see prefetch again and is high enough to saturate the task pool.
fn should_prefetch(&self, id: Id<PubGrubPackage>) -> (usize, bool) {
let num_tried = self.tried_versions.get(&id).copied().unwrap_or_default();
let previous_prefetch = self.last_prefetch.get(&id).copied().unwrap_or_default();
fn should_prefetch(&self, next: &PubGrubPackage) -> (usize, bool) {
let PubGrubPackageInner::Package {
name,
extra: None,
dev: None,
marker: None,
} = &**next
else {
return (0, false);
};
let num_tried = self.tried_versions.get(name).copied().unwrap_or_default();
let previous_prefetch = self.last_prefetch.get(name).copied().unwrap_or_default();
let do_prefetch = (num_tried >= 5 && previous_prefetch < 5)
|| (num_tried >= 10 && previous_prefetch < 10)
|| (num_tried >= 20 && previous_prefetch < 20)
@ -259,13 +267,9 @@ impl BatchPrefetcher {
///
/// Note that they may be inflated when we count the same version repeatedly during
/// backtracking.
pub(crate) fn log_tried_versions(&self, state: &State<UvDependencyProvider>) {
pub(crate) fn log_tried_versions(&self) {
let total_versions: usize = self.tried_versions.values().sum();
let mut tried_versions: Vec<_> = self
.tried_versions
.iter()
.map(|(id, count)| (&state.package_store[*id], *count))
.collect();
let mut tried_versions: Vec<_> = self.tried_versions.iter().collect();
tried_versions.sort_by(|(p1, c1), (p2, c2)| {
c1.cmp(c2)
.reverse()

View file

@ -359,7 +359,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
})
else {
if tracing::enabled!(Level::DEBUG) {
prefetcher.log_tried_versions(&state.pubgrub);
prefetcher.log_tried_versions();
}
debug!(
"{} resolution took {:.3}s",
@ -424,7 +424,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
// (idempotent due to caching).
self.request_package(next_package, url, index, &request_sink)?;
prefetcher.version_tried(next_id, next_package);
prefetcher.version_tried(next_package);
let term_intersection = state
.pubgrub
@ -490,7 +490,6 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
// Only consider registry packages for prefetch.
if url.is_none() {
prefetcher.prefetch_batches(
next_id,
next_package,
index,
&version,