mirror of
https://github.com/python/cpython.git
synced 2025-09-26 18:29:57 +00:00
gh-82129: Improve annotations for make_dataclass() (#133406)
Co-authored-by: sobolevn <mail@sobolevn.me> Co-authored-by: Carl Meyer <carl@oddbird.net>
This commit is contained in:
parent
4e498d1e8b
commit
bb5ec6ea6e
3 changed files with 97 additions and 12 deletions
|
@ -244,6 +244,10 @@ _ATOMIC_TYPES = frozenset({
|
|||
property,
|
||||
})
|
||||
|
||||
# Any marker is used in `make_dataclass` to mark unannotated fields as `Any`
|
||||
# without importing `typing` module.
|
||||
_ANY_MARKER = object()
|
||||
|
||||
|
||||
class InitVar:
|
||||
__slots__ = ('type', )
|
||||
|
@ -1591,7 +1595,7 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
|
|||
for item in fields:
|
||||
if isinstance(item, str):
|
||||
name = item
|
||||
tp = 'typing.Any'
|
||||
tp = _ANY_MARKER
|
||||
elif len(item) == 2:
|
||||
name, tp, = item
|
||||
elif len(item) == 3:
|
||||
|
@ -1610,15 +1614,49 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
|
|||
seen.add(name)
|
||||
annotations[name] = tp
|
||||
|
||||
# We initially block the VALUE format, because inside dataclass() we'll
|
||||
# call get_annotations(), which will try the VALUE format first. If we don't
|
||||
# block, that means we'd always end up eagerly importing typing here, which
|
||||
# is what we're trying to avoid.
|
||||
value_blocked = True
|
||||
|
||||
def annotate_method(format):
|
||||
def get_any():
|
||||
match format:
|
||||
case annotationlib.Format.STRING:
|
||||
return 'typing.Any'
|
||||
case annotationlib.Format.FORWARDREF:
|
||||
typing = sys.modules.get("typing")
|
||||
if typing is None:
|
||||
return annotationlib.ForwardRef("Any", module="typing")
|
||||
else:
|
||||
return typing.Any
|
||||
case annotationlib.Format.VALUE:
|
||||
if value_blocked:
|
||||
raise NotImplementedError
|
||||
from typing import Any
|
||||
return Any
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
annos = {
|
||||
ann: get_any() if t is _ANY_MARKER else t
|
||||
for ann, t in annotations.items()
|
||||
}
|
||||
if format == annotationlib.Format.STRING:
|
||||
return annotationlib.annotations_to_string(annos)
|
||||
else:
|
||||
return annos
|
||||
|
||||
# Update 'ns' with the user-supplied namespace plus our calculated values.
|
||||
def exec_body_callback(ns):
|
||||
ns.update(namespace)
|
||||
ns.update(defaults)
|
||||
ns['__annotations__'] = annotations
|
||||
|
||||
# We use `types.new_class()` instead of simply `type()` to allow dynamic creation
|
||||
# of generic dataclasses.
|
||||
cls = types.new_class(cls_name, bases, {}, exec_body_callback)
|
||||
# For now, set annotations including the _ANY_MARKER.
|
||||
cls.__annotate__ = annotate_method
|
||||
|
||||
# For pickling to work, the __module__ variable needs to be set to the frame
|
||||
# where the dataclass is created.
|
||||
|
@ -1634,10 +1672,13 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
|
|||
cls.__module__ = module
|
||||
|
||||
# Apply the normal provided decorator.
|
||||
return decorator(cls, init=init, repr=repr, eq=eq, order=order,
|
||||
unsafe_hash=unsafe_hash, frozen=frozen,
|
||||
match_args=match_args, kw_only=kw_only, slots=slots,
|
||||
weakref_slot=weakref_slot)
|
||||
cls = decorator(cls, init=init, repr=repr, eq=eq, order=order,
|
||||
unsafe_hash=unsafe_hash, frozen=frozen,
|
||||
match_args=match_args, kw_only=kw_only, slots=slots,
|
||||
weakref_slot=weakref_slot)
|
||||
# Now that the class is ready, allow the VALUE format.
|
||||
value_blocked = False
|
||||
return cls
|
||||
|
||||
|
||||
def replace(obj, /, **changes):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue