diff --git a/crates/uv-torch/src/backend.rs b/crates/uv-torch/src/backend.rs index a60ff0cad..9db252c77 100644 --- a/crates/uv-torch/src/backend.rs +++ b/crates/uv-torch/src/backend.rs @@ -59,6 +59,8 @@ pub enum TorchMode { Auto, /// Use the CPU-only PyTorch index. Cpu, + /// Use the PyTorch index for CUDA 12.8. + Cu128, /// Use the PyTorch index for CUDA 12.6. Cu126, /// Use the PyTorch index for CUDA 12.5. @@ -131,6 +133,7 @@ impl TorchStrategy { } } TorchMode::Cpu => Ok(Self::Backend(TorchBackend::Cpu)), + TorchMode::Cu128 => Ok(Self::Backend(TorchBackend::Cu128)), TorchMode::Cu126 => Ok(Self::Backend(TorchBackend::Cu126)), TorchMode::Cu125 => Ok(Self::Backend(TorchBackend::Cu125)), TorchMode::Cu124 => Ok(Self::Backend(TorchBackend::Cu124)), @@ -230,6 +233,7 @@ impl TorchStrategy { #[derive(Debug, Clone, Eq, PartialEq)] pub enum TorchBackend { Cpu, + Cu128, Cu126, Cu125, Cu124, @@ -260,6 +264,7 @@ impl TorchBackend { fn index_url(&self) -> &'static IndexUrl { match self { Self::Cpu => &CPU_INDEX_URL, + Self::Cu128 => &CU128_INDEX_URL, Self::Cu126 => &CU126_INDEX_URL, Self::Cu125 => &CU125_INDEX_URL, Self::Cu124 => &CU124_INDEX_URL, @@ -290,10 +295,11 @@ impl TorchBackend { /// Linux CUDA driver versions and the corresponding CUDA versions. /// /// See: -static LINUX_DRIVERS: LazyLock<[(TorchBackend, Version); 23]> = LazyLock::new(|| { +static LINUX_DRIVERS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock::new(|| { [ // Table 2 from // https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html + (TorchBackend::Cu128, Version::new([525, 60, 13])), (TorchBackend::Cu126, Version::new([525, 60, 13])), (TorchBackend::Cu125, Version::new([525, 60, 13])), (TorchBackend::Cu124, Version::new([525, 60, 13])), @@ -327,10 +333,11 @@ static LINUX_DRIVERS: LazyLock<[(TorchBackend, Version); 23]> = LazyLock::new(|| /// Windows CUDA driver versions and the corresponding CUDA versions. /// /// See: -static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 23]> = LazyLock::new(|| { +static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock::new(|| { [ // Table 2 from // https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html + (TorchBackend::Cu128, Version::new([528, 33])), (TorchBackend::Cu126, Version::new([528, 33])), (TorchBackend::Cu125, Version::new([528, 33])), (TorchBackend::Cu124, Version::new([528, 33])), @@ -363,6 +370,8 @@ static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 23]> = LazyLock static CPU_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cpu").unwrap()); +static CU128_INDEX_URL: LazyLock = + LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu128").unwrap()); static CU126_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu126").unwrap()); static CU125_INDEX_URL: LazyLock = diff --git a/docs/reference/cli.md b/docs/reference/cli.md index 8851ca2c0..f8399bcd2 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -6073,6 +6073,8 @@ uv pip compile [OPTIONS] >
  • cpu: Use the CPU-only PyTorch index
  • +
  • cu128: Use the PyTorch index for CUDA 12.8
  • +
  • cu126: Use the PyTorch index for CUDA 12.6
  • cu125: Use the PyTorch index for CUDA 12.5
  • @@ -6548,6 +6550,8 @@ uv pip sync [OPTIONS] ...
  • cpu: Use the CPU-only PyTorch index
  • +
  • cu128: Use the PyTorch index for CUDA 12.8
  • +
  • cu126: Use the PyTorch index for CUDA 12.6
  • cu125: Use the PyTorch index for CUDA 12.5
  • @@ -7095,6 +7099,8 @@ uv pip install [OPTIONS] |--editable cpu: Use the CPU-only PyTorch index +
  • cu128: Use the PyTorch index for CUDA 12.8
  • +
  • cu126: Use the PyTorch index for CUDA 12.6
  • cu125: Use the PyTorch index for CUDA 12.5
  • diff --git a/uv.schema.json b/uv.schema.json index edaf3c9c4..aab9c9e45 100644 --- a/uv.schema.json +++ b/uv.schema.json @@ -2321,6 +2321,13 @@ "cpu" ] }, + { + "description": "Use the PyTorch index for CUDA 12.8.", + "type": "string", + "enum": [ + "cu128" + ] + }, { "description": "Use the PyTorch index for CUDA 12.6.", "type": "string",