Simplify marker expressions in lockfile (#4066)

## Summary

Simplify and normalize marker expressions in the lockfile. Right now
this does a simple analysis by only looking at related operators at the
same level of precedence. I think anything more complex would be out of
scope.

Resolves https://github.com/astral-sh/uv/issues/4002.
This commit is contained in:
Ibraheem Ahmed 2024-06-07 16:14:24 -04:00 committed by GitHub
parent bcfe88dfdc
commit 7232c53718
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 225 additions and 51 deletions

View file

@ -1,7 +1,9 @@
#[cfg(feature = "pyo3")]
use std::hash::{Hash, Hasher};
use std::cmp::Ordering;
use std::ops::Bound;
use std::{cmp::Ordering, str::FromStr};
use std::str::FromStr;
#[cfg(feature = "pyo3")]
use pyo3::{

View file

@ -1881,30 +1881,6 @@ impl MarkerTree {
exprs.push(tree);
}
}
/// Normalizes this marker tree such that all conjunctions and disjunctions
/// are sorted.
///
/// This is useful in cases where creating conjunctions or disjunctions
/// might occur in a non-deterministic order. This routine will erase the
/// distinction created by such a construction.
pub fn normalize(&mut self) {
match *self {
MarkerTree::Expression(_) => {}
MarkerTree::And(ref mut trees) | MarkerTree::Or(ref mut trees) => {
// This is kind of cheesy, because we're doing a recursive call
// followed by a sort, and that sort is also recursive (due to
// the corresponding Ord impl being recursive).
//
// We should consider refactoring `MarkerTree` to a "smart
// constructor" design that normalizes them by construction.
for tree in &mut *trees {
tree.normalize();
}
trees.sort();
}
}
}
}
impl Display for MarkerTree {

View file

@ -511,7 +511,7 @@ impl Distribution {
// Markers can be combined in an unpredictable order, so normalize them
// such that the lock file output is consistent and deterministic.
if let Some(ref mut marker) = marker {
marker.normalize();
crate::marker::normalize(marker);
}
let sdist = SourceDist::from_annotated_dist(annotated_dist)?;
let wheels = Wheel::from_annotated_dist(annotated_dist)?;

View file

@ -1,5 +1,7 @@
#![allow(clippy::enum_glob_use)]
use std::collections::HashMap;
use std::mem;
use std::ops::Bound::{self, *};
use std::ops::RangeBounds;
@ -10,6 +12,7 @@ use pep508_rs::{
};
use crate::pubgrub::PubGrubSpecifier;
use pubgrub::range::Range as PubGrubRange;
/// Returns `true` if there is no environment in which both marker trees can both apply, i.e.
/// the expression `first and second` is always false.
@ -79,6 +82,111 @@ fn string_is_disjoint(this: &MarkerExpression, other: &MarkerExpression) -> bool
true
}
/// Normalizes this marker tree.
///
/// This function does a number of operations to normalize a marker tree recursively:
/// - Sort all nested expressions.
/// - Simplify expressions. This includes combining overlapping version ranges and removing duplicate
/// expressions at the same level of precedence. For example, `(a == 'a' and a == 'a') or b == 'b'` can
/// be reduced, but `a == 'a' and (a == 'a' or b == 'b')` cannot.
/// - Normalize the order of version expressions to the form `<version key> <version op> <version>`
/// (i.e. not the reverse).
///
/// This is useful in cases where creating conjunctions or disjunctions might occur in a non-deterministic
/// order. This routine will attempt to erase the distinction created by such a construction.
pub(crate) fn normalize(tree: &mut MarkerTree) {
match tree {
MarkerTree::And(trees) | MarkerTree::Or(trees) => {
let mut reduced = Vec::new();
let mut versions: HashMap<_, Vec<_>> = HashMap::new();
for mut tree in mem::take(trees) {
// Simplify nested expressions as much as possible first.
normalize(&mut tree);
// Extract expressions we may be able to simplify more.
if let MarkerTree::Expression(ref expr) = tree {
if let Some((key, range)) = keyed_range(expr) {
versions.entry(key.clone()).or_default().push(range);
continue;
}
}
reduced.push(tree);
}
match tree {
MarkerTree::And(_) => {
simplify_ranges(&mut reduced, versions, |ranges| {
ranges
.iter()
.fold(PubGrubRange::full(), |acc, range| acc.intersection(range))
});
reduced.dedup();
reduced.sort();
*tree = match reduced.len() {
1 => reduced.remove(0),
_ => MarkerTree::And(reduced),
};
}
MarkerTree::Or(_) => {
simplify_ranges(&mut reduced, versions, |ranges| {
ranges
.iter()
.fold(PubGrubRange::empty(), |acc, range| acc.union(range))
});
reduced.dedup();
reduced.sort();
*tree = match reduced.len() {
1 => reduced.remove(0),
_ => MarkerTree::Or(reduced),
};
}
MarkerTree::Expression(_) => unreachable!(),
}
}
MarkerTree::Expression(_) => {}
}
}
// Simplify version expressions.
fn simplify_ranges(
reduced: &mut Vec<MarkerTree>,
versions: HashMap<MarkerValueVersion, Vec<PubGrubRange<Version>>>,
combine: impl Fn(&Vec<PubGrubRange<Version>>) -> PubGrubRange<Version>,
) {
for (key, ranges) in versions {
let simplified = combine(&ranges);
// If this is a meaningless expressions with no valid intersection, add back
// the original ranges.
if simplified.is_empty() {
for specifier in ranges
.iter()
.flat_map(PubGrubRange::iter)
.flat_map(VersionSpecifier::from_bounds)
{
reduced.push(MarkerTree::Expression(MarkerExpression::Version {
specifier,
key: key.clone(),
}));
}
}
// Add back the simplified segments.
for specifier in simplified.iter().flat_map(VersionSpecifier::from_bounds) {
reduced.push(MarkerTree::Expression(MarkerExpression::Version {
key: key.clone(),
specifier,
}));
}
}
}
/// Extracts the key, value, and string from a string expression, reversing the operator if necessary.
fn extract_string_expression(
expr: &MarkerExpression,
@ -145,12 +253,12 @@ fn extra_is_disjoint(operator: &ExtraOperator, name: &ExtraName, other: &MarkerE
/// Returns `true` if this version expression does not intersect with the given expression.
fn version_is_disjoint(this: &MarkerExpression, other: &MarkerExpression) -> bool {
let Some((key, range)) = keyed_range(this).unwrap() else {
let Some((key, range)) = keyed_range(this) else {
return false;
};
// if this is not a version expression it may intersect
let Ok(Some((key2, range2))) = keyed_range(other) else {
let Some((key2, range2)) = keyed_range(other) else {
return false;
};
@ -164,9 +272,7 @@ fn version_is_disjoint(this: &MarkerExpression, other: &MarkerExpression) -> boo
}
/// Returns the key and version range for a version expression.
fn keyed_range(
expr: &MarkerExpression,
) -> Result<Option<(&MarkerValueVersion, pubgrub::range::Range<Version>)>, ()> {
fn keyed_range(expr: &MarkerExpression) -> Option<(&MarkerValueVersion, PubGrubRange<Version>)> {
let (key, specifier) = match expr {
MarkerExpression::Version { key, specifier } => (key, specifier.clone()),
MarkerExpression::VersionInverted {
@ -178,19 +284,19 @@ fn keyed_range(
// a version specifier
let operator = reverse_operator(*operator);
let Ok(specifier) = VersionSpecifier::from_version(operator, version.clone()) else {
return Ok(None);
return None;
};
(key, specifier)
}
_ => return Err(()),
_ => return None,
};
let Ok(pubgrub_specifier) = PubGrubSpecifier::try_from(&specifier) else {
return Ok(None);
return None;
};
Ok(Some((key, pubgrub_specifier.into())))
Some((key, pubgrub_specifier.into()))
}
/// Reverses a binary operator.
@ -223,14 +329,85 @@ mod tests {
use super::*;
fn is_disjoint(one: impl AsRef<str>, two: impl AsRef<str>) -> bool {
let one = MarkerTree::parse_reporter(one.as_ref(), &mut TracingReporter).unwrap();
let two = MarkerTree::parse_reporter(two.as_ref(), &mut TracingReporter).unwrap();
super::is_disjoint(&one, &two) && super::is_disjoint(&two, &one)
#[test]
fn simplify() {
assert_marker_equal(
"python_version == '3.1' or python_version == '3.1'",
"python_version == '3.1'",
);
assert_marker_equal(
"python_version < '3.17' or python_version < '3.18'",
"python_version < '3.18'",
);
assert_marker_equal(
"python_version > '3.17' or python_version > '3.18' or python_version > '3.12'",
"python_version > '3.12'",
);
// a quirk of how pubgrub works, but this is considered part of normalization
assert_marker_equal(
"python_version > '3.17.post4' or python_version > '3.18.post4'",
"python_version >= '3.17.post5'",
);
assert_marker_equal(
"python_version < '3.17' and python_version < '3.18'",
"python_version < '3.17'",
);
assert_marker_equal(
"python_version <= '3.18' and python_version == '3.18'",
"python_version == '3.18'",
);
assert_marker_equal(
"python_version <= '3.18' or python_version == '3.18'",
"python_version <= '3.18'",
);
assert_marker_equal(
"python_version <= '3.15' or (python_version <= '3.17' and python_version < '3.16')",
"python_version < '3.16'",
);
assert_marker_equal(
"(python_version > '3.17' or python_version > '3.16') and python_version > '3.15'",
"python_version > '3.16'",
);
assert_marker_equal(
"(python_version > '3.17' or python_version > '3.16') and python_version > '3.15' and implementation_version == '1'",
"implementation_version == '1' and python_version > '3.16'",
);
assert_marker_equal(
"('3.17' < python_version or '3.16' < python_version) and '3.15' < python_version and implementation_version == '1'",
"implementation_version == '1' and python_version > '3.16'",
);
assert_marker_equal("extra == 'a' or extra == 'a'", "extra == 'a'");
assert_marker_equal(
"extra == 'a' and extra == 'a' or extra == 'b'",
"extra == 'a' or extra == 'b'",
);
// bogus expressions are retained but still normalized
assert_marker_equal(
"python_version < '3.17' and '3.18' == python_version",
"python_version == '3.18' and python_version < '3.17'",
);
// cannot simplify nested complex expressions
assert_marker_equal(
"extra == 'a' and (extra == 'a' or extra == 'b')",
"extra == 'a' and (extra == 'a' or extra == 'b')",
);
}
#[test]
fn extra() {
fn extra_disjointness() {
assert!(!is_disjoint("extra == 'a'", "python_version == '1'"));
assert!(!is_disjoint("extra == 'a'", "extra == 'a'"));
@ -243,7 +420,7 @@ mod tests {
}
#[test]
fn arbitrary() {
fn arbitrary_disjointness() {
assert!(is_disjoint(
"python_version == 'Linux'",
"python_version == '3.7.1'"
@ -251,13 +428,13 @@ mod tests {
}
#[test]
fn version() {
fn version_disjointness() {
assert!(!is_disjoint(
"os_name == 'Linux'",
"python_version == '3.7.1'"
));
test_version_bounds("python_version");
test_version_bounds_disjointness("python_version");
assert!(!is_disjoint(
"python_version == '3.7.*'",
@ -266,7 +443,7 @@ mod tests {
}
#[test]
fn string() {
fn string_disjointness() {
assert!(!is_disjoint(
"os_name == 'Linux'",
"platform_version == '3.7.1'"
@ -277,7 +454,7 @@ mod tests {
));
// basic version bounds checking should still work with lexicographical comparisons
test_version_bounds("platform_version");
test_version_bounds_disjointness("platform_version");
assert!(is_disjoint("os_name == 'Linux'", "os_name == 'OSX'"));
assert!(is_disjoint("os_name <= 'Linux'", "os_name == 'OSX'"));
@ -303,7 +480,7 @@ mod tests {
}
#[test]
fn combined() {
fn combined_disjointness() {
assert!(!is_disjoint(
"os_name == 'a' and platform_version == '1'",
"os_name == 'a'"
@ -327,7 +504,7 @@ mod tests {
));
}
fn test_version_bounds(version: &str) {
fn test_version_bounds_disjointness(version: &str) {
assert!(!is_disjoint(
format!("{version} > '2.7.0'"),
format!("{version} == '3.6.0'")
@ -372,4 +549,17 @@ mod tests {
format!("{version} != '3.7.0'")
));
}
fn is_disjoint(one: impl AsRef<str>, two: impl AsRef<str>) -> bool {
let one = MarkerTree::parse_reporter(one.as_ref(), &mut TracingReporter).unwrap();
let two = MarkerTree::parse_reporter(two.as_ref(), &mut TracingReporter).unwrap();
super::is_disjoint(&one, &two) && super::is_disjoint(&two, &one)
}
fn assert_marker_equal(one: impl AsRef<str>, two: impl AsRef<str>) {
let mut tree1 = MarkerTree::parse_reporter(one.as_ref(), &mut TracingReporter).unwrap();
super::normalize(&mut tree1);
let tree2 = MarkerTree::parse_reporter(two.as_ref(), &mut TracingReporter).unwrap();
assert_eq!(tree1.to_string(), tree2.to_string());
}
}

View file

@ -22,6 +22,12 @@ impl PubGrubSpecifier {
}
}
impl From<Range<Version>> for PubGrubSpecifier {
fn from(range: Range<Version>) -> Self {
PubGrubSpecifier(range)
}
}
impl From<PubGrubSpecifier> for Range<Version> {
/// Convert a PubGrub specifier to a range of versions.
fn from(specifier: PubGrubSpecifier) -> Self {

View file

@ -788,7 +788,7 @@ fn lock_dependency_extra() -> Result<()> {
name = "importlib-metadata"
version = "7.1.0"
source = "registry+https://pypi.org/simple"
marker = "python_version < '3.8' or python_version < '3.10'"
marker = "python_version < '3.10'"
sdist = { url = "https://files.pythonhosted.org/packages/a0/fc/c4e6078d21fc4fa56300a241b87eae76766aa380a23fc450fc85bb7bf547/importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2", size = 52120 }
wheels = [{ url = "https://files.pythonhosted.org/packages/2d/0a/679461c511447ffaf176567d5c496d1de27cbe34a87df6677d7171b2fbd4/importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570", size = 24409 }]
@ -1473,7 +1473,7 @@ fn lock_requires_python() -> Result<()> {
name = "typing-extensions"
version = "4.7.1"
source = "registry+https://pypi.org/simple"
marker = "python_version < '3.8' or python_version < '3.11'"
marker = "python_version < '3.11'"
sdist = { url = "https://files.pythonhosted.org/packages/3c/8b/0111dd7d6c1478bf83baa1cab85c686426c7a6274119aceb2bd9d35395ad/typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2", size = 72876 }
wheels = [{ url = "https://files.pythonhosted.org/packages/ec/6b/63cc3df74987c36fe26157ee12e09e8f9db4de771e0f3404263117e75b95/typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36", size = 33232 }]
@ -1625,7 +1625,7 @@ fn lock_requires_python() -> Result<()> {
name = "typing-extensions"
version = "4.7.1"
source = "registry+https://pypi.org/simple"
marker = "python_version < '3.8' or python_version < '3.11'"
marker = "python_version < '3.11'"
sdist = { url = "https://files.pythonhosted.org/packages/3c/8b/0111dd7d6c1478bf83baa1cab85c686426c7a6274119aceb2bd9d35395ad/typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2", size = 72876 }
wheels = [{ url = "https://files.pythonhosted.org/packages/ec/6b/63cc3df74987c36fe26157ee12e09e8f9db4de771e0f3404263117e75b95/typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36", size = 33232 }]
@ -1782,7 +1782,7 @@ fn lock_requires_python() -> Result<()> {
name = "typing-extensions"
version = "4.10.0"
source = "registry+https://pypi.org/simple"
marker = "python_version < '3.8' or python_version < '3.11'"
marker = "python_version < '3.11'"
sdist = { url = "https://files.pythonhosted.org/packages/16/3a/0d26ce356c7465a19c9ea8814b960f8a36c3b0d07c323176620b7b483e44/typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb", size = 77558 }
wheels = [{ url = "https://files.pythonhosted.org/packages/f9/de/dc04a3ea60b22624b51c703a84bbe0184abcd1d0b9bc8074b5d6b7ab90bb/typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475", size = 33926 }]