mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-04 10:48:32 +00:00
[red-knot] Correctly identify protocol classes (#17487)
This commit is contained in:
parent
c077b109ce
commit
9ff4772a2c
5 changed files with 90 additions and 36 deletions
|
@ -74,8 +74,6 @@ class Baz(Bar):
|
|||
T = TypeVar("T")
|
||||
|
||||
class Qux(Protocol[T]):
|
||||
# TODO: no error
|
||||
# error: [invalid-return-type]
|
||||
def f(self) -> int: ...
|
||||
|
||||
class Foo(Protocol):
|
||||
|
|
|
@ -40,27 +40,63 @@ class Foo(Protocol, Protocol): ... # error: [inconsistent-mro]
|
|||
reveal_type(Foo.__mro__) # revealed: tuple[Literal[Foo], Unknown, Literal[object]]
|
||||
```
|
||||
|
||||
Protocols can also be generic, either by including `Generic[]` in the bases list, subscripting
|
||||
`Protocol` directly in the bases list, using PEP-695 type parameters, or some combination of the
|
||||
above:
|
||||
|
||||
```py
|
||||
from typing import TypeVar, Generic
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class Bar0(Protocol[T]):
|
||||
x: T
|
||||
|
||||
class Bar1(Protocol[T], Generic[T]):
|
||||
x: T
|
||||
|
||||
class Bar2[T](Protocol):
|
||||
x: T
|
||||
|
||||
class Bar3[T](Protocol[T]):
|
||||
x: T
|
||||
```
|
||||
|
||||
It's an error to include both bare `Protocol` and subscripted `Protocol[]` in the bases list
|
||||
simultaneously:
|
||||
|
||||
```py
|
||||
# TODO: should emit a `[duplicate-bases]` error here:
|
||||
class DuplicateBases(Protocol, Protocol[T]):
|
||||
x: T
|
||||
|
||||
# TODO: should not have `Generic` multiple times and `Protocol` multiple times
|
||||
# revealed: tuple[Literal[DuplicateBases], typing.Protocol, typing.Generic, @Todo(`Protocol[]` subscript), @Todo(`Generic[]` subscript), Literal[object]]
|
||||
reveal_type(DuplicateBases.__mro__)
|
||||
```
|
||||
|
||||
The introspection helper `typing(_extensions).is_protocol` can be used to verify whether a class is
|
||||
a protocol class or not:
|
||||
|
||||
```py
|
||||
from typing_extensions import is_protocol
|
||||
|
||||
# TODO: should be `Literal[True]`
|
||||
reveal_type(is_protocol(MyProtocol)) # revealed: bool
|
||||
reveal_type(is_protocol(MyProtocol)) # revealed: Literal[True]
|
||||
reveal_type(is_protocol(Bar0)) # revealed: Literal[True]
|
||||
reveal_type(is_protocol(Bar1)) # revealed: Literal[True]
|
||||
reveal_type(is_protocol(Bar2)) # revealed: Literal[True]
|
||||
reveal_type(is_protocol(Bar3)) # revealed: Literal[True]
|
||||
|
||||
class NotAProtocol: ...
|
||||
|
||||
# TODO: should be `Literal[False]`
|
||||
reveal_type(is_protocol(NotAProtocol)) # revealed: bool
|
||||
reveal_type(is_protocol(NotAProtocol)) # revealed: Literal[False]
|
||||
```
|
||||
|
||||
A type checker should follow the typeshed stubs if a non-class is passed in, and typeshed's stubs
|
||||
indicate that the argument passed in must be an instance of `type`. `Literal[False]` should be
|
||||
inferred as the return type, however.
|
||||
indicate that the argument passed in must be an instance of `type`.
|
||||
|
||||
```py
|
||||
# TODO: the diagnostic is correct, but should infer `Literal[False]`
|
||||
# We could also reasonably infer `Literal[False]` here, but it probably doesn't matter that much:
|
||||
# error: [invalid-argument-type]
|
||||
reveal_type(is_protocol("not a class")) # revealed: bool
|
||||
```
|
||||
|
@ -74,8 +110,7 @@ class SubclassOfMyProtocol(MyProtocol): ...
|
|||
# revealed: tuple[Literal[SubclassOfMyProtocol], Literal[MyProtocol], typing.Protocol, typing.Generic, Literal[object]]
|
||||
reveal_type(SubclassOfMyProtocol.__mro__)
|
||||
|
||||
# TODO: should be `Literal[False]`
|
||||
reveal_type(is_protocol(SubclassOfMyProtocol)) # revealed: bool
|
||||
reveal_type(is_protocol(SubclassOfMyProtocol)) # revealed: Literal[False]
|
||||
```
|
||||
|
||||
A protocol class may inherit from other protocols, however, as long as it re-inherits from
|
||||
|
@ -84,8 +119,7 @@ A protocol class may inherit from other protocols, however, as long as it re-inh
|
|||
```py
|
||||
class SubProtocol(MyProtocol, Protocol): ...
|
||||
|
||||
# TODO: should be `Literal[True]`
|
||||
reveal_type(is_protocol(SubProtocol)) # revealed: bool
|
||||
reveal_type(is_protocol(SubProtocol)) # revealed: Literal[True]
|
||||
|
||||
class OtherProtocol(Protocol):
|
||||
some_attribute: str
|
||||
|
@ -95,8 +129,7 @@ class ComplexInheritance(SubProtocol, OtherProtocol, Protocol): ...
|
|||
# revealed: tuple[Literal[ComplexInheritance], Literal[SubProtocol], Literal[MyProtocol], Literal[OtherProtocol], typing.Protocol, typing.Generic, Literal[object]]
|
||||
reveal_type(ComplexInheritance.__mro__)
|
||||
|
||||
# TODO: should be `Literal[True]`
|
||||
reveal_type(is_protocol(ComplexInheritance)) # revealed: bool
|
||||
reveal_type(is_protocol(ComplexInheritance)) # revealed: Literal[True]
|
||||
```
|
||||
|
||||
If `Protocol` is present in the bases tuple, all other bases in the tuple must be protocol classes,
|
||||
|
@ -134,6 +167,8 @@ reveal_type(Fine.__mro__) # revealed: tuple[Literal[Fine], typing.Protocol, typ
|
|||
|
||||
class StillFine(Protocol, Generic[T], object): ...
|
||||
class EvenThis[T](Protocol, object): ...
|
||||
class OrThis(Protocol[T], Generic[T]): ...
|
||||
class AndThis(Protocol[T], Generic[T], object): ...
|
||||
```
|
||||
|
||||
And multiple inheritance from a mix of protocol and non-protocol classes is fine as long as
|
||||
|
@ -150,8 +185,7 @@ But if `Protocol` is not present in the bases list, the resulting class doesn't
|
|||
class anymore:
|
||||
|
||||
```py
|
||||
# TODO: should reveal `Literal[False]`
|
||||
reveal_type(is_protocol(FineAndDandy)) # revealed: bool
|
||||
reveal_type(is_protocol(FineAndDandy)) # revealed: Literal[False]
|
||||
```
|
||||
|
||||
A class does not *have* to inherit from a protocol class in order for it to be considered a subtype
|
||||
|
@ -230,9 +264,10 @@ class Foo(typing.Protocol):
|
|||
class Bar(typing_extensions.Protocol):
|
||||
x: int
|
||||
|
||||
# TODO: these should pass
|
||||
static_assert(typing_extensions.is_protocol(Foo)) # error: [static-assert-error]
|
||||
static_assert(typing_extensions.is_protocol(Bar)) # error: [static-assert-error]
|
||||
static_assert(typing_extensions.is_protocol(Foo))
|
||||
static_assert(typing_extensions.is_protocol(Bar))
|
||||
|
||||
# TODO: should pass
|
||||
static_assert(is_equivalent_to(Foo, Bar)) # error: [static-assert-error]
|
||||
```
|
||||
|
||||
|
@ -247,9 +282,10 @@ class RuntimeCheckableFoo(typing.Protocol):
|
|||
class RuntimeCheckableBar(typing_extensions.Protocol):
|
||||
x: int
|
||||
|
||||
# TODO: these should pass
|
||||
static_assert(typing_extensions.is_protocol(RuntimeCheckableFoo)) # error: [static-assert-error]
|
||||
static_assert(typing_extensions.is_protocol(RuntimeCheckableBar)) # error: [static-assert-error]
|
||||
static_assert(typing_extensions.is_protocol(RuntimeCheckableFoo))
|
||||
static_assert(typing_extensions.is_protocol(RuntimeCheckableBar))
|
||||
|
||||
# TODO: should pass
|
||||
static_assert(is_equivalent_to(RuntimeCheckableFoo, RuntimeCheckableBar)) # error: [static-assert-error]
|
||||
|
||||
# These should not error because the protocols are decorated with `@runtime_checkable`
|
||||
|
|
|
@ -535,6 +535,15 @@ impl<'db> Bindings<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
Some(KnownFunction::IsProtocol) => {
|
||||
if let [Some(ty)] = overload.parameter_types() {
|
||||
overload.set_return_type(Type::BooleanLiteral(
|
||||
ty.into_class_literal()
|
||||
.is_some_and(|class| class.is_protocol(db)),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Some(KnownFunction::Overload) => {
|
||||
// TODO: This can be removed once we understand legacy generics because the
|
||||
// typeshed definition for `typing.overload` is an identity function.
|
||||
|
|
|
@ -582,6 +582,17 @@ impl<'db> ClassLiteralType<'db> {
|
|||
.collect()
|
||||
}
|
||||
|
||||
/// Determine if this class is a protocol.
|
||||
pub(super) fn is_protocol(self, db: &'db dyn Db) -> bool {
|
||||
self.explicit_bases(db).iter().any(|base| {
|
||||
matches!(
|
||||
base,
|
||||
Type::KnownInstance(KnownInstanceType::Protocol)
|
||||
| Type::Dynamic(DynamicType::SubscriptedProtocol)
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Return the types of the decorators on this class
|
||||
#[salsa::tracked(return_ref)]
|
||||
fn decorators(self, db: &'db dyn Db) -> Box<[Type<'db>]> {
|
||||
|
|
|
@ -81,9 +81,9 @@ use crate::types::generics::GenericContext;
|
|||
use crate::types::mro::MroErrorKind;
|
||||
use crate::types::unpacker::{UnpackResult, Unpacker};
|
||||
use crate::types::{
|
||||
todo_type, CallDunderError, CallableSignature, CallableType, Class, ClassLiteralType,
|
||||
ClassType, DataclassMetadata, DynamicType, FunctionDecorators, FunctionType, GenericAlias,
|
||||
GenericClass, IntersectionBuilder, IntersectionType, KnownClass, KnownFunction,
|
||||
binding_type, todo_type, CallDunderError, CallableSignature, CallableType, Class,
|
||||
ClassLiteralType, ClassType, DataclassMetadata, DynamicType, FunctionDecorators, FunctionType,
|
||||
GenericAlias, GenericClass, IntersectionBuilder, IntersectionType, KnownClass, KnownFunction,
|
||||
KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, NonGenericClass, Parameter,
|
||||
ParameterForm, Parameters, Signature, Signatures, SliceLiteralType, StringLiteralType,
|
||||
SubclassOfType, Symbol, SymbolAndQualifiers, Truthiness, TupleType, Type, TypeAliasType,
|
||||
|
@ -1224,7 +1224,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
|
||||
/// Returns `true` if the current scope is the function body scope of a method of a protocol
|
||||
/// (that is, a class which directly inherits `typing.Protocol`.)
|
||||
fn in_class_that_inherits_protocol_directly(&self) -> bool {
|
||||
fn in_protocol_class(&self) -> bool {
|
||||
let current_scope_id = self.scope().file_scope_id(self.db());
|
||||
let current_scope = self.index.scope(current_scope_id);
|
||||
let Some(parent_scope_id) = current_scope.parent() else {
|
||||
|
@ -1252,13 +1252,13 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
return false;
|
||||
};
|
||||
|
||||
// TODO move this to `Class` once we add proper `Protocol` support
|
||||
node_ref.bases().iter().any(|base| {
|
||||
matches!(
|
||||
self.file_expression_type(base),
|
||||
Type::KnownInstance(KnownInstanceType::Protocol)
|
||||
)
|
||||
})
|
||||
let class_definition = self.index.expect_single_definition(node_ref.node());
|
||||
|
||||
let Type::ClassLiteral(class) = binding_type(self.db(), class_definition) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
class.is_protocol(self.db())
|
||||
}
|
||||
|
||||
/// Returns `true` if the current scope is the function body scope of a function overload (that
|
||||
|
@ -1322,7 +1322,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
|
||||
if (self.in_stub()
|
||||
|| self.in_function_overload_or_abstractmethod()
|
||||
|| self.in_class_that_inherits_protocol_directly())
|
||||
|| self.in_protocol_class())
|
||||
&& self.return_types_and_ranges.is_empty()
|
||||
&& is_stub_suite(&function.body)
|
||||
{
|
||||
|
@ -1625,7 +1625,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
}
|
||||
} else if (self.in_stub()
|
||||
|| self.in_function_overload_or_abstractmethod()
|
||||
|| self.in_class_that_inherits_protocol_directly())
|
||||
|| self.in_protocol_class())
|
||||
&& default
|
||||
.as_ref()
|
||||
.is_some_and(|d| d.is_ellipsis_literal_expr())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue