Add auto-detection for AMD GPUs (#14176)
Some checks failed
CI / Determine changes (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / cargo shear (push) Has been cancelled
CI / mkdocs (push) Has been cancelled
CI / typos (push) Has been cancelled
CI / cargo clippy | ubuntu (push) Has been cancelled
CI / cargo clippy | windows (push) Has been cancelled
CI / cargo dev generate-all (push) Has been cancelled
CI / cargo test | ubuntu (push) Has been cancelled
CI / cargo test | macos (push) Has been cancelled
CI / cargo test | windows (push) Has been cancelled
CI / check windows trampoline | aarch64 (push) Has been cancelled
CI / build binary | windows aarch64 (push) Has been cancelled
CI / check windows trampoline | i686 (push) Has been cancelled
CI / check windows trampoline | x86_64 (push) Has been cancelled
CI / test windows trampoline | i686 (push) Has been cancelled
CI / test windows trampoline | x86_64 (push) Has been cancelled
CI / build binary | linux libc (push) Has been cancelled
CI / build binary | linux musl (push) Has been cancelled
CI / build binary | macos aarch64 (push) Has been cancelled
CI / build binary | macos x86_64 (push) Has been cancelled
CI / build binary | windows x86_64 (push) Has been cancelled
CI / cargo build (msrv) (push) Has been cancelled
CI / build binary | freebsd (push) Has been cancelled
CI / ecosystem test | pydantic/pydantic-core (push) Has been cancelled
CI / ecosystem test | prefecthq/prefect (push) Has been cancelled
CI / ecosystem test | pallets/flask (push) Has been cancelled
CI / smoke test | linux (push) Has been cancelled
CI / check system | alpine (push) Has been cancelled
CI / smoke test | macos (push) Has been cancelled
CI / smoke test | windows x86_64 (push) Has been cancelled
CI / smoke test | windows aarch64 (push) Has been cancelled
CI / integration test | conda on ubuntu (push) Has been cancelled
CI / integration test | deadsnakes python3.9 on ubuntu (push) Has been cancelled
CI / integration test | free-threaded on windows (push) Has been cancelled
CI / integration test | pypy on ubuntu (push) Has been cancelled
CI / integration test | pypy on windows (push) Has been cancelled
CI / integration test | graalpy on ubuntu (push) Has been cancelled
CI / integration test | graalpy on windows (push) Has been cancelled
CI / integration test | pyodide on ubuntu (push) Has been cancelled
CI / integration test | github actions (push) Has been cancelled
CI / integration test | free-threaded python on github actions (push) Has been cancelled
CI / integration test | determine publish changes (push) Has been cancelled
CI / integration test | registries (push) Has been cancelled
CI / integration test | uv publish (push) Has been cancelled
CI / integration test | uv_build (push) Has been cancelled
CI / check cache | ubuntu (push) Has been cancelled
CI / check cache | macos aarch64 (push) Has been cancelled
CI / check system | python on debian (push) Has been cancelled
CI / check system | python on fedora (push) Has been cancelled
CI / check system | python on ubuntu (push) Has been cancelled
CI / check system | python on rocky linux 8 (push) Has been cancelled
CI / check system | python on rocky linux 9 (push) Has been cancelled
CI / check system | graalpy on ubuntu (push) Has been cancelled
CI / check system | pypy on ubuntu (push) Has been cancelled
CI / check system | pyston (push) Has been cancelled
CI / check system | python on macos aarch64 (push) Has been cancelled
CI / check system | homebrew python on macos aarch64 (push) Has been cancelled
CI / check system | python on macos x86-64 (push) Has been cancelled
CI / check system | python3.10 on windows x86-64 (push) Has been cancelled
CI / check system | python3.10 on windows x86 (push) Has been cancelled
CI / check system | python3.13 on windows x86-64 (push) Has been cancelled
CI / check system | x86-64 python3.13 on windows aarch64 (push) Has been cancelled
CI / check system | windows registry (push) Has been cancelled
CI / check system | python3.12 via chocolatey (push) Has been cancelled
CI / check system | python3.9 via pyenv (push) Has been cancelled
CI / check system | python3.13 (push) Has been cancelled
CI / check system | conda3.11 on macos aarch64 (push) Has been cancelled
CI / check system | conda3.8 on macos aarch64 (push) Has been cancelled
CI / check system | conda3.11 on linux x86-64 (push) Has been cancelled
CI / check system | conda3.8 on linux x86-64 (push) Has been cancelled
CI / check system | conda3.11 on windows x86-64 (push) Has been cancelled
CI / check system | conda3.8 on windows x86-64 (push) Has been cancelled
CI / check system | amazonlinux (push) Has been cancelled
CI / check system | embedded python3.10 on windows x86-64 (push) Has been cancelled
CI / benchmarks | walltime aarch64 linux (push) Has been cancelled
CI / benchmarks | instrumented (push) Has been cancelled

## Summary

Allows `--torch-backend=auto` to detect AMD GPUs. The approach is fairly
well-documented inline, but I opted for `rocm_agent_enumerator` over
(e.g.) `rocminfo` since it seems to be the recommended approach for
scripting:
https://rocm.docs.amd.com/projects/rocminfo/en/latest/how-to/use-rocm-agent-enumerator.html.

Closes https://github.com/astral-sh/uv/issues/14086.

## Test Plan

```
root@rocm-jupyter-gpu-mi300x1-192gb-devcloud-atl1:~# ./uv-linux-libc-11fb582c5c046bae09766ceddd276dcc5bb41218/uv pip install torch --torch-backend=auto
Resolved 11 packages in 251ms
Prepared 2 packages in 6ms
Installed 11 packages in 257ms
 + filelock==3.18.0
 + fsspec==2025.5.1
 + jinja2==3.1.6
 + markupsafe==3.0.2
 + mpmath==1.3.0
 + networkx==3.5
 + pytorch-triton-rocm==3.3.1
 + setuptools==80.9.0
 + sympy==1.14.0
 + torch==2.7.1+rocm6.3
 + typing-extensions==4.14.0
```

---------

Co-authored-by: Zanie Blue <contact@zanie.dev>
This commit is contained in:
Charlie Marsh 2025-06-21 11:21:06 -04:00 committed by GitHub
parent f0407e4b6f
commit a82c210cab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 257 additions and 40 deletions

View file

@ -47,7 +47,7 @@ use uv_normalize::PackageName;
use uv_pep440::Version;
use uv_platform_tags::Os;
use crate::{Accelerator, AcceleratorError};
use crate::{Accelerator, AcceleratorError, AmdGpuArchitecture};
/// The strategy to use when determining the appropriate PyTorch index.
#[derive(Debug, Copy, Clone, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
@ -178,8 +178,13 @@ pub enum TorchMode {
/// The strategy to use when determining the appropriate PyTorch index.
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum TorchStrategy {
/// Select the appropriate PyTorch index based on the operating system and CUDA driver version.
Auto { os: Os, driver_version: Version },
/// Select the appropriate PyTorch index based on the operating system and CUDA driver version (e.g., `550.144.03`).
Cuda { os: Os, driver_version: Version },
/// Select the appropriate PyTorch index based on the operating system and AMD GPU architecture (e.g., `gfx1100`).
Amd {
os: Os,
gpu_architecture: AmdGpuArchitecture,
},
/// Use the specified PyTorch index.
Backend(TorchBackend),
}
@ -188,16 +193,17 @@ impl TorchStrategy {
/// Determine the [`TorchStrategy`] from the given [`TorchMode`], [`Os`], and [`Accelerator`].
pub fn from_mode(mode: TorchMode, os: &Os) -> Result<Self, AcceleratorError> {
match mode {
TorchMode::Auto => {
if let Some(Accelerator::Cuda { driver_version }) = Accelerator::detect()? {
Ok(Self::Auto {
os: os.clone(),
driver_version: driver_version.clone(),
})
} else {
Ok(Self::Backend(TorchBackend::Cpu))
}
}
TorchMode::Auto => match Accelerator::detect()? {
Some(Accelerator::Cuda { driver_version }) => Ok(Self::Cuda {
os: os.clone(),
driver_version: driver_version.clone(),
}),
Some(Accelerator::Amd { gpu_architecture }) => Ok(Self::Amd {
os: os.clone(),
gpu_architecture,
}),
None => Ok(Self::Backend(TorchBackend::Cpu)),
},
TorchMode::Cpu => Ok(Self::Backend(TorchBackend::Cpu)),
TorchMode::Cu128 => Ok(Self::Backend(TorchBackend::Cu128)),
TorchMode::Cu126 => Ok(Self::Backend(TorchBackend::Cu126)),
@ -267,25 +273,27 @@ impl TorchStrategy {
/// Return the appropriate index URLs for the given [`TorchStrategy`].
pub fn index_urls(&self) -> impl Iterator<Item = &IndexUrl> {
match self {
TorchStrategy::Auto { os, driver_version } => {
TorchStrategy::Cuda { os, driver_version } => {
// If this is a GPU-enabled package, and CUDA drivers are installed, use PyTorch's CUDA
// indexes.
//
// See: https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_patch.py#L36-L49
match os {
Os::Manylinux { .. } | Os::Musllinux { .. } => Either::Left(Either::Left(
LINUX_DRIVERS
.iter()
.filter_map(move |(backend, version)| {
if driver_version >= version {
Some(backend.index_url())
} else {
None
}
})
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
)),
Os::Windows => Either::Left(Either::Right(
Os::Manylinux { .. } | Os::Musllinux { .. } => {
Either::Left(Either::Left(Either::Left(
LINUX_CUDA_DRIVERS
.iter()
.filter_map(move |(backend, version)| {
if driver_version >= version {
Some(backend.index_url())
} else {
None
}
})
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
)))
}
Os::Windows => Either::Left(Either::Left(Either::Right(
WINDOWS_CUDA_VERSIONS
.iter()
.filter_map(move |(backend, version)| {
@ -296,7 +304,7 @@ impl TorchStrategy {
}
})
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
)),
))),
Os::Macos { .. }
| Os::FreeBsd { .. }
| Os::NetBsd { .. }
@ -306,11 +314,42 @@ impl TorchStrategy {
| Os::Haiku { .. }
| Os::Android { .. }
| Os::Pyodide { .. } => {
Either::Right(std::iter::once(TorchBackend::Cpu.index_url()))
Either::Right(Either::Left(std::iter::once(TorchBackend::Cpu.index_url())))
}
}
}
TorchStrategy::Backend(backend) => Either::Right(std::iter::once(backend.index_url())),
TorchStrategy::Amd {
os,
gpu_architecture,
} => match os {
Os::Manylinux { .. } | Os::Musllinux { .. } => Either::Left(Either::Right(
LINUX_AMD_GPU_DRIVERS
.iter()
.filter_map(move |(backend, architecture)| {
if gpu_architecture == architecture {
Some(backend.index_url())
} else {
None
}
})
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
)),
Os::Windows
| Os::Macos { .. }
| Os::FreeBsd { .. }
| Os::NetBsd { .. }
| Os::OpenBsd { .. }
| Os::Dragonfly { .. }
| Os::Illumos { .. }
| Os::Haiku { .. }
| Os::Android { .. }
| Os::Pyodide { .. } => {
Either::Right(Either::Left(std::iter::once(TorchBackend::Cpu.index_url())))
}
},
TorchStrategy::Backend(backend) => {
Either::Right(Either::Right(std::iter::once(backend.index_url())))
}
}
}
}
@ -578,7 +617,7 @@ impl FromStr for TorchBackend {
/// Linux CUDA driver versions and the corresponding CUDA versions.
///
/// See: <https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_cb.py#L150-L213>
static LINUX_DRIVERS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock::new(|| {
static LINUX_CUDA_DRIVERS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock::new(|| {
[
// Table 2 from
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
@ -651,6 +690,73 @@ static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock
]
});
/// Linux AMD GPU architectures and the corresponding PyTorch backends.
///
/// These were inferred by running the following snippet for each ROCm version:
///
/// ```python
/// import torch
///
/// print(torch.cuda.get_arch_list())
/// ```
///
/// AMD also provides a compatibility matrix: <https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html>;
/// however, this list includes a broader array of GPUs than those in the matrix.
static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 44]> =
LazyLock::new(|| {
[
// ROCm 6.3
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx900),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx906),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx908),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx90a),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx942),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1030),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1100),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1101),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1102),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1200),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1201),
// ROCm 6.2.4
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx900),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx906),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx908),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx90a),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx942),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1030),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1100),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1101),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1102),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1200),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1201),
// ROCm 6.2
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx900),
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx906),
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx908),
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx90a),
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx1030),
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx1100),
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx942),
// ROCm 6.1
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx900),
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx906),
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx908),
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx90a),
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx942),
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1030),
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1100),
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1101),
// ROCm 6.0
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx900),
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx906),
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx908),
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx90a),
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx1030),
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx1100),
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx942),
]
});
static CPU_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cpu").unwrap());
static CU128_INDEX_URL: LazyLock<IndexUrl> =