diff --git a/crates/red_knot_python_semantic/resources/mdtest/protocols.md b/crates/red_knot_python_semantic/resources/mdtest/protocols.md index 966cfcebab..9988ea761f 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/protocols.md +++ b/crates/red_knot_python_semantic/resources/mdtest/protocols.md @@ -315,7 +315,7 @@ reveal_type(Protocol()) # revealed: Unknown class MyProtocol(Protocol): x: int -# error +# TODO: should emit error reveal_type(MyProtocol()) # revealed: MyProtocol ``` @@ -363,16 +363,8 @@ class Foo(Protocol): def method_member(self) -> bytes: return b"foo" -# TODO: at runtime, `get_protocol_members` returns a `frozenset`, -# but for now we might pretend it returns a `tuple`, as we support heterogeneous `tuple` types -# but not yet generic `frozenset`s -# -# So this should either be -# -# `tuple[Literal["x"], Literal["y"], Literal["z"], Literal["method_member"]]` -# -# `frozenset[Literal["x", "y", "z", "method_member"]]` -reveal_type(get_protocol_members(Foo)) # revealed: @Todo(specialized non-generic class) +# TODO: actually a frozenset (requires support for legacy generics) +reveal_type(get_protocol_members(Foo)) # revealed: tuple[Literal["method_member"], Literal["x"], Literal["y"], Literal["z"]] ``` Certain special attributes and methods are not considered protocol members at runtime, and should @@ -390,8 +382,8 @@ class Lumberjack(Protocol): def __init__(self, x: int) -> None: self.x = x -# TODO: `tuple[Literal["x"]]` or `frozenset[Literal["x"]]` -reveal_type(get_protocol_members(Lumberjack)) # revealed: @Todo(specialized non-generic class) +# TODO: actually a frozenset +reveal_type(get_protocol_members(Lumberjack)) # revealed: tuple[Literal["x"]] ``` A sub-protocol inherits and extends the members of its superclass protocol(s): @@ -403,15 +395,42 @@ class Bar(Protocol): class Baz(Bar, Protocol): ham: memoryview -# TODO: `tuple[Literal["spam", "ham"]]` or `frozenset[Literal["spam", "ham"]]` -reveal_type(get_protocol_members(Baz)) # revealed: @Todo(specialized non-generic class) +# TODO: actually a frozenset +reveal_type(get_protocol_members(Baz)) # revealed: tuple[Literal["ham"], Literal["spam"]] class Baz2(Bar, Foo, Protocol): ... -# TODO: either -# `tuple[Literal["spam"], Literal["x"], Literal["y"], Literal["z"], Literal["method_member"]]` -# or `frozenset[Literal["spam", "x", "y", "z", "method_member"]]` -reveal_type(get_protocol_members(Baz2)) # revealed: @Todo(specialized non-generic class) +# TODO: actually a frozenset +# revealed: tuple[Literal["method_member"], Literal["spam"], Literal["x"], Literal["y"], Literal["z"]] +reveal_type(get_protocol_members(Baz2)) +``` + +## Protocol members in statically known branches + +The list of protocol members does not include any members declared in branches that are statically +known to be unreachable: + +```toml +[environment] +python-version = "3.9" +``` + +```py +import sys +from typing_extensions import Protocol, get_protocol_members + +class Foo(Protocol): + if sys.version_info >= (3, 10): + a: int + b = 42 + def c(self) -> None: ... + else: + d: int + e = 56 + def f(self) -> None: ... + +# TODO: actually a frozenset +reveal_type(get_protocol_members(Foo)) # revealed: tuple[Literal["d"], Literal["e"], Literal["f"]] ``` ## Invalid calls to `get_protocol_members()` @@ -639,14 +658,14 @@ class LotsOfBindings(Protocol): case l: # TODO: this should error with `[invalid-protocol]` (`l` is not declared) ... -# TODO: all bindings in the above class should be understood as protocol members, -# even those that we complained about with a diagnostic -reveal_type(get_protocol_members(LotsOfBindings)) # revealed: @Todo(specialized non-generic class) +# TODO: actually a frozenset +# revealed: tuple[Literal["Nested"], Literal["NestedProtocol"], Literal["a"], Literal["b"], Literal["c"], Literal["d"], Literal["e"], Literal["f"], Literal["g"], Literal["h"], Literal["i"], Literal["j"], Literal["k"], Literal["l"]] +reveal_type(get_protocol_members(LotsOfBindings)) ``` Attribute members are allowed to have assignments in methods on the protocol class, just like -non-protocol classes. Unlike other classes, however, *implicit* instance attributes -- those that -are not declared in the class body -- are not allowed: +non-protocol classes. Unlike other classes, however, instance attributes that are not declared in +the class body are disallowed: ```py class Foo(Protocol): @@ -655,11 +674,18 @@ class Foo(Protocol): def __init__(self) -> None: self.x = 42 # fine - self.a = 56 # error + self.a = 56 # TODO: should emit diagnostic + self.b: int = 128 # TODO: should emit diagnostic def non_init_method(self) -> None: self.y = 64 # fine - self.b = 72 # error + self.c = 72 # TODO: should emit diagnostic + +# Note: the list of members does not include `a`, `b` or `c`, +# as none of these attributes is declared in the class body. +# +# TODO: actually a frozenset +reveal_type(get_protocol_members(Foo)) # revealed: tuple[Literal["non_init_method"], Literal["x"], Literal["y"]] ``` If a protocol has 0 members, then all other types are assignable to it, and all fully static types diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs index ed540c380a..fe7f4ed0db 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs @@ -437,6 +437,15 @@ impl<'db> UseDefMap<'db> { .map(|symbol_id| (symbol_id, self.public_declarations(symbol_id))) } + pub(crate) fn all_public_bindings<'map>( + &'map self, + ) -> impl Iterator)> + 'map + { + (0..self.public_symbols.len()) + .map(ScopedSymbolId::from_usize) + .map(|symbol_id| (symbol_id, self.public_bindings(symbol_id))) + } + /// This function is intended to be called only once inside `TypeInferenceBuilder::infer_function_body`. pub(crate) fn can_implicit_return(&self, db: &dyn crate::Db) -> bool { !self diff --git a/crates/red_knot_python_semantic/src/types/call/bind.rs b/crates/red_knot_python_semantic/src/types/call/bind.rs index 4a2b5a7bfd..4a36ab726c 100644 --- a/crates/red_knot_python_semantic/src/types/call/bind.rs +++ b/crates/red_knot_python_semantic/src/types/call/bind.rs @@ -20,8 +20,8 @@ use crate::types::generics::{Specialization, SpecializationBuilder}; use crate::types::signatures::{Parameter, ParameterForm}; use crate::types::{ BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, KnownClass, - KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, UnionType, - WrapperDescriptorKind, + KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, TupleType, + UnionType, WrapperDescriptorKind, }; use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic}; use ruff_python_ast as ast; @@ -561,6 +561,22 @@ impl<'db> Bindings<'db> { } } + Some(KnownFunction::GetProtocolMembers) => { + if let [Some(Type::ClassLiteral(class))] = overload.parameter_types() { + if let Some(protocol_class) = class.into_protocol_class(db) { + // TODO: actually a frozenset at runtime (requires support for legacy generic classes) + overload.set_return_type(Type::Tuple(TupleType::new( + db, + protocol_class + .protocol_members(db) + .iter() + .map(|member| Type::string_literal(db, member)) + .collect::]>>(), + ))); + } + } + } + Some(KnownFunction::Overload) => { // TODO: This can be removed once we understand legacy generics because the // typeshed definition for `typing.overload` is an identity function. diff --git a/crates/red_knot_python_semantic/src/types/class.rs b/crates/red_knot_python_semantic/src/types/class.rs index 85aa50aa0d..82f88a35de 100644 --- a/crates/red_knot_python_semantic/src/types/class.rs +++ b/crates/red_knot_python_semantic/src/types/class.rs @@ -1,4 +1,5 @@ use std::hash::BuildHasherDefault; +use std::ops::Deref; use std::sync::{LazyLock, Mutex}; use super::{ @@ -13,6 +14,7 @@ use crate::types::signatures::{Parameter, Parameters}; use crate::types::{ CallableType, DataclassParams, DataclassTransformerParams, KnownInstanceType, Signature, }; +use crate::FxOrderSet; use crate::{ module_resolver::file_to_module, semantic_index::{ @@ -1710,6 +1712,11 @@ impl<'db> ClassLiteralType<'db> { Some(InheritanceCycle::Inherited) } } + + /// Returns `Some` if this is a protocol class, `None` otherwise. + pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option> { + self.is_protocol(db).then_some(ProtocolClassLiteral(self)) + } } impl<'db> From> for Type<'db> { @@ -1721,6 +1728,125 @@ impl<'db> From> for Type<'db> { } } +/// Representation of a single `Protocol` class definition. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub(super) struct ProtocolClassLiteral<'db>(ClassLiteralType<'db>); + +impl<'db> ProtocolClassLiteral<'db> { + /// Returns the protocol members of this class. + /// + /// A protocol's members define the interface declared by the protocol. + /// They therefore determine how the protocol should behave with regards to + /// assignability and subtyping. + /// + /// The list of members consists of all bindings and declarations that take place + /// in the protocol's class body, except for a list of excluded attributes which should + /// not be taken into account. (This list includes `__init__` and `__new__`, which can + /// legally be defined on protocol classes but do not constitute protocol members.) + /// + /// It is illegal for a protocol class to have any instance attributes that are not declared + /// in the protocol's class body. If any are assigned to, they are not taken into account in + /// the protocol's list of members. + pub(super) fn protocol_members(self, db: &'db dyn Db) -> &'db ordermap::set::Slice { + /// The list of excluded members is subject to change between Python versions, + /// especially for dunders, but it probably doesn't matter *too* much if this + /// list goes out of date. It's up to date as of Python commit 87b1ea016b1454b1e83b9113fa9435849b7743aa + /// () + fn excluded_from_proto_members(member: &str) -> bool { + matches!( + member, + "_is_protocol" + | "__non_callable_proto_members__" + | "__static_attributes__" + | "__orig_class__" + | "__match_args__" + | "__weakref__" + | "__doc__" + | "__parameters__" + | "__module__" + | "_MutableMapping__marker" + | "__slots__" + | "__dict__" + | "__new__" + | "__protocol_attrs__" + | "__init__" + | "__class_getitem__" + | "__firstlineno__" + | "__abstractmethods__" + | "__orig_bases__" + | "_is_runtime_protocol" + | "__subclasshook__" + | "__type_params__" + | "__annotations__" + | "__annotate__" + | "__annotate_func__" + | "__annotations_cache__" + ) + } + + #[salsa::tracked(return_ref)] + fn cached_protocol_members<'db>( + db: &'db dyn Db, + class: ClassLiteralType<'db>, + ) -> Box> { + let mut members = FxOrderSet::default(); + + for parent_protocol in class + .iter_mro(db, None) + .filter_map(ClassBase::into_class) + .filter_map(|class| class.class_literal(db).0.into_protocol_class(db)) + { + let parent_scope = parent_protocol.body_scope(db); + let use_def_map = use_def_map(db, parent_scope); + let symbol_table = symbol_table(db, parent_scope); + + members.extend( + use_def_map + .all_public_declarations() + .flat_map(|(symbol_id, declarations)| { + symbol_from_declarations(db, declarations) + .map(|symbol| (symbol_id, symbol)) + }) + .filter_map(|(symbol_id, symbol)| { + symbol.symbol.ignore_possibly_unbound().map(|_| symbol_id) + }) + // Bindings in the class body that are not declared in the class body + // are not valid protocol members, and we plan to emit diagnostics for them + // elsewhere. Invalid or not, however, it's important that we still consider + // them to be protocol members. The implementation of `issubclass()` and + // `isinstance()` for runtime-checkable protocols considers them to be protocol + // members at runtime, and it's important that we accurately understand + // type narrowing that uses `isinstance()` or `issubclass()` with + // runtime-checkable protocols. + .chain(use_def_map.all_public_bindings().filter_map( + |(symbol_id, bindings)| { + symbol_from_bindings(db, bindings) + .ignore_possibly_unbound() + .map(|_| symbol_id) + }, + )) + .map(|symbol_id| symbol_table.symbol(symbol_id).name()) + .filter(|name| !excluded_from_proto_members(name)) + .cloned(), + ); + } + + members.sort(); + members.into_boxed_slice() + } + + cached_protocol_members(db, *self) + } +} + +impl<'db> Deref for ProtocolClassLiteral<'db> { + type Target = ClassLiteralType<'db>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub(super) enum InheritanceCycle { /// The class is cyclically defined and is a participant in the cycle.