mirror of
https://github.com/astral-sh/uv.git
synced 2025-07-07 13:25:00 +00:00
Add auto-detection for AMD GPUs (#14176)
Some checks failed
CI / Determine changes (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / cargo shear (push) Has been cancelled
CI / mkdocs (push) Has been cancelled
CI / typos (push) Has been cancelled
CI / cargo clippy | ubuntu (push) Has been cancelled
CI / cargo clippy | windows (push) Has been cancelled
CI / cargo dev generate-all (push) Has been cancelled
CI / cargo test | ubuntu (push) Has been cancelled
CI / cargo test | macos (push) Has been cancelled
CI / cargo test | windows (push) Has been cancelled
CI / check windows trampoline | aarch64 (push) Has been cancelled
CI / build binary | windows aarch64 (push) Has been cancelled
CI / check windows trampoline | i686 (push) Has been cancelled
CI / check windows trampoline | x86_64 (push) Has been cancelled
CI / test windows trampoline | i686 (push) Has been cancelled
CI / test windows trampoline | x86_64 (push) Has been cancelled
CI / build binary | linux libc (push) Has been cancelled
CI / build binary | linux musl (push) Has been cancelled
CI / build binary | macos aarch64 (push) Has been cancelled
CI / build binary | macos x86_64 (push) Has been cancelled
CI / build binary | windows x86_64 (push) Has been cancelled
CI / cargo build (msrv) (push) Has been cancelled
CI / build binary | freebsd (push) Has been cancelled
CI / ecosystem test | pydantic/pydantic-core (push) Has been cancelled
CI / ecosystem test | prefecthq/prefect (push) Has been cancelled
CI / ecosystem test | pallets/flask (push) Has been cancelled
CI / smoke test | linux (push) Has been cancelled
CI / check system | alpine (push) Has been cancelled
CI / smoke test | macos (push) Has been cancelled
CI / smoke test | windows x86_64 (push) Has been cancelled
CI / smoke test | windows aarch64 (push) Has been cancelled
CI / integration test | conda on ubuntu (push) Has been cancelled
CI / integration test | deadsnakes python3.9 on ubuntu (push) Has been cancelled
CI / integration test | free-threaded on windows (push) Has been cancelled
CI / integration test | pypy on ubuntu (push) Has been cancelled
CI / integration test | pypy on windows (push) Has been cancelled
CI / integration test | graalpy on ubuntu (push) Has been cancelled
CI / integration test | graalpy on windows (push) Has been cancelled
CI / integration test | pyodide on ubuntu (push) Has been cancelled
CI / integration test | github actions (push) Has been cancelled
CI / integration test | free-threaded python on github actions (push) Has been cancelled
CI / integration test | determine publish changes (push) Has been cancelled
CI / integration test | registries (push) Has been cancelled
CI / integration test | uv publish (push) Has been cancelled
CI / integration test | uv_build (push) Has been cancelled
CI / check cache | ubuntu (push) Has been cancelled
CI / check cache | macos aarch64 (push) Has been cancelled
CI / check system | python on debian (push) Has been cancelled
CI / check system | python on fedora (push) Has been cancelled
CI / check system | python on ubuntu (push) Has been cancelled
CI / check system | python on rocky linux 8 (push) Has been cancelled
CI / check system | python on rocky linux 9 (push) Has been cancelled
CI / check system | graalpy on ubuntu (push) Has been cancelled
CI / check system | pypy on ubuntu (push) Has been cancelled
CI / check system | pyston (push) Has been cancelled
CI / check system | python on macos aarch64 (push) Has been cancelled
CI / check system | homebrew python on macos aarch64 (push) Has been cancelled
CI / check system | python on macos x86-64 (push) Has been cancelled
CI / check system | python3.10 on windows x86-64 (push) Has been cancelled
CI / check system | python3.10 on windows x86 (push) Has been cancelled
CI / check system | python3.13 on windows x86-64 (push) Has been cancelled
CI / check system | x86-64 python3.13 on windows aarch64 (push) Has been cancelled
CI / check system | windows registry (push) Has been cancelled
CI / check system | python3.12 via chocolatey (push) Has been cancelled
CI / check system | python3.9 via pyenv (push) Has been cancelled
CI / check system | python3.13 (push) Has been cancelled
CI / check system | conda3.11 on macos aarch64 (push) Has been cancelled
CI / check system | conda3.8 on macos aarch64 (push) Has been cancelled
CI / check system | conda3.11 on linux x86-64 (push) Has been cancelled
CI / check system | conda3.8 on linux x86-64 (push) Has been cancelled
CI / check system | conda3.11 on windows x86-64 (push) Has been cancelled
CI / check system | conda3.8 on windows x86-64 (push) Has been cancelled
CI / check system | amazonlinux (push) Has been cancelled
CI / check system | embedded python3.10 on windows x86-64 (push) Has been cancelled
CI / benchmarks | walltime aarch64 linux (push) Has been cancelled
CI / benchmarks | instrumented (push) Has been cancelled
Some checks failed
CI / Determine changes (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / cargo shear (push) Has been cancelled
CI / mkdocs (push) Has been cancelled
CI / typos (push) Has been cancelled
CI / cargo clippy | ubuntu (push) Has been cancelled
CI / cargo clippy | windows (push) Has been cancelled
CI / cargo dev generate-all (push) Has been cancelled
CI / cargo test | ubuntu (push) Has been cancelled
CI / cargo test | macos (push) Has been cancelled
CI / cargo test | windows (push) Has been cancelled
CI / check windows trampoline | aarch64 (push) Has been cancelled
CI / build binary | windows aarch64 (push) Has been cancelled
CI / check windows trampoline | i686 (push) Has been cancelled
CI / check windows trampoline | x86_64 (push) Has been cancelled
CI / test windows trampoline | i686 (push) Has been cancelled
CI / test windows trampoline | x86_64 (push) Has been cancelled
CI / build binary | linux libc (push) Has been cancelled
CI / build binary | linux musl (push) Has been cancelled
CI / build binary | macos aarch64 (push) Has been cancelled
CI / build binary | macos x86_64 (push) Has been cancelled
CI / build binary | windows x86_64 (push) Has been cancelled
CI / cargo build (msrv) (push) Has been cancelled
CI / build binary | freebsd (push) Has been cancelled
CI / ecosystem test | pydantic/pydantic-core (push) Has been cancelled
CI / ecosystem test | prefecthq/prefect (push) Has been cancelled
CI / ecosystem test | pallets/flask (push) Has been cancelled
CI / smoke test | linux (push) Has been cancelled
CI / check system | alpine (push) Has been cancelled
CI / smoke test | macos (push) Has been cancelled
CI / smoke test | windows x86_64 (push) Has been cancelled
CI / smoke test | windows aarch64 (push) Has been cancelled
CI / integration test | conda on ubuntu (push) Has been cancelled
CI / integration test | deadsnakes python3.9 on ubuntu (push) Has been cancelled
CI / integration test | free-threaded on windows (push) Has been cancelled
CI / integration test | pypy on ubuntu (push) Has been cancelled
CI / integration test | pypy on windows (push) Has been cancelled
CI / integration test | graalpy on ubuntu (push) Has been cancelled
CI / integration test | graalpy on windows (push) Has been cancelled
CI / integration test | pyodide on ubuntu (push) Has been cancelled
CI / integration test | github actions (push) Has been cancelled
CI / integration test | free-threaded python on github actions (push) Has been cancelled
CI / integration test | determine publish changes (push) Has been cancelled
CI / integration test | registries (push) Has been cancelled
CI / integration test | uv publish (push) Has been cancelled
CI / integration test | uv_build (push) Has been cancelled
CI / check cache | ubuntu (push) Has been cancelled
CI / check cache | macos aarch64 (push) Has been cancelled
CI / check system | python on debian (push) Has been cancelled
CI / check system | python on fedora (push) Has been cancelled
CI / check system | python on ubuntu (push) Has been cancelled
CI / check system | python on rocky linux 8 (push) Has been cancelled
CI / check system | python on rocky linux 9 (push) Has been cancelled
CI / check system | graalpy on ubuntu (push) Has been cancelled
CI / check system | pypy on ubuntu (push) Has been cancelled
CI / check system | pyston (push) Has been cancelled
CI / check system | python on macos aarch64 (push) Has been cancelled
CI / check system | homebrew python on macos aarch64 (push) Has been cancelled
CI / check system | python on macos x86-64 (push) Has been cancelled
CI / check system | python3.10 on windows x86-64 (push) Has been cancelled
CI / check system | python3.10 on windows x86 (push) Has been cancelled
CI / check system | python3.13 on windows x86-64 (push) Has been cancelled
CI / check system | x86-64 python3.13 on windows aarch64 (push) Has been cancelled
CI / check system | windows registry (push) Has been cancelled
CI / check system | python3.12 via chocolatey (push) Has been cancelled
CI / check system | python3.9 via pyenv (push) Has been cancelled
CI / check system | python3.13 (push) Has been cancelled
CI / check system | conda3.11 on macos aarch64 (push) Has been cancelled
CI / check system | conda3.8 on macos aarch64 (push) Has been cancelled
CI / check system | conda3.11 on linux x86-64 (push) Has been cancelled
CI / check system | conda3.8 on linux x86-64 (push) Has been cancelled
CI / check system | conda3.11 on windows x86-64 (push) Has been cancelled
CI / check system | conda3.8 on windows x86-64 (push) Has been cancelled
CI / check system | amazonlinux (push) Has been cancelled
CI / check system | embedded python3.10 on windows x86-64 (push) Has been cancelled
CI / benchmarks | walltime aarch64 linux (push) Has been cancelled
CI / benchmarks | instrumented (push) Has been cancelled
## Summary Allows `--torch-backend=auto` to detect AMD GPUs. The approach is fairly well-documented inline, but I opted for `rocm_agent_enumerator` over (e.g.) `rocminfo` since it seems to be the recommended approach for scripting: https://rocm.docs.amd.com/projects/rocminfo/en/latest/how-to/use-rocm-agent-enumerator.html. Closes https://github.com/astral-sh/uv/issues/14086. ## Test Plan ``` root@rocm-jupyter-gpu-mi300x1-192gb-devcloud-atl1:~# ./uv-linux-libc-11fb582c5c046bae09766ceddd276dcc5bb41218/uv pip install torch --torch-backend=auto Resolved 11 packages in 251ms Prepared 2 packages in 6ms Installed 11 packages in 257ms + filelock==3.18.0 + fsspec==2025.5.1 + jinja2==3.1.6 + markupsafe==3.0.2 + mpmath==1.3.0 + networkx==3.5 + pytorch-triton-rocm==3.3.1 + setuptools==80.9.0 + sympy==1.14.0 + torch==2.7.1+rocm6.3 + typing-extensions==4.14.0 ``` --------- Co-authored-by: Zanie Blue <contact@zanie.dev>
This commit is contained in:
parent
f0407e4b6f
commit
a82c210cab
5 changed files with 257 additions and 40 deletions
|
@ -1814,7 +1814,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
|
|||
.options
|
||||
.torch_backend
|
||||
.as_ref()
|
||||
.filter(|torch_backend| matches!(torch_backend, TorchStrategy::Auto { .. }))
|
||||
.filter(|torch_backend| matches!(torch_backend, TorchStrategy::Cuda { .. }))
|
||||
.and_then(|_| pins.get(name, version).and_then(ResolvedDist::index))
|
||||
.map(IndexUrl::url)
|
||||
.and_then(SystemDependency::from_index)
|
||||
|
|
|
@ -718,10 +718,14 @@ impl EnvVars {
|
|||
/// This is a quasi-standard variable, described, e.g., in `ncurses(3x)`.
|
||||
pub const COLUMNS: &'static str = "COLUMNS";
|
||||
|
||||
/// The CUDA driver version to assume when inferring the PyTorch backend.
|
||||
/// The CUDA driver version to assume when inferring the PyTorch backend (e.g., `550.144.03`).
|
||||
#[attr_hidden]
|
||||
pub const UV_CUDA_DRIVER_VERSION: &'static str = "UV_CUDA_DRIVER_VERSION";
|
||||
|
||||
/// The AMD GPU architecture to assume when inferring the PyTorch backend (e.g., `gfx1100`).
|
||||
#[attr_hidden]
|
||||
pub const UV_AMD_GPU_ARCHITECTURE: &'static str = "UV_AMD_GPU_ARCHITECTURE";
|
||||
|
||||
/// Equivalent to the `--torch-backend` command-line argument (e.g., `cpu`, `cu126`, or `auto`).
|
||||
pub const UV_TORCH_BACKEND: &'static str = "UV_TORCH_BACKEND";
|
||||
|
||||
|
|
|
@ -13,17 +13,30 @@ pub enum AcceleratorError {
|
|||
Version(#[from] uv_pep440::VersionParseError),
|
||||
#[error(transparent)]
|
||||
Utf8(#[from] std::string::FromUtf8Error),
|
||||
#[error("Unknown AMD GPU architecture: {0}")]
|
||||
UnknownAmdGpuArchitecture(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum Accelerator {
|
||||
/// The CUDA driver version (e.g., `550.144.03`).
|
||||
///
|
||||
/// This is in contrast to the CUDA toolkit version (e.g., `12.8.0`).
|
||||
Cuda { driver_version: Version },
|
||||
/// The AMD GPU architecture (e.g., `gfx906`).
|
||||
///
|
||||
/// This is in contrast to the user-space ROCm version (e.g., `6.4.0-47`) or the kernel-mode
|
||||
/// driver version (e.g., `6.12.12`).
|
||||
Amd {
|
||||
gpu_architecture: AmdGpuArchitecture,
|
||||
},
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Accelerator {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Cuda { driver_version } => write!(f, "CUDA {driver_version}"),
|
||||
Self::Amd { gpu_architecture } => write!(f, "AMD {gpu_architecture}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -33,9 +46,11 @@ impl Accelerator {
|
|||
///
|
||||
/// Query, in order:
|
||||
/// 1. The `UV_CUDA_DRIVER_VERSION` environment variable.
|
||||
/// 2. The `UV_AMD_GPU_ARCHITECTURE` environment variable.
|
||||
/// 2. `/sys/module/nvidia/version`, which contains the driver version (e.g., `550.144.03`).
|
||||
/// 3. `/proc/driver/nvidia/version`, which contains the driver version among other information.
|
||||
/// 4. `nvidia-smi --query-gpu=driver_version --format=csv,noheader`.
|
||||
/// 5. `rocm_agent_enumerator`, which lists the AMD GPU architectures.
|
||||
pub fn detect() -> Result<Option<Self>, AcceleratorError> {
|
||||
// Read from `UV_CUDA_DRIVER_VERSION`.
|
||||
if let Ok(driver_version) = std::env::var(EnvVars::UV_CUDA_DRIVER_VERSION) {
|
||||
|
@ -44,6 +59,15 @@ impl Accelerator {
|
|||
return Ok(Some(Self::Cuda { driver_version }));
|
||||
}
|
||||
|
||||
// Read from `UV_AMD_GPU_ARCHITECTURE`.
|
||||
if let Ok(gpu_architecture) = std::env::var(EnvVars::UV_AMD_GPU_ARCHITECTURE) {
|
||||
let gpu_architecture = AmdGpuArchitecture::from_str(&gpu_architecture)?;
|
||||
debug!(
|
||||
"Detected AMD GPU architecture from `UV_AMD_GPU_ARCHITECTURE`: {gpu_architecture}"
|
||||
);
|
||||
return Ok(Some(Self::Amd { gpu_architecture }));
|
||||
}
|
||||
|
||||
// Read from `/sys/module/nvidia/version`.
|
||||
match fs_err::read_to_string("/sys/module/nvidia/version") {
|
||||
Ok(content) => {
|
||||
|
@ -100,7 +124,34 @@ impl Accelerator {
|
|||
);
|
||||
}
|
||||
|
||||
debug!("Failed to detect CUDA driver version");
|
||||
// Query `rocm_agent_enumerator` to detect the AMD GPU architecture.
|
||||
//
|
||||
// See: https://rocm.docs.amd.com/projects/rocminfo/en/latest/how-to/use-rocm-agent-enumerator.html
|
||||
if let Ok(output) = std::process::Command::new("rocm_agent_enumerator").output() {
|
||||
if output.status.success() {
|
||||
let stdout = String::from_utf8(output.stdout)?;
|
||||
if let Some(gpu_architecture) = stdout
|
||||
.lines()
|
||||
.map(str::trim)
|
||||
.filter_map(|line| AmdGpuArchitecture::from_str(line).ok())
|
||||
.min()
|
||||
{
|
||||
debug!(
|
||||
"Detected AMD GPU architecture from `rocm_agent_enumerator`: {gpu_architecture}"
|
||||
);
|
||||
return Ok(Some(Self::Amd { gpu_architecture }));
|
||||
}
|
||||
} else {
|
||||
debug!(
|
||||
"Failed to query AMD GPU architecture with `rocm_agent_enumerator` with status `{}`: {}",
|
||||
output.status,
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
debug!("Failed to detect GPU driver version");
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
@ -129,6 +180,63 @@ fn parse_proc_driver_nvidia_version(content: &str) -> Result<Option<Version>, Ac
|
|||
Ok(Some(driver_version))
|
||||
}
|
||||
|
||||
/// A GPU architecture for AMD GPUs.
|
||||
///
|
||||
/// See: <https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html>
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub enum AmdGpuArchitecture {
|
||||
Gfx900,
|
||||
Gfx906,
|
||||
Gfx908,
|
||||
Gfx90a,
|
||||
Gfx942,
|
||||
Gfx1030,
|
||||
Gfx1100,
|
||||
Gfx1101,
|
||||
Gfx1102,
|
||||
Gfx1200,
|
||||
Gfx1201,
|
||||
}
|
||||
|
||||
impl FromStr for AmdGpuArchitecture {
|
||||
type Err = AcceleratorError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"gfx900" => Ok(Self::Gfx900),
|
||||
"gfx906" => Ok(Self::Gfx906),
|
||||
"gfx908" => Ok(Self::Gfx908),
|
||||
"gfx90a" => Ok(Self::Gfx90a),
|
||||
"gfx942" => Ok(Self::Gfx942),
|
||||
"gfx1030" => Ok(Self::Gfx1030),
|
||||
"gfx1100" => Ok(Self::Gfx1100),
|
||||
"gfx1101" => Ok(Self::Gfx1101),
|
||||
"gfx1102" => Ok(Self::Gfx1102),
|
||||
"gfx1200" => Ok(Self::Gfx1200),
|
||||
"gfx1201" => Ok(Self::Gfx1201),
|
||||
_ => Err(AcceleratorError::UnknownAmdGpuArchitecture(s.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AmdGpuArchitecture {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Gfx900 => write!(f, "gfx900"),
|
||||
Self::Gfx906 => write!(f, "gfx906"),
|
||||
Self::Gfx908 => write!(f, "gfx908"),
|
||||
Self::Gfx90a => write!(f, "gfx90a"),
|
||||
Self::Gfx942 => write!(f, "gfx942"),
|
||||
Self::Gfx1030 => write!(f, "gfx1030"),
|
||||
Self::Gfx1100 => write!(f, "gfx1100"),
|
||||
Self::Gfx1101 => write!(f, "gfx1101"),
|
||||
Self::Gfx1102 => write!(f, "gfx1102"),
|
||||
Self::Gfx1200 => write!(f, "gfx1200"),
|
||||
Self::Gfx1201 => write!(f, "gfx1201"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -47,7 +47,7 @@ use uv_normalize::PackageName;
|
|||
use uv_pep440::Version;
|
||||
use uv_platform_tags::Os;
|
||||
|
||||
use crate::{Accelerator, AcceleratorError};
|
||||
use crate::{Accelerator, AcceleratorError, AmdGpuArchitecture};
|
||||
|
||||
/// The strategy to use when determining the appropriate PyTorch index.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
|
||||
|
@ -178,8 +178,13 @@ pub enum TorchMode {
|
|||
/// The strategy to use when determining the appropriate PyTorch index.
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum TorchStrategy {
|
||||
/// Select the appropriate PyTorch index based on the operating system and CUDA driver version.
|
||||
Auto { os: Os, driver_version: Version },
|
||||
/// Select the appropriate PyTorch index based on the operating system and CUDA driver version (e.g., `550.144.03`).
|
||||
Cuda { os: Os, driver_version: Version },
|
||||
/// Select the appropriate PyTorch index based on the operating system and AMD GPU architecture (e.g., `gfx1100`).
|
||||
Amd {
|
||||
os: Os,
|
||||
gpu_architecture: AmdGpuArchitecture,
|
||||
},
|
||||
/// Use the specified PyTorch index.
|
||||
Backend(TorchBackend),
|
||||
}
|
||||
|
@ -188,16 +193,17 @@ impl TorchStrategy {
|
|||
/// Determine the [`TorchStrategy`] from the given [`TorchMode`], [`Os`], and [`Accelerator`].
|
||||
pub fn from_mode(mode: TorchMode, os: &Os) -> Result<Self, AcceleratorError> {
|
||||
match mode {
|
||||
TorchMode::Auto => {
|
||||
if let Some(Accelerator::Cuda { driver_version }) = Accelerator::detect()? {
|
||||
Ok(Self::Auto {
|
||||
os: os.clone(),
|
||||
driver_version: driver_version.clone(),
|
||||
})
|
||||
} else {
|
||||
Ok(Self::Backend(TorchBackend::Cpu))
|
||||
}
|
||||
}
|
||||
TorchMode::Auto => match Accelerator::detect()? {
|
||||
Some(Accelerator::Cuda { driver_version }) => Ok(Self::Cuda {
|
||||
os: os.clone(),
|
||||
driver_version: driver_version.clone(),
|
||||
}),
|
||||
Some(Accelerator::Amd { gpu_architecture }) => Ok(Self::Amd {
|
||||
os: os.clone(),
|
||||
gpu_architecture,
|
||||
}),
|
||||
None => Ok(Self::Backend(TorchBackend::Cpu)),
|
||||
},
|
||||
TorchMode::Cpu => Ok(Self::Backend(TorchBackend::Cpu)),
|
||||
TorchMode::Cu128 => Ok(Self::Backend(TorchBackend::Cu128)),
|
||||
TorchMode::Cu126 => Ok(Self::Backend(TorchBackend::Cu126)),
|
||||
|
@ -267,25 +273,27 @@ impl TorchStrategy {
|
|||
/// Return the appropriate index URLs for the given [`TorchStrategy`].
|
||||
pub fn index_urls(&self) -> impl Iterator<Item = &IndexUrl> {
|
||||
match self {
|
||||
TorchStrategy::Auto { os, driver_version } => {
|
||||
TorchStrategy::Cuda { os, driver_version } => {
|
||||
// If this is a GPU-enabled package, and CUDA drivers are installed, use PyTorch's CUDA
|
||||
// indexes.
|
||||
//
|
||||
// See: https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_patch.py#L36-L49
|
||||
match os {
|
||||
Os::Manylinux { .. } | Os::Musllinux { .. } => Either::Left(Either::Left(
|
||||
LINUX_DRIVERS
|
||||
.iter()
|
||||
.filter_map(move |(backend, version)| {
|
||||
if driver_version >= version {
|
||||
Some(backend.index_url())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
|
||||
)),
|
||||
Os::Windows => Either::Left(Either::Right(
|
||||
Os::Manylinux { .. } | Os::Musllinux { .. } => {
|
||||
Either::Left(Either::Left(Either::Left(
|
||||
LINUX_CUDA_DRIVERS
|
||||
.iter()
|
||||
.filter_map(move |(backend, version)| {
|
||||
if driver_version >= version {
|
||||
Some(backend.index_url())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
|
||||
)))
|
||||
}
|
||||
Os::Windows => Either::Left(Either::Left(Either::Right(
|
||||
WINDOWS_CUDA_VERSIONS
|
||||
.iter()
|
||||
.filter_map(move |(backend, version)| {
|
||||
|
@ -296,7 +304,7 @@ impl TorchStrategy {
|
|||
}
|
||||
})
|
||||
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
|
||||
)),
|
||||
))),
|
||||
Os::Macos { .. }
|
||||
| Os::FreeBsd { .. }
|
||||
| Os::NetBsd { .. }
|
||||
|
@ -306,11 +314,42 @@ impl TorchStrategy {
|
|||
| Os::Haiku { .. }
|
||||
| Os::Android { .. }
|
||||
| Os::Pyodide { .. } => {
|
||||
Either::Right(std::iter::once(TorchBackend::Cpu.index_url()))
|
||||
Either::Right(Either::Left(std::iter::once(TorchBackend::Cpu.index_url())))
|
||||
}
|
||||
}
|
||||
}
|
||||
TorchStrategy::Backend(backend) => Either::Right(std::iter::once(backend.index_url())),
|
||||
TorchStrategy::Amd {
|
||||
os,
|
||||
gpu_architecture,
|
||||
} => match os {
|
||||
Os::Manylinux { .. } | Os::Musllinux { .. } => Either::Left(Either::Right(
|
||||
LINUX_AMD_GPU_DRIVERS
|
||||
.iter()
|
||||
.filter_map(move |(backend, architecture)| {
|
||||
if gpu_architecture == architecture {
|
||||
Some(backend.index_url())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
|
||||
)),
|
||||
Os::Windows
|
||||
| Os::Macos { .. }
|
||||
| Os::FreeBsd { .. }
|
||||
| Os::NetBsd { .. }
|
||||
| Os::OpenBsd { .. }
|
||||
| Os::Dragonfly { .. }
|
||||
| Os::Illumos { .. }
|
||||
| Os::Haiku { .. }
|
||||
| Os::Android { .. }
|
||||
| Os::Pyodide { .. } => {
|
||||
Either::Right(Either::Left(std::iter::once(TorchBackend::Cpu.index_url())))
|
||||
}
|
||||
},
|
||||
TorchStrategy::Backend(backend) => {
|
||||
Either::Right(Either::Right(std::iter::once(backend.index_url())))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -578,7 +617,7 @@ impl FromStr for TorchBackend {
|
|||
/// Linux CUDA driver versions and the corresponding CUDA versions.
|
||||
///
|
||||
/// See: <https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_cb.py#L150-L213>
|
||||
static LINUX_DRIVERS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock::new(|| {
|
||||
static LINUX_CUDA_DRIVERS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock::new(|| {
|
||||
[
|
||||
// Table 2 from
|
||||
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
|
||||
|
@ -651,6 +690,73 @@ static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock
|
|||
]
|
||||
});
|
||||
|
||||
/// Linux AMD GPU architectures and the corresponding PyTorch backends.
|
||||
///
|
||||
/// These were inferred by running the following snippet for each ROCm version:
|
||||
///
|
||||
/// ```python
|
||||
/// import torch
|
||||
///
|
||||
/// print(torch.cuda.get_arch_list())
|
||||
/// ```
|
||||
///
|
||||
/// AMD also provides a compatibility matrix: <https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html>;
|
||||
/// however, this list includes a broader array of GPUs than those in the matrix.
|
||||
static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 44]> =
|
||||
LazyLock::new(|| {
|
||||
[
|
||||
// ROCm 6.3
|
||||
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx900),
|
||||
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx906),
|
||||
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx908),
|
||||
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx90a),
|
||||
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx942),
|
||||
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1030),
|
||||
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1100),
|
||||
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1101),
|
||||
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1102),
|
||||
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1200),
|
||||
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1201),
|
||||
// ROCm 6.2.4
|
||||
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx900),
|
||||
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx906),
|
||||
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx908),
|
||||
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx90a),
|
||||
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx942),
|
||||
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1030),
|
||||
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1100),
|
||||
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1101),
|
||||
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1102),
|
||||
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1200),
|
||||
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1201),
|
||||
// ROCm 6.2
|
||||
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx900),
|
||||
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx906),
|
||||
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx908),
|
||||
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx90a),
|
||||
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx1030),
|
||||
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx1100),
|
||||
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx942),
|
||||
// ROCm 6.1
|
||||
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx900),
|
||||
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx906),
|
||||
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx908),
|
||||
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx90a),
|
||||
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx942),
|
||||
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1030),
|
||||
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1100),
|
||||
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1101),
|
||||
// ROCm 6.0
|
||||
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx900),
|
||||
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx906),
|
||||
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx908),
|
||||
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx90a),
|
||||
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx1030),
|
||||
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx1100),
|
||||
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx942),
|
||||
]
|
||||
});
|
||||
|
||||
static CPU_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cpu").unwrap());
|
||||
static CU128_INDEX_URL: LazyLock<IndexUrl> =
|
||||
|
|
|
@ -444,10 +444,10 @@ $ # With an environment variable.
|
|||
$ UV_TORCH_BACKEND=auto uv pip install torch
|
||||
```
|
||||
|
||||
When enabled, uv will query for the installed CUDA driver version and use the most-compatible
|
||||
PyTorch index for all relevant packages (e.g., `torch`, `torchvision`, etc.). If no such CUDA driver
|
||||
is found, uv will fall back to the CPU-only index. uv will continue to respect existing index
|
||||
configuration for any packages outside the PyTorch ecosystem.
|
||||
When enabled, uv will query for the installed CUDA driver and AMD GPU versions then use the
|
||||
most-compatible PyTorch index for all relevant packages (e.g., `torch`, `torchvision`, etc.). If no
|
||||
such GPU is found, uv will fall back to the CPU-only index. uv will continue to respect existing
|
||||
index configuration for any packages outside the PyTorch ecosystem.
|
||||
|
||||
You can also select a specific backend (e.g., CUDA 12.6) with `--torch-backend=cu126` (or
|
||||
`UV_TORCH_BACKEND=cu126`):
|
||||
|
@ -460,5 +460,4 @@ $ # With an environment variable.
|
|||
$ UV_TORCH_BACKEND=cu126 uv pip install torch torchvision
|
||||
```
|
||||
|
||||
At present, `--torch-backend` is only available in the `uv pip` interface, and only supports
|
||||
detection of CUDA drivers (as opposed to other accelerators like ROCm or Intel GPUs).
|
||||
At present, `--torch-backend` is only available in the `uv pip` interface.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue