Introduce a BaseClient for construction of canonical configured client (#2431)

In preparation for support of
https://github.com/astral-sh/uv/issues/2357 (see
https://github.com/astral-sh/uv/pull/2434)
This commit is contained in:
Zanie Blue 2024-03-15 12:07:38 -05:00 committed by GitHub
parent 8463d6d672
commit 9c27f92203
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 240 additions and 115 deletions

View file

@ -0,0 +1,188 @@
use reqwest::{Client, ClientBuilder};
use reqwest_middleware::ClientWithMiddleware;
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::RetryTransientMiddleware;
use std::env;
use std::fmt::Debug;
use std::ops::Deref;
use std::path::Path;
use tracing::debug;
use uv_auth::{AuthMiddleware, KeyringProvider};
use uv_fs::Simplified;
use uv_version::version;
use uv_warnings::warn_user_once;
use crate::middleware::OfflineMiddleware;
use crate::tls::Roots;
use crate::{tls, Connectivity};
/// A builder for an [`RegistryClient`].
#[derive(Debug, Clone)]
pub struct BaseClientBuilder {
keyring_provider: KeyringProvider,
native_tls: bool,
retries: u32,
connectivity: Connectivity,
client: Option<Client>,
}
impl BaseClientBuilder {
pub fn new() -> Self {
Self {
keyring_provider: KeyringProvider::default(),
native_tls: false,
connectivity: Connectivity::Online,
retries: 3,
client: None,
}
}
}
impl BaseClientBuilder {
#[must_use]
pub fn keyring_provider(mut self, keyring_provider: KeyringProvider) -> Self {
self.keyring_provider = keyring_provider;
self
}
#[must_use]
pub fn connectivity(mut self, connectivity: Connectivity) -> Self {
self.connectivity = connectivity;
self
}
#[must_use]
pub fn retries(mut self, retries: u32) -> Self {
self.retries = retries;
self
}
#[must_use]
pub fn native_tls(mut self, native_tls: bool) -> Self {
self.native_tls = native_tls;
self
}
#[must_use]
pub fn client(mut self, client: Client) -> Self {
self.client = Some(client);
self
}
pub fn build(self) -> BaseClient {
// Create user agent.
let user_agent_string = format!("uv/{}", version());
// Timeout options, matching https://doc.rust-lang.org/nightly/cargo/reference/config.html#httptimeout
// `UV_REQUEST_TIMEOUT` is provided for backwards compatibility with v0.1.6
let default_timeout = 5 * 60;
let timeout = env::var("UV_HTTP_TIMEOUT")
.or_else(|_| env::var("UV_REQUEST_TIMEOUT"))
.or_else(|_| env::var("HTTP_TIMEOUT"))
.and_then(|value| {
value.parse::<u64>()
.or_else(|_| {
// On parse error, warn and use the default timeout
warn_user_once!("Ignoring invalid value from environment for UV_HTTP_TIMEOUT. Expected integer number of seconds, got \"{value}\".");
Ok(default_timeout)
})
})
.unwrap_or(default_timeout);
debug!("Using registry request timeout of {}s", timeout);
// Initialize the base client.
let client = self.client.unwrap_or_else(|| {
// Check for the presence of an `SSL_CERT_FILE`.
let ssl_cert_file_exists = env::var_os("SSL_CERT_FILE").is_some_and(|path| {
let path_exists = Path::new(&path).exists();
if !path_exists {
warn_user_once!(
"Ignoring invalid `SSL_CERT_FILE`. File does not exist: {}.",
path.simplified_display()
);
}
path_exists
});
// Load the TLS configuration.
let tls = tls::load(if self.native_tls || ssl_cert_file_exists {
Roots::Native
} else {
Roots::Webpki
})
.expect("Failed to load TLS configuration.");
let client_core = ClientBuilder::new()
.user_agent(user_agent_string)
.pool_max_idle_per_host(20)
.timeout(std::time::Duration::from_secs(timeout))
.use_preconfigured_tls(tls);
client_core.build().expect("Failed to build HTTP client.")
});
// Wrap in any relevant middleware.
let client = match self.connectivity {
Connectivity::Online => {
let client = reqwest_middleware::ClientBuilder::new(client.clone());
// Initialize the retry strategy.
let retry_policy =
ExponentialBackoff::builder().build_with_max_retries(self.retries);
let retry_strategy = RetryTransientMiddleware::new_with_policy(retry_policy);
let client = client.with(retry_strategy);
// Initialize the authentication middleware to set headers.
let client = client.with(AuthMiddleware::new(self.keyring_provider));
client.build()
}
Connectivity::Offline => reqwest_middleware::ClientBuilder::new(client.clone())
.with(OfflineMiddleware)
.build(),
};
BaseClient {
connectivity: self.connectivity,
client,
timeout,
}
}
}
/// A base client for HTTP requests
#[derive(Debug, Clone)]
pub struct BaseClient {
/// The underlying HTTP client.
client: ClientWithMiddleware,
/// The connectivity mode to use.
connectivity: Connectivity,
/// Configured client timeout, in seconds.
timeout: u64,
}
impl BaseClient {
/// The underyling [`ClientWithMiddleware`].
pub fn client(&self) -> ClientWithMiddleware {
self.client.clone()
}
/// The configured client timeout, in seconds.
pub fn timeout(&self) -> u64 {
self.timeout
}
/// The configured connectivity mode.
pub fn connectivity(&self) -> Connectivity {
self.connectivity
}
}
// To avoid excessively verbose call chains, as the [`BaseClient`] is often nested within other client types.
impl Deref for BaseClient {
type Target = ClientWithMiddleware;
/// Deference to the underlying [`ClientWithMiddleware`].
fn deref(&self) -> &Self::Target {
&self.client
}
}

View file

@ -2,7 +2,6 @@ use std::{borrow::Cow, future::Future, path::Path};
use futures::FutureExt;
use reqwest::{Request, Response};
use reqwest_middleware::ClientWithMiddleware;
use rkyv::util::AlignedVec;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
@ -11,6 +10,7 @@ use tracing::{debug, info_span, instrument, trace, warn, Instrument};
use uv_cache::{CacheEntry, Freshness};
use uv_fs::write_atomic;
use crate::BaseClient;
use crate::{
httpcache::{AfterResponse, BeforeRequest, CachePolicy, CachePolicyBuilder},
rkyvutil::OwnedArchive,
@ -158,15 +158,15 @@ impl From<Freshness> for CacheControl {
/// Again unlike `http-cache`, the caller gets full control over the cache key with the assumption
/// that it's a file.
#[derive(Debug, Clone)]
pub struct CachedClient(ClientWithMiddleware);
pub struct CachedClient(BaseClient);
impl CachedClient {
pub fn new(client: ClientWithMiddleware) -> Self {
pub fn new(client: BaseClient) -> Self {
Self(client)
}
/// The middleware is the retry strategy
pub fn uncached(&self) -> ClientWithMiddleware {
/// The base client
pub fn uncached(&self) -> BaseClient {
self.0.clone()
}

View file

@ -143,10 +143,9 @@ impl<'a> FlatIndexClient<'a> {
Connectivity::Offline => CacheControl::AllowStale,
};
let cached_client = self.client.cached_client();
let flat_index_request = cached_client
.uncached()
let flat_index_request = self
.client
.uncached_client()
.get(url.clone())
.header("Accept-Encoding", "gzip")
.header("Accept", "text/html")
@ -180,7 +179,9 @@ impl<'a> FlatIndexClient<'a> {
.boxed()
.instrument(info_span!("parse_flat_index_html", url = % url))
};
let response = cached_client
let response = self
.client
.cached_client()
.get_serde(
flat_index_request,
&cache_entry,

View file

@ -1,3 +1,4 @@
pub use base_client::BaseClient;
pub use cached_client::{CacheControl, CachedClient, CachedClientError, DataWithCachePolicy};
pub use error::{BetterReqwestError, Error, ErrorKind};
pub use flat_index::{FlatDistributions, FlatIndex, FlatIndexClient, FlatIndexError};
@ -7,6 +8,7 @@ pub use registry_client::{
};
pub use rkyvutil::OwnedArchive;
mod base_client;
mod cached_client;
mod error;
mod flat_index;

View file

@ -1,5 +1,4 @@
use std::collections::BTreeMap;
use std::env;
use std::fmt::Debug;
use std::path::Path;
use std::str::FromStr;
@ -7,13 +6,11 @@ use std::str::FromStr;
use async_http_range_reader::AsyncHttpRangeReader;
use futures::{FutureExt, TryStreamExt};
use http::HeaderMap;
use reqwest::{Client, ClientBuilder, Response, StatusCode};
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::RetryTransientMiddleware;
use reqwest::{Client, Response, StatusCode};
use serde::{Deserialize, Serialize};
use tokio::io::AsyncReadExt;
use tokio_util::compat::FuturesAsyncReadCompatExt;
use tracing::{debug, info_span, instrument, trace, warn, Instrument};
use tracing::{info_span, instrument, trace, warn, Instrument};
use url::Url;
use distribution_filename::{DistFilename, SourceDistFilename, WheelFilename};
@ -21,20 +18,16 @@ use distribution_types::{BuiltDist, File, FileLocation, IndexUrl, IndexUrls, Nam
use install_wheel_rs::metadata::{find_archive_dist_info, is_metadata_entry};
use pep440_rs::Version;
use pypi_types::{Metadata23, SimpleJson};
use uv_auth::{AuthMiddleware, KeyringProvider};
use uv_auth::KeyringProvider;
use uv_cache::{Cache, CacheBucket, WheelCache};
use uv_fs::Simplified;
use uv_normalize::PackageName;
use uv_version::version;
use uv_warnings::warn_user_once;
use crate::base_client::{BaseClient, BaseClientBuilder};
use crate::cached_client::CacheControl;
use crate::html::SimpleHtml;
use crate::middleware::OfflineMiddleware;
use crate::remote_metadata::wheel_metadata_from_remote_zip;
use crate::rkyvutil::OwnedArchive;
use crate::tls::Roots;
use crate::{tls, CachedClient, CachedClientError, Error, ErrorKind};
use crate::{CachedClient, CachedClientError, Error, ErrorKind};
/// A builder for an [`RegistryClient`].
#[derive(Debug, Clone)]
@ -106,76 +99,22 @@ impl RegistryClientBuilder {
}
pub fn build(self) -> RegistryClient {
// Create user agent.
let user_agent_string = format!("uv/{}", version());
// Build a base client
let mut builder = BaseClientBuilder::new();
// Timeout options, matching https://doc.rust-lang.org/nightly/cargo/reference/config.html#httptimeout
// `UV_REQUEST_TIMEOUT` is provided for backwards compatibility with v0.1.6
let default_timeout = 5 * 60;
let timeout = env::var("UV_HTTP_TIMEOUT")
.or_else(|_| env::var("UV_REQUEST_TIMEOUT"))
.or_else(|_| env::var("HTTP_TIMEOUT"))
.and_then(|value| {
value.parse::<u64>()
.or_else(|_| {
// On parse error, warn and use the default timeout
warn_user_once!("Ignoring invalid value from environment for UV_HTTP_TIMEOUT. Expected integer number of seconds, got \"{value}\".");
Ok(default_timeout)
})
})
.unwrap_or(default_timeout);
debug!("Using registry request timeout of {}s", timeout);
// Initialize the base client.
let client = self.client.unwrap_or_else(|| {
// Check for the presence of an `SSL_CERT_FILE`.
let ssl_cert_file_exists = env::var_os("SSL_CERT_FILE").is_some_and(|path| {
let path_exists = Path::new(&path).exists();
if !path_exists {
warn_user_once!(
"Ignoring invalid `SSL_CERT_FILE`. File does not exist: {}.",
path.simplified_display()
);
if let Some(client) = self.client {
builder = builder.client(client)
}
path_exists
});
// Load the TLS configuration.
let tls = tls::load(if self.native_tls || ssl_cert_file_exists {
Roots::Native
} else {
Roots::Webpki
})
.expect("Failed to load TLS configuration.");
let client_core = ClientBuilder::new()
.user_agent(user_agent_string)
.pool_max_idle_per_host(20)
.timeout(std::time::Duration::from_secs(timeout))
.use_preconfigured_tls(tls);
let client = builder
.retries(self.retries)
.connectivity(self.connectivity)
.native_tls(self.native_tls)
.keyring_provider(self.keyring_provider)
.build();
client_core.build().expect("Failed to build HTTP client.")
});
// Wrap in any relevant middleware.
let client = match self.connectivity {
Connectivity::Online => {
let client = reqwest_middleware::ClientBuilder::new(client.clone());
// Initialize the retry strategy.
let retry_policy =
ExponentialBackoff::builder().build_with_max_retries(self.retries);
let retry_strategy = RetryTransientMiddleware::new_with_policy(retry_policy);
let client = client.with(retry_strategy);
// Initialize the authentication middleware to set headers.
let client = client.with(AuthMiddleware::new(self.keyring_provider));
client.build()
}
Connectivity::Offline => reqwest_middleware::ClientBuilder::new(client.clone())
.with(OfflineMiddleware)
.build(),
};
let timeout = client.timeout();
let connectivity = client.connectivity();
// Wrap in the cache middleware.
let client = CachedClient::new(client);
@ -183,7 +122,7 @@ impl RegistryClientBuilder {
RegistryClient {
index_urls: self.index_urls,
cache: self.cache,
connectivity: self.connectivity,
connectivity,
client,
timeout,
}
@ -211,6 +150,11 @@ impl RegistryClient {
&self.client
}
/// Return the [`BaseClient`] used by this client.
pub fn uncached_client(&self) -> BaseClient {
self.client.uncached()
}
/// Return the [`Connectivity`] mode used by this client.
pub fn connectivity(&self) -> Connectivity {
self.connectivity
@ -306,8 +250,7 @@ impl RegistryClient {
};
let simple_request = self
.client
.uncached()
.uncached_client()
.get(url.clone())
.header("Accept-Encoding", "gzip")
.header("Accept", MediaType::accepts())
@ -356,7 +299,7 @@ impl RegistryClient {
.instrument(info_span!("parse_simple_api", package = %package_name))
};
let result = self
.client
.cached_client()
.get_cacheable(
simple_request,
&cache_entry,
@ -469,13 +412,12 @@ impl RegistryClient {
})
};
let req = self
.client
.uncached()
.uncached_client()
.get(url.clone())
.build()
.map_err(ErrorKind::from)?;
Ok(self
.client
.cached_client()
.get_serde(req, &cache_entry, cache_control, response_callback)
.await?)
} else {
@ -509,8 +451,7 @@ impl RegistryClient {
};
let req = self
.client
.uncached()
.uncached_client()
.head(url.clone())
.header(
"accept-encoding",
@ -530,7 +471,7 @@ impl RegistryClient {
let read_metadata_range_request = |response: Response| {
async {
let mut reader = AsyncHttpRangeReader::from_head_response(
self.client.uncached(),
self.uncached_client().client(),
response,
headers,
)
@ -552,7 +493,7 @@ impl RegistryClient {
};
let result = self
.client
.cached_client()
.get_serde(
req,
&cache_entry,
@ -577,8 +518,7 @@ impl RegistryClient {
// Create a request to stream the file.
let req = self
.client
.uncached()
.uncached_client()
.get(url.clone())
.header(
// `reqwest` defaults to accepting compressed responses.
@ -603,7 +543,7 @@ impl RegistryClient {
.instrument(info_span!("read_metadata_stream", wheel = %filename))
};
self.client
self.cached_client()
.get_serde(req, &cache_entry, cache_control, read_metadata_stream)
.await
.map_err(crate::Error::from)

View file

@ -52,8 +52,7 @@ async fn test_client_with_netrc_credentials() -> Result<()> {
// Send request to our dummy server
let res = client
.cached_client()
.uncached()
.uncached_client()
.get(format!("http://{addr}"))
.send()
.await?;

View file

@ -44,8 +44,7 @@ async fn test_user_agent_has_version() -> Result<()> {
// Send request to our dummy server
let res = client
.cached_client()
.uncached()
.uncached_client()
.get(format!("http://{addr}"))
.send()
.await?;

View file

@ -460,8 +460,7 @@ impl<'a, Context: BuildContext + Send + Sync> DistributionDatabase<'a, Context>
let req = self
.client
.cached_client()
.uncached()
.uncached_client()
.get(url)
.header(
// `reqwest` defaults to accepting compressed responses.
@ -542,8 +541,7 @@ impl<'a, Context: BuildContext + Send + Sync> DistributionDatabase<'a, Context>
let req = self
.client
.cached_client()
.uncached()
.uncached_client()
.get(url)
.header(
// `reqwest` defaults to accepting compressed responses.

View file

@ -304,8 +304,7 @@ impl<'a, T: BuildContext> SourceDistCachedBuilder<'a, T> {
};
let req = self
.client
.cached_client()
.uncached()
.uncached_client()
.get(url.clone())
.header(
// `reqwest` defaults to accepting compressed responses.
@ -414,8 +413,7 @@ impl<'a, T: BuildContext> SourceDistCachedBuilder<'a, T> {
};
let req = self
.client
.cached_client()
.uncached()
.uncached_client()
.get(url.clone())
.header(
// `reqwest` defaults to accepting compressed responses.