[ty] Improve generic call expression inference (#21210)

## Summary

Implements https://github.com/astral-sh/ty/issues/1356 and https://github.com/astral-sh/ty/issues/136#issuecomment-3413669994.
This commit is contained in:
Ibraheem Ahmed 2025-11-10 16:29:05 -05:00 committed by GitHub
parent d258302b08
commit 98869f0307
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 655 additions and 153 deletions

View file

@ -417,6 +417,8 @@ reveal_type(x) # revealed: Literal[1]
python-version = "3.12"
```
`generic_list.py`:
```py
from typing import Literal
@ -427,14 +429,13 @@ a = f("a")
reveal_type(a) # revealed: list[Literal["a"]]
b: list[int | Literal["a"]] = f("a")
reveal_type(b) # revealed: list[Literal["a"] | int]
reveal_type(b) # revealed: list[int | Literal["a"]]
c: list[int | str] = f("a")
reveal_type(c) # revealed: list[str | int]
reveal_type(c) # revealed: list[int | str]
d: list[int | tuple[int, int]] = f((1, 2))
# TODO: We could avoid reordering the union elements here.
reveal_type(d) # revealed: list[tuple[int, int] | int]
reveal_type(d) # revealed: list[int | tuple[int, int]]
e: list[int] = f(True)
reveal_type(e) # revealed: list[int]
@ -455,10 +456,218 @@ j: int | str = f2(True)
reveal_type(j) # revealed: Literal[True]
```
Types are not widened unnecessarily:
A function's arguments are also inferred using the type context:
`typed_dict.py`:
```py
def id[T](x: T) -> T:
from typing import TypedDict
class TD(TypedDict):
x: int
def f[T](x: list[T]) -> T:
return x[0]
a: TD = f([{"x": 0}, {"x": 1}])
reveal_type(a) # revealed: TD
b: TD | None = f([{"x": 0}, {"x": 1}])
reveal_type(b) # revealed: TD
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor"
# error: [invalid-key] "Invalid key for TypedDict `TD`: Unknown key "y""
# error: [invalid-assignment] "Object of type `Unknown | dict[Unknown | str, Unknown | int]` is not assignable to `TD`"
c: TD = f([{"y": 0}, {"x": 1}])
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor"
# error: [invalid-key] "Invalid key for TypedDict `TD`: Unknown key "y""
# error: [invalid-assignment] "Object of type `Unknown | dict[Unknown | str, Unknown | int]` is not assignable to `TD | None`"
c: TD | None = f([{"y": 0}, {"x": 1}])
```
But not in a way that leads to assignability errors:
`dict_any.py`:
```py
from typing import TypedDict, Any
class TD(TypedDict, total=False):
x: str
class TD2(TypedDict):
x: str
def f(self, dt: dict[str, Any], key: str):
# TODO: This should not error once typed dict assignability is implemented.
# error: [invalid-assignment]
x1: TD = dt.get(key, {})
reveal_type(x1) # revealed: TD
x2: TD = dt.get(key, {"x": 0})
reveal_type(x2) # revealed: Any
x3: TD | None = dt.get(key, {})
# TODO: This should reveal `Any` once typed dict assignability is implemented.
reveal_type(x3) # revealed: Any | None
x4: TD | None = dt.get(key, {"x": 0})
reveal_type(x4) # revealed: Any
x5: TD2 = dt.get(key, {})
reveal_type(x5) # revealed: Any
x6: TD2 = dt.get(key, {"x": 0})
reveal_type(x6) # revealed: Any
x7: TD2 | None = dt.get(key, {})
reveal_type(x7) # revealed: Any
x8: TD2 | None = dt.get(key, {"x": 0})
reveal_type(x8) # revealed: Any
```
## Prefer the declared type of generic classes
```toml
[environment]
python-version = "3.12"
```
```py
from typing import Any
def f[T](x: T) -> list[T]:
return [x]
def f2[T](x: T) -> list[T] | None:
return [x]
def f3[T](x: T) -> list[T] | dict[T, T]:
return [x]
a = f(1)
reveal_type(a) # revealed: list[Literal[1]]
b: list[Any] = f(1)
reveal_type(b) # revealed: list[Any]
c: list[Any] = [1]
reveal_type(c) # revealed: list[Any]
d: list[Any] | None = f(1)
reveal_type(d) # revealed: list[Any]
e: list[Any] | None = [1]
reveal_type(e) # revealed: list[Any]
f: list[Any] | None = f2(1)
# TODO: Better constraint solver.
reveal_type(f) # revealed: list[Literal[1]] | None
g: list[Any] | dict[Any, Any] = f3(1)
# TODO: Better constraint solver.
reveal_type(g) # revealed: list[Literal[1]] | dict[Literal[1], Literal[1]]
```
We currently prefer the generic declared type regardless of its variance:
```py
class Bivariant[T]:
pass
class Covariant[T]:
def pop(self) -> T:
raise NotImplementedError
class Contravariant[T]:
def push(self, value: T) -> None:
pass
class Invariant[T]:
x: T
def bivariant[T](x: T) -> Bivariant[T]:
return Bivariant()
def covariant[T](x: T) -> Covariant[T]:
return Covariant()
def contravariant[T](x: T) -> Contravariant[T]:
return Contravariant()
def invariant[T](x: T) -> Invariant[T]:
return Invariant()
x1 = bivariant(1)
x2 = covariant(1)
x3 = contravariant(1)
x4 = invariant(1)
reveal_type(x1) # revealed: Bivariant[Literal[1]]
reveal_type(x2) # revealed: Covariant[Literal[1]]
reveal_type(x3) # revealed: Contravariant[Literal[1]]
reveal_type(x4) # revealed: Invariant[Literal[1]]
x5: Bivariant[Any] = bivariant(1)
x6: Covariant[Any] = covariant(1)
x7: Contravariant[Any] = contravariant(1)
x8: Invariant[Any] = invariant(1)
# TODO: This could reveal `Bivariant[Any]`.
reveal_type(x5) # revealed: Bivariant[Literal[1]]
reveal_type(x6) # revealed: Covariant[Any]
reveal_type(x7) # revealed: Contravariant[Any]
reveal_type(x8) # revealed: Invariant[Any]
```
## Narrow generic unions
```toml
[environment]
python-version = "3.12"
```
```py
from typing import reveal_type, TypedDict
def identity[T](x: T) -> T:
return x
def _(narrow: dict[str, str], target: list[str] | dict[str, str] | None):
target = identity(narrow)
reveal_type(target) # revealed: dict[str, str]
def _(narrow: list[str], target: list[str] | dict[str, str] | None):
target = identity(narrow)
reveal_type(target) # revealed: list[str]
def _(narrow: list[str] | dict[str, str], target: list[str] | dict[str, str] | None):
target = identity(narrow)
reveal_type(target) # revealed: list[str] | dict[str, str]
class TD(TypedDict):
x: int
def _(target: list[TD] | dict[str, TD] | None):
target = identity([{"x": 1}])
reveal_type(target) # revealed: list[TD]
def _(target: list[TD] | dict[str, TD] | None):
target = identity({"x": {"x": 1}})
reveal_type(target) # revealed: dict[str, TD]
```
## Prefer the inferred type of non-generic classes
```toml
[environment]
python-version = "3.12"
```
```py
def identity[T](x: T) -> T:
return x
def lst[T](x: T) -> list[T]:
@ -466,20 +675,18 @@ def lst[T](x: T) -> list[T]:
def _(i: int):
a: int | None = i
b: int | None = id(i)
c: int | str | None = id(i)
b: int | None = identity(i)
c: int | str | None = identity(i)
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int
reveal_type(c) # revealed: int
a: list[int | None] | None = [i]
b: list[int | None] | None = id([i])
c: list[int | None] | int | None = id([i])
b: list[int | None] | None = identity([i])
c: list[int | None] | int | None = identity([i])
reveal_type(a) # revealed: list[int | None]
# TODO: these should reveal `list[int | None]`
# we currently do not use the call expression annotation as type context for argument inference
reveal_type(b) # revealed: list[Unknown | int]
reveal_type(c) # revealed: list[Unknown | int]
reveal_type(b) # revealed: list[int | None]
reveal_type(c) # revealed: list[int | None]
a: list[int | None] | None = [i]
b: list[int | None] | None = lst(i)
@ -489,9 +696,44 @@ def _(i: int):
reveal_type(c) # revealed: list[int | None]
a: list | None = []
b: list | None = id([])
c: list | int | None = id([])
b: list | None = identity([])
c: list | int | None = identity([])
reveal_type(a) # revealed: list[Unknown]
reveal_type(b) # revealed: list[Unknown]
reveal_type(c) # revealed: list[Unknown]
def f[T](x: list[T]) -> T:
return x[0]
def _(a: int, b: str, c: int | str):
x1: int = f(lst(a))
reveal_type(x1) # revealed: int
x2: int | str = f(lst(a))
reveal_type(x2) # revealed: int
x3: int | None = f(lst(a))
reveal_type(x3) # revealed: int
x4: str = f(lst(b))
reveal_type(x4) # revealed: str
x5: int | str = f(lst(b))
reveal_type(x5) # revealed: str
x6: str | None = f(lst(b))
reveal_type(x6) # revealed: str
x7: int | str = f(lst(c))
reveal_type(x7) # revealed: int | str
x8: int | str = f(lst(c))
reveal_type(x8) # revealed: int | str
# TODO: Ideally this would reveal `int | str`. This is a known limitation of our
# call inference solver, and would # require an extra inference attempt without type
# context, or with type context # of subsets of the union, both of which are impractical
# for performance reasons.
x9: int | str | None = f(lst(c))
reveal_type(x9) # revealed: int | str | None
```

View file

@ -50,8 +50,8 @@ def _(l: list[int] | None = None):
def f[T](x: T, cond: bool) -> T | list[T]:
return x if cond else [x]
# TODO: no error
# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`"
# TODO: Better constraint solver.
# error: [invalid-assignment]
l5: int | list[int] = f(1, True)
```

View file

@ -37,7 +37,7 @@ class Data:
content: list[int] = field(default_factory=list)
timestamp: datetime = field(default_factory=datetime.now, init=False)
# revealed: (self: Data, content: list[int] = Unknown) -> None
# revealed: (self: Data, content: list[int] = list[int]) -> None
reveal_type(Data.__init__)
data = Data([1, 2, 3])
@ -63,7 +63,6 @@ class Person:
age: int | None = field(default=None, kw_only=True)
role: str = field(default="user", kw_only=True)
# TODO: this would ideally show a default value of `None` for `age`
# revealed: (self: Person, name: str, *, age: int | None = None, role: str = Literal["user"]) -> None
reveal_type(Person.__init__)

View file

@ -885,20 +885,31 @@ impl<'db> Type<'db> {
}
}
// If the type is a specialized instance of the given `KnownClass`, returns the specialization.
/// If the type is a specialized instance of the given `KnownClass`, returns the specialization.
pub(crate) fn known_specialization(
&self,
db: &'db dyn Db,
known_class: KnownClass,
) -> Option<Specialization<'db>> {
let class_literal = known_class.try_to_class_literal(db)?;
self.specialization_of(db, Some(class_literal))
self.specialization_of(db, class_literal)
}
// If the type is a specialized instance of the given class, returns the specialization.
//
// If no class is provided, returns the specialization of any class instance.
/// If this type is a class instance, returns its specialization.
pub(crate) fn class_specialization(self, db: &'db dyn Db) -> Option<Specialization<'db>> {
self.specialization_of_optional(db, None)
}
/// If the type is a specialized instance of the given class, returns the specialization.
pub(crate) fn specialization_of(
self,
db: &'db dyn Db,
expected_class: ClassLiteral<'_>,
) -> Option<Specialization<'db>> {
self.specialization_of_optional(db, Some(expected_class))
}
fn specialization_of_optional(
self,
db: &'db dyn Db,
expected_class: Option<ClassLiteral<'_>>,
@ -5588,7 +5599,7 @@ impl<'db> Type<'db> {
) -> Result<Bindings<'db>, CallError<'db>> {
self.bindings(db)
.match_parameters(db, argument_types)
.check_types(db, argument_types, &TypeContext::default(), &[])
.check_types(db, argument_types, TypeContext::default(), &[])
}
/// Look up a dunder method on the meta-type of `self` and call it.
@ -5640,7 +5651,8 @@ impl<'db> Type<'db> {
let bindings = dunder_callable
.bindings(db)
.match_parameters(db, argument_types)
.check_types(db, argument_types, &tcx, &[])?;
.check_types(db, argument_types, tcx, &[])?;
if boundness == Definedness::PossiblyUndefined {
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));
}

View file

@ -35,11 +35,11 @@ use crate::types::generics::{
use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters};
use crate::types::tuple::{TupleLength, TupleType};
use crate::types::{
BoundMethodType, ClassLiteral, DataclassFlags, DataclassParams, FieldInstance,
KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, NominalInstanceType,
PropertyInstanceType, SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext,
UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support, infer_isolated_expression,
todo_type,
BoundMethodType, BoundTypeVarIdentity, ClassLiteral, DataclassFlags, DataclassParams,
FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy,
NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
TypeAliasType, TypeContext, UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support,
infer_isolated_expression, todo_type,
};
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
@ -48,7 +48,7 @@ use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
/// compatible with _all_ of the types in the union for the call to be valid.
///
/// It's guaranteed that the wrapped bindings have no errors.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub(crate) struct Bindings<'db> {
/// The type that is (hopefully) callable.
callable_type: Type<'db>,
@ -150,9 +150,27 @@ impl<'db> Bindings<'db> {
mut self,
db: &'db dyn Db,
argument_types: &CallArguments<'_, 'db>,
call_expression_tcx: &TypeContext<'db>,
call_expression_tcx: TypeContext<'db>,
dataclass_field_specifiers: &[Type<'db>],
) -> Result<Self, CallError<'db>> {
match self.check_types_impl(
db,
argument_types,
call_expression_tcx,
dataclass_field_specifiers,
) {
Ok(()) => Ok(self),
Err(err) => Err(CallError(err, Box::new(self))),
}
}
pub(crate) fn check_types_impl(
&mut self,
db: &'db dyn Db,
argument_types: &CallArguments<'_, 'db>,
call_expression_tcx: TypeContext<'db>,
dataclass_field_specifiers: &[Type<'db>],
) -> Result<(), CallErrorKind> {
for element in &mut self.elements {
if let Some(mut updated_argument_forms) =
element.check_types(db, argument_types, call_expression_tcx)
@ -197,16 +215,13 @@ impl<'db> Bindings<'db> {
}
if all_ok {
Ok(self)
Ok(())
} else if any_binding_error {
Err(CallError(CallErrorKind::BindingError, Box::new(self)))
Err(CallErrorKind::BindingError)
} else if all_not_callable {
Err(CallError(CallErrorKind::NotCallable, Box::new(self)))
Err(CallErrorKind::NotCallable)
} else {
Err(CallError(
CallErrorKind::PossiblyNotCallable,
Box::new(self),
))
Err(CallErrorKind::PossiblyNotCallable)
}
}
@ -1365,7 +1380,7 @@ impl<'db> From<Binding<'db>> for Bindings<'db> {
/// If the arguments cannot be matched to formal parameters, we store information about the
/// specific errors that occurred when trying to match them up. If the callable has multiple
/// overloads, we store this error information for each overload.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub(crate) struct CallableBinding<'db> {
/// The type that is (hopefully) callable.
pub(crate) callable_type: Type<'db>,
@ -1486,7 +1501,7 @@ impl<'db> CallableBinding<'db> {
&mut self,
db: &'db dyn Db,
argument_types: &CallArguments<'_, 'db>,
call_expression_tcx: &TypeContext<'db>,
call_expression_tcx: TypeContext<'db>,
) -> Option<ArgumentForms> {
// If this callable is a bound method, prepend the self instance onto the arguments list
// before checking.
@ -2267,7 +2282,7 @@ pub(crate) enum MatchingOverloadIndex {
Multiple(Vec<usize>),
}
#[derive(Default, Debug)]
#[derive(Default, Debug, Clone)]
struct ArgumentForms {
values: Vec<Option<ParameterForm>>,
conflicting: Vec<bool>,
@ -2672,7 +2687,7 @@ struct ArgumentTypeChecker<'a, 'db> {
arguments: &'a CallArguments<'a, 'db>,
argument_matches: &'a [MatchedArgument<'db>],
parameter_tys: &'a mut [Option<Type<'db>>],
call_expression_tcx: &'a TypeContext<'db>,
call_expression_tcx: TypeContext<'db>,
return_ty: Type<'db>,
errors: &'a mut Vec<BindingError<'db>>,
@ -2688,7 +2703,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
arguments: &'a CallArguments<'a, 'db>,
argument_matches: &'a [MatchedArgument<'db>],
parameter_tys: &'a mut [Option<Type<'db>>],
call_expression_tcx: &'a TypeContext<'db>,
call_expression_tcx: TypeContext<'db>,
return_ty: Type<'db>,
errors: &'a mut Vec<BindingError<'db>>,
) -> Self {
@ -2738,9 +2753,21 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
return;
};
let return_with_tcx = self
.signature
.return_ty
.zip(self.call_expression_tcx.annotation);
self.inferable_typevars = generic_context.inferable_typevars(self.db);
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);
// Prefer the declared type of generic classes.
let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| {
tcx.class_specialization(self.db)?;
builder.infer(return_ty, tcx).ok()?;
Some(builder.type_mappings().clone())
});
let parameters = self.signature.parameters();
for (argument_index, adjusted_argument_index, _, argument_type) in
self.enumerate_argument_types()
@ -2753,9 +2780,21 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
continue;
};
if let Err(error) = builder.infer(
let filter = |declared_ty: BoundTypeVarIdentity<'_>, inferred_ty: Type<'_>| {
// Avoid widening the inferred type if it is already assignable to the
// preferred declared type.
preferred_type_mappings
.as_ref()
.and_then(|types| types.get(&declared_ty))
.is_none_or(|preferred_ty| {
!inferred_ty.is_assignable_to(self.db, *preferred_ty)
})
};
if let Err(error) = builder.infer_filter(
expected_type,
variadic_argument_type.unwrap_or(argument_type),
filter,
) {
self.errors.push(BindingError::SpecializationError {
error,
@ -2765,15 +2804,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
}
}
// Build the specialization first without inferring the type context.
let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx);
// Build the specialization first without inferring the complete type context.
let isolated_specialization = builder.build(generic_context, self.call_expression_tcx);
let isolated_return_ty = self
.return_ty
.apply_specialization(self.db, isolated_specialization);
let mut try_infer_tcx = || {
let return_ty = self.signature.return_ty?;
let call_expression_tcx = self.call_expression_tcx.annotation?;
let (return_ty, call_expression_tcx) = return_with_tcx?;
// A type variable is not a useful type-context for expression inference, and applying it
// to the return type can lead to confusing unions in nested generic calls.
@ -2781,8 +2819,8 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
return None;
}
// If the return type is already assignable to the annotated type, we can ignore the
// type context and prefer the narrower inferred type.
// If the return type is already assignable to the annotated type, we ignore the rest of
// the type context and prefer the narrower inferred type.
if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) {
return None;
}
@ -2791,8 +2829,8 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
// annotated assignment, to closer match the order of any unions written in the type annotation.
builder.infer(return_ty, call_expression_tcx).ok()?;
// Otherwise, build the specialization again after inferring the type context.
let specialization = builder.build(generic_context, *self.call_expression_tcx);
// Otherwise, build the specialization again after inferring the complete type context.
let specialization = builder.build(generic_context, self.call_expression_tcx);
let return_ty = return_ty.apply_specialization(self.db, specialization);
Some((Some(specialization), return_ty))
@ -3051,7 +3089,7 @@ impl<'db> MatchedArgument<'db> {
pub(crate) struct UnknownParameterNameError;
/// Binding information for one of the overloads of a callable.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub(crate) struct Binding<'db> {
pub(crate) signature: Signature<'db>,
@ -3150,7 +3188,7 @@ impl<'db> Binding<'db> {
&mut self,
db: &'db dyn Db,
arguments: &CallArguments<'_, 'db>,
call_expression_tcx: &TypeContext<'db>,
call_expression_tcx: TypeContext<'db>,
) {
let mut checker = ArgumentTypeChecker::new(
db,

View file

@ -258,7 +258,7 @@ impl<'db> GenericAlias<'db> {
) -> Self {
let tcx = tcx
.annotation
.and_then(|ty| ty.specialization_of(db, Some(self.origin(db))))
.and_then(|ty| ty.specialization_of(db, self.origin(db)))
.map(|specialization| specialization.types(db))
.unwrap_or(&[]);

View file

@ -1,4 +1,5 @@
use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::fmt::Display;
use itertools::Itertools;
@ -1315,6 +1316,11 @@ impl<'db> SpecializationBuilder<'db> {
}
}
/// Returns the current set of type mappings for this specialization.
pub(crate) fn type_mappings(&self) -> &FxHashMap<BoundTypeVarIdentity<'db>, Type<'db>> {
&self.types
}
pub(crate) fn build(
&mut self,
generic_context: GenericContext<'db>,
@ -1322,7 +1328,7 @@ impl<'db> SpecializationBuilder<'db> {
) -> Specialization<'db> {
let tcx_specialization = tcx
.annotation
.and_then(|annotation| annotation.specialization_of(self.db, None));
.and_then(|annotation| annotation.class_specialization(self.db));
let types =
(generic_context.variables_inner(self.db).iter()).map(|(identity, variable)| {
@ -1345,19 +1351,43 @@ impl<'db> SpecializationBuilder<'db> {
generic_context.specialize_partial(self.db, types)
}
fn add_type_mapping(&mut self, bound_typevar: BoundTypeVarInstance<'db>, ty: Type<'db>) {
self.types
.entry(bound_typevar.identity(self.db))
.and_modify(|existing| {
*existing = UnionType::from_elements(self.db, [*existing, ty]);
})
.or_insert(ty);
fn add_type_mapping(
&mut self,
bound_typevar: BoundTypeVarInstance<'db>,
ty: Type<'db>,
filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
) {
let identity = bound_typevar.identity(self.db);
match self.types.entry(identity) {
Entry::Occupied(mut entry) => {
if filter(identity, ty) {
*entry.get_mut() = UnionType::from_elements(self.db, [*entry.get(), ty]);
}
}
Entry::Vacant(entry) => {
entry.insert(ty);
}
}
}
/// Infer type mappings for the specialization based on a given type and its declared type.
pub(crate) fn infer(
&mut self,
formal: Type<'db>,
actual: Type<'db>,
) -> Result<(), SpecializationError<'db>> {
self.infer_filter(formal, actual, |_, _| true)
}
/// Infer type mappings for the specialization based on a given type and its declared type.
///
/// The filter predicate is provided with a type variable and the type being mapped to it. Type
/// mappings to which the predicate returns `false` will be ignored.
pub(crate) fn infer_filter(
&mut self,
formal: Type<'db>,
actual: Type<'db>,
filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
) -> Result<(), SpecializationError<'db>> {
if formal == actual {
return Ok(());
@ -1391,8 +1421,8 @@ impl<'db> SpecializationBuilder<'db> {
// Remove the union elements from `actual` that are not related to `formal`, and vice
// versa.
//
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T`
// to `int`, and so ignore the `None`.
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to
// specialize `T` to `int`, and so ignore the `None`.
let actual = actual.filter_disjoint_elements(self.db, formal, self.inferable);
let formal = formal.filter_disjoint_elements(self.db, actual, self.inferable);
@ -1440,7 +1470,7 @@ impl<'db> SpecializationBuilder<'db> {
if remaining_actual.is_never() {
return Ok(());
}
self.add_type_mapping(*formal_bound_typevar, remaining_actual);
self.add_type_mapping(*formal_bound_typevar, remaining_actual, filter);
}
(Type::Union(formal), _) => {
// Second, if the formal is a union, and precisely one union element _is_ a typevar (not
@ -1450,7 +1480,7 @@ impl<'db> SpecializationBuilder<'db> {
let bound_typevars =
(formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
if let Ok(bound_typevar) = bound_typevars.exactly_one() {
self.add_type_mapping(bound_typevar, actual);
self.add_type_mapping(bound_typevar, actual, filter);
}
}
@ -1478,13 +1508,13 @@ impl<'db> SpecializationBuilder<'db> {
argument: ty,
});
}
self.add_type_mapping(bound_typevar, ty);
self.add_type_mapping(bound_typevar, ty, filter);
}
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
// Prefer an exact match first.
for constraint in constraints.elements(self.db) {
if ty == *constraint {
self.add_type_mapping(bound_typevar, ty);
self.add_type_mapping(bound_typevar, ty, filter);
return Ok(());
}
}
@ -1494,7 +1524,7 @@ impl<'db> SpecializationBuilder<'db> {
.when_assignable_to(self.db, *constraint, self.inferable)
.is_always_satisfied(self.db)
{
self.add_type_mapping(bound_typevar, *constraint);
self.add_type_mapping(bound_typevar, *constraint, filter);
return Ok(());
}
}
@ -1504,7 +1534,7 @@ impl<'db> SpecializationBuilder<'db> {
});
}
_ => {
self.add_type_mapping(bound_typevar, ty);
self.add_type_mapping(bound_typevar, ty, filter);
}
}
}

View file

@ -1,4 +1,4 @@
use std::{iter, mem};
use std::iter;
use itertools::{Either, Itertools};
use ruff_db::diagnostic::{Annotation, DiagnosticId, Severity};
@ -211,6 +211,7 @@ const NUM_FIELD_SPECIFIERS_INLINE: usize = 1;
/// don't infer its types more than once.
pub(super) struct TypeInferenceBuilder<'db, 'ast> {
context: InferContext<'db, 'ast>,
index: &'db SemanticIndex<'db>,
region: InferenceRegion<'db>,
@ -349,16 +350,19 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
assert_eq!(self.scope, inference.scope);
self.expressions.extend(inference.expressions.iter());
self.declarations.extend(inference.declarations());
self.declarations
.extend(inference.declarations(), self.multi_inference_state);
if !matches!(self.region, InferenceRegion::Scope(..)) {
self.bindings.extend(inference.bindings());
self.bindings
.extend(inference.bindings(), self.multi_inference_state);
}
if let Some(extra) = &inference.extra {
self.extend_cycle_recovery(extra.cycle_recovery);
self.context.extend(&extra.diagnostics);
self.deferred.extend(extra.deferred.iter().copied());
self.deferred
.extend(extra.deferred.iter().copied(), self.multi_inference_state);
}
}
@ -377,7 +381,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.extend_cycle_recovery(extra.cycle_recovery);
if !matches!(self.region, InferenceRegion::Scope(..)) {
self.bindings.extend(extra.bindings.iter().copied());
self.bindings
.extend(extra.bindings.iter().copied(), self.multi_inference_state);
}
}
}
@ -398,6 +403,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.scope
}
/// Set the multi-inference state, returning the previous value.
fn set_multi_inference_state(&mut self, state: MultiInferenceState) -> MultiInferenceState {
std::mem::replace(&mut self.multi_inference_state, state)
}
/// Are we currently inferring types in file with deferred types?
/// This is true for stub files, for files with `__future__.annotations`, and
/// by default for all source files in Python 3.14 and later.
@ -1637,7 +1647,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
self.bindings.insert(binding, bound_ty);
self.bindings
.insert(binding, bound_ty, self.multi_inference_state);
inferred_ty
}
@ -1704,7 +1715,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
TypeAndQualifiers::declared(Type::unknown())
};
self.declarations.insert(declaration, ty);
self.declarations
.insert(declaration, ty, self.multi_inference_state);
}
fn add_declaration_with_binding(
@ -1778,8 +1790,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
};
self.declarations.insert(definition, declared_ty);
self.bindings.insert(definition, inferred_ty);
self.declarations
.insert(definition, declared_ty, self.multi_inference_state);
self.bindings
.insert(definition, inferred_ty, self.multi_inference_state);
}
fn add_unknown_declaration_with_binding(
@ -2198,7 +2212,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// `infer_function_type_params`, rather than here.
if type_params.is_none() {
if self.defer_annotations() {
self.deferred.insert(definition);
self.deferred.insert(definition, self.multi_inference_state);
} else {
let previous_typevar_binding_context =
self.typevar_binding_context.replace(definition);
@ -2756,7 +2770,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// Inference of bases deferred in stubs, or if any are string literals.
if self.in_stub() || class_node.bases().iter().any(contains_string_literal) {
self.deferred.insert(definition);
self.deferred.insert(definition, self.multi_inference_state);
} else {
let previous_typevar_binding_context =
self.typevar_binding_context.replace(definition);
@ -3126,7 +3140,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
None => None,
};
if bound_or_constraint.is_some() || default.is_some() {
self.deferred.insert(definition);
self.deferred.insert(definition, self.multi_inference_state);
}
let identity =
TypeVarIdentity::new(self.db(), &name.id, Some(definition), TypeVarKind::Pep695);
@ -3190,7 +3204,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
default,
} = node;
if default.is_some() {
self.deferred.insert(definition);
self.deferred.insert(definition, self.multi_inference_state);
}
let identity = TypeVarIdentity::new(
self.db(),
@ -3680,10 +3694,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// Overwrite the previously inferred value, preferring later inferences, which are
// likely more precise. Note that we still ensure each inference is assignable to
// its declared type, so this mainly affects the IDE hover type.
let prev_multi_inference_state = mem::replace(
&mut builder.multi_inference_state,
MultiInferenceState::Overwrite,
);
let prev_multi_inference_state =
builder.set_multi_inference_state(MultiInferenceState::Overwrite);
// If we are inferring the argument multiple times, silence diagnostics to avoid duplicated warnings.
let was_in_multi_inference = if let Some(first_tcx) = first_tcx {
@ -4625,7 +4637,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
if default.is_some() {
self.deferred.insert(definition);
self.deferred.insert(definition, self.multi_inference_state);
}
let identity =
@ -4867,7 +4879,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
};
if bound_or_constraints.is_some() || default.is_some() {
self.deferred.insert(definition);
self.deferred.insert(definition, self.multi_inference_state);
}
let identity = TypeVarIdentity::new(db, target_name, Some(definition), TypeVarKind::Legacy);
@ -5961,27 +5973,156 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
/// Infer the argument types for multiple potential bindings and overloads.
fn infer_all_argument_types<'a>(
fn infer_and_check_argument_types(
&mut self,
ast_arguments: &ast::Arguments,
arguments: &mut CallArguments<'a, 'db>,
bindings: &Bindings<'db>,
) {
debug_assert!(
ast_arguments.len() == arguments.len()
&& arguments.len() == bindings.argument_forms().len()
argument_types: &mut CallArguments<'_, 'db>,
bindings: &mut Bindings<'db>,
call_expression_tcx: TypeContext<'db>,
) -> Result<(), CallErrorKind> {
let db = self.db();
// If the type context is a union, attempt to narrow to a specific element.
let narrow_targets: &[_] = match call_expression_tcx.annotation {
// TODO: We could theoretically attempt to narrow to every element of
// the power set of this union. However, this leads to an exponential
// explosion of inference attempts, and is rarely needed in practice.
Some(Type::Union(union)) => union.elements(db),
_ => &[],
};
// We silence diagnostics until we successfully narrow to a specific type.
let mut speculated_bindings = bindings.clone();
let was_in_multi_inference = self.context.set_multi_inference(true);
let mut try_narrow = |narrowed_ty| {
let narrowed_tcx = TypeContext::new(Some(narrowed_ty));
// Attempt to infer the argument types using the narrowed type context.
self.infer_all_argument_types(
ast_arguments,
argument_types,
bindings,
narrowed_tcx,
MultiInferenceState::Ignore,
);
// Ensure the argument types match their annotated types.
if speculated_bindings
.check_types_impl(
db,
argument_types,
narrowed_tcx,
&self.dataclass_field_specifiers,
)
.is_err()
{
return None;
}
// Ensure the inferred return type is assignable to the (narrowed) declared type.
//
// TODO: Checking assignability against the full declared type could help avoid
// cases where the constraint solver is not smart enough to solve complex unions.
// We should see revisit this after the new constraint solver is implemented.
if !speculated_bindings
.return_type(db)
.is_assignable_to(db, narrowed_ty)
{
return None;
}
// Successfully narrowed to an element of the union.
//
// If necessary, infer the argument types again with diagnostics enabled.
if !was_in_multi_inference {
self.context.set_multi_inference(was_in_multi_inference);
self.infer_all_argument_types(
ast_arguments,
argument_types,
bindings,
narrowed_tcx,
MultiInferenceState::Intersect,
);
}
Some(bindings.check_types_impl(
db,
argument_types,
narrowed_tcx,
&self.dataclass_field_specifiers,
))
};
// Prefer the declared type of generic classes.
for narrowed_ty in narrow_targets
.iter()
.filter(|ty| ty.class_specialization(db).is_some())
{
if let Some(result) = try_narrow(*narrowed_ty) {
return result;
}
}
// Try the remaining elements of the union.
//
// TODO: We could also attempt an inference without type context, but this
// leads to similar performance issues.
for narrowed_ty in narrow_targets
.iter()
.filter(|ty| ty.class_specialization(db).is_none())
{
if let Some(result) = try_narrow(*narrowed_ty) {
return result;
}
}
// Re-enable diagnostics, and infer against the entire union as a fallback.
self.context.set_multi_inference(was_in_multi_inference);
self.infer_all_argument_types(
ast_arguments,
argument_types,
bindings,
call_expression_tcx,
MultiInferenceState::Intersect,
);
bindings.check_types_impl(
db,
argument_types,
call_expression_tcx,
&self.dataclass_field_specifiers,
)
}
/// Infer the argument types for all bindings.
///
/// Note that this method may infer the type of a given argument expression multiple times with
/// distinct type context. The provided `MultiInferenceState` can be used to dictate multi-inference
/// behavior.
fn infer_all_argument_types(
&mut self,
ast_arguments: &ast::Arguments,
arguments_types: &mut CallArguments<'_, 'db>,
bindings: &Bindings<'db>,
call_expression_tcx: TypeContext<'db>,
multi_inference_state: MultiInferenceState,
) {
debug_assert_eq!(ast_arguments.len(), arguments_types.len());
debug_assert_eq!(arguments_types.len(), bindings.argument_forms().len());
let db = self.db();
let iter = itertools::izip!(
0..,
arguments.iter_mut(),
arguments_types.iter_mut(),
bindings.argument_forms().iter().copied(),
ast_arguments.arguments_source_order()
);
let overloads_with_binding = bindings
.into_iter()
.iter()
.filter_map(|binding| {
match binding.matching_overload_index() {
MatchingOverloadIndex::Single(_) | MatchingOverloadIndex::Multiple(_) => {
@ -6000,7 +6141,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
},
}
})
.flatten();
.flatten()
.collect::<Vec<_>>();
let old_multi_inference_state = self.set_multi_inference_state(multi_inference_state);
for (argument_index, (_, argument_type), argument_form, ast_argument) in iter {
let ast_argument = match ast_argument {
@ -6022,7 +6166,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
// Retrieve the parameter type for the current argument in a given overload and its binding.
let db = self.db();
let parameter_type = |overload: &Binding<'db>, binding: &CallableBinding<'db>| {
let argument_index = if binding.bound_type.is_some() {
argument_index + 1
@ -6035,10 +6178,25 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
return None;
};
let parameter_type =
let mut parameter_type =
overload.signature.parameters()[*parameter_index].annotated_type()?;
// TODO: For now, skip any parameter annotations that mention any typevars. There
// If this is a generic call, attempt to specialize the parameter type using the
// declared type context, if provided.
if let Some(generic_context) = overload.signature.generic_context
&& let Some(return_ty) = overload.signature.return_ty
&& let Some(declared_return_ty) = call_expression_tcx.annotation
{
let mut builder =
SpecializationBuilder::new(db, generic_context.inferable_typevars(db));
let _ = builder.infer(return_ty, declared_return_ty);
let specialization = builder.build(generic_context, call_expression_tcx);
parameter_type = parameter_type.apply_specialization(db, specialization);
}
// TODO: For now, skip any parameter annotations that still mention any typevars. There
// are two issues:
//
// First, if we include those typevars in the type context that we use to infer the
@ -6069,26 +6227,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// If there is only a single binding and overload, we can infer the argument directly with
// the unique parameter type annotation.
if let Ok((overload, binding)) = overloads_with_binding.clone().exactly_one() {
self.infer_expression_impl(
if let Ok((overload, binding)) = overloads_with_binding.iter().exactly_one() {
*argument_type = Some(self.infer_expression(
ast_argument,
TypeContext::new(parameter_type(overload, binding)),
);
));
} else {
// Otherwise, each type is a valid independent inference of the given argument, and we may
// require different permutations of argument types to correctly perform argument expansion
// during overload evaluation, so we take the intersection of all the types we inferred for
// each argument.
//
// Note that this applies to all nested expressions within each argument.
let old_multi_inference_state = mem::replace(
&mut self.multi_inference_state,
MultiInferenceState::Intersect,
);
// We perform inference once without any type context, emitting any diagnostics that are unrelated
// to bidirectional type inference.
self.infer_expression_impl(ast_argument, TypeContext::default());
*argument_type = Some(self.infer_expression(ast_argument, TypeContext::default()));
// We then silence any diagnostics emitted during multi-inference, as the type context is only
// used as a hint to infer a more assignable argument type, and should not lead to diagnostics
@ -6097,24 +6244,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// Infer the type of each argument once with each distinct parameter type as type context.
let parameter_types = overloads_with_binding
.clone()
.iter()
.filter_map(|(overload, binding)| parameter_type(overload, binding))
.collect::<FxHashSet<_>>();
for parameter_type in parameter_types {
self.infer_expression_impl(
ast_argument,
TypeContext::new(Some(parameter_type)),
);
let inferred_ty =
self.infer_expression(ast_argument, TypeContext::new(Some(parameter_type)));
// Each type is a valid independent inference of the given argument, and we may require different
// permutations of argument types to correctly perform argument expansion during overload evaluation,
// so we take the intersection of all the types we inferred for each argument.
*argument_type = argument_type
.map(|current| IntersectionType::from_elements(db, [inferred_ty, current]))
.or(Some(inferred_ty));
}
// Restore the multi-inference state.
self.multi_inference_state = old_multi_inference_state;
// Re-enable diagnostics.
self.context.set_multi_inference(was_in_multi_inference);
}
*argument_type = self.try_expression_type(ast_argument);
}
self.set_multi_inference_state(old_multi_inference_state);
}
fn infer_argument_type(
@ -6275,6 +6426,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let db = self.db();
match self.multi_inference_state {
MultiInferenceState::Ignore => {}
MultiInferenceState::Panic => {
let previous = self.expressions.insert(expression.into(), ty);
assert_eq!(previous, None);
@ -6593,7 +6746,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
validate_typed_dict_dict_literal(&self.context, typed_dict, dict, dict.into(), |expr| {
self.expression_type(expr)
item_types
.get(&expr.node_index().load())
.copied()
.unwrap_or(Type::unknown())
})
.ok()
.map(|_| Type::TypedDict(typed_dict))
@ -7356,7 +7512,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let infer_call_arguments = |bindings: Option<Bindings<'db>>| {
if let Some(bindings) = bindings {
let bindings = bindings.match_parameters(self.db(), &call_arguments);
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings);
self.infer_all_argument_types(
arguments,
&mut call_arguments,
&bindings,
tcx,
MultiInferenceState::Intersect,
);
} else {
let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()];
self.infer_argument_types(arguments, &mut call_arguments, &argument_forms);
@ -7374,10 +7536,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
let bindings = callable_type
let mut bindings = callable_type
.bindings(self.db())
.match_parameters(self.db(), &call_arguments);
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings);
let bindings_result =
self.infer_and_check_argument_types(arguments, &mut call_arguments, &mut bindings, tcx);
// Validate `TypedDict` constructor calls after argument type inference
if let Some(class_literal) = callable_type.as_class_literal() {
@ -7395,14 +7559,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
let mut bindings = match bindings.check_types(
self.db(),
&call_arguments,
&tcx,
&self.dataclass_field_specifiers[..],
) {
Ok(bindings) => bindings,
Err(CallError(_, bindings)) => {
let mut bindings = match bindings_result {
Ok(()) => bindings,
Err(_) => {
bindings.report_diagnostics(&self.context, call_expression.into());
return bindings.return_type(self.db());
}
@ -10100,8 +10259,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.check_types(
self.db(),
&call_argument_types,
&TypeContext::default(),
&self.dataclass_field_specifiers[..],
TypeContext::default(),
&self.dataclass_field_specifiers,
) {
Ok(bindings) => bindings,
Err(CallError(_, bindings)) => {
@ -10833,8 +10992,14 @@ enum MultiInferenceState {
Panic,
/// Overwrite the previously inferred value.
///
/// Note that `Overwrite` does not interact well with nested inferences:
/// it overwrites values that were written with `MultiInferenceState::Intersect`.
Overwrite,
/// Ignore the newly inferred value.
Ignore,
/// Store the intersection of all types inferred for the expression.
Intersect,
}
@ -11078,7 +11243,11 @@ where
self.0.iter().map(|(k, v)| (k, v))
}
fn insert(&mut self, key: K, value: V) {
fn insert(&mut self, key: K, value: V, multi_inference_state: MultiInferenceState) {
if matches!(multi_inference_state, MultiInferenceState::Ignore) {
return;
}
debug_assert!(
!self.0.iter().any(|(existing, _)| existing == &key),
"An existing entry already exists for key {key:?}",
@ -11092,17 +11261,21 @@ where
}
}
impl<K, V> Extend<(K, V)> for VecMap<K, V>
impl<K, V> VecMap<K, V>
where
K: Eq,
K: std::fmt::Debug,
V: std::fmt::Debug,
{
#[inline]
fn extend<T: IntoIterator<Item = (K, V)>>(&mut self, iter: T) {
fn extend<T: IntoIterator<Item = (K, V)>>(
&mut self,
iter: T,
multi_inference_state: MultiInferenceState,
) {
if cfg!(debug_assertions) {
for (key, value) in iter {
self.insert(key, value);
self.insert(key, value, multi_inference_state);
}
} else {
self.0.extend(iter);
@ -11140,7 +11313,11 @@ where
V: Eq,
V: std::fmt::Debug,
{
fn insert(&mut self, value: V) {
fn insert(&mut self, value: V, multi_inference_state: MultiInferenceState) {
if matches!(multi_inference_state, MultiInferenceState::Ignore) {
return;
}
debug_assert!(
!self.0.iter().any(|existing| existing == &value),
"An existing entry already exists for {value:?}",
@ -11150,16 +11327,20 @@ where
}
}
impl<V> Extend<V> for VecSet<V>
impl<V> VecSet<V>
where
V: Eq,
V: std::fmt::Debug,
{
#[inline]
fn extend<T: IntoIterator<Item = V>>(&mut self, iter: T) {
fn extend<T: IntoIterator<Item = V>>(
&mut self,
iter: T,
multi_inference_state: MultiInferenceState,
) {
if cfg!(debug_assertions) {
for value in iter {
self.insert(value);
self.insert(value, multi_inference_state);
}
} else {
self.0.extend(iter);