Allow switching out the resolver's IO (#517)

I'm working off of @konstin's commit here to implement arbitrary unsat
test cases for the resolver.

The entirety of the resolver's io are two functions: Get the version map
for a package (PEP 440 version -> distribution) and get the metadata for
a distribution. A new trait `ResolverProvider` abstracts these two away and
allows replacing the real network requests e.g. with stored responses
(https://github.com/pradyunsg/pip-resolver-benchmarks/blob/main/scenarios/pyrax_198.json).

---------

Co-authored-by: konsti <konstin@mailbox.org>
This commit is contained in:
Zanie Blue 2023-12-06 11:53:16 -06:00 committed by GitHub
parent 7acfda889f
commit 2bb04771ce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 181 additions and 70 deletions

View file

@ -20,7 +20,7 @@ use puffin_client::RegistryClientBuilder;
use puffin_dispatch::BuildDispatch; use puffin_dispatch::BuildDispatch;
use puffin_interpreter::Virtualenv; use puffin_interpreter::Virtualenv;
use puffin_normalize::ExtraName; use puffin_normalize::ExtraName;
use puffin_resolver::{Manifest, PreReleaseMode, ResolutionMode, ResolutionOptions}; use puffin_resolver::{Manifest, PreReleaseMode, ResolutionMode, ResolutionOptions, Resolver};
use pypi_types::IndexUrls; use pypi_types::IndexUrls;
use crate::commands::reporters::ResolverReporter; use crate::commands::reporters::ResolverReporter;
@ -149,14 +149,7 @@ pub(crate) async fn pip_compile(
.with_options(options); .with_options(options);
// Resolve the dependencies. // Resolve the dependencies.
let resolver = puffin_resolver::Resolver::new( let resolver = Resolver::new(manifest, options, &markers, &tags, &client, &build_dispatch)
manifest,
options,
&markers,
&tags,
&client,
&build_dispatch,
)
.with_reporter(ResolverReporter::from(printer)); .with_reporter(ResolverReporter::from(printer));
let resolution = match resolver.resolve().await { let resolution = match resolver.resolve().await {
Err(puffin_resolver::ResolveError::PubGrub(err)) => { Err(puffin_resolver::ResolveError::PubGrub(err)) => {

View file

@ -119,7 +119,10 @@ impl RegistryClient {
/// "simple" here refers to [PEP 503 Simple Repository API](https://peps.python.org/pep-0503/) /// "simple" here refers to [PEP 503 Simple Repository API](https://peps.python.org/pep-0503/)
/// and [PEP 691 JSON-based Simple API for Python Package Indexes](https://peps.python.org/pep-0691/), /// and [PEP 691 JSON-based Simple API for Python Package Indexes](https://peps.python.org/pep-0691/),
/// which the pypi json api approximately implements. /// which the pypi json api approximately implements.
pub async fn simple(&self, package_name: PackageName) -> Result<(IndexUrl, SimpleJson), Error> { pub async fn simple(
&self,
package_name: &PackageName,
) -> Result<(IndexUrl, SimpleJson), Error> {
if self.index_urls.no_index() { if self.index_urls.no_index() {
return Err(Error::NoIndex(package_name.as_ref().to_string())); return Err(Error::NoIndex(package_name.as_ref().to_string()));
} }
@ -131,11 +134,7 @@ impl RegistryClient {
url.path_segments_mut().unwrap().push(""); url.path_segments_mut().unwrap().push("");
url.set_query(Some("format=application/vnd.pypi.simple.v1+json")); url.set_query(Some("format=application/vnd.pypi.simple.v1+json"));
trace!( trace!("Fetching metadata for {} from {}", package_name, url);
"Fetching metadata for {} from {}",
package_name.as_ref(),
url
);
let cache_entry = self.cache.entry( let cache_entry = self.cache.entry(
CacheBucket::Simple, CacheBucket::Simple,
@ -180,7 +179,7 @@ impl RegistryClient {
} }
} }
Err(Error::PackageNotFound(package_name.as_ref().to_string())) Err(Error::PackageNotFound(package_name.to_string()))
} }
/// Fetch the metadata for a remote wheel file. /// Fetch the metadata for a remote wheel file.

View file

@ -6,7 +6,7 @@ use std::hash::BuildHasherDefault;
use std::str::FromStr; use std::str::FromStr;
use anyhow::Result; use anyhow::Result;
use futures::{StreamExt, TryFutureExt}; use futures::StreamExt;
use fxhash::FxHashMap; use fxhash::FxHashMap;
use distribution_filename::{SourceDistFilename, WheelFilename}; use distribution_filename::{SourceDistFilename, WheelFilename};
@ -61,12 +61,16 @@ impl<'a> DistFinder<'a> {
// Initialize the package stream. // Initialize the package stream.
let mut package_stream = package_stream let mut package_stream = package_stream
.map(|request: Request| match request { .map(|request: Request| match request {
Request::Package(requirement) => self Request::Package(requirement) => {
.client async move {
.simple(requirement.name.clone()) let (index, metadata) = self.client.simple(&requirement.name).await?;
.map_ok(move |(index, metadata)| { Ok::<_, puffin_client::Error>(Response::Package(
Response::Package(requirement, index, metadata) requirement,
}), index,
metadata,
))
}
}
}) })
.buffer_unordered(32) .buffer_unordered(32)
.ready_chunks(32); .ready_chunks(32);

View file

@ -6,7 +6,9 @@ pub use pubgrub::PubGrubReportFormatter;
pub use resolution::Graph; pub use resolution::Graph;
pub use resolution_mode::ResolutionMode; pub use resolution_mode::ResolutionMode;
pub use resolution_options::ResolutionOptions; pub use resolution_options::ResolutionOptions;
pub use resolver::{BuildId, Reporter as ResolverReporter, Resolver}; pub use resolver::{
BuildId, DefaultResolverProvider, Reporter as ResolverReporter, Resolver, ResolverProvider,
};
mod candidate_selector; mod candidate_selector;
mod error; mod error;

View file

@ -4,9 +4,10 @@ use chrono::{DateTime, Utc};
/// Options for resolving a manifest. /// Options for resolving a manifest.
#[derive(Debug, Default, Copy, Clone)] #[derive(Debug, Default, Copy, Clone)]
pub struct ResolutionOptions { pub struct ResolutionOptions {
pub(crate) resolution_mode: ResolutionMode, // TODO(konstin): These should be pub(crate) again
pub(crate) prerelease_mode: PreReleaseMode, pub resolution_mode: ResolutionMode,
pub(crate) exclude_newer: Option<DateTime<Utc>>, pub prerelease_mode: PreReleaseMode,
pub exclude_newer: Option<DateTime<Utc>>,
} }
impl ResolutionOptions { impl ResolutionOptions {

View file

@ -1,5 +1,7 @@
//! Given a set of requirements, find a set of compatible packages. //! Given a set of requirements, find a set of compatible packages.
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
@ -22,10 +24,10 @@ use pep508_rs::{MarkerEnvironment, Requirement};
use platform_tags::Tags; use platform_tags::Tags;
use puffin_cache::CanonicalUrl; use puffin_cache::CanonicalUrl;
use puffin_client::RegistryClient; use puffin_client::RegistryClient;
use puffin_distribution::{DistributionDatabase, Download}; use puffin_distribution::{DistributionDatabase, DistributionDatabaseError, Download};
use puffin_normalize::{ExtraName, PackageName}; use puffin_normalize::{ExtraName, PackageName};
use puffin_traits::BuildContext; use puffin_traits::BuildContext;
use pypi_types::{File, IndexUrl, Metadata21, SimpleJson}; use pypi_types::{File, IndexUrl, Metadata21};
use crate::candidate_selector::CandidateSelector; use crate::candidate_selector::CandidateSelector;
use crate::error::ResolveError; use crate::error::ResolveError;
@ -39,25 +41,128 @@ use crate::version_map::VersionMap;
use crate::yanks::AllowedYanks; use crate::yanks::AllowedYanks;
use crate::ResolutionOptions; use crate::ResolutionOptions;
pub struct Resolver<'a, Context: BuildContext + Send + Sync> { type VersionMapResponse = Result<(IndexUrl, VersionMap), puffin_client::Error>;
type WheelMetadataResponse = Result<(Metadata21, Option<Url>), DistributionDatabaseError>;
pub trait ResolverProvider: Send + Sync {
/// Get the version map for a package.
fn get_version_map<'io>(
&'io self,
package_name: &'io PackageName,
) -> Pin<Box<dyn Future<Output = VersionMapResponse> + Send + 'io>>;
/// Get the metadata for a distribution.
///
/// For a wheel, this is done by querying it's (remote) metadata, for a source dist we
/// (fetch and) build the source distribution and return the metadata from the built
/// distribution.
fn get_or_build_wheel_metadata<'io>(
&'io self,
dist: &'io Dist,
) -> Pin<Box<dyn Future<Output = WheelMetadataResponse> + Send + 'io>>;
/// Set the [`Reporter`] to use for this installer.
#[must_use]
fn with_reporter(self, reporter: impl puffin_distribution::Reporter + 'static) -> Self;
}
/// The main IO backend for the resolver, which does cached requests network requests using the
/// [`RegistryClient`] and [`DistributionDatabase`].
pub struct DefaultResolverProvider<'a, Context: BuildContext + Send + Sync> {
client: &'a RegistryClient,
fetcher: DistributionDatabase<'a, Context>,
build_context: &'a Context,
tags: &'a Tags,
markers: &'a MarkerEnvironment,
exclude_newer: Option<DateTime<Utc>>,
allowed_yanks: AllowedYanks,
}
impl<'a, Context: BuildContext + Send + Sync> DefaultResolverProvider<'a, Context> {
pub fn new(
client: &'a RegistryClient,
fetcher: DistributionDatabase<'a, Context>,
build_context: &'a Context,
tags: &'a Tags,
markers: &'a MarkerEnvironment,
exclude_newer: Option<DateTime<Utc>>,
allowed_yanks: AllowedYanks,
) -> Self {
Self {
client,
fetcher,
build_context,
tags,
markers,
exclude_newer,
allowed_yanks,
}
}
}
impl<'a, Context: BuildContext + Send + Sync> ResolverProvider
for DefaultResolverProvider<'a, Context>
{
fn get_version_map<'io>(
&'io self,
package_name: &'io PackageName,
) -> Pin<Box<dyn Future<Output = VersionMapResponse> + Send + 'io>> {
Box::pin(
self.client
.simple(package_name)
.map_ok(move |(index, metadata)| {
// TODO(konstin): I think the client should return something in between
// `SimpleJson` and `VersionMap`, with source dists and wheels grouped by
// version, but python version and exclude newer not yet applied. This should
// work well with caching, testing and PEP 503 html APIs.
// (https://github.com/astral-sh/puffin/issues/412)
(
index,
VersionMap::from_metadata(
metadata,
package_name,
self.tags,
self.markers,
self.build_context.interpreter(),
&self.allowed_yanks,
self.exclude_newer.as_ref(),
),
)
}),
)
}
fn get_or_build_wheel_metadata<'io>(
&'io self,
dist: &'io Dist,
) -> Pin<Box<dyn Future<Output = WheelMetadataResponse> + Send + 'io>> {
Box::pin(self.fetcher.get_or_build_wheel_metadata(dist))
}
/// Set the [`puffin_distribution::Reporter`] to use for this installer.
#[must_use]
fn with_reporter(self, reporter: impl puffin_distribution::Reporter + 'static) -> Self {
Self {
fetcher: self.fetcher.with_reporter(reporter),
..self
}
}
}
pub struct Resolver<'a, Provider: ResolverProvider> {
project: Option<PackageName>, project: Option<PackageName>,
requirements: Vec<Requirement>, requirements: Vec<Requirement>,
constraints: Vec<Requirement>, constraints: Vec<Requirement>,
allowed_urls: AllowedUrls, allowed_urls: AllowedUrls,
allowed_yanks: AllowedYanks,
markers: &'a MarkerEnvironment, markers: &'a MarkerEnvironment,
tags: &'a Tags,
client: &'a RegistryClient,
selector: CandidateSelector, selector: CandidateSelector,
index: Arc<Index>, index: Arc<Index>,
exclude_newer: Option<DateTime<Utc>>,
fetcher: DistributionDatabase<'a, Context>,
build_context: &'a Context,
reporter: Option<Arc<dyn Reporter>>, reporter: Option<Arc<dyn Reporter>>,
provider: Provider,
} }
impl<'a, Context: BuildContext + Send + Sync> Resolver<'a, Context> { impl<'a, Context: BuildContext + Send + Sync> Resolver<'a, DefaultResolverProvider<'a, Context>> {
/// Initialize a new resolver. /// Initialize a new resolver using the default backend doing real requests.
pub fn new( pub fn new(
manifest: Manifest, manifest: Manifest,
options: ResolutionOptions, options: ResolutionOptions,
@ -65,6 +170,31 @@ impl<'a, Context: BuildContext + Send + Sync> Resolver<'a, Context> {
tags: &'a Tags, tags: &'a Tags,
client: &'a RegistryClient, client: &'a RegistryClient,
build_context: &'a Context, build_context: &'a Context,
) -> Self {
let provider = DefaultResolverProvider::new(
client,
DistributionDatabase::new(build_context.cache(), tags, client, build_context),
build_context,
tags,
markers,
options.exclude_newer,
manifest
.requirements
.iter()
.chain(manifest.constraints.iter())
.collect(),
);
Self::new_custom_io(manifest, options, markers, provider)
}
}
impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
/// Initialize a new resolver using a user provided backend.
pub fn new_custom_io(
manifest: Manifest,
options: ResolutionOptions,
markers: &'a MarkerEnvironment,
provider: Provider,
) -> Self { ) -> Self {
Self { Self {
index: Arc::new(Index::default()), index: Arc::new(Index::default()),
@ -81,21 +211,12 @@ impl<'a, Context: BuildContext + Send + Sync> Resolver<'a, Context> {
} }
}) })
.collect(), .collect(),
allowed_yanks: manifest
.requirements
.iter()
.chain(manifest.constraints.iter())
.collect(),
project: manifest.project, project: manifest.project,
requirements: manifest.requirements, requirements: manifest.requirements,
constraints: manifest.constraints, constraints: manifest.constraints,
exclude_newer: options.exclude_newer,
markers, markers,
tags,
client,
fetcher: DistributionDatabase::new(build_context.cache(), tags, client, build_context),
build_context,
reporter: None, reporter: None,
provider,
} }
} }
@ -105,7 +226,7 @@ impl<'a, Context: BuildContext + Send + Sync> Resolver<'a, Context> {
let reporter = Arc::new(reporter); let reporter = Arc::new(reporter);
Self { Self {
reporter: Some(reporter.clone()), reporter: Some(reporter.clone()),
fetcher: self.fetcher.with_reporter(Facade { reporter }), provider: self.provider.with_reporter(Facade { reporter }),
..self ..self
} }
} }
@ -552,17 +673,9 @@ impl<'a, Context: BuildContext + Send + Sync> Resolver<'a, Context> {
while let Some(response) = response_stream.next().await { while let Some(response) = response_stream.next().await {
match response? { match response? {
Response::Package(package_name, index, metadata) => { Response::Package(package_name, index, version_map) => {
trace!("Received package metadata for: {package_name}"); trace!("Received package metadata for: {package_name}");
let version_map = VersionMap::from_metadata(
metadata,
&package_name,
self.tags,
self.markers,
self.build_context.interpreter(),
&self.allowed_yanks,
self.exclude_newer.as_ref(),
);
self.index self.index
.packages .packages
.insert(package_name, (index, version_map)); .insert(package_name, (index, version_map));
@ -603,18 +716,17 @@ impl<'a, Context: BuildContext + Send + Sync> Resolver<'a, Context> {
match request { match request {
// Fetch package metadata from the registry. // Fetch package metadata from the registry.
Request::Package(package_name) => { Request::Package(package_name) => {
self.client let (index, metadata) = self
.simple(package_name.clone()) .provider
.map_ok(move |(index, metadata)| { .get_version_map(&package_name)
Response::Package(package_name, index, metadata)
})
.map_err(ResolveError::Client)
.await .await
.map_err(ResolveError::Client)?;
Ok(Response::Package(package_name, index, metadata))
} }
Request::Dist(dist) => { Request::Dist(dist) => {
let (metadata, precise) = self let (metadata, precise) = self
.fetcher .provider
.get_or_build_wheel_metadata(&dist) .get_or_build_wheel_metadata(&dist)
.await .await
.map_err(|err| match dist.clone() { .map_err(|err| match dist.clone() {
@ -732,7 +844,7 @@ enum Request {
#[allow(clippy::large_enum_variant)] #[allow(clippy::large_enum_variant)]
enum Response { enum Response {
/// The returned metadata for a package hosted on a registry. /// The returned metadata for a package hosted on a registry.
Package(PackageName, IndexUrl, SimpleJson), Package(PackageName, IndexUrl, VersionMap),
/// The returned metadata for a distribution. /// The returned metadata for a distribution.
Dist(Dist, Metadata21, Option<Url>), Dist(Dist, Metadata21, Option<Url>),
} }

View file

@ -19,7 +19,7 @@ use crate::yanks::AllowedYanks;
/// A map from versions to distributions. /// A map from versions to distributions.
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub(crate) struct VersionMap(BTreeMap<PubGrubVersion, PrioritizedDistribution>); pub struct VersionMap(BTreeMap<PubGrubVersion, PrioritizedDistribution>);
impl VersionMap { impl VersionMap {
/// Initialize a [`VersionMap`] from the given metadata. /// Initialize a [`VersionMap`] from the given metadata.

View file

@ -7,7 +7,7 @@ use puffin_normalize::PackageName;
/// A set of package versions that are permitted, even if they're marked as yanked by the /// A set of package versions that are permitted, even if they're marked as yanked by the
/// relevant index. /// relevant index.
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub(crate) struct AllowedYanks(FxHashMap<PackageName, FxHashSet<Version>>); pub struct AllowedYanks(FxHashMap<PackageName, FxHashSet<Version>>);
impl AllowedYanks { impl AllowedYanks {
/// Returns `true` if the given package version is allowed, even if it's marked as yanked by /// Returns `true` if the given package version is allowed, even if it's marked as yanked by