Disallow mixing requirements across PyTorch indexes (#13179)
Some checks are pending
CI / cargo dev generate-all (push) Blocked by required conditions
CI / cargo shear (push) Waiting to run
CI / cargo test | ubuntu (push) Blocked by required conditions
CI / check system | pyston (push) Blocked by required conditions
CI / Determine changes (push) Waiting to run
CI / lint (push) Waiting to run
CI / cargo clippy | ubuntu (push) Blocked by required conditions
CI / cargo clippy | windows (push) Blocked by required conditions
CI / cargo test | macos (push) Blocked by required conditions
CI / cargo test | windows (push) Blocked by required conditions
CI / check windows trampoline | aarch64 (push) Blocked by required conditions
CI / check windows trampoline | i686 (push) Blocked by required conditions
CI / build binary | linux libc (push) Blocked by required conditions
CI / check windows trampoline | x86_64 (push) Blocked by required conditions
CI / test windows trampoline | i686 (push) Blocked by required conditions
CI / test windows trampoline | x86_64 (push) Blocked by required conditions
CI / typos (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / integration test | pypy on windows (push) Blocked by required conditions
CI / build binary | linux musl (push) Blocked by required conditions
CI / build binary | macos aarch64 (push) Blocked by required conditions
CI / build binary | macos x86_64 (push) Blocked by required conditions
CI / build binary | windows x86_64 (push) Blocked by required conditions
CI / build binary | windows aarch64 (push) Blocked by required conditions
CI / integration test | uv publish (push) Blocked by required conditions
CI / cargo build (msrv) (push) Blocked by required conditions
CI / build binary | freebsd (push) Blocked by required conditions
CI / ecosystem test | pydantic/pydantic-core (push) Blocked by required conditions
CI / ecosystem test | prefecthq/prefect (push) Blocked by required conditions
CI / integration test | pypy on ubuntu (push) Blocked by required conditions
CI / integration test | free-threaded on windows (push) Blocked by required conditions
CI / check system | python on debian (push) Blocked by required conditions
CI / check system | python on fedora (push) Blocked by required conditions
CI / check system | python on ubuntu (push) Blocked by required conditions
CI / check system | python on opensuse (push) Blocked by required conditions
CI / check system | python on rocky linux 8 (push) Blocked by required conditions
CI / check system | python on rocky linux 9 (push) Blocked by required conditions
CI / check system | pypy on ubuntu (push) Blocked by required conditions
CI / check system | python on macos aarch64 (push) Blocked by required conditions
CI / check system | homebrew python on macos aarch64 (push) Blocked by required conditions
CI / ecosystem test | pallets/flask (push) Blocked by required conditions
CI / smoke test | linux (push) Blocked by required conditions
CI / check system | alpine (push) Blocked by required conditions
CI / smoke test | macos (push) Blocked by required conditions
CI / smoke test | windows x86_64 (push) Blocked by required conditions
CI / smoke test | windows aarch64 (push) Blocked by required conditions
CI / integration test | conda on ubuntu (push) Blocked by required conditions
CI / integration test | deadsnakes python3.9 on ubuntu (push) Blocked by required conditions
CI / integration test | free-threaded on linux (push) Blocked by required conditions
CI / integration test | graalpy on ubuntu (push) Blocked by required conditions
CI / integration test | graalpy on windows (push) Blocked by required conditions
CI / integration test | github actions (push) Blocked by required conditions
CI / integration test | free-threaded python on github actions (push) Blocked by required conditions
CI / integration test | determine publish changes (push) Blocked by required conditions
CI / integration test | uv_build (push) Blocked by required conditions
CI / check cache | ubuntu (push) Blocked by required conditions
CI / check cache | macos aarch64 (push) Blocked by required conditions
CI / check system | python on macos x86-64 (push) Blocked by required conditions
CI / check system | python3.10 on windows x86-64 (push) Blocked by required conditions
CI / check system | python3.10 on windows x86 (push) Blocked by required conditions
CI / check system | python3.13 on windows x86-64 (push) Blocked by required conditions
CI / check system | x86-64 python3.13 on windows aarch64 (push) Blocked by required conditions
CI / check system | windows registry (push) Blocked by required conditions
CI / check system | python3.12 via chocolatey (push) Blocked by required conditions
CI / check system | python3.9 via pyenv (push) Blocked by required conditions
CI / check system | python3.13 (push) Blocked by required conditions
CI / check system | conda3.11 on macos aarch64 (push) Blocked by required conditions
CI / check system | conda3.8 on macos aarch64 (push) Blocked by required conditions
CI / check system | conda3.11 on linux x86-64 (push) Blocked by required conditions
CI / check system | conda3.8 on linux x86-64 (push) Blocked by required conditions
CI / check system | conda3.11 on windows x86-64 (push) Blocked by required conditions
CI / check system | conda3.8 on windows x86-64 (push) Blocked by required conditions
CI / check system | amazonlinux (push) Blocked by required conditions
CI / check system | embedded python3.10 on windows x86-64 (push) Blocked by required conditions
CI / benchmarks (push) Blocked by required conditions

## Summary

If you use `--torch-backend=auto`, we want to avoid selecting (e.g.) a
`+cu124` build of `torch` alongside a `+cu126` build of `torchvision`.
This commit is contained in:
Charlie Marsh 2025-04-28 16:06:18 -04:00 committed by GitHub
parent 6292748371
commit a3dae2512c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 289 additions and 18 deletions

2
Cargo.lock generated
View file

@ -5689,6 +5689,7 @@ dependencies = [
"uv-requirements-txt",
"uv-small-str",
"uv-static",
"uv-torch",
"uv-types",
"uv-warnings",
"uv-workspace",
@ -5823,6 +5824,7 @@ dependencies = [
"serde",
"thiserror 2.0.12",
"tracing",
"url",
"uv-distribution-types",
"uv-normalize",
"uv-pep440",

View file

@ -36,6 +36,7 @@ uv-python = { workspace = true }
uv-requirements-txt = { workspace = true }
uv-small-str = { workspace = true }
uv-static = { workspace = true }
uv-torch = { workspace = true }
uv-types = { workspace = true }
uv-warnings = { workspace = true }
uv-workspace = { workspace = true }

View file

@ -1,7 +1,9 @@
use crate::fork_strategy::ForkStrategy;
use crate::{DependencyMode, ExcludeNewer, PrereleaseMode, ResolutionMode};
use uv_configuration::{BuildOptions, IndexStrategy};
use uv_pypi_types::SupportedEnvironments;
use uv_torch::TorchStrategy;
use crate::fork_strategy::ForkStrategy;
use crate::{DependencyMode, ExcludeNewer, PrereleaseMode, ResolutionMode};
/// Options for resolving a manifest.
#[derive(Debug, Default, Clone, PartialEq, Eq)]
@ -15,6 +17,7 @@ pub struct Options {
pub required_environments: SupportedEnvironments,
pub flexibility: Flexibility,
pub build_options: BuildOptions,
pub torch_backend: Option<TorchStrategy>,
}
/// Builder for [`Options`].
@ -29,6 +32,7 @@ pub struct OptionsBuilder {
required_environments: SupportedEnvironments,
flexibility: Flexibility,
build_options: BuildOptions,
torch_backend: Option<TorchStrategy>,
}
impl OptionsBuilder {
@ -100,6 +104,13 @@ impl OptionsBuilder {
self
}
/// Sets the [`TorchStrategy`].
#[must_use]
pub fn torch_backend(mut self, torch_backend: Option<TorchStrategy>) -> Self {
self.torch_backend = torch_backend;
self
}
/// Builds the options.
pub fn build(self) -> Options {
Options {
@ -112,6 +123,7 @@ impl OptionsBuilder {
required_environments: self.required_environments,
flexibility: self.flexibility,
build_options: self.build_options,
torch_backend: self.torch_backend,
}
}
}

View file

@ -127,10 +127,11 @@ impl PubGrubDependency {
url,
}
}
PubGrubPackageInner::Root(_) => unreachable!("root package in dependencies"),
PubGrubPackageInner::Root(_) => unreachable!("Root package in dependencies"),
PubGrubPackageInner::Python(_) => {
unreachable!("python package in dependencies")
unreachable!("Python package in dependencies")
}
PubGrubPackageInner::System(_) => unreachable!("System package in dependencies"),
}
})
}

View file

@ -44,6 +44,8 @@ pub(crate) enum PubGrubPackageInner {
Root(Option<PackageName>),
/// A Python version.
Python(PubGrubPython),
/// A system package, which is used to represent a non-Python package.
System(PackageName),
/// A Python package.
///
/// Note that it is guaranteed that `extra` and `dev` are never both
@ -134,6 +136,7 @@ impl PubGrubPackage {
// package is never returned by `get_dependencies`. So these cases never occur.
PubGrubPackageInner::Root(None) | PubGrubPackageInner::Python(_) => None,
PubGrubPackageInner::Root(Some(name))
| PubGrubPackageInner::System(name)
| PubGrubPackageInner::Package { name, .. }
| PubGrubPackageInner::Extra { name, .. }
| PubGrubPackageInner::Dev { name, .. }
@ -141,11 +144,13 @@ impl PubGrubPackage {
}
}
/// Returns the name of this PubGrub package, if it is not the root package or a Python version
/// constraint.
/// Returns the name of this PubGrub package, if it is not the root package, a Python version
/// constraint, or a system package.
pub(crate) fn name_no_root(&self) -> Option<&PackageName> {
match &**self {
PubGrubPackageInner::Root(_) | PubGrubPackageInner::Python(_) => None,
PubGrubPackageInner::Root(_)
| PubGrubPackageInner::Python(_)
| PubGrubPackageInner::System(_) => None,
PubGrubPackageInner::Package { name, .. }
| PubGrubPackageInner::Extra { name, .. }
| PubGrubPackageInner::Dev { name, .. }
@ -159,7 +164,9 @@ impl PubGrubPackage {
match &**self {
// A root can never be a dependency of another package, and a `Python` pubgrub
// package is never returned by `get_dependencies`. So these cases never occur.
PubGrubPackageInner::Root(_) | PubGrubPackageInner::Python(_) => MarkerTree::TRUE,
PubGrubPackageInner::Root(_)
| PubGrubPackageInner::Python(_)
| PubGrubPackageInner::System(_) => MarkerTree::TRUE,
PubGrubPackageInner::Package { marker, .. }
| PubGrubPackageInner::Extra { marker, .. }
| PubGrubPackageInner::Dev { marker, .. } => *marker,
@ -177,6 +184,7 @@ impl PubGrubPackage {
// package is never returned by `get_dependencies`. So these cases never occur.
PubGrubPackageInner::Root(_)
| PubGrubPackageInner::Python(_)
| PubGrubPackageInner::System(_)
| PubGrubPackageInner::Package { extra: None, .. }
| PubGrubPackageInner::Dev { .. }
| PubGrubPackageInner::Marker { .. } => None,
@ -198,6 +206,7 @@ impl PubGrubPackage {
// package is never returned by `get_dependencies`. So these cases never occur.
PubGrubPackageInner::Root(_)
| PubGrubPackageInner::Python(_)
| PubGrubPackageInner::System(_)
| PubGrubPackageInner::Package { dev: None, .. }
| PubGrubPackageInner::Extra { .. }
| PubGrubPackageInner::Marker { .. } => None,
@ -256,7 +265,9 @@ impl PubGrubPackage {
/// reporting where this routine is used.
pub(crate) fn simplify_markers(&mut self, python_requirement: &PythonRequirement) {
match *Arc::make_mut(&mut self.0) {
PubGrubPackageInner::Root(_) | PubGrubPackageInner::Python(_) => {}
PubGrubPackageInner::Root(_)
| PubGrubPackageInner::Python(_)
| PubGrubPackageInner::System(_) => {}
PubGrubPackageInner::Package { ref mut marker, .. }
| PubGrubPackageInner::Extra { ref mut marker, .. }
| PubGrubPackageInner::Dev { ref mut marker, .. }
@ -272,6 +283,7 @@ impl PubGrubPackage {
match &**self {
PubGrubPackageInner::Root(_) => "root",
PubGrubPackageInner::Python(_) => "python",
PubGrubPackageInner::System(_) => "system",
PubGrubPackageInner::Package { .. } => "package",
PubGrubPackageInner::Extra { .. } => "extra",
PubGrubPackageInner::Dev { .. } => "dev",
@ -304,6 +316,7 @@ impl std::fmt::Display for PubGrubPackageInner {
}
}
Self::Python(_) => write!(f, "Python"),
Self::System(name) => write!(f, "system:{name}"),
Self::Package {
name,
extra: None,

View file

@ -129,6 +129,7 @@ impl PubGrubPriorities {
PubGrubPackageInner::Python(PubGrubPython::Target) => {
(PubGrubPriority::Root, PubGrubTiebreaker::from(2))
}
PubGrubPackageInner::System(_) => (PubGrubPriority::Root, PubGrubTiebreaker::from(3)),
PubGrubPackageInner::Marker { name, .. }
| PubGrubPackageInner::Extra { name, .. }
| PubGrubPackageInner::Dev { name, .. }

View file

@ -63,10 +63,6 @@ use crate::resolver::environment::{
fork_version_by_marker, fork_version_by_python_requirement, ForkingPossibility,
};
pub(crate) use crate::resolver::fork_map::{ForkMap, ForkSet};
pub(crate) use crate::resolver::urls::Urls;
use crate::universal_marker::{ConflictMarker, UniversalMarker};
pub(crate) use provider::MetadataUnavailable;
pub use crate::resolver::index::InMemoryIndex;
use crate::resolver::indexes::Indexes;
pub use crate::resolver::provider::{
@ -74,8 +70,13 @@ pub use crate::resolver::provider::{
VersionsResponse, WheelMetadataResult,
};
pub use crate::resolver::reporter::{BuildId, Reporter};
use crate::resolver::system::SystemDependency;
pub(crate) use crate::resolver::urls::Urls;
use crate::universal_marker::{ConflictMarker, UniversalMarker};
use crate::yanks::AllowedYanks;
use crate::{marker, DependencyMode, Exclusions, FlatIndex, Options, ResolutionMode, VersionMap};
pub(crate) use provider::MetadataUnavailable;
use uv_torch::TorchStrategy;
mod availability;
mod batch_prefetch;
@ -86,6 +87,7 @@ mod index;
mod indexes;
mod provider;
mod reporter;
mod system;
mod urls;
/// The number of conflicts a package may accumulate before we re-prioritize and backtrack.
@ -598,6 +600,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
next_id,
next_package,
&version,
&state.pins,
&state.fork_urls,
&state.env,
&state.python_requirement,
@ -1055,6 +1058,15 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
Ok(None)
}
PubGrubPackageInner::System(_) => {
// We don't care what the actual version is here, just that it's consistent across
// the dependency graph.
let Some(version) = range.as_singleton() else {
return Ok(None);
};
Ok(Some(ResolverVersion::Unforked(version.clone())))
}
PubGrubPackageInner::Marker { name, .. }
| PubGrubPackageInner::Extra { name, .. }
| PubGrubPackageInner::Dev { name, .. }
@ -1641,6 +1653,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
id: Id<PubGrubPackage>,
package: &PubGrubPackage,
version: &Version,
pins: &FilePins,
fork_urls: &ForkUrls,
env: &ResolverEnvironment,
python_requirement: &PythonRequirement,
@ -1650,6 +1663,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
id,
package,
version,
pins,
fork_urls,
env,
python_requirement,
@ -1674,6 +1688,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
id: Id<PubGrubPackage>,
package: &PubGrubPackage,
version: &Version,
pins: &FilePins,
fork_urls: &ForkUrls,
env: &ResolverEnvironment,
python_requirement: &PythonRequirement,
@ -1781,6 +1796,24 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
}
}
// Identify any system dependencies based on the index URL.
let system_dependencies = self
.options
.torch_backend
.as_ref()
.filter(|torch_backend| matches!(torch_backend, TorchStrategy::Auto { .. }))
.and_then(|_| pins.get(name, version).and_then(ResolvedDist::index))
.map(IndexUrl::url)
.and_then(SystemDependency::from_index)
.into_iter()
.inspect(|system_dependency| {
debug!(
"Adding system dependency `{}` for `{package}@{version}`",
system_dependency
);
})
.map(PubGrubDependency::from);
let requirements = self.flatten_requirements(
&metadata.requires_dist,
&metadata.dependency_groups,
@ -1800,11 +1833,14 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
Some(name),
)
})
.chain(system_dependencies)
.collect()
}
PubGrubPackageInner::Python(_) => return Ok(Dependencies::Unforkable(Vec::default())),
PubGrubPackageInner::System(_) => return Ok(Dependencies::Unforkable(Vec::default())),
// Add a dependency on both the marker and base package.
PubGrubPackageInner::Marker { name, marker } => {
return Ok(Dependencies::Unforkable(
@ -2562,6 +2598,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
match &**package {
PubGrubPackageInner::Root(_) => {}
PubGrubPackageInner::Python(_) => {}
PubGrubPackageInner::System(_) => {}
PubGrubPackageInner::Marker { .. } => {}
PubGrubPackageInner::Extra { .. } => {}
PubGrubPackageInner::Dev { .. } => {}

View file

@ -0,0 +1,84 @@
use std::str::FromStr;
use pubgrub::Ranges;
use url::Url;
use uv_normalize::PackageName;
use uv_pep440::Version;
use uv_torch::TorchBackend;
use crate::pubgrub::{PubGrubDependency, PubGrubPackage, PubGrubPackageInner};
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) struct SystemDependency {
/// The name of the system dependency (e.g., `cuda`).
name: PackageName,
/// The version of the system dependency (e.g., `12.4`).
version: Version,
}
impl SystemDependency {
/// Extract a [`SystemDependency`] from an index URL.
///
/// For example, given `https://download.pytorch.org/whl/cu124`, returns CUDA 12.4.
pub(super) fn from_index(index: &Url) -> Option<Self> {
let backend = TorchBackend::from_index(index)?;
let cuda_version = backend.cuda_version()?;
Some(Self {
name: PackageName::from_str("cuda").unwrap(),
version: cuda_version,
})
}
}
impl std::fmt::Display for SystemDependency {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}@{}", self.name, self.version)
}
}
impl From<SystemDependency> for PubGrubDependency {
fn from(value: SystemDependency) -> Self {
PubGrubDependency {
package: PubGrubPackage::from(PubGrubPackageInner::System(value.name)),
version: Ranges::singleton(value.version),
url: None,
}
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use url::Url;
use uv_normalize::PackageName;
use uv_pep440::Version;
use crate::resolver::system::SystemDependency;
#[test]
fn pypi() {
let url = Url::parse("https://pypi.org/simple").unwrap();
assert_eq!(SystemDependency::from_index(&url), None);
}
#[test]
fn pytorch_cuda_12_4() {
let url = Url::parse("https://download.pytorch.org/whl/cu124").unwrap();
assert_eq!(
SystemDependency::from_index(&url),
Some(SystemDependency {
name: PackageName::from_str("cuda").unwrap(),
version: Version::new([12, 4]),
})
);
}
#[test]
fn pytorch_cpu() {
let url = Url::parse("https://download.pytorch.org/whl/cpu").unwrap();
assert_eq!(SystemDependency::from_index(&url), None);
}
}

View file

@ -23,6 +23,7 @@ schemars = { workspace = true, optional = true }
serde = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
url = { workspace = true }
[lints]
workspace = true

View file

@ -41,6 +41,7 @@ use std::str::FromStr;
use std::sync::LazyLock;
use either::Either;
use url::Url;
use uv_distribution_types::IndexUrl;
use uv_normalize::PackageName;
@ -230,7 +231,7 @@ impl TorchStrategy {
}
/// The available backends for PyTorch.
#[derive(Debug, Clone, Eq, PartialEq)]
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum TorchBackend {
Cpu,
Cu128,
@ -261,7 +262,7 @@ pub enum TorchBackend {
impl TorchBackend {
/// Return the appropriate index URL for the given [`TorchBackend`].
fn index_url(&self) -> &'static IndexUrl {
fn index_url(self) -> &'static IndexUrl {
match self {
Self::Cpu => &CPU_INDEX_URL,
Self::Cu128 => &CU128_INDEX_URL,
@ -290,6 +291,87 @@ impl TorchBackend {
Self::Cu80 => &CU80_INDEX_URL,
}
}
/// Extract a [`TorchBackend`] from an index URL.
pub fn from_index(index: &Url) -> Option<Self> {
let backend_identifier = if index.host_str() == Some("download.pytorch.org") {
// E.g., `https://download.pytorch.org/whl/cu124`
let mut path_segments = index.path_segments()?;
if path_segments.next() != Some("whl") {
return None;
}
path_segments.next()?
} else {
return None;
};
Self::from_str(backend_identifier).ok()
}
/// Returns the CUDA [`Version`] for the given [`TorchBackend`].
pub fn cuda_version(&self) -> Option<Version> {
match self {
TorchBackend::Cpu => None,
TorchBackend::Cu128 => Some(Version::new([12, 8])),
TorchBackend::Cu126 => Some(Version::new([12, 6])),
TorchBackend::Cu125 => Some(Version::new([12, 5])),
TorchBackend::Cu124 => Some(Version::new([12, 4])),
TorchBackend::Cu123 => Some(Version::new([12, 3])),
TorchBackend::Cu122 => Some(Version::new([12, 2])),
TorchBackend::Cu121 => Some(Version::new([12, 1])),
TorchBackend::Cu120 => Some(Version::new([12, 0])),
TorchBackend::Cu118 => Some(Version::new([11, 8])),
TorchBackend::Cu117 => Some(Version::new([11, 7])),
TorchBackend::Cu116 => Some(Version::new([11, 6])),
TorchBackend::Cu115 => Some(Version::new([11, 5])),
TorchBackend::Cu114 => Some(Version::new([11, 4])),
TorchBackend::Cu113 => Some(Version::new([11, 3])),
TorchBackend::Cu112 => Some(Version::new([11, 2])),
TorchBackend::Cu111 => Some(Version::new([11, 1])),
TorchBackend::Cu110 => Some(Version::new([11, 0])),
TorchBackend::Cu102 => Some(Version::new([10, 2])),
TorchBackend::Cu101 => Some(Version::new([10, 1])),
TorchBackend::Cu100 => Some(Version::new([10, 0])),
TorchBackend::Cu92 => Some(Version::new([9, 2])),
TorchBackend::Cu91 => Some(Version::new([9, 1])),
TorchBackend::Cu90 => Some(Version::new([9, 0])),
TorchBackend::Cu80 => Some(Version::new([8, 0])),
}
}
}
impl FromStr for TorchBackend {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"cpu" => Ok(TorchBackend::Cpu),
"cu128" => Ok(TorchBackend::Cu128),
"cu126" => Ok(TorchBackend::Cu126),
"cu125" => Ok(TorchBackend::Cu125),
"cu124" => Ok(TorchBackend::Cu124),
"cu123" => Ok(TorchBackend::Cu123),
"cu122" => Ok(TorchBackend::Cu122),
"cu121" => Ok(TorchBackend::Cu121),
"cu120" => Ok(TorchBackend::Cu120),
"cu118" => Ok(TorchBackend::Cu118),
"cu117" => Ok(TorchBackend::Cu117),
"cu116" => Ok(TorchBackend::Cu116),
"cu115" => Ok(TorchBackend::Cu115),
"cu114" => Ok(TorchBackend::Cu114),
"cu113" => Ok(TorchBackend::Cu113),
"cu112" => Ok(TorchBackend::Cu112),
"cu111" => Ok(TorchBackend::Cu111),
"cu110" => Ok(TorchBackend::Cu110),
"cu102" => Ok(TorchBackend::Cu102),
"cu101" => Ok(TorchBackend::Cu101),
"cu100" => Ok(TorchBackend::Cu100),
"cu92" => Ok(TorchBackend::Cu92),
"cu91" => Ok(TorchBackend::Cu91),
"cu90" => Ok(TorchBackend::Cu90),
"cu80" => Ok(TorchBackend::Cu80),
_ => Err(format!("Unknown PyTorch backend: {s}")),
}
}
}
/// Linux CUDA driver versions and the corresponding CUDA versions.

View file

@ -395,7 +395,7 @@ pub(crate) async fn pip_compile(
.index_urls(index_locations.index_urls())
.index_strategy(index_strategy)
.url_auth_policies(UrlAuthPolicies::from(&index_locations))
.torch_backend(torch_backend)
.torch_backend(torch_backend.clone())
.markers(interpreter.markers())
.platform(interpreter.platform())
.build();
@ -482,6 +482,7 @@ pub(crate) async fn pip_compile(
.dependency_mode(dependency_mode)
.exclude_newer(exclude_newer)
.index_strategy(index_strategy)
.torch_backend(torch_backend)
.build_options(build_options.clone())
.build();

View file

@ -362,7 +362,7 @@ pub(crate) async fn pip_install(
.index_urls(index_locations.index_urls())
.index_strategy(index_strategy)
.url_auth_policies(UrlAuthPolicies::from(&index_locations))
.torch_backend(torch_backend)
.torch_backend(torch_backend.clone())
.markers(interpreter.markers())
.platform(interpreter.platform())
.build();
@ -456,6 +456,7 @@ pub(crate) async fn pip_install(
.dependency_mode(dependency_mode)
.exclude_newer(exclude_newer)
.index_strategy(index_strategy)
.torch_backend(torch_backend)
.build_options(build_options.clone())
.build();

View file

@ -294,7 +294,7 @@ pub(crate) async fn pip_sync(
.index_urls(index_locations.index_urls())
.index_strategy(index_strategy)
.url_auth_policies(UrlAuthPolicies::from(&index_locations))
.torch_backend(torch_backend)
.torch_backend(torch_backend.clone())
.markers(interpreter.markers())
.platform(interpreter.platform())
.build();
@ -391,6 +391,7 @@ pub(crate) async fn pip_sync(
.dependency_mode(dependency_mode)
.exclude_newer(exclude_newer)
.index_strategy(index_strategy)
.torch_backend(torch_backend)
.build_options(build_options.clone())
.build();

View file

@ -17300,3 +17300,37 @@ async fn index_has_no_requires_python() -> Result<()> {
Ok(())
}
/// Disallow resolving to multiple different PyTorch indexes.
#[test]
fn incompatible_cuda() -> Result<()> {
let context = TestContext::new("3.11");
let requirements_in = context.temp_dir.child("requirements.in");
requirements_in.write_str(indoc! {r"
torch==2.6.0+cu126
torchvision==0.16.0+cu121
"})?;
uv_snapshot!(context
.pip_compile()
.env_remove(EnvVars::UV_EXCLUDE_NEWER)
.env(EnvVars::UV_TORCH_BACKEND, "auto")
.env(EnvVars::UV_CUDA_DRIVER_VERSION, "525.60.13")
.arg("--preview")
.arg("requirements.in")
.arg("--python-platform")
.arg("x86_64-manylinux_2_28")
.arg("--python-version")
.arg("3.11"), @r"
success: false
exit_code: 1
----- stdout -----
----- stderr -----
× No solution found when resolving dependencies:
Because torchvision==0.16.0+cu121 depends on system:cuda==12.1 and torch==2.6.0+cu126 depends on system:cuda==12.6, we can conclude that torch==2.6.0+cu126 and torchvision==0.16.0+cu121 are incompatible.
And because you require torch==2.6.0+cu126 and torchvision==0.16.0+cu121, we can conclude that your requirements are unsatisfiable.
");
Ok(())
}