diff --git a/crates/ty_python_semantic/resources/mdtest/protocols.md b/crates/ty_python_semantic/resources/mdtest/protocols.md index 1c49a8c9a7..6e4f4ed0e1 100644 --- a/crates/ty_python_semantic/resources/mdtest/protocols.md +++ b/crates/ty_python_semantic/resources/mdtest/protocols.md @@ -382,6 +382,31 @@ class Foo(Protocol): reveal_type(get_protocol_members(Foo)) # revealed: frozenset[Literal["method_member", "x", "y", "z"]] ``` +To see the kinds and types of the protocol members, you can use the debugging aid +`ty_extensions.reveal_protocol_interface`, meanwhile: + +```py +from ty_extensions import reveal_protocol_interface +from typing import SupportsIndex, SupportsAbs + +# error: [revealed-type] "Revealed protocol interface: `{"method_member": MethodMember(`(self) -> bytes`), "x": AttributeMember(`int`), "y": PropertyMember { getter: `def y(self) -> str` }, "z": PropertyMember { getter: `def z(self) -> int`, setter: `def z(self, z: int) -> None` }}`" +reveal_protocol_interface(Foo) +# error: [revealed-type] "Revealed protocol interface: `{"__index__": MethodMember(`(self) -> int`)}`" +reveal_protocol_interface(SupportsIndex) +# error: [revealed-type] "Revealed protocol interface: `{"__abs__": MethodMember(`(self) -> _T_co`)}`" +reveal_protocol_interface(SupportsAbs) + +# error: [invalid-argument-type] "Invalid argument to `reveal_protocol_interface`: Only protocol classes can be passed to `reveal_protocol_interface`" +reveal_protocol_interface(int) +# error: [invalid-argument-type] "Argument to function `reveal_protocol_interface` is incorrect: Expected `type`, found `Literal["foo"]`" +reveal_protocol_interface("foo") + +# TODO: this should be a `revealed-type` diagnostic rather than `invalid-argument-type`, and it should reveal `{"__abs__": MethodMember(`(self) -> int`)}` for the protocol interface +# +# error: [invalid-argument-type] "Invalid argument to `reveal_protocol_interface`: Only protocol classes can be passed to `reveal_protocol_interface`" +reveal_protocol_interface(SupportsAbs[int]) +``` + Certain special attributes and methods are not considered protocol members at runtime, and should not be considered protocol members by type checkers either: diff --git a/crates/ty_python_semantic/src/types/diagnostic.rs b/crates/ty_python_semantic/src/types/diagnostic.rs index 391c1df289..8226a293e3 100644 --- a/crates/ty_python_semantic/src/types/diagnostic.rs +++ b/crates/ty_python_semantic/src/types/diagnostic.rs @@ -2251,6 +2251,41 @@ pub(crate) fn report_bad_argument_to_get_protocol_members( diagnostic.info("See https://typing.python.org/en/latest/spec/protocol.html#"); } +pub(crate) fn report_bad_argument_to_protocol_interface( + context: &InferContext, + call: &ast::ExprCall, + param_type: Type, +) { + let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, call) else { + return; + }; + let db = context.db(); + let mut diagnostic = builder.into_diagnostic("Invalid argument to `reveal_protocol_interface`"); + diagnostic + .set_primary_message("Only protocol classes can be passed to `reveal_protocol_interface`"); + + if let Some(class) = param_type.to_class_type(context.db()) { + let mut class_def_diagnostic = SubDiagnostic::new( + SubDiagnosticSeverity::Info, + format_args!( + "`{}` is declared here, but it is not a protocol class:", + class.name(db) + ), + ); + class_def_diagnostic.annotate(Annotation::primary( + class.class_literal(db).0.header_span(db), + )); + diagnostic.sub(class_def_diagnostic); + } + + diagnostic.info( + "A class is only a protocol class if it directly inherits \ + from `typing.Protocol` or `typing_extensions.Protocol`", + ); + // See TODO in `report_bad_argument_to_get_protocol_members` above + diagnostic.info("See https://typing.python.org/en/latest/spec/protocol.html"); +} + pub(crate) fn report_invalid_arguments_to_callable( context: &InferContext, subscript: &ast::ExprSubscript, diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 2c13e03083..9e4225836a 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -68,7 +68,7 @@ use crate::types::call::{Binding, CallArguments}; use crate::types::context::InferContext; use crate::types::diagnostic::{ REDUNDANT_CAST, STATIC_ASSERT_ERROR, TYPE_ASSERTION_FAILURE, - report_bad_argument_to_get_protocol_members, + report_bad_argument_to_get_protocol_members, report_bad_argument_to_protocol_interface, report_runtime_check_against_non_runtime_checkable_protocol, }; use crate::types::generics::{GenericContext, walk_generic_context}; @@ -1093,6 +1093,8 @@ pub enum KnownFunction { TopMaterialization, /// `ty_extensions.bottom_materialization` BottomMaterialization, + /// `ty_extensions.reveal_protocol_interface` + RevealProtocolInterface, } impl KnownFunction { @@ -1158,6 +1160,7 @@ impl KnownFunction { | Self::EnumMembers | Self::StaticAssert | Self::HasMember + | Self::RevealProtocolInterface | Self::AllMembers => module.is_ty_extensions(), Self::ImportModule => module.is_importlib(), } @@ -1350,6 +1353,33 @@ impl KnownFunction { report_bad_argument_to_get_protocol_members(context, call_expression, *class); } + KnownFunction::RevealProtocolInterface => { + let [Some(param_type)] = parameter_types else { + return; + }; + let Some(protocol_class) = param_type + .into_class_literal() + .and_then(|class| class.into_protocol_class(db)) + else { + report_bad_argument_to_protocol_interface( + context, + call_expression, + *param_type, + ); + return; + }; + if let Some(builder) = + context.report_diagnostic(DiagnosticId::RevealedType, Severity::Info) + { + let mut diag = builder.into_diagnostic("Revealed protocol interface"); + let span = context.span(&call_expression.arguments.args[0]); + diag.annotate(Annotation::primary(span).message(format_args!( + "`{}`", + protocol_class.interface(db).display(db) + ))); + } + } + KnownFunction::IsInstance | KnownFunction::IsSubclass => { let [Some(first_arg), Some(Type::ClassLiteral(class))] = parameter_types else { return; @@ -1463,6 +1493,7 @@ pub(crate) mod tests { | KnownFunction::TopMaterialization | KnownFunction::BottomMaterialization | KnownFunction::HasMember + | KnownFunction::RevealProtocolInterface | KnownFunction::AllMembers => KnownModule::TyExtensions, KnownFunction::ImportModule => KnownModule::ImportLib, diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index 845a3b8f47..6525be8181 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -1,3 +1,4 @@ +use std::fmt::Write; use std::{collections::BTreeMap, ops::Deref}; use itertools::Itertools; @@ -215,6 +216,31 @@ impl<'db> ProtocolInterface<'db> { data.find_legacy_typevars(db, typevars); } } + + pub(super) fn display(self, db: &'db dyn Db) -> impl std::fmt::Display { + struct ProtocolInterfaceDisplay<'db> { + db: &'db dyn Db, + interface: ProtocolInterface<'db>, + } + + impl std::fmt::Display for ProtocolInterfaceDisplay<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_char('{')?; + for (i, (name, data)) in self.interface.inner(self.db).iter().enumerate() { + write!(f, "\"{name}\": {data}", data = data.display(self.db))?; + if i < self.interface.inner(self.db).len() - 1 { + f.write_str(", ")?; + } + } + f.write_char('}') + } + } + + ProtocolInterfaceDisplay { + db, + interface: self, + } + } } #[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update)] @@ -256,6 +282,41 @@ impl<'db> ProtocolMemberData<'db> { qualifiers: self.qualifiers, } } + + fn display(&self, db: &'db dyn Db) -> impl std::fmt::Display { + struct ProtocolMemberDataDisplay<'db> { + db: &'db dyn Db, + data: ProtocolMemberKind<'db>, + } + + impl std::fmt::Display for ProtocolMemberDataDisplay<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.data { + ProtocolMemberKind::Method(callable) => { + write!(f, "MethodMember(`{}`)", callable.display(self.db)) + } + ProtocolMemberKind::Property(property) => { + let mut d = f.debug_struct("PropertyMember"); + if let Some(getter) = property.getter(self.db) { + d.field("getter", &format_args!("`{}`", &getter.display(self.db))); + } + if let Some(setter) = property.setter(self.db) { + d.field("setter", &format_args!("`{}`", &setter.display(self.db))); + } + d.finish() + } + ProtocolMemberKind::Other(ty) => { + write!(f, "AttributeMember(`{}`)", ty.display(self.db)) + } + } + } + } + + ProtocolMemberDataDisplay { + db, + data: self.kind, + } + } } #[derive(Debug, Copy, Clone, PartialEq, Eq, salsa::Update, Hash)] diff --git a/crates/ty_vendored/ty_extensions/ty_extensions.pyi b/crates/ty_vendored/ty_extensions/ty_extensions.pyi index e14b3b2c21..4dd041762f 100644 --- a/crates/ty_vendored/ty_extensions/ty_extensions.pyi +++ b/crates/ty_vendored/ty_extensions/ty_extensions.pyi @@ -64,3 +64,8 @@ def all_members(obj: Any) -> tuple[str, ...]: ... # Returns `True` if the given object has a member with the given name. def has_member(obj: Any, name: str) -> bool: ... + +# Passing a protocol type to this function will cause ty to emit an info-level +# diagnostic describing the protocol's interface. Passing a non-protocol type +# will cause ty to emit an error diagnostic. +def reveal_protocol_interface(protocol: type) -> None: ...