Implement RFC 7231 compliant relative URI and fragment handling in redirects (#13050)

This PR restores #13041 and integrates two PRs from @zanieb:
* #13038
* #13040

It also adds tests for relative URI and fragment handling.

Closes #13037.

---------

Co-authored-by: Zanie Blue <contact@zanie.dev>
This commit is contained in:
John Mumm 2025-04-28 09:07:06 +02:00 committed by GitHub
parent 576a4ae3a7
commit 4ee4a8861e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 696 additions and 37 deletions

1
Cargo.lock generated
View file

@ -4934,6 +4934,7 @@ dependencies = [
"uv-torch", "uv-torch",
"uv-version", "uv-version",
"uv-warnings", "uv-warnings",
"wiremock",
] ]
[[package]] [[package]]

View file

@ -64,3 +64,4 @@ hyper = { version = "1.4.1", features = ["server", "http1"] }
hyper-util = { version = "0.1.8", features = ["tokio"] } hyper-util = { version = "0.1.8", features = ["tokio"] }
insta = { version = "1.40.0", features = ["filters", "json", "redactions"] } insta = { version = "1.40.0", features = ["filters", "json", "redactions"] }
tokio = { workspace = true } tokio = { workspace = true }
wiremock = { workspace = true }

View file

@ -6,14 +6,17 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use std::{env, iter}; use std::{env, iter};
use anyhow::anyhow;
use http::{HeaderMap, HeaderName, HeaderValue, StatusCode};
use itertools::Itertools; use itertools::Itertools;
use reqwest::{Client, ClientBuilder, Proxy, Response}; use reqwest::{multipart, Client, ClientBuilder, IntoUrl, Proxy, Request, Response};
use reqwest_middleware::{ClientWithMiddleware, Middleware}; use reqwest_middleware::{ClientWithMiddleware, Middleware};
use reqwest_retry::policies::ExponentialBackoff; use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::{ use reqwest_retry::{
DefaultRetryableStrategy, RetryTransientMiddleware, Retryable, RetryableStrategy, DefaultRetryableStrategy, RetryTransientMiddleware, Retryable, RetryableStrategy,
}; };
use tracing::{debug, trace}; use tracing::{debug, trace};
use url::ParseError;
use url::Url; use url::Url;
use uv_auth::{AuthMiddleware, UrlAuthPolicies}; use uv_auth::{AuthMiddleware, UrlAuthPolicies};
@ -60,6 +63,24 @@ pub struct BaseClientBuilder<'a> {
default_timeout: Duration, default_timeout: Duration,
extra_middleware: Option<ExtraMiddleware>, extra_middleware: Option<ExtraMiddleware>,
proxies: Vec<Proxy>, proxies: Vec<Proxy>,
redirect_policy: RedirectPolicy,
}
/// The policy for handling redirects.
#[derive(Debug, Default, Clone, Copy)]
pub enum RedirectPolicy {
#[default]
BypassMiddleware,
RetriggerMiddleware,
}
impl RedirectPolicy {
pub fn reqwest_policy(self) -> reqwest::redirect::Policy {
match self {
RedirectPolicy::BypassMiddleware => reqwest::redirect::Policy::default(),
RedirectPolicy::RetriggerMiddleware => reqwest::redirect::Policy::none(),
}
}
} }
/// A list of user-defined middlewares to be applied to the client. /// A list of user-defined middlewares to be applied to the client.
@ -95,6 +116,7 @@ impl BaseClientBuilder<'_> {
default_timeout: Duration::from_secs(30), default_timeout: Duration::from_secs(30),
extra_middleware: None, extra_middleware: None,
proxies: vec![], proxies: vec![],
redirect_policy: RedirectPolicy::default(),
} }
} }
} }
@ -172,6 +194,12 @@ impl<'a> BaseClientBuilder<'a> {
self self
} }
#[must_use]
pub fn redirect(mut self, policy: RedirectPolicy) -> Self {
self.redirect_policy = policy;
self
}
pub fn is_offline(&self) -> bool { pub fn is_offline(&self) -> bool {
matches!(self.connectivity, Connectivity::Offline) matches!(self.connectivity, Connectivity::Offline)
} }
@ -228,6 +256,7 @@ impl<'a> BaseClientBuilder<'a> {
timeout, timeout,
ssl_cert_file_exists, ssl_cert_file_exists,
Security::Secure, Security::Secure,
self.redirect_policy,
); );
// Create an insecure client that accepts invalid certificates. // Create an insecure client that accepts invalid certificates.
@ -236,11 +265,18 @@ impl<'a> BaseClientBuilder<'a> {
timeout, timeout,
ssl_cert_file_exists, ssl_cert_file_exists,
Security::Insecure, Security::Insecure,
self.redirect_policy,
); );
// Wrap in any relevant middleware and handle connectivity. // Wrap in any relevant middleware and handle connectivity.
let client = self.apply_middleware(raw_client.clone()); let client = RedirectClientWithMiddleware {
let dangerous_client = self.apply_middleware(raw_dangerous_client.clone()); client: self.apply_middleware(raw_client.clone()),
redirect_policy: self.redirect_policy,
};
let dangerous_client = RedirectClientWithMiddleware {
client: self.apply_middleware(raw_dangerous_client.clone()),
redirect_policy: self.redirect_policy,
};
BaseClient { BaseClient {
connectivity: self.connectivity, connectivity: self.connectivity,
@ -257,8 +293,14 @@ impl<'a> BaseClientBuilder<'a> {
/// Share the underlying client between two different middleware configurations. /// Share the underlying client between two different middleware configurations.
pub fn wrap_existing(&self, existing: &BaseClient) -> BaseClient { pub fn wrap_existing(&self, existing: &BaseClient) -> BaseClient {
// Wrap in any relevant middleware and handle connectivity. // Wrap in any relevant middleware and handle connectivity.
let client = self.apply_middleware(existing.raw_client.clone()); let client = RedirectClientWithMiddleware {
let dangerous_client = self.apply_middleware(existing.raw_dangerous_client.clone()); client: self.apply_middleware(existing.raw_client.clone()),
redirect_policy: self.redirect_policy,
};
let dangerous_client = RedirectClientWithMiddleware {
client: self.apply_middleware(existing.raw_dangerous_client.clone()),
redirect_policy: self.redirect_policy,
};
BaseClient { BaseClient {
connectivity: self.connectivity, connectivity: self.connectivity,
@ -278,6 +320,7 @@ impl<'a> BaseClientBuilder<'a> {
timeout: Duration, timeout: Duration,
ssl_cert_file_exists: bool, ssl_cert_file_exists: bool,
security: Security, security: Security,
redirect_policy: RedirectPolicy,
) -> Client { ) -> Client {
// Configure the builder. // Configure the builder.
let client_builder = ClientBuilder::new() let client_builder = ClientBuilder::new()
@ -285,7 +328,8 @@ impl<'a> BaseClientBuilder<'a> {
.user_agent(user_agent) .user_agent(user_agent)
.pool_max_idle_per_host(20) .pool_max_idle_per_host(20)
.read_timeout(timeout) .read_timeout(timeout)
.tls_built_in_root_certs(false); .tls_built_in_root_certs(false)
.redirect(redirect_policy.reqwest_policy());
// If necessary, accept invalid certificates. // If necessary, accept invalid certificates.
let client_builder = match security { let client_builder = match security {
@ -382,9 +426,9 @@ impl<'a> BaseClientBuilder<'a> {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct BaseClient { pub struct BaseClient {
/// The underlying HTTP client that enforces valid certificates. /// The underlying HTTP client that enforces valid certificates.
client: ClientWithMiddleware, client: RedirectClientWithMiddleware,
/// The underlying HTTP client that accepts invalid certificates. /// The underlying HTTP client that accepts invalid certificates.
dangerous_client: ClientWithMiddleware, dangerous_client: RedirectClientWithMiddleware,
/// The HTTP client without middleware. /// The HTTP client without middleware.
raw_client: Client, raw_client: Client,
/// The HTTP client that accepts invalid certificates without middleware. /// The HTTP client that accepts invalid certificates without middleware.
@ -409,7 +453,7 @@ enum Security {
impl BaseClient { impl BaseClient {
/// Selects the appropriate client based on the host's trustworthiness. /// Selects the appropriate client based on the host's trustworthiness.
pub fn for_host(&self, url: &Url) -> &ClientWithMiddleware { pub fn for_host(&self, url: &Url) -> &RedirectClientWithMiddleware {
if self.disable_ssl(url) { if self.disable_ssl(url) {
&self.dangerous_client &self.dangerous_client
} else { } else {
@ -417,6 +461,12 @@ impl BaseClient {
} }
} }
/// Executes a request, applying redirect policy.
pub async fn execute(&self, req: Request) -> reqwest_middleware::Result<Response> {
let client = self.for_host(req.url());
client.execute(req).await
}
/// Returns `true` if the host is trusted to use the insecure client. /// Returns `true` if the host is trusted to use the insecure client.
pub fn disable_ssl(&self, url: &Url) -> bool { pub fn disable_ssl(&self, url: &Url) -> bool {
self.allow_insecure_host self.allow_insecure_host
@ -440,6 +490,205 @@ impl BaseClient {
} }
} }
/// Wrapper around [`ClientWithMiddleware`] that manages redirects.
#[derive(Debug, Clone)]
pub struct RedirectClientWithMiddleware {
client: ClientWithMiddleware,
redirect_policy: RedirectPolicy,
}
impl RedirectClientWithMiddleware {
/// Convenience method to make a `GET` request to a URL.
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder {
RequestBuilder::new(self.client.get(url), self)
}
/// Convenience method to make a `POST` request to a URL.
pub fn post<U: IntoUrl>(&self, url: U) -> RequestBuilder {
RequestBuilder::new(self.client.post(url), self)
}
/// Convenience method to make a `HEAD` request to a URL.
pub fn head<U: IntoUrl>(&self, url: U) -> RequestBuilder {
RequestBuilder::new(self.client.head(url), self)
}
/// Executes a request, applying the redirect policy.
pub async fn execute(&self, req: Request) -> reqwest_middleware::Result<Response> {
match self.redirect_policy {
RedirectPolicy::BypassMiddleware => self.client.execute(req).await,
RedirectPolicy::RetriggerMiddleware => self.execute_with_redirect_handling(req).await,
}
}
/// Executes a request. If the response is a redirect (one of HTTP 301, 302, 307, or 308), the
/// request is executed again with the redirect location URL (up to a maximum number of
/// redirects).
///
/// Unlike the built-in reqwest redirect policies, this sends the redirect request through the
/// entire middleware pipeline again.
///
/// See RFC 7231 7.1.2 <https://www.rfc-editor.org/rfc/rfc7231#section-7.1.2> for details on
/// redirect semantics.
async fn execute_with_redirect_handling(
&self,
req: Request,
) -> reqwest_middleware::Result<Response> {
let mut request = req;
let mut redirects = 0;
// This is the default used by reqwest.
let max_redirects = 10;
loop {
let request_url = request.url().clone();
let result = self
.client
.execute(request.try_clone().expect("HTTP request must be cloneable"))
.await;
if redirects == max_redirects {
return result;
}
let Ok(response) = result else {
return result;
};
// Handle redirect if we receive a 301, 302, 307, or 308.
let status = response.status();
if matches!(
status,
StatusCode::MOVED_PERMANENTLY
| StatusCode::FOUND
| StatusCode::TEMPORARY_REDIRECT
| StatusCode::PERMANENT_REDIRECT
) {
let location = response
.headers()
.get("location")
.ok_or(reqwest_middleware::Error::Middleware(anyhow!(
"Missing expected HTTP {status} 'Location' header"
)))?
.to_str()
.map_err(|_| {
reqwest_middleware::Error::Middleware(anyhow!(
"Invalid HTTP {status} 'Location' value: must only contain visible ascii characters"
))
})?;
let mut redirect_url = match Url::parse(location) {
Ok(url) => url,
// Per RFC 7231, URLs should be resolved against the request URL.
Err(ParseError::RelativeUrlWithoutBase) => request_url.join(location).map_err(|err| {
reqwest_middleware::Error::Middleware(anyhow!(
"Invalid HTTP {status} 'Location' value `{location}` relative to `{request_url}`: {err}"
))
})?,
Err(err) => {
return Err(reqwest_middleware::Error::Middleware(anyhow!(
"Invalid HTTP {status} 'Location' value `{location}`: {err}"
)));
}
};
// Ensure the URL is a valid HTTP URI.
if let Err(err) = redirect_url.as_str().parse::<http::Uri>() {
return Err(reqwest_middleware::Error::Middleware(anyhow!(
"Invalid HTTP {status} 'Location' value `{location}`: {err}"
)));
}
// Per RFC 7231, fragments must be propagated
if let Some(fragment) = request_url.fragment() {
redirect_url.set_fragment(Some(fragment));
}
debug!("Received HTTP {status} to {redirect_url}");
*request.url_mut() = redirect_url;
redirects += 1;
continue;
}
return Ok(response);
}
}
pub fn raw_client(&self) -> &ClientWithMiddleware {
&self.client
}
}
impl From<RedirectClientWithMiddleware> for ClientWithMiddleware {
fn from(item: RedirectClientWithMiddleware) -> ClientWithMiddleware {
item.client
}
}
/// A builder to construct the properties of a `Request`.
///
/// This wraps [`reqwest_middleware::RequestBuilder`] to ensure that the [`BaseClient`]
/// redirect policy is respected if `send()` is called.
#[derive(Debug)]
#[must_use]
pub struct RequestBuilder<'a> {
builder: reqwest_middleware::RequestBuilder,
client: &'a RedirectClientWithMiddleware,
}
impl<'a> RequestBuilder<'a> {
pub fn new(
builder: reqwest_middleware::RequestBuilder,
client: &'a RedirectClientWithMiddleware,
) -> Self {
Self { builder, client }
}
/// Add a `Header` to this Request.
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
self.builder = self.builder.header(key, value);
self
}
/// Add a set of Headers to the existing ones on this Request.
///
/// The headers will be merged in to any already set.
pub fn headers(mut self, headers: HeaderMap) -> Self {
self.builder = self.builder.headers(headers);
self
}
#[cfg(not(target_arch = "wasm32"))]
pub fn version(mut self, version: reqwest::Version) -> Self {
self.builder = self.builder.version(version);
self
}
#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
pub fn multipart(mut self, multipart: multipart::Form) -> Self {
self.builder = self.builder.multipart(multipart);
self
}
/// Build a `Request`.
pub fn build(self) -> reqwest::Result<Request> {
self.builder.build()
}
/// Constructs the Request and sends it to the target URL, returning a
/// future Response.
pub async fn send(self) -> reqwest_middleware::Result<Response> {
self.client.execute(self.build()?).await
}
pub fn raw_builder(&self) -> &reqwest_middleware::RequestBuilder {
&self.builder
}
}
/// Extends [`DefaultRetryableStrategy`], to log transient request failures and additional retry cases. /// Extends [`DefaultRetryableStrategy`], to log transient request failures and additional retry cases.
pub struct UvRetryableStrategy; pub struct UvRetryableStrategy;

View file

@ -510,7 +510,6 @@ impl CachedClient {
debug!("Sending revalidation request for: {url}"); debug!("Sending revalidation request for: {url}");
let response = self let response = self
.0 .0
.for_host(req.url())
.execute(req) .execute(req)
.instrument(info_span!("revalidation_request", url = url.as_str())) .instrument(info_span!("revalidation_request", url = url.as_str()))
.await .await
@ -551,7 +550,6 @@ impl CachedClient {
let cache_policy_builder = CachePolicyBuilder::new(&req); let cache_policy_builder = CachePolicyBuilder::new(&req);
let response = self let response = self
.0 .0
.for_host(&url)
.execute(req) .execute(req)
.await .await
.map_err(|err| ErrorKind::from_reqwest_middleware(url.clone(), err))? .map_err(|err| ErrorKind::from_reqwest_middleware(url.clone(), err))?

View file

@ -1,6 +1,6 @@
pub use base_client::{ pub use base_client::{
is_extended_transient_error, AuthIntegration, BaseClient, BaseClientBuilder, ExtraMiddleware, is_extended_transient_error, AuthIntegration, BaseClient, BaseClientBuilder, ExtraMiddleware,
UvRetryableStrategy, DEFAULT_RETRIES, RedirectClientWithMiddleware, RequestBuilder, UvRetryableStrategy, DEFAULT_RETRIES,
}; };
pub use cached_client::{CacheControl, CachedClient, CachedClientError, DataWithCachePolicy}; pub use cached_client::{CacheControl, CachedClient, CachedClientError, DataWithCachePolicy};
pub use error::{Error, ErrorKind, WrappedReqwestError}; pub use error::{Error, ErrorKind, WrappedReqwestError};

View file

@ -10,7 +10,6 @@ use futures::{FutureExt, StreamExt, TryStreamExt};
use http::HeaderMap; use http::HeaderMap;
use itertools::Either; use itertools::Either;
use reqwest::{Proxy, Response, StatusCode}; use reqwest::{Proxy, Response, StatusCode};
use reqwest_middleware::ClientWithMiddleware;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use tokio::sync::{Mutex, Semaphore}; use tokio::sync::{Mutex, Semaphore};
use tracing::{info_span, instrument, trace, warn, Instrument}; use tracing::{info_span, instrument, trace, warn, Instrument};
@ -34,7 +33,7 @@ use uv_pypi_types::{ResolutionMetadata, SimpleJson};
use uv_small_str::SmallString; use uv_small_str::SmallString;
use uv_torch::TorchStrategy; use uv_torch::TorchStrategy;
use crate::base_client::{BaseClientBuilder, ExtraMiddleware}; use crate::base_client::{BaseClientBuilder, ExtraMiddleware, RedirectPolicy};
use crate::cached_client::CacheControl; use crate::cached_client::CacheControl;
use crate::flat_index::FlatIndexEntry; use crate::flat_index::FlatIndexEntry;
use crate::html::SimpleHtml; use crate::html::SimpleHtml;
@ -42,7 +41,7 @@ use crate::remote_metadata::wheel_metadata_from_remote_zip;
use crate::rkyvutil::OwnedArchive; use crate::rkyvutil::OwnedArchive;
use crate::{ use crate::{
BaseClient, CachedClient, CachedClientError, Error, ErrorKind, FlatIndexClient, BaseClient, CachedClient, CachedClientError, Error, ErrorKind, FlatIndexClient,
FlatIndexEntries, FlatIndexEntries, RedirectClientWithMiddleware,
}; };
/// A builder for an [`RegistryClient`]. /// A builder for an [`RegistryClient`].
@ -158,7 +157,9 @@ impl<'a> RegistryClientBuilder<'a> {
pub fn build(self) -> RegistryClient { pub fn build(self) -> RegistryClient {
// Build a base client // Build a base client
let builder = self.base_client_builder; let builder = self
.base_client_builder
.redirect(RedirectPolicy::RetriggerMiddleware);
let client = builder.build(); let client = builder.build();
@ -255,7 +256,7 @@ impl RegistryClient {
} }
/// Return the [`BaseClient`] used by this client. /// Return the [`BaseClient`] used by this client.
pub fn uncached_client(&self, url: &Url) -> &ClientWithMiddleware { pub fn uncached_client(&self, url: &Url) -> &RedirectClientWithMiddleware {
self.client.uncached().for_host(url) self.client.uncached().for_host(url)
} }
@ -1175,6 +1176,215 @@ mod tests {
use crate::{html::SimpleHtml, SimpleMetadata, SimpleMetadatum}; use crate::{html::SimpleHtml, SimpleMetadata, SimpleMetadatum};
use uv_cache::Cache;
use wiremock::matchers::{basic_auth, method, path_regex};
use wiremock::{Mock, MockServer, ResponseTemplate};
use crate::RegistryClientBuilder;
type Error = Box<dyn std::error::Error>;
async fn start_test_server(username: &'static str, password: &'static str) -> MockServer {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(basic_auth(username, password))
.respond_with(ResponseTemplate::new(200))
.mount(&server)
.await;
Mock::given(method("GET"))
.respond_with(ResponseTemplate::new(401))
.mount(&server)
.await;
server
}
#[tokio::test]
async fn test_redirect_to_server_with_credentials() -> Result<(), Error> {
let username = "user";
let password = "password";
let auth_server = start_test_server(username, password).await;
let auth_base_url = Url::parse(&auth_server.uri())?;
let redirect_server = MockServer::start().await;
// Configure the redirect server to respond with a 302 to the auth server
Mock::given(method("GET"))
.respond_with(
ResponseTemplate::new(302).insert_header("Location", format!("{}", &auth_base_url)),
)
.mount(&redirect_server)
.await;
let redirect_server_url = Url::parse(&redirect_server.uri())?;
let cache = Cache::temp()?;
let registry_client = RegistryClientBuilder::new(cache).build();
let client = registry_client.cached_client().uncached();
assert_eq!(
client
.for_host(&redirect_server_url)
.get(redirect_server.uri())
.send()
.await?
.status(),
401,
"Requests should fail if credentials are missing"
);
let mut url = redirect_server_url.clone();
let _ = url.set_username(username);
let _ = url.set_password(Some(password));
assert_eq!(
client
.for_host(&redirect_server_url)
.get(format!("{url}"))
.send()
.await?
.status(),
200,
"Requests should succeed if credentials are present"
);
Ok(())
}
#[tokio::test]
async fn test_redirect_root_relative_url() -> Result<(), Error> {
let username = "user";
let password = "password";
let redirect_server = MockServer::start().await;
// Configure the redirect server to respond with a 307 with a relative URL.
Mock::given(method("GET"))
.and(path_regex("/foo/"))
.respond_with(
ResponseTemplate::new(307).insert_header("Location", "/bar/baz/".to_string()),
)
.mount(&redirect_server)
.await;
Mock::given(method("GET"))
.and(path_regex("/bar/baz/"))
.and(basic_auth(username, password))
.respond_with(ResponseTemplate::new(200))
.mount(&redirect_server)
.await;
let redirect_server_url = Url::parse(&redirect_server.uri())?.join("foo/")?;
let cache = Cache::temp()?;
let registry_client = RegistryClientBuilder::new(cache).build();
let client = registry_client.cached_client().uncached();
let mut url = redirect_server_url.clone();
let _ = url.set_username(username);
let _ = url.set_password(Some(password));
assert_eq!(
client
.for_host(&url)
.get(format!("{url}"))
.send()
.await?
.status(),
200,
"Requests should succeed for relative URL"
);
Ok(())
}
#[tokio::test]
async fn test_redirect_relative_url() -> Result<(), Error> {
let username = "user";
let password = "password";
let redirect_server = MockServer::start().await;
// Configure the redirect server to respond with a 307 with a relative URL.
Mock::given(method("GET"))
.and(path_regex("/foo/bar/baz/"))
.and(basic_auth(username, password))
.respond_with(ResponseTemplate::new(200))
.mount(&redirect_server)
.await;
Mock::given(method("GET"))
.and(path_regex("/foo/"))
.respond_with(
ResponseTemplate::new(307).insert_header("Location", "bar/baz/".to_string()),
)
.mount(&redirect_server)
.await;
let cache = Cache::temp()?;
let registry_client = RegistryClientBuilder::new(cache).build();
let client = registry_client.cached_client().uncached();
let redirect_server_url = Url::parse(&redirect_server.uri())?.join("foo/")?;
let mut url = redirect_server_url.clone();
let _ = url.set_username(username);
let _ = url.set_password(Some(password));
assert_eq!(
client
.for_host(&url)
.get(format!("{url}"))
.send()
.await?
.status(),
200,
"Requests should succeed for relative URL"
);
Ok(())
}
#[tokio::test]
async fn test_redirect_preserve_fragment() -> Result<(), Error> {
let redirect_server = MockServer::start().await;
// Configure the redirect server to respond with a 307 with a relative URL.
Mock::given(method("GET"))
.respond_with(ResponseTemplate::new(307).insert_header("Location", "/foo".to_string()))
.mount(&redirect_server)
.await;
Mock::given(method("GET"))
.and(path_regex("/foo"))
.respond_with(ResponseTemplate::new(200))
.mount(&redirect_server)
.await;
let cache = Cache::temp()?;
let registry_client = RegistryClientBuilder::new(cache).build();
let client = registry_client.cached_client().uncached();
let mut url = Url::parse(&redirect_server.uri())?;
url.set_fragment(Some("fragment"));
assert_eq!(
client
.for_host(&url)
.get(format!("{}", url.clone()))
.send()
.await?
.url()
.to_string(),
format!("{}/foo#fragment", redirect_server.uri()),
"Requests should preserve fragment"
);
Ok(())
}
#[test] #[test]
fn ignore_failing_files() { fn ignore_failing_files() {
// 1.7.7 has an invalid requires-python field (double comma), 1.7.8 is valid // 1.7.7 has an invalid requires-python field (double comma), 1.7.8 is valid

View file

@ -1582,7 +1582,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> {
client client
.unmanaged .unmanaged
.uncached_client(resource.git.repository()) .uncached_client(resource.git.repository())
.clone(), .raw_client(),
) )
.await .await
{ {
@ -1863,7 +1863,10 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> {
.git() .git()
.github_fast_path( .github_fast_path(
git, git,
client.unmanaged.uncached_client(git.repository()).clone(), client
.unmanaged
.uncached_client(git.repository())
.raw_client(),
) )
.await? .await?
.is_some() .is_some()

View file

@ -52,7 +52,7 @@ impl GitResolver {
pub async fn github_fast_path( pub async fn github_fast_path(
&self, &self,
url: &GitUrl, url: &GitUrl,
client: ClientWithMiddleware, client: &ClientWithMiddleware,
) -> Result<Option<GitOid>, GitResolverError> { ) -> Result<Option<GitOid>, GitResolverError> {
let reference = RepositoryReference::from(url); let reference = RepositoryReference::from(url);
@ -112,7 +112,7 @@ impl GitResolver {
pub async fn fetch( pub async fn fetch(
&self, &self,
url: &GitUrl, url: &GitUrl,
client: ClientWithMiddleware, client: impl Into<ClientWithMiddleware>,
disable_ssl: bool, disable_ssl: bool,
offline: bool, offline: bool,
cache: PathBuf, cache: PathBuf,

View file

@ -12,7 +12,6 @@ use itertools::Itertools;
use reqwest::header::AUTHORIZATION; use reqwest::header::AUTHORIZATION;
use reqwest::multipart::Part; use reqwest::multipart::Part;
use reqwest::{Body, Response, StatusCode}; use reqwest::{Body, Response, StatusCode};
use reqwest_middleware::RequestBuilder;
use reqwest_retry::policies::ExponentialBackoff; use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::{RetryPolicy, Retryable, RetryableStrategy}; use reqwest_retry::{RetryPolicy, Retryable, RetryableStrategy};
use rustc_hash::FxHashSet; use rustc_hash::FxHashSet;
@ -28,8 +27,8 @@ use url::Url;
use uv_auth::Credentials; use uv_auth::Credentials;
use uv_cache::{Cache, Refresh}; use uv_cache::{Cache, Refresh};
use uv_client::{ use uv_client::{
BaseClient, MetadataFormat, OwnedArchive, RegistryClientBuilder, UvRetryableStrategy, BaseClient, MetadataFormat, OwnedArchive, RegistryClientBuilder, RequestBuilder,
DEFAULT_RETRIES, UvRetryableStrategy, DEFAULT_RETRIES,
}; };
use uv_configuration::{KeyringProviderType, TrustedPublishing}; use uv_configuration::{KeyringProviderType, TrustedPublishing};
use uv_distribution_filename::{DistFilename, SourceDistExtension, SourceDistFilename}; use uv_distribution_filename::{DistFilename, SourceDistExtension, SourceDistFilename};
@ -320,7 +319,9 @@ pub async fn check_trusted_publishing(
// We could check for credentials from the keyring or netrc the auth middleware first, but // We could check for credentials from the keyring or netrc the auth middleware first, but
// given that we are in GitHub Actions we check for trusted publishing first. // given that we are in GitHub Actions we check for trusted publishing first.
debug!("Running on GitHub Actions without explicit credentials, checking for trusted publishing"); debug!("Running on GitHub Actions without explicit credentials, checking for trusted publishing");
match trusted_publishing::get_token(registry, client.for_host(registry)).await { match trusted_publishing::get_token(registry, client.for_host(registry).raw_client())
.await
{
Ok(token) => Ok(TrustedPublishResult::Configured(token)), Ok(token) => Ok(TrustedPublishResult::Configured(token)),
Err(err) => { Err(err) => {
// TODO(konsti): It would be useful if we could differentiate between actual errors // TODO(konsti): It would be useful if we could differentiate between actual errors
@ -354,7 +355,9 @@ pub async fn check_trusted_publishing(
); );
} }
let token = trusted_publishing::get_token(registry, client.for_host(registry)).await?; let token =
trusted_publishing::get_token(registry, client.for_host(registry).raw_client())
.await?;
Ok(TrustedPublishResult::Configured(token)) Ok(TrustedPublishResult::Configured(token))
} }
TrustedPublishing::Never => Ok(TrustedPublishResult::Skipped), TrustedPublishing::Never => Ok(TrustedPublishResult::Skipped),
@ -738,16 +741,16 @@ async fn form_metadata(
/// Build the upload request. /// Build the upload request.
/// ///
/// Returns the request and the reporter progress bar id. /// Returns the request and the reporter progress bar id.
async fn build_request( async fn build_request<'a>(
file: &Path, file: &Path,
raw_filename: &str, raw_filename: &str,
filename: &DistFilename, filename: &DistFilename,
registry: &Url, registry: &Url,
client: &BaseClient, client: &'a BaseClient,
credentials: &Credentials, credentials: &Credentials,
form_metadata: &[(&'static str, String)], form_metadata: &[(&'static str, String)],
reporter: Arc<impl Reporter>, reporter: Arc<impl Reporter>,
) -> Result<(RequestBuilder, usize), PublishPrepareError> { ) -> Result<(RequestBuilder<'a>, usize), PublishPrepareError> {
let mut form = reqwest::multipart::Form::new(); let mut form = reqwest::multipart::Form::new();
for (key, value) in form_metadata { for (key, value) in form_metadata {
form = form.text(*key, value.clone()); form = form.text(*key, value.clone());
@ -959,12 +962,13 @@ mod tests {
project_urls: Source, https://github.com/unknown/tqdm project_urls: Source, https://github.com/unknown/tqdm
"###); "###);
let client = BaseClientBuilder::new().build();
let (request, _) = build_request( let (request, _) = build_request(
&file, &file,
raw_filename, raw_filename,
&filename, &filename,
&Url::parse("https://example.org/upload").unwrap(), &Url::parse("https://example.org/upload").unwrap(),
&BaseClientBuilder::new().build(), &client,
&Credentials::basic(Some("ferris".to_string()), Some("F3RR!S".to_string())), &Credentials::basic(Some("ferris".to_string()), Some("F3RR!S".to_string())),
&form_metadata, &form_metadata,
Arc::new(DummyReporter), Arc::new(DummyReporter),
@ -975,7 +979,7 @@ mod tests {
insta::with_settings!({ insta::with_settings!({
filters => [("boundary=[0-9a-f-]+", "boundary=[...]")], filters => [("boundary=[0-9a-f-]+", "boundary=[...]")],
}, { }, {
assert_debug_snapshot!(&request, @r#" assert_debug_snapshot!(&request.raw_builder(), @r#"
RequestBuilder { RequestBuilder {
inner: RequestBuilder { inner: RequestBuilder {
method: POST, method: POST,
@ -1109,12 +1113,13 @@ mod tests {
requires_dist: requests ; extra == 'telegram' requires_dist: requests ; extra == 'telegram'
"###); "###);
let client = BaseClientBuilder::new().build();
let (request, _) = build_request( let (request, _) = build_request(
&file, &file,
raw_filename, raw_filename,
&filename, &filename,
&Url::parse("https://example.org/upload").unwrap(), &Url::parse("https://example.org/upload").unwrap(),
&BaseClientBuilder::new().build(), &client,
&Credentials::basic(Some("ferris".to_string()), Some("F3RR!S".to_string())), &Credentials::basic(Some("ferris".to_string()), Some("F3RR!S".to_string())),
&form_metadata, &form_metadata,
Arc::new(DummyReporter), Arc::new(DummyReporter),
@ -1125,7 +1130,7 @@ mod tests {
insta::with_settings!({ insta::with_settings!({
filters => [("boundary=[0-9a-f-]+", "boundary=[...]")], filters => [("boundary=[0-9a-f-]+", "boundary=[...]")],
}, { }, {
assert_debug_snapshot!(&request, @r#" assert_debug_snapshot!(&request.raw_builder(), @r#"
RequestBuilder { RequestBuilder {
inner: RequestBuilder { inner: RequestBuilder {
method: POST, method: POST,

View file

@ -1605,8 +1605,7 @@ pub async fn download_to_disk(url: &str, path: &Path) {
.allow_insecure_host(trusted_hosts) .allow_insecure_host(trusted_hosts)
.build(); .build();
let url: reqwest::Url = url.parse().unwrap(); let url: reqwest::Url = url.parse().unwrap();
let client = client.for_host(&url); let response = client.for_host(&url).get(url).send().await.unwrap();
let response = client.request(http::Method::GET, url).send().await.unwrap();
let mut file = tokio::fs::File::create(path).await.unwrap(); let mut file = tokio::fs::File::create(path).await.unwrap();
let mut stream = response.bytes_stream(); let mut stream = response.bytes_stream();

View file

@ -3,22 +3,24 @@
#[cfg(feature = "git")] #[cfg(feature = "git")]
mod conditional_imports { mod conditional_imports {
pub(crate) use crate::common::{decode_token, READ_ONLY_GITHUB_TOKEN}; pub(crate) use crate::common::{decode_token, READ_ONLY_GITHUB_TOKEN};
pub(crate) use assert_cmd::assert::OutputAssertExt;
} }
#[cfg(feature = "git")] #[cfg(feature = "git")]
use conditional_imports::*; use conditional_imports::*;
use anyhow::Result; use anyhow::Result;
use assert_cmd::assert::OutputAssertExt;
use assert_fs::prelude::*; use assert_fs::prelude::*;
use indoc::{formatdoc, indoc}; use indoc::{formatdoc, indoc};
use insta::assert_snapshot; use insta::assert_snapshot;
use std::path::Path; use std::path::Path;
use url::Url;
use uv_fs::Simplified; use uv_fs::Simplified;
use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate};
use uv_static::EnvVars; use uv_static::EnvVars;
use crate::common::{packse_index_url, uv_snapshot, TestContext}; use crate::common::{packse_index_url, uv_snapshot, venv_bin_path, TestContext};
/// Add a PyPI requirement. /// Add a PyPI requirement.
#[test] #[test]
@ -10748,6 +10750,197 @@ fn add_auth_policy_never_without_credentials() -> Result<()> {
Ok(()) Ok(())
} }
/// If uv receives a 302 redirect, it should use supplied credentials for the
/// new location.
#[tokio::test]
async fn add_redirect() -> Result<()> {
let context = TestContext::new("3.12");
let pyproject_toml = context.temp_dir.child("pyproject.toml");
pyproject_toml.write_str(indoc! { r#"
[project]
name = "foo"
version = "1.0.0"
requires-python = ">=3.12"
dependencies = []
"#
})?;
let redirect_server = MockServer::start().await;
Mock::given(method("GET"))
.respond_with(|req: &wiremock::Request| {
let redirect_url = redirect_url_to_pypi_proxy(req);
ResponseTemplate::new(302).insert_header("Location", &redirect_url)
})
.mount(&redirect_server)
.await;
let mut redirect_url = Url::parse(&redirect_server.uri())?;
let _ = redirect_url.set_username("public");
let _ = redirect_url.set_password(Some("heron"));
uv_snapshot!(context.add().arg("--default-index").arg(redirect_url.as_str()).arg("anyio"), @r"
success: true
exit_code: 0
----- stdout -----
----- stderr -----
Resolved 4 packages in [TIME]
Prepared 3 packages in [TIME]
Installed 3 packages in [TIME]
+ anyio==4.3.0
+ idna==3.6
+ sniffio==1.3.1
"
);
context.assert_command("import anyio").success();
Ok(())
}
/// If uv receives a 302 redirect, it should use credentials from the keyring
/// for the new location.
#[tokio::test]
async fn add_redirect_with_keyring() -> Result<()> {
let keyring_context = TestContext::new("3.12");
// Install our keyring plugin
keyring_context
.pip_install()
.arg(
keyring_context
.workspace_root
.join("scripts")
.join("packages")
.join("keyring_test_plugin"),
)
.assert()
.success();
let context = TestContext::new("3.12");
let filters = context
.filters()
.into_iter()
.chain([(r"127\.0\.0\.1[^\r\n]*", "[LOCALHOST]")])
.collect::<Vec<_>>();
let pyproject_toml = context.temp_dir.child("pyproject.toml");
pyproject_toml.write_str(indoc! { r#"
[project]
name = "foo"
version = "1.0.0"
requires-python = ">=3.12"
dependencies = []
[tool.uv]
keyring-provider = "subprocess"
"#,
})?;
let redirect_server = MockServer::start().await;
Mock::given(method("GET"))
.respond_with(|req: &wiremock::Request| {
let redirect_url = redirect_url_to_pypi_proxy(req);
ResponseTemplate::new(302).insert_header("Location", &redirect_url)
})
.mount(&redirect_server)
.await;
let mut redirect_url = Url::parse(&redirect_server.uri())?;
let _ = redirect_url.set_username("public");
uv_snapshot!(filters, context.add().arg("--default-index")
.arg(redirect_url.as_str())
.arg("anyio")
.env(EnvVars::KEYRING_TEST_CREDENTIALS, r#"{"pypi-proxy.fly.dev": {"public": "heron"}}"#)
.env(EnvVars::PATH, venv_bin_path(&keyring_context.venv)), @r"
success: true
exit_code: 0
----- stdout -----
----- stderr -----
Request for public@http://[LOCALHOST]
Request for public@[LOCALHOST]
Request for public@https://pypi-proxy.fly.dev/basic-auth/simple/anyio/
Request for public@pypi-proxy.fly.dev
Resolved 4 packages in [TIME]
Prepared 3 packages in [TIME]
Installed 3 packages in [TIME]
+ anyio==4.3.0
+ idna==3.6
+ sniffio==1.3.1
"
);
context.assert_command("import anyio").success();
Ok(())
}
/// If uv receives a 302 redirect, it should use credentials from netrc
/// for the new location.
#[tokio::test]
async fn add_redirect_with_netrc() -> Result<()> {
let context = TestContext::new("3.12");
let filters = context
.filters()
.into_iter()
.chain([(r"127\.0\.0\.1[^\r\n]*", "[LOCALHOST]")])
.collect::<Vec<_>>();
let netrc = context.temp_dir.child(".netrc");
netrc.write_str("machine pypi-proxy.fly.dev login public password heron")?;
let redirect_server = MockServer::start().await;
Mock::given(method("GET"))
.respond_with(|req: &wiremock::Request| {
let redirect_url = redirect_url_to_pypi_proxy(req);
ResponseTemplate::new(302).insert_header("Location", &redirect_url)
})
.mount(&redirect_server)
.await;
let mut redirect_url = Url::parse(&redirect_server.uri())?;
let _ = redirect_url.set_username("public");
uv_snapshot!(filters, context.pip_install()
.arg("anyio")
.arg("--index-url")
.arg(redirect_url.as_str())
.env(EnvVars::NETRC, netrc.to_str().unwrap())
.arg("--strict"), @r###"
success: true
exit_code: 0
----- stdout -----
----- stderr -----
Resolved 3 packages in [TIME]
Prepared 3 packages in [TIME]
Installed 3 packages in [TIME]
+ anyio==4.3.0
+ idna==3.6
+ sniffio==1.3.1
"###
);
context.assert_command("import anyio").success();
Ok(())
}
fn redirect_url_to_pypi_proxy(req: &wiremock::Request) -> String {
let last_path_segment = req
.url
.path_segments()
.expect("path has segments")
.filter(|segment| !segment.is_empty()) // Filter out empty segments
.next_back()
.expect("path has a package segment");
format!("https://pypi-proxy.fly.dev/basic-auth/simple/{last_path_segment}/")
}
/// Test the error message when adding a package with multiple existing references in /// Test the error message when adding a package with multiple existing references in
/// `pyproject.toml`. /// `pyproject.toml`.
#[test] #[test]