bpo-37058: PEP 544: Add Protocol to typing module (GH-13585)

I tried to get rid of the `_ProtocolMeta`, but unfortunately it didn'y work. My idea to return a generic alias from `@runtime_checkable` made runtime protocols unpickleable. I am not sure what is worse (a custom metaclass or having some classes unpickleable), so I decided to stick with the status quo (since there were no complains so far). So essentially this is a copy of the implementation in `typing_extensions` with two modifications:
* Rename `@runtime` to `@runtime_checkable` (plus corresponding updates).
* Allow protocols that extend `collections.abc.Iterable` etc.
This commit is contained in:
Ivan Levkivskyi 2019-05-28 08:40:15 +01:00 committed by GitHub
parent 3880f263d2
commit 74d7f76e2c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 1053 additions and 119 deletions

View file

@ -9,8 +9,7 @@ At large scale, the structure of the module is following:
* The core of internal generics API: _GenericAlias and _VariadicGenericAlias, the latter is
currently only used by Tuple and Callable. All subscripted types like X[int], Union[int, str],
etc., are instances of either of these classes.
* The public counterpart of the generics API consists of two classes: Generic and Protocol
(the latter is currently private, but will be made public after PEP 544 acceptance).
* The public counterpart of the generics API consists of two classes: Generic and Protocol.
* Public helper functions: get_type_hints, overload, cast, no_type_check,
no_type_check_decorator.
* Generic aliases for collections.abc ABCs and few additional protocols.
@ -18,7 +17,7 @@ At large scale, the structure of the module is following:
* Wrapper submodules for re and io related types.
"""
from abc import abstractmethod, abstractproperty
from abc import abstractmethod, abstractproperty, ABCMeta
import collections
import collections.abc
import contextlib
@ -39,6 +38,7 @@ __all__ = [
'Generic',
'Literal',
'Optional',
'Protocol',
'Tuple',
'Type',
'TypeVar',
@ -102,6 +102,7 @@ __all__ = [
'no_type_check_decorator',
'NoReturn',
'overload',
'runtime_checkable',
'Text',
'TYPE_CHECKING',
]
@ -123,7 +124,7 @@ def _type_check(arg, msg, is_argument=True):
We append the repr() of the actual value (truncated to 100 chars).
"""
invalid_generic_forms = (Generic, _Protocol)
invalid_generic_forms = (Generic, Protocol)
if is_argument:
invalid_generic_forms = invalid_generic_forms + (ClassVar, Final)
@ -135,7 +136,7 @@ def _type_check(arg, msg, is_argument=True):
arg.__origin__ in invalid_generic_forms):
raise TypeError(f"{arg} is not valid as type argument")
if (isinstance(arg, _SpecialForm) and arg not in (Any, NoReturn) or
arg in (Generic, _Protocol)):
arg in (Generic, Protocol)):
raise TypeError(f"Plain {arg} is not valid as type argument")
if isinstance(arg, (type, TypeVar, ForwardRef)):
return arg
@ -665,8 +666,8 @@ class _GenericAlias(_Final, _root=True):
@_tp_cache
def __getitem__(self, params):
if self.__origin__ in (Generic, _Protocol):
# Can't subscript Generic[...] or _Protocol[...].
if self.__origin__ in (Generic, Protocol):
# Can't subscript Generic[...] or Protocol[...].
raise TypeError(f"Cannot subscript already-subscripted {self}")
if not isinstance(params, tuple):
params = (params,)
@ -733,6 +734,8 @@ class _GenericAlias(_Final, _root=True):
res.append(Generic)
return tuple(res)
if self.__origin__ is Generic:
if Protocol in bases:
return ()
i = bases.index(self)
for b in bases[i+1:]:
if isinstance(b, _GenericAlias) and b is not self:
@ -850,10 +853,11 @@ class Generic:
return default
"""
__slots__ = ()
_is_protocol = False
def __new__(cls, *args, **kwds):
if cls is Generic:
raise TypeError("Type Generic cannot be instantiated; "
if cls in (Generic, Protocol):
raise TypeError(f"Type {cls.__name__} cannot be instantiated; "
"it can be used only as a base class")
if super().__new__ is object.__new__ and cls.__init__ is not object.__init__:
obj = super().__new__(cls)
@ -870,17 +874,14 @@ class Generic:
f"Parameter list to {cls.__qualname__}[...] cannot be empty")
msg = "Parameters to generic types must be types."
params = tuple(_type_check(p, msg) for p in params)
if cls is Generic:
# Generic can only be subscripted with unique type variables.
if cls in (Generic, Protocol):
# Generic and Protocol can only be subscripted with unique type variables.
if not all(isinstance(p, TypeVar) for p in params):
raise TypeError(
"Parameters to Generic[...] must all be type variables")
f"Parameters to {cls.__name__}[...] must all be type variables")
if len(set(params)) != len(params):
raise TypeError(
"Parameters to Generic[...] must all be unique")
elif cls is _Protocol:
# _Protocol is internal at the moment, just skip the check
pass
f"Parameters to {cls.__name__}[...] must all be unique")
else:
# Subscripting a regular Generic subclass.
_check_generic(cls, params)
@ -892,7 +893,7 @@ class Generic:
if '__orig_bases__' in cls.__dict__:
error = Generic in cls.__orig_bases__
else:
error = Generic in cls.__bases__ and cls.__name__ != '_Protocol'
error = Generic in cls.__bases__ and cls.__name__ != 'Protocol'
if error:
raise TypeError("Cannot inherit from plain Generic")
if '__orig_bases__' in cls.__dict__:
@ -910,9 +911,7 @@ class Generic:
raise TypeError(
"Cannot inherit from Generic[...] multiple types.")
gvars = base.__parameters__
if gvars is None:
gvars = tvars
else:
if gvars is not None:
tvarset = set(tvars)
gvarset = set(gvars)
if not tvarset <= gvarset:
@ -935,6 +934,204 @@ class _TypingEllipsis:
"""Internal placeholder for ... (ellipsis)."""
_TYPING_INTERNALS = ['__parameters__', '__orig_bases__', '__orig_class__',
'_is_protocol', '_is_runtime_protocol']
_SPECIAL_NAMES = ['__abstractmethods__', '__annotations__', '__dict__', '__doc__',
'__init__', '__module__', '__new__', '__slots__',
'__subclasshook__', '__weakref__']
# These special attributes will be not collected as protocol members.
EXCLUDED_ATTRIBUTES = _TYPING_INTERNALS + _SPECIAL_NAMES + ['_MutableMapping__marker']
def _get_protocol_attrs(cls):
"""Collect protocol members from a protocol class objects.
This includes names actually defined in the class dictionary, as well
as names that appear in annotations. Special names (above) are skipped.
"""
attrs = set()
for base in cls.__mro__[:-1]: # without object
if base.__name__ in ('Protocol', 'Generic'):
continue
annotations = getattr(base, '__annotations__', {})
for attr in list(base.__dict__.keys()) + list(annotations.keys()):
if not attr.startswith('_abc_') and attr not in EXCLUDED_ATTRIBUTES:
attrs.add(attr)
return attrs
def _is_callable_members_only(cls):
# PEP 544 prohibits using issubclass() with protocols that have non-method members.
return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls))
def _no_init(self, *args, **kwargs):
if type(self)._is_protocol:
raise TypeError('Protocols cannot be instantiated')
def _allow_reckless_class_cheks():
"""Allow instnance and class checks for special stdlib modules.
The abc and functools modules indiscriminately call isinstance() and
issubclass() on the whole MRO of a user class, which may contain protocols.
"""
try:
return sys._getframe(3).f_globals['__name__'] in ['abc', 'functools']
except (AttributeError, ValueError): # For platforms without _getframe().
return True
_PROTO_WHITELIST = ['Callable', 'Awaitable',
'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator',
'Hashable', 'Sized', 'Container', 'Collection', 'Reversible',
'ContextManager', 'AsyncContextManager']
class _ProtocolMeta(ABCMeta):
# This metaclass is really unfortunate and exists only because of
# the lack of __instancehook__.
def __instancecheck__(cls, instance):
# We need this method for situations where attributes are
# assigned in __init__.
if ((not getattr(cls, '_is_protocol', False) or
_is_callable_members_only(cls)) and
issubclass(instance.__class__, cls)):
return True
if cls._is_protocol:
if all(hasattr(instance, attr) and
# All *methods* can be blocked by setting them to None.
(not callable(getattr(cls, attr, None)) or
getattr(instance, attr) is not None)
for attr in _get_protocol_attrs(cls)):
return True
return super().__instancecheck__(instance)
class Protocol(Generic, metaclass=_ProtocolMeta):
"""Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize
structural subtyping (static duck-typing), for example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with
@typing.runtime_checkable act as simple-minded runtime protocols that check
only the presence of given attributes, ignoring their type signatures.
Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
"""
__slots__ = ()
_is_protocol = True
_is_runtime_protocol = False
def __init_subclass__(cls, *args, **kwargs):
super().__init_subclass__(*args, **kwargs)
# Determine if this is a protocol or a concrete subclass.
if not cls.__dict__.get('_is_protocol', False):
cls._is_protocol = any(b is Protocol for b in cls.__bases__)
# Set (or override) the protocol subclass hook.
def _proto_hook(other):
if not cls.__dict__.get('_is_protocol', False):
return NotImplemented
# First, perform various sanity checks.
if not getattr(cls, '_is_runtime_protocol', False):
if _allow_reckless_class_cheks():
return NotImplemented
raise TypeError("Instance and class checks can only be used with"
" @runtime_checkable protocols")
if not _is_callable_members_only(cls):
if _allow_reckless_class_cheks():
return NotImplemented
raise TypeError("Protocols with non-method members"
" don't support issubclass()")
if not isinstance(other, type):
# Same error message as for issubclass(1, int).
raise TypeError('issubclass() arg 1 must be a class')
# Second, perform the actual structural compatibility check.
for attr in _get_protocol_attrs(cls):
for base in other.__mro__:
# Check if the members appears in the class dictionary...
if attr in base.__dict__:
if base.__dict__[attr] is None:
return NotImplemented
break
# ...or in annotations, if it is a sub-protocol.
annotations = getattr(base, '__annotations__', {})
if (isinstance(annotations, collections.abc.Mapping) and
attr in annotations and
issubclass(other, Generic) and other._is_protocol):
break
else:
return NotImplemented
return True
if '__subclasshook__' not in cls.__dict__:
cls.__subclasshook__ = _proto_hook
# We have nothing more to do for non-protocols...
if not cls._is_protocol:
return
# ... otherwise check consistency of bases, and prohibit instantiation.
for base in cls.__bases__:
if not (base in (object, Generic) or
base.__module__ == 'collections.abc' and base.__name__ in _PROTO_WHITELIST or
issubclass(base, Generic) and base._is_protocol):
raise TypeError('Protocols can only inherit from other'
' protocols, got %r' % base)
cls.__init__ = _no_init
def runtime_checkable(cls):
"""Mark a protocol class as a runtime protocol.
Such protocol can be used with isinstance() and issubclass().
Raise TypeError if applied to a non-protocol class.
This allows a simple-minded structural check very similar to
one trick ponies in collections.abc such as Iterable.
For example::
@runtime_checkable
class Closable(Protocol):
def close(self): ...
assert isinstance(open('/some/file'), Closable)
Warning: this will check only the presence of the required methods,
not their type signatures!
"""
if not issubclass(cls, Generic) or not cls._is_protocol:
raise TypeError('@runtime_checkable can be only applied to protocol classes,'
' got %r' % cls)
cls._is_runtime_protocol = True
return cls
def cast(typ, val):
"""Cast a value to a type.
@ -1159,90 +1356,6 @@ def final(f):
return f
class _ProtocolMeta(type):
"""Internal metaclass for _Protocol.
This exists so _Protocol classes can be generic without deriving
from Generic.
"""
def __instancecheck__(self, obj):
if _Protocol not in self.__bases__:
return super().__instancecheck__(obj)
raise TypeError("Protocols cannot be used with isinstance().")
def __subclasscheck__(self, cls):
if not self._is_protocol:
# No structural checks since this isn't a protocol.
return NotImplemented
if self is _Protocol:
# Every class is a subclass of the empty protocol.
return True
# Find all attributes defined in the protocol.
attrs = self._get_protocol_attrs()
for attr in attrs:
if not any(attr in d.__dict__ for d in cls.__mro__):
return False
return True
def _get_protocol_attrs(self):
# Get all Protocol base classes.
protocol_bases = []
for c in self.__mro__:
if getattr(c, '_is_protocol', False) and c.__name__ != '_Protocol':
protocol_bases.append(c)
# Get attributes included in protocol.
attrs = set()
for base in protocol_bases:
for attr in base.__dict__.keys():
# Include attributes not defined in any non-protocol bases.
for c in self.__mro__:
if (c is not base and attr in c.__dict__ and
not getattr(c, '_is_protocol', False)):
break
else:
if (not attr.startswith('_abc_') and
attr != '__abstractmethods__' and
attr != '__annotations__' and
attr != '__weakref__' and
attr != '_is_protocol' and
attr != '_gorg' and
attr != '__dict__' and
attr != '__args__' and
attr != '__slots__' and
attr != '_get_protocol_attrs' and
attr != '__next_in_mro__' and
attr != '__parameters__' and
attr != '__origin__' and
attr != '__orig_bases__' and
attr != '__extra__' and
attr != '__tree_hash__' and
attr != '__module__'):
attrs.add(attr)
return attrs
class _Protocol(Generic, metaclass=_ProtocolMeta):
"""Internal base class for protocol classes.
This implements a simple-minded structural issubclass check
(similar but more general than the one-offs in collections.abc
such as Hashable).
"""
__slots__ = ()
_is_protocol = True
def __class_getitem__(cls, params):
return super().__class_getitem__(params)
# Some unconstrained type variables. These are used by the container types.
# (These are not for export.)
T = TypeVar('T') # Any type.
@ -1347,7 +1460,8 @@ Type.__doc__ = \
"""
class SupportsInt(_Protocol):
@runtime_checkable
class SupportsInt(Protocol):
__slots__ = ()
@abstractmethod
@ -1355,7 +1469,8 @@ class SupportsInt(_Protocol):
pass
class SupportsFloat(_Protocol):
@runtime_checkable
class SupportsFloat(Protocol):
__slots__ = ()
@abstractmethod
@ -1363,7 +1478,8 @@ class SupportsFloat(_Protocol):
pass
class SupportsComplex(_Protocol):
@runtime_checkable
class SupportsComplex(Protocol):
__slots__ = ()
@abstractmethod
@ -1371,7 +1487,8 @@ class SupportsComplex(_Protocol):
pass
class SupportsBytes(_Protocol):
@runtime_checkable
class SupportsBytes(Protocol):
__slots__ = ()
@abstractmethod
@ -1379,7 +1496,8 @@ class SupportsBytes(_Protocol):
pass
class SupportsIndex(_Protocol):
@runtime_checkable
class SupportsIndex(Protocol):
__slots__ = ()
@abstractmethod
@ -1387,7 +1505,8 @@ class SupportsIndex(_Protocol):
pass
class SupportsAbs(_Protocol[T_co]):
@runtime_checkable
class SupportsAbs(Protocol[T_co]):
__slots__ = ()
@abstractmethod
@ -1395,7 +1514,8 @@ class SupportsAbs(_Protocol[T_co]):
pass
class SupportsRound(_Protocol[T_co]):
@runtime_checkable
class SupportsRound(Protocol[T_co]):
__slots__ = ()
@abstractmethod