[red-knot] fix narrowing in nested scopes (#17630)

## Summary

This PR fixes #17595.

## Test Plan

New test cases are added to `mdtest/narrow/conditionals/nested.md`.

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Shunsuke Shibayama 2025-05-06 08:28:42 +09:00 committed by GitHub
parent a4c8e43c5f
commit fd76d70a31
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 414 additions and 124 deletions

View file

@ -42,7 +42,7 @@ use crate::semantic_index::symbol::{
ScopedSymbolId, SymbolTableBuilder,
};
use crate::semantic_index::use_def::{
EagerBindingsKey, FlowSnapshot, ScopedEagerBindingsId, UseDefMapBuilder,
EagerSnapshotKey, FlowSnapshot, ScopedEagerSnapshotId, UseDefMapBuilder,
};
use crate::semantic_index::visibility_constraints::{
ScopedVisibilityConstraintId, VisibilityConstraintsBuilder,
@ -113,7 +113,7 @@ pub(super) struct SemanticIndexBuilder<'db> {
///
/// [generator functions]: https://docs.python.org/3/glossary.html#term-generator
generator_functions: FxHashSet<FileScopeId>,
eager_bindings: FxHashMap<EagerBindingsKey, ScopedEagerBindingsId>,
eager_snapshots: FxHashMap<EagerSnapshotKey, ScopedEagerSnapshotId>,
/// Errors collected by the `semantic_checker`.
semantic_syntax_errors: RefCell<Vec<SemanticSyntaxError>>,
}
@ -148,7 +148,7 @@ impl<'db> SemanticIndexBuilder<'db> {
imported_modules: FxHashSet::default(),
generator_functions: FxHashSet::default(),
eager_bindings: FxHashMap::default(),
eager_snapshots: FxHashMap::default(),
python_version: Program::get(db).python_version(db),
source_text: OnceCell::new(),
@ -253,13 +253,15 @@ impl<'db> SemanticIndexBuilder<'db> {
children_start..children_start,
reachability,
);
let is_class_scope = scope.kind().is_class();
self.try_node_context_stack_manager.enter_nested_scope();
let file_scope_id = self.scopes.push(scope);
self.symbol_tables.push(SymbolTableBuilder::default());
self.instance_attribute_tables
.push(SymbolTableBuilder::default());
self.use_def_maps.push(UseDefMapBuilder::default());
self.use_def_maps
.push(UseDefMapBuilder::new(is_class_scope));
let ast_id_scope = self.ast_ids.push(AstIdsBuilder::default());
let scope_id = ScopeId::new(self.db, self.file, file_scope_id, countme::Count::default());
@ -303,12 +305,6 @@ impl<'db> SemanticIndexBuilder<'db> {
let enclosing_scope_kind = self.scopes[enclosing_scope_id].kind();
let enclosing_symbol_table = &self.symbol_tables[enclosing_scope_id];
// Names bound in class scopes are never visible to nested scopes, so we never need to
// save eager scope bindings in a class scope.
if enclosing_scope_kind.is_class() {
continue;
}
for nested_symbol in self.symbol_tables[popped_scope_id].symbols() {
// Skip this symbol if this enclosing scope doesn't contain any bindings for it.
// Note that even if this symbol is bound in the popped scope,
@ -321,24 +317,26 @@ impl<'db> SemanticIndexBuilder<'db> {
continue;
};
let enclosing_symbol = enclosing_symbol_table.symbol(enclosing_symbol_id);
if !enclosing_symbol.is_bound() {
continue;
}
// Snapshot the bindings of this symbol that are visible at this point in this
// Snapshot the state of this symbol that are visible at this point in this
// enclosing scope.
let key = EagerBindingsKey {
let key = EagerSnapshotKey {
enclosing_scope: enclosing_scope_id,
enclosing_symbol: enclosing_symbol_id,
nested_scope: popped_scope_id,
};
let eager_bindings = self.use_def_maps[enclosing_scope_id]
.snapshot_eager_bindings(enclosing_symbol_id);
self.eager_bindings.insert(key, eager_bindings);
let eager_snapshot = self.use_def_maps[enclosing_scope_id].snapshot_eager_state(
enclosing_symbol_id,
enclosing_scope_kind,
enclosing_symbol.is_bound(),
);
self.eager_snapshots.insert(key, eager_snapshot);
}
// Lazy scopes are "sticky": once we see a lazy scope we stop doing lookups
// eagerly, even if we would encounter another eager enclosing scope later on.
// Also, narrowing constraints outside a lazy scope are not applicable.
// TODO: If the symbol has never been rewritten, they are applicable.
if !enclosing_scope_kind.is_eager() {
break;
}
@ -1085,8 +1083,8 @@ impl<'db> SemanticIndexBuilder<'db> {
self.scope_ids_by_scope.shrink_to_fit();
self.scopes_by_node.shrink_to_fit();
self.eager_bindings.shrink_to_fit();
self.generator_functions.shrink_to_fit();
self.eager_snapshots.shrink_to_fit();
SemanticIndex {
symbol_tables,
@ -1101,7 +1099,7 @@ impl<'db> SemanticIndexBuilder<'db> {
use_def_maps,
imported_modules: Arc::new(self.imported_modules),
has_future_annotations: self.has_future_annotations,
eager_bindings: self.eager_bindings,
eager_snapshots: self.eager_snapshots,
semantic_syntax_errors: self.semantic_syntax_errors.into_inner(),
generator_functions: self.generator_functions,
}

View file

@ -29,6 +29,7 @@
//! [`Predicate`]: crate::semantic_index::predicate::Predicate
use crate::list::{List, ListBuilder, ListSetReverseIterator, ListStorage};
use crate::semantic_index::ast_ids::ScopedUseId;
use crate::semantic_index::predicate::ScopedPredicateId;
/// A narrowing constraint associated with a live binding.
@ -38,6 +39,12 @@ use crate::semantic_index::predicate::ScopedPredicateId;
/// [`Predicate`]: crate::semantic_index::predicate::Predicate
pub(crate) type ScopedNarrowingConstraint = List<ScopedNarrowingConstraintPredicate>;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum ConstraintKey {
NarrowingConstraint(ScopedNarrowingConstraint),
UseId(ScopedUseId),
}
/// One of the [`Predicate`]s in a narrowing constraint, which constraints the type of the
/// binding's symbol.
///

View file

@ -259,25 +259,25 @@
use ruff_index::{newtype_index, IndexVec};
use rustc_hash::FxHashMap;
use self::symbol_state::ScopedDefinitionId;
use self::symbol_state::{
LiveBindingsIterator, LiveDeclaration, LiveDeclarationsIterator, SymbolBindings,
SymbolDeclarations, SymbolState,
EagerSnapshot, LiveBindingsIterator, LiveDeclaration, LiveDeclarationsIterator,
ScopedDefinitionId, SymbolBindings, SymbolDeclarations, SymbolState,
};
use crate::node_key::NodeKey;
use crate::semantic_index::ast_ids::ScopedUseId;
use crate::semantic_index::definition::Definition;
use crate::semantic_index::narrowing_constraints::{
NarrowingConstraints, NarrowingConstraintsBuilder, NarrowingConstraintsIterator,
ConstraintKey, NarrowingConstraints, NarrowingConstraintsBuilder, NarrowingConstraintsIterator,
};
use crate::semantic_index::predicate::{
Predicate, Predicates, PredicatesBuilder, ScopedPredicateId, StarImportPlaceholderPredicate,
};
use crate::semantic_index::symbol::{FileScopeId, ScopedSymbolId};
use crate::semantic_index::symbol::{FileScopeId, ScopeKind, ScopedSymbolId};
use crate::semantic_index::visibility_constraints::{
ScopedVisibilityConstraintId, VisibilityConstraints, VisibilityConstraintsBuilder,
};
use crate::types::Truthiness;
use crate::semantic_index::EagerSnapshotResult;
use crate::types::{infer_narrowing_constraint, IntersectionBuilder, Truthiness, Type};
mod symbol_state;
@ -328,7 +328,7 @@ pub(crate) struct UseDefMap<'db> {
/// Snapshot of bindings in this scope that can be used to resolve a reference in a nested
/// eager scope.
eager_bindings: EagerBindings,
eager_snapshots: EagerSnapshots,
/// Whether or not the start of the scope is visible.
/// This is used to check if the function can implicitly return `None`.
@ -354,6 +354,22 @@ impl<'db> UseDefMap<'db> {
self.bindings_iterator(&self.bindings_by_use[use_id])
}
pub(crate) fn narrowing_constraints_at_use(
&self,
constraint_key: ConstraintKey,
) -> ConstraintsIterator<'_, 'db> {
let constraint = match constraint_key {
ConstraintKey::NarrowingConstraint(constraint) => constraint,
ConstraintKey::UseId(use_id) => {
self.bindings_by_use[use_id].unbound_narrowing_constraint()
}
};
ConstraintsIterator {
predicates: &self.predicates,
constraint_ids: self.narrowing_constraints.iter_predicates(constraint),
}
}
pub(super) fn is_reachable(
&self,
db: &dyn crate::Db,
@ -398,13 +414,19 @@ impl<'db> UseDefMap<'db> {
self.bindings_iterator(self.instance_attributes[symbol].bindings())
}
pub(crate) fn eager_bindings(
pub(crate) fn eager_snapshot(
&self,
eager_bindings: ScopedEagerBindingsId,
) -> Option<BindingWithConstraintsIterator<'_, 'db>> {
self.eager_bindings
.get(eager_bindings)
.map(|symbol_bindings| self.bindings_iterator(symbol_bindings))
eager_bindings: ScopedEagerSnapshotId,
) -> EagerSnapshotResult<'_, 'db> {
match self.eager_snapshots.get(eager_bindings) {
Some(EagerSnapshot::Constraint(constraint)) => {
EagerSnapshotResult::FoundConstraint(*constraint)
}
Some(EagerSnapshot::Bindings(symbol_bindings)) => {
EagerSnapshotResult::FoundBindings(self.bindings_iterator(symbol_bindings))
}
None => EagerSnapshotResult::NotFound,
}
}
pub(crate) fn bindings_at_declaration(
@ -489,19 +511,19 @@ impl<'db> UseDefMap<'db> {
}
}
/// Uniquely identifies a snapshot of bindings that can be used to resolve a reference in a nested
/// eager scope.
/// Uniquely identifies a snapshot of a symbol state that can be used to resolve a reference in a
/// nested eager scope.
///
/// An eager scope has its entire body executed immediately at the location where it is defined.
/// For any free references in the nested scope, we use the bindings that are visible at the point
/// where the nested scope is defined, instead of using the public type of the symbol.
///
/// There is a unique ID for each distinct [`EagerBindingsKey`] in the file.
/// There is a unique ID for each distinct [`EagerSnapshotKey`] in the file.
#[newtype_index]
pub(crate) struct ScopedEagerBindingsId;
pub(crate) struct ScopedEagerSnapshotId;
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub(crate) struct EagerBindingsKey {
pub(crate) struct EagerSnapshotKey {
/// The enclosing scope containing the bindings
pub(crate) enclosing_scope: FileScopeId,
/// The referenced symbol (in the enclosing scope)
@ -510,8 +532,8 @@ pub(crate) struct EagerBindingsKey {
pub(crate) nested_scope: FileScopeId,
}
/// A snapshot of bindings that can be used to resolve a reference in a nested eager scope.
type EagerBindings = IndexVec<ScopedEagerBindingsId, SymbolBindings>;
/// A snapshot of symbol states that can be used to resolve a reference in a nested eager scope.
type EagerSnapshots = IndexVec<ScopedEagerSnapshotId, EagerSnapshot>;
#[derive(Debug)]
pub(crate) struct BindingWithConstraintsIterator<'map, 'db> {
@ -568,6 +590,33 @@ impl<'db> Iterator for ConstraintsIterator<'_, 'db> {
impl std::iter::FusedIterator for ConstraintsIterator<'_, '_> {}
impl<'db> ConstraintsIterator<'_, 'db> {
pub(crate) fn narrow(
self,
db: &'db dyn crate::Db,
base_ty: Type<'db>,
symbol: ScopedSymbolId,
) -> Type<'db> {
let constraint_tys: Vec<_> = self
.filter_map(|constraint| infer_narrowing_constraint(db, constraint, symbol))
.collect();
if constraint_tys.is_empty() {
base_ty
} else {
let intersection_ty = constraint_tys
.into_iter()
.rev()
.fold(
IntersectionBuilder::new(db).add_positive(base_ty),
IntersectionBuilder::add_positive,
)
.build();
intersection_ty
}
}
}
#[derive(Clone)]
pub(crate) struct DeclarationsIterator<'map, 'db> {
all_definitions: &'map IndexVec<ScopedDefinitionId, Option<Definition<'db>>>,
@ -688,13 +737,16 @@ pub(super) struct UseDefMapBuilder<'db> {
/// Currently live bindings for each instance attribute.
instance_attribute_states: IndexVec<ScopedSymbolId, SymbolState>,
/// Snapshot of bindings in this scope that can be used to resolve a reference in a nested
/// eager scope.
eager_bindings: EagerBindings,
/// Snapshots of symbol states in this scope that can be used to resolve a reference in a
/// nested eager scope.
eager_snapshots: EagerSnapshots,
/// Is this a class scope?
is_class_scope: bool,
}
impl Default for UseDefMapBuilder<'_> {
fn default() -> Self {
impl<'db> UseDefMapBuilder<'db> {
pub(super) fn new(is_class_scope: bool) -> Self {
Self {
all_definitions: IndexVec::from_iter([None]),
predicates: PredicatesBuilder::default(),
@ -707,13 +759,11 @@ impl Default for UseDefMapBuilder<'_> {
declarations_by_binding: FxHashMap::default(),
bindings_by_declaration: FxHashMap::default(),
symbol_states: IndexVec::new(),
eager_bindings: EagerBindings::default(),
eager_snapshots: EagerSnapshots::default(),
instance_attribute_states: IndexVec::new(),
is_class_scope,
}
}
}
impl<'db> UseDefMapBuilder<'db> {
pub(super) fn mark_unreachable(&mut self) {
self.record_visibility_constraint(ScopedVisibilityConstraintId::ALWAYS_FALSE);
self.reachability = ScopedVisibilityConstraintId::ALWAYS_FALSE;
@ -738,7 +788,7 @@ impl<'db> UseDefMapBuilder<'db> {
let symbol_state = &mut self.symbol_states[symbol];
self.declarations_by_binding
.insert(binding, symbol_state.declarations().clone());
symbol_state.record_binding(def_id, self.scope_start_visibility);
symbol_state.record_binding(def_id, self.scope_start_visibility, self.is_class_scope);
}
pub(super) fn record_attribute_binding(
@ -750,7 +800,7 @@ impl<'db> UseDefMapBuilder<'db> {
let attribute_state = &mut self.instance_attribute_states[symbol];
self.declarations_by_binding
.insert(binding, attribute_state.declarations().clone());
attribute_state.record_binding(def_id, self.scope_start_visibility);
attribute_state.record_binding(def_id, self.scope_start_visibility, self.is_class_scope);
}
pub(super) fn add_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId {
@ -936,7 +986,7 @@ impl<'db> UseDefMapBuilder<'db> {
let def_id = self.all_definitions.push(Some(definition));
let symbol_state = &mut self.symbol_states[symbol];
symbol_state.record_declaration(def_id);
symbol_state.record_binding(def_id, self.scope_start_visibility);
symbol_state.record_binding(def_id, self.scope_start_visibility, self.is_class_scope);
}
pub(super) fn record_use(
@ -961,12 +1011,25 @@ impl<'db> UseDefMapBuilder<'db> {
self.node_reachability.insert(node_key, self.reachability);
}
pub(super) fn snapshot_eager_bindings(
pub(super) fn snapshot_eager_state(
&mut self,
enclosing_symbol: ScopedSymbolId,
) -> ScopedEagerBindingsId {
self.eager_bindings
.push(self.symbol_states[enclosing_symbol].bindings().clone())
scope: ScopeKind,
is_bound: bool,
) -> ScopedEagerSnapshotId {
// Names bound in class scopes are never visible to nested scopes, so we never need to
// save eager scope bindings in a class scope.
if scope.is_class() || !is_bound {
self.eager_snapshots.push(EagerSnapshot::Constraint(
self.symbol_states[enclosing_symbol]
.bindings()
.unbound_narrowing_constraint(),
))
} else {
self.eager_snapshots.push(EagerSnapshot::Bindings(
self.symbol_states[enclosing_symbol].bindings().clone(),
))
}
}
/// Take a snapshot of the current visible-symbols state.
@ -1086,7 +1149,7 @@ impl<'db> UseDefMapBuilder<'db> {
self.node_reachability.shrink_to_fit();
self.declarations_by_binding.shrink_to_fit();
self.bindings_by_declaration.shrink_to_fit();
self.eager_bindings.shrink_to_fit();
self.eager_snapshots.shrink_to_fit();
UseDefMap {
all_definitions: self.all_definitions,
@ -1099,7 +1162,7 @@ impl<'db> UseDefMapBuilder<'db> {
instance_attributes: self.instance_attribute_states,
declarations_by_binding: self.declarations_by_binding,
bindings_by_declaration: self.bindings_by_declaration,
eager_bindings: self.eager_bindings,
eager_snapshots: self.eager_snapshots,
scope_start_visibility: self.scope_start_visibility,
}
}

View file

@ -65,6 +65,10 @@ impl ScopedDefinitionId {
/// When creating a use-def-map builder, we always add an empty `None` definition
/// at index 0, so this ID is always present.
pub(super) const UNBOUND: ScopedDefinitionId = ScopedDefinitionId::from_u32(0);
fn is_unbound(self) -> bool {
self == Self::UNBOUND
}
}
/// Can keep inline this many live bindings or declarations per symbol at a given time; more will
@ -177,14 +181,41 @@ impl SymbolDeclarations {
}
}
/// A snapshot of a symbol state that can be used to resolve a reference in a nested eager scope.
/// If there are bindings in a (non-class) scope , they are stored in `Bindings`.
/// Even if it's a class scope (class variables are not visible to nested scopes) or there are no
/// bindings, the current narrowing constraint is necessary for narrowing, so it's stored in
/// `Constraint`.
#[derive(Clone, Debug, PartialEq, Eq, salsa::Update)]
pub(super) enum EagerSnapshot {
Constraint(ScopedNarrowingConstraint),
Bindings(SymbolBindings),
}
/// Live bindings for a single symbol at some point in control flow. Each live binding comes
/// with a set of narrowing constraints and a visibility constraint.
#[derive(Clone, Debug, Default, PartialEq, Eq, salsa::Update)]
pub(super) struct SymbolBindings {
/// The narrowing constraint applicable to the "unbound" binding, if we need access to it even
/// when it's not visible. This happens in class scopes, where local bindings are not visible
/// to nested scopes, but we still need to know what narrowing constraints were applied to the
/// "unbound" binding.
unbound_narrowing_constraint: Option<ScopedNarrowingConstraint>,
/// A list of live bindings for this symbol, sorted by their `ScopedDefinitionId`
live_bindings: SmallVec<[LiveBinding; INLINE_DEFINITIONS_PER_SYMBOL]>,
}
impl SymbolBindings {
pub(super) fn unbound_narrowing_constraint(&self) -> ScopedNarrowingConstraint {
debug_assert!(
self.unbound_narrowing_constraint.is_some()
|| self.live_bindings[0].binding.is_unbound()
);
self.unbound_narrowing_constraint
.unwrap_or(self.live_bindings[0].narrowing_constraint)
}
}
/// One of the live bindings for a single symbol at some point in control flow.
#[derive(Clone, Debug, PartialEq, Eq)]
pub(super) struct LiveBinding {
@ -203,6 +234,7 @@ impl SymbolBindings {
visibility_constraint: scope_start_visibility,
};
Self {
unbound_narrowing_constraint: None,
live_bindings: smallvec![initial_binding],
}
}
@ -212,7 +244,13 @@ impl SymbolBindings {
&mut self,
binding: ScopedDefinitionId,
visibility_constraint: ScopedVisibilityConstraintId,
is_class_scope: bool,
) {
// If we are in a class scope, and the unbound binding was previously visible, but we will
// now replace it, record the narrowing constraints on it:
if is_class_scope && self.live_bindings[0].binding.is_unbound() {
self.unbound_narrowing_constraint = Some(self.live_bindings[0].narrowing_constraint);
}
// The new binding replaces all previous live bindings in this path, and has no
// constraints.
self.live_bindings.clear();
@ -278,6 +316,14 @@ impl SymbolBindings {
) {
let a = std::mem::take(self);
if let Some((a, b)) = a
.unbound_narrowing_constraint
.zip(b.unbound_narrowing_constraint)
{
self.unbound_narrowing_constraint =
Some(narrowing_constraints.intersect_constraints(a, b));
}
// Invariant: merge_join_by consumes the two iterators in sorted order, which ensures that
// the merged `live_bindings` vec remains sorted. If a definition is found in both `a` and
// `b`, we compose the constraints from the two paths in an appropriate way (intersection
@ -333,10 +379,11 @@ impl SymbolState {
&mut self,
binding_id: ScopedDefinitionId,
visibility_constraint: ScopedVisibilityConstraintId,
is_class_scope: bool,
) {
debug_assert_ne!(binding_id, ScopedDefinitionId::UNBOUND);
self.bindings
.record_binding(binding_id, visibility_constraint);
.record_binding(binding_id, visibility_constraint, is_class_scope);
}
/// Add given constraint to all live bindings.
@ -467,6 +514,7 @@ mod tests {
sym.record_binding(
ScopedDefinitionId::from_u32(1),
ScopedVisibilityConstraintId::ALWAYS_TRUE,
false,
);
assert_bindings(&narrowing_constraints, &sym, &["1<>"]);
@ -479,6 +527,7 @@ mod tests {
sym.record_binding(
ScopedDefinitionId::from_u32(1),
ScopedVisibilityConstraintId::ALWAYS_TRUE,
false,
);
let predicate = ScopedPredicateId::from_u32(0).into();
sym.record_narrowing_constraint(&mut narrowing_constraints, predicate);
@ -496,6 +545,7 @@ mod tests {
sym1a.record_binding(
ScopedDefinitionId::from_u32(1),
ScopedVisibilityConstraintId::ALWAYS_TRUE,
false,
);
let predicate = ScopedPredicateId::from_u32(0).into();
sym1a.record_narrowing_constraint(&mut narrowing_constraints, predicate);
@ -504,6 +554,7 @@ mod tests {
sym1b.record_binding(
ScopedDefinitionId::from_u32(1),
ScopedVisibilityConstraintId::ALWAYS_TRUE,
false,
);
let predicate = ScopedPredicateId::from_u32(0).into();
sym1b.record_narrowing_constraint(&mut narrowing_constraints, predicate);
@ -521,6 +572,7 @@ mod tests {
sym2a.record_binding(
ScopedDefinitionId::from_u32(2),
ScopedVisibilityConstraintId::ALWAYS_TRUE,
false,
);
let predicate = ScopedPredicateId::from_u32(1).into();
sym2a.record_narrowing_constraint(&mut narrowing_constraints, predicate);
@ -529,6 +581,7 @@ mod tests {
sym1b.record_binding(
ScopedDefinitionId::from_u32(2),
ScopedVisibilityConstraintId::ALWAYS_TRUE,
false,
);
let predicate = ScopedPredicateId::from_u32(2).into();
sym1b.record_narrowing_constraint(&mut narrowing_constraints, predicate);
@ -546,6 +599,7 @@ mod tests {
sym3a.record_binding(
ScopedDefinitionId::from_u32(3),
ScopedVisibilityConstraintId::ALWAYS_TRUE,
false,
);
let predicate = ScopedPredicateId::from_u32(3).into();
sym3a.record_narrowing_constraint(&mut narrowing_constraints, predicate);