mirror of
https://github.com/astral-sh/uv.git
synced 2025-11-12 00:45:35 +00:00
Check all compatible torch indexes when --torch-backend is enabled (#12385)
## Summary It's possible that the PyTorch version the user depends on isn't in the latest index. These indexes are equally trusted, so we should override the policy. Closes #12357.
This commit is contained in:
parent
59c6d34b59
commit
4215d0e16b
2 changed files with 43 additions and 33 deletions
|
|
@ -157,12 +157,9 @@ impl TorchStrategy {
|
|||
}
|
||||
}
|
||||
|
||||
/// Return the appropriate index URLs for the given [`TorchStrategy`] and [`PackageName`].
|
||||
pub fn index_urls(
|
||||
&self,
|
||||
package_name: &PackageName,
|
||||
) -> Option<impl Iterator<Item = &IndexUrl>> {
|
||||
if !matches!(
|
||||
/// Returns `true` if the [`TorchStrategy`] applies to the given [`PackageName`].
|
||||
pub fn applies_to(&self, package_name: &PackageName) -> bool {
|
||||
matches!(
|
||||
package_name.as_str(),
|
||||
"torch"
|
||||
| "torch-model-archiver"
|
||||
|
|
@ -176,10 +173,11 @@ impl TorchStrategy {
|
|||
| "torchtext"
|
||||
| "torchvision"
|
||||
| "pytorch-triton"
|
||||
) {
|
||||
return None;
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
/// Return the appropriate index URLs for the given [`TorchStrategy`].
|
||||
pub fn index_urls(&self) -> impl Iterator<Item = &IndexUrl> {
|
||||
match self {
|
||||
TorchStrategy::Auto { os, driver_version } => {
|
||||
// If this is a GPU-enabled package, and CUDA drivers are installed, use PyTorch's CUDA
|
||||
|
|
@ -187,21 +185,19 @@ impl TorchStrategy {
|
|||
//
|
||||
// See: https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_patch.py#L36-L49
|
||||
match os {
|
||||
Os::Manylinux { .. } | Os::Musllinux { .. } => {
|
||||
Some(Either::Left(Either::Left(
|
||||
LINUX_DRIVERS
|
||||
.iter()
|
||||
.filter_map(move |(backend, version)| {
|
||||
if driver_version >= version {
|
||||
Some(backend.index_url())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
|
||||
)))
|
||||
}
|
||||
Os::Windows => Some(Either::Left(Either::Right(
|
||||
Os::Manylinux { .. } | Os::Musllinux { .. } => Either::Left(Either::Left(
|
||||
LINUX_DRIVERS
|
||||
.iter()
|
||||
.filter_map(move |(backend, version)| {
|
||||
if driver_version >= version {
|
||||
Some(backend.index_url())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
|
||||
)),
|
||||
Os::Windows => Either::Left(Either::Right(
|
||||
WINDOWS_CUDA_VERSIONS
|
||||
.iter()
|
||||
.filter_map(move |(backend, version)| {
|
||||
|
|
@ -212,7 +208,7 @@ impl TorchStrategy {
|
|||
}
|
||||
})
|
||||
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
|
||||
))),
|
||||
)),
|
||||
Os::Macos { .. }
|
||||
| Os::FreeBsd { .. }
|
||||
| Os::NetBsd { .. }
|
||||
|
|
@ -220,14 +216,12 @@ impl TorchStrategy {
|
|||
| Os::Dragonfly { .. }
|
||||
| Os::Illumos { .. }
|
||||
| Os::Haiku { .. }
|
||||
| Os::Android { .. } => Some(Either::Right(std::iter::once(
|
||||
TorchBackend::Cpu.index_url(),
|
||||
))),
|
||||
| Os::Android { .. } => {
|
||||
Either::Right(std::iter::once(TorchBackend::Cpu.index_url()))
|
||||
}
|
||||
}
|
||||
}
|
||||
TorchStrategy::Backend(backend) => {
|
||||
Some(Either::Right(std::iter::once(backend.index_url())))
|
||||
}
|
||||
TorchStrategy::Backend(backend) => Either::Right(std::iter::once(backend.index_url())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue