diff --git a/Cargo.lock b/Cargo.lock index bc209c8f6..312b1d11a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2882,6 +2882,16 @@ dependencies = [ "thiserror", ] +[[package]] +name = "reqwest-netrc" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eca0c58cd4b2978f9697dea94302e772399f559cd175356eb631cb6daaa0b6db" +dependencies = [ + "reqwest-middleware", + "rust-netrc", +] + [[package]] name = "reqwest-retry" version = "0.3.0" @@ -3037,6 +3047,15 @@ version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3cd14fd5e3b777a7422cca79358c57a8f6e3a703d9ac187448d0daf220c2407f" +[[package]] +name = "rust-netrc" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32662f97cbfdbad9d5f78f1338116f06871e7dae4fd37e9f59a0f57cf2044868" +dependencies = [ + "thiserror", +] + [[package]] name = "rustc-demangle" version = "0.1.23" @@ -4316,6 +4335,7 @@ dependencies = [ "pypi-types", "reqwest", "reqwest-middleware", + "reqwest-netrc", "reqwest-retry", "rkyv", "rmp-serde", diff --git a/Cargo.toml b/Cargo.toml index 07724e332..b3ff6fb50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,6 +77,7 @@ reflink-copy = { version = "0.1.14" } regex = { version = "1.10.2" } reqwest = { version = "0.11.23", default-features = false, features = ["json", "gzip", "brotli", "stream", "rustls-tls-native-roots"] } reqwest-middleware = { version = "0.2.4" } +reqwest-netrc = { version = "0.1.1" } reqwest-retry = { version = "0.3.0" } rkyv = { version = "0.7.43", features = ["strict", "validation"] } rmp-serde = { version = "1.1.2" } diff --git a/crates/uv-client/Cargo.toml b/crates/uv-client/Cargo.toml index c1bb9f877..1303df3eb 100644 --- a/crates/uv-client/Cargo.toml +++ b/crates/uv-client/Cargo.toml @@ -30,6 +30,7 @@ html-escape = { workspace = true } http = { workspace = true } reqwest = { workspace = true } reqwest-middleware = { workspace = true } +reqwest-netrc = { workspace = true } reqwest-retry = { workspace = true } rkyv = { workspace = true, features = ["strict", "validation"] } rmp-serde = { workspace = true } diff --git a/crates/uv-client/src/registry_client.rs b/crates/uv-client/src/registry_client.rs index 0a19698f2..994b83185 100644 --- a/crates/uv-client/src/registry_client.rs +++ b/crates/uv-client/src/registry_client.rs @@ -8,6 +8,7 @@ use async_http_range_reader::AsyncHttpRangeReader; use futures::{FutureExt, TryStreamExt}; use http::HeaderMap; use reqwest::{Client, ClientBuilder, Response, StatusCode}; +use reqwest_netrc::NetrcMiddleware; use reqwest_retry::policies::ExponentialBackoff; use reqwest_retry::RetryTransientMiddleware; use serde::{Deserialize, Serialize}; @@ -122,12 +123,22 @@ impl RegistryClientBuilder { // Wrap in any relevant middleware. let client = match self.connectivity { Connectivity::Online => { + let client = reqwest_middleware::ClientBuilder::new(client.clone()); + + // Initialize the retry strategy. let retry_policy = ExponentialBackoff::builder().build_with_max_retries(self.retries); let retry_strategy = RetryTransientMiddleware::new_with_policy(retry_policy); - reqwest_middleware::ClientBuilder::new(client.clone()) - .with(retry_strategy) - .build() + let client = client.with(retry_strategy); + + // Initialize the netrc middleware. + let client = if let Ok(netrc) = NetrcMiddleware::new() { + client.with_init(netrc) + } else { + client + }; + + client.build() } Connectivity::Offline => reqwest_middleware::ClientBuilder::new(client.clone()) .with(OfflineMiddleware) diff --git a/crates/uv-client/tests/netrc_auth.rs b/crates/uv-client/tests/netrc_auth.rs new file mode 100644 index 000000000..d23103e40 --- /dev/null +++ b/crates/uv-client/tests/netrc_auth.rs @@ -0,0 +1,68 @@ +use std::env; +use std::io::Write; + +use anyhow::Result; +use futures::future; +use hyper::header::AUTHORIZATION; +use hyper::server::conn::Http; +use hyper::service::service_fn; +use hyper::{Body, Request, Response}; +use tempfile::NamedTempFile; +use tokio::net::TcpListener; + +use uv_cache::Cache; +use uv_client::RegistryClientBuilder; + +#[tokio::test] +async fn test_client_with_netrc_credentials() -> Result<()> { + // Set up the TCP listener on a random available port + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + // Spawn the server loop in a background task + tokio::spawn(async move { + let svc = service_fn(move |req: Request| { + // Get User Agent Header and send it back in the response + let auth = req + .headers() + .get(AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()) + .unwrap_or_default(); // Empty Default + future::ok::<_, hyper::Error>(Response::new(Body::from(auth))) + }); + // Start Hyper Server + let (socket, _) = listener.accept().await.unwrap(); + Http::new() + .http1_keep_alive(false) + .serve_connection(socket, svc) + .with_upgrades() + .await + .expect("Server Started"); + }); + + // Create a netrc file + let mut netrc_file = NamedTempFile::new()?; + env::set_var("NETRC", netrc_file.path()); + writeln!(netrc_file, "machine 127.0.0.1 login user password 1234")?; + + // Initialize uv-client + let cache = Cache::temp()?; + let client = RegistryClientBuilder::new(cache).build(); + + // Send request to our dummy server + let res = client + .cached_client() + .uncached() + .get(format!("http://{addr}")) + .send() + .await?; + + // Check the HTTP status + assert!(res.status().is_success()); + + // Verify auth header + assert_eq!(res.text().await?, "Basic dXNlcjoxMjM0"); + + Ok(()) +}