diff --git a/crates/uv-client/src/base_client.rs b/crates/uv-client/src/base_client.rs index b2b2829cf..8b06099d1 100644 --- a/crates/uv-client/src/base_client.rs +++ b/crates/uv-client/src/base_client.rs @@ -1,6 +1,6 @@ use itertools::Itertools; use reqwest::{Client, ClientBuilder, Response}; -use reqwest_middleware::ClientWithMiddleware; +use reqwest_middleware::{ClientWithMiddleware, Middleware}; use reqwest_retry::policies::ExponentialBackoff; use reqwest_retry::{ DefaultRetryableStrategy, RetryTransientMiddleware, Retryable, RetryableStrategy, @@ -8,6 +8,7 @@ use reqwest_retry::{ use std::error::Error; use std::fmt::Debug; use std::path::Path; +use std::sync::Arc; use std::time::Duration; use std::{env, iter}; use tracing::debug; @@ -54,6 +55,19 @@ pub struct BaseClientBuilder<'a> { platform: Option<&'a Platform>, auth_integration: AuthIntegration, default_timeout: Duration, + extra_middleware: Option, +} + +/// A list of user-defined middlewares to be applied to the client. +#[derive(Clone)] +pub struct ExtraMiddleware(pub Vec>); + +impl Debug for ExtraMiddleware { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExtraMiddleware") + .field("0", &format!("{} middlewares", self.0.len())) + .finish() + } } impl Default for BaseClientBuilder<'_> { @@ -75,6 +89,7 @@ impl BaseClientBuilder<'_> { platform: None, auth_integration: AuthIntegration::default(), default_timeout: Duration::from_secs(30), + extra_middleware: None, } } } @@ -140,6 +155,12 @@ impl<'a> BaseClientBuilder<'a> { self } + #[must_use] + pub fn extra_middleware(mut self, middleware: ExtraMiddleware) -> Self { + self.extra_middleware = Some(middleware); + self + } + pub fn is_offline(&self) -> bool { matches!(self.connectivity, Connectivity::Offline) } @@ -313,6 +334,13 @@ impl<'a> BaseClientBuilder<'a> { } } + // When supplied add the extra middleware + if let Some(extra_middleware) = &self.extra_middleware { + for middleware in &extra_middleware.0 { + client = client.with_arc(middleware.clone()); + } + } + client.build() } Connectivity::Offline => reqwest_middleware::ClientBuilder::new(client) diff --git a/crates/uv-client/src/registry_client.rs b/crates/uv-client/src/registry_client.rs index 0052b92d0..46b47d244 100644 --- a/crates/uv-client/src/registry_client.rs +++ b/crates/uv-client/src/registry_client.rs @@ -26,7 +26,7 @@ use uv_pep508::MarkerEnvironment; use uv_platform_tags::Platform; use uv_pypi_types::{ResolutionMetadata, SimpleJson}; -use crate::base_client::BaseClientBuilder; +use crate::base_client::{BaseClientBuilder, ExtraMiddleware}; use crate::cached_client::CacheControl; use crate::html::SimpleHtml; use crate::remote_metadata::wheel_metadata_from_remote_zip; @@ -110,6 +110,12 @@ impl<'a> RegistryClientBuilder<'a> { self } + #[must_use] + pub fn extra_middleware(mut self, middleware: ExtraMiddleware) -> Self { + self.base_client_builder = self.base_client_builder.extra_middleware(middleware); + self + } + #[must_use] pub fn markers(mut self, markers: &'a MarkerEnvironment) -> Self { self.base_client_builder = self.base_client_builder.markers(markers);