[ty] Implement equivalence for protocols with method members (#18659)

## Summary

This PR implements the following pieces of `Protocol` semantics:
1. A protocol with a method member that does not have a fully static
signature should not be considered fully static. I.e., this protocol is
not fully static because `Foo.x` has no return type; we previously
incorrectly considered that it was:
  ```py
  class Foo(Protocol):
      def f(self): ...
  ```
2. Two protocols `P1` and `P2`, both with method members `x`, should be
considered equivalent if the signature of `P1.x` is equivalent to the
signature of `P2.x`. Currently we do not recognize this.

Implementing these semantics requires distinguishing between method
members and non-method members. The stored type of a method member must
be eagerly upcast to a `Callable` type when collecting the protocol's
interface: doing otherwise would mean that it would be hard to implement
equivalence of protocols even in the face of differently ordered unions,
since the two equivalent protocols would have different Salsa IDs even
when normalized.

The semantics implemented by this PR are that we consider something a
method member if:
1. It is accessible on the class itself; and
2. It is a function-like callable: a callable type that also has a
`__get__` method, meaning it can be used as a method when accessed on
instances.

Note that the spec has complicated things to say about classmethod
members and staticmethod members. These semantics are not implemented by
this PR; they are all deferred for now.

The infrastructure added in this PR fixes bugs in its own right, but
also lays the groundwork for implementing subtyping and assignability
rules for method members of protocols. A (currently failing) test is
added to verify this.

## Test Plan

mdtests
This commit is contained in:
Alex Waygood 2025-07-07 12:28:32 +01:00 committed by GitHub
parent c15aa572ff
commit a6637964d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 146 additions and 21 deletions

View file

@ -1476,8 +1476,7 @@ class P1(Protocol):
class P2(Protocol):
def x(self, y: int) -> None: ...
# TODO: this should pass
static_assert(is_equivalent_to(P1, P2)) # error: [static-assert-error]
static_assert(is_equivalent_to(P1, P2))
```
As with protocols that only have non-method members, this also holds true when they appear in
@ -1487,8 +1486,7 @@ differently ordered unions:
class A: ...
class B: ...
# TODO: this should pass
static_assert(is_equivalent_to(A | B | P1, P2 | B | A)) # error: [static-assert-error]
static_assert(is_equivalent_to(A | B | P1, P2 | B | A))
```
## Narrowing of protocols
@ -1896,6 +1894,86 @@ if isinstance(obj, (B, A)):
reveal_type(obj) # revealed: (Unknown & B) | (Unknown & A)
```
### Protocols that use `Self`
`Self` is a `TypeVar` with an upper bound of the class in which it is defined. This means that
`Self` annotations in protocols can also be tricky to handle without infinite recursion and stack
overflows.
```toml
[environment]
python-version = "3.12"
```
```py
from typing_extensions import Protocol, Self
from ty_extensions import static_assert
class _HashObject(Protocol):
def copy(self) -> Self: ...
class Foo: ...
# Attempting to build this union caused us to overflow on an early version of
# <https://github.com/astral-sh/ruff/pull/18659>
x: Foo | _HashObject
```
Some other similar cases that caused issues in our early `Protocol` implementation:
`a.py`:
```py
from typing_extensions import Protocol, Self
class PGconn(Protocol):
def connect(self) -> Self: ...
class Connection:
pgconn: PGconn
def is_crdb(conn: PGconn) -> bool:
return isinstance(conn, Connection)
```
and:
`b.py`:
```py
from typing_extensions import Protocol
class PGconn(Protocol):
def connect[T: PGconn](self: T) -> T: ...
class Connection:
pgconn: PGconn
def f(x: PGconn):
isinstance(x, Connection)
```
### Recursive protocols used as the first argument to `cast()`
These caused issues in an early version of our `Protocol` implementation due to the fact that we use
a recursive function in our `cast()` implementation to check whether a type contains `Unknown` or
`Todo`. Recklessly recursing into a type causes stack overflows if the type is recursive:
```toml
[environment]
python-version = "3.12"
```
```py
from typing import cast, Protocol
class Iterator[T](Protocol):
def __iter__(self) -> Iterator[T]: ...
def f(value: Iterator):
cast(Iterator, value) # error: [redundant-cast]
```
## TODO
Add tests for:

View file

@ -300,6 +300,20 @@ static_assert(not is_equivalent_to(CallableTypeOf[f12], CallableTypeOf[f13]))
static_assert(not is_equivalent_to(CallableTypeOf[f13], CallableTypeOf[f12]))
```
### Unions containing `Callable`s
Two unions containing different `Callable` types are equivalent even if the unions are differently
ordered:
```py
from ty_extensions import CallableTypeOf, Unknown, is_equivalent_to, static_assert
def f(x): ...
def g(x: Unknown): ...
static_assert(is_equivalent_to(CallableTypeOf[f] | int | str, str | int | CallableTypeOf[g]))
```
### Unions containing `Callable`s containing unions
Differently ordered unions inside `Callable`s inside unions can still be equivalent: