gh-82129: Improve annotations for make_dataclass() (#133406)

Co-authored-by: sobolevn <mail@sobolevn.me>
Co-authored-by: Carl Meyer <carl@oddbird.net>
This commit is contained in:
Jelle Zijlstra 2025-05-05 08:21:32 -07:00 committed by GitHub
parent 4e498d1e8b
commit bb5ec6ea6e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 97 additions and 12 deletions

View file

@ -244,6 +244,10 @@ _ATOMIC_TYPES = frozenset({
property, property,
}) })
# Any marker is used in `make_dataclass` to mark unannotated fields as `Any`
# without importing `typing` module.
_ANY_MARKER = object()
class InitVar: class InitVar:
__slots__ = ('type', ) __slots__ = ('type', )
@ -1591,7 +1595,7 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
for item in fields: for item in fields:
if isinstance(item, str): if isinstance(item, str):
name = item name = item
tp = 'typing.Any' tp = _ANY_MARKER
elif len(item) == 2: elif len(item) == 2:
name, tp, = item name, tp, = item
elif len(item) == 3: elif len(item) == 3:
@ -1610,15 +1614,49 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
seen.add(name) seen.add(name)
annotations[name] = tp annotations[name] = tp
# We initially block the VALUE format, because inside dataclass() we'll
# call get_annotations(), which will try the VALUE format first. If we don't
# block, that means we'd always end up eagerly importing typing here, which
# is what we're trying to avoid.
value_blocked = True
def annotate_method(format):
def get_any():
match format:
case annotationlib.Format.STRING:
return 'typing.Any'
case annotationlib.Format.FORWARDREF:
typing = sys.modules.get("typing")
if typing is None:
return annotationlib.ForwardRef("Any", module="typing")
else:
return typing.Any
case annotationlib.Format.VALUE:
if value_blocked:
raise NotImplementedError
from typing import Any
return Any
case _:
raise NotImplementedError
annos = {
ann: get_any() if t is _ANY_MARKER else t
for ann, t in annotations.items()
}
if format == annotationlib.Format.STRING:
return annotationlib.annotations_to_string(annos)
else:
return annos
# Update 'ns' with the user-supplied namespace plus our calculated values. # Update 'ns' with the user-supplied namespace plus our calculated values.
def exec_body_callback(ns): def exec_body_callback(ns):
ns.update(namespace) ns.update(namespace)
ns.update(defaults) ns.update(defaults)
ns['__annotations__'] = annotations
# We use `types.new_class()` instead of simply `type()` to allow dynamic creation # We use `types.new_class()` instead of simply `type()` to allow dynamic creation
# of generic dataclasses. # of generic dataclasses.
cls = types.new_class(cls_name, bases, {}, exec_body_callback) cls = types.new_class(cls_name, bases, {}, exec_body_callback)
# For now, set annotations including the _ANY_MARKER.
cls.__annotate__ = annotate_method
# For pickling to work, the __module__ variable needs to be set to the frame # For pickling to work, the __module__ variable needs to be set to the frame
# where the dataclass is created. # where the dataclass is created.
@ -1634,10 +1672,13 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
cls.__module__ = module cls.__module__ = module
# Apply the normal provided decorator. # Apply the normal provided decorator.
return decorator(cls, init=init, repr=repr, eq=eq, order=order, cls = decorator(cls, init=init, repr=repr, eq=eq, order=order,
unsafe_hash=unsafe_hash, frozen=frozen, unsafe_hash=unsafe_hash, frozen=frozen,
match_args=match_args, kw_only=kw_only, slots=slots, match_args=match_args, kw_only=kw_only, slots=slots,
weakref_slot=weakref_slot) weakref_slot=weakref_slot)
# Now that the class is ready, allow the VALUE format.
value_blocked = False
return cls
def replace(obj, /, **changes): def replace(obj, /, **changes):

View file

@ -5,6 +5,7 @@
from dataclasses import * from dataclasses import *
import abc import abc
import annotationlib
import io import io
import pickle import pickle
import inspect import inspect
@ -12,6 +13,7 @@ import builtins
import types import types
import weakref import weakref
import traceback import traceback
import sys
import textwrap import textwrap
import unittest import unittest
from unittest.mock import Mock from unittest.mock import Mock
@ -25,6 +27,7 @@ import typing # Needed for the string "typing.ClassVar[int]" to work as an
import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation. import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
from test import support from test import support
from test.support import import_helper
# Just any custom exception we can catch. # Just any custom exception we can catch.
class CustomError(Exception): pass class CustomError(Exception): pass
@ -3754,7 +3757,6 @@ class TestSlots(unittest.TestCase):
@support.cpython_only @support.cpython_only
def test_dataclass_slot_dict_ctype(self): def test_dataclass_slot_dict_ctype(self):
# https://github.com/python/cpython/issues/123935 # https://github.com/python/cpython/issues/123935
from test.support import import_helper
# Skips test if `_testcapi` is not present: # Skips test if `_testcapi` is not present:
_testcapi = import_helper.import_module('_testcapi') _testcapi = import_helper.import_module('_testcapi')
@ -4246,16 +4248,56 @@ class TestMakeDataclass(unittest.TestCase):
C = make_dataclass('Point', ['x', 'y', 'z']) C = make_dataclass('Point', ['x', 'y', 'z'])
c = C(1, 2, 3) c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
self.assertEqual(C.__annotations__, {'x': 'typing.Any', self.assertEqual(C.__annotations__, {'x': typing.Any,
'y': 'typing.Any', 'y': typing.Any,
'z': 'typing.Any'}) 'z': typing.Any})
C = make_dataclass('Point', ['x', ('y', int), 'z']) C = make_dataclass('Point', ['x', ('y', int), 'z'])
c = C(1, 2, 3) c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
self.assertEqual(C.__annotations__, {'x': 'typing.Any', self.assertEqual(C.__annotations__, {'x': typing.Any,
'y': int, 'y': int,
'z': 'typing.Any'}) 'z': typing.Any})
def test_no_types_get_annotations(self):
C = make_dataclass('C', ['x', ('y', int), 'z'])
self.assertEqual(
annotationlib.get_annotations(C, format=annotationlib.Format.VALUE),
{'x': typing.Any, 'y': int, 'z': typing.Any},
)
self.assertEqual(
annotationlib.get_annotations(
C, format=annotationlib.Format.FORWARDREF),
{'x': typing.Any, 'y': int, 'z': typing.Any},
)
self.assertEqual(
annotationlib.get_annotations(
C, format=annotationlib.Format.STRING),
{'x': 'typing.Any', 'y': 'int', 'z': 'typing.Any'},
)
def test_no_types_no_typing_import(self):
with import_helper.CleanImport('typing'):
self.assertNotIn('typing', sys.modules)
C = make_dataclass('C', ['x', ('y', int)])
self.assertNotIn('typing', sys.modules)
self.assertEqual(
C.__annotate__(annotationlib.Format.FORWARDREF),
{
'x': annotationlib.ForwardRef('Any', module='typing'),
'y': int,
},
)
self.assertNotIn('typing', sys.modules)
for field in fields(C):
if field.name == "x":
self.assertEqual(field.type, annotationlib.ForwardRef('Any', module='typing'))
else:
self.assertEqual(field.name, "y")
self.assertIs(field.type, int)
def test_module_attr(self): def test_module_attr(self):
self.assertEqual(ByMakeDataClass.__module__, __name__) self.assertEqual(ByMakeDataClass.__module__, __name__)

View file

@ -0,0 +1,2 @@
Fix :exc:`NameError` when calling :func:`typing.get_type_hints` on a :func:`dataclasses.dataclass` created by
:func:`dataclasses.make_dataclass` with un-annotated fields.