Yield after channel send and move cpu tasks to thread (#1163)

## Summary

Previously, we were blocking operations that could run in parallel. We
would send request through our main requests channel, but not yield so
that the receiver could only start processing requests much later than
necessary. We solve this by switching to the async
`tokio::sync::mpsc::channel`, where send is an async functions that
yields.

Due to the increased parallelism cache deserialization and the
conversion from simple api request to version map became bottlenecks, so
i moved them to `spawn_blocking`. Together these result in a 30-60%
speedup for larger warm cache resolution. Small cases such as black
already resolve in 5.7 ms on my machine so there's no speedup to be
gained, refresh and no cache were to noisy to get signal from.

Note for the future: Revisit the bounded channel if we want to produce
requests from `process_request`, too, (this would be good for
prefetching) to avoid deadlocks.

## Details

We can look at the behavior change through the spans:

```
RUST_LOG=puffin=info TRACING_DURATIONS_FILE=target/traces/jupyter-warm-branch.ndjson cargo run --features tracing-durations-export --bin puffin-dev --profile profiling -- resolve jupyter 2> /dev/null
```

Below, you can see how on main, we have discrete phases: All (cached)
simple api requests in parallel, then all (cached) metadata requests in
parallel, repeat until done. The solver is mostly waiting until it has
it's version map from the simple API query to be able to choose a
version. The main thread is blocked by process requests.

In the PR branch, the simple api requests succeeds much earlier,
allowing the solver to advance and also to schedule more prefetching.
Due to that `parse_cache` and `from_metadata` became bottlenecks, so i
moved them off the main thread (green color, and their spans can now
overlap because they can run on multiple threads in parallel). The main
thread isn't blocked on `process_request` anymore, instead it has
frequent idle times. The spans are all much shorter, which indicates
that on main they could have finished much earlier, but a task didn't
yield so they weren't scheduled to finish (though i haven't dug deep
enough to understand the exact scheduling between the process request
stream and the solver here).

**main**


![jupyter-warm-main](693c53cc-1090-41b7-b02a-a607fcd2cd99)

**PR**


![jupyter-warm-branch](33435f34-b39b-4b0a-a9d7-4bfc22f55f05)

## Benchmarks

```
$ hyperfine --warmup 3 "target/profiling/main-dev resolve jupyter" "target/profiling/branch-dev resolve jupyter"
Benchmark 1: target/profiling/main-dev resolve jupyter
  Time (mean ± σ):      29.1 ms ±   0.7 ms    [User: 22.9 ms, System: 11.1 ms]
  Range (min … max):    27.7 ms …  32.2 ms    103 runs
 
Benchmark 2: target/profiling/branch-dev resolve jupyter
  Time (mean ± σ):      18.8 ms ±   1.1 ms    [User: 37.0 ms, System: 22.7 ms]
  Range (min … max):    16.5 ms …  21.9 ms    154 runs
 
Summary
  target/profiling/branch-dev resolve jupyter ran
    1.55 ± 0.10 times faster than target/profiling/main-dev resolve jupyter

$ hyperfine --warmup 3 "target/profiling/main-dev resolve meine_stadt_transparent" "target/profiling/branch-dev resolve meine_stadt_transparent"
Benchmark 1: target/profiling/main-dev resolve meine_stadt_transparent
  Time (mean ± σ):      37.8 ms ±   0.9 ms    [User: 30.7 ms, System: 14.1 ms]
  Range (min … max):    36.6 ms …  41.5 ms    79 runs
 
Benchmark 2: target/profiling/branch-dev resolve meine_stadt_transparent
  Time (mean ± σ):      24.7 ms ±   1.5 ms    [User: 47.0 ms, System: 39.3 ms]
  Range (min … max):    21.5 ms …  28.7 ms    113 runs
 
Summary
  target/profiling/branch-dev resolve meine_stadt_transparent ran
    1.53 ± 0.10 times faster than target/profiling/main-dev resolve meine_stadt_transparent

$ hyperfine --warmup 3 "target/profiling/main pip compile scripts/requirements/home-assistant.in" "target/profiling/branch pip compile scripts/requirements/home-assistant.in"
Benchmark 1: target/profiling/main pip compile scripts/requirements/home-assistant.in
  Time (mean ± σ):     229.0 ms ±   2.8 ms    [User: 197.3 ms, System: 63.7 ms]
  Range (min … max):   225.8 ms … 234.0 ms    13 runs
 
Benchmark 2: target/profiling/branch pip compile scripts/requirements/home-assistant.in
  Time (mean ± σ):      91.4 ms ±   5.3 ms    [User: 289.2 ms, System: 176.9 ms]
  Range (min … max):    81.0 ms … 104.7 ms    32 runs
 
Summary
  target/profiling/branch pip compile scripts/requirements/home-assistant.in ran
    2.50 ± 0.15 times faster than target/profiling/main pip compile scripts/requirements/home-assistant.in
```
This commit is contained in:
konsti 2024-02-02 18:18:24 +01:00 committed by GitHub
parent 3771f6656e
commit f10f902570
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 117 additions and 93 deletions

View file

@ -104,7 +104,7 @@ impl CachedClient {
/// client.
#[instrument(skip_all)]
pub async fn get_cached_with_callback<
Payload: Serialize + DeserializeOwned + Send,
Payload: Serialize + DeserializeOwned + Send + 'static,
CallBackError,
Callback,
CallbackReturn,
@ -172,7 +172,7 @@ impl CachedClient {
}
}
async fn read_cache<Payload: Serialize + DeserializeOwned + Send>(
async fn read_cache<Payload: Serialize + DeserializeOwned + Send + 'static>(
cache_entry: &CacheEntry,
) -> Option<DataWithCachePolicy<Payload>> {
let read_span = info_span!("read_cache", file = %cache_entry.path().display());
@ -185,8 +185,12 @@ impl CachedClient {
"parse_cache",
path = %cache_entry.path().display()
);
let parse_result = parse_span
.in_scope(|| rmp_serde::from_slice::<DataWithCachePolicy<Payload>>(&cached));
let parse_result = tokio::task::spawn_blocking(move || {
parse_span
.in_scope(|| rmp_serde::from_slice::<DataWithCachePolicy<Payload>>(&cached))
})
.await
.expect("Tokio executor failed, was there a panic?");
match parse_result {
Ok(data) => Some(data),
Err(err) => {

View file

@ -54,6 +54,7 @@ sha2 = { workspace = true }
tempfile = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["macros"] }
tokio-stream = { workspace = true }
tokio-util = { workspace = true, features = ["compat"] }
tracing = { workspace = true }
url = { workspace = true }

View file

@ -24,14 +24,11 @@ pub enum ResolveError {
#[error("Failed to find a version of {0} that satisfies the requirement")]
NotFound(Requirement),
#[error("The request stream terminated unexpectedly")]
StreamTermination,
#[error(transparent)]
Client(#[from] puffin_client::Error),
#[error(transparent)]
TrySend(#[from] futures::channel::mpsc::SendError),
#[error("The channel is closed, was there a panic?")]
ChannelClosed,
#[error(transparent)]
Join(#[from] tokio::task::JoinError),
@ -88,9 +85,11 @@ pub enum ResolveError {
Failure(String),
}
impl<T> From<futures::channel::mpsc::TrySendError<T>> for ResolveError {
fn from(value: futures::channel::mpsc::TrySendError<T>) -> Self {
value.into_send_error().into()
impl<T> From<tokio::sync::mpsc::error::SendError<T>> for ResolveError {
/// Drop the value we want to send to not leak the private type we're sending.
/// The tokio error only says "channel closed", so we don't lose information.
fn from(_value: tokio::sync::mpsc::error::SendError<T>) -> Self {
Self::ChannelClosed
}
}

View file

@ -5,7 +5,6 @@ use std::sync::Arc;
use anyhow::Result;
use dashmap::{DashMap, DashSet};
use futures::channel::mpsc::UnboundedReceiver;
use futures::{FutureExt, StreamExt};
use itertools::Itertools;
use pubgrub::error::PubGrubError;
@ -14,6 +13,7 @@ use pubgrub::solver::{Incompatibility, State};
use pubgrub::type_aliases::DependencyConstraints;
use rustc_hash::{FxHashMap, FxHashSet};
use tokio::select;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, info_span, instrument, trace, Instrument};
use url::Url;
@ -202,7 +202,8 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
pub async fn resolve(self) -> Result<ResolutionGraph, ResolveError> {
// A channel to fetch package metadata (e.g., given `flask`, fetch all versions) and version
// metadata (e.g., given `flask==1.0.0`, fetch the metadata for that version).
let (request_sink, request_stream) = futures::channel::mpsc::unbounded();
// Channel size is set to the same size as the task buffer for simplicity.
let (request_sink, request_stream) = tokio::sync::mpsc::channel(50);
// Run the fetcher.
let requests_fut = self.fetch(request_stream).fuse();
@ -213,7 +214,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
let resolution = select! {
result = requests_fut => {
result?;
return Err(ResolveError::StreamTermination);
return Err(ResolveError::ChannelClosed);
}
resolution = resolve_fut => {
resolution.map_err(|err| {
@ -241,7 +242,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
#[instrument(skip_all)]
async fn solve(
&self,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
request_sink: &tokio::sync::mpsc::Sender<Request>,
) -> Result<ResolutionGraph, ResolveError> {
let root = PubGrubPackage::Root(self.project.clone());
@ -265,7 +266,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
state.unit_propagation(next)?;
// Pre-visit all candidate packages, to allow metadata to be fetched in parallel.
Self::pre_visit(state.partial_solution.prioritized_packages(), request_sink)?;
Self::pre_visit(state.partial_solution.prioritized_packages(), request_sink).await?;
// Choose a package version.
let Some(highest_priority_pkg) =
@ -386,7 +387,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
&self,
package: &PubGrubPackage,
priorities: &mut PubGrubPriorities,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
request_sink: &tokio::sync::mpsc::Sender<Request>,
) -> Result<(), ResolveError> {
match package {
PubGrubPackage::Root(_) => {}
@ -395,10 +396,9 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
// Emit a request to fetch the metadata for this package.
if self.index.packages.register(package_name.clone()) {
priorities.add(package_name.clone());
request_sink.unbounded_send(Request::Package(package_name.clone()))?;
// Yield to allow subscribers to continue, as the channel is sync.
tokio::task::yield_now().await;
request_sink
.send(Request::Package(package_name.clone()))
.await?;
}
}
PubGrubPackage::Package(package_name, _extra, Some(url)) => {
@ -406,10 +406,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
let dist = Dist::from_url(package_name.clone(), url.clone())?;
if self.index.distributions.register(dist.package_id()) {
priorities.add(dist.name().clone());
request_sink.unbounded_send(Request::Dist(dist))?;
// Yield to allow subscribers to continue, as the channel is sync.
tokio::task::yield_now().await;
request_sink.send(Request::Dist(dist)).await?;
}
}
}
@ -418,9 +415,9 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
/// Visit the set of [`PubGrubPackage`] candidates prior to selection. This allows us to fetch
/// metadata for all of the packages in parallel.
fn pre_visit<'data>(
async fn pre_visit<'data>(
packages: impl Iterator<Item = (&'data PubGrubPackage, &'data Range<Version>)>,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
request_sink: &tokio::sync::mpsc::Sender<Request>,
) -> Result<(), ResolveError> {
// Iterate over the potential packages, and fetch file metadata for any of them. These
// represent our current best guesses for the versions that we _might_ select.
@ -428,7 +425,9 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
let PubGrubPackage::Package(package_name, _extra, None) = package else {
continue;
};
request_sink.unbounded_send(Request::Prefetch(package_name.clone(), range.clone()))?;
request_sink
.send(Request::Prefetch(package_name.clone(), range.clone()))
.await?;
}
Ok(())
}
@ -441,9 +440,9 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
package: &PubGrubPackage,
range: &Range<Version>,
pins: &mut FilePins,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
request_sink: &tokio::sync::mpsc::Sender<Request>,
) -> Result<Option<Version>, ResolveError> {
return match package {
match package {
PubGrubPackage::Root(_) => Ok(Some(MIN_VERSION.clone())),
PubGrubPackage::Python(PubGrubPython::Installed) => {
@ -576,24 +575,22 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
// Emit a request to fetch the metadata for this version.
if self.index.distributions.register(candidate.package_id()) {
let dist = candidate.resolve().dist.clone();
request_sink.unbounded_send(Request::Dist(dist))?;
// Yield to allow subscribers to continue, as the channel is sync.
tokio::task::yield_now().await;
request_sink.send(Request::Dist(dist)).await?;
}
Ok(Some(version))
}
};
}
}
/// Given a candidate package and version, return its dependencies.
#[instrument(skip_all, fields(%package, %version))]
async fn get_dependencies(
&self,
package: &PubGrubPackage,
version: &Version,
priorities: &mut PubGrubPriorities,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
request_sink: &tokio::sync::mpsc::Sender<Request>,
) -> Result<Dependencies, ResolveError> {
match package {
PubGrubPackage::Root(_) => {
@ -724,8 +721,11 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
}
/// Fetch the metadata for a stream of packages and versions.
async fn fetch(&self, request_stream: UnboundedReceiver<Request>) -> Result<(), ResolveError> {
let mut response_stream = request_stream
async fn fetch(
&self,
request_stream: tokio::sync::mpsc::Receiver<Request>,
) -> Result<(), ResolveError> {
let mut response_stream = ReceiverStream::new(request_stream)
.map(|request| self.process_request(request).boxed())
.buffer_unordered(50);
@ -769,9 +769,6 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
}
None => {}
}
// Yield to allow subscribers to continue, as the channel is sync.
tokio::task::yield_now().await;
}
Ok::<(), ResolveError>(())
@ -902,7 +899,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
/// Fetch the metadata for an item
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
enum Request {
pub(crate) enum Request {
/// A request to fetch the metadata for a package.
Package(PackageName),
/// A request to fetch the metadata for a built or source distribution.
@ -915,10 +912,10 @@ impl Display for Request {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Request::Package(package_name) => {
write!(f, "Package {package_name}")
write!(f, "Versions {package_name}")
}
Request::Dist(dist) => {
write!(f, "Dist {dist}")
write!(f, "Metadata {dist}")
}
Request::Prefetch(package_name, range) => {
write!(f, "Prefetch {package_name} {range}")

View file

@ -1,8 +1,9 @@
use std::future::Future;
use std::ops::Deref;
use std::sync::Arc;
use anyhow::Result;
use chrono::{DateTime, Utc};
use futures::FutureExt;
use url::Url;
use distribution_types::Dist;
@ -45,17 +46,30 @@ pub trait ResolverProvider: Send + Sync {
/// The main IO backend for the resolver, which does cached requests network requests using the
/// [`RegistryClient`] and [`DistributionDatabase`].
pub struct DefaultResolverProvider<'a, Context: BuildContext + Send + Sync> {
/// The [`RegistryClient`] used to query the index.
client: &'a RegistryClient,
/// The [`DistributionDatabase`] used to build source distributions.
fetcher: DistributionDatabase<'a, Context>,
/// Allow moving the parameters to `VersionMap::from_metadata` to a different thread.
inner: Arc<DefaultResolverProviderInner>,
}
pub struct DefaultResolverProviderInner {
/// The [`RegistryClient`] used to query the index.
client: RegistryClient,
/// These are the entries from `--find-links` that act as overrides for index responses.
flat_index: &'a FlatIndex,
tags: &'a Tags,
flat_index: FlatIndex,
tags: Tags,
python_requirement: PythonRequirement,
exclude_newer: Option<DateTime<Utc>>,
allowed_yanks: AllowedYanks,
no_binary: &'a NoBinary,
no_binary: NoBinary,
}
impl<'a, Context: BuildContext + Send + Sync> Deref for DefaultResolverProvider<'a, Context> {
type Target = DefaultResolverProviderInner;
fn deref(&self) -> &Self::Target {
self.inner.as_ref()
}
}
impl<'a, Context: BuildContext + Send + Sync> DefaultResolverProvider<'a, Context> {
@ -72,14 +86,16 @@ impl<'a, Context: BuildContext + Send + Sync> DefaultResolverProvider<'a, Contex
no_binary: &'a NoBinary,
) -> Self {
Self {
client,
fetcher,
flat_index,
tags,
python_requirement,
exclude_newer,
allowed_yanks,
no_binary,
inner: Arc::new(DefaultResolverProviderInner {
client: client.clone(),
flat_index: flat_index.clone(),
tags: tags.clone(),
python_requirement,
exclude_newer,
allowed_yanks,
no_binary: no_binary.clone(),
}),
}
}
}
@ -87,43 +103,48 @@ impl<'a, Context: BuildContext + Send + Sync> DefaultResolverProvider<'a, Contex
impl<'a, Context: BuildContext + Send + Sync> ResolverProvider
for DefaultResolverProvider<'a, Context>
{
fn get_version_map<'io>(
&'io self,
package_name: &'io PackageName,
) -> impl Future<Output = VersionMapResponse> + Send + 'io {
self.client
.simple(package_name)
.map(move |result| match result {
Ok((index, metadata)) => Ok(VersionMap::from_metadata(
metadata,
package_name,
&index,
self.tags,
&self.python_requirement,
&self.allowed_yanks,
self.exclude_newer.as_ref(),
self.flat_index.get(package_name).cloned(),
self.no_binary,
)),
Err(err) => match err.into_kind() {
kind @ (puffin_client::ErrorKind::PackageNotFound(_)
| puffin_client::ErrorKind::NoIndex(_)) => {
if let Some(flat_index) = self.flat_index.get(package_name).cloned() {
Ok(VersionMap::from(flat_index))
} else {
Err(kind.into())
}
/// Make a simple api request for the package and convert the result to a [`VersionMap`].
async fn get_version_map<'io>(&'io self, package_name: &'io PackageName) -> VersionMapResponse {
let result = self.client.simple(package_name).await;
// If the simple api request was successful, perform on the slow conversion to `VersionMap` on the tokio
// threadpool
match result {
Ok((index, metadata)) => {
let self_send = self.inner.clone();
let package_name_owned = package_name.clone();
Ok(tokio::task::spawn_blocking(move || {
VersionMap::from_metadata(
metadata,
&package_name_owned,
&index,
&self_send.tags,
&self_send.python_requirement,
&self_send.allowed_yanks,
self_send.exclude_newer.as_ref(),
self_send.flat_index.get(&package_name_owned).cloned(),
&self_send.no_binary,
)
})
.await
.expect("Tokio executor failed, was there a panic?"))
}
Err(err) => match err.into_kind() {
kind @ (puffin_client::ErrorKind::PackageNotFound(_)
| puffin_client::ErrorKind::NoIndex(_)) => {
if let Some(flat_index) = self.flat_index.get(package_name).cloned() {
Ok(VersionMap::from(flat_index))
} else {
Err(kind.into())
}
kind => Err(kind.into()),
},
})
}
kind => Err(kind.into()),
},
}
}
fn get_or_build_wheel_metadata<'io>(
&'io self,
dist: &'io Dist,
) -> impl Future<Output = WheelMetadataResponse> + Send + 'io {
self.fetcher.get_or_build_wheel_metadata(dist)
async fn get_or_build_wheel_metadata<'io>(&'io self, dist: &'io Dist) -> WheelMetadataResponse {
self.fetcher.get_or_build_wheel_metadata(dist).await
}
/// Set the [`puffin_distribution::Reporter`] to use for this installer.

View file

@ -160,7 +160,7 @@ impl Display for BuildKind {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum NoBinary {
/// Allow installation of any wheel.
None,