mirror of
https://github.com/astral-sh/ruff.git
synced 2025-10-29 03:02:27 +00:00
[ty] bidirectional type inference using function return type annotations (#20528)
Some checks failed
CI / cargo fmt (push) Has been cancelled
CI / mkdocs (push) Has been cancelled
CI / Determine changes (push) Has been cancelled
CI / cargo build (release) (push) Has been cancelled
CI / python package (push) Has been cancelled
CI / pre-commit (push) Has been cancelled
[ty Playground] Release / publish (push) Has been cancelled
CI / cargo clippy (push) Has been cancelled
CI / cargo test (linux) (push) Has been cancelled
CI / cargo test (linux, release) (push) Has been cancelled
CI / cargo test (windows) (push) Has been cancelled
CI / cargo test (wasm) (push) Has been cancelled
CI / cargo build (msrv) (push) Has been cancelled
CI / cargo fuzz build (push) Has been cancelled
CI / fuzz parser (push) Has been cancelled
CI / test scripts (push) Has been cancelled
CI / ecosystem (push) Has been cancelled
CI / Fuzz for new ty panics (push) Has been cancelled
CI / cargo shear (push) Has been cancelled
CI / ty completion evaluation (push) Has been cancelled
CI / formatter instabilities and black similarity (push) Has been cancelled
CI / test ruff-lsp (push) Has been cancelled
CI / check playground (push) Has been cancelled
CI / benchmarks instrumented (ruff) (push) Has been cancelled
CI / benchmarks instrumented (ty) (push) Has been cancelled
CI / benchmarks walltime (medium|multithreaded) (push) Has been cancelled
CI / benchmarks walltime (small|large) (push) Has been cancelled
Some checks failed
CI / cargo fmt (push) Has been cancelled
CI / mkdocs (push) Has been cancelled
CI / Determine changes (push) Has been cancelled
CI / cargo build (release) (push) Has been cancelled
CI / python package (push) Has been cancelled
CI / pre-commit (push) Has been cancelled
[ty Playground] Release / publish (push) Has been cancelled
CI / cargo clippy (push) Has been cancelled
CI / cargo test (linux) (push) Has been cancelled
CI / cargo test (linux, release) (push) Has been cancelled
CI / cargo test (windows) (push) Has been cancelled
CI / cargo test (wasm) (push) Has been cancelled
CI / cargo build (msrv) (push) Has been cancelled
CI / cargo fuzz build (push) Has been cancelled
CI / fuzz parser (push) Has been cancelled
CI / test scripts (push) Has been cancelled
CI / ecosystem (push) Has been cancelled
CI / Fuzz for new ty panics (push) Has been cancelled
CI / cargo shear (push) Has been cancelled
CI / ty completion evaluation (push) Has been cancelled
CI / formatter instabilities and black similarity (push) Has been cancelled
CI / test ruff-lsp (push) Has been cancelled
CI / check playground (push) Has been cancelled
CI / benchmarks instrumented (ruff) (push) Has been cancelled
CI / benchmarks instrumented (ty) (push) Has been cancelled
CI / benchmarks walltime (medium|multithreaded) (push) Has been cancelled
CI / benchmarks walltime (small|large) (push) Has been cancelled
## Summary
Implements bidirectional type inference using function return type
annotations.
This PR was originally proposed to solve astral-sh/ty#1167, but this
does not fully resolve it on its own.
Additionally, I believe we need to allow dataclasses to generate their
own `__new__` methods, [use constructor return types for
inference](5844c0103d/crates/ty_python_semantic/src/types.rs (L5326-L5328)),
and a mechanism to discard type narrowing like `& ~AlwaysFalsy` if
necessary (at a more general level than this PR).
## Test Plan
`mdtest/bidirectional.md` is added.
---------
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Ibraheem Ahmed <ibraheem@ibraheem.ca>
This commit is contained in:
parent
11a9e7ee44
commit
dc64c08633
11 changed files with 442 additions and 58 deletions
|
|
@ -146,7 +146,75 @@ r: dict[int | str, int | str] = {1: 1, 2: 2, 3: 3}
|
|||
reveal_type(r) # revealed: dict[int | str, int | str]
|
||||
```
|
||||
|
||||
## Incorrect collection literal assignments are complained aobut
|
||||
## Optional collection literal annotations are understood
|
||||
|
||||
```toml
|
||||
[environment]
|
||||
python-version = "3.12"
|
||||
```
|
||||
|
||||
```py
|
||||
import typing
|
||||
|
||||
a: list[int] | None = [1, 2, 3]
|
||||
reveal_type(a) # revealed: list[int]
|
||||
|
||||
b: list[int | str] | None = [1, 2, 3]
|
||||
reveal_type(b) # revealed: list[int | str]
|
||||
|
||||
c: typing.List[int] | None = [1, 2, 3]
|
||||
reveal_type(c) # revealed: list[int]
|
||||
|
||||
d: list[typing.Any] | None = []
|
||||
reveal_type(d) # revealed: list[Any]
|
||||
|
||||
e: set[int] | None = {1, 2, 3}
|
||||
reveal_type(e) # revealed: set[int]
|
||||
|
||||
f: set[int | str] | None = {1, 2, 3}
|
||||
reveal_type(f) # revealed: set[int | str]
|
||||
|
||||
g: typing.Set[int] | None = {1, 2, 3}
|
||||
reveal_type(g) # revealed: set[int]
|
||||
|
||||
h: list[list[int]] | None = [[], [42]]
|
||||
reveal_type(h) # revealed: list[list[int]]
|
||||
|
||||
i: list[typing.Any] | None = [1, 2, "3", ([4],)]
|
||||
reveal_type(i) # revealed: list[Any | int | str | tuple[list[Unknown | int]]]
|
||||
|
||||
j: list[tuple[str | int, ...]] | None = [(1, 2), ("foo", "bar"), ()]
|
||||
reveal_type(j) # revealed: list[tuple[str | int, ...]]
|
||||
|
||||
k: list[tuple[list[int], ...]] | None = [([],), ([1, 2], [3, 4]), ([5], [6], [7])]
|
||||
reveal_type(k) # revealed: list[tuple[list[int], ...]]
|
||||
|
||||
l: tuple[list[int], *tuple[list[typing.Any], ...], list[str]] | None = ([1, 2, 3], [4, 5, 6], [7, 8, 9], ["10", "11", "12"])
|
||||
# TODO: this should be `tuple[list[int], list[Any | int], list[Any | int], list[str]]`
|
||||
reveal_type(l) # revealed: tuple[list[Unknown | int], list[Unknown | int], list[Unknown | int], list[Unknown | str]]
|
||||
|
||||
type IntList = list[int]
|
||||
|
||||
m: IntList | None = [1, 2, 3]
|
||||
reveal_type(m) # revealed: list[int]
|
||||
|
||||
n: list[typing.Literal[1, 2, 3]] | None = [1, 2, 3]
|
||||
reveal_type(n) # revealed: list[Literal[1, 2, 3]]
|
||||
|
||||
o: list[typing.LiteralString] | None = ["a", "b", "c"]
|
||||
reveal_type(o) # revealed: list[LiteralString]
|
||||
|
||||
p: dict[int, int] | None = {}
|
||||
reveal_type(p) # revealed: dict[int, int]
|
||||
|
||||
q: dict[int | str, int] | None = {1: 1, 2: 2, 3: 3}
|
||||
reveal_type(q) # revealed: dict[int | str, int]
|
||||
|
||||
r: dict[int | str, int | str] | None = {1: 1, 2: 2, 3: 3}
|
||||
reveal_type(r) # revealed: dict[int | str, int | str]
|
||||
```
|
||||
|
||||
## Incorrect collection literal assignments are complained about
|
||||
|
||||
```py
|
||||
# error: [invalid-assignment] "Object of type `list[Unknown | int]` is not assignable to `list[str]`"
|
||||
|
|
|
|||
147
crates/ty_python_semantic/resources/mdtest/bidirectional.md
Normal file
147
crates/ty_python_semantic/resources/mdtest/bidirectional.md
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
# Bidirectional type inference
|
||||
|
||||
ty partially supports bidirectional type inference. This is a mechanism for inferring the type of an
|
||||
expression "from the outside in". Normally, type inference proceeds "from the inside out". That is,
|
||||
in order to infer the type of an expression, the types of all sub-expressions must first be
|
||||
inferred. There is no reverse dependency. However, when performing complex type inference, such as
|
||||
when generics are involved, the type of an outer expression can sometimes be useful in inferring
|
||||
inner expressions. Bidirectional type inference is a mechanism that propagates such "expected types"
|
||||
to the inference of inner expressions.
|
||||
|
||||
## Propagating target type annotation
|
||||
|
||||
```toml
|
||||
[environment]
|
||||
python-version = "3.12"
|
||||
```
|
||||
|
||||
```py
|
||||
def list1[T](x: T) -> list[T]:
|
||||
return [x]
|
||||
|
||||
l1 = list1(1)
|
||||
reveal_type(l1) # revealed: list[Literal[1]]
|
||||
l2: list[int] = list1(1)
|
||||
reveal_type(l2) # revealed: list[int]
|
||||
|
||||
# `list[Literal[1]]` and `list[int]` are incompatible, since `list[T]` is invariant in `T`.
|
||||
# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`"
|
||||
l2 = l1
|
||||
|
||||
intermediate = list1(1)
|
||||
# TODO: the error will not occur if we can infer the type of `intermediate` to be `list[int]`
|
||||
# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`"
|
||||
l3: list[int] = intermediate
|
||||
# TODO: it would be nice if this were `list[int]`
|
||||
reveal_type(intermediate) # revealed: list[Literal[1]]
|
||||
reveal_type(l3) # revealed: list[int]
|
||||
|
||||
l4: list[int | str] | None = list1(1)
|
||||
reveal_type(l4) # revealed: list[int | str]
|
||||
|
||||
def _(l: list[int] | None = None):
|
||||
l1 = l or list()
|
||||
reveal_type(l1) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown]
|
||||
|
||||
l2: list[int] = l or list()
|
||||
# it would be better if this were `list[int]`? (https://github.com/astral-sh/ty/issues/136)
|
||||
reveal_type(l2) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown]
|
||||
|
||||
def f[T](x: T, cond: bool) -> T | list[T]:
|
||||
return x if cond else [x]
|
||||
|
||||
# TODO: no error
|
||||
# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`"
|
||||
l5: int | list[int] = f(1, True)
|
||||
```
|
||||
|
||||
`typed_dict.py`:
|
||||
|
||||
```py
|
||||
from typing import TypedDict
|
||||
|
||||
class TD(TypedDict):
|
||||
x: int
|
||||
|
||||
d1 = {"x": 1}
|
||||
d2: TD = {"x": 1}
|
||||
d3: dict[str, int] = {"x": 1}
|
||||
|
||||
reveal_type(d1) # revealed: dict[Unknown | str, Unknown | int]
|
||||
reveal_type(d2) # revealed: TD
|
||||
reveal_type(d3) # revealed: dict[str, int]
|
||||
|
||||
def _() -> TD:
|
||||
return {"x": 1}
|
||||
|
||||
def _() -> TD:
|
||||
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor"
|
||||
return {}
|
||||
```
|
||||
|
||||
## Propagating return type annotation
|
||||
|
||||
```toml
|
||||
[environment]
|
||||
python-version = "3.12"
|
||||
```
|
||||
|
||||
```py
|
||||
from typing import overload, Callable
|
||||
|
||||
def list1[T](x: T) -> list[T]:
|
||||
return [x]
|
||||
|
||||
def get_data() -> dict | None:
|
||||
return {}
|
||||
|
||||
def wrap_data() -> list[dict]:
|
||||
if not (res := get_data()):
|
||||
return list1({})
|
||||
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
|
||||
# `list[dict[Unknown, Unknown] & ~AlwaysFalsy]` and `list[dict[Unknown, Unknown]]` are incompatible,
|
||||
# but the return type check passes here because the type of `list1(res)` is inferred
|
||||
# by bidirectional type inference using the annotated return type, and the type of `res` is not used.
|
||||
return list1(res)
|
||||
|
||||
def wrap_data2() -> list[dict] | None:
|
||||
if not (res := get_data()):
|
||||
return None
|
||||
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
|
||||
return list1(res)
|
||||
|
||||
def deco[T](func: Callable[[], T]) -> Callable[[], T]:
|
||||
return func
|
||||
|
||||
def outer() -> Callable[[], list[dict]]:
|
||||
@deco
|
||||
def inner() -> list[dict]:
|
||||
if not (res := get_data()):
|
||||
return list1({})
|
||||
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
|
||||
return list1(res)
|
||||
return inner
|
||||
|
||||
@overload
|
||||
def f(x: int) -> list[int]: ...
|
||||
@overload
|
||||
def f(x: str) -> list[str]: ...
|
||||
def f(x: int | str) -> list[int] | list[str]:
|
||||
# `list[int] | list[str]` is disjoint from `list[int | str]`.
|
||||
if isinstance(x, int):
|
||||
return list1(x)
|
||||
else:
|
||||
return list1(x)
|
||||
|
||||
reveal_type(f(1)) # revealed: list[int]
|
||||
reveal_type(f("a")) # revealed: list[str]
|
||||
|
||||
async def g() -> list[int | str]:
|
||||
return list1(1)
|
||||
|
||||
def h[T](x: T, cond: bool) -> T | list[T]:
|
||||
return i(x, cond)
|
||||
|
||||
def i[T](x: T, cond: bool) -> T | list[T]:
|
||||
return x if cond else [x]
|
||||
```
|
||||
|
|
@ -323,6 +323,9 @@ def union_param(x: T | None) -> T:
|
|||
reveal_type(union_param("a")) # revealed: Literal["a"]
|
||||
reveal_type(union_param(1)) # revealed: Literal[1]
|
||||
reveal_type(union_param(None)) # revealed: Unknown
|
||||
|
||||
def _(x: int | None):
|
||||
reveal_type(union_param(x)) # revealed: int
|
||||
```
|
||||
|
||||
```py
|
||||
|
|
|
|||
|
|
@ -286,6 +286,9 @@ def union_param[T](x: T | None) -> T:
|
|||
reveal_type(union_param("a")) # revealed: Literal["a"]
|
||||
reveal_type(union_param(1)) # revealed: Literal[1]
|
||||
reveal_type(union_param(None)) # revealed: Unknown
|
||||
|
||||
def _(x: int | None):
|
||||
reveal_type(union_param(x)) # revealed: int
|
||||
```
|
||||
|
||||
```py
|
||||
|
|
|
|||
|
|
@ -125,9 +125,10 @@ def homogeneous_list[T](*args: T) -> list[T]:
|
|||
reveal_type(homogeneous_list(1, 2, 3)) # revealed: list[Literal[1, 2, 3]]
|
||||
plot2: Plot = {"y": homogeneous_list(1, 2, 3), "x": None}
|
||||
reveal_type(plot2["y"]) # revealed: list[int]
|
||||
# TODO: no error
|
||||
# error: [invalid-argument-type]
|
||||
|
||||
plot3: Plot = {"y": homogeneous_list(1, 2, 3), "x": homogeneous_list(1, 2, 3)}
|
||||
reveal_type(plot3["y"]) # revealed: list[int]
|
||||
reveal_type(plot3["x"]) # revealed: list[int] | None
|
||||
|
||||
Y = "y"
|
||||
X = "x"
|
||||
|
|
@ -362,7 +363,7 @@ qualifiers override the class-level `total` setting, which sets the default (`to
|
|||
all keys are required by default, `total=False` means that all keys are non-required by default):
|
||||
|
||||
```py
|
||||
from typing_extensions import TypedDict, Required, NotRequired
|
||||
from typing_extensions import TypedDict, Required, NotRequired, Final
|
||||
|
||||
# total=False by default, but id is explicitly Required
|
||||
class Message(TypedDict, total=False):
|
||||
|
|
@ -376,10 +377,17 @@ class User(TypedDict):
|
|||
email: Required[str] # Explicitly required (redundant here)
|
||||
bio: NotRequired[str] # Optional despite total=True
|
||||
|
||||
ID: Final = "id"
|
||||
|
||||
# Valid Message constructions
|
||||
msg1 = Message(id=1) # id required, content optional
|
||||
msg2 = Message(id=2, content="Hello") # both provided
|
||||
msg3 = Message(id=3, timestamp="2024-01-01") # id required, timestamp optional
|
||||
msg4: Message = {"id": 4} # id required, content optional
|
||||
msg5: Message = {ID: 5} # id required, content optional
|
||||
|
||||
def msg() -> Message:
|
||||
return {ID: 1}
|
||||
|
||||
# Valid User constructions
|
||||
user1 = User(name="Alice", email="alice@example.com") # required fields
|
||||
|
|
|
|||
|
|
@ -977,6 +977,10 @@ impl<'db> Type<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn has_type_var(self, db: &'db dyn Db) -> bool {
|
||||
any_over_type(db, self, &|ty| matches!(ty, Type::TypeVar(_)), false)
|
||||
}
|
||||
|
||||
pub(crate) const fn into_class_literal(self) -> Option<ClassLiteral<'db>> {
|
||||
match self {
|
||||
Type::ClassLiteral(class_type) => Some(class_type),
|
||||
|
|
@ -1167,6 +1171,15 @@ impl<'db> Type<'db> {
|
|||
if yes { self.negate(db) } else { *self }
|
||||
}
|
||||
|
||||
/// Remove the union elements that are not related to `target`.
|
||||
pub(crate) fn filter_disjoint_elements(self, db: &'db dyn Db, target: Type<'db>) -> Type<'db> {
|
||||
if let Type::Union(union) = self {
|
||||
union.filter(db, |elem| !elem.is_disjoint_from(db, target))
|
||||
} else {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the fallback instance type that a literal is an instance of, or `None` if the type
|
||||
/// is not a literal.
|
||||
pub(crate) fn literal_fallback_instance(self, db: &'db dyn Db) -> Option<Type<'db>> {
|
||||
|
|
|
|||
|
|
@ -341,6 +341,48 @@ impl<'db> OverloadLiteral<'db> {
|
|||
/// a cross-module dependency directly on the full AST which will lead to cache
|
||||
/// over-invalidation.
|
||||
pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> {
|
||||
let mut signature = self.raw_signature(db);
|
||||
|
||||
let scope = self.body_scope(db);
|
||||
let module = parsed_module(db, self.file(db)).load(db);
|
||||
let function_node = scope.node(db).expect_function().node(&module);
|
||||
let index = semantic_index(db, scope.file(db));
|
||||
let file_scope_id = scope.file_scope_id(db);
|
||||
let is_generator = file_scope_id.is_generator_function(index);
|
||||
|
||||
if function_node.is_async && !is_generator {
|
||||
signature = signature.wrap_coroutine_return_type(db);
|
||||
}
|
||||
signature = signature.mark_typevars_inferable(db);
|
||||
|
||||
let pep695_ctx = function_node.type_params.as_ref().map(|type_params| {
|
||||
GenericContext::from_type_params(db, index, self.definition(db), type_params)
|
||||
});
|
||||
let legacy_ctx = GenericContext::from_function_params(
|
||||
db,
|
||||
self.definition(db),
|
||||
signature.parameters(),
|
||||
signature.return_ty,
|
||||
);
|
||||
// We need to update `signature.generic_context` here,
|
||||
// because type variables in `GenericContext::variables` are still non-inferable.
|
||||
signature.generic_context =
|
||||
GenericContext::merge_pep695_and_legacy(db, pep695_ctx, legacy_ctx);
|
||||
|
||||
signature
|
||||
}
|
||||
|
||||
/// Typed internally-visible "raw" signature for this function.
|
||||
/// That is, type variables in parameter types and the return type remain non-inferable,
|
||||
/// and the return types of async functions are not wrapped in `CoroutineType[...]`.
|
||||
///
|
||||
/// ## Warning
|
||||
///
|
||||
/// This uses the semantic index to find the definition of the function. This means that if the
|
||||
/// calling query is not in the same file as this function is defined in, then this will create
|
||||
/// a cross-module dependency directly on the full AST which will lead to cache
|
||||
/// over-invalidation.
|
||||
fn raw_signature(self, db: &'db dyn Db) -> Signature<'db> {
|
||||
/// `self` or `cls` can be implicitly positional-only if:
|
||||
/// - It is a method AND
|
||||
/// - No parameters in the method use PEP-570 syntax AND
|
||||
|
|
@ -402,11 +444,11 @@ impl<'db> OverloadLiteral<'db> {
|
|||
let function_stmt_node = scope.node(db).expect_function().node(&module);
|
||||
let definition = self.definition(db);
|
||||
let index = semantic_index(db, scope.file(db));
|
||||
let generic_context = function_stmt_node.type_params.as_ref().map(|type_params| {
|
||||
let pep695_ctx = function_stmt_node.type_params.as_ref().map(|type_params| {
|
||||
GenericContext::from_type_params(db, index, definition, type_params)
|
||||
});
|
||||
let file_scope_id = scope.file_scope_id(db);
|
||||
let is_generator = file_scope_id.is_generator_function(index);
|
||||
|
||||
let has_implicitly_positional_first_parameter = has_implicitly_positional_only_first_param(
|
||||
db,
|
||||
self,
|
||||
|
|
@ -417,10 +459,9 @@ impl<'db> OverloadLiteral<'db> {
|
|||
|
||||
Signature::from_function(
|
||||
db,
|
||||
generic_context,
|
||||
pep695_ctx,
|
||||
definition,
|
||||
function_stmt_node,
|
||||
is_generator,
|
||||
has_implicitly_positional_first_parameter,
|
||||
)
|
||||
}
|
||||
|
|
@ -599,6 +640,18 @@ impl<'db> FunctionLiteral<'db> {
|
|||
fn last_definition_signature(self, db: &'db dyn Db) -> Signature<'db> {
|
||||
self.last_definition(db).signature(db)
|
||||
}
|
||||
|
||||
/// Typed externally-visible "raw" signature of the last overload or implementation of this function.
|
||||
///
|
||||
/// ## Warning
|
||||
///
|
||||
/// This uses the semantic index to find the definition of the function. This means that if the
|
||||
/// calling query is not in the same file as this function is defined in, then this will create
|
||||
/// a cross-module dependency directly on the full AST which will lead to cache
|
||||
/// over-invalidation.
|
||||
fn last_definition_raw_signature(self, db: &'db dyn Db) -> Signature<'db> {
|
||||
self.last_definition(db).raw_signature(db)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a function type, which might be a non-generic function, or a specialization of a
|
||||
|
|
@ -877,6 +930,17 @@ impl<'db> FunctionType<'db> {
|
|||
.unwrap_or_else(|| self.literal(db).last_definition_signature(db))
|
||||
}
|
||||
|
||||
/// Typed externally-visible "raw" signature of the last overload or implementation of this function.
|
||||
#[salsa::tracked(
|
||||
returns(ref),
|
||||
cycle_fn=last_definition_signature_cycle_recover,
|
||||
cycle_initial=last_definition_signature_cycle_initial,
|
||||
heap_size=ruff_memory_usage::heap_size,
|
||||
)]
|
||||
pub(crate) fn last_definition_raw_signature(self, db: &'db dyn Db) -> Signature<'db> {
|
||||
self.literal(db).last_definition_raw_signature(db)
|
||||
}
|
||||
|
||||
/// Convert the `FunctionType` into a [`CallableType`].
|
||||
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> CallableType<'db> {
|
||||
CallableType::new(db, self.signature(db), false)
|
||||
|
|
|
|||
|
|
@ -291,6 +291,28 @@ impl<'db> GenericContext<'db> {
|
|||
Some(Self::from_typevar_instances(db, variables))
|
||||
}
|
||||
|
||||
pub(crate) fn merge_pep695_and_legacy(
|
||||
db: &'db dyn Db,
|
||||
pep695_generic_context: Option<Self>,
|
||||
legacy_generic_context: Option<Self>,
|
||||
) -> Option<Self> {
|
||||
match (legacy_generic_context, pep695_generic_context) {
|
||||
(Some(legacy_ctx), Some(ctx)) => {
|
||||
if legacy_ctx
|
||||
.variables(db)
|
||||
.exactly_one()
|
||||
.is_ok_and(|bound_typevar| bound_typevar.typevar(db).is_self(db))
|
||||
{
|
||||
Some(legacy_ctx.merge(db, ctx))
|
||||
} else {
|
||||
// TODO: Raise a diagnostic — mixing PEP 695 and legacy typevars is not allowed
|
||||
Some(ctx)
|
||||
}
|
||||
}
|
||||
(left, right) => left.or(right),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a generic context from the legacy `TypeVar`s that appear in class's base class
|
||||
/// list.
|
||||
pub(crate) fn from_base_classes(
|
||||
|
|
@ -1174,7 +1196,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
pub(crate) fn infer(
|
||||
&mut self,
|
||||
formal: Type<'db>,
|
||||
actual: Type<'db>,
|
||||
mut actual: Type<'db>,
|
||||
) -> Result<(), SpecializationError<'db>> {
|
||||
if formal == actual {
|
||||
return Ok(());
|
||||
|
|
@ -1203,6 +1225,10 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`.
|
||||
// So, here we remove the union elements that are not related to `formal`.
|
||||
actual = actual.filter_disjoint_elements(self.db, formal);
|
||||
|
||||
match (formal, actual) {
|
||||
// TODO: We haven't implemented a full unification solver yet. If typevars appear in
|
||||
// multiple union elements, we ideally want to express that _only one_ of them needs to
|
||||
|
|
@ -1228,9 +1254,15 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
// def _(y: str | int | None):
|
||||
// reveal_type(g(x)) # revealed: str | int
|
||||
// ```
|
||||
let formal_bound_typevars =
|
||||
(formal_union.elements(self.db).iter()).filter_map(|ty| ty.into_type_var());
|
||||
let Ok(formal_bound_typevar) = formal_bound_typevars.exactly_one() else {
|
||||
// We do not handle cases where the `formal` types contain other types that contain type variables
|
||||
// to prevent incorrect specialization: e.g. `T = int | list[int]` for `formal: T | list[T], actual: int | list[int]`
|
||||
// (the correct specialization is `T = int`).
|
||||
let types_have_typevars = formal_union
|
||||
.elements(self.db)
|
||||
.iter()
|
||||
.filter(|ty| ty.has_type_var(self.db));
|
||||
let Ok(Type::TypeVar(formal_bound_typevar)) = types_have_typevars.exactly_one()
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
if (actual_union.elements(self.db).iter()).any(|ty| ty.is_type_var()) {
|
||||
|
|
@ -1241,7 +1273,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
if remaining_actual.is_never() {
|
||||
return Ok(());
|
||||
}
|
||||
self.add_type_mapping(formal_bound_typevar, remaining_actual);
|
||||
self.add_type_mapping(*formal_bound_typevar, remaining_actual);
|
||||
}
|
||||
(Type::Union(formal), _) => {
|
||||
// Second, if the formal is a union, and precisely one union element _is_ a typevar (not
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ use crate::semantic_index::expression::Expression;
|
|||
use crate::semantic_index::scope::ScopeId;
|
||||
use crate::semantic_index::{SemanticIndex, semantic_index};
|
||||
use crate::types::diagnostic::TypeCheckDiagnostics;
|
||||
use crate::types::function::FunctionType;
|
||||
use crate::types::generics::Specialization;
|
||||
use crate::types::unpacker::{UnpackResult, Unpacker};
|
||||
use crate::types::{ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers};
|
||||
|
|
@ -389,6 +390,12 @@ impl<'db> TypeContext<'db> {
|
|||
self.annotation
|
||||
.and_then(|ty| ty.known_specialization(db, known_class))
|
||||
}
|
||||
|
||||
pub(crate) fn map_annotation(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
|
||||
Self {
|
||||
annotation: self.annotation.map(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the statically-known truthiness of a given expression.
|
||||
|
|
@ -487,6 +494,30 @@ pub(crate) fn nearest_enclosing_class<'db>(
|
|||
})
|
||||
}
|
||||
|
||||
/// Returns the type of the nearest enclosing function for the given scope.
|
||||
///
|
||||
/// This function walks up the ancestor scopes starting from the given scope,
|
||||
/// and finds the closest (non-lambda) function definition.
|
||||
///
|
||||
/// Returns `None` if no enclosing function is found.
|
||||
pub(crate) fn nearest_enclosing_function<'db>(
|
||||
db: &'db dyn Db,
|
||||
semantic: &SemanticIndex<'db>,
|
||||
scope: ScopeId,
|
||||
) -> Option<FunctionType<'db>> {
|
||||
semantic
|
||||
.ancestor_scopes(scope.file_scope_id(db))
|
||||
.find_map(|(_, ancestor_scope)| {
|
||||
let func = ancestor_scope.node().as_function()?;
|
||||
let definition = semantic.expect_single_definition(func);
|
||||
let inference = infer_definition_types(db, definition);
|
||||
inference
|
||||
.undecorated_type()
|
||||
.unwrap_or_else(|| inference.declaration_type(definition).inner_type())
|
||||
.into_function_literal()
|
||||
})
|
||||
}
|
||||
|
||||
/// A region within which we can infer types.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub(crate) enum InferenceRegion<'db> {
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ use crate::types::function::{
|
|||
};
|
||||
use crate::types::generics::{GenericContext, bind_typevar};
|
||||
use crate::types::generics::{LegacyGenericBase, SpecializationBuilder};
|
||||
use crate::types::infer::nearest_enclosing_function;
|
||||
use crate::types::instance::SliceLiteral;
|
||||
use crate::types::mro::MroErrorKind;
|
||||
use crate::types::signatures::Signature;
|
||||
|
|
@ -5101,9 +5102,20 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
}
|
||||
|
||||
fn infer_return_statement(&mut self, ret: &ast::StmtReturn) {
|
||||
if let Some(ty) =
|
||||
self.infer_optional_expression(ret.value.as_deref(), TypeContext::default())
|
||||
{
|
||||
let tcx = if ret.value.is_some() {
|
||||
nearest_enclosing_function(self.db(), self.index, self.scope())
|
||||
.map(|func| {
|
||||
// When inferring expressions within a function body,
|
||||
// the expected type passed should be the "raw" type,
|
||||
// i.e. type variables in the return type are non-inferable,
|
||||
// and the return types of async functions are not wrapped in `CoroutineType[...]`.
|
||||
TypeContext::new(func.last_definition_raw_signature(self.db()).return_ty)
|
||||
})
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
TypeContext::default()
|
||||
};
|
||||
if let Some(ty) = self.infer_optional_expression(ret.value.as_deref(), tcx) {
|
||||
let range = ret
|
||||
.value
|
||||
.as_ref()
|
||||
|
|
@ -5900,6 +5912,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
return None;
|
||||
};
|
||||
|
||||
let tcx = tcx.map_annotation(|annotation| {
|
||||
// Remove any union elements of `annotation` that are not related to `collection_ty`.
|
||||
// e.g. `annotation: list[int] | None => list[int]` if `collection_ty: list`
|
||||
let collection_ty = collection_class.to_instance(self.db());
|
||||
annotation.filter_disjoint_elements(self.db(), collection_ty)
|
||||
});
|
||||
|
||||
// Extract the annotated type of `T`, if provided.
|
||||
let annotated_elt_tys = tcx
|
||||
.known_specialization(self.db(), collection_class)
|
||||
|
|
|
|||
|
|
@ -26,9 +26,10 @@ use crate::types::function::FunctionType;
|
|||
use crate::types::generics::{GenericContext, typing_self, walk_generic_context};
|
||||
use crate::types::infer::nearest_enclosing_class;
|
||||
use crate::types::{
|
||||
ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassLiteral, FindLegacyTypeVarsVisitor,
|
||||
HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, KnownClass, MaterializationKind,
|
||||
NormalizedVisitor, TypeContext, TypeMapping, TypeRelation, VarianceInferable, todo_type,
|
||||
ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, ClassLiteral,
|
||||
FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor,
|
||||
KnownClass, MaterializationKind, NormalizedVisitor, TypeContext, TypeMapping, TypeRelation,
|
||||
VarianceInferable, todo_type,
|
||||
};
|
||||
use crate::{Db, FxOrderSet};
|
||||
use ruff_python_ast::{self as ast, name::Name};
|
||||
|
|
@ -419,10 +420,9 @@ impl<'db> Signature<'db> {
|
|||
/// Return a typed signature from a function definition.
|
||||
pub(super) fn from_function(
|
||||
db: &'db dyn Db,
|
||||
generic_context: Option<GenericContext<'db>>,
|
||||
pep695_generic_context: Option<GenericContext<'db>>,
|
||||
definition: Definition<'db>,
|
||||
function_node: &ast::StmtFunctionDef,
|
||||
is_generator: bool,
|
||||
has_implicitly_positional_first_parameter: bool,
|
||||
) -> Self {
|
||||
let parameters = Parameters::from_parameters(
|
||||
|
|
@ -431,38 +431,17 @@ impl<'db> Signature<'db> {
|
|||
function_node.parameters.as_ref(),
|
||||
has_implicitly_positional_first_parameter,
|
||||
);
|
||||
let return_ty = function_node.returns.as_ref().map(|returns| {
|
||||
let plain_return_ty = definition_expression_type(db, definition, returns.as_ref())
|
||||
.apply_type_mapping(
|
||||
db,
|
||||
&TypeMapping::MarkTypeVarsInferable(Some(definition.into())),
|
||||
TypeContext::default(),
|
||||
);
|
||||
if function_node.is_async && !is_generator {
|
||||
KnownClass::CoroutineType
|
||||
.to_specialized_instance(db, [Type::any(), Type::any(), plain_return_ty])
|
||||
} else {
|
||||
plain_return_ty
|
||||
}
|
||||
});
|
||||
let return_ty = function_node
|
||||
.returns
|
||||
.as_ref()
|
||||
.map(|returns| definition_expression_type(db, definition, returns.as_ref()));
|
||||
let legacy_generic_context =
|
||||
GenericContext::from_function_params(db, definition, ¶meters, return_ty);
|
||||
|
||||
let full_generic_context = match (legacy_generic_context, generic_context) {
|
||||
(Some(legacy_ctx), Some(ctx)) => {
|
||||
if legacy_ctx
|
||||
.variables(db)
|
||||
.exactly_one()
|
||||
.is_ok_and(|bound_typevar| bound_typevar.typevar(db).is_self(db))
|
||||
{
|
||||
Some(legacy_ctx.merge(db, ctx))
|
||||
} else {
|
||||
// TODO: Raise a diagnostic — mixing PEP 695 and legacy typevars is not allowed
|
||||
Some(ctx)
|
||||
}
|
||||
}
|
||||
(left, right) => left.or(right),
|
||||
};
|
||||
let full_generic_context = GenericContext::merge_pep695_and_legacy(
|
||||
db,
|
||||
pep695_generic_context,
|
||||
legacy_generic_context,
|
||||
);
|
||||
|
||||
Self {
|
||||
generic_context: full_generic_context,
|
||||
|
|
@ -472,6 +451,27 @@ impl<'db> Signature<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
pub(super) fn mark_typevars_inferable(self, db: &'db dyn Db) -> Self {
|
||||
if let Some(definition) = self.definition {
|
||||
self.apply_type_mapping_impl(
|
||||
db,
|
||||
&TypeMapping::MarkTypeVarsInferable(Some(BindingContext::Definition(definition))),
|
||||
TypeContext::default(),
|
||||
&ApplyTypeMappingVisitor::default(),
|
||||
)
|
||||
} else {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn wrap_coroutine_return_type(self, db: &'db dyn Db) -> Self {
|
||||
let return_ty = self.return_ty.map(|return_ty| {
|
||||
KnownClass::CoroutineType
|
||||
.to_specialized_instance(db, [Type::any(), Type::any(), return_ty])
|
||||
});
|
||||
Self { return_ty, ..self }
|
||||
}
|
||||
|
||||
/// Returns the signature which accepts any parameters and returns an `Unknown` type.
|
||||
pub(crate) fn unknown() -> Self {
|
||||
Self::new(Parameters::unknown(), Some(Type::unknown()))
|
||||
|
|
@ -1728,13 +1728,9 @@ impl<'db> Parameter<'db> {
|
|||
kind: ParameterKind<'db>,
|
||||
) -> Self {
|
||||
Self {
|
||||
annotated_type: parameter.annotation().map(|annotation| {
|
||||
definition_expression_type(db, definition, annotation).apply_type_mapping(
|
||||
db,
|
||||
&TypeMapping::MarkTypeVarsInferable(Some(definition.into())),
|
||||
TypeContext::default(),
|
||||
)
|
||||
}),
|
||||
annotated_type: parameter
|
||||
.annotation()
|
||||
.map(|annotation| definition_expression_type(db, definition, annotation)),
|
||||
kind,
|
||||
form: ParameterForm::Value,
|
||||
inferred_annotation: false,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue