From d0efe1ed9c4bec806e8449a471bd97c9be10ba14 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Fri, 18 Jul 2025 16:32:29 -0400 Subject: [PATCH] Apply Cache-Control overrides to response, not request headers (#14736) ## Summary This was just an oversight on my part in the initial implementation. Closes https://github.com/astral-sh/uv/issues/14719. ## Test Plan With: ```toml [project] name = "foo" version = "0.1.0" description = "Add your description here" readme = "README.md" requires-python = ">=3.13.2" dependencies = [ ] [[tool.uv.index]] url = "https://download.pytorch.org/whl/cpu" cache-control = { api = "max-age=600" } ``` Ran `cargo run lock -vvv` and verified that the PyTorch index response was cached (whereas it typically returns `cache-control: no-cache,no-store,must-revalidate`). --- crates/uv-client/src/cached_client.rs | 83 ++++++++++++++----- crates/uv-distribution-types/src/index_url.rs | 20 +++++ .../src/distribution_database.rs | 68 +++++++++++---- crates/uv-distribution/src/source/mod.rs | 80 +++++++++++++++--- 4 files changed, 201 insertions(+), 50 deletions(-) diff --git a/crates/uv-client/src/cached_client.rs b/crates/uv-client/src/cached_client.rs index f888ea5f1..4219decd5 100644 --- a/crates/uv-client/src/cached_client.rs +++ b/crates/uv-client/src/cached_client.rs @@ -304,7 +304,7 @@ impl CachedClient { .await? } else { debug!("No cache entry for: {}", req.url()); - let (response, cache_policy) = self.fresh_request(req).await?; + let (response, cache_policy) = self.fresh_request(req, cache_control).await?; CachedResponse::ModifiedOrNew { response, cache_policy, @@ -318,8 +318,13 @@ impl CachedClient { "Broken fresh cache entry (for payload) at {}, removing: {err}", cache_entry.path().display() ); - self.resend_and_heal_cache(fresh_req, cache_entry, response_callback) - .await + self.resend_and_heal_cache( + fresh_req, + cache_entry, + cache_control, + response_callback, + ) + .await } }, CachedResponse::NotModified { cached, new_policy } => { @@ -339,8 +344,13 @@ impl CachedClient { (for payload) at {}, removing: {err}", cache_entry.path().display() ); - self.resend_and_heal_cache(fresh_req, cache_entry, response_callback) - .await + self.resend_and_heal_cache( + fresh_req, + cache_entry, + cache_control, + response_callback, + ) + .await } } } @@ -355,8 +365,13 @@ impl CachedClient { // ETag didn't match). We need to make a fresh request. if response.status() == http::StatusCode::NOT_MODIFIED { warn!("Server returned unusable 304 for: {}", fresh_req.url()); - self.resend_and_heal_cache(fresh_req, cache_entry, response_callback) - .await + self.resend_and_heal_cache( + fresh_req, + cache_entry, + cache_control, + response_callback, + ) + .await } else { self.run_response_callback( cache_entry, @@ -379,9 +394,10 @@ impl CachedClient { &self, req: Request, cache_entry: &CacheEntry, + cache_control: CacheControl<'_>, response_callback: Callback, ) -> Result> { - let (response, cache_policy) = self.fresh_request(req).await?; + let (response, cache_policy) = self.fresh_request(req, cache_control).await?; let payload = self .run_response_callback(cache_entry, cache_policy, response, async |resp| { @@ -401,10 +417,11 @@ impl CachedClient { &self, req: Request, cache_entry: &CacheEntry, + cache_control: CacheControl<'_>, response_callback: Callback, ) -> Result> { let _ = fs_err::tokio::remove_file(&cache_entry.path()).await; - let (response, cache_policy) = self.fresh_request(req).await?; + let (response, cache_policy) = self.fresh_request(req, cache_control).await?; self.run_response_callback(cache_entry, cache_policy, response, response_callback) .await } @@ -476,20 +493,13 @@ impl CachedClient { ) -> Result { // Apply the cache control header, if necessary. match cache_control { - CacheControl::None | CacheControl::AllowStale => {} + CacheControl::None | CacheControl::AllowStale | CacheControl::Override(..) => {} CacheControl::MustRevalidate => { req.headers_mut().insert( http::header::CACHE_CONTROL, http::HeaderValue::from_static("no-cache"), ); } - CacheControl::Override(value) => { - req.headers_mut().insert( - http::header::CACHE_CONTROL, - http::HeaderValue::from_str(value) - .map_err(|_| ErrorKind::InvalidCacheControl(value.to_string()))?, - ); - } } Ok(match cached.cache_policy.before_request(&mut req) { BeforeRequest::Fresh => { @@ -499,8 +509,13 @@ impl CachedClient { BeforeRequest::Stale(new_cache_policy_builder) => match cache_control { CacheControl::None | CacheControl::MustRevalidate | CacheControl::Override(_) => { debug!("Found stale response for: {}", req.url()); - self.send_cached_handle_stale(req, cached, new_cache_policy_builder) - .await? + self.send_cached_handle_stale( + req, + cache_control, + cached, + new_cache_policy_builder, + ) + .await? } CacheControl::AllowStale => { debug!("Found stale (but allowed) response for: {}", req.url()); @@ -513,7 +528,7 @@ impl CachedClient { "Cached request doesn't match current request for: {}", req.url() ); - let (response, cache_policy) = self.fresh_request(req).await?; + let (response, cache_policy) = self.fresh_request(req, cache_control).await?; CachedResponse::ModifiedOrNew { response, cache_policy, @@ -525,12 +540,13 @@ impl CachedClient { async fn send_cached_handle_stale( &self, req: Request, + cache_control: CacheControl<'_>, cached: DataWithCachePolicy, new_cache_policy_builder: CachePolicyBuilder, ) -> Result { let url = DisplaySafeUrl::from(req.url().clone()); debug!("Sending revalidation request for: {url}"); - let response = self + let mut response = self .0 .execute(req) .instrument(info_span!("revalidation_request", url = url.as_str())) @@ -538,6 +554,16 @@ impl CachedClient { .map_err(|err| ErrorKind::from_reqwest_middleware(url.clone(), err))? .error_for_status() .map_err(|err| ErrorKind::from_reqwest(url.clone(), err))?; + + // If the user set a custom `Cache-Control` header, override it. + if let CacheControl::Override(header) = cache_control { + response.headers_mut().insert( + http::header::CACHE_CONTROL, + http::HeaderValue::from_str(header) + .expect("Cache-Control header must be valid UTF-8"), + ); + } + match cached .cache_policy .after_response(new_cache_policy_builder, &response) @@ -566,16 +592,26 @@ impl CachedClient { async fn fresh_request( &self, req: Request, + cache_control: CacheControl<'_>, ) -> Result<(Response, Option>), Error> { let url = DisplaySafeUrl::from(req.url().clone()); trace!("Sending fresh {} request for {}", req.method(), url); let cache_policy_builder = CachePolicyBuilder::new(&req); - let response = self + let mut response = self .0 .execute(req) .await .map_err(|err| ErrorKind::from_reqwest_middleware(url.clone(), err))?; + // If the user set a custom `Cache-Control` header, override it. + if let CacheControl::Override(header) = cache_control { + response.headers_mut().insert( + http::header::CACHE_CONTROL, + http::HeaderValue::from_str(header) + .expect("Cache-Control header must be valid UTF-8"), + ); + } + let retry_count = response .extensions() .get::() @@ -690,6 +726,7 @@ impl CachedClient { &self, req: Request, cache_entry: &CacheEntry, + cache_control: CacheControl<'_>, response_callback: Callback, ) -> Result> { let mut past_retries = 0; @@ -698,7 +735,7 @@ impl CachedClient { loop { let fresh_req = req.try_clone().expect("HTTP request must be cloneable"); let result = self - .skip_cache(fresh_req, cache_entry, &response_callback) + .skip_cache(fresh_req, cache_entry, cache_control, &response_callback) .await; // Check if the middleware already performed retries diff --git a/crates/uv-distribution-types/src/index_url.rs b/crates/uv-distribution-types/src/index_url.rs index cbc1a4eb1..6baca1c1f 100644 --- a/crates/uv-distribution-types/src/index_url.rs +++ b/crates/uv-distribution-types/src/index_url.rs @@ -441,6 +441,26 @@ impl<'a> IndexLocations { } } } + + /// Return the Simple API cache control header for an [`IndexUrl`], if configured. + pub fn simple_api_cache_control_for(&self, url: &IndexUrl) -> Option<&str> { + for index in &self.indexes { + if index.url() == url { + return index.cache_control.as_ref()?.api.as_deref(); + } + } + None + } + + /// Return the artifact cache control header for an [`IndexUrl`], if configured. + pub fn artifact_cache_control_for(&self, url: &IndexUrl) -> Option<&str> { + for index in &self.indexes { + if index.url() == url { + return index.cache_control.as_ref()?.files.as_deref(); + } + } + None + } } impl From<&IndexLocations> for uv_auth::Indexes { diff --git a/crates/uv-distribution/src/distribution_database.rs b/crates/uv-distribution/src/distribution_database.rs index d18269730..30f3a243c 100644 --- a/crates/uv-distribution/src/distribution_database.rs +++ b/crates/uv-distribution/src/distribution_database.rs @@ -20,7 +20,7 @@ use uv_client::{ }; use uv_distribution_filename::WheelFilename; use uv_distribution_types::{ - BuildableSource, BuiltDist, Dist, HashPolicy, Hashed, InstalledDist, Name, SourceDist, + BuildableSource, BuiltDist, Dist, HashPolicy, Hashed, IndexUrl, InstalledDist, Name, SourceDist, }; use uv_extract::hash::Hasher; use uv_fs::write_atomic; @@ -201,6 +201,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { match self .stream_wheel( url.clone(), + dist.index(), &wheel.filename, wheel.file.size, &wheel_entry, @@ -236,6 +237,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { let archive = self .download_wheel( url, + dist.index(), &wheel.filename, wheel.file.size, &wheel_entry, @@ -272,6 +274,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { match self .stream_wheel( wheel.url.raw().clone(), + None, &wheel.filename, None, &wheel_entry, @@ -301,6 +304,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { let archive = self .download_wheel( wheel.url.raw().clone(), + None, &wheel.filename, None, &wheel_entry, @@ -534,6 +538,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { async fn stream_wheel( &self, url: DisplaySafeUrl, + index: Option<&IndexUrl>, filename: &WheelFilename, size: Option, wheel_entry: &CacheEntry, @@ -616,13 +621,24 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { // Fetch the archive from the cache, or download it if necessary. let req = self.request(url.clone())?; + // Determine the cache control policy for the URL. let cache_control = match self.client.unmanaged.connectivity() { - Connectivity::Online => CacheControl::from( - self.build_context - .cache() - .freshness(&http_entry, Some(&filename.name), None) - .map_err(Error::CacheRead)?, - ), + Connectivity::Online => { + if let Some(header) = index.and_then(|index| { + self.build_context + .locations() + .artifact_cache_control_for(index) + }) { + CacheControl::Override(header) + } else { + CacheControl::from( + self.build_context + .cache() + .freshness(&http_entry, Some(&filename.name), None) + .map_err(Error::CacheRead)?, + ) + } + } Connectivity::Offline => CacheControl::AllowStale, }; @@ -654,7 +670,12 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { .managed(async |client| { client .cached_client() - .skip_cache_with_retry(self.request(url)?, &http_entry, download) + .skip_cache_with_retry( + self.request(url)?, + &http_entry, + cache_control, + download, + ) .await .map_err(|err| match err { CachedClientError::Callback { err, .. } => err, @@ -671,6 +692,7 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { async fn download_wheel( &self, url: DisplaySafeUrl, + index: Option<&IndexUrl>, filename: &WheelFilename, size: Option, wheel_entry: &CacheEntry, @@ -783,13 +805,24 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { // Fetch the archive from the cache, or download it if necessary. let req = self.request(url.clone())?; + // Determine the cache control policy for the URL. let cache_control = match self.client.unmanaged.connectivity() { - Connectivity::Online => CacheControl::from( - self.build_context - .cache() - .freshness(&http_entry, Some(&filename.name), None) - .map_err(Error::CacheRead)?, - ), + Connectivity::Online => { + if let Some(header) = index.and_then(|index| { + self.build_context + .locations() + .artifact_cache_control_for(index) + }) { + CacheControl::Override(header) + } else { + CacheControl::from( + self.build_context + .cache() + .freshness(&http_entry, Some(&filename.name), None) + .map_err(Error::CacheRead)?, + ) + } + } Connectivity::Offline => CacheControl::AllowStale, }; @@ -821,7 +854,12 @@ impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> { .managed(async |client| { client .cached_client() - .skip_cache_with_retry(self.request(url)?, &http_entry, download) + .skip_cache_with_retry( + self.request(url)?, + &http_entry, + cache_control, + download, + ) .await .map_err(|err| match err { CachedClientError::Callback { err, .. } => err, diff --git a/crates/uv-distribution/src/source/mod.rs b/crates/uv-distribution/src/source/mod.rs index 080a1e52d..66b6122e0 100644 --- a/crates/uv-distribution/src/source/mod.rs +++ b/crates/uv-distribution/src/source/mod.rs @@ -32,7 +32,7 @@ use uv_client::{ use uv_configuration::{BuildKind, BuildOutput, ConfigSettings, SourceStrategy}; use uv_distribution_filename::{SourceDistExtension, WheelFilename}; use uv_distribution_types::{ - BuildableSource, DirectorySourceUrl, GitSourceUrl, HashPolicy, Hashed, PathSourceUrl, + BuildableSource, DirectorySourceUrl, GitSourceUrl, HashPolicy, Hashed, IndexUrl, PathSourceUrl, SourceDist, SourceUrl, }; use uv_extract::hash::Hasher; @@ -148,6 +148,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { self.url( source, &url, + Some(&dist.index), &cache_shard, None, dist.ext, @@ -168,6 +169,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { self.url( source, &dist.url, + None, &cache_shard, dist.subdirectory.as_deref(), dist.ext, @@ -213,6 +215,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { self.url( source, resource.url, + None, &cache_shard, resource.subdirectory, resource.ext, @@ -288,9 +291,18 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { .await; } - self.url_metadata(source, &url, &cache_shard, None, dist.ext, hashes, client) - .boxed_local() - .await? + self.url_metadata( + source, + &url, + Some(&dist.index), + &cache_shard, + None, + dist.ext, + hashes, + client, + ) + .boxed_local() + .await? } BuildableSource::Dist(SourceDist::DirectUrl(dist)) => { // For direct URLs, cache directly under the hash of the URL itself. @@ -302,6 +314,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { self.url_metadata( source, &dist.url, + None, &cache_shard, dist.subdirectory.as_deref(), dist.ext, @@ -340,6 +353,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { self.url_metadata( source, resource.url, + None, &cache_shard, resource.subdirectory, resource.ext, @@ -395,6 +409,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { &self, source: &BuildableSource<'data>, url: &'data DisplaySafeUrl, + index: Option<&'data IndexUrl>, cache_shard: &CacheShard, subdirectory: Option<&'data Path>, ext: SourceDistExtension, @@ -406,7 +421,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { // Fetch the revision for the source distribution. let revision = self - .url_revision(source, ext, url, cache_shard, hashes, client) + .url_revision(source, ext, url, index, cache_shard, hashes, client) .await?; // Before running the build, check that the hashes match. @@ -448,6 +463,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { source, ext, url, + index, &source_dist_entry, revision, hashes, @@ -511,6 +527,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { &self, source: &BuildableSource<'data>, url: &'data Url, + index: Option<&'data IndexUrl>, cache_shard: &CacheShard, subdirectory: Option<&'data Path>, ext: SourceDistExtension, @@ -521,7 +538,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { // Fetch the revision for the source distribution. let revision = self - .url_revision(source, ext, url, cache_shard, hashes, client) + .url_revision(source, ext, url, index, cache_shard, hashes, client) .await?; // Before running the build, check that the hashes match. @@ -578,6 +595,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { source, ext, url, + index, &source_dist_entry, revision, hashes, @@ -689,18 +707,31 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { source: &BuildableSource<'_>, ext: SourceDistExtension, url: &Url, + index: Option<&IndexUrl>, cache_shard: &CacheShard, hashes: HashPolicy<'_>, client: &ManagedClient<'_>, ) -> Result { let cache_entry = cache_shard.entry(HTTP_REVISION); + + // Determine the cache control policy for the request. let cache_control = match client.unmanaged.connectivity() { - Connectivity::Online => CacheControl::from( - self.build_context - .cache() - .freshness(&cache_entry, source.name(), source.source_tree()) - .map_err(Error::CacheRead)?, - ), + Connectivity::Online => { + if let Some(header) = index.and_then(|index| { + self.build_context + .locations() + .artifact_cache_control_for(index) + }) { + CacheControl::Override(header) + } else { + CacheControl::from( + self.build_context + .cache() + .freshness(&cache_entry, source.name(), source.source_tree()) + .map_err(Error::CacheRead)?, + ) + } + } Connectivity::Offline => CacheControl::AllowStale, }; @@ -750,6 +781,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { .skip_cache_with_retry( Self::request(DisplaySafeUrl::from(url.clone()), client)?, &cache_entry, + cache_control, download, ) .await @@ -2056,6 +2088,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { source: &BuildableSource<'_>, ext: SourceDistExtension, url: &Url, + index: Option<&IndexUrl>, entry: &CacheEntry, revision: Revision, hashes: HashPolicy<'_>, @@ -2063,6 +2096,28 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { ) -> Result { warn!("Re-downloading missing source distribution: {source}"); let cache_entry = entry.shard().entry(HTTP_REVISION); + + // Determine the cache control policy for the request. + let cache_control = match client.unmanaged.connectivity() { + Connectivity::Online => { + if let Some(header) = index.and_then(|index| { + self.build_context + .locations() + .artifact_cache_control_for(index) + }) { + CacheControl::Override(header) + } else { + CacheControl::from( + self.build_context + .cache() + .freshness(&cache_entry, source.name(), source.source_tree()) + .map_err(Error::CacheRead)?, + ) + } + } + Connectivity::Offline => CacheControl::AllowStale, + }; + let download = |response| { async { // Take the union of the requested and existing hash algorithms. @@ -2096,6 +2151,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { .skip_cache_with_retry( Self::request(DisplaySafeUrl::from(url.clone()), client)?, &cache_entry, + cache_control, download, ) .await