If using a frozen class with slots, add __getstate__ and __setstate__ to set the instance values. (GH-25786)

This commit is contained in:
Eric V. Smith 2021-05-01 13:27:30 -04:00 committed by GitHub
parent f82fd77717
commit 823fbf4e0e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 2 deletions

View file

@ -1087,14 +1087,28 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
tuple(f.name for f in std_init_fields)) tuple(f.name for f in std_init_fields))
if slots: if slots:
cls = _add_slots(cls) cls = _add_slots(cls, frozen)
abc.update_abstractmethods(cls) abc.update_abstractmethods(cls)
return cls return cls
def _add_slots(cls): # _dataclass_getstate and _dataclass_setstate are needed for pickling frozen
# classes with slots. These could be slighly more performant if we generated
# the code instead of iterating over fields. But that can be a project for
# another day, if performance becomes an issue.
def _dataclass_getstate(self):
return [getattr(self, f.name) for f in fields(self)]
def _dataclass_setstate(self, state):
for field, value in zip(fields(self), state):
# use setattr because dataclass may be frozen
object.__setattr__(self, field.name, value)
def _add_slots(cls, is_frozen):
# Need to create a new class, since we can't set __slots__ # Need to create a new class, since we can't set __slots__
# after a class has been created. # after a class has been created.
@ -1120,6 +1134,11 @@ def _add_slots(cls):
if qualname is not None: if qualname is not None:
cls.__qualname__ = qualname cls.__qualname__ = qualname
if is_frozen:
# Need this for pickling frozen classes with slots.
cls.__getstate__ = _dataclass_getstate
cls.__setstate__ = _dataclass_setstate
return cls return cls

View file

@ -2833,6 +2833,19 @@ class TestSlots(unittest.TestCase):
self.assertFalse(hasattr(A, "__slots__")) self.assertFalse(hasattr(A, "__slots__"))
self.assertTrue(hasattr(B, "__slots__")) self.assertTrue(hasattr(B, "__slots__"))
# Can't be local to test_frozen_pickle.
@dataclass(frozen=True, slots=True)
class FrozenSlotsClass:
foo: str
bar: int
def test_frozen_pickle(self):
# bpo-43999
assert self.FrozenSlotsClass.__slots__ == ("foo", "bar")
p = pickle.dumps(self.FrozenSlotsClass("a", 1))
assert pickle.loads(p) == self.FrozenSlotsClass("a", 1)
class TestDescriptors(unittest.TestCase): class TestDescriptors(unittest.TestCase):
def test_set_name(self): def test_set_name(self):