Add PyTorch v2.7.0 to GPU backend (#13072)

## Summary

The first version to support CUDA 12.8.
This commit is contained in:
Charlie Marsh 2025-04-23 16:59:41 -04:00 committed by GitHub
parent 473d7c75a4
commit 4bef9fadbb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 24 additions and 2 deletions

View file

@ -59,6 +59,8 @@ pub enum TorchMode {
Auto,
/// Use the CPU-only PyTorch index.
Cpu,
/// Use the PyTorch index for CUDA 12.8.
Cu128,
/// Use the PyTorch index for CUDA 12.6.
Cu126,
/// Use the PyTorch index for CUDA 12.5.
@ -131,6 +133,7 @@ impl TorchStrategy {
}
}
TorchMode::Cpu => Ok(Self::Backend(TorchBackend::Cpu)),
TorchMode::Cu128 => Ok(Self::Backend(TorchBackend::Cu128)),
TorchMode::Cu126 => Ok(Self::Backend(TorchBackend::Cu126)),
TorchMode::Cu125 => Ok(Self::Backend(TorchBackend::Cu125)),
TorchMode::Cu124 => Ok(Self::Backend(TorchBackend::Cu124)),
@ -230,6 +233,7 @@ impl TorchStrategy {
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum TorchBackend {
Cpu,
Cu128,
Cu126,
Cu125,
Cu124,
@ -260,6 +264,7 @@ impl TorchBackend {
fn index_url(&self) -> &'static IndexUrl {
match self {
Self::Cpu => &CPU_INDEX_URL,
Self::Cu128 => &CU128_INDEX_URL,
Self::Cu126 => &CU126_INDEX_URL,
Self::Cu125 => &CU125_INDEX_URL,
Self::Cu124 => &CU124_INDEX_URL,
@ -290,10 +295,11 @@ impl TorchBackend {
/// Linux CUDA driver versions and the corresponding CUDA versions.
///
/// See: <https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_cb.py#L150-L213>
static LINUX_DRIVERS: LazyLock<[(TorchBackend, Version); 23]> = LazyLock::new(|| {
static LINUX_DRIVERS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock::new(|| {
[
// Table 2 from
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
(TorchBackend::Cu128, Version::new([525, 60, 13])),
(TorchBackend::Cu126, Version::new([525, 60, 13])),
(TorchBackend::Cu125, Version::new([525, 60, 13])),
(TorchBackend::Cu124, Version::new([525, 60, 13])),
@ -327,10 +333,11 @@ static LINUX_DRIVERS: LazyLock<[(TorchBackend, Version); 23]> = LazyLock::new(||
/// Windows CUDA driver versions and the corresponding CUDA versions.
///
/// See: <https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_cb.py#L150-L213>
static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 23]> = LazyLock::new(|| {
static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock::new(|| {
[
// Table 2 from
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
(TorchBackend::Cu128, Version::new([528, 33])),
(TorchBackend::Cu126, Version::new([528, 33])),
(TorchBackend::Cu125, Version::new([528, 33])),
(TorchBackend::Cu124, Version::new([528, 33])),
@ -363,6 +370,8 @@ static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 23]> = LazyLock
static CPU_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cpu").unwrap());
static CU128_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu128").unwrap());
static CU126_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu126").unwrap());
static CU125_INDEX_URL: LazyLock<IndexUrl> =