bpo-37058: PEP 544: Add Protocol to typing module (GH-13585)

I tried to get rid of the `_ProtocolMeta`, but unfortunately it didn'y work. My idea to return a generic alias from `@runtime_checkable` made runtime protocols unpickleable. I am not sure what is worse (a custom metaclass or having some classes unpickleable), so I decided to stick with the status quo (since there were no complains so far). So essentially this is a copy of the implementation in `typing_extensions` with two modifications:
* Rename `@runtime` to `@runtime_checkable` (plus corresponding updates).
* Allow protocols that extend `collections.abc.Iterable` etc.
This commit is contained in:
Ivan Levkivskyi 2019-05-28 08:40:15 +01:00 committed by GitHub
parent 3880f263d2
commit 74d7f76e2c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 1053 additions and 119 deletions

View file

@ -12,8 +12,8 @@ from typing import T, KT, VT # Not in __all__.
from typing import Union, Optional, Literal
from typing import Tuple, List, MutableMapping
from typing import Callable
from typing import Generic, ClassVar, Final, final
from typing import cast
from typing import Generic, ClassVar, Final, final, Protocol
from typing import cast, runtime_checkable
from typing import get_type_hints
from typing import no_type_check, no_type_check_decorator
from typing import Type
@ -24,6 +24,7 @@ from typing import Pattern, Match
import abc
import typing
import weakref
import types
from test import mod_generics_cache
@ -585,7 +586,710 @@ class MySimpleMapping(SimpleMapping[XK, XV]):
return default
class Coordinate(Protocol):
x: int
y: int
@runtime_checkable
class Point(Coordinate, Protocol):
label: str
class MyPoint:
x: int
y: int
label: str
class XAxis(Protocol):
x: int
class YAxis(Protocol):
y: int
@runtime_checkable
class Position(XAxis, YAxis, Protocol):
pass
@runtime_checkable
class Proto(Protocol):
attr: int
def meth(self, arg: str) -> int:
...
class Concrete(Proto):
pass
class Other:
attr: int = 1
def meth(self, arg: str) -> int:
if arg == 'this':
return 1
return 0
class NT(NamedTuple):
x: int
y: int
@runtime_checkable
class HasCallProtocol(Protocol):
__call__: typing.Callable
class ProtocolTests(BaseTestCase):
def test_basic_protocol(self):
@runtime_checkable
class P(Protocol):
def meth(self):
pass
class C: pass
class D:
def meth(self):
pass
def f():
pass
self.assertIsSubclass(D, P)
self.assertIsInstance(D(), P)
self.assertNotIsSubclass(C, P)
self.assertNotIsInstance(C(), P)
self.assertNotIsSubclass(types.FunctionType, P)
self.assertNotIsInstance(f, P)
def test_everything_implements_empty_protocol(self):
@runtime_checkable
class Empty(Protocol):
pass
class C:
pass
def f():
pass
for thing in (object, type, tuple, C, types.FunctionType):
self.assertIsSubclass(thing, Empty)
for thing in (object(), 1, (), typing, f):
self.assertIsInstance(thing, Empty)
def test_function_implements_protocol(self):
def f():
pass
self.assertIsInstance(f, HasCallProtocol)
def test_no_inheritance_from_nominal(self):
class C: pass
class BP(Protocol): pass
with self.assertRaises(TypeError):
class P(C, Protocol):
pass
with self.assertRaises(TypeError):
class P(Protocol, C):
pass
with self.assertRaises(TypeError):
class P(BP, C, Protocol):
pass
class D(BP, C): pass
class E(C, BP): pass
self.assertNotIsInstance(D(), E)
self.assertNotIsInstance(E(), D)
def test_no_instantiation(self):
class P(Protocol): pass
with self.assertRaises(TypeError):
P()
class C(P): pass
self.assertIsInstance(C(), C)
T = TypeVar('T')
class PG(Protocol[T]): pass
with self.assertRaises(TypeError):
PG()
with self.assertRaises(TypeError):
PG[int]()
with self.assertRaises(TypeError):
PG[T]()
class CG(PG[T]): pass
self.assertIsInstance(CG[int](), CG)
def test_cannot_instantiate_abstract(self):
@runtime_checkable
class P(Protocol):
@abc.abstractmethod
def ameth(self) -> int:
raise NotImplementedError
class B(P):
pass
class C(B):
def ameth(self) -> int:
return 26
with self.assertRaises(TypeError):
B()
self.assertIsInstance(C(), P)
def test_subprotocols_extending(self):
class P1(Protocol):
def meth1(self):
pass
@runtime_checkable
class P2(P1, Protocol):
def meth2(self):
pass
class C:
def meth1(self):
pass
def meth2(self):
pass
class C1:
def meth1(self):
pass
class C2:
def meth2(self):
pass
self.assertNotIsInstance(C1(), P2)
self.assertNotIsInstance(C2(), P2)
self.assertNotIsSubclass(C1, P2)
self.assertNotIsSubclass(C2, P2)
self.assertIsInstance(C(), P2)
self.assertIsSubclass(C, P2)
def test_subprotocols_merging(self):
class P1(Protocol):
def meth1(self):
pass
class P2(Protocol):
def meth2(self):
pass
@runtime_checkable
class P(P1, P2, Protocol):
pass
class C:
def meth1(self):
pass
def meth2(self):
pass
class C1:
def meth1(self):
pass
class C2:
def meth2(self):
pass
self.assertNotIsInstance(C1(), P)
self.assertNotIsInstance(C2(), P)
self.assertNotIsSubclass(C1, P)
self.assertNotIsSubclass(C2, P)
self.assertIsInstance(C(), P)
self.assertIsSubclass(C, P)
def test_protocols_issubclass(self):
T = TypeVar('T')
@runtime_checkable
class P(Protocol):
def x(self): ...
@runtime_checkable
class PG(Protocol[T]):
def x(self): ...
class BadP(Protocol):
def x(self): ...
class BadPG(Protocol[T]):
def x(self): ...
class C:
def x(self): ...
self.assertIsSubclass(C, P)
self.assertIsSubclass(C, PG)
self.assertIsSubclass(BadP, PG)
with self.assertRaises(TypeError):
issubclass(C, PG[T])
with self.assertRaises(TypeError):
issubclass(C, PG[C])
with self.assertRaises(TypeError):
issubclass(C, BadP)
with self.assertRaises(TypeError):
issubclass(C, BadPG)
with self.assertRaises(TypeError):
issubclass(P, PG[T])
with self.assertRaises(TypeError):
issubclass(PG, PG[int])
def test_protocols_issubclass_non_callable(self):
class C:
x = 1
@runtime_checkable
class PNonCall(Protocol):
x = 1
with self.assertRaises(TypeError):
issubclass(C, PNonCall)
self.assertIsInstance(C(), PNonCall)
PNonCall.register(C)
with self.assertRaises(TypeError):
issubclass(C, PNonCall)
self.assertIsInstance(C(), PNonCall)
# check that non-protocol subclasses are not affected
class D(PNonCall): ...
self.assertNotIsSubclass(C, D)
self.assertNotIsInstance(C(), D)
D.register(C)
self.assertIsSubclass(C, D)
self.assertIsInstance(C(), D)
with self.assertRaises(TypeError):
issubclass(D, PNonCall)
def test_protocols_isinstance(self):
T = TypeVar('T')
@runtime_checkable
class P(Protocol):
def meth(x): ...
@runtime_checkable
class PG(Protocol[T]):
def meth(x): ...
class BadP(Protocol):
def meth(x): ...
class BadPG(Protocol[T]):
def meth(x): ...
class C:
def meth(x): ...
self.assertIsInstance(C(), P)
self.assertIsInstance(C(), PG)
with self.assertRaises(TypeError):
isinstance(C(), PG[T])
with self.assertRaises(TypeError):
isinstance(C(), PG[C])
with self.assertRaises(TypeError):
isinstance(C(), BadP)
with self.assertRaises(TypeError):
isinstance(C(), BadPG)
def test_protocols_isinstance_py36(self):
class APoint:
def __init__(self, x, y, label):
self.x = x
self.y = y
self.label = label
class BPoint:
label = 'B'
def __init__(self, x, y):
self.x = x
self.y = y
class C:
def __init__(self, attr):
self.attr = attr
def meth(self, arg):
return 0
class Bad: pass
self.assertIsInstance(APoint(1, 2, 'A'), Point)
self.assertIsInstance(BPoint(1, 2), Point)
self.assertNotIsInstance(MyPoint(), Point)
self.assertIsInstance(BPoint(1, 2), Position)
self.assertIsInstance(Other(), Proto)
self.assertIsInstance(Concrete(), Proto)
self.assertIsInstance(C(42), Proto)
self.assertNotIsInstance(Bad(), Proto)
self.assertNotIsInstance(Bad(), Point)
self.assertNotIsInstance(Bad(), Position)
self.assertNotIsInstance(Bad(), Concrete)
self.assertNotIsInstance(Other(), Concrete)
self.assertIsInstance(NT(1, 2), Position)
def test_protocols_isinstance_init(self):
T = TypeVar('T')
@runtime_checkable
class P(Protocol):
x = 1
@runtime_checkable
class PG(Protocol[T]):
x = 1
class C:
def __init__(self, x):
self.x = x
self.assertIsInstance(C(1), P)
self.assertIsInstance(C(1), PG)
def test_protocols_support_register(self):
@runtime_checkable
class P(Protocol):
x = 1
class PM(Protocol):
def meth(self): pass
class D(PM): pass
class C: pass
D.register(C)
P.register(C)
self.assertIsInstance(C(), P)
self.assertIsInstance(C(), D)
def test_none_on_non_callable_doesnt_block_implementation(self):
@runtime_checkable
class P(Protocol):
x = 1
class A:
x = 1
class B(A):
x = None
class C:
def __init__(self):
self.x = None
self.assertIsInstance(B(), P)
self.assertIsInstance(C(), P)
def test_none_on_callable_blocks_implementation(self):
@runtime_checkable
class P(Protocol):
def x(self): ...
class A:
def x(self): ...
class B(A):
x = None
class C:
def __init__(self):
self.x = None
self.assertNotIsInstance(B(), P)
self.assertNotIsInstance(C(), P)
def test_non_protocol_subclasses(self):
class P(Protocol):
x = 1
@runtime_checkable
class PR(Protocol):
def meth(self): pass
class NonP(P):
x = 1
class NonPR(PR): pass
class C:
x = 1
class D:
def meth(self): pass
self.assertNotIsInstance(C(), NonP)
self.assertNotIsInstance(D(), NonPR)
self.assertNotIsSubclass(C, NonP)
self.assertNotIsSubclass(D, NonPR)
self.assertIsInstance(NonPR(), PR)
self.assertIsSubclass(NonPR, PR)
def test_custom_subclasshook(self):
class P(Protocol):
x = 1
class OKClass: pass
class BadClass:
x = 1
class C(P):
@classmethod
def __subclasshook__(cls, other):
return other.__name__.startswith("OK")
self.assertIsInstance(OKClass(), C)
self.assertNotIsInstance(BadClass(), C)
self.assertIsSubclass(OKClass, C)
self.assertNotIsSubclass(BadClass, C)
def test_issubclass_fails_correctly(self):
@runtime_checkable
class P(Protocol):
x = 1
class C: pass
with self.assertRaises(TypeError):
issubclass(C(), P)
def test_defining_generic_protocols(self):
T = TypeVar('T')
S = TypeVar('S')
@runtime_checkable
class PR(Protocol[T, S]):
def meth(self): pass
class P(PR[int, T], Protocol[T]):
y = 1
with self.assertRaises(TypeError):
PR[int]
with self.assertRaises(TypeError):
P[int, str]
with self.assertRaises(TypeError):
PR[int, 1]
with self.assertRaises(TypeError):
PR[int, ClassVar]
class C(PR[int, T]): pass
self.assertIsInstance(C[str](), C)
def test_defining_generic_protocols_old_style(self):
T = TypeVar('T')
S = TypeVar('S')
@runtime_checkable
class PR(Protocol, Generic[T, S]):
def meth(self): pass
class P(PR[int, str], Protocol):
y = 1
with self.assertRaises(TypeError):
issubclass(PR[int, str], PR)
self.assertIsSubclass(P, PR)
with self.assertRaises(TypeError):
PR[int]
with self.assertRaises(TypeError):
PR[int, 1]
class P1(Protocol, Generic[T]):
def bar(self, x: T) -> str: ...
class P2(Generic[T], Protocol):
def bar(self, x: T) -> str: ...
@runtime_checkable
class PSub(P1[str], Protocol):
x = 1
class Test:
x = 1
def bar(self, x: str) -> str:
return x
self.assertIsInstance(Test(), PSub)
with self.assertRaises(TypeError):
PR[int, ClassVar]
def test_init_called(self):
T = TypeVar('T')
class P(Protocol[T]): pass
class C(P[T]):
def __init__(self):
self.test = 'OK'
self.assertEqual(C[int]().test, 'OK')
def test_protocols_bad_subscripts(self):
T = TypeVar('T')
S = TypeVar('S')
with self.assertRaises(TypeError):
class P(Protocol[T, T]): pass
with self.assertRaises(TypeError):
class P(Protocol[int]): pass
with self.assertRaises(TypeError):
class P(Protocol[T], Protocol[S]): pass
with self.assertRaises(TypeError):
class P(typing.Mapping[T, S], Protocol[T]): pass
def test_generic_protocols_repr(self):
T = TypeVar('T')
S = TypeVar('S')
class P(Protocol[T, S]): pass
self.assertTrue(repr(P[T, S]).endswith('P[~T, ~S]'))
self.assertTrue(repr(P[int, str]).endswith('P[int, str]'))
def test_generic_protocols_eq(self):
T = TypeVar('T')
S = TypeVar('S')
class P(Protocol[T, S]): pass
self.assertEqual(P, P)
self.assertEqual(P[int, T], P[int, T])
self.assertEqual(P[T, T][Tuple[T, S]][int, str],
P[Tuple[int, str], Tuple[int, str]])
def test_generic_protocols_special_from_generic(self):
T = TypeVar('T')
class P(Protocol[T]): pass
self.assertEqual(P.__parameters__, (T,))
self.assertEqual(P[int].__parameters__, ())
self.assertEqual(P[int].__args__, (int,))
self.assertIs(P[int].__origin__, P)
def test_generic_protocols_special_from_protocol(self):
@runtime_checkable
class PR(Protocol):
x = 1
class P(Protocol):
def meth(self):
pass
T = TypeVar('T')
class PG(Protocol[T]):
x = 1
def meth(self):
pass
self.assertTrue(P._is_protocol)
self.assertTrue(PR._is_protocol)
self.assertTrue(PG._is_protocol)
self.assertFalse(P._is_runtime_protocol)
self.assertTrue(PR._is_runtime_protocol)
self.assertTrue(PG[int]._is_protocol)
self.assertEqual(typing._get_protocol_attrs(P), {'meth'})
self.assertEqual(typing._get_protocol_attrs(PR), {'x'})
self.assertEqual(frozenset(typing._get_protocol_attrs(PG)),
frozenset({'x', 'meth'}))
def test_no_runtime_deco_on_nominal(self):
with self.assertRaises(TypeError):
@runtime_checkable
class C: pass
class Proto(Protocol):
x = 1
with self.assertRaises(TypeError):
@runtime_checkable
class Concrete(Proto):
pass
def test_none_treated_correctly(self):
@runtime_checkable
class P(Protocol):
x = None # type: int
class B(object): pass
self.assertNotIsInstance(B(), P)
class C:
x = 1
class D:
x = None
self.assertIsInstance(C(), P)
self.assertIsInstance(D(), P)
class CI:
def __init__(self):
self.x = 1
class DI:
def __init__(self):
self.x = None
self.assertIsInstance(C(), P)
self.assertIsInstance(D(), P)
def test_protocols_in_unions(self):
class P(Protocol):
x = None # type: int
Alias = typing.Union[typing.Iterable, P]
Alias2 = typing.Union[P, typing.Iterable]
self.assertEqual(Alias, Alias2)
def test_protocols_pickleable(self):
global P, CP # pickle wants to reference the class by name
T = TypeVar('T')
@runtime_checkable
class P(Protocol[T]):
x = 1
class CP(P[int]):
pass
c = CP()
c.foo = 42
c.bar = 'abc'
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
z = pickle.dumps(c, proto)
x = pickle.loads(z)
self.assertEqual(x.foo, 42)
self.assertEqual(x.bar, 'abc')
self.assertEqual(x.x, 1)
self.assertEqual(x.__dict__, {'foo': 42, 'bar': 'abc'})
s = pickle.dumps(P)
D = pickle.loads(s)
class E:
x = 1
self.assertIsInstance(E(), D)
def test_supports_int(self):
self.assertIsSubclass(int, typing.SupportsInt)
@ -634,9 +1338,8 @@ class ProtocolTests(BaseTestCase):
self.assertIsSubclass(int, typing.SupportsIndex)
self.assertNotIsSubclass(str, typing.SupportsIndex)
def test_protocol_instance_type_error(self):
with self.assertRaises(TypeError):
isinstance(0, typing.SupportsAbs)
def test_bundled_protocol_instance_works(self):
self.assertIsInstance(0, typing.SupportsAbs)
class C1(typing.SupportsInt):
def __int__(self) -> int:
return 42
@ -645,6 +1348,20 @@ class ProtocolTests(BaseTestCase):
c = C2()
self.assertIsInstance(c, C1)
def test_collections_protocols_allowed(self):
@runtime_checkable
class Custom(collections.abc.Iterable, Protocol):
def close(self): ...
class A: pass
class B:
def __iter__(self):
return []
def close(self):
return 0
self.assertIsSubclass(B, Custom)
self.assertNotIsSubclass(A, Custom)
class GenericTests(BaseTestCase):
@ -771,7 +1488,7 @@ class GenericTests(BaseTestCase):
def test_new_repr_bare(self):
T = TypeVar('T')
self.assertEqual(repr(Generic[T]), 'typing.Generic[~T]')
self.assertEqual(repr(typing._Protocol[T]), 'typing._Protocol[~T]')
self.assertEqual(repr(typing.Protocol[T]), 'typing.Protocol[~T]')
class C(typing.Dict[Any, Any]): ...
# this line should just work
repr(C.__mro__)
@ -1067,7 +1784,7 @@ class GenericTests(BaseTestCase):
with self.assertRaises(TypeError):
Tuple[Generic[T]]
with self.assertRaises(TypeError):
List[typing._Protocol]
List[typing.Protocol]
def test_type_erasure_special(self):
T = TypeVar('T')