[ty] don't assume that deferred type inference means deferred name resolution (#20160)

## Summary

We have the ability to defer type inference of some parts of
definitions, so as to allow us to create a type that may need to be
recursively referenced in those other parts of the definition.

We also have the ability to do type inference in a context where all
name resolution should be deferred (that is, names should be looked up
from all-reachable-definitions rather than from the location of use.)
This is used for all annotations in stubs, or if `from __future__ import
annotations` is active.

Previous to this PR, these two concepts were linked: deferred-inference
always implied deferred-name-resolution, though we also supported
deferred-name-resolution without deferred-inference, via
`DeferredExpressionState`.

For the upcoming `typing.TypeAlias` support, I will defer inference of
the entire RHS of the alias (so as to support cycles), but that doesn't
imply deferred name resolution; at runtime, the RHS of a name annotated
as `typing.TypeAlias` is executed eagerly.

So this PR fully de-couples the two concepts, instead explicitly setting
the `DeferredExpressionState` in those cases where we should defer name
resolution.

It also fixes a long-standing related bug, where we were deferring name
resolution of all names in class bases, if any of the class bases
contained a stringified annotation.

## Test Plan

Added test that failed before this PR.
This commit is contained in:
Carl Meyer 2025-08-29 16:19:45 -07:00 committed by GitHub
parent 694e7ed52e
commit 17dc2e4d80
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 48 additions and 7 deletions

View file

@ -0,0 +1,25 @@
# Class definitions
## Deferred resolution of bases
### Only the stringified name is deferred
If a class base contains a stringified name, only that name is deferred. Other names are resolved
normally.
```toml
[environment]
python-version = "3.12"
```
```py
A = int
class G[T]: ...
class C(A, G["B"]): ...
A = str
B = bytes
reveal_type(C.__mro__) # revealed: tuple[<class 'C'>, <class 'int'>, <class 'G[bytes]'>, typing.Generic, <class 'object'>]
```

View file

@ -999,9 +999,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.index.has_future_annotations() || self.in_stub() self.index.has_future_annotations() || self.in_stub()
} }
/// Are we currently inferring deferred types? /// Are we currently in a context where name resolution should be deferred
/// (`__future__.annotations`, stub file, or stringified annotation)?
fn is_deferred(&self) -> bool { fn is_deferred(&self) -> bool {
matches!(self.region, InferenceRegion::Deferred(_)) || self.deferred_state.is_deferred() self.deferred_state.is_deferred()
} }
/// Return the node key of the given AST node, or the key of the outermost enclosing string /// Return the node key of the given AST node, or the key of the outermost enclosing string
@ -3173,10 +3174,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.infer_expression(&keyword.value); self.infer_expression(&keyword.value);
} }
// Inference of bases deferred in stubs // Inference of bases deferred in stubs, or if any are string literals.
// TODO: Only defer the references that are actually string literals, instead of
// deferring the entire class definition if a string literal occurs anywhere in the
// base class list.
if self.in_stub() || class_node.bases().iter().any(contains_string_literal) { if self.in_stub() || class_node.bases().iter().any(contains_string_literal) {
self.deferred.insert(definition); self.deferred.insert(definition);
} else { } else {
@ -3207,8 +3205,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
fn infer_class_deferred(&mut self, definition: Definition<'db>, class: &ast::StmtClassDef) { fn infer_class_deferred(&mut self, definition: Definition<'db>, class: &ast::StmtClassDef) {
let previous_typevar_binding_context = self.typevar_binding_context.replace(definition); let previous_typevar_binding_context = self.typevar_binding_context.replace(definition);
for base in class.bases() { for base in class.bases() {
if self.in_stub() {
self.infer_expression_with_state(base, DeferredExpressionState::Deferred);
} else {
self.infer_expression(base); self.infer_expression(base);
} }
}
self.typevar_binding_context = previous_typevar_binding_context; self.typevar_binding_context = previous_typevar_binding_context;
} }
@ -3561,6 +3563,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
bound, bound,
default, default,
} = node; } = node;
let previous_deferred_state =
std::mem::replace(&mut self.deferred_state, DeferredExpressionState::Deferred);
match bound.as_deref() { match bound.as_deref() {
Some(expr @ ast::Expr::Tuple(ast::ExprTuple { elts, .. })) => { Some(expr @ ast::Expr::Tuple(ast::ExprTuple { elts, .. })) => {
// We don't use UnionType::from_elements or UnionBuilder here, because we don't // We don't use UnionType::from_elements or UnionBuilder here, because we don't
@ -3582,6 +3586,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
None => {} None => {}
} }
self.infer_optional_type_expression(default.as_deref()); self.infer_optional_type_expression(default.as_deref());
self.deferred_state = previous_deferred_state;
} }
fn infer_paramspec_definition( fn infer_paramspec_definition(
@ -5600,6 +5605,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.infer_expression_impl(expression) self.infer_expression_impl(expression)
} }
fn infer_expression_with_state(
&mut self,
expression: &ast::Expr,
state: DeferredExpressionState,
) -> Type<'db> {
let previous_deferred_state = std::mem::replace(&mut self.deferred_state, state);
let ty = self.infer_expression(expression);
self.deferred_state = previous_deferred_state;
ty
}
fn infer_maybe_standalone_expression(&mut self, expression: &ast::Expr) -> Type<'db> { fn infer_maybe_standalone_expression(&mut self, expression: &ast::Expr) -> Type<'db> {
if let Some(standalone_expression) = self.index.try_expression(expression) { if let Some(standalone_expression) = self.index.try_expression(expression) {
self.infer_standalone_expression_impl(expression, standalone_expression) self.infer_standalone_expression_impl(expression, standalone_expression)