[ty] bidirectional type inference using function return type annotations (#20528)
Some checks failed
CI / cargo fmt (push) Has been cancelled
CI / mkdocs (push) Has been cancelled
CI / Determine changes (push) Has been cancelled
CI / cargo build (release) (push) Has been cancelled
CI / python package (push) Has been cancelled
CI / pre-commit (push) Has been cancelled
[ty Playground] Release / publish (push) Has been cancelled
CI / cargo clippy (push) Has been cancelled
CI / cargo test (linux) (push) Has been cancelled
CI / cargo test (linux, release) (push) Has been cancelled
CI / cargo test (windows) (push) Has been cancelled
CI / cargo test (wasm) (push) Has been cancelled
CI / cargo build (msrv) (push) Has been cancelled
CI / cargo fuzz build (push) Has been cancelled
CI / fuzz parser (push) Has been cancelled
CI / test scripts (push) Has been cancelled
CI / ecosystem (push) Has been cancelled
CI / Fuzz for new ty panics (push) Has been cancelled
CI / cargo shear (push) Has been cancelled
CI / ty completion evaluation (push) Has been cancelled
CI / formatter instabilities and black similarity (push) Has been cancelled
CI / test ruff-lsp (push) Has been cancelled
CI / check playground (push) Has been cancelled
CI / benchmarks instrumented (ruff) (push) Has been cancelled
CI / benchmarks instrumented (ty) (push) Has been cancelled
CI / benchmarks walltime (medium|multithreaded) (push) Has been cancelled
CI / benchmarks walltime (small|large) (push) Has been cancelled

## Summary

Implements bidirectional type inference using function return type
annotations.

This PR was originally proposed to solve astral-sh/ty#1167, but this
does not fully resolve it on its own.
Additionally, I believe we need to allow dataclasses to generate their
own `__new__` methods, [use constructor return types ​​for
inference](5844c0103d/crates/ty_python_semantic/src/types.rs (L5326-L5328)),
and a mechanism to discard type narrowing like `& ~AlwaysFalsy` if
necessary (at a more general level than this PR).

## Test Plan

`mdtest/bidirectional.md` is added.

---------

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Ibraheem Ahmed <ibraheem@ibraheem.ca>
This commit is contained in:
Shunsuke Shibayama 2025-10-11 09:38:35 +09:00 committed by GitHub
parent 11a9e7ee44
commit dc64c08633
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 442 additions and 58 deletions

View file

@ -146,7 +146,75 @@ r: dict[int | str, int | str] = {1: 1, 2: 2, 3: 3}
reveal_type(r) # revealed: dict[int | str, int | str]
```
## Incorrect collection literal assignments are complained aobut
## Optional collection literal annotations are understood
```toml
[environment]
python-version = "3.12"
```
```py
import typing
a: list[int] | None = [1, 2, 3]
reveal_type(a) # revealed: list[int]
b: list[int | str] | None = [1, 2, 3]
reveal_type(b) # revealed: list[int | str]
c: typing.List[int] | None = [1, 2, 3]
reveal_type(c) # revealed: list[int]
d: list[typing.Any] | None = []
reveal_type(d) # revealed: list[Any]
e: set[int] | None = {1, 2, 3}
reveal_type(e) # revealed: set[int]
f: set[int | str] | None = {1, 2, 3}
reveal_type(f) # revealed: set[int | str]
g: typing.Set[int] | None = {1, 2, 3}
reveal_type(g) # revealed: set[int]
h: list[list[int]] | None = [[], [42]]
reveal_type(h) # revealed: list[list[int]]
i: list[typing.Any] | None = [1, 2, "3", ([4],)]
reveal_type(i) # revealed: list[Any | int | str | tuple[list[Unknown | int]]]
j: list[tuple[str | int, ...]] | None = [(1, 2), ("foo", "bar"), ()]
reveal_type(j) # revealed: list[tuple[str | int, ...]]
k: list[tuple[list[int], ...]] | None = [([],), ([1, 2], [3, 4]), ([5], [6], [7])]
reveal_type(k) # revealed: list[tuple[list[int], ...]]
l: tuple[list[int], *tuple[list[typing.Any], ...], list[str]] | None = ([1, 2, 3], [4, 5, 6], [7, 8, 9], ["10", "11", "12"])
# TODO: this should be `tuple[list[int], list[Any | int], list[Any | int], list[str]]`
reveal_type(l) # revealed: tuple[list[Unknown | int], list[Unknown | int], list[Unknown | int], list[Unknown | str]]
type IntList = list[int]
m: IntList | None = [1, 2, 3]
reveal_type(m) # revealed: list[int]
n: list[typing.Literal[1, 2, 3]] | None = [1, 2, 3]
reveal_type(n) # revealed: list[Literal[1, 2, 3]]
o: list[typing.LiteralString] | None = ["a", "b", "c"]
reveal_type(o) # revealed: list[LiteralString]
p: dict[int, int] | None = {}
reveal_type(p) # revealed: dict[int, int]
q: dict[int | str, int] | None = {1: 1, 2: 2, 3: 3}
reveal_type(q) # revealed: dict[int | str, int]
r: dict[int | str, int | str] | None = {1: 1, 2: 2, 3: 3}
reveal_type(r) # revealed: dict[int | str, int | str]
```
## Incorrect collection literal assignments are complained about
```py
# error: [invalid-assignment] "Object of type `list[Unknown | int]` is not assignable to `list[str]`"

View file

@ -0,0 +1,147 @@
# Bidirectional type inference
ty partially supports bidirectional type inference. This is a mechanism for inferring the type of an
expression "from the outside in". Normally, type inference proceeds "from the inside out". That is,
in order to infer the type of an expression, the types of all sub-expressions must first be
inferred. There is no reverse dependency. However, when performing complex type inference, such as
when generics are involved, the type of an outer expression can sometimes be useful in inferring
inner expressions. Bidirectional type inference is a mechanism that propagates such "expected types"
to the inference of inner expressions.
## Propagating target type annotation
```toml
[environment]
python-version = "3.12"
```
```py
def list1[T](x: T) -> list[T]:
return [x]
l1 = list1(1)
reveal_type(l1) # revealed: list[Literal[1]]
l2: list[int] = list1(1)
reveal_type(l2) # revealed: list[int]
# `list[Literal[1]]` and `list[int]` are incompatible, since `list[T]` is invariant in `T`.
# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`"
l2 = l1
intermediate = list1(1)
# TODO: the error will not occur if we can infer the type of `intermediate` to be `list[int]`
# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`"
l3: list[int] = intermediate
# TODO: it would be nice if this were `list[int]`
reveal_type(intermediate) # revealed: list[Literal[1]]
reveal_type(l3) # revealed: list[int]
l4: list[int | str] | None = list1(1)
reveal_type(l4) # revealed: list[int | str]
def _(l: list[int] | None = None):
l1 = l or list()
reveal_type(l1) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown]
l2: list[int] = l or list()
# it would be better if this were `list[int]`? (https://github.com/astral-sh/ty/issues/136)
reveal_type(l2) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown]
def f[T](x: T, cond: bool) -> T | list[T]:
return x if cond else [x]
# TODO: no error
# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`"
l5: int | list[int] = f(1, True)
```
`typed_dict.py`:
```py
from typing import TypedDict
class TD(TypedDict):
x: int
d1 = {"x": 1}
d2: TD = {"x": 1}
d3: dict[str, int] = {"x": 1}
reveal_type(d1) # revealed: dict[Unknown | str, Unknown | int]
reveal_type(d2) # revealed: TD
reveal_type(d3) # revealed: dict[str, int]
def _() -> TD:
return {"x": 1}
def _() -> TD:
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor"
return {}
```
## Propagating return type annotation
```toml
[environment]
python-version = "3.12"
```
```py
from typing import overload, Callable
def list1[T](x: T) -> list[T]:
return [x]
def get_data() -> dict | None:
return {}
def wrap_data() -> list[dict]:
if not (res := get_data()):
return list1({})
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
# `list[dict[Unknown, Unknown] & ~AlwaysFalsy]` and `list[dict[Unknown, Unknown]]` are incompatible,
# but the return type check passes here because the type of `list1(res)` is inferred
# by bidirectional type inference using the annotated return type, and the type of `res` is not used.
return list1(res)
def wrap_data2() -> list[dict] | None:
if not (res := get_data()):
return None
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
return list1(res)
def deco[T](func: Callable[[], T]) -> Callable[[], T]:
return func
def outer() -> Callable[[], list[dict]]:
@deco
def inner() -> list[dict]:
if not (res := get_data()):
return list1({})
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
return list1(res)
return inner
@overload
def f(x: int) -> list[int]: ...
@overload
def f(x: str) -> list[str]: ...
def f(x: int | str) -> list[int] | list[str]:
# `list[int] | list[str]` is disjoint from `list[int | str]`.
if isinstance(x, int):
return list1(x)
else:
return list1(x)
reveal_type(f(1)) # revealed: list[int]
reveal_type(f("a")) # revealed: list[str]
async def g() -> list[int | str]:
return list1(1)
def h[T](x: T, cond: bool) -> T | list[T]:
return i(x, cond)
def i[T](x: T, cond: bool) -> T | list[T]:
return x if cond else [x]
```

View file

@ -323,6 +323,9 @@ def union_param(x: T | None) -> T:
reveal_type(union_param("a")) # revealed: Literal["a"]
reveal_type(union_param(1)) # revealed: Literal[1]
reveal_type(union_param(None)) # revealed: Unknown
def _(x: int | None):
reveal_type(union_param(x)) # revealed: int
```
```py

View file

@ -286,6 +286,9 @@ def union_param[T](x: T | None) -> T:
reveal_type(union_param("a")) # revealed: Literal["a"]
reveal_type(union_param(1)) # revealed: Literal[1]
reveal_type(union_param(None)) # revealed: Unknown
def _(x: int | None):
reveal_type(union_param(x)) # revealed: int
```
```py

View file

@ -125,9 +125,10 @@ def homogeneous_list[T](*args: T) -> list[T]:
reveal_type(homogeneous_list(1, 2, 3)) # revealed: list[Literal[1, 2, 3]]
plot2: Plot = {"y": homogeneous_list(1, 2, 3), "x": None}
reveal_type(plot2["y"]) # revealed: list[int]
# TODO: no error
# error: [invalid-argument-type]
plot3: Plot = {"y": homogeneous_list(1, 2, 3), "x": homogeneous_list(1, 2, 3)}
reveal_type(plot3["y"]) # revealed: list[int]
reveal_type(plot3["x"]) # revealed: list[int] | None
Y = "y"
X = "x"
@ -362,7 +363,7 @@ qualifiers override the class-level `total` setting, which sets the default (`to
all keys are required by default, `total=False` means that all keys are non-required by default):
```py
from typing_extensions import TypedDict, Required, NotRequired
from typing_extensions import TypedDict, Required, NotRequired, Final
# total=False by default, but id is explicitly Required
class Message(TypedDict, total=False):
@ -376,10 +377,17 @@ class User(TypedDict):
email: Required[str] # Explicitly required (redundant here)
bio: NotRequired[str] # Optional despite total=True
ID: Final = "id"
# Valid Message constructions
msg1 = Message(id=1) # id required, content optional
msg2 = Message(id=2, content="Hello") # both provided
msg3 = Message(id=3, timestamp="2024-01-01") # id required, timestamp optional
msg4: Message = {"id": 4} # id required, content optional
msg5: Message = {ID: 5} # id required, content optional
def msg() -> Message:
return {ID: 1}
# Valid User constructions
user1 = User(name="Alice", email="alice@example.com") # required fields

View file

@ -977,6 +977,10 @@ impl<'db> Type<'db> {
}
}
pub(crate) fn has_type_var(self, db: &'db dyn Db) -> bool {
any_over_type(db, self, &|ty| matches!(ty, Type::TypeVar(_)), false)
}
pub(crate) const fn into_class_literal(self) -> Option<ClassLiteral<'db>> {
match self {
Type::ClassLiteral(class_type) => Some(class_type),
@ -1167,6 +1171,15 @@ impl<'db> Type<'db> {
if yes { self.negate(db) } else { *self }
}
/// Remove the union elements that are not related to `target`.
pub(crate) fn filter_disjoint_elements(self, db: &'db dyn Db, target: Type<'db>) -> Type<'db> {
if let Type::Union(union) = self {
union.filter(db, |elem| !elem.is_disjoint_from(db, target))
} else {
self
}
}
/// Returns the fallback instance type that a literal is an instance of, or `None` if the type
/// is not a literal.
pub(crate) fn literal_fallback_instance(self, db: &'db dyn Db) -> Option<Type<'db>> {

View file

@ -341,6 +341,48 @@ impl<'db> OverloadLiteral<'db> {
/// a cross-module dependency directly on the full AST which will lead to cache
/// over-invalidation.
pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> {
let mut signature = self.raw_signature(db);
let scope = self.body_scope(db);
let module = parsed_module(db, self.file(db)).load(db);
let function_node = scope.node(db).expect_function().node(&module);
let index = semantic_index(db, scope.file(db));
let file_scope_id = scope.file_scope_id(db);
let is_generator = file_scope_id.is_generator_function(index);
if function_node.is_async && !is_generator {
signature = signature.wrap_coroutine_return_type(db);
}
signature = signature.mark_typevars_inferable(db);
let pep695_ctx = function_node.type_params.as_ref().map(|type_params| {
GenericContext::from_type_params(db, index, self.definition(db), type_params)
});
let legacy_ctx = GenericContext::from_function_params(
db,
self.definition(db),
signature.parameters(),
signature.return_ty,
);
// We need to update `signature.generic_context` here,
// because type variables in `GenericContext::variables` are still non-inferable.
signature.generic_context =
GenericContext::merge_pep695_and_legacy(db, pep695_ctx, legacy_ctx);
signature
}
/// Typed internally-visible "raw" signature for this function.
/// That is, type variables in parameter types and the return type remain non-inferable,
/// and the return types of async functions are not wrapped in `CoroutineType[...]`.
///
/// ## Warning
///
/// This uses the semantic index to find the definition of the function. This means that if the
/// calling query is not in the same file as this function is defined in, then this will create
/// a cross-module dependency directly on the full AST which will lead to cache
/// over-invalidation.
fn raw_signature(self, db: &'db dyn Db) -> Signature<'db> {
/// `self` or `cls` can be implicitly positional-only if:
/// - It is a method AND
/// - No parameters in the method use PEP-570 syntax AND
@ -402,11 +444,11 @@ impl<'db> OverloadLiteral<'db> {
let function_stmt_node = scope.node(db).expect_function().node(&module);
let definition = self.definition(db);
let index = semantic_index(db, scope.file(db));
let generic_context = function_stmt_node.type_params.as_ref().map(|type_params| {
let pep695_ctx = function_stmt_node.type_params.as_ref().map(|type_params| {
GenericContext::from_type_params(db, index, definition, type_params)
});
let file_scope_id = scope.file_scope_id(db);
let is_generator = file_scope_id.is_generator_function(index);
let has_implicitly_positional_first_parameter = has_implicitly_positional_only_first_param(
db,
self,
@ -417,10 +459,9 @@ impl<'db> OverloadLiteral<'db> {
Signature::from_function(
db,
generic_context,
pep695_ctx,
definition,
function_stmt_node,
is_generator,
has_implicitly_positional_first_parameter,
)
}
@ -599,6 +640,18 @@ impl<'db> FunctionLiteral<'db> {
fn last_definition_signature(self, db: &'db dyn Db) -> Signature<'db> {
self.last_definition(db).signature(db)
}
/// Typed externally-visible "raw" signature of the last overload or implementation of this function.
///
/// ## Warning
///
/// This uses the semantic index to find the definition of the function. This means that if the
/// calling query is not in the same file as this function is defined in, then this will create
/// a cross-module dependency directly on the full AST which will lead to cache
/// over-invalidation.
fn last_definition_raw_signature(self, db: &'db dyn Db) -> Signature<'db> {
self.last_definition(db).raw_signature(db)
}
}
/// Represents a function type, which might be a non-generic function, or a specialization of a
@ -877,6 +930,17 @@ impl<'db> FunctionType<'db> {
.unwrap_or_else(|| self.literal(db).last_definition_signature(db))
}
/// Typed externally-visible "raw" signature of the last overload or implementation of this function.
#[salsa::tracked(
returns(ref),
cycle_fn=last_definition_signature_cycle_recover,
cycle_initial=last_definition_signature_cycle_initial,
heap_size=ruff_memory_usage::heap_size,
)]
pub(crate) fn last_definition_raw_signature(self, db: &'db dyn Db) -> Signature<'db> {
self.literal(db).last_definition_raw_signature(db)
}
/// Convert the `FunctionType` into a [`CallableType`].
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> CallableType<'db> {
CallableType::new(db, self.signature(db), false)

View file

@ -291,6 +291,28 @@ impl<'db> GenericContext<'db> {
Some(Self::from_typevar_instances(db, variables))
}
pub(crate) fn merge_pep695_and_legacy(
db: &'db dyn Db,
pep695_generic_context: Option<Self>,
legacy_generic_context: Option<Self>,
) -> Option<Self> {
match (legacy_generic_context, pep695_generic_context) {
(Some(legacy_ctx), Some(ctx)) => {
if legacy_ctx
.variables(db)
.exactly_one()
.is_ok_and(|bound_typevar| bound_typevar.typevar(db).is_self(db))
{
Some(legacy_ctx.merge(db, ctx))
} else {
// TODO: Raise a diagnostic — mixing PEP 695 and legacy typevars is not allowed
Some(ctx)
}
}
(left, right) => left.or(right),
}
}
/// Creates a generic context from the legacy `TypeVar`s that appear in class's base class
/// list.
pub(crate) fn from_base_classes(
@ -1174,7 +1196,7 @@ impl<'db> SpecializationBuilder<'db> {
pub(crate) fn infer(
&mut self,
formal: Type<'db>,
actual: Type<'db>,
mut actual: Type<'db>,
) -> Result<(), SpecializationError<'db>> {
if formal == actual {
return Ok(());
@ -1203,6 +1225,10 @@ impl<'db> SpecializationBuilder<'db> {
return Ok(());
}
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`.
// So, here we remove the union elements that are not related to `formal`.
actual = actual.filter_disjoint_elements(self.db, formal);
match (formal, actual) {
// TODO: We haven't implemented a full unification solver yet. If typevars appear in
// multiple union elements, we ideally want to express that _only one_ of them needs to
@ -1228,9 +1254,15 @@ impl<'db> SpecializationBuilder<'db> {
// def _(y: str | int | None):
// reveal_type(g(x)) # revealed: str | int
// ```
let formal_bound_typevars =
(formal_union.elements(self.db).iter()).filter_map(|ty| ty.into_type_var());
let Ok(formal_bound_typevar) = formal_bound_typevars.exactly_one() else {
// We do not handle cases where the `formal` types contain other types that contain type variables
// to prevent incorrect specialization: e.g. `T = int | list[int]` for `formal: T | list[T], actual: int | list[int]`
// (the correct specialization is `T = int`).
let types_have_typevars = formal_union
.elements(self.db)
.iter()
.filter(|ty| ty.has_type_var(self.db));
let Ok(Type::TypeVar(formal_bound_typevar)) = types_have_typevars.exactly_one()
else {
return Ok(());
};
if (actual_union.elements(self.db).iter()).any(|ty| ty.is_type_var()) {
@ -1241,7 +1273,7 @@ impl<'db> SpecializationBuilder<'db> {
if remaining_actual.is_never() {
return Ok(());
}
self.add_type_mapping(formal_bound_typevar, remaining_actual);
self.add_type_mapping(*formal_bound_typevar, remaining_actual);
}
(Type::Union(formal), _) => {
// Second, if the formal is a union, and precisely one union element _is_ a typevar (not

View file

@ -50,6 +50,7 @@ use crate::semantic_index::expression::Expression;
use crate::semantic_index::scope::ScopeId;
use crate::semantic_index::{SemanticIndex, semantic_index};
use crate::types::diagnostic::TypeCheckDiagnostics;
use crate::types::function::FunctionType;
use crate::types::generics::Specialization;
use crate::types::unpacker::{UnpackResult, Unpacker};
use crate::types::{ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers};
@ -389,6 +390,12 @@ impl<'db> TypeContext<'db> {
self.annotation
.and_then(|ty| ty.known_specialization(db, known_class))
}
pub(crate) fn map_annotation(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
Self {
annotation: self.annotation.map(f),
}
}
}
/// Returns the statically-known truthiness of a given expression.
@ -487,6 +494,30 @@ pub(crate) fn nearest_enclosing_class<'db>(
})
}
/// Returns the type of the nearest enclosing function for the given scope.
///
/// This function walks up the ancestor scopes starting from the given scope,
/// and finds the closest (non-lambda) function definition.
///
/// Returns `None` if no enclosing function is found.
pub(crate) fn nearest_enclosing_function<'db>(
db: &'db dyn Db,
semantic: &SemanticIndex<'db>,
scope: ScopeId,
) -> Option<FunctionType<'db>> {
semantic
.ancestor_scopes(scope.file_scope_id(db))
.find_map(|(_, ancestor_scope)| {
let func = ancestor_scope.node().as_function()?;
let definition = semantic.expect_single_definition(func);
let inference = infer_definition_types(db, definition);
inference
.undecorated_type()
.unwrap_or_else(|| inference.declaration_type(definition).inner_type())
.into_function_literal()
})
}
/// A region within which we can infer types.
#[derive(Copy, Clone, Debug)]
pub(crate) enum InferenceRegion<'db> {

View file

@ -79,6 +79,7 @@ use crate::types::function::{
};
use crate::types::generics::{GenericContext, bind_typevar};
use crate::types::generics::{LegacyGenericBase, SpecializationBuilder};
use crate::types::infer::nearest_enclosing_function;
use crate::types::instance::SliceLiteral;
use crate::types::mro::MroErrorKind;
use crate::types::signatures::Signature;
@ -5101,9 +5102,20 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
fn infer_return_statement(&mut self, ret: &ast::StmtReturn) {
if let Some(ty) =
self.infer_optional_expression(ret.value.as_deref(), TypeContext::default())
{
let tcx = if ret.value.is_some() {
nearest_enclosing_function(self.db(), self.index, self.scope())
.map(|func| {
// When inferring expressions within a function body,
// the expected type passed should be the "raw" type,
// i.e. type variables in the return type are non-inferable,
// and the return types of async functions are not wrapped in `CoroutineType[...]`.
TypeContext::new(func.last_definition_raw_signature(self.db()).return_ty)
})
.unwrap_or_default()
} else {
TypeContext::default()
};
if let Some(ty) = self.infer_optional_expression(ret.value.as_deref(), tcx) {
let range = ret
.value
.as_ref()
@ -5900,6 +5912,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
return None;
};
let tcx = tcx.map_annotation(|annotation| {
// Remove any union elements of `annotation` that are not related to `collection_ty`.
// e.g. `annotation: list[int] | None => list[int]` if `collection_ty: list`
let collection_ty = collection_class.to_instance(self.db());
annotation.filter_disjoint_elements(self.db(), collection_ty)
});
// Extract the annotated type of `T`, if provided.
let annotated_elt_tys = tcx
.known_specialization(self.db(), collection_class)

View file

@ -26,9 +26,10 @@ use crate::types::function::FunctionType;
use crate::types::generics::{GenericContext, typing_self, walk_generic_context};
use crate::types::infer::nearest_enclosing_class;
use crate::types::{
ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassLiteral, FindLegacyTypeVarsVisitor,
HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, KnownClass, MaterializationKind,
NormalizedVisitor, TypeContext, TypeMapping, TypeRelation, VarianceInferable, todo_type,
ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, ClassLiteral,
FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor,
KnownClass, MaterializationKind, NormalizedVisitor, TypeContext, TypeMapping, TypeRelation,
VarianceInferable, todo_type,
};
use crate::{Db, FxOrderSet};
use ruff_python_ast::{self as ast, name::Name};
@ -419,10 +420,9 @@ impl<'db> Signature<'db> {
/// Return a typed signature from a function definition.
pub(super) fn from_function(
db: &'db dyn Db,
generic_context: Option<GenericContext<'db>>,
pep695_generic_context: Option<GenericContext<'db>>,
definition: Definition<'db>,
function_node: &ast::StmtFunctionDef,
is_generator: bool,
has_implicitly_positional_first_parameter: bool,
) -> Self {
let parameters = Parameters::from_parameters(
@ -431,38 +431,17 @@ impl<'db> Signature<'db> {
function_node.parameters.as_ref(),
has_implicitly_positional_first_parameter,
);
let return_ty = function_node.returns.as_ref().map(|returns| {
let plain_return_ty = definition_expression_type(db, definition, returns.as_ref())
.apply_type_mapping(
db,
&TypeMapping::MarkTypeVarsInferable(Some(definition.into())),
TypeContext::default(),
);
if function_node.is_async && !is_generator {
KnownClass::CoroutineType
.to_specialized_instance(db, [Type::any(), Type::any(), plain_return_ty])
} else {
plain_return_ty
}
});
let return_ty = function_node
.returns
.as_ref()
.map(|returns| definition_expression_type(db, definition, returns.as_ref()));
let legacy_generic_context =
GenericContext::from_function_params(db, definition, &parameters, return_ty);
let full_generic_context = match (legacy_generic_context, generic_context) {
(Some(legacy_ctx), Some(ctx)) => {
if legacy_ctx
.variables(db)
.exactly_one()
.is_ok_and(|bound_typevar| bound_typevar.typevar(db).is_self(db))
{
Some(legacy_ctx.merge(db, ctx))
} else {
// TODO: Raise a diagnostic — mixing PEP 695 and legacy typevars is not allowed
Some(ctx)
}
}
(left, right) => left.or(right),
};
let full_generic_context = GenericContext::merge_pep695_and_legacy(
db,
pep695_generic_context,
legacy_generic_context,
);
Self {
generic_context: full_generic_context,
@ -472,6 +451,27 @@ impl<'db> Signature<'db> {
}
}
pub(super) fn mark_typevars_inferable(self, db: &'db dyn Db) -> Self {
if let Some(definition) = self.definition {
self.apply_type_mapping_impl(
db,
&TypeMapping::MarkTypeVarsInferable(Some(BindingContext::Definition(definition))),
TypeContext::default(),
&ApplyTypeMappingVisitor::default(),
)
} else {
self
}
}
pub(super) fn wrap_coroutine_return_type(self, db: &'db dyn Db) -> Self {
let return_ty = self.return_ty.map(|return_ty| {
KnownClass::CoroutineType
.to_specialized_instance(db, [Type::any(), Type::any(), return_ty])
});
Self { return_ty, ..self }
}
/// Returns the signature which accepts any parameters and returns an `Unknown` type.
pub(crate) fn unknown() -> Self {
Self::new(Parameters::unknown(), Some(Type::unknown()))
@ -1728,13 +1728,9 @@ impl<'db> Parameter<'db> {
kind: ParameterKind<'db>,
) -> Self {
Self {
annotated_type: parameter.annotation().map(|annotation| {
definition_expression_type(db, definition, annotation).apply_type_mapping(
db,
&TypeMapping::MarkTypeVarsInferable(Some(definition.into())),
TypeContext::default(),
)
}),
annotated_type: parameter
.annotation()
.map(|annotation| definition_expression_type(db, definition, annotation)),
kind,
form: ParameterForm::Value,
inferred_annotation: false,