[ty] Use constructor parameter types as type context (#21054)

## Summary

Resolves https://github.com/astral-sh/ty/issues/1408.
This commit is contained in:
Ibraheem Ahmed 2025-10-24 16:14:18 -04:00 committed by GitHub
parent c3de8847d5
commit 304ac22e74
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 161 additions and 24 deletions

View file

@ -145,3 +145,84 @@ def h[T](x: T, cond: bool) -> T | list[T]:
def i[T](x: T, cond: bool) -> T | list[T]:
return x if cond else [x]
```
## Type context sources
Type context is sourced from various places, including annotated assignments:
```py
from typing import Literal
a: list[Literal[1]] = [1]
```
Function parameter annotations:
```py
def b(x: list[Literal[1]]): ...
b([1])
```
Bound method parameter annotations:
```py
class C:
def __init__(self, x: list[Literal[1]]): ...
def foo(self, x: list[Literal[1]]): ...
C([1]).foo([1])
```
Declared variable types:
```py
d: list[Literal[1]]
d = [1]
```
Declared attribute types:
```py
class E:
e: list[Literal[1]]
def _(e: E):
# TODO: Implement attribute type context.
# error: [invalid-assignment] "Object of type `list[Unknown | int]` is not assignable to attribute `e` of type `list[Literal[1]]`"
e.e = [1]
```
Function return types:
```py
def f() -> list[Literal[1]]:
return [1]
```
## Class constructor parameters
```toml
[environment]
python-version = "3.12"
```
The parameters of both `__init__` and `__new__` are used as type context sources for constructor
calls:
```py
def f[T](x: T) -> list[T]:
return [x]
class A:
def __new__(cls, value: list[int | str]):
return super().__new__(cls, value)
def __init__(self, value: list[int | None]): ...
A(f(1))
# error: [invalid-argument-type] "Argument to function `__new__` is incorrect: Expected `list[int | str]`, found `list[list[Unknown]]`"
# error: [invalid-argument-type] "Argument to bound method `__init__` is incorrect: Expected `list[int | None]`, found `list[list[Unknown]]`"
A(f([]))
```

View file

@ -6007,6 +6007,9 @@ impl<'db> Type<'db> {
/// Given a class literal or non-dynamic `SubclassOf` type, try calling it (creating an instance)
/// and return the resulting instance type.
///
/// The `infer_argument_types` closure should be invoked with the signatures of `__new__` and
/// `__init__`, such that the argument types can be inferred with the correct type context.
///
/// Models `type.__call__` behavior.
/// TODO: model metaclass `__call__`.
///
@ -6017,10 +6020,10 @@ impl<'db> Type<'db> {
///
/// Foo()
/// ```
fn try_call_constructor(
fn try_call_constructor<'ast>(
self,
db: &'db dyn Db,
argument_types: CallArguments<'_, 'db>,
infer_argument_types: impl FnOnce(Option<Bindings<'db>>) -> CallArguments<'ast, 'db>,
tcx: TypeContext<'db>,
) -> Result<Type<'db>, ConstructorCallError<'db>> {
debug_assert!(matches!(
@ -6076,11 +6079,63 @@ impl<'db> Type<'db> {
// easy to check if that's the one we found?
// Note that `__new__` is a static method, so we must inject the `cls` argument.
let new_method = self_type.lookup_dunder_new(db, ());
// Construct an instance type that we can use to look up the `__init__` instance method.
// This performs the same logic as `Type::to_instance`, except for generic class literals.
// TODO: we should use the actual return type of `__new__` to determine the instance type
let init_ty = self_type
.to_instance(db)
.expect("type should be convertible to instance type");
// Lookup the `__init__` instance method in the MRO.
let init_method = init_ty.member_lookup_with_policy(
db,
"__init__".into(),
MemberLookupPolicy::NO_INSTANCE_FALLBACK | MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK,
);
// Infer the call argument types, using both `__new__` and `__init__` for type-context.
let bindings = match (
new_method.as_ref().map(|method| &method.place),
&init_method.place,
) {
(Some(Place::Defined(new_method, ..)), Place::Undefined) => Some(
new_method
.bindings(db)
.map(|binding| binding.with_bound_type(self_type)),
),
(Some(Place::Undefined) | None, Place::Defined(init_method, ..)) => {
Some(init_method.bindings(db))
}
(Some(Place::Defined(new_method, ..)), Place::Defined(init_method, ..)) => {
let callable = UnionBuilder::new(db)
.add(*new_method)
.add(*init_method)
.build();
let new_method_bindings = new_method
.bindings(db)
.map(|binding| binding.with_bound_type(self_type));
Some(Bindings::from_union(
callable,
[new_method_bindings, init_method.bindings(db)],
))
}
_ => None,
};
let argument_types = infer_argument_types(bindings);
let new_call_outcome = new_method.and_then(|new_method| {
match new_method.place.try_call_dunder_get(db, self_type) {
Place::Defined(new_method, _, boundness) => {
let result =
new_method.try_call(db, argument_types.with_self(Some(self_type)).as_ref());
if boundness == Definedness::PossiblyUndefined {
Some(Err(DunderNewCallError::PossiblyUnbound(result.err())))
} else {
@ -6091,24 +6146,7 @@ impl<'db> Type<'db> {
}
});
// Construct an instance type that we can use to look up the `__init__` instance method.
// This performs the same logic as `Type::to_instance`, except for generic class literals.
// TODO: we should use the actual return type of `__new__` to determine the instance type
let init_ty = self_type
.to_instance(db)
.expect("type should be convertible to instance type");
let init_call_outcome = if new_call_outcome.is_none()
|| !init_ty
.member_lookup_with_policy(
db,
"__init__".into(),
MemberLookupPolicy::NO_INSTANCE_FALLBACK
| MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK,
)
.place
.is_undefined()
{
let init_call_outcome = if new_call_outcome.is_none() || !init_method.is_undefined() {
Some(init_ty.try_call_dunder(db, "__init__", argument_types, tcx))
} else {
None

View file

@ -100,6 +100,14 @@ impl<'db> Bindings<'db> {
self.elements.iter()
}
pub(crate) fn map(self, f: impl Fn(CallableBinding<'db>) -> CallableBinding<'db>) -> Self {
Self {
callable_type: self.callable_type,
argument_forms: self.argument_forms,
elements: self.elements.into_iter().map(f).collect(),
}
}
/// Match the arguments of a call site against the parameters of a collection of possibly
/// unioned, possibly overloaded signatures.
///

View file

@ -6798,9 +6798,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.to_class_type(self.db())
.is_none_or(|enum_class| !class.is_subclass_of(self.db(), enum_class))
{
let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()];
self.infer_argument_types(arguments, &mut call_arguments, &argument_forms);
if matches!(
class.known(self.db()),
Some(KnownClass::TypeVar | KnownClass::ExtensionsTypeVar)
@ -6819,8 +6816,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
let db = self.db();
let infer_call_arguments = |bindings: Option<Bindings<'db>>| {
if let Some(bindings) = bindings {
let bindings = bindings.match_parameters(self.db(), &call_arguments);
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings);
} else {
let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()];
self.infer_argument_types(arguments, &mut call_arguments, &argument_forms);
}
call_arguments
};
return callable_type
.try_call_constructor(self.db(), call_arguments, tcx)
.try_call_constructor(db, infer_call_arguments, tcx)
.unwrap_or_else(|err| {
err.report_diagnostic(&self.context, callable_type, call_expression.into());
err.return_type()