gh-118647: Add defaults to typing.Generator and typing.AsyncGenerator (#118648)

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Nikita Sobolev <mail@sobolevn.me>
This commit is contained in:
Jelle Zijlstra 2024-05-06 15:35:06 -07:00 committed by GitHub
parent 9fd33af5ac
commit 8419f01673
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 51 additions and 9 deletions

View file

@ -1328,7 +1328,7 @@ class _BaseGenericAlias(_Final, _root=True):
raise AttributeError(attr)
def __setattr__(self, attr, val):
if _is_dunder(attr) or attr in {'_name', '_inst', '_nparams'}:
if _is_dunder(attr) or attr in {'_name', '_inst', '_nparams', '_defaults'}:
super().__setattr__(attr, val)
else:
setattr(self.__origin__, attr, val)
@ -1578,11 +1578,12 @@ class _GenericAlias(_BaseGenericAlias, _root=True):
# parameters are accepted (needs custom __getitem__).
class _SpecialGenericAlias(_NotIterable, _BaseGenericAlias, _root=True):
def __init__(self, origin, nparams, *, inst=True, name=None):
def __init__(self, origin, nparams, *, inst=True, name=None, defaults=()):
if name is None:
name = origin.__name__
super().__init__(origin, inst=inst, name=name)
self._nparams = nparams
self._defaults = defaults
if origin.__module__ == 'builtins':
self.__doc__ = f'A generic version of {origin.__qualname__}.'
else:
@ -1594,12 +1595,22 @@ class _SpecialGenericAlias(_NotIterable, _BaseGenericAlias, _root=True):
params = (params,)
msg = "Parameters to generic types must be types."
params = tuple(_type_check(p, msg) for p in params)
if (self._defaults
and len(params) < self._nparams
and len(params) + len(self._defaults) >= self._nparams
):
params = (*params, *self._defaults[len(params) - self._nparams:])
actual_len = len(params)
if actual_len != self._nparams:
if self._defaults:
expected = f"at least {self._nparams - len(self._defaults)}"
else:
expected = str(self._nparams)
if not self._nparams:
raise TypeError(f"{self} is not a generic class")
raise TypeError(f"Too {'many' if actual_len > self._nparams else 'few'} arguments for {self};"
f" actual {actual_len}, expected {self._nparams}")
f" actual {actual_len}, expected {expected}")
return self.copy_with(params)
def copy_with(self, params):
@ -2813,8 +2824,8 @@ DefaultDict = _alias(collections.defaultdict, 2, name='DefaultDict')
OrderedDict = _alias(collections.OrderedDict, 2)
Counter = _alias(collections.Counter, 1)
ChainMap = _alias(collections.ChainMap, 2)
Generator = _alias(collections.abc.Generator, 3)
AsyncGenerator = _alias(collections.abc.AsyncGenerator, 2)
Generator = _alias(collections.abc.Generator, 3, defaults=(types.NoneType, types.NoneType))
AsyncGenerator = _alias(collections.abc.AsyncGenerator, 2, defaults=(types.NoneType,))
Type = _alias(type, 1, inst=False, name='Type')
Type.__doc__ = \
"""Deprecated alias to builtins.type.