diff --git a/crates/uv-resolver/src/resolver/mod.rs b/crates/uv-resolver/src/resolver/mod.rs index 1384ce4f7..ed1cd48af 100644 --- a/crates/uv-resolver/src/resolver/mod.rs +++ b/crates/uv-resolver/src/resolver/mod.rs @@ -1814,7 +1814,7 @@ impl ResolverState std::fmt::Result { match self { Self::Cuda { driver_version } => write!(f, "CUDA {driver_version}"), + Self::Amd { gpu_architecture } => write!(f, "AMD {gpu_architecture}"), } } } @@ -33,9 +46,11 @@ impl Accelerator { /// /// Query, in order: /// 1. The `UV_CUDA_DRIVER_VERSION` environment variable. + /// 2. The `UV_AMD_GPU_ARCHITECTURE` environment variable. /// 2. `/sys/module/nvidia/version`, which contains the driver version (e.g., `550.144.03`). /// 3. `/proc/driver/nvidia/version`, which contains the driver version among other information. /// 4. `nvidia-smi --query-gpu=driver_version --format=csv,noheader`. + /// 5. `rocm_agent_enumerator`, which lists the AMD GPU architectures. pub fn detect() -> Result, AcceleratorError> { // Read from `UV_CUDA_DRIVER_VERSION`. if let Ok(driver_version) = std::env::var(EnvVars::UV_CUDA_DRIVER_VERSION) { @@ -44,6 +59,15 @@ impl Accelerator { return Ok(Some(Self::Cuda { driver_version })); } + // Read from `UV_AMD_GPU_ARCHITECTURE`. + if let Ok(gpu_architecture) = std::env::var(EnvVars::UV_AMD_GPU_ARCHITECTURE) { + let gpu_architecture = AmdGpuArchitecture::from_str(&gpu_architecture)?; + debug!( + "Detected AMD GPU architecture from `UV_AMD_GPU_ARCHITECTURE`: {gpu_architecture}" + ); + return Ok(Some(Self::Amd { gpu_architecture })); + } + // Read from `/sys/module/nvidia/version`. match fs_err::read_to_string("/sys/module/nvidia/version") { Ok(content) => { @@ -100,7 +124,34 @@ impl Accelerator { ); } - debug!("Failed to detect CUDA driver version"); + // Query `rocm_agent_enumerator` to detect the AMD GPU architecture. + // + // See: https://rocm.docs.amd.com/projects/rocminfo/en/latest/how-to/use-rocm-agent-enumerator.html + if let Ok(output) = std::process::Command::new("rocm_agent_enumerator").output() { + if output.status.success() { + let stdout = String::from_utf8(output.stdout)?; + if let Some(gpu_architecture) = stdout + .lines() + .map(str::trim) + .filter_map(|line| AmdGpuArchitecture::from_str(line).ok()) + .min() + { + debug!( + "Detected AMD GPU architecture from `rocm_agent_enumerator`: {gpu_architecture}" + ); + return Ok(Some(Self::Amd { gpu_architecture })); + } + } else { + debug!( + "Failed to query AMD GPU architecture with `rocm_agent_enumerator` with status `{}`: {}", + output.status, + String::from_utf8_lossy(&output.stderr) + ); + } + } + + debug!("Failed to detect GPU driver version"); + Ok(None) } } @@ -129,6 +180,63 @@ fn parse_proc_driver_nvidia_version(content: &str) -> Result, Ac Ok(Some(driver_version)) } +/// A GPU architecture for AMD GPUs. +/// +/// See: +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)] +pub enum AmdGpuArchitecture { + Gfx900, + Gfx906, + Gfx908, + Gfx90a, + Gfx942, + Gfx1030, + Gfx1100, + Gfx1101, + Gfx1102, + Gfx1200, + Gfx1201, +} + +impl FromStr for AmdGpuArchitecture { + type Err = AcceleratorError; + + fn from_str(s: &str) -> Result { + match s { + "gfx900" => Ok(Self::Gfx900), + "gfx906" => Ok(Self::Gfx906), + "gfx908" => Ok(Self::Gfx908), + "gfx90a" => Ok(Self::Gfx90a), + "gfx942" => Ok(Self::Gfx942), + "gfx1030" => Ok(Self::Gfx1030), + "gfx1100" => Ok(Self::Gfx1100), + "gfx1101" => Ok(Self::Gfx1101), + "gfx1102" => Ok(Self::Gfx1102), + "gfx1200" => Ok(Self::Gfx1200), + "gfx1201" => Ok(Self::Gfx1201), + _ => Err(AcceleratorError::UnknownAmdGpuArchitecture(s.to_string())), + } + } +} + +impl std::fmt::Display for AmdGpuArchitecture { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Gfx900 => write!(f, "gfx900"), + Self::Gfx906 => write!(f, "gfx906"), + Self::Gfx908 => write!(f, "gfx908"), + Self::Gfx90a => write!(f, "gfx90a"), + Self::Gfx942 => write!(f, "gfx942"), + Self::Gfx1030 => write!(f, "gfx1030"), + Self::Gfx1100 => write!(f, "gfx1100"), + Self::Gfx1101 => write!(f, "gfx1101"), + Self::Gfx1102 => write!(f, "gfx1102"), + Self::Gfx1200 => write!(f, "gfx1200"), + Self::Gfx1201 => write!(f, "gfx1201"), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/uv-torch/src/backend.rs b/crates/uv-torch/src/backend.rs index 60d43f3d7..0f2b72077 100644 --- a/crates/uv-torch/src/backend.rs +++ b/crates/uv-torch/src/backend.rs @@ -47,7 +47,7 @@ use uv_normalize::PackageName; use uv_pep440::Version; use uv_platform_tags::Os; -use crate::{Accelerator, AcceleratorError}; +use crate::{Accelerator, AcceleratorError, AmdGpuArchitecture}; /// The strategy to use when determining the appropriate PyTorch index. #[derive(Debug, Copy, Clone, Eq, PartialEq, serde::Deserialize, serde::Serialize)] @@ -178,8 +178,13 @@ pub enum TorchMode { /// The strategy to use when determining the appropriate PyTorch index. #[derive(Debug, Clone, Eq, PartialEq)] pub enum TorchStrategy { - /// Select the appropriate PyTorch index based on the operating system and CUDA driver version. - Auto { os: Os, driver_version: Version }, + /// Select the appropriate PyTorch index based on the operating system and CUDA driver version (e.g., `550.144.03`). + Cuda { os: Os, driver_version: Version }, + /// Select the appropriate PyTorch index based on the operating system and AMD GPU architecture (e.g., `gfx1100`). + Amd { + os: Os, + gpu_architecture: AmdGpuArchitecture, + }, /// Use the specified PyTorch index. Backend(TorchBackend), } @@ -188,16 +193,17 @@ impl TorchStrategy { /// Determine the [`TorchStrategy`] from the given [`TorchMode`], [`Os`], and [`Accelerator`]. pub fn from_mode(mode: TorchMode, os: &Os) -> Result { match mode { - TorchMode::Auto => { - if let Some(Accelerator::Cuda { driver_version }) = Accelerator::detect()? { - Ok(Self::Auto { - os: os.clone(), - driver_version: driver_version.clone(), - }) - } else { - Ok(Self::Backend(TorchBackend::Cpu)) - } - } + TorchMode::Auto => match Accelerator::detect()? { + Some(Accelerator::Cuda { driver_version }) => Ok(Self::Cuda { + os: os.clone(), + driver_version: driver_version.clone(), + }), + Some(Accelerator::Amd { gpu_architecture }) => Ok(Self::Amd { + os: os.clone(), + gpu_architecture, + }), + None => Ok(Self::Backend(TorchBackend::Cpu)), + }, TorchMode::Cpu => Ok(Self::Backend(TorchBackend::Cpu)), TorchMode::Cu128 => Ok(Self::Backend(TorchBackend::Cu128)), TorchMode::Cu126 => Ok(Self::Backend(TorchBackend::Cu126)), @@ -267,25 +273,27 @@ impl TorchStrategy { /// Return the appropriate index URLs for the given [`TorchStrategy`]. pub fn index_urls(&self) -> impl Iterator { match self { - TorchStrategy::Auto { os, driver_version } => { + TorchStrategy::Cuda { os, driver_version } => { // If this is a GPU-enabled package, and CUDA drivers are installed, use PyTorch's CUDA // indexes. // // See: https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_patch.py#L36-L49 match os { - Os::Manylinux { .. } | Os::Musllinux { .. } => Either::Left(Either::Left( - LINUX_DRIVERS - .iter() - .filter_map(move |(backend, version)| { - if driver_version >= version { - Some(backend.index_url()) - } else { - None - } - }) - .chain(std::iter::once(TorchBackend::Cpu.index_url())), - )), - Os::Windows => Either::Left(Either::Right( + Os::Manylinux { .. } | Os::Musllinux { .. } => { + Either::Left(Either::Left(Either::Left( + LINUX_CUDA_DRIVERS + .iter() + .filter_map(move |(backend, version)| { + if driver_version >= version { + Some(backend.index_url()) + } else { + None + } + }) + .chain(std::iter::once(TorchBackend::Cpu.index_url())), + ))) + } + Os::Windows => Either::Left(Either::Left(Either::Right( WINDOWS_CUDA_VERSIONS .iter() .filter_map(move |(backend, version)| { @@ -296,7 +304,7 @@ impl TorchStrategy { } }) .chain(std::iter::once(TorchBackend::Cpu.index_url())), - )), + ))), Os::Macos { .. } | Os::FreeBsd { .. } | Os::NetBsd { .. } @@ -306,11 +314,42 @@ impl TorchStrategy { | Os::Haiku { .. } | Os::Android { .. } | Os::Pyodide { .. } => { - Either::Right(std::iter::once(TorchBackend::Cpu.index_url())) + Either::Right(Either::Left(std::iter::once(TorchBackend::Cpu.index_url()))) } } } - TorchStrategy::Backend(backend) => Either::Right(std::iter::once(backend.index_url())), + TorchStrategy::Amd { + os, + gpu_architecture, + } => match os { + Os::Manylinux { .. } | Os::Musllinux { .. } => Either::Left(Either::Right( + LINUX_AMD_GPU_DRIVERS + .iter() + .filter_map(move |(backend, architecture)| { + if gpu_architecture == architecture { + Some(backend.index_url()) + } else { + None + } + }) + .chain(std::iter::once(TorchBackend::Cpu.index_url())), + )), + Os::Windows + | Os::Macos { .. } + | Os::FreeBsd { .. } + | Os::NetBsd { .. } + | Os::OpenBsd { .. } + | Os::Dragonfly { .. } + | Os::Illumos { .. } + | Os::Haiku { .. } + | Os::Android { .. } + | Os::Pyodide { .. } => { + Either::Right(Either::Left(std::iter::once(TorchBackend::Cpu.index_url()))) + } + }, + TorchStrategy::Backend(backend) => { + Either::Right(Either::Right(std::iter::once(backend.index_url()))) + } } } } @@ -578,7 +617,7 @@ impl FromStr for TorchBackend { /// Linux CUDA driver versions and the corresponding CUDA versions. /// /// See: -static LINUX_DRIVERS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock::new(|| { +static LINUX_CUDA_DRIVERS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock::new(|| { [ // Table 2 from // https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html @@ -651,6 +690,73 @@ static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock ] }); +/// Linux AMD GPU architectures and the corresponding PyTorch backends. +/// +/// These were inferred by running the following snippet for each ROCm version: +/// +/// ```python +/// import torch +/// +/// print(torch.cuda.get_arch_list()) +/// ``` +/// +/// AMD also provides a compatibility matrix: ; +/// however, this list includes a broader array of GPUs than those in the matrix. +static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 44]> = + LazyLock::new(|| { + [ + // ROCm 6.3 + (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx900), + (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx906), + (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx908), + (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx90a), + (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx942), + (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1030), + (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1100), + (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1101), + (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1102), + (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1200), + (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1201), + // ROCm 6.2.4 + (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx900), + (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx906), + (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx908), + (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx90a), + (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx942), + (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1030), + (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1100), + (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1101), + (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1102), + (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1200), + (TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1201), + // ROCm 6.2 + (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx900), + (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx906), + (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx908), + (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx90a), + (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx1030), + (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx1100), + (TorchBackend::Rocm62, AmdGpuArchitecture::Gfx942), + // ROCm 6.1 + (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx900), + (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx906), + (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx908), + (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx90a), + (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx942), + (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1030), + (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1100), + (TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1101), + // ROCm 6.0 + (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx900), + (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx906), + (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx908), + (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx90a), + (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx1030), + (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx1100), + (TorchBackend::Rocm60, AmdGpuArchitecture::Gfx942), + ] + }); + static CPU_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cpu").unwrap()); static CU128_INDEX_URL: LazyLock = diff --git a/docs/guides/integration/pytorch.md b/docs/guides/integration/pytorch.md index 7a2500ec5..a90ebeb6b 100644 --- a/docs/guides/integration/pytorch.md +++ b/docs/guides/integration/pytorch.md @@ -444,10 +444,10 @@ $ # With an environment variable. $ UV_TORCH_BACKEND=auto uv pip install torch ``` -When enabled, uv will query for the installed CUDA driver version and use the most-compatible -PyTorch index for all relevant packages (e.g., `torch`, `torchvision`, etc.). If no such CUDA driver -is found, uv will fall back to the CPU-only index. uv will continue to respect existing index -configuration for any packages outside the PyTorch ecosystem. +When enabled, uv will query for the installed CUDA driver and AMD GPU versions then use the +most-compatible PyTorch index for all relevant packages (e.g., `torch`, `torchvision`, etc.). If no +such GPU is found, uv will fall back to the CPU-only index. uv will continue to respect existing +index configuration for any packages outside the PyTorch ecosystem. You can also select a specific backend (e.g., CUDA 12.6) with `--torch-backend=cu126` (or `UV_TORCH_BACKEND=cu126`): @@ -460,5 +460,4 @@ $ # With an environment variable. $ UV_TORCH_BACKEND=cu126 uv pip install torch torchvision ``` -At present, `--torch-backend` is only available in the `uv pip` interface, and only supports -detection of CUDA drivers (as opposed to other accelerators like ROCm or Intel GPUs). +At present, `--torch-backend` is only available in the `uv pip` interface.