Check all compatible torch indexes when --torch-backend is enabled (#12385)

## Summary

It's possible that the PyTorch version the user depends on isn't in the
latest index. These indexes are equally trusted, so we should override
the policy.

Closes #12357.
This commit is contained in:
Charlie Marsh 2025-03-22 08:53:23 -07:00 committed by GitHub
parent 59c6d34b59
commit 4215d0e16b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 43 additions and 33 deletions

View file

@ -157,12 +157,9 @@ impl TorchStrategy {
}
}
/// Return the appropriate index URLs for the given [`TorchStrategy`] and [`PackageName`].
pub fn index_urls(
&self,
package_name: &PackageName,
) -> Option<impl Iterator<Item = &IndexUrl>> {
if !matches!(
/// Returns `true` if the [`TorchStrategy`] applies to the given [`PackageName`].
pub fn applies_to(&self, package_name: &PackageName) -> bool {
matches!(
package_name.as_str(),
"torch"
| "torch-model-archiver"
@ -176,10 +173,11 @@ impl TorchStrategy {
| "torchtext"
| "torchvision"
| "pytorch-triton"
) {
return None;
}
)
}
/// Return the appropriate index URLs for the given [`TorchStrategy`].
pub fn index_urls(&self) -> impl Iterator<Item = &IndexUrl> {
match self {
TorchStrategy::Auto { os, driver_version } => {
// If this is a GPU-enabled package, and CUDA drivers are installed, use PyTorch's CUDA
@ -187,21 +185,19 @@ impl TorchStrategy {
//
// See: https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_patch.py#L36-L49
match os {
Os::Manylinux { .. } | Os::Musllinux { .. } => {
Some(Either::Left(Either::Left(
LINUX_DRIVERS
.iter()
.filter_map(move |(backend, version)| {
if driver_version >= version {
Some(backend.index_url())
} else {
None
}
})
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
)))
}
Os::Windows => Some(Either::Left(Either::Right(
Os::Manylinux { .. } | Os::Musllinux { .. } => Either::Left(Either::Left(
LINUX_DRIVERS
.iter()
.filter_map(move |(backend, version)| {
if driver_version >= version {
Some(backend.index_url())
} else {
None
}
})
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
)),
Os::Windows => Either::Left(Either::Right(
WINDOWS_CUDA_VERSIONS
.iter()
.filter_map(move |(backend, version)| {
@ -212,7 +208,7 @@ impl TorchStrategy {
}
})
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
))),
)),
Os::Macos { .. }
| Os::FreeBsd { .. }
| Os::NetBsd { .. }
@ -220,14 +216,12 @@ impl TorchStrategy {
| Os::Dragonfly { .. }
| Os::Illumos { .. }
| Os::Haiku { .. }
| Os::Android { .. } => Some(Either::Right(std::iter::once(
TorchBackend::Cpu.index_url(),
))),
| Os::Android { .. } => {
Either::Right(std::iter::once(TorchBackend::Cpu.index_url()))
}
}
}
TorchStrategy::Backend(backend) => {
Some(Either::Right(std::iter::once(backend.index_url())))
}
TorchStrategy::Backend(backend) => Either::Right(std::iter::once(backend.index_url())),
}
}
}