From 62365d4ec86c146fc1d2050bff3fa70e53abfa08 Mon Sep 17 00:00:00 2001 From: John Mumm Date: Fri, 20 Jun 2025 03:21:32 -0400 Subject: [PATCH] Support netrc and same-origin credential propagation on index redirects (#14126) This PR is a combination of #12920 and #13754. Prior to these changes, following a redirect when searching indexes would bypass our authentication middleware. This PR updates uv to support propagating credentials through our middleware on same-origin redirects and to support netrc credentials for both same- and cross-origin redirects. It does not handle the case described in #11097 where the redirect location itself includes credentials (e.g., `https://user:pass@redirect-location.com`). That will be addressed in follow-up work. This includes unit tests for the new redirect logic and integration tests for credential propagation. The automated external registries test is also passing for AWS CodeArtifact, Azure Artifacts, GCP Artifact Registry, JFrog Artifactory, GitLab, Cloudsmith, and Gemfury. --- Cargo.lock | 1 + crates/uv-client/Cargo.toml | 1 + crates/uv-client/src/base_client.rs | 574 ++++++++++++++++++++++- crates/uv-client/src/cached_client.rs | 2 - crates/uv-client/src/lib.rs | 2 +- crates/uv-client/src/registry_client.rs | 243 +++++++++- crates/uv-distribution/src/source/mod.rs | 7 +- crates/uv-git/src/resolver.rs | 4 +- crates/uv-publish/src/lib.rs | 27 +- crates/uv/tests/it/common/mod.rs | 4 +- crates/uv/tests/it/edit.rs | 191 ++++++++ 11 files changed, 1022 insertions(+), 34 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 77a17a4f3..a30a0cbe1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4946,6 +4946,7 @@ dependencies = [ "uv-torch", "uv-version", "uv-warnings", + "wiremock", ] [[package]] diff --git a/crates/uv-client/Cargo.toml b/crates/uv-client/Cargo.toml index 81d1909fe..bc7fc611f 100644 --- a/crates/uv-client/Cargo.toml +++ b/crates/uv-client/Cargo.toml @@ -65,3 +65,4 @@ hyper = { version = "1.4.1", features = ["server", "http1"] } hyper-util = { version = "0.1.8", features = ["tokio"] } insta = { version = "1.40.0", features = ["filters", "json", "redactions"] } tokio = { workspace = true } +wiremock = { workspace = true } diff --git a/crates/uv-client/src/base_client.rs b/crates/uv-client/src/base_client.rs index f5fda246d..85c384b0d 100644 --- a/crates/uv-client/src/base_client.rs +++ b/crates/uv-client/src/base_client.rs @@ -6,14 +6,23 @@ use std::sync::Arc; use std::time::Duration; use std::{env, io, iter}; +use anyhow::anyhow; +use http::{ + HeaderMap, HeaderName, HeaderValue, Method, StatusCode, + header::{ + AUTHORIZATION, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, COOKIE, LOCATION, + PROXY_AUTHORIZATION, REFERER, TRANSFER_ENCODING, WWW_AUTHENTICATE, + }, +}; use itertools::Itertools; -use reqwest::{Client, ClientBuilder, Proxy, Response}; +use reqwest::{Client, ClientBuilder, IntoUrl, Proxy, Request, Response, multipart}; use reqwest_middleware::{ClientWithMiddleware, Middleware}; use reqwest_retry::policies::ExponentialBackoff; use reqwest_retry::{ DefaultRetryableStrategy, RetryTransientMiddleware, Retryable, RetryableStrategy, }; use tracing::{debug, trace}; +use url::ParseError; use url::Url; use uv_auth::{AuthMiddleware, Indexes}; @@ -32,6 +41,10 @@ use crate::middleware::OfflineMiddleware; use crate::tls::read_identity; pub const DEFAULT_RETRIES: u32 = 3; +/// Maximum number of redirects to follow before giving up. +/// +/// This is the default used by [`reqwest`]. +const DEFAULT_MAX_REDIRECTS: u32 = 10; /// Selectively skip parts or the entire auth middleware. #[derive(Debug, Clone, Copy, Default)] @@ -61,6 +74,31 @@ pub struct BaseClientBuilder<'a> { default_timeout: Duration, extra_middleware: Option, proxies: Vec, + redirect_policy: RedirectPolicy, + /// Whether credentials should be propagated during cross-origin redirects. + /// + /// A policy allowing propagation is insecure and should only be available for test code. + cross_origin_credential_policy: CrossOriginCredentialsPolicy, +} + +/// The policy for handling HTTP redirects. +#[derive(Debug, Default, Clone, Copy)] +pub enum RedirectPolicy { + /// Use reqwest's built-in redirect handling. This bypasses our custom middleware + /// on redirect. + #[default] + BypassMiddleware, + /// Handle redirects manually, re-triggering our custom middleware for each request. + RetriggerMiddleware, +} + +impl RedirectPolicy { + pub fn reqwest_policy(self) -> reqwest::redirect::Policy { + match self { + RedirectPolicy::BypassMiddleware => reqwest::redirect::Policy::default(), + RedirectPolicy::RetriggerMiddleware => reqwest::redirect::Policy::none(), + } + } } /// A list of user-defined middlewares to be applied to the client. @@ -96,6 +134,8 @@ impl BaseClientBuilder<'_> { default_timeout: Duration::from_secs(30), extra_middleware: None, proxies: vec![], + redirect_policy: RedirectPolicy::default(), + cross_origin_credential_policy: CrossOriginCredentialsPolicy::Secure, } } } @@ -173,6 +213,24 @@ impl<'a> BaseClientBuilder<'a> { self } + #[must_use] + pub fn redirect(mut self, policy: RedirectPolicy) -> Self { + self.redirect_policy = policy; + self + } + + /// Allows credentials to be propagated on cross-origin redirects. + /// + /// WARNING: This should only be available for tests. In production code, propagating credentials + /// during cross-origin redirects can lead to security vulnerabilities including credential + /// leakage to untrusted domains. + #[cfg(test)] + #[must_use] + pub fn allow_cross_origin_credentials(mut self) -> Self { + self.cross_origin_credential_policy = CrossOriginCredentialsPolicy::Insecure; + self + } + pub fn is_offline(&self) -> bool { matches!(self.connectivity, Connectivity::Offline) } @@ -229,6 +287,7 @@ impl<'a> BaseClientBuilder<'a> { timeout, ssl_cert_file_exists, Security::Secure, + self.redirect_policy, ); // Create an insecure client that accepts invalid certificates. @@ -237,11 +296,20 @@ impl<'a> BaseClientBuilder<'a> { timeout, ssl_cert_file_exists, Security::Insecure, + self.redirect_policy, ); // Wrap in any relevant middleware and handle connectivity. - let client = self.apply_middleware(raw_client.clone()); - let dangerous_client = self.apply_middleware(raw_dangerous_client.clone()); + let client = RedirectClientWithMiddleware { + client: self.apply_middleware(raw_client.clone()), + redirect_policy: self.redirect_policy, + cross_origin_credentials_policy: self.cross_origin_credential_policy, + }; + let dangerous_client = RedirectClientWithMiddleware { + client: self.apply_middleware(raw_dangerous_client.clone()), + redirect_policy: self.redirect_policy, + cross_origin_credentials_policy: self.cross_origin_credential_policy, + }; BaseClient { connectivity: self.connectivity, @@ -258,8 +326,16 @@ impl<'a> BaseClientBuilder<'a> { /// Share the underlying client between two different middleware configurations. pub fn wrap_existing(&self, existing: &BaseClient) -> BaseClient { // Wrap in any relevant middleware and handle connectivity. - let client = self.apply_middleware(existing.raw_client.clone()); - let dangerous_client = self.apply_middleware(existing.raw_dangerous_client.clone()); + let client = RedirectClientWithMiddleware { + client: self.apply_middleware(existing.raw_client.clone()), + redirect_policy: self.redirect_policy, + cross_origin_credentials_policy: self.cross_origin_credential_policy, + }; + let dangerous_client = RedirectClientWithMiddleware { + client: self.apply_middleware(existing.raw_dangerous_client.clone()), + redirect_policy: self.redirect_policy, + cross_origin_credentials_policy: self.cross_origin_credential_policy, + }; BaseClient { connectivity: self.connectivity, @@ -279,6 +355,7 @@ impl<'a> BaseClientBuilder<'a> { timeout: Duration, ssl_cert_file_exists: bool, security: Security, + redirect_policy: RedirectPolicy, ) -> Client { // Configure the builder. let client_builder = ClientBuilder::new() @@ -286,7 +363,8 @@ impl<'a> BaseClientBuilder<'a> { .user_agent(user_agent) .pool_max_idle_per_host(20) .read_timeout(timeout) - .tls_built_in_root_certs(false); + .tls_built_in_root_certs(false) + .redirect(redirect_policy.reqwest_policy()); // If necessary, accept invalid certificates. let client_builder = match security { @@ -381,9 +459,9 @@ impl<'a> BaseClientBuilder<'a> { #[derive(Debug, Clone)] pub struct BaseClient { /// The underlying HTTP client that enforces valid certificates. - client: ClientWithMiddleware, + client: RedirectClientWithMiddleware, /// The underlying HTTP client that accepts invalid certificates. - dangerous_client: ClientWithMiddleware, + dangerous_client: RedirectClientWithMiddleware, /// The HTTP client without middleware. raw_client: Client, /// The HTTP client that accepts invalid certificates without middleware. @@ -408,7 +486,7 @@ enum Security { impl BaseClient { /// Selects the appropriate client based on the host's trustworthiness. - pub fn for_host(&self, url: &DisplaySafeUrl) -> &ClientWithMiddleware { + pub fn for_host(&self, url: &DisplaySafeUrl) -> &RedirectClientWithMiddleware { if self.disable_ssl(url) { &self.dangerous_client } else { @@ -416,6 +494,12 @@ impl BaseClient { } } + /// Executes a request, applying redirect policy. + pub async fn execute(&self, req: Request) -> reqwest_middleware::Result { + let client = self.for_host(&DisplaySafeUrl::from(req.url().clone())); + client.execute(req).await + } + /// Returns `true` if the host is trusted to use the insecure client. pub fn disable_ssl(&self, url: &DisplaySafeUrl) -> bool { self.allow_insecure_host @@ -439,6 +523,316 @@ impl BaseClient { } } +/// Wrapper around [`ClientWithMiddleware`] that manages redirects. +#[derive(Debug, Clone)] +pub struct RedirectClientWithMiddleware { + client: ClientWithMiddleware, + redirect_policy: RedirectPolicy, + /// Whether credentials should be preserved during cross-origin redirects. + /// + /// WARNING: This should only be available for tests. In production code, preserving credentials + /// during cross-origin redirects can lead to security vulnerabilities including credential + /// leakage to untrusted domains. + cross_origin_credentials_policy: CrossOriginCredentialsPolicy, +} + +impl RedirectClientWithMiddleware { + /// Convenience method to make a `GET` request to a URL. + pub fn get(&self, url: U) -> RequestBuilder { + RequestBuilder::new(self.client.get(url), self) + } + + /// Convenience method to make a `POST` request to a URL. + pub fn post(&self, url: U) -> RequestBuilder { + RequestBuilder::new(self.client.post(url), self) + } + + /// Convenience method to make a `HEAD` request to a URL. + pub fn head(&self, url: U) -> RequestBuilder { + RequestBuilder::new(self.client.head(url), self) + } + + /// Executes a request, applying the redirect policy. + pub async fn execute(&self, req: Request) -> reqwest_middleware::Result { + match self.redirect_policy { + RedirectPolicy::BypassMiddleware => self.client.execute(req).await, + RedirectPolicy::RetriggerMiddleware => self.execute_with_redirect_handling(req).await, + } + } + + /// Executes a request. If the response is a redirect (one of HTTP 301, 302, 303, 307, or 308), the + /// request is executed again with the redirect location URL (up to a maximum number of + /// redirects). + /// + /// Unlike the built-in reqwest redirect policies, this sends the redirect request through the + /// entire middleware pipeline again. + /// + /// See RFC 7231 7.1.2 for details on + /// redirect semantics. + async fn execute_with_redirect_handling( + &self, + req: Request, + ) -> reqwest_middleware::Result { + let mut request = req; + let mut redirects = 0; + let max_redirects = DEFAULT_MAX_REDIRECTS; + + loop { + let result = self + .client + .execute(request.try_clone().expect("HTTP request must be cloneable")) + .await; + let Ok(response) = result else { + return result; + }; + + if redirects >= max_redirects { + return Ok(response); + } + + let Some(redirect_request) = + request_into_redirect(request, &response, self.cross_origin_credentials_policy)? + else { + return Ok(response); + }; + + redirects += 1; + request = redirect_request; + } + } + + pub fn raw_client(&self) -> &ClientWithMiddleware { + &self.client + } +} + +impl From for ClientWithMiddleware { + fn from(item: RedirectClientWithMiddleware) -> ClientWithMiddleware { + item.client + } +} + +/// Check if this is should be a redirect and, if so, return a new redirect request. +/// +/// This implementation is based on the [`reqwest`] crate redirect implementation. +/// It takes ownership of the original [`Request`] and mutates it to create the new +/// redirect [`Request`]. +fn request_into_redirect( + mut req: Request, + res: &Response, + cross_origin_credentials_policy: CrossOriginCredentialsPolicy, +) -> reqwest_middleware::Result> { + let original_req_url = DisplaySafeUrl::from(req.url().clone()); + let status = res.status(); + let should_redirect = match status { + StatusCode::MOVED_PERMANENTLY + | StatusCode::FOUND + | StatusCode::TEMPORARY_REDIRECT + | StatusCode::PERMANENT_REDIRECT => true, + StatusCode::SEE_OTHER => { + // Per RFC 7231, HTTP 303 is intended for the user agent + // to perform a GET or HEAD request to the redirect target. + // Historically, some browsers also changed method from POST + // to GET on 301 or 302, but this is not required by RFC 7231 + // and was not intended by the HTTP spec. + *req.body_mut() = None; + for header in &[ + TRANSFER_ENCODING, + CONTENT_ENCODING, + CONTENT_TYPE, + CONTENT_LENGTH, + ] { + req.headers_mut().remove(header); + } + + match *req.method() { + Method::GET | Method::HEAD => {} + _ => { + *req.method_mut() = Method::GET; + } + } + true + } + _ => false, + }; + if !should_redirect { + return Ok(None); + } + + let location = res + .headers() + .get(LOCATION) + .ok_or(reqwest_middleware::Error::Middleware(anyhow!( + "Server returned redirect (HTTP {status}) without destination URL. This may indicate a server configuration issue" + )))? + .to_str() + .map_err(|_| { + reqwest_middleware::Error::Middleware(anyhow!( + "Invalid HTTP {status} 'Location' value: must only contain visible ascii characters" + )) + })?; + + let mut redirect_url = match DisplaySafeUrl::parse(location) { + Ok(url) => url, + // Per RFC 7231, URLs should be resolved against the request URL. + Err(ParseError::RelativeUrlWithoutBase) => original_req_url.join(location).map_err(|err| { + reqwest_middleware::Error::Middleware(anyhow!( + "Invalid HTTP {status} 'Location' value `{location}` relative to `{original_req_url}`: {err}" + )) + })?, + Err(err) => { + return Err(reqwest_middleware::Error::Middleware(anyhow!( + "Invalid HTTP {status} 'Location' value `{location}`: {err}" + ))); + } + }; + // Per RFC 7231, fragments must be propagated + if let Some(fragment) = original_req_url.fragment() { + redirect_url.set_fragment(Some(fragment)); + } + + // Ensure the URL is a valid HTTP URI. + if let Err(err) = redirect_url.as_str().parse::() { + return Err(reqwest_middleware::Error::Middleware(anyhow!( + "HTTP {status} 'Location' value `{redirect_url}` is not a valid HTTP URI: {err}" + ))); + } + + if redirect_url.scheme() != "http" && redirect_url.scheme() != "https" { + return Err(reqwest_middleware::Error::Middleware(anyhow!( + "Invalid HTTP {status} 'Location' value `{redirect_url}`: scheme needs to be https or http" + ))); + } + + let mut headers = HeaderMap::new(); + std::mem::swap(req.headers_mut(), &mut headers); + + let cross_host = redirect_url.host_str() != original_req_url.host_str() + || redirect_url.port_or_known_default() != original_req_url.port_or_known_default(); + if cross_host { + if cross_origin_credentials_policy == CrossOriginCredentialsPolicy::Secure { + debug!("Received a cross-origin redirect. Removing sensitive headers."); + headers.remove(AUTHORIZATION); + headers.remove(COOKIE); + headers.remove(PROXY_AUTHORIZATION); + headers.remove(WWW_AUTHENTICATE); + } + // If the redirect request is not a cross-origin request and the original request already + // had a Referer header, attempt to set the Referer header for the redirect request. + } else if headers.contains_key(REFERER) { + if let Some(referer) = make_referer(&redirect_url, &original_req_url) { + headers.insert(REFERER, referer); + } + } + + std::mem::swap(req.headers_mut(), &mut headers); + *req.url_mut() = Url::from(redirect_url); + debug!( + "Received HTTP {status}. Redirecting to {}", + DisplaySafeUrl::ref_cast(req.url()) + ); + Ok(Some(req)) +} + +/// Return a Referer [`HeaderValue`] according to RFC 7231. +/// +/// Return [`None`] if https has been downgraded in the redirect location. +fn make_referer( + redirect_url: &DisplaySafeUrl, + original_url: &DisplaySafeUrl, +) -> Option { + if redirect_url.scheme() == "http" && original_url.scheme() == "https" { + return None; + } + + let mut referer = original_url.clone(); + referer.remove_credentials(); + referer.set_fragment(None); + referer.as_str().parse().ok() +} + +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] +pub(crate) enum CrossOriginCredentialsPolicy { + /// Do not propagate credentials on cross-origin requests. + #[default] + Secure, + + /// Propagate credentials on cross-origin requests. + /// + /// WARNING: This should only be available for tests. In production code, preserving credentials + /// during cross-origin redirects can lead to security vulnerabilities including credential + /// leakage to untrusted domains. + #[cfg(test)] + Insecure, +} + +/// A builder to construct the properties of a `Request`. +/// +/// This wraps [`reqwest_middleware::RequestBuilder`] to ensure that the [`BaseClient`] +/// redirect policy is respected if `send()` is called. +#[derive(Debug)] +#[must_use] +pub struct RequestBuilder<'a> { + builder: reqwest_middleware::RequestBuilder, + client: &'a RedirectClientWithMiddleware, +} + +impl<'a> RequestBuilder<'a> { + pub fn new( + builder: reqwest_middleware::RequestBuilder, + client: &'a RedirectClientWithMiddleware, + ) -> Self { + Self { builder, client } + } + + /// Add a `Header` to this Request. + pub fn header(mut self, key: K, value: V) -> Self + where + HeaderName: TryFrom, + >::Error: Into, + HeaderValue: TryFrom, + >::Error: Into, + { + self.builder = self.builder.header(key, value); + self + } + + /// Add a set of Headers to the existing ones on this Request. + /// + /// The headers will be merged in to any already set. + pub fn headers(mut self, headers: HeaderMap) -> Self { + self.builder = self.builder.headers(headers); + self + } + + #[cfg(not(target_arch = "wasm32"))] + pub fn version(mut self, version: reqwest::Version) -> Self { + self.builder = self.builder.version(version); + self + } + + #[cfg_attr(docsrs, doc(cfg(feature = "multipart")))] + pub fn multipart(mut self, multipart: multipart::Form) -> Self { + self.builder = self.builder.multipart(multipart); + self + } + + /// Build a `Request`. + pub fn build(self) -> reqwest::Result { + self.builder.build() + } + + /// Constructs the Request and sends it to the target URL, returning a + /// future Response. + pub async fn send(self) -> reqwest_middleware::Result { + self.client.execute(self.build()?).await + } + + pub fn raw_builder(&self) -> &reqwest_middleware::RequestBuilder { + &self.builder + } +} + /// Extends [`DefaultRetryableStrategy`], to log transient request failures and additional retry cases. pub struct UvRetryableStrategy; @@ -528,3 +922,165 @@ fn find_source(orig: &dyn Error) -> Option<&E> { fn find_sources(orig: &dyn Error) -> impl Iterator { iter::successors(find_source::(orig), |&err| find_source(err)) } + +#[cfg(test)] +mod tests { + use super::*; + use anyhow::Result; + + use reqwest::{Client, Method}; + use wiremock::matchers::method; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + use crate::base_client::request_into_redirect; + + #[tokio::test] + async fn test_redirect_preserves_authorization_header_on_same_origin() -> Result<()> { + for status in &[301, 302, 303, 307, 308] { + let server = MockServer::start().await; + Mock::given(method("GET")) + .respond_with( + ResponseTemplate::new(*status) + .insert_header("location", format!("{}/redirect", server.uri())), + ) + .mount(&server) + .await; + + let request = Client::new() + .get(server.uri()) + .basic_auth("username", Some("password")) + .build() + .unwrap(); + + assert!(request.headers().contains_key(AUTHORIZATION)); + + let response = Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap() + .execute(request.try_clone().unwrap()) + .await + .unwrap(); + + let redirect_request = + request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)? + .unwrap(); + assert!(redirect_request.headers().contains_key(AUTHORIZATION)); + } + + Ok(()) + } + + #[tokio::test] + async fn test_redirect_removes_authorization_header_on_cross_origin() -> Result<()> { + for status in &[301, 302, 303, 307, 308] { + let server = MockServer::start().await; + Mock::given(method("GET")) + .respond_with( + ResponseTemplate::new(*status) + .insert_header("location", "https://cross-origin.com/simple"), + ) + .mount(&server) + .await; + + let request = Client::new() + .get(server.uri()) + .basic_auth("username", Some("password")) + .build() + .unwrap(); + + assert!(request.headers().contains_key(AUTHORIZATION)); + + let response = Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap() + .execute(request.try_clone().unwrap()) + .await + .unwrap(); + + let redirect_request = + request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)? + .unwrap(); + assert!(!redirect_request.headers().contains_key(AUTHORIZATION)); + } + + Ok(()) + } + + #[tokio::test] + async fn test_redirect_303_changes_post_to_get() -> Result<()> { + let server = MockServer::start().await; + Mock::given(method("POST")) + .respond_with( + ResponseTemplate::new(303) + .insert_header("location", format!("{}/redirect", server.uri())), + ) + .mount(&server) + .await; + + let request = Client::new() + .post(server.uri()) + .basic_auth("username", Some("password")) + .build() + .unwrap(); + + assert_eq!(request.method(), Method::POST); + + let response = Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap() + .execute(request.try_clone().unwrap()) + .await + .unwrap(); + + let redirect_request = + request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)? + .unwrap(); + assert_eq!(redirect_request.method(), Method::GET); + + Ok(()) + } + + #[tokio::test] + async fn test_redirect_no_referer_if_disabled() -> Result<()> { + for status in &[301, 302, 303, 307, 308] { + let server = MockServer::start().await; + Mock::given(method("GET")) + .respond_with( + ResponseTemplate::new(*status) + .insert_header("location", format!("{}/redirect", server.uri())), + ) + .mount(&server) + .await; + + let request = Client::builder() + .referer(false) + .build() + .unwrap() + .get(server.uri()) + .basic_auth("username", Some("password")) + .build() + .unwrap(); + + assert!(!request.headers().contains_key(REFERER)); + + let response = Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap() + .execute(request.try_clone().unwrap()) + .await + .unwrap(); + + let redirect_request = + request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)? + .unwrap(); + + assert!(!redirect_request.headers().contains_key(REFERER)); + } + + Ok(()) + } +} diff --git a/crates/uv-client/src/cached_client.rs b/crates/uv-client/src/cached_client.rs index d19f95ec7..ee3314d1c 100644 --- a/crates/uv-client/src/cached_client.rs +++ b/crates/uv-client/src/cached_client.rs @@ -523,7 +523,6 @@ impl CachedClient { debug!("Sending revalidation request for: {url}"); let response = self .0 - .for_host(&url) .execute(req) .instrument(info_span!("revalidation_request", url = url.as_str())) .await @@ -564,7 +563,6 @@ impl CachedClient { let cache_policy_builder = CachePolicyBuilder::new(&req); let response = self .0 - .for_host(&url) .execute(req) .await .map_err(|err| ErrorKind::from_reqwest_middleware(url.clone(), err))?; diff --git a/crates/uv-client/src/lib.rs b/crates/uv-client/src/lib.rs index 3ea33204c..e42c86620 100644 --- a/crates/uv-client/src/lib.rs +++ b/crates/uv-client/src/lib.rs @@ -1,6 +1,6 @@ pub use base_client::{ AuthIntegration, BaseClient, BaseClientBuilder, DEFAULT_RETRIES, ExtraMiddleware, - UvRetryableStrategy, is_extended_transient_error, + RedirectClientWithMiddleware, RequestBuilder, UvRetryableStrategy, is_extended_transient_error, }; pub use cached_client::{CacheControl, CachedClient, CachedClientError, DataWithCachePolicy}; pub use error::{Error, ErrorKind, WrappedReqwestError}; diff --git a/crates/uv-client/src/registry_client.rs b/crates/uv-client/src/registry_client.rs index 6271b7d20..b53e1ed9a 100644 --- a/crates/uv-client/src/registry_client.rs +++ b/crates/uv-client/src/registry_client.rs @@ -10,7 +10,6 @@ use futures::{FutureExt, StreamExt, TryStreamExt}; use http::{HeaderMap, StatusCode}; use itertools::Either; use reqwest::{Proxy, Response}; -use reqwest_middleware::ClientWithMiddleware; use rustc_hash::FxHashMap; use tokio::sync::{Mutex, Semaphore}; use tracing::{Instrument, debug, info_span, instrument, trace, warn}; @@ -35,13 +34,16 @@ use uv_redacted::DisplaySafeUrl; use uv_small_str::SmallString; use uv_torch::TorchStrategy; -use crate::base_client::{BaseClientBuilder, ExtraMiddleware}; +use crate::base_client::{BaseClientBuilder, ExtraMiddleware, RedirectPolicy}; use crate::cached_client::CacheControl; use crate::flat_index::FlatIndexEntry; use crate::html::SimpleHtml; use crate::remote_metadata::wheel_metadata_from_remote_zip; use crate::rkyvutil::OwnedArchive; -use crate::{BaseClient, CachedClient, Error, ErrorKind, FlatIndexClient, FlatIndexEntries}; +use crate::{ + BaseClient, CachedClient, Error, ErrorKind, FlatIndexClient, FlatIndexEntries, + RedirectClientWithMiddleware, +}; /// A builder for an [`RegistryClient`]. #[derive(Debug, Clone)] @@ -149,9 +151,23 @@ impl<'a> RegistryClientBuilder<'a> { self } + /// Allows credentials to be propagated on cross-origin redirects. + /// + /// WARNING: This should only be available for tests. In production code, propagating credentials + /// during cross-origin redirects can lead to security vulnerabilities including credential + /// leakage to untrusted domains. + #[cfg(test)] + #[must_use] + pub fn allow_cross_origin_credentials(mut self) -> Self { + self.base_client_builder = self.base_client_builder.allow_cross_origin_credentials(); + self + } + pub fn build(self) -> RegistryClient { // Build a base client - let builder = self.base_client_builder; + let builder = self + .base_client_builder + .redirect(RedirectPolicy::RetriggerMiddleware); let client = builder.build(); @@ -248,7 +264,7 @@ impl RegistryClient { } /// Return the [`BaseClient`] used by this client. - pub fn uncached_client(&self, url: &DisplaySafeUrl) -> &ClientWithMiddleware { + pub fn uncached_client(&self, url: &DisplaySafeUrl) -> &RedirectClientWithMiddleware { self.client.uncached().for_host(url) } @@ -1215,12 +1231,229 @@ impl Connectivity { mod tests { use std::str::FromStr; + use url::Url; use uv_normalize::PackageName; use uv_pypi_types::{JoinRelativeError, SimpleJson}; use uv_redacted::DisplaySafeUrl; use crate::{SimpleMetadata, SimpleMetadatum, html::SimpleHtml}; + use uv_cache::Cache; + use wiremock::matchers::{basic_auth, method, path_regex}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + use crate::RegistryClientBuilder; + + type Error = Box; + + async fn start_test_server(username: &'static str, password: &'static str) -> MockServer { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(basic_auth(username, password)) + .respond_with(ResponseTemplate::new(200)) + .mount(&server) + .await; + + Mock::given(method("GET")) + .respond_with(ResponseTemplate::new(401)) + .mount(&server) + .await; + + server + } + + #[tokio::test] + async fn test_redirect_to_server_with_credentials() -> Result<(), Error> { + let username = "user"; + let password = "password"; + + let auth_server = start_test_server(username, password).await; + let auth_base_url = DisplaySafeUrl::parse(&auth_server.uri())?; + + let redirect_server = MockServer::start().await; + + // Configure the redirect server to respond with a 302 to the auth server + Mock::given(method("GET")) + .respond_with( + ResponseTemplate::new(302).insert_header("Location", format!("{auth_base_url}")), + ) + .mount(&redirect_server) + .await; + + let redirect_server_url = DisplaySafeUrl::parse(&redirect_server.uri())?; + + let cache = Cache::temp()?; + let registry_client = RegistryClientBuilder::new(cache) + .allow_cross_origin_credentials() + .build(); + let client = registry_client.cached_client().uncached(); + + assert_eq!( + client + .for_host(&redirect_server_url) + .get(redirect_server.uri()) + .send() + .await? + .status(), + 401, + "Requests should fail if credentials are missing" + ); + + let mut url = redirect_server_url.clone(); + let _ = url.set_username(username); + let _ = url.set_password(Some(password)); + + assert_eq!( + client + .for_host(&redirect_server_url) + .get(Url::from(url)) + .send() + .await? + .status(), + 200, + "Requests should succeed if credentials are present" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_redirect_root_relative_url() -> Result<(), Error> { + let username = "user"; + let password = "password"; + + let redirect_server = MockServer::start().await; + + // Configure the redirect server to respond with a 307 with a relative URL. + Mock::given(method("GET")) + .and(path_regex("/foo/")) + .respond_with( + ResponseTemplate::new(307).insert_header("Location", "/bar/baz/".to_string()), + ) + .mount(&redirect_server) + .await; + + Mock::given(method("GET")) + .and(path_regex("/bar/baz/")) + .and(basic_auth(username, password)) + .respond_with(ResponseTemplate::new(200)) + .mount(&redirect_server) + .await; + + let redirect_server_url = DisplaySafeUrl::parse(&redirect_server.uri())?.join("foo/")?; + + let cache = Cache::temp()?; + let registry_client = RegistryClientBuilder::new(cache) + .allow_cross_origin_credentials() + .build(); + let client = registry_client.cached_client().uncached(); + + let mut url = redirect_server_url.clone(); + let _ = url.set_username(username); + let _ = url.set_password(Some(password)); + + assert_eq!( + client + .for_host(&url) + .get(Url::from(url)) + .send() + .await? + .status(), + 200, + "Requests should succeed for relative URL" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_redirect_relative_url() -> Result<(), Error> { + let username = "user"; + let password = "password"; + + let redirect_server = MockServer::start().await; + + // Configure the redirect server to respond with a 307 with a relative URL. + Mock::given(method("GET")) + .and(path_regex("/foo/bar/baz/")) + .and(basic_auth(username, password)) + .respond_with(ResponseTemplate::new(200)) + .mount(&redirect_server) + .await; + + Mock::given(method("GET")) + .and(path_regex("/foo/")) + .and(basic_auth(username, password)) + .respond_with( + ResponseTemplate::new(307).insert_header("Location", "bar/baz/".to_string()), + ) + .mount(&redirect_server) + .await; + + let cache = Cache::temp()?; + let registry_client = RegistryClientBuilder::new(cache) + .allow_cross_origin_credentials() + .build(); + let client = registry_client.cached_client().uncached(); + + let redirect_server_url = DisplaySafeUrl::parse(&redirect_server.uri())?.join("foo/")?; + let mut url = redirect_server_url.clone(); + let _ = url.set_username(username); + let _ = url.set_password(Some(password)); + + assert_eq!( + client + .for_host(&url) + .get(Url::from(url)) + .send() + .await? + .status(), + 200, + "Requests should succeed for relative URL" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_redirect_preserve_fragment() -> Result<(), Error> { + let redirect_server = MockServer::start().await; + + // Configure the redirect server to respond with a 307 with a relative URL. + Mock::given(method("GET")) + .respond_with(ResponseTemplate::new(307).insert_header("Location", "/foo".to_string())) + .mount(&redirect_server) + .await; + + Mock::given(method("GET")) + .and(path_regex("/foo")) + .respond_with(ResponseTemplate::new(200)) + .mount(&redirect_server) + .await; + + let cache = Cache::temp()?; + let registry_client = RegistryClientBuilder::new(cache).build(); + let client = registry_client.cached_client().uncached(); + + let mut url = DisplaySafeUrl::parse(&redirect_server.uri())?; + url.set_fragment(Some("fragment")); + + assert_eq!( + client + .for_host(&url) + .get(Url::from(url.clone())) + .send() + .await? + .url() + .to_string(), + format!("{}/foo#fragment", redirect_server.uri()), + "Requests should preserve fragment" + ); + + Ok(()) + } + #[test] fn ignore_failing_files() { // 1.7.7 has an invalid requires-python field (double comma), 1.7.8 is valid diff --git a/crates/uv-distribution/src/source/mod.rs b/crates/uv-distribution/src/source/mod.rs index 2cf148f2b..90d77bd90 100644 --- a/crates/uv-distribution/src/source/mod.rs +++ b/crates/uv-distribution/src/source/mod.rs @@ -1583,7 +1583,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { client .unmanaged .uncached_client(resource.git.repository()) - .clone(), + .raw_client(), ) .await { @@ -1866,7 +1866,10 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { .git() .github_fast_path( git, - client.unmanaged.uncached_client(git.repository()).clone(), + client + .unmanaged + .uncached_client(git.repository()) + .raw_client(), ) .await? .is_some() diff --git a/crates/uv-git/src/resolver.rs b/crates/uv-git/src/resolver.rs index fd90ff587..d404390f3 100644 --- a/crates/uv-git/src/resolver.rs +++ b/crates/uv-git/src/resolver.rs @@ -53,7 +53,7 @@ impl GitResolver { pub async fn github_fast_path( &self, url: &GitUrl, - client: ClientWithMiddleware, + client: &ClientWithMiddleware, ) -> Result, GitResolverError> { if std::env::var_os(EnvVars::UV_NO_GITHUB_FAST_PATH).is_some() { return Ok(None); @@ -117,7 +117,7 @@ impl GitResolver { pub async fn fetch( &self, url: &GitUrl, - client: ClientWithMiddleware, + client: impl Into, disable_ssl: bool, offline: bool, cache: PathBuf, diff --git a/crates/uv-publish/src/lib.rs b/crates/uv-publish/src/lib.rs index dd8358439..ec19713cc 100644 --- a/crates/uv-publish/src/lib.rs +++ b/crates/uv-publish/src/lib.rs @@ -12,7 +12,6 @@ use itertools::Itertools; use reqwest::header::AUTHORIZATION; use reqwest::multipart::Part; use reqwest::{Body, Response, StatusCode}; -use reqwest_middleware::RequestBuilder; use reqwest_retry::policies::ExponentialBackoff; use reqwest_retry::{RetryPolicy, Retryable, RetryableStrategy}; use rustc_hash::FxHashSet; @@ -29,7 +28,7 @@ use uv_auth::Credentials; use uv_cache::{Cache, Refresh}; use uv_client::{ BaseClient, DEFAULT_RETRIES, MetadataFormat, OwnedArchive, RegistryClientBuilder, - UvRetryableStrategy, + RequestBuilder, UvRetryableStrategy, }; use uv_configuration::{KeyringProviderType, TrustedPublishing}; use uv_distribution_filename::{DistFilename, SourceDistExtension, SourceDistFilename}; @@ -330,7 +329,9 @@ pub async fn check_trusted_publishing( debug!( "Running on GitHub Actions without explicit credentials, checking for trusted publishing" ); - match trusted_publishing::get_token(registry, client.for_host(registry)).await { + match trusted_publishing::get_token(registry, client.for_host(registry).raw_client()) + .await + { Ok(token) => Ok(TrustedPublishResult::Configured(token)), Err(err) => { // TODO(konsti): It would be useful if we could differentiate between actual errors @@ -364,7 +365,9 @@ pub async fn check_trusted_publishing( ); } - let token = trusted_publishing::get_token(registry, client.for_host(registry)).await?; + let token = + trusted_publishing::get_token(registry, client.for_host(registry).raw_client()) + .await?; Ok(TrustedPublishResult::Configured(token)) } TrustedPublishing::Never => Ok(TrustedPublishResult::Skipped), @@ -748,16 +751,16 @@ async fn form_metadata( /// Build the upload request. /// /// Returns the request and the reporter progress bar id. -async fn build_request( +async fn build_request<'a>( file: &Path, raw_filename: &str, filename: &DistFilename, registry: &DisplaySafeUrl, - client: &BaseClient, + client: &'a BaseClient, credentials: &Credentials, form_metadata: &[(&'static str, String)], reporter: Arc, -) -> Result<(RequestBuilder, usize), PublishPrepareError> { +) -> Result<(RequestBuilder<'a>, usize), PublishPrepareError> { let mut form = reqwest::multipart::Form::new(); for (key, value) in form_metadata { form = form.text(*key, value.clone()); @@ -969,12 +972,13 @@ mod tests { project_urls: Source, https://github.com/unknown/tqdm "###); + let client = BaseClientBuilder::new().build(); let (request, _) = build_request( &file, raw_filename, &filename, &DisplaySafeUrl::parse("https://example.org/upload").unwrap(), - &BaseClientBuilder::new().build(), + &client, &Credentials::basic(Some("ferris".to_string()), Some("F3RR!S".to_string())), &form_metadata, Arc::new(DummyReporter), @@ -985,7 +989,7 @@ mod tests { insta::with_settings!({ filters => [("boundary=[0-9a-f-]+", "boundary=[...]")], }, { - assert_debug_snapshot!(&request, @r#" + assert_debug_snapshot!(&request.raw_builder(), @r#" RequestBuilder { inner: RequestBuilder { method: POST, @@ -1118,12 +1122,13 @@ mod tests { requires_dist: requests ; extra == 'telegram' "###); + let client = BaseClientBuilder::new().build(); let (request, _) = build_request( &file, raw_filename, &filename, &DisplaySafeUrl::parse("https://example.org/upload").unwrap(), - &BaseClientBuilder::new().build(), + &client, &Credentials::basic(Some("ferris".to_string()), Some("F3RR!S".to_string())), &form_metadata, Arc::new(DummyReporter), @@ -1134,7 +1139,7 @@ mod tests { insta::with_settings!({ filters => [("boundary=[0-9a-f-]+", "boundary=[...]")], }, { - assert_debug_snapshot!(&request, @r#" + assert_debug_snapshot!(&request.raw_builder(), @r#" RequestBuilder { inner: RequestBuilder { method: POST, diff --git a/crates/uv/tests/it/common/mod.rs b/crates/uv/tests/it/common/mod.rs index f997561a9..7ef7cfff6 100644 --- a/crates/uv/tests/it/common/mod.rs +++ b/crates/uv/tests/it/common/mod.rs @@ -1668,9 +1668,9 @@ pub async fn download_to_disk(url: &str, path: &Path) { .allow_insecure_host(trusted_hosts) .build(); let url = url.parse().unwrap(); - let client = client.for_host(&url); let response = client - .request(http::Method::GET, reqwest::Url::from(url)) + .for_host(&url) + .get(reqwest::Url::from(url)) .send() .await .unwrap(); diff --git a/crates/uv/tests/it/edit.rs b/crates/uv/tests/it/edit.rs index 68eae5110..d117b1e8c 100644 --- a/crates/uv/tests/it/edit.rs +++ b/crates/uv/tests/it/edit.rs @@ -14,6 +14,7 @@ use assert_fs::prelude::*; use indoc::{formatdoc, indoc}; use insta::assert_snapshot; use std::path::Path; +use url::Url; use uv_fs::Simplified; use wiremock::{Mock, MockServer, ResponseTemplate, matchers::method}; @@ -11838,6 +11839,196 @@ fn add_auth_policy_never_without_credentials() -> Result<()> { Ok(()) } +/// If uv receives a 302 redirect to a cross-origin server, it should not forward +/// credentials. In the absence of a netrc entry for the new location, +/// it should fail. +#[tokio::test] +async fn add_redirect_cross_origin() -> Result<()> { + let context = TestContext::new("3.12"); + let filters = context + .filters() + .into_iter() + .chain([(r"127\.0\.0\.1:\d*", "[LOCALHOST]")]) + .collect::>(); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! { r#" + [project] + name = "foo" + version = "1.0.0" + requires-python = ">=3.12" + dependencies = [] + "# + })?; + + let redirect_server = MockServer::start().await; + + Mock::given(method("GET")) + .respond_with(|req: &wiremock::Request| { + let redirect_url = redirect_url_to_pypi_proxy(req); + ResponseTemplate::new(302).insert_header("Location", &redirect_url) + }) + .mount(&redirect_server) + .await; + + let mut redirect_url = Url::parse(&redirect_server.uri())?; + let _ = redirect_url.set_username("public"); + let _ = redirect_url.set_password(Some("heron")); + + uv_snapshot!(filters, context.add().arg("--default-index").arg(redirect_url.as_str()).arg("anyio"), @r" + success: false + exit_code: 1 + ----- stdout ----- + + ----- stderr ----- + × No solution found when resolving dependencies: + ╰─▶ Because anyio was not found in the package registry and your project depends on anyio, we can conclude that your project's requirements are unsatisfiable. + + hint: An index URL (http://[LOCALHOST]/) could not be queried due to a lack of valid authentication credentials (401 Unauthorized). + help: If you want to add the package regardless of the failed resolution, provide the `--frozen` flag to skip locking and syncing. + " + ); + + Ok(()) +} + +/// uv currently fails to look up keyring credentials on a cross-origin redirect. +#[tokio::test] +async fn add_redirect_with_keyring_cross_origin() -> Result<()> { + let keyring_context = TestContext::new("3.12"); + + // Install our keyring plugin + keyring_context + .pip_install() + .arg( + keyring_context + .workspace_root + .join("scripts") + .join("packages") + .join("keyring_test_plugin"), + ) + .assert() + .success(); + + let context = TestContext::new("3.12"); + let filters = context + .filters() + .into_iter() + .chain([(r"127\.0\.0\.1:\d*", "[LOCALHOST]")]) + .collect::>(); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! { r#" + [project] + name = "foo" + version = "1.0.0" + requires-python = ">=3.12" + dependencies = [] + + [tool.uv] + keyring-provider = "subprocess" + "#, + })?; + + let redirect_server = MockServer::start().await; + + Mock::given(method("GET")) + .respond_with(|req: &wiremock::Request| { + let redirect_url = redirect_url_to_pypi_proxy(req); + ResponseTemplate::new(302).insert_header("Location", &redirect_url) + }) + .mount(&redirect_server) + .await; + + let mut redirect_url = Url::parse(&redirect_server.uri())?; + let _ = redirect_url.set_username("public"); + + uv_snapshot!(filters, context.add().arg("--default-index") + .arg(redirect_url.as_str()) + .arg("anyio") + .env(EnvVars::KEYRING_TEST_CREDENTIALS, r#"{"pypi-proxy.fly.dev": {"public": "heron"}}"#) + .env(EnvVars::PATH, venv_bin_path(&keyring_context.venv)), @r" + success: false + exit_code: 1 + ----- stdout ----- + + ----- stderr ----- + Keyring request for public@http://[LOCALHOST]/ + Keyring request for public@[LOCALHOST] + × No solution found when resolving dependencies: + ╰─▶ Because anyio was not found in the package registry and your project depends on anyio, we can conclude that your project's requirements are unsatisfiable. + + hint: An index URL (http://[LOCALHOST]/) could not be queried due to a lack of valid authentication credentials (401 Unauthorized). + help: If you want to add the package regardless of the failed resolution, provide the `--frozen` flag to skip locking and syncing. + " + ); + + Ok(()) +} + +/// If uv receives a cross-origin 302 redirect, it should use credentials from netrc +/// for the new location. +#[tokio::test] +async fn pip_install_redirect_with_netrc_cross_origin() -> Result<()> { + let context = TestContext::new("3.12"); + let filters = context + .filters() + .into_iter() + .chain([(r"127\.0\.0\.1:\d*", "[LOCALHOST]")]) + .collect::>(); + + let netrc = context.temp_dir.child(".netrc"); + netrc.write_str("machine pypi-proxy.fly.dev login public password heron")?; + + let redirect_server = MockServer::start().await; + + Mock::given(method("GET")) + .respond_with(|req: &wiremock::Request| { + let redirect_url = redirect_url_to_pypi_proxy(req); + ResponseTemplate::new(302).insert_header("Location", &redirect_url) + }) + .mount(&redirect_server) + .await; + + let mut redirect_url = Url::parse(&redirect_server.uri())?; + let _ = redirect_url.set_username("public"); + + uv_snapshot!(filters, context.pip_install() + .arg("anyio") + .arg("--index-url") + .arg(redirect_url.as_str()) + .env(EnvVars::NETRC, netrc.to_str().unwrap()) + .arg("--strict"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + + ----- stderr ----- + Resolved 3 packages in [TIME] + Prepared 3 packages in [TIME] + Installed 3 packages in [TIME] + + anyio==4.3.0 + + idna==3.6 + + sniffio==1.3.1 + "### + ); + + context.assert_command("import anyio").success(); + + Ok(()) +} + +fn redirect_url_to_pypi_proxy(req: &wiremock::Request) -> String { + let last_path_segment = req + .url + .path_segments() + .expect("path has segments") + .filter(|segment| !segment.is_empty()) // Filter out empty segments + .next_back() + .expect("path has a package segment"); + format!("https://pypi-proxy.fly.dev/basic-auth/simple/{last_path_segment}/") +} + /// Test the error message when adding a package with multiple existing references in /// `pyproject.toml`. #[test]