revset: add optimization pass to flatten intersections

This change makes it possible to extract a helper function to sort the
elements of an intersection by a key, which simplifies the logic of
`internalize_filter`, and it will make it easier to add other passes
which rearrange intersections.
This commit is contained in:
Scott Taylor 2025-06-06 16:52:26 -05:00 committed by Scott Taylor
parent e72f161189
commit 9a7ca8edb5

View file

@ -1506,6 +1506,66 @@ where
Ok(expression)
}
/// Flatten all intersections to be left-recursive. For instance, transforms
/// `(a & b) & (c & d)` into `((a & b) & c) & d`.
fn flatten_intersections<St: ExpressionState>(
expression: &Rc<RevsetExpression<St>>,
) -> TransformedExpression<St> {
fn flatten<St: ExpressionState>(
expression1: &Rc<RevsetExpression<St>>,
expression2: &Rc<RevsetExpression<St>>,
) -> TransformedExpression<St> {
let recurse = |a, b| flatten(a, b).unwrap_or_else(|| a.intersection(b));
match expression2.as_ref() {
// flatten(a & (b & c)) -> flatten(a & b) & c
RevsetExpression::Intersection(inner1, inner2) => {
Some(recurse(expression1, inner1).intersection(inner2))
}
_ => None,
}
}
transform_expression_bottom_up(expression, |expression| match expression.as_ref() {
RevsetExpression::Intersection(expression1, expression2) => {
flatten(expression1, expression2)
}
_ => None,
})
}
/// Intersects `expression` with `base`, maintaining sorted order using the
/// provided key. If `base` is an intersection, it must be left-recursive, and
/// it must already be in sorted order.
fn sort_intersection_by_key<St: ExpressionState, T: Ord>(
base: &Rc<RevsetExpression<St>>,
expression: &Rc<RevsetExpression<St>>,
mut get_key: impl FnMut(&RevsetExpression<St>) -> T,
) -> TransformedExpression<St> {
// We only want to compute the key for `expression` once instead of computing it
// on every iteration.
fn sort_intersection_helper<St: ExpressionState, T: Ord>(
base: &Rc<RevsetExpression<St>>,
expression: &Rc<RevsetExpression<St>>,
expression_key: T,
mut get_key: impl FnMut(&RevsetExpression<St>) -> T,
) -> TransformedExpression<St> {
if let RevsetExpression::Intersection(inner1, inner2) = base.as_ref() {
// sort_intersection(a & b, c) -> sort_intersection(a, c) & b
(expression_key < get_key(inner2)).then(|| {
sort_intersection_helper(inner1, expression, expression_key, get_key)
.unwrap_or_else(|| inner1.intersection(expression))
.intersection(inner2)
})
} else {
// a & b -> b & a
(expression_key < get_key(base)).then(|| expression.intersection(base))
}
}
sort_intersection_helper(base, expression, get_key(expression), get_key)
}
/// Transforms filter expressions, by applying the following rules.
///
/// a. Moves as many sets to left of filter intersection as possible, to
@ -1534,49 +1594,6 @@ fn internalize_filter<St: ExpressionState>(
}
}
// Extracts 'c & f' from intersect_down()-ed node.
#[expect(clippy::type_complexity)]
fn as_filter_intersection<St: ExpressionState>(
expression: &RevsetExpression<St>,
) -> Option<(&Rc<RevsetExpression<St>>, &Rc<RevsetExpression<St>>)> {
if let RevsetExpression::Intersection(expression1, expression2) = expression {
is_filter(expression2).then_some((expression1, expression2))
} else {
None
}
}
// Since both sides must have already been intersect_down()-ed, we don't need to
// apply the whole bottom-up pass to new intersection node. Instead, just push
// new 'c & (d & g)' down-left to '(c & d) & g' while either side is
// an intersection of filter node.
fn intersect_down<St: ExpressionState>(
expression1: &Rc<RevsetExpression<St>>,
expression2: &Rc<RevsetExpression<St>>,
) -> TransformedExpression<St> {
let recurse = |e1, e2| intersect_down(e1, e2).unwrap_or_else(|| e1.intersection(e2));
match (expression1.as_ref(), expression2.as_ref()) {
// Don't reorder 'f1 & f2'
(_, e2) if is_filter(e2) => None,
// f1 & e2 -> e2 & f1
(e1, _) if is_filter(e1) => Some(expression2.intersection(expression1)),
(e1, e2) => match (as_filter_intersection(e1), as_filter_intersection(e2)) {
// e1 & (c2 & f2) -> (e1 & c2) & f2
// (c1 & f1) & (c2 & f2) -> ((c1 & f1) & c2) & f2 -> ((c1 & c2) & f1) & f2
(_, Some((c2, f2))) => Some(recurse(expression1, c2).intersection(f2)),
// (c1 & f1) & e2 -> (c1 & e2) & f1
// ((c1 & f1) & g1) & e2 -> ((c1 & f1) & e2) & g1 -> ((c1 & e2) & f1) & g1
(Some((c1, f1)), _) => Some(recurse(c1, expression2).intersection(f1)),
(None, None) => None,
},
}
}
// Bottom-up pass pulls up-right filter node from leaf '(c & f) & e' ->
// '(c & e) & f', so that an intersection of filter node can be found as
// a direct child of another intersection node. However, the rewritten
// intersection node 'c & e' can also be a rewrite target if 'e' contains
// a filter node. That's why intersect_down() is also recursive.
transform_expression_bottom_up(expression, |expression| match expression.as_ref() {
RevsetExpression::Present(e) => {
is_filter_tree(e).then(|| Rc::new(RevsetExpression::AsFilter(expression.clone())))
@ -1586,8 +1603,11 @@ fn internalize_filter<St: ExpressionState>(
}
RevsetExpression::Union(e1, e2) => (is_filter_tree(e1) || is_filter_tree(e2))
.then(|| Rc::new(RevsetExpression::AsFilter(expression.clone()))),
// Bottom-up pass pulls up-right filter node from leaf '(c & f) & e' ->
// '(c & e) & f', so that an intersection of filter node can be found as
// a direct child of another intersection node.
RevsetExpression::Intersection(expression1, expression2) => {
intersect_down(expression1, expression2)
sort_intersection_by_key(expression1, expression2, is_filter)
}
// Difference(e1, e2) should have been unfolded to Intersection(e1, NotIn(e2)).
_ => None,
@ -1811,6 +1831,7 @@ pub fn optimize<St: ExpressionState>(
let expression = unfold_difference(&expression).unwrap_or(expression);
let expression = fold_redundant_expression(&expression).unwrap_or(expression);
let expression = fold_generation(&expression).unwrap_or(expression);
let expression = flatten_intersections(&expression).unwrap_or(expression);
let expression = internalize_filter(&expression).unwrap_or(expression);
let expression = fold_difference(&expression).unwrap_or(expression);
fold_not_in_ancestors(&expression).unwrap_or(expression)
@ -4045,11 +4066,11 @@ mod tests {
Intersection(
Intersection(
Intersection(
CommitRef(Symbol("a")),
Intersection(
CommitRef(Symbol("a")),
CommitRef(Symbol("b")),
CommitRef(Symbol("c")),
),
CommitRef(Symbol("c")),
),
CommitRef(Symbol("d")),
),
@ -4379,6 +4400,29 @@ mod tests {
"#);
}
#[test]
fn test_optimize_flatten_intersection() {
let settings = insta_settings();
let _guard = settings.bind_to_scope();
// Nested intersections should be flattened.
insta::assert_debug_snapshot!(optimize(parse("a & ((b & c) & (d & e))").unwrap()), @r#"
Intersection(
Intersection(
Intersection(
Intersection(
CommitRef(Symbol("a")),
CommitRef(Symbol("b")),
),
CommitRef(Symbol("c")),
),
CommitRef(Symbol("d")),
),
CommitRef(Symbol("e")),
)
"#);
}
#[test]
fn test_escape_string_literal() {
// Valid identifiers don't need quoting