mirror of
https://github.com/astral-sh/uv.git
synced 2025-07-07 21:35:00 +00:00
Concurrent progress bars (#3252)
## Summary
Implements concurrent progress bars. Resolves
https://github.com/astral-sh/uv/issues/1209.
## Test Plan
b21bdfbb
-8817-4873-a65c-16c9e8c7c460
This commit is contained in:
parent
70cbc32565
commit
7dc322665c
5 changed files with 477 additions and 230 deletions
|
@ -1,12 +1,14 @@
|
|||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::pin::Pin;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use futures::{FutureExt, TryStreamExt};
|
||||
use tempfile::TempDir;
|
||||
use tokio::io::AsyncSeekExt;
|
||||
use tokio::io::{AsyncRead, AsyncSeekExt, ReadBuf};
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio_util::compat::FuturesAsyncReadCompatExt;
|
||||
use tracing::{info_span, instrument, warn, Instrument};
|
||||
|
@ -49,6 +51,7 @@ pub struct DistributionDatabase<'a, Context: BuildContext> {
|
|||
builder: SourceDistributionBuilder<'a, Context>,
|
||||
locks: Rc<Locks>,
|
||||
client: ManagedClient<'a>,
|
||||
reporter: Option<Arc<dyn Reporter>>,
|
||||
}
|
||||
|
||||
impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
|
||||
|
@ -62,6 +65,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
|
|||
builder: SourceDistributionBuilder::new(build_context),
|
||||
locks: Rc::new(Locks::default()),
|
||||
client: ManagedClient::new(client, concurrent_downloads),
|
||||
reporter: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,6 +74,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
|
|||
pub fn with_reporter(self, reporter: impl Reporter + 'static) -> Self {
|
||||
let reporter = Arc::new(reporter);
|
||||
Self {
|
||||
reporter: Some(reporter.clone()),
|
||||
builder: self.builder.with_reporter(reporter),
|
||||
..self
|
||||
}
|
||||
|
@ -168,6 +173,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
|
|||
NoBinary::All => true,
|
||||
NoBinary::Packages(packages) => packages.contains(dist.name()),
|
||||
};
|
||||
|
||||
if no_binary {
|
||||
return Err(Error::NoBinary);
|
||||
}
|
||||
|
@ -188,6 +194,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
|
|||
WheelCache::Index(&wheel.index).wheel_dir(wheel.name().as_ref()),
|
||||
wheel.filename.stem(),
|
||||
);
|
||||
|
||||
return self
|
||||
.load_wheel(path, &wheel.filename, cache_entry, dist, hashes)
|
||||
.await;
|
||||
|
@ -203,7 +210,14 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
|
|||
|
||||
// Download and unzip.
|
||||
match self
|
||||
.stream_wheel(url.clone(), &wheel.filename, &wheel_entry, dist, hashes)
|
||||
.stream_wheel(
|
||||
url.clone(),
|
||||
&wheel.filename,
|
||||
wheel.file.size,
|
||||
&wheel_entry,
|
||||
dist,
|
||||
hashes,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(archive) => Ok(LocalWheel {
|
||||
|
@ -220,8 +234,16 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
|
|||
// If the request failed because streaming is unsupported, download the
|
||||
// wheel directly.
|
||||
let archive = self
|
||||
.download_wheel(url, &wheel.filename, &wheel_entry, dist, hashes)
|
||||
.download_wheel(
|
||||
url,
|
||||
&wheel.filename,
|
||||
wheel.file.size,
|
||||
&wheel_entry,
|
||||
dist,
|
||||
hashes,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(LocalWheel {
|
||||
dist: Dist::Built(dist.clone()),
|
||||
archive: self.build_context.cache().archive(&archive.id),
|
||||
|
@ -246,6 +268,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
|
|||
.stream_wheel(
|
||||
wheel.url.raw().clone(),
|
||||
&wheel.filename,
|
||||
None,
|
||||
&wheel_entry,
|
||||
dist,
|
||||
hashes,
|
||||
|
@ -269,6 +292,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
|
|||
.download_wheel(
|
||||
wheel.url.raw().clone(),
|
||||
&wheel.filename,
|
||||
None,
|
||||
&wheel_entry,
|
||||
dist,
|
||||
hashes,
|
||||
|
@ -427,6 +451,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
|
|||
&self,
|
||||
url: Url,
|
||||
filename: &WheelFilename,
|
||||
size: Option<u64>,
|
||||
wheel_entry: &CacheEntry,
|
||||
dist: &BuiltDist,
|
||||
hashes: HashPolicy<'_>,
|
||||
|
@ -434,8 +459,19 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
|
|||
// Create an entry for the HTTP cache.
|
||||
let http_entry = wheel_entry.with_file(format!("{}.http", filename.stem()));
|
||||
|
||||
// Fetch the archive from the cache, or download it if necessary.
|
||||
let req = self.request(url.clone())?;
|
||||
|
||||
// Extract the size from the `Content-Length` header, if not provided by the registry.
|
||||
let size = size.or_else(|| content_length(&req));
|
||||
|
||||
let download = |response: reqwest::Response| {
|
||||
async {
|
||||
let progress = self
|
||||
.reporter
|
||||
.as_ref()
|
||||
.map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
|
||||
|
||||
let reader = response
|
||||
.bytes_stream()
|
||||
.map_err(|err| self.handle_response_errors(err))
|
||||
|
@ -449,7 +485,16 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
|
|||
// Download and unzip the wheel to a temporary directory.
|
||||
let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
|
||||
.map_err(Error::CacheWrite)?;
|
||||
uv_extract::stream::unzip(&mut hasher, temp_dir.path()).await?;
|
||||
|
||||
match progress {
|
||||
Some((reporter, progress)) => {
|
||||
let mut reader = ProgressReader::new(&mut hasher, progress, &**reporter);
|
||||
uv_extract::stream::unzip(&mut reader, temp_dir.path()).await?;
|
||||
}
|
||||
None => {
|
||||
uv_extract::stream::unzip(&mut hasher, temp_dir.path()).await?;
|
||||
}
|
||||
}
|
||||
|
||||
// If necessary, exhaust the reader to compute the hash.
|
||||
if !hashes.is_none() {
|
||||
|
@ -464,6 +509,10 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
|
|||
.await
|
||||
.map_err(Error::CacheRead)?;
|
||||
|
||||
if let Some((reporter, progress)) = progress {
|
||||
reporter.on_download_complete(dist.name(), progress);
|
||||
}
|
||||
|
||||
Ok(Archive::new(
|
||||
id,
|
||||
hashers.into_iter().map(HashDigest::from).collect(),
|
||||
|
@ -523,6 +572,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
|
|||
&self,
|
||||
url: Url,
|
||||
filename: &WheelFilename,
|
||||
size: Option<u64>,
|
||||
wheel_entry: &CacheEntry,
|
||||
dist: &BuiltDist,
|
||||
hashes: HashPolicy<'_>,
|
||||
|
@ -530,8 +580,18 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
|
|||
// Create an entry for the HTTP cache.
|
||||
let http_entry = wheel_entry.with_file(format!("{}.http", filename.stem()));
|
||||
|
||||
let req = self.request(url.clone())?;
|
||||
|
||||
// Extract the size from the `Content-Length` header, if not provided by the registry.
|
||||
let size = size.or_else(|| content_length(&req));
|
||||
|
||||
let download = |response: reqwest::Response| {
|
||||
async {
|
||||
let progress = self
|
||||
.reporter
|
||||
.as_ref()
|
||||
.map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
|
||||
|
||||
let reader = response
|
||||
.bytes_stream()
|
||||
.map_err(|err| self.handle_response_errors(err))
|
||||
|
@ -541,9 +601,25 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
|
|||
let temp_file = tempfile::tempfile_in(self.build_context.cache().root())
|
||||
.map_err(Error::CacheWrite)?;
|
||||
let mut writer = tokio::io::BufWriter::new(tokio::fs::File::from_std(temp_file));
|
||||
tokio::io::copy(&mut reader.compat(), &mut writer)
|
||||
.await
|
||||
.map_err(Error::CacheWrite)?;
|
||||
|
||||
match progress {
|
||||
Some((reporter, progress)) => {
|
||||
// Wrap the reader in a progress reporter. This will report 100% progress
|
||||
// after the download is complete, even if we still have to unzip and hash
|
||||
// part of the file.
|
||||
let mut reader =
|
||||
ProgressReader::new(reader.compat(), progress, &**reporter);
|
||||
|
||||
tokio::io::copy(&mut reader, &mut writer)
|
||||
.await
|
||||
.map_err(Error::CacheWrite)?;
|
||||
}
|
||||
None => {
|
||||
tokio::io::copy(&mut reader.compat(), &mut writer)
|
||||
.await
|
||||
.map_err(Error::CacheWrite)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Unzip the wheel to a temporary directory.
|
||||
let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
|
||||
|
@ -588,6 +664,10 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
|
|||
.await
|
||||
.map_err(Error::CacheRead)?;
|
||||
|
||||
if let Some((reporter, progress)) = progress {
|
||||
reporter.on_download_complete(dist.name(), progress);
|
||||
}
|
||||
|
||||
Ok(Archive::new(id, hashes))
|
||||
}
|
||||
.instrument(info_span!("wheel", wheel = %dist))
|
||||
|
@ -813,6 +893,50 @@ impl<'a> ManagedClient<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Returns the value of the `Content-Length` header from the [`reqwest::Request`], if present.
|
||||
fn content_length(req: &reqwest::Request) -> Option<u64> {
|
||||
req.headers()
|
||||
.get(reqwest::header::CONTENT_LENGTH)
|
||||
.and_then(|val| val.to_str().ok())
|
||||
.and_then(|val| val.parse::<u64>().ok())
|
||||
}
|
||||
|
||||
/// An asynchronous reader that reports progress as bytes are read.
|
||||
struct ProgressReader<'a, R> {
|
||||
reader: R,
|
||||
index: usize,
|
||||
reporter: &'a dyn Reporter,
|
||||
}
|
||||
|
||||
impl<'a, R> ProgressReader<'a, R> {
|
||||
/// Create a new [`ProgressReader`] that wraps another reader.
|
||||
fn new(reader: R, index: usize, reporter: &'a dyn Reporter) -> Self {
|
||||
Self {
|
||||
reader,
|
||||
index,
|
||||
reporter,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> AsyncRead for ProgressReader<'_, R>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.as_mut().reader)
|
||||
.poll_read(cx, buf)
|
||||
.map_ok(|()| {
|
||||
self.reporter
|
||||
.on_download_progress(self.index, buf.filled().len() as u64);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A pointer to an archive in the cache, fetched from an HTTP archive.
|
||||
///
|
||||
/// Encoded with `MsgPack`, and represented on disk by a `.http` file.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue