bpo-26579: Add object.__getstate__(). (GH-2821)

Copying and pickling instances of subclasses of builtin types
bytearray, set, frozenset, collections.OrderedDict, collections.deque,
weakref.WeakSet, and datetime.tzinfo now copies and pickles instance attributes
implemented as slots.
This commit is contained in:
Serhiy Storchaka 2022-04-06 20:00:14 +03:00 committed by GitHub
parent f82f9ce323
commit 884eba3c76
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 389 additions and 255 deletions

View file

@ -80,8 +80,7 @@ class WeakSet:
return wr in self.data
def __reduce__(self):
return (self.__class__, (list(self),),
getattr(self, '__dict__', None))
return self.__class__, (list(self),), self.__getstate__()
def add(self, item):
if self._pending_removals:

View file

@ -271,10 +271,22 @@ class OrderedDict(dict):
def __reduce__(self):
'Return state information for pickling'
inst_dict = vars(self).copy()
for k in vars(OrderedDict()):
inst_dict.pop(k, None)
return self.__class__, (), inst_dict or None, None, iter(self.items())
state = self.__getstate__()
if state:
if isinstance(state, tuple):
state, slots = state
else:
slots = {}
state = state.copy()
slots = slots.copy()
for k in vars(OrderedDict()):
state.pop(k, None)
slots.pop(k, None)
if slots:
state = state, slots
else:
state = state or None
return self.__class__, (), state, None, iter(self.items())
def copy(self):
'od.copy() -> a shallow copy of od'

View file

@ -89,6 +89,10 @@ def _reduce_ex(self, proto):
except AttributeError:
dict = None
else:
if (type(self).__getstate__ is object.__getstate__ and
getattr(self, "__slots__", None)):
raise TypeError("a class that defines __slots__ without "
"defining __getstate__ cannot be pickled")
dict = getstate()
if dict:
return _reconstructor, args, dict

View file

@ -1169,15 +1169,7 @@ class tzinfo:
args = getinitargs()
else:
args = ()
getstate = getattr(self, "__getstate__", None)
if getstate:
state = getstate()
else:
state = getattr(self, "__dict__", None) or None
if state is None:
return (self.__class__, args)
else:
return (self.__class__, args, state)
return (self.__class__, args, self.__getstate__())
class IsoCalendarDate(tuple):

View file

@ -218,7 +218,7 @@ class BaseHeader(str):
self.__class__.__bases__,
str(self),
),
self.__dict__)
self.__getstate__())
@classmethod
def _reconstruct(cls, value):

View file

@ -139,8 +139,8 @@ class PicklableFixedOffset(FixedOffset):
def __init__(self, offset=None, name=None, dstoffset=None):
FixedOffset.__init__(self, offset, name, dstoffset)
def __getstate__(self):
return self.__dict__
class PicklableFixedOffsetWithSlots(PicklableFixedOffset):
__slots__ = '_FixedOffset__offset', '_FixedOffset__name', 'spam'
class _TZInfo(tzinfo):
def utcoffset(self, datetime_module):
@ -202,6 +202,7 @@ class TestTZInfo(unittest.TestCase):
offset = timedelta(minutes=-300)
for otype, args in [
(PicklableFixedOffset, (offset, 'cookie')),
(PicklableFixedOffsetWithSlots, (offset, 'cookie')),
(timezone, (offset,)),
(timezone, (offset, "EST"))]:
orig = otype(*args)
@ -217,6 +218,7 @@ class TestTZInfo(unittest.TestCase):
self.assertIs(type(derived), otype)
self.assertEqual(derived.utcoffset(None), offset)
self.assertEqual(derived.tzname(None), oname)
self.assertFalse(hasattr(derived, 'spam'))
def test_issue23600(self):
DSTDIFF = DSTOFFSET = timedelta(hours=1)

View file

@ -2382,9 +2382,11 @@ class AbstractPickleTests:
def test_bad_getattr(self):
# Issue #3514: crash when there is an infinite loop in __getattr__
x = BadGetattr()
for proto in protocols:
for proto in range(2):
with support.infinite_recursion():
self.assertRaises(RuntimeError, self.dumps, x, proto)
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(x, proto)
def test_reduce_bad_iterator(self):
# Issue4176: crash when 4th and 5th items of __reduce__()

View file

@ -1940,28 +1940,30 @@ class SubclassTest:
def test_pickle(self):
a = self.type2test(b"abcd")
a.x = 10
a.y = self.type2test(b"efgh")
a.z = self.type2test(b"efgh")
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
b = pickle.loads(pickle.dumps(a, proto))
self.assertNotEqual(id(a), id(b))
self.assertEqual(a, b)
self.assertEqual(a.x, b.x)
self.assertEqual(a.y, b.y)
self.assertEqual(a.z, b.z)
self.assertEqual(type(a), type(b))
self.assertEqual(type(a.y), type(b.y))
self.assertEqual(type(a.z), type(b.z))
self.assertFalse(hasattr(b, 'y'))
def test_copy(self):
a = self.type2test(b"abcd")
a.x = 10
a.y = self.type2test(b"efgh")
a.z = self.type2test(b"efgh")
for copy_method in (copy.copy, copy.deepcopy):
b = copy_method(a)
self.assertNotEqual(id(a), id(b))
self.assertEqual(a, b)
self.assertEqual(a.x, b.x)
self.assertEqual(a.y, b.y)
self.assertEqual(a.z, b.z)
self.assertEqual(type(a), type(b))
self.assertEqual(type(a.y), type(b.y))
self.assertEqual(type(a.z), type(b.z))
self.assertFalse(hasattr(b, 'y'))
def test_fromhex(self):
b = self.type2test.fromhex('1a2B30')
@ -1994,6 +1996,9 @@ class SubclassTest:
class ByteArraySubclass(bytearray):
pass
class ByteArraySubclassWithSlots(bytearray):
__slots__ = ('x', 'y', '__dict__')
class BytesSubclass(bytes):
pass
@ -2014,6 +2019,9 @@ class ByteArraySubclassTest(SubclassTest, unittest.TestCase):
x = subclass(newarg=4, source=b"abcd")
self.assertEqual(x, b"abcd")
class ByteArraySubclassWithSlotsTest(SubclassTest, unittest.TestCase):
basetype = bytearray
type2test = ByteArraySubclassWithSlots
class BytesSubclassTest(SubclassTest, unittest.TestCase):
basetype = bytes

View file

@ -781,6 +781,9 @@ class TestVariousIteratorArgs(unittest.TestCase):
class Deque(deque):
pass
class DequeWithSlots(deque):
__slots__ = ('x', 'y', '__dict__')
class DequeWithBadIter(deque):
def __iter__(self):
raise TypeError
@ -810,40 +813,28 @@ class TestSubclass(unittest.TestCase):
self.assertEqual(len(d), 0)
def test_copy_pickle(self):
for cls in Deque, DequeWithSlots:
for d in cls('abc'), cls('abcde', maxlen=4):
d.x = ['x']
d.z = ['z']
d = Deque('abc')
e = d.__copy__()
self.assertEqual(type(d), type(e))
self.assertEqual(list(d), list(e))
e = d.__copy__()
self.assertEqual(type(d), type(e))
self.assertEqual(list(d), list(e))
e = cls(d)
self.assertEqual(type(d), type(e))
self.assertEqual(list(d), list(e))
e = Deque(d)
self.assertEqual(type(d), type(e))
self.assertEqual(list(d), list(e))
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
s = pickle.dumps(d, proto)
e = pickle.loads(s)
self.assertNotEqual(id(d), id(e))
self.assertEqual(type(d), type(e))
self.assertEqual(list(d), list(e))
d = Deque('abcde', maxlen=4)
e = d.__copy__()
self.assertEqual(type(d), type(e))
self.assertEqual(list(d), list(e))
e = Deque(d)
self.assertEqual(type(d), type(e))
self.assertEqual(list(d), list(e))
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
s = pickle.dumps(d, proto)
e = pickle.loads(s)
self.assertNotEqual(id(d), id(e))
self.assertEqual(type(d), type(e))
self.assertEqual(list(d), list(e))
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
s = pickle.dumps(d, proto)
e = pickle.loads(s)
self.assertNotEqual(id(d), id(e))
self.assertEqual(type(d), type(e))
self.assertEqual(list(d), list(e))
self.assertEqual(e.x, d.x)
self.assertEqual(e.z, d.z)
self.assertFalse(hasattr(e, 'y'))
def test_pickle_recursive(self):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):

View file

@ -181,6 +181,7 @@ You can get the information from the list type:
'__ge__',
'__getattribute__',
'__getitem__',
'__getstate__',
'__gt__',
'__hash__',
'__iadd__',

View file

@ -287,6 +287,8 @@ class OrderedDictTests:
# and have a repr/eval round-trip
pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
od = OrderedDict(pairs)
od.x = ['x']
od.z = ['z']
def check(dup):
msg = "\ncopy: %s\nod: %s" % (dup, od)
self.assertIsNot(dup, od, msg)
@ -295,13 +297,27 @@ class OrderedDictTests:
self.assertEqual(len(dup), len(od))
self.assertEqual(type(dup), type(od))
check(od.copy())
check(copy.copy(od))
check(copy.deepcopy(od))
dup = copy.copy(od)
check(dup)
self.assertIs(dup.x, od.x)
self.assertIs(dup.z, od.z)
self.assertFalse(hasattr(dup, 'y'))
dup = copy.deepcopy(od)
check(dup)
self.assertEqual(dup.x, od.x)
self.assertIsNot(dup.x, od.x)
self.assertEqual(dup.z, od.z)
self.assertIsNot(dup.z, od.z)
self.assertFalse(hasattr(dup, 'y'))
# pickle directly pulls the module, so we have to fake it
with replaced_module('collections', self.module):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto):
check(pickle.loads(pickle.dumps(od, proto)))
dup = pickle.loads(pickle.dumps(od, proto))
check(dup)
self.assertEqual(dup.x, od.x)
self.assertEqual(dup.z, od.z)
self.assertFalse(hasattr(dup, 'y'))
check(eval(repr(od)))
update_test = OrderedDict()
update_test.update(od)
@ -846,6 +862,23 @@ class CPythonOrderedDictSubclassTests(CPythonOrderedDictTests):
pass
class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
module = py_coll
class OrderedDict(py_coll.OrderedDict):
__slots__ = ('x', 'y')
test_copying = OrderedDictTests.test_copying
@unittest.skipUnless(c_coll, 'requires the C version of the collections module')
class CPythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
module = c_coll
class OrderedDict(c_coll.OrderedDict):
__slots__ = ('x', 'y')
test_copying = OrderedDictTests.test_copying
class PurePythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
@classmethod

View file

@ -227,14 +227,17 @@ class TestJointOps:
def test_pickling(self):
for i in range(pickle.HIGHEST_PROTOCOL + 1):
if type(self.s) not in (set, frozenset):
self.s.x = ['x']
self.s.z = ['z']
p = pickle.dumps(self.s, i)
dup = pickle.loads(p)
self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup))
if type(self.s) not in (set, frozenset):
self.s.x = 10
p = pickle.dumps(self.s, i)
dup = pickle.loads(p)
self.assertEqual(self.s.x, dup.x)
self.assertEqual(self.s.z, dup.z)
self.assertFalse(hasattr(self.s, 'y'))
del self.s.x, self.s.z
def test_iterator_pickling(self):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
@ -808,6 +811,21 @@ class TestFrozenSetSubclass(TestFrozenSet):
# All empty frozenset subclass instances should have different ids
self.assertEqual(len(set(map(id, efs))), len(efs))
class SetSubclassWithSlots(set):
__slots__ = ('x', 'y', '__dict__')
class TestSetSubclassWithSlots(unittest.TestCase):
thetype = SetSubclassWithSlots
setUp = TestJointOps.setUp
test_pickling = TestJointOps.test_pickling
class FrozenSetSubclassWithSlots(frozenset):
__slots__ = ('x', 'y', '__dict__')
class TestFrozenSetSubclassWithSlots(TestSetSubclassWithSlots):
thetype = FrozenSetSubclassWithSlots
# Tests taken from test_sets.py =============================================
empty_set = set()

View file

@ -1,5 +1,6 @@
import unittest
from weakref import WeakSet
import copy
import string
from collections import UserString as ustr
from collections.abc import Set, MutableSet
@ -15,6 +16,12 @@ class RefCycle:
def __init__(self):
self.cycle = self
class WeakSetSubclass(WeakSet):
pass
class WeakSetWithSlots(WeakSet):
__slots__ = ('x', 'y')
class TestWeakSet(unittest.TestCase):
@ -447,6 +454,30 @@ class TestWeakSet(unittest.TestCase):
self.assertIsInstance(self.s, Set)
self.assertIsInstance(self.s, MutableSet)
def test_copying(self):
for cls in WeakSet, WeakSetWithSlots:
s = cls(self.items)
s.x = ['x']
s.z = ['z']
dup = copy.copy(s)
self.assertIsInstance(dup, cls)
self.assertEqual(dup, s)
self.assertIsNot(dup, s)
self.assertIs(dup.x, s.x)
self.assertIs(dup.z, s.z)
self.assertFalse(hasattr(dup, 'y'))
dup = copy.deepcopy(s)
self.assertIsInstance(dup, cls)
self.assertEqual(dup, s)
self.assertIsNot(dup, s)
self.assertEqual(dup.x, s.x)
self.assertIsNot(dup.x, s.x)
self.assertEqual(dup.z, s.z)
self.assertIsNot(dup.z, s.z)
self.assertFalse(hasattr(dup, 'y'))
if __name__ == "__main__":
unittest.main()

View file

@ -2524,8 +2524,7 @@ class BasicElementTest(ElementTestCase, unittest.TestCase):
<group><dogs>4</dogs>
</group>"""
e1 = dumper.fromstring(XMLTEXT)
if hasattr(e1, '__getstate__'):
self.assertEqual(e1.__getstate__()['tag'], 'group')
self.assertEqual(e1.__getstate__()['tag'], 'group')
e2 = self.pickleRoundTrip(e1, 'xml.etree.ElementTree',
dumper, loader, proto)
self.assertEqual(e2.tag, 'group')