[ty] bidirectional type inference using function return type annotations (#20528)
Some checks failed
[ty Playground] Release / publish (push) Has been cancelled
CI / Determine changes (push) Has been cancelled
CI / cargo fmt (push) Has been cancelled
CI / cargo build (release) (push) Has been cancelled
CI / python package (push) Has been cancelled
CI / pre-commit (push) Has been cancelled
CI / mkdocs (push) Has been cancelled
CI / cargo clippy (push) Has been cancelled
CI / cargo test (linux) (push) Has been cancelled
CI / cargo test (linux, release) (push) Has been cancelled
CI / cargo test (windows) (push) Has been cancelled
CI / cargo test (wasm) (push) Has been cancelled
CI / cargo build (msrv) (push) Has been cancelled
CI / cargo fuzz build (push) Has been cancelled
CI / fuzz parser (push) Has been cancelled
CI / test scripts (push) Has been cancelled
CI / ecosystem (push) Has been cancelled
CI / Fuzz for new ty panics (push) Has been cancelled
CI / cargo shear (push) Has been cancelled
CI / ty completion evaluation (push) Has been cancelled
CI / formatter instabilities and black similarity (push) Has been cancelled
CI / test ruff-lsp (push) Has been cancelled
CI / check playground (push) Has been cancelled
CI / benchmarks instrumented (ruff) (push) Has been cancelled
CI / benchmarks instrumented (ty) (push) Has been cancelled
CI / benchmarks walltime (medium|multithreaded) (push) Has been cancelled
CI / benchmarks walltime (small|large) (push) Has been cancelled

## Summary

Implements bidirectional type inference using function return type
annotations.

This PR was originally proposed to solve astral-sh/ty#1167, but this
does not fully resolve it on its own.
Additionally, I believe we need to allow dataclasses to generate their
own `__new__` methods, [use constructor return types ​​for
inference](5844c0103d/crates/ty_python_semantic/src/types.rs (L5326-L5328)),
and a mechanism to discard type narrowing like `& ~AlwaysFalsy` if
necessary (at a more general level than this PR).

## Test Plan

`mdtest/bidirectional.md` is added.

---------

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Ibraheem Ahmed <ibraheem@ibraheem.ca>
This commit is contained in:
Shunsuke Shibayama 2025-10-11 09:38:35 +09:00 committed by GitHub
parent 11a9e7ee44
commit dc64c08633
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 442 additions and 58 deletions

View file

@ -146,7 +146,75 @@ r: dict[int | str, int | str] = {1: 1, 2: 2, 3: 3}
reveal_type(r) # revealed: dict[int | str, int | str] reveal_type(r) # revealed: dict[int | str, int | str]
``` ```
## Incorrect collection literal assignments are complained aobut ## Optional collection literal annotations are understood
```toml
[environment]
python-version = "3.12"
```
```py
import typing
a: list[int] | None = [1, 2, 3]
reveal_type(a) # revealed: list[int]
b: list[int | str] | None = [1, 2, 3]
reveal_type(b) # revealed: list[int | str]
c: typing.List[int] | None = [1, 2, 3]
reveal_type(c) # revealed: list[int]
d: list[typing.Any] | None = []
reveal_type(d) # revealed: list[Any]
e: set[int] | None = {1, 2, 3}
reveal_type(e) # revealed: set[int]
f: set[int | str] | None = {1, 2, 3}
reveal_type(f) # revealed: set[int | str]
g: typing.Set[int] | None = {1, 2, 3}
reveal_type(g) # revealed: set[int]
h: list[list[int]] | None = [[], [42]]
reveal_type(h) # revealed: list[list[int]]
i: list[typing.Any] | None = [1, 2, "3", ([4],)]
reveal_type(i) # revealed: list[Any | int | str | tuple[list[Unknown | int]]]
j: list[tuple[str | int, ...]] | None = [(1, 2), ("foo", "bar"), ()]
reveal_type(j) # revealed: list[tuple[str | int, ...]]
k: list[tuple[list[int], ...]] | None = [([],), ([1, 2], [3, 4]), ([5], [6], [7])]
reveal_type(k) # revealed: list[tuple[list[int], ...]]
l: tuple[list[int], *tuple[list[typing.Any], ...], list[str]] | None = ([1, 2, 3], [4, 5, 6], [7, 8, 9], ["10", "11", "12"])
# TODO: this should be `tuple[list[int], list[Any | int], list[Any | int], list[str]]`
reveal_type(l) # revealed: tuple[list[Unknown | int], list[Unknown | int], list[Unknown | int], list[Unknown | str]]
type IntList = list[int]
m: IntList | None = [1, 2, 3]
reveal_type(m) # revealed: list[int]
n: list[typing.Literal[1, 2, 3]] | None = [1, 2, 3]
reveal_type(n) # revealed: list[Literal[1, 2, 3]]
o: list[typing.LiteralString] | None = ["a", "b", "c"]
reveal_type(o) # revealed: list[LiteralString]
p: dict[int, int] | None = {}
reveal_type(p) # revealed: dict[int, int]
q: dict[int | str, int] | None = {1: 1, 2: 2, 3: 3}
reveal_type(q) # revealed: dict[int | str, int]
r: dict[int | str, int | str] | None = {1: 1, 2: 2, 3: 3}
reveal_type(r) # revealed: dict[int | str, int | str]
```
## Incorrect collection literal assignments are complained about
```py ```py
# error: [invalid-assignment] "Object of type `list[Unknown | int]` is not assignable to `list[str]`" # error: [invalid-assignment] "Object of type `list[Unknown | int]` is not assignable to `list[str]`"

View file

@ -0,0 +1,147 @@
# Bidirectional type inference
ty partially supports bidirectional type inference. This is a mechanism for inferring the type of an
expression "from the outside in". Normally, type inference proceeds "from the inside out". That is,
in order to infer the type of an expression, the types of all sub-expressions must first be
inferred. There is no reverse dependency. However, when performing complex type inference, such as
when generics are involved, the type of an outer expression can sometimes be useful in inferring
inner expressions. Bidirectional type inference is a mechanism that propagates such "expected types"
to the inference of inner expressions.
## Propagating target type annotation
```toml
[environment]
python-version = "3.12"
```
```py
def list1[T](x: T) -> list[T]:
return [x]
l1 = list1(1)
reveal_type(l1) # revealed: list[Literal[1]]
l2: list[int] = list1(1)
reveal_type(l2) # revealed: list[int]
# `list[Literal[1]]` and `list[int]` are incompatible, since `list[T]` is invariant in `T`.
# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`"
l2 = l1
intermediate = list1(1)
# TODO: the error will not occur if we can infer the type of `intermediate` to be `list[int]`
# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`"
l3: list[int] = intermediate
# TODO: it would be nice if this were `list[int]`
reveal_type(intermediate) # revealed: list[Literal[1]]
reveal_type(l3) # revealed: list[int]
l4: list[int | str] | None = list1(1)
reveal_type(l4) # revealed: list[int | str]
def _(l: list[int] | None = None):
l1 = l or list()
reveal_type(l1) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown]
l2: list[int] = l or list()
# it would be better if this were `list[int]`? (https://github.com/astral-sh/ty/issues/136)
reveal_type(l2) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown]
def f[T](x: T, cond: bool) -> T | list[T]:
return x if cond else [x]
# TODO: no error
# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`"
l5: int | list[int] = f(1, True)
```
`typed_dict.py`:
```py
from typing import TypedDict
class TD(TypedDict):
x: int
d1 = {"x": 1}
d2: TD = {"x": 1}
d3: dict[str, int] = {"x": 1}
reveal_type(d1) # revealed: dict[Unknown | str, Unknown | int]
reveal_type(d2) # revealed: TD
reveal_type(d3) # revealed: dict[str, int]
def _() -> TD:
return {"x": 1}
def _() -> TD:
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor"
return {}
```
## Propagating return type annotation
```toml
[environment]
python-version = "3.12"
```
```py
from typing import overload, Callable
def list1[T](x: T) -> list[T]:
return [x]
def get_data() -> dict | None:
return {}
def wrap_data() -> list[dict]:
if not (res := get_data()):
return list1({})
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
# `list[dict[Unknown, Unknown] & ~AlwaysFalsy]` and `list[dict[Unknown, Unknown]]` are incompatible,
# but the return type check passes here because the type of `list1(res)` is inferred
# by bidirectional type inference using the annotated return type, and the type of `res` is not used.
return list1(res)
def wrap_data2() -> list[dict] | None:
if not (res := get_data()):
return None
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
return list1(res)
def deco[T](func: Callable[[], T]) -> Callable[[], T]:
return func
def outer() -> Callable[[], list[dict]]:
@deco
def inner() -> list[dict]:
if not (res := get_data()):
return list1({})
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
return list1(res)
return inner
@overload
def f(x: int) -> list[int]: ...
@overload
def f(x: str) -> list[str]: ...
def f(x: int | str) -> list[int] | list[str]:
# `list[int] | list[str]` is disjoint from `list[int | str]`.
if isinstance(x, int):
return list1(x)
else:
return list1(x)
reveal_type(f(1)) # revealed: list[int]
reveal_type(f("a")) # revealed: list[str]
async def g() -> list[int | str]:
return list1(1)
def h[T](x: T, cond: bool) -> T | list[T]:
return i(x, cond)
def i[T](x: T, cond: bool) -> T | list[T]:
return x if cond else [x]
```

View file

@ -323,6 +323,9 @@ def union_param(x: T | None) -> T:
reveal_type(union_param("a")) # revealed: Literal["a"] reveal_type(union_param("a")) # revealed: Literal["a"]
reveal_type(union_param(1)) # revealed: Literal[1] reveal_type(union_param(1)) # revealed: Literal[1]
reveal_type(union_param(None)) # revealed: Unknown reveal_type(union_param(None)) # revealed: Unknown
def _(x: int | None):
reveal_type(union_param(x)) # revealed: int
``` ```
```py ```py

View file

@ -286,6 +286,9 @@ def union_param[T](x: T | None) -> T:
reveal_type(union_param("a")) # revealed: Literal["a"] reveal_type(union_param("a")) # revealed: Literal["a"]
reveal_type(union_param(1)) # revealed: Literal[1] reveal_type(union_param(1)) # revealed: Literal[1]
reveal_type(union_param(None)) # revealed: Unknown reveal_type(union_param(None)) # revealed: Unknown
def _(x: int | None):
reveal_type(union_param(x)) # revealed: int
``` ```
```py ```py

View file

@ -125,9 +125,10 @@ def homogeneous_list[T](*args: T) -> list[T]:
reveal_type(homogeneous_list(1, 2, 3)) # revealed: list[Literal[1, 2, 3]] reveal_type(homogeneous_list(1, 2, 3)) # revealed: list[Literal[1, 2, 3]]
plot2: Plot = {"y": homogeneous_list(1, 2, 3), "x": None} plot2: Plot = {"y": homogeneous_list(1, 2, 3), "x": None}
reveal_type(plot2["y"]) # revealed: list[int] reveal_type(plot2["y"]) # revealed: list[int]
# TODO: no error
# error: [invalid-argument-type]
plot3: Plot = {"y": homogeneous_list(1, 2, 3), "x": homogeneous_list(1, 2, 3)} plot3: Plot = {"y": homogeneous_list(1, 2, 3), "x": homogeneous_list(1, 2, 3)}
reveal_type(plot3["y"]) # revealed: list[int]
reveal_type(plot3["x"]) # revealed: list[int] | None
Y = "y" Y = "y"
X = "x" X = "x"
@ -362,7 +363,7 @@ qualifiers override the class-level `total` setting, which sets the default (`to
all keys are required by default, `total=False` means that all keys are non-required by default): all keys are required by default, `total=False` means that all keys are non-required by default):
```py ```py
from typing_extensions import TypedDict, Required, NotRequired from typing_extensions import TypedDict, Required, NotRequired, Final
# total=False by default, but id is explicitly Required # total=False by default, but id is explicitly Required
class Message(TypedDict, total=False): class Message(TypedDict, total=False):
@ -376,10 +377,17 @@ class User(TypedDict):
email: Required[str] # Explicitly required (redundant here) email: Required[str] # Explicitly required (redundant here)
bio: NotRequired[str] # Optional despite total=True bio: NotRequired[str] # Optional despite total=True
ID: Final = "id"
# Valid Message constructions # Valid Message constructions
msg1 = Message(id=1) # id required, content optional msg1 = Message(id=1) # id required, content optional
msg2 = Message(id=2, content="Hello") # both provided msg2 = Message(id=2, content="Hello") # both provided
msg3 = Message(id=3, timestamp="2024-01-01") # id required, timestamp optional msg3 = Message(id=3, timestamp="2024-01-01") # id required, timestamp optional
msg4: Message = {"id": 4} # id required, content optional
msg5: Message = {ID: 5} # id required, content optional
def msg() -> Message:
return {ID: 1}
# Valid User constructions # Valid User constructions
user1 = User(name="Alice", email="alice@example.com") # required fields user1 = User(name="Alice", email="alice@example.com") # required fields

View file

@ -977,6 +977,10 @@ impl<'db> Type<'db> {
} }
} }
pub(crate) fn has_type_var(self, db: &'db dyn Db) -> bool {
any_over_type(db, self, &|ty| matches!(ty, Type::TypeVar(_)), false)
}
pub(crate) const fn into_class_literal(self) -> Option<ClassLiteral<'db>> { pub(crate) const fn into_class_literal(self) -> Option<ClassLiteral<'db>> {
match self { match self {
Type::ClassLiteral(class_type) => Some(class_type), Type::ClassLiteral(class_type) => Some(class_type),
@ -1167,6 +1171,15 @@ impl<'db> Type<'db> {
if yes { self.negate(db) } else { *self } if yes { self.negate(db) } else { *self }
} }
/// Remove the union elements that are not related to `target`.
pub(crate) fn filter_disjoint_elements(self, db: &'db dyn Db, target: Type<'db>) -> Type<'db> {
if let Type::Union(union) = self {
union.filter(db, |elem| !elem.is_disjoint_from(db, target))
} else {
self
}
}
/// Returns the fallback instance type that a literal is an instance of, or `None` if the type /// Returns the fallback instance type that a literal is an instance of, or `None` if the type
/// is not a literal. /// is not a literal.
pub(crate) fn literal_fallback_instance(self, db: &'db dyn Db) -> Option<Type<'db>> { pub(crate) fn literal_fallback_instance(self, db: &'db dyn Db) -> Option<Type<'db>> {

View file

@ -341,6 +341,48 @@ impl<'db> OverloadLiteral<'db> {
/// a cross-module dependency directly on the full AST which will lead to cache /// a cross-module dependency directly on the full AST which will lead to cache
/// over-invalidation. /// over-invalidation.
pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> { pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> {
let mut signature = self.raw_signature(db);
let scope = self.body_scope(db);
let module = parsed_module(db, self.file(db)).load(db);
let function_node = scope.node(db).expect_function().node(&module);
let index = semantic_index(db, scope.file(db));
let file_scope_id = scope.file_scope_id(db);
let is_generator = file_scope_id.is_generator_function(index);
if function_node.is_async && !is_generator {
signature = signature.wrap_coroutine_return_type(db);
}
signature = signature.mark_typevars_inferable(db);
let pep695_ctx = function_node.type_params.as_ref().map(|type_params| {
GenericContext::from_type_params(db, index, self.definition(db), type_params)
});
let legacy_ctx = GenericContext::from_function_params(
db,
self.definition(db),
signature.parameters(),
signature.return_ty,
);
// We need to update `signature.generic_context` here,
// because type variables in `GenericContext::variables` are still non-inferable.
signature.generic_context =
GenericContext::merge_pep695_and_legacy(db, pep695_ctx, legacy_ctx);
signature
}
/// Typed internally-visible "raw" signature for this function.
/// That is, type variables in parameter types and the return type remain non-inferable,
/// and the return types of async functions are not wrapped in `CoroutineType[...]`.
///
/// ## Warning
///
/// This uses the semantic index to find the definition of the function. This means that if the
/// calling query is not in the same file as this function is defined in, then this will create
/// a cross-module dependency directly on the full AST which will lead to cache
/// over-invalidation.
fn raw_signature(self, db: &'db dyn Db) -> Signature<'db> {
/// `self` or `cls` can be implicitly positional-only if: /// `self` or `cls` can be implicitly positional-only if:
/// - It is a method AND /// - It is a method AND
/// - No parameters in the method use PEP-570 syntax AND /// - No parameters in the method use PEP-570 syntax AND
@ -402,11 +444,11 @@ impl<'db> OverloadLiteral<'db> {
let function_stmt_node = scope.node(db).expect_function().node(&module); let function_stmt_node = scope.node(db).expect_function().node(&module);
let definition = self.definition(db); let definition = self.definition(db);
let index = semantic_index(db, scope.file(db)); let index = semantic_index(db, scope.file(db));
let generic_context = function_stmt_node.type_params.as_ref().map(|type_params| { let pep695_ctx = function_stmt_node.type_params.as_ref().map(|type_params| {
GenericContext::from_type_params(db, index, definition, type_params) GenericContext::from_type_params(db, index, definition, type_params)
}); });
let file_scope_id = scope.file_scope_id(db); let file_scope_id = scope.file_scope_id(db);
let is_generator = file_scope_id.is_generator_function(index);
let has_implicitly_positional_first_parameter = has_implicitly_positional_only_first_param( let has_implicitly_positional_first_parameter = has_implicitly_positional_only_first_param(
db, db,
self, self,
@ -417,10 +459,9 @@ impl<'db> OverloadLiteral<'db> {
Signature::from_function( Signature::from_function(
db, db,
generic_context, pep695_ctx,
definition, definition,
function_stmt_node, function_stmt_node,
is_generator,
has_implicitly_positional_first_parameter, has_implicitly_positional_first_parameter,
) )
} }
@ -599,6 +640,18 @@ impl<'db> FunctionLiteral<'db> {
fn last_definition_signature(self, db: &'db dyn Db) -> Signature<'db> { fn last_definition_signature(self, db: &'db dyn Db) -> Signature<'db> {
self.last_definition(db).signature(db) self.last_definition(db).signature(db)
} }
/// Typed externally-visible "raw" signature of the last overload or implementation of this function.
///
/// ## Warning
///
/// This uses the semantic index to find the definition of the function. This means that if the
/// calling query is not in the same file as this function is defined in, then this will create
/// a cross-module dependency directly on the full AST which will lead to cache
/// over-invalidation.
fn last_definition_raw_signature(self, db: &'db dyn Db) -> Signature<'db> {
self.last_definition(db).raw_signature(db)
}
} }
/// Represents a function type, which might be a non-generic function, or a specialization of a /// Represents a function type, which might be a non-generic function, or a specialization of a
@ -877,6 +930,17 @@ impl<'db> FunctionType<'db> {
.unwrap_or_else(|| self.literal(db).last_definition_signature(db)) .unwrap_or_else(|| self.literal(db).last_definition_signature(db))
} }
/// Typed externally-visible "raw" signature of the last overload or implementation of this function.
#[salsa::tracked(
returns(ref),
cycle_fn=last_definition_signature_cycle_recover,
cycle_initial=last_definition_signature_cycle_initial,
heap_size=ruff_memory_usage::heap_size,
)]
pub(crate) fn last_definition_raw_signature(self, db: &'db dyn Db) -> Signature<'db> {
self.literal(db).last_definition_raw_signature(db)
}
/// Convert the `FunctionType` into a [`CallableType`]. /// Convert the `FunctionType` into a [`CallableType`].
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> CallableType<'db> { pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> CallableType<'db> {
CallableType::new(db, self.signature(db), false) CallableType::new(db, self.signature(db), false)

View file

@ -291,6 +291,28 @@ impl<'db> GenericContext<'db> {
Some(Self::from_typevar_instances(db, variables)) Some(Self::from_typevar_instances(db, variables))
} }
pub(crate) fn merge_pep695_and_legacy(
db: &'db dyn Db,
pep695_generic_context: Option<Self>,
legacy_generic_context: Option<Self>,
) -> Option<Self> {
match (legacy_generic_context, pep695_generic_context) {
(Some(legacy_ctx), Some(ctx)) => {
if legacy_ctx
.variables(db)
.exactly_one()
.is_ok_and(|bound_typevar| bound_typevar.typevar(db).is_self(db))
{
Some(legacy_ctx.merge(db, ctx))
} else {
// TODO: Raise a diagnostic — mixing PEP 695 and legacy typevars is not allowed
Some(ctx)
}
}
(left, right) => left.or(right),
}
}
/// Creates a generic context from the legacy `TypeVar`s that appear in class's base class /// Creates a generic context from the legacy `TypeVar`s that appear in class's base class
/// list. /// list.
pub(crate) fn from_base_classes( pub(crate) fn from_base_classes(
@ -1174,7 +1196,7 @@ impl<'db> SpecializationBuilder<'db> {
pub(crate) fn infer( pub(crate) fn infer(
&mut self, &mut self,
formal: Type<'db>, formal: Type<'db>,
actual: Type<'db>, mut actual: Type<'db>,
) -> Result<(), SpecializationError<'db>> { ) -> Result<(), SpecializationError<'db>> {
if formal == actual { if formal == actual {
return Ok(()); return Ok(());
@ -1203,6 +1225,10 @@ impl<'db> SpecializationBuilder<'db> {
return Ok(()); return Ok(());
} }
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`.
// So, here we remove the union elements that are not related to `formal`.
actual = actual.filter_disjoint_elements(self.db, formal);
match (formal, actual) { match (formal, actual) {
// TODO: We haven't implemented a full unification solver yet. If typevars appear in // TODO: We haven't implemented a full unification solver yet. If typevars appear in
// multiple union elements, we ideally want to express that _only one_ of them needs to // multiple union elements, we ideally want to express that _only one_ of them needs to
@ -1228,9 +1254,15 @@ impl<'db> SpecializationBuilder<'db> {
// def _(y: str | int | None): // def _(y: str | int | None):
// reveal_type(g(x)) # revealed: str | int // reveal_type(g(x)) # revealed: str | int
// ``` // ```
let formal_bound_typevars = // We do not handle cases where the `formal` types contain other types that contain type variables
(formal_union.elements(self.db).iter()).filter_map(|ty| ty.into_type_var()); // to prevent incorrect specialization: e.g. `T = int | list[int]` for `formal: T | list[T], actual: int | list[int]`
let Ok(formal_bound_typevar) = formal_bound_typevars.exactly_one() else { // (the correct specialization is `T = int`).
let types_have_typevars = formal_union
.elements(self.db)
.iter()
.filter(|ty| ty.has_type_var(self.db));
let Ok(Type::TypeVar(formal_bound_typevar)) = types_have_typevars.exactly_one()
else {
return Ok(()); return Ok(());
}; };
if (actual_union.elements(self.db).iter()).any(|ty| ty.is_type_var()) { if (actual_union.elements(self.db).iter()).any(|ty| ty.is_type_var()) {
@ -1241,7 +1273,7 @@ impl<'db> SpecializationBuilder<'db> {
if remaining_actual.is_never() { if remaining_actual.is_never() {
return Ok(()); return Ok(());
} }
self.add_type_mapping(formal_bound_typevar, remaining_actual); self.add_type_mapping(*formal_bound_typevar, remaining_actual);
} }
(Type::Union(formal), _) => { (Type::Union(formal), _) => {
// Second, if the formal is a union, and precisely one union element _is_ a typevar (not // Second, if the formal is a union, and precisely one union element _is_ a typevar (not

View file

@ -50,6 +50,7 @@ use crate::semantic_index::expression::Expression;
use crate::semantic_index::scope::ScopeId; use crate::semantic_index::scope::ScopeId;
use crate::semantic_index::{SemanticIndex, semantic_index}; use crate::semantic_index::{SemanticIndex, semantic_index};
use crate::types::diagnostic::TypeCheckDiagnostics; use crate::types::diagnostic::TypeCheckDiagnostics;
use crate::types::function::FunctionType;
use crate::types::generics::Specialization; use crate::types::generics::Specialization;
use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::unpacker::{UnpackResult, Unpacker};
use crate::types::{ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers}; use crate::types::{ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers};
@ -389,6 +390,12 @@ impl<'db> TypeContext<'db> {
self.annotation self.annotation
.and_then(|ty| ty.known_specialization(db, known_class)) .and_then(|ty| ty.known_specialization(db, known_class))
} }
pub(crate) fn map_annotation(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
Self {
annotation: self.annotation.map(f),
}
}
} }
/// Returns the statically-known truthiness of a given expression. /// Returns the statically-known truthiness of a given expression.
@ -487,6 +494,30 @@ pub(crate) fn nearest_enclosing_class<'db>(
}) })
} }
/// Returns the type of the nearest enclosing function for the given scope.
///
/// This function walks up the ancestor scopes starting from the given scope,
/// and finds the closest (non-lambda) function definition.
///
/// Returns `None` if no enclosing function is found.
pub(crate) fn nearest_enclosing_function<'db>(
db: &'db dyn Db,
semantic: &SemanticIndex<'db>,
scope: ScopeId,
) -> Option<FunctionType<'db>> {
semantic
.ancestor_scopes(scope.file_scope_id(db))
.find_map(|(_, ancestor_scope)| {
let func = ancestor_scope.node().as_function()?;
let definition = semantic.expect_single_definition(func);
let inference = infer_definition_types(db, definition);
inference
.undecorated_type()
.unwrap_or_else(|| inference.declaration_type(definition).inner_type())
.into_function_literal()
})
}
/// A region within which we can infer types. /// A region within which we can infer types.
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
pub(crate) enum InferenceRegion<'db> { pub(crate) enum InferenceRegion<'db> {

View file

@ -79,6 +79,7 @@ use crate::types::function::{
}; };
use crate::types::generics::{GenericContext, bind_typevar}; use crate::types::generics::{GenericContext, bind_typevar};
use crate::types::generics::{LegacyGenericBase, SpecializationBuilder}; use crate::types::generics::{LegacyGenericBase, SpecializationBuilder};
use crate::types::infer::nearest_enclosing_function;
use crate::types::instance::SliceLiteral; use crate::types::instance::SliceLiteral;
use crate::types::mro::MroErrorKind; use crate::types::mro::MroErrorKind;
use crate::types::signatures::Signature; use crate::types::signatures::Signature;
@ -5101,9 +5102,20 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} }
fn infer_return_statement(&mut self, ret: &ast::StmtReturn) { fn infer_return_statement(&mut self, ret: &ast::StmtReturn) {
if let Some(ty) = let tcx = if ret.value.is_some() {
self.infer_optional_expression(ret.value.as_deref(), TypeContext::default()) nearest_enclosing_function(self.db(), self.index, self.scope())
{ .map(|func| {
// When inferring expressions within a function body,
// the expected type passed should be the "raw" type,
// i.e. type variables in the return type are non-inferable,
// and the return types of async functions are not wrapped in `CoroutineType[...]`.
TypeContext::new(func.last_definition_raw_signature(self.db()).return_ty)
})
.unwrap_or_default()
} else {
TypeContext::default()
};
if let Some(ty) = self.infer_optional_expression(ret.value.as_deref(), tcx) {
let range = ret let range = ret
.value .value
.as_ref() .as_ref()
@ -5900,6 +5912,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
return None; return None;
}; };
let tcx = tcx.map_annotation(|annotation| {
// Remove any union elements of `annotation` that are not related to `collection_ty`.
// e.g. `annotation: list[int] | None => list[int]` if `collection_ty: list`
let collection_ty = collection_class.to_instance(self.db());
annotation.filter_disjoint_elements(self.db(), collection_ty)
});
// Extract the annotated type of `T`, if provided. // Extract the annotated type of `T`, if provided.
let annotated_elt_tys = tcx let annotated_elt_tys = tcx
.known_specialization(self.db(), collection_class) .known_specialization(self.db(), collection_class)

View file

@ -26,9 +26,10 @@ use crate::types::function::FunctionType;
use crate::types::generics::{GenericContext, typing_self, walk_generic_context}; use crate::types::generics::{GenericContext, typing_self, walk_generic_context};
use crate::types::infer::nearest_enclosing_class; use crate::types::infer::nearest_enclosing_class;
use crate::types::{ use crate::types::{
ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassLiteral, FindLegacyTypeVarsVisitor, ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, ClassLiteral,
HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, KnownClass, MaterializationKind, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor,
NormalizedVisitor, TypeContext, TypeMapping, TypeRelation, VarianceInferable, todo_type, KnownClass, MaterializationKind, NormalizedVisitor, TypeContext, TypeMapping, TypeRelation,
VarianceInferable, todo_type,
}; };
use crate::{Db, FxOrderSet}; use crate::{Db, FxOrderSet};
use ruff_python_ast::{self as ast, name::Name}; use ruff_python_ast::{self as ast, name::Name};
@ -419,10 +420,9 @@ impl<'db> Signature<'db> {
/// Return a typed signature from a function definition. /// Return a typed signature from a function definition.
pub(super) fn from_function( pub(super) fn from_function(
db: &'db dyn Db, db: &'db dyn Db,
generic_context: Option<GenericContext<'db>>, pep695_generic_context: Option<GenericContext<'db>>,
definition: Definition<'db>, definition: Definition<'db>,
function_node: &ast::StmtFunctionDef, function_node: &ast::StmtFunctionDef,
is_generator: bool,
has_implicitly_positional_first_parameter: bool, has_implicitly_positional_first_parameter: bool,
) -> Self { ) -> Self {
let parameters = Parameters::from_parameters( let parameters = Parameters::from_parameters(
@ -431,38 +431,17 @@ impl<'db> Signature<'db> {
function_node.parameters.as_ref(), function_node.parameters.as_ref(),
has_implicitly_positional_first_parameter, has_implicitly_positional_first_parameter,
); );
let return_ty = function_node.returns.as_ref().map(|returns| { let return_ty = function_node
let plain_return_ty = definition_expression_type(db, definition, returns.as_ref()) .returns
.apply_type_mapping( .as_ref()
db, .map(|returns| definition_expression_type(db, definition, returns.as_ref()));
&TypeMapping::MarkTypeVarsInferable(Some(definition.into())),
TypeContext::default(),
);
if function_node.is_async && !is_generator {
KnownClass::CoroutineType
.to_specialized_instance(db, [Type::any(), Type::any(), plain_return_ty])
} else {
plain_return_ty
}
});
let legacy_generic_context = let legacy_generic_context =
GenericContext::from_function_params(db, definition, &parameters, return_ty); GenericContext::from_function_params(db, definition, &parameters, return_ty);
let full_generic_context = GenericContext::merge_pep695_and_legacy(
let full_generic_context = match (legacy_generic_context, generic_context) { db,
(Some(legacy_ctx), Some(ctx)) => { pep695_generic_context,
if legacy_ctx legacy_generic_context,
.variables(db) );
.exactly_one()
.is_ok_and(|bound_typevar| bound_typevar.typevar(db).is_self(db))
{
Some(legacy_ctx.merge(db, ctx))
} else {
// TODO: Raise a diagnostic — mixing PEP 695 and legacy typevars is not allowed
Some(ctx)
}
}
(left, right) => left.or(right),
};
Self { Self {
generic_context: full_generic_context, generic_context: full_generic_context,
@ -472,6 +451,27 @@ impl<'db> Signature<'db> {
} }
} }
pub(super) fn mark_typevars_inferable(self, db: &'db dyn Db) -> Self {
if let Some(definition) = self.definition {
self.apply_type_mapping_impl(
db,
&TypeMapping::MarkTypeVarsInferable(Some(BindingContext::Definition(definition))),
TypeContext::default(),
&ApplyTypeMappingVisitor::default(),
)
} else {
self
}
}
pub(super) fn wrap_coroutine_return_type(self, db: &'db dyn Db) -> Self {
let return_ty = self.return_ty.map(|return_ty| {
KnownClass::CoroutineType
.to_specialized_instance(db, [Type::any(), Type::any(), return_ty])
});
Self { return_ty, ..self }
}
/// Returns the signature which accepts any parameters and returns an `Unknown` type. /// Returns the signature which accepts any parameters and returns an `Unknown` type.
pub(crate) fn unknown() -> Self { pub(crate) fn unknown() -> Self {
Self::new(Parameters::unknown(), Some(Type::unknown())) Self::new(Parameters::unknown(), Some(Type::unknown()))
@ -1728,13 +1728,9 @@ impl<'db> Parameter<'db> {
kind: ParameterKind<'db>, kind: ParameterKind<'db>,
) -> Self { ) -> Self {
Self { Self {
annotated_type: parameter.annotation().map(|annotation| { annotated_type: parameter
definition_expression_type(db, definition, annotation).apply_type_mapping( .annotation()
db, .map(|annotation| definition_expression_type(db, definition, annotation)),
&TypeMapping::MarkTypeVarsInferable(Some(definition.into())),
TypeContext::default(),
)
}),
kind, kind,
form: ParameterForm::Value, form: ParameterForm::Value,
inferred_annotation: false, inferred_annotation: false,