Refactor DistFinder to allow handling errors (#709)

For the install tests, i need the ability to ignore failures in the
`DistFinder`. To avoid just copy&pasting a version that collects errors
separately, i followed
https://gendignoux.com/blog/2021/04/01/rust-async-streams-futures-part1.html
and switched the custom channel over to an async stream yielding
`Result` items.

I like the async streams mirror the normal iterator api.
This commit is contained in:
konsti 2023-12-20 05:07:55 +01:00 committed by GitHub
parent 12eedb1c12
commit 9f8b7e7e12
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -2,10 +2,8 @@
//!
//! This is similar to running `pip install` with the `--no-deps` flag.
use std::hash::BuildHasherDefault;
use anyhow::Result;
use futures::StreamExt;
use futures::{stream, Stream, StreamExt, TryStreamExt};
use rustc_hash::FxHashMap;
use distribution_types::{Dist, Resolution};
@ -46,84 +44,57 @@ impl<'a> DistFinder<'a> {
}
}
/// Resolve a single pinned package, either as cached network request
/// (version or no constraint) or by constructing a URL [`Dist`] from the
/// specifier URL.
async fn resolve_requirement(
&self,
requirement: &Requirement,
) -> Result<(PackageName, Dist), ResolveError> {
match requirement.version_or_url.as_ref() {
None | Some(VersionOrUrl::VersionSpecifier(_)) => {
// Query the index(es) (cached) to get the URLs for the available files.
let (index, metadata) = self.client.simple(&requirement.name).await?;
// Pick a version that satisfies the requirement.
let Some(distribution) = self.select(requirement, &index, metadata) else {
return Err(ResolveError::NotFound(requirement.clone()));
};
if let Some(reporter) = self.reporter.as_ref() {
reporter.on_progress(&distribution);
}
let normalized_name = requirement.name.clone();
Ok((normalized_name, distribution))
}
Some(VersionOrUrl::Url(url)) => {
// We have a URL; fetch the distribution directly.
let package_name = requirement.name.clone();
let package = Dist::from_url(package_name.clone(), url.clone())?;
Ok((package_name, package))
}
}
}
/// Resolve the pinned packages in parallel
pub fn resolve_stream<'data>(
&'data self,
requirements: &'data [Requirement],
) -> impl Stream<Item = Result<(PackageName, Dist), ResolveError>> + 'data {
stream::iter(requirements)
.map(move |requirement| self.resolve_requirement(requirement))
.buffer_unordered(32)
}
/// Resolve a set of pinned packages into a set of wheels.
pub async fn resolve(&self, requirements: &[Requirement]) -> Result<Resolution, ResolveError> {
if requirements.is_empty() {
return Ok(Resolution::default());
}
// A channel to fetch package metadata (e.g., given `flask`, fetch all versions).
let (package_sink, package_stream) = futures::channel::mpsc::unbounded();
// Initialize the package stream.
let mut package_stream = package_stream
.map(|request: Request| match request {
Request::Package(requirement) => {
async move {
let (index, metadata) = self.client.simple(&requirement.name).await?;
Ok::<_, puffin_client::Error>(Response::Package(
requirement,
index,
metadata,
))
}
}
})
.buffer_unordered(32)
.ready_chunks(32);
// Resolve the requirements.
let mut resolution: FxHashMap<PackageName, Dist> =
FxHashMap::with_capacity_and_hasher(requirements.len(), BuildHasherDefault::default());
// Push all the requirements into the package sink.
for requirement in requirements {
match requirement.version_or_url.as_ref() {
None | Some(VersionOrUrl::VersionSpecifier(_)) => {
package_sink.unbounded_send(Request::Package(requirement.clone()))?;
}
Some(VersionOrUrl::Url(url)) => {
let package_name = requirement.name.clone();
let package = Dist::from_url(package_name.clone(), url.clone())?;
resolution.insert(package_name, package);
}
}
}
// If all the dependencies were already resolved, we're done.
if resolution.len() == requirements.len() {
if let Some(reporter) = self.reporter.as_ref() {
reporter.on_complete();
}
return Ok(Resolution::new(resolution));
}
// Otherwise, wait for the package stream to complete.
while let Some(chunk) = package_stream.next().await {
for result in chunk {
let result: Response = result?;
match result {
Response::Package(requirement, index, metadata) => {
// Pick a version that satisfies the requirement.
let Some(distribution) = self.select(&requirement, &index, metadata) else {
return Err(ResolveError::NotFound(requirement));
};
if let Some(reporter) = self.reporter.as_ref() {
reporter.on_progress(&distribution);
}
// Add to the resolved set.
let normalized_name = requirement.name.clone();
resolution.insert(normalized_name, distribution);
}
}
}
if resolution.len() == requirements.len() {
break;
}
}
let resolution: FxHashMap<PackageName, Dist> =
self.resolve_stream(requirements).try_collect().await?;
if let Some(reporter) = self.reporter.as_ref() {
reporter.on_complete();
@ -221,18 +192,6 @@ impl<'a> DistFinder<'a> {
}
}
#[derive(Debug)]
enum Request {
/// A request to fetch the metadata for a package.
Package(Requirement),
}
#[derive(Debug)]
enum Response {
/// The returned metadata for a package.
Package(Requirement, IndexUrl, SimpleMetadata),
}
pub trait Reporter: Send + Sync {
/// Callback to invoke when a package is resolved to a specific distribution.
fn on_progress(&self, dist: &Dist);