Respect local versions for all user requirements (#5232)

## Summary

This fixes a few bugs introduced by
https://github.com/astral-sh/uv/pull/5104. I previously thought we could
track conflicting locals the same way we track conflicting URLs in
forks, but it turns out that ends up being very tricky. URL forks work
because we prioritize directly URL requirements. We can't prioritize
locals in the same way without conflicting with the URL prioritization
(this may be possible but it's not trivial), so we run into issues where
a correct resolution depends on the order in which dependencies are
traversed.

Instead, we track local versions across all forks in `Locals`. When
applying a local version, we apply all locals with markers that
intersect with the current fork. This way we end up applying some local
versions without creating a fork. For example, given:
```
// pyproject.toml
dependencies = [
    "torch==2.0.0+cu118 ; platform_machine == 'x86_64'",
]

// requirements.in
torch==2.0.0
.
```

We choose `2.0.0+cu118` in all cases. However, if a disjoint fork is
created based on local versions, the resolver will choose the most
compatible local when it narrows to a specific fork. Thus we correctly
respect local versions when forking:
```
// pyproject.toml
dependencies = [
    "torch==2.0.0+cu118 ; platform_machine == 'x86_64'",
    "torch==2.0.0+cpu ; platform_machine != 'x86_64'"
]

// requirements.in
torch==2.0.0
.
``` 

We should also be able to use a similar strategy for
https://github.com/astral-sh/uv/pull/5150.

## Test Plan

This fixes https://github.com/astral-sh/uv/issues/5220 locally for me,
as well as a few other bugs that were not reported yet.
This commit is contained in:
Ibraheem Ahmed 2024-07-19 17:56:09 -04:00 committed by GitHub
parent 92e11022e7
commit bb73edb03b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 718 additions and 95 deletions

View file

@ -12,37 +12,37 @@ use pypi_types::{
use uv_normalize::{ExtraName, PackageName};
use crate::pubgrub::{PubGrubPackage, PubGrubPackageInner};
use crate::resolver::ForkLocals;
use crate::{PubGrubSpecifier, ResolveError};
#[derive(Clone, Debug)]
pub(crate) struct PubGrubDependency {
pub(crate) package: PubGrubPackage,
pub(crate) version: Range<Version>,
/// The original version specifiers from the requirement.
pub(crate) specifier: Option<VersionSpecifiers>,
/// This field is set if the [`Requirement`] had a URL. We still use a URL from [`Urls`]
/// even if this field is None where there is an override with a URL or there is a different
/// requirement or constraint for the same package that has a URL.
pub(crate) url: Option<VerbatimParsedUrl>,
/// The local version for this requirement, if specified.
pub(crate) local: Option<Version>,
}
impl PubGrubDependency {
pub(crate) fn from_requirement<'a>(
requirement: &'a Requirement,
source_name: Option<&'a PackageName>,
fork_locals: &'a ForkLocals,
) -> impl Iterator<Item = Result<Self, ResolveError>> + 'a {
// Add the package, plus any extra variants.
iter::once(None)
.chain(requirement.extras.clone().into_iter().map(Some))
.map(|extra| PubGrubRequirement::from_requirement(requirement, extra, fork_locals))
.map(|extra| PubGrubRequirement::from_requirement(requirement, extra))
.filter_map_ok(move |requirement| {
let PubGrubRequirement {
package,
version,
specifier,
url,
local,
} = requirement;
match &*package {
PubGrubPackageInner::Package { name, .. } => {
@ -55,15 +55,15 @@ impl PubGrubDependency {
Some(PubGrubDependency {
package: package.clone(),
version: version.clone(),
specifier,
url,
local,
})
}
PubGrubPackageInner::Marker { .. } => Some(PubGrubDependency {
package: package.clone(),
version: version.clone(),
specifier,
url,
local,
}),
PubGrubPackageInner::Extra { name, .. } => {
debug_assert!(
@ -73,8 +73,8 @@ impl PubGrubDependency {
Some(PubGrubDependency {
package: package.clone(),
version: version.clone(),
specifier,
url: None,
local: None,
})
}
_ => None,
@ -88,8 +88,8 @@ impl PubGrubDependency {
pub(crate) struct PubGrubRequirement {
pub(crate) package: PubGrubPackage,
pub(crate) version: Range<Version>,
pub(crate) specifier: Option<VersionSpecifiers>,
pub(crate) url: Option<VerbatimParsedUrl>,
pub(crate) local: Option<Version>,
}
impl PubGrubRequirement {
@ -98,11 +98,10 @@ impl PubGrubRequirement {
pub(crate) fn from_requirement(
requirement: &Requirement,
extra: Option<ExtraName>,
fork_locals: &ForkLocals,
) -> Result<Self, ResolveError> {
let (verbatim_url, parsed_url) = match &requirement.source {
RequirementSource::Registry { specifier, .. } => {
return Self::from_registry_requirement(specifier, extra, requirement, fork_locals);
return Self::from_registry_requirement(specifier, extra, requirement);
}
RequirementSource::Url {
subdirectory,
@ -165,11 +164,11 @@ impl PubGrubRequirement {
requirement.marker.clone(),
),
version: Range::full(),
specifier: None,
url: Some(VerbatimParsedUrl {
parsed_url,
verbatim: verbatim_url.clone(),
}),
local: None,
})
}
@ -177,26 +176,8 @@ impl PubGrubRequirement {
specifier: &VersionSpecifiers,
extra: Option<ExtraName>,
requirement: &Requirement,
fork_locals: &ForkLocals,
) -> Result<PubGrubRequirement, ResolveError> {
// If the specifier is an exact version and the user requested a local version for this
// fork that's more precise than the specifier, use the local version instead.
let version = if let Some(local) = fork_locals.get(&requirement.name) {
specifier
.iter()
.map(|specifier| {
ForkLocals::map(local, specifier)
.map_err(ResolveError::InvalidVersion)
.and_then(|specifier| {
Ok(PubGrubSpecifier::from_pep440_specifier(&specifier)?)
})
})
.fold_ok(Range::full(), |range, specifier| {
range.intersection(&specifier.into())
})?
} else {
PubGrubSpecifier::from_pep440_specifiers(specifier)?.into()
};
let version = PubGrubSpecifier::from_pep440_specifiers(specifier)?.into();
let requirement = Self {
package: PubGrubPackage::from_package(
@ -204,9 +185,9 @@ impl PubGrubRequirement {
extra,
requirement.marker.clone(),
),
version,
specifier: Some(specifier.clone()),
url: None,
local: None,
version,
};
Ok(requirement)

View file

@ -3,24 +3,74 @@ use std::str::FromStr;
use distribution_filename::{SourceDistFilename, WheelFilename};
use distribution_types::RemoteSource;
use pep440_rs::{Operator, Version, VersionSpecifier, VersionSpecifierBuildError};
use pep508_rs::PackageName;
use pep508_rs::{MarkerEnvironment, MarkerTree, PackageName};
use pypi_types::RequirementSource;
use rustc_hash::FxHashMap;
/// A map of package names to their associated, required local versions in a given fork.
#[derive(Debug, Default, Clone)]
pub(crate) struct ForkLocals(FxHashMap<PackageName, Version>);
use crate::{marker::is_disjoint, DependencyMode, Manifest, ResolverMarkers};
impl ForkLocals {
/// Insert the local [`Version`] to which a package is pinned for this fork.
pub(crate) fn insert(&mut self, package_name: PackageName, local: Version) {
assert!(local.is_local());
self.0.insert(package_name, local);
/// A map of package names to their associated, required local versions across all forks.
#[derive(Debug, Default, Clone)]
pub(crate) struct Locals(FxHashMap<PackageName, Vec<(Option<MarkerTree>, Version)>>);
impl Locals {
/// Determine the set of permitted local versions in the [`Manifest`].
pub(crate) fn from_manifest(
manifest: &Manifest,
markers: Option<&MarkerEnvironment>,
dependencies: DependencyMode,
) -> Self {
let mut required: FxHashMap<PackageName, Vec<_>> = FxHashMap::default();
// Add all direct requirements and constraints. There's no need to look for conflicts,
// since conflicts will be enforced by the solver.
for requirement in manifest.requirements(markers, dependencies) {
if let Some(local) = from_source(&requirement.source) {
required
.entry(requirement.name.clone())
.or_default()
.push((requirement.marker.clone(), local));
}
}
/// Return the local [`Version`] to which a package is pinned in this fork, if any.
pub(crate) fn get(&self, package_name: &PackageName) -> Option<&Version> {
self.0.get(package_name)
Self(required)
}
/// Return a list of local versions that are compatible with a package in the given fork.
pub(crate) fn get(
&self,
package_name: &PackageName,
markers: &ResolverMarkers,
) -> Vec<&Version> {
let Some(locals) = self.0.get(package_name) else {
return Vec::new();
};
match markers {
// If we are solving for a specific environment we already filtered
// compatible requirements `from_manifest`.
ResolverMarkers::SpecificEnvironment(_) => {
locals.first().map(|(_, local)| local).into_iter().collect()
}
// Return all locals that were requested with markers that are compatible
// with the current fork.
//
// Compatibility implies that the markers are not disjoint. The resolver will
// choose the most compatible local when it narrows to the specific fork.
ResolverMarkers::Fork(fork) => locals
.iter()
.filter(|(marker, _)| {
!marker
.as_ref()
.is_some_and(|marker| is_disjoint(fork, marker))
})
.map(|(_, local)| local)
.collect(),
// If we haven't forked yet, all locals are potentially compatible.
ResolverMarkers::Universal => locals.iter().map(|(_, local)| local).collect(),
}
}
/// Given a specifier that may include the version _without_ a local segment, return a specifier
@ -190,7 +240,7 @@ mod tests {
use pypi_types::ParsedUrl;
use pypi_types::RequirementSource;
use super::{from_source, ForkLocals};
use super::{from_source, Locals};
#[test]
fn extract_locals() -> Result<()> {
@ -251,7 +301,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?
);
@ -260,7 +310,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::NotEqual, Version::from_str("1.0.0")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::NotEqual, Version::from_str("1.0.0+local")?)?
);
@ -269,7 +319,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::LessThanEqual, Version::from_str("1.0.0")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?
);
@ -278,7 +328,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::GreaterThan, Version::from_str("1.0.0")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::GreaterThan, Version::from_str("1.0.0")?)?
);
@ -287,7 +337,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::ExactEqual, Version::from_str("1.0.0")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::ExactEqual, Version::from_str("1.0.0")?)?
);
@ -296,7 +346,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?
);
@ -305,7 +355,7 @@ mod tests {
let specifier =
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+other")?)?;
assert_eq!(
ForkLocals::map(&local, &specifier)?,
Locals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+other")?)?
);

View file

@ -26,7 +26,7 @@ use distribution_types::{
IncompatibleWheel, IndexLocations, InstalledDist, PythonRequirementKind, RemoteSource,
ResolvedDist, ResolvedDistRef, SourceDist, VersionOrUrlRef,
};
pub(crate) use locals::ForkLocals;
pub(crate) use locals::Locals;
use pep440_rs::{Version, MIN_VERSION};
use pep508_rs::MarkerTree;
use platform_tags::Tags;
@ -96,6 +96,7 @@ struct ResolverState<InstalledPackages: InstalledPackagesProvider> {
git: GitResolver,
exclusions: Exclusions,
urls: Urls,
locals: Locals,
dependency_mode: DependencyMode,
hasher: HashStrategy,
markers: ResolverMarkers,
@ -215,6 +216,11 @@ impl<Provider: ResolverProvider, InstalledPackages: InstalledPackagesProvider>
git,
options.dependency_mode,
)?,
locals: Locals::from_manifest(
&manifest,
markers.marker_environment(),
options.dependency_mode,
),
project: manifest.project,
requirements: manifest.requirements,
constraints: manifest.constraints,
@ -309,7 +315,6 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
next: root,
pins: FilePins::default(),
fork_urls: ForkUrls::default(),
fork_locals: ForkLocals::default(),
priorities: PubGrubPriorities::default(),
added_dependencies: FxHashMap::default(),
markers: self.markers.clone(),
@ -492,7 +497,6 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
&state.next,
&version,
&state.fork_urls,
&state.fork_locals,
&state.markers,
state.requires_python.as_ref(),
)?;
@ -512,17 +516,19 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
for_package.as_deref(),
&version,
&self.urls,
&self.locals,
dependencies.clone(),
&self.git,
self.selector.resolution_strategy(),
)?;
// Emit a request to fetch the metadata for each registry package.
for dependency in &dependencies {
let PubGrubDependency {
package,
version: _,
specifier: _,
url: _,
local: _,
} = dependency;
let url = package.name().and_then(|name| state.fork_urls.get(name));
self.visit_package(package, url, &request_sink)?;
@ -581,11 +587,11 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
}
forked_state.markers = ResolverMarkers::Fork(combined_markers);
forked_state.add_package_version_dependencies(
for_package.as_deref(),
&version,
&self.urls,
&self.locals,
fork.dependencies.clone(),
&self.git,
self.selector.resolution_strategy(),
@ -595,8 +601,8 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
let PubGrubDependency {
package,
version: _,
specifier: _,
url: _,
local: _,
} = dependency;
let url = package
.name()
@ -1100,18 +1106,10 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
package: &PubGrubPackage,
version: &Version,
fork_urls: &ForkUrls,
fork_locals: &ForkLocals,
markers: &ResolverMarkers,
requires_python: Option<&MarkerTree>,
) -> Result<ForkedDependencies, ResolveError> {
let result = self.get_dependencies(
package,
version,
fork_urls,
fork_locals,
markers,
requires_python,
);
let result = self.get_dependencies(package, version, fork_urls, markers, requires_python);
match markers {
ResolverMarkers::SpecificEnvironment(_) => result.map(|deps| match deps {
Dependencies::Available(deps) => ForkedDependencies::Unforked(deps),
@ -1128,7 +1126,6 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
package: &PubGrubPackage,
version: &Version,
fork_urls: &ForkUrls,
fork_locals: &ForkLocals,
markers: &ResolverMarkers,
requires_python: Option<&MarkerTree>,
) -> Result<Dependencies, ResolveError> {
@ -1148,14 +1145,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
requirements
.iter()
.flat_map(|requirement| {
PubGrubDependency::from_requirement(requirement, None, fork_locals)
// Keep track of local versions to propagate to transitive dependencies.
.map_ok(|dependency| PubGrubDependency {
local: locals::from_source(&requirement.source),
..dependency
})
})
.flat_map(|requirement| PubGrubDependency::from_requirement(requirement, None))
.collect::<Result<Vec<_>, _>>()?
}
PubGrubPackageInner::Package {
@ -1283,7 +1273,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
let mut dependencies = requirements
.iter()
.flat_map(|requirement| {
PubGrubDependency::from_requirement(requirement, Some(name), fork_locals)
PubGrubDependency::from_requirement(requirement, Some(name))
})
.collect::<Result<Vec<_>, _>>()?;
@ -1302,8 +1292,8 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
marker: marker.clone(),
}),
version: Range::singleton(version.clone()),
specifier: None,
url: None,
local: None,
});
}
}
@ -1325,8 +1315,8 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
marker: marker.cloned(),
}),
version: Range::singleton(version.clone()),
specifier: None,
url: None,
local: None,
})
.collect(),
))
@ -1353,8 +1343,8 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
marker: marker.cloned(),
}),
version: Range::singleton(version.clone()),
specifier: None,
url: None,
local: None,
})
})
.collect(),
@ -1379,8 +1369,8 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
marker: marker.cloned(),
}),
version: Range::singleton(version.clone()),
specifier: None,
url: None,
local: None,
})
})
.collect(),
@ -1983,8 +1973,6 @@ struct ForkState {
/// one URL per package. By prioritizing direct URL dependencies over registry dependencies,
/// this map is populated for all direct URL packages before we look at any registry packages.
fork_urls: ForkUrls,
/// The local versions required by packages in this fork.
fork_locals: ForkLocals,
/// When dependencies for a package are retrieved, this map of priorities
/// is updated based on how each dependency was specified. Certain types
/// of dependencies have more "priority" than others (like direct URL
@ -2034,16 +2022,17 @@ impl ForkState {
for_package: Option<&str>,
version: &Version,
urls: &Urls,
dependencies: Vec<PubGrubDependency>,
locals: &Locals,
mut dependencies: Vec<PubGrubDependency>,
git: &GitResolver,
resolution_strategy: &ResolutionStrategy,
) -> Result<(), ResolveError> {
for dependency in &dependencies {
for dependency in &mut dependencies {
let PubGrubDependency {
package,
version,
specifier,
url,
local,
} = dependency;
let mut has_url = false;
@ -2057,11 +2046,36 @@ impl ForkState {
has_url = true;
};
// `PubGrubDependency` also gives us a local version if specified by the user.
// Keep track of which local version we will be using in this fork for transitive
// dependencies.
if let Some(local) = local {
self.fork_locals.insert(name.clone(), local.clone());
// If the specifier is an exact version and the user requested a local version for this
// fork that's more precise than the specifier, use the local version instead.
if let Some(specifier) = specifier {
let locals = locals.get(name, &self.markers);
// Prioritize local versions over the original version range.
if !locals.is_empty() {
*version = Range::empty();
}
// It's possible that there are multiple matching local versions requested with
// different marker expressions. All of these are potentially compatible until we
// narrow to a specific fork.
for local in locals {
let local = specifier
.iter()
.map(|specifier| {
Locals::map(local, specifier)
.map_err(ResolveError::InvalidVersion)
.and_then(|specifier| {
Ok(PubGrubSpecifier::from_pep440_specifier(&specifier)?)
})
})
.fold_ok(Range::full(), |range, specifier| {
range.intersection(&specifier.into())
})?;
// Add the local version.
*version = version.union(&local);
}
}
}
@ -2096,8 +2110,8 @@ impl ForkState {
let PubGrubDependency {
package,
version,
specifier: _,
url: _,
local: _,
} = dependency;
(package, version)
}),

View file

@ -6708,6 +6708,7 @@ fn universal_multi_version() -> Result<()> {
Ok(())
}
// Requested distinct local versions with disjoint markers.
#[test]
fn universal_disjoint_locals() -> Result<()> {
let context = TestContext::new("3.12");
@ -6764,6 +6765,8 @@ fn universal_disjoint_locals() -> Result<()> {
Ok(())
}
// Requested distinct local versions with disjoint markers of a package
// that is also present as a transitive dependency.
#[test]
fn universal_transitive_disjoint_locals() -> Result<()> {
let context = TestContext::new("3.12");
@ -6776,7 +6779,7 @@ fn universal_transitive_disjoint_locals() -> Result<()> {
torchvision==0.15.1
"})?;
// The marker expressions on the output here are incorrect due to https://github.com/astral-sh/uv/issues/5086,
// Some marker expressions on the output here are missing due to https://github.com/astral-sh/uv/issues/5086,
// but the local versions are still respected correctly.
uv_snapshot!(context.filters(), windows_filters=false, context.pip_compile()
.arg("requirements.in")
@ -6842,6 +6845,581 @@ fn universal_transitive_disjoint_locals() -> Result<()> {
Ok(())
}
/// Prefer local versions for dependencies of path requirements.
#[test]
fn universal_local_path_requirement() -> Result<()> {
let context = TestContext::new("3.12");
let pyproject_toml = context.temp_dir.child("pyproject.toml");
pyproject_toml.write_str(indoc! {r#"
[project]
name = "example"
version = "0.0.0"
dependencies = [
"torch==2.0.0+cu118"
]
requires-python = ">=3.11"
"#})?;
let requirements_in = context.temp_dir.child("requirements.in");
requirements_in.write_str(indoc! {"
torch==2.0.0
.
"})?;
uv_snapshot!(context.pip_compile()
.arg("requirements.in")
.arg("--universal")
.arg("--find-links")
.arg("https://download.pytorch.org/whl/torch_stable.html"), @r###"
success: true
exit_code: 0
----- stdout -----
# This file was autogenerated by uv via the following command:
# uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal
cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
.
# via -r requirements.in
filelock==3.13.1
# via
# torch
# triton
jinja2==3.1.3
# via torch
lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
markupsafe==2.1.5
# via jinja2
mpmath==1.3.0
# via sympy
networkx==3.2.1
# via torch
sympy==1.12
# via torch
torch==2.0.0+cu118
# via
# -r requirements.in
# example
# triton
triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
typing-extensions==4.10.0
# via torch
----- stderr -----
Resolved 12 packages in [TIME]
"###
);
Ok(())
}
/// If a dependency requests a local version with an overlapping marker expression,
/// we should prefer the local in all cases.
#[test]
fn universal_overlapping_local_requirement() -> Result<()> {
let context = TestContext::new("3.12");
let pyproject_toml = context.temp_dir.child("pyproject.toml");
pyproject_toml.write_str(indoc! {r#"
[project]
name = "example"
version = "0.0.0"
dependencies = [
"torch==2.0.0+cu118 ; platform_machine == 'x86_64'"
]
requires-python = ">=3.11"
"#})?;
let requirements_in = context.temp_dir.child("requirements.in");
requirements_in.write_str(indoc! {"
torch==2.0.0
.
"})?;
uv_snapshot!(context.pip_compile()
.arg("requirements.in")
.arg("--universal")
.arg("--find-links")
.arg("https://download.pytorch.org/whl/torch_stable.html"), @r###"
success: true
exit_code: 0
----- stdout -----
# This file was autogenerated by uv via the following command:
# uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal
cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
.
# via -r requirements.in
filelock==3.13.1
# via
# torch
# triton
jinja2==3.1.3
# via torch
lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
markupsafe==2.1.5
# via jinja2
mpmath==1.3.0
# via sympy
networkx==3.2.1
# via torch
sympy==1.12
# via torch
torch==2.0.0+cu118
# via
# -r requirements.in
# example
# triton
triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
typing-extensions==4.10.0
# via torch
----- stderr -----
Resolved 12 packages in [TIME]
"###
);
Ok(())
}
/// If a dependency requests distinct local versions with disjoint marker expressions,
/// we should fork the root requirement.
#[test]
fn universal_disjoint_local_requirement() -> Result<()> {
let context = TestContext::new("3.12");
let pyproject_toml = context.temp_dir.child("pyproject.toml");
pyproject_toml.write_str(indoc! {r#"
[project]
name = "example"
version = "0.0.0"
dependencies = [
"torch==2.0.0+cu118 ; platform_machine == 'x86_64'",
"torch==2.0.0+cpu ; platform_machine != 'x86_64'"
]
requires-python = ">=3.11"
"#})?;
let requirements_in = context.temp_dir.child("requirements.in");
requirements_in.write_str(indoc! {"
torch==2.0.0
.
"})?;
// Some marker expressions on the output here are missing due to https://github.com/astral-sh/uv/issues/5086,
// but the local versions are still respected correctly.
uv_snapshot!(context.pip_compile()
.arg("requirements.in")
.arg("--universal")
.arg("--find-links")
.arg("https://download.pytorch.org/whl/torch_stable.html"), @r###"
success: true
exit_code: 0
----- stdout -----
# This file was autogenerated by uv via the following command:
# uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal
cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
.
# via -r requirements.in
filelock==3.13.1
# via
# torch
# triton
jinja2==3.1.3
# via torch
lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
markupsafe==2.1.5
# via jinja2
mpmath==1.3.0
# via sympy
networkx==3.2.1
# via torch
sympy==1.12
# via torch
torch==2.0.0+cpu
# via
# -r requirements.in
# example
torch==2.0.0+cu118
# via
# -r requirements.in
# example
# triton
triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
typing-extensions==4.10.0
# via torch
----- stderr -----
Resolved 13 packages in [TIME]
"###
);
Ok(())
}
/// If a dependency requests distinct local versions and non-local versions with disjoint marker
/// expressions, we should fork the root requirement.
#[test]
fn universal_disjoint_base_or_local_requirement() -> Result<()> {
let context = TestContext::new("3.12");
let pyproject_toml = context.temp_dir.child("pyproject.toml");
pyproject_toml.write_str(indoc! {r#"
[project]
name = "example"
version = "0.0.0"
dependencies = [
"torch==2.0.0; python_version < '3.10'",
"torch==2.0.0+cu118 ; python_version >= '3.10' and python_version <= '3.12'",
"torch==2.0.0+cpu ; python_version > '3.12'"
]
requires-python = ">=3.11"
"#})?;
let requirements_in = context.temp_dir.child("requirements.in");
requirements_in.write_str(indoc! {"
torch==2.0.0
.
"})?;
// Some marker expressions on the output here are missing due to https://github.com/astral-sh/uv/issues/5086,
// but the local versions are still respected correctly.
uv_snapshot!(context.pip_compile()
.arg("requirements.in")
.arg("--universal")
.arg("--find-links")
.arg("https://download.pytorch.org/whl/torch_stable.html"), @r###"
success: true
exit_code: 0
----- stdout -----
# This file was autogenerated by uv via the following command:
# uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal
cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
.
# via -r requirements.in
filelock==3.13.1
# via
# torch
# triton
jinja2==3.1.3
# via torch
lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
markupsafe==2.1.5
# via jinja2
mpmath==1.3.0
# via sympy
networkx==3.2.1
# via torch
sympy==1.12
# via torch
torch==2.0.0+cpu
# via
# -r requirements.in
# example
torch==2.0.0+cu118
# via
# -r requirements.in
# example
# triton
triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
typing-extensions==4.10.0
# via torch
----- stderr -----
Resolved 13 packages in [TIME]
"###
);
Ok(())
}
/// If a dependency requests a local version with an overlapping marker expression
/// that form a nested fork, we should prefer the local in both children of the outer
/// fork.
#[test]
fn universal_nested_overlapping_local_requirement() -> Result<()> {
let context = TestContext::new("3.12");
let pyproject_toml = context.temp_dir.child("pyproject.toml");
pyproject_toml.write_str(indoc! {r#"
[project]
name = "example"
version = "0.0.0"
dependencies = [
"torch==2.0.0+cu118 ; platform_machine == 'x86_64' and os_name == 'Linux'"
]
requires-python = ">=3.11"
"#})?;
let requirements_in = context.temp_dir.child("requirements.in");
requirements_in.write_str(indoc! {"
torch==2.0.0 ; platform_machine == 'x86_64'
torch==2.3.0 ; platform_machine != 'x86_64'
.
"})?;
uv_snapshot!(context.pip_compile()
.arg("requirements.in")
.arg("--universal")
.arg("--find-links")
.arg("https://download.pytorch.org/whl/torch_stable.html"), @r###"
success: true
exit_code: 0
----- stdout -----
# This file was autogenerated by uv via the following command:
# uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal
cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
.
# via -r requirements.in
filelock==3.13.1
# via
# torch
# triton
fsspec==2024.3.1 ; platform_machine != 'x86_64'
# via torch
intel-openmp==2021.4.0 ; platform_machine != 'x86_64' and platform_system == 'Windows'
# via mkl
jinja2==3.1.3
# via torch
lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
markupsafe==2.1.5
# via jinja2
mkl==2021.4.0 ; platform_machine != 'x86_64' and platform_system == 'Windows'
# via torch
mpmath==1.3.0
# via sympy
networkx==3.2.1
# via torch
sympy==1.12
# via torch
tbb==2021.11.0 ; platform_machine != 'x86_64' and platform_system == 'Windows'
# via mkl
torch==2.3.0 ; platform_machine != 'x86_64'
# via -r requirements.in
torch==2.0.0+cu118 ; platform_machine == 'x86_64'
# via
# -r requirements.in
# example
# triton
triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
typing-extensions==4.10.0
# via torch
----- stderr -----
Resolved 17 packages in [TIME]
"###
);
// A similar case, except the nested marker is now on the path requirement.
let requirements_in = context.temp_dir.child("requirements.in");
requirements_in.write_str(indoc! {"
torch==2.0.0 ; platform_machine == 'x86_64'
torch==2.3.0 ; platform_machine != 'x86_64'
. ; os_name == 'Linux'
"})?;
let pyproject_toml = context.temp_dir.child("pyproject.toml");
pyproject_toml.write_str(indoc! {r#"
[project]
name = "example"
version = "0.0.0"
dependencies = [
"torch==2.0.0+cu118 ; platform_machine == 'x86_64'",
]
requires-python = ">=3.11"
"#})?;
uv_snapshot!(context.pip_compile()
.arg("requirements.in")
.arg("--universal")
.arg("--find-links")
.arg("https://download.pytorch.org/whl/torch_stable.html"), @r###"
success: true
exit_code: 0
----- stdout -----
# This file was autogenerated by uv via the following command:
# uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal
cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
. ; os_name == 'Linux'
# via -r requirements.in
filelock==3.13.1
# via
# torch
# triton
fsspec==2024.3.1 ; platform_machine != 'x86_64'
# via torch
intel-openmp==2021.4.0 ; platform_machine != 'x86_64' and platform_system == 'Windows'
# via mkl
jinja2==3.1.3
# via torch
lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
markupsafe==2.1.5
# via jinja2
mkl==2021.4.0 ; platform_machine != 'x86_64' and platform_system == 'Windows'
# via torch
mpmath==1.3.0
# via sympy
networkx==3.2.1
# via torch
sympy==1.12
# via torch
tbb==2021.11.0 ; platform_machine != 'x86_64' and platform_system == 'Windows'
# via mkl
torch==2.3.0 ; platform_machine != 'x86_64'
# via -r requirements.in
torch==2.0.0+cu118 ; platform_machine == 'x86_64'
# via
# -r requirements.in
# example
# triton
triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
typing-extensions==4.10.0
# via torch
----- stderr -----
Resolved 17 packages in [TIME]
"###
);
Ok(())
}
/// If a dependency requests distinct local versions with disjoint marker expressions
/// that form a nested fork, we should create a nested fork.
#[test]
fn universal_nested_disjoint_local_requirement() -> Result<()> {
let context = TestContext::new("3.12");
let pyproject_toml = context.temp_dir.child("pyproject.toml");
pyproject_toml.write_str(indoc! {r#"
[project]
name = "example"
version = "0.0.0"
dependencies = [
"torch==2.0.0+cu118 ; platform_machine == 'x86_64'",
"torch==2.0.0+cpu ; platform_machine != 'x86_64'"
]
requires-python = ">=3.11"
"#})?;
let requirements_in = context.temp_dir.child("requirements.in");
requirements_in.write_str(indoc! {"
torch==2.0.0 ; os_name == 'Linux'
torch==2.3.0 ; os_name != 'Linux'
. ; os_name == 'Linux'
"})?;
// Some marker expressions on the output here are missing due to https://github.com/astral-sh/uv/issues/5086,
// but the local versions are still respected correctly.
uv_snapshot!(context.pip_compile()
.arg("requirements.in")
.arg("--universal")
.arg("--find-links")
.arg("https://download.pytorch.org/whl/torch_stable.html"), @r###"
success: true
exit_code: 0
----- stdout -----
# This file was autogenerated by uv via the following command:
# uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal
cmake==3.28.4 ; os_name == 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
. ; os_name == 'Linux'
# via -r requirements.in
filelock==3.13.1
# via
# torch
# triton
fsspec==2024.3.1 ; os_name != 'Linux'
# via torch
intel-openmp==2021.4.0 ; os_name != 'Linux' and platform_system == 'Windows'
# via mkl
jinja2==3.1.3
# via torch
lit==18.1.2 ; os_name == 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
markupsafe==2.1.5
# via jinja2
mkl==2021.4.0 ; os_name != 'Linux' and platform_system == 'Windows'
# via torch
mpmath==1.3.0
# via sympy
networkx==3.2.1
# via torch
nvidia-cublas-cu12==12.1.3.1 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux'
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.1.105 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
nvidia-cuda-runtime-cu12==12.1.105 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
nvidia-cudnn-cu12==8.9.2.26 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
nvidia-cufft-cu12==11.0.2.54 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
nvidia-curand-cu12==10.3.2.106 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
nvidia-cusolver-cu12==11.4.5.107 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
nvidia-cusparse-cu12==12.1.0.106 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux'
# via
# nvidia-cusolver-cu12
# torch
nvidia-nccl-cu12==2.20.5 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
nvidia-nvjitlink-cu12==12.4.99 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux'
# via
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
sympy==1.12
# via torch
tbb==2021.11.0 ; os_name != 'Linux' and platform_system == 'Windows'
# via mkl
torch==2.0.0+cu118 ; os_name == 'Linux'
# via
# -r requirements.in
# example
# triton
torch==2.3.0 ; os_name != 'Linux'
# via -r requirements.in
torch==2.0.0+cpu ; os_name == 'Linux'
# via
# -r requirements.in
# example
triton==2.0.0 ; os_name == 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
typing-extensions==4.10.0
# via torch
----- stderr -----
Resolved 30 packages in [TIME]
"###
);
Ok(())
}
/// Perform a universal resolution that requires narrowing the supported Python range in one of the
/// fork branches.
///