[red-knot] Correctly identify protocol classes (#17487)

This commit is contained in:
Alex Waygood 2025-04-21 16:17:06 +01:00 committed by GitHub
parent c077b109ce
commit 9ff4772a2c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 90 additions and 36 deletions

View file

@ -74,8 +74,6 @@ class Baz(Bar):
T = TypeVar("T")
class Qux(Protocol[T]):
# TODO: no error
# error: [invalid-return-type]
def f(self) -> int: ...
class Foo(Protocol):

View file

@ -40,27 +40,63 @@ class Foo(Protocol, Protocol): ... # error: [inconsistent-mro]
reveal_type(Foo.__mro__) # revealed: tuple[Literal[Foo], Unknown, Literal[object]]
```
Protocols can also be generic, either by including `Generic[]` in the bases list, subscripting
`Protocol` directly in the bases list, using PEP-695 type parameters, or some combination of the
above:
```py
from typing import TypeVar, Generic
T = TypeVar("T")
class Bar0(Protocol[T]):
x: T
class Bar1(Protocol[T], Generic[T]):
x: T
class Bar2[T](Protocol):
x: T
class Bar3[T](Protocol[T]):
x: T
```
It's an error to include both bare `Protocol` and subscripted `Protocol[]` in the bases list
simultaneously:
```py
# TODO: should emit a `[duplicate-bases]` error here:
class DuplicateBases(Protocol, Protocol[T]):
x: T
# TODO: should not have `Generic` multiple times and `Protocol` multiple times
# revealed: tuple[Literal[DuplicateBases], typing.Protocol, typing.Generic, @Todo(`Protocol[]` subscript), @Todo(`Generic[]` subscript), Literal[object]]
reveal_type(DuplicateBases.__mro__)
```
The introspection helper `typing(_extensions).is_protocol` can be used to verify whether a class is
a protocol class or not:
```py
from typing_extensions import is_protocol
# TODO: should be `Literal[True]`
reveal_type(is_protocol(MyProtocol)) # revealed: bool
reveal_type(is_protocol(MyProtocol)) # revealed: Literal[True]
reveal_type(is_protocol(Bar0)) # revealed: Literal[True]
reveal_type(is_protocol(Bar1)) # revealed: Literal[True]
reveal_type(is_protocol(Bar2)) # revealed: Literal[True]
reveal_type(is_protocol(Bar3)) # revealed: Literal[True]
class NotAProtocol: ...
# TODO: should be `Literal[False]`
reveal_type(is_protocol(NotAProtocol)) # revealed: bool
reveal_type(is_protocol(NotAProtocol)) # revealed: Literal[False]
```
A type checker should follow the typeshed stubs if a non-class is passed in, and typeshed's stubs
indicate that the argument passed in must be an instance of `type`. `Literal[False]` should be
inferred as the return type, however.
indicate that the argument passed in must be an instance of `type`.
```py
# TODO: the diagnostic is correct, but should infer `Literal[False]`
# We could also reasonably infer `Literal[False]` here, but it probably doesn't matter that much:
# error: [invalid-argument-type]
reveal_type(is_protocol("not a class")) # revealed: bool
```
@ -74,8 +110,7 @@ class SubclassOfMyProtocol(MyProtocol): ...
# revealed: tuple[Literal[SubclassOfMyProtocol], Literal[MyProtocol], typing.Protocol, typing.Generic, Literal[object]]
reveal_type(SubclassOfMyProtocol.__mro__)
# TODO: should be `Literal[False]`
reveal_type(is_protocol(SubclassOfMyProtocol)) # revealed: bool
reveal_type(is_protocol(SubclassOfMyProtocol)) # revealed: Literal[False]
```
A protocol class may inherit from other protocols, however, as long as it re-inherits from
@ -84,8 +119,7 @@ A protocol class may inherit from other protocols, however, as long as it re-inh
```py
class SubProtocol(MyProtocol, Protocol): ...
# TODO: should be `Literal[True]`
reveal_type(is_protocol(SubProtocol)) # revealed: bool
reveal_type(is_protocol(SubProtocol)) # revealed: Literal[True]
class OtherProtocol(Protocol):
some_attribute: str
@ -95,8 +129,7 @@ class ComplexInheritance(SubProtocol, OtherProtocol, Protocol): ...
# revealed: tuple[Literal[ComplexInheritance], Literal[SubProtocol], Literal[MyProtocol], Literal[OtherProtocol], typing.Protocol, typing.Generic, Literal[object]]
reveal_type(ComplexInheritance.__mro__)
# TODO: should be `Literal[True]`
reveal_type(is_protocol(ComplexInheritance)) # revealed: bool
reveal_type(is_protocol(ComplexInheritance)) # revealed: Literal[True]
```
If `Protocol` is present in the bases tuple, all other bases in the tuple must be protocol classes,
@ -134,6 +167,8 @@ reveal_type(Fine.__mro__) # revealed: tuple[Literal[Fine], typing.Protocol, typ
class StillFine(Protocol, Generic[T], object): ...
class EvenThis[T](Protocol, object): ...
class OrThis(Protocol[T], Generic[T]): ...
class AndThis(Protocol[T], Generic[T], object): ...
```
And multiple inheritance from a mix of protocol and non-protocol classes is fine as long as
@ -150,8 +185,7 @@ But if `Protocol` is not present in the bases list, the resulting class doesn't
class anymore:
```py
# TODO: should reveal `Literal[False]`
reveal_type(is_protocol(FineAndDandy)) # revealed: bool
reveal_type(is_protocol(FineAndDandy)) # revealed: Literal[False]
```
A class does not *have* to inherit from a protocol class in order for it to be considered a subtype
@ -230,9 +264,10 @@ class Foo(typing.Protocol):
class Bar(typing_extensions.Protocol):
x: int
# TODO: these should pass
static_assert(typing_extensions.is_protocol(Foo)) # error: [static-assert-error]
static_assert(typing_extensions.is_protocol(Bar)) # error: [static-assert-error]
static_assert(typing_extensions.is_protocol(Foo))
static_assert(typing_extensions.is_protocol(Bar))
# TODO: should pass
static_assert(is_equivalent_to(Foo, Bar)) # error: [static-assert-error]
```
@ -247,9 +282,10 @@ class RuntimeCheckableFoo(typing.Protocol):
class RuntimeCheckableBar(typing_extensions.Protocol):
x: int
# TODO: these should pass
static_assert(typing_extensions.is_protocol(RuntimeCheckableFoo)) # error: [static-assert-error]
static_assert(typing_extensions.is_protocol(RuntimeCheckableBar)) # error: [static-assert-error]
static_assert(typing_extensions.is_protocol(RuntimeCheckableFoo))
static_assert(typing_extensions.is_protocol(RuntimeCheckableBar))
# TODO: should pass
static_assert(is_equivalent_to(RuntimeCheckableFoo, RuntimeCheckableBar)) # error: [static-assert-error]
# These should not error because the protocols are decorated with `@runtime_checkable`

View file

@ -535,6 +535,15 @@ impl<'db> Bindings<'db> {
}
}
Some(KnownFunction::IsProtocol) => {
if let [Some(ty)] = overload.parameter_types() {
overload.set_return_type(Type::BooleanLiteral(
ty.into_class_literal()
.is_some_and(|class| class.is_protocol(db)),
));
}
}
Some(KnownFunction::Overload) => {
// TODO: This can be removed once we understand legacy generics because the
// typeshed definition for `typing.overload` is an identity function.

View file

@ -582,6 +582,17 @@ impl<'db> ClassLiteralType<'db> {
.collect()
}
/// Determine if this class is a protocol.
pub(super) fn is_protocol(self, db: &'db dyn Db) -> bool {
self.explicit_bases(db).iter().any(|base| {
matches!(
base,
Type::KnownInstance(KnownInstanceType::Protocol)
| Type::Dynamic(DynamicType::SubscriptedProtocol)
)
})
}
/// Return the types of the decorators on this class
#[salsa::tracked(return_ref)]
fn decorators(self, db: &'db dyn Db) -> Box<[Type<'db>]> {

View file

@ -81,9 +81,9 @@ use crate::types::generics::GenericContext;
use crate::types::mro::MroErrorKind;
use crate::types::unpacker::{UnpackResult, Unpacker};
use crate::types::{
todo_type, CallDunderError, CallableSignature, CallableType, Class, ClassLiteralType,
ClassType, DataclassMetadata, DynamicType, FunctionDecorators, FunctionType, GenericAlias,
GenericClass, IntersectionBuilder, IntersectionType, KnownClass, KnownFunction,
binding_type, todo_type, CallDunderError, CallableSignature, CallableType, Class,
ClassLiteralType, ClassType, DataclassMetadata, DynamicType, FunctionDecorators, FunctionType,
GenericAlias, GenericClass, IntersectionBuilder, IntersectionType, KnownClass, KnownFunction,
KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, NonGenericClass, Parameter,
ParameterForm, Parameters, Signature, Signatures, SliceLiteralType, StringLiteralType,
SubclassOfType, Symbol, SymbolAndQualifiers, Truthiness, TupleType, Type, TypeAliasType,
@ -1224,7 +1224,7 @@ impl<'db> TypeInferenceBuilder<'db> {
/// Returns `true` if the current scope is the function body scope of a method of a protocol
/// (that is, a class which directly inherits `typing.Protocol`.)
fn in_class_that_inherits_protocol_directly(&self) -> bool {
fn in_protocol_class(&self) -> bool {
let current_scope_id = self.scope().file_scope_id(self.db());
let current_scope = self.index.scope(current_scope_id);
let Some(parent_scope_id) = current_scope.parent() else {
@ -1252,13 +1252,13 @@ impl<'db> TypeInferenceBuilder<'db> {
return false;
};
// TODO move this to `Class` once we add proper `Protocol` support
node_ref.bases().iter().any(|base| {
matches!(
self.file_expression_type(base),
Type::KnownInstance(KnownInstanceType::Protocol)
)
})
let class_definition = self.index.expect_single_definition(node_ref.node());
let Type::ClassLiteral(class) = binding_type(self.db(), class_definition) else {
return false;
};
class.is_protocol(self.db())
}
/// Returns `true` if the current scope is the function body scope of a function overload (that
@ -1322,7 +1322,7 @@ impl<'db> TypeInferenceBuilder<'db> {
if (self.in_stub()
|| self.in_function_overload_or_abstractmethod()
|| self.in_class_that_inherits_protocol_directly())
|| self.in_protocol_class())
&& self.return_types_and_ranges.is_empty()
&& is_stub_suite(&function.body)
{
@ -1625,7 +1625,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
} else if (self.in_stub()
|| self.in_function_overload_or_abstractmethod()
|| self.in_class_that_inherits_protocol_directly())
|| self.in_protocol_class())
&& default
.as_ref()
.is_some_and(|d| d.is_ellipsis_literal_expr())