Introduce a BaseClient for construction of canonical configured client (#2431)

In preparation for support of
https://github.com/astral-sh/uv/issues/2357 (see
https://github.com/astral-sh/uv/pull/2434)
This commit is contained in:
Zanie Blue 2024-03-15 12:07:38 -05:00 committed by GitHub
parent 8463d6d672
commit 9c27f92203
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 240 additions and 115 deletions

View file

@ -0,0 +1,188 @@
use reqwest::{Client, ClientBuilder};
use reqwest_middleware::ClientWithMiddleware;
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::RetryTransientMiddleware;
use std::env;
use std::fmt::Debug;
use std::ops::Deref;
use std::path::Path;
use tracing::debug;
use uv_auth::{AuthMiddleware, KeyringProvider};
use uv_fs::Simplified;
use uv_version::version;
use uv_warnings::warn_user_once;
use crate::middleware::OfflineMiddleware;
use crate::tls::Roots;
use crate::{tls, Connectivity};
/// A builder for an [`RegistryClient`].
#[derive(Debug, Clone)]
pub struct BaseClientBuilder {
keyring_provider: KeyringProvider,
native_tls: bool,
retries: u32,
connectivity: Connectivity,
client: Option<Client>,
}
impl BaseClientBuilder {
pub fn new() -> Self {
Self {
keyring_provider: KeyringProvider::default(),
native_tls: false,
connectivity: Connectivity::Online,
retries: 3,
client: None,
}
}
}
impl BaseClientBuilder {
#[must_use]
pub fn keyring_provider(mut self, keyring_provider: KeyringProvider) -> Self {
self.keyring_provider = keyring_provider;
self
}
#[must_use]
pub fn connectivity(mut self, connectivity: Connectivity) -> Self {
self.connectivity = connectivity;
self
}
#[must_use]
pub fn retries(mut self, retries: u32) -> Self {
self.retries = retries;
self
}
#[must_use]
pub fn native_tls(mut self, native_tls: bool) -> Self {
self.native_tls = native_tls;
self
}
#[must_use]
pub fn client(mut self, client: Client) -> Self {
self.client = Some(client);
self
}
pub fn build(self) -> BaseClient {
// Create user agent.
let user_agent_string = format!("uv/{}", version());
// Timeout options, matching https://doc.rust-lang.org/nightly/cargo/reference/config.html#httptimeout
// `UV_REQUEST_TIMEOUT` is provided for backwards compatibility with v0.1.6
let default_timeout = 5 * 60;
let timeout = env::var("UV_HTTP_TIMEOUT")
.or_else(|_| env::var("UV_REQUEST_TIMEOUT"))
.or_else(|_| env::var("HTTP_TIMEOUT"))
.and_then(|value| {
value.parse::<u64>()
.or_else(|_| {
// On parse error, warn and use the default timeout
warn_user_once!("Ignoring invalid value from environment for UV_HTTP_TIMEOUT. Expected integer number of seconds, got \"{value}\".");
Ok(default_timeout)
})
})
.unwrap_or(default_timeout);
debug!("Using registry request timeout of {}s", timeout);
// Initialize the base client.
let client = self.client.unwrap_or_else(|| {
// Check for the presence of an `SSL_CERT_FILE`.
let ssl_cert_file_exists = env::var_os("SSL_CERT_FILE").is_some_and(|path| {
let path_exists = Path::new(&path).exists();
if !path_exists {
warn_user_once!(
"Ignoring invalid `SSL_CERT_FILE`. File does not exist: {}.",
path.simplified_display()
);
}
path_exists
});
// Load the TLS configuration.
let tls = tls::load(if self.native_tls || ssl_cert_file_exists {
Roots::Native
} else {
Roots::Webpki
})
.expect("Failed to load TLS configuration.");
let client_core = ClientBuilder::new()
.user_agent(user_agent_string)
.pool_max_idle_per_host(20)
.timeout(std::time::Duration::from_secs(timeout))
.use_preconfigured_tls(tls);
client_core.build().expect("Failed to build HTTP client.")
});
// Wrap in any relevant middleware.
let client = match self.connectivity {
Connectivity::Online => {
let client = reqwest_middleware::ClientBuilder::new(client.clone());
// Initialize the retry strategy.
let retry_policy =
ExponentialBackoff::builder().build_with_max_retries(self.retries);
let retry_strategy = RetryTransientMiddleware::new_with_policy(retry_policy);
let client = client.with(retry_strategy);
// Initialize the authentication middleware to set headers.
let client = client.with(AuthMiddleware::new(self.keyring_provider));
client.build()
}
Connectivity::Offline => reqwest_middleware::ClientBuilder::new(client.clone())
.with(OfflineMiddleware)
.build(),
};
BaseClient {
connectivity: self.connectivity,
client,
timeout,
}
}
}
/// A base client for HTTP requests
#[derive(Debug, Clone)]
pub struct BaseClient {
/// The underlying HTTP client.
client: ClientWithMiddleware,
/// The connectivity mode to use.
connectivity: Connectivity,
/// Configured client timeout, in seconds.
timeout: u64,
}
impl BaseClient {
/// The underyling [`ClientWithMiddleware`].
pub fn client(&self) -> ClientWithMiddleware {
self.client.clone()
}
/// The configured client timeout, in seconds.
pub fn timeout(&self) -> u64 {
self.timeout
}
/// The configured connectivity mode.
pub fn connectivity(&self) -> Connectivity {
self.connectivity
}
}
// To avoid excessively verbose call chains, as the [`BaseClient`] is often nested within other client types.
impl Deref for BaseClient {
type Target = ClientWithMiddleware;
/// Deference to the underlying [`ClientWithMiddleware`].
fn deref(&self) -> &Self::Target {
&self.client
}
}

View file

@ -2,7 +2,6 @@ use std::{borrow::Cow, future::Future, path::Path};
use futures::FutureExt; use futures::FutureExt;
use reqwest::{Request, Response}; use reqwest::{Request, Response};
use reqwest_middleware::ClientWithMiddleware;
use rkyv::util::AlignedVec; use rkyv::util::AlignedVec;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -11,6 +10,7 @@ use tracing::{debug, info_span, instrument, trace, warn, Instrument};
use uv_cache::{CacheEntry, Freshness}; use uv_cache::{CacheEntry, Freshness};
use uv_fs::write_atomic; use uv_fs::write_atomic;
use crate::BaseClient;
use crate::{ use crate::{
httpcache::{AfterResponse, BeforeRequest, CachePolicy, CachePolicyBuilder}, httpcache::{AfterResponse, BeforeRequest, CachePolicy, CachePolicyBuilder},
rkyvutil::OwnedArchive, rkyvutil::OwnedArchive,
@ -158,15 +158,15 @@ impl From<Freshness> for CacheControl {
/// Again unlike `http-cache`, the caller gets full control over the cache key with the assumption /// Again unlike `http-cache`, the caller gets full control over the cache key with the assumption
/// that it's a file. /// that it's a file.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct CachedClient(ClientWithMiddleware); pub struct CachedClient(BaseClient);
impl CachedClient { impl CachedClient {
pub fn new(client: ClientWithMiddleware) -> Self { pub fn new(client: BaseClient) -> Self {
Self(client) Self(client)
} }
/// The middleware is the retry strategy /// The base client
pub fn uncached(&self) -> ClientWithMiddleware { pub fn uncached(&self) -> BaseClient {
self.0.clone() self.0.clone()
} }

View file

@ -143,10 +143,9 @@ impl<'a> FlatIndexClient<'a> {
Connectivity::Offline => CacheControl::AllowStale, Connectivity::Offline => CacheControl::AllowStale,
}; };
let cached_client = self.client.cached_client(); let flat_index_request = self
.client
let flat_index_request = cached_client .uncached_client()
.uncached()
.get(url.clone()) .get(url.clone())
.header("Accept-Encoding", "gzip") .header("Accept-Encoding", "gzip")
.header("Accept", "text/html") .header("Accept", "text/html")
@ -180,7 +179,9 @@ impl<'a> FlatIndexClient<'a> {
.boxed() .boxed()
.instrument(info_span!("parse_flat_index_html", url = % url)) .instrument(info_span!("parse_flat_index_html", url = % url))
}; };
let response = cached_client let response = self
.client
.cached_client()
.get_serde( .get_serde(
flat_index_request, flat_index_request,
&cache_entry, &cache_entry,

View file

@ -1,3 +1,4 @@
pub use base_client::BaseClient;
pub use cached_client::{CacheControl, CachedClient, CachedClientError, DataWithCachePolicy}; pub use cached_client::{CacheControl, CachedClient, CachedClientError, DataWithCachePolicy};
pub use error::{BetterReqwestError, Error, ErrorKind}; pub use error::{BetterReqwestError, Error, ErrorKind};
pub use flat_index::{FlatDistributions, FlatIndex, FlatIndexClient, FlatIndexError}; pub use flat_index::{FlatDistributions, FlatIndex, FlatIndexClient, FlatIndexError};
@ -7,6 +8,7 @@ pub use registry_client::{
}; };
pub use rkyvutil::OwnedArchive; pub use rkyvutil::OwnedArchive;
mod base_client;
mod cached_client; mod cached_client;
mod error; mod error;
mod flat_index; mod flat_index;

View file

@ -1,5 +1,4 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::env;
use std::fmt::Debug; use std::fmt::Debug;
use std::path::Path; use std::path::Path;
use std::str::FromStr; use std::str::FromStr;
@ -7,13 +6,11 @@ use std::str::FromStr;
use async_http_range_reader::AsyncHttpRangeReader; use async_http_range_reader::AsyncHttpRangeReader;
use futures::{FutureExt, TryStreamExt}; use futures::{FutureExt, TryStreamExt};
use http::HeaderMap; use http::HeaderMap;
use reqwest::{Client, ClientBuilder, Response, StatusCode}; use reqwest::{Client, Response, StatusCode};
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::RetryTransientMiddleware;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
use tokio_util::compat::FuturesAsyncReadCompatExt; use tokio_util::compat::FuturesAsyncReadCompatExt;
use tracing::{debug, info_span, instrument, trace, warn, Instrument}; use tracing::{info_span, instrument, trace, warn, Instrument};
use url::Url; use url::Url;
use distribution_filename::{DistFilename, SourceDistFilename, WheelFilename}; use distribution_filename::{DistFilename, SourceDistFilename, WheelFilename};
@ -21,20 +18,16 @@ use distribution_types::{BuiltDist, File, FileLocation, IndexUrl, IndexUrls, Nam
use install_wheel_rs::metadata::{find_archive_dist_info, is_metadata_entry}; use install_wheel_rs::metadata::{find_archive_dist_info, is_metadata_entry};
use pep440_rs::Version; use pep440_rs::Version;
use pypi_types::{Metadata23, SimpleJson}; use pypi_types::{Metadata23, SimpleJson};
use uv_auth::{AuthMiddleware, KeyringProvider}; use uv_auth::KeyringProvider;
use uv_cache::{Cache, CacheBucket, WheelCache}; use uv_cache::{Cache, CacheBucket, WheelCache};
use uv_fs::Simplified;
use uv_normalize::PackageName; use uv_normalize::PackageName;
use uv_version::version;
use uv_warnings::warn_user_once;
use crate::base_client::{BaseClient, BaseClientBuilder};
use crate::cached_client::CacheControl; use crate::cached_client::CacheControl;
use crate::html::SimpleHtml; use crate::html::SimpleHtml;
use crate::middleware::OfflineMiddleware;
use crate::remote_metadata::wheel_metadata_from_remote_zip; use crate::remote_metadata::wheel_metadata_from_remote_zip;
use crate::rkyvutil::OwnedArchive; use crate::rkyvutil::OwnedArchive;
use crate::tls::Roots; use crate::{CachedClient, CachedClientError, Error, ErrorKind};
use crate::{tls, CachedClient, CachedClientError, Error, ErrorKind};
/// A builder for an [`RegistryClient`]. /// A builder for an [`RegistryClient`].
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -106,76 +99,22 @@ impl RegistryClientBuilder {
} }
pub fn build(self) -> RegistryClient { pub fn build(self) -> RegistryClient {
// Create user agent. // Build a base client
let user_agent_string = format!("uv/{}", version()); let mut builder = BaseClientBuilder::new();
// Timeout options, matching https://doc.rust-lang.org/nightly/cargo/reference/config.html#httptimeout if let Some(client) = self.client {
// `UV_REQUEST_TIMEOUT` is provided for backwards compatibility with v0.1.6 builder = builder.client(client)
let default_timeout = 5 * 60;
let timeout = env::var("UV_HTTP_TIMEOUT")
.or_else(|_| env::var("UV_REQUEST_TIMEOUT"))
.or_else(|_| env::var("HTTP_TIMEOUT"))
.and_then(|value| {
value.parse::<u64>()
.or_else(|_| {
// On parse error, warn and use the default timeout
warn_user_once!("Ignoring invalid value from environment for UV_HTTP_TIMEOUT. Expected integer number of seconds, got \"{value}\".");
Ok(default_timeout)
})
})
.unwrap_or(default_timeout);
debug!("Using registry request timeout of {}s", timeout);
// Initialize the base client.
let client = self.client.unwrap_or_else(|| {
// Check for the presence of an `SSL_CERT_FILE`.
let ssl_cert_file_exists = env::var_os("SSL_CERT_FILE").is_some_and(|path| {
let path_exists = Path::new(&path).exists();
if !path_exists {
warn_user_once!(
"Ignoring invalid `SSL_CERT_FILE`. File does not exist: {}.",
path.simplified_display()
);
} }
path_exists
});
// Load the TLS configuration.
let tls = tls::load(if self.native_tls || ssl_cert_file_exists {
Roots::Native
} else {
Roots::Webpki
})
.expect("Failed to load TLS configuration.");
let client_core = ClientBuilder::new() let client = builder
.user_agent(user_agent_string) .retries(self.retries)
.pool_max_idle_per_host(20) .connectivity(self.connectivity)
.timeout(std::time::Duration::from_secs(timeout)) .native_tls(self.native_tls)
.use_preconfigured_tls(tls); .keyring_provider(self.keyring_provider)
.build();
client_core.build().expect("Failed to build HTTP client.") let timeout = client.timeout();
}); let connectivity = client.connectivity();
// Wrap in any relevant middleware.
let client = match self.connectivity {
Connectivity::Online => {
let client = reqwest_middleware::ClientBuilder::new(client.clone());
// Initialize the retry strategy.
let retry_policy =
ExponentialBackoff::builder().build_with_max_retries(self.retries);
let retry_strategy = RetryTransientMiddleware::new_with_policy(retry_policy);
let client = client.with(retry_strategy);
// Initialize the authentication middleware to set headers.
let client = client.with(AuthMiddleware::new(self.keyring_provider));
client.build()
}
Connectivity::Offline => reqwest_middleware::ClientBuilder::new(client.clone())
.with(OfflineMiddleware)
.build(),
};
// Wrap in the cache middleware. // Wrap in the cache middleware.
let client = CachedClient::new(client); let client = CachedClient::new(client);
@ -183,7 +122,7 @@ impl RegistryClientBuilder {
RegistryClient { RegistryClient {
index_urls: self.index_urls, index_urls: self.index_urls,
cache: self.cache, cache: self.cache,
connectivity: self.connectivity, connectivity,
client, client,
timeout, timeout,
} }
@ -211,6 +150,11 @@ impl RegistryClient {
&self.client &self.client
} }
/// Return the [`BaseClient`] used by this client.
pub fn uncached_client(&self) -> BaseClient {
self.client.uncached()
}
/// Return the [`Connectivity`] mode used by this client. /// Return the [`Connectivity`] mode used by this client.
pub fn connectivity(&self) -> Connectivity { pub fn connectivity(&self) -> Connectivity {
self.connectivity self.connectivity
@ -306,8 +250,7 @@ impl RegistryClient {
}; };
let simple_request = self let simple_request = self
.client .uncached_client()
.uncached()
.get(url.clone()) .get(url.clone())
.header("Accept-Encoding", "gzip") .header("Accept-Encoding", "gzip")
.header("Accept", MediaType::accepts()) .header("Accept", MediaType::accepts())
@ -356,7 +299,7 @@ impl RegistryClient {
.instrument(info_span!("parse_simple_api", package = %package_name)) .instrument(info_span!("parse_simple_api", package = %package_name))
}; };
let result = self let result = self
.client .cached_client()
.get_cacheable( .get_cacheable(
simple_request, simple_request,
&cache_entry, &cache_entry,
@ -469,13 +412,12 @@ impl RegistryClient {
}) })
}; };
let req = self let req = self
.client .uncached_client()
.uncached()
.get(url.clone()) .get(url.clone())
.build() .build()
.map_err(ErrorKind::from)?; .map_err(ErrorKind::from)?;
Ok(self Ok(self
.client .cached_client()
.get_serde(req, &cache_entry, cache_control, response_callback) .get_serde(req, &cache_entry, cache_control, response_callback)
.await?) .await?)
} else { } else {
@ -509,8 +451,7 @@ impl RegistryClient {
}; };
let req = self let req = self
.client .uncached_client()
.uncached()
.head(url.clone()) .head(url.clone())
.header( .header(
"accept-encoding", "accept-encoding",
@ -530,7 +471,7 @@ impl RegistryClient {
let read_metadata_range_request = |response: Response| { let read_metadata_range_request = |response: Response| {
async { async {
let mut reader = AsyncHttpRangeReader::from_head_response( let mut reader = AsyncHttpRangeReader::from_head_response(
self.client.uncached(), self.uncached_client().client(),
response, response,
headers, headers,
) )
@ -552,7 +493,7 @@ impl RegistryClient {
}; };
let result = self let result = self
.client .cached_client()
.get_serde( .get_serde(
req, req,
&cache_entry, &cache_entry,
@ -577,8 +518,7 @@ impl RegistryClient {
// Create a request to stream the file. // Create a request to stream the file.
let req = self let req = self
.client .uncached_client()
.uncached()
.get(url.clone()) .get(url.clone())
.header( .header(
// `reqwest` defaults to accepting compressed responses. // `reqwest` defaults to accepting compressed responses.
@ -603,7 +543,7 @@ impl RegistryClient {
.instrument(info_span!("read_metadata_stream", wheel = %filename)) .instrument(info_span!("read_metadata_stream", wheel = %filename))
}; };
self.client self.cached_client()
.get_serde(req, &cache_entry, cache_control, read_metadata_stream) .get_serde(req, &cache_entry, cache_control, read_metadata_stream)
.await .await
.map_err(crate::Error::from) .map_err(crate::Error::from)

View file

@ -52,8 +52,7 @@ async fn test_client_with_netrc_credentials() -> Result<()> {
// Send request to our dummy server // Send request to our dummy server
let res = client let res = client
.cached_client() .uncached_client()
.uncached()
.get(format!("http://{addr}")) .get(format!("http://{addr}"))
.send() .send()
.await?; .await?;

View file

@ -44,8 +44,7 @@ async fn test_user_agent_has_version() -> Result<()> {
// Send request to our dummy server // Send request to our dummy server
let res = client let res = client
.cached_client() .uncached_client()
.uncached()
.get(format!("http://{addr}")) .get(format!("http://{addr}"))
.send() .send()
.await?; .await?;

View file

@ -460,8 +460,7 @@ impl<'a, Context: BuildContext + Send + Sync> DistributionDatabase<'a, Context>
let req = self let req = self
.client .client
.cached_client() .uncached_client()
.uncached()
.get(url) .get(url)
.header( .header(
// `reqwest` defaults to accepting compressed responses. // `reqwest` defaults to accepting compressed responses.
@ -542,8 +541,7 @@ impl<'a, Context: BuildContext + Send + Sync> DistributionDatabase<'a, Context>
let req = self let req = self
.client .client
.cached_client() .uncached_client()
.uncached()
.get(url) .get(url)
.header( .header(
// `reqwest` defaults to accepting compressed responses. // `reqwest` defaults to accepting compressed responses.

View file

@ -304,8 +304,7 @@ impl<'a, T: BuildContext> SourceDistCachedBuilder<'a, T> {
}; };
let req = self let req = self
.client .client
.cached_client() .uncached_client()
.uncached()
.get(url.clone()) .get(url.clone())
.header( .header(
// `reqwest` defaults to accepting compressed responses. // `reqwest` defaults to accepting compressed responses.
@ -414,8 +413,7 @@ impl<'a, T: BuildContext> SourceDistCachedBuilder<'a, T> {
}; };
let req = self let req = self
.client .client
.cached_client() .uncached_client()
.uncached()
.get(url.clone()) .get(url.clone())
.header( .header(
// `reqwest` defaults to accepting compressed responses. // `reqwest` defaults to accepting compressed responses.