[ty] Fix subtyping/assignability of function- and class-literal types to callback protocols (#20363)

## Summary

Fixes https://github.com/astral-sh/ty/issues/377.

We were treating any function as being assignable to any callback
protocol, because we were trying to figure out a type's `Callable`
supertype by looking up the `__call__` attribute on the type's
meta-type. But a function-literal's meta-type is `types.FunctionType`,
and `types.FunctionType.__call__` is `(...) -> Any`, which is not very
helpful!

While working on this PR, I also realised that assignability between
class-literals and callback protocols was somewhat broken too, so I
fixed that at the same time.

## Test Plan

Added mdtests
This commit is contained in:
Alex Waygood 2025-09-12 22:20:09 +01:00 committed by GitHub
parent c7f6b85fb3
commit 98708976e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 137 additions and 29 deletions

View file

@ -2227,9 +2227,32 @@ def satisfies_foo(x: int) -> str:
return "foo" return "foo"
static_assert(is_assignable_to(TypeOf[satisfies_foo], Foo)) static_assert(is_assignable_to(TypeOf[satisfies_foo], Foo))
static_assert(is_subtype_of(TypeOf[satisfies_foo], Foo))
# TODO: this should pass def doesnt_satisfy_foo(x: str) -> int:
static_assert(is_subtype_of(TypeOf[satisfies_foo], Foo)) # error: [static-assert-error] return 42
static_assert(not is_assignable_to(TypeOf[doesnt_satisfy_foo], Foo))
static_assert(not is_subtype_of(TypeOf[doesnt_satisfy_foo], Foo))
```
Class-literals and generic aliases can also be subtypes of callback protocols:
```py
from typing import Sequence, TypeVar
static_assert(is_subtype_of(TypeOf[str], Foo))
T = TypeVar("T")
class SequenceMaker(Protocol[T]):
def __call__(self, arg: Sequence[T], /) -> Sequence[T]: ...
static_assert(is_subtype_of(TypeOf[list[int]], SequenceMaker[int]))
# TODO: these should pass
static_assert(is_subtype_of(TypeOf[tuple[str, ...]], SequenceMaker[str])) # error: [static-assert-error]
static_assert(is_subtype_of(TypeOf[tuple[str, ...]], SequenceMaker[int | str])) # error: [static-assert-error]
``` ```
## Nominal subtyping of protocols ## Nominal subtyping of protocols

View file

@ -1580,10 +1580,19 @@ reveal_type(A()(1)) # revealed: str
### Class literals ### Class literals
This section also tests assignability of class-literals to callback protocols, since the rules for
assignability of class-literals to callback protocols are the same as the rules for assignability of
class-literals to `Callable` types.
```toml
[environment]
python-version = "3.12"
```
#### Classes with metaclasses #### Classes with metaclasses
```py ```py
from typing import Callable, overload from typing import Callable, Protocol, overload
from typing_extensions import Self from typing_extensions import Self
from ty_extensions import TypeOf, static_assert, is_subtype_of from ty_extensions import TypeOf, static_assert, is_subtype_of
@ -1593,8 +1602,16 @@ class MetaWithReturn(type):
class A(metaclass=MetaWithReturn): ... class A(metaclass=MetaWithReturn): ...
class Returns[T](Protocol):
def __call__(self) -> T: ...
class ReturnsWithArgument[T1, T2](Protocol):
def __call__(self, arg: T1, /) -> T2: ...
static_assert(is_subtype_of(TypeOf[A], Callable[[], A])) static_assert(is_subtype_of(TypeOf[A], Callable[[], A]))
static_assert(is_subtype_of(TypeOf[A], Returns[A]))
static_assert(not is_subtype_of(TypeOf[A], Callable[[object], A])) static_assert(not is_subtype_of(TypeOf[A], Callable[[object], A]))
static_assert(not is_subtype_of(TypeOf[A], ReturnsWithArgument[object, A]))
class MetaWithDifferentReturn(type): class MetaWithDifferentReturn(type):
def __call__(cls) -> int: def __call__(cls) -> int:
@ -1603,7 +1620,9 @@ class MetaWithDifferentReturn(type):
class B(metaclass=MetaWithDifferentReturn): ... class B(metaclass=MetaWithDifferentReturn): ...
static_assert(is_subtype_of(TypeOf[B], Callable[[], int])) static_assert(is_subtype_of(TypeOf[B], Callable[[], int]))
static_assert(is_subtype_of(TypeOf[B], Returns[int]))
static_assert(not is_subtype_of(TypeOf[B], Callable[[], B])) static_assert(not is_subtype_of(TypeOf[B], Callable[[], B]))
static_assert(not is_subtype_of(TypeOf[B], Returns[B]))
class MetaWithOverloadReturn(type): class MetaWithOverloadReturn(type):
@overload @overload
@ -1617,20 +1636,30 @@ class C(metaclass=MetaWithOverloadReturn): ...
static_assert(is_subtype_of(TypeOf[C], Callable[[int], int])) static_assert(is_subtype_of(TypeOf[C], Callable[[int], int]))
static_assert(is_subtype_of(TypeOf[C], Callable[[], str])) static_assert(is_subtype_of(TypeOf[C], Callable[[], str]))
static_assert(is_subtype_of(TypeOf[C], ReturnsWithArgument[int, int]))
static_assert(is_subtype_of(TypeOf[C], Returns[str]))
``` ```
#### Classes with `__new__` #### Classes with `__new__`
```py ```py
from typing import Callable, overload from typing import Callable, overload, Protocol
from ty_extensions import TypeOf, static_assert, is_subtype_of from ty_extensions import TypeOf, static_assert, is_subtype_of
class A: class A:
def __new__(cls, a: int) -> int: def __new__(cls, a: int) -> int:
return a return a
class Returns[T](Protocol):
def __call__(self) -> T: ...
class ReturnsWithArgument[T1, T2](Protocol):
def __call__(self, arg: T1, /) -> T2: ...
static_assert(is_subtype_of(TypeOf[A], Callable[[int], int])) static_assert(is_subtype_of(TypeOf[A], Callable[[int], int]))
static_assert(is_subtype_of(TypeOf[A], ReturnsWithArgument[int, int]))
static_assert(not is_subtype_of(TypeOf[A], Callable[[], int])) static_assert(not is_subtype_of(TypeOf[A], Callable[[], int]))
static_assert(not is_subtype_of(TypeOf[A], Returns[int]))
class B: ... class B: ...
class C(B): ... class C(B): ...
@ -1644,9 +1673,13 @@ class E(D):
return C() return C()
static_assert(is_subtype_of(TypeOf[E], Callable[[], C])) static_assert(is_subtype_of(TypeOf[E], Callable[[], C]))
static_assert(is_subtype_of(TypeOf[E], Returns[C]))
static_assert(is_subtype_of(TypeOf[E], Callable[[], B])) static_assert(is_subtype_of(TypeOf[E], Callable[[], B]))
static_assert(is_subtype_of(TypeOf[E], Returns[B]))
static_assert(not is_subtype_of(TypeOf[D], Callable[[], C])) static_assert(not is_subtype_of(TypeOf[D], Callable[[], C]))
static_assert(not is_subtype_of(TypeOf[D], Returns[C]))
static_assert(is_subtype_of(TypeOf[D], Callable[[], B])) static_assert(is_subtype_of(TypeOf[D], Callable[[], B]))
static_assert(is_subtype_of(TypeOf[D], Returns[B]))
class F: class F:
@overload @overload
@ -1668,7 +1701,7 @@ static_assert(not is_subtype_of(TypeOf[F], Callable[[str], F]))
If `__call__` and `__new__` are both present, `__call__` takes precedence. If `__call__` and `__new__` are both present, `__call__` takes precedence.
```py ```py
from typing import Callable from typing import Callable, Protocol
from ty_extensions import TypeOf, static_assert, is_subtype_of from ty_extensions import TypeOf, static_assert, is_subtype_of
class MetaWithIntReturn(type): class MetaWithIntReturn(type):
@ -1679,21 +1712,34 @@ class F(metaclass=MetaWithIntReturn):
def __new__(cls) -> str: def __new__(cls) -> str:
return super().__new__(cls) return super().__new__(cls)
class Returns[T](Protocol):
def __call__(self) -> T: ...
static_assert(is_subtype_of(TypeOf[F], Callable[[], int])) static_assert(is_subtype_of(TypeOf[F], Callable[[], int]))
static_assert(is_subtype_of(TypeOf[F], Returns[int]))
static_assert(not is_subtype_of(TypeOf[F], Callable[[], str])) static_assert(not is_subtype_of(TypeOf[F], Callable[[], str]))
static_assert(not is_subtype_of(TypeOf[F], Returns[str]))
``` ```
#### Classes with `__init__` #### Classes with `__init__`
```py ```py
from typing import Callable, overload from typing import Callable, overload, Protocol
from ty_extensions import TypeOf, static_assert, is_subtype_of from ty_extensions import TypeOf, static_assert, is_subtype_of
class Returns[T](Protocol):
def __call__(self) -> T: ...
class ReturnsWithArgument[T1, T2](Protocol):
def __call__(self, arg: T1, /) -> T2: ...
class A: class A:
def __init__(self, a: int) -> None: ... def __init__(self, a: int) -> None: ...
static_assert(is_subtype_of(TypeOf[A], Callable[[int], A])) static_assert(is_subtype_of(TypeOf[A], Callable[[int], A]))
static_assert(is_subtype_of(TypeOf[A], ReturnsWithArgument[int, A]))
static_assert(not is_subtype_of(TypeOf[A], Callable[[], A])) static_assert(not is_subtype_of(TypeOf[A], Callable[[], A]))
static_assert(not is_subtype_of(TypeOf[A], Returns[A]))
class B: class B:
@overload @overload
@ -1703,27 +1749,37 @@ class B:
def __init__(self, a: int | None = None) -> None: ... def __init__(self, a: int | None = None) -> None: ...
static_assert(is_subtype_of(TypeOf[B], Callable[[int], B])) static_assert(is_subtype_of(TypeOf[B], Callable[[int], B]))
static_assert(is_subtype_of(TypeOf[B], ReturnsWithArgument[int, B]))
static_assert(is_subtype_of(TypeOf[B], Callable[[], B])) static_assert(is_subtype_of(TypeOf[B], Callable[[], B]))
static_assert(is_subtype_of(TypeOf[B], Returns[B]))
class C: ... class C: ...
# TODO: This assertion should be true once we understand `Self` # TODO: These assertions should be true once we understand `Self`
# error: [static-assert-error] "Static assertion error: argument of type `ty_extensions.ConstraintSet[never]` is statically known to be falsy" static_assert(is_subtype_of(TypeOf[C], Callable[[], C])) # error: [static-assert-error]
static_assert(is_subtype_of(TypeOf[C], Callable[[], C])) static_assert(is_subtype_of(TypeOf[C], Returns[C])) # error: [static-assert-error]
class D[T]: class D[T]:
def __init__(self, x: T) -> None: ... def __init__(self, x: T) -> None: ...
static_assert(is_subtype_of(TypeOf[D[int]], Callable[[int], D[int]])) static_assert(is_subtype_of(TypeOf[D[int]], Callable[[int], D[int]]))
static_assert(is_subtype_of(TypeOf[D[int]], ReturnsWithArgument[int, D[int]]))
static_assert(not is_subtype_of(TypeOf[D[int]], Callable[[str], D[int]])) static_assert(not is_subtype_of(TypeOf[D[int]], Callable[[str], D[int]]))
static_assert(not is_subtype_of(TypeOf[D[int]], ReturnsWithArgument[str, D[int]]))
``` ```
#### Classes with `__init__` and `__new__` #### Classes with `__init__` and `__new__`
```py ```py
from typing import Callable, overload, Self from typing import Callable, overload, Self, Protocol
from ty_extensions import TypeOf, static_assert, is_subtype_of from ty_extensions import TypeOf, static_assert, is_subtype_of
class Returns[T](Protocol):
def __call__(self) -> T: ...
class ReturnsWithArgument[T1, T2](Protocol):
def __call__(self, arg: T1, /) -> T2: ...
class A: class A:
def __new__(cls, a: int) -> Self: def __new__(cls, a: int) -> Self:
return super().__new__(cls) return super().__new__(cls)
@ -1731,7 +1787,9 @@ class A:
def __init__(self, a: int) -> None: ... def __init__(self, a: int) -> None: ...
static_assert(is_subtype_of(TypeOf[A], Callable[[int], A])) static_assert(is_subtype_of(TypeOf[A], Callable[[int], A]))
static_assert(is_subtype_of(TypeOf[A], ReturnsWithArgument[int, A]))
static_assert(not is_subtype_of(TypeOf[A], Callable[[], A])) static_assert(not is_subtype_of(TypeOf[A], Callable[[], A]))
static_assert(not is_subtype_of(TypeOf[A], Returns[A]))
class B: class B:
def __new__(cls, a: int) -> int: def __new__(cls, a: int) -> int:
@ -1740,7 +1798,9 @@ class B:
def __init__(self, a: str) -> None: ... def __init__(self, a: str) -> None: ...
static_assert(is_subtype_of(TypeOf[B], Callable[[int], int])) static_assert(is_subtype_of(TypeOf[B], Callable[[int], int]))
static_assert(is_subtype_of(TypeOf[B], ReturnsWithArgument[int, int]))
static_assert(not is_subtype_of(TypeOf[B], Callable[[str], B])) static_assert(not is_subtype_of(TypeOf[B], Callable[[str], B]))
static_assert(not is_subtype_of(TypeOf[B], ReturnsWithArgument[str, B]))
class C: class C:
def __new__(cls, *args, **kwargs) -> "C": def __new__(cls, *args, **kwargs) -> "C":
@ -1750,7 +1810,9 @@ class C:
# Not subtype because __new__ signature is not fully static # Not subtype because __new__ signature is not fully static
static_assert(not is_subtype_of(TypeOf[C], Callable[[int], C])) static_assert(not is_subtype_of(TypeOf[C], Callable[[int], C]))
static_assert(not is_subtype_of(TypeOf[C], ReturnsWithArgument[int, C]))
static_assert(not is_subtype_of(TypeOf[C], Callable[[], C])) static_assert(not is_subtype_of(TypeOf[C], Callable[[], C]))
static_assert(not is_subtype_of(TypeOf[C], Returns[C]))
class D: ... class D: ...
@ -1765,7 +1827,9 @@ class E:
def __init__(self, y: str) -> None: ... def __init__(self, y: str) -> None: ...
static_assert(is_subtype_of(TypeOf[E], Callable[[int], D])) static_assert(is_subtype_of(TypeOf[E], Callable[[int], D]))
static_assert(is_subtype_of(TypeOf[E], ReturnsWithArgument[int, D]))
static_assert(is_subtype_of(TypeOf[E], Callable[[], int])) static_assert(is_subtype_of(TypeOf[E], Callable[[], int]))
static_assert(is_subtype_of(TypeOf[E], Returns[int]))
class F[T]: class F[T]:
def __new__(cls, x: T) -> "F[T]": def __new__(cls, x: T) -> "F[T]":
@ -1774,7 +1838,9 @@ class F[T]:
def __init__(self, x: T) -> None: ... def __init__(self, x: T) -> None: ...
static_assert(is_subtype_of(TypeOf[F[int]], Callable[[int], F[int]])) static_assert(is_subtype_of(TypeOf[F[int]], Callable[[int], F[int]]))
static_assert(is_subtype_of(TypeOf[F[int]], ReturnsWithArgument[int, F[int]]))
static_assert(not is_subtype_of(TypeOf[F[int]], Callable[[str], F[int]])) static_assert(not is_subtype_of(TypeOf[F[int]], Callable[[str], F[int]]))
static_assert(not is_subtype_of(TypeOf[F[int]], ReturnsWithArgument[str, F[int]]))
``` ```
#### Classes with `__call__`, `__new__` and `__init__` #### Classes with `__call__`, `__new__` and `__init__`
@ -1782,9 +1848,15 @@ static_assert(not is_subtype_of(TypeOf[F[int]], Callable[[str], F[int]]))
If `__call__`, `__new__` and `__init__` are all present, `__call__` takes precedence. If `__call__`, `__new__` and `__init__` are all present, `__call__` takes precedence.
```py ```py
from typing import Callable from typing import Callable, Protocol
from ty_extensions import TypeOf, static_assert, is_subtype_of from ty_extensions import TypeOf, static_assert, is_subtype_of
class Returns[T](Protocol):
def __call__(self) -> T: ...
class ReturnsWithArgument[T1, T2](Protocol):
def __call__(self, arg: T1, /) -> T2: ...
class MetaWithIntReturn(type): class MetaWithIntReturn(type):
def __call__(cls) -> int: def __call__(cls) -> int:
return super().__call__() return super().__call__()
@ -1796,8 +1868,11 @@ class F(metaclass=MetaWithIntReturn):
def __init__(self, x: int) -> None: ... def __init__(self, x: int) -> None: ...
static_assert(is_subtype_of(TypeOf[F], Callable[[], int])) static_assert(is_subtype_of(TypeOf[F], Callable[[], int]))
static_assert(is_subtype_of(TypeOf[F], Returns[int]))
static_assert(not is_subtype_of(TypeOf[F], Callable[[], str])) static_assert(not is_subtype_of(TypeOf[F], Callable[[], str]))
static_assert(not is_subtype_of(TypeOf[F], Returns[str]))
static_assert(not is_subtype_of(TypeOf[F], Callable[[int], F])) static_assert(not is_subtype_of(TypeOf[F], Callable[[int], F]))
static_assert(not is_subtype_of(TypeOf[F], ReturnsWithArgument[int, F]))
``` ```
### Subclass of ### Subclass of

View file

@ -3268,13 +3268,6 @@ impl<'db> Type<'db> {
policy: InstanceFallbackShadowsNonDataDescriptor, policy: InstanceFallbackShadowsNonDataDescriptor,
member_policy: MemberLookupPolicy, member_policy: MemberLookupPolicy,
) -> PlaceAndQualifiers<'db> { ) -> PlaceAndQualifiers<'db> {
// TODO: this is a workaround for the fact that looking up the `__call__` attribute on the
// meta-type of a `Callable` type currently returns `Unbound`. We should fix this by inferring
// a more sophisticated meta-type for `Callable` types; that would allow us to remove this branch.
if name == "__call__" && matches!(self, Type::Callable(_) | Type::DataclassTransformer(_)) {
return Place::bound(self).into();
}
let ( let (
PlaceAndQualifiers { PlaceAndQualifiers {
place: meta_attr, place: meta_attr,

View file

@ -541,17 +541,34 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
match &self.kind { match &self.kind {
ProtocolMemberKind::Method(method) => { ProtocolMemberKind::Method(method) => {
let Place::Type(attribute_type, Boundness::Bound) = other // `__call__` members must be special cased for several reasons:
.invoke_descriptor_protocol( //
db, // 1. Looking up `__call__` on the meta-type of a `Callable` type returns `Place::Unbound` currently
self.name, // 2. Looking up `__call__` on the meta-type of a function-literal type currently returns a type that
Place::Unbound.into(), // has an extremely vague signature (`(*args, **kwargs) -> Any`), which is not useful for protocol
InstanceFallbackShadowsNonDataDescriptor::No, // checking.
MemberLookupPolicy::default(), // 3. Looking up `__call__` on the meta-type of a class-literal, generic-alias or subclass-of type is
) // unfortunately not sufficient to obtain the `Callable` supertypes of these types, due to the
.place // complex interaction between `__new__`, `__init__` and metaclass `__call__`.
else { let attribute_type = if self.name == "__call__" {
return ConstraintSet::from(false); let Some(attribute_type) = other.into_callable(db) else {
return ConstraintSet::from(false);
};
attribute_type
} else {
let Place::Type(attribute_type, Boundness::Bound) = other
.invoke_descriptor_protocol(
db,
self.name,
Place::Unbound.into(),
InstanceFallbackShadowsNonDataDescriptor::No,
MemberLookupPolicy::default(),
)
.place
else {
return ConstraintSet::from(false);
};
attribute_type
}; };
let proto_member_as_bound_method = method.bind_self(db); let proto_member_as_bound_method = method.bind_self(db);