diff --git a/crates/uv-client/src/base_client.rs b/crates/uv-client/src/base_client.rs index 8b06099d1..8e84d4d4d 100644 --- a/crates/uv-client/src/base_client.rs +++ b/crates/uv-client/src/base_client.rs @@ -165,6 +165,11 @@ impl<'a> BaseClientBuilder<'a> { matches!(self.connectivity, Connectivity::Offline) } + /// Create a [`RetryPolicy`] for the client. + fn retry_policy(&self) -> ExponentialBackoff { + ExponentialBackoff::builder().build_with_max_retries(self.retries) + } + pub fn build(&self) -> BaseClient { // Create user agent. let mut user_agent_string = format!("uv/{}", version()); @@ -229,6 +234,7 @@ impl<'a> BaseClientBuilder<'a> { BaseClient { connectivity: self.connectivity, allow_insecure_host: self.allow_insecure_host.clone(), + retries: self.retries, client, raw_client, dangerous_client, @@ -246,6 +252,7 @@ impl<'a> BaseClientBuilder<'a> { BaseClient { connectivity: self.connectivity, allow_insecure_host: self.allow_insecure_host.clone(), + retries: self.retries, client, dangerous_client, raw_client: existing.raw_client.clone(), @@ -307,10 +314,8 @@ impl<'a> BaseClientBuilder<'a> { // Avoid uncloneable errors with a streaming body during publish. if self.retries > 0 { // Initialize the retry strategy. - let retry_policy = - ExponentialBackoff::builder().build_with_max_retries(self.retries); let retry_strategy = RetryTransientMiddleware::new_with_policy_and_strategy( - retry_policy, + self.retry_policy(), UvRetryableStrategy, ); client = client.with(retry_strategy); @@ -367,6 +372,8 @@ pub struct BaseClient { timeout: Duration, /// Hosts that are trusted to use the insecure client. allow_insecure_host: Vec, + /// The number of retries to attempt on transient errors. + retries: u32, } #[derive(Debug, Clone, Copy)] @@ -400,6 +407,16 @@ impl BaseClient { pub fn connectivity(&self) -> Connectivity { self.connectivity } + + /// The number of retries to attempt on transient errors. + pub fn retries(&self) -> u32 { + self.retries + } + + /// The [`RetryPolicy`] for the client. + pub fn retry_policy(&self) -> ExponentialBackoff { + ExponentialBackoff::builder().build_with_max_retries(self.retries) + } } /// Extends [`DefaultRetryableStrategy`], to log transient request failures and additional retry cases. @@ -409,7 +426,11 @@ impl RetryableStrategy for UvRetryableStrategy { fn handle(&self, res: &Result) -> Option { // Use the default strategy and check for additional transient error cases. let retryable = match DefaultRetryableStrategy.handle(res) { - None | Some(Retryable::Fatal) if is_extended_transient_error(res) => { + None | Some(Retryable::Fatal) + if res + .as_ref() + .is_err_and(|err| is_extended_transient_error(err)) => + { Some(Retryable::Transient) } default => default, @@ -427,7 +448,7 @@ impl RetryableStrategy for UvRetryableStrategy { .join("\n"); debug!( "Transient request failure for {}, retrying: {err}\n{context}", - err.url().map(reqwest::Url::as_str).unwrap_or("unknown URL") + err.url().map(Url::as_str).unwrap_or("unknown URL") ); } } @@ -439,9 +460,18 @@ impl RetryableStrategy for UvRetryableStrategy { /// Check for additional transient error kinds not supported by the default retry strategy in `reqwest_retry`. /// /// These cases should be safe to retry with [`Retryable::Transient`]. -fn is_extended_transient_error(res: &Result) -> bool { - // Check for connection reset errors, these are usually `Body` errors which are not retried by default. - if let Err(reqwest_middleware::Error::Reqwest(err)) = res { +pub(crate) fn is_extended_transient_error(err: &dyn Error) -> bool { + if let Some(err) = find_source::(&err) { + if let Some(io) = find_source::(&err) { + if io.kind() == std::io::ErrorKind::ConnectionReset + || io.kind() == std::io::ErrorKind::UnexpectedEof + { + return true; + } + } + } + + if let Some(err) = find_source::(&err) { if let Some(io) = find_source::(&err) { if io.kind() == std::io::ErrorKind::ConnectionReset || io.kind() == std::io::ErrorKind::UnexpectedEof @@ -457,7 +487,7 @@ fn is_extended_transient_error(res: &Result /// Find the first source error of a specific type. /// /// See -fn find_source(orig: &dyn std::error::Error) -> Option<&E> { +fn find_source(orig: &dyn Error) -> Option<&E> { let mut cause = orig.source(); while let Some(err) = cause { if let Some(typed) = err.downcast_ref() { @@ -465,7 +495,5 @@ fn find_source(orig: &dyn std::error::Error) -> } cause = err.source(); } - - // else None } diff --git a/crates/uv-client/src/cached_client.rs b/crates/uv-client/src/cached_client.rs index 56266c6f8..509429014 100644 --- a/crates/uv-client/src/cached_client.rs +++ b/crates/uv-client/src/cached_client.rs @@ -1,7 +1,10 @@ +use std::fmt::{Debug, Display, Formatter}; +use std::time::{Duration, SystemTime}; use std::{borrow::Cow, future::Future, path::Path}; use futures::FutureExt; use reqwest::{Request, Response}; +use reqwest_retry::RetryPolicy; use rkyv::util::AlignedVec; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; @@ -10,6 +13,7 @@ use tracing::{debug, info_span, instrument, trace, warn, Instrument}; use uv_cache::{CacheEntry, Freshness}; use uv_fs::write_atomic; +use crate::base_client::is_extended_transient_error; use crate::BaseClient; use crate::{ httpcache::{AfterResponse, BeforeRequest, CachePolicy, CachePolicyBuilder}, @@ -39,7 +43,7 @@ pub trait Cacheable: Sized { type Target; /// Deserialize a value from bytes aligned to a 16-byte boundary. - fn from_aligned_bytes(bytes: AlignedVec) -> Result; + fn from_aligned_bytes(bytes: AlignedVec) -> Result; /// Serialize bytes to a possibly owned byte buffer. fn to_bytes(&self) -> Result, crate::Error>; /// Convert this type into its final form. @@ -75,9 +79,6 @@ impl Cacheable for SerdeCacheable { /// All `OwnedArchive` values are cacheable. impl Cacheable for OwnedArchive where - // A: rkyv::Archive + rkyv::Serialize>, - // A::Archived: for<'a> rkyv::bytecheck::CheckBytes> - // + rkyv::Deserialize, A: rkyv::Archive + for<'a> rkyv::Serialize>, A::Archived: rkyv::Portable + rkyv::Deserialize @@ -99,25 +100,55 @@ where } /// Either a cached client error or a (user specified) error from the callback -#[derive(Debug)] -pub enum CachedClientError { +pub enum CachedClientError { Client(Error), Callback(CallbackError), } -impl From for CachedClientError { +impl Display for CachedClientError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + CachedClientError::Client(err) => write!(f, "{err}"), + CachedClientError::Callback(err) => write!(f, "{err}"), + } + } +} + +impl Debug for CachedClientError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + CachedClientError::Client(err) => write!(f, "{err:?}"), + CachedClientError::Callback(err) => write!(f, "{err:?}"), + } + } +} + +impl std::error::Error + for CachedClientError +{ + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + CachedClientError::Client(err) => Some(err), + CachedClientError::Callback(err) => Some(err), + } + } +} + +impl From for CachedClientError { fn from(error: Error) -> Self { Self::Client(error) } } -impl From for CachedClientError { +impl From + for CachedClientError +{ fn from(error: ErrorKind) -> Self { Self::Client(error.into()) } } -impl> From> for Error { +impl + std::error::Error + 'static> From> for Error { fn from(error: CachedClientError) -> Self { match error { CachedClientError::Client(error) => error, @@ -184,7 +215,7 @@ impl CachedClient { #[instrument(skip_all)] pub async fn get_serde< Payload: Serialize + DeserializeOwned + 'static, - CallBackError, + CallBackError: std::error::Error + 'static, Callback, CallbackReturn, >( @@ -195,11 +226,11 @@ impl CachedClient { response_callback: Callback, ) -> Result> where - Callback: FnOnce(Response) -> CallbackReturn, + Callback: Fn(Response) -> CallbackReturn, CallbackReturn: Future>, { let payload = self - .get_cacheable(req, cache_entry, cache_control, move |resp| async { + .get_cacheable(req, cache_entry, cache_control, |resp| async { let payload = response_callback(resp).await?; Ok(SerdeCacheable { inner: payload }) }) @@ -220,7 +251,12 @@ impl CachedClient { /// only the result is cached and returned. The `response_callback` is /// allowed to make subsequent requests, e.g. through the uncached client. #[instrument(skip_all)] - pub async fn get_cacheable( + pub async fn get_cacheable< + Payload: Cacheable, + CallBackError: std::error::Error + 'static, + Callback, + CallbackReturn, + >( &self, req: Request, cache_entry: &CacheEntry, @@ -228,7 +264,7 @@ impl CachedClient { response_callback: Callback, ) -> Result> where - Callback: FnOnce(Response) -> CallbackReturn, + Callback: Fn(Response) -> CallbackReturn, CallbackReturn: Future>, { let fresh_req = req.try_clone().expect("HTTP request must be cloneable"); @@ -307,7 +343,7 @@ impl CachedClient { /// Make a request without checking whether the cache is fresh. pub async fn skip_cache< Payload: Serialize + DeserializeOwned + 'static, - CallBackError, + CallBackError: std::error::Error + 'static, Callback, CallbackReturn, >( @@ -332,7 +368,12 @@ impl CachedClient { Ok(payload) } - async fn resend_and_heal_cache( + async fn resend_and_heal_cache< + Payload: Cacheable, + CallBackError: std::error::Error + 'static, + Callback, + CallbackReturn, + >( &self, req: Request, cache_entry: &CacheEntry, @@ -348,7 +389,12 @@ impl CachedClient { .await } - async fn run_response_callback( + async fn run_response_callback< + Payload: Cacheable, + CallBackError: std::error::Error + 'static, + Callback, + CallbackReturn, + >( &self, cache_entry: &CacheEntry, cache_policy: Option>, @@ -519,6 +565,133 @@ impl CachedClient { }; Ok((response, cache_policy)) } + + /// Perform a [`CachedClient::get_serde`] request with a default retry strategy. + #[instrument(skip_all)] + pub async fn get_serde_with_retry< + Payload: Serialize + DeserializeOwned + 'static, + CallBackError: std::error::Error + 'static, + Callback, + CallbackReturn, + >( + &self, + req: Request, + cache_entry: &CacheEntry, + cache_control: CacheControl, + response_callback: Callback, + ) -> Result> + where + Callback: Fn(Response) -> CallbackReturn, + CallbackReturn: Future>, + { + let payload = self + .get_cacheable_with_retry(req, cache_entry, cache_control, |resp| async { + let payload = response_callback(resp).await?; + Ok(SerdeCacheable { inner: payload }) + }) + .await?; + Ok(payload) + } + + /// Perform a [`CachedClient::get_cacheable`] request with a default retry strategy. + /// + /// See: + #[instrument(skip_all)] + pub async fn get_cacheable_with_retry< + Payload: Cacheable, + CallBackError: std::error::Error + 'static, + Callback, + CallbackReturn, + >( + &self, + req: Request, + cache_entry: &CacheEntry, + cache_control: CacheControl, + response_callback: Callback, + ) -> Result> + where + Callback: Fn(Response) -> CallbackReturn, + CallbackReturn: Future>, + { + let mut n_past_retries = 0; + let start_time = SystemTime::now(); + let retry_policy = self.uncached().retry_policy(); + loop { + let fresh_req = req.try_clone().expect("HTTP request must be cloneable"); + let result = self + .get_cacheable(fresh_req, cache_entry, cache_control, &response_callback) + .await; + if let Some(err) = result + .as_ref() + .err() + .filter(|err| is_extended_transient_error(err)) + { + let retry_decision = retry_policy.should_retry(start_time, n_past_retries); + if let reqwest_retry::RetryDecision::Retry { execute_after } = retry_decision { + debug!( + "Transient failure while handling response from {}; retrying: {err}", + req.url(), + ); + let duration = execute_after + .duration_since(SystemTime::now()) + .unwrap_or_else(|_| Duration::default()); + tokio::time::sleep(duration).await; + n_past_retries += 1; + continue; + } + } + return result; + } + } + + /// Perform a [`CachedClient::skip_cache`] request with a default retry strategy. + /// + /// See: + pub async fn skip_cache_with_retry< + Payload: Serialize + DeserializeOwned + 'static, + CallBackError: std::error::Error + 'static, + Callback, + CallbackReturn, + >( + &self, + req: Request, + cache_entry: &CacheEntry, + response_callback: Callback, + ) -> Result> + where + Callback: Fn(Response) -> CallbackReturn, + CallbackReturn: Future>, + { + let mut n_past_retries = 0; + let start_time = SystemTime::now(); + let retry_policy = self.uncached().retry_policy(); + loop { + let fresh_req = req.try_clone().expect("HTTP request must be cloneable"); + let result = self + .skip_cache(fresh_req, cache_entry, &response_callback) + .await; + if let Some(err) = result + .as_ref() + .err() + .filter(|err| is_extended_transient_error(err)) + { + let retry_decision = retry_policy.should_retry(start_time, n_past_retries); + if let reqwest_retry::RetryDecision::Retry { execute_after } = retry_decision { + debug!( + "Transient failure while handling response from {}; retrying: {err}", + req.url(), + ); + let duration = execute_after + .duration_since(SystemTime::now()) + .unwrap_or_else(|_| Duration::default()); + tokio::time::sleep(duration).await; + n_past_retries += 1; + continue; + } + } + return result; + } + } } #[derive(Debug)] diff --git a/crates/uv-client/src/flat_index.rs b/crates/uv-client/src/flat_index.rs index ef7caf7fd..50620fe09 100644 --- a/crates/uv-client/src/flat_index.rs +++ b/crates/uv-client/src/flat_index.rs @@ -195,7 +195,7 @@ impl<'a> FlatIndexClient<'a> { let response = self .client .cached_client() - .get_cacheable( + .get_cacheable_with_retry( flat_index_request, &cache_entry, cache_control, diff --git a/crates/uv-client/src/registry_client.rs b/crates/uv-client/src/registry_client.rs index 58982426a..6cbb0d8c1 100644 --- a/crates/uv-client/src/registry_client.rs +++ b/crates/uv-client/src/registry_client.rs @@ -397,7 +397,7 @@ impl RegistryClient { .instrument(info_span!("parse_simple_api", package = %package_name)) }; self.cached_client() - .get_cacheable( + .get_cacheable_with_retry( simple_request, cache_entry, cache_control, @@ -589,7 +589,7 @@ impl RegistryClient { .in_scope(|| ResolutionMetadata::parse_metadata(bytes.as_ref())) .map_err(|err| { Error::from(ErrorKind::MetadataParseError( - filename, + filename.clone(), url.to_string(), Box::new(err), )) @@ -602,7 +602,7 @@ impl RegistryClient { .map_err(|err| ErrorKind::from_reqwest(url.clone(), err))?; Ok(self .cached_client() - .get_serde(req, &cache_entry, cache_control, response_callback) + .get_serde_with_retry(req, &cache_entry, cache_control, response_callback) .await?) } else { // If we lack PEP 658 support, try using HTTP range requests to read only the @@ -668,7 +668,7 @@ impl RegistryClient { self.uncached_client(url).clone(), response, url.clone(), - headers, + headers.clone(), ) .await .map_err(|err| ErrorKind::AsyncHttpRangeReader(url.clone(), err))?; @@ -690,7 +690,7 @@ impl RegistryClient { let result = self .cached_client() - .get_serde( + .get_serde_with_retry( req, &cache_entry, cache_control, @@ -748,7 +748,7 @@ impl RegistryClient { }; self.cached_client() - .get_serde(req, &cache_entry, cache_control, read_metadata_stream) + .get_serde_with_retry(req, &cache_entry, cache_control, read_metadata_stream) .await .map_err(crate::Error::from) } diff --git a/crates/uv-distribution/src/distribution_database.rs b/crates/uv-distribution/src/distribution_database.rs index f96154f33..1dbdf064f 100644 --- a/crates/uv-distribution/src/distribution_database.rs +++ b/crates/uv-distribution/src/distribution_database.rs @@ -556,9 +556,12 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { let archive = self .client .managed(|client| { - client - .cached_client() - .get_serde(req, &http_entry, cache_control, download) + client.cached_client().get_serde_with_retry( + req, + &http_entry, + cache_control, + download, + ) }) .await .map_err(|err| match err { @@ -578,7 +581,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { .managed(|client| async { client .cached_client() - .skip_cache(self.request(url)?, &http_entry, download) + .skip_cache_with_retry(self.request(url)?, &http_entry, download) .await .map_err(|err| match err { CachedClientError::Callback(err) => err, @@ -710,9 +713,12 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { let archive = self .client .managed(|client| { - client - .cached_client() - .get_serde(req, &http_entry, cache_control, download) + client.cached_client().get_serde_with_retry( + req, + &http_entry, + cache_control, + download, + ) }) .await .map_err(|err| match err { @@ -732,7 +738,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { .managed(|client| async { client .cached_client() - .skip_cache(self.request(url)?, &http_entry, download) + .skip_cache_with_retry(self.request(url)?, &http_entry, download) .await .map_err(|err| match err { CachedClientError::Callback(err) => err, diff --git a/crates/uv-distribution/src/source/mod.rs b/crates/uv-distribution/src/source/mod.rs index 5739b2c8e..daaff2952 100644 --- a/crates/uv-distribution/src/source/mod.rs +++ b/crates/uv-distribution/src/source/mod.rs @@ -654,9 +654,12 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { let req = Self::request(url.clone(), client.unmanaged)?; let revision = client .managed(|client| { - client - .cached_client() - .get_serde(req, &cache_entry, cache_control, download) + client.cached_client().get_serde_with_retry( + req, + &cache_entry, + cache_control, + download, + ) }) .await .map_err(|err| match err { @@ -672,7 +675,11 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { .managed(|client| async move { client .cached_client() - .skip_cache(Self::request(url.clone(), client)?, &cache_entry, download) + .skip_cache_with_retry( + Self::request(url.clone(), client)?, + &cache_entry, + download, + ) .await .map_err(|err| match err { CachedClientError::Callback(err) => err, @@ -1584,7 +1591,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { return Err(Error::CacheHeal(source.to_string(), existing.algorithm())); } } - Ok(revision.with_hashes(hashes)) + Ok(revision.clone().with_hashes(hashes)) } .boxed_local() .instrument(info_span!("download", source_dist = %source)) @@ -1593,7 +1600,11 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { .managed(|client| async move { client .cached_client() - .skip_cache(Self::request(url.clone(), client)?, &cache_entry, download) + .skip_cache_with_retry( + Self::request(url.clone(), client)?, + &cache_entry, + download, + ) .await .map_err(|err| match err { CachedClientError::Callback(err) => err, diff --git a/crates/uv-publish/src/lib.rs b/crates/uv-publish/src/lib.rs index 338bb1ef4..4ecc17f89 100644 --- a/crates/uv-publish/src/lib.rs +++ b/crates/uv-publish/src/lib.rs @@ -378,7 +378,6 @@ pub async fn upload( // Retry loop let mut attempt = 0; loop { - attempt += 1; let (request, idx) = build_request( file, raw_filename, @@ -397,6 +396,7 @@ pub async fn upload( if attempt < retries && UvRetryableStrategy.handle(&result) == Some(Retryable::Transient) { reporter.on_download_complete(idx); warn_user!("Transient request failure for {}, retrying", registry); + attempt += 1; continue; } diff --git a/crates/uv-resolver/src/bare.rs b/crates/uv-resolver/src/bare.rs deleted file mode 100644 index 8b1378917..000000000 --- a/crates/uv-resolver/src/bare.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/crates/uv-resolver/src/lib.rs b/crates/uv-resolver/src/lib.rs index af379a95f..af0193d9a 100644 --- a/crates/uv-resolver/src/lib.rs +++ b/crates/uv-resolver/src/lib.rs @@ -32,9 +32,7 @@ pub use yanks::AllowedYanks; /// `ConflictItemRef`. i.e., We can avoid allocs on lookups. type FxHashbrownSet = hashbrown::HashSet; -mod bare; mod candidate_selector; - mod dependency_mode; mod dependency_provider; mod error;