mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-04 18:58:04 +00:00
[ty] Implement equivalence for protocols with method members (#18659)
## Summary This PR implements the following pieces of `Protocol` semantics: 1. A protocol with a method member that does not have a fully static signature should not be considered fully static. I.e., this protocol is not fully static because `Foo.x` has no return type; we previously incorrectly considered that it was: ```py class Foo(Protocol): def f(self): ... ``` 2. Two protocols `P1` and `P2`, both with method members `x`, should be considered equivalent if the signature of `P1.x` is equivalent to the signature of `P2.x`. Currently we do not recognize this. Implementing these semantics requires distinguishing between method members and non-method members. The stored type of a method member must be eagerly upcast to a `Callable` type when collecting the protocol's interface: doing otherwise would mean that it would be hard to implement equivalence of protocols even in the face of differently ordered unions, since the two equivalent protocols would have different Salsa IDs even when normalized. The semantics implemented by this PR are that we consider something a method member if: 1. It is accessible on the class itself; and 2. It is a function-like callable: a callable type that also has a `__get__` method, meaning it can be used as a method when accessed on instances. Note that the spec has complicated things to say about classmethod members and staticmethod members. These semantics are not implemented by this PR; they are all deferred for now. The infrastructure added in this PR fixes bugs in its own right, but also lays the groundwork for implementing subtyping and assignability rules for method members of protocols. A (currently failing) test is added to verify this. ## Test Plan mdtests
This commit is contained in:
parent
c15aa572ff
commit
a6637964d2
8 changed files with 146 additions and 21 deletions
|
@ -1476,8 +1476,7 @@ class P1(Protocol):
|
|||
class P2(Protocol):
|
||||
def x(self, y: int) -> None: ...
|
||||
|
||||
# TODO: this should pass
|
||||
static_assert(is_equivalent_to(P1, P2)) # error: [static-assert-error]
|
||||
static_assert(is_equivalent_to(P1, P2))
|
||||
```
|
||||
|
||||
As with protocols that only have non-method members, this also holds true when they appear in
|
||||
|
@ -1487,8 +1486,7 @@ differently ordered unions:
|
|||
class A: ...
|
||||
class B: ...
|
||||
|
||||
# TODO: this should pass
|
||||
static_assert(is_equivalent_to(A | B | P1, P2 | B | A)) # error: [static-assert-error]
|
||||
static_assert(is_equivalent_to(A | B | P1, P2 | B | A))
|
||||
```
|
||||
|
||||
## Narrowing of protocols
|
||||
|
@ -1896,6 +1894,86 @@ if isinstance(obj, (B, A)):
|
|||
reveal_type(obj) # revealed: (Unknown & B) | (Unknown & A)
|
||||
```
|
||||
|
||||
### Protocols that use `Self`
|
||||
|
||||
`Self` is a `TypeVar` with an upper bound of the class in which it is defined. This means that
|
||||
`Self` annotations in protocols can also be tricky to handle without infinite recursion and stack
|
||||
overflows.
|
||||
|
||||
```toml
|
||||
[environment]
|
||||
python-version = "3.12"
|
||||
```
|
||||
|
||||
```py
|
||||
from typing_extensions import Protocol, Self
|
||||
from ty_extensions import static_assert
|
||||
|
||||
class _HashObject(Protocol):
|
||||
def copy(self) -> Self: ...
|
||||
|
||||
class Foo: ...
|
||||
|
||||
# Attempting to build this union caused us to overflow on an early version of
|
||||
# <https://github.com/astral-sh/ruff/pull/18659>
|
||||
x: Foo | _HashObject
|
||||
```
|
||||
|
||||
Some other similar cases that caused issues in our early `Protocol` implementation:
|
||||
|
||||
`a.py`:
|
||||
|
||||
```py
|
||||
from typing_extensions import Protocol, Self
|
||||
|
||||
class PGconn(Protocol):
|
||||
def connect(self) -> Self: ...
|
||||
|
||||
class Connection:
|
||||
pgconn: PGconn
|
||||
|
||||
def is_crdb(conn: PGconn) -> bool:
|
||||
return isinstance(conn, Connection)
|
||||
```
|
||||
|
||||
and:
|
||||
|
||||
`b.py`:
|
||||
|
||||
```py
|
||||
from typing_extensions import Protocol
|
||||
|
||||
class PGconn(Protocol):
|
||||
def connect[T: PGconn](self: T) -> T: ...
|
||||
|
||||
class Connection:
|
||||
pgconn: PGconn
|
||||
|
||||
def f(x: PGconn):
|
||||
isinstance(x, Connection)
|
||||
```
|
||||
|
||||
### Recursive protocols used as the first argument to `cast()`
|
||||
|
||||
These caused issues in an early version of our `Protocol` implementation due to the fact that we use
|
||||
a recursive function in our `cast()` implementation to check whether a type contains `Unknown` or
|
||||
`Todo`. Recklessly recursing into a type causes stack overflows if the type is recursive:
|
||||
|
||||
```toml
|
||||
[environment]
|
||||
python-version = "3.12"
|
||||
```
|
||||
|
||||
```py
|
||||
from typing import cast, Protocol
|
||||
|
||||
class Iterator[T](Protocol):
|
||||
def __iter__(self) -> Iterator[T]: ...
|
||||
|
||||
def f(value: Iterator):
|
||||
cast(Iterator, value) # error: [redundant-cast]
|
||||
```
|
||||
|
||||
## TODO
|
||||
|
||||
Add tests for:
|
||||
|
|
|
@ -300,6 +300,20 @@ static_assert(not is_equivalent_to(CallableTypeOf[f12], CallableTypeOf[f13]))
|
|||
static_assert(not is_equivalent_to(CallableTypeOf[f13], CallableTypeOf[f12]))
|
||||
```
|
||||
|
||||
### Unions containing `Callable`s
|
||||
|
||||
Two unions containing different `Callable` types are equivalent even if the unions are differently
|
||||
ordered:
|
||||
|
||||
```py
|
||||
from ty_extensions import CallableTypeOf, Unknown, is_equivalent_to, static_assert
|
||||
|
||||
def f(x): ...
|
||||
def g(x: Unknown): ...
|
||||
|
||||
static_assert(is_equivalent_to(CallableTypeOf[f] | int | str, str | int | CallableTypeOf[g]))
|
||||
```
|
||||
|
||||
### Unions containing `Callable`s containing unions
|
||||
|
||||
Differently ordered unions inside `Callable`s inside unions can still be equivalent:
|
||||
|
|
|
@ -1102,7 +1102,7 @@ impl<'db> Type<'db> {
|
|||
Type::Dynamic(_) => Some(CallableType::single(db, Signature::dynamic(self))),
|
||||
|
||||
Type::FunctionLiteral(function_literal) => {
|
||||
Some(function_literal.into_callable_type(db))
|
||||
Some(Type::Callable(function_literal.into_callable_type(db)))
|
||||
}
|
||||
Type::BoundMethod(bound_method) => Some(bound_method.into_callable_type(db)),
|
||||
|
||||
|
@ -7336,6 +7336,10 @@ impl<'db> CallableType<'db> {
|
|||
///
|
||||
/// See [`Type::is_equivalent_to`] for more details.
|
||||
fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
|
||||
if self == other {
|
||||
return true;
|
||||
}
|
||||
|
||||
self.is_function_like(db) == other.is_function_like(db)
|
||||
&& self
|
||||
.signatures(db)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::FxIndexSet;
|
||||
use crate::types::Type;
|
||||
use std::cmp::Eq;
|
||||
|
@ -19,14 +21,27 @@ pub(crate) type PairVisitor<'db> = CycleDetector<(Type<'db>, Type<'db>), bool>;
|
|||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct CycleDetector<T, R> {
|
||||
/// If the type we're visiting is present in `seen`,
|
||||
/// it indicates that we've hit a cycle (due to a recursive type);
|
||||
/// we need to immediately short circuit the whole operation and return the fallback value.
|
||||
/// That's why we pop items off the end of `seen` after we've visited them.
|
||||
seen: FxIndexSet<T>,
|
||||
|
||||
/// Unlike `seen`, this field is a pure performance optimisation (and an essential one).
|
||||
/// If the type we're trying to normalize is present in `cache`, it doesn't necessarily mean we've hit a cycle:
|
||||
/// it just means that we've already visited this inner type as part of a bigger call chain we're currently in.
|
||||
/// Since this cache is just a performance optimisation, it doesn't make sense to pop items off the end of the
|
||||
/// cache after they've been visited (it would sort-of defeat the point of a cache if we did!)
|
||||
cache: FxHashMap<T, R>,
|
||||
|
||||
fallback: R,
|
||||
}
|
||||
|
||||
impl<T: Hash + Eq, R: Copy> CycleDetector<T, R> {
|
||||
impl<T: Hash + Eq + Copy, R: Copy> CycleDetector<T, R> {
|
||||
pub(crate) fn new(fallback: R) -> Self {
|
||||
CycleDetector {
|
||||
seen: FxIndexSet::default(),
|
||||
cache: FxHashMap::default(),
|
||||
fallback,
|
||||
}
|
||||
}
|
||||
|
@ -35,7 +50,12 @@ impl<T: Hash + Eq, R: Copy> CycleDetector<T, R> {
|
|||
if !self.seen.insert(item) {
|
||||
return self.fallback;
|
||||
}
|
||||
if let Some(ty) = self.cache.get(&item) {
|
||||
self.seen.pop();
|
||||
return *ty;
|
||||
}
|
||||
let ret = func(self);
|
||||
self.cache.insert(item, ret);
|
||||
self.seen.pop();
|
||||
ret
|
||||
}
|
||||
|
|
|
@ -767,8 +767,8 @@ impl<'db> FunctionType<'db> {
|
|||
}
|
||||
|
||||
/// Convert the `FunctionType` into a [`Type::Callable`].
|
||||
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> {
|
||||
Type::Callable(CallableType::new(db, self.signature(db), false))
|
||||
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> CallableType<'db> {
|
||||
CallableType::new(db, self.signature(db), false)
|
||||
}
|
||||
|
||||
/// Convert the `FunctionType` into a [`Type::BoundMethod`].
|
||||
|
|
|
@ -270,7 +270,14 @@ impl<'db> ProtocolInstanceType<'db> {
|
|||
///
|
||||
/// TODO: consider the types of the members as well as their existence
|
||||
pub(super) fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
|
||||
self.normalized(db) == other.normalized(db)
|
||||
if self == other {
|
||||
return true;
|
||||
}
|
||||
let self_normalized = self.normalized(db);
|
||||
if self_normalized == Type::ProtocolInstance(other) {
|
||||
return true;
|
||||
}
|
||||
self_normalized == other.normalized(db)
|
||||
}
|
||||
|
||||
/// Return `true` if this protocol type is disjoint from the protocol `other`.
|
||||
|
|
|
@ -260,7 +260,7 @@ impl<'db> ProtocolMemberData<'db> {
|
|||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, salsa::Update, Hash)]
|
||||
enum ProtocolMemberKind<'db> {
|
||||
Method(Type<'db>), // TODO: use CallableType
|
||||
Method(CallableType<'db>),
|
||||
Property(PropertyInstanceType<'db>),
|
||||
Other(Type<'db>),
|
||||
}
|
||||
|
@ -335,7 +335,7 @@ fn walk_protocol_member<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>(
|
|||
visitor: &mut V,
|
||||
) {
|
||||
match member.kind {
|
||||
ProtocolMemberKind::Method(method) => visitor.visit_type(db, method),
|
||||
ProtocolMemberKind::Method(method) => visitor.visit_callable_type(db, method),
|
||||
ProtocolMemberKind::Property(property) => {
|
||||
visitor.visit_property_instance_type(db, property);
|
||||
}
|
||||
|
@ -354,7 +354,7 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
|
|||
|
||||
fn ty(&self) -> Type<'db> {
|
||||
match &self.kind {
|
||||
ProtocolMemberKind::Method(callable) => *callable,
|
||||
ProtocolMemberKind::Method(callable) => Type::Callable(*callable),
|
||||
ProtocolMemberKind::Property(property) => Type::PropertyInstance(*property),
|
||||
ProtocolMemberKind::Other(ty) => *ty,
|
||||
}
|
||||
|
@ -508,13 +508,10 @@ fn cached_protocol_interface<'db>(
|
|||
(Type::Callable(callable), BoundOnClass::Yes)
|
||||
if callable.is_function_like(db) =>
|
||||
{
|
||||
ProtocolMemberKind::Method(ty)
|
||||
ProtocolMemberKind::Method(callable)
|
||||
}
|
||||
// TODO: method members that have `FunctionLiteral` types should be upcast
|
||||
// to `CallableType` so that two protocols with identical method members
|
||||
// are recognized as equivalent.
|
||||
(Type::FunctionLiteral(_function), BoundOnClass::Yes) => {
|
||||
ProtocolMemberKind::Method(ty)
|
||||
(Type::FunctionLiteral(function), BoundOnClass::Yes) => {
|
||||
ProtocolMemberKind::Method(function.into_callable_type(db))
|
||||
}
|
||||
_ => ProtocolMemberKind::Other(ty),
|
||||
};
|
||||
|
|
|
@ -1318,8 +1318,13 @@ impl<'db> Parameter<'db> {
|
|||
form,
|
||||
} = self;
|
||||
|
||||
// Ensure unions and intersections are ordered in the annotated type (if there is one)
|
||||
let annotated_type = annotated_type.map(|ty| ty.normalized_impl(db, visitor));
|
||||
// Ensure unions and intersections are ordered in the annotated type (if there is one).
|
||||
// Ensure that a parameter without an annotation is treated equivalently to a parameter
|
||||
// with a dynamic type as its annotation. (We must use `Any` here as all dynamic types
|
||||
// normalize to `Any`.)
|
||||
let annotated_type = annotated_type
|
||||
.map(|ty| ty.normalized_impl(db, visitor))
|
||||
.unwrap_or_else(Type::any);
|
||||
|
||||
// Ensure that parameter names are stripped from positional-only, variadic and keyword-variadic parameters.
|
||||
// Ensure that we only record whether a parameter *has* a default
|
||||
|
@ -1351,7 +1356,7 @@ impl<'db> Parameter<'db> {
|
|||
};
|
||||
|
||||
Self {
|
||||
annotated_type,
|
||||
annotated_type: Some(annotated_type),
|
||||
kind,
|
||||
form: *form,
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue