mirror of
https://github.com/python/cpython.git
synced 2025-08-04 08:59:19 +00:00
bpo-46014: Add ability to use typing.Union with singledispatch (GH-30017)
This commit is contained in:
parent
810c1769f1
commit
3cb357a2e6
3 changed files with 60 additions and 7 deletions
|
@ -837,6 +837,14 @@ def singledispatch(func):
|
|||
dispatch_cache[cls] = impl
|
||||
return impl
|
||||
|
||||
def _is_union_type(cls):
|
||||
from typing import get_origin, Union
|
||||
return get_origin(cls) in {Union, types.UnionType}
|
||||
|
||||
def _is_valid_union_type(cls):
|
||||
from typing import get_args
|
||||
return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls))
|
||||
|
||||
def register(cls, func=None):
|
||||
"""generic_func.register(cls, func) -> func
|
||||
|
||||
|
@ -845,7 +853,7 @@ def singledispatch(func):
|
|||
"""
|
||||
nonlocal cache_token
|
||||
if func is None:
|
||||
if isinstance(cls, type):
|
||||
if isinstance(cls, type) or _is_valid_union_type(cls):
|
||||
return lambda f: register(cls, f)
|
||||
ann = getattr(cls, '__annotations__', {})
|
||||
if not ann:
|
||||
|
@ -859,12 +867,25 @@ def singledispatch(func):
|
|||
# only import typing if annotation parsing is necessary
|
||||
from typing import get_type_hints
|
||||
argname, cls = next(iter(get_type_hints(func).items()))
|
||||
if not isinstance(cls, type):
|
||||
raise TypeError(
|
||||
f"Invalid annotation for {argname!r}. "
|
||||
f"{cls!r} is not a class."
|
||||
)
|
||||
registry[cls] = func
|
||||
if not isinstance(cls, type) and not _is_valid_union_type(cls):
|
||||
if _is_union_type(cls):
|
||||
raise TypeError(
|
||||
f"Invalid annotation for {argname!r}. "
|
||||
f"{cls!r} not all arguments are classes."
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Invalid annotation for {argname!r}. "
|
||||
f"{cls!r} is not a class."
|
||||
)
|
||||
|
||||
if _is_union_type(cls):
|
||||
from typing import get_args
|
||||
|
||||
for arg in get_args(cls):
|
||||
registry[arg] = func
|
||||
else:
|
||||
registry[cls] = func
|
||||
if cache_token is None and hasattr(cls, '__abstractmethods__'):
|
||||
cache_token = get_cache_token()
|
||||
dispatch_cache.clear()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue