diff --git a/crates/puffin-client/src/cached_client.rs b/crates/puffin-client/src/cached_client.rs index f98e3fdbb..2cb34f5c4 100644 --- a/crates/puffin-client/src/cached_client.rs +++ b/crates/puffin-client/src/cached_client.rs @@ -4,7 +4,7 @@ use std::time::SystemTime; use futures::FutureExt; use http::request::Parts; use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy}; -use reqwest::{Body, Request, Response}; +use reqwest::{Request, Response}; use reqwest_middleware::ClientWithMiddleware; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; @@ -206,18 +206,10 @@ impl CachedClient { /// `http-cache-semantics` to `reqwest` wrapper async fn send_cached( &self, - req: Request, + mut req: Request, cache_control: CacheControl, cached: Option>, ) -> Result, Error> { - // The converted types are from the specific `reqwest` types to the more generic `http` - // types. - let mut converted_req = http::Request::try_from( - req.try_clone() - .expect("You can't use streaming request bodies with this function"), - ) - .map_err(ErrorKind::RequestError)?; - let url = req.url().clone(); let cached_response = if let Some(cached) = cached { // Avoid sending revalidation requests for immutable responses. @@ -230,7 +222,7 @@ impl CachedClient { match cache_control { CacheControl::None => {} CacheControl::MustRevalidate => { - converted_req.headers_mut().insert( + req.headers_mut().insert( http::header::CACHE_CONTROL, http::HeaderValue::from_static("max-age=0, must-revalidate"), ); @@ -239,27 +231,20 @@ impl CachedClient { match cached .cache_policy - .before_request(&converted_req, SystemTime::now()) + .before_request(&RequestLikeReqwest(&req), SystemTime::now()) { BeforeRequest::Fresh(_) => { debug!("Found fresh response for: {url}"); CachedResponse::FreshCache(cached.data) } BeforeRequest::Stale { request, matches } => { - self.send_cached_handle_stale( - req, - converted_req, - url, - cached, - &request, - matches, - ) - .await? + self.send_cached_handle_stale(req, url, cached, &request, matches) + .await? } } } else { debug!("No cache entry for: {url}"); - self.fresh_request(req, converted_req).await? + self.fresh_request(req).await? }; Ok(cached_response) } @@ -267,7 +252,6 @@ impl CachedClient { async fn send_cached_handle_stale( &self, mut req: Request, - mut converted_req: http::Request, url: Url, cached: DataWithCachePolicy, request: &Parts, @@ -276,36 +260,26 @@ impl CachedClient { if !matches { // This shouldn't happen; if it does, we'll override the cache. warn!("Cached request doesn't match current request for: {url}"); - return self.fresh_request(req, converted_req).await; + return self.fresh_request(req).await; } debug!("Sending revalidation request for: {url}"); for header in &request.headers { req.headers_mut().insert(header.0.clone(), header.1.clone()); - converted_req - .headers_mut() - .insert(header.0.clone(), header.1.clone()); } let res = self .0 - .execute(req) + .execute(req.try_clone().expect("streaming requests not supported")) .instrument(info_span!("revalidation_request", url = url.as_str())) .await .map_err(ErrorKind::RequestMiddlewareError)? .error_for_status() .map_err(ErrorKind::RequestError)?; - let mut converted_res = http::Response::new(()); - *converted_res.status_mut() = res.status(); - for header in res.headers() { - converted_res.headers_mut().insert( - http::HeaderName::from(header.0), - http::HeaderValue::from(header.1), - ); - } - let after_response = - cached - .cache_policy - .after_response(&converted_req, &converted_res, SystemTime::now()); + let after_response = cached.cache_policy.after_response( + &RequestLikeReqwest(&req), + &ResponseLikeReqwest(&res), + SystemTime::now(), + ); match after_response { AfterResponse::NotModified(new_policy, _parts) => { debug!("Found not-modified response for: {url}"); @@ -328,29 +302,16 @@ impl CachedClient { } #[instrument(skip_all, fields(url = req.url().as_str()))] - async fn fresh_request( - &self, - req: Request, - converted_req: http::Request, - ) -> Result, Error> { + async fn fresh_request(&self, req: Request) -> Result, Error> { trace!("{} {}", req.method(), req.url()); let res = self .0 - .execute(req) + .execute(req.try_clone().expect("streaming requests not supported")) .await .map_err(ErrorKind::RequestMiddlewareError)? .error_for_status() .map_err(ErrorKind::RequestError)?; - let mut converted_res = http::Response::new(()); - *converted_res.status_mut() = res.status(); - for header in res.headers() { - converted_res.headers_mut().insert( - http::HeaderName::from(header.0), - http::HeaderValue::from(header.1), - ); - } - let cache_policy = - CachePolicy::new(&converted_req.into_parts().0, &converted_res.into_parts().0); + let cache_policy = CachePolicy::new(&RequestLikeReqwest(&req), &ResponseLikeReqwest(&res)); Ok(CachedResponse::ModifiedOrNew( res, cache_policy.is_storable().then(|| Box::new(cache_policy)), @@ -375,3 +336,45 @@ impl From for CacheControl { } } } + +#[derive(Debug)] +struct RequestLikeReqwest<'a>(&'a Request); + +impl<'a> http_cache_semantics::RequestLike for RequestLikeReqwest<'a> { + fn uri(&self) -> http::uri::Uri { + // This converts from a url::Url (as returned by reqwest::Request::url) + // to a http::uri::Uri. The conversion requires parsing, but this is + // only called ~once per HTTP request. We can afford it. + self.0 + .url() + .as_str() + .parse() + .expect("reqwest::Request::url always returns a valid URL") + } + fn is_same_uri(&self, other: &http::uri::Uri) -> bool { + // At time of writing, I saw no way to cheaply compare a http::uri::Uri + // with a url::Url. We can at least avoid parsing anything, and + // Url::as_str() is free. In practice though, this routine is called + // ~once per HTTP request. We can afford it. (And it looks like + // http::uri::Uri's PartialEq implementation has been tuned.) + self.0.url().as_str() == *other + } + fn method(&self) -> &http::method::Method { + self.0.method() + } + fn headers(&self) -> &http::header::HeaderMap { + self.0.headers() + } +} + +#[derive(Debug)] +struct ResponseLikeReqwest<'a>(&'a Response); + +impl<'a> http_cache_semantics::ResponseLike for ResponseLikeReqwest<'a> { + fn status(&self) -> http::status::StatusCode { + self.0.status() + } + fn headers(&self) -> &http::header::HeaderMap { + self.0.headers() + } +}