[ty] Fix the inferred interface of specialized generic protocols (#19866)

This commit is contained in:
Alex Waygood 2025-08-27 18:16:15 +01:00 committed by GitHub
parent 7d0c8e045c
commit ce1dc21e7e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 200 additions and 57 deletions

View file

@ -1038,6 +1038,49 @@ def _(int_str: tuple[int, str], int_any: tuple[int, Any], any_any: tuple[Any, An
reveal_type(f(*(any_any,))) # revealed: Unknown reveal_type(f(*(any_any,))) # revealed: Unknown
``` ```
### `Unknown` passed into an overloaded function annotated with protocols
`Foo.join()` here has similar annotations to `str.join()` in typeshed:
`module.pyi`:
```pyi
from typing_extensions import Iterable, overload, LiteralString, Protocol
from ty_extensions import Unknown, is_assignable_to
class Foo:
@overload
def join(self, iterable: Iterable[LiteralString], /) -> LiteralString: ...
@overload
def join(self, iterable: Iterable[str], /) -> str: ...
```
`main.py`:
```py
from module import Foo
from typing_extensions import LiteralString
def f(a: Foo, b: list[str], c: list[LiteralString], e):
reveal_type(e) # revealed: Unknown
# TODO: we should select the second overload here and reveal `str`
# (the incorrect result is due to missing logic in protocol subtyping/assignability)
reveal_type(a.join(b)) # revealed: LiteralString
reveal_type(a.join(c)) # revealed: LiteralString
# since both overloads match and they have return types that are not equivalent,
# step (5) of the overload evaluation algorithm says we must evaluate the result of the
# call as `Unknown`.
#
# Note: although the spec does not state as such (since intersections in general are not
# specified currently), `(str | LiteralString) & Unknown` might also be a reasonable type
# here (the union of all overload returns, intersected with `Unknown`) -- here that would
# simplify to `str & Unknown`.
reveal_type(a.join(e)) # revealed: Unknown
```
### Multiple arguments ### Multiple arguments
`overloaded.pyi`: `overloaded.pyi`:

View file

@ -95,6 +95,20 @@ class NotAProtocol: ...
reveal_type(is_protocol(NotAProtocol)) # revealed: Literal[False] reveal_type(is_protocol(NotAProtocol)) # revealed: Literal[False]
``` ```
Note, however, that `is_protocol` returns `False` at runtime for specializations of generic
protocols. We still consider these to be "protocol classes" internally, regardless:
```py
class MyGenericProtocol[T](Protocol):
x: T
reveal_type(is_protocol(MyGenericProtocol)) # revealed: Literal[True]
# We still consider this a protocol class internally,
# but the inferred type of the call here reflects the result at runtime:
reveal_type(is_protocol(MyGenericProtocol[int])) # revealed: Literal[False]
```
A type checker should follow the typeshed stubs if a non-class is passed in, and typeshed's stubs A type checker should follow the typeshed stubs if a non-class is passed in, and typeshed's stubs
indicate that the argument passed in must be an instance of `type`. indicate that the argument passed in must be an instance of `type`.
@ -397,24 +411,38 @@ To see the kinds and types of the protocol members, you can use the debugging ai
```py ```py
from ty_extensions import reveal_protocol_interface from ty_extensions import reveal_protocol_interface
from typing import SupportsIndex, SupportsAbs, ClassVar from typing import SupportsIndex, SupportsAbs, ClassVar, Iterator
# error: [revealed-type] "Revealed protocol interface: `{"method_member": MethodMember(`(self) -> bytes`), "x": AttributeMember(`int`), "y": PropertyMember { getter: `def y(self) -> str` }, "z": PropertyMember { getter: `def z(self) -> int`, setter: `def z(self, z: int) -> None` }}`" # error: [revealed-type] "Revealed protocol interface: `{"method_member": MethodMember(`(self) -> bytes`), "x": AttributeMember(`int`), "y": PropertyMember { getter: `def y(self) -> str` }, "z": PropertyMember { getter: `def z(self) -> int`, setter: `def z(self, z: int) -> None` }}`"
reveal_protocol_interface(Foo) reveal_protocol_interface(Foo)
# error: [revealed-type] "Revealed protocol interface: `{"__index__": MethodMember(`(self) -> int`)}`" # error: [revealed-type] "Revealed protocol interface: `{"__index__": MethodMember(`(self) -> int`)}`"
reveal_protocol_interface(SupportsIndex) reveal_protocol_interface(SupportsIndex)
# error: [revealed-type] "Revealed protocol interface: `{"__abs__": MethodMember(`(self) -> _T_co@SupportsAbs`)}`" # error: [revealed-type] "Revealed protocol interface: `{"__abs__": MethodMember(`(self) -> Unknown`)}`"
reveal_protocol_interface(SupportsAbs) reveal_protocol_interface(SupportsAbs)
# error: [revealed-type] "Revealed protocol interface: `{"__iter__": MethodMember(`(self) -> Iterator[Unknown]`), "__next__": MethodMember(`(self) -> Unknown`)}`"
reveal_protocol_interface(Iterator)
# error: [invalid-argument-type] "Invalid argument to `reveal_protocol_interface`: Only protocol classes can be passed to `reveal_protocol_interface`" # error: [invalid-argument-type] "Invalid argument to `reveal_protocol_interface`: Only protocol classes can be passed to `reveal_protocol_interface`"
reveal_protocol_interface(int) reveal_protocol_interface(int)
# error: [invalid-argument-type] "Argument to function `reveal_protocol_interface` is incorrect: Expected `type`, found `Literal["foo"]`" # error: [invalid-argument-type] "Argument to function `reveal_protocol_interface` is incorrect: Expected `type`, found `Literal["foo"]`"
reveal_protocol_interface("foo") reveal_protocol_interface("foo")
```
# TODO: this should be a `revealed-type` diagnostic rather than `invalid-argument-type`, and it should reveal `{"__abs__": MethodMember(`(self) -> int`)}` for the protocol interface Similar to the way that `typing.is_protocol` returns `False` at runtime for all generic aliases,
# `typing.get_protocol_members` raises an exception at runtime if you pass it a generic alias, so we
# error: [invalid-argument-type] "Invalid argument to `reveal_protocol_interface`: Only protocol classes can be passed to `reveal_protocol_interface`" do not implement any special handling for generic aliases passed to the function.
`ty_extensions.reveal_protocol_interface` can be used on both, however:
```py
# TODO: these fail at runtime, but we don't emit `[invalid-argument-type]` diagnostics
# currently due to https://github.com/astral-sh/ty/issues/116
reveal_type(get_protocol_members(SupportsAbs[int])) # revealed: frozenset[str]
reveal_type(get_protocol_members(Iterator[int])) # revealed: frozenset[str]
# error: [revealed-type] "Revealed protocol interface: `{"__abs__": MethodMember(`(self) -> int`)}`"
reveal_protocol_interface(SupportsAbs[int]) reveal_protocol_interface(SupportsAbs[int])
# error: [revealed-type] "Revealed protocol interface: `{"__iter__": MethodMember(`(self) -> Iterator[int]`), "__next__": MethodMember(`(self) -> int`)}`"
reveal_protocol_interface(Iterator[int])
class BaseProto(Protocol): class BaseProto(Protocol):
def member(self) -> int: ... def member(self) -> int: ...
@ -1032,6 +1060,11 @@ class A(Protocol):
## Equivalence of protocols ## Equivalence of protocols
```toml
[environment]
python-version = "3.12"
```
Two protocols are considered equivalent types if they specify the same interface, even if they have Two protocols are considered equivalent types if they specify the same interface, even if they have
different names: different names:
@ -1080,6 +1113,46 @@ static_assert(is_equivalent_to(UnionProto1, UnionProto2))
static_assert(is_equivalent_to(UnionProto1 | A | B, B | UnionProto2 | A)) static_assert(is_equivalent_to(UnionProto1 | A | B, B | UnionProto2 | A))
``` ```
Different generic protocols with equivalent specializations can be equivalent, but generic protocols
with different specializations are not considered equivalent:
```py
from typing import TypeVar
S = TypeVar("S")
class NonGenericProto1(Protocol):
x: int
y: str
class NonGenericProto2(Protocol):
y: str
x: int
class Nominal1: ...
class Nominal2: ...
class GenericProto[T](Protocol):
x: T
class LegacyGenericProto(Protocol[S]):
x: S
static_assert(is_equivalent_to(GenericProto[int], LegacyGenericProto[int]))
static_assert(is_equivalent_to(GenericProto[NonGenericProto1], LegacyGenericProto[NonGenericProto2]))
static_assert(
is_equivalent_to(
GenericProto[NonGenericProto1 | Nominal1 | Nominal2], LegacyGenericProto[Nominal2 | Nominal1 | NonGenericProto2]
)
)
static_assert(not is_equivalent_to(GenericProto[str], GenericProto[int]))
static_assert(not is_equivalent_to(GenericProto[str], LegacyGenericProto[int]))
static_assert(not is_equivalent_to(GenericProto, GenericProto[int]))
static_assert(not is_equivalent_to(LegacyGenericProto, LegacyGenericProto[int]))
```
## Intersections of protocols ## Intersections of protocols
An intersection of two protocol types `X` and `Y` is equivalent to a protocol type `Z` that inherits An intersection of two protocol types `X` and `Y` is equivalent to a protocol type `Z` that inherits

View file

@ -748,6 +748,10 @@ impl<'db> Bindings<'db> {
Some(KnownFunction::IsProtocol) => { Some(KnownFunction::IsProtocol) => {
if let [Some(ty)] = overload.parameter_types() { if let [Some(ty)] = overload.parameter_types() {
// We evaluate this to `Literal[True]` only if the runtime function `typing.is_protocol`
// would return `True` for the given type. Internally we consider `SupportsAbs[int]` to
// be a "(specialised) protocol class", but `typing.is_protocol(SupportsAbs[int])` returns
// `False` at runtime, so we do not set the return type to `Literal[True]` in this case.
overload.set_return_type(Type::BooleanLiteral( overload.set_return_type(Type::BooleanLiteral(
ty.into_class_literal() ty.into_class_literal()
.is_some_and(|class| class.is_protocol(db)), .is_some_and(|class| class.is_protocol(db)),
@ -756,6 +760,9 @@ impl<'db> Bindings<'db> {
} }
Some(KnownFunction::GetProtocolMembers) => { Some(KnownFunction::GetProtocolMembers) => {
// Similarly to `is_protocol`, we only evaluate to this a frozenset of literal strings if a
// class-literal is passed in, not if a generic alias is passed in, to emulate the behaviour
// of `typing.get_protocol_members` at runtime.
if let [Some(Type::ClassLiteral(class))] = overload.parameter_types() { if let [Some(Type::ClassLiteral(class))] = overload.parameter_types() {
if let Some(protocol_class) = class.into_protocol_class(db) { if let Some(protocol_class) = class.into_protocol_class(db) {
let member_names = protocol_class let member_names = protocol_class

View file

@ -1198,6 +1198,14 @@ impl<'db> ClassType<'db> {
} }
} }
} }
pub(super) fn is_protocol(self, db: &'db dyn Db) -> bool {
self.class_literal(db).0.is_protocol(db)
}
pub(super) fn header_span(self, db: &'db dyn Db) -> Span {
self.class_literal(db).0.header_span(db)
}
} }
impl<'db> From<GenericAlias<'db>> for ClassType<'db> { impl<'db> From<GenericAlias<'db>> for ClassType<'db> {

View file

@ -20,7 +20,7 @@ use crate::types::string_annotation::{
use crate::types::{ use crate::types::{
DynamicType, LintDiagnosticGuard, Protocol, ProtocolInstanceType, SubclassOfInner, binding_type, DynamicType, LintDiagnosticGuard, Protocol, ProtocolInstanceType, SubclassOfInner, binding_type,
}; };
use crate::types::{SpecialFormType, Type, protocol_class::ProtocolClassLiteral}; use crate::types::{SpecialFormType, Type, protocol_class::ProtocolClass};
use crate::util::diagnostics::format_enumeration; use crate::util::diagnostics::format_enumeration;
use crate::{Db, FxIndexMap, FxOrderMap, Module, ModuleName, Program, declare_lint}; use crate::{Db, FxIndexMap, FxOrderMap, Module, ModuleName, Program, declare_lint};
use itertools::Itertools; use itertools::Itertools;
@ -2467,7 +2467,7 @@ pub(crate) fn add_type_expression_reference_link<'db, 'ctx>(
pub(crate) fn report_runtime_check_against_non_runtime_checkable_protocol( pub(crate) fn report_runtime_check_against_non_runtime_checkable_protocol(
context: &InferContext, context: &InferContext,
call: &ast::ExprCall, call: &ast::ExprCall,
protocol: ProtocolClassLiteral, protocol: ProtocolClass,
function: KnownFunction, function: KnownFunction,
) { ) {
let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, call) else { let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, call) else {
@ -2504,7 +2504,7 @@ pub(crate) fn report_runtime_check_against_non_runtime_checkable_protocol(
pub(crate) fn report_attempted_protocol_instantiation( pub(crate) fn report_attempted_protocol_instantiation(
context: &InferContext, context: &InferContext,
call: &ast::ExprCall, call: &ast::ExprCall,
protocol: ProtocolClassLiteral, protocol: ProtocolClass,
) { ) {
let Some(builder) = context.report_lint(&CALL_NON_CALLABLE, call) else { let Some(builder) = context.report_lint(&CALL_NON_CALLABLE, call) else {
return; return;
@ -2529,7 +2529,7 @@ pub(crate) fn report_attempted_protocol_instantiation(
pub(crate) fn report_undeclared_protocol_member( pub(crate) fn report_undeclared_protocol_member(
context: &InferContext, context: &InferContext,
definition: Definition, definition: Definition,
protocol_class: ProtocolClassLiteral, protocol_class: ProtocolClass,
class_symbol_table: &PlaceTable, class_symbol_table: &PlaceTable,
) { ) {
/// We want to avoid suggesting an annotation for e.g. `x = None`, /// We want to avoid suggesting an annotation for e.g. `x = None`,

View file

@ -1433,7 +1433,7 @@ impl KnownFunction {
return; return;
}; };
let Some(protocol_class) = param_type let Some(protocol_class) = param_type
.into_class_literal() .to_class_type(db)
.and_then(|class| class.into_protocol_class(db)) .and_then(|class| class.into_protocol_class(db))
else { else {
report_bad_argument_to_protocol_interface( report_bad_argument_to_protocol_interface(

View file

@ -1216,8 +1216,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} }
if is_protocol if is_protocol
&& !(base_class.class_literal(self.db()).0.is_protocol(self.db()) && !(base_class.is_protocol(self.db()) || base_class.is_object(self.db()))
|| base_class.is_known(self.db(), KnownClass::Object))
{ {
if let Some(builder) = self if let Some(builder) = self
.context .context
@ -6249,11 +6248,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// subclasses of the protocol to be passed to parameters that accept `type[SomeProtocol]`. // subclasses of the protocol to be passed to parameters that accept `type[SomeProtocol]`.
// <https://typing.python.org/en/latest/spec/protocol.html#type-and-class-objects-vs-protocols>. // <https://typing.python.org/en/latest/spec/protocol.html#type-and-class-objects-vs-protocols>.
if !callable_type.is_subclass_of() { if !callable_type.is_subclass_of() {
if let Some(protocol) = class if let Some(protocol) = class.into_protocol_class(self.db()) {
.class_literal(self.db())
.0
.into_protocol_class(self.db())
{
report_attempted_protocol_instantiation( report_attempted_protocol_instantiation(
&self.context, &self.context,
call_expression, call_expression,

View file

@ -645,10 +645,8 @@ impl<'db> Protocol<'db> {
fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> { fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> {
match self { match self {
Self::FromClass(class) => class Self::FromClass(class) => class
.class_literal(db)
.0
.into_protocol_class(db) .into_protocol_class(db)
.expect("Protocol class literal should be a protocol class") .expect("Class wrapped by `Protocol` should be a protocol class")
.interface(db), .interface(db),
Self::Synthesized(synthesized) => synthesized.interface(), Self::Synthesized(synthesized) => synthesized.interface(),
} }

View file

@ -9,6 +9,7 @@ use rustc_hash::FxHashMap;
use super::TypeVarVariance; use super::TypeVarVariance;
use crate::semantic_index::place::ScopedPlaceId; use crate::semantic_index::place::ScopedPlaceId;
use crate::semantic_index::{SemanticIndex, place_table}; use crate::semantic_index::{SemanticIndex, place_table};
use crate::types::ClassType;
use crate::types::context::InferContext; use crate::types::context::InferContext;
use crate::types::diagnostic::report_undeclared_protocol_member; use crate::types::diagnostic::report_undeclared_protocol_member;
use crate::{ use crate::{
@ -26,16 +27,24 @@ use crate::{
impl<'db> ClassLiteral<'db> { impl<'db> ClassLiteral<'db> {
/// Returns `Some` if this is a protocol class, `None` otherwise. /// Returns `Some` if this is a protocol class, `None` otherwise.
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClassLiteral<'db>> { pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClass<'db>> {
self.is_protocol(db).then_some(ProtocolClassLiteral(self)) self.is_protocol(db)
.then_some(ProtocolClass(ClassType::NonGeneric(self)))
}
}
impl<'db> ClassType<'db> {
/// Returns `Some` if this is a protocol class, `None` otherwise.
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClass<'db>> {
self.is_protocol(db).then_some(ProtocolClass(self))
} }
} }
/// Representation of a single `Protocol` class definition. /// Representation of a single `Protocol` class definition.
#[derive(Debug, Copy, Clone, PartialEq, Eq)] #[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(super) struct ProtocolClassLiteral<'db>(ClassLiteral<'db>); pub(super) struct ProtocolClass<'db>(ClassType<'db>);
impl<'db> ProtocolClassLiteral<'db> { impl<'db> ProtocolClass<'db> {
/// Returns the protocol members of this class. /// Returns the protocol members of this class.
/// ///
/// A protocol's members define the interface declared by the protocol. /// A protocol's members define the interface declared by the protocol.
@ -56,7 +65,9 @@ impl<'db> ProtocolClassLiteral<'db> {
} }
pub(super) fn is_runtime_checkable(self, db: &'db dyn Db) -> bool { pub(super) fn is_runtime_checkable(self, db: &'db dyn Db) -> bool {
self.known_function_decorators(db) self.class_literal(db)
.0
.known_function_decorators(db)
.contains(&KnownFunction::RuntimeCheckable) .contains(&KnownFunction::RuntimeCheckable)
} }
@ -66,10 +77,11 @@ impl<'db> ProtocolClassLiteral<'db> {
pub(super) fn validate_members(self, context: &InferContext, index: &SemanticIndex<'db>) { pub(super) fn validate_members(self, context: &InferContext, index: &SemanticIndex<'db>) {
let db = context.db(); let db = context.db();
let interface = self.interface(db); let interface = self.interface(db);
let class_place_table = index.place_table(self.body_scope(db).file_scope_id(db)); let body_scope = self.class_literal(db).0.body_scope(db);
let class_place_table = index.place_table(body_scope.file_scope_id(db));
for (symbol_id, mut bindings_iterator) in for (symbol_id, mut bindings_iterator) in
use_def_map(db, self.body_scope(db)).all_end_of_scope_symbol_bindings() use_def_map(db, body_scope).all_end_of_scope_symbol_bindings()
{ {
let symbol_name = class_place_table.symbol(symbol_id).name(); let symbol_name = class_place_table.symbol(symbol_id).name();
@ -77,8 +89,8 @@ impl<'db> ProtocolClassLiteral<'db> {
continue; continue;
} }
let has_declaration = self let has_declaration =
.iter_mro(db, None) self.iter_mro(db)
.filter_map(ClassBase::into_class) .filter_map(ClassBase::into_class)
.any(|superclass| { .any(|superclass| {
let superclass_scope = superclass.class_literal(db).0.body_scope(db); let superclass_scope = superclass.class_literal(db).0.body_scope(db);
@ -114,8 +126,8 @@ impl<'db> ProtocolClassLiteral<'db> {
} }
} }
impl<'db> Deref for ProtocolClassLiteral<'db> { impl<'db> Deref for ProtocolClass<'db> {
type Target = ClassLiteral<'db>; type Target = ClassType<'db>;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.0 &self.0
@ -622,16 +634,19 @@ impl BoundOnClass {
#[salsa::tracked(cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial, heap_size=ruff_memory_usage::heap_size)] #[salsa::tracked(cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
fn cached_protocol_interface<'db>( fn cached_protocol_interface<'db>(
db: &'db dyn Db, db: &'db dyn Db,
class: ClassLiteral<'db>, class: ClassType<'db>,
) -> ProtocolInterface<'db> { ) -> ProtocolInterface<'db> {
let mut members = BTreeMap::default(); let mut members = BTreeMap::default();
for parent_protocol in class for (parent_protocol, specialization) in class
.iter_mro(db, None) .iter_mro(db)
.filter_map(ClassBase::into_class) .filter_map(ClassBase::into_class)
.filter_map(|class| class.class_literal(db).0.into_protocol_class(db)) .filter_map(|class| {
let (class, specialization) = class.class_literal(db);
Some((class.into_protocol_class(db)?, specialization))
})
{ {
let parent_scope = parent_protocol.body_scope(db); let parent_scope = parent_protocol.class_literal(db).0.body_scope(db);
let use_def_map = use_def_map(db, parent_scope); let use_def_map = use_def_map(db, parent_scope);
let place_table = place_table(db, parent_scope); let place_table = place_table(db, parent_scope);
let mut direct_members = FxHashMap::default(); let mut direct_members = FxHashMap::default();
@ -676,6 +691,8 @@ fn cached_protocol_interface<'db>(
continue; continue;
} }
let ty = ty.apply_optional_specialization(db, specialization);
let member = match ty { let member = match ty {
Type::PropertyInstance(property) => ProtocolMemberKind::Property(property), Type::PropertyInstance(property) => ProtocolMemberKind::Property(property),
Type::Callable(callable) Type::Callable(callable)
@ -702,19 +719,21 @@ fn cached_protocol_interface<'db>(
ProtocolInterface::new(db, members) ProtocolInterface::new(db, members)
} }
// If we use `expect(clippy::trivially_copy_pass_by_ref)` here,
// the lint expectation is unfulfilled on WASM
#[allow(clippy::trivially_copy_pass_by_ref)] #[allow(clippy::trivially_copy_pass_by_ref)]
fn proto_interface_cycle_recover<'db>( fn proto_interface_cycle_recover<'db>(
_db: &dyn Db, _db: &dyn Db,
_value: &ProtocolInterface<'db>, _value: &ProtocolInterface<'db>,
_count: u32, _count: u32,
_class: ClassLiteral<'db>, _class: ClassType<'db>,
) -> salsa::CycleRecoveryAction<ProtocolInterface<'db>> { ) -> salsa::CycleRecoveryAction<ProtocolInterface<'db>> {
salsa::CycleRecoveryAction::Iterate salsa::CycleRecoveryAction::Iterate
} }
fn proto_interface_cycle_initial<'db>( fn proto_interface_cycle_initial<'db>(
db: &'db dyn Db, db: &'db dyn Db,
_class: ClassLiteral<'db>, _class: ClassType<'db>,
) -> ProtocolInterface<'db> { ) -> ProtocolInterface<'db> {
ProtocolInterface::empty(db) ProtocolInterface::empty(db)
} }