[3.12] gh-112281: Allow Union with unhashable Annotated metadata (GH-112283) (#116213)

Co-authored-by: Nikita Sobolev <mail@sobolevn.me>
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
Miss Islington (bot) 2024-03-01 19:01:27 +01:00 committed by GitHub
parent 16be4a3b93
commit 90f75e1069
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 155 additions and 17 deletions

View file

@ -314,19 +314,33 @@ def _unpack_args(args):
newargs.append(arg)
return newargs
def _deduplicate(params):
def _deduplicate(params, *, unhashable_fallback=False):
# Weed out strict duplicates, preserving the first of each occurrence.
all_params = set(params)
if len(all_params) < len(params):
new_params = []
for t in params:
if t in all_params:
new_params.append(t)
all_params.remove(t)
params = new_params
assert not all_params, all_params
return params
try:
return dict.fromkeys(params)
except TypeError:
if not unhashable_fallback:
raise
# Happens for cases like `Annotated[dict, {'x': IntValidator()}]`
return _deduplicate_unhashable(params)
def _deduplicate_unhashable(unhashable_params):
new_unhashable = []
for t in unhashable_params:
if t not in new_unhashable:
new_unhashable.append(t)
return new_unhashable
def _compare_args_orderless(first_args, second_args):
first_unhashable = _deduplicate_unhashable(first_args)
second_unhashable = _deduplicate_unhashable(second_args)
t = list(second_unhashable)
try:
for elem in first_unhashable:
t.remove(elem)
except ValueError:
return False
return not t
def _remove_dups_flatten(parameters):
"""Internal helper for Union creation and substitution.
@ -341,7 +355,7 @@ def _remove_dups_flatten(parameters):
else:
params.append(p)
return tuple(_deduplicate(params))
return tuple(_deduplicate(params, unhashable_fallback=True))
def _flatten_literal_params(parameters):
@ -1548,7 +1562,10 @@ class _UnionGenericAlias(_NotIterable, _GenericAlias, _root=True):
def __eq__(self, other):
if not isinstance(other, (_UnionGenericAlias, types.UnionType)):
return NotImplemented
return set(self.__args__) == set(other.__args__)
try: # fast path
return set(self.__args__) == set(other.__args__)
except TypeError: # not hashable, slow path
return _compare_args_orderless(self.__args__, other.__args__)
def __hash__(self):
return hash(frozenset(self.__args__))