mirror of
https://github.com/python/cpython.git
synced 2025-07-07 19:35:27 +00:00
gh-82129: Improve annotations for make_dataclass() (#133406)
Co-authored-by: sobolevn <mail@sobolevn.me> Co-authored-by: Carl Meyer <carl@oddbird.net>
This commit is contained in:
parent
4e498d1e8b
commit
bb5ec6ea6e
3 changed files with 97 additions and 12 deletions
|
@ -244,6 +244,10 @@ _ATOMIC_TYPES = frozenset({
|
||||||
property,
|
property,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# Any marker is used in `make_dataclass` to mark unannotated fields as `Any`
|
||||||
|
# without importing `typing` module.
|
||||||
|
_ANY_MARKER = object()
|
||||||
|
|
||||||
|
|
||||||
class InitVar:
|
class InitVar:
|
||||||
__slots__ = ('type', )
|
__slots__ = ('type', )
|
||||||
|
@ -1591,7 +1595,7 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
|
||||||
for item in fields:
|
for item in fields:
|
||||||
if isinstance(item, str):
|
if isinstance(item, str):
|
||||||
name = item
|
name = item
|
||||||
tp = 'typing.Any'
|
tp = _ANY_MARKER
|
||||||
elif len(item) == 2:
|
elif len(item) == 2:
|
||||||
name, tp, = item
|
name, tp, = item
|
||||||
elif len(item) == 3:
|
elif len(item) == 3:
|
||||||
|
@ -1610,15 +1614,49 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
|
||||||
seen.add(name)
|
seen.add(name)
|
||||||
annotations[name] = tp
|
annotations[name] = tp
|
||||||
|
|
||||||
|
# We initially block the VALUE format, because inside dataclass() we'll
|
||||||
|
# call get_annotations(), which will try the VALUE format first. If we don't
|
||||||
|
# block, that means we'd always end up eagerly importing typing here, which
|
||||||
|
# is what we're trying to avoid.
|
||||||
|
value_blocked = True
|
||||||
|
|
||||||
|
def annotate_method(format):
|
||||||
|
def get_any():
|
||||||
|
match format:
|
||||||
|
case annotationlib.Format.STRING:
|
||||||
|
return 'typing.Any'
|
||||||
|
case annotationlib.Format.FORWARDREF:
|
||||||
|
typing = sys.modules.get("typing")
|
||||||
|
if typing is None:
|
||||||
|
return annotationlib.ForwardRef("Any", module="typing")
|
||||||
|
else:
|
||||||
|
return typing.Any
|
||||||
|
case annotationlib.Format.VALUE:
|
||||||
|
if value_blocked:
|
||||||
|
raise NotImplementedError
|
||||||
|
from typing import Any
|
||||||
|
return Any
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError
|
||||||
|
annos = {
|
||||||
|
ann: get_any() if t is _ANY_MARKER else t
|
||||||
|
for ann, t in annotations.items()
|
||||||
|
}
|
||||||
|
if format == annotationlib.Format.STRING:
|
||||||
|
return annotationlib.annotations_to_string(annos)
|
||||||
|
else:
|
||||||
|
return annos
|
||||||
|
|
||||||
# Update 'ns' with the user-supplied namespace plus our calculated values.
|
# Update 'ns' with the user-supplied namespace plus our calculated values.
|
||||||
def exec_body_callback(ns):
|
def exec_body_callback(ns):
|
||||||
ns.update(namespace)
|
ns.update(namespace)
|
||||||
ns.update(defaults)
|
ns.update(defaults)
|
||||||
ns['__annotations__'] = annotations
|
|
||||||
|
|
||||||
# We use `types.new_class()` instead of simply `type()` to allow dynamic creation
|
# We use `types.new_class()` instead of simply `type()` to allow dynamic creation
|
||||||
# of generic dataclasses.
|
# of generic dataclasses.
|
||||||
cls = types.new_class(cls_name, bases, {}, exec_body_callback)
|
cls = types.new_class(cls_name, bases, {}, exec_body_callback)
|
||||||
|
# For now, set annotations including the _ANY_MARKER.
|
||||||
|
cls.__annotate__ = annotate_method
|
||||||
|
|
||||||
# For pickling to work, the __module__ variable needs to be set to the frame
|
# For pickling to work, the __module__ variable needs to be set to the frame
|
||||||
# where the dataclass is created.
|
# where the dataclass is created.
|
||||||
|
@ -1634,10 +1672,13 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
|
||||||
cls.__module__ = module
|
cls.__module__ = module
|
||||||
|
|
||||||
# Apply the normal provided decorator.
|
# Apply the normal provided decorator.
|
||||||
return decorator(cls, init=init, repr=repr, eq=eq, order=order,
|
cls = decorator(cls, init=init, repr=repr, eq=eq, order=order,
|
||||||
unsafe_hash=unsafe_hash, frozen=frozen,
|
unsafe_hash=unsafe_hash, frozen=frozen,
|
||||||
match_args=match_args, kw_only=kw_only, slots=slots,
|
match_args=match_args, kw_only=kw_only, slots=slots,
|
||||||
weakref_slot=weakref_slot)
|
weakref_slot=weakref_slot)
|
||||||
|
# Now that the class is ready, allow the VALUE format.
|
||||||
|
value_blocked = False
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
def replace(obj, /, **changes):
|
def replace(obj, /, **changes):
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
from dataclasses import *
|
from dataclasses import *
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import annotationlib
|
||||||
import io
|
import io
|
||||||
import pickle
|
import pickle
|
||||||
import inspect
|
import inspect
|
||||||
|
@ -12,6 +13,7 @@ import builtins
|
||||||
import types
|
import types
|
||||||
import weakref
|
import weakref
|
||||||
import traceback
|
import traceback
|
||||||
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
@ -25,6 +27,7 @@ import typing # Needed for the string "typing.ClassVar[int]" to work as an
|
||||||
import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
|
import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
|
||||||
|
|
||||||
from test import support
|
from test import support
|
||||||
|
from test.support import import_helper
|
||||||
|
|
||||||
# Just any custom exception we can catch.
|
# Just any custom exception we can catch.
|
||||||
class CustomError(Exception): pass
|
class CustomError(Exception): pass
|
||||||
|
@ -3754,7 +3757,6 @@ class TestSlots(unittest.TestCase):
|
||||||
@support.cpython_only
|
@support.cpython_only
|
||||||
def test_dataclass_slot_dict_ctype(self):
|
def test_dataclass_slot_dict_ctype(self):
|
||||||
# https://github.com/python/cpython/issues/123935
|
# https://github.com/python/cpython/issues/123935
|
||||||
from test.support import import_helper
|
|
||||||
# Skips test if `_testcapi` is not present:
|
# Skips test if `_testcapi` is not present:
|
||||||
_testcapi = import_helper.import_module('_testcapi')
|
_testcapi = import_helper.import_module('_testcapi')
|
||||||
|
|
||||||
|
@ -4246,16 +4248,56 @@ class TestMakeDataclass(unittest.TestCase):
|
||||||
C = make_dataclass('Point', ['x', 'y', 'z'])
|
C = make_dataclass('Point', ['x', 'y', 'z'])
|
||||||
c = C(1, 2, 3)
|
c = C(1, 2, 3)
|
||||||
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
|
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
|
||||||
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
|
self.assertEqual(C.__annotations__, {'x': typing.Any,
|
||||||
'y': 'typing.Any',
|
'y': typing.Any,
|
||||||
'z': 'typing.Any'})
|
'z': typing.Any})
|
||||||
|
|
||||||
C = make_dataclass('Point', ['x', ('y', int), 'z'])
|
C = make_dataclass('Point', ['x', ('y', int), 'z'])
|
||||||
c = C(1, 2, 3)
|
c = C(1, 2, 3)
|
||||||
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
|
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
|
||||||
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
|
self.assertEqual(C.__annotations__, {'x': typing.Any,
|
||||||
'y': int,
|
'y': int,
|
||||||
'z': 'typing.Any'})
|
'z': typing.Any})
|
||||||
|
|
||||||
|
def test_no_types_get_annotations(self):
|
||||||
|
C = make_dataclass('C', ['x', ('y', int), 'z'])
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
annotationlib.get_annotations(C, format=annotationlib.Format.VALUE),
|
||||||
|
{'x': typing.Any, 'y': int, 'z': typing.Any},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
annotationlib.get_annotations(
|
||||||
|
C, format=annotationlib.Format.FORWARDREF),
|
||||||
|
{'x': typing.Any, 'y': int, 'z': typing.Any},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
annotationlib.get_annotations(
|
||||||
|
C, format=annotationlib.Format.STRING),
|
||||||
|
{'x': 'typing.Any', 'y': 'int', 'z': 'typing.Any'},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_no_types_no_typing_import(self):
|
||||||
|
with import_helper.CleanImport('typing'):
|
||||||
|
self.assertNotIn('typing', sys.modules)
|
||||||
|
C = make_dataclass('C', ['x', ('y', int)])
|
||||||
|
|
||||||
|
self.assertNotIn('typing', sys.modules)
|
||||||
|
self.assertEqual(
|
||||||
|
C.__annotate__(annotationlib.Format.FORWARDREF),
|
||||||
|
{
|
||||||
|
'x': annotationlib.ForwardRef('Any', module='typing'),
|
||||||
|
'y': int,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertNotIn('typing', sys.modules)
|
||||||
|
|
||||||
|
for field in fields(C):
|
||||||
|
if field.name == "x":
|
||||||
|
self.assertEqual(field.type, annotationlib.ForwardRef('Any', module='typing'))
|
||||||
|
else:
|
||||||
|
self.assertEqual(field.name, "y")
|
||||||
|
self.assertIs(field.type, int)
|
||||||
|
|
||||||
def test_module_attr(self):
|
def test_module_attr(self):
|
||||||
self.assertEqual(ByMakeDataClass.__module__, __name__)
|
self.assertEqual(ByMakeDataClass.__module__, __name__)
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
Fix :exc:`NameError` when calling :func:`typing.get_type_hints` on a :func:`dataclasses.dataclass` created by
|
||||||
|
:func:`dataclasses.make_dataclass` with un-annotated fields.
|
Loading…
Add table
Add a link
Reference in a new issue