gh-98169 dataclasses.astuple support DefaultDict (#98170)

Co-authored-by: Pieter Eendebak <pieter.eendebak@gmail.com>
This commit is contained in:
T 2023-03-14 04:46:35 +08:00 committed by GitHub
parent 85ba8a3e03
commit 71e37d9079
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 13 deletions

View file

@ -1321,15 +1321,14 @@ def _asdict_inner(obj, dict_factory):
# generator (which is not true for namedtuples, handled # generator (which is not true for namedtuples, handled
# above). # above).
return type(obj)(_asdict_inner(v, dict_factory) for v in obj) return type(obj)(_asdict_inner(v, dict_factory) for v in obj)
elif isinstance(obj, dict) and hasattr(type(obj), 'default_factory'):
# obj is a defaultdict, which has a different constructor from
# dict as it requires the default_factory as its first arg.
# https://bugs.python.org/issue35540
result = type(obj)(getattr(obj, 'default_factory'))
for k, v in obj.items():
result[_asdict_inner(k, dict_factory)] = _asdict_inner(v, dict_factory)
return result
elif isinstance(obj, dict): elif isinstance(obj, dict):
if hasattr(type(obj), 'default_factory'):
# obj is a defaultdict, which has a different constructor from
# dict as it requires the default_factory as its first arg.
result = type(obj)(getattr(obj, 'default_factory'))
for k, v in obj.items():
result[_asdict_inner(k, dict_factory)] = _asdict_inner(v, dict_factory)
return result
return type(obj)((_asdict_inner(k, dict_factory), return type(obj)((_asdict_inner(k, dict_factory),
_asdict_inner(v, dict_factory)) _asdict_inner(v, dict_factory))
for k, v in obj.items()) for k, v in obj.items())
@ -1382,7 +1381,15 @@ def _astuple_inner(obj, tuple_factory):
# above). # above).
return type(obj)(_astuple_inner(v, tuple_factory) for v in obj) return type(obj)(_astuple_inner(v, tuple_factory) for v in obj)
elif isinstance(obj, dict): elif isinstance(obj, dict):
return type(obj)((_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory)) obj_type = type(obj)
if hasattr(obj_type, 'default_factory'):
# obj is a defaultdict, which has a different constructor from
# dict as it requires the default_factory as its first arg.
result = obj_type(getattr(obj, 'default_factory'))
for k, v in obj.items():
result[_astuple_inner(k, tuple_factory)] = _astuple_inner(v, tuple_factory)
return result
return obj_type((_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory))
for k, v in obj.items()) for k, v in obj.items())
else: else:
return copy.deepcopy(obj) return copy.deepcopy(obj)

View file

@ -1706,19 +1706,17 @@ class TestCase(unittest.TestCase):
def test_helper_asdict_defaultdict(self): def test_helper_asdict_defaultdict(self):
# Ensure asdict() does not throw exceptions when a # Ensure asdict() does not throw exceptions when a
# defaultdict is a member of a dataclass # defaultdict is a member of a dataclass
@dataclass @dataclass
class C: class C:
mp: DefaultDict[str, List] mp: DefaultDict[str, List]
dd = defaultdict(list) dd = defaultdict(list)
dd["x"].append(12) dd["x"].append(12)
c = C(mp=dd) c = C(mp=dd)
d = asdict(c) d = asdict(c)
assert d == {"mp": {"x": [12]}} self.assertEqual(d, {"mp": {"x": [12]}})
assert d["mp"] is not c.mp # make sure defaultdict is copied self.assertTrue(d["mp"] is not c.mp) # make sure defaultdict is copied
def test_helper_astuple(self): def test_helper_astuple(self):
# Basic tests for astuple(), it should return a new tuple. # Basic tests for astuple(), it should return a new tuple.
@ -1847,6 +1845,21 @@ class TestCase(unittest.TestCase):
t = astuple(c, tuple_factory=list) t = astuple(c, tuple_factory=list)
self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)]) self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
def test_helper_astuple_defaultdict(self):
# Ensure astuple() does not throw exceptions when a
# defaultdict is a member of a dataclass
@dataclass
class C:
mp: DefaultDict[str, List]
dd = defaultdict(list)
dd["x"].append(12)
c = C(mp=dd)
t = astuple(c)
self.assertEqual(t, ({"x": [12]},))
self.assertTrue(t[0] is not dd) # make sure defaultdict is copied
def test_dynamic_class_creation(self): def test_dynamic_class_creation(self):
cls_dict = {'__annotations__': {'x': int, 'y': int}, cls_dict = {'__annotations__': {'x': int, 'y': int},
} }

View file

@ -0,0 +1,2 @@
Fix :func:`dataclasses.astuple` crash when :class:`collections.defaultdict`
is present in the attributes.