From ce1dc21e7ee99eebd5c2a00c2a0c9be4827321d1 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Wed, 27 Aug 2025 18:16:15 +0100 Subject: [PATCH] [ty] Fix the inferred interface of specialized generic protocols (#19866) --- .../resources/mdtest/call/overloads.md | 43 +++++++++ .../resources/mdtest/protocols.md | 83 ++++++++++++++++- .../ty_python_semantic/src/types/call/bind.rs | 7 ++ crates/ty_python_semantic/src/types/class.rs | 8 ++ .../src/types/diagnostic.rs | 8 +- .../ty_python_semantic/src/types/function.rs | 2 +- crates/ty_python_semantic/src/types/infer.rs | 9 +- .../ty_python_semantic/src/types/instance.rs | 4 +- .../src/types/protocol_class.rs | 93 +++++++++++-------- 9 files changed, 200 insertions(+), 57 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/call/overloads.md b/crates/ty_python_semantic/resources/mdtest/call/overloads.md index c2cc47d1ee..ca34c94b5e 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/overloads.md +++ b/crates/ty_python_semantic/resources/mdtest/call/overloads.md @@ -1038,6 +1038,49 @@ def _(int_str: tuple[int, str], int_any: tuple[int, Any], any_any: tuple[Any, An reveal_type(f(*(any_any,))) # revealed: Unknown ``` +### `Unknown` passed into an overloaded function annotated with protocols + +`Foo.join()` here has similar annotations to `str.join()` in typeshed: + +`module.pyi`: + +```pyi +from typing_extensions import Iterable, overload, LiteralString, Protocol +from ty_extensions import Unknown, is_assignable_to + +class Foo: + @overload + def join(self, iterable: Iterable[LiteralString], /) -> LiteralString: ... + @overload + def join(self, iterable: Iterable[str], /) -> str: ... +``` + +`main.py`: + +```py +from module import Foo +from typing_extensions import LiteralString + +def f(a: Foo, b: list[str], c: list[LiteralString], e): + reveal_type(e) # revealed: Unknown + + # TODO: we should select the second overload here and reveal `str` + # (the incorrect result is due to missing logic in protocol subtyping/assignability) + reveal_type(a.join(b)) # revealed: LiteralString + + reveal_type(a.join(c)) # revealed: LiteralString + + # since both overloads match and they have return types that are not equivalent, + # step (5) of the overload evaluation algorithm says we must evaluate the result of the + # call as `Unknown`. + # + # Note: although the spec does not state as such (since intersections in general are not + # specified currently), `(str | LiteralString) & Unknown` might also be a reasonable type + # here (the union of all overload returns, intersected with `Unknown`) -- here that would + # simplify to `str & Unknown`. + reveal_type(a.join(e)) # revealed: Unknown +``` + ### Multiple arguments `overloaded.pyi`: diff --git a/crates/ty_python_semantic/resources/mdtest/protocols.md b/crates/ty_python_semantic/resources/mdtest/protocols.md index 42971c61f4..d0b72df047 100644 --- a/crates/ty_python_semantic/resources/mdtest/protocols.md +++ b/crates/ty_python_semantic/resources/mdtest/protocols.md @@ -95,6 +95,20 @@ class NotAProtocol: ... reveal_type(is_protocol(NotAProtocol)) # revealed: Literal[False] ``` +Note, however, that `is_protocol` returns `False` at runtime for specializations of generic +protocols. We still consider these to be "protocol classes" internally, regardless: + +```py +class MyGenericProtocol[T](Protocol): + x: T + +reveal_type(is_protocol(MyGenericProtocol)) # revealed: Literal[True] + +# We still consider this a protocol class internally, +# but the inferred type of the call here reflects the result at runtime: +reveal_type(is_protocol(MyGenericProtocol[int])) # revealed: Literal[False] +``` + A type checker should follow the typeshed stubs if a non-class is passed in, and typeshed's stubs indicate that the argument passed in must be an instance of `type`. @@ -397,24 +411,38 @@ To see the kinds and types of the protocol members, you can use the debugging ai ```py from ty_extensions import reveal_protocol_interface -from typing import SupportsIndex, SupportsAbs, ClassVar +from typing import SupportsIndex, SupportsAbs, ClassVar, Iterator # error: [revealed-type] "Revealed protocol interface: `{"method_member": MethodMember(`(self) -> bytes`), "x": AttributeMember(`int`), "y": PropertyMember { getter: `def y(self) -> str` }, "z": PropertyMember { getter: `def z(self) -> int`, setter: `def z(self, z: int) -> None` }}`" reveal_protocol_interface(Foo) # error: [revealed-type] "Revealed protocol interface: `{"__index__": MethodMember(`(self) -> int`)}`" reveal_protocol_interface(SupportsIndex) -# error: [revealed-type] "Revealed protocol interface: `{"__abs__": MethodMember(`(self) -> _T_co@SupportsAbs`)}`" +# error: [revealed-type] "Revealed protocol interface: `{"__abs__": MethodMember(`(self) -> Unknown`)}`" reveal_protocol_interface(SupportsAbs) +# error: [revealed-type] "Revealed protocol interface: `{"__iter__": MethodMember(`(self) -> Iterator[Unknown]`), "__next__": MethodMember(`(self) -> Unknown`)}`" +reveal_protocol_interface(Iterator) # error: [invalid-argument-type] "Invalid argument to `reveal_protocol_interface`: Only protocol classes can be passed to `reveal_protocol_interface`" reveal_protocol_interface(int) # error: [invalid-argument-type] "Argument to function `reveal_protocol_interface` is incorrect: Expected `type`, found `Literal["foo"]`" reveal_protocol_interface("foo") +``` -# TODO: this should be a `revealed-type` diagnostic rather than `invalid-argument-type`, and it should reveal `{"__abs__": MethodMember(`(self) -> int`)}` for the protocol interface -# -# error: [invalid-argument-type] "Invalid argument to `reveal_protocol_interface`: Only protocol classes can be passed to `reveal_protocol_interface`" +Similar to the way that `typing.is_protocol` returns `False` at runtime for all generic aliases, +`typing.get_protocol_members` raises an exception at runtime if you pass it a generic alias, so we +do not implement any special handling for generic aliases passed to the function. +`ty_extensions.reveal_protocol_interface` can be used on both, however: + +```py +# TODO: these fail at runtime, but we don't emit `[invalid-argument-type]` diagnostics +# currently due to https://github.com/astral-sh/ty/issues/116 +reveal_type(get_protocol_members(SupportsAbs[int])) # revealed: frozenset[str] +reveal_type(get_protocol_members(Iterator[int])) # revealed: frozenset[str] + +# error: [revealed-type] "Revealed protocol interface: `{"__abs__": MethodMember(`(self) -> int`)}`" reveal_protocol_interface(SupportsAbs[int]) +# error: [revealed-type] "Revealed protocol interface: `{"__iter__": MethodMember(`(self) -> Iterator[int]`), "__next__": MethodMember(`(self) -> int`)}`" +reveal_protocol_interface(Iterator[int]) class BaseProto(Protocol): def member(self) -> int: ... @@ -1032,6 +1060,11 @@ class A(Protocol): ## Equivalence of protocols +```toml +[environment] +python-version = "3.12" +``` + Two protocols are considered equivalent types if they specify the same interface, even if they have different names: @@ -1080,6 +1113,46 @@ static_assert(is_equivalent_to(UnionProto1, UnionProto2)) static_assert(is_equivalent_to(UnionProto1 | A | B, B | UnionProto2 | A)) ``` +Different generic protocols with equivalent specializations can be equivalent, but generic protocols +with different specializations are not considered equivalent: + +```py +from typing import TypeVar + +S = TypeVar("S") + +class NonGenericProto1(Protocol): + x: int + y: str + +class NonGenericProto2(Protocol): + y: str + x: int + +class Nominal1: ... +class Nominal2: ... + +class GenericProto[T](Protocol): + x: T + +class LegacyGenericProto(Protocol[S]): + x: S + +static_assert(is_equivalent_to(GenericProto[int], LegacyGenericProto[int])) +static_assert(is_equivalent_to(GenericProto[NonGenericProto1], LegacyGenericProto[NonGenericProto2])) + +static_assert( + is_equivalent_to( + GenericProto[NonGenericProto1 | Nominal1 | Nominal2], LegacyGenericProto[Nominal2 | Nominal1 | NonGenericProto2] + ) +) + +static_assert(not is_equivalent_to(GenericProto[str], GenericProto[int])) +static_assert(not is_equivalent_to(GenericProto[str], LegacyGenericProto[int])) +static_assert(not is_equivalent_to(GenericProto, GenericProto[int])) +static_assert(not is_equivalent_to(LegacyGenericProto, LegacyGenericProto[int])) +``` + ## Intersections of protocols An intersection of two protocol types `X` and `Y` is equivalent to a protocol type `Z` that inherits diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index da1d962f99..4933e515e6 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -748,6 +748,10 @@ impl<'db> Bindings<'db> { Some(KnownFunction::IsProtocol) => { if let [Some(ty)] = overload.parameter_types() { + // We evaluate this to `Literal[True]` only if the runtime function `typing.is_protocol` + // would return `True` for the given type. Internally we consider `SupportsAbs[int]` to + // be a "(specialised) protocol class", but `typing.is_protocol(SupportsAbs[int])` returns + // `False` at runtime, so we do not set the return type to `Literal[True]` in this case. overload.set_return_type(Type::BooleanLiteral( ty.into_class_literal() .is_some_and(|class| class.is_protocol(db)), @@ -756,6 +760,9 @@ impl<'db> Bindings<'db> { } Some(KnownFunction::GetProtocolMembers) => { + // Similarly to `is_protocol`, we only evaluate to this a frozenset of literal strings if a + // class-literal is passed in, not if a generic alias is passed in, to emulate the behaviour + // of `typing.get_protocol_members` at runtime. if let [Some(Type::ClassLiteral(class))] = overload.parameter_types() { if let Some(protocol_class) = class.into_protocol_class(db) { let member_names = protocol_class diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index d29a0afdee..a99a2af6ca 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -1198,6 +1198,14 @@ impl<'db> ClassType<'db> { } } } + + pub(super) fn is_protocol(self, db: &'db dyn Db) -> bool { + self.class_literal(db).0.is_protocol(db) + } + + pub(super) fn header_span(self, db: &'db dyn Db) -> Span { + self.class_literal(db).0.header_span(db) + } } impl<'db> From> for ClassType<'db> { diff --git a/crates/ty_python_semantic/src/types/diagnostic.rs b/crates/ty_python_semantic/src/types/diagnostic.rs index 5377614e5a..44981271c1 100644 --- a/crates/ty_python_semantic/src/types/diagnostic.rs +++ b/crates/ty_python_semantic/src/types/diagnostic.rs @@ -20,7 +20,7 @@ use crate::types::string_annotation::{ use crate::types::{ DynamicType, LintDiagnosticGuard, Protocol, ProtocolInstanceType, SubclassOfInner, binding_type, }; -use crate::types::{SpecialFormType, Type, protocol_class::ProtocolClassLiteral}; +use crate::types::{SpecialFormType, Type, protocol_class::ProtocolClass}; use crate::util::diagnostics::format_enumeration; use crate::{Db, FxIndexMap, FxOrderMap, Module, ModuleName, Program, declare_lint}; use itertools::Itertools; @@ -2467,7 +2467,7 @@ pub(crate) fn add_type_expression_reference_link<'db, 'ctx>( pub(crate) fn report_runtime_check_against_non_runtime_checkable_protocol( context: &InferContext, call: &ast::ExprCall, - protocol: ProtocolClassLiteral, + protocol: ProtocolClass, function: KnownFunction, ) { let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, call) else { @@ -2504,7 +2504,7 @@ pub(crate) fn report_runtime_check_against_non_runtime_checkable_protocol( pub(crate) fn report_attempted_protocol_instantiation( context: &InferContext, call: &ast::ExprCall, - protocol: ProtocolClassLiteral, + protocol: ProtocolClass, ) { let Some(builder) = context.report_lint(&CALL_NON_CALLABLE, call) else { return; @@ -2529,7 +2529,7 @@ pub(crate) fn report_attempted_protocol_instantiation( pub(crate) fn report_undeclared_protocol_member( context: &InferContext, definition: Definition, - protocol_class: ProtocolClassLiteral, + protocol_class: ProtocolClass, class_symbol_table: &PlaceTable, ) { /// We want to avoid suggesting an annotation for e.g. `x = None`, diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 5fc02fc0c3..8840e00413 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -1433,7 +1433,7 @@ impl KnownFunction { return; }; let Some(protocol_class) = param_type - .into_class_literal() + .to_class_type(db) .and_then(|class| class.into_protocol_class(db)) else { report_bad_argument_to_protocol_interface( diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 6413bbefaa..67fb854f17 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -1216,8 +1216,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } if is_protocol - && !(base_class.class_literal(self.db()).0.is_protocol(self.db()) - || base_class.is_known(self.db(), KnownClass::Object)) + && !(base_class.is_protocol(self.db()) || base_class.is_object(self.db())) { if let Some(builder) = self .context @@ -6249,11 +6248,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // subclasses of the protocol to be passed to parameters that accept `type[SomeProtocol]`. // . if !callable_type.is_subclass_of() { - if let Some(protocol) = class - .class_literal(self.db()) - .0 - .into_protocol_class(self.db()) - { + if let Some(protocol) = class.into_protocol_class(self.db()) { report_attempted_protocol_instantiation( &self.context, call_expression, diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 1a403e11ea..fb6bff18fc 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -645,10 +645,8 @@ impl<'db> Protocol<'db> { fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> { match self { Self::FromClass(class) => class - .class_literal(db) - .0 .into_protocol_class(db) - .expect("Protocol class literal should be a protocol class") + .expect("Class wrapped by `Protocol` should be a protocol class") .interface(db), Self::Synthesized(synthesized) => synthesized.interface(), } diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index 6bb6a16ac4..ec1e394167 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -9,6 +9,7 @@ use rustc_hash::FxHashMap; use super::TypeVarVariance; use crate::semantic_index::place::ScopedPlaceId; use crate::semantic_index::{SemanticIndex, place_table}; +use crate::types::ClassType; use crate::types::context::InferContext; use crate::types::diagnostic::report_undeclared_protocol_member; use crate::{ @@ -26,16 +27,24 @@ use crate::{ impl<'db> ClassLiteral<'db> { /// Returns `Some` if this is a protocol class, `None` otherwise. - pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option> { - self.is_protocol(db).then_some(ProtocolClassLiteral(self)) + pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option> { + self.is_protocol(db) + .then_some(ProtocolClass(ClassType::NonGeneric(self))) + } +} + +impl<'db> ClassType<'db> { + /// Returns `Some` if this is a protocol class, `None` otherwise. + pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option> { + self.is_protocol(db).then_some(ProtocolClass(self)) } } /// Representation of a single `Protocol` class definition. #[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub(super) struct ProtocolClassLiteral<'db>(ClassLiteral<'db>); +pub(super) struct ProtocolClass<'db>(ClassType<'db>); -impl<'db> ProtocolClassLiteral<'db> { +impl<'db> ProtocolClass<'db> { /// Returns the protocol members of this class. /// /// A protocol's members define the interface declared by the protocol. @@ -56,7 +65,9 @@ impl<'db> ProtocolClassLiteral<'db> { } pub(super) fn is_runtime_checkable(self, db: &'db dyn Db) -> bool { - self.known_function_decorators(db) + self.class_literal(db) + .0 + .known_function_decorators(db) .contains(&KnownFunction::RuntimeCheckable) } @@ -66,10 +77,11 @@ impl<'db> ProtocolClassLiteral<'db> { pub(super) fn validate_members(self, context: &InferContext, index: &SemanticIndex<'db>) { let db = context.db(); let interface = self.interface(db); - let class_place_table = index.place_table(self.body_scope(db).file_scope_id(db)); + let body_scope = self.class_literal(db).0.body_scope(db); + let class_place_table = index.place_table(body_scope.file_scope_id(db)); for (symbol_id, mut bindings_iterator) in - use_def_map(db, self.body_scope(db)).all_end_of_scope_symbol_bindings() + use_def_map(db, body_scope).all_end_of_scope_symbol_bindings() { let symbol_name = class_place_table.symbol(symbol_id).name(); @@ -77,27 +89,27 @@ impl<'db> ProtocolClassLiteral<'db> { continue; } - let has_declaration = self - .iter_mro(db, None) - .filter_map(ClassBase::into_class) - .any(|superclass| { - let superclass_scope = superclass.class_literal(db).0.body_scope(db); - let Some(scoped_symbol_id) = - place_table(db, superclass_scope).symbol_id(symbol_name) - else { - return false; - }; - !place_from_declarations( - db, - index - .use_def_map(superclass_scope.file_scope_id(db)) - .end_of_scope_declarations(ScopedPlaceId::Symbol(scoped_symbol_id)), - ) - .into_place_and_conflicting_declarations() - .0 - .place - .is_unbound() - }); + let has_declaration = + self.iter_mro(db) + .filter_map(ClassBase::into_class) + .any(|superclass| { + let superclass_scope = superclass.class_literal(db).0.body_scope(db); + let Some(scoped_symbol_id) = + place_table(db, superclass_scope).symbol_id(symbol_name) + else { + return false; + }; + !place_from_declarations( + db, + index + .use_def_map(superclass_scope.file_scope_id(db)) + .end_of_scope_declarations(ScopedPlaceId::Symbol(scoped_symbol_id)), + ) + .into_place_and_conflicting_declarations() + .0 + .place + .is_unbound() + }); if has_declaration { continue; @@ -114,8 +126,8 @@ impl<'db> ProtocolClassLiteral<'db> { } } -impl<'db> Deref for ProtocolClassLiteral<'db> { - type Target = ClassLiteral<'db>; +impl<'db> Deref for ProtocolClass<'db> { + type Target = ClassType<'db>; fn deref(&self) -> &Self::Target { &self.0 @@ -622,16 +634,19 @@ impl BoundOnClass { #[salsa::tracked(cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial, heap_size=ruff_memory_usage::heap_size)] fn cached_protocol_interface<'db>( db: &'db dyn Db, - class: ClassLiteral<'db>, + class: ClassType<'db>, ) -> ProtocolInterface<'db> { let mut members = BTreeMap::default(); - for parent_protocol in class - .iter_mro(db, None) + for (parent_protocol, specialization) in class + .iter_mro(db) .filter_map(ClassBase::into_class) - .filter_map(|class| class.class_literal(db).0.into_protocol_class(db)) + .filter_map(|class| { + let (class, specialization) = class.class_literal(db); + Some((class.into_protocol_class(db)?, specialization)) + }) { - let parent_scope = parent_protocol.body_scope(db); + let parent_scope = parent_protocol.class_literal(db).0.body_scope(db); let use_def_map = use_def_map(db, parent_scope); let place_table = place_table(db, parent_scope); let mut direct_members = FxHashMap::default(); @@ -676,6 +691,8 @@ fn cached_protocol_interface<'db>( continue; } + let ty = ty.apply_optional_specialization(db, specialization); + let member = match ty { Type::PropertyInstance(property) => ProtocolMemberKind::Property(property), Type::Callable(callable) @@ -702,19 +719,21 @@ fn cached_protocol_interface<'db>( ProtocolInterface::new(db, members) } +// If we use `expect(clippy::trivially_copy_pass_by_ref)` here, +// the lint expectation is unfulfilled on WASM #[allow(clippy::trivially_copy_pass_by_ref)] fn proto_interface_cycle_recover<'db>( _db: &dyn Db, _value: &ProtocolInterface<'db>, _count: u32, - _class: ClassLiteral<'db>, + _class: ClassType<'db>, ) -> salsa::CycleRecoveryAction> { salsa::CycleRecoveryAction::Iterate } fn proto_interface_cycle_initial<'db>( db: &'db dyn Db, - _class: ClassLiteral<'db>, + _class: ClassType<'db>, ) -> ProtocolInterface<'db> { ProtocolInterface::empty(db) }