mirror of
https://github.com/python/cpython.git
synced 2025-09-26 18:29:57 +00:00
bpo-46014: Add ability to use typing.Union with singledispatch (GH-30017)
This commit is contained in:
parent
810c1769f1
commit
3cb357a2e6
3 changed files with 60 additions and 7 deletions
|
@ -837,6 +837,14 @@ def singledispatch(func):
|
||||||
dispatch_cache[cls] = impl
|
dispatch_cache[cls] = impl
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
def _is_union_type(cls):
|
||||||
|
from typing import get_origin, Union
|
||||||
|
return get_origin(cls) in {Union, types.UnionType}
|
||||||
|
|
||||||
|
def _is_valid_union_type(cls):
|
||||||
|
from typing import get_args
|
||||||
|
return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls))
|
||||||
|
|
||||||
def register(cls, func=None):
|
def register(cls, func=None):
|
||||||
"""generic_func.register(cls, func) -> func
|
"""generic_func.register(cls, func) -> func
|
||||||
|
|
||||||
|
@ -845,7 +853,7 @@ def singledispatch(func):
|
||||||
"""
|
"""
|
||||||
nonlocal cache_token
|
nonlocal cache_token
|
||||||
if func is None:
|
if func is None:
|
||||||
if isinstance(cls, type):
|
if isinstance(cls, type) or _is_valid_union_type(cls):
|
||||||
return lambda f: register(cls, f)
|
return lambda f: register(cls, f)
|
||||||
ann = getattr(cls, '__annotations__', {})
|
ann = getattr(cls, '__annotations__', {})
|
||||||
if not ann:
|
if not ann:
|
||||||
|
@ -859,11 +867,24 @@ def singledispatch(func):
|
||||||
# only import typing if annotation parsing is necessary
|
# only import typing if annotation parsing is necessary
|
||||||
from typing import get_type_hints
|
from typing import get_type_hints
|
||||||
argname, cls = next(iter(get_type_hints(func).items()))
|
argname, cls = next(iter(get_type_hints(func).items()))
|
||||||
if not isinstance(cls, type):
|
if not isinstance(cls, type) and not _is_valid_union_type(cls):
|
||||||
|
if _is_union_type(cls):
|
||||||
|
raise TypeError(
|
||||||
|
f"Invalid annotation for {argname!r}. "
|
||||||
|
f"{cls!r} not all arguments are classes."
|
||||||
|
)
|
||||||
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Invalid annotation for {argname!r}. "
|
f"Invalid annotation for {argname!r}. "
|
||||||
f"{cls!r} is not a class."
|
f"{cls!r} is not a class."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if _is_union_type(cls):
|
||||||
|
from typing import get_args
|
||||||
|
|
||||||
|
for arg in get_args(cls):
|
||||||
|
registry[arg] = func
|
||||||
|
else:
|
||||||
registry[cls] = func
|
registry[cls] = func
|
||||||
if cache_token is None and hasattr(cls, '__abstractmethods__'):
|
if cache_token is None and hasattr(cls, '__abstractmethods__'):
|
||||||
cache_token = get_cache_token()
|
cache_token = get_cache_token()
|
||||||
|
|
|
@ -2684,6 +2684,17 @@ class TestSingleDispatch(unittest.TestCase):
|
||||||
'typing.Iterable[str] is not a class.'
|
'typing.Iterable[str] is not a class.'
|
||||||
))
|
))
|
||||||
|
|
||||||
|
with self.assertRaises(TypeError) as exc:
|
||||||
|
@i.register
|
||||||
|
def _(arg: typing.Union[int, typing.Iterable[str]]):
|
||||||
|
return "Invalid Union"
|
||||||
|
self.assertTrue(str(exc.exception).startswith(
|
||||||
|
"Invalid annotation for 'arg'."
|
||||||
|
))
|
||||||
|
self.assertTrue(str(exc.exception).endswith(
|
||||||
|
'typing.Union[int, typing.Iterable[str]] not all arguments are classes.'
|
||||||
|
))
|
||||||
|
|
||||||
def test_invalid_positional_argument(self):
|
def test_invalid_positional_argument(self):
|
||||||
@functools.singledispatch
|
@functools.singledispatch
|
||||||
def f(*args):
|
def f(*args):
|
||||||
|
@ -2692,6 +2703,25 @@ class TestSingleDispatch(unittest.TestCase):
|
||||||
with self.assertRaisesRegex(TypeError, msg):
|
with self.assertRaisesRegex(TypeError, msg):
|
||||||
f()
|
f()
|
||||||
|
|
||||||
|
def test_union(self):
|
||||||
|
@functools.singledispatch
|
||||||
|
def f(arg):
|
||||||
|
return "default"
|
||||||
|
|
||||||
|
@f.register
|
||||||
|
def _(arg: typing.Union[str, bytes]):
|
||||||
|
return "typing.Union"
|
||||||
|
|
||||||
|
@f.register
|
||||||
|
def _(arg: int | float):
|
||||||
|
return "types.UnionType"
|
||||||
|
|
||||||
|
self.assertEqual(f([]), "default")
|
||||||
|
self.assertEqual(f(""), "typing.Union")
|
||||||
|
self.assertEqual(f(b""), "typing.Union")
|
||||||
|
self.assertEqual(f(1), "types.UnionType")
|
||||||
|
self.assertEqual(f(1.0), "types.UnionType")
|
||||||
|
|
||||||
|
|
||||||
class CachedCostItem:
|
class CachedCostItem:
|
||||||
_cost = 1
|
_cost = 1
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
Add ability to use ``typing.Union`` and ``types.UnionType`` as dispatch
|
||||||
|
argument to ``functools.singledispatch``. Patch provided by Yurii Karabas.
|
Loading…
Add table
Add a link
Reference in a new issue