Concurrent progress bars (#3252)

## Summary

Implements concurrent progress bars. Resolves
https://github.com/astral-sh/uv/issues/1209.

## Test Plan

b21bdfbb-8817-4873-a65c-16c9e8c7c460
This commit is contained in:
Ibraheem Ahmed 2024-05-26 21:21:07 -04:00 committed by GitHub
parent 70cbc32565
commit 7dc322665c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 477 additions and 230 deletions

View file

@ -1,12 +1,14 @@
use std::future::Future;
use std::io;
use std::path::Path;
use std::pin::Pin;
use std::rc::Rc;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures::{FutureExt, TryStreamExt};
use tempfile::TempDir;
use tokio::io::AsyncSeekExt;
use tokio::io::{AsyncRead, AsyncSeekExt, ReadBuf};
use tokio::sync::Semaphore;
use tokio_util::compat::FuturesAsyncReadCompatExt;
use tracing::{info_span, instrument, warn, Instrument};
@ -49,6 +51,7 @@ pub struct DistributionDatabase<'a, Context: BuildContext> {
builder: SourceDistributionBuilder<'a, Context>,
locks: Rc<Locks>,
client: ManagedClient<'a>,
reporter: Option<Arc<dyn Reporter>>,
}
impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
@ -62,6 +65,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
builder: SourceDistributionBuilder::new(build_context),
locks: Rc::new(Locks::default()),
client: ManagedClient::new(client, concurrent_downloads),
reporter: None,
}
}
@ -70,6 +74,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
pub fn with_reporter(self, reporter: impl Reporter + 'static) -> Self {
let reporter = Arc::new(reporter);
Self {
reporter: Some(reporter.clone()),
builder: self.builder.with_reporter(reporter),
..self
}
@ -168,6 +173,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
NoBinary::All => true,
NoBinary::Packages(packages) => packages.contains(dist.name()),
};
if no_binary {
return Err(Error::NoBinary);
}
@ -188,6 +194,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
WheelCache::Index(&wheel.index).wheel_dir(wheel.name().as_ref()),
wheel.filename.stem(),
);
return self
.load_wheel(path, &wheel.filename, cache_entry, dist, hashes)
.await;
@ -203,7 +210,14 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
// Download and unzip.
match self
.stream_wheel(url.clone(), &wheel.filename, &wheel_entry, dist, hashes)
.stream_wheel(
url.clone(),
&wheel.filename,
wheel.file.size,
&wheel_entry,
dist,
hashes,
)
.await
{
Ok(archive) => Ok(LocalWheel {
@ -220,8 +234,16 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
// If the request failed because streaming is unsupported, download the
// wheel directly.
let archive = self
.download_wheel(url, &wheel.filename, &wheel_entry, dist, hashes)
.download_wheel(
url,
&wheel.filename,
wheel.file.size,
&wheel_entry,
dist,
hashes,
)
.await?;
Ok(LocalWheel {
dist: Dist::Built(dist.clone()),
archive: self.build_context.cache().archive(&archive.id),
@ -246,6 +268,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
.stream_wheel(
wheel.url.raw().clone(),
&wheel.filename,
None,
&wheel_entry,
dist,
hashes,
@ -269,6 +292,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
.download_wheel(
wheel.url.raw().clone(),
&wheel.filename,
None,
&wheel_entry,
dist,
hashes,
@ -427,6 +451,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
&self,
url: Url,
filename: &WheelFilename,
size: Option<u64>,
wheel_entry: &CacheEntry,
dist: &BuiltDist,
hashes: HashPolicy<'_>,
@ -434,8 +459,19 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
// Create an entry for the HTTP cache.
let http_entry = wheel_entry.with_file(format!("{}.http", filename.stem()));
// Fetch the archive from the cache, or download it if necessary.
let req = self.request(url.clone())?;
// Extract the size from the `Content-Length` header, if not provided by the registry.
let size = size.or_else(|| content_length(&req));
let download = |response: reqwest::Response| {
async {
let progress = self
.reporter
.as_ref()
.map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
let reader = response
.bytes_stream()
.map_err(|err| self.handle_response_errors(err))
@ -449,7 +485,16 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
// Download and unzip the wheel to a temporary directory.
let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
.map_err(Error::CacheWrite)?;
uv_extract::stream::unzip(&mut hasher, temp_dir.path()).await?;
match progress {
Some((reporter, progress)) => {
let mut reader = ProgressReader::new(&mut hasher, progress, &**reporter);
uv_extract::stream::unzip(&mut reader, temp_dir.path()).await?;
}
None => {
uv_extract::stream::unzip(&mut hasher, temp_dir.path()).await?;
}
}
// If necessary, exhaust the reader to compute the hash.
if !hashes.is_none() {
@ -464,6 +509,10 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
.await
.map_err(Error::CacheRead)?;
if let Some((reporter, progress)) = progress {
reporter.on_download_complete(dist.name(), progress);
}
Ok(Archive::new(
id,
hashers.into_iter().map(HashDigest::from).collect(),
@ -523,6 +572,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
&self,
url: Url,
filename: &WheelFilename,
size: Option<u64>,
wheel_entry: &CacheEntry,
dist: &BuiltDist,
hashes: HashPolicy<'_>,
@ -530,8 +580,18 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
// Create an entry for the HTTP cache.
let http_entry = wheel_entry.with_file(format!("{}.http", filename.stem()));
let req = self.request(url.clone())?;
// Extract the size from the `Content-Length` header, if not provided by the registry.
let size = size.or_else(|| content_length(&req));
let download = |response: reqwest::Response| {
async {
let progress = self
.reporter
.as_ref()
.map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
let reader = response
.bytes_stream()
.map_err(|err| self.handle_response_errors(err))
@ -541,9 +601,25 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
let temp_file = tempfile::tempfile_in(self.build_context.cache().root())
.map_err(Error::CacheWrite)?;
let mut writer = tokio::io::BufWriter::new(tokio::fs::File::from_std(temp_file));
tokio::io::copy(&mut reader.compat(), &mut writer)
.await
.map_err(Error::CacheWrite)?;
match progress {
Some((reporter, progress)) => {
// Wrap the reader in a progress reporter. This will report 100% progress
// after the download is complete, even if we still have to unzip and hash
// part of the file.
let mut reader =
ProgressReader::new(reader.compat(), progress, &**reporter);
tokio::io::copy(&mut reader, &mut writer)
.await
.map_err(Error::CacheWrite)?;
}
None => {
tokio::io::copy(&mut reader.compat(), &mut writer)
.await
.map_err(Error::CacheWrite)?;
}
}
// Unzip the wheel to a temporary directory.
let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
@ -588,6 +664,10 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
.await
.map_err(Error::CacheRead)?;
if let Some((reporter, progress)) = progress {
reporter.on_download_complete(dist.name(), progress);
}
Ok(Archive::new(id, hashes))
}
.instrument(info_span!("wheel", wheel = %dist))
@ -813,6 +893,50 @@ impl<'a> ManagedClient<'a> {
}
}
/// Returns the value of the `Content-Length` header from the [`reqwest::Request`], if present.
fn content_length(req: &reqwest::Request) -> Option<u64> {
req.headers()
.get(reqwest::header::CONTENT_LENGTH)
.and_then(|val| val.to_str().ok())
.and_then(|val| val.parse::<u64>().ok())
}
/// An asynchronous reader that reports progress as bytes are read.
struct ProgressReader<'a, R> {
reader: R,
index: usize,
reporter: &'a dyn Reporter,
}
impl<'a, R> ProgressReader<'a, R> {
/// Create a new [`ProgressReader`] that wraps another reader.
fn new(reader: R, index: usize, reporter: &'a dyn Reporter) -> Self {
Self {
reader,
index,
reporter,
}
}
}
impl<R> AsyncRead for ProgressReader<'_, R>
where
R: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.as_mut().reader)
.poll_read(cx, buf)
.map_ok(|()| {
self.reporter
.on_download_progress(self.index, buf.filled().len() as u64);
})
}
}
/// A pointer to an archive in the cache, fetched from an HTTP archive.
///
/// Encoded with `MsgPack`, and represented on disk by a `.http` file.