Add auto-detection for AMD GPUs (#14176)
Some checks failed
CI / Determine changes (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / cargo shear (push) Has been cancelled
CI / mkdocs (push) Has been cancelled
CI / typos (push) Has been cancelled
CI / cargo clippy | ubuntu (push) Has been cancelled
CI / cargo clippy | windows (push) Has been cancelled
CI / cargo dev generate-all (push) Has been cancelled
CI / cargo test | ubuntu (push) Has been cancelled
CI / cargo test | macos (push) Has been cancelled
CI / cargo test | windows (push) Has been cancelled
CI / check windows trampoline | aarch64 (push) Has been cancelled
CI / build binary | windows aarch64 (push) Has been cancelled
CI / check windows trampoline | i686 (push) Has been cancelled
CI / check windows trampoline | x86_64 (push) Has been cancelled
CI / test windows trampoline | i686 (push) Has been cancelled
CI / test windows trampoline | x86_64 (push) Has been cancelled
CI / build binary | linux libc (push) Has been cancelled
CI / build binary | linux musl (push) Has been cancelled
CI / build binary | macos aarch64 (push) Has been cancelled
CI / build binary | macos x86_64 (push) Has been cancelled
CI / build binary | windows x86_64 (push) Has been cancelled
CI / cargo build (msrv) (push) Has been cancelled
CI / build binary | freebsd (push) Has been cancelled
CI / ecosystem test | pydantic/pydantic-core (push) Has been cancelled
CI / ecosystem test | prefecthq/prefect (push) Has been cancelled
CI / ecosystem test | pallets/flask (push) Has been cancelled
CI / smoke test | linux (push) Has been cancelled
CI / check system | alpine (push) Has been cancelled
CI / smoke test | macos (push) Has been cancelled
CI / smoke test | windows x86_64 (push) Has been cancelled
CI / smoke test | windows aarch64 (push) Has been cancelled
CI / integration test | conda on ubuntu (push) Has been cancelled
CI / integration test | deadsnakes python3.9 on ubuntu (push) Has been cancelled
CI / integration test | free-threaded on windows (push) Has been cancelled
CI / integration test | pypy on ubuntu (push) Has been cancelled
CI / integration test | pypy on windows (push) Has been cancelled
CI / integration test | graalpy on ubuntu (push) Has been cancelled
CI / integration test | graalpy on windows (push) Has been cancelled
CI / integration test | pyodide on ubuntu (push) Has been cancelled
CI / integration test | github actions (push) Has been cancelled
CI / integration test | free-threaded python on github actions (push) Has been cancelled
CI / integration test | determine publish changes (push) Has been cancelled
CI / integration test | registries (push) Has been cancelled
CI / integration test | uv publish (push) Has been cancelled
CI / integration test | uv_build (push) Has been cancelled
CI / check cache | ubuntu (push) Has been cancelled
CI / check cache | macos aarch64 (push) Has been cancelled
CI / check system | python on debian (push) Has been cancelled
CI / check system | python on fedora (push) Has been cancelled
CI / check system | python on ubuntu (push) Has been cancelled
CI / check system | python on rocky linux 8 (push) Has been cancelled
CI / check system | python on rocky linux 9 (push) Has been cancelled
CI / check system | graalpy on ubuntu (push) Has been cancelled
CI / check system | pypy on ubuntu (push) Has been cancelled
CI / check system | pyston (push) Has been cancelled
CI / check system | python on macos aarch64 (push) Has been cancelled
CI / check system | homebrew python on macos aarch64 (push) Has been cancelled
CI / check system | python on macos x86-64 (push) Has been cancelled
CI / check system | python3.10 on windows x86-64 (push) Has been cancelled
CI / check system | python3.10 on windows x86 (push) Has been cancelled
CI / check system | python3.13 on windows x86-64 (push) Has been cancelled
CI / check system | x86-64 python3.13 on windows aarch64 (push) Has been cancelled
CI / check system | windows registry (push) Has been cancelled
CI / check system | python3.12 via chocolatey (push) Has been cancelled
CI / check system | python3.9 via pyenv (push) Has been cancelled
CI / check system | python3.13 (push) Has been cancelled
CI / check system | conda3.11 on macos aarch64 (push) Has been cancelled
CI / check system | conda3.8 on macos aarch64 (push) Has been cancelled
CI / check system | conda3.11 on linux x86-64 (push) Has been cancelled
CI / check system | conda3.8 on linux x86-64 (push) Has been cancelled
CI / check system | conda3.11 on windows x86-64 (push) Has been cancelled
CI / check system | conda3.8 on windows x86-64 (push) Has been cancelled
CI / check system | amazonlinux (push) Has been cancelled
CI / check system | embedded python3.10 on windows x86-64 (push) Has been cancelled
CI / benchmarks | walltime aarch64 linux (push) Has been cancelled
CI / benchmarks | instrumented (push) Has been cancelled

## Summary

Allows `--torch-backend=auto` to detect AMD GPUs. The approach is fairly
well-documented inline, but I opted for `rocm_agent_enumerator` over
(e.g.) `rocminfo` since it seems to be the recommended approach for
scripting:
https://rocm.docs.amd.com/projects/rocminfo/en/latest/how-to/use-rocm-agent-enumerator.html.

Closes https://github.com/astral-sh/uv/issues/14086.

## Test Plan

```
root@rocm-jupyter-gpu-mi300x1-192gb-devcloud-atl1:~# ./uv-linux-libc-11fb582c5c046bae09766ceddd276dcc5bb41218/uv pip install torch --torch-backend=auto
Resolved 11 packages in 251ms
Prepared 2 packages in 6ms
Installed 11 packages in 257ms
 + filelock==3.18.0
 + fsspec==2025.5.1
 + jinja2==3.1.6
 + markupsafe==3.0.2
 + mpmath==1.3.0
 + networkx==3.5
 + pytorch-triton-rocm==3.3.1
 + setuptools==80.9.0
 + sympy==1.14.0
 + torch==2.7.1+rocm6.3
 + typing-extensions==4.14.0
```

---------

Co-authored-by: Zanie Blue <contact@zanie.dev>
This commit is contained in:
Charlie Marsh 2025-06-21 11:21:06 -04:00 committed by GitHub
parent f0407e4b6f
commit a82c210cab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 257 additions and 40 deletions

View file

@ -1814,7 +1814,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
.options
.torch_backend
.as_ref()
.filter(|torch_backend| matches!(torch_backend, TorchStrategy::Auto { .. }))
.filter(|torch_backend| matches!(torch_backend, TorchStrategy::Cuda { .. }))
.and_then(|_| pins.get(name, version).and_then(ResolvedDist::index))
.map(IndexUrl::url)
.and_then(SystemDependency::from_index)

View file

@ -718,10 +718,14 @@ impl EnvVars {
/// This is a quasi-standard variable, described, e.g., in `ncurses(3x)`.
pub const COLUMNS: &'static str = "COLUMNS";
/// The CUDA driver version to assume when inferring the PyTorch backend.
/// The CUDA driver version to assume when inferring the PyTorch backend (e.g., `550.144.03`).
#[attr_hidden]
pub const UV_CUDA_DRIVER_VERSION: &'static str = "UV_CUDA_DRIVER_VERSION";
/// The AMD GPU architecture to assume when inferring the PyTorch backend (e.g., `gfx1100`).
#[attr_hidden]
pub const UV_AMD_GPU_ARCHITECTURE: &'static str = "UV_AMD_GPU_ARCHITECTURE";
/// Equivalent to the `--torch-backend` command-line argument (e.g., `cpu`, `cu126`, or `auto`).
pub const UV_TORCH_BACKEND: &'static str = "UV_TORCH_BACKEND";

View file

@ -13,17 +13,30 @@ pub enum AcceleratorError {
Version(#[from] uv_pep440::VersionParseError),
#[error(transparent)]
Utf8(#[from] std::string::FromUtf8Error),
#[error("Unknown AMD GPU architecture: {0}")]
UnknownAmdGpuArchitecture(String),
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Accelerator {
/// The CUDA driver version (e.g., `550.144.03`).
///
/// This is in contrast to the CUDA toolkit version (e.g., `12.8.0`).
Cuda { driver_version: Version },
/// The AMD GPU architecture (e.g., `gfx906`).
///
/// This is in contrast to the user-space ROCm version (e.g., `6.4.0-47`) or the kernel-mode
/// driver version (e.g., `6.12.12`).
Amd {
gpu_architecture: AmdGpuArchitecture,
},
}
impl std::fmt::Display for Accelerator {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Cuda { driver_version } => write!(f, "CUDA {driver_version}"),
Self::Amd { gpu_architecture } => write!(f, "AMD {gpu_architecture}"),
}
}
}
@ -33,9 +46,11 @@ impl Accelerator {
///
/// Query, in order:
/// 1. The `UV_CUDA_DRIVER_VERSION` environment variable.
/// 2. The `UV_AMD_GPU_ARCHITECTURE` environment variable.
/// 2. `/sys/module/nvidia/version`, which contains the driver version (e.g., `550.144.03`).
/// 3. `/proc/driver/nvidia/version`, which contains the driver version among other information.
/// 4. `nvidia-smi --query-gpu=driver_version --format=csv,noheader`.
/// 5. `rocm_agent_enumerator`, which lists the AMD GPU architectures.
pub fn detect() -> Result<Option<Self>, AcceleratorError> {
// Read from `UV_CUDA_DRIVER_VERSION`.
if let Ok(driver_version) = std::env::var(EnvVars::UV_CUDA_DRIVER_VERSION) {
@ -44,6 +59,15 @@ impl Accelerator {
return Ok(Some(Self::Cuda { driver_version }));
}
// Read from `UV_AMD_GPU_ARCHITECTURE`.
if let Ok(gpu_architecture) = std::env::var(EnvVars::UV_AMD_GPU_ARCHITECTURE) {
let gpu_architecture = AmdGpuArchitecture::from_str(&gpu_architecture)?;
debug!(
"Detected AMD GPU architecture from `UV_AMD_GPU_ARCHITECTURE`: {gpu_architecture}"
);
return Ok(Some(Self::Amd { gpu_architecture }));
}
// Read from `/sys/module/nvidia/version`.
match fs_err::read_to_string("/sys/module/nvidia/version") {
Ok(content) => {
@ -100,7 +124,34 @@ impl Accelerator {
);
}
debug!("Failed to detect CUDA driver version");
// Query `rocm_agent_enumerator` to detect the AMD GPU architecture.
//
// See: https://rocm.docs.amd.com/projects/rocminfo/en/latest/how-to/use-rocm-agent-enumerator.html
if let Ok(output) = std::process::Command::new("rocm_agent_enumerator").output() {
if output.status.success() {
let stdout = String::from_utf8(output.stdout)?;
if let Some(gpu_architecture) = stdout
.lines()
.map(str::trim)
.filter_map(|line| AmdGpuArchitecture::from_str(line).ok())
.min()
{
debug!(
"Detected AMD GPU architecture from `rocm_agent_enumerator`: {gpu_architecture}"
);
return Ok(Some(Self::Amd { gpu_architecture }));
}
} else {
debug!(
"Failed to query AMD GPU architecture with `rocm_agent_enumerator` with status `{}`: {}",
output.status,
String::from_utf8_lossy(&output.stderr)
);
}
}
debug!("Failed to detect GPU driver version");
Ok(None)
}
}
@ -129,6 +180,63 @@ fn parse_proc_driver_nvidia_version(content: &str) -> Result<Option<Version>, Ac
Ok(Some(driver_version))
}
/// A GPU architecture for AMD GPUs.
///
/// See: <https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html>
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
pub enum AmdGpuArchitecture {
Gfx900,
Gfx906,
Gfx908,
Gfx90a,
Gfx942,
Gfx1030,
Gfx1100,
Gfx1101,
Gfx1102,
Gfx1200,
Gfx1201,
}
impl FromStr for AmdGpuArchitecture {
type Err = AcceleratorError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"gfx900" => Ok(Self::Gfx900),
"gfx906" => Ok(Self::Gfx906),
"gfx908" => Ok(Self::Gfx908),
"gfx90a" => Ok(Self::Gfx90a),
"gfx942" => Ok(Self::Gfx942),
"gfx1030" => Ok(Self::Gfx1030),
"gfx1100" => Ok(Self::Gfx1100),
"gfx1101" => Ok(Self::Gfx1101),
"gfx1102" => Ok(Self::Gfx1102),
"gfx1200" => Ok(Self::Gfx1200),
"gfx1201" => Ok(Self::Gfx1201),
_ => Err(AcceleratorError::UnknownAmdGpuArchitecture(s.to_string())),
}
}
}
impl std::fmt::Display for AmdGpuArchitecture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Gfx900 => write!(f, "gfx900"),
Self::Gfx906 => write!(f, "gfx906"),
Self::Gfx908 => write!(f, "gfx908"),
Self::Gfx90a => write!(f, "gfx90a"),
Self::Gfx942 => write!(f, "gfx942"),
Self::Gfx1030 => write!(f, "gfx1030"),
Self::Gfx1100 => write!(f, "gfx1100"),
Self::Gfx1101 => write!(f, "gfx1101"),
Self::Gfx1102 => write!(f, "gfx1102"),
Self::Gfx1200 => write!(f, "gfx1200"),
Self::Gfx1201 => write!(f, "gfx1201"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -47,7 +47,7 @@ use uv_normalize::PackageName;
use uv_pep440::Version;
use uv_platform_tags::Os;
use crate::{Accelerator, AcceleratorError};
use crate::{Accelerator, AcceleratorError, AmdGpuArchitecture};
/// The strategy to use when determining the appropriate PyTorch index.
#[derive(Debug, Copy, Clone, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
@ -178,8 +178,13 @@ pub enum TorchMode {
/// The strategy to use when determining the appropriate PyTorch index.
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum TorchStrategy {
/// Select the appropriate PyTorch index based on the operating system and CUDA driver version.
Auto { os: Os, driver_version: Version },
/// Select the appropriate PyTorch index based on the operating system and CUDA driver version (e.g., `550.144.03`).
Cuda { os: Os, driver_version: Version },
/// Select the appropriate PyTorch index based on the operating system and AMD GPU architecture (e.g., `gfx1100`).
Amd {
os: Os,
gpu_architecture: AmdGpuArchitecture,
},
/// Use the specified PyTorch index.
Backend(TorchBackend),
}
@ -188,16 +193,17 @@ impl TorchStrategy {
/// Determine the [`TorchStrategy`] from the given [`TorchMode`], [`Os`], and [`Accelerator`].
pub fn from_mode(mode: TorchMode, os: &Os) -> Result<Self, AcceleratorError> {
match mode {
TorchMode::Auto => {
if let Some(Accelerator::Cuda { driver_version }) = Accelerator::detect()? {
Ok(Self::Auto {
os: os.clone(),
driver_version: driver_version.clone(),
})
} else {
Ok(Self::Backend(TorchBackend::Cpu))
}
}
TorchMode::Auto => match Accelerator::detect()? {
Some(Accelerator::Cuda { driver_version }) => Ok(Self::Cuda {
os: os.clone(),
driver_version: driver_version.clone(),
}),
Some(Accelerator::Amd { gpu_architecture }) => Ok(Self::Amd {
os: os.clone(),
gpu_architecture,
}),
None => Ok(Self::Backend(TorchBackend::Cpu)),
},
TorchMode::Cpu => Ok(Self::Backend(TorchBackend::Cpu)),
TorchMode::Cu128 => Ok(Self::Backend(TorchBackend::Cu128)),
TorchMode::Cu126 => Ok(Self::Backend(TorchBackend::Cu126)),
@ -267,25 +273,27 @@ impl TorchStrategy {
/// Return the appropriate index URLs for the given [`TorchStrategy`].
pub fn index_urls(&self) -> impl Iterator<Item = &IndexUrl> {
match self {
TorchStrategy::Auto { os, driver_version } => {
TorchStrategy::Cuda { os, driver_version } => {
// If this is a GPU-enabled package, and CUDA drivers are installed, use PyTorch's CUDA
// indexes.
//
// See: https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_patch.py#L36-L49
match os {
Os::Manylinux { .. } | Os::Musllinux { .. } => Either::Left(Either::Left(
LINUX_DRIVERS
.iter()
.filter_map(move |(backend, version)| {
if driver_version >= version {
Some(backend.index_url())
} else {
None
}
})
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
)),
Os::Windows => Either::Left(Either::Right(
Os::Manylinux { .. } | Os::Musllinux { .. } => {
Either::Left(Either::Left(Either::Left(
LINUX_CUDA_DRIVERS
.iter()
.filter_map(move |(backend, version)| {
if driver_version >= version {
Some(backend.index_url())
} else {
None
}
})
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
)))
}
Os::Windows => Either::Left(Either::Left(Either::Right(
WINDOWS_CUDA_VERSIONS
.iter()
.filter_map(move |(backend, version)| {
@ -296,7 +304,7 @@ impl TorchStrategy {
}
})
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
)),
))),
Os::Macos { .. }
| Os::FreeBsd { .. }
| Os::NetBsd { .. }
@ -306,11 +314,42 @@ impl TorchStrategy {
| Os::Haiku { .. }
| Os::Android { .. }
| Os::Pyodide { .. } => {
Either::Right(std::iter::once(TorchBackend::Cpu.index_url()))
Either::Right(Either::Left(std::iter::once(TorchBackend::Cpu.index_url())))
}
}
}
TorchStrategy::Backend(backend) => Either::Right(std::iter::once(backend.index_url())),
TorchStrategy::Amd {
os,
gpu_architecture,
} => match os {
Os::Manylinux { .. } | Os::Musllinux { .. } => Either::Left(Either::Right(
LINUX_AMD_GPU_DRIVERS
.iter()
.filter_map(move |(backend, architecture)| {
if gpu_architecture == architecture {
Some(backend.index_url())
} else {
None
}
})
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
)),
Os::Windows
| Os::Macos { .. }
| Os::FreeBsd { .. }
| Os::NetBsd { .. }
| Os::OpenBsd { .. }
| Os::Dragonfly { .. }
| Os::Illumos { .. }
| Os::Haiku { .. }
| Os::Android { .. }
| Os::Pyodide { .. } => {
Either::Right(Either::Left(std::iter::once(TorchBackend::Cpu.index_url())))
}
},
TorchStrategy::Backend(backend) => {
Either::Right(Either::Right(std::iter::once(backend.index_url())))
}
}
}
}
@ -578,7 +617,7 @@ impl FromStr for TorchBackend {
/// Linux CUDA driver versions and the corresponding CUDA versions.
///
/// See: <https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_cb.py#L150-L213>
static LINUX_DRIVERS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock::new(|| {
static LINUX_CUDA_DRIVERS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock::new(|| {
[
// Table 2 from
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
@ -651,6 +690,73 @@ static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock
]
});
/// Linux AMD GPU architectures and the corresponding PyTorch backends.
///
/// These were inferred by running the following snippet for each ROCm version:
///
/// ```python
/// import torch
///
/// print(torch.cuda.get_arch_list())
/// ```
///
/// AMD also provides a compatibility matrix: <https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html>;
/// however, this list includes a broader array of GPUs than those in the matrix.
static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 44]> =
LazyLock::new(|| {
[
// ROCm 6.3
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx900),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx906),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx908),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx90a),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx942),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1030),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1100),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1101),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1102),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1200),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx1201),
// ROCm 6.2.4
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx900),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx906),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx908),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx90a),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx942),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1030),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1100),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1101),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1102),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1200),
(TorchBackend::Rocm624, AmdGpuArchitecture::Gfx1201),
// ROCm 6.2
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx900),
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx906),
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx908),
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx90a),
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx1030),
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx1100),
(TorchBackend::Rocm62, AmdGpuArchitecture::Gfx942),
// ROCm 6.1
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx900),
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx906),
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx908),
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx90a),
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx942),
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1030),
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1100),
(TorchBackend::Rocm61, AmdGpuArchitecture::Gfx1101),
// ROCm 6.0
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx900),
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx906),
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx908),
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx90a),
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx1030),
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx1100),
(TorchBackend::Rocm60, AmdGpuArchitecture::Gfx942),
]
});
static CPU_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cpu").unwrap());
static CU128_INDEX_URL: LazyLock<IndexUrl> =

View file

@ -444,10 +444,10 @@ $ # With an environment variable.
$ UV_TORCH_BACKEND=auto uv pip install torch
```
When enabled, uv will query for the installed CUDA driver version and use the most-compatible
PyTorch index for all relevant packages (e.g., `torch`, `torchvision`, etc.). If no such CUDA driver
is found, uv will fall back to the CPU-only index. uv will continue to respect existing index
configuration for any packages outside the PyTorch ecosystem.
When enabled, uv will query for the installed CUDA driver and AMD GPU versions then use the
most-compatible PyTorch index for all relevant packages (e.g., `torch`, `torchvision`, etc.). If no
such GPU is found, uv will fall back to the CPU-only index. uv will continue to respect existing
index configuration for any packages outside the PyTorch ecosystem.
You can also select a specific backend (e.g., CUDA 12.6) with `--torch-backend=cu126` (or
`UV_TORCH_BACKEND=cu126`):
@ -460,5 +460,4 @@ $ # With an environment variable.
$ UV_TORCH_BACKEND=cu126 uv pip install torch torchvision
```
At present, `--torch-backend` is only available in the `uv pip` interface, and only supports
detection of CUDA drivers (as opposed to other accelerators like ROCm or Intel GPUs).
At present, `--torch-backend` is only available in the `uv pip` interface.