[red-knot] Optimise visibility constraints for *-import definitions (#17317)

This commit is contained in:
Alex Waygood 2025-04-09 17:53:26 +01:00 committed by GitHub
parent ff376fc262
commit 73399029b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 103 additions and 25 deletions

View file

@ -331,12 +331,15 @@ impl<'db> SemanticIndexBuilder<'db> {
self.current_use_def_map_mut().merge(state);
}
fn add_symbol(&mut self, name: Name) -> ScopedSymbolId {
/// Return a 2-element tuple, where the first element is the [`ScopedSymbolId`] of the
/// symbol added, and the second element is a boolean indicating whether the symbol was *newly*
/// added or not
fn add_symbol(&mut self, name: Name) -> (ScopedSymbolId, bool) {
let (symbol_id, added) = self.current_symbol_table().add_symbol(name);
if added {
self.current_use_def_map_mut().add_symbol(symbol_id);
}
symbol_id
(symbol_id, added)
}
fn mark_symbol_bound(&mut self, id: ScopedSymbolId) {
@ -516,6 +519,7 @@ impl<'db> SemanticIndexBuilder<'db> {
}
/// Records a visibility constraint by applying it to all live bindings and declarations.
#[must_use = "A visibility constraint must always be negated after it is added"]
fn record_visibility_constraint(
&mut self,
predicate: Predicate<'db>,
@ -747,7 +751,7 @@ impl<'db> SemanticIndexBuilder<'db> {
..
}) => (name, &None, default),
};
let symbol = self.add_symbol(name.id.clone());
let (symbol, _) = self.add_symbol(name.id.clone());
// TODO create Definition for PEP 695 typevars
// note that the "bound" on the typevar is a totally different thing than whether
// or not a name is "bound" by a typevar declaration; the latter is always true.
@ -841,20 +845,20 @@ impl<'db> SemanticIndexBuilder<'db> {
self.declare_parameter(parameter);
}
if let Some(vararg) = parameters.vararg.as_ref() {
let symbol = self.add_symbol(vararg.name.id().clone());
let (symbol, _) = self.add_symbol(vararg.name.id().clone());
self.add_definition(
symbol,
DefinitionNodeRef::VariadicPositionalParameter(vararg),
);
}
if let Some(kwarg) = parameters.kwarg.as_ref() {
let symbol = self.add_symbol(kwarg.name.id().clone());
let (symbol, _) = self.add_symbol(kwarg.name.id().clone());
self.add_definition(symbol, DefinitionNodeRef::VariadicKeywordParameter(kwarg));
}
}
fn declare_parameter(&mut self, parameter: &'db ast::ParameterWithDefault) {
let symbol = self.add_symbol(parameter.name().id().clone());
let (symbol, _) = self.add_symbol(parameter.name().id().clone());
let definition = self.add_definition(symbol, parameter);
@ -1071,7 +1075,7 @@ where
// The symbol for the function name itself has to be evaluated
// at the end to match the runtime evaluation of parameter defaults
// and return-type annotations.
let symbol = self.add_symbol(name.id.clone());
let (symbol, _) = self.add_symbol(name.id.clone());
self.add_definition(symbol, function_def);
}
ast::Stmt::ClassDef(class) => {
@ -1095,11 +1099,11 @@ where
);
// In Python runtime semantics, a class is registered after its scope is evaluated.
let symbol = self.add_symbol(class.name.id.clone());
let (symbol, _) = self.add_symbol(class.name.id.clone());
self.add_definition(symbol, class);
}
ast::Stmt::TypeAlias(type_alias) => {
let symbol = self.add_symbol(
let (symbol, _) = self.add_symbol(
type_alias
.name
.as_name_expr()
@ -1133,7 +1137,7 @@ where
(Name::new(alias.name.id.split('.').next().unwrap()), false)
};
let symbol = self.add_symbol(symbol_name);
let (symbol, _) = self.add_symbol(symbol_name);
self.add_definition(
symbol,
ImportDefinitionNodeRef {
@ -1200,7 +1204,7 @@ where
//
// For more details, see the doc-comment on `StarImportPlaceholderPredicate`.
for export in exported_names(self.db, referenced_module) {
let symbol_id = self.add_symbol(export.clone());
let (symbol_id, newly_added) = self.add_symbol(export.clone());
let node_ref = StarImportDefinitionNodeRef { node, symbol_id };
let star_import = StarImportPlaceholderPredicate::new(
self.db,
@ -1210,13 +1214,38 @@ where
);
let pre_definition = self.flow_snapshot();
self.push_additional_definition(symbol_id, node_ref);
let constraint_id =
self.record_visibility_constraint(star_import.into());
let post_definition = self.flow_snapshot();
self.flow_restore(pre_definition.clone());
self.record_negated_visibility_constraint(constraint_id);
self.flow_merge(post_definition);
self.simplify_visibility_constraints(pre_definition);
// Fast path for if there were no previous definitions
// of the symbol defined through the `*` import:
// we can apply the visibility constraint to *only* the added definition,
// rather than all definitions
if newly_added {
let constraint_id = self
.current_use_def_map_mut()
.record_star_import_visibility_constraint(
star_import,
symbol_id,
);
let post_definition = self.flow_snapshot();
self.flow_restore(pre_definition);
self.current_use_def_map_mut()
.negate_star_import_visibility_constraint(
symbol_id,
constraint_id,
);
self.flow_merge(post_definition);
} else {
let constraint_id =
self.record_visibility_constraint(star_import.into());
let post_definition = self.flow_snapshot();
self.flow_restore(pre_definition.clone());
self.record_negated_visibility_constraint(constraint_id);
self.flow_merge(post_definition);
self.simplify_visibility_constraints(pre_definition);
}
}
continue;
@ -1236,7 +1265,7 @@ where
self.has_future_annotations |= alias.name.id == "annotations"
&& node.module.as_deref() == Some("__future__");
let symbol = self.add_symbol(symbol_name.clone());
let (symbol, _) = self.add_symbol(symbol_name.clone());
self.add_definition(
symbol,
@ -1636,7 +1665,7 @@ where
// which is invalid syntax. However, it's still pretty obvious here that the user
// *wanted* `e` to be bound, so we should still create a definition here nonetheless.
if let Some(symbol_name) = symbol_name {
let symbol = self.add_symbol(symbol_name.id.clone());
let (symbol, _) = self.add_symbol(symbol_name.id.clone());
self.add_definition(
symbol,
@ -1721,7 +1750,7 @@ where
(ast::ExprContext::Del, _) => (false, true),
(ast::ExprContext::Invalid, _) => (false, false),
};
let symbol = self.add_symbol(id.clone());
let (symbol, _) = self.add_symbol(id.clone());
if is_use {
self.mark_symbol_used(symbol);
@ -2007,7 +2036,7 @@ where
range: _,
}) = pattern
{
let symbol = self.add_symbol(name.id().clone());
let (symbol, _) = self.add_symbol(name.id().clone());
let state = self.current_match_case.as_ref().unwrap();
self.add_definition(
symbol,
@ -2028,7 +2057,7 @@ where
rest: Some(name), ..
}) = pattern
{
let symbol = self.add_symbol(name.id().clone());
let (symbol, _) = self.add_symbol(name.id().clone());
let state = self.current_match_case.as_ref().unwrap();
self.add_definition(
symbol,

View file

@ -269,7 +269,7 @@ use crate::semantic_index::narrowing_constraints::{
NarrowingConstraints, NarrowingConstraintsBuilder, NarrowingConstraintsIterator,
};
use crate::semantic_index::predicate::{
Predicate, Predicates, PredicatesBuilder, ScopedPredicateId,
Predicate, Predicates, PredicatesBuilder, ScopedPredicateId, StarImportPlaceholderPredicate,
};
use crate::semantic_index::symbol::{FileScopeId, ScopedSymbolId};
use crate::semantic_index::visibility_constraints::{
@ -603,7 +603,7 @@ pub(super) struct UseDefMapBuilder<'db> {
/// x # we store a reachability constraint of [test] for this use of `x`
///
/// y = 2
///
///
/// # we record a visibility constraint of [test] here, which retroactively affects
/// # the `y = 1` and the `y = 2` binding.
/// else:
@ -701,6 +701,34 @@ impl<'db> UseDefMapBuilder<'db> {
.add_and_constraint(self.scope_start_visibility, constraint);
}
#[must_use = "A `*`-import visibility constraint must always be negated after it is added"]
pub(super) fn record_star_import_visibility_constraint(
&mut self,
star_import: StarImportPlaceholderPredicate<'db>,
symbol: ScopedSymbolId,
) -> StarImportVisibilityConstraintId {
let predicate_id = self.add_predicate(star_import.into());
let visibility_id = self.visibility_constraints.add_atom(predicate_id);
self.symbol_states[symbol]
.record_visibility_constraint(&mut self.visibility_constraints, visibility_id);
StarImportVisibilityConstraintId(visibility_id)
}
pub(super) fn negate_star_import_visibility_constraint(
&mut self,
symbol_id: ScopedSymbolId,
constraint: StarImportVisibilityConstraintId,
) {
let negated_constraint = self
.visibility_constraints
.add_not_constraint(constraint.into_scoped_constraint_id());
self.symbol_states[symbol_id]
.record_visibility_constraint(&mut self.visibility_constraints, negated_constraint);
self.scope_start_visibility = self
.visibility_constraints
.add_and_constraint(self.scope_start_visibility, negated_constraint);
}
/// This method resets the visibility constraints for all symbols to a previous state
/// *if* there have been no new declarations or bindings since then. Consider the
/// following example:
@ -900,3 +928,24 @@ impl<'db> UseDefMapBuilder<'db> {
}
}
}
/// Newtype wrapper over [`ScopedVisibilityConstraintId`] to improve type safety.
///
/// By returning this type from [`UseDefMapBuilder::record_star_import_visibility_constraint`]
/// rather than [`ScopedVisibilityConstraintId`] directly, we ensure that
/// [`UseDefMapBuilder::negate_star_import_visibility_constraint`] must be called after the
/// visibility constraint has been added, and we ensure that
/// [`super::SemanticIndexBuilder::record_negated_visibility_constraint`] *cannot* be called with
/// the narrowing constraint (which would lead to incorrect behaviour).
///
/// This type is defined here rather than in the [`super::visibility_constraints`] module
/// because it should only ever be constructed and deconstructed from methods in the
/// [`UseDefMapBuilder`].
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(super) struct StarImportVisibilityConstraintId(ScopedVisibilityConstraintId);
impl StarImportVisibilityConstraintId {
fn into_scoped_constraint_id(self) -> ScopedVisibilityConstraintId {
self.0
}
}