[red-knot] fix eager nested scopes handling (#16916)

## Summary

From #16861, and the continuation of #16915.

This PR fixes the incorrect behavior of
`TypeInferenceBuilder::infer_name_load` in eager nested scopes.

And this PR closes #16341.

## Test Plan

New test cases are added in `annotations/deferred.md`.
This commit is contained in:
Shunsuke Shibayama 2025-03-29 00:11:56 +09:00 committed by GitHub
parent 64171744dc
commit aca6254e82
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 242 additions and 64 deletions

View file

@ -45,3 +45,104 @@ class Foo: ...
reveal_type(get_foo()) # revealed: Foo
```
## Deferred self-reference annotations in a class definition
```py
from __future__ import annotations
class Foo:
this: Foo
# error: [unresolved-reference]
_ = Foo()
# error: [unresolved-reference]
[Foo for _ in range(1)]
a = int
def f(self, x: Foo):
reveal_type(x) # revealed: Foo
def g(self) -> Foo:
_: Foo = self
return self
class Bar:
foo: Foo
b = int
def f(self, x: Foo):
return self
# error: [unresolved-reference]
def g(self) -> Bar:
return self
# error: [unresolved-reference]
def h[T: Bar](self):
pass
class Baz[T: Foo]:
pass
# error: [unresolved-reference]
type S = a
type T = b
def h[T: Bar]():
# error: [unresolved-reference]
return Bar()
type Baz = Foo
```
## Non-deferred self-reference annotations in a class definition
```py
class Foo:
# error: [unresolved-reference]
this: Foo
ok: "Foo"
# error: [unresolved-reference]
_ = Foo()
# error: [unresolved-reference]
[Foo for _ in range(1)]
a = int
# error: [unresolved-reference]
def f(self, x: Foo):
reveal_type(x) # revealed: Unknown
# error: [unresolved-reference]
def g(self) -> Foo:
_: Foo = self
return self
class Bar:
# error: [unresolved-reference]
foo: Foo
b = int
# error: [unresolved-reference]
def f(self, x: Foo):
return self
# error: [unresolved-reference]
def g(self) -> Bar:
return self
# error: [unresolved-reference]
def h[T: Bar](self):
pass
class Baz[T: Foo]:
pass
# error: [unresolved-reference]
type S = a
type T = b
def h[T: Bar]():
# error: [unresolved-reference]
return Bar()
type Qux = Foo
def _():
class C:
# error: [unresolved-reference]
def f(self) -> C:
return self
```

View file

@ -14,43 +14,43 @@ We support inference for all Python's binary operators: `+`, `-`, `*`, `@`, `/`,
```py
class A:
def __add__(self, other) -> A:
def __add__(self, other) -> "A":
return self
def __sub__(self, other) -> A:
def __sub__(self, other) -> "A":
return self
def __mul__(self, other) -> A:
def __mul__(self, other) -> "A":
return self
def __matmul__(self, other) -> A:
def __matmul__(self, other) -> "A":
return self
def __truediv__(self, other) -> A:
def __truediv__(self, other) -> "A":
return self
def __floordiv__(self, other) -> A:
def __floordiv__(self, other) -> "A":
return self
def __mod__(self, other) -> A:
def __mod__(self, other) -> "A":
return self
def __pow__(self, other) -> A:
def __pow__(self, other) -> "A":
return self
def __lshift__(self, other) -> A:
def __lshift__(self, other) -> "A":
return self
def __rshift__(self, other) -> A:
def __rshift__(self, other) -> "A":
return self
def __and__(self, other) -> A:
def __and__(self, other) -> "A":
return self
def __xor__(self, other) -> A:
def __xor__(self, other) -> "A":
return self
def __or__(self, other) -> A:
def __or__(self, other) -> "A":
return self
class B: ...
@ -76,43 +76,43 @@ We also support inference for reflected operations:
```py
class A:
def __radd__(self, other) -> A:
def __radd__(self, other) -> "A":
return self
def __rsub__(self, other) -> A:
def __rsub__(self, other) -> "A":
return self
def __rmul__(self, other) -> A:
def __rmul__(self, other) -> "A":
return self
def __rmatmul__(self, other) -> A:
def __rmatmul__(self, other) -> "A":
return self
def __rtruediv__(self, other) -> A:
def __rtruediv__(self, other) -> "A":
return self
def __rfloordiv__(self, other) -> A:
def __rfloordiv__(self, other) -> "A":
return self
def __rmod__(self, other) -> A:
def __rmod__(self, other) -> "A":
return self
def __rpow__(self, other) -> A:
def __rpow__(self, other) -> "A":
return self
def __rlshift__(self, other) -> A:
def __rlshift__(self, other) -> "A":
return self
def __rrshift__(self, other) -> A:
def __rrshift__(self, other) -> "A":
return self
def __rand__(self, other) -> A:
def __rand__(self, other) -> "A":
return self
def __rxor__(self, other) -> A:
def __rxor__(self, other) -> "A":
return self
def __ror__(self, other) -> A:
def __ror__(self, other) -> "A":
return self
class B: ...
@ -157,11 +157,11 @@ the right-hand side is not a subtype of the left-hand side, `lhs.__add__` will t
```py
class A:
def __add__(self, other: B) -> int:
def __add__(self, other: "B") -> int:
return 42
class B:
def __radd__(self, other: A) -> str:
def __radd__(self, other: "A") -> str:
return "foo"
reveal_type(A() + B()) # revealed: int
@ -169,10 +169,10 @@ reveal_type(A() + B()) # revealed: int
# Edge case: C is a subtype of C, *but* if the two sides are of *equal* types,
# the lhs *still* takes precedence
class C:
def __add__(self, other: C) -> int:
def __add__(self, other: "C") -> int:
return 42
def __radd__(self, other: C) -> str:
def __radd__(self, other: "C") -> str:
return "foo"
reveal_type(C() + C()) # revealed: int
@ -237,11 +237,11 @@ well.
```py
class A:
def __sub__(self, other: A) -> A:
def __sub__(self, other: "A") -> "A":
return A()
class B:
def __rsub__(self, other: A) -> B:
def __rsub__(self, other: A) -> "B":
return B()
reveal_type(A() - B()) # revealed: B
@ -300,10 +300,10 @@ its instance super-type.
```py
class A:
def __add__(self, other) -> A:
def __add__(self, other) -> "A":
return self
def __radd__(self, other) -> A:
def __radd__(self, other) -> "A":
return self
reveal_type(A() + 1) # revealed: A
@ -433,7 +433,7 @@ the unreflected dunder of the left-hand operand. For context, see
```py
class Foo:
def __radd__(self, other: Foo) -> Foo:
def __radd__(self, other: "Foo") -> "Foo":
return self
# error: [unsupported-operator]

View file

@ -382,13 +382,13 @@ arbitrary objects to a `bool`, but a comparison of tuples will fail if the resul
pair of elements at equivalent positions cannot be converted to a `bool`:
```py
class NotBoolable:
__bool__: None = None
class A:
def __eq__(self, other) -> NotBoolable:
return NotBoolable()
class NotBoolable:
__bool__: None = None
# error: [unsupported-bool-conversion]
(A(),) == (A(),)
```

View file

@ -154,6 +154,10 @@ x = 1
[reveal_type(x) for a in range(1)]
x = 2
# error: [unresolved-reference]
[y for a in range(1)]
y = 1
```
### Set comprehensions
@ -165,6 +169,10 @@ x = 1
{reveal_type(x) for a in range(1)}
x = 2
# error: [unresolved-reference]
{y for a in range(1)}
y = 1
```
### Dict comprehensions
@ -176,6 +184,10 @@ x = 1
{a: reveal_type(x) for a in range(1)}
x = 2
# error: [unresolved-reference]
{a: y for a in range(1)}
y = 1
```
### Generator expressions
@ -187,6 +199,10 @@ x = 1
list(reveal_type(x) for a in range(1))
x = 2
# error: [unresolved-reference]
list(y for a in range(1))
y = 1
```
`evaluated_later.py`:
@ -262,6 +278,14 @@ def _():
[reveal_type(x) for a in range(1)]
x = 2
x = 1
def _():
class C:
# revealed: Unknown | Literal[1]
[reveal_type(x) for _ in [1]]
x = 2
```
### Eager scope within a lazy scope

View file

@ -12,12 +12,12 @@ mdtest path: crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.
## mdtest_snippet.py
```
1 | class A:
2 | def __eq__(self, other) -> NotBoolable:
3 | return NotBoolable()
4 |
5 | class NotBoolable:
6 | __bool__: None = None
1 | class NotBoolable:
2 | __bool__: None = None
3 |
4 | class A:
5 | def __eq__(self, other) -> NotBoolable:
6 | return NotBoolable()
7 |
8 | # error: [unsupported-bool-conversion]
9 | (A(),) == (A(),)

View file

@ -124,6 +124,12 @@ pub(crate) fn global_scope(db: &dyn Db, file: File) -> ScopeId<'_> {
FileScopeId::global().to_scope_id(db, file)
}
pub(crate) enum EagerBindingsResult<'map, 'db> {
Found(BindingWithConstraintsIterator<'map, 'db>),
NotFound,
NoLongerInEagerContext,
}
/// The symbol tables and use-def maps for all scopes in a file.
#[derive(Debug, Update)]
pub(crate) struct SemanticIndex<'db> {
@ -319,21 +325,40 @@ impl<'db> SemanticIndex<'db> {
self.has_future_annotations
}
/// Returns an iterator of bindings for a particular nested eager scope reference.
/// Returns
/// * `NoLongerInEagerContext` if the nested scope is no longer in an eager context
/// (that is, not every scope that will be traversed is eager).
/// * an iterator of bindings for a particular nested eager scope reference if the bindings exist.
/// * `NotFound` if the bindings do not exist in the nested eager scope.
pub(crate) fn eager_bindings(
&self,
enclosing_scope: FileScopeId,
symbol: &str,
nested_scope: FileScopeId,
) -> Option<BindingWithConstraintsIterator<'_, 'db>> {
let symbol_id = self.symbol_tables[enclosing_scope].symbol_id_by_name(symbol)?;
) -> EagerBindingsResult<'_, 'db> {
for (ancestor_scope_id, ancestor_scope) in self.ancestor_scopes(nested_scope) {
if ancestor_scope_id == enclosing_scope {
break;
}
if !ancestor_scope.is_eager() {
return EagerBindingsResult::NoLongerInEagerContext;
}
}
let Some(symbol_id) = self.symbol_tables[enclosing_scope].symbol_id_by_name(symbol) else {
return EagerBindingsResult::NotFound;
};
let key = EagerBindingsKey {
enclosing_scope,
enclosing_symbol: symbol_id,
nested_scope,
};
let id = self.eager_bindings.get(&key)?;
self.use_def_maps[enclosing_scope].eager_bindings(*id)
let Some(id) = self.eager_bindings.get(&key) else {
return EagerBindingsResult::NotFound;
};
match self.use_def_maps[enclosing_scope].eager_bindings(*id) {
Some(bindings) => EagerBindingsResult::Found(bindings),
None => EagerBindingsResult::NotFound,
}
}
}

View file

@ -256,11 +256,11 @@ impl<'db> SemanticIndexBuilder<'db> {
}
for nested_symbol in self.symbol_tables[popped_scope_id].symbols() {
// Skip this symbol if this enclosing scope doesn't contain any bindings for
// it, or if the nested scope _does_.
if nested_symbol.is_bound() {
continue;
}
// Skip this symbol if this enclosing scope doesn't contain any bindings for it.
// Note that even if this symbol is bound in the popped scope,
// it may refer to the enclosing scope bindings
// so we also need to snapshot the bindings of the enclosing scope.
let Some(enclosing_symbol_id) =
enclosing_symbol_table.symbol_id_by_name(nested_symbol.name())
else {

View file

@ -114,6 +114,10 @@ impl<'db> ScopeId<'db> {
self.node(db).scope_kind().is_function_like()
}
pub(crate) fn is_type_parameter(self, db: &'db dyn Db) -> bool {
self.node(db).scope_kind().is_type_parameter()
}
pub(crate) fn node(self, db: &dyn Db) -> &NodeWithScopeKind {
self.scope(db).node()
}
@ -226,9 +230,8 @@ pub enum ScopeKind {
impl ScopeKind {
pub(crate) fn is_eager(self) -> bool {
match self {
ScopeKind::Class | ScopeKind::Comprehension => true,
ScopeKind::Module
| ScopeKind::Annotation
ScopeKind::Module | ScopeKind::Class | ScopeKind::Comprehension => true,
ScopeKind::Annotation
| ScopeKind::Function
| ScopeKind::Lambda
| ScopeKind::TypeAlias => false,
@ -251,6 +254,10 @@ impl ScopeKind {
pub(crate) fn is_class(self) -> bool {
matches!(self, ScopeKind::Class)
}
pub(crate) fn is_type_parameter(self) -> bool {
matches!(self, ScopeKind::Annotation | ScopeKind::TypeAlias)
}
}
/// Symbol table for a specific [`Scope`].

View file

@ -51,11 +51,10 @@ use crate::semantic_index::definition::{
ExceptHandlerDefinitionKind, ForStmtDefinitionKind, TargetKind, WithItemDefinitionKind,
};
use crate::semantic_index::expression::{Expression, ExpressionKind};
use crate::semantic_index::semantic_index;
use crate::semantic_index::symbol::{
FileScopeId, NodeWithScopeKind, NodeWithScopeRef, ScopeId, ScopeKind,
};
use crate::semantic_index::SemanticIndex;
use crate::semantic_index::{semantic_index, EagerBindingsResult, SemanticIndex};
use crate::symbol::{
builtins_module_scope, builtins_symbol, explicit_global_symbol,
module_type_implicit_global_symbol, symbol, symbol_from_bindings, symbol_from_declarations,
@ -4131,8 +4130,15 @@ impl<'db> TypeInferenceBuilder<'db> {
// Class scopes are not visible to nested scopes, and we need to handle global
// scope differently (because an unbound name there falls back to builtins), so
// check only function-like scopes.
// There is one exception to this rule: type parameter scopes can see
// names defined in an immediately-enclosing class scope.
let enclosing_scope_id = enclosing_scope_file_id.to_scope_id(db, current_file);
if !enclosing_scope_id.is_function_like(db) {
let is_immediately_enclosing_scope = scope.is_type_parameter(db)
&& scope
.scope(db)
.parent()
.is_some_and(|parent| parent == enclosing_scope_file_id);
if !enclosing_scope_id.is_function_like(db) && !is_immediately_enclosing_scope {
continue;
}
@ -4143,12 +4149,20 @@ impl<'db> TypeInferenceBuilder<'db> {
// enclosing scopes that actually contain bindings that we should use when
// resolving the reference.)
if !self.is_deferred() {
if let Some(bindings) = self.index.eager_bindings(
match self.index.eager_bindings(
enclosing_scope_file_id,
symbol_name,
file_scope_id,
) {
return symbol_from_bindings(db, bindings).into();
EagerBindingsResult::Found(bindings) => {
return symbol_from_bindings(db, bindings).into();
}
// There are no visible bindings here.
// Don't fall back to non-eager symbol resolution.
EagerBindingsResult::NotFound => {
continue;
}
EagerBindingsResult::NoLongerInEagerContext => {}
}
}
@ -4176,12 +4190,19 @@ impl<'db> TypeInferenceBuilder<'db> {
}
if !self.is_deferred() {
if let Some(bindings) = self.index.eager_bindings(
match self.index.eager_bindings(
FileScopeId::global(),
symbol_name,
file_scope_id,
) {
return symbol_from_bindings(db, bindings).into();
EagerBindingsResult::Found(bindings) => {
return symbol_from_bindings(db, bindings).into();
}
// There are no visible bindings here.
EagerBindingsResult::NotFound => {
return Symbol::Unbound.into();
}
EagerBindingsResult::NoLongerInEagerContext => {}
}
}