Implement basic request caching for compilation server (#1253)

* Implement basic request caching for compilation server

* Fix formatting
This commit is contained in:
Dennis Kobert 2023-05-28 00:52:10 +02:00 committed by Keavon Chambers
parent 6289d92e02
commit 9da83d3280
3 changed files with 27 additions and 8 deletions

View file

@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc, sync::RwLock};
use gpu_compiler_bin_wrapper::CompileRequest;
use tower_http::cors::CorsLayer;
@ -12,12 +12,14 @@ use axum::{
struct AppState {
compile_dir: tempfile::TempDir,
cache: RwLock<HashMap<CompileRequest, Result<Vec<u8>, StatusCode>>>,
}
#[tokio::main]
async fn main() {
let shared_state = Arc::new(AppState {
compile_dir: tempfile::tempdir().expect("failed to create tempdir"),
cache: Default::default(),
});
// build our application with a single route
@ -33,9 +35,15 @@ async fn main() {
}
async fn post_compile_spirv(State(state): State<Arc<AppState>>, Json(compile_request): Json<CompileRequest>) -> Result<Vec<u8>, StatusCode> {
if let Some(result) = state.cache.read().unwrap().get(&compile_request) {
return result.clone();
}
let path = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/../gpu-compiler/Cargo.toml";
compile_request.compile(state.compile_dir.path().to_str().expect("non utf8 tempdir path"), &path).map_err(|e| {
let result = compile_request.compile(state.compile_dir.path().to_str().expect("non utf8 tempdir path"), &path).map_err(|e| {
eprintln!("compilation failed: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})
});
state.cache.write().unwrap().insert(compile_request, result.clone());
result
}