[red-knot] Use itertools to clean up SymbolState::merge (#15702)

[`merge_join_by`](https://docs.rs/itertools/latest/itertools/trait.Itertools.html#method.merge_join_by)
handles the "merge two sorted iterators" bit, and `zip` handles
iterating through the bindings/definitions along with their associated
constraints.

---------

Co-authored-by: Micha Reiser <micha@reiser.io>
This commit is contained in:
Douglas Creager 2025-01-24 10:21:29 -05:00 committed by GitHub
parent 4e3982cf95
commit 716b246cf3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -43,12 +43,14 @@
//!
//! Tracking live declarations is simpler, since constraints are not involved, but otherwise very
//! similar to tracking live bindings.
use crate::semantic_index::use_def::VisibilityConstraints;
use super::bitset::{BitSet, BitSetIterator};
use itertools::{EitherOrBoth, Itertools};
use ruff_index::newtype_index;
use smallvec::SmallVec;
use crate::semantic_index::use_def::bitset::{BitSet, BitSetIterator};
use crate::semantic_index::use_def::VisibilityConstraints;
/// A newtype-index for a definition in a particular scope.
#[newtype_index]
pub(super) struct ScopedDefinitionId;
@ -96,7 +98,6 @@ type ConstraintsPerBinding = SmallVec<InlineConstraintArray>;
/// Iterate over all constraints for a single binding.
type ConstraintsIterator<'a> = std::slice::Iter<'a, Constraints>;
type ConstraintsIntoIterator = smallvec::IntoIter<InlineConstraintArray>;
/// A newtype-index for a visibility constraint in a particular scope.
#[newtype_index]
@ -123,13 +124,18 @@ type VisibilityConstraintPerBinding = SmallVec<InlineVisibilityConstraintsArray>
/// Iterator over the visibility constraints for all live bindings/declarations.
type VisibilityConstraintsIterator<'a> = std::slice::Iter<'a, ScopedVisibilityConstraintId>;
type VisibilityConstraintsIntoIterator = smallvec::IntoIter<InlineVisibilityConstraintsArray>;
/// Live declarations for a single symbol at some point in control flow, with their
/// corresponding visibility constraints.
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub(super) struct SymbolDeclarations {
/// [`BitSet`]: which declarations (as [`ScopedDefinitionId`]) can reach the current location?
///
/// Invariant: Because this is a `BitSet`, it can be viewed as a _sorted_ set of definition
/// IDs. The `visibility_constraints` field stores constraints for each definition. Therefore
/// those fields must always have the same `len()` as `live_declarations`, and the elements
/// must appear in the same order. Effectively, this means that elements must always be added
/// in sorted order, or via a binary search that determines the correct place to insert new
/// constraints.
pub(crate) live_declarations: Declarations,
/// For each live declaration, which visibility constraint applies to it?
@ -173,13 +179,51 @@ impl SymbolDeclarations {
visibility_constraints: self.visibility_constraints.iter(),
}
}
fn merge(&mut self, b: Self, visibility_constraints: &mut VisibilityConstraints) {
let a = std::mem::take(self);
self.live_declarations = a.live_declarations.clone();
self.live_declarations.union(&b.live_declarations);
// Invariant: These zips are well-formed since we maintain an invariant that all of our
// fields are sets/vecs with the same length.
let a = (a.live_declarations.iter()).zip(a.visibility_constraints);
let b = (b.live_declarations.iter()).zip(b.visibility_constraints);
// Invariant: merge_join_by consumes the two iterators in sorted order, which ensures that
// the definition IDs and constraints line up correctly in the merged result. If a
// definition is found in both `a` and `b`, we compose the constraints from the two paths
// in an appropriate way (intersection for narrowing constraints; ternary OR for visibility
// constraints). If a definition is found in only one path, it is used as-is.
for zipped in a.merge_join_by(b, |(a_decl, _), (b_decl, _)| a_decl.cmp(b_decl)) {
match zipped {
EitherOrBoth::Both((_, a_vis_constraint), (_, b_vis_constraint)) => {
let vis_constraint = visibility_constraints
.add_or_constraint(a_vis_constraint, b_vis_constraint);
self.visibility_constraints.push(vis_constraint);
}
EitherOrBoth::Left((_, vis_constraint))
| EitherOrBoth::Right((_, vis_constraint)) => {
self.visibility_constraints.push(vis_constraint);
}
}
}
}
}
/// Live bindings for a single symbol at some point in control flow. Each live binding comes
/// with a set of narrowing constraints and a visibility constraint.
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub(super) struct SymbolBindings {
/// [`BitSet`]: which bindings (as [`ScopedDefinitionId`]) can reach the current location?
///
/// Invariant: Because this is a `BitSet`, it can be viewed as a _sorted_ set of definition
/// IDs. The `constraints` and `visibility_constraints` field stores constraints for each
/// definition. Therefore those fields must always have the same `len()` as
/// `live_bindings`, and the elements must appear in the same order. Effectively, this means
/// that elements must always be added in sorted order, or via a binary search that determines
/// the correct place to insert new constraints.
live_bindings: Bindings,
/// For each live binding, which [`ScopedConstraintId`] apply?
@ -242,6 +286,53 @@ impl SymbolBindings {
visibility_constraints: self.visibility_constraints.iter(),
}
}
fn merge(&mut self, b: Self, visibility_constraints: &mut VisibilityConstraints) {
let a = std::mem::take(self);
self.live_bindings = a.live_bindings.clone();
self.live_bindings.union(&b.live_bindings);
// Invariant: These zips are well-formed since we maintain an invariant that all of our
// fields are sets/vecs with the same length.
let a = (a.live_bindings.iter())
.zip(a.constraints)
.zip(a.visibility_constraints);
let b = (b.live_bindings.iter())
.zip(b.constraints)
.zip(b.visibility_constraints);
// Invariant: merge_join_by consumes the two iterators in sorted order, which ensures that
// the definition IDs and constraints line up correctly in the merged result. If a
// definition is found in both `a` and `b`, we compose the constraints from the two paths
// in an appropriate way (intersection for narrowing constraints; ternary OR for visibility
// constraints). If a definition is found in only one path, it is used as-is.
for zipped in a.merge_join_by(b, |((a_def, _), _), ((b_def, _), _)| a_def.cmp(b_def)) {
match zipped {
EitherOrBoth::Both(
((_, a_constraints), a_vis_constraint),
((_, b_constraints), b_vis_constraint),
) => {
// If the same definition is visible through both paths, any constraint
// that applies on only one path is irrelevant to the resulting type from
// unioning the two paths, so we intersect the constraints.
let mut constraints = a_constraints;
constraints.intersect(&b_constraints);
self.constraints.push(constraints);
// For visibility constraints, we merge them using a ternary OR operation:
let vis_constraint = visibility_constraints
.add_or_constraint(a_vis_constraint, b_vis_constraint);
self.visibility_constraints.push(vis_constraint);
}
EitherOrBoth::Left(((_, constraints), vis_constraint))
| EitherOrBoth::Right(((_, constraints), vis_constraint)) => {
self.constraints.push(constraints);
self.visibility_constraints.push(vis_constraint);
}
}
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
@ -303,202 +394,9 @@ impl SymbolState {
b: SymbolState,
visibility_constraints: &mut VisibilityConstraints,
) {
let mut a = Self {
bindings: SymbolBindings {
live_bindings: Bindings::default(),
constraints: ConstraintsPerBinding::default(),
visibility_constraints: VisibilityConstraintPerBinding::default(),
},
declarations: SymbolDeclarations {
live_declarations: self.declarations.live_declarations.clone(),
visibility_constraints: VisibilityConstraintPerDeclaration::default(),
},
};
std::mem::swap(&mut a, self);
self.bindings.merge(b.bindings, visibility_constraints);
self.declarations
.live_declarations
.union(&b.declarations.live_declarations);
let mut a_defs_iter = a.bindings.live_bindings.iter();
let mut b_defs_iter = b.bindings.live_bindings.iter();
let mut a_constraints_iter = a.bindings.constraints.into_iter();
let mut b_constraints_iter = b.bindings.constraints.into_iter();
let mut a_vis_constraints_iter = a.bindings.visibility_constraints.into_iter();
let mut b_vis_constraints_iter = b.bindings.visibility_constraints.into_iter();
let mut opt_a_def: Option<u32> = a_defs_iter.next();
let mut opt_b_def: Option<u32> = b_defs_iter.next();
// Iterate through the definitions from `a` and `b`, always processing the lower definition
// ID first, and pushing each definition onto the merged `SymbolState` with its
// constraints. If a definition is found in both `a` and `b`, push it with the intersection
// of the constraints from the two paths; a constraint that applies from only one possible
// path is irrelevant.
// Helper to push `def`, with constraints in `constraints_iter`, onto `self`.
let push = |def,
constraints_iter: &mut ConstraintsIntoIterator,
visibility_constraints_iter: &mut VisibilityConstraintsIntoIterator,
merged: &mut Self| {
merged.bindings.live_bindings.insert(def);
// SAFETY: we only ever create SymbolState using [`SymbolState::undefined`], which adds
// one "unbound" definition with corresponding narrowing and visibility constraints, or
// using [`SymbolState::record_binding`] or [`SymbolState::record_declaration`], which
// similarly add one definition with corresponding constraints. [`SymbolState::merge`]
// always pushes one definition and one constraint bitset and one visibility constraint
// together (just below), so the number of definitions and the number of constraints can
// never get out of sync.
// get out of sync.
let constraints = constraints_iter
.next()
.expect("definitions and constraints length mismatch");
let visibility_constraints = visibility_constraints_iter
.next()
.expect("definitions and visibility_constraints length mismatch");
merged.bindings.constraints.push(constraints);
merged
.bindings
.visibility_constraints
.push(visibility_constraints);
};
loop {
match (opt_a_def, opt_b_def) {
(Some(a_def), Some(b_def)) => match a_def.cmp(&b_def) {
std::cmp::Ordering::Less => {
// Next definition ID is only in `a`, push it to `self` and advance `a`.
push(
a_def,
&mut a_constraints_iter,
&mut a_vis_constraints_iter,
self,
);
opt_a_def = a_defs_iter.next();
}
std::cmp::Ordering::Greater => {
// Next definition ID is only in `b`, push it to `self` and advance `b`.
push(
b_def,
&mut b_constraints_iter,
&mut b_vis_constraints_iter,
self,
);
opt_b_def = b_defs_iter.next();
}
std::cmp::Ordering::Equal => {
// Next definition is in both; push to `self` and intersect constraints.
push(
a_def,
&mut b_constraints_iter,
&mut b_vis_constraints_iter,
self,
);
// SAFETY: see comment in `push` above.
let a_constraints = a_constraints_iter
.next()
.expect("definitions and constraints length mismatch");
let current_constraints = self.bindings.constraints.last_mut().unwrap();
// If the same definition is visible through both paths, any constraint
// that applies on only one path is irrelevant to the resulting type from
// unioning the two paths, so we intersect the constraints.
current_constraints.intersect(&a_constraints);
// For visibility constraints, we merge them using a ternary OR operation:
let a_vis_constraint = a_vis_constraints_iter
.next()
.expect("visibility_constraints length mismatch");
let current_vis_constraint =
self.bindings.visibility_constraints.last_mut().unwrap();
*current_vis_constraint = visibility_constraints
.add_or_constraint(*current_vis_constraint, a_vis_constraint);
opt_a_def = a_defs_iter.next();
opt_b_def = b_defs_iter.next();
}
},
(Some(a_def), None) => {
// We've exhausted `b`, just push the def from `a` and move on to the next.
push(
a_def,
&mut a_constraints_iter,
&mut a_vis_constraints_iter,
self,
);
opt_a_def = a_defs_iter.next();
}
(None, Some(b_def)) => {
// We've exhausted `a`, just push the def from `b` and move on to the next.
push(
b_def,
&mut b_constraints_iter,
&mut b_vis_constraints_iter,
self,
);
opt_b_def = b_defs_iter.next();
}
(None, None) => break,
}
}
// Same as above, but for declarations.
let mut a_decls_iter = a.declarations.live_declarations.iter();
let mut b_decls_iter = b.declarations.live_declarations.iter();
let mut a_vis_constraints_iter = a.declarations.visibility_constraints.into_iter();
let mut b_vis_constraints_iter = b.declarations.visibility_constraints.into_iter();
let mut opt_a_decl: Option<u32> = a_decls_iter.next();
let mut opt_b_decl: Option<u32> = b_decls_iter.next();
let push = |vis_constraints_iter: &mut VisibilityConstraintsIntoIterator,
merged: &mut Self| {
let vis_constraints = vis_constraints_iter
.next()
.expect("declarations and visibility_constraints length mismatch");
merged
.declarations
.visibility_constraints
.push(vis_constraints);
};
loop {
match (opt_a_decl, opt_b_decl) {
(Some(a_decl), Some(b_decl)) => match a_decl.cmp(&b_decl) {
std::cmp::Ordering::Less => {
push(&mut a_vis_constraints_iter, self);
opt_a_decl = a_decls_iter.next();
}
std::cmp::Ordering::Greater => {
push(&mut b_vis_constraints_iter, self);
opt_b_decl = b_decls_iter.next();
}
std::cmp::Ordering::Equal => {
push(&mut b_vis_constraints_iter, self);
let a_vis_constraint = a_vis_constraints_iter
.next()
.expect("declarations and visibility_constraints length mismatch");
let current = self.declarations.visibility_constraints.last_mut().unwrap();
*current =
visibility_constraints.add_or_constraint(*current, a_vis_constraint);
opt_a_decl = a_decls_iter.next();
opt_b_decl = b_decls_iter.next();
}
},
(Some(_), None) => {
push(&mut a_vis_constraints_iter, self);
opt_a_decl = a_decls_iter.next();
}
(None, Some(_)) => {
push(&mut b_vis_constraints_iter, self);
opt_b_decl = b_decls_iter.next();
}
(None, None) => break,
}
}
.merge(b.declarations, visibility_constraints);
}
pub(super) fn bindings(&self) -> &SymbolBindings {