From bea8bc6c61d93c41e9bb7ffaf13c80fe2541d4c5 Mon Sep 17 00:00:00 2001 From: konsti Date: Tue, 23 Jul 2024 16:46:32 +0200 Subject: [PATCH] Stable sorting of requirements.txt in universal mode (#5334) The `RequirementsTxtComparator` was written assuming there is one distribution per package name. This changed with the universal resolution, which allows multiple versions or urls for the same package name. The sorting we emitted for these new entries was incidental. With this change, we properly sort these entries by name, version and then url in universal mode. This is an output format change for `--universal` users. --- crates/uv-resolver/src/resolution/graph.rs | 1 + crates/uv-resolver/src/resolution/mod.rs | 2 ++ .../src/resolution/requirements_txt.rs | 25 +++++++++++++++++-- crates/uv/tests/pip_compile.rs | 16 ++++++------ 4 files changed, 34 insertions(+), 10 deletions(-) diff --git a/crates/uv-resolver/src/resolution/graph.rs b/crates/uv-resolver/src/resolution/graph.rs index aa0dbd01b..37b8bc018 100644 --- a/crates/uv-resolver/src/resolution/graph.rs +++ b/crates/uv-resolver/src/resolution/graph.rs @@ -241,6 +241,7 @@ impl ResolutionGraph { // Add the distribution to the graph. let index = petgraph.add_node(ResolutionGraphNode::Dist(AnnotatedDist { dist, + version: version.clone(), extra: extra.clone(), dev: dev.clone(), hashes, diff --git a/crates/uv-resolver/src/resolution/mod.rs b/crates/uv-resolver/src/resolution/mod.rs index 22c8ec910..4b14fb803 100644 --- a/crates/uv-resolver/src/resolution/mod.rs +++ b/crates/uv-resolver/src/resolution/mod.rs @@ -1,6 +1,7 @@ use std::fmt::Display; use distribution_types::{DistributionMetadata, Name, ResolvedDist, VersionOrUrlRef}; +use pep440_rs::Version; use pypi_types::HashDigest; use uv_distribution::Metadata; use uv_normalize::{ExtraName, GroupName, PackageName}; @@ -20,6 +21,7 @@ mod requirements_txt; #[derive(Debug, Clone)] pub(crate) struct AnnotatedDist { pub(crate) dist: ResolvedDist, + pub(crate) version: Version, pub(crate) extra: Option, pub(crate) dev: Option, pub(crate) hashes: Vec, diff --git a/crates/uv-resolver/src/resolution/requirements_txt.rs b/crates/uv-resolver/src/resolution/requirements_txt.rs index be468da69..35049bbd0 100644 --- a/crates/uv-resolver/src/resolution/requirements_txt.rs +++ b/crates/uv-resolver/src/resolution/requirements_txt.rs @@ -5,6 +5,7 @@ use std::path::Path; use itertools::Itertools; use distribution_types::{DistributionMetadata, Name, ResolvedDist, Verbatim, VersionOrUrlRef}; +use pep440_rs::Version; use pep508_rs::{split_scheme, MarkerTree, Scheme}; use pypi_types::HashDigest; use uv_normalize::{ExtraName, PackageName}; @@ -15,6 +16,7 @@ use crate::resolution::AnnotatedDist; /// A pinned package with its resolved distribution and all the extras that were pinned for it. pub(crate) struct RequirementsTxtDist { pub(crate) dist: ResolvedDist, + pub(crate) version: Version, pub(crate) extras: Vec, pub(crate) hashes: Vec, pub(crate) markers: Option, @@ -133,7 +135,19 @@ impl RequirementsTxtDist { } } - RequirementsTxtComparator::Name(self.name()) + if let VersionOrUrlRef::Url(url) = self.version_or_url() { + RequirementsTxtComparator::Name { + name: self.name(), + version: &self.version, + url: Some(url.verbatim()), + } + } else { + RequirementsTxtComparator::Name { + name: self.name(), + version: &self.version, + url: None, + } + } } } @@ -141,6 +155,7 @@ impl From<&AnnotatedDist> for RequirementsTxtDist { fn from(annotated: &AnnotatedDist) -> Self { Self { dist: annotated.dist.clone(), + version: annotated.version.clone(), extras: if let Some(extra) = annotated.extra.clone() { vec![extra] } else { @@ -155,7 +170,13 @@ impl From<&AnnotatedDist> for RequirementsTxtDist { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub(crate) enum RequirementsTxtComparator<'a> { Url(Cow<'a, str>), - Name(&'a PackageName), + /// In universal mode, we can have multiple versions for a package, so we track the version and + /// the URL (for non-index packages) to have a stable sort for those, too. + Name { + name: &'a PackageName, + version: &'a Version, + url: Option>, + }, } impl Name for RequirementsTxtDist { diff --git a/crates/uv/tests/pip_compile.rs b/crates/uv/tests/pip_compile.rs index 2cccc7c0c..db7ec9d17 100644 --- a/crates/uv/tests/pip_compile.rs +++ b/crates/uv/tests/pip_compile.rs @@ -7250,13 +7250,13 @@ fn universal_nested_overlapping_local_requirement() -> Result<()> { # via torch tbb==2021.11.0 ; platform_machine != 'x86_64' and platform_system == 'Windows' # via mkl - torch==2.3.0 ; platform_machine != 'x86_64' - # via -r requirements.in torch==2.0.0+cu118 ; platform_machine == 'x86_64' # via # -r requirements.in # example # triton + torch==2.3.0 ; platform_machine != 'x86_64' + # via -r requirements.in triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux' # via torch typing-extensions==4.10.0 @@ -7324,13 +7324,13 @@ fn universal_nested_overlapping_local_requirement() -> Result<()> { # via torch tbb==2021.11.0 ; platform_machine != 'x86_64' and platform_system == 'Windows' # via mkl - torch==2.3.0 ; platform_machine != 'x86_64' - # via -r requirements.in torch==2.0.0+cu118 ; platform_machine == 'x86_64' # via # -r requirements.in # example # triton + torch==2.3.0 ; platform_machine != 'x86_64' + # via -r requirements.in triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux' # via torch typing-extensions==4.10.0 @@ -7440,6 +7440,10 @@ fn universal_nested_disjoint_local_requirement() -> Result<()> { # via torch tbb==2021.11.0 ; os_name != 'Linux' and platform_system == 'Windows' # via mkl + torch==2.0.0+cpu ; os_name == 'Linux' + # via + # -r requirements.in + # example torch==2.0.0+cu118 ; os_name == 'Linux' # via # -r requirements.in @@ -7447,10 +7451,6 @@ fn universal_nested_disjoint_local_requirement() -> Result<()> { # triton torch==2.3.0 ; os_name != 'Linux' # via -r requirements.in - torch==2.0.0+cpu ; os_name == 'Linux' - # via - # -r requirements.in - # example triton==2.0.0 ; os_name == 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' # via torch typing-extensions==4.10.0