mirror of
https://github.com/python/cpython.git
synced 2025-08-04 08:59:19 +00:00
[3.11] gh-112281: Allow Union
with unhashable Annotated
metadata (GH-112283) (#116288)
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
parent
6c2484bbf1
commit
8bfbeeb0a1
4 changed files with 154 additions and 14 deletions
|
@ -709,6 +709,26 @@ class UnionTests(unittest.TestCase):
|
|||
self.assertEqual(hash(int | str), hash(str | int))
|
||||
self.assertEqual(hash(int | str), hash(typing.Union[int, str]))
|
||||
|
||||
def test_union_of_unhashable(self):
|
||||
class UnhashableMeta(type):
|
||||
__hash__ = None
|
||||
|
||||
class A(metaclass=UnhashableMeta): ...
|
||||
class B(metaclass=UnhashableMeta): ...
|
||||
|
||||
self.assertEqual((A | B).__args__, (A, B))
|
||||
union1 = A | B
|
||||
with self.assertRaises(TypeError):
|
||||
hash(union1)
|
||||
|
||||
union2 = int | B
|
||||
with self.assertRaises(TypeError):
|
||||
hash(union2)
|
||||
|
||||
union3 = A | int
|
||||
with self.assertRaises(TypeError):
|
||||
hash(union3)
|
||||
|
||||
def test_instancecheck_and_subclasscheck(self):
|
||||
for x in (int | str, typing.Union[int, str]):
|
||||
with self.subTest(x=x):
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import contextlib
|
||||
import collections
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache, wraps
|
||||
from functools import lru_cache, wraps, reduce
|
||||
import inspect
|
||||
import itertools
|
||||
import gc
|
||||
import operator
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
|
@ -1705,6 +1706,26 @@ class UnionTests(BaseTestCase):
|
|||
v = Union[u, Employee]
|
||||
self.assertEqual(v, Union[int, float, Employee])
|
||||
|
||||
def test_union_of_unhashable(self):
|
||||
class UnhashableMeta(type):
|
||||
__hash__ = None
|
||||
|
||||
class A(metaclass=UnhashableMeta): ...
|
||||
class B(metaclass=UnhashableMeta): ...
|
||||
|
||||
self.assertEqual(Union[A, B].__args__, (A, B))
|
||||
union1 = Union[A, B]
|
||||
with self.assertRaises(TypeError):
|
||||
hash(union1)
|
||||
|
||||
union2 = Union[int, B]
|
||||
with self.assertRaises(TypeError):
|
||||
hash(union2)
|
||||
|
||||
union3 = Union[A, int]
|
||||
with self.assertRaises(TypeError):
|
||||
hash(union3)
|
||||
|
||||
def test_repr(self):
|
||||
self.assertEqual(repr(Union), 'typing.Union')
|
||||
u = Union[Employee, int]
|
||||
|
@ -7374,6 +7395,76 @@ class AnnotatedTests(BaseTestCase):
|
|||
self.assertEqual(A.__metadata__, (4, 5))
|
||||
self.assertEqual(A.__origin__, int)
|
||||
|
||||
def test_deduplicate_from_union(self):
|
||||
# Regular:
|
||||
self.assertEqual(get_args(Annotated[int, 1] | int),
|
||||
(Annotated[int, 1], int))
|
||||
self.assertEqual(get_args(Union[Annotated[int, 1], int]),
|
||||
(Annotated[int, 1], int))
|
||||
self.assertEqual(get_args(Annotated[int, 1] | Annotated[int, 2] | int),
|
||||
(Annotated[int, 1], Annotated[int, 2], int))
|
||||
self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[int, 2], int]),
|
||||
(Annotated[int, 1], Annotated[int, 2], int))
|
||||
self.assertEqual(get_args(Annotated[int, 1] | Annotated[str, 1] | int),
|
||||
(Annotated[int, 1], Annotated[str, 1], int))
|
||||
self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[str, 1], int]),
|
||||
(Annotated[int, 1], Annotated[str, 1], int))
|
||||
|
||||
# Duplicates:
|
||||
self.assertEqual(Annotated[int, 1] | Annotated[int, 1] | int,
|
||||
Annotated[int, 1] | int)
|
||||
self.assertEqual(Union[Annotated[int, 1], Annotated[int, 1], int],
|
||||
Union[Annotated[int, 1], int])
|
||||
|
||||
# Unhashable metadata:
|
||||
self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[int, set()] | int),
|
||||
(str, Annotated[int, {}], Annotated[int, set()], int))
|
||||
self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[int, set()], int]),
|
||||
(str, Annotated[int, {}], Annotated[int, set()], int))
|
||||
self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[str, {}] | int),
|
||||
(str, Annotated[int, {}], Annotated[str, {}], int))
|
||||
self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[str, {}], int]),
|
||||
(str, Annotated[int, {}], Annotated[str, {}], int))
|
||||
|
||||
self.assertEqual(get_args(Annotated[int, 1] | str | Annotated[str, {}] | int),
|
||||
(Annotated[int, 1], str, Annotated[str, {}], int))
|
||||
self.assertEqual(get_args(Union[Annotated[int, 1], str, Annotated[str, {}], int]),
|
||||
(Annotated[int, 1], str, Annotated[str, {}], int))
|
||||
|
||||
import dataclasses
|
||||
@dataclasses.dataclass
|
||||
class ValueRange:
|
||||
lo: int
|
||||
hi: int
|
||||
v = ValueRange(1, 2)
|
||||
self.assertEqual(get_args(Annotated[int, v] | None),
|
||||
(Annotated[int, v], types.NoneType))
|
||||
self.assertEqual(get_args(Union[Annotated[int, v], None]),
|
||||
(Annotated[int, v], types.NoneType))
|
||||
self.assertEqual(get_args(Optional[Annotated[int, v]]),
|
||||
(Annotated[int, v], types.NoneType))
|
||||
|
||||
# Unhashable metadata duplicated:
|
||||
self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
|
||||
Annotated[int, {}] | int)
|
||||
self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
|
||||
int | Annotated[int, {}])
|
||||
self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
|
||||
Union[Annotated[int, {}], int])
|
||||
self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
|
||||
Union[int, Annotated[int, {}]])
|
||||
|
||||
def test_order_in_union(self):
|
||||
expr1 = Annotated[int, 1] | str | Annotated[str, {}] | int
|
||||
for args in itertools.permutations(get_args(expr1)):
|
||||
with self.subTest(args=args):
|
||||
self.assertEqual(expr1, reduce(operator.or_, args))
|
||||
|
||||
expr2 = Union[Annotated[int, 1], str, Annotated[str, {}], int]
|
||||
for args in itertools.permutations(get_args(expr2)):
|
||||
with self.subTest(args=args):
|
||||
self.assertEqual(expr2, Union[args])
|
||||
|
||||
def test_specialize(self):
|
||||
L = Annotated[List[T], "my decoration"]
|
||||
LI = Annotated[List[int], "my decoration"]
|
||||
|
@ -7394,6 +7485,16 @@ class AnnotatedTests(BaseTestCase):
|
|||
{Annotated[int, 4, 5], Annotated[int, 4, 5], Annotated[T, 4, 5]},
|
||||
{Annotated[int, 4, 5], Annotated[T, 4, 5]}
|
||||
)
|
||||
# Unhashable `metadata` raises `TypeError`:
|
||||
a1 = Annotated[int, []]
|
||||
with self.assertRaises(TypeError):
|
||||
hash(a1)
|
||||
|
||||
class A:
|
||||
__hash__ = None
|
||||
a2 = Annotated[int, A()]
|
||||
with self.assertRaises(TypeError):
|
||||
hash(a2)
|
||||
|
||||
def test_instantiate(self):
|
||||
class C:
|
||||
|
|
|
@ -303,19 +303,33 @@ def _unpack_args(args):
|
|||
newargs.append(arg)
|
||||
return newargs
|
||||
|
||||
def _deduplicate(params):
|
||||
def _deduplicate(params, *, unhashable_fallback=False):
|
||||
# Weed out strict duplicates, preserving the first of each occurrence.
|
||||
all_params = set(params)
|
||||
if len(all_params) < len(params):
|
||||
new_params = []
|
||||
for t in params:
|
||||
if t in all_params:
|
||||
new_params.append(t)
|
||||
all_params.remove(t)
|
||||
params = new_params
|
||||
assert not all_params, all_params
|
||||
return params
|
||||
try:
|
||||
return dict.fromkeys(params)
|
||||
except TypeError:
|
||||
if not unhashable_fallback:
|
||||
raise
|
||||
# Happens for cases like `Annotated[dict, {'x': IntValidator()}]`
|
||||
return _deduplicate_unhashable(params)
|
||||
|
||||
def _deduplicate_unhashable(unhashable_params):
|
||||
new_unhashable = []
|
||||
for t in unhashable_params:
|
||||
if t not in new_unhashable:
|
||||
new_unhashable.append(t)
|
||||
return new_unhashable
|
||||
|
||||
def _compare_args_orderless(first_args, second_args):
|
||||
first_unhashable = _deduplicate_unhashable(first_args)
|
||||
second_unhashable = _deduplicate_unhashable(second_args)
|
||||
t = list(second_unhashable)
|
||||
try:
|
||||
for elem in first_unhashable:
|
||||
t.remove(elem)
|
||||
except ValueError:
|
||||
return False
|
||||
return not t
|
||||
|
||||
def _remove_dups_flatten(parameters):
|
||||
"""Internal helper for Union creation and substitution.
|
||||
|
@ -330,7 +344,7 @@ def _remove_dups_flatten(parameters):
|
|||
else:
|
||||
params.append(p)
|
||||
|
||||
return tuple(_deduplicate(params))
|
||||
return tuple(_deduplicate(params, unhashable_fallback=True))
|
||||
|
||||
|
||||
def _flatten_literal_params(parameters):
|
||||
|
@ -1673,7 +1687,10 @@ class _UnionGenericAlias(_NotIterable, _GenericAlias, _root=True):
|
|||
def __eq__(self, other):
|
||||
if not isinstance(other, (_UnionGenericAlias, types.UnionType)):
|
||||
return NotImplemented
|
||||
return set(self.__args__) == set(other.__args__)
|
||||
try: # fast path
|
||||
return set(self.__args__) == set(other.__args__)
|
||||
except TypeError: # not hashable, slow path
|
||||
return _compare_args_orderless(self.__args__, other.__args__)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(frozenset(self.__args__))
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
Allow creating :ref:`union of types<types-union>` for
|
||||
:class:`typing.Annotated` with unhashable metadata.
|
Loading…
Add table
Add a link
Reference in a new issue