[ty] Add subtyping between Callable types and class literals with __init__ (#17638)

## Summary

Allow classes with `__init__` to be subtypes of `Callable`

Fixes https://github.com/astral-sh/ty/issues/358

## Test Plan

Update is_subtype_of.md

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Matthew Mckee 2025-05-28 21:43:07 +01:00 committed by GitHub
parent 16621fa19d
commit c60b4d7f30
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 329 additions and 46 deletions

View file

@ -608,11 +608,49 @@ c: Callable[[Any], str] = A().g
```py
from typing import Any, Callable
c: Callable[[object], type] = type
c: Callable[[str], Any] = str
c: Callable[[str], Any] = int
# error: [invalid-assignment]
c: Callable[[str], Any] = object
class A:
def __init__(self, x: int) -> None: ...
a: Callable[[int], A] = A
class C:
def __new__(cls, *args, **kwargs) -> "C":
return super().__new__(cls)
def __init__(self, x: int) -> None: ...
c: Callable[[int], C] = C
```
### Generic class literal types
```toml
[environment]
python-version = "3.12"
```
```py
from typing import Callable
class B[T]:
def __init__(self, x: T) -> None: ...
b: Callable[[int], B[int]] = B[int]
class C[T]:
def __new__(cls, *args, **kwargs) -> "C[T]":
return super().__new__(cls)
def __init__(self, x: T) -> None: ...
c: Callable[[int], C[int]] = C[int]
```
### Overloads

View file

@ -1219,7 +1219,7 @@ static_assert(is_subtype_of(TypeOf[C], Callable[[], str]))
#### Classes with `__new__`
```py
from typing import Callable
from typing import Callable, overload
from ty_extensions import TypeOf, static_assert, is_subtype_of
class A:
@ -1244,6 +1244,20 @@ static_assert(is_subtype_of(TypeOf[E], Callable[[], C]))
static_assert(is_subtype_of(TypeOf[E], Callable[[], B]))
static_assert(not is_subtype_of(TypeOf[D], Callable[[], C]))
static_assert(is_subtype_of(TypeOf[D], Callable[[], B]))
class F:
@overload
def __new__(cls) -> int: ...
@overload
def __new__(cls, x: int) -> "F": ...
def __new__(cls, x: int | None = None) -> "int | F":
return 1 if x is None else object.__new__(cls)
def __init__(self, y: str) -> None: ...
static_assert(is_subtype_of(TypeOf[F], Callable[[int], F]))
static_assert(is_subtype_of(TypeOf[F], Callable[[], int]))
static_assert(not is_subtype_of(TypeOf[F], Callable[[str], F]))
```
#### Classes with `__call__` and `__new__`
@ -1266,6 +1280,123 @@ static_assert(is_subtype_of(TypeOf[F], Callable[[], int]))
static_assert(not is_subtype_of(TypeOf[F], Callable[[], str]))
```
#### Classes with `__init__`
```py
from typing import Callable, overload
from ty_extensions import TypeOf, static_assert, is_subtype_of
class A:
def __init__(self, a: int) -> None: ...
static_assert(is_subtype_of(TypeOf[A], Callable[[int], A]))
static_assert(not is_subtype_of(TypeOf[A], Callable[[], A]))
class B:
@overload
def __init__(self, a: int) -> None: ...
@overload
def __init__(self) -> None: ...
def __init__(self, a: int | None = None) -> None: ...
static_assert(is_subtype_of(TypeOf[B], Callable[[int], B]))
static_assert(is_subtype_of(TypeOf[B], Callable[[], B]))
class C: ...
# TODO: This assertion should be true once we understand `Self`
# error: [static-assert-error] "Static assertion error: argument evaluates to `False`"
static_assert(is_subtype_of(TypeOf[C], Callable[[], C]))
class D[T]:
def __init__(self, x: T) -> None: ...
static_assert(is_subtype_of(TypeOf[D[int]], Callable[[int], D[int]]))
static_assert(not is_subtype_of(TypeOf[D[int]], Callable[[str], D[int]]))
```
#### Classes with `__init__` and `__new__`
```py
from typing import Callable, overload, Self
from ty_extensions import TypeOf, static_assert, is_subtype_of
class A:
def __new__(cls, a: int) -> Self:
return super().__new__(cls)
def __init__(self, a: int) -> None: ...
static_assert(is_subtype_of(TypeOf[A], Callable[[int], A]))
static_assert(not is_subtype_of(TypeOf[A], Callable[[], A]))
class B:
def __new__(cls, a: int) -> int:
return super().__new__(cls)
def __init__(self, a: str) -> None: ...
static_assert(is_subtype_of(TypeOf[B], Callable[[int], int]))
static_assert(not is_subtype_of(TypeOf[B], Callable[[str], B]))
class C:
def __new__(cls, *args, **kwargs) -> "C":
return super().__new__(cls)
def __init__(self, x: int) -> None: ...
# Not subtype because __new__ signature is not fully static
static_assert(not is_subtype_of(TypeOf[C], Callable[[int], C]))
static_assert(not is_subtype_of(TypeOf[C], Callable[[], C]))
class D: ...
class E:
@overload
def __new__(cls) -> int: ...
@overload
def __new__(cls, x: int) -> D: ...
def __new__(cls, x: int | None = None) -> int | D:
return D()
def __init__(self, y: str) -> None: ...
static_assert(is_subtype_of(TypeOf[E], Callable[[int], D]))
static_assert(is_subtype_of(TypeOf[E], Callable[[], int]))
class F[T]:
def __new__(cls, x: T) -> "F[T]":
return super().__new__(cls)
def __init__(self, x: T) -> None: ...
static_assert(is_subtype_of(TypeOf[F[int]], Callable[[int], F[int]]))
static_assert(not is_subtype_of(TypeOf[F[int]], Callable[[str], F[int]]))
```
#### Classes with `__call__`, `__new__` and `__init__`
If `__call__`, `__new__` and `__init__` are all present, `__call__` takes precedence.
```py
from typing import Callable
from ty_extensions import TypeOf, static_assert, is_subtype_of
class MetaWithIntReturn(type):
def __call__(cls) -> int:
return super().__call__()
class F(metaclass=MetaWithIntReturn):
def __new__(cls) -> str:
return super().__new__(cls)
def __init__(self, x: int) -> None: ...
static_assert(is_subtype_of(TypeOf[F], Callable[[], int]))
static_assert(not is_subtype_of(TypeOf[F], Callable[[], str]))
static_assert(not is_subtype_of(TypeOf[F], Callable[[int], F]))
```
### Bound methods
```py

View file

@ -1351,12 +1351,15 @@ impl<'db> Type<'db> {
}
(Type::ClassLiteral(class_literal), Type::Callable(_)) => {
if let Some(callable) = class_literal.into_callable(db) {
return callable.is_subtype_of(db, target);
}
false
ClassType::NonGeneric(class_literal)
.into_callable(db)
.is_subtype_of(db, target)
}
(Type::GenericAlias(alias), Type::Callable(_)) => ClassType::Generic(alias)
.into_callable(db)
.is_subtype_of(db, target),
// `Literal[str]` is a subtype of `type` because the `str` class object is an instance of its metaclass `type`.
// `Literal[abc.ABC]` is a subtype of `abc.ABCMeta` because the `abc.ABC` class object
// is an instance of its metaclass `abc.ABCMeta`.
@ -1656,12 +1659,15 @@ impl<'db> Type<'db> {
}
(Type::ClassLiteral(class_literal), Type::Callable(_)) => {
if let Some(callable) = class_literal.into_callable(db) {
return callable.is_assignable_to(db, target);
}
false
ClassType::NonGeneric(class_literal)
.into_callable(db)
.is_assignable_to(db, target)
}
(Type::GenericAlias(alias), Type::Callable(_)) => ClassType::Generic(alias)
.into_callable(db)
.is_assignable_to(db, target),
(Type::FunctionLiteral(self_function_literal), Type::Callable(_)) => {
self_function_literal
.into_callable_type(db)

View file

@ -9,7 +9,7 @@ use super::{
use crate::semantic_index::DeclarationWithConstraint;
use crate::semantic_index::definition::Definition;
use crate::types::generics::{GenericContext, Specialization};
use crate::types::signatures::{Parameter, Parameters, Signature};
use crate::types::signatures::{CallableSignature, Parameter, Parameters, Signature};
use crate::types::{
CallableType, DataclassParams, DataclassTransformerParams, KnownInstanceType, TypeMapping,
TypeVarInstance,
@ -493,6 +493,142 @@ impl<'db> ClassType<'db> {
.own_instance_member(db, name)
.map_type(|ty| ty.apply_optional_specialization(db, specialization))
}
/// Return a callable type (or union of callable types) that represents the callable
/// constructor signature of this class.
pub(super) fn into_callable(self, db: &'db dyn Db) -> Type<'db> {
let self_ty = Type::from(self);
let metaclass_dunder_call_function_symbol = self_ty
.member_lookup_with_policy(
db,
"__call__".into(),
MemberLookupPolicy::NO_INSTANCE_FALLBACK
| MemberLookupPolicy::META_CLASS_NO_TYPE_FALLBACK,
)
.symbol;
if let Symbol::Type(Type::BoundMethod(metaclass_dunder_call_function), _) =
metaclass_dunder_call_function_symbol
{
// TODO: this intentionally diverges from step 1 in
// https://typing.python.org/en/latest/spec/constructors.html#converting-a-constructor-to-callable
// by always respecting the signature of the metaclass `__call__`, rather than
// using a heuristic which makes unwarranted assumptions to sometimes ignore it.
return metaclass_dunder_call_function.into_callable_type(db);
}
let dunder_new_function_symbol = self_ty
.member_lookup_with_policy(
db,
"__new__".into(),
MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK,
)
.symbol;
let dunder_new_function =
if let Symbol::Type(Type::FunctionLiteral(dunder_new_function), _) =
dunder_new_function_symbol
{
// Step 3: If the return type of the `__new__` evaluates to a type that is not a subclass of this class,
// then we should ignore the `__init__` and just return the `__new__` method.
let returns_non_subclass =
dunder_new_function
.signature(db)
.overloads
.iter()
.any(|signature| {
signature.return_ty.is_some_and(|return_ty| {
!return_ty.is_assignable_to(
db,
self_ty
.to_instance(db)
.expect("ClassType should be instantiable"),
)
})
});
let dunder_new_bound_method =
dunder_new_function.into_bound_method_type(db, self_ty);
if returns_non_subclass {
return dunder_new_bound_method;
}
Some(dunder_new_bound_method)
} else {
None
};
let dunder_init_function_symbol = self_ty
.member_lookup_with_policy(
db,
"__init__".into(),
MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK
| MemberLookupPolicy::META_CLASS_NO_TYPE_FALLBACK,
)
.symbol;
let correct_return_type = self_ty.to_instance(db).unwrap_or_else(Type::unknown);
// If the class defines an `__init__` method, then we synthesize a callable type with the
// same parameters as the `__init__` method after it is bound, and with the return type of
// the concrete type of `Self`.
let synthesized_dunder_init_callable =
if let Symbol::Type(Type::FunctionLiteral(dunder_init_function), _) =
dunder_init_function_symbol
{
let synthesized_signature = |signature: Signature<'db>| {
Signature::new(signature.parameters().clone(), Some(correct_return_type))
.bind_self()
};
let synthesized_dunder_init_signature = CallableSignature::from_overloads(
dunder_init_function
.signature(db)
.overloads
.iter()
.cloned()
.map(synthesized_signature),
);
Some(Type::Callable(CallableType::new(
db,
synthesized_dunder_init_signature,
true,
)))
} else {
None
};
match (dunder_new_function, synthesized_dunder_init_callable) {
(Some(dunder_new_function), Some(synthesized_dunder_init_callable)) => {
UnionType::from_elements(
db,
vec![dunder_new_function, synthesized_dunder_init_callable],
)
}
(Some(constructor), None) | (None, Some(constructor)) => constructor,
(None, None) => {
// If no `__new__` or `__init__` method is found, then we fall back to looking for
// an `object.__new__` method.
let new_function_symbol = self_ty
.member_lookup_with_policy(
db,
"__new__".into(),
MemberLookupPolicy::META_CLASS_NO_TYPE_FALLBACK,
)
.symbol;
if let Symbol::Type(Type::FunctionLiteral(new_function), _) = new_function_symbol {
new_function.into_bound_method_type(db, self_ty)
} else {
// Fallback if no `object.__new__` is found.
CallableType::single(
db,
Signature::new(Parameters::empty(), Some(correct_return_type)),
)
}
}
}
}
}
impl<'db> From<GenericAlias<'db>> for ClassType<'db> {
@ -991,40 +1127,6 @@ impl<'db> ClassLiteral<'db> {
))
}
pub(super) fn into_callable(self, db: &'db dyn Db) -> Option<Type<'db>> {
let self_ty = Type::from(self);
let metaclass_call_function_symbol = self_ty
.member_lookup_with_policy(
db,
"__call__".into(),
MemberLookupPolicy::NO_INSTANCE_FALLBACK
| MemberLookupPolicy::META_CLASS_NO_TYPE_FALLBACK,
)
.symbol;
if let Symbol::Type(Type::BoundMethod(metaclass_call_function), _) =
metaclass_call_function_symbol
{
// TODO: this intentionally diverges from step 1 in
// https://typing.python.org/en/latest/spec/constructors.html#converting-a-constructor-to-callable
// by always respecting the signature of the metaclass `__call__`, rather than
// using a heuristic which makes unwarranted assumptions to sometimes ignore it.
return Some(metaclass_call_function.into_callable_type(db));
}
let dunder_new_method = self_ty
.find_name_in_mro(db, "__new__")
.expect("find_name_in_mro always succeeds for class literals")
.symbol
.try_call_dunder_get(db, self_ty);
if let Symbol::Type(Type::FunctionLiteral(dunder_new_method), _) = dunder_new_method {
return Some(dunder_new_method.into_bound_method_type(db, self.into()));
}
// TODO handle `__init__` also
None
}
/// Returns the class member of this class named `name`.
///
/// The member resolves to a member on the class itself or any of its proper superclasses.

View file

@ -88,7 +88,8 @@ pub enum KnownInstanceType<'db> {
/// The symbol `typing.Callable`
/// (which can also be found as `typing_extensions.Callable` or as `collections.abc.Callable`)
Callable,
/// The symbol `typing.Self` (which can also be found as `typing_extensions.Self`)
/// The symbol `typing.Self` (which can also be found as `typing_extensions.Self` or
/// `_typeshed.Self`)
TypingSelf,
// Various special forms, special aliases and type qualifiers that we don't yet understand
@ -307,7 +308,6 @@ impl<'db> KnownInstanceType<'db> {
| Self::Literal
| Self::LiteralString
| Self::Never
| Self::TypingSelf
| Self::Final
| Self::Concatenate
| Self::Unpack
@ -322,6 +322,12 @@ impl<'db> KnownInstanceType<'db> {
| Self::TypeVar(_) => {
matches!(module, KnownModule::Typing | KnownModule::TypingExtensions)
}
Self::TypingSelf => {
matches!(
module,
KnownModule::Typing | KnownModule::TypingExtensions | KnownModule::Typeshed
)
}
Self::Unknown
| Self::AlwaysTruthy
| Self::AlwaysFalsy