[red-knot] Infer the members of a protocol class (#17556)

This commit is contained in:
Alex Waygood 2025-04-23 22:36:12 +01:00 committed by GitHub
parent 7b6222700b
commit 00e73dc331
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 205 additions and 28 deletions

View file

@ -315,7 +315,7 @@ reveal_type(Protocol()) # revealed: Unknown
class MyProtocol(Protocol): class MyProtocol(Protocol):
x: int x: int
# error # TODO: should emit error
reveal_type(MyProtocol()) # revealed: MyProtocol reveal_type(MyProtocol()) # revealed: MyProtocol
``` ```
@ -363,16 +363,8 @@ class Foo(Protocol):
def method_member(self) -> bytes: def method_member(self) -> bytes:
return b"foo" return b"foo"
# TODO: at runtime, `get_protocol_members` returns a `frozenset`, # TODO: actually a frozenset (requires support for legacy generics)
# but for now we might pretend it returns a `tuple`, as we support heterogeneous `tuple` types reveal_type(get_protocol_members(Foo)) # revealed: tuple[Literal["method_member"], Literal["x"], Literal["y"], Literal["z"]]
# but not yet generic `frozenset`s
#
# So this should either be
#
# `tuple[Literal["x"], Literal["y"], Literal["z"], Literal["method_member"]]`
#
# `frozenset[Literal["x", "y", "z", "method_member"]]`
reveal_type(get_protocol_members(Foo)) # revealed: @Todo(specialized non-generic class)
``` ```
Certain special attributes and methods are not considered protocol members at runtime, and should Certain special attributes and methods are not considered protocol members at runtime, and should
@ -390,8 +382,8 @@ class Lumberjack(Protocol):
def __init__(self, x: int) -> None: def __init__(self, x: int) -> None:
self.x = x self.x = x
# TODO: `tuple[Literal["x"]]` or `frozenset[Literal["x"]]` # TODO: actually a frozenset
reveal_type(get_protocol_members(Lumberjack)) # revealed: @Todo(specialized non-generic class) reveal_type(get_protocol_members(Lumberjack)) # revealed: tuple[Literal["x"]]
``` ```
A sub-protocol inherits and extends the members of its superclass protocol(s): A sub-protocol inherits and extends the members of its superclass protocol(s):
@ -403,15 +395,42 @@ class Bar(Protocol):
class Baz(Bar, Protocol): class Baz(Bar, Protocol):
ham: memoryview ham: memoryview
# TODO: `tuple[Literal["spam", "ham"]]` or `frozenset[Literal["spam", "ham"]]` # TODO: actually a frozenset
reveal_type(get_protocol_members(Baz)) # revealed: @Todo(specialized non-generic class) reveal_type(get_protocol_members(Baz)) # revealed: tuple[Literal["ham"], Literal["spam"]]
class Baz2(Bar, Foo, Protocol): ... class Baz2(Bar, Foo, Protocol): ...
# TODO: either # TODO: actually a frozenset
# `tuple[Literal["spam"], Literal["x"], Literal["y"], Literal["z"], Literal["method_member"]]` # revealed: tuple[Literal["method_member"], Literal["spam"], Literal["x"], Literal["y"], Literal["z"]]
# or `frozenset[Literal["spam", "x", "y", "z", "method_member"]]` reveal_type(get_protocol_members(Baz2))
reveal_type(get_protocol_members(Baz2)) # revealed: @Todo(specialized non-generic class) ```
## Protocol members in statically known branches
The list of protocol members does not include any members declared in branches that are statically
known to be unreachable:
```toml
[environment]
python-version = "3.9"
```
```py
import sys
from typing_extensions import Protocol, get_protocol_members
class Foo(Protocol):
if sys.version_info >= (3, 10):
a: int
b = 42
def c(self) -> None: ...
else:
d: int
e = 56
def f(self) -> None: ...
# TODO: actually a frozenset
reveal_type(get_protocol_members(Foo)) # revealed: tuple[Literal["d"], Literal["e"], Literal["f"]]
``` ```
## Invalid calls to `get_protocol_members()` ## Invalid calls to `get_protocol_members()`
@ -639,14 +658,14 @@ class LotsOfBindings(Protocol):
case l: # TODO: this should error with `[invalid-protocol]` (`l` is not declared) case l: # TODO: this should error with `[invalid-protocol]` (`l` is not declared)
... ...
# TODO: all bindings in the above class should be understood as protocol members, # TODO: actually a frozenset
# even those that we complained about with a diagnostic # revealed: tuple[Literal["Nested"], Literal["NestedProtocol"], Literal["a"], Literal["b"], Literal["c"], Literal["d"], Literal["e"], Literal["f"], Literal["g"], Literal["h"], Literal["i"], Literal["j"], Literal["k"], Literal["l"]]
reveal_type(get_protocol_members(LotsOfBindings)) # revealed: @Todo(specialized non-generic class) reveal_type(get_protocol_members(LotsOfBindings))
``` ```
Attribute members are allowed to have assignments in methods on the protocol class, just like Attribute members are allowed to have assignments in methods on the protocol class, just like
non-protocol classes. Unlike other classes, however, *implicit* instance attributes -- those that non-protocol classes. Unlike other classes, however, instance attributes that are not declared in
are not declared in the class body -- are not allowed: the class body are disallowed:
```py ```py
class Foo(Protocol): class Foo(Protocol):
@ -655,11 +674,18 @@ class Foo(Protocol):
def __init__(self) -> None: def __init__(self) -> None:
self.x = 42 # fine self.x = 42 # fine
self.a = 56 # error self.a = 56 # TODO: should emit diagnostic
self.b: int = 128 # TODO: should emit diagnostic
def non_init_method(self) -> None: def non_init_method(self) -> None:
self.y = 64 # fine self.y = 64 # fine
self.b = 72 # error self.c = 72 # TODO: should emit diagnostic
# Note: the list of members does not include `a`, `b` or `c`,
# as none of these attributes is declared in the class body.
#
# TODO: actually a frozenset
reveal_type(get_protocol_members(Foo)) # revealed: tuple[Literal["non_init_method"], Literal["x"], Literal["y"]]
``` ```
If a protocol has 0 members, then all other types are assignable to it, and all fully static types If a protocol has 0 members, then all other types are assignable to it, and all fully static types

View file

@ -437,6 +437,15 @@ impl<'db> UseDefMap<'db> {
.map(|symbol_id| (symbol_id, self.public_declarations(symbol_id))) .map(|symbol_id| (symbol_id, self.public_declarations(symbol_id)))
} }
pub(crate) fn all_public_bindings<'map>(
&'map self,
) -> impl Iterator<Item = (ScopedSymbolId, BindingWithConstraintsIterator<'map, 'db>)> + 'map
{
(0..self.public_symbols.len())
.map(ScopedSymbolId::from_usize)
.map(|symbol_id| (symbol_id, self.public_bindings(symbol_id)))
}
/// This function is intended to be called only once inside `TypeInferenceBuilder::infer_function_body`. /// This function is intended to be called only once inside `TypeInferenceBuilder::infer_function_body`.
pub(crate) fn can_implicit_return(&self, db: &dyn crate::Db) -> bool { pub(crate) fn can_implicit_return(&self, db: &dyn crate::Db) -> bool {
!self !self

View file

@ -20,8 +20,8 @@ use crate::types::generics::{Specialization, SpecializationBuilder};
use crate::types::signatures::{Parameter, ParameterForm}; use crate::types::signatures::{Parameter, ParameterForm};
use crate::types::{ use crate::types::{
BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, KnownClass, BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, KnownClass,
KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, UnionType, KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, TupleType,
WrapperDescriptorKind, UnionType, WrapperDescriptorKind,
}; };
use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic}; use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic};
use ruff_python_ast as ast; use ruff_python_ast as ast;
@ -561,6 +561,22 @@ impl<'db> Bindings<'db> {
} }
} }
Some(KnownFunction::GetProtocolMembers) => {
if let [Some(Type::ClassLiteral(class))] = overload.parameter_types() {
if let Some(protocol_class) = class.into_protocol_class(db) {
// TODO: actually a frozenset at runtime (requires support for legacy generic classes)
overload.set_return_type(Type::Tuple(TupleType::new(
db,
protocol_class
.protocol_members(db)
.iter()
.map(|member| Type::string_literal(db, member))
.collect::<Box<[Type<'db>]>>(),
)));
}
}
}
Some(KnownFunction::Overload) => { Some(KnownFunction::Overload) => {
// TODO: This can be removed once we understand legacy generics because the // TODO: This can be removed once we understand legacy generics because the
// typeshed definition for `typing.overload` is an identity function. // typeshed definition for `typing.overload` is an identity function.

View file

@ -1,4 +1,5 @@
use std::hash::BuildHasherDefault; use std::hash::BuildHasherDefault;
use std::ops::Deref;
use std::sync::{LazyLock, Mutex}; use std::sync::{LazyLock, Mutex};
use super::{ use super::{
@ -13,6 +14,7 @@ use crate::types::signatures::{Parameter, Parameters};
use crate::types::{ use crate::types::{
CallableType, DataclassParams, DataclassTransformerParams, KnownInstanceType, Signature, CallableType, DataclassParams, DataclassTransformerParams, KnownInstanceType, Signature,
}; };
use crate::FxOrderSet;
use crate::{ use crate::{
module_resolver::file_to_module, module_resolver::file_to_module,
semantic_index::{ semantic_index::{
@ -1710,6 +1712,11 @@ impl<'db> ClassLiteralType<'db> {
Some(InheritanceCycle::Inherited) Some(InheritanceCycle::Inherited)
} }
} }
/// Returns `Some` if this is a protocol class, `None` otherwise.
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClassLiteral<'db>> {
self.is_protocol(db).then_some(ProtocolClassLiteral(self))
}
} }
impl<'db> From<ClassLiteralType<'db>> for Type<'db> { impl<'db> From<ClassLiteralType<'db>> for Type<'db> {
@ -1721,6 +1728,125 @@ impl<'db> From<ClassLiteralType<'db>> for Type<'db> {
} }
} }
/// Representation of a single `Protocol` class definition.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(super) struct ProtocolClassLiteral<'db>(ClassLiteralType<'db>);
impl<'db> ProtocolClassLiteral<'db> {
/// Returns the protocol members of this class.
///
/// A protocol's members define the interface declared by the protocol.
/// They therefore determine how the protocol should behave with regards to
/// assignability and subtyping.
///
/// The list of members consists of all bindings and declarations that take place
/// in the protocol's class body, except for a list of excluded attributes which should
/// not be taken into account. (This list includes `__init__` and `__new__`, which can
/// legally be defined on protocol classes but do not constitute protocol members.)
///
/// It is illegal for a protocol class to have any instance attributes that are not declared
/// in the protocol's class body. If any are assigned to, they are not taken into account in
/// the protocol's list of members.
pub(super) fn protocol_members(self, db: &'db dyn Db) -> &'db ordermap::set::Slice<Name> {
/// The list of excluded members is subject to change between Python versions,
/// especially for dunders, but it probably doesn't matter *too* much if this
/// list goes out of date. It's up to date as of Python commit 87b1ea016b1454b1e83b9113fa9435849b7743aa
/// (<https://github.com/python/cpython/blob/87b1ea016b1454b1e83b9113fa9435849b7743aa/Lib/typing.py#L1776-L1791>)
fn excluded_from_proto_members(member: &str) -> bool {
matches!(
member,
"_is_protocol"
| "__non_callable_proto_members__"
| "__static_attributes__"
| "__orig_class__"
| "__match_args__"
| "__weakref__"
| "__doc__"
| "__parameters__"
| "__module__"
| "_MutableMapping__marker"
| "__slots__"
| "__dict__"
| "__new__"
| "__protocol_attrs__"
| "__init__"
| "__class_getitem__"
| "__firstlineno__"
| "__abstractmethods__"
| "__orig_bases__"
| "_is_runtime_protocol"
| "__subclasshook__"
| "__type_params__"
| "__annotations__"
| "__annotate__"
| "__annotate_func__"
| "__annotations_cache__"
)
}
#[salsa::tracked(return_ref)]
fn cached_protocol_members<'db>(
db: &'db dyn Db,
class: ClassLiteralType<'db>,
) -> Box<ordermap::set::Slice<Name>> {
let mut members = FxOrderSet::default();
for parent_protocol in class
.iter_mro(db, None)
.filter_map(ClassBase::into_class)
.filter_map(|class| class.class_literal(db).0.into_protocol_class(db))
{
let parent_scope = parent_protocol.body_scope(db);
let use_def_map = use_def_map(db, parent_scope);
let symbol_table = symbol_table(db, parent_scope);
members.extend(
use_def_map
.all_public_declarations()
.flat_map(|(symbol_id, declarations)| {
symbol_from_declarations(db, declarations)
.map(|symbol| (symbol_id, symbol))
})
.filter_map(|(symbol_id, symbol)| {
symbol.symbol.ignore_possibly_unbound().map(|_| symbol_id)
})
// Bindings in the class body that are not declared in the class body
// are not valid protocol members, and we plan to emit diagnostics for them
// elsewhere. Invalid or not, however, it's important that we still consider
// them to be protocol members. The implementation of `issubclass()` and
// `isinstance()` for runtime-checkable protocols considers them to be protocol
// members at runtime, and it's important that we accurately understand
// type narrowing that uses `isinstance()` or `issubclass()` with
// runtime-checkable protocols.
.chain(use_def_map.all_public_bindings().filter_map(
|(symbol_id, bindings)| {
symbol_from_bindings(db, bindings)
.ignore_possibly_unbound()
.map(|_| symbol_id)
},
))
.map(|symbol_id| symbol_table.symbol(symbol_id).name())
.filter(|name| !excluded_from_proto_members(name))
.cloned(),
);
}
members.sort();
members.into_boxed_slice()
}
cached_protocol_members(db, *self)
}
}
impl<'db> Deref for ProtocolClassLiteral<'db> {
type Target = ClassLiteralType<'db>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub(super) enum InheritanceCycle { pub(super) enum InheritanceCycle {
/// The class is cyclically defined and is a participant in the cycle. /// The class is cyclically defined and is a participant in the cycle.