diff --git a/Cargo.lock b/Cargo.lock index 77a17a4f3..a30a0cbe1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4946,6 +4946,7 @@ dependencies = [ "uv-torch", "uv-version", "uv-warnings", + "wiremock", ] [[package]] diff --git a/crates/uv-client/Cargo.toml b/crates/uv-client/Cargo.toml index 81d1909fe..bc7fc611f 100644 --- a/crates/uv-client/Cargo.toml +++ b/crates/uv-client/Cargo.toml @@ -65,3 +65,4 @@ hyper = { version = "1.4.1", features = ["server", "http1"] } hyper-util = { version = "0.1.8", features = ["tokio"] } insta = { version = "1.40.0", features = ["filters", "json", "redactions"] } tokio = { workspace = true } +wiremock = { workspace = true } diff --git a/crates/uv-client/src/base_client.rs b/crates/uv-client/src/base_client.rs index f5fda246d..85c384b0d 100644 --- a/crates/uv-client/src/base_client.rs +++ b/crates/uv-client/src/base_client.rs @@ -6,14 +6,23 @@ use std::sync::Arc; use std::time::Duration; use std::{env, io, iter}; +use anyhow::anyhow; +use http::{ + HeaderMap, HeaderName, HeaderValue, Method, StatusCode, + header::{ + AUTHORIZATION, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, COOKIE, LOCATION, + PROXY_AUTHORIZATION, REFERER, TRANSFER_ENCODING, WWW_AUTHENTICATE, + }, +}; use itertools::Itertools; -use reqwest::{Client, ClientBuilder, Proxy, Response}; +use reqwest::{Client, ClientBuilder, IntoUrl, Proxy, Request, Response, multipart}; use reqwest_middleware::{ClientWithMiddleware, Middleware}; use reqwest_retry::policies::ExponentialBackoff; use reqwest_retry::{ DefaultRetryableStrategy, RetryTransientMiddleware, Retryable, RetryableStrategy, }; use tracing::{debug, trace}; +use url::ParseError; use url::Url; use uv_auth::{AuthMiddleware, Indexes}; @@ -32,6 +41,10 @@ use crate::middleware::OfflineMiddleware; use crate::tls::read_identity; pub const DEFAULT_RETRIES: u32 = 3; +/// Maximum number of redirects to follow before giving up. +/// +/// This is the default used by [`reqwest`]. +const DEFAULT_MAX_REDIRECTS: u32 = 10; /// Selectively skip parts or the entire auth middleware. #[derive(Debug, Clone, Copy, Default)] @@ -61,6 +74,31 @@ pub struct BaseClientBuilder<'a> { default_timeout: Duration, extra_middleware: Option, proxies: Vec, + redirect_policy: RedirectPolicy, + /// Whether credentials should be propagated during cross-origin redirects. + /// + /// A policy allowing propagation is insecure and should only be available for test code. + cross_origin_credential_policy: CrossOriginCredentialsPolicy, +} + +/// The policy for handling HTTP redirects. +#[derive(Debug, Default, Clone, Copy)] +pub enum RedirectPolicy { + /// Use reqwest's built-in redirect handling. This bypasses our custom middleware + /// on redirect. + #[default] + BypassMiddleware, + /// Handle redirects manually, re-triggering our custom middleware for each request. + RetriggerMiddleware, +} + +impl RedirectPolicy { + pub fn reqwest_policy(self) -> reqwest::redirect::Policy { + match self { + RedirectPolicy::BypassMiddleware => reqwest::redirect::Policy::default(), + RedirectPolicy::RetriggerMiddleware => reqwest::redirect::Policy::none(), + } + } } /// A list of user-defined middlewares to be applied to the client. @@ -96,6 +134,8 @@ impl BaseClientBuilder<'_> { default_timeout: Duration::from_secs(30), extra_middleware: None, proxies: vec![], + redirect_policy: RedirectPolicy::default(), + cross_origin_credential_policy: CrossOriginCredentialsPolicy::Secure, } } } @@ -173,6 +213,24 @@ impl<'a> BaseClientBuilder<'a> { self } + #[must_use] + pub fn redirect(mut self, policy: RedirectPolicy) -> Self { + self.redirect_policy = policy; + self + } + + /// Allows credentials to be propagated on cross-origin redirects. + /// + /// WARNING: This should only be available for tests. In production code, propagating credentials + /// during cross-origin redirects can lead to security vulnerabilities including credential + /// leakage to untrusted domains. + #[cfg(test)] + #[must_use] + pub fn allow_cross_origin_credentials(mut self) -> Self { + self.cross_origin_credential_policy = CrossOriginCredentialsPolicy::Insecure; + self + } + pub fn is_offline(&self) -> bool { matches!(self.connectivity, Connectivity::Offline) } @@ -229,6 +287,7 @@ impl<'a> BaseClientBuilder<'a> { timeout, ssl_cert_file_exists, Security::Secure, + self.redirect_policy, ); // Create an insecure client that accepts invalid certificates. @@ -237,11 +296,20 @@ impl<'a> BaseClientBuilder<'a> { timeout, ssl_cert_file_exists, Security::Insecure, + self.redirect_policy, ); // Wrap in any relevant middleware and handle connectivity. - let client = self.apply_middleware(raw_client.clone()); - let dangerous_client = self.apply_middleware(raw_dangerous_client.clone()); + let client = RedirectClientWithMiddleware { + client: self.apply_middleware(raw_client.clone()), + redirect_policy: self.redirect_policy, + cross_origin_credentials_policy: self.cross_origin_credential_policy, + }; + let dangerous_client = RedirectClientWithMiddleware { + client: self.apply_middleware(raw_dangerous_client.clone()), + redirect_policy: self.redirect_policy, + cross_origin_credentials_policy: self.cross_origin_credential_policy, + }; BaseClient { connectivity: self.connectivity, @@ -258,8 +326,16 @@ impl<'a> BaseClientBuilder<'a> { /// Share the underlying client between two different middleware configurations. pub fn wrap_existing(&self, existing: &BaseClient) -> BaseClient { // Wrap in any relevant middleware and handle connectivity. - let client = self.apply_middleware(existing.raw_client.clone()); - let dangerous_client = self.apply_middleware(existing.raw_dangerous_client.clone()); + let client = RedirectClientWithMiddleware { + client: self.apply_middleware(existing.raw_client.clone()), + redirect_policy: self.redirect_policy, + cross_origin_credentials_policy: self.cross_origin_credential_policy, + }; + let dangerous_client = RedirectClientWithMiddleware { + client: self.apply_middleware(existing.raw_dangerous_client.clone()), + redirect_policy: self.redirect_policy, + cross_origin_credentials_policy: self.cross_origin_credential_policy, + }; BaseClient { connectivity: self.connectivity, @@ -279,6 +355,7 @@ impl<'a> BaseClientBuilder<'a> { timeout: Duration, ssl_cert_file_exists: bool, security: Security, + redirect_policy: RedirectPolicy, ) -> Client { // Configure the builder. let client_builder = ClientBuilder::new() @@ -286,7 +363,8 @@ impl<'a> BaseClientBuilder<'a> { .user_agent(user_agent) .pool_max_idle_per_host(20) .read_timeout(timeout) - .tls_built_in_root_certs(false); + .tls_built_in_root_certs(false) + .redirect(redirect_policy.reqwest_policy()); // If necessary, accept invalid certificates. let client_builder = match security { @@ -381,9 +459,9 @@ impl<'a> BaseClientBuilder<'a> { #[derive(Debug, Clone)] pub struct BaseClient { /// The underlying HTTP client that enforces valid certificates. - client: ClientWithMiddleware, + client: RedirectClientWithMiddleware, /// The underlying HTTP client that accepts invalid certificates. - dangerous_client: ClientWithMiddleware, + dangerous_client: RedirectClientWithMiddleware, /// The HTTP client without middleware. raw_client: Client, /// The HTTP client that accepts invalid certificates without middleware. @@ -408,7 +486,7 @@ enum Security { impl BaseClient { /// Selects the appropriate client based on the host's trustworthiness. - pub fn for_host(&self, url: &DisplaySafeUrl) -> &ClientWithMiddleware { + pub fn for_host(&self, url: &DisplaySafeUrl) -> &RedirectClientWithMiddleware { if self.disable_ssl(url) { &self.dangerous_client } else { @@ -416,6 +494,12 @@ impl BaseClient { } } + /// Executes a request, applying redirect policy. + pub async fn execute(&self, req: Request) -> reqwest_middleware::Result { + let client = self.for_host(&DisplaySafeUrl::from(req.url().clone())); + client.execute(req).await + } + /// Returns `true` if the host is trusted to use the insecure client. pub fn disable_ssl(&self, url: &DisplaySafeUrl) -> bool { self.allow_insecure_host @@ -439,6 +523,316 @@ impl BaseClient { } } +/// Wrapper around [`ClientWithMiddleware`] that manages redirects. +#[derive(Debug, Clone)] +pub struct RedirectClientWithMiddleware { + client: ClientWithMiddleware, + redirect_policy: RedirectPolicy, + /// Whether credentials should be preserved during cross-origin redirects. + /// + /// WARNING: This should only be available for tests. In production code, preserving credentials + /// during cross-origin redirects can lead to security vulnerabilities including credential + /// leakage to untrusted domains. + cross_origin_credentials_policy: CrossOriginCredentialsPolicy, +} + +impl RedirectClientWithMiddleware { + /// Convenience method to make a `GET` request to a URL. + pub fn get(&self, url: U) -> RequestBuilder { + RequestBuilder::new(self.client.get(url), self) + } + + /// Convenience method to make a `POST` request to a URL. + pub fn post(&self, url: U) -> RequestBuilder { + RequestBuilder::new(self.client.post(url), self) + } + + /// Convenience method to make a `HEAD` request to a URL. + pub fn head(&self, url: U) -> RequestBuilder { + RequestBuilder::new(self.client.head(url), self) + } + + /// Executes a request, applying the redirect policy. + pub async fn execute(&self, req: Request) -> reqwest_middleware::Result { + match self.redirect_policy { + RedirectPolicy::BypassMiddleware => self.client.execute(req).await, + RedirectPolicy::RetriggerMiddleware => self.execute_with_redirect_handling(req).await, + } + } + + /// Executes a request. If the response is a redirect (one of HTTP 301, 302, 303, 307, or 308), the + /// request is executed again with the redirect location URL (up to a maximum number of + /// redirects). + /// + /// Unlike the built-in reqwest redirect policies, this sends the redirect request through the + /// entire middleware pipeline again. + /// + /// See RFC 7231 7.1.2 for details on + /// redirect semantics. + async fn execute_with_redirect_handling( + &self, + req: Request, + ) -> reqwest_middleware::Result { + let mut request = req; + let mut redirects = 0; + let max_redirects = DEFAULT_MAX_REDIRECTS; + + loop { + let result = self + .client + .execute(request.try_clone().expect("HTTP request must be cloneable")) + .await; + let Ok(response) = result else { + return result; + }; + + if redirects >= max_redirects { + return Ok(response); + } + + let Some(redirect_request) = + request_into_redirect(request, &response, self.cross_origin_credentials_policy)? + else { + return Ok(response); + }; + + redirects += 1; + request = redirect_request; + } + } + + pub fn raw_client(&self) -> &ClientWithMiddleware { + &self.client + } +} + +impl From for ClientWithMiddleware { + fn from(item: RedirectClientWithMiddleware) -> ClientWithMiddleware { + item.client + } +} + +/// Check if this is should be a redirect and, if so, return a new redirect request. +/// +/// This implementation is based on the [`reqwest`] crate redirect implementation. +/// It takes ownership of the original [`Request`] and mutates it to create the new +/// redirect [`Request`]. +fn request_into_redirect( + mut req: Request, + res: &Response, + cross_origin_credentials_policy: CrossOriginCredentialsPolicy, +) -> reqwest_middleware::Result> { + let original_req_url = DisplaySafeUrl::from(req.url().clone()); + let status = res.status(); + let should_redirect = match status { + StatusCode::MOVED_PERMANENTLY + | StatusCode::FOUND + | StatusCode::TEMPORARY_REDIRECT + | StatusCode::PERMANENT_REDIRECT => true, + StatusCode::SEE_OTHER => { + // Per RFC 7231, HTTP 303 is intended for the user agent + // to perform a GET or HEAD request to the redirect target. + // Historically, some browsers also changed method from POST + // to GET on 301 or 302, but this is not required by RFC 7231 + // and was not intended by the HTTP spec. + *req.body_mut() = None; + for header in &[ + TRANSFER_ENCODING, + CONTENT_ENCODING, + CONTENT_TYPE, + CONTENT_LENGTH, + ] { + req.headers_mut().remove(header); + } + + match *req.method() { + Method::GET | Method::HEAD => {} + _ => { + *req.method_mut() = Method::GET; + } + } + true + } + _ => false, + }; + if !should_redirect { + return Ok(None); + } + + let location = res + .headers() + .get(LOCATION) + .ok_or(reqwest_middleware::Error::Middleware(anyhow!( + "Server returned redirect (HTTP {status}) without destination URL. This may indicate a server configuration issue" + )))? + .to_str() + .map_err(|_| { + reqwest_middleware::Error::Middleware(anyhow!( + "Invalid HTTP {status} 'Location' value: must only contain visible ascii characters" + )) + })?; + + let mut redirect_url = match DisplaySafeUrl::parse(location) { + Ok(url) => url, + // Per RFC 7231, URLs should be resolved against the request URL. + Err(ParseError::RelativeUrlWithoutBase) => original_req_url.join(location).map_err(|err| { + reqwest_middleware::Error::Middleware(anyhow!( + "Invalid HTTP {status} 'Location' value `{location}` relative to `{original_req_url}`: {err}" + )) + })?, + Err(err) => { + return Err(reqwest_middleware::Error::Middleware(anyhow!( + "Invalid HTTP {status} 'Location' value `{location}`: {err}" + ))); + } + }; + // Per RFC 7231, fragments must be propagated + if let Some(fragment) = original_req_url.fragment() { + redirect_url.set_fragment(Some(fragment)); + } + + // Ensure the URL is a valid HTTP URI. + if let Err(err) = redirect_url.as_str().parse::() { + return Err(reqwest_middleware::Error::Middleware(anyhow!( + "HTTP {status} 'Location' value `{redirect_url}` is not a valid HTTP URI: {err}" + ))); + } + + if redirect_url.scheme() != "http" && redirect_url.scheme() != "https" { + return Err(reqwest_middleware::Error::Middleware(anyhow!( + "Invalid HTTP {status} 'Location' value `{redirect_url}`: scheme needs to be https or http" + ))); + } + + let mut headers = HeaderMap::new(); + std::mem::swap(req.headers_mut(), &mut headers); + + let cross_host = redirect_url.host_str() != original_req_url.host_str() + || redirect_url.port_or_known_default() != original_req_url.port_or_known_default(); + if cross_host { + if cross_origin_credentials_policy == CrossOriginCredentialsPolicy::Secure { + debug!("Received a cross-origin redirect. Removing sensitive headers."); + headers.remove(AUTHORIZATION); + headers.remove(COOKIE); + headers.remove(PROXY_AUTHORIZATION); + headers.remove(WWW_AUTHENTICATE); + } + // If the redirect request is not a cross-origin request and the original request already + // had a Referer header, attempt to set the Referer header for the redirect request. + } else if headers.contains_key(REFERER) { + if let Some(referer) = make_referer(&redirect_url, &original_req_url) { + headers.insert(REFERER, referer); + } + } + + std::mem::swap(req.headers_mut(), &mut headers); + *req.url_mut() = Url::from(redirect_url); + debug!( + "Received HTTP {status}. Redirecting to {}", + DisplaySafeUrl::ref_cast(req.url()) + ); + Ok(Some(req)) +} + +/// Return a Referer [`HeaderValue`] according to RFC 7231. +/// +/// Return [`None`] if https has been downgraded in the redirect location. +fn make_referer( + redirect_url: &DisplaySafeUrl, + original_url: &DisplaySafeUrl, +) -> Option { + if redirect_url.scheme() == "http" && original_url.scheme() == "https" { + return None; + } + + let mut referer = original_url.clone(); + referer.remove_credentials(); + referer.set_fragment(None); + referer.as_str().parse().ok() +} + +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] +pub(crate) enum CrossOriginCredentialsPolicy { + /// Do not propagate credentials on cross-origin requests. + #[default] + Secure, + + /// Propagate credentials on cross-origin requests. + /// + /// WARNING: This should only be available for tests. In production code, preserving credentials + /// during cross-origin redirects can lead to security vulnerabilities including credential + /// leakage to untrusted domains. + #[cfg(test)] + Insecure, +} + +/// A builder to construct the properties of a `Request`. +/// +/// This wraps [`reqwest_middleware::RequestBuilder`] to ensure that the [`BaseClient`] +/// redirect policy is respected if `send()` is called. +#[derive(Debug)] +#[must_use] +pub struct RequestBuilder<'a> { + builder: reqwest_middleware::RequestBuilder, + client: &'a RedirectClientWithMiddleware, +} + +impl<'a> RequestBuilder<'a> { + pub fn new( + builder: reqwest_middleware::RequestBuilder, + client: &'a RedirectClientWithMiddleware, + ) -> Self { + Self { builder, client } + } + + /// Add a `Header` to this Request. + pub fn header(mut self, key: K, value: V) -> Self + where + HeaderName: TryFrom, + >::Error: Into, + HeaderValue: TryFrom, + >::Error: Into, + { + self.builder = self.builder.header(key, value); + self + } + + /// Add a set of Headers to the existing ones on this Request. + /// + /// The headers will be merged in to any already set. + pub fn headers(mut self, headers: HeaderMap) -> Self { + self.builder = self.builder.headers(headers); + self + } + + #[cfg(not(target_arch = "wasm32"))] + pub fn version(mut self, version: reqwest::Version) -> Self { + self.builder = self.builder.version(version); + self + } + + #[cfg_attr(docsrs, doc(cfg(feature = "multipart")))] + pub fn multipart(mut self, multipart: multipart::Form) -> Self { + self.builder = self.builder.multipart(multipart); + self + } + + /// Build a `Request`. + pub fn build(self) -> reqwest::Result { + self.builder.build() + } + + /// Constructs the Request and sends it to the target URL, returning a + /// future Response. + pub async fn send(self) -> reqwest_middleware::Result { + self.client.execute(self.build()?).await + } + + pub fn raw_builder(&self) -> &reqwest_middleware::RequestBuilder { + &self.builder + } +} + /// Extends [`DefaultRetryableStrategy`], to log transient request failures and additional retry cases. pub struct UvRetryableStrategy; @@ -528,3 +922,165 @@ fn find_source(orig: &dyn Error) -> Option<&E> { fn find_sources(orig: &dyn Error) -> impl Iterator { iter::successors(find_source::(orig), |&err| find_source(err)) } + +#[cfg(test)] +mod tests { + use super::*; + use anyhow::Result; + + use reqwest::{Client, Method}; + use wiremock::matchers::method; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + use crate::base_client::request_into_redirect; + + #[tokio::test] + async fn test_redirect_preserves_authorization_header_on_same_origin() -> Result<()> { + for status in &[301, 302, 303, 307, 308] { + let server = MockServer::start().await; + Mock::given(method("GET")) + .respond_with( + ResponseTemplate::new(*status) + .insert_header("location", format!("{}/redirect", server.uri())), + ) + .mount(&server) + .await; + + let request = Client::new() + .get(server.uri()) + .basic_auth("username", Some("password")) + .build() + .unwrap(); + + assert!(request.headers().contains_key(AUTHORIZATION)); + + let response = Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap() + .execute(request.try_clone().unwrap()) + .await + .unwrap(); + + let redirect_request = + request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)? + .unwrap(); + assert!(redirect_request.headers().contains_key(AUTHORIZATION)); + } + + Ok(()) + } + + #[tokio::test] + async fn test_redirect_removes_authorization_header_on_cross_origin() -> Result<()> { + for status in &[301, 302, 303, 307, 308] { + let server = MockServer::start().await; + Mock::given(method("GET")) + .respond_with( + ResponseTemplate::new(*status) + .insert_header("location", "https://cross-origin.com/simple"), + ) + .mount(&server) + .await; + + let request = Client::new() + .get(server.uri()) + .basic_auth("username", Some("password")) + .build() + .unwrap(); + + assert!(request.headers().contains_key(AUTHORIZATION)); + + let response = Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap() + .execute(request.try_clone().unwrap()) + .await + .unwrap(); + + let redirect_request = + request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)? + .unwrap(); + assert!(!redirect_request.headers().contains_key(AUTHORIZATION)); + } + + Ok(()) + } + + #[tokio::test] + async fn test_redirect_303_changes_post_to_get() -> Result<()> { + let server = MockServer::start().await; + Mock::given(method("POST")) + .respond_with( + ResponseTemplate::new(303) + .insert_header("location", format!("{}/redirect", server.uri())), + ) + .mount(&server) + .await; + + let request = Client::new() + .post(server.uri()) + .basic_auth("username", Some("password")) + .build() + .unwrap(); + + assert_eq!(request.method(), Method::POST); + + let response = Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap() + .execute(request.try_clone().unwrap()) + .await + .unwrap(); + + let redirect_request = + request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)? + .unwrap(); + assert_eq!(redirect_request.method(), Method::GET); + + Ok(()) + } + + #[tokio::test] + async fn test_redirect_no_referer_if_disabled() -> Result<()> { + for status in &[301, 302, 303, 307, 308] { + let server = MockServer::start().await; + Mock::given(method("GET")) + .respond_with( + ResponseTemplate::new(*status) + .insert_header("location", format!("{}/redirect", server.uri())), + ) + .mount(&server) + .await; + + let request = Client::builder() + .referer(false) + .build() + .unwrap() + .get(server.uri()) + .basic_auth("username", Some("password")) + .build() + .unwrap(); + + assert!(!request.headers().contains_key(REFERER)); + + let response = Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap() + .execute(request.try_clone().unwrap()) + .await + .unwrap(); + + let redirect_request = + request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)? + .unwrap(); + + assert!(!redirect_request.headers().contains_key(REFERER)); + } + + Ok(()) + } +} diff --git a/crates/uv-client/src/cached_client.rs b/crates/uv-client/src/cached_client.rs index d19f95ec7..ee3314d1c 100644 --- a/crates/uv-client/src/cached_client.rs +++ b/crates/uv-client/src/cached_client.rs @@ -523,7 +523,6 @@ impl CachedClient { debug!("Sending revalidation request for: {url}"); let response = self .0 - .for_host(&url) .execute(req) .instrument(info_span!("revalidation_request", url = url.as_str())) .await @@ -564,7 +563,6 @@ impl CachedClient { let cache_policy_builder = CachePolicyBuilder::new(&req); let response = self .0 - .for_host(&url) .execute(req) .await .map_err(|err| ErrorKind::from_reqwest_middleware(url.clone(), err))?; diff --git a/crates/uv-client/src/lib.rs b/crates/uv-client/src/lib.rs index 3ea33204c..e42c86620 100644 --- a/crates/uv-client/src/lib.rs +++ b/crates/uv-client/src/lib.rs @@ -1,6 +1,6 @@ pub use base_client::{ AuthIntegration, BaseClient, BaseClientBuilder, DEFAULT_RETRIES, ExtraMiddleware, - UvRetryableStrategy, is_extended_transient_error, + RedirectClientWithMiddleware, RequestBuilder, UvRetryableStrategy, is_extended_transient_error, }; pub use cached_client::{CacheControl, CachedClient, CachedClientError, DataWithCachePolicy}; pub use error::{Error, ErrorKind, WrappedReqwestError}; diff --git a/crates/uv-client/src/registry_client.rs b/crates/uv-client/src/registry_client.rs index 6271b7d20..b53e1ed9a 100644 --- a/crates/uv-client/src/registry_client.rs +++ b/crates/uv-client/src/registry_client.rs @@ -10,7 +10,6 @@ use futures::{FutureExt, StreamExt, TryStreamExt}; use http::{HeaderMap, StatusCode}; use itertools::Either; use reqwest::{Proxy, Response}; -use reqwest_middleware::ClientWithMiddleware; use rustc_hash::FxHashMap; use tokio::sync::{Mutex, Semaphore}; use tracing::{Instrument, debug, info_span, instrument, trace, warn}; @@ -35,13 +34,16 @@ use uv_redacted::DisplaySafeUrl; use uv_small_str::SmallString; use uv_torch::TorchStrategy; -use crate::base_client::{BaseClientBuilder, ExtraMiddleware}; +use crate::base_client::{BaseClientBuilder, ExtraMiddleware, RedirectPolicy}; use crate::cached_client::CacheControl; use crate::flat_index::FlatIndexEntry; use crate::html::SimpleHtml; use crate::remote_metadata::wheel_metadata_from_remote_zip; use crate::rkyvutil::OwnedArchive; -use crate::{BaseClient, CachedClient, Error, ErrorKind, FlatIndexClient, FlatIndexEntries}; +use crate::{ + BaseClient, CachedClient, Error, ErrorKind, FlatIndexClient, FlatIndexEntries, + RedirectClientWithMiddleware, +}; /// A builder for an [`RegistryClient`]. #[derive(Debug, Clone)] @@ -149,9 +151,23 @@ impl<'a> RegistryClientBuilder<'a> { self } + /// Allows credentials to be propagated on cross-origin redirects. + /// + /// WARNING: This should only be available for tests. In production code, propagating credentials + /// during cross-origin redirects can lead to security vulnerabilities including credential + /// leakage to untrusted domains. + #[cfg(test)] + #[must_use] + pub fn allow_cross_origin_credentials(mut self) -> Self { + self.base_client_builder = self.base_client_builder.allow_cross_origin_credentials(); + self + } + pub fn build(self) -> RegistryClient { // Build a base client - let builder = self.base_client_builder; + let builder = self + .base_client_builder + .redirect(RedirectPolicy::RetriggerMiddleware); let client = builder.build(); @@ -248,7 +264,7 @@ impl RegistryClient { } /// Return the [`BaseClient`] used by this client. - pub fn uncached_client(&self, url: &DisplaySafeUrl) -> &ClientWithMiddleware { + pub fn uncached_client(&self, url: &DisplaySafeUrl) -> &RedirectClientWithMiddleware { self.client.uncached().for_host(url) } @@ -1215,12 +1231,229 @@ impl Connectivity { mod tests { use std::str::FromStr; + use url::Url; use uv_normalize::PackageName; use uv_pypi_types::{JoinRelativeError, SimpleJson}; use uv_redacted::DisplaySafeUrl; use crate::{SimpleMetadata, SimpleMetadatum, html::SimpleHtml}; + use uv_cache::Cache; + use wiremock::matchers::{basic_auth, method, path_regex}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + use crate::RegistryClientBuilder; + + type Error = Box; + + async fn start_test_server(username: &'static str, password: &'static str) -> MockServer { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(basic_auth(username, password)) + .respond_with(ResponseTemplate::new(200)) + .mount(&server) + .await; + + Mock::given(method("GET")) + .respond_with(ResponseTemplate::new(401)) + .mount(&server) + .await; + + server + } + + #[tokio::test] + async fn test_redirect_to_server_with_credentials() -> Result<(), Error> { + let username = "user"; + let password = "password"; + + let auth_server = start_test_server(username, password).await; + let auth_base_url = DisplaySafeUrl::parse(&auth_server.uri())?; + + let redirect_server = MockServer::start().await; + + // Configure the redirect server to respond with a 302 to the auth server + Mock::given(method("GET")) + .respond_with( + ResponseTemplate::new(302).insert_header("Location", format!("{auth_base_url}")), + ) + .mount(&redirect_server) + .await; + + let redirect_server_url = DisplaySafeUrl::parse(&redirect_server.uri())?; + + let cache = Cache::temp()?; + let registry_client = RegistryClientBuilder::new(cache) + .allow_cross_origin_credentials() + .build(); + let client = registry_client.cached_client().uncached(); + + assert_eq!( + client + .for_host(&redirect_server_url) + .get(redirect_server.uri()) + .send() + .await? + .status(), + 401, + "Requests should fail if credentials are missing" + ); + + let mut url = redirect_server_url.clone(); + let _ = url.set_username(username); + let _ = url.set_password(Some(password)); + + assert_eq!( + client + .for_host(&redirect_server_url) + .get(Url::from(url)) + .send() + .await? + .status(), + 200, + "Requests should succeed if credentials are present" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_redirect_root_relative_url() -> Result<(), Error> { + let username = "user"; + let password = "password"; + + let redirect_server = MockServer::start().await; + + // Configure the redirect server to respond with a 307 with a relative URL. + Mock::given(method("GET")) + .and(path_regex("/foo/")) + .respond_with( + ResponseTemplate::new(307).insert_header("Location", "/bar/baz/".to_string()), + ) + .mount(&redirect_server) + .await; + + Mock::given(method("GET")) + .and(path_regex("/bar/baz/")) + .and(basic_auth(username, password)) + .respond_with(ResponseTemplate::new(200)) + .mount(&redirect_server) + .await; + + let redirect_server_url = DisplaySafeUrl::parse(&redirect_server.uri())?.join("foo/")?; + + let cache = Cache::temp()?; + let registry_client = RegistryClientBuilder::new(cache) + .allow_cross_origin_credentials() + .build(); + let client = registry_client.cached_client().uncached(); + + let mut url = redirect_server_url.clone(); + let _ = url.set_username(username); + let _ = url.set_password(Some(password)); + + assert_eq!( + client + .for_host(&url) + .get(Url::from(url)) + .send() + .await? + .status(), + 200, + "Requests should succeed for relative URL" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_redirect_relative_url() -> Result<(), Error> { + let username = "user"; + let password = "password"; + + let redirect_server = MockServer::start().await; + + // Configure the redirect server to respond with a 307 with a relative URL. + Mock::given(method("GET")) + .and(path_regex("/foo/bar/baz/")) + .and(basic_auth(username, password)) + .respond_with(ResponseTemplate::new(200)) + .mount(&redirect_server) + .await; + + Mock::given(method("GET")) + .and(path_regex("/foo/")) + .and(basic_auth(username, password)) + .respond_with( + ResponseTemplate::new(307).insert_header("Location", "bar/baz/".to_string()), + ) + .mount(&redirect_server) + .await; + + let cache = Cache::temp()?; + let registry_client = RegistryClientBuilder::new(cache) + .allow_cross_origin_credentials() + .build(); + let client = registry_client.cached_client().uncached(); + + let redirect_server_url = DisplaySafeUrl::parse(&redirect_server.uri())?.join("foo/")?; + let mut url = redirect_server_url.clone(); + let _ = url.set_username(username); + let _ = url.set_password(Some(password)); + + assert_eq!( + client + .for_host(&url) + .get(Url::from(url)) + .send() + .await? + .status(), + 200, + "Requests should succeed for relative URL" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_redirect_preserve_fragment() -> Result<(), Error> { + let redirect_server = MockServer::start().await; + + // Configure the redirect server to respond with a 307 with a relative URL. + Mock::given(method("GET")) + .respond_with(ResponseTemplate::new(307).insert_header("Location", "/foo".to_string())) + .mount(&redirect_server) + .await; + + Mock::given(method("GET")) + .and(path_regex("/foo")) + .respond_with(ResponseTemplate::new(200)) + .mount(&redirect_server) + .await; + + let cache = Cache::temp()?; + let registry_client = RegistryClientBuilder::new(cache).build(); + let client = registry_client.cached_client().uncached(); + + let mut url = DisplaySafeUrl::parse(&redirect_server.uri())?; + url.set_fragment(Some("fragment")); + + assert_eq!( + client + .for_host(&url) + .get(Url::from(url.clone())) + .send() + .await? + .url() + .to_string(), + format!("{}/foo#fragment", redirect_server.uri()), + "Requests should preserve fragment" + ); + + Ok(()) + } + #[test] fn ignore_failing_files() { // 1.7.7 has an invalid requires-python field (double comma), 1.7.8 is valid diff --git a/crates/uv-distribution/src/source/mod.rs b/crates/uv-distribution/src/source/mod.rs index 2cf148f2b..90d77bd90 100644 --- a/crates/uv-distribution/src/source/mod.rs +++ b/crates/uv-distribution/src/source/mod.rs @@ -1583,7 +1583,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { client .unmanaged .uncached_client(resource.git.repository()) - .clone(), + .raw_client(), ) .await { @@ -1866,7 +1866,10 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { .git() .github_fast_path( git, - client.unmanaged.uncached_client(git.repository()).clone(), + client + .unmanaged + .uncached_client(git.repository()) + .raw_client(), ) .await? .is_some() diff --git a/crates/uv-git/src/resolver.rs b/crates/uv-git/src/resolver.rs index fd90ff587..d404390f3 100644 --- a/crates/uv-git/src/resolver.rs +++ b/crates/uv-git/src/resolver.rs @@ -53,7 +53,7 @@ impl GitResolver { pub async fn github_fast_path( &self, url: &GitUrl, - client: ClientWithMiddleware, + client: &ClientWithMiddleware, ) -> Result, GitResolverError> { if std::env::var_os(EnvVars::UV_NO_GITHUB_FAST_PATH).is_some() { return Ok(None); @@ -117,7 +117,7 @@ impl GitResolver { pub async fn fetch( &self, url: &GitUrl, - client: ClientWithMiddleware, + client: impl Into, disable_ssl: bool, offline: bool, cache: PathBuf, diff --git a/crates/uv-publish/src/lib.rs b/crates/uv-publish/src/lib.rs index dd8358439..ec19713cc 100644 --- a/crates/uv-publish/src/lib.rs +++ b/crates/uv-publish/src/lib.rs @@ -12,7 +12,6 @@ use itertools::Itertools; use reqwest::header::AUTHORIZATION; use reqwest::multipart::Part; use reqwest::{Body, Response, StatusCode}; -use reqwest_middleware::RequestBuilder; use reqwest_retry::policies::ExponentialBackoff; use reqwest_retry::{RetryPolicy, Retryable, RetryableStrategy}; use rustc_hash::FxHashSet; @@ -29,7 +28,7 @@ use uv_auth::Credentials; use uv_cache::{Cache, Refresh}; use uv_client::{ BaseClient, DEFAULT_RETRIES, MetadataFormat, OwnedArchive, RegistryClientBuilder, - UvRetryableStrategy, + RequestBuilder, UvRetryableStrategy, }; use uv_configuration::{KeyringProviderType, TrustedPublishing}; use uv_distribution_filename::{DistFilename, SourceDistExtension, SourceDistFilename}; @@ -330,7 +329,9 @@ pub async fn check_trusted_publishing( debug!( "Running on GitHub Actions without explicit credentials, checking for trusted publishing" ); - match trusted_publishing::get_token(registry, client.for_host(registry)).await { + match trusted_publishing::get_token(registry, client.for_host(registry).raw_client()) + .await + { Ok(token) => Ok(TrustedPublishResult::Configured(token)), Err(err) => { // TODO(konsti): It would be useful if we could differentiate between actual errors @@ -364,7 +365,9 @@ pub async fn check_trusted_publishing( ); } - let token = trusted_publishing::get_token(registry, client.for_host(registry)).await?; + let token = + trusted_publishing::get_token(registry, client.for_host(registry).raw_client()) + .await?; Ok(TrustedPublishResult::Configured(token)) } TrustedPublishing::Never => Ok(TrustedPublishResult::Skipped), @@ -748,16 +751,16 @@ async fn form_metadata( /// Build the upload request. /// /// Returns the request and the reporter progress bar id. -async fn build_request( +async fn build_request<'a>( file: &Path, raw_filename: &str, filename: &DistFilename, registry: &DisplaySafeUrl, - client: &BaseClient, + client: &'a BaseClient, credentials: &Credentials, form_metadata: &[(&'static str, String)], reporter: Arc, -) -> Result<(RequestBuilder, usize), PublishPrepareError> { +) -> Result<(RequestBuilder<'a>, usize), PublishPrepareError> { let mut form = reqwest::multipart::Form::new(); for (key, value) in form_metadata { form = form.text(*key, value.clone()); @@ -969,12 +972,13 @@ mod tests { project_urls: Source, https://github.com/unknown/tqdm "###); + let client = BaseClientBuilder::new().build(); let (request, _) = build_request( &file, raw_filename, &filename, &DisplaySafeUrl::parse("https://example.org/upload").unwrap(), - &BaseClientBuilder::new().build(), + &client, &Credentials::basic(Some("ferris".to_string()), Some("F3RR!S".to_string())), &form_metadata, Arc::new(DummyReporter), @@ -985,7 +989,7 @@ mod tests { insta::with_settings!({ filters => [("boundary=[0-9a-f-]+", "boundary=[...]")], }, { - assert_debug_snapshot!(&request, @r#" + assert_debug_snapshot!(&request.raw_builder(), @r#" RequestBuilder { inner: RequestBuilder { method: POST, @@ -1118,12 +1122,13 @@ mod tests { requires_dist: requests ; extra == 'telegram' "###); + let client = BaseClientBuilder::new().build(); let (request, _) = build_request( &file, raw_filename, &filename, &DisplaySafeUrl::parse("https://example.org/upload").unwrap(), - &BaseClientBuilder::new().build(), + &client, &Credentials::basic(Some("ferris".to_string()), Some("F3RR!S".to_string())), &form_metadata, Arc::new(DummyReporter), @@ -1134,7 +1139,7 @@ mod tests { insta::with_settings!({ filters => [("boundary=[0-9a-f-]+", "boundary=[...]")], }, { - assert_debug_snapshot!(&request, @r#" + assert_debug_snapshot!(&request.raw_builder(), @r#" RequestBuilder { inner: RequestBuilder { method: POST, diff --git a/crates/uv/tests/it/common/mod.rs b/crates/uv/tests/it/common/mod.rs index f997561a9..7ef7cfff6 100644 --- a/crates/uv/tests/it/common/mod.rs +++ b/crates/uv/tests/it/common/mod.rs @@ -1668,9 +1668,9 @@ pub async fn download_to_disk(url: &str, path: &Path) { .allow_insecure_host(trusted_hosts) .build(); let url = url.parse().unwrap(); - let client = client.for_host(&url); let response = client - .request(http::Method::GET, reqwest::Url::from(url)) + .for_host(&url) + .get(reqwest::Url::from(url)) .send() .await .unwrap(); diff --git a/crates/uv/tests/it/edit.rs b/crates/uv/tests/it/edit.rs index 68eae5110..d117b1e8c 100644 --- a/crates/uv/tests/it/edit.rs +++ b/crates/uv/tests/it/edit.rs @@ -14,6 +14,7 @@ use assert_fs::prelude::*; use indoc::{formatdoc, indoc}; use insta::assert_snapshot; use std::path::Path; +use url::Url; use uv_fs::Simplified; use wiremock::{Mock, MockServer, ResponseTemplate, matchers::method}; @@ -11838,6 +11839,196 @@ fn add_auth_policy_never_without_credentials() -> Result<()> { Ok(()) } +/// If uv receives a 302 redirect to a cross-origin server, it should not forward +/// credentials. In the absence of a netrc entry for the new location, +/// it should fail. +#[tokio::test] +async fn add_redirect_cross_origin() -> Result<()> { + let context = TestContext::new("3.12"); + let filters = context + .filters() + .into_iter() + .chain([(r"127\.0\.0\.1:\d*", "[LOCALHOST]")]) + .collect::>(); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! { r#" + [project] + name = "foo" + version = "1.0.0" + requires-python = ">=3.12" + dependencies = [] + "# + })?; + + let redirect_server = MockServer::start().await; + + Mock::given(method("GET")) + .respond_with(|req: &wiremock::Request| { + let redirect_url = redirect_url_to_pypi_proxy(req); + ResponseTemplate::new(302).insert_header("Location", &redirect_url) + }) + .mount(&redirect_server) + .await; + + let mut redirect_url = Url::parse(&redirect_server.uri())?; + let _ = redirect_url.set_username("public"); + let _ = redirect_url.set_password(Some("heron")); + + uv_snapshot!(filters, context.add().arg("--default-index").arg(redirect_url.as_str()).arg("anyio"), @r" + success: false + exit_code: 1 + ----- stdout ----- + + ----- stderr ----- + × No solution found when resolving dependencies: + ╰─▶ Because anyio was not found in the package registry and your project depends on anyio, we can conclude that your project's requirements are unsatisfiable. + + hint: An index URL (http://[LOCALHOST]/) could not be queried due to a lack of valid authentication credentials (401 Unauthorized). + help: If you want to add the package regardless of the failed resolution, provide the `--frozen` flag to skip locking and syncing. + " + ); + + Ok(()) +} + +/// uv currently fails to look up keyring credentials on a cross-origin redirect. +#[tokio::test] +async fn add_redirect_with_keyring_cross_origin() -> Result<()> { + let keyring_context = TestContext::new("3.12"); + + // Install our keyring plugin + keyring_context + .pip_install() + .arg( + keyring_context + .workspace_root + .join("scripts") + .join("packages") + .join("keyring_test_plugin"), + ) + .assert() + .success(); + + let context = TestContext::new("3.12"); + let filters = context + .filters() + .into_iter() + .chain([(r"127\.0\.0\.1:\d*", "[LOCALHOST]")]) + .collect::>(); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! { r#" + [project] + name = "foo" + version = "1.0.0" + requires-python = ">=3.12" + dependencies = [] + + [tool.uv] + keyring-provider = "subprocess" + "#, + })?; + + let redirect_server = MockServer::start().await; + + Mock::given(method("GET")) + .respond_with(|req: &wiremock::Request| { + let redirect_url = redirect_url_to_pypi_proxy(req); + ResponseTemplate::new(302).insert_header("Location", &redirect_url) + }) + .mount(&redirect_server) + .await; + + let mut redirect_url = Url::parse(&redirect_server.uri())?; + let _ = redirect_url.set_username("public"); + + uv_snapshot!(filters, context.add().arg("--default-index") + .arg(redirect_url.as_str()) + .arg("anyio") + .env(EnvVars::KEYRING_TEST_CREDENTIALS, r#"{"pypi-proxy.fly.dev": {"public": "heron"}}"#) + .env(EnvVars::PATH, venv_bin_path(&keyring_context.venv)), @r" + success: false + exit_code: 1 + ----- stdout ----- + + ----- stderr ----- + Keyring request for public@http://[LOCALHOST]/ + Keyring request for public@[LOCALHOST] + × No solution found when resolving dependencies: + ╰─▶ Because anyio was not found in the package registry and your project depends on anyio, we can conclude that your project's requirements are unsatisfiable. + + hint: An index URL (http://[LOCALHOST]/) could not be queried due to a lack of valid authentication credentials (401 Unauthorized). + help: If you want to add the package regardless of the failed resolution, provide the `--frozen` flag to skip locking and syncing. + " + ); + + Ok(()) +} + +/// If uv receives a cross-origin 302 redirect, it should use credentials from netrc +/// for the new location. +#[tokio::test] +async fn pip_install_redirect_with_netrc_cross_origin() -> Result<()> { + let context = TestContext::new("3.12"); + let filters = context + .filters() + .into_iter() + .chain([(r"127\.0\.0\.1:\d*", "[LOCALHOST]")]) + .collect::>(); + + let netrc = context.temp_dir.child(".netrc"); + netrc.write_str("machine pypi-proxy.fly.dev login public password heron")?; + + let redirect_server = MockServer::start().await; + + Mock::given(method("GET")) + .respond_with(|req: &wiremock::Request| { + let redirect_url = redirect_url_to_pypi_proxy(req); + ResponseTemplate::new(302).insert_header("Location", &redirect_url) + }) + .mount(&redirect_server) + .await; + + let mut redirect_url = Url::parse(&redirect_server.uri())?; + let _ = redirect_url.set_username("public"); + + uv_snapshot!(filters, context.pip_install() + .arg("anyio") + .arg("--index-url") + .arg(redirect_url.as_str()) + .env(EnvVars::NETRC, netrc.to_str().unwrap()) + .arg("--strict"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + + ----- stderr ----- + Resolved 3 packages in [TIME] + Prepared 3 packages in [TIME] + Installed 3 packages in [TIME] + + anyio==4.3.0 + + idna==3.6 + + sniffio==1.3.1 + "### + ); + + context.assert_command("import anyio").success(); + + Ok(()) +} + +fn redirect_url_to_pypi_proxy(req: &wiremock::Request) -> String { + let last_path_segment = req + .url + .path_segments() + .expect("path has segments") + .filter(|segment| !segment.is_empty()) // Filter out empty segments + .next_back() + .expect("path has a package segment"); + format!("https://pypi-proxy.fly.dev/basic-auth/simple/{last_path_segment}/") +} + /// Test the error message when adding a package with multiple existing references in /// `pyproject.toml`. #[test]