[red-knot] Prevent cross-module query dependencies in own_instance_member (#16268)

This commit is contained in:
Micha Reiser 2025-02-20 17:46:45 +00:00 committed by GitHub
parent b385c7d22a
commit 470f852f04
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 150 additions and 61 deletions

View file

@ -346,12 +346,14 @@ impl<'db> SemanticIndexBuilder<'db> {
// SAFETY: `definition_node` is guaranteed to be a child of `self.module` // SAFETY: `definition_node` is guaranteed to be a child of `self.module`
let kind = unsafe { definition_node.into_owned(self.module.clone()) }; let kind = unsafe { definition_node.into_owned(self.module.clone()) };
let category = kind.category(); let category = kind.category();
let is_reexported = kind.is_reexported();
let definition = Definition::new( let definition = Definition::new(
self.db, self.db,
self.file, self.file,
self.current_scope(), self.current_scope(),
symbol, symbol,
kind, kind,
is_reexported,
countme::Count::default(), countme::Count::default(),
); );

View file

@ -33,11 +33,16 @@ pub struct Definition<'db> {
/// The symbol defined. /// The symbol defined.
pub(crate) symbol: ScopedSymbolId, pub(crate) symbol: ScopedSymbolId,
/// WARNING: Only access this field when doing type inference for the same
/// file as where `Definition` is defined to avoid cross-file query dependencies.
#[no_eq] #[no_eq]
#[return_ref] #[return_ref]
#[tracked] #[tracked]
pub(crate) kind: DefinitionKind<'db>, pub(crate) kind: DefinitionKind<'db>,
/// This is a dedicated field to avoid accessing `kind` to compute this value.
pub(crate) is_reexported: bool,
count: countme::Count<Definition<'static>>, count: countme::Count<Definition<'static>>,
} }
@ -45,22 +50,6 @@ impl<'db> Definition<'db> {
pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> { pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.file_scope(db).to_scope_id(db, self.file(db)) self.file_scope(db).to_scope_id(db, self.file(db))
} }
pub(crate) fn category(self, db: &'db dyn Db) -> DefinitionCategory {
self.kind(db).category()
}
pub(crate) fn is_declaration(self, db: &'db dyn Db) -> bool {
self.kind(db).category().is_declaration()
}
pub(crate) fn is_binding(self, db: &'db dyn Db) -> bool {
self.kind(db).category().is_binding()
}
pub(crate) fn is_reexported(self, db: &'db dyn Db) -> bool {
self.kind(db).is_reexported()
}
} }
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]

View file

@ -293,7 +293,7 @@ fn core_module_scope(db: &dyn Db, core_module: KnownModule) -> Option<ScopeId<'_
/// together with boundness information in a [`Symbol`]. /// together with boundness information in a [`Symbol`].
/// ///
/// The type will be a union if there are multiple bindings with different types. /// The type will be a union if there are multiple bindings with different types.
pub(crate) fn symbol_from_bindings<'db>( pub(super) fn symbol_from_bindings<'db>(
db: &'db dyn Db, db: &'db dyn Db,
bindings_with_constraints: BindingWithConstraintsIterator<'_, 'db>, bindings_with_constraints: BindingWithConstraintsIterator<'_, 'db>,
) -> Symbol<'db> { ) -> Symbol<'db> {
@ -479,6 +479,10 @@ fn symbol_impl<'db>(
} }
/// Implementation of [`symbol_from_bindings`]. /// Implementation of [`symbol_from_bindings`].
///
/// ## Implementation Note
/// This function gets called cross-module. It, therefore, shouldn't
/// access any AST nodes from the file containing the declarations.
fn symbol_from_bindings_impl<'db>( fn symbol_from_bindings_impl<'db>(
db: &'db dyn Db, db: &'db dyn Db,
bindings_with_constraints: BindingWithConstraintsIterator<'_, 'db>, bindings_with_constraints: BindingWithConstraintsIterator<'_, 'db>,
@ -562,6 +566,10 @@ fn symbol_from_bindings_impl<'db>(
} }
/// Implementation of [`symbol_from_declarations`]. /// Implementation of [`symbol_from_declarations`].
///
/// ## Implementation Note
/// This function gets called cross-module. It, therefore, shouldn't
/// access any AST nodes from the file containing the declarations.
fn symbol_from_declarations_impl<'db>( fn symbol_from_declarations_impl<'db>(
db: &'db dyn Db, db: &'db dyn Db,
declarations: DeclarationsIterator<'_, 'db>, declarations: DeclarationsIterator<'_, 'db>,

View file

@ -16,7 +16,8 @@ pub(crate) use self::diagnostic::register_lints;
pub use self::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics}; pub use self::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics};
pub(crate) use self::display::TypeArrayDisplay; pub(crate) use self::display::TypeArrayDisplay;
pub(crate) use self::infer::{ pub(crate) use self::infer::{
infer_deferred_types, infer_definition_types, infer_expression_types, infer_scope_types, infer_deferred_types, infer_definition_types, infer_expression_type, infer_expression_types,
infer_scope_types,
}; };
pub use self::narrow::KnownConstraintFunction; pub use self::narrow::KnownConstraintFunction;
pub(crate) use self::signatures::Signature; pub(crate) use self::signatures::Signature;
@ -26,7 +27,6 @@ use crate::module_resolver::{file_to_module, resolve_module, KnownModule};
use crate::semantic_index::ast_ids::HasScopedExpressionId; use crate::semantic_index::ast_ids::HasScopedExpressionId;
use crate::semantic_index::attribute_assignment::AttributeAssignment; use crate::semantic_index::attribute_assignment::AttributeAssignment;
use crate::semantic_index::definition::Definition; use crate::semantic_index::definition::Definition;
use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::ScopeId; use crate::semantic_index::symbol::ScopeId;
use crate::semantic_index::{ use crate::semantic_index::{
attribute_assignments, imported_modules, semantic_index, symbol_table, use_def_map, attribute_assignments, imported_modules, semantic_index, symbol_table, use_def_map,
@ -3818,16 +3818,6 @@ impl<'db> Class<'db> {
name: &str, name: &str,
inferred_type_from_class_body: Option<Type<'db>>, inferred_type_from_class_body: Option<Type<'db>>,
) -> Symbol<'db> { ) -> Symbol<'db> {
// We use a separate salsa query here to prevent unrelated changes in the AST of an external
// file from triggering re-evaluations of downstream queries.
// See the `dependency_implicit_instance_attribute` test for more information.
#[salsa::tracked]
fn infer_expression_type<'db>(db: &'db dyn Db, expression: Expression<'db>) -> Type<'db> {
let inference = infer_expression_types(db, expression);
let expr_scope = expression.scope(db);
inference.expression_type(expression.node_ref(db).scoped_expression_id(db, expr_scope))
}
// If we do not see any declarations of an attribute, neither in the class body nor in // If we do not see any declarations of an attribute, neither in the class body nor in
// any method, we build a union of `Unknown` with the inferred types of all bindings of // any method, we build a union of `Unknown` with the inferred types of all bindings of
// that attribute. We include `Unknown` in that union to account for the fact that the // that attribute. We include `Unknown` in that union to account for the fact that the

View file

@ -118,7 +118,7 @@ fn infer_definition_types_cycle_recovery<'db>(
) -> TypeInference<'db> { ) -> TypeInference<'db> {
tracing::trace!("infer_definition_types_cycle_recovery"); tracing::trace!("infer_definition_types_cycle_recovery");
let mut inference = TypeInference::empty(input.scope(db)); let mut inference = TypeInference::empty(input.scope(db));
let category = input.category(db); let category = input.kind(db).category();
if category.is_declaration() { if category.is_declaration() {
inference inference
.declarations .declarations
@ -198,6 +198,36 @@ pub(crate) fn infer_expression_types<'db>(
TypeInferenceBuilder::new(db, InferenceRegion::Expression(expression), index).finish() TypeInferenceBuilder::new(db, InferenceRegion::Expression(expression), index).finish()
} }
/// Infers the type of an `expression` that is guaranteed to be in the same file as the calling query.
///
/// This is a small helper around [`infer_expression_types()`] to reduce the boilerplate.
/// Use [`infer_expression_type()`] if it isn't guaranteed that `expression` is in the same file to
/// avoid cross-file query dependencies.
pub(super) fn infer_same_file_expression_type<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
) -> Type<'db> {
let inference = infer_expression_types(db, expression);
let scope = expression.scope(db);
inference.expression_type(expression.node_ref(db).scoped_expression_id(db, scope))
}
/// Infers the type of an expression where the expression might come from another file.
///
/// Use this over [`infer_expression_types`] if the expression might come from another file than the
/// enclosing query to avoid cross-file query dependencies.
///
/// Use [`infer_same_file_expression_type`] if it is guaranteed that `expression` is in the same
/// to avoid unnecessary salsa ingredients. This is normally the case inside the `TypeInferenceBuilder`.
#[salsa::tracked]
pub(crate) fn infer_expression_type<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
) -> Type<'db> {
// It's okay to call the "same file" version here because we're inside a salsa query.
infer_same_file_expression_type(db, expression)
}
/// Infer the types for an [`Unpack`] operation. /// Infer the types for an [`Unpack`] operation.
/// ///
/// This infers the expression type and performs structural match against the target expression /// This infers the expression type and performs structural match against the target expression
@ -870,7 +900,7 @@ impl<'db> TypeInferenceBuilder<'db> {
} }
fn add_binding(&mut self, node: AnyNodeRef, binding: Definition<'db>, ty: Type<'db>) { fn add_binding(&mut self, node: AnyNodeRef, binding: Definition<'db>, ty: Type<'db>) {
debug_assert!(binding.is_binding(self.db())); debug_assert!(binding.kind(self.db()).category().is_binding());
let use_def = self.index.use_def_map(binding.file_scope(self.db())); let use_def = self.index.use_def_map(binding.file_scope(self.db()));
let declarations = use_def.declarations_at_binding(binding); let declarations = use_def.declarations_at_binding(binding);
let mut bound_ty = ty; let mut bound_ty = ty;
@ -905,7 +935,7 @@ impl<'db> TypeInferenceBuilder<'db> {
declaration: Definition<'db>, declaration: Definition<'db>,
ty: TypeAndQualifiers<'db>, ty: TypeAndQualifiers<'db>,
) { ) {
debug_assert!(declaration.is_declaration(self.db())); debug_assert!(declaration.kind(self.db()).category().is_declaration());
let use_def = self.index.use_def_map(declaration.file_scope(self.db())); let use_def = self.index.use_def_map(declaration.file_scope(self.db()));
let prior_bindings = use_def.bindings_at_declaration(declaration); let prior_bindings = use_def.bindings_at_declaration(declaration);
// unbound_ty is Never because for this check we don't care about unbound // unbound_ty is Never because for this check we don't care about unbound
@ -935,8 +965,8 @@ impl<'db> TypeInferenceBuilder<'db> {
definition: Definition<'db>, definition: Definition<'db>,
declared_and_inferred_ty: &DeclaredAndInferredType<'db>, declared_and_inferred_ty: &DeclaredAndInferredType<'db>,
) { ) {
debug_assert!(definition.is_binding(self.db())); debug_assert!(definition.kind(self.db()).category().is_binding());
debug_assert!(definition.is_declaration(self.db())); debug_assert!(definition.kind(self.db()).category().is_declaration());
let (declared_ty, inferred_ty) = match *declared_and_inferred_ty { let (declared_ty, inferred_ty) = match *declared_and_inferred_ty {
DeclaredAndInferredType::AreTheSame(ty) => (ty.into(), ty), DeclaredAndInferredType::AreTheSame(ty) => (ty.into(), ty),
@ -6626,4 +6656,93 @@ mod tests {
Ok(()) Ok(())
} }
/// This test verifies that changing a class's declaration in a non-meaningful way (e.g. by adding a comment)
/// doesn't trigger type inference for expressions that depend on the class's members.
#[test]
fn dependency_own_instance_member() -> anyhow::Result<()> {
fn x_rhs_expression(db: &TestDb) -> Expression<'_> {
let file_main = system_path_to_file(db, "/src/main.py").unwrap();
let ast = parsed_module(db, file_main);
// Get the second statement in `main.py` (x = …) and extract the expression
// node on the right-hand side:
let x_rhs_node = &ast.syntax().body[1].as_assign_stmt().unwrap().value;
let index = semantic_index(db, file_main);
index.expression(x_rhs_node.as_ref())
}
let mut db = setup_db();
db.write_dedented(
"/src/mod.py",
r#"
class C:
if random.choice([True, False]):
attr: int = 42
else:
attr: None = None
"#,
)?;
db.write_dedented(
"/src/main.py",
r#"
from mod import C
x = C().attr
"#,
)?;
let file_main = system_path_to_file(&db, "/src/main.py").unwrap();
let attr_ty = global_symbol(&db, file_main, "x").expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | int | None");
// Change the type of `attr` to `str | None`; this should trigger the type of `x` to be re-inferred
db.write_dedented(
"/src/mod.py",
r#"
class C:
if random.choice([True, False]):
attr: str = "42"
else:
attr: None = None
"#,
)?;
let events = {
db.clear_salsa_events();
let attr_ty = global_symbol(&db, file_main, "x").expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None");
db.take_salsa_events()
};
assert_function_query_was_run(&db, infer_expression_types, x_rhs_expression(&db), &events);
// Add a comment; this should not trigger the type of `x` to be re-inferred
db.write_dedented(
"/src/mod.py",
r#"
class C:
# comment
if random.choice([True, False]):
attr: str = "42"
else:
attr: None = None
"#,
)?;
let events = {
db.clear_salsa_events();
let attr_ty = global_symbol(&db, file_main, "x").expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None");
db.take_salsa_events()
};
assert_function_query_was_not_run(
&db,
infer_expression_types,
x_rhs_expression(&db),
&events,
);
Ok(())
}
} }

View file

@ -6,6 +6,7 @@ use crate::semantic_index::definition::Definition;
use crate::semantic_index::expression::Expression; use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable}; use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable};
use crate::semantic_index::symbol_table; use crate::semantic_index::symbol_table;
use crate::types::infer::infer_same_file_expression_type;
use crate::types::{ use crate::types::{
infer_expression_types, ClassLiteralType, IntersectionBuilder, KnownClass, KnownFunction, infer_expression_types, ClassLiteralType, IntersectionBuilder, KnownClass, KnownFunction,
SubclassOfType, Truthiness, Type, UnionBuilder, SubclassOfType, Truthiness, Type, UnionBuilder,
@ -497,11 +498,8 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
if let Some(ast::ExprName { id, .. }) = subject.node_ref(self.db).as_name_expr() { if let Some(ast::ExprName { id, .. }) = subject.node_ref(self.db).as_name_expr() {
// SAFETY: we should always have a symbol for every Name node. // SAFETY: we should always have a symbol for every Name node.
let symbol = self.symbols().symbol_id_by_name(id).unwrap(); let symbol = self.symbols().symbol_id_by_name(id).unwrap();
let scope = self.scope(); let ty = infer_same_file_expression_type(self.db, cls).to_instance(self.db);
let inference = infer_expression_types(self.db, cls);
let ty = inference
.expression_type(cls.node_ref(self.db).scoped_expression_id(self.db, scope))
.to_instance(self.db);
let mut constraints = NarrowingConstraints::default(); let mut constraints = NarrowingConstraints::default();
constraints.insert(symbol, ty); constraints.insert(symbol, ty);
Some(constraints) Some(constraints)

View file

@ -178,11 +178,8 @@ use std::cmp::Ordering;
use ruff_index::{Idx, IndexVec}; use ruff_index::{Idx, IndexVec};
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use crate::semantic_index::{ use crate::semantic_index::constraint::{Constraint, ConstraintNode, PatternConstraintKind};
ast_ids::HasScopedExpressionId, use crate::types::{infer_expression_type, Truthiness};
constraint::{Constraint, ConstraintNode, PatternConstraintKind},
};
use crate::types::{infer_expression_types, Truthiness};
use crate::Db; use crate::Db;
/// A ternary formula that defines under what conditions a binding is visible. (A ternary formula /// A ternary formula that defines under what conditions a binding is visible. (A ternary formula
@ -617,28 +614,14 @@ impl<'db> VisibilityConstraints<'db> {
fn analyze_single(db: &dyn Db, constraint: &Constraint) -> Truthiness { fn analyze_single(db: &dyn Db, constraint: &Constraint) -> Truthiness {
match constraint.node { match constraint.node {
ConstraintNode::Expression(test_expr) => { ConstraintNode::Expression(test_expr) => {
let inference = infer_expression_types(db, test_expr); let ty = infer_expression_type(db, test_expr);
let scope = test_expr.scope(db);
let ty = inference
.expression_type(test_expr.node_ref(db).scoped_expression_id(db, scope));
ty.bool(db).negate_if(!constraint.is_positive) ty.bool(db).negate_if(!constraint.is_positive)
} }
ConstraintNode::Pattern(inner) => match inner.kind(db) { ConstraintNode::Pattern(inner) => match inner.kind(db) {
PatternConstraintKind::Value(value, guard) => { PatternConstraintKind::Value(value, guard) => {
let subject_expression = inner.subject(db); let subject_expression = inner.subject(db);
let inference = infer_expression_types(db, subject_expression); let subject_ty = infer_expression_type(db, subject_expression);
let scope = subject_expression.scope(db); let value_ty = infer_expression_type(db, *value);
let subject_ty = inference.expression_type(
subject_expression
.node_ref(db)
.scoped_expression_id(db, scope),
);
let inference = infer_expression_types(db, *value);
let scope = value.scope(db);
let value_ty = inference
.expression_type(value.node_ref(db).scoped_expression_id(db, scope));
if subject_ty.is_single_valued(db) { if subject_ty.is_single_valued(db) {
let truthiness = let truthiness =