diff --git a/crates/red_knot_python_semantic/src/lib.rs b/crates/red_knot_python_semantic/src/lib.rs index ab8a622d8a..18cd697363 100644 --- a/crates/red_knot_python_semantic/src/lib.rs +++ b/crates/red_knot_python_semantic/src/lib.rs @@ -27,7 +27,6 @@ pub(crate) mod symbol; pub mod types; mod unpack; mod util; -mod visibility_constraints; type FxOrderSet = ordermap::set::OrderSet>; diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index a2c766cccc..0e9bd9b202 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -28,8 +28,10 @@ mod builder; pub(crate) mod constraint; pub mod definition; pub mod expression; +mod narrowing_constraints; pub mod symbol; mod use_def; +mod visibility_constraints; pub(crate) use self::use_def::{ BindingWithConstraints, BindingWithConstraintsIterator, DeclarationWithConstraint, diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 3f77df38cd..6eb45aad2d 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -15,10 +15,14 @@ use crate::module_name::ModuleName; use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; use crate::semantic_index::ast_ids::AstIdsBuilder; use crate::semantic_index::attribute_assignment::{AttributeAssignment, AttributeAssignments}; -use crate::semantic_index::constraint::{PatternConstraintKind, ScopedConstraintId}; +use crate::semantic_index::constraint::{ + Constraint, ConstraintNode, PatternConstraint, PatternConstraintKind, ScopedConstraintId, +}; use crate::semantic_index::definition::{ - AssignmentDefinitionNodeRef, ComprehensionDefinitionNodeRef, Definition, DefinitionNodeKey, - DefinitionNodeRef, ForStmtDefinitionNodeRef, ImportFromDefinitionNodeRef, + AssignmentDefinitionNodeRef, ComprehensionDefinitionNodeRef, Definition, DefinitionCategory, + DefinitionNodeKey, DefinitionNodeRef, ExceptHandlerDefinitionNodeRef, ForStmtDefinitionNodeRef, + ImportDefinitionNodeRef, ImportFromDefinitionNodeRef, MatchPatternDefinitionNodeRef, + WithItemDefinitionNodeRef, }; use crate::semantic_index::expression::{Expression, ExpressionKind}; use crate::semantic_index::symbol::{ @@ -28,17 +32,13 @@ use crate::semantic_index::symbol::{ use crate::semantic_index::use_def::{ EagerBindingsKey, FlowSnapshot, ScopedEagerBindingsId, UseDefMapBuilder, }; +use crate::semantic_index::visibility_constraints::{ + ScopedVisibilityConstraintId, VisibilityConstraintsBuilder, +}; use crate::semantic_index::SemanticIndex; use crate::unpack::{Unpack, UnpackValue}; -use crate::visibility_constraints::{ScopedVisibilityConstraintId, VisibilityConstraintsBuilder}; use crate::Db; -use super::constraint::{Constraint, ConstraintNode, PatternConstraint}; -use super::definition::{ - DefinitionCategory, ExceptHandlerDefinitionNodeRef, ImportDefinitionNodeRef, - MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef, -}; - mod except_handlers; /// Are we in a state where a `break` statement is allowed? diff --git a/crates/red_knot_python_semantic/src/semantic_index/narrowing_constraints.rs b/crates/red_knot_python_semantic/src/semantic_index/narrowing_constraints.rs new file mode 100644 index 0000000000..645c775a79 --- /dev/null +++ b/crates/red_knot_python_semantic/src/semantic_index/narrowing_constraints.rs @@ -0,0 +1,151 @@ +//! # Narrowing constraints +//! +//! When building a semantic index for a file, we associate each binding with _narrowing +//! constraints_. The narrowing constraint is used to constrain the type of the binding's symbol. +//! Note that a binding can be associated with a different narrowing constraint at different points +//! in a file. See the [`use_def`][crate::semantic_index::use_def] module for more details. +//! +//! This module defines how narrowing constraints are stored internally. +//! +//! A _narrowing constraint_ consists of a list of _clauses_, each of which corresponds with an +//! expression in the source file (represented by a [`Constraint`]). We need to support the +//! following operations on narrowing constraints: +//! +//! - Adding a new clause to an existing constraint +//! - Merging two constraints together, which produces the _intersection_ of their clauses +//! - Iterating through the clauses in a constraint +//! +//! In particular, note that we do not need random access to the clauses in a constraint. That +//! means that we can use a simple [_sorted association list_][ruff_index::list] as our data +//! structure. That lets us use a single 32-bit integer to store each narrowing constraint, no +//! matter how many clauses it contains. It also makes merging two narrowing constraints fast, +//! since alists support fast intersection. +//! +//! Because we visit the contents of each scope in source-file order, and assign scoped IDs in +//! source-file order, that means that we will tend to visit narrowing constraints in order by +//! their IDs. This is exactly how to get the best performance from our alist implementation. +//! +//! [`Constraint`]: crate::semantic_index::constraint::Constraint + +use ruff_index::list::{ListBuilder, ListSetReverseIterator, ListStorage}; +use ruff_index::newtype_index; + +use crate::semantic_index::constraint::ScopedConstraintId; + +/// A narrowing constraint associated with a live binding. +/// +/// A constraint is a list of clauses, each of which is a [`Constraint`] that constrains the type +/// of the binding's symbol. +/// +/// An instance of this type represents a _non-empty_ narrowing constraint. You will often wrap +/// this in `Option` and use `None` to represent an empty narrowing constraint. +/// +/// [`Constraint`]: crate::semantic_index::constraint::Constraint +#[newtype_index] +pub(crate) struct ScopedNarrowingConstraintId; + +/// One of the clauses in a narrowing constraint, which is a [`Constraint`] that constrains the +/// type of the binding's symbol. +/// +/// Note that those [`Constraint`]s are stored in [their own per-scope +/// arena][crate::semantic_index::constraint::Constraints], so internally we use a +/// [`ScopedConstraintId`] to refer to the underlying constraint. +/// +/// [`Constraint`]: crate::semantic_index::constraint::Constraint +#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)] +pub(crate) struct ScopedNarrowingConstraintClause(ScopedConstraintId); + +impl ScopedNarrowingConstraintClause { + /// Returns (the ID of) the `Constraint` for this clause + pub(crate) fn constraint(self) -> ScopedConstraintId { + self.0 + } +} + +impl From for ScopedNarrowingConstraintClause { + fn from(constraint: ScopedConstraintId) -> ScopedNarrowingConstraintClause { + ScopedNarrowingConstraintClause(constraint) + } +} + +/// A collection of narrowing constraints for a given scope. +#[derive(Debug, Eq, PartialEq)] +pub(crate) struct NarrowingConstraints { + lists: ListStorage, +} + +// Building constraints +// -------------------- + +/// A builder for creating narrowing constraints. +#[derive(Debug, Default, Eq, PartialEq)] +pub(crate) struct NarrowingConstraintsBuilder { + lists: ListBuilder, +} + +impl NarrowingConstraintsBuilder { + pub(crate) fn build(self) -> NarrowingConstraints { + NarrowingConstraints { + lists: self.lists.build(), + } + } + + /// Adds a clause to an existing narrowing constraint. + pub(crate) fn add( + &mut self, + constraint: Option, + clause: ScopedNarrowingConstraintClause, + ) -> Option { + self.lists.insert(constraint, clause) + } + + /// Returns the intersection of two narrowing constraints. The result contains the clauses that + /// appear in both inputs. + pub(crate) fn intersect( + &mut self, + a: Option, + b: Option, + ) -> Option { + self.lists.intersect(a, b) + } +} + +// Iteration +// --------- + +pub(crate) type NarrowingConstraintsIterator<'a> = std::iter::Copied< + ListSetReverseIterator<'a, ScopedNarrowingConstraintId, ScopedNarrowingConstraintClause>, +>; + +impl NarrowingConstraints { + /// Iterates over the clauses in a narrowing constraint. + pub(crate) fn iter_clauses( + &self, + set: Option, + ) -> NarrowingConstraintsIterator<'_> { + self.lists.iter_set_reverse(set).copied() + } +} + +// Test support +// ------------ + +#[cfg(test)] +mod tests { + use super::*; + + impl ScopedNarrowingConstraintClause { + pub(crate) fn as_u32(self) -> u32 { + self.0.as_u32() + } + } + + impl NarrowingConstraintsBuilder { + pub(crate) fn iter_constraints( + &self, + set: Option, + ) -> NarrowingConstraintsIterator<'_> { + self.lists.iter_set_reverse(set).copied() + } + } +} diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs index bc58c79a74..bf7a29ecd5 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs @@ -260,20 +260,22 @@ use ruff_index::{newtype_index, IndexVec}; use rustc_hash::FxHashMap; use self::symbol_state::{ - ConstraintIndexIterator, LiveBindingsIterator, LiveDeclaration, LiveDeclarationsIterator, - ScopedDefinitionId, SymbolBindings, SymbolDeclarations, SymbolState, + LiveBindingsIterator, LiveDeclaration, LiveDeclarationsIterator, ScopedDefinitionId, + SymbolBindings, SymbolDeclarations, SymbolState, }; use crate::semantic_index::ast_ids::ScopedUseId; use crate::semantic_index::constraint::{ Constraint, Constraints, ConstraintsBuilder, ScopedConstraintId, }; use crate::semantic_index::definition::Definition; +use crate::semantic_index::narrowing_constraints::{ + NarrowingConstraints, NarrowingConstraintsBuilder, NarrowingConstraintsIterator, +}; use crate::semantic_index::symbol::{FileScopeId, ScopedSymbolId}; -use crate::visibility_constraints::{ +use crate::semantic_index::visibility_constraints::{ ScopedVisibilityConstraintId, VisibilityConstraints, VisibilityConstraintsBuilder, }; -mod bitset; mod symbol_state; /// Applicable definitions and constraints for every use of a name. @@ -286,6 +288,9 @@ pub(crate) struct UseDefMap<'db> { /// Array of [`Constraint`] in this scope. constraints: Constraints<'db>, + /// Array of narrowing constraints in this scope. + narrowing_constraints: NarrowingConstraints, + /// Array of visibility constraints in this scope. visibility_constraints: VisibilityConstraints, @@ -370,6 +375,7 @@ impl<'db> UseDefMap<'db> { BindingWithConstraintsIterator { all_definitions: &self.all_definitions, constraints: &self.constraints, + narrowing_constraints: &self.narrowing_constraints, visibility_constraints: &self.visibility_constraints, inner: bindings.iter(), } @@ -416,6 +422,7 @@ type EagerBindings = IndexVec; pub(crate) struct BindingWithConstraintsIterator<'map, 'db> { all_definitions: &'map IndexVec>>, pub(crate) constraints: &'map Constraints<'db>, + pub(crate) narrowing_constraints: &'map NarrowingConstraints, pub(crate) visibility_constraints: &'map VisibilityConstraints, inner: LiveBindingsIterator<'map>, } @@ -425,14 +432,16 @@ impl<'map, 'db> Iterator for BindingWithConstraintsIterator<'map, 'db> { fn next(&mut self) -> Option { let constraints = self.constraints; + let narrowing_constraints = self.narrowing_constraints; self.inner .next() .map(|live_binding| BindingWithConstraints { binding: self.all_definitions[live_binding.binding], - narrowing_constraints: ConstraintsIterator { + narrowing_constraint: ConstraintsIterator { constraints, - constraint_ids: live_binding.narrowing_constraints.iter(), + constraint_ids: narrowing_constraints + .iter_clauses(live_binding.narrowing_constraint), }, visibility_constraint: live_binding.visibility_constraint, }) @@ -443,13 +452,13 @@ impl std::iter::FusedIterator for BindingWithConstraintsIterator<'_, '_> {} pub(crate) struct BindingWithConstraints<'map, 'db> { pub(crate) binding: Option>, - pub(crate) narrowing_constraints: ConstraintsIterator<'map, 'db>, + pub(crate) narrowing_constraint: ConstraintsIterator<'map, 'db>, pub(crate) visibility_constraint: ScopedVisibilityConstraintId, } pub(crate) struct ConstraintsIterator<'map, 'db> { constraints: &'map Constraints<'db>, - constraint_ids: ConstraintIndexIterator<'map>, + constraint_ids: NarrowingConstraintsIterator<'map>, } impl<'db> Iterator for ConstraintsIterator<'_, 'db> { @@ -458,7 +467,7 @@ impl<'db> Iterator for ConstraintsIterator<'_, 'db> { fn next(&mut self) -> Option { self.constraint_ids .next() - .map(|constraint_id| self.constraints[ScopedConstraintId::from_u32(constraint_id)]) + .map(|narrowing_constraint| self.constraints[narrowing_constraint.constraint()]) } } @@ -509,7 +518,10 @@ pub(super) struct UseDefMapBuilder<'db> { all_definitions: IndexVec>>, /// Builder of constraints. - constraints: ConstraintsBuilder<'db>, + pub(super) constraints: ConstraintsBuilder<'db>, + + /// Builder of narrowing constraints. + pub(super) narrowing_constraints: NarrowingConstraintsBuilder, /// Builder of visibility constraints. pub(super) visibility_constraints: VisibilityConstraintsBuilder, @@ -542,6 +554,7 @@ impl Default for UseDefMapBuilder<'_> { Self { all_definitions: IndexVec::from_iter([None]), constraints: ConstraintsBuilder::default(), + narrowing_constraints: NarrowingConstraintsBuilder::default(), visibility_constraints: VisibilityConstraintsBuilder::default(), scope_start_visibility: ScopedVisibilityConstraintId::ALWAYS_TRUE, bindings_by_use: IndexVec::new(), @@ -578,8 +591,9 @@ impl<'db> UseDefMapBuilder<'db> { } pub(super) fn record_constraint_id(&mut self, constraint: ScopedConstraintId) { + let narrowing_constraint = constraint.into(); for state in &mut self.symbol_states { - state.record_constraint(constraint); + state.record_constraint(&mut self.narrowing_constraints, narrowing_constraint); } } @@ -737,10 +751,15 @@ impl<'db> UseDefMapBuilder<'db> { let mut snapshot_definitions_iter = snapshot.symbol_states.into_iter(); for current in &mut self.symbol_states { if let Some(snapshot) = snapshot_definitions_iter.next() { - current.merge(snapshot, &mut self.visibility_constraints); + current.merge( + snapshot, + &mut self.narrowing_constraints, + &mut self.visibility_constraints, + ); } else { current.merge( SymbolState::undefined(snapshot.scope_start_visibility), + &mut self.narrowing_constraints, &mut self.visibility_constraints, ); // Symbol not present in snapshot, so it's unbound/undeclared from that path. @@ -763,6 +782,7 @@ impl<'db> UseDefMapBuilder<'db> { UseDefMap { all_definitions: self.all_definitions, constraints: self.constraints.build(), + narrowing_constraints: self.narrowing_constraints.build(), visibility_constraints: self.visibility_constraints.build(), bindings_by_use: self.bindings_by_use, public_symbols: self.symbol_states, diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def/bitset.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def/bitset.rs deleted file mode 100644 index da44fd0bf0..0000000000 --- a/crates/red_knot_python_semantic/src/semantic_index/use_def/bitset.rs +++ /dev/null @@ -1,234 +0,0 @@ -/// Ordered set of `u32`. -/// -/// Uses an inline bit-set for small values (up to 64 * B), falls back to heap allocated vector of -/// blocks for larger values. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(super) enum BitSet { - /// Bit-set (in 64-bit blocks) for the first 64 * B entries. - Inline([u64; B]), - - /// Overflow beyond 64 * B. - Heap(Vec), -} - -impl Default for BitSet { - fn default() -> Self { - // B * 64 must fit in a u32, or else we have unusable bits; this assertion makes the - // truncating casts to u32 below safe. This would be better as a const assertion, but - // that's not possible on stable with const generic params. (B should never really be - // anywhere close to this large.) - assert!(B * 64 < (u32::MAX as usize)); - // This implementation requires usize >= 32 bits. - static_assertions::const_assert!(usize::BITS >= 32); - Self::Inline([0; B]) - } -} - -impl BitSet { - /// Convert from Inline to Heap, if needed, and resize the Heap vector, if needed. - fn resize(&mut self, value: u32) { - let num_blocks_needed = (value / 64) + 1; - self.resize_blocks(num_blocks_needed as usize); - } - - fn resize_blocks(&mut self, num_blocks_needed: usize) { - match self { - Self::Inline(blocks) => { - let mut vec = blocks.to_vec(); - vec.resize(num_blocks_needed, 0); - *self = Self::Heap(vec); - } - Self::Heap(vec) => { - vec.resize(num_blocks_needed, 0); - } - } - } - - fn blocks_mut(&mut self) -> &mut [u64] { - match self { - Self::Inline(blocks) => blocks.as_mut_slice(), - Self::Heap(blocks) => blocks.as_mut_slice(), - } - } - - fn blocks(&self) -> &[u64] { - match self { - Self::Inline(blocks) => blocks.as_slice(), - Self::Heap(blocks) => blocks.as_slice(), - } - } - - /// Insert a value into the [`BitSet`]. - /// - /// Return true if the value was newly inserted, false if already present. - pub(super) fn insert(&mut self, value: u32) -> bool { - let value_usize = value as usize; - let (block, index) = (value_usize / 64, value_usize % 64); - if block >= self.blocks().len() { - self.resize(value); - } - let blocks = self.blocks_mut(); - let missing = blocks[block] & (1 << index) == 0; - blocks[block] |= 1 << index; - missing - } - - /// Intersect in-place with another [`BitSet`]. - pub(super) fn intersect(&mut self, other: &BitSet) { - let my_blocks = self.blocks_mut(); - let other_blocks = other.blocks(); - let min_len = my_blocks.len().min(other_blocks.len()); - for i in 0..min_len { - my_blocks[i] &= other_blocks[i]; - } - for block in my_blocks.iter_mut().skip(min_len) { - *block = 0; - } - } - - /// Return an iterator over the values (in ascending order) in this [`BitSet`]. - pub(super) fn iter(&self) -> BitSetIterator<'_, B> { - let blocks = self.blocks(); - BitSetIterator { - blocks, - current_block_index: 0, - current_block: blocks[0], - } - } -} - -/// Iterator over values in a [`BitSet`]. -#[derive(Debug)] -pub(super) struct BitSetIterator<'a, const B: usize> { - /// The blocks we are iterating over. - blocks: &'a [u64], - - /// The index of the block we are currently iterating through. - current_block_index: usize, - - /// The block we are currently iterating through (and zeroing as we go.) - current_block: u64, -} - -impl Iterator for BitSetIterator<'_, B> { - type Item = u32; - - fn next(&mut self) -> Option { - while self.current_block == 0 { - if self.current_block_index + 1 >= self.blocks.len() { - return None; - } - self.current_block_index += 1; - self.current_block = self.blocks[self.current_block_index]; - } - let lowest_bit_set = self.current_block.trailing_zeros(); - // reset the lowest set bit, without a data dependency on `lowest_bit_set` - self.current_block &= self.current_block.wrapping_sub(1); - // SAFETY: `lowest_bit_set` cannot be more than 64, `current_block_index` cannot be more - // than `B - 1`, and we check above that `B * 64 < u32::MAX`. So both `64 * - // current_block_index` and the final value here must fit in u32. - #[allow(clippy::cast_possible_truncation)] - Some(lowest_bit_set + (64 * self.current_block_index) as u32) - } -} - -impl std::iter::FusedIterator for BitSetIterator<'_, B> {} - -#[cfg(test)] -mod tests { - use super::BitSet; - - impl BitSet { - /// Create and return a new [`BitSet`] with a single `value` inserted. - pub(super) fn with(value: u32) -> Self { - let mut bitset = Self::default(); - bitset.insert(value); - bitset - } - } - - fn assert_bitset(bitset: &BitSet, contents: &[u32]) { - assert_eq!(bitset.iter().collect::>(), contents); - } - - #[test] - fn iter() { - let mut b = BitSet::<1>::with(3); - b.insert(27); - b.insert(6); - assert!(matches!(b, BitSet::Inline(_))); - assert_bitset(&b, &[3, 6, 27]); - } - - #[test] - fn iter_overflow() { - let mut b = BitSet::<1>::with(140); - b.insert(100); - b.insert(129); - assert!(matches!(b, BitSet::Heap(_))); - assert_bitset(&b, &[100, 129, 140]); - } - - #[test] - fn intersect() { - let mut b1 = BitSet::<1>::with(4); - let mut b2 = BitSet::<1>::with(4); - b1.insert(23); - b2.insert(5); - - b1.intersect(&b2); - assert_bitset(&b1, &[4]); - } - - #[test] - fn intersect_mixed_1() { - let mut b1 = BitSet::<1>::with(4); - let mut b2 = BitSet::<1>::with(4); - b1.insert(89); - b2.insert(5); - - b1.intersect(&b2); - assert_bitset(&b1, &[4]); - } - - #[test] - fn intersect_mixed_2() { - let mut b1 = BitSet::<1>::with(4); - let mut b2 = BitSet::<1>::with(4); - b1.insert(23); - b2.insert(89); - - b1.intersect(&b2); - assert_bitset(&b1, &[4]); - } - - #[test] - fn intersect_heap() { - let mut b1 = BitSet::<1>::with(4); - let mut b2 = BitSet::<1>::with(4); - b1.insert(89); - b2.insert(90); - - b1.intersect(&b2); - assert_bitset(&b1, &[4]); - } - - #[test] - fn intersect_heap_2() { - let mut b1 = BitSet::<1>::with(89); - let mut b2 = BitSet::<1>::with(89); - b1.insert(91); - b2.insert(90); - - b1.intersect(&b2); - assert_bitset(&b1, &[89]); - } - - #[test] - fn multiple_blocks() { - let mut b = BitSet::<2>::with(120); - b.insert(45); - assert!(matches!(b, BitSet::Inline(_))); - assert_bitset(&b, &[45, 120]); - } -} diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def/symbol_state.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def/symbol_state.rs index 0dabb54902..f29b628308 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/use_def/symbol_state.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def/symbol_state.rs @@ -36,10 +36,8 @@ //! dominates, but it does dominate the `x = 1 if flag2 else None` binding, so we have to keep //! track of that. //! -//! The data structures used here ([`BitSet`] and [`smallvec::SmallVec`]) optimize for keeping all -//! data inline (avoiding lots of scattered allocations) in small-to-medium cases, and falling back -//! to heap allocation to be able to scale to arbitrary numbers of live bindings and constraints -//! when needed. +//! The data structures use `IndexVec` arenas to store all data compactly and contiguously, while +//! supporting very cheap clones. //! //! Tracking live declarations is simpler, since constraints are not involved, but otherwise very //! similar to tracking live bindings. @@ -48,10 +46,12 @@ use itertools::{EitherOrBoth, Itertools}; use ruff_index::newtype_index; use smallvec::{smallvec, SmallVec}; -use crate::semantic_index::constraint::ScopedConstraintId; -use crate::semantic_index::use_def::bitset::{BitSet, BitSetIterator}; -use crate::semantic_index::use_def::VisibilityConstraintsBuilder; -use crate::visibility_constraints::ScopedVisibilityConstraintId; +use crate::semantic_index::narrowing_constraints::{ + NarrowingConstraintsBuilder, ScopedNarrowingConstraintClause, ScopedNarrowingConstraintId, +}; +use crate::semantic_index::visibility_constraints::{ + ScopedVisibilityConstraintId, VisibilityConstraintsBuilder, +}; /// A newtype-index for a definition in a particular scope. #[newtype_index] @@ -67,18 +67,10 @@ impl ScopedDefinitionId { pub(super) const UNBOUND: ScopedDefinitionId = ScopedDefinitionId::from_u32(0); } -/// Can reference this * 64 total constraints inline; more will fall back to the heap. -const INLINE_CONSTRAINT_BLOCKS: usize = 2; - /// Can keep inline this many live bindings or declarations per symbol at a given time; more will /// go to heap. const INLINE_DEFINITIONS_PER_SYMBOL: usize = 4; -/// Which constraints apply to a given binding? -type Constraints = BitSet; - -pub(super) type ConstraintIndexIterator<'a> = BitSetIterator<'a, INLINE_CONSTRAINT_BLOCKS>; - /// Live declarations for a single symbol at some point in control flow, with their /// corresponding visibility constraints. #[derive(Clone, Debug, Default, PartialEq, Eq, salsa::Update)] @@ -197,7 +189,7 @@ pub(super) struct SymbolBindings { #[derive(Clone, Debug, PartialEq, Eq)] pub(super) struct LiveBinding { pub(super) binding: ScopedDefinitionId, - pub(super) narrowing_constraints: Constraints, + pub(super) narrowing_constraint: Option, pub(super) visibility_constraint: ScopedVisibilityConstraintId, } @@ -207,7 +199,7 @@ impl SymbolBindings { fn unbound(scope_start_visibility: ScopedVisibilityConstraintId) -> Self { let initial_binding = LiveBinding { binding: ScopedDefinitionId::UNBOUND, - narrowing_constraints: Constraints::default(), + narrowing_constraint: None, visibility_constraint: scope_start_visibility, }; Self { @@ -226,15 +218,20 @@ impl SymbolBindings { self.live_bindings.clear(); self.live_bindings.push(LiveBinding { binding, - narrowing_constraints: Constraints::default(), + narrowing_constraint: None, visibility_constraint, }); } /// Add given constraint to all live bindings. - pub(super) fn record_constraint(&mut self, constraint_id: ScopedConstraintId) { + pub(super) fn record_constraint( + &mut self, + narrowing_constraints: &mut NarrowingConstraintsBuilder, + constraint: ScopedNarrowingConstraintClause, + ) { for binding in &mut self.live_bindings { - binding.narrowing_constraints.insert(constraint_id.into()); + binding.narrowing_constraint = + narrowing_constraints.add(binding.narrowing_constraint, constraint); } } @@ -273,7 +270,12 @@ impl SymbolBindings { } } - fn merge(&mut self, b: Self, visibility_constraints: &mut VisibilityConstraintsBuilder) { + fn merge( + &mut self, + b: Self, + narrowing_constraints: &mut NarrowingConstraintsBuilder, + visibility_constraints: &mut VisibilityConstraintsBuilder, + ) { let a = std::mem::take(self); // Invariant: merge_join_by consumes the two iterators in sorted order, which ensures that @@ -289,8 +291,8 @@ impl SymbolBindings { // If the same definition is visible through both paths, any constraint // that applies on only one path is irrelevant to the resulting type from // unioning the two paths, so we intersect the constraints. - let mut narrowing_constraints = a.narrowing_constraints; - narrowing_constraints.intersect(&b.narrowing_constraints); + let narrowing_constraint = narrowing_constraints + .intersect(a.narrowing_constraint, b.narrowing_constraint); // For visibility constraints, we merge them using a ternary OR operation: let visibility_constraint = visibility_constraints @@ -298,7 +300,7 @@ impl SymbolBindings { self.live_bindings.push(LiveBinding { binding: a.binding, - narrowing_constraints, + narrowing_constraint, visibility_constraint, }); } @@ -338,8 +340,13 @@ impl SymbolState { } /// Add given constraint to all live bindings. - pub(super) fn record_constraint(&mut self, constraint_id: ScopedConstraintId) { - self.bindings.record_constraint(constraint_id); + pub(super) fn record_constraint( + &mut self, + narrowing_constraints: &mut NarrowingConstraintsBuilder, + constraint: ScopedNarrowingConstraintClause, + ) { + self.bindings + .record_constraint(narrowing_constraints, constraint); } /// Add given visibility constraint to all live bindings. @@ -373,9 +380,11 @@ impl SymbolState { pub(super) fn merge( &mut self, b: SymbolState, + narrowing_constraints: &mut NarrowingConstraintsBuilder, visibility_constraints: &mut VisibilityConstraintsBuilder, ) { - self.bindings.merge(b.bindings, visibility_constraints); + self.bindings + .merge(b.bindings, narrowing_constraints, visibility_constraints); self.declarations .merge(b.declarations, visibility_constraints); } @@ -393,8 +402,14 @@ impl SymbolState { mod tests { use super::*; + use crate::semantic_index::constraint::ScopedConstraintId; + #[track_caller] - fn assert_bindings(symbol: &SymbolState, expected: &[&str]) { + fn assert_bindings( + narrowing_constraints: &NarrowingConstraintsBuilder, + symbol: &SymbolState, + expected: &[&str], + ) { let actual = symbol .bindings() .iter() @@ -405,10 +420,9 @@ mod tests { } else { def_id.as_u32().to_string() }; - let constraints = live_binding - .narrowing_constraints - .iter() - .map(|idx| idx.to_string()) + let constraints = narrowing_constraints + .iter_constraints(live_binding.narrowing_constraint) + .map(|idx| idx.as_u32().to_string()) .collect::>() .join(", "); format!("{def}<{constraints}>") @@ -440,36 +454,41 @@ mod tests { #[test] fn unbound() { + let narrowing_constraints = NarrowingConstraintsBuilder::default(); let sym = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE); - assert_bindings(&sym, &["unbound<>"]); + assert_bindings(&narrowing_constraints, &sym, &["unbound<>"]); } #[test] fn with() { + let narrowing_constraints = NarrowingConstraintsBuilder::default(); let mut sym = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE); sym.record_binding( ScopedDefinitionId::from_u32(1), ScopedVisibilityConstraintId::ALWAYS_TRUE, ); - assert_bindings(&sym, &["1<>"]); + assert_bindings(&narrowing_constraints, &sym, &["1<>"]); } #[test] fn record_constraint() { + let mut narrowing_constraints = NarrowingConstraintsBuilder::default(); let mut sym = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE); sym.record_binding( ScopedDefinitionId::from_u32(1), ScopedVisibilityConstraintId::ALWAYS_TRUE, ); - sym.record_constraint(ScopedConstraintId::from_u32(0)); + let constraint = ScopedConstraintId::from_u32(0).into(); + sym.record_constraint(&mut narrowing_constraints, constraint); - assert_bindings(&sym, &["1<0>"]); + assert_bindings(&narrowing_constraints, &sym, &["1<0>"]); } #[test] fn merge() { + let mut narrowing_constraints = NarrowingConstraintsBuilder::default(); let mut visibility_constraints = VisibilityConstraintsBuilder::default(); // merging the same definition with the same constraint keeps the constraint @@ -478,18 +497,24 @@ mod tests { ScopedDefinitionId::from_u32(1), ScopedVisibilityConstraintId::ALWAYS_TRUE, ); - sym1a.record_constraint(ScopedConstraintId::from_u32(0)); + let constraint = ScopedConstraintId::from_u32(0).into(); + sym1a.record_constraint(&mut narrowing_constraints, constraint); let mut sym1b = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE); sym1b.record_binding( ScopedDefinitionId::from_u32(1), ScopedVisibilityConstraintId::ALWAYS_TRUE, ); - sym1b.record_constraint(ScopedConstraintId::from_u32(0)); + let constraint = ScopedConstraintId::from_u32(0).into(); + sym1b.record_constraint(&mut narrowing_constraints, constraint); - sym1a.merge(sym1b, &mut visibility_constraints); + sym1a.merge( + sym1b, + &mut narrowing_constraints, + &mut visibility_constraints, + ); let mut sym1 = sym1a; - assert_bindings(&sym1, &["1<0>"]); + assert_bindings(&narrowing_constraints, &sym1, &["1<0>"]); // merging the same definition with differing constraints drops all constraints let mut sym2a = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE); @@ -497,18 +522,24 @@ mod tests { ScopedDefinitionId::from_u32(2), ScopedVisibilityConstraintId::ALWAYS_TRUE, ); - sym2a.record_constraint(ScopedConstraintId::from_u32(1)); + let constraint = ScopedConstraintId::from_u32(1).into(); + sym2a.record_constraint(&mut narrowing_constraints, constraint); let mut sym1b = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE); sym1b.record_binding( ScopedDefinitionId::from_u32(2), ScopedVisibilityConstraintId::ALWAYS_TRUE, ); - sym1b.record_constraint(ScopedConstraintId::from_u32(2)); + let constraint = ScopedConstraintId::from_u32(2).into(); + sym1b.record_constraint(&mut narrowing_constraints, constraint); - sym2a.merge(sym1b, &mut visibility_constraints); + sym2a.merge( + sym1b, + &mut narrowing_constraints, + &mut visibility_constraints, + ); let sym2 = sym2a; - assert_bindings(&sym2, &["2<>"]); + assert_bindings(&narrowing_constraints, &sym2, &["2<>"]); // merging a constrained definition with unbound keeps both let mut sym3a = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE); @@ -516,18 +547,27 @@ mod tests { ScopedDefinitionId::from_u32(3), ScopedVisibilityConstraintId::ALWAYS_TRUE, ); - sym3a.record_constraint(ScopedConstraintId::from_u32(3)); + let constraint = ScopedConstraintId::from_u32(3).into(); + sym3a.record_constraint(&mut narrowing_constraints, constraint); let sym2b = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE); - sym3a.merge(sym2b, &mut visibility_constraints); + sym3a.merge( + sym2b, + &mut narrowing_constraints, + &mut visibility_constraints, + ); let sym3 = sym3a; - assert_bindings(&sym3, &["unbound<>", "3<3>"]); + assert_bindings(&narrowing_constraints, &sym3, &["unbound<>", "3<3>"]); // merging different definitions keeps them each with their existing constraints - sym1.merge(sym3, &mut visibility_constraints); + sym1.merge( + sym3, + &mut narrowing_constraints, + &mut visibility_constraints, + ); let sym = sym1; - assert_bindings(&sym, &["unbound<>", "1<0>", "3<3>"]); + assert_bindings(&narrowing_constraints, &sym, &["unbound<>", "1<0>", "3<3>"]); } #[test] @@ -556,6 +596,7 @@ mod tests { #[test] fn record_declaration_merge() { + let mut narrowing_constraints = NarrowingConstraintsBuilder::default(); let mut visibility_constraints = VisibilityConstraintsBuilder::default(); let mut sym = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE); sym.record_declaration(ScopedDefinitionId::from_u32(1)); @@ -563,20 +604,29 @@ mod tests { let mut sym2 = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE); sym2.record_declaration(ScopedDefinitionId::from_u32(2)); - sym.merge(sym2, &mut visibility_constraints); + sym.merge( + sym2, + &mut narrowing_constraints, + &mut visibility_constraints, + ); assert_declarations(&sym, &["1", "2"]); } #[test] fn record_declaration_merge_partial_undeclared() { + let mut narrowing_constraints = NarrowingConstraintsBuilder::default(); let mut visibility_constraints = VisibilityConstraintsBuilder::default(); let mut sym = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE); sym.record_declaration(ScopedDefinitionId::from_u32(1)); let sym2 = SymbolState::undefined(ScopedVisibilityConstraintId::ALWAYS_TRUE); - sym.merge(sym2, &mut visibility_constraints); + sym.merge( + sym2, + &mut narrowing_constraints, + &mut visibility_constraints, + ); assert_declarations(&sym, &["undeclared", "1"]); } diff --git a/crates/red_knot_python_semantic/src/visibility_constraints.rs b/crates/red_knot_python_semantic/src/semantic_index/visibility_constraints.rs similarity index 99% rename from crates/red_knot_python_semantic/src/visibility_constraints.rs rename to crates/red_knot_python_semantic/src/semantic_index/visibility_constraints.rs index 97cfe06369..af1c873800 100644 --- a/crates/red_knot_python_semantic/src/visibility_constraints.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/visibility_constraints.rs @@ -281,8 +281,7 @@ const AMBIGUOUS: ScopedVisibilityConstraintId = ScopedVisibilityConstraintId::AM const ALWAYS_FALSE: ScopedVisibilityConstraintId = ScopedVisibilityConstraintId::ALWAYS_FALSE; const SMALLEST_TERMINAL: ScopedVisibilityConstraintId = ALWAYS_FALSE; -/// A collection of visibility constraints. This is currently stored in `UseDefMap`, which means we -/// maintain a separate set of visibility constraints for each scope in file. +/// A collection of visibility constraints for a given scope. #[derive(Debug, PartialEq, Eq, salsa::Update)] pub(crate) struct VisibilityConstraints { interiors: IndexVec, diff --git a/crates/red_knot_python_semantic/src/symbol.rs b/crates/red_knot_python_semantic/src/symbol.rs index 0f696d3d98..49cde8986a 100644 --- a/crates/red_knot_python_semantic/src/symbol.rs +++ b/crates/red_knot_python_semantic/src/symbol.rs @@ -8,7 +8,7 @@ use crate::semantic_index::{ symbol_table, BindingWithConstraints, BindingWithConstraintsIterator, DeclarationsIterator, }; use crate::types::{ - binding_type, declaration_type, narrowing_constraint, todo_type, IntersectionBuilder, + binding_type, declaration_type, infer_narrowing_constraint, todo_type, IntersectionBuilder, KnownClass, Truthiness, Type, TypeAndQualifiers, TypeQualifiers, UnionBuilder, UnionType, }; use crate::{resolve_module, Db, KnownModule, Module, Program}; @@ -550,7 +550,7 @@ fn symbol_from_bindings_impl<'db>( Some(BindingWithConstraints { binding, visibility_constraint, - narrowing_constraints: _, + narrowing_constraint: _, }) if binding.map_or(true, is_non_exported) => { visibility_constraints.evaluate(db, constraints, *visibility_constraint) } @@ -560,7 +560,7 @@ fn symbol_from_bindings_impl<'db>( let mut types = bindings_with_constraints.filter_map( |BindingWithConstraints { binding, - narrowing_constraints, + narrowing_constraint, visibility_constraint, }| { let binding = binding?; @@ -576,21 +576,23 @@ fn symbol_from_bindings_impl<'db>( return None; } - let mut constraint_tys = narrowing_constraints - .filter_map(|constraint| narrowing_constraint(db, constraint, binding)) - .peekable(); + let constraint_tys: Vec<_> = narrowing_constraint + .filter_map(|constraint| infer_narrowing_constraint(db, constraint, binding)) + .collect(); let binding_ty = binding_type(db, binding); - if constraint_tys.peek().is_some() { + if constraint_tys.is_empty() { + Some(binding_ty) + } else { let intersection_ty = constraint_tys + .into_iter() + .rev() .fold( IntersectionBuilder::new(db).add_positive(binding_ty), IntersectionBuilder::add_positive, ) .build(); Some(intersection_ty) - } else { - Some(binding_ty) } }, ); diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 29df6db3fa..18103cddbb 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -34,7 +34,7 @@ use crate::types::class_base::ClassBase; use crate::types::diagnostic::{INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION}; use crate::types::infer::infer_unpack_types; use crate::types::mro::{Mro, MroError, MroIterator}; -pub(crate) use crate::types::narrow::narrowing_constraint; +pub(crate) use crate::types::narrow::infer_narrowing_constraint; use crate::types::signatures::{Parameter, ParameterKind, Parameters}; use crate::{Db, FxOrderSet, Module, Program}; pub(crate) use class::{Class, ClassLiteralType, InstanceType, KnownClass, KnownInstanceType}; diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index bd1d975011..adbd9fb214 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -35,7 +35,7 @@ use std::sync::Arc; /// /// But if we called this with the same `test` expression, but the `definition` of `y`, no /// constraint is applied to that definition, so we'd just return `None`. -pub(crate) fn narrowing_constraint<'db>( +pub(crate) fn infer_narrowing_constraint<'db>( db: &'db dyn Db, constraint: Constraint<'db>, definition: Definition<'db>, diff --git a/crates/ruff_index/src/lib.rs b/crates/ruff_index/src/lib.rs index 6f7ac59c41..33dc83d342 100644 --- a/crates/ruff_index/src/lib.rs +++ b/crates/ruff_index/src/lib.rs @@ -4,6 +4,7 @@ //! Inspired by [rustc_index](https://github.com/rust-lang/rust/blob/master/compiler/rustc_index/src/lib.rs). mod idx; +pub mod list; mod slice; mod vec; diff --git a/crates/ruff_index/src/list.rs b/crates/ruff_index/src/list.rs new file mode 100644 index 0000000000..72fdba105d --- /dev/null +++ b/crates/ruff_index/src/list.rs @@ -0,0 +1,761 @@ +use std::cmp::Ordering; +use std::ops::Deref; + +use crate::vec::IndexVec; +use crate::Idx; + +/// Stores one or more _association lists_, which are linked lists of key/value pairs. We +/// additionally guarantee that the elements of an association list are sorted (by their keys), and +/// that they do not contain any entries with duplicate keys. +/// +/// Association lists have fallen out of favor in recent decades, since you often need operations +/// that are inefficient on them. In particular, looking up a random element by index is O(n), just +/// like a linked list; and looking up an element by key is also O(n), since you must do a linear +/// scan of the list to find the matching element. The typical implementation also suffers from +/// poor cache locality and high memory allocation overhead, since individual list cells are +/// typically allocated separately from the heap. +/// +/// We solve that last problem by storing the cells of an association list in an [`IndexVec`] +/// arena. You provide the index type (`I`) that you want to use with this arena. That means that +/// an individual association list is represented by an `Option`, with `None` representing an +/// empty list. +/// +/// We exploit structural sharing where possible, reusing cells across multiple lists when we can. +/// That said, we don't guarantee that lists are canonical — it's entirely possible for two lists +/// with identical contents to use different list cells and have different identifiers. +/// +/// Given all of this, association lists have the following benefits: +/// +/// - Lists can be represented by a single 32-bit integer (the index into the arena of the head of +/// the list). +/// - Lists can be cloned in constant time, since the underlying cells are immutable. +/// - Lists can be combined quickly (for both intersection and union), especially when you already +/// have to zip through both input lists to combine each key's values in some way. +/// +/// There is one remaining caveat: +/// +/// - You should construct lists in key order; doing this lets you insert each value in constant time. +/// Inserting entries in reverse order results in _quadratic_ overall time to construct the list. +/// +/// This type provides read-only access to the lists. Use a [`ListBuilder`] to create lists. +#[derive(Debug, Eq, PartialEq)] +pub struct ListStorage { + cells: IndexVec>, +} + +/// Each association list is represented by a sequence of snoc cells. A snoc cell is like the more +/// familiar cons cell `(a : (b : (c : nil)))`, but in reverse `(((nil : a) : b) : c)`. +/// +/// **Terminology**: The elements of a cons cell are usually called `head` and `tail` (assuming +/// you're not in Lisp-land, where they're called `car` and `cdr`). The elements of a snoc cell +/// are usually called `rest` and `last`. +/// +/// We use a tuple struct instead of named fields because we always unpack a cell into local +/// variables: +/// +/// ```ignore +/// let ListCell(rest, last_key, last_value) = /* ... */; +/// ``` +#[derive(Debug, Eq, PartialEq)] +struct ListCell(Option, K, V); + +impl ListStorage { + /// Iterates through the entries in a list _in reverse order by key_. + pub fn iter_reverse(&self, list: Option) -> ListReverseIterator<'_, I, K, V> { + ListReverseIterator { + storage: self, + curr: list, + } + } +} + +pub struct ListReverseIterator<'a, I, K, V> { + storage: &'a ListStorage, + curr: Option, +} + +impl<'a, I: Idx, K, V> Iterator for ListReverseIterator<'a, I, K, V> { + type Item = (&'a K, &'a V); + + fn next(&mut self) -> Option { + let ListCell(rest, key, value) = &self.storage.cells[self.curr?]; + self.curr = *rest; + Some((key, value)) + } +} + +/// Constructs one or more association lists. +#[derive(Debug, Eq, PartialEq)] +pub struct ListBuilder { + storage: ListStorage, + + /// Scratch space that lets us implement our list operations iteratively instead of + /// recursively. + /// + /// The snoc-list representation that we use for alists is very common in functional + /// programming, and the simplest implementations of most of the operations are defined + /// recursively on that data structure. However, they are not _tail_ recursive, which means + /// that the call stack grows linearly with the size of the input, which can be a problem for + /// large lists. + /// + /// You can often rework those recursive implementations into iterative ones using an + /// _accumulator_, but that comes at the cost of reversing the list. If we didn't care about + /// ordering, that wouldn't be a problem. Since we want our lists to be sorted, we can't rely + /// on that on its own. + /// + /// The next standard trick is to use an accumulator, and use a fix-up step at the end to + /// reverse the (reversed) result in the accumulator, restoring the correct order. + /// + /// So, that's what we do! However, as one last optimization, we don't build up alist cells in + /// our accumulator, since that would add wasteful cruft to our list storage. Instead, we use a + /// normal Vec as our accumulator, holding the key/value pairs that should be stitched onto the + /// end of whatever result list we are creating. For our fix-up step, we can consume a Vec in + /// reverse order by `pop`ping the elements off one by one. + scratch: Vec<(K, V)>, +} + +impl Default for ListBuilder { + fn default() -> Self { + ListBuilder { + storage: ListStorage { + cells: IndexVec::default(), + }, + scratch: Vec::default(), + } + } +} + +impl Deref for ListBuilder { + type Target = ListStorage; + fn deref(&self) -> &ListStorage { + &self.storage + } +} + +impl ListBuilder { + /// Finalizes a `ListBuilder`. After calling this, you cannot create any new lists managed by + /// this storage. + pub fn build(mut self) -> ListStorage { + self.storage.cells.shrink_to_fit(); + self.storage + } + + /// Adds a new cell to the list. + /// + /// Adding an element always returns a non-empty list, which means we could technically use `I` + /// as our return type, since we never return `None`. However, for consistency with our other + /// methods, we always use `Option` as the return type for any method that can return a + /// list. + #[allow(clippy::unnecessary_wraps)] + fn add_cell(&mut self, rest: Option, key: K, value: V) -> Option { + Some(self.storage.cells.push(ListCell(rest, key, value))) + } + + /// Returns an entry pointing at where `key` would be inserted into a list. + /// + /// Note that when we add a new element to a list, we might have to clone the keys and values + /// of some existing elements. This is because list cells are immutable once created, since + /// they might be shared across multiple lists. We must therefore create new cells for every + /// element that appears after the new element. + /// + /// That means that you should construct lists in key order, since that means that there are no + /// entries to duplicate for each insertion. If you construct the list in reverse order, we + /// will have to duplicate O(n) entries for each insertion, making it _quadratic_ to construct + /// the entire list. + pub fn entry(&mut self, list: Option, key: K) -> ListEntry + where + K: Clone + Ord, + V: Clone, + { + self.scratch.clear(); + + // Iterate through the input list, looking for the position where the key should be + // inserted. We will need to create new list cells for any elements that appear after the + // new key. Stash those away in our scratch accumulator as we step through the input. The + // result of the loop is that "rest" of the result list, which we will stitch the new key + // (and any succeeding keys) onto. + let mut curr = list; + while let Some(curr_id) = curr { + let ListCell(rest, curr_key, curr_value) = &self.storage.cells[curr_id]; + match key.cmp(curr_key) { + // We found an existing entry in the input list with the desired key. + Ordering::Equal => { + return ListEntry { + builder: self, + list, + key, + rest: ListTail::Occupied(curr_id), + }; + } + // The input list does not already contain this key, and this is where we should + // add it. + Ordering::Greater => { + return ListEntry { + builder: self, + list, + key, + rest: ListTail::Vacant(curr_id), + }; + } + // If this key is in the list, it's further along. We'll need to create a new cell + // for this entry in the result list, so add its contents to the scratch + // accumulator. + Ordering::Less => { + let new_key = curr_key.clone(); + let new_value = curr_value.clone(); + self.scratch.push((new_key, new_value)); + curr = *rest; + } + } + } + + // We made it all the way through the list without finding the desired key, so it belongs + // at the beginning. (And we will unfortunately have to duplicate every existing cell if + // the caller proceeds with inserting the new key!) + ListEntry { + builder: self, + list, + key, + rest: ListTail::Beginning, + } + } +} + +/// A view into a list, indicating where a key would be inserted. +pub struct ListEntry<'a, I, K, V> { + builder: &'a mut ListBuilder, + list: Option, + key: K, + /// Points at the element that already contains `key`, if there is one, or the element + /// immediately before where it would go, if not. + rest: ListTail, +} + +enum ListTail { + /// The list does not already contain `key`, and it would go at the beginning of the list. + Beginning, + /// The list already contains `key` + Occupied(I), + /// The list does not already contain key, and it would go immediately after the given element + Vacant(I), +} + +impl ListEntry<'_, I, K, V> +where + K: Clone + Ord, + V: Clone, +{ + fn stitch_up(self, rest: Option, value: V) -> Option { + let mut result = rest; + result = self.builder.add_cell(result, self.key, value); + while let Some((key, value)) = self.builder.scratch.pop() { + result = self.builder.add_cell(result, key, value); + } + result + } + + /// Inserts a new key/value into the list if the key is not already present. If the list + /// already contains `key`, we return the original list as-is, and do not invoke your closure. + pub fn or_insert_with(self, f: F) -> Option + where + F: FnOnce() -> V, + { + let rest = match self.rest { + // If the list already contains `key`, we don't need to replace anything, and can + // return the original list unmodified. + ListTail::Occupied(_) => return self.list, + // Otherwise we have to create a new entry and stitch it onto the list. + ListTail::Beginning => None, + ListTail::Vacant(index) => Some(index), + }; + self.stitch_up(rest, f()) + } + + /// Inserts a new key/value into the list if the key is not already present. If the list + /// already contains `key`, we return the original list as-is. + pub fn or_insert(self, value: V) -> Option { + self.or_insert_with(|| value) + } + + /// Inserts a new key and the default value into the list if the key is not already present. If + /// the list already contains `key`, we return the original list as-is. + pub fn or_insert_default(self) -> Option + where + V: Default, + { + self.or_insert_with(V::default) + } + + /// Ensures that the list contains an entry mapping the key to `value`, returning the resulting + /// list. Overwrites any existing entry with the same key. As an optimization, if the existing + /// entry has an equal _value_, as well, we return the original list as-is. + pub fn replace(self, value: V) -> Option + where + V: Eq, + { + // If the list already contains `key`, skip past its entry before we add its replacement. + let rest = match self.rest { + ListTail::Beginning => None, + ListTail::Occupied(index) => { + let ListCell(rest, _, existing_value) = &self.builder.cells[index]; + if value == *existing_value { + // As an optimization, if value isn't changed, there's no need to stitch up a + // new list. + return self.list; + } + *rest + } + ListTail::Vacant(index) => Some(index), + }; + self.stitch_up(rest, value) + } + + /// Ensures that the list contains an entry mapping the key to the default, returning the + /// resulting list. Overwrites any existing entry with the same key. As an optimization, if the + /// existing entry has an equal _value_, as well, we return the original list as-is. + pub fn replace_with_default(self) -> Option + where + V: Default + Eq, + { + self.replace(V::default()) + } +} + +impl ListBuilder { + /// Returns the intersection of two lists. The result will contain an entry for any key that + /// appears in both lists. The corresponding values will be combined using the `combine` + /// function that you provide. + pub fn intersect_with( + &mut self, + mut a: Option, + mut b: Option, + mut combine: F, + ) -> Option + where + K: Clone + Ord, + V: Clone, + F: FnMut(&V, &V) -> V, + { + self.scratch.clear(); + + // Zip through the lists, building up the keys/values of the new entries into our scratch + // vector. Continue until we run out of elements in either list. (Any remaining elements in + // the other list cannot possibly be in the intersection.) + while let (Some(a_id), Some(b_id)) = (a, b) { + let ListCell(a_rest, a_key, a_value) = &self.storage.cells[a_id]; + let ListCell(b_rest, b_key, b_value) = &self.storage.cells[b_id]; + match a_key.cmp(b_key) { + // Both lists contain this key; combine their values + Ordering::Equal => { + let new_key = a_key.clone(); + let new_value = combine(a_value, b_value); + self.scratch.push((new_key, new_value)); + a = *a_rest; + b = *b_rest; + } + // a's key is only present in a, so it's not included in the result. + Ordering::Greater => a = *a_rest, + // b's key is only present in b, so it's not included in the result. + Ordering::Less => b = *b_rest, + } + } + + // Once the iteration loop terminates, we stitch the new entries back together into proper + // alist cells. + let mut result = None; + while let Some((key, value)) = self.scratch.pop() { + result = self.add_cell(result, key, value); + } + result + } + + /// Returns the union of two lists. The result will contain an entry for any key that appears + /// in either list. For keys that appear in both lists, the corresponding values will be + /// combined using the `combine` function that you provide. + pub fn union_with(&mut self, mut a: Option, mut b: Option, mut combine: F) -> Option + where + K: Clone + Ord, + V: Clone, + F: FnMut(&V, &V) -> V, + { + self.scratch.clear(); + + // Zip through the lists, building up the keys/values of the new entries into our scratch + // vector. Continue until we run out of elements in either list. (Any remaining elements in + // the other list will be added to the result, but won't need to be combined with + // anything.) + let mut result = loop { + let (a_id, b_id) = match (a, b) { + // If we run out of elements in one of the lists, the non-empty list will appear in + // the output unchanged. + (None, other) | (other, None) => break other, + (Some(a_id), Some(b_id)) => (a_id, b_id), + }; + + let ListCell(a_rest, a_key, a_value) = &self.storage.cells[a_id]; + let ListCell(b_rest, b_key, b_value) = &self.storage.cells[b_id]; + match a_key.cmp(b_key) { + // Both lists contain this key; combine their values + Ordering::Equal => { + let new_key = a_key.clone(); + let new_value = combine(a_value, b_value); + self.scratch.push((new_key, new_value)); + a = *a_rest; + b = *b_rest; + } + // a's key goes into the result next + Ordering::Greater => { + let new_key = a_key.clone(); + let new_value = a_value.clone(); + self.scratch.push((new_key, new_value)); + a = *a_rest; + } + // b's key goes into the result next + Ordering::Less => { + let new_key = b_key.clone(); + let new_value = b_value.clone(); + self.scratch.push((new_key, new_value)); + b = *b_rest; + } + } + }; + + // Once the iteration loop terminates, we stitch the new entries back together into proper + // alist cells. + while let Some((key, value)) = self.scratch.pop() { + result = self.add_cell(result, key, value); + } + result + } +} + +// ---- +// Sets + +impl ListStorage { + /// Iterates through the elements in a set _in reverse order_. + pub fn iter_set_reverse(&self, set: Option) -> ListSetReverseIterator<'_, I, K> { + ListSetReverseIterator { + storage: self, + curr: set, + } + } +} + +pub struct ListSetReverseIterator<'a, I, K> { + storage: &'a ListStorage, + curr: Option, +} + +impl<'a, I: Idx, K> Iterator for ListSetReverseIterator<'a, I, K> { + type Item = &'a K; + + fn next(&mut self) -> Option { + let ListCell(rest, key, ()) = &self.storage.cells[self.curr?]; + self.curr = *rest; + Some(key) + } +} + +impl ListBuilder { + /// Adds an element to a set. + pub fn insert(&mut self, set: Option, element: K) -> Option + where + K: Clone + Ord, + { + self.entry(set, element).or_insert_default() + } + + /// Returns the intersection of two sets. The result will contain any value that appears in + /// both sets. + pub fn intersect(&mut self, a: Option, b: Option) -> Option + where + K: Clone + Ord, + { + self.intersect_with(a, b, |(), ()| ()) + } + + /// Returns the intersection of two sets. The result will contain any value that appears in + /// either set. + pub fn union(&mut self, a: Option, b: Option) -> Option + where + K: Clone + Ord, + { + self.union_with(a, b, |(), ()| ()) + } +} + +// ----- +// Tests + +#[cfg(test)] +mod tests { + use super::*; + + use std::fmt::Display; + use std::fmt::Write; + + use crate::newtype_index; + + // Allows the macro invocation below to work + use crate as ruff_index; + + #[newtype_index] + struct TestIndex; + + // ---- + // Sets + + impl ListStorage + where + I: Idx, + K: Display, + { + fn display_set(&self, list: Option) -> String { + let elements: Vec<_> = self.iter_set_reverse(list).collect(); + let mut result = String::new(); + result.push('['); + for element in elements.into_iter().rev() { + if result.len() > 1 { + result.push_str(", "); + } + write!(&mut result, "{element}").unwrap(); + } + result.push(']'); + result + } + } + + #[test] + fn can_insert_into_set() { + let mut builder = ListBuilder::::default(); + + // Build up the set in order + let set1 = builder.insert(None, 1); + let set12 = builder.insert(set1, 2); + let set123 = builder.insert(set12, 3); + let set1232 = builder.insert(set123, 2); + assert_eq!(builder.display_set(None), "[]"); + assert_eq!(builder.display_set(set1), "[1]"); + assert_eq!(builder.display_set(set12), "[1, 2]"); + assert_eq!(builder.display_set(set123), "[1, 2, 3]"); + assert_eq!(builder.display_set(set1232), "[1, 2, 3]"); + + // And in reverse order + let set3 = builder.insert(None, 3); + let set32 = builder.insert(set3, 2); + let set321 = builder.insert(set32, 1); + let set3212 = builder.insert(set321, 2); + assert_eq!(builder.display_set(None), "[]"); + assert_eq!(builder.display_set(set3), "[3]"); + assert_eq!(builder.display_set(set32), "[2, 3]"); + assert_eq!(builder.display_set(set321), "[1, 2, 3]"); + assert_eq!(builder.display_set(set3212), "[1, 2, 3]"); + } + + #[test] + fn can_intersect_sets() { + let mut builder = ListBuilder::::default(); + + let set1 = builder.entry(None, 1).or_insert_default(); + let set12 = builder.entry(set1, 2).or_insert_default(); + let set123 = builder.entry(set12, 3).or_insert_default(); + let set1234 = builder.entry(set123, 4).or_insert_default(); + + let set2 = builder.entry(None, 2).or_insert_default(); + let set24 = builder.entry(set2, 4).or_insert_default(); + let set245 = builder.entry(set24, 5).or_insert_default(); + let set2457 = builder.entry(set245, 7).or_insert_default(); + + let intersection = builder.intersect(None, None); + assert_eq!(builder.display_set(intersection), "[]"); + let intersection = builder.intersect(None, set1234); + assert_eq!(builder.display_set(intersection), "[]"); + let intersection = builder.intersect(None, set2457); + assert_eq!(builder.display_set(intersection), "[]"); + let intersection = builder.intersect(set1, set1234); + assert_eq!(builder.display_set(intersection), "[1]"); + let intersection = builder.intersect(set1, set2457); + assert_eq!(builder.display_set(intersection), "[]"); + let intersection = builder.intersect(set2, set1234); + assert_eq!(builder.display_set(intersection), "[2]"); + let intersection = builder.intersect(set2, set2457); + assert_eq!(builder.display_set(intersection), "[2]"); + let intersection = builder.intersect(set1234, set2457); + assert_eq!(builder.display_set(intersection), "[2, 4]"); + } + + #[test] + fn can_union_sets() { + let mut builder = ListBuilder::::default(); + + let set1 = builder.entry(None, 1).or_insert_default(); + let set12 = builder.entry(set1, 2).or_insert_default(); + let set123 = builder.entry(set12, 3).or_insert_default(); + let set1234 = builder.entry(set123, 4).or_insert_default(); + + let set2 = builder.entry(None, 2).or_insert_default(); + let set24 = builder.entry(set2, 4).or_insert_default(); + let set245 = builder.entry(set24, 5).or_insert_default(); + let set2457 = builder.entry(set245, 7).or_insert_default(); + + let union = builder.union(None, None); + assert_eq!(builder.display_set(union), "[]"); + let union = builder.union(None, set1234); + assert_eq!(builder.display_set(union), "[1, 2, 3, 4]"); + let union = builder.union(None, set2457); + assert_eq!(builder.display_set(union), "[2, 4, 5, 7]"); + let union = builder.union(set1, set1234); + assert_eq!(builder.display_set(union), "[1, 2, 3, 4]"); + let union = builder.union(set1, set2457); + assert_eq!(builder.display_set(union), "[1, 2, 4, 5, 7]"); + let union = builder.union(set2, set1234); + assert_eq!(builder.display_set(union), "[1, 2, 3, 4]"); + let union = builder.union(set2, set2457); + assert_eq!(builder.display_set(union), "[2, 4, 5, 7]"); + let union = builder.union(set1234, set2457); + assert_eq!(builder.display_set(union), "[1, 2, 3, 4, 5, 7]"); + } + + // ---- + // Maps + + impl ListStorage + where + I: Idx, + K: Display, + V: Display, + { + fn display(&self, list: Option) -> String { + let entries: Vec<_> = self.iter_reverse(list).collect(); + let mut result = String::new(); + result.push('['); + for (key, value) in entries.into_iter().rev() { + if result.len() > 1 { + result.push_str(", "); + } + write!(&mut result, "{key}:{value}").unwrap(); + } + result.push(']'); + result + } + } + + #[test] + fn can_insert_into_map() { + let mut builder = ListBuilder::::default(); + + // Build up the map in order + let map1 = builder.entry(None, 1).replace(1); + let map12 = builder.entry(map1, 2).replace(2); + let map123 = builder.entry(map12, 3).replace(3); + let map1232 = builder.entry(map123, 2).replace(4); + assert_eq!(builder.display(None), "[]"); + assert_eq!(builder.display(map1), "[1:1]"); + assert_eq!(builder.display(map12), "[1:1, 2:2]"); + assert_eq!(builder.display(map123), "[1:1, 2:2, 3:3]"); + assert_eq!(builder.display(map1232), "[1:1, 2:4, 3:3]"); + + // And in reverse order + let map3 = builder.entry(None, 3).replace(3); + let map32 = builder.entry(map3, 2).replace(2); + let map321 = builder.entry(map32, 1).replace(1); + let map3212 = builder.entry(map321, 2).replace(4); + assert_eq!(builder.display(None), "[]"); + assert_eq!(builder.display(map3), "[3:3]"); + assert_eq!(builder.display(map32), "[2:2, 3:3]"); + assert_eq!(builder.display(map321), "[1:1, 2:2, 3:3]"); + assert_eq!(builder.display(map3212), "[1:1, 2:4, 3:3]"); + } + + #[test] + fn can_insert_if_needed_into_map() { + let mut builder = ListBuilder::::default(); + + // Build up the map in order + let map1 = builder.entry(None, 1).or_insert(1); + let map12 = builder.entry(map1, 2).or_insert(2); + let map123 = builder.entry(map12, 3).or_insert(3); + let map1232 = builder.entry(map123, 2).or_insert(4); + assert_eq!(builder.display(None), "[]"); + assert_eq!(builder.display(map1), "[1:1]"); + assert_eq!(builder.display(map12), "[1:1, 2:2]"); + assert_eq!(builder.display(map123), "[1:1, 2:2, 3:3]"); + assert_eq!(builder.display(map1232), "[1:1, 2:2, 3:3]"); + + // And in reverse order + let map3 = builder.entry(None, 3).or_insert(3); + let map32 = builder.entry(map3, 2).or_insert(2); + let map321 = builder.entry(map32, 1).or_insert(1); + let map3212 = builder.entry(map321, 2).or_insert(4); + assert_eq!(builder.display(None), "[]"); + assert_eq!(builder.display(map3), "[3:3]"); + assert_eq!(builder.display(map32), "[2:2, 3:3]"); + assert_eq!(builder.display(map321), "[1:1, 2:2, 3:3]"); + assert_eq!(builder.display(map3212), "[1:1, 2:2, 3:3]"); + } + + #[test] + fn can_intersect_maps() { + let mut builder = ListBuilder::::default(); + + let map1 = builder.entry(None, 1).or_insert(1); + let map12 = builder.entry(map1, 2).or_insert(2); + let map123 = builder.entry(map12, 3).or_insert(3); + let map1234 = builder.entry(map123, 4).or_insert(4); + + let map2 = builder.entry(None, 2).or_insert(20); + let map24 = builder.entry(map2, 4).or_insert(40); + let map245 = builder.entry(map24, 5).or_insert(50); + let map2457 = builder.entry(map245, 7).or_insert(70); + + let intersection = builder.intersect_with(None, None, |a, b| a + b); + assert_eq!(builder.display(intersection), "[]"); + let intersection = builder.intersect_with(None, map1234, |a, b| a + b); + assert_eq!(builder.display(intersection), "[]"); + let intersection = builder.intersect_with(None, map2457, |a, b| a + b); + assert_eq!(builder.display(intersection), "[]"); + let intersection = builder.intersect_with(map1, map1234, |a, b| a + b); + assert_eq!(builder.display(intersection), "[1:2]"); + let intersection = builder.intersect_with(map1, map2457, |a, b| a + b); + assert_eq!(builder.display(intersection), "[]"); + let intersection = builder.intersect_with(map2, map1234, |a, b| a + b); + assert_eq!(builder.display(intersection), "[2:22]"); + let intersection = builder.intersect_with(map2, map2457, |a, b| a + b); + assert_eq!(builder.display(intersection), "[2:40]"); + let intersection = builder.intersect_with(map1234, map2457, |a, b| a + b); + assert_eq!(builder.display(intersection), "[2:22, 4:44]"); + } + + #[test] + fn can_union_maps() { + let mut builder = ListBuilder::::default(); + + let map1 = builder.entry(None, 1).or_insert(1); + let map12 = builder.entry(map1, 2).or_insert(2); + let map123 = builder.entry(map12, 3).or_insert(3); + let map1234 = builder.entry(map123, 4).or_insert(4); + + let map2 = builder.entry(None, 2).or_insert(20); + let map24 = builder.entry(map2, 4).or_insert(40); + let map245 = builder.entry(map24, 5).or_insert(50); + let map2457 = builder.entry(map245, 7).or_insert(70); + + let union = builder.union_with(None, None, |a, b| a + b); + assert_eq!(builder.display(union), "[]"); + let union = builder.union_with(None, map1234, |a, b| a + b); + assert_eq!(builder.display(union), "[1:1, 2:2, 3:3, 4:4]"); + let union = builder.union_with(None, map2457, |a, b| a + b); + assert_eq!(builder.display(union), "[2:20, 4:40, 5:50, 7:70]"); + let union = builder.union_with(map1, map1234, |a, b| a + b); + assert_eq!(builder.display(union), "[1:2, 2:2, 3:3, 4:4]"); + let union = builder.union_with(map1, map2457, |a, b| a + b); + assert_eq!(builder.display(union), "[1:1, 2:20, 4:40, 5:50, 7:70]"); + let union = builder.union_with(map2, map1234, |a, b| a + b); + assert_eq!(builder.display(union), "[1:1, 2:22, 3:3, 4:4]"); + let union = builder.union_with(map2, map2457, |a, b| a + b); + assert_eq!(builder.display(union), "[2:40, 4:40, 5:50, 7:70]"); + let union = builder.union_with(map1234, map2457, |a, b| a + b); + assert_eq!(builder.display(union), "[1:1, 2:22, 3:3, 4:44, 5:50, 7:70]"); + } +}