[ty] Improve protocol member type checking and relation handling (#18847)

Co-authored-by: Alex Waygood <alex.waygood@gmail.com>
This commit is contained in:
Shunsuke Shibayama 2025-06-29 19:46:33 +09:00 committed by GitHub
parent 9218bf72ad
commit de1f8177be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 519 additions and 132 deletions

View file

@ -371,14 +371,23 @@ class MyCallable:
f_wrong(MyCallable()) # raises `AttributeError` at runtime
```
If users want to write to attributes such as `__qualname__`, they need to check the existence of the
attribute first:
If users want to read/write to attributes such as `__qualname__`, they need to check the existence
of the attribute first:
```py
from inspect import getattr_static
def f_okay(c: Callable[[], None]):
if hasattr(c, "__qualname__"):
c.__qualname__ # okay
c.__qualname__ = "my_callable" # also okay
# `hasattr` only guarantees that an attribute is readable.
# error: [invalid-assignment] "Object of type `Literal["my_callable"]` is not assignable to attribute `__qualname__` on type `(() -> None) & <Protocol with members '__qualname__'>`"
c.__qualname__ = "my_callable"
result = getattr_static(c, "__qualname__")
reveal_type(result) # revealed: Never
if isinstance(result, property) and result.fset:
c.__qualname__ = "my_callable" # okay
```
[gradual form]: https://typing.python.org/en/latest/spec/glossary.html#term-gradual-form

View file

@ -482,8 +482,8 @@ reveal_type(c.method3()) # revealed: LinkedList[int]
class SomeProtocol(Protocol[T]):
x: T
class Foo:
x: int
class Foo(Generic[T]):
x: T
class D(Generic[T, U]):
x: T

View file

@ -5,6 +5,7 @@ accomplished using an intersection with a synthesized protocol:
```py
from typing import final
from typing_extensions import LiteralString
class Foo: ...
@ -56,4 +57,10 @@ def h(obj: Baz):
# TODO: should emit `[unresolved-attribute]` and reveal `Unknown`
reveal_type(obj.x) # revealed: @Todo(map_with_boundness: intersections with negative contributions)
def i(x: int | LiteralString):
if hasattr(x, "capitalize"):
reveal_type(x) # revealed: (int & <Protocol with members 'capitalize'>) | LiteralString
else:
reveal_type(x) # revealed: int & ~<Protocol with members 'capitalize'>
```

View file

@ -489,35 +489,122 @@ python-version = "3.12"
```
```py
from typing import Protocol
from typing import Protocol, Any, ClassVar
from collections.abc import Sequence
from ty_extensions import static_assert, is_assignable_to, is_subtype_of
class HasX(Protocol):
x: int
class HasXY(Protocol):
x: int
y: int
class Foo:
x: int
static_assert(is_subtype_of(Foo, HasX))
static_assert(is_assignable_to(Foo, HasX))
static_assert(not is_subtype_of(Foo, HasXY))
static_assert(not is_assignable_to(Foo, HasXY))
class FooSub(Foo): ...
static_assert(is_subtype_of(FooSub, HasX))
static_assert(is_assignable_to(FooSub, HasX))
static_assert(not is_subtype_of(FooSub, HasXY))
static_assert(not is_assignable_to(FooSub, HasXY))
class FooBool(Foo):
x: bool
static_assert(not is_subtype_of(FooBool, HasX))
static_assert(not is_assignable_to(FooBool, HasX))
class FooAny:
x: Any
static_assert(not is_subtype_of(FooAny, HasX))
static_assert(is_assignable_to(FooAny, HasX))
class SubclassOfAny(Any): ...
class FooSubclassOfAny:
x: SubclassOfAny
static_assert(not is_subtype_of(FooSubclassOfAny, HasX))
static_assert(not is_assignable_to(FooSubclassOfAny, HasX))
class FooWithY(Foo):
y: int
assert is_subtype_of(FooWithY, HasXY)
static_assert(is_assignable_to(FooWithY, HasXY))
class Bar:
x: str
# TODO: these should pass
static_assert(not is_subtype_of(Bar, HasX)) # error: [static-assert-error]
static_assert(not is_assignable_to(Bar, HasX)) # error: [static-assert-error]
static_assert(not is_subtype_of(Bar, HasX))
static_assert(not is_assignable_to(Bar, HasX))
class Baz:
y: int
static_assert(not is_subtype_of(Baz, HasX))
static_assert(not is_assignable_to(Baz, HasX))
class Qux:
def __init__(self, x: int) -> None:
self.x: int = x
static_assert(is_subtype_of(Qux, HasX))
static_assert(is_assignable_to(Qux, HasX))
class HalfUnknownQux:
def __init__(self, x: int) -> None:
self.x = x
reveal_type(HalfUnknownQux(1).x) # revealed: Unknown | int
static_assert(not is_subtype_of(HalfUnknownQux, HasX))
static_assert(is_assignable_to(HalfUnknownQux, HasX))
class FullyUnknownQux:
def __init__(self, x) -> None:
self.x = x
static_assert(not is_subtype_of(FullyUnknownQux, HasX))
static_assert(is_assignable_to(FullyUnknownQux, HasX))
class HasXWithDefault(Protocol):
x: int = 0
class FooWithZero:
x: int = 0
# TODO: these should pass
static_assert(is_subtype_of(FooWithZero, HasXWithDefault)) # error: [static-assert-error]
static_assert(is_assignable_to(FooWithZero, HasXWithDefault)) # error: [static-assert-error]
static_assert(not is_subtype_of(Foo, HasXWithDefault))
static_assert(not is_assignable_to(Foo, HasXWithDefault))
static_assert(not is_subtype_of(Qux, HasXWithDefault))
static_assert(not is_assignable_to(Qux, HasXWithDefault))
class HasClassVarX(Protocol):
x: ClassVar[int]
static_assert(is_subtype_of(FooWithZero, HasClassVarX))
static_assert(is_assignable_to(FooWithZero, HasClassVarX))
# TODO: these should pass
static_assert(not is_subtype_of(Foo, HasClassVarX)) # error: [static-assert-error]
static_assert(not is_assignable_to(Foo, HasClassVarX)) # error: [static-assert-error]
static_assert(not is_subtype_of(Qux, HasClassVarX)) # error: [static-assert-error]
static_assert(not is_assignable_to(Qux, HasClassVarX)) # error: [static-assert-error]
static_assert(is_subtype_of(Sequence[Foo], Sequence[HasX]))
static_assert(is_assignable_to(Sequence[Foo], Sequence[HasX]))
static_assert(not is_subtype_of(list[Foo], list[HasX]))
static_assert(not is_assignable_to(list[Foo], list[HasX]))
```
Note that declaring an attribute member on a protocol mandates that the attribute must be mutable. A
@ -552,10 +639,8 @@ class C:
# due to invariance, a type is only a subtype of `HasX`
# if its `x` attribute is of type *exactly* `int`:
# a subclass of `int` does not satisfy the interface
#
# TODO: these should pass
static_assert(not is_subtype_of(C, HasX)) # error: [static-assert-error]
static_assert(not is_assignable_to(C, HasX)) # error: [static-assert-error]
static_assert(not is_subtype_of(C, HasX))
static_assert(not is_assignable_to(C, HasX))
```
All attributes on frozen dataclasses and namedtuples are immutable, so instances of these classes
@ -1229,6 +1314,62 @@ static_assert(is_subtype_of(HasGetAttrAndSetAttr, XAsymmetricProperty)) # error
static_assert(is_assignable_to(HasGetAttrAndSetAttr, XAsymmetricProperty)) # error: [static-assert-error]
```
## Subtyping of protocols with method members
A protocol can have method members. `T` is assignable to `P` in the following example because the
class `T` has a method `m` which is assignable to the `Callable` supertype of the method `P.m`:
```py
from typing import Protocol
from ty_extensions import is_subtype_of, static_assert
class P(Protocol):
def m(self, x: int, /) -> None: ...
class NominalSubtype:
def m(self, y: int) -> None: ...
class NotSubtype:
def m(self, x: int) -> int:
return 42
static_assert(is_subtype_of(NominalSubtype, P))
# TODO: should pass
static_assert(not is_subtype_of(NotSubtype, P)) # error: [static-assert-error]
```
## Equivalence of protocols with method members
Two protocols `P1` and `P2`, both with a method member `x`, are considered equivalent if the
signature of `P1.x` is equivalent to the signature of `P2.x`, even though ty would normally model
any two function definitions as inhabiting distinct function-literal types.
```py
from typing import Protocol
from ty_extensions import is_equivalent_to, static_assert
class P1(Protocol):
def x(self, y: int) -> None: ...
class P2(Protocol):
def x(self, y: int) -> None: ...
# TODO: this should pass
static_assert(is_equivalent_to(P1, P2)) # error: [static-assert-error]
```
As with protocols that only have non-method members, this also holds true when they appear in
differently ordered unions:
```py
class A: ...
class B: ...
# TODO: this should pass
static_assert(is_equivalent_to(A | B | P1, P2 | B | A)) # error: [static-assert-error]
```
## Narrowing of protocols
<!-- snapshot-diagnostics -->
@ -1458,7 +1599,7 @@ def two(some_list: list, some_tuple: tuple[int, str], some_sized: Sized):
```py
from __future__ import annotations
from typing import Protocol, Any
from typing import Protocol, Any, TypeVar
from ty_extensions import static_assert, is_assignable_to, is_subtype_of, is_equivalent_to
class RecursiveFullyStatic(Protocol):
@ -1514,6 +1655,17 @@ class Bar(Protocol):
# TODO: this should pass
# error: [static-assert-error]
static_assert(is_equivalent_to(Foo, Bar))
T = TypeVar("T", bound="TypeVarRecursive")
class TypeVarRecursive(Protocol):
# TODO: commenting this out will cause a stack overflow.
# x: T
y: "TypeVarRecursive"
def _(t: TypeVarRecursive):
# reveal_type(t.x) # revealed: T
reveal_type(t.y) # revealed: TypeVarRecursive
```
### Nested occurrences of self-reference

View file

@ -501,6 +501,67 @@ static_assert(is_disjoint_from(str, TypeGuard[str])) # error: [static-assert-er
static_assert(is_disjoint_from(str, TypeIs[str]))
```
### `Protocol`
A protocol is disjoint from another type if any of the protocol's members are available as an
attribute on the other type *but* the type of the attribute on the other type is disjoint from the
type of the protocol's member.
```py
from typing_extensions import Protocol, Literal, final, ClassVar
from ty_extensions import is_disjoint_from, static_assert
class HasAttrA(Protocol):
attr: Literal["a"]
class SupportsInt(Protocol):
def __int__(self) -> int: ...
class A:
attr: Literal["a"]
class B:
attr: Literal["b"]
class C:
foo: int
class D:
attr: int
@final
class E:
pass
@final
class F:
def __int__(self) -> int:
return 1
static_assert(not is_disjoint_from(HasAttrA, A))
static_assert(is_disjoint_from(HasAttrA, B))
# A subclass of E may satisfy HasAttrA
static_assert(not is_disjoint_from(HasAttrA, C))
static_assert(is_disjoint_from(HasAttrA, D))
static_assert(is_disjoint_from(HasAttrA, E))
static_assert(is_disjoint_from(SupportsInt, E))
static_assert(not is_disjoint_from(SupportsInt, F))
class NotIterable(Protocol):
__iter__: ClassVar[None]
static_assert(is_disjoint_from(tuple[int, int], NotIterable))
class Foo:
BAR: ClassVar[int]
class BarNone(Protocol):
BAR: None
static_assert(is_disjoint_from(type[Foo], BarNone))
```
## Callables
No two callable types are disjoint because there exists a non-empty callable type

View file

@ -404,6 +404,22 @@ impl<'db> PropertyInstanceType<'db> {
ty.find_legacy_typevars(db, typevars);
}
}
fn materialize(self, db: &'db dyn Db, variance: TypeVarVariance) -> Self {
Self::new(
db,
self.getter(db).map(|ty| ty.materialize(db, variance)),
self.setter(db).map(|ty| ty.materialize(db, variance)),
)
}
fn any_over_type(self, db: &'db dyn Db, type_fn: &dyn Fn(Type<'db>) -> bool) -> bool {
self.getter(db)
.is_some_and(|ty| ty.any_over_type(db, type_fn))
|| self
.setter(db)
.is_some_and(|ty| ty.any_over_type(db, type_fn))
}
}
bitflags! {
@ -681,10 +697,13 @@ impl<'db> Type<'db> {
| Type::KnownInstance(_)
| Type::AlwaysFalsy
| Type::AlwaysTruthy
| Type::PropertyInstance(_)
| Type::ClassLiteral(_)
| Type::BoundSuper(_) => *self,
Type::PropertyInstance(property_instance) => {
Type::PropertyInstance(property_instance.materialize(db, variance))
}
Type::FunctionLiteral(_) | Type::BoundMethod(_) => {
// TODO: Subtyping between function / methods with a callable accounts for the
// signature (parameters and return type), so we might need to do something here
@ -902,15 +921,7 @@ impl<'db> Type<'db> {
}
Self::ProtocolInstance(protocol) => protocol.any_over_type(db, type_fn),
Self::PropertyInstance(property) => {
property
.getter(db)
.is_some_and(|ty| ty.any_over_type(db, type_fn))
|| property
.setter(db)
.is_some_and(|ty| ty.any_over_type(db, type_fn))
}
Self::PropertyInstance(property) => property.any_over_type(db, type_fn),
Self::NominalInstance(instance) => match instance.class {
ClassType::NonGeneric(_) => false,
@ -1453,7 +1464,9 @@ impl<'db> Type<'db> {
}
// A protocol instance can never be a subtype of a nominal type, with the *sole* exception of `object`.
(Type::ProtocolInstance(_), _) => false,
(_, Type::ProtocolInstance(protocol)) => self.satisfies_protocol(db, protocol),
(_, Type::ProtocolInstance(protocol)) => {
self.satisfies_protocol(db, protocol, relation)
}
// All `StringLiteral` types are a subtype of `LiteralString`.
(Type::StringLiteral(_), Type::LiteralString) => true,
@ -1865,26 +1878,6 @@ impl<'db> Type<'db> {
Type::Tuple(..),
) => true,
(Type::SubclassOf(subclass_of_ty), Type::ClassLiteral(class_b))
| (Type::ClassLiteral(class_b), Type::SubclassOf(subclass_of_ty)) => {
match subclass_of_ty.subclass_of() {
SubclassOfInner::Dynamic(_) => false,
SubclassOfInner::Class(class_a) => !class_b.is_subclass_of(db, None, class_a),
}
}
(Type::SubclassOf(subclass_of_ty), Type::GenericAlias(alias_b))
| (Type::GenericAlias(alias_b), Type::SubclassOf(subclass_of_ty)) => {
match subclass_of_ty.subclass_of() {
SubclassOfInner::Dynamic(_) => false,
SubclassOfInner::Class(class_a) => {
!ClassType::from(alias_b).is_subclass_of(db, class_a)
}
}
}
(Type::SubclassOf(left), Type::SubclassOf(right)) => left.is_disjoint_from(db, right),
(
Type::SubclassOf(_),
Type::BooleanLiteral(..)
@ -1912,28 +1905,6 @@ impl<'db> Type<'db> {
Type::SubclassOf(_),
) => true,
(Type::AlwaysTruthy, ty) | (ty, Type::AlwaysTruthy) => {
// `Truthiness::Ambiguous` may include `AlwaysTrue` as a subset, so it's not guaranteed to be disjoint.
// Thus, they are only disjoint if `ty.bool() == AlwaysFalse`.
ty.bool(db).is_always_false()
}
(Type::AlwaysFalsy, ty) | (ty, Type::AlwaysFalsy) => {
// Similarly, they are only disjoint if `ty.bool() == AlwaysTrue`.
ty.bool(db).is_always_true()
}
(Type::ProtocolInstance(left), Type::ProtocolInstance(right)) => {
left.is_disjoint_from(db, right)
}
// TODO: we could also consider `protocol` to be disjoint from `nominal` if `nominal`
// has the right member but the type of its member is disjoint from the type of the
// member on `protocol`.
(Type::ProtocolInstance(protocol), nominal @ Type::NominalInstance(n))
| (nominal @ Type::NominalInstance(n), Type::ProtocolInstance(protocol)) => {
n.class.is_final(db) && !nominal.satisfies_protocol(db, protocol)
}
(
ty @ (Type::LiteralString
| Type::StringLiteral(..)
@ -1957,36 +1928,75 @@ impl<'db> Type<'db> {
| Type::ModuleLiteral(..)
| Type::GenericAlias(..)
| Type::IntLiteral(..)),
) => !ty.satisfies_protocol(db, protocol),
) => !ty.satisfies_protocol(db, protocol, TypeRelation::Assignability),
(Type::AlwaysTruthy, ty) | (ty, Type::AlwaysTruthy) => {
// `Truthiness::Ambiguous` may include `AlwaysTrue` as a subset, so it's not guaranteed to be disjoint.
// Thus, they are only disjoint if `ty.bool() == AlwaysFalse`.
ty.bool(db).is_always_false()
}
(Type::AlwaysFalsy, ty) | (ty, Type::AlwaysFalsy) => {
// Similarly, they are only disjoint if `ty.bool() == AlwaysTrue`.
ty.bool(db).is_always_true()
}
(Type::ProtocolInstance(left), Type::ProtocolInstance(right)) => {
left.is_disjoint_from(db, right)
}
(Type::ProtocolInstance(protocol), Type::SpecialForm(special_form))
| (Type::SpecialForm(special_form), Type::ProtocolInstance(protocol)) => !special_form
.instance_fallback(db)
.satisfies_protocol(db, protocol),
.satisfies_protocol(db, protocol, TypeRelation::Assignability),
(Type::ProtocolInstance(protocol), Type::KnownInstance(known_instance))
| (Type::KnownInstance(known_instance), Type::ProtocolInstance(protocol)) => {
!known_instance
.instance_fallback(db)
.satisfies_protocol(db, protocol)
!known_instance.instance_fallback(db).satisfies_protocol(
db,
protocol,
TypeRelation::Assignability,
)
}
(Type::Callable(_), Type::ProtocolInstance(_))
| (Type::ProtocolInstance(_), Type::Callable(_)) => {
// TODO disjointness between `Callable` and `ProtocolInstance`
false
(Type::ProtocolInstance(protocol), nominal @ Type::NominalInstance(n))
| (nominal @ Type::NominalInstance(n), Type::ProtocolInstance(protocol))
if n.class.is_final(db) =>
{
!nominal.satisfies_protocol(db, protocol, TypeRelation::Assignability)
}
(Type::Tuple(..), Type::ProtocolInstance(..))
| (Type::ProtocolInstance(..), Type::Tuple(..)) => {
// Currently we do not make any general assumptions about the disjointness of a `Tuple` type
// and a `ProtocolInstance` type because a `Tuple` type can be an instance of a tuple
// subclass.
//
// TODO when we capture the types of the protocol members, we can improve on this.
false
(Type::ProtocolInstance(protocol), other)
| (other, Type::ProtocolInstance(protocol)) => {
protocol.interface(db).members(db).any(|member| {
// TODO: implement disjointness for property/method members as well as attribute members
member.is_attribute_member()
&& matches!(
other.member(db, member.name()).place,
Place::Type(ty, Boundness::Bound) if ty.is_disjoint_from(db, member.ty())
)
})
}
(Type::SubclassOf(subclass_of_ty), Type::ClassLiteral(class_b))
| (Type::ClassLiteral(class_b), Type::SubclassOf(subclass_of_ty)) => {
match subclass_of_ty.subclass_of() {
SubclassOfInner::Dynamic(_) => false,
SubclassOfInner::Class(class_a) => !class_b.is_subclass_of(db, None, class_a),
}
}
(Type::SubclassOf(subclass_of_ty), Type::GenericAlias(alias_b))
| (Type::GenericAlias(alias_b), Type::SubclassOf(subclass_of_ty)) => {
match subclass_of_ty.subclass_of() {
SubclassOfInner::Dynamic(_) => false,
SubclassOfInner::Class(class_a) => {
!ClassType::from(alias_b).is_subclass_of(db, class_a)
}
}
}
(Type::SubclassOf(left), Type::SubclassOf(right)) => left.is_disjoint_from(db, right),
// for `type[Any]`/`type[Unknown]`/`type[Todo]`, we know the type cannot be any larger than `type`,
// so although the type is dynamic we can still determine disjointedness in some situations
(Type::SubclassOf(subclass_of_ty), other)
@ -2531,6 +2541,11 @@ impl<'db> Type<'db> {
Type::Intersection(inter) => inter.map_with_boundness_and_qualifiers(db, |elem| {
elem.class_member_with_policy(db, name.clone(), policy)
}),
// TODO: Once `to_meta_type` for the synthesized protocol is fully implemented, this handling should be removed.
Type::ProtocolInstance(ProtocolInstanceType {
inner: Protocol::Synthesized(_),
..
}) => self.instance_member(db, &name),
_ => self
.to_meta_type(db)
.find_name_in_mro_with_policy(db, name.as_str(), policy)

View file

@ -4,7 +4,7 @@ use std::marker::PhantomData;
use super::protocol_class::ProtocolInterface;
use super::{ClassType, KnownClass, SubclassOfType, Type, TypeVarVariance};
use crate::place::{Boundness, Place, PlaceAndQualifiers};
use crate::place::{Place, PlaceAndQualifiers};
use crate::types::tuple::TupleType;
use crate::types::{ClassLiteral, DynamicType, TypeMapping, TypeRelation, TypeVarInstance};
use crate::{Db, FxOrderSet};
@ -35,31 +35,28 @@ impl<'db> Type<'db> {
}
}
pub(super) fn synthesized_protocol<'a, M>(db: &'db dyn Db, members: M) -> Self
/// Synthesize a protocol instance type with a given set of read-only property members.
pub(super) fn protocol_with_readonly_members<'a, M>(db: &'db dyn Db, members: M) -> Self
where
M: IntoIterator<Item = (&'a str, Type<'db>)>,
{
Self::ProtocolInstance(ProtocolInstanceType::synthesized(
SynthesizedProtocolType::new(db, ProtocolInterface::with_members(db, members)),
SynthesizedProtocolType::new(db, ProtocolInterface::with_property_members(db, members)),
))
}
/// Return `true` if `self` conforms to the interface described by `protocol`.
///
/// TODO: we may need to split this into two methods in the future, once we start
/// differentiating between fully-static and non-fully-static protocols.
pub(super) fn satisfies_protocol(
self,
db: &'db dyn Db,
protocol: ProtocolInstanceType<'db>,
relation: TypeRelation,
) -> bool {
// TODO: this should consider the types of the protocol members
protocol.inner.interface(db).members(db).all(|member| {
matches!(
self.member(db, member.name()).place,
Place::Type(_, Boundness::Bound)
)
})
protocol
.inner
.interface(db)
.members(db)
.all(|member| member.is_satisfied_by(db, self, relation))
}
}
@ -205,7 +202,7 @@ impl<'db> ProtocolInstanceType<'db> {
/// See [`Type::normalized`] for more details.
pub(super) fn normalized(self, db: &'db dyn Db) -> Type<'db> {
let object = KnownClass::Object.to_instance(db);
if object.satisfies_protocol(db, self) {
if object.satisfies_protocol(db, self, TypeRelation::Subtyping) {
return object;
}
match self.inner {
@ -322,6 +319,10 @@ impl<'db> ProtocolInstanceType<'db> {
}
}
}
pub(super) fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> {
self.inner.interface(db)
}
}
/// An enumeration of the two kinds of protocol types: those that originate from a class

View file

@ -832,9 +832,11 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
return None;
}
let constraint = Type::synthesized_protocol(
// Since `hasattr` only checks if an attribute is readable,
// the type of the protocol member should be a read-only property that returns `object`.
let constraint = Type::protocol_with_readonly_members(
self.db,
[(attr, KnownClass::Object.to_instance(self.db))],
[(attr, Type::object(self.db))],
);
return Some(NarrowingConstraints::from_iter([(

View file

@ -6,10 +6,12 @@ use ruff_python_ast::name::Name;
use crate::{
Db, FxOrderSet,
place::{place_from_bindings, place_from_declarations},
place::{Boundness, Place, place_from_bindings, place_from_declarations},
semantic_index::{place_table, use_def_map},
types::{
ClassBase, ClassLiteral, KnownFunction, Type, TypeMapping, TypeQualifiers, TypeVarInstance,
CallableType, ClassBase, ClassLiteral, KnownFunction, PropertyInstanceType, Signature,
Type, TypeMapping, TypeQualifiers, TypeRelation, TypeVarInstance,
signatures::{Parameter, Parameters},
},
};
@ -82,18 +84,30 @@ pub(super) enum ProtocolInterface<'db> {
}
impl<'db> ProtocolInterface<'db> {
pub(super) fn with_members<'a, M>(db: &'db dyn Db, members: M) -> Self
/// Synthesize a new protocol interface with the given members.
///
/// All created members will be covariant, read-only property members
/// rather than method members or mutable attribute members.
pub(super) fn with_property_members<'a, M>(db: &'db dyn Db, members: M) -> Self
where
M: IntoIterator<Item = (&'a str, Type<'db>)>,
{
let members: BTreeMap<_, _> = members
.into_iter()
.map(|(name, ty)| {
// Synthesize a read-only property (one that has a getter but no setter)
// which returns the specified type from its getter.
let property_getter_signature = Signature::new(
Parameters::new([Parameter::positional_only(Some(Name::new_static("self")))]),
Some(ty.normalized(db)),
);
let property_getter = CallableType::single(db, property_getter_signature);
let property = PropertyInstanceType::new(db, Some(property_getter), None);
(
Name::new(name),
ProtocolMemberData {
ty: ty.normalized(db),
qualifiers: TypeQualifiers::default(),
kind: ProtocolMemberKind::Property(property),
},
)
})
@ -116,7 +130,7 @@ impl<'db> ProtocolInterface<'db> {
Self::Members(members) => {
Either::Left(members.inner(db).iter().map(|(name, data)| ProtocolMember {
name,
ty: data.ty,
kind: data.kind,
qualifiers: data.qualifiers,
}))
}
@ -132,7 +146,7 @@ impl<'db> ProtocolInterface<'db> {
match self {
Self::Members(members) => members.inner(db).get(name).map(|data| ProtocolMember {
name,
ty: data.ty,
kind: data.kind,
qualifiers: data.qualifiers,
}),
Self::SelfReference => None,
@ -161,7 +175,7 @@ impl<'db> ProtocolInterface<'db> {
type_fn: &dyn Fn(Type<'db>) -> bool,
) -> bool {
self.members(db)
.any(|member| member.ty.any_over_type(db, type_fn))
.any(|member| member.any_over_type(db, type_fn))
}
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
@ -185,15 +199,7 @@ impl<'db> ProtocolInterface<'db> {
members
.inner(db)
.iter()
.map(|(name, data)| {
(
name.clone(),
ProtocolMemberData {
ty: data.ty.materialize(db, variance),
qualifiers: data.qualifiers,
},
)
})
.map(|(name, data)| (name.clone(), data.materialize(db, variance)))
.collect::<BTreeMap<_, _>>(),
)),
Self::SelfReference => Self::SelfReference,
@ -241,21 +247,21 @@ impl<'db> ProtocolInterface<'db> {
#[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update)]
pub(super) struct ProtocolMemberData<'db> {
ty: Type<'db>,
kind: ProtocolMemberKind<'db>,
qualifiers: TypeQualifiers,
}
impl<'db> ProtocolMemberData<'db> {
fn normalized(&self, db: &'db dyn Db) -> Self {
Self {
ty: self.ty.normalized(db),
kind: self.kind.normalized(db),
qualifiers: self.qualifiers,
}
}
fn apply_type_mapping<'a>(&self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>) -> Self {
Self {
ty: self.ty.apply_type_mapping(db, type_mapping),
kind: self.kind.apply_type_mapping(db, type_mapping),
qualifiers: self.qualifiers,
}
}
@ -265,7 +271,75 @@ impl<'db> ProtocolMemberData<'db> {
db: &'db dyn Db,
typevars: &mut FxOrderSet<TypeVarInstance<'db>>,
) {
self.ty.find_legacy_typevars(db, typevars);
self.kind.find_legacy_typevars(db, typevars);
}
fn materialize(&self, db: &'db dyn Db, variance: TypeVarVariance) -> Self {
Self {
kind: self.kind.materialize(db, variance),
qualifiers: self.qualifiers,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, salsa::Update, Hash)]
enum ProtocolMemberKind<'db> {
Method(Type<'db>), // TODO: use CallableType
Property(PropertyInstanceType<'db>),
Other(Type<'db>),
}
impl<'db> ProtocolMemberKind<'db> {
fn normalized(&self, db: &'db dyn Db) -> Self {
match self {
ProtocolMemberKind::Method(callable) => {
ProtocolMemberKind::Method(callable.normalized(db))
}
ProtocolMemberKind::Property(property) => {
ProtocolMemberKind::Property(property.normalized(db))
}
ProtocolMemberKind::Other(ty) => ProtocolMemberKind::Other(ty.normalized(db)),
}
}
fn apply_type_mapping<'a>(&self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>) -> Self {
match self {
ProtocolMemberKind::Method(callable) => {
ProtocolMemberKind::Method(callable.apply_type_mapping(db, type_mapping))
}
ProtocolMemberKind::Property(property) => {
ProtocolMemberKind::Property(property.apply_type_mapping(db, type_mapping))
}
ProtocolMemberKind::Other(ty) => {
ProtocolMemberKind::Other(ty.apply_type_mapping(db, type_mapping))
}
}
}
fn find_legacy_typevars(
&self,
db: &'db dyn Db,
typevars: &mut FxOrderSet<TypeVarInstance<'db>>,
) {
match self {
ProtocolMemberKind::Method(callable) => callable.find_legacy_typevars(db, typevars),
ProtocolMemberKind::Property(property) => property.find_legacy_typevars(db, typevars),
ProtocolMemberKind::Other(ty) => ty.find_legacy_typevars(db, typevars),
}
}
fn materialize(self, db: &'db dyn Db, variance: TypeVarVariance) -> Self {
match self {
ProtocolMemberKind::Method(callable) => {
ProtocolMemberKind::Method(callable.materialize(db, variance))
}
ProtocolMemberKind::Property(property) => {
ProtocolMemberKind::Property(property.materialize(db, variance))
}
ProtocolMemberKind::Other(ty) => {
ProtocolMemberKind::Other(ty.materialize(db, variance))
}
}
}
}
@ -273,7 +347,7 @@ impl<'db> ProtocolMemberData<'db> {
#[derive(Debug, PartialEq, Eq)]
pub(super) struct ProtocolMember<'a, 'db> {
name: &'a str,
ty: Type<'db>,
kind: ProtocolMemberKind<'db>,
qualifiers: TypeQualifiers,
}
@ -282,13 +356,52 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
self.name
}
pub(super) fn ty(&self) -> Type<'db> {
self.ty
}
pub(super) fn qualifiers(&self) -> TypeQualifiers {
self.qualifiers
}
pub(super) fn ty(&self) -> Type<'db> {
match &self.kind {
ProtocolMemberKind::Method(callable) => *callable,
ProtocolMemberKind::Property(property) => Type::PropertyInstance(*property),
ProtocolMemberKind::Other(ty) => *ty,
}
}
pub(super) const fn is_attribute_member(&self) -> bool {
matches!(self.kind, ProtocolMemberKind::Other(_))
}
/// Return `true` if `other` contains an attribute/method/property that satisfies
/// the part of the interface defined by this protocol member.
pub(super) fn is_satisfied_by(
&self,
db: &'db dyn Db,
other: Type<'db>,
relation: TypeRelation,
) -> bool {
let Place::Type(attribute_type, Boundness::Bound) = other.member(db, self.name).place
else {
return false;
};
match &self.kind {
// TODO: consider the types of the attribute on `other` for property/method members
ProtocolMemberKind::Method(_) | ProtocolMemberKind::Property(_) => true,
ProtocolMemberKind::Other(member_type) => {
member_type.has_relation_to(db, attribute_type, relation)
&& attribute_type.has_relation_to(db, *member_type, relation)
}
}
}
fn any_over_type(&self, db: &'db dyn Db, type_fn: &dyn Fn(Type<'db>) -> bool) -> bool {
match &self.kind {
ProtocolMemberKind::Method(callable) => callable.any_over_type(db, type_fn),
ProtocolMemberKind::Property(property) => property.any_over_type(db, type_fn),
ProtocolMemberKind::Other(ty) => ty.any_over_type(db, type_fn),
}
}
}
/// Returns `true` if a declaration or binding to a given name in a protocol class body
@ -330,6 +443,12 @@ fn excluded_from_proto_members(member: &str) -> bool {
) || member.starts_with("_abc_")
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum BoundOnClass {
Yes,
No,
}
/// Inner Salsa query for [`ProtocolClassLiteral::interface`].
#[salsa::tracked(cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial, heap_size=get_size2::GetSize::get_heap_size)]
fn cached_protocol_interface<'db>(
@ -357,7 +476,7 @@ fn cached_protocol_interface<'db>(
place
.place
.ignore_possibly_unbound()
.map(|ty| (place_id, ty, place.qualifiers))
.map(|ty| (place_id, ty, place.qualifiers, BoundOnClass::No))
})
// Bindings in the class body that are not declared in the class body
// are not valid protocol members, and we plan to emit diagnostics for them
@ -371,20 +490,41 @@ fn cached_protocol_interface<'db>(
|(place_id, bindings)| {
place_from_bindings(db, bindings)
.ignore_possibly_unbound()
.map(|ty| (place_id, ty, TypeQualifiers::default()))
.map(|ty| (place_id, ty, TypeQualifiers::default(), BoundOnClass::Yes))
},
))
.filter_map(|(place_id, member, qualifiers)| {
.filter_map(|(place_id, member, qualifiers, bound_on_class)| {
Some((
place_table.place_expr(place_id).as_name()?,
member,
qualifiers,
bound_on_class,
))
})
.filter(|(name, _, _)| !excluded_from_proto_members(name))
.map(|(name, ty, qualifiers)| {
let ty = ty.replace_self_reference(db, class);
let member = ProtocolMemberData { ty, qualifiers };
.filter(|(name, _, _, _)| !excluded_from_proto_members(name))
.map(|(name, ty, qualifiers, bound_on_class)| {
let kind = match (ty, bound_on_class) {
// TODO: if the getter or setter is a function literal, we should
// upcast it to a `CallableType` so that two protocols with identical property
// members are recognized as equivalent.
(Type::PropertyInstance(property), _) => {
ProtocolMemberKind::Property(property)
}
(Type::Callable(callable), BoundOnClass::Yes)
if callable.is_function_like(db) =>
{
ProtocolMemberKind::Method(ty.replace_self_reference(db, class))
}
// TODO: method members that have `FunctionLiteral` types should be upcast
// to `CallableType` so that two protocols with identical method members
// are recognized as equivalent.
(Type::FunctionLiteral(_function), BoundOnClass::Yes) => {
ProtocolMemberKind::Method(ty.replace_self_reference(db, class))
}
_ => ProtocolMemberKind::Other(ty.replace_self_reference(db, class)),
};
let member = ProtocolMemberData { kind, qualifiers };
(name.clone(), member)
}),
);