mirror of
https://github.com/python/cpython.git
synced 2025-08-04 00:48:58 +00:00
[3.9] bpo-42345: Fix three issues with typing.Literal parameters (GH-23294) (GH-23335)
Literal equality no longer depends on the order of arguments.
Fix issue related to `typing.Literal` caching by adding `typed` parameter to `typing._tp_cache` function.
Add deduplication of `typing.Literal` arguments.
(cherry picked from commit f03d318ca4
)
Co-authored-by: Yurii Karabas <1998uriyyo@gmail.com>
This commit is contained in:
parent
656d50f98d
commit
ac472b316c
4 changed files with 105 additions and 23 deletions
100
Lib/typing.py
100
Lib/typing.py
|
@ -200,6 +200,20 @@ def _check_generic(cls, parameters, elen):
|
|||
f" actual {alen}, expected {elen}")
|
||||
|
||||
|
||||
def _deduplicate(params):
|
||||
# Weed out strict duplicates, preserving the first of each occurrence.
|
||||
all_params = set(params)
|
||||
if len(all_params) < len(params):
|
||||
new_params = []
|
||||
for t in params:
|
||||
if t in all_params:
|
||||
new_params.append(t)
|
||||
all_params.remove(t)
|
||||
params = new_params
|
||||
assert not all_params, all_params
|
||||
return params
|
||||
|
||||
|
||||
def _remove_dups_flatten(parameters):
|
||||
"""An internal helper for Union creation and substitution: flatten Unions
|
||||
among parameters, then remove duplicates.
|
||||
|
@ -213,38 +227,45 @@ def _remove_dups_flatten(parameters):
|
|||
params.extend(p[1:])
|
||||
else:
|
||||
params.append(p)
|
||||
# Weed out strict duplicates, preserving the first of each occurrence.
|
||||
all_params = set(params)
|
||||
if len(all_params) < len(params):
|
||||
new_params = []
|
||||
for t in params:
|
||||
if t in all_params:
|
||||
new_params.append(t)
|
||||
all_params.remove(t)
|
||||
params = new_params
|
||||
assert not all_params, all_params
|
||||
|
||||
return tuple(_deduplicate(params))
|
||||
|
||||
|
||||
def _flatten_literal_params(parameters):
|
||||
"""An internal helper for Literal creation: flatten Literals among parameters"""
|
||||
params = []
|
||||
for p in parameters:
|
||||
if isinstance(p, _LiteralGenericAlias):
|
||||
params.extend(p.__args__)
|
||||
else:
|
||||
params.append(p)
|
||||
return tuple(params)
|
||||
|
||||
|
||||
_cleanups = []
|
||||
|
||||
|
||||
def _tp_cache(func):
|
||||
def _tp_cache(func=None, /, *, typed=False):
|
||||
"""Internal wrapper caching __getitem__ of generic types with a fallback to
|
||||
original function for non-hashable arguments.
|
||||
"""
|
||||
cached = functools.lru_cache()(func)
|
||||
_cleanups.append(cached.cache_clear)
|
||||
def decorator(func):
|
||||
cached = functools.lru_cache(typed=typed)(func)
|
||||
_cleanups.append(cached.cache_clear)
|
||||
|
||||
@functools.wraps(func)
|
||||
def inner(*args, **kwds):
|
||||
try:
|
||||
return cached(*args, **kwds)
|
||||
except TypeError:
|
||||
pass # All real errors (not unhashable args) are raised below.
|
||||
return func(*args, **kwds)
|
||||
return inner
|
||||
@functools.wraps(func)
|
||||
def inner(*args, **kwds):
|
||||
try:
|
||||
return cached(*args, **kwds)
|
||||
except TypeError:
|
||||
pass # All real errors (not unhashable args) are raised below.
|
||||
return func(*args, **kwds)
|
||||
return inner
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
|
||||
return decorator
|
||||
|
||||
def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
|
||||
"""Evaluate all forward references in the given type t.
|
||||
|
@ -317,6 +338,13 @@ class _SpecialForm(_Final, _root=True):
|
|||
def __getitem__(self, parameters):
|
||||
return self._getitem(self, parameters)
|
||||
|
||||
|
||||
class _LiteralSpecialForm(_SpecialForm, _root=True):
|
||||
@_tp_cache(typed=True)
|
||||
def __getitem__(self, parameters):
|
||||
return self._getitem(self, parameters)
|
||||
|
||||
|
||||
@_SpecialForm
|
||||
def Any(self, parameters):
|
||||
"""Special type indicating an unconstrained type.
|
||||
|
@ -434,7 +462,7 @@ def Optional(self, parameters):
|
|||
arg = _type_check(parameters, f"{self} requires a single type.")
|
||||
return Union[arg, type(None)]
|
||||
|
||||
@_SpecialForm
|
||||
@_LiteralSpecialForm
|
||||
def Literal(self, parameters):
|
||||
"""Special typing form to define literal types (a.k.a. value types).
|
||||
|
||||
|
@ -458,7 +486,17 @@ def Literal(self, parameters):
|
|||
"""
|
||||
# There is no '_type_check' call because arguments to Literal[...] are
|
||||
# values, not types.
|
||||
return _GenericAlias(self, parameters)
|
||||
if not isinstance(parameters, tuple):
|
||||
parameters = (parameters,)
|
||||
|
||||
parameters = _flatten_literal_params(parameters)
|
||||
|
||||
try:
|
||||
parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))
|
||||
except TypeError: # unhashable parameters
|
||||
pass
|
||||
|
||||
return _LiteralGenericAlias(self, parameters)
|
||||
|
||||
|
||||
class ForwardRef(_Final, _root=True):
|
||||
|
@ -881,6 +919,22 @@ class _UnionGenericAlias(_GenericAlias, _root=True):
|
|||
return super().__repr__()
|
||||
|
||||
|
||||
def _value_and_type_iter(parameters):
|
||||
return ((p, type(p)) for p in parameters)
|
||||
|
||||
|
||||
class _LiteralGenericAlias(_GenericAlias, _root=True):
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, _LiteralGenericAlias):
|
||||
return NotImplemented
|
||||
|
||||
return set(_value_and_type_iter(self.__args__)) == set(_value_and_type_iter(other.__args__))
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple(_value_and_type_iter(self.__args__)))
|
||||
|
||||
|
||||
class Generic:
|
||||
"""Abstract base class for generic types.
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue