mirror of
https://github.com/astral-sh/uv.git
synced 2025-08-04 19:08:04 +00:00
Automatically infer the PyTorch index via --torch-backend=auto
(#12070)
## Summary This is a prototype that I'm considering shipping under `--preview`, based on [`light-the-torch`](https://github.com/pmeier/light-the-torch). `light-the-torch` patches pip to pull PyTorch packages from the PyTorch indexes automatically. And, in particular, `light-the-torch` will query the installed CUDA drivers to determine which indexes are compatible with your system. This PR implements equivalent behavior under `--torch-backend auto`, though you can also set `--torch-backend cpu`, etc. for convenience. When enabled, the registry client will fetch from the appropriate PyTorch index when it sees a package from the PyTorch ecosystem (and ignore any other configured indexes, _unless_ the package is explicitly pinned to a different index). Right now, this is only implemented in the `uv pip` CLI, since it doesn't quite fit into the lockfile APIs given that it relies on feature detection on the currently-running machine. ## Test Plan On macOS, you can test this with (e.g.): ```shell UV_TORCH_BACKEND=auto UV_CUDA_DRIVER_VERSION=450.80.2 cargo run \ pip install torch --python-platform linux --python-version 3.12 ``` On a GPU-enabled EC2 machine: ```shell ubuntu@ip-172-31-47-149:~/uv$ UV_TORCH_BACKEND=auto cargo run pip install torch -v Finished `dev` profile [unoptimized + debuginfo] target(s) in 0.31s Running `target/debug/uv pip install torch -v` DEBUG uv 0.6.6 (e95ca063b 2025-03-14) DEBUG Searching for default Python interpreter in virtual environments DEBUG Found `cpython-3.13.0-linux-x86_64-gnu` at `/home/ubuntu/uv/.venv/bin/python3` (virtual environment) DEBUG Using Python 3.13.0 environment at: .venv DEBUG Acquired lock for `.venv` DEBUG At least one requirement is not satisfied: torch warning: The `--torch-backend` setting is experimental and may change without warning. Pass `--preview` to disable this warning. DEBUG Detected CUDA driver version from `/sys/module/nvidia/version`: 550.144.3 ... ```
This commit is contained in:
parent
e40c551b80
commit
5173b59b50
31 changed files with 1289 additions and 29 deletions
28
crates/uv-torch/Cargo.toml
Normal file
28
crates/uv-torch/Cargo.toml
Normal file
|
@ -0,0 +1,28 @@
|
|||
[package]
|
||||
name = "uv-torch"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
rust-version.workspace = true
|
||||
homepage.workspace = true
|
||||
documentation.workspace = true
|
||||
repository.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
uv-distribution-types = { workspace = true }
|
||||
uv-normalize = { workspace = true }
|
||||
uv-pep440 = { workspace = true }
|
||||
uv-platform-tags = { workspace = true }
|
||||
uv-static = { workspace = true }
|
||||
|
||||
clap = { workspace = true, optional = true }
|
||||
either = { workspace = true }
|
||||
fs-err = { workspace = true }
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
142
crates/uv-torch/src/accelerator.rs
Normal file
142
crates/uv-torch/src/accelerator.rs
Normal file
|
@ -0,0 +1,142 @@
|
|||
use std::str::FromStr;
|
||||
|
||||
use tracing::debug;
|
||||
|
||||
use uv_pep440::Version;
|
||||
use uv_static::EnvVars;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AcceleratorError {
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error(transparent)]
|
||||
Version(#[from] uv_pep440::VersionParseError),
|
||||
#[error(transparent)]
|
||||
Utf8(#[from] std::string::FromUtf8Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum Accelerator {
|
||||
Cuda { driver_version: Version },
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Accelerator {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Cuda { driver_version } => write!(f, "CUDA {driver_version}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Accelerator {
|
||||
/// Detect the CUDA driver version from the system.
|
||||
///
|
||||
/// Query, in order:
|
||||
/// 1. The `UV_CUDA_DRIVER_VERSION` environment variable.
|
||||
/// 2. `/sys/module/nvidia/version`, which contains the driver version (e.g., `550.144.03`).
|
||||
/// 3. `/proc/driver/nvidia/version`, which contains the driver version among other information.
|
||||
/// 4. `nvidia-smi --query-gpu=driver_version --format=csv,noheader`.
|
||||
pub fn detect() -> Result<Option<Self>, AcceleratorError> {
|
||||
// Read from `UV_CUDA_DRIVER_VERSION`.
|
||||
if let Ok(driver_version) = std::env::var(EnvVars::UV_CUDA_DRIVER_VERSION) {
|
||||
let driver_version = Version::from_str(&driver_version)?;
|
||||
debug!("Detected CUDA driver version from `UV_CUDA_DRIVER_VERSION`: {driver_version}");
|
||||
return Ok(Some(Self::Cuda { driver_version }));
|
||||
}
|
||||
|
||||
// Read from `/sys/module/nvidia/version`.
|
||||
match fs_err::read_to_string("/sys/module/nvidia/version") {
|
||||
Ok(content) => {
|
||||
return match parse_sys_module_nvidia_version(&content) {
|
||||
Ok(driver_version) => {
|
||||
debug!("Detected CUDA driver version from `/sys/module/nvidia/version`: {driver_version}");
|
||||
Ok(Some(Self::Cuda { driver_version }))
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
|
||||
// Read from `/proc/driver/nvidia/version`
|
||||
match fs_err::read_to_string("/proc/driver/nvidia/version") {
|
||||
Ok(content) => {
|
||||
match parse_proc_driver_nvidia_version(&content) {
|
||||
Ok(Some(driver_version)) => {
|
||||
debug!("Detected CUDA driver version from `/proc/driver/nvidia/version`: {driver_version}");
|
||||
return Ok(Some(Self::Cuda { driver_version }));
|
||||
}
|
||||
Ok(None) => {
|
||||
debug!("Failed to parse CUDA driver version from `/proc/driver/nvidia/version`");
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
|
||||
// Query `nvidia-smi`.
|
||||
if let Ok(output) = std::process::Command::new("nvidia-smi")
|
||||
.arg("--query-gpu=driver_version")
|
||||
.arg("--format=csv,noheader")
|
||||
.output()
|
||||
{
|
||||
if output.status.success() {
|
||||
let driver_version = Version::from_str(&String::from_utf8(output.stdout)?)?;
|
||||
debug!("Detected CUDA driver version from `nvidia-smi`: {driver_version}");
|
||||
return Ok(Some(Self::Cuda { driver_version }));
|
||||
}
|
||||
|
||||
debug!(
|
||||
"Failed to query CUDA driver version with `nvidia-smi` with status `{}`: {}",
|
||||
output.status,
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
}
|
||||
|
||||
debug!("Failed to detect CUDA driver version");
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse the CUDA driver version from the content of `/sys/module/nvidia/version`.
|
||||
fn parse_sys_module_nvidia_version(content: &str) -> Result<Version, AcceleratorError> {
|
||||
// Parse, e.g.:
|
||||
// ```text
|
||||
// 550.144.03
|
||||
// ```
|
||||
let driver_version = Version::from_str(content.trim())?;
|
||||
Ok(driver_version)
|
||||
}
|
||||
|
||||
/// Parse the CUDA driver version from the content of `/proc/driver/nvidia/version`.
|
||||
fn parse_proc_driver_nvidia_version(content: &str) -> Result<Option<Version>, AcceleratorError> {
|
||||
// Parse, e.g.:
|
||||
// ```text
|
||||
// NVRM version: NVIDIA UNIX Open Kernel Module for x86_64 550.144.03 Release Build (dvs-builder@U16-I3-D08-1-2) Mon Dec 30 17:26:13 UTC 2024
|
||||
// GCC version: gcc version 12.3.0 (Ubuntu 12.3.0-1ubuntu1~22.04)
|
||||
// ```
|
||||
let Some(version) = content.split(" ").nth(1) else {
|
||||
return Ok(None);
|
||||
};
|
||||
let driver_version = Version::from_str(version.trim())?;
|
||||
Ok(Some(driver_version))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn proc_driver_nvidia_version() {
|
||||
let content = "NVRM version: NVIDIA UNIX Open Kernel Module for x86_64 550.144.03 Release Build (dvs-builder@U16-I3-D08-1-2) Mon Dec 30 17:26:13 UTC 2024\nGCC version: gcc version 12.3.0 (Ubuntu 12.3.0-1ubuntu1~22.04)";
|
||||
let result = parse_proc_driver_nvidia_version(content).unwrap();
|
||||
assert_eq!(result, Some(Version::from_str("550.144.03").unwrap()));
|
||||
|
||||
let content = "NVRM version: NVIDIA UNIX x86_64 Kernel Module 375.74 Wed Jun 14 01:39:39 PDT 2017\nGCC version: gcc version 5.4.0 20160609 (Ubuntu 5.4.0-6ubuntu1~16.04.4)";
|
||||
let result = parse_proc_driver_nvidia_version(content).unwrap();
|
||||
assert_eq!(result, Some(Version::from_str("375.74").unwrap()));
|
||||
}
|
||||
}
|
417
crates/uv-torch/src/backend.rs
Normal file
417
crates/uv-torch/src/backend.rs
Normal file
|
@ -0,0 +1,417 @@
|
|||
//! `uv-torch` is a library for determining the appropriate PyTorch index based on the operating
|
||||
//! system and CUDA driver version.
|
||||
//!
|
||||
//! This library is derived from `light-the-torch` by Philipp Meier, which is available under the
|
||||
//! following BSD-3 Clause license:
|
||||
//!
|
||||
//! ```text
|
||||
//! BSD 3-Clause License
|
||||
//!
|
||||
//! Copyright (c) 2020, Philip Meier
|
||||
//! All rights reserved.
|
||||
//!
|
||||
//! Redistribution and use in source and binary forms, with or without
|
||||
//! modification, are permitted provided that the following conditions are met:
|
||||
//!
|
||||
//! 1. Redistributions of source code must retain the above copyright notice, this
|
||||
//! list of conditions and the following disclaimer.
|
||||
//!
|
||||
//! 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
//! this list of conditions and the following disclaimer in the documentation
|
||||
//! and/or other materials provided with the distribution.
|
||||
//!
|
||||
//! 3. Neither the name of the copyright holder nor the names of its
|
||||
//! contributors may be used to endorse or promote products derived from
|
||||
//! this software without specific prior written permission.
|
||||
//!
|
||||
//! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
//! AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
//! IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
//! DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
//! FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
//! DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
//! SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
//! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
//! OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
//! OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
//! ```
|
||||
//!
|
||||
|
||||
use std::str::FromStr;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
use either::Either;
|
||||
|
||||
use uv_distribution_types::IndexUrl;
|
||||
use uv_normalize::PackageName;
|
||||
use uv_pep440::Version;
|
||||
use uv_platform_tags::Os;
|
||||
|
||||
use crate::{Accelerator, AcceleratorError};
|
||||
|
||||
/// The strategy to use when determining the appropriate PyTorch index.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
|
||||
#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum TorchMode {
|
||||
/// Select the appropriate PyTorch index based on the operating system and CUDA driver version.
|
||||
Auto,
|
||||
/// Use the CPU-only PyTorch index.
|
||||
Cpu,
|
||||
/// Use the PyTorch index for CUDA 12.6.
|
||||
Cu126,
|
||||
/// Use the PyTorch index for CUDA 12.5.
|
||||
Cu125,
|
||||
/// Use the PyTorch index for CUDA 12.4.
|
||||
Cu124,
|
||||
/// Use the PyTorch index for CUDA 12.3.
|
||||
Cu123,
|
||||
/// Use the PyTorch index for CUDA 12.2.
|
||||
Cu122,
|
||||
/// Use the PyTorch index for CUDA 12.1.
|
||||
Cu121,
|
||||
/// Use the PyTorch index for CUDA 12.0.
|
||||
Cu120,
|
||||
/// Use the PyTorch index for CUDA 11.8.
|
||||
Cu118,
|
||||
/// Use the PyTorch index for CUDA 11.7.
|
||||
Cu117,
|
||||
/// Use the PyTorch index for CUDA 11.6.
|
||||
Cu116,
|
||||
/// Use the PyTorch index for CUDA 11.5.
|
||||
Cu115,
|
||||
/// Use the PyTorch index for CUDA 11.4.
|
||||
Cu114,
|
||||
/// Use the PyTorch index for CUDA 11.3.
|
||||
Cu113,
|
||||
/// Use the PyTorch index for CUDA 11.2.
|
||||
Cu112,
|
||||
/// Use the PyTorch index for CUDA 11.1.
|
||||
Cu111,
|
||||
/// Use the PyTorch index for CUDA 11.0.
|
||||
Cu110,
|
||||
/// Use the PyTorch index for CUDA 10.2.
|
||||
Cu102,
|
||||
/// Use the PyTorch index for CUDA 10.1.
|
||||
Cu101,
|
||||
/// Use the PyTorch index for CUDA 10.0.
|
||||
Cu100,
|
||||
/// Use the PyTorch index for CUDA 9.2.
|
||||
Cu92,
|
||||
/// Use the PyTorch index for CUDA 9.1.
|
||||
Cu91,
|
||||
/// Use the PyTorch index for CUDA 9.0.
|
||||
Cu90,
|
||||
/// Use the PyTorch index for CUDA 8.0.
|
||||
Cu80,
|
||||
}
|
||||
|
||||
/// The strategy to use when determining the appropriate PyTorch index.
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum TorchStrategy {
|
||||
/// Select the appropriate PyTorch index based on the operating system and CUDA driver version.
|
||||
Auto { os: Os, driver_version: Version },
|
||||
/// Use the specified PyTorch index.
|
||||
Backend(TorchBackend),
|
||||
}
|
||||
|
||||
impl TorchStrategy {
|
||||
/// Determine the [`TorchStrategy`] from the given [`TorchMode`], [`Os`], and [`Accelerator`].
|
||||
pub fn from_mode(mode: TorchMode, os: &Os) -> Result<Self, AcceleratorError> {
|
||||
match mode {
|
||||
TorchMode::Auto => {
|
||||
if let Some(Accelerator::Cuda { driver_version }) = Accelerator::detect()? {
|
||||
Ok(Self::Auto {
|
||||
os: os.clone(),
|
||||
driver_version: driver_version.clone(),
|
||||
})
|
||||
} else {
|
||||
Ok(Self::Backend(TorchBackend::Cpu))
|
||||
}
|
||||
}
|
||||
TorchMode::Cpu => Ok(Self::Backend(TorchBackend::Cpu)),
|
||||
TorchMode::Cu126 => Ok(Self::Backend(TorchBackend::Cu126)),
|
||||
TorchMode::Cu125 => Ok(Self::Backend(TorchBackend::Cu125)),
|
||||
TorchMode::Cu124 => Ok(Self::Backend(TorchBackend::Cu124)),
|
||||
TorchMode::Cu123 => Ok(Self::Backend(TorchBackend::Cu123)),
|
||||
TorchMode::Cu122 => Ok(Self::Backend(TorchBackend::Cu122)),
|
||||
TorchMode::Cu121 => Ok(Self::Backend(TorchBackend::Cu121)),
|
||||
TorchMode::Cu120 => Ok(Self::Backend(TorchBackend::Cu120)),
|
||||
TorchMode::Cu118 => Ok(Self::Backend(TorchBackend::Cu118)),
|
||||
TorchMode::Cu117 => Ok(Self::Backend(TorchBackend::Cu117)),
|
||||
TorchMode::Cu116 => Ok(Self::Backend(TorchBackend::Cu116)),
|
||||
TorchMode::Cu115 => Ok(Self::Backend(TorchBackend::Cu115)),
|
||||
TorchMode::Cu114 => Ok(Self::Backend(TorchBackend::Cu114)),
|
||||
TorchMode::Cu113 => Ok(Self::Backend(TorchBackend::Cu113)),
|
||||
TorchMode::Cu112 => Ok(Self::Backend(TorchBackend::Cu112)),
|
||||
TorchMode::Cu111 => Ok(Self::Backend(TorchBackend::Cu111)),
|
||||
TorchMode::Cu110 => Ok(Self::Backend(TorchBackend::Cu110)),
|
||||
TorchMode::Cu102 => Ok(Self::Backend(TorchBackend::Cu102)),
|
||||
TorchMode::Cu101 => Ok(Self::Backend(TorchBackend::Cu101)),
|
||||
TorchMode::Cu100 => Ok(Self::Backend(TorchBackend::Cu100)),
|
||||
TorchMode::Cu92 => Ok(Self::Backend(TorchBackend::Cu92)),
|
||||
TorchMode::Cu91 => Ok(Self::Backend(TorchBackend::Cu91)),
|
||||
TorchMode::Cu90 => Ok(Self::Backend(TorchBackend::Cu90)),
|
||||
TorchMode::Cu80 => Ok(Self::Backend(TorchBackend::Cu80)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the appropriate index URLs for the given [`TorchStrategy`] and [`PackageName`].
|
||||
pub fn index_urls(
|
||||
&self,
|
||||
package_name: &PackageName,
|
||||
) -> Option<impl Iterator<Item = &IndexUrl>> {
|
||||
if !matches!(
|
||||
package_name.as_str(),
|
||||
"torch"
|
||||
| "torch-model-archiver"
|
||||
| "torch-tb-profiler"
|
||||
| "torcharrow"
|
||||
| "torchaudio"
|
||||
| "torchcsprng"
|
||||
| "torchdata"
|
||||
| "torchdistx"
|
||||
| "torchserve"
|
||||
| "torchtext"
|
||||
| "torchvision"
|
||||
| "pytorch-triton"
|
||||
) {
|
||||
return None;
|
||||
}
|
||||
|
||||
match self {
|
||||
TorchStrategy::Auto { os, driver_version } => {
|
||||
// If this is a GPU-enabled package, and CUDA drivers are installed, use PyTorch's CUDA
|
||||
// indexes.
|
||||
//
|
||||
// See: https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_patch.py#L36-L49
|
||||
match os {
|
||||
Os::Manylinux { .. } | Os::Musllinux { .. } => {
|
||||
Some(Either::Left(Either::Left(
|
||||
LINUX_DRIVERS
|
||||
.iter()
|
||||
.filter_map(move |(backend, version)| {
|
||||
if driver_version >= version {
|
||||
Some(backend.index_url())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
|
||||
)))
|
||||
}
|
||||
Os::Windows => Some(Either::Left(Either::Right(
|
||||
WINDOWS_CUDA_VERSIONS
|
||||
.iter()
|
||||
.filter_map(move |(backend, version)| {
|
||||
if driver_version >= version {
|
||||
Some(backend.index_url())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
|
||||
))),
|
||||
Os::Macos { .. }
|
||||
| Os::FreeBsd { .. }
|
||||
| Os::NetBsd { .. }
|
||||
| Os::OpenBsd { .. }
|
||||
| Os::Dragonfly { .. }
|
||||
| Os::Illumos { .. }
|
||||
| Os::Haiku { .. }
|
||||
| Os::Android { .. } => Some(Either::Right(std::iter::once(
|
||||
TorchBackend::Cpu.index_url(),
|
||||
))),
|
||||
}
|
||||
}
|
||||
TorchStrategy::Backend(backend) => {
|
||||
Some(Either::Right(std::iter::once(backend.index_url())))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The available backends for PyTorch.
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum TorchBackend {
|
||||
Cpu,
|
||||
Cu126,
|
||||
Cu125,
|
||||
Cu124,
|
||||
Cu123,
|
||||
Cu122,
|
||||
Cu121,
|
||||
Cu120,
|
||||
Cu118,
|
||||
Cu117,
|
||||
Cu116,
|
||||
Cu115,
|
||||
Cu114,
|
||||
Cu113,
|
||||
Cu112,
|
||||
Cu111,
|
||||
Cu110,
|
||||
Cu102,
|
||||
Cu101,
|
||||
Cu100,
|
||||
Cu92,
|
||||
Cu91,
|
||||
Cu90,
|
||||
Cu80,
|
||||
}
|
||||
|
||||
impl TorchBackend {
|
||||
/// Return the appropriate index URL for the given [`TorchBackend`].
|
||||
fn index_url(&self) -> &'static IndexUrl {
|
||||
match self {
|
||||
Self::Cpu => &CPU_INDEX_URL,
|
||||
Self::Cu126 => &CU126_INDEX_URL,
|
||||
Self::Cu125 => &CU125_INDEX_URL,
|
||||
Self::Cu124 => &CU124_INDEX_URL,
|
||||
Self::Cu123 => &CU123_INDEX_URL,
|
||||
Self::Cu122 => &CU122_INDEX_URL,
|
||||
Self::Cu121 => &CU121_INDEX_URL,
|
||||
Self::Cu120 => &CU120_INDEX_URL,
|
||||
Self::Cu118 => &CU118_INDEX_URL,
|
||||
Self::Cu117 => &CU117_INDEX_URL,
|
||||
Self::Cu116 => &CU116_INDEX_URL,
|
||||
Self::Cu115 => &CU115_INDEX_URL,
|
||||
Self::Cu114 => &CU114_INDEX_URL,
|
||||
Self::Cu113 => &CU113_INDEX_URL,
|
||||
Self::Cu112 => &CU112_INDEX_URL,
|
||||
Self::Cu111 => &CU111_INDEX_URL,
|
||||
Self::Cu110 => &CU110_INDEX_URL,
|
||||
Self::Cu102 => &CU102_INDEX_URL,
|
||||
Self::Cu101 => &CU101_INDEX_URL,
|
||||
Self::Cu100 => &CU100_INDEX_URL,
|
||||
Self::Cu92 => &CU92_INDEX_URL,
|
||||
Self::Cu91 => &CU91_INDEX_URL,
|
||||
Self::Cu90 => &CU90_INDEX_URL,
|
||||
Self::Cu80 => &CU80_INDEX_URL,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Linux CUDA driver versions and the corresponding CUDA versions.
|
||||
///
|
||||
/// See: <https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_cb.py#L150-L213>
|
||||
static LINUX_DRIVERS: LazyLock<[(TorchBackend, Version); 23]> = LazyLock::new(|| {
|
||||
[
|
||||
// Table 2 from
|
||||
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
|
||||
(TorchBackend::Cu126, Version::new([525, 60, 13])),
|
||||
(TorchBackend::Cu125, Version::new([525, 60, 13])),
|
||||
(TorchBackend::Cu124, Version::new([525, 60, 13])),
|
||||
(TorchBackend::Cu123, Version::new([525, 60, 13])),
|
||||
(TorchBackend::Cu122, Version::new([525, 60, 13])),
|
||||
(TorchBackend::Cu121, Version::new([525, 60, 13])),
|
||||
(TorchBackend::Cu120, Version::new([525, 60, 13])),
|
||||
// Table 2 from
|
||||
// https://docs.nvidia.com/cuda/archive/11.8.0/cuda-toolkit-release-notes/index.html
|
||||
(TorchBackend::Cu118, Version::new([450, 80, 2])),
|
||||
(TorchBackend::Cu117, Version::new([450, 80, 2])),
|
||||
(TorchBackend::Cu116, Version::new([450, 80, 2])),
|
||||
(TorchBackend::Cu115, Version::new([450, 80, 2])),
|
||||
(TorchBackend::Cu114, Version::new([450, 80, 2])),
|
||||
(TorchBackend::Cu113, Version::new([450, 80, 2])),
|
||||
(TorchBackend::Cu112, Version::new([450, 80, 2])),
|
||||
(TorchBackend::Cu111, Version::new([450, 80, 2])),
|
||||
(TorchBackend::Cu110, Version::new([450, 36, 6])),
|
||||
// Table 1 from
|
||||
// https://docs.nvidia.com/cuda/archive/10.2/cuda-toolkit-release-notes/index.html
|
||||
(TorchBackend::Cu102, Version::new([440, 33])),
|
||||
(TorchBackend::Cu101, Version::new([418, 39])),
|
||||
(TorchBackend::Cu100, Version::new([410, 48])),
|
||||
(TorchBackend::Cu92, Version::new([396, 26])),
|
||||
(TorchBackend::Cu91, Version::new([390, 46])),
|
||||
(TorchBackend::Cu90, Version::new([384, 81])),
|
||||
(TorchBackend::Cu80, Version::new([375, 26])),
|
||||
]
|
||||
});
|
||||
|
||||
/// Windows CUDA driver versions and the corresponding CUDA versions.
|
||||
///
|
||||
/// See: <https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_cb.py#L150-L213>
|
||||
static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 23]> = LazyLock::new(|| {
|
||||
[
|
||||
// Table 2 from
|
||||
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
|
||||
(TorchBackend::Cu126, Version::new([528, 33])),
|
||||
(TorchBackend::Cu125, Version::new([528, 33])),
|
||||
(TorchBackend::Cu124, Version::new([528, 33])),
|
||||
(TorchBackend::Cu123, Version::new([528, 33])),
|
||||
(TorchBackend::Cu122, Version::new([528, 33])),
|
||||
(TorchBackend::Cu121, Version::new([528, 33])),
|
||||
(TorchBackend::Cu120, Version::new([528, 33])),
|
||||
// Table 2 from
|
||||
// https://docs.nvidia.com/cuda/archive/11.8.0/cuda-toolkit-release-notes/index.html
|
||||
(TorchBackend::Cu118, Version::new([452, 39])),
|
||||
(TorchBackend::Cu117, Version::new([452, 39])),
|
||||
(TorchBackend::Cu116, Version::new([452, 39])),
|
||||
(TorchBackend::Cu115, Version::new([452, 39])),
|
||||
(TorchBackend::Cu114, Version::new([452, 39])),
|
||||
(TorchBackend::Cu113, Version::new([452, 39])),
|
||||
(TorchBackend::Cu112, Version::new([452, 39])),
|
||||
(TorchBackend::Cu111, Version::new([452, 39])),
|
||||
(TorchBackend::Cu110, Version::new([451, 22])),
|
||||
// Table 1 from
|
||||
// https://docs.nvidia.com/cuda/archive/10.2/cuda-toolkit-release-notes/index.html
|
||||
(TorchBackend::Cu102, Version::new([441, 22])),
|
||||
(TorchBackend::Cu101, Version::new([418, 96])),
|
||||
(TorchBackend::Cu100, Version::new([411, 31])),
|
||||
(TorchBackend::Cu92, Version::new([398, 26])),
|
||||
(TorchBackend::Cu91, Version::new([391, 29])),
|
||||
(TorchBackend::Cu90, Version::new([385, 54])),
|
||||
(TorchBackend::Cu80, Version::new([376, 51])),
|
||||
]
|
||||
});
|
||||
|
||||
static CPU_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cpu").unwrap());
|
||||
static CU126_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu126").unwrap());
|
||||
static CU125_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu125").unwrap());
|
||||
static CU124_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu124").unwrap());
|
||||
static CU123_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu123").unwrap());
|
||||
static CU122_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu122").unwrap());
|
||||
static CU121_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu121").unwrap());
|
||||
static CU120_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu120").unwrap());
|
||||
static CU118_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu118").unwrap());
|
||||
static CU117_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu117").unwrap());
|
||||
static CU116_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu116").unwrap());
|
||||
static CU115_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu115").unwrap());
|
||||
static CU114_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu114").unwrap());
|
||||
static CU113_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu113").unwrap());
|
||||
static CU112_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu112").unwrap());
|
||||
static CU111_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu111").unwrap());
|
||||
static CU110_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu110").unwrap());
|
||||
static CU102_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu102").unwrap());
|
||||
static CU101_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu101").unwrap());
|
||||
static CU100_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu100").unwrap());
|
||||
static CU92_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu92").unwrap());
|
||||
static CU91_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu91").unwrap());
|
||||
static CU90_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu90").unwrap());
|
||||
static CU80_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu80").unwrap());
|
5
crates/uv-torch/src/lib.rs
Normal file
5
crates/uv-torch/src/lib.rs
Normal file
|
@ -0,0 +1,5 @@
|
|||
mod accelerator;
|
||||
mod backend;
|
||||
|
||||
pub use accelerator::*;
|
||||
pub use backend::*;
|
Loading…
Add table
Add a link
Reference in a new issue