gh-112281: Allow Union with unhashable Annotated metadata (#112283)

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
Nikita Sobolev 2024-03-01 19:19:24 +03:00 committed by GitHub
parent 2713c2abc8
commit a7549b03ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 155 additions and 17 deletions

View file

@ -308,19 +308,33 @@ def _unpack_args(args):
newargs.append(arg)
return newargs
def _deduplicate(params):
def _deduplicate(params, *, unhashable_fallback=False):
# Weed out strict duplicates, preserving the first of each occurrence.
all_params = set(params)
if len(all_params) < len(params):
new_params = []
for t in params:
if t in all_params:
new_params.append(t)
all_params.remove(t)
params = new_params
assert not all_params, all_params
return params
try:
return dict.fromkeys(params)
except TypeError:
if not unhashable_fallback:
raise
# Happens for cases like `Annotated[dict, {'x': IntValidator()}]`
return _deduplicate_unhashable(params)
def _deduplicate_unhashable(unhashable_params):
new_unhashable = []
for t in unhashable_params:
if t not in new_unhashable:
new_unhashable.append(t)
return new_unhashable
def _compare_args_orderless(first_args, second_args):
first_unhashable = _deduplicate_unhashable(first_args)
second_unhashable = _deduplicate_unhashable(second_args)
t = list(second_unhashable)
try:
for elem in first_unhashable:
t.remove(elem)
except ValueError:
return False
return not t
def _remove_dups_flatten(parameters):
"""Internal helper for Union creation and substitution.
@ -335,7 +349,7 @@ def _remove_dups_flatten(parameters):
else:
params.append(p)
return tuple(_deduplicate(params))
return tuple(_deduplicate(params, unhashable_fallback=True))
def _flatten_literal_params(parameters):
@ -1555,7 +1569,10 @@ class _UnionGenericAlias(_NotIterable, _GenericAlias, _root=True):
def __eq__(self, other):
if not isinstance(other, (_UnionGenericAlias, types.UnionType)):
return NotImplemented
return set(self.__args__) == set(other.__args__)
try: # fast path
return set(self.__args__) == set(other.__args__)
except TypeError: # not hashable, slow path
return _compare_args_orderless(self.__args__, other.__args__)
def __hash__(self):
return hash(frozenset(self.__args__))