mirror of
https://github.com/astral-sh/ruff.git
synced 2025-11-19 03:48:29 +00:00
[ty] Improve generic call expression inference (#21210)
## Summary Implements https://github.com/astral-sh/ty/issues/1356 and https://github.com/astral-sh/ty/issues/136#issuecomment-3413669994.
This commit is contained in:
parent
d258302b08
commit
98869f0307
8 changed files with 655 additions and 153 deletions
|
|
@ -417,6 +417,8 @@ reveal_type(x) # revealed: Literal[1]
|
|||
python-version = "3.12"
|
||||
```
|
||||
|
||||
`generic_list.py`:
|
||||
|
||||
```py
|
||||
from typing import Literal
|
||||
|
||||
|
|
@ -427,14 +429,13 @@ a = f("a")
|
|||
reveal_type(a) # revealed: list[Literal["a"]]
|
||||
|
||||
b: list[int | Literal["a"]] = f("a")
|
||||
reveal_type(b) # revealed: list[Literal["a"] | int]
|
||||
reveal_type(b) # revealed: list[int | Literal["a"]]
|
||||
|
||||
c: list[int | str] = f("a")
|
||||
reveal_type(c) # revealed: list[str | int]
|
||||
reveal_type(c) # revealed: list[int | str]
|
||||
|
||||
d: list[int | tuple[int, int]] = f((1, 2))
|
||||
# TODO: We could avoid reordering the union elements here.
|
||||
reveal_type(d) # revealed: list[tuple[int, int] | int]
|
||||
reveal_type(d) # revealed: list[int | tuple[int, int]]
|
||||
|
||||
e: list[int] = f(True)
|
||||
reveal_type(e) # revealed: list[int]
|
||||
|
|
@ -455,10 +456,218 @@ j: int | str = f2(True)
|
|||
reveal_type(j) # revealed: Literal[True]
|
||||
```
|
||||
|
||||
Types are not widened unnecessarily:
|
||||
A function's arguments are also inferred using the type context:
|
||||
|
||||
`typed_dict.py`:
|
||||
|
||||
```py
|
||||
def id[T](x: T) -> T:
|
||||
from typing import TypedDict
|
||||
|
||||
class TD(TypedDict):
|
||||
x: int
|
||||
|
||||
def f[T](x: list[T]) -> T:
|
||||
return x[0]
|
||||
|
||||
a: TD = f([{"x": 0}, {"x": 1}])
|
||||
reveal_type(a) # revealed: TD
|
||||
|
||||
b: TD | None = f([{"x": 0}, {"x": 1}])
|
||||
reveal_type(b) # revealed: TD
|
||||
|
||||
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor"
|
||||
# error: [invalid-key] "Invalid key for TypedDict `TD`: Unknown key "y""
|
||||
# error: [invalid-assignment] "Object of type `Unknown | dict[Unknown | str, Unknown | int]` is not assignable to `TD`"
|
||||
c: TD = f([{"y": 0}, {"x": 1}])
|
||||
|
||||
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor"
|
||||
# error: [invalid-key] "Invalid key for TypedDict `TD`: Unknown key "y""
|
||||
# error: [invalid-assignment] "Object of type `Unknown | dict[Unknown | str, Unknown | int]` is not assignable to `TD | None`"
|
||||
c: TD | None = f([{"y": 0}, {"x": 1}])
|
||||
```
|
||||
|
||||
But not in a way that leads to assignability errors:
|
||||
|
||||
`dict_any.py`:
|
||||
|
||||
```py
|
||||
from typing import TypedDict, Any
|
||||
|
||||
class TD(TypedDict, total=False):
|
||||
x: str
|
||||
|
||||
class TD2(TypedDict):
|
||||
x: str
|
||||
|
||||
def f(self, dt: dict[str, Any], key: str):
|
||||
# TODO: This should not error once typed dict assignability is implemented.
|
||||
# error: [invalid-assignment]
|
||||
x1: TD = dt.get(key, {})
|
||||
reveal_type(x1) # revealed: TD
|
||||
|
||||
x2: TD = dt.get(key, {"x": 0})
|
||||
reveal_type(x2) # revealed: Any
|
||||
|
||||
x3: TD | None = dt.get(key, {})
|
||||
# TODO: This should reveal `Any` once typed dict assignability is implemented.
|
||||
reveal_type(x3) # revealed: Any | None
|
||||
|
||||
x4: TD | None = dt.get(key, {"x": 0})
|
||||
reveal_type(x4) # revealed: Any
|
||||
|
||||
x5: TD2 = dt.get(key, {})
|
||||
reveal_type(x5) # revealed: Any
|
||||
|
||||
x6: TD2 = dt.get(key, {"x": 0})
|
||||
reveal_type(x6) # revealed: Any
|
||||
|
||||
x7: TD2 | None = dt.get(key, {})
|
||||
reveal_type(x7) # revealed: Any
|
||||
|
||||
x8: TD2 | None = dt.get(key, {"x": 0})
|
||||
reveal_type(x8) # revealed: Any
|
||||
```
|
||||
|
||||
## Prefer the declared type of generic classes
|
||||
|
||||
```toml
|
||||
[environment]
|
||||
python-version = "3.12"
|
||||
```
|
||||
|
||||
```py
|
||||
from typing import Any
|
||||
|
||||
def f[T](x: T) -> list[T]:
|
||||
return [x]
|
||||
|
||||
def f2[T](x: T) -> list[T] | None:
|
||||
return [x]
|
||||
|
||||
def f3[T](x: T) -> list[T] | dict[T, T]:
|
||||
return [x]
|
||||
|
||||
a = f(1)
|
||||
reveal_type(a) # revealed: list[Literal[1]]
|
||||
|
||||
b: list[Any] = f(1)
|
||||
reveal_type(b) # revealed: list[Any]
|
||||
|
||||
c: list[Any] = [1]
|
||||
reveal_type(c) # revealed: list[Any]
|
||||
|
||||
d: list[Any] | None = f(1)
|
||||
reveal_type(d) # revealed: list[Any]
|
||||
|
||||
e: list[Any] | None = [1]
|
||||
reveal_type(e) # revealed: list[Any]
|
||||
|
||||
f: list[Any] | None = f2(1)
|
||||
# TODO: Better constraint solver.
|
||||
reveal_type(f) # revealed: list[Literal[1]] | None
|
||||
|
||||
g: list[Any] | dict[Any, Any] = f3(1)
|
||||
# TODO: Better constraint solver.
|
||||
reveal_type(g) # revealed: list[Literal[1]] | dict[Literal[1], Literal[1]]
|
||||
```
|
||||
|
||||
We currently prefer the generic declared type regardless of its variance:
|
||||
|
||||
```py
|
||||
class Bivariant[T]:
|
||||
pass
|
||||
|
||||
class Covariant[T]:
|
||||
def pop(self) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
class Contravariant[T]:
|
||||
def push(self, value: T) -> None:
|
||||
pass
|
||||
|
||||
class Invariant[T]:
|
||||
x: T
|
||||
|
||||
def bivariant[T](x: T) -> Bivariant[T]:
|
||||
return Bivariant()
|
||||
|
||||
def covariant[T](x: T) -> Covariant[T]:
|
||||
return Covariant()
|
||||
|
||||
def contravariant[T](x: T) -> Contravariant[T]:
|
||||
return Contravariant()
|
||||
|
||||
def invariant[T](x: T) -> Invariant[T]:
|
||||
return Invariant()
|
||||
|
||||
x1 = bivariant(1)
|
||||
x2 = covariant(1)
|
||||
x3 = contravariant(1)
|
||||
x4 = invariant(1)
|
||||
|
||||
reveal_type(x1) # revealed: Bivariant[Literal[1]]
|
||||
reveal_type(x2) # revealed: Covariant[Literal[1]]
|
||||
reveal_type(x3) # revealed: Contravariant[Literal[1]]
|
||||
reveal_type(x4) # revealed: Invariant[Literal[1]]
|
||||
|
||||
x5: Bivariant[Any] = bivariant(1)
|
||||
x6: Covariant[Any] = covariant(1)
|
||||
x7: Contravariant[Any] = contravariant(1)
|
||||
x8: Invariant[Any] = invariant(1)
|
||||
|
||||
# TODO: This could reveal `Bivariant[Any]`.
|
||||
reveal_type(x5) # revealed: Bivariant[Literal[1]]
|
||||
reveal_type(x6) # revealed: Covariant[Any]
|
||||
reveal_type(x7) # revealed: Contravariant[Any]
|
||||
reveal_type(x8) # revealed: Invariant[Any]
|
||||
```
|
||||
|
||||
## Narrow generic unions
|
||||
|
||||
```toml
|
||||
[environment]
|
||||
python-version = "3.12"
|
||||
```
|
||||
|
||||
```py
|
||||
from typing import reveal_type, TypedDict
|
||||
|
||||
def identity[T](x: T) -> T:
|
||||
return x
|
||||
|
||||
def _(narrow: dict[str, str], target: list[str] | dict[str, str] | None):
|
||||
target = identity(narrow)
|
||||
reveal_type(target) # revealed: dict[str, str]
|
||||
|
||||
def _(narrow: list[str], target: list[str] | dict[str, str] | None):
|
||||
target = identity(narrow)
|
||||
reveal_type(target) # revealed: list[str]
|
||||
|
||||
def _(narrow: list[str] | dict[str, str], target: list[str] | dict[str, str] | None):
|
||||
target = identity(narrow)
|
||||
reveal_type(target) # revealed: list[str] | dict[str, str]
|
||||
|
||||
class TD(TypedDict):
|
||||
x: int
|
||||
|
||||
def _(target: list[TD] | dict[str, TD] | None):
|
||||
target = identity([{"x": 1}])
|
||||
reveal_type(target) # revealed: list[TD]
|
||||
|
||||
def _(target: list[TD] | dict[str, TD] | None):
|
||||
target = identity({"x": {"x": 1}})
|
||||
reveal_type(target) # revealed: dict[str, TD]
|
||||
```
|
||||
|
||||
## Prefer the inferred type of non-generic classes
|
||||
|
||||
```toml
|
||||
[environment]
|
||||
python-version = "3.12"
|
||||
```
|
||||
|
||||
```py
|
||||
def identity[T](x: T) -> T:
|
||||
return x
|
||||
|
||||
def lst[T](x: T) -> list[T]:
|
||||
|
|
@ -466,20 +675,18 @@ def lst[T](x: T) -> list[T]:
|
|||
|
||||
def _(i: int):
|
||||
a: int | None = i
|
||||
b: int | None = id(i)
|
||||
c: int | str | None = id(i)
|
||||
b: int | None = identity(i)
|
||||
c: int | str | None = identity(i)
|
||||
reveal_type(a) # revealed: int
|
||||
reveal_type(b) # revealed: int
|
||||
reveal_type(c) # revealed: int
|
||||
|
||||
a: list[int | None] | None = [i]
|
||||
b: list[int | None] | None = id([i])
|
||||
c: list[int | None] | int | None = id([i])
|
||||
b: list[int | None] | None = identity([i])
|
||||
c: list[int | None] | int | None = identity([i])
|
||||
reveal_type(a) # revealed: list[int | None]
|
||||
# TODO: these should reveal `list[int | None]`
|
||||
# we currently do not use the call expression annotation as type context for argument inference
|
||||
reveal_type(b) # revealed: list[Unknown | int]
|
||||
reveal_type(c) # revealed: list[Unknown | int]
|
||||
reveal_type(b) # revealed: list[int | None]
|
||||
reveal_type(c) # revealed: list[int | None]
|
||||
|
||||
a: list[int | None] | None = [i]
|
||||
b: list[int | None] | None = lst(i)
|
||||
|
|
@ -489,9 +696,44 @@ def _(i: int):
|
|||
reveal_type(c) # revealed: list[int | None]
|
||||
|
||||
a: list | None = []
|
||||
b: list | None = id([])
|
||||
c: list | int | None = id([])
|
||||
b: list | None = identity([])
|
||||
c: list | int | None = identity([])
|
||||
reveal_type(a) # revealed: list[Unknown]
|
||||
reveal_type(b) # revealed: list[Unknown]
|
||||
reveal_type(c) # revealed: list[Unknown]
|
||||
|
||||
def f[T](x: list[T]) -> T:
|
||||
return x[0]
|
||||
|
||||
def _(a: int, b: str, c: int | str):
|
||||
x1: int = f(lst(a))
|
||||
reveal_type(x1) # revealed: int
|
||||
|
||||
x2: int | str = f(lst(a))
|
||||
reveal_type(x2) # revealed: int
|
||||
|
||||
x3: int | None = f(lst(a))
|
||||
reveal_type(x3) # revealed: int
|
||||
|
||||
x4: str = f(lst(b))
|
||||
reveal_type(x4) # revealed: str
|
||||
|
||||
x5: int | str = f(lst(b))
|
||||
reveal_type(x5) # revealed: str
|
||||
|
||||
x6: str | None = f(lst(b))
|
||||
reveal_type(x6) # revealed: str
|
||||
|
||||
x7: int | str = f(lst(c))
|
||||
reveal_type(x7) # revealed: int | str
|
||||
|
||||
x8: int | str = f(lst(c))
|
||||
reveal_type(x8) # revealed: int | str
|
||||
|
||||
# TODO: Ideally this would reveal `int | str`. This is a known limitation of our
|
||||
# call inference solver, and would # require an extra inference attempt without type
|
||||
# context, or with type context # of subsets of the union, both of which are impractical
|
||||
# for performance reasons.
|
||||
x9: int | str | None = f(lst(c))
|
||||
reveal_type(x9) # revealed: int | str | None
|
||||
```
|
||||
|
|
|
|||
|
|
@ -50,8 +50,8 @@ def _(l: list[int] | None = None):
|
|||
def f[T](x: T, cond: bool) -> T | list[T]:
|
||||
return x if cond else [x]
|
||||
|
||||
# TODO: no error
|
||||
# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`"
|
||||
# TODO: Better constraint solver.
|
||||
# error: [invalid-assignment]
|
||||
l5: int | list[int] = f(1, True)
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class Data:
|
|||
content: list[int] = field(default_factory=list)
|
||||
timestamp: datetime = field(default_factory=datetime.now, init=False)
|
||||
|
||||
# revealed: (self: Data, content: list[int] = Unknown) -> None
|
||||
# revealed: (self: Data, content: list[int] = list[int]) -> None
|
||||
reveal_type(Data.__init__)
|
||||
|
||||
data = Data([1, 2, 3])
|
||||
|
|
@ -63,7 +63,6 @@ class Person:
|
|||
age: int | None = field(default=None, kw_only=True)
|
||||
role: str = field(default="user", kw_only=True)
|
||||
|
||||
# TODO: this would ideally show a default value of `None` for `age`
|
||||
# revealed: (self: Person, name: str, *, age: int | None = None, role: str = Literal["user"]) -> None
|
||||
reveal_type(Person.__init__)
|
||||
|
||||
|
|
|
|||
|
|
@ -885,20 +885,31 @@ impl<'db> Type<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
// If the type is a specialized instance of the given `KnownClass`, returns the specialization.
|
||||
/// If the type is a specialized instance of the given `KnownClass`, returns the specialization.
|
||||
pub(crate) fn known_specialization(
|
||||
&self,
|
||||
db: &'db dyn Db,
|
||||
known_class: KnownClass,
|
||||
) -> Option<Specialization<'db>> {
|
||||
let class_literal = known_class.try_to_class_literal(db)?;
|
||||
self.specialization_of(db, Some(class_literal))
|
||||
self.specialization_of(db, class_literal)
|
||||
}
|
||||
|
||||
// If the type is a specialized instance of the given class, returns the specialization.
|
||||
//
|
||||
// If no class is provided, returns the specialization of any class instance.
|
||||
/// If this type is a class instance, returns its specialization.
|
||||
pub(crate) fn class_specialization(self, db: &'db dyn Db) -> Option<Specialization<'db>> {
|
||||
self.specialization_of_optional(db, None)
|
||||
}
|
||||
|
||||
/// If the type is a specialized instance of the given class, returns the specialization.
|
||||
pub(crate) fn specialization_of(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
expected_class: ClassLiteral<'_>,
|
||||
) -> Option<Specialization<'db>> {
|
||||
self.specialization_of_optional(db, Some(expected_class))
|
||||
}
|
||||
|
||||
fn specialization_of_optional(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
expected_class: Option<ClassLiteral<'_>>,
|
||||
|
|
@ -5588,7 +5599,7 @@ impl<'db> Type<'db> {
|
|||
) -> Result<Bindings<'db>, CallError<'db>> {
|
||||
self.bindings(db)
|
||||
.match_parameters(db, argument_types)
|
||||
.check_types(db, argument_types, &TypeContext::default(), &[])
|
||||
.check_types(db, argument_types, TypeContext::default(), &[])
|
||||
}
|
||||
|
||||
/// Look up a dunder method on the meta-type of `self` and call it.
|
||||
|
|
@ -5640,7 +5651,8 @@ impl<'db> Type<'db> {
|
|||
let bindings = dunder_callable
|
||||
.bindings(db)
|
||||
.match_parameters(db, argument_types)
|
||||
.check_types(db, argument_types, &tcx, &[])?;
|
||||
.check_types(db, argument_types, tcx, &[])?;
|
||||
|
||||
if boundness == Definedness::PossiblyUndefined {
|
||||
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,11 +35,11 @@ use crate::types::generics::{
|
|||
use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters};
|
||||
use crate::types::tuple::{TupleLength, TupleType};
|
||||
use crate::types::{
|
||||
BoundMethodType, ClassLiteral, DataclassFlags, DataclassParams, FieldInstance,
|
||||
KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, NominalInstanceType,
|
||||
PropertyInstanceType, SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext,
|
||||
UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support, infer_isolated_expression,
|
||||
todo_type,
|
||||
BoundMethodType, BoundTypeVarIdentity, ClassLiteral, DataclassFlags, DataclassParams,
|
||||
FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy,
|
||||
NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
|
||||
TypeAliasType, TypeContext, UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support,
|
||||
infer_isolated_expression, todo_type,
|
||||
};
|
||||
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
|
||||
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
|
||||
|
|
@ -48,7 +48,7 @@ use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
|
|||
/// compatible with _all_ of the types in the union for the call to be valid.
|
||||
///
|
||||
/// It's guaranteed that the wrapped bindings have no errors.
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Bindings<'db> {
|
||||
/// The type that is (hopefully) callable.
|
||||
callable_type: Type<'db>,
|
||||
|
|
@ -150,9 +150,27 @@ impl<'db> Bindings<'db> {
|
|||
mut self,
|
||||
db: &'db dyn Db,
|
||||
argument_types: &CallArguments<'_, 'db>,
|
||||
call_expression_tcx: &TypeContext<'db>,
|
||||
call_expression_tcx: TypeContext<'db>,
|
||||
dataclass_field_specifiers: &[Type<'db>],
|
||||
) -> Result<Self, CallError<'db>> {
|
||||
match self.check_types_impl(
|
||||
db,
|
||||
argument_types,
|
||||
call_expression_tcx,
|
||||
dataclass_field_specifiers,
|
||||
) {
|
||||
Ok(()) => Ok(self),
|
||||
Err(err) => Err(CallError(err, Box::new(self))),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn check_types_impl(
|
||||
&mut self,
|
||||
db: &'db dyn Db,
|
||||
argument_types: &CallArguments<'_, 'db>,
|
||||
call_expression_tcx: TypeContext<'db>,
|
||||
dataclass_field_specifiers: &[Type<'db>],
|
||||
) -> Result<(), CallErrorKind> {
|
||||
for element in &mut self.elements {
|
||||
if let Some(mut updated_argument_forms) =
|
||||
element.check_types(db, argument_types, call_expression_tcx)
|
||||
|
|
@ -197,16 +215,13 @@ impl<'db> Bindings<'db> {
|
|||
}
|
||||
|
||||
if all_ok {
|
||||
Ok(self)
|
||||
Ok(())
|
||||
} else if any_binding_error {
|
||||
Err(CallError(CallErrorKind::BindingError, Box::new(self)))
|
||||
Err(CallErrorKind::BindingError)
|
||||
} else if all_not_callable {
|
||||
Err(CallError(CallErrorKind::NotCallable, Box::new(self)))
|
||||
Err(CallErrorKind::NotCallable)
|
||||
} else {
|
||||
Err(CallError(
|
||||
CallErrorKind::PossiblyNotCallable,
|
||||
Box::new(self),
|
||||
))
|
||||
Err(CallErrorKind::PossiblyNotCallable)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1365,7 +1380,7 @@ impl<'db> From<Binding<'db>> for Bindings<'db> {
|
|||
/// If the arguments cannot be matched to formal parameters, we store information about the
|
||||
/// specific errors that occurred when trying to match them up. If the callable has multiple
|
||||
/// overloads, we store this error information for each overload.
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct CallableBinding<'db> {
|
||||
/// The type that is (hopefully) callable.
|
||||
pub(crate) callable_type: Type<'db>,
|
||||
|
|
@ -1486,7 +1501,7 @@ impl<'db> CallableBinding<'db> {
|
|||
&mut self,
|
||||
db: &'db dyn Db,
|
||||
argument_types: &CallArguments<'_, 'db>,
|
||||
call_expression_tcx: &TypeContext<'db>,
|
||||
call_expression_tcx: TypeContext<'db>,
|
||||
) -> Option<ArgumentForms> {
|
||||
// If this callable is a bound method, prepend the self instance onto the arguments list
|
||||
// before checking.
|
||||
|
|
@ -2267,7 +2282,7 @@ pub(crate) enum MatchingOverloadIndex {
|
|||
Multiple(Vec<usize>),
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
#[derive(Default, Debug, Clone)]
|
||||
struct ArgumentForms {
|
||||
values: Vec<Option<ParameterForm>>,
|
||||
conflicting: Vec<bool>,
|
||||
|
|
@ -2672,7 +2687,7 @@ struct ArgumentTypeChecker<'a, 'db> {
|
|||
arguments: &'a CallArguments<'a, 'db>,
|
||||
argument_matches: &'a [MatchedArgument<'db>],
|
||||
parameter_tys: &'a mut [Option<Type<'db>>],
|
||||
call_expression_tcx: &'a TypeContext<'db>,
|
||||
call_expression_tcx: TypeContext<'db>,
|
||||
return_ty: Type<'db>,
|
||||
errors: &'a mut Vec<BindingError<'db>>,
|
||||
|
||||
|
|
@ -2688,7 +2703,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
arguments: &'a CallArguments<'a, 'db>,
|
||||
argument_matches: &'a [MatchedArgument<'db>],
|
||||
parameter_tys: &'a mut [Option<Type<'db>>],
|
||||
call_expression_tcx: &'a TypeContext<'db>,
|
||||
call_expression_tcx: TypeContext<'db>,
|
||||
return_ty: Type<'db>,
|
||||
errors: &'a mut Vec<BindingError<'db>>,
|
||||
) -> Self {
|
||||
|
|
@ -2738,9 +2753,21 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
return;
|
||||
};
|
||||
|
||||
let return_with_tcx = self
|
||||
.signature
|
||||
.return_ty
|
||||
.zip(self.call_expression_tcx.annotation);
|
||||
|
||||
self.inferable_typevars = generic_context.inferable_typevars(self.db);
|
||||
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);
|
||||
|
||||
// Prefer the declared type of generic classes.
|
||||
let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| {
|
||||
tcx.class_specialization(self.db)?;
|
||||
builder.infer(return_ty, tcx).ok()?;
|
||||
Some(builder.type_mappings().clone())
|
||||
});
|
||||
|
||||
let parameters = self.signature.parameters();
|
||||
for (argument_index, adjusted_argument_index, _, argument_type) in
|
||||
self.enumerate_argument_types()
|
||||
|
|
@ -2753,9 +2780,21 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
continue;
|
||||
};
|
||||
|
||||
if let Err(error) = builder.infer(
|
||||
let filter = |declared_ty: BoundTypeVarIdentity<'_>, inferred_ty: Type<'_>| {
|
||||
// Avoid widening the inferred type if it is already assignable to the
|
||||
// preferred declared type.
|
||||
preferred_type_mappings
|
||||
.as_ref()
|
||||
.and_then(|types| types.get(&declared_ty))
|
||||
.is_none_or(|preferred_ty| {
|
||||
!inferred_ty.is_assignable_to(self.db, *preferred_ty)
|
||||
})
|
||||
};
|
||||
|
||||
if let Err(error) = builder.infer_filter(
|
||||
expected_type,
|
||||
variadic_argument_type.unwrap_or(argument_type),
|
||||
filter,
|
||||
) {
|
||||
self.errors.push(BindingError::SpecializationError {
|
||||
error,
|
||||
|
|
@ -2765,15 +2804,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
}
|
||||
}
|
||||
|
||||
// Build the specialization first without inferring the type context.
|
||||
let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx);
|
||||
// Build the specialization first without inferring the complete type context.
|
||||
let isolated_specialization = builder.build(generic_context, self.call_expression_tcx);
|
||||
let isolated_return_ty = self
|
||||
.return_ty
|
||||
.apply_specialization(self.db, isolated_specialization);
|
||||
|
||||
let mut try_infer_tcx = || {
|
||||
let return_ty = self.signature.return_ty?;
|
||||
let call_expression_tcx = self.call_expression_tcx.annotation?;
|
||||
let (return_ty, call_expression_tcx) = return_with_tcx?;
|
||||
|
||||
// A type variable is not a useful type-context for expression inference, and applying it
|
||||
// to the return type can lead to confusing unions in nested generic calls.
|
||||
|
|
@ -2781,8 +2819,8 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
return None;
|
||||
}
|
||||
|
||||
// If the return type is already assignable to the annotated type, we can ignore the
|
||||
// type context and prefer the narrower inferred type.
|
||||
// If the return type is already assignable to the annotated type, we ignore the rest of
|
||||
// the type context and prefer the narrower inferred type.
|
||||
if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) {
|
||||
return None;
|
||||
}
|
||||
|
|
@ -2791,8 +2829,8 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
// annotated assignment, to closer match the order of any unions written in the type annotation.
|
||||
builder.infer(return_ty, call_expression_tcx).ok()?;
|
||||
|
||||
// Otherwise, build the specialization again after inferring the type context.
|
||||
let specialization = builder.build(generic_context, *self.call_expression_tcx);
|
||||
// Otherwise, build the specialization again after inferring the complete type context.
|
||||
let specialization = builder.build(generic_context, self.call_expression_tcx);
|
||||
let return_ty = return_ty.apply_specialization(self.db, specialization);
|
||||
|
||||
Some((Some(specialization), return_ty))
|
||||
|
|
@ -3051,7 +3089,7 @@ impl<'db> MatchedArgument<'db> {
|
|||
pub(crate) struct UnknownParameterNameError;
|
||||
|
||||
/// Binding information for one of the overloads of a callable.
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Binding<'db> {
|
||||
pub(crate) signature: Signature<'db>,
|
||||
|
||||
|
|
@ -3150,7 +3188,7 @@ impl<'db> Binding<'db> {
|
|||
&mut self,
|
||||
db: &'db dyn Db,
|
||||
arguments: &CallArguments<'_, 'db>,
|
||||
call_expression_tcx: &TypeContext<'db>,
|
||||
call_expression_tcx: TypeContext<'db>,
|
||||
) {
|
||||
let mut checker = ArgumentTypeChecker::new(
|
||||
db,
|
||||
|
|
|
|||
|
|
@ -258,7 +258,7 @@ impl<'db> GenericAlias<'db> {
|
|||
) -> Self {
|
||||
let tcx = tcx
|
||||
.annotation
|
||||
.and_then(|ty| ty.specialization_of(db, Some(self.origin(db))))
|
||||
.and_then(|ty| ty.specialization_of(db, self.origin(db)))
|
||||
.map(|specialization| specialization.types(db))
|
||||
.unwrap_or(&[]);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
use std::cell::RefCell;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::fmt::Display;
|
||||
|
||||
use itertools::Itertools;
|
||||
|
|
@ -1315,6 +1316,11 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Returns the current set of type mappings for this specialization.
|
||||
pub(crate) fn type_mappings(&self) -> &FxHashMap<BoundTypeVarIdentity<'db>, Type<'db>> {
|
||||
&self.types
|
||||
}
|
||||
|
||||
pub(crate) fn build(
|
||||
&mut self,
|
||||
generic_context: GenericContext<'db>,
|
||||
|
|
@ -1322,7 +1328,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
) -> Specialization<'db> {
|
||||
let tcx_specialization = tcx
|
||||
.annotation
|
||||
.and_then(|annotation| annotation.specialization_of(self.db, None));
|
||||
.and_then(|annotation| annotation.class_specialization(self.db));
|
||||
|
||||
let types =
|
||||
(generic_context.variables_inner(self.db).iter()).map(|(identity, variable)| {
|
||||
|
|
@ -1345,19 +1351,43 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
generic_context.specialize_partial(self.db, types)
|
||||
}
|
||||
|
||||
fn add_type_mapping(&mut self, bound_typevar: BoundTypeVarInstance<'db>, ty: Type<'db>) {
|
||||
self.types
|
||||
.entry(bound_typevar.identity(self.db))
|
||||
.and_modify(|existing| {
|
||||
*existing = UnionType::from_elements(self.db, [*existing, ty]);
|
||||
})
|
||||
.or_insert(ty);
|
||||
fn add_type_mapping(
|
||||
&mut self,
|
||||
bound_typevar: BoundTypeVarInstance<'db>,
|
||||
ty: Type<'db>,
|
||||
filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
|
||||
) {
|
||||
let identity = bound_typevar.identity(self.db);
|
||||
match self.types.entry(identity) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
if filter(identity, ty) {
|
||||
*entry.get_mut() = UnionType::from_elements(self.db, [*entry.get(), ty]);
|
||||
}
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(ty);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Infer type mappings for the specialization based on a given type and its declared type.
|
||||
pub(crate) fn infer(
|
||||
&mut self,
|
||||
formal: Type<'db>,
|
||||
actual: Type<'db>,
|
||||
) -> Result<(), SpecializationError<'db>> {
|
||||
self.infer_filter(formal, actual, |_, _| true)
|
||||
}
|
||||
|
||||
/// Infer type mappings for the specialization based on a given type and its declared type.
|
||||
///
|
||||
/// The filter predicate is provided with a type variable and the type being mapped to it. Type
|
||||
/// mappings to which the predicate returns `false` will be ignored.
|
||||
pub(crate) fn infer_filter(
|
||||
&mut self,
|
||||
formal: Type<'db>,
|
||||
actual: Type<'db>,
|
||||
filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
|
||||
) -> Result<(), SpecializationError<'db>> {
|
||||
if formal == actual {
|
||||
return Ok(());
|
||||
|
|
@ -1391,8 +1421,8 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
// Remove the union elements from `actual` that are not related to `formal`, and vice
|
||||
// versa.
|
||||
//
|
||||
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T`
|
||||
// to `int`, and so ignore the `None`.
|
||||
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to
|
||||
// specialize `T` to `int`, and so ignore the `None`.
|
||||
let actual = actual.filter_disjoint_elements(self.db, formal, self.inferable);
|
||||
let formal = formal.filter_disjoint_elements(self.db, actual, self.inferable);
|
||||
|
||||
|
|
@ -1440,7 +1470,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
if remaining_actual.is_never() {
|
||||
return Ok(());
|
||||
}
|
||||
self.add_type_mapping(*formal_bound_typevar, remaining_actual);
|
||||
self.add_type_mapping(*formal_bound_typevar, remaining_actual, filter);
|
||||
}
|
||||
(Type::Union(formal), _) => {
|
||||
// Second, if the formal is a union, and precisely one union element _is_ a typevar (not
|
||||
|
|
@ -1450,7 +1480,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
let bound_typevars =
|
||||
(formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
|
||||
if let Ok(bound_typevar) = bound_typevars.exactly_one() {
|
||||
self.add_type_mapping(bound_typevar, actual);
|
||||
self.add_type_mapping(bound_typevar, actual, filter);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1478,13 +1508,13 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
argument: ty,
|
||||
});
|
||||
}
|
||||
self.add_type_mapping(bound_typevar, ty);
|
||||
self.add_type_mapping(bound_typevar, ty, filter);
|
||||
}
|
||||
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
|
||||
// Prefer an exact match first.
|
||||
for constraint in constraints.elements(self.db) {
|
||||
if ty == *constraint {
|
||||
self.add_type_mapping(bound_typevar, ty);
|
||||
self.add_type_mapping(bound_typevar, ty, filter);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
|
@ -1494,7 +1524,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
.when_assignable_to(self.db, *constraint, self.inferable)
|
||||
.is_always_satisfied(self.db)
|
||||
{
|
||||
self.add_type_mapping(bound_typevar, *constraint);
|
||||
self.add_type_mapping(bound_typevar, *constraint, filter);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
|
@ -1504,7 +1534,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
});
|
||||
}
|
||||
_ => {
|
||||
self.add_type_mapping(bound_typevar, ty);
|
||||
self.add_type_mapping(bound_typevar, ty, filter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use std::{iter, mem};
|
||||
use std::iter;
|
||||
|
||||
use itertools::{Either, Itertools};
|
||||
use ruff_db::diagnostic::{Annotation, DiagnosticId, Severity};
|
||||
|
|
@ -211,6 +211,7 @@ const NUM_FIELD_SPECIFIERS_INLINE: usize = 1;
|
|||
/// don't infer its types more than once.
|
||||
pub(super) struct TypeInferenceBuilder<'db, 'ast> {
|
||||
context: InferContext<'db, 'ast>,
|
||||
|
||||
index: &'db SemanticIndex<'db>,
|
||||
region: InferenceRegion<'db>,
|
||||
|
||||
|
|
@ -349,16 +350,19 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
assert_eq!(self.scope, inference.scope);
|
||||
|
||||
self.expressions.extend(inference.expressions.iter());
|
||||
self.declarations.extend(inference.declarations());
|
||||
self.declarations
|
||||
.extend(inference.declarations(), self.multi_inference_state);
|
||||
|
||||
if !matches!(self.region, InferenceRegion::Scope(..)) {
|
||||
self.bindings.extend(inference.bindings());
|
||||
self.bindings
|
||||
.extend(inference.bindings(), self.multi_inference_state);
|
||||
}
|
||||
|
||||
if let Some(extra) = &inference.extra {
|
||||
self.extend_cycle_recovery(extra.cycle_recovery);
|
||||
self.context.extend(&extra.diagnostics);
|
||||
self.deferred.extend(extra.deferred.iter().copied());
|
||||
self.deferred
|
||||
.extend(extra.deferred.iter().copied(), self.multi_inference_state);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -377,7 +381,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
self.extend_cycle_recovery(extra.cycle_recovery);
|
||||
|
||||
if !matches!(self.region, InferenceRegion::Scope(..)) {
|
||||
self.bindings.extend(extra.bindings.iter().copied());
|
||||
self.bindings
|
||||
.extend(extra.bindings.iter().copied(), self.multi_inference_state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -398,6 +403,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
self.scope
|
||||
}
|
||||
|
||||
/// Set the multi-inference state, returning the previous value.
|
||||
fn set_multi_inference_state(&mut self, state: MultiInferenceState) -> MultiInferenceState {
|
||||
std::mem::replace(&mut self.multi_inference_state, state)
|
||||
}
|
||||
|
||||
/// Are we currently inferring types in file with deferred types?
|
||||
/// This is true for stub files, for files with `__future__.annotations`, and
|
||||
/// by default for all source files in Python 3.14 and later.
|
||||
|
|
@ -1637,7 +1647,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
self.bindings.insert(binding, bound_ty);
|
||||
self.bindings
|
||||
.insert(binding, bound_ty, self.multi_inference_state);
|
||||
|
||||
inferred_ty
|
||||
}
|
||||
|
|
@ -1704,7 +1715,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
}
|
||||
TypeAndQualifiers::declared(Type::unknown())
|
||||
};
|
||||
self.declarations.insert(declaration, ty);
|
||||
self.declarations
|
||||
.insert(declaration, ty, self.multi_inference_state);
|
||||
}
|
||||
|
||||
fn add_declaration_with_binding(
|
||||
|
|
@ -1778,8 +1790,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
}
|
||||
}
|
||||
};
|
||||
self.declarations.insert(definition, declared_ty);
|
||||
self.bindings.insert(definition, inferred_ty);
|
||||
self.declarations
|
||||
.insert(definition, declared_ty, self.multi_inference_state);
|
||||
self.bindings
|
||||
.insert(definition, inferred_ty, self.multi_inference_state);
|
||||
}
|
||||
|
||||
fn add_unknown_declaration_with_binding(
|
||||
|
|
@ -2198,7 +2212,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
// `infer_function_type_params`, rather than here.
|
||||
if type_params.is_none() {
|
||||
if self.defer_annotations() {
|
||||
self.deferred.insert(definition);
|
||||
self.deferred.insert(definition, self.multi_inference_state);
|
||||
} else {
|
||||
let previous_typevar_binding_context =
|
||||
self.typevar_binding_context.replace(definition);
|
||||
|
|
@ -2756,7 +2770,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
|
||||
// Inference of bases deferred in stubs, or if any are string literals.
|
||||
if self.in_stub() || class_node.bases().iter().any(contains_string_literal) {
|
||||
self.deferred.insert(definition);
|
||||
self.deferred.insert(definition, self.multi_inference_state);
|
||||
} else {
|
||||
let previous_typevar_binding_context =
|
||||
self.typevar_binding_context.replace(definition);
|
||||
|
|
@ -3126,7 +3140,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
None => None,
|
||||
};
|
||||
if bound_or_constraint.is_some() || default.is_some() {
|
||||
self.deferred.insert(definition);
|
||||
self.deferred.insert(definition, self.multi_inference_state);
|
||||
}
|
||||
let identity =
|
||||
TypeVarIdentity::new(self.db(), &name.id, Some(definition), TypeVarKind::Pep695);
|
||||
|
|
@ -3190,7 +3204,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
default,
|
||||
} = node;
|
||||
if default.is_some() {
|
||||
self.deferred.insert(definition);
|
||||
self.deferred.insert(definition, self.multi_inference_state);
|
||||
}
|
||||
let identity = TypeVarIdentity::new(
|
||||
self.db(),
|
||||
|
|
@ -3680,10 +3694,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
// Overwrite the previously inferred value, preferring later inferences, which are
|
||||
// likely more precise. Note that we still ensure each inference is assignable to
|
||||
// its declared type, so this mainly affects the IDE hover type.
|
||||
let prev_multi_inference_state = mem::replace(
|
||||
&mut builder.multi_inference_state,
|
||||
MultiInferenceState::Overwrite,
|
||||
);
|
||||
let prev_multi_inference_state =
|
||||
builder.set_multi_inference_state(MultiInferenceState::Overwrite);
|
||||
|
||||
// If we are inferring the argument multiple times, silence diagnostics to avoid duplicated warnings.
|
||||
let was_in_multi_inference = if let Some(first_tcx) = first_tcx {
|
||||
|
|
@ -4625,7 +4637,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
}
|
||||
|
||||
if default.is_some() {
|
||||
self.deferred.insert(definition);
|
||||
self.deferred.insert(definition, self.multi_inference_state);
|
||||
}
|
||||
|
||||
let identity =
|
||||
|
|
@ -4867,7 +4879,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
};
|
||||
|
||||
if bound_or_constraints.is_some() || default.is_some() {
|
||||
self.deferred.insert(definition);
|
||||
self.deferred.insert(definition, self.multi_inference_state);
|
||||
}
|
||||
|
||||
let identity = TypeVarIdentity::new(db, target_name, Some(definition), TypeVarKind::Legacy);
|
||||
|
|
@ -5961,27 +5973,156 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Infer the argument types for multiple potential bindings and overloads.
|
||||
fn infer_all_argument_types<'a>(
|
||||
fn infer_and_check_argument_types(
|
||||
&mut self,
|
||||
ast_arguments: &ast::Arguments,
|
||||
arguments: &mut CallArguments<'a, 'db>,
|
||||
bindings: &Bindings<'db>,
|
||||
) {
|
||||
debug_assert!(
|
||||
ast_arguments.len() == arguments.len()
|
||||
&& arguments.len() == bindings.argument_forms().len()
|
||||
argument_types: &mut CallArguments<'_, 'db>,
|
||||
bindings: &mut Bindings<'db>,
|
||||
call_expression_tcx: TypeContext<'db>,
|
||||
) -> Result<(), CallErrorKind> {
|
||||
let db = self.db();
|
||||
|
||||
// If the type context is a union, attempt to narrow to a specific element.
|
||||
let narrow_targets: &[_] = match call_expression_tcx.annotation {
|
||||
// TODO: We could theoretically attempt to narrow to every element of
|
||||
// the power set of this union. However, this leads to an exponential
|
||||
// explosion of inference attempts, and is rarely needed in practice.
|
||||
Some(Type::Union(union)) => union.elements(db),
|
||||
_ => &[],
|
||||
};
|
||||
|
||||
// We silence diagnostics until we successfully narrow to a specific type.
|
||||
let mut speculated_bindings = bindings.clone();
|
||||
let was_in_multi_inference = self.context.set_multi_inference(true);
|
||||
|
||||
let mut try_narrow = |narrowed_ty| {
|
||||
let narrowed_tcx = TypeContext::new(Some(narrowed_ty));
|
||||
|
||||
// Attempt to infer the argument types using the narrowed type context.
|
||||
self.infer_all_argument_types(
|
||||
ast_arguments,
|
||||
argument_types,
|
||||
bindings,
|
||||
narrowed_tcx,
|
||||
MultiInferenceState::Ignore,
|
||||
);
|
||||
|
||||
// Ensure the argument types match their annotated types.
|
||||
if speculated_bindings
|
||||
.check_types_impl(
|
||||
db,
|
||||
argument_types,
|
||||
narrowed_tcx,
|
||||
&self.dataclass_field_specifiers,
|
||||
)
|
||||
.is_err()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
// Ensure the inferred return type is assignable to the (narrowed) declared type.
|
||||
//
|
||||
// TODO: Checking assignability against the full declared type could help avoid
|
||||
// cases where the constraint solver is not smart enough to solve complex unions.
|
||||
// We should see revisit this after the new constraint solver is implemented.
|
||||
if !speculated_bindings
|
||||
.return_type(db)
|
||||
.is_assignable_to(db, narrowed_ty)
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
// Successfully narrowed to an element of the union.
|
||||
//
|
||||
// If necessary, infer the argument types again with diagnostics enabled.
|
||||
if !was_in_multi_inference {
|
||||
self.context.set_multi_inference(was_in_multi_inference);
|
||||
|
||||
self.infer_all_argument_types(
|
||||
ast_arguments,
|
||||
argument_types,
|
||||
bindings,
|
||||
narrowed_tcx,
|
||||
MultiInferenceState::Intersect,
|
||||
);
|
||||
}
|
||||
|
||||
Some(bindings.check_types_impl(
|
||||
db,
|
||||
argument_types,
|
||||
narrowed_tcx,
|
||||
&self.dataclass_field_specifiers,
|
||||
))
|
||||
};
|
||||
|
||||
// Prefer the declared type of generic classes.
|
||||
for narrowed_ty in narrow_targets
|
||||
.iter()
|
||||
.filter(|ty| ty.class_specialization(db).is_some())
|
||||
{
|
||||
if let Some(result) = try_narrow(*narrowed_ty) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Try the remaining elements of the union.
|
||||
//
|
||||
// TODO: We could also attempt an inference without type context, but this
|
||||
// leads to similar performance issues.
|
||||
for narrowed_ty in narrow_targets
|
||||
.iter()
|
||||
.filter(|ty| ty.class_specialization(db).is_none())
|
||||
{
|
||||
if let Some(result) = try_narrow(*narrowed_ty) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Re-enable diagnostics, and infer against the entire union as a fallback.
|
||||
self.context.set_multi_inference(was_in_multi_inference);
|
||||
|
||||
self.infer_all_argument_types(
|
||||
ast_arguments,
|
||||
argument_types,
|
||||
bindings,
|
||||
call_expression_tcx,
|
||||
MultiInferenceState::Intersect,
|
||||
);
|
||||
|
||||
bindings.check_types_impl(
|
||||
db,
|
||||
argument_types,
|
||||
call_expression_tcx,
|
||||
&self.dataclass_field_specifiers,
|
||||
)
|
||||
}
|
||||
|
||||
/// Infer the argument types for all bindings.
|
||||
///
|
||||
/// Note that this method may infer the type of a given argument expression multiple times with
|
||||
/// distinct type context. The provided `MultiInferenceState` can be used to dictate multi-inference
|
||||
/// behavior.
|
||||
fn infer_all_argument_types(
|
||||
&mut self,
|
||||
ast_arguments: &ast::Arguments,
|
||||
arguments_types: &mut CallArguments<'_, 'db>,
|
||||
bindings: &Bindings<'db>,
|
||||
call_expression_tcx: TypeContext<'db>,
|
||||
multi_inference_state: MultiInferenceState,
|
||||
) {
|
||||
debug_assert_eq!(ast_arguments.len(), arguments_types.len());
|
||||
debug_assert_eq!(arguments_types.len(), bindings.argument_forms().len());
|
||||
|
||||
let db = self.db();
|
||||
let iter = itertools::izip!(
|
||||
0..,
|
||||
arguments.iter_mut(),
|
||||
arguments_types.iter_mut(),
|
||||
bindings.argument_forms().iter().copied(),
|
||||
ast_arguments.arguments_source_order()
|
||||
);
|
||||
|
||||
let overloads_with_binding = bindings
|
||||
.into_iter()
|
||||
.iter()
|
||||
.filter_map(|binding| {
|
||||
match binding.matching_overload_index() {
|
||||
MatchingOverloadIndex::Single(_) | MatchingOverloadIndex::Multiple(_) => {
|
||||
|
|
@ -6000,7 +6141,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
},
|
||||
}
|
||||
})
|
||||
.flatten();
|
||||
.flatten()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let old_multi_inference_state = self.set_multi_inference_state(multi_inference_state);
|
||||
|
||||
for (argument_index, (_, argument_type), argument_form, ast_argument) in iter {
|
||||
let ast_argument = match ast_argument {
|
||||
|
|
@ -6022,7 +6166,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
}
|
||||
|
||||
// Retrieve the parameter type for the current argument in a given overload and its binding.
|
||||
let db = self.db();
|
||||
let parameter_type = |overload: &Binding<'db>, binding: &CallableBinding<'db>| {
|
||||
let argument_index = if binding.bound_type.is_some() {
|
||||
argument_index + 1
|
||||
|
|
@ -6035,10 +6178,25 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
return None;
|
||||
};
|
||||
|
||||
let parameter_type =
|
||||
let mut parameter_type =
|
||||
overload.signature.parameters()[*parameter_index].annotated_type()?;
|
||||
|
||||
// TODO: For now, skip any parameter annotations that mention any typevars. There
|
||||
// If this is a generic call, attempt to specialize the parameter type using the
|
||||
// declared type context, if provided.
|
||||
if let Some(generic_context) = overload.signature.generic_context
|
||||
&& let Some(return_ty) = overload.signature.return_ty
|
||||
&& let Some(declared_return_ty) = call_expression_tcx.annotation
|
||||
{
|
||||
let mut builder =
|
||||
SpecializationBuilder::new(db, generic_context.inferable_typevars(db));
|
||||
|
||||
let _ = builder.infer(return_ty, declared_return_ty);
|
||||
let specialization = builder.build(generic_context, call_expression_tcx);
|
||||
|
||||
parameter_type = parameter_type.apply_specialization(db, specialization);
|
||||
}
|
||||
|
||||
// TODO: For now, skip any parameter annotations that still mention any typevars. There
|
||||
// are two issues:
|
||||
//
|
||||
// First, if we include those typevars in the type context that we use to infer the
|
||||
|
|
@ -6069,26 +6227,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
|
||||
// If there is only a single binding and overload, we can infer the argument directly with
|
||||
// the unique parameter type annotation.
|
||||
if let Ok((overload, binding)) = overloads_with_binding.clone().exactly_one() {
|
||||
self.infer_expression_impl(
|
||||
if let Ok((overload, binding)) = overloads_with_binding.iter().exactly_one() {
|
||||
*argument_type = Some(self.infer_expression(
|
||||
ast_argument,
|
||||
TypeContext::new(parameter_type(overload, binding)),
|
||||
);
|
||||
));
|
||||
} else {
|
||||
// Otherwise, each type is a valid independent inference of the given argument, and we may
|
||||
// require different permutations of argument types to correctly perform argument expansion
|
||||
// during overload evaluation, so we take the intersection of all the types we inferred for
|
||||
// each argument.
|
||||
//
|
||||
// Note that this applies to all nested expressions within each argument.
|
||||
let old_multi_inference_state = mem::replace(
|
||||
&mut self.multi_inference_state,
|
||||
MultiInferenceState::Intersect,
|
||||
);
|
||||
|
||||
// We perform inference once without any type context, emitting any diagnostics that are unrelated
|
||||
// to bidirectional type inference.
|
||||
self.infer_expression_impl(ast_argument, TypeContext::default());
|
||||
*argument_type = Some(self.infer_expression(ast_argument, TypeContext::default()));
|
||||
|
||||
// We then silence any diagnostics emitted during multi-inference, as the type context is only
|
||||
// used as a hint to infer a more assignable argument type, and should not lead to diagnostics
|
||||
|
|
@ -6097,24 +6244,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
|
||||
// Infer the type of each argument once with each distinct parameter type as type context.
|
||||
let parameter_types = overloads_with_binding
|
||||
.clone()
|
||||
.iter()
|
||||
.filter_map(|(overload, binding)| parameter_type(overload, binding))
|
||||
.collect::<FxHashSet<_>>();
|
||||
|
||||
for parameter_type in parameter_types {
|
||||
self.infer_expression_impl(
|
||||
ast_argument,
|
||||
TypeContext::new(Some(parameter_type)),
|
||||
);
|
||||
let inferred_ty =
|
||||
self.infer_expression(ast_argument, TypeContext::new(Some(parameter_type)));
|
||||
|
||||
// Each type is a valid independent inference of the given argument, and we may require different
|
||||
// permutations of argument types to correctly perform argument expansion during overload evaluation,
|
||||
// so we take the intersection of all the types we inferred for each argument.
|
||||
*argument_type = argument_type
|
||||
.map(|current| IntersectionType::from_elements(db, [inferred_ty, current]))
|
||||
.or(Some(inferred_ty));
|
||||
}
|
||||
|
||||
// Restore the multi-inference state.
|
||||
self.multi_inference_state = old_multi_inference_state;
|
||||
// Re-enable diagnostics.
|
||||
self.context.set_multi_inference(was_in_multi_inference);
|
||||
}
|
||||
|
||||
*argument_type = self.try_expression_type(ast_argument);
|
||||
}
|
||||
|
||||
self.set_multi_inference_state(old_multi_inference_state);
|
||||
}
|
||||
|
||||
fn infer_argument_type(
|
||||
|
|
@ -6275,6 +6426,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
let db = self.db();
|
||||
|
||||
match self.multi_inference_state {
|
||||
MultiInferenceState::Ignore => {}
|
||||
|
||||
MultiInferenceState::Panic => {
|
||||
let previous = self.expressions.insert(expression.into(), ty);
|
||||
assert_eq!(previous, None);
|
||||
|
|
@ -6593,7 +6746,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
}
|
||||
|
||||
validate_typed_dict_dict_literal(&self.context, typed_dict, dict, dict.into(), |expr| {
|
||||
self.expression_type(expr)
|
||||
item_types
|
||||
.get(&expr.node_index().load())
|
||||
.copied()
|
||||
.unwrap_or(Type::unknown())
|
||||
})
|
||||
.ok()
|
||||
.map(|_| Type::TypedDict(typed_dict))
|
||||
|
|
@ -7356,7 +7512,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
let infer_call_arguments = |bindings: Option<Bindings<'db>>| {
|
||||
if let Some(bindings) = bindings {
|
||||
let bindings = bindings.match_parameters(self.db(), &call_arguments);
|
||||
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings);
|
||||
self.infer_all_argument_types(
|
||||
arguments,
|
||||
&mut call_arguments,
|
||||
&bindings,
|
||||
tcx,
|
||||
MultiInferenceState::Intersect,
|
||||
);
|
||||
} else {
|
||||
let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()];
|
||||
self.infer_argument_types(arguments, &mut call_arguments, &argument_forms);
|
||||
|
|
@ -7374,10 +7536,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
let bindings = callable_type
|
||||
let mut bindings = callable_type
|
||||
.bindings(self.db())
|
||||
.match_parameters(self.db(), &call_arguments);
|
||||
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings);
|
||||
|
||||
let bindings_result =
|
||||
self.infer_and_check_argument_types(arguments, &mut call_arguments, &mut bindings, tcx);
|
||||
|
||||
// Validate `TypedDict` constructor calls after argument type inference
|
||||
if let Some(class_literal) = callable_type.as_class_literal() {
|
||||
|
|
@ -7395,14 +7559,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
let mut bindings = match bindings.check_types(
|
||||
self.db(),
|
||||
&call_arguments,
|
||||
&tcx,
|
||||
&self.dataclass_field_specifiers[..],
|
||||
) {
|
||||
Ok(bindings) => bindings,
|
||||
Err(CallError(_, bindings)) => {
|
||||
let mut bindings = match bindings_result {
|
||||
Ok(()) => bindings,
|
||||
Err(_) => {
|
||||
bindings.report_diagnostics(&self.context, call_expression.into());
|
||||
return bindings.return_type(self.db());
|
||||
}
|
||||
|
|
@ -10100,8 +10259,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
.check_types(
|
||||
self.db(),
|
||||
&call_argument_types,
|
||||
&TypeContext::default(),
|
||||
&self.dataclass_field_specifiers[..],
|
||||
TypeContext::default(),
|
||||
&self.dataclass_field_specifiers,
|
||||
) {
|
||||
Ok(bindings) => bindings,
|
||||
Err(CallError(_, bindings)) => {
|
||||
|
|
@ -10833,8 +10992,14 @@ enum MultiInferenceState {
|
|||
Panic,
|
||||
|
||||
/// Overwrite the previously inferred value.
|
||||
///
|
||||
/// Note that `Overwrite` does not interact well with nested inferences:
|
||||
/// it overwrites values that were written with `MultiInferenceState::Intersect`.
|
||||
Overwrite,
|
||||
|
||||
/// Ignore the newly inferred value.
|
||||
Ignore,
|
||||
|
||||
/// Store the intersection of all types inferred for the expression.
|
||||
Intersect,
|
||||
}
|
||||
|
|
@ -11078,7 +11243,11 @@ where
|
|||
self.0.iter().map(|(k, v)| (k, v))
|
||||
}
|
||||
|
||||
fn insert(&mut self, key: K, value: V) {
|
||||
fn insert(&mut self, key: K, value: V, multi_inference_state: MultiInferenceState) {
|
||||
if matches!(multi_inference_state, MultiInferenceState::Ignore) {
|
||||
return;
|
||||
}
|
||||
|
||||
debug_assert!(
|
||||
!self.0.iter().any(|(existing, _)| existing == &key),
|
||||
"An existing entry already exists for key {key:?}",
|
||||
|
|
@ -11092,17 +11261,21 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<K, V> Extend<(K, V)> for VecMap<K, V>
|
||||
impl<K, V> VecMap<K, V>
|
||||
where
|
||||
K: Eq,
|
||||
K: std::fmt::Debug,
|
||||
V: std::fmt::Debug,
|
||||
{
|
||||
#[inline]
|
||||
fn extend<T: IntoIterator<Item = (K, V)>>(&mut self, iter: T) {
|
||||
fn extend<T: IntoIterator<Item = (K, V)>>(
|
||||
&mut self,
|
||||
iter: T,
|
||||
multi_inference_state: MultiInferenceState,
|
||||
) {
|
||||
if cfg!(debug_assertions) {
|
||||
for (key, value) in iter {
|
||||
self.insert(key, value);
|
||||
self.insert(key, value, multi_inference_state);
|
||||
}
|
||||
} else {
|
||||
self.0.extend(iter);
|
||||
|
|
@ -11140,7 +11313,11 @@ where
|
|||
V: Eq,
|
||||
V: std::fmt::Debug,
|
||||
{
|
||||
fn insert(&mut self, value: V) {
|
||||
fn insert(&mut self, value: V, multi_inference_state: MultiInferenceState) {
|
||||
if matches!(multi_inference_state, MultiInferenceState::Ignore) {
|
||||
return;
|
||||
}
|
||||
|
||||
debug_assert!(
|
||||
!self.0.iter().any(|existing| existing == &value),
|
||||
"An existing entry already exists for {value:?}",
|
||||
|
|
@ -11150,16 +11327,20 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<V> Extend<V> for VecSet<V>
|
||||
impl<V> VecSet<V>
|
||||
where
|
||||
V: Eq,
|
||||
V: std::fmt::Debug,
|
||||
{
|
||||
#[inline]
|
||||
fn extend<T: IntoIterator<Item = V>>(&mut self, iter: T) {
|
||||
fn extend<T: IntoIterator<Item = V>>(
|
||||
&mut self,
|
||||
iter: T,
|
||||
multi_inference_state: MultiInferenceState,
|
||||
) {
|
||||
if cfg!(debug_assertions) {
|
||||
for value in iter {
|
||||
self.insert(value);
|
||||
self.insert(value, multi_inference_state);
|
||||
}
|
||||
} else {
|
||||
self.0.extend(iter);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue