mirror of
https://github.com/astral-sh/ruff.git
synced 2025-07-23 21:15:19 +00:00
[red-knot] Infer the members of a protocol class (#17556)
This commit is contained in:
parent
7b6222700b
commit
00e73dc331
4 changed files with 205 additions and 28 deletions
|
@ -315,7 +315,7 @@ reveal_type(Protocol()) # revealed: Unknown
|
||||||
class MyProtocol(Protocol):
|
class MyProtocol(Protocol):
|
||||||
x: int
|
x: int
|
||||||
|
|
||||||
# error
|
# TODO: should emit error
|
||||||
reveal_type(MyProtocol()) # revealed: MyProtocol
|
reveal_type(MyProtocol()) # revealed: MyProtocol
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -363,16 +363,8 @@ class Foo(Protocol):
|
||||||
def method_member(self) -> bytes:
|
def method_member(self) -> bytes:
|
||||||
return b"foo"
|
return b"foo"
|
||||||
|
|
||||||
# TODO: at runtime, `get_protocol_members` returns a `frozenset`,
|
# TODO: actually a frozenset (requires support for legacy generics)
|
||||||
# but for now we might pretend it returns a `tuple`, as we support heterogeneous `tuple` types
|
reveal_type(get_protocol_members(Foo)) # revealed: tuple[Literal["method_member"], Literal["x"], Literal["y"], Literal["z"]]
|
||||||
# but not yet generic `frozenset`s
|
|
||||||
#
|
|
||||||
# So this should either be
|
|
||||||
#
|
|
||||||
# `tuple[Literal["x"], Literal["y"], Literal["z"], Literal["method_member"]]`
|
|
||||||
#
|
|
||||||
# `frozenset[Literal["x", "y", "z", "method_member"]]`
|
|
||||||
reveal_type(get_protocol_members(Foo)) # revealed: @Todo(specialized non-generic class)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Certain special attributes and methods are not considered protocol members at runtime, and should
|
Certain special attributes and methods are not considered protocol members at runtime, and should
|
||||||
|
@ -390,8 +382,8 @@ class Lumberjack(Protocol):
|
||||||
def __init__(self, x: int) -> None:
|
def __init__(self, x: int) -> None:
|
||||||
self.x = x
|
self.x = x
|
||||||
|
|
||||||
# TODO: `tuple[Literal["x"]]` or `frozenset[Literal["x"]]`
|
# TODO: actually a frozenset
|
||||||
reveal_type(get_protocol_members(Lumberjack)) # revealed: @Todo(specialized non-generic class)
|
reveal_type(get_protocol_members(Lumberjack)) # revealed: tuple[Literal["x"]]
|
||||||
```
|
```
|
||||||
|
|
||||||
A sub-protocol inherits and extends the members of its superclass protocol(s):
|
A sub-protocol inherits and extends the members of its superclass protocol(s):
|
||||||
|
@ -403,15 +395,42 @@ class Bar(Protocol):
|
||||||
class Baz(Bar, Protocol):
|
class Baz(Bar, Protocol):
|
||||||
ham: memoryview
|
ham: memoryview
|
||||||
|
|
||||||
# TODO: `tuple[Literal["spam", "ham"]]` or `frozenset[Literal["spam", "ham"]]`
|
# TODO: actually a frozenset
|
||||||
reveal_type(get_protocol_members(Baz)) # revealed: @Todo(specialized non-generic class)
|
reveal_type(get_protocol_members(Baz)) # revealed: tuple[Literal["ham"], Literal["spam"]]
|
||||||
|
|
||||||
class Baz2(Bar, Foo, Protocol): ...
|
class Baz2(Bar, Foo, Protocol): ...
|
||||||
|
|
||||||
# TODO: either
|
# TODO: actually a frozenset
|
||||||
# `tuple[Literal["spam"], Literal["x"], Literal["y"], Literal["z"], Literal["method_member"]]`
|
# revealed: tuple[Literal["method_member"], Literal["spam"], Literal["x"], Literal["y"], Literal["z"]]
|
||||||
# or `frozenset[Literal["spam", "x", "y", "z", "method_member"]]`
|
reveal_type(get_protocol_members(Baz2))
|
||||||
reveal_type(get_protocol_members(Baz2)) # revealed: @Todo(specialized non-generic class)
|
```
|
||||||
|
|
||||||
|
## Protocol members in statically known branches
|
||||||
|
|
||||||
|
The list of protocol members does not include any members declared in branches that are statically
|
||||||
|
known to be unreachable:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[environment]
|
||||||
|
python-version = "3.9"
|
||||||
|
```
|
||||||
|
|
||||||
|
```py
|
||||||
|
import sys
|
||||||
|
from typing_extensions import Protocol, get_protocol_members
|
||||||
|
|
||||||
|
class Foo(Protocol):
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
a: int
|
||||||
|
b = 42
|
||||||
|
def c(self) -> None: ...
|
||||||
|
else:
|
||||||
|
d: int
|
||||||
|
e = 56
|
||||||
|
def f(self) -> None: ...
|
||||||
|
|
||||||
|
# TODO: actually a frozenset
|
||||||
|
reveal_type(get_protocol_members(Foo)) # revealed: tuple[Literal["d"], Literal["e"], Literal["f"]]
|
||||||
```
|
```
|
||||||
|
|
||||||
## Invalid calls to `get_protocol_members()`
|
## Invalid calls to `get_protocol_members()`
|
||||||
|
@ -639,14 +658,14 @@ class LotsOfBindings(Protocol):
|
||||||
case l: # TODO: this should error with `[invalid-protocol]` (`l` is not declared)
|
case l: # TODO: this should error with `[invalid-protocol]` (`l` is not declared)
|
||||||
...
|
...
|
||||||
|
|
||||||
# TODO: all bindings in the above class should be understood as protocol members,
|
# TODO: actually a frozenset
|
||||||
# even those that we complained about with a diagnostic
|
# revealed: tuple[Literal["Nested"], Literal["NestedProtocol"], Literal["a"], Literal["b"], Literal["c"], Literal["d"], Literal["e"], Literal["f"], Literal["g"], Literal["h"], Literal["i"], Literal["j"], Literal["k"], Literal["l"]]
|
||||||
reveal_type(get_protocol_members(LotsOfBindings)) # revealed: @Todo(specialized non-generic class)
|
reveal_type(get_protocol_members(LotsOfBindings))
|
||||||
```
|
```
|
||||||
|
|
||||||
Attribute members are allowed to have assignments in methods on the protocol class, just like
|
Attribute members are allowed to have assignments in methods on the protocol class, just like
|
||||||
non-protocol classes. Unlike other classes, however, *implicit* instance attributes -- those that
|
non-protocol classes. Unlike other classes, however, instance attributes that are not declared in
|
||||||
are not declared in the class body -- are not allowed:
|
the class body are disallowed:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
class Foo(Protocol):
|
class Foo(Protocol):
|
||||||
|
@ -655,11 +674,18 @@ class Foo(Protocol):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.x = 42 # fine
|
self.x = 42 # fine
|
||||||
self.a = 56 # error
|
self.a = 56 # TODO: should emit diagnostic
|
||||||
|
self.b: int = 128 # TODO: should emit diagnostic
|
||||||
|
|
||||||
def non_init_method(self) -> None:
|
def non_init_method(self) -> None:
|
||||||
self.y = 64 # fine
|
self.y = 64 # fine
|
||||||
self.b = 72 # error
|
self.c = 72 # TODO: should emit diagnostic
|
||||||
|
|
||||||
|
# Note: the list of members does not include `a`, `b` or `c`,
|
||||||
|
# as none of these attributes is declared in the class body.
|
||||||
|
#
|
||||||
|
# TODO: actually a frozenset
|
||||||
|
reveal_type(get_protocol_members(Foo)) # revealed: tuple[Literal["non_init_method"], Literal["x"], Literal["y"]]
|
||||||
```
|
```
|
||||||
|
|
||||||
If a protocol has 0 members, then all other types are assignable to it, and all fully static types
|
If a protocol has 0 members, then all other types are assignable to it, and all fully static types
|
||||||
|
|
|
@ -437,6 +437,15 @@ impl<'db> UseDefMap<'db> {
|
||||||
.map(|symbol_id| (symbol_id, self.public_declarations(symbol_id)))
|
.map(|symbol_id| (symbol_id, self.public_declarations(symbol_id)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn all_public_bindings<'map>(
|
||||||
|
&'map self,
|
||||||
|
) -> impl Iterator<Item = (ScopedSymbolId, BindingWithConstraintsIterator<'map, 'db>)> + 'map
|
||||||
|
{
|
||||||
|
(0..self.public_symbols.len())
|
||||||
|
.map(ScopedSymbolId::from_usize)
|
||||||
|
.map(|symbol_id| (symbol_id, self.public_bindings(symbol_id)))
|
||||||
|
}
|
||||||
|
|
||||||
/// This function is intended to be called only once inside `TypeInferenceBuilder::infer_function_body`.
|
/// This function is intended to be called only once inside `TypeInferenceBuilder::infer_function_body`.
|
||||||
pub(crate) fn can_implicit_return(&self, db: &dyn crate::Db) -> bool {
|
pub(crate) fn can_implicit_return(&self, db: &dyn crate::Db) -> bool {
|
||||||
!self
|
!self
|
||||||
|
|
|
@ -20,8 +20,8 @@ use crate::types::generics::{Specialization, SpecializationBuilder};
|
||||||
use crate::types::signatures::{Parameter, ParameterForm};
|
use crate::types::signatures::{Parameter, ParameterForm};
|
||||||
use crate::types::{
|
use crate::types::{
|
||||||
BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, KnownClass,
|
BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, KnownClass,
|
||||||
KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, UnionType,
|
KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, TupleType,
|
||||||
WrapperDescriptorKind,
|
UnionType, WrapperDescriptorKind,
|
||||||
};
|
};
|
||||||
use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic};
|
use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic};
|
||||||
use ruff_python_ast as ast;
|
use ruff_python_ast as ast;
|
||||||
|
@ -561,6 +561,22 @@ impl<'db> Bindings<'db> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Some(KnownFunction::GetProtocolMembers) => {
|
||||||
|
if let [Some(Type::ClassLiteral(class))] = overload.parameter_types() {
|
||||||
|
if let Some(protocol_class) = class.into_protocol_class(db) {
|
||||||
|
// TODO: actually a frozenset at runtime (requires support for legacy generic classes)
|
||||||
|
overload.set_return_type(Type::Tuple(TupleType::new(
|
||||||
|
db,
|
||||||
|
protocol_class
|
||||||
|
.protocol_members(db)
|
||||||
|
.iter()
|
||||||
|
.map(|member| Type::string_literal(db, member))
|
||||||
|
.collect::<Box<[Type<'db>]>>(),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Some(KnownFunction::Overload) => {
|
Some(KnownFunction::Overload) => {
|
||||||
// TODO: This can be removed once we understand legacy generics because the
|
// TODO: This can be removed once we understand legacy generics because the
|
||||||
// typeshed definition for `typing.overload` is an identity function.
|
// typeshed definition for `typing.overload` is an identity function.
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
use std::hash::BuildHasherDefault;
|
use std::hash::BuildHasherDefault;
|
||||||
|
use std::ops::Deref;
|
||||||
use std::sync::{LazyLock, Mutex};
|
use std::sync::{LazyLock, Mutex};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
|
@ -13,6 +14,7 @@ use crate::types::signatures::{Parameter, Parameters};
|
||||||
use crate::types::{
|
use crate::types::{
|
||||||
CallableType, DataclassParams, DataclassTransformerParams, KnownInstanceType, Signature,
|
CallableType, DataclassParams, DataclassTransformerParams, KnownInstanceType, Signature,
|
||||||
};
|
};
|
||||||
|
use crate::FxOrderSet;
|
||||||
use crate::{
|
use crate::{
|
||||||
module_resolver::file_to_module,
|
module_resolver::file_to_module,
|
||||||
semantic_index::{
|
semantic_index::{
|
||||||
|
@ -1710,6 +1712,11 @@ impl<'db> ClassLiteralType<'db> {
|
||||||
Some(InheritanceCycle::Inherited)
|
Some(InheritanceCycle::Inherited)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns `Some` if this is a protocol class, `None` otherwise.
|
||||||
|
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClassLiteral<'db>> {
|
||||||
|
self.is_protocol(db).then_some(ProtocolClassLiteral(self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'db> From<ClassLiteralType<'db>> for Type<'db> {
|
impl<'db> From<ClassLiteralType<'db>> for Type<'db> {
|
||||||
|
@ -1721,6 +1728,125 @@ impl<'db> From<ClassLiteralType<'db>> for Type<'db> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Representation of a single `Protocol` class definition.
|
||||||
|
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||||
|
pub(super) struct ProtocolClassLiteral<'db>(ClassLiteralType<'db>);
|
||||||
|
|
||||||
|
impl<'db> ProtocolClassLiteral<'db> {
|
||||||
|
/// Returns the protocol members of this class.
|
||||||
|
///
|
||||||
|
/// A protocol's members define the interface declared by the protocol.
|
||||||
|
/// They therefore determine how the protocol should behave with regards to
|
||||||
|
/// assignability and subtyping.
|
||||||
|
///
|
||||||
|
/// The list of members consists of all bindings and declarations that take place
|
||||||
|
/// in the protocol's class body, except for a list of excluded attributes which should
|
||||||
|
/// not be taken into account. (This list includes `__init__` and `__new__`, which can
|
||||||
|
/// legally be defined on protocol classes but do not constitute protocol members.)
|
||||||
|
///
|
||||||
|
/// It is illegal for a protocol class to have any instance attributes that are not declared
|
||||||
|
/// in the protocol's class body. If any are assigned to, they are not taken into account in
|
||||||
|
/// the protocol's list of members.
|
||||||
|
pub(super) fn protocol_members(self, db: &'db dyn Db) -> &'db ordermap::set::Slice<Name> {
|
||||||
|
/// The list of excluded members is subject to change between Python versions,
|
||||||
|
/// especially for dunders, but it probably doesn't matter *too* much if this
|
||||||
|
/// list goes out of date. It's up to date as of Python commit 87b1ea016b1454b1e83b9113fa9435849b7743aa
|
||||||
|
/// (<https://github.com/python/cpython/blob/87b1ea016b1454b1e83b9113fa9435849b7743aa/Lib/typing.py#L1776-L1791>)
|
||||||
|
fn excluded_from_proto_members(member: &str) -> bool {
|
||||||
|
matches!(
|
||||||
|
member,
|
||||||
|
"_is_protocol"
|
||||||
|
| "__non_callable_proto_members__"
|
||||||
|
| "__static_attributes__"
|
||||||
|
| "__orig_class__"
|
||||||
|
| "__match_args__"
|
||||||
|
| "__weakref__"
|
||||||
|
| "__doc__"
|
||||||
|
| "__parameters__"
|
||||||
|
| "__module__"
|
||||||
|
| "_MutableMapping__marker"
|
||||||
|
| "__slots__"
|
||||||
|
| "__dict__"
|
||||||
|
| "__new__"
|
||||||
|
| "__protocol_attrs__"
|
||||||
|
| "__init__"
|
||||||
|
| "__class_getitem__"
|
||||||
|
| "__firstlineno__"
|
||||||
|
| "__abstractmethods__"
|
||||||
|
| "__orig_bases__"
|
||||||
|
| "_is_runtime_protocol"
|
||||||
|
| "__subclasshook__"
|
||||||
|
| "__type_params__"
|
||||||
|
| "__annotations__"
|
||||||
|
| "__annotate__"
|
||||||
|
| "__annotate_func__"
|
||||||
|
| "__annotations_cache__"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked(return_ref)]
|
||||||
|
fn cached_protocol_members<'db>(
|
||||||
|
db: &'db dyn Db,
|
||||||
|
class: ClassLiteralType<'db>,
|
||||||
|
) -> Box<ordermap::set::Slice<Name>> {
|
||||||
|
let mut members = FxOrderSet::default();
|
||||||
|
|
||||||
|
for parent_protocol in class
|
||||||
|
.iter_mro(db, None)
|
||||||
|
.filter_map(ClassBase::into_class)
|
||||||
|
.filter_map(|class| class.class_literal(db).0.into_protocol_class(db))
|
||||||
|
{
|
||||||
|
let parent_scope = parent_protocol.body_scope(db);
|
||||||
|
let use_def_map = use_def_map(db, parent_scope);
|
||||||
|
let symbol_table = symbol_table(db, parent_scope);
|
||||||
|
|
||||||
|
members.extend(
|
||||||
|
use_def_map
|
||||||
|
.all_public_declarations()
|
||||||
|
.flat_map(|(symbol_id, declarations)| {
|
||||||
|
symbol_from_declarations(db, declarations)
|
||||||
|
.map(|symbol| (symbol_id, symbol))
|
||||||
|
})
|
||||||
|
.filter_map(|(symbol_id, symbol)| {
|
||||||
|
symbol.symbol.ignore_possibly_unbound().map(|_| symbol_id)
|
||||||
|
})
|
||||||
|
// Bindings in the class body that are not declared in the class body
|
||||||
|
// are not valid protocol members, and we plan to emit diagnostics for them
|
||||||
|
// elsewhere. Invalid or not, however, it's important that we still consider
|
||||||
|
// them to be protocol members. The implementation of `issubclass()` and
|
||||||
|
// `isinstance()` for runtime-checkable protocols considers them to be protocol
|
||||||
|
// members at runtime, and it's important that we accurately understand
|
||||||
|
// type narrowing that uses `isinstance()` or `issubclass()` with
|
||||||
|
// runtime-checkable protocols.
|
||||||
|
.chain(use_def_map.all_public_bindings().filter_map(
|
||||||
|
|(symbol_id, bindings)| {
|
||||||
|
symbol_from_bindings(db, bindings)
|
||||||
|
.ignore_possibly_unbound()
|
||||||
|
.map(|_| symbol_id)
|
||||||
|
},
|
||||||
|
))
|
||||||
|
.map(|symbol_id| symbol_table.symbol(symbol_id).name())
|
||||||
|
.filter(|name| !excluded_from_proto_members(name))
|
||||||
|
.cloned(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
members.sort();
|
||||||
|
members.into_boxed_slice()
|
||||||
|
}
|
||||||
|
|
||||||
|
cached_protocol_members(db, *self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'db> Deref for ProtocolClassLiteral<'db> {
|
||||||
|
type Target = ClassLiteralType<'db>;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
pub(super) enum InheritanceCycle {
|
pub(super) enum InheritanceCycle {
|
||||||
/// The class is cyclically defined and is a participant in the cycle.
|
/// The class is cyclically defined and is a participant in the cycle.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue