Allow conflicting locals when forking (#5104)

## Summary

Currently, the `Locals` type relies on there being a single local
version for a given package. With marker expressions this may not be
true, a similar problem to https://github.com/astral-sh/uv/pull/4435.
This changes the `Locals` type to `ForkLocals`, which tracks locals for
a given fork. Local versions are now tracked on `PubGrubRequirement`
before forking.

Resolves https://github.com/astral-sh/uv/issues/4580.
This commit is contained in:
Ibraheem Ahmed 2024-07-16 12:57:30 -04:00 committed by GitHub
parent 048ae8f7f3
commit d583847f8b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 269 additions and 116 deletions

View file

@ -12,7 +12,7 @@ use pypi_types::{
use uv_normalize::{ExtraName, PackageName}; use uv_normalize::{ExtraName, PackageName};
use crate::pubgrub::{PubGrubPackage, PubGrubPackageInner}; use crate::pubgrub::{PubGrubPackage, PubGrubPackageInner};
use crate::resolver::Locals; use crate::resolver::ForkLocals;
use crate::{PubGrubSpecifier, ResolveError}; use crate::{PubGrubSpecifier, ResolveError};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -23,23 +23,26 @@ pub(crate) struct PubGrubDependency {
/// even if this field is None where there is an override with a URL or there is a different /// even if this field is None where there is an override with a URL or there is a different
/// requirement or constraint for the same package that has a URL. /// requirement or constraint for the same package that has a URL.
pub(crate) url: Option<VerbatimParsedUrl>, pub(crate) url: Option<VerbatimParsedUrl>,
/// The local version for this requirement, if specified.
pub(crate) local: Option<Version>,
} }
impl PubGrubDependency { impl PubGrubDependency {
pub(crate) fn from_requirement<'a>( pub(crate) fn from_requirement<'a>(
requirement: &'a Requirement, requirement: &'a Requirement,
source_name: Option<&'a PackageName>, source_name: Option<&'a PackageName>,
locals: &'a Locals, fork_locals: &'a ForkLocals,
) -> impl Iterator<Item = Result<Self, ResolveError>> + 'a { ) -> impl Iterator<Item = Result<Self, ResolveError>> + 'a {
// Add the package, plus any extra variants. // Add the package, plus any extra variants.
iter::once(None) iter::once(None)
.chain(requirement.extras.clone().into_iter().map(Some)) .chain(requirement.extras.clone().into_iter().map(Some))
.map(|extra| PubGrubRequirement::from_requirement(requirement, extra, locals)) .map(|extra| PubGrubRequirement::from_requirement(requirement, extra, fork_locals))
.filter_map_ok(move |requirement| { .filter_map_ok(move |requirement| {
let PubGrubRequirement { let PubGrubRequirement {
package, package,
version, version,
url, url,
local,
} = requirement; } = requirement;
match &*package { match &*package {
PubGrubPackageInner::Package { name, .. } => { PubGrubPackageInner::Package { name, .. } => {
@ -53,12 +56,14 @@ impl PubGrubDependency {
package: package.clone(), package: package.clone(),
version: version.clone(), version: version.clone(),
url, url,
local,
}) })
} }
PubGrubPackageInner::Marker { .. } => Some(PubGrubDependency { PubGrubPackageInner::Marker { .. } => Some(PubGrubDependency {
package: package.clone(), package: package.clone(),
version: version.clone(), version: version.clone(),
url, url,
local,
}), }),
PubGrubPackageInner::Extra { name, .. } => { PubGrubPackageInner::Extra { name, .. } => {
debug_assert!( debug_assert!(
@ -69,6 +74,7 @@ impl PubGrubDependency {
package: package.clone(), package: package.clone(),
version: version.clone(), version: version.clone(),
url: None, url: None,
local: None,
}) })
} }
_ => None, _ => None,
@ -83,6 +89,7 @@ pub(crate) struct PubGrubRequirement {
pub(crate) package: PubGrubPackage, pub(crate) package: PubGrubPackage,
pub(crate) version: Range<Version>, pub(crate) version: Range<Version>,
pub(crate) url: Option<VerbatimParsedUrl>, pub(crate) url: Option<VerbatimParsedUrl>,
pub(crate) local: Option<Version>,
} }
impl PubGrubRequirement { impl PubGrubRequirement {
@ -91,11 +98,11 @@ impl PubGrubRequirement {
pub(crate) fn from_requirement( pub(crate) fn from_requirement(
requirement: &Requirement, requirement: &Requirement,
extra: Option<ExtraName>, extra: Option<ExtraName>,
locals: &Locals, fork_locals: &ForkLocals,
) -> Result<Self, ResolveError> { ) -> Result<Self, ResolveError> {
let (verbatim_url, parsed_url) = match &requirement.source { let (verbatim_url, parsed_url) = match &requirement.source {
RequirementSource::Registry { specifier, .. } => { RequirementSource::Registry { specifier, .. } => {
return Self::from_registry_requirement(specifier, extra, requirement, locals); return Self::from_registry_requirement(specifier, extra, requirement, fork_locals);
} }
RequirementSource::Url { RequirementSource::Url {
subdirectory, subdirectory,
@ -162,6 +169,7 @@ impl PubGrubRequirement {
parsed_url, parsed_url,
verbatim: verbatim_url.clone(), verbatim: verbatim_url.clone(),
}), }),
local: None,
}) })
} }
@ -169,15 +177,15 @@ impl PubGrubRequirement {
specifier: &VersionSpecifiers, specifier: &VersionSpecifiers,
extra: Option<ExtraName>, extra: Option<ExtraName>,
requirement: &Requirement, requirement: &Requirement,
locals: &Locals, fork_locals: &ForkLocals,
) -> Result<PubGrubRequirement, ResolveError> { ) -> Result<PubGrubRequirement, ResolveError> {
// If the specifier is an exact version, and the user requested a local version that's // If the specifier is an exact version and the user requested a local version for this
// more precise than the specifier, use the local version instead. // fork that's more precise than the specifier, use the local version instead.
let version = if let Some(expected) = locals.get(&requirement.name) { let version = if let Some(local) = fork_locals.get(&requirement.name) {
specifier specifier
.iter() .iter()
.map(|specifier| { .map(|specifier| {
Locals::map(expected, specifier) ForkLocals::map(local, specifier)
.map_err(ResolveError::InvalidVersion) .map_err(ResolveError::InvalidVersion)
.and_then(|specifier| { .and_then(|specifier| {
Ok(PubGrubSpecifier::from_pep440_specifier(&specifier)?) Ok(PubGrubSpecifier::from_pep440_specifier(&specifier)?)
@ -198,7 +206,9 @@ impl PubGrubRequirement {
), ),
version, version,
url: None, url: None,
local: None,
}; };
Ok(requirement) Ok(requirement)
} }
} }

View file

@ -1,46 +1,26 @@
use std::iter;
use std::str::FromStr; use std::str::FromStr;
use rustc_hash::FxHashMap;
use distribution_filename::{SourceDistFilename, WheelFilename}; use distribution_filename::{SourceDistFilename, WheelFilename};
use distribution_types::RemoteSource; use distribution_types::RemoteSource;
use pep440_rs::{Operator, Version, VersionSpecifier, VersionSpecifierBuildError}; use pep440_rs::{Operator, Version, VersionSpecifier, VersionSpecifierBuildError};
use pep508_rs::MarkerEnvironment; use pep508_rs::PackageName;
use pypi_types::RequirementSource; use pypi_types::RequirementSource;
use uv_normalize::PackageName; use rustc_hash::FxHashMap;
use crate::{DependencyMode, Manifest}; /// A map of package names to their associated, required local versions in a given fork.
#[derive(Debug, Default, Clone)]
pub(crate) struct ForkLocals(FxHashMap<PackageName, Version>);
#[derive(Debug, Default)] impl ForkLocals {
pub(crate) struct Locals { /// Insert the local [`Version`] to which a package is pinned for this fork.
/// A map of package names to their associated, required local versions. pub(crate) fn insert(&mut self, package_name: PackageName, local: Version) {
required: FxHashMap<PackageName, Version>, assert!(local.is_local());
} self.0.insert(package_name, local);
impl Locals {
/// Determine the set of permitted local versions in the [`Manifest`].
pub(crate) fn from_manifest(
manifest: &Manifest,
markers: Option<&MarkerEnvironment>,
dependencies: DependencyMode,
) -> Self {
let mut required: FxHashMap<PackageName, Version> = FxHashMap::default();
// Add all direct requirements and constraints. There's no need to look for conflicts,
// since conflicts will be enforced by the solver.
for requirement in manifest.requirements(markers, dependencies) {
for local in iter_locals(&requirement.source) {
required.insert(requirement.name.clone(), local);
}
}
Self { required }
} }
/// Return the local [`Version`] to which a package is pinned, if any. /// Return the local [`Version`] to which a package is pinned in this fork, if any.
pub(crate) fn get(&self, package: &PackageName) -> Option<&Version> { pub(crate) fn get(&self, package_name: &PackageName) -> Option<&Version> {
self.required.get(package) self.0.get(package_name)
} }
/// Given a specifier that may include the version _without_ a local segment, return a specifier /// Given a specifier that may include the version _without_ a local segment, return a specifier
@ -140,63 +120,61 @@ fn is_compatible(expected: &Version, provided: &Version) -> bool {
} }
} }
/// If a [`VersionSpecifier`] contains exact equality specifiers for a local version, returns an /// If a [`VersionSpecifier`] contains an exact equality specifier for a local version,
/// iterator over the local versions. /// returns the local version.
fn iter_locals(source: &RequirementSource) -> Box<dyn Iterator<Item = Version> + '_> { pub(crate) fn from_source(source: &RequirementSource) -> Option<Version> {
match source { match source {
// Extract all local versions from specifiers that require an exact version (e.g., // Extract all local versions from specifiers that require an exact version (e.g.,
// `==1.0.0+local`). // `==1.0.0+local`).
RequirementSource::Registry { RequirementSource::Registry {
specifier: version, .. specifier: version, ..
} => Box::new( } => version
version .iter()
.iter() .filter(|specifier| {
.filter(|specifier| { matches!(specifier.operator(), Operator::Equal | Operator::ExactEqual)
matches!(specifier.operator(), Operator::Equal | Operator::ExactEqual) })
}) .filter(|specifier| !specifier.version().local().is_empty())
.filter(|specifier| !specifier.version().local().is_empty()) .map(|specifier| specifier.version().clone())
.map(|specifier| specifier.version().clone()), // It's technically possible for there to be multiple local segments here.
), // For example, `a==1.0+foo,==1.0+bar`. However, in that case resolution
// will fail later.
.next(),
// Exact a local version from a URL, if it includes a fully-qualified filename (e.g., // Exact a local version from a URL, if it includes a fully-qualified filename (e.g.,
// `torch-2.2.1%2Bcu118-cp311-cp311-linux_x86_64.whl`). // `torch-2.2.1%2Bcu118-cp311-cp311-linux_x86_64.whl`).
RequirementSource::Url { url, .. } => Box::new( RequirementSource::Url { url, .. } => url
url.filename() .filename()
.ok() .ok()
.and_then(|filename| { .and_then(|filename| {
if let Ok(filename) = WheelFilename::from_str(&filename) { if let Ok(filename) = WheelFilename::from_str(&filename) {
Some(filename.version) Some(filename.version)
} else if let Ok(filename) = } else if let Ok(filename) =
SourceDistFilename::parsed_normalized_filename(&filename) SourceDistFilename::parsed_normalized_filename(&filename)
{ {
Some(filename.version) Some(filename.version)
} else { } else {
None None
} }
}) })
.into_iter() .filter(pep440_rs::Version::is_local),
.filter(pep440_rs::Version::is_local), RequirementSource::Git { .. } => None,
),
RequirementSource::Git { .. } => Box::new(iter::empty()),
RequirementSource::Path { RequirementSource::Path {
install_path: path, .. install_path: path, ..
} => Box::new( } => path
path.file_name() .file_name()
.and_then(|filename| { .and_then(|filename| {
let filename = filename.to_string_lossy(); let filename = filename.to_string_lossy();
if let Ok(filename) = WheelFilename::from_str(&filename) { if let Ok(filename) = WheelFilename::from_str(&filename) {
Some(filename.version) Some(filename.version)
} else if let Ok(filename) = } else if let Ok(filename) =
SourceDistFilename::parsed_normalized_filename(&filename) SourceDistFilename::parsed_normalized_filename(&filename)
{ {
Some(filename.version) Some(filename.version)
} else { } else {
None None
} }
}) })
.into_iter() .filter(pep440_rs::Version::is_local),
.filter(pep440_rs::Version::is_local), RequirementSource::Directory { .. } => None,
),
RequirementSource::Directory { .. } => Box::new(iter::empty()),
} }
} }
@ -212,7 +190,7 @@ mod tests {
use pypi_types::ParsedUrl; use pypi_types::ParsedUrl;
use pypi_types::RequirementSource; use pypi_types::RequirementSource;
use crate::resolver::locals::{iter_locals, Locals}; use super::{from_source, ForkLocals};
#[test] #[test]
fn extract_locals() -> Result<()> { fn extract_locals() -> Result<()> {
@ -220,7 +198,7 @@ mod tests {
let url = VerbatimUrl::from_url(Url::parse("https://example.com/foo-1.0.0+local.tar.gz")?); let url = VerbatimUrl::from_url(Url::parse("https://example.com/foo-1.0.0+local.tar.gz")?);
let source = let source =
RequirementSource::from_parsed_url(ParsedUrl::try_from(url.to_url()).unwrap(), url); RequirementSource::from_parsed_url(ParsedUrl::try_from(url.to_url()).unwrap(), url);
let locals: Vec<_> = iter_locals(&source).collect(); let locals: Vec<_> = from_source(&source).into_iter().collect();
assert_eq!(locals, vec![Version::from_str("1.0.0+local")?]); assert_eq!(locals, vec![Version::from_str("1.0.0+local")?]);
// Extract from a wheel in a URL. // Extract from a wheel in a URL.
@ -229,14 +207,14 @@ mod tests {
)?); )?);
let source = let source =
RequirementSource::from_parsed_url(ParsedUrl::try_from(url.to_url()).unwrap(), url); RequirementSource::from_parsed_url(ParsedUrl::try_from(url.to_url()).unwrap(), url);
let locals: Vec<_> = iter_locals(&source).collect(); let locals: Vec<_> = from_source(&source).into_iter().collect();
assert_eq!(locals, vec![Version::from_str("1.0.0+local")?]); assert_eq!(locals, vec![Version::from_str("1.0.0+local")?]);
// Don't extract anything if the URL is opaque. // Don't extract anything if the URL is opaque.
let url = VerbatimUrl::from_url(Url::parse("git+https://example.com/foo/bar")?); let url = VerbatimUrl::from_url(Url::parse("git+https://example.com/foo/bar")?);
let source = let source =
RequirementSource::from_parsed_url(ParsedUrl::try_from(url.to_url()).unwrap(), url); RequirementSource::from_parsed_url(ParsedUrl::try_from(url.to_url()).unwrap(), url);
let locals: Vec<_> = iter_locals(&source).collect(); let locals: Vec<_> = from_source(&source).into_iter().collect();
assert!(locals.is_empty()); assert!(locals.is_empty());
// Extract from `==` specifiers. // Extract from `==` specifiers.
@ -248,7 +226,7 @@ mod tests {
specifier: version, specifier: version,
index: None, index: None,
}; };
let locals: Vec<_> = iter_locals(&source).collect(); let locals: Vec<_> = from_source(&source).into_iter().collect();
assert_eq!(locals, vec![Version::from_str("1.0.0+local")?]); assert_eq!(locals, vec![Version::from_str("1.0.0+local")?]);
// Ignore other specifiers. // Ignore other specifiers.
@ -260,7 +238,7 @@ mod tests {
specifier: version, specifier: version,
index: None, index: None,
}; };
let locals: Vec<_> = iter_locals(&source).collect(); let locals: Vec<_> = from_source(&source).into_iter().collect();
assert!(locals.is_empty()); assert!(locals.is_empty());
Ok(()) Ok(())
@ -273,7 +251,7 @@ mod tests {
let specifier = let specifier =
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0")?)?; VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0")?)?;
assert_eq!( assert_eq!(
Locals::map(&local, &specifier)?, ForkLocals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)? VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?
); );
@ -282,7 +260,7 @@ mod tests {
let specifier = let specifier =
VersionSpecifier::from_version(Operator::NotEqual, Version::from_str("1.0.0")?)?; VersionSpecifier::from_version(Operator::NotEqual, Version::from_str("1.0.0")?)?;
assert_eq!( assert_eq!(
Locals::map(&local, &specifier)?, ForkLocals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::NotEqual, Version::from_str("1.0.0+local")?)? VersionSpecifier::from_version(Operator::NotEqual, Version::from_str("1.0.0+local")?)?
); );
@ -291,7 +269,7 @@ mod tests {
let specifier = let specifier =
VersionSpecifier::from_version(Operator::LessThanEqual, Version::from_str("1.0.0")?)?; VersionSpecifier::from_version(Operator::LessThanEqual, Version::from_str("1.0.0")?)?;
assert_eq!( assert_eq!(
Locals::map(&local, &specifier)?, ForkLocals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)? VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?
); );
@ -300,7 +278,7 @@ mod tests {
let specifier = let specifier =
VersionSpecifier::from_version(Operator::GreaterThan, Version::from_str("1.0.0")?)?; VersionSpecifier::from_version(Operator::GreaterThan, Version::from_str("1.0.0")?)?;
assert_eq!( assert_eq!(
Locals::map(&local, &specifier)?, ForkLocals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::GreaterThan, Version::from_str("1.0.0")?)? VersionSpecifier::from_version(Operator::GreaterThan, Version::from_str("1.0.0")?)?
); );
@ -309,7 +287,7 @@ mod tests {
let specifier = let specifier =
VersionSpecifier::from_version(Operator::ExactEqual, Version::from_str("1.0.0")?)?; VersionSpecifier::from_version(Operator::ExactEqual, Version::from_str("1.0.0")?)?;
assert_eq!( assert_eq!(
Locals::map(&local, &specifier)?, ForkLocals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::ExactEqual, Version::from_str("1.0.0")?)? VersionSpecifier::from_version(Operator::ExactEqual, Version::from_str("1.0.0")?)?
); );
@ -318,7 +296,7 @@ mod tests {
let specifier = let specifier =
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?; VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?;
assert_eq!( assert_eq!(
Locals::map(&local, &specifier)?, ForkLocals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)? VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?
); );
@ -327,7 +305,7 @@ mod tests {
let specifier = let specifier =
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+other")?)?; VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+other")?)?;
assert_eq!( assert_eq!(
Locals::map(&local, &specifier)?, ForkLocals::map(&local, &specifier)?,
VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+other")?)? VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+other")?)?
); );

View file

@ -26,7 +26,7 @@ use distribution_types::{
IncompatibleWheel, IndexLocations, InstalledDist, PythonRequirementKind, RemoteSource, IncompatibleWheel, IndexLocations, InstalledDist, PythonRequirementKind, RemoteSource,
ResolvedDist, ResolvedDistRef, SourceDist, VersionOrUrlRef, ResolvedDist, ResolvedDistRef, SourceDist, VersionOrUrlRef,
}; };
pub(crate) use locals::Locals; pub(crate) use locals::ForkLocals;
use pep440_rs::{Version, MIN_VERSION}; use pep440_rs::{Version, MIN_VERSION};
use pep508_rs::{MarkerEnvironment, MarkerTree}; use pep508_rs::{MarkerEnvironment, MarkerTree};
use platform_tags::Tags; use platform_tags::Tags;
@ -92,7 +92,6 @@ struct ResolverState<InstalledPackages: InstalledPackagesProvider> {
git: GitResolver, git: GitResolver,
exclusions: Exclusions, exclusions: Exclusions,
urls: Urls, urls: Urls,
locals: Locals,
dependency_mode: DependencyMode, dependency_mode: DependencyMode,
hasher: HashStrategy, hasher: HashStrategy,
/// When not set, the resolver is in "universal" mode. /// When not set, the resolver is in "universal" mode.
@ -200,7 +199,6 @@ impl<Provider: ResolverProvider, InstalledPackages: InstalledPackagesProvider>
selector: CandidateSelector::for_resolution(options, &manifest, markers), selector: CandidateSelector::for_resolution(options, &manifest, markers),
dependency_mode: options.dependency_mode, dependency_mode: options.dependency_mode,
urls: Urls::from_manifest(&manifest, markers, git, options.dependency_mode)?, urls: Urls::from_manifest(&manifest, markers, git, options.dependency_mode)?,
locals: Locals::from_manifest(&manifest, markers, options.dependency_mode),
project: manifest.project, project: manifest.project,
requirements: manifest.requirements, requirements: manifest.requirements,
constraints: manifest.constraints, constraints: manifest.constraints,
@ -295,6 +293,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
next: root, next: root,
pins: FilePins::default(), pins: FilePins::default(),
fork_urls: ForkUrls::default(), fork_urls: ForkUrls::default(),
fork_locals: ForkLocals::default(),
priorities: PubGrubPriorities::default(), priorities: PubGrubPriorities::default(),
added_dependencies: FxHashMap::default(), added_dependencies: FxHashMap::default(),
markers: MarkerTree::And(vec![]), markers: MarkerTree::And(vec![]),
@ -476,6 +475,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
&state.next, &state.next,
&version, &version,
&state.fork_urls, &state.fork_urls,
&state.fork_locals,
&state.markers, &state.markers,
state.requires_python.as_ref(), state.requires_python.as_ref(),
)?; )?;
@ -504,6 +504,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
package, package,
version: _, version: _,
url: _, url: _,
local: _,
} = dependency; } = dependency;
let url = package.name().and_then(|name| state.fork_urls.get(name)); let url = package.name().and_then(|name| state.fork_urls.get(name));
self.visit_package(package, url, &request_sink)?; self.visit_package(package, url, &request_sink)?;
@ -574,6 +575,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
package, package,
version: _, version: _,
url: _, url: _,
local: _,
} = dependency; } = dependency;
let url = package let url = package
.name() .name()
@ -1080,10 +1082,18 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
package: &PubGrubPackage, package: &PubGrubPackage,
version: &Version, version: &Version,
fork_urls: &ForkUrls, fork_urls: &ForkUrls,
fork_locals: &ForkLocals,
markers: &MarkerTree, markers: &MarkerTree,
requires_python: Option<&MarkerTree>, requires_python: Option<&MarkerTree>,
) -> Result<ForkedDependencies, ResolveError> { ) -> Result<ForkedDependencies, ResolveError> {
let result = self.get_dependencies(package, version, fork_urls, markers, requires_python); let result = self.get_dependencies(
package,
version,
fork_urls,
fork_locals,
markers,
requires_python,
);
if self.markers.is_some() { if self.markers.is_some() {
return result.map(|deps| match deps { return result.map(|deps| match deps {
Dependencies::Available(deps) => ForkedDependencies::Unforked(deps), Dependencies::Available(deps) => ForkedDependencies::Unforked(deps),
@ -1100,6 +1110,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
package: &PubGrubPackage, package: &PubGrubPackage,
version: &Version, version: &Version,
fork_urls: &ForkUrls, fork_urls: &ForkUrls,
fork_locals: &ForkLocals,
markers: &MarkerTree, markers: &MarkerTree,
requires_python: Option<&MarkerTree>, requires_python: Option<&MarkerTree>,
) -> Result<Dependencies, ResolveError> { ) -> Result<Dependencies, ResolveError> {
@ -1120,7 +1131,12 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
requirements requirements
.iter() .iter()
.flat_map(|requirement| { .flat_map(|requirement| {
PubGrubDependency::from_requirement(requirement, None, &self.locals) PubGrubDependency::from_requirement(requirement, None, fork_locals)
// Keep track of local versions to propagate to transitive dependencies.
.map_ok(|dependency| PubGrubDependency {
local: locals::from_source(&requirement.source),
..dependency
})
}) })
.collect::<Result<Vec<_>, _>>()? .collect::<Result<Vec<_>, _>>()?
} }
@ -1249,13 +1265,13 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
let mut dependencies = requirements let mut dependencies = requirements
.iter() .iter()
.flat_map(|requirement| { .flat_map(|requirement| {
PubGrubDependency::from_requirement(requirement, Some(name), &self.locals) PubGrubDependency::from_requirement(requirement, Some(name), fork_locals)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
// If a package has metadata for an enabled dependency group, // If a package has metadata for an enabled dependency group,
// add a dependency from it to the same package with the group // add a dependency from it to the same package with the group
// enabled. // enabled.
if extra.is_none() && dev.is_none() { if extra.is_none() && dev.is_none() {
for group in &self.dev { for group in &self.dev {
if !metadata.dev_dependencies.contains_key(group) { if !metadata.dev_dependencies.contains_key(group) {
@ -1269,6 +1285,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
}), }),
version: Range::singleton(version.clone()), version: Range::singleton(version.clone()),
url: None, url: None,
local: None,
}); });
} }
} }
@ -1291,6 +1308,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
}), }),
version: Range::singleton(version.clone()), version: Range::singleton(version.clone()),
url: None, url: None,
local: None,
}) })
.collect(), .collect(),
)) ))
@ -1318,6 +1336,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
}), }),
version: Range::singleton(version.clone()), version: Range::singleton(version.clone()),
url: None, url: None,
local: None,
}) })
}) })
.collect(), .collect(),
@ -1343,6 +1362,7 @@ impl<InstalledPackages: InstalledPackagesProvider> ResolverState<InstalledPackag
}), }),
version: Range::singleton(version.clone()), version: Range::singleton(version.clone()),
url: None, url: None,
local: None,
}) })
}) })
.collect(), .collect(),
@ -1931,6 +1951,8 @@ struct ForkState {
/// one URL per package. By prioritizing direct URL dependencies over registry dependencies, /// one URL per package. By prioritizing direct URL dependencies over registry dependencies,
/// this map is populated for all direct URL packages before we look at any registry packages. /// this map is populated for all direct URL packages before we look at any registry packages.
fork_urls: ForkUrls, fork_urls: ForkUrls,
/// The local versions required by packages in this fork.
fork_locals: ForkLocals,
/// When dependencies for a package are retrieved, this map of priorities /// When dependencies for a package are retrieved, this map of priorities
/// is updated based on how each dependency was specified. Certain types /// is updated based on how each dependency was specified. Certain types
/// of dependencies have more "priority" than others (like direct URL /// of dependencies have more "priority" than others (like direct URL
@ -1974,7 +1996,7 @@ struct ForkState {
impl ForkState { impl ForkState {
/// Add the dependencies for the selected version of the current package, checking for /// Add the dependencies for the selected version of the current package, checking for
/// self-dependencies, and handling URLs. /// self-dependencies, and handling URLs and locals.
fn add_package_version_dependencies( fn add_package_version_dependencies(
&mut self, &mut self,
for_package: Option<&str>, for_package: Option<&str>,
@ -1988,16 +2010,24 @@ impl ForkState {
package, package,
version, version,
url, url,
local,
} = dependency; } = dependency;
// From the [`Requirement`] to [`PubGrubDependency`] conversion, we get a URL if the
// requirement was a URL requirement. `Urls` applies canonicalization to this and
// override URLs to both URL and registry requirements, which we then check for
// conflicts using [`ForkUrl`].
if let Some(name) = package.name() { if let Some(name) = package.name() {
// From the [`Requirement`] to [`PubGrubDependency`] conversion, we get a URL if the
// requirement was a URL requirement. `Urls` applies canonicalization to this and
// override URLs to both URL and registry requirements, which we then check for
// conflicts using [`ForkUrl`].
if let Some(url) = urls.get_url(name, url.as_ref(), git)? { if let Some(url) = urls.get_url(name, url.as_ref(), git)? {
self.fork_urls.insert(name, url, &self.markers)?; self.fork_urls.insert(name, url, &self.markers)?;
}; };
// `PubGrubDependency` also gives us a local version if specified by the user.
// Keep track of which local version we will be using in this fork for transitive
// dependencies.
if let Some(local) = local {
self.fork_locals.insert(name.clone(), local.clone());
}
} }
if let Some(for_package) = for_package { if let Some(for_package) = for_package {
@ -2019,6 +2049,7 @@ impl ForkState {
package, package,
version, version,
url: _, url: _,
local: _,
} = dependency; } = dependency;
(package, version) (package, version)
}), }),

View file

@ -6708,6 +6708,140 @@ fn universal_multi_version() -> Result<()> {
Ok(()) Ok(())
} }
#[test]
fn universal_disjoint_locals() -> Result<()> {
let context = TestContext::new("3.12");
let requirements_in = context.temp_dir.child("requirements.in");
requirements_in.write_str(indoc::indoc! {r"
--find-links https://download.pytorch.org/whl/torch_stable.html
torch==2.0.0+cu118 ; platform_machine == 'x86_64'
torch==2.0.0+cpu ; platform_machine != 'x86_64'
"})?;
uv_snapshot!(context.filters(), windows_filters=false, context.pip_compile()
.arg("requirements.in")
.arg("--universal"), @r###"
success: true
exit_code: 0
----- stdout -----
# This file was autogenerated by uv via the following command:
# uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal
cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
filelock==3.13.1
# via
# torch
# triton
jinja2==3.1.3
# via torch
lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
markupsafe==2.1.5
# via jinja2
mpmath==1.3.0
# via sympy
networkx==3.2.1
# via torch
sympy==1.12
# via torch
torch==2.0.0+cpu ; platform_machine != 'x86_64'
# via -r requirements.in
torch==2.0.0+cu118 ; platform_machine == 'x86_64'
# via
# -r requirements.in
# triton
triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
typing-extensions==4.10.0
# via torch
----- stderr -----
Resolved 12 packages in [TIME]
"###
);
Ok(())
}
#[test]
fn universal_transitive_disjoint_locals() -> Result<()> {
let context = TestContext::new("3.12");
let requirements_in = context.temp_dir.child("requirements.in");
requirements_in.write_str(indoc::indoc! {r"
--find-links https://download.pytorch.org/whl/torch_stable.html
torch==2.0.0+cu118 ; platform_machine == 'x86_64'
torch==2.0.0+cpu ; platform_machine != 'x86_64'
torchvision==0.15.1
"})?;
// The marker expressions on the output here are incorrect due to https://github.com/astral-sh/uv/issues/5086,
// but the local versions are still respected correctly.
uv_snapshot!(context.filters(), windows_filters=false, context.pip_compile()
.arg("requirements.in")
.arg("--universal"), @r###"
success: true
exit_code: 0
----- stdout -----
# This file was autogenerated by uv via the following command:
# uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal
certifi==2024.2.2
# via requests
charset-normalizer==3.3.2
# via requests
cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
filelock==3.13.1
# via
# torch
# triton
idna==3.6
# via requests
jinja2==3.1.3
# via torch
lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via triton
markupsafe==2.1.5
# via jinja2
mpmath==1.3.0
# via sympy
networkx==3.2.1
# via torch
numpy==1.26.4
# via torchvision
pillow==10.2.0
# via torchvision
requests==2.31.0
# via torchvision
sympy==1.12
# via torch
torch==2.0.0+cpu
# via
# -r requirements.in
# torchvision
torch==2.0.0+cu118
# via
# -r requirements.in
# torchvision
# triton
torchvision==0.15.1
# via -r requirements.in
triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux'
# via torch
typing-extensions==4.10.0
# via torch
urllib3==2.2.1
# via requests
----- stderr -----
Resolved 20 packages in [TIME]
"###
);
Ok(())
}
/// Perform a universal resolution that requires narrowing the supported Python range in one of the /// Perform a universal resolution that requires narrowing the supported Python range in one of the
/// fork branches. /// fork branches.
/// ///