Use select! instead of tokio::spawn for network thread (#110)

This commit is contained in:
Charlie Marsh 2023-10-16 15:41:25 -04:00 committed by GitHub
parent 1b433fdcee
commit 5b046a8102
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -7,6 +7,7 @@ use std::str::FromStr;
use std::sync::Arc;
use anyhow::Result;
use futures::channel::mpsc::UnboundedReceiver;
use futures::future::Either;
use futures::{pin_mut, FutureExt, StreamExt, TryFutureExt};
use pubgrub::error::PubGrubError;
@ -36,6 +37,7 @@ pub struct Resolver<'a> {
markers: &'a MarkerEnvironment,
tags: &'a Tags,
client: &'a PypiClient,
cache: Arc<SolverCache>,
}
impl<'a> Resolver<'a> {
@ -51,100 +53,15 @@ impl<'a> Resolver<'a> {
markers,
tags,
client,
cache: Arc::new(SolverCache::default()),
}
}
/// Resolve a set of requirements into a set of pinned versions.
pub async fn resolve(self) -> Result<Resolution, ResolveError> {
let client = Arc::new(self.client.clone());
let cache = Arc::new(SolverCache::default());
// A channel to fetch package metadata (e.g., given `flask`, fetch all versions) and version
// metadata (e.g., given `flask==1.0.0`, fetch the metadata for that version).
let (request_sink, request_stream) = futures::channel::mpsc::unbounded();
let requests_fut = tokio::spawn({
let tags = self.tags.clone();
let cache = cache.clone();
let client = client.clone();
async move {
let mut response_stream = request_stream
.map({
|request: Request| match request {
Request::Package(package_name) => {
Either::Left(client.simple(package_name.clone()).map_ok(
move |metadata| Response::Package(package_name, metadata),
))
}
Request::Version(file) => Either::Right(
client
.file(file.clone())
.map_ok(move |metadata| Response::Version(file, metadata)),
),
}
})
.buffer_unordered(32)
.ready_chunks(32);
while let Some(chunk) = response_stream.next().await {
for response in chunk {
match response? {
Response::Package(package_name, metadata) => {
trace!("Received package metadata for {}", package_name);
// Only bother storing platform-compatible wheels.
let wheels: Vec<Wheel> = metadata
.files
.into_iter()
.filter_map(|file| {
let Ok(filename) =
WheelFilename::from_str(file.filename.as_str())
else {
debug!("Ignoring non-wheel: {}", file.filename);
return None;
};
let Ok(version) =
pep440_rs::Version::from_str(&filename.version)
else {
debug!("Ignoring invalid version: {}", file.filename);
return None;
};
if !filename.is_compatible(&tags) {
debug!(
"Ignoring wheel with incompatible tags: {}",
file.filename
);
return None;
}
Some(Wheel {
name: PackageName::normalize(&filename.distribution),
version,
file,
})
})
.collect();
if wheels.is_empty() {
return Err(ResolveError::NoCompatibleDistributions(
package_name,
));
}
cache.packages.insert(package_name.clone(), wheels);
}
Response::Version(file, metadata) => {
trace!("Received file metadata for {}", file.filename);
cache.versions.insert(file.hashes.sha256.clone(), metadata);
}
}
}
}
Ok::<(), ResolveError>(())
}
});
// Push all the requirements into the package sink.
for requirement in &self.requirements {
@ -153,8 +70,11 @@ impl<'a> Resolver<'a> {
request_sink.unbounded_send(Request::Package(package_name))?;
}
// Run the fetcher.
let requests_fut = self.fetch(request_stream);
// Run the solver.
let resolve_fut = self.solve(&cache, &request_sink);
let resolve_fut = self.solve(&request_sink);
let requests_fut = requests_fut.fuse();
let resolve_fut = resolve_fut.fuse();
@ -162,7 +82,7 @@ impl<'a> Resolver<'a> {
let resolution = select! {
result = requests_fut => {
result??;
result?;
return Err(ResolveError::StreamTermination);
}
resolution = resolve_fut => {
@ -176,7 +96,6 @@ impl<'a> Resolver<'a> {
/// Run the `PubGrub` solver.
async fn solve(
&self,
cache: &SolverCache,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
) -> Result<PubGrubResolution, ResolveError> {
let root = PubGrubPackage::Root;
@ -212,7 +131,6 @@ impl<'a> Resolver<'a> {
let decision = self
.choose_package_version(
potential_packages,
cache,
&mut pins,
&mut requested_versions,
request_sink,
@ -248,7 +166,6 @@ impl<'a> Resolver<'a> {
.get_dependencies(
package,
&version,
cache,
&mut pins,
&mut requested_packages,
request_sink,
@ -321,7 +238,6 @@ impl<'a> Resolver<'a> {
async fn choose_package_version<T: Borrow<PubGrubPackage>, U: Borrow<Range<PubGrubVersion>>>(
&self,
mut potential_packages: Vec<(T, U)>,
cache: &SolverCache,
pins: &mut HashMap<PackageName, HashMap<pep440_rs::Version, File>>,
in_flight: &mut HashSet<String>,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
@ -336,7 +252,7 @@ impl<'a> Resolver<'a> {
};
// If we don't have metadata for this package, we can't make an early decision.
let Some(entry) = cache.packages.get(package_name) else {
let Some(entry) = self.cache.packages.get(package_name) else {
continue;
};
@ -369,7 +285,7 @@ impl<'a> Resolver<'a> {
// Wait for the metadata to be available.
// TODO(charlie): Ideally, we'd choose the first package for which metadata is
// available.
let entry = cache.packages.wait(package_name).await.unwrap();
let entry = self.cache.packages.wait(package_name).await.unwrap();
let wheels = entry.value();
debug!(
@ -422,7 +338,6 @@ impl<'a> Resolver<'a> {
&self,
package: &PubGrubPackage,
version: &PubGrubVersion,
cache: &SolverCache,
pins: &mut HashMap<PackageName, HashMap<pep440_rs::Version, File>>,
requested_packages: &mut HashSet<PackageName>,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
@ -454,7 +369,7 @@ impl<'a> Resolver<'a> {
// Wait for the metadata to be available.
let versions = pins.get(package_name).unwrap();
let file = versions.get(version.into()).unwrap();
let entry = cache.versions.wait(&file.hashes.sha256).await.unwrap();
let entry = self.cache.versions.wait(&file.hashes.sha256).await.unwrap();
let metadata = entry.value();
let mut constraints = DependencyConstraints::default();
@ -499,6 +414,84 @@ impl<'a> Resolver<'a> {
}
}
}
/// Fetch the metadata for a stream of packages and versions.
async fn fetch(&self, request_stream: UnboundedReceiver<Request>) -> Result<(), ResolveError> {
let mut response_stream = request_stream
.map({
|request: Request| match request {
Request::Package(package_name) => Either::Left(
self.client
.simple(package_name.clone())
.map_ok(move |metadata| Response::Package(package_name, metadata)),
),
Request::Version(file) => Either::Right(
self.client
.file(file.clone())
.map_ok(move |metadata| Response::Version(file, metadata)),
),
}
})
.buffer_unordered(32)
.ready_chunks(32);
while let Some(chunk) = response_stream.next().await {
for response in chunk {
match response? {
Response::Package(package_name, metadata) => {
trace!("Received package metadata for {}", package_name);
// Only bother storing platform-compatible wheels.
let wheels: Vec<Wheel> = metadata
.files
.into_iter()
.filter_map(|file| {
let Ok(filename) = WheelFilename::from_str(file.filename.as_str())
else {
debug!("Ignoring non-wheel: {}", file.filename);
return None;
};
let Ok(version) = pep440_rs::Version::from_str(&filename.version)
else {
debug!("Ignoring invalid version: {}", file.filename);
return None;
};
if !filename.is_compatible(self.tags) {
debug!(
"Ignoring wheel with incompatible tags: {}",
file.filename
);
return None;
}
Some(Wheel {
name: PackageName::normalize(&filename.distribution),
version,
file,
})
})
.collect();
if wheels.is_empty() {
return Err(ResolveError::NoCompatibleDistributions(package_name));
}
self.cache.packages.insert(package_name.clone(), wheels);
}
Response::Version(file, metadata) => {
trace!("Received file metadata for {}", file.filename);
self.cache
.versions
.insert(file.hashes.sha256.clone(), metadata);
}
}
}
}
Ok::<(), ResolveError>(())
}
}
#[derive(Debug)]