[ty] Garbage-collect reachability constraints (#19414)

This is a follow-on to #19410 that further reduces the memory usage of
our reachability constraints. When finishing the building of a use-def
map, we walk through all of the "final" states and mark only those
reachability constraints as "used". We then throw away the interior TDD
nodes of any reachability constraints that weren't marked as used.

(This helps because we build up quite a few intermediate TDD nodes when
constructing complex reachability constraints. These nodes can never be
accessed if they were _only_ used as an intermediate TDD node. The
marking step ensures that we keep any nodes that ended up being referred
to in some accessible use-def map state.)
This commit is contained in:
Douglas Creager 2025-07-21 14:16:27 -04:00 committed by GitHub
parent b8dec79182
commit 88de5727df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 252 additions and 10 deletions

View file

@ -34,6 +34,7 @@ mod node_key;
pub(crate) mod place;
mod program;
mod python_platform;
mod rank;
pub mod semantic_index;
mod semantic_model;
pub(crate) mod site_packages;

View file

@ -0,0 +1,83 @@
//! A boxed bit slice that supports a constant-time `rank` operation.
use bitvec::prelude::{BitBox, Msb0};
use get_size2::GetSize;
/// A boxed bit slice that supports a constant-time `rank` operation.
///
/// This can be used to "shrink" a large vector, where you only need to keep certain elements, and
/// you want to continue to use the index in the large vector to identify each element.
///
/// First you create a new smaller vector, keeping only the elements of the large vector that you
/// care about. Now you need a way to translate an index into the large vector (which no longer
/// exists) into the corresponding index into the smaller vector. To do that, you create a bit
/// slice, containing a bit for every element of the original large vector. Each bit in the bit
/// slice indicates whether that element of the large vector was kept in the smaller vector. And
/// the `rank` of the bit gives us the index of the element in the smaller vector.
///
/// However, the naive implementation of `rank` is O(n) in the size of the bit slice. To address
/// that, we use a standard trick: we divide the bit slice into 64-bit chunks, and when
/// constructing the bit slice, precalculate the rank of the first bit in each chunk. Then, to
/// calculate the rank of an arbitrary bit, we first grab the precalculated rank of the chunk that
/// bit belongs to, and add the rank of the bit within its (fixed-sized) chunk.
///
/// This trick adds O(1.5) bits of overhead per large vector element on 64-bit platforms, and O(2)
/// bits of overhead on 32-bit platforms.
#[derive(Clone, Debug, Eq, PartialEq, GetSize)]
pub(crate) struct RankBitBox {
#[get_size(size_fn = bit_box_size)]
bits: BitBox<Chunk, Msb0>,
chunk_ranks: Box<[u32]>,
}
fn bit_box_size(bits: &BitBox<Chunk, Msb0>) -> usize {
bits.as_raw_slice().get_heap_size()
}
// bitvec does not support `u64` as a Store type on 32-bit platforms
#[cfg(target_pointer_width = "64")]
type Chunk = u64;
#[cfg(not(target_pointer_width = "64"))]
type Chunk = u32;
const CHUNK_SIZE: usize = Chunk::BITS as usize;
impl RankBitBox {
pub(crate) fn from_bits(iter: impl Iterator<Item = bool>) -> Self {
let bits: BitBox<Chunk, Msb0> = iter.collect();
let chunk_ranks = bits
.as_raw_slice()
.iter()
.scan(0u32, |rank, chunk| {
let result = *rank;
*rank += chunk.count_ones();
Some(result)
})
.collect();
Self { bits, chunk_ranks }
}
#[inline]
pub(crate) fn get_bit(&self, index: usize) -> Option<bool> {
self.bits.get(index).map(|bit| *bit)
}
/// Returns the number of bits _before_ (and not including) the given index that are set.
#[inline]
pub(crate) fn rank(&self, index: usize) -> u32 {
let chunk_index = index / CHUNK_SIZE;
let index_within_chunk = index % CHUNK_SIZE;
let chunk_rank = self.chunk_ranks[chunk_index];
if index_within_chunk == 0 {
return chunk_rank;
}
// To calculate the rank within the bit's chunk, we zero out the requested bit and every
// bit to the right, then count the number of 1s remaining (i.e., to the left of the
// requested bit).
let chunk = self.bits.as_raw_slice()[chunk_index];
let chunk_mask = Chunk::MAX << (CHUNK_SIZE - index_within_chunk);
let rank_within_chunk = (chunk & chunk_mask).count_ones();
chunk_rank + rank_within_chunk
}
}

View file

@ -1021,6 +1021,14 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
assert_eq!(&self.current_assignments, &[]);
for scope in &self.scopes {
if let Some(parent) = scope.parent() {
self.use_def_maps[parent]
.reachability_constraints
.mark_used(scope.reachability());
}
}
let mut place_tables: IndexVec<_, _> = self
.place_tables
.into_iter()

View file

@ -201,6 +201,7 @@ use rustc_hash::FxHashMap;
use crate::Db;
use crate::dunder_all::dunder_all_names;
use crate::place::{RequiresExplicitReExport, imported_symbol};
use crate::rank::RankBitBox;
use crate::semantic_index::expression::Expression;
use crate::semantic_index::place_table;
use crate::semantic_index::predicate::{
@ -283,6 +284,10 @@ impl ScopedReachabilityConstraintId {
fn is_terminal(self) -> bool {
self.0 >= SMALLEST_TERMINAL.0
}
fn as_u32(self) -> u32 {
self.0
}
}
impl Idx for ScopedReachabilityConstraintId {
@ -309,12 +314,18 @@ const SMALLEST_TERMINAL: ScopedReachabilityConstraintId = ALWAYS_FALSE;
/// A collection of reachability constraints for a given scope.
#[derive(Debug, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
pub(crate) struct ReachabilityConstraints {
interiors: IndexVec<ScopedReachabilityConstraintId, InteriorNode>,
/// The interior TDD nodes that were marked as used when being built.
used_interiors: Box<[InteriorNode]>,
/// A bit vector indicating which interior TDD nodes were marked as used. This is indexed by
/// the node's [`ScopedReachabilityConstraintId`]. The rank of the corresponding bit gives the
/// index of that node in the `used_interiors` vector.
used_indices: RankBitBox,
}
#[derive(Debug, Default, PartialEq, Eq)]
pub(crate) struct ReachabilityConstraintsBuilder {
interiors: IndexVec<ScopedReachabilityConstraintId, InteriorNode>,
interior_used: IndexVec<ScopedReachabilityConstraintId, bool>,
interior_cache: FxHashMap<InteriorNode, ScopedReachabilityConstraintId>,
not_cache: FxHashMap<ScopedReachabilityConstraintId, ScopedReachabilityConstraintId>,
and_cache: FxHashMap<
@ -334,11 +345,28 @@ pub(crate) struct ReachabilityConstraintsBuilder {
}
impl ReachabilityConstraintsBuilder {
pub(crate) fn build(mut self) -> ReachabilityConstraints {
self.interiors.shrink_to_fit();
pub(crate) fn build(self) -> ReachabilityConstraints {
let used_indices = RankBitBox::from_bits(self.interior_used.iter().copied());
let used_interiors = (self.interiors.into_iter())
.zip(self.interior_used)
.filter_map(|(interior, used)| used.then_some(interior))
.collect();
ReachabilityConstraints {
interiors: self.interiors,
used_interiors,
used_indices,
}
}
/// Marks that a particular TDD node is used. This lets us throw away interior nodes that were
/// only calculated for intermediate values, and which don't need to be included in the final
/// built result.
pub(crate) fn mark_used(&mut self, node: ScopedReachabilityConstraintId) {
if !node.is_terminal() && !self.interior_used[node] {
self.interior_used[node] = true;
let node = self.interiors[node];
self.mark_used(node.if_true);
self.mark_used(node.if_ambiguous);
self.mark_used(node.if_false);
}
}
@ -370,10 +398,10 @@ impl ReachabilityConstraintsBuilder {
return node.if_true;
}
*self
.interior_cache
.entry(node)
.or_insert_with(|| self.interiors.push(node))
*self.interior_cache.entry(node).or_insert_with(|| {
self.interior_used.push(false);
self.interiors.push(node)
})
}
/// Adds a new reachability constraint that checks a single [`Predicate`].
@ -581,7 +609,21 @@ impl ReachabilityConstraints {
ALWAYS_TRUE => return Truthiness::AlwaysTrue,
AMBIGUOUS => return Truthiness::Ambiguous,
ALWAYS_FALSE => return Truthiness::AlwaysFalse,
_ => self.interiors[id],
_ => {
// `id` gives us the index of this node in the IndexVec that we used when
// constructing this BDD. When finalizing the builder, we threw away any
// interior nodes that weren't marked as used. The `used_indices` bit vector
// lets us verify that this node was marked as used, and the rank of that bit
// in the bit vector tells us where this node lives in the "condensed"
// `used_interiors` vector.
let raw_index = id.as_u32() as usize;
debug_assert!(
self.used_indices.get_bit(raw_index).unwrap_or(false),
"all used reachability constraints should have been marked as used",
);
let index = self.used_indices.rank(raw_index) as usize;
self.used_interiors[index]
}
};
let predicate = &predicates[node.atom];
match Self::analyze_single(db, predicate) {

View file

@ -1118,7 +1118,41 @@ impl<'db> UseDefMapBuilder<'db> {
.add_or_constraint(self.reachability, snapshot.reachability);
}
fn mark_reachability_constraints(&mut self) {
// We only walk the fields that are copied through to the UseDefMap when we finish building
// it.
for bindings in &mut self.bindings_by_use {
bindings.finish(&mut self.reachability_constraints);
}
for constraint in self.node_reachability.values() {
self.reachability_constraints.mark_used(*constraint);
}
for place_state in &mut self.place_states {
place_state.finish(&mut self.reachability_constraints);
}
for reachable_definition in &mut self.reachable_definitions {
reachable_definition
.bindings
.finish(&mut self.reachability_constraints);
reachable_definition
.declarations
.finish(&mut self.reachability_constraints);
}
for declarations in self.declarations_by_binding.values_mut() {
declarations.finish(&mut self.reachability_constraints);
}
for bindings in self.bindings_by_definition.values_mut() {
bindings.finish(&mut self.reachability_constraints);
}
for eager_snapshot in &mut self.eager_snapshots {
eager_snapshot.finish(&mut self.reachability_constraints);
}
self.reachability_constraints.mark_used(self.reachability);
}
pub(super) fn finish(mut self) -> UseDefMap<'db> {
self.mark_reachability_constraints();
self.all_definitions.shrink_to_fit();
self.place_states.shrink_to_fit();
self.reachable_definitions.shrink_to_fit();

View file

@ -172,6 +172,13 @@ impl Declarations {
}
}
}
pub(super) fn finish(&mut self, reachability_constraints: &mut ReachabilityConstraintsBuilder) {
self.live_declarations.shrink_to_fit();
for declaration in &self.live_declarations {
reachability_constraints.mark_used(declaration.reachability_constraint);
}
}
}
/// A snapshot of a place state that can be used to resolve a reference in a nested eager scope.
@ -185,6 +192,17 @@ pub(super) enum EagerSnapshot {
Bindings(Bindings),
}
impl EagerSnapshot {
pub(super) fn finish(&mut self, reachability_constraints: &mut ReachabilityConstraintsBuilder) {
match self {
EagerSnapshot::Constraint(_) => {}
EagerSnapshot::Bindings(bindings) => {
bindings.finish(reachability_constraints);
}
}
}
}
/// Live bindings for a single place at some point in control flow. Each live binding comes
/// with a set of narrowing constraints and a reachability constraint.
#[derive(Clone, Debug, Default, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
@ -203,6 +221,13 @@ impl Bindings {
self.unbound_narrowing_constraint
.unwrap_or(self.live_bindings[0].narrowing_constraint)
}
pub(super) fn finish(&mut self, reachability_constraints: &mut ReachabilityConstraintsBuilder) {
self.live_bindings.shrink_to_fit();
for binding in &self.live_bindings {
reachability_constraints.mark_used(binding.reachability_constraint);
}
}
}
/// One of the live bindings for a single place at some point in control flow.
@ -422,6 +447,11 @@ impl PlaceState {
pub(super) fn declarations(&self) -> &Declarations {
&self.declarations
}
pub(super) fn finish(&mut self, reachability_constraints: &mut ReachabilityConstraintsBuilder) {
self.declarations.finish(reachability_constraints);
self.bindings.finish(reachability_constraints);
}
}
#[cfg(test)]