[red-knot] Implicit instance attributes (#15811)

## Summary

Add support for implicitly-defined instance attributes, i.e. support
type inference for cases like this:
```py
class C:
    def __init__(self) -> None:
        self.x: int = 1
        self.y = None

reveal_type(C().x)  # int
reveal_type(C().y)  # Unknown | None
```

## Benchmarks

Codspeed reports no change in a cold-cache benchmark, and a -1%
regression in the incremental benchmark. On `black`'s `src` folder, I
don't see a statistically significant difference between the branches:

| Command | Mean [ms] | Min [ms] | Max [ms] | Relative |
|:---|---:|---:|---:|---:|
| `./red_knot_main check --project /home/shark/black/src` | 133.7 ± 9.5 | 126.7 | 164.7 | 1.01 ± 0.08 |
| `./red_knot_feature check --project /home/shark/black/src` | 132.2 ± 5.1 | 118.1 | 140.9 | 1.00 |

## Test Plan

Updated and new Markdown tests
This commit is contained in:
David Peter 2025-02-03 19:34:23 +01:00 committed by GitHub
parent dc5e922221
commit 102c2eec12
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 657 additions and 53 deletions

View file

@ -11,6 +11,7 @@ use ruff_index::{IndexSlice, IndexVec};
use crate::module_name::ModuleName;
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
use crate::semantic_index::ast_ids::AstIds;
use crate::semantic_index::attribute_assignment::AttributeAssignments;
use crate::semantic_index::builder::SemanticIndexBuilder;
use crate::semantic_index::definition::{Definition, DefinitionNodeKey};
use crate::semantic_index::expression::Expression;
@ -21,6 +22,7 @@ use crate::semantic_index::use_def::UseDefMap;
use crate::Db;
pub mod ast_ids;
pub mod attribute_assignment;
mod builder;
pub(crate) mod constraint;
pub mod definition;
@ -93,6 +95,25 @@ pub(crate) fn use_def_map<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc<UseD
index.use_def_map(scope.file_scope_id(db))
}
/// Returns all attribute assignments for a specific class body scope.
///
/// Using [`attribute_assignments`] over [`semantic_index`] has the advantage that
/// Salsa can avoid invalidating dependent queries if this scope's instance attributes
/// are unchanged.
#[salsa::tracked]
pub(crate) fn attribute_assignments<'db>(
db: &'db dyn Db,
class_body_scope: ScopeId<'db>,
) -> Option<Arc<AttributeAssignments<'db>>> {
let file = class_body_scope.file(db);
let index = semantic_index(db, file);
index
.attribute_assignments
.get(&class_body_scope.file_scope_id(db))
.cloned()
}
/// Returns the module global scope of `file`.
#[salsa::tracked]
pub(crate) fn global_scope(db: &dyn Db, file: File) -> ScopeId<'_> {
@ -139,6 +160,10 @@ pub(crate) struct SemanticIndex<'db> {
/// Flags about the global scope (code usage impacting inference)
has_future_annotations: bool,
/// Maps from class body scopes to attribute assignments that were found
/// in methods of that class.
attribute_assignments: FxHashMap<FileScopeId, Arc<AttributeAssignments<'db>>>,
}
impl<'db> SemanticIndex<'db> {

View file

@ -0,0 +1,19 @@
use crate::semantic_index::expression::Expression;
use ruff_python_ast::name::Name;
use rustc_hash::FxHashMap;
/// Describes an (annotated) attribute assignment that we discovered in a method
/// body, typically of the form `self.x: int`, `self.x: int = …` or `self.x = …`.
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum AttributeAssignment<'db> {
/// An attribute assignment with an explicit type annotation, either
/// `self.x: <annotation>` or `self.x: <annotation> = …`.
Annotated { annotation: Expression<'db> },
/// An attribute assignment without a type annotation, e.g. `self.x = <value>`.
Unannotated { value: Expression<'db> },
}
pub(crate) type AttributeAssignments<'db> = FxHashMap<Name, Vec<AttributeAssignment<'db>>>;

View file

@ -14,14 +14,15 @@ use crate::ast_node_ref::AstNodeRef;
use crate::module_name::ModuleName;
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
use crate::semantic_index::ast_ids::AstIdsBuilder;
use crate::semantic_index::attribute_assignment::{AttributeAssignment, AttributeAssignments};
use crate::semantic_index::constraint::PatternConstraintKind;
use crate::semantic_index::definition::{
AssignmentDefinitionNodeRef, ComprehensionDefinitionNodeRef, Definition, DefinitionNodeKey,
DefinitionNodeRef, ForStmtDefinitionNodeRef, ImportFromDefinitionNodeRef,
};
use crate::semantic_index::expression::Expression;
use crate::semantic_index::expression::{Expression, ExpressionKind};
use crate::semantic_index::symbol::{
FileScopeId, NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopedSymbolId,
FileScopeId, NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopeKind, ScopedSymbolId,
SymbolTableBuilder,
};
use crate::semantic_index::use_def::{
@ -53,17 +54,24 @@ impl LoopState {
}
}
struct ScopeInfo {
file_scope_id: FileScopeId,
loop_state: LoopState,
}
pub(super) struct SemanticIndexBuilder<'db> {
// Builder state
db: &'db dyn Db,
file: File,
module: &'db ParsedModule,
scope_stack: Vec<(FileScopeId, LoopState)>,
scope_stack: Vec<ScopeInfo>,
/// The assignments we're currently visiting, with
/// the most recent visit at the end of the Vec
current_assignments: Vec<CurrentAssignment<'db>>,
/// The match case we're currently visiting.
current_match_case: Option<CurrentMatchCase<'db>>,
/// The name of the first function parameter of the innermost function that we're currently visiting.
current_first_parameter_name: Option<&'db str>,
/// Flow states at each `break` in the current loop.
loop_break_states: Vec<FlowSnapshot>,
@ -84,6 +92,7 @@ pub(super) struct SemanticIndexBuilder<'db> {
definitions_by_node: FxHashMap<DefinitionNodeKey, Definition<'db>>,
expressions_by_node: FxHashMap<ExpressionNodeKey, Expression<'db>>,
imported_modules: FxHashSet<ModuleName>,
attribute_assignments: FxHashMap<FileScopeId, AttributeAssignments<'db>>,
}
impl<'db> SemanticIndexBuilder<'db> {
@ -95,6 +104,7 @@ impl<'db> SemanticIndexBuilder<'db> {
scope_stack: Vec::new(),
current_assignments: vec![],
current_match_case: None,
current_first_parameter_name: None,
loop_break_states: vec![],
try_node_context_stack_manager: TryNodeContextStackManager::default(),
@ -112,6 +122,8 @@ impl<'db> SemanticIndexBuilder<'db> {
expressions_by_node: FxHashMap::default(),
imported_modules: FxHashSet::default(),
attribute_assignments: FxHashMap::default(),
};
builder.push_scope_with_parent(NodeWithScopeRef::Module, None);
@ -123,7 +135,7 @@ impl<'db> SemanticIndexBuilder<'db> {
*self
.scope_stack
.last()
.map(|(scope, _)| scope)
.map(|ScopeInfo { file_scope_id, .. }| file_scope_id)
.expect("Always to have a root scope")
}
@ -131,14 +143,32 @@ impl<'db> SemanticIndexBuilder<'db> {
self.scope_stack
.last()
.expect("Always to have a root scope")
.1
.loop_state
}
/// Returns the scope ID of the surrounding class body scope if the current scope
/// is a method inside a class body. Returns `None` otherwise, e.g. if the current
/// scope is a function body outside of a class, or if the current scope is not a
/// function body.
fn is_method_of_class(&self) -> Option<FileScopeId> {
let mut scopes_rev = self.scope_stack.iter().rev();
let current = scopes_rev.next()?;
let parent = scopes_rev.next()?;
match (
self.scopes[current.file_scope_id].kind(),
self.scopes[parent.file_scope_id].kind(),
) {
(ScopeKind::Function, ScopeKind::Class) => Some(parent.file_scope_id),
_ => None,
}
}
fn set_inside_loop(&mut self, state: LoopState) {
self.scope_stack
.last_mut()
.expect("Always to have a root scope")
.1 = state;
.loop_state = state;
}
fn push_scope(&mut self, node: NodeWithScopeRef) {
@ -171,16 +201,20 @@ impl<'db> SemanticIndexBuilder<'db> {
debug_assert_eq!(ast_id_scope, file_scope_id);
self.scope_stack.push((file_scope_id, LoopState::NotInLoop));
self.scope_stack.push(ScopeInfo {
file_scope_id,
loop_state: LoopState::NotInLoop,
});
}
fn pop_scope(&mut self) -> FileScopeId {
let (id, _) = self.scope_stack.pop().expect("Root scope to be present");
let ScopeInfo { file_scope_id, .. } =
self.scope_stack.pop().expect("Root scope to be present");
let children_end = self.scopes.next_index();
let scope = &mut self.scopes[id];
let scope = &mut self.scopes[file_scope_id];
scope.descendents = scope.descendents.start..children_end;
self.try_node_context_stack_manager.exit_scope();
id
file_scope_id
}
fn current_symbol_table(&mut self) -> &mut SymbolTableBuilder {
@ -404,6 +438,32 @@ impl<'db> SemanticIndexBuilder<'db> {
self.current_assignments.last_mut()
}
/// Records the fact that we saw an attribute assignment of the form
/// `object.attr: <annotation>( = …)` or `object.attr = <value>`.
fn register_attribute_assignment(
&mut self,
object: &ast::Expr,
attr: &'db ast::Identifier,
attribute_assignment: AttributeAssignment<'db>,
) {
if let Some(class_body_scope) = self.is_method_of_class() {
// We only care about attribute assignments to the first parameter of a method,
// i.e. typically `self` or `cls`.
let accessed_object_refers_to_first_parameter =
object.as_name_expr().map(|name| name.id.as_str())
== self.current_first_parameter_name;
if accessed_object_refers_to_first_parameter {
self.attribute_assignments
.entry(class_body_scope)
.or_default()
.entry(attr.id().clone())
.or_default()
.push(attribute_assignment);
}
}
}
fn add_pattern_constraint(
&mut self,
subject: Expression<'db>,
@ -457,6 +517,20 @@ impl<'db> SemanticIndexBuilder<'db> {
/// Record an expression that needs to be a Salsa ingredient, because we need to infer its type
/// standalone (type narrowing tests, RHS of an assignment.)
fn add_standalone_expression(&mut self, expression_node: &ast::Expr) -> Expression<'db> {
self.add_standalone_expression_impl(expression_node, ExpressionKind::Normal)
}
/// Same as [`SemanticIndexBuilder::add_standalone_expression`], but marks the expression as a
/// *type* expression, which makes sure that it will later be inferred as such.
fn add_standalone_type_expression(&mut self, expression_node: &ast::Expr) -> Expression<'db> {
self.add_standalone_expression_impl(expression_node, ExpressionKind::TypeExpression)
}
fn add_standalone_expression_impl(
&mut self,
expression_node: &ast::Expr,
expression_kind: ExpressionKind,
) -> Expression<'db> {
let expression = Expression::new(
self.db,
self.file,
@ -465,6 +539,7 @@ impl<'db> SemanticIndexBuilder<'db> {
unsafe {
AstNodeRef::new(self.module.clone(), expression_node)
},
expression_kind,
countme::Count::default(),
);
self.expressions_by_node
@ -668,6 +743,11 @@ impl<'db> SemanticIndexBuilder<'db> {
use_def_maps,
imported_modules: Arc::new(self.imported_modules),
has_future_annotations: self.has_future_annotations,
attribute_assignments: self
.attribute_assignments
.into_iter()
.map(|(k, v)| (k, Arc::new(v)))
.collect(),
}
}
}
@ -706,7 +786,17 @@ where
builder.declare_parameters(parameters);
let mut first_parameter_name = parameters
.iter_non_variadic_params()
.next()
.map(|first_param| first_param.parameter.name.id().as_str());
std::mem::swap(
&mut builder.current_first_parameter_name,
&mut first_parameter_name,
);
builder.visit_body(body);
builder.current_first_parameter_name = first_parameter_name;
builder.pop_scope()
},
);
@ -840,6 +930,19 @@ where
unpack: None,
first: false,
}),
ast::Expr::Attribute(ast::ExprAttribute {
value: object,
attr,
..
}) => {
self.register_attribute_assignment(
object,
attr,
AttributeAssignment::Unannotated { value },
);
None
}
_ => None,
};
@ -858,6 +961,7 @@ where
ast::Stmt::AnnAssign(node) => {
debug_assert_eq!(&self.current_assignments, &[]);
self.visit_expr(&node.annotation);
let annotation = self.add_standalone_type_expression(&node.annotation);
if let Some(value) = &node.value {
self.visit_expr(value);
}
@ -869,6 +973,20 @@ where
) {
self.push_assignment(node.into());
self.visit_expr(&node.target);
if let ast::Expr::Attribute(ast::ExprAttribute {
value: object,
attr,
..
}) = &*node.target
{
self.register_attribute_assignment(
object,
attr,
AttributeAssignment::Annotated { annotation },
);
}
self.pop_assignment();
} else {
self.visit_expr(&node.target);

View file

@ -5,6 +5,16 @@ use ruff_db::files::File;
use ruff_python_ast as ast;
use salsa;
/// Whether or not this expression should be inferred as a normal expression or
/// a type expression. For example, in `self.x: <annotation> = <value>`, the
/// `<annotation>` is inferred as a type expression, while `<value>` is inferred
/// as a normal expression.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) enum ExpressionKind {
Normal,
TypeExpression,
}
/// An independently type-inferable expression.
///
/// Includes constraint expressions (e.g. if tests) and the RHS of an unpacking assignment.
@ -35,6 +45,10 @@ pub(crate) struct Expression<'db> {
#[return_ref]
pub(crate) node_ref: AstNodeRef<ast::Expr>,
/// Should this expression be inferred as a normal expression or a type expression?
#[id]
pub(crate) kind: ExpressionKind,
#[no_eq]
count: countme::Count<Expression<'static>>,
}

View file

@ -119,6 +119,7 @@ impl<'db> ScopeId<'db> {
self.node(db).scope_kind(),
ScopeKind::Annotation
| ScopeKind::Function
| ScopeKind::Lambda
| ScopeKind::TypeAlias
| ScopeKind::Comprehension
)
@ -203,6 +204,7 @@ pub enum ScopeKind {
Annotation,
Class,
Function,
Lambda,
Comprehension,
TypeAlias,
}
@ -443,7 +445,8 @@ impl NodeWithScopeKind {
match self {
Self::Module => ScopeKind::Module,
Self::Class(_) => ScopeKind::Class,
Self::Function(_) | Self::Lambda(_) => ScopeKind::Function,
Self::Function(_) => ScopeKind::Function,
Self::Lambda(_) => ScopeKind::Lambda,
Self::FunctionTypeParameters(_)
| Self::ClassTypeParameters(_)
| Self::TypeAliasTypeParameters(_) => ScopeKind::Annotation,

View file

@ -23,11 +23,13 @@ pub use self::subclass_of::SubclassOfType;
use crate::module_name::ModuleName;
use crate::module_resolver::{file_to_module, resolve_module, KnownModule};
use crate::semantic_index::ast_ids::HasScopedExpressionId;
use crate::semantic_index::attribute_assignment::AttributeAssignment;
use crate::semantic_index::definition::Definition;
use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::{self as symbol, ScopeId, ScopedSymbolId};
use crate::semantic_index::{
global_scope, imported_modules, semantic_index, symbol_table, use_def_map,
BindingWithConstraints, BindingWithConstraintsIterator, DeclarationWithConstraint,
attribute_assignments, global_scope, imported_modules, semantic_index, symbol_table,
use_def_map, BindingWithConstraints, BindingWithConstraintsIterator, DeclarationWithConstraint,
DeclarationsIterator,
};
use crate::stdlib::{builtins_symbol, known_module_symbol, typing_extensions_symbol};
@ -4134,9 +4136,76 @@ impl<'db> Class<'db> {
}
}
// TODO: The symbol is not present in any class body, but it could be implicitly
// defined in `__init__` or other methods anywhere in the MRO.
todo_type!("implicit instance attribute").into()
SymbolAndQualifiers(Symbol::Unbound, TypeQualifiers::empty())
}
/// Tries to find declarations/bindings of an instance attribute named `name` that are only
/// "implicitly" defined in a method of the class that corresponds to `class_body_scope`.
fn implicit_instance_attribute(
db: &'db dyn Db,
class_body_scope: ScopeId<'db>,
name: &str,
inferred_type_from_class_body: Option<Type<'db>>,
) -> Symbol<'db> {
// We use a separate salsa query here to prevent unrelated changes in the AST of an external
// file from triggering re-evaluations of downstream queries.
// See the `dependency_implicit_instance_attribute` test for more information.
#[salsa::tracked]
fn infer_expression_type<'db>(db: &'db dyn Db, expression: Expression<'db>) -> Type<'db> {
let inference = infer_expression_types(db, expression);
let expr_scope = expression.scope(db);
inference.expression_type(expression.node_ref(db).scoped_expression_id(db, expr_scope))
}
// If we do not see any declarations of an attribute, neither in the class body nor in
// any method, we build a union of `Unknown` with the inferred types of all bindings of
// that attribute. We include `Unknown` in that union to account for the fact that the
// attribute might be externally modified.
let mut union_of_inferred_types = UnionBuilder::new(db).add(Type::unknown());
if let Some(ty) = inferred_type_from_class_body {
union_of_inferred_types = union_of_inferred_types.add(ty);
}
let attribute_assignments = attribute_assignments(db, class_body_scope);
let Some(attribute_assignments) = attribute_assignments
.as_deref()
.and_then(|assignments| assignments.get(name))
else {
if inferred_type_from_class_body.is_some() {
return union_of_inferred_types.build().into();
}
return Symbol::Unbound;
};
for attribute_assignment in attribute_assignments {
match attribute_assignment {
AttributeAssignment::Annotated { annotation } => {
// We found an annotated assignment of one of the following forms (using 'self' in these
// examples, but we support arbitrary names for the first parameters of methods):
//
// self.name: <annotation>
// self.name: <annotation> = …
let annotation_ty = infer_expression_type(db, *annotation);
// TODO: check if there are conflicting declarations
return annotation_ty.into();
}
AttributeAssignment::Unannotated { value } => {
// We found an un-annotated attribute assignment of the form:
//
// self.name = <value>
let inferred_ty = infer_expression_type(db, *value);
union_of_inferred_types = union_of_inferred_types.add(inferred_ty);
}
}
}
union_of_inferred_types.build().into()
}
/// A helper function for `instance_member` that looks up the `name` attribute only on
@ -4158,6 +4227,8 @@ impl<'db> Class<'db> {
match symbol_from_declarations(db, declarations) {
Ok(SymbolAndQualifiers(Symbol::Type(declared_ty, _), qualifiers)) => {
// The attribute is declared in the class body.
if let Some(function) = declared_ty.into_function_literal() {
// TODO: Eventually, we are going to process all decorators correctly. This is
// just a temporary heuristic to provide a broad categorization into properties
@ -4171,22 +4242,26 @@ impl<'db> Class<'db> {
SymbolAndQualifiers(Symbol::Type(declared_ty, Boundness::Bound), qualifiers)
}
}
Ok(symbol @ SymbolAndQualifiers(Symbol::Unbound, qualifiers)) => {
Ok(SymbolAndQualifiers(Symbol::Unbound, _)) => {
// The attribute is not *declared* in the class body. It could still be declared
// in a method, and it could also be *bound* in the class body (and/or in a method).
let bindings = use_def.public_bindings(symbol_id);
let inferred = symbol_from_bindings(db, bindings);
let inferred_ty = inferred.ignore_possibly_unbound();
SymbolAndQualifiers(
widen_type_for_undeclared_public_symbol(db, inferred, symbol.is_final()),
qualifiers,
)
Self::implicit_instance_attribute(db, body_scope, name, inferred_ty).into()
}
Err((declared_ty, _conflicting_declarations)) => {
// Ignore conflicting declarations
// There are conflicting declarations for this attribute in the class body.
SymbolAndQualifiers(declared_ty.inner_type().into(), declared_ty.qualifiers())
}
}
} else {
Symbol::Unbound.into()
// This attribute is neither declared nor bound in the class body.
// It could still be implicitly defined in a method.
Self::implicit_instance_attribute(db, body_scope, name, None).into()
}
}

View file

@ -44,7 +44,7 @@ use crate::semantic_index::definition::{
AssignmentDefinitionKind, Definition, DefinitionKind, DefinitionNodeKey,
ExceptHandlerDefinitionKind, ForStmtDefinitionKind, TargetKind,
};
use crate::semantic_index::expression::Expression;
use crate::semantic_index::expression::{Expression, ExpressionKind};
use crate::semantic_index::semantic_index;
use crate::semantic_index::symbol::{NodeWithScopeKind, NodeWithScopeRef, ScopeId};
use crate::semantic_index::SemanticIndex;
@ -832,7 +832,14 @@ impl<'db> TypeInferenceBuilder<'db> {
}
fn infer_region_expression(&mut self, expression: Expression<'db>) {
self.infer_expression_impl(expression.node_ref(self.db()));
match expression.kind(self.db()) {
ExpressionKind::Normal => {
self.infer_expression_impl(expression.node_ref(self.db()));
}
ExpressionKind::TypeExpression => {
self.infer_type_expression(expression.node_ref(self.db()));
}
}
}
/// Raise a diagnostic if the given type cannot be divided by zero.
@ -6019,7 +6026,7 @@ mod tests {
use crate::types::check_types;
use ruff_db::files::{system_path_to_file, File};
use ruff_db::system::DbWithTestSystem;
use ruff_db::testing::assert_function_query_was_not_run;
use ruff_db::testing::{assert_function_query_was_not_run, assert_function_query_was_run};
use super::*;
@ -6346,4 +6353,84 @@ mod tests {
);
Ok(())
}
#[test]
fn dependency_implicit_instance_attribute() -> anyhow::Result<()> {
fn x_rhs_expression(db: &TestDb) -> Expression<'_> {
let file_main = system_path_to_file(db, "/src/main.py").unwrap();
let ast = parsed_module(db, file_main);
// Get the second statement in `main.py` (x = …) and extract the expression
// node on the right-hand side:
let x_rhs_node = &ast.syntax().body[1].as_assign_stmt().unwrap().value;
let index = semantic_index(db, file_main);
index.expression(x_rhs_node.as_ref())
}
let mut db = setup_db();
db.write_dedented(
"/src/mod.py",
r#"
class C:
def f(self):
self.attr: int | None = None
"#,
)?;
db.write_dedented(
"/src/main.py",
r#"
from mod import C
x = C().attr
"#,
)?;
let file_main = system_path_to_file(&db, "/src/main.py").unwrap();
let attr_ty = global_symbol(&db, file_main, "x").expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | int | None");
// Change the type of `attr` to `str | None`; this should trigger the type of `x` to be re-inferred
db.write_dedented(
"/src/mod.py",
r#"
class C:
def f(self):
self.attr: str | None = None
"#,
)?;
let events = {
db.clear_salsa_events();
let attr_ty = global_symbol(&db, file_main, "x").expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None");
db.take_salsa_events()
};
assert_function_query_was_run(&db, infer_expression_types, x_rhs_expression(&db), &events);
// Add a comment; this should not trigger the type of `x` to be re-inferred
db.write_dedented(
"/src/mod.py",
r#"
class C:
def f(self):
# a comment!
self.attr: str | None = None
"#,
)?;
let events = {
db.clear_salsa_events();
let attr_ty = global_symbol(&db, file_main, "x").expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None");
db.take_salsa_events()
};
assert_function_query_was_not_run(
&db,
infer_expression_types,
x_rhs_expression(&db),
&events,
);
Ok(())
}
}