[ty] Implement equivalence for protocols with method members (#18659)

## Summary

This PR implements the following pieces of `Protocol` semantics:
1. A protocol with a method member that does not have a fully static
signature should not be considered fully static. I.e., this protocol is
not fully static because `Foo.x` has no return type; we previously
incorrectly considered that it was:
  ```py
  class Foo(Protocol):
      def f(self): ...
  ```
2. Two protocols `P1` and `P2`, both with method members `x`, should be
considered equivalent if the signature of `P1.x` is equivalent to the
signature of `P2.x`. Currently we do not recognize this.

Implementing these semantics requires distinguishing between method
members and non-method members. The stored type of a method member must
be eagerly upcast to a `Callable` type when collecting the protocol's
interface: doing otherwise would mean that it would be hard to implement
equivalence of protocols even in the face of differently ordered unions,
since the two equivalent protocols would have different Salsa IDs even
when normalized.

The semantics implemented by this PR are that we consider something a
method member if:
1. It is accessible on the class itself; and
2. It is a function-like callable: a callable type that also has a
`__get__` method, meaning it can be used as a method when accessed on
instances.

Note that the spec has complicated things to say about classmethod
members and staticmethod members. These semantics are not implemented by
this PR; they are all deferred for now.

The infrastructure added in this PR fixes bugs in its own right, but
also lays the groundwork for implementing subtyping and assignability
rules for method members of protocols. A (currently failing) test is
added to verify this.

## Test Plan

mdtests
This commit is contained in:
Alex Waygood 2025-07-07 12:28:32 +01:00 committed by GitHub
parent c15aa572ff
commit a6637964d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 146 additions and 21 deletions

View file

@ -1476,8 +1476,7 @@ class P1(Protocol):
class P2(Protocol):
def x(self, y: int) -> None: ...
# TODO: this should pass
static_assert(is_equivalent_to(P1, P2)) # error: [static-assert-error]
static_assert(is_equivalent_to(P1, P2))
```
As with protocols that only have non-method members, this also holds true when they appear in
@ -1487,8 +1486,7 @@ differently ordered unions:
class A: ...
class B: ...
# TODO: this should pass
static_assert(is_equivalent_to(A | B | P1, P2 | B | A)) # error: [static-assert-error]
static_assert(is_equivalent_to(A | B | P1, P2 | B | A))
```
## Narrowing of protocols
@ -1896,6 +1894,86 @@ if isinstance(obj, (B, A)):
reveal_type(obj) # revealed: (Unknown & B) | (Unknown & A)
```
### Protocols that use `Self`
`Self` is a `TypeVar` with an upper bound of the class in which it is defined. This means that
`Self` annotations in protocols can also be tricky to handle without infinite recursion and stack
overflows.
```toml
[environment]
python-version = "3.12"
```
```py
from typing_extensions import Protocol, Self
from ty_extensions import static_assert
class _HashObject(Protocol):
def copy(self) -> Self: ...
class Foo: ...
# Attempting to build this union caused us to overflow on an early version of
# <https://github.com/astral-sh/ruff/pull/18659>
x: Foo | _HashObject
```
Some other similar cases that caused issues in our early `Protocol` implementation:
`a.py`:
```py
from typing_extensions import Protocol, Self
class PGconn(Protocol):
def connect(self) -> Self: ...
class Connection:
pgconn: PGconn
def is_crdb(conn: PGconn) -> bool:
return isinstance(conn, Connection)
```
and:
`b.py`:
```py
from typing_extensions import Protocol
class PGconn(Protocol):
def connect[T: PGconn](self: T) -> T: ...
class Connection:
pgconn: PGconn
def f(x: PGconn):
isinstance(x, Connection)
```
### Recursive protocols used as the first argument to `cast()`
These caused issues in an early version of our `Protocol` implementation due to the fact that we use
a recursive function in our `cast()` implementation to check whether a type contains `Unknown` or
`Todo`. Recklessly recursing into a type causes stack overflows if the type is recursive:
```toml
[environment]
python-version = "3.12"
```
```py
from typing import cast, Protocol
class Iterator[T](Protocol):
def __iter__(self) -> Iterator[T]: ...
def f(value: Iterator):
cast(Iterator, value) # error: [redundant-cast]
```
## TODO
Add tests for:

View file

@ -300,6 +300,20 @@ static_assert(not is_equivalent_to(CallableTypeOf[f12], CallableTypeOf[f13]))
static_assert(not is_equivalent_to(CallableTypeOf[f13], CallableTypeOf[f12]))
```
### Unions containing `Callable`s
Two unions containing different `Callable` types are equivalent even if the unions are differently
ordered:
```py
from ty_extensions import CallableTypeOf, Unknown, is_equivalent_to, static_assert
def f(x): ...
def g(x: Unknown): ...
static_assert(is_equivalent_to(CallableTypeOf[f] | int | str, str | int | CallableTypeOf[g]))
```
### Unions containing `Callable`s containing unions
Differently ordered unions inside `Callable`s inside unions can still be equivalent:

View file

@ -1102,7 +1102,7 @@ impl<'db> Type<'db> {
Type::Dynamic(_) => Some(CallableType::single(db, Signature::dynamic(self))),
Type::FunctionLiteral(function_literal) => {
Some(function_literal.into_callable_type(db))
Some(Type::Callable(function_literal.into_callable_type(db)))
}
Type::BoundMethod(bound_method) => Some(bound_method.into_callable_type(db)),
@ -7336,6 +7336,10 @@ impl<'db> CallableType<'db> {
///
/// See [`Type::is_equivalent_to`] for more details.
fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
if self == other {
return true;
}
self.is_function_like(db) == other.is_function_like(db)
&& self
.signatures(db)

View file

@ -1,3 +1,5 @@
use rustc_hash::FxHashMap;
use crate::FxIndexSet;
use crate::types::Type;
use std::cmp::Eq;
@ -19,14 +21,27 @@ pub(crate) type PairVisitor<'db> = CycleDetector<(Type<'db>, Type<'db>), bool>;
#[derive(Debug)]
pub(crate) struct CycleDetector<T, R> {
/// If the type we're visiting is present in `seen`,
/// it indicates that we've hit a cycle (due to a recursive type);
/// we need to immediately short circuit the whole operation and return the fallback value.
/// That's why we pop items off the end of `seen` after we've visited them.
seen: FxIndexSet<T>,
/// Unlike `seen`, this field is a pure performance optimisation (and an essential one).
/// If the type we're trying to normalize is present in `cache`, it doesn't necessarily mean we've hit a cycle:
/// it just means that we've already visited this inner type as part of a bigger call chain we're currently in.
/// Since this cache is just a performance optimisation, it doesn't make sense to pop items off the end of the
/// cache after they've been visited (it would sort-of defeat the point of a cache if we did!)
cache: FxHashMap<T, R>,
fallback: R,
}
impl<T: Hash + Eq, R: Copy> CycleDetector<T, R> {
impl<T: Hash + Eq + Copy, R: Copy> CycleDetector<T, R> {
pub(crate) fn new(fallback: R) -> Self {
CycleDetector {
seen: FxIndexSet::default(),
cache: FxHashMap::default(),
fallback,
}
}
@ -35,7 +50,12 @@ impl<T: Hash + Eq, R: Copy> CycleDetector<T, R> {
if !self.seen.insert(item) {
return self.fallback;
}
if let Some(ty) = self.cache.get(&item) {
self.seen.pop();
return *ty;
}
let ret = func(self);
self.cache.insert(item, ret);
self.seen.pop();
ret
}

View file

@ -767,8 +767,8 @@ impl<'db> FunctionType<'db> {
}
/// Convert the `FunctionType` into a [`Type::Callable`].
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> {
Type::Callable(CallableType::new(db, self.signature(db), false))
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> CallableType<'db> {
CallableType::new(db, self.signature(db), false)
}
/// Convert the `FunctionType` into a [`Type::BoundMethod`].

View file

@ -270,7 +270,14 @@ impl<'db> ProtocolInstanceType<'db> {
///
/// TODO: consider the types of the members as well as their existence
pub(super) fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
self.normalized(db) == other.normalized(db)
if self == other {
return true;
}
let self_normalized = self.normalized(db);
if self_normalized == Type::ProtocolInstance(other) {
return true;
}
self_normalized == other.normalized(db)
}
/// Return `true` if this protocol type is disjoint from the protocol `other`.

View file

@ -260,7 +260,7 @@ impl<'db> ProtocolMemberData<'db> {
#[derive(Debug, Copy, Clone, PartialEq, Eq, salsa::Update, Hash)]
enum ProtocolMemberKind<'db> {
Method(Type<'db>), // TODO: use CallableType
Method(CallableType<'db>),
Property(PropertyInstanceType<'db>),
Other(Type<'db>),
}
@ -335,7 +335,7 @@ fn walk_protocol_member<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>(
visitor: &mut V,
) {
match member.kind {
ProtocolMemberKind::Method(method) => visitor.visit_type(db, method),
ProtocolMemberKind::Method(method) => visitor.visit_callable_type(db, method),
ProtocolMemberKind::Property(property) => {
visitor.visit_property_instance_type(db, property);
}
@ -354,7 +354,7 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
fn ty(&self) -> Type<'db> {
match &self.kind {
ProtocolMemberKind::Method(callable) => *callable,
ProtocolMemberKind::Method(callable) => Type::Callable(*callable),
ProtocolMemberKind::Property(property) => Type::PropertyInstance(*property),
ProtocolMemberKind::Other(ty) => *ty,
}
@ -508,13 +508,10 @@ fn cached_protocol_interface<'db>(
(Type::Callable(callable), BoundOnClass::Yes)
if callable.is_function_like(db) =>
{
ProtocolMemberKind::Method(ty)
ProtocolMemberKind::Method(callable)
}
// TODO: method members that have `FunctionLiteral` types should be upcast
// to `CallableType` so that two protocols with identical method members
// are recognized as equivalent.
(Type::FunctionLiteral(_function), BoundOnClass::Yes) => {
ProtocolMemberKind::Method(ty)
(Type::FunctionLiteral(function), BoundOnClass::Yes) => {
ProtocolMemberKind::Method(function.into_callable_type(db))
}
_ => ProtocolMemberKind::Other(ty),
};

View file

@ -1318,8 +1318,13 @@ impl<'db> Parameter<'db> {
form,
} = self;
// Ensure unions and intersections are ordered in the annotated type (if there is one)
let annotated_type = annotated_type.map(|ty| ty.normalized_impl(db, visitor));
// Ensure unions and intersections are ordered in the annotated type (if there is one).
// Ensure that a parameter without an annotation is treated equivalently to a parameter
// with a dynamic type as its annotation. (We must use `Any` here as all dynamic types
// normalize to `Any`.)
let annotated_type = annotated_type
.map(|ty| ty.normalized_impl(db, visitor))
.unwrap_or_else(Type::any);
// Ensure that parameter names are stripped from positional-only, variadic and keyword-variadic parameters.
// Ensure that we only record whether a parameter *has* a default
@ -1351,7 +1356,7 @@ impl<'db> Parameter<'db> {
};
Self {
annotated_type,
annotated_type: Some(annotated_type),
kind,
form: *form,
}