bpo-46032: Check types in singledispatch's register() at declaration time (GH-30050)

The registry() method of functools.singledispatch() functions checks now
the first argument or the first parameter annotation and raises a TypeError if it is
not supported. Previously unsupported "types" were ignored (e.g. typing.List[int])
or caused an error at calling time (e.g. list[int]).
This commit is contained in:
Serhiy Storchaka 2021-12-25 14:16:14 +02:00 committed by GitHub
parent 1b30660c3b
commit 078abb676c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 89 additions and 5 deletions

View file

@ -740,6 +740,7 @@ def _compose_mro(cls, types):
# Remove entries which are already present in the __mro__ or unrelated.
def is_related(typ):
return (typ not in bases and hasattr(typ, '__mro__')
and not isinstance(typ, GenericAlias)
and issubclass(cls, typ))
types = [n for n in types if is_related(n)]
# Remove entries which are strict bases of other entries (they will end up
@ -841,9 +842,13 @@ def singledispatch(func):
from typing import get_origin, Union
return get_origin(cls) in {Union, types.UnionType}
def _is_valid_union_type(cls):
def _is_valid_dispatch_type(cls):
if isinstance(cls, type) and not isinstance(cls, GenericAlias):
return True
from typing import get_args
return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls))
return (_is_union_type(cls) and
all(isinstance(arg, type) and not isinstance(arg, GenericAlias)
for arg in get_args(cls)))
def register(cls, func=None):
"""generic_func.register(cls, func) -> func
@ -852,9 +857,15 @@ def singledispatch(func):
"""
nonlocal cache_token
if func is None:
if isinstance(cls, type) or _is_valid_union_type(cls):
if _is_valid_dispatch_type(cls):
if func is None:
return lambda f: register(cls, f)
else:
if func is not None:
raise TypeError(
f"Invalid first argument to `register()`. "
f"{cls!r} is not a class or union type."
)
ann = getattr(cls, '__annotations__', {})
if not ann:
raise TypeError(
@ -867,7 +878,7 @@ def singledispatch(func):
# only import typing if annotation parsing is necessary
from typing import get_type_hints
argname, cls = next(iter(get_type_hints(func).items()))
if not isinstance(cls, type) and not _is_valid_union_type(cls):
if not _is_valid_dispatch_type(cls):
if _is_union_type(cls):
raise TypeError(
f"Invalid annotation for {argname!r}. "