mirror of
https://github.com/python/cpython.git
synced 2025-08-04 08:59:19 +00:00
bpo-26579: Add object.__getstate__(). (GH-2821)
Copying and pickling instances of subclasses of builtin types bytearray, set, frozenset, collections.OrderedDict, collections.deque, weakref.WeakSet, and datetime.tzinfo now copies and pickles instance attributes implemented as slots.
This commit is contained in:
parent
f82f9ce323
commit
884eba3c76
25 changed files with 389 additions and 255 deletions
|
@ -80,8 +80,7 @@ class WeakSet:
|
|||
return wr in self.data
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (list(self),),
|
||||
getattr(self, '__dict__', None))
|
||||
return self.__class__, (list(self),), self.__getstate__()
|
||||
|
||||
def add(self, item):
|
||||
if self._pending_removals:
|
||||
|
|
|
@ -271,10 +271,22 @@ class OrderedDict(dict):
|
|||
|
||||
def __reduce__(self):
|
||||
'Return state information for pickling'
|
||||
inst_dict = vars(self).copy()
|
||||
for k in vars(OrderedDict()):
|
||||
inst_dict.pop(k, None)
|
||||
return self.__class__, (), inst_dict or None, None, iter(self.items())
|
||||
state = self.__getstate__()
|
||||
if state:
|
||||
if isinstance(state, tuple):
|
||||
state, slots = state
|
||||
else:
|
||||
slots = {}
|
||||
state = state.copy()
|
||||
slots = slots.copy()
|
||||
for k in vars(OrderedDict()):
|
||||
state.pop(k, None)
|
||||
slots.pop(k, None)
|
||||
if slots:
|
||||
state = state, slots
|
||||
else:
|
||||
state = state or None
|
||||
return self.__class__, (), state, None, iter(self.items())
|
||||
|
||||
def copy(self):
|
||||
'od.copy() -> a shallow copy of od'
|
||||
|
|
|
@ -89,6 +89,10 @@ def _reduce_ex(self, proto):
|
|||
except AttributeError:
|
||||
dict = None
|
||||
else:
|
||||
if (type(self).__getstate__ is object.__getstate__ and
|
||||
getattr(self, "__slots__", None)):
|
||||
raise TypeError("a class that defines __slots__ without "
|
||||
"defining __getstate__ cannot be pickled")
|
||||
dict = getstate()
|
||||
if dict:
|
||||
return _reconstructor, args, dict
|
||||
|
|
|
@ -1169,15 +1169,7 @@ class tzinfo:
|
|||
args = getinitargs()
|
||||
else:
|
||||
args = ()
|
||||
getstate = getattr(self, "__getstate__", None)
|
||||
if getstate:
|
||||
state = getstate()
|
||||
else:
|
||||
state = getattr(self, "__dict__", None) or None
|
||||
if state is None:
|
||||
return (self.__class__, args)
|
||||
else:
|
||||
return (self.__class__, args, state)
|
||||
return (self.__class__, args, self.__getstate__())
|
||||
|
||||
|
||||
class IsoCalendarDate(tuple):
|
||||
|
|
|
@ -218,7 +218,7 @@ class BaseHeader(str):
|
|||
self.__class__.__bases__,
|
||||
str(self),
|
||||
),
|
||||
self.__dict__)
|
||||
self.__getstate__())
|
||||
|
||||
@classmethod
|
||||
def _reconstruct(cls, value):
|
||||
|
|
|
@ -139,8 +139,8 @@ class PicklableFixedOffset(FixedOffset):
|
|||
def __init__(self, offset=None, name=None, dstoffset=None):
|
||||
FixedOffset.__init__(self, offset, name, dstoffset)
|
||||
|
||||
def __getstate__(self):
|
||||
return self.__dict__
|
||||
class PicklableFixedOffsetWithSlots(PicklableFixedOffset):
|
||||
__slots__ = '_FixedOffset__offset', '_FixedOffset__name', 'spam'
|
||||
|
||||
class _TZInfo(tzinfo):
|
||||
def utcoffset(self, datetime_module):
|
||||
|
@ -202,6 +202,7 @@ class TestTZInfo(unittest.TestCase):
|
|||
offset = timedelta(minutes=-300)
|
||||
for otype, args in [
|
||||
(PicklableFixedOffset, (offset, 'cookie')),
|
||||
(PicklableFixedOffsetWithSlots, (offset, 'cookie')),
|
||||
(timezone, (offset,)),
|
||||
(timezone, (offset, "EST"))]:
|
||||
orig = otype(*args)
|
||||
|
@ -217,6 +218,7 @@ class TestTZInfo(unittest.TestCase):
|
|||
self.assertIs(type(derived), otype)
|
||||
self.assertEqual(derived.utcoffset(None), offset)
|
||||
self.assertEqual(derived.tzname(None), oname)
|
||||
self.assertFalse(hasattr(derived, 'spam'))
|
||||
|
||||
def test_issue23600(self):
|
||||
DSTDIFF = DSTOFFSET = timedelta(hours=1)
|
||||
|
|
|
@ -2382,9 +2382,11 @@ class AbstractPickleTests:
|
|||
def test_bad_getattr(self):
|
||||
# Issue #3514: crash when there is an infinite loop in __getattr__
|
||||
x = BadGetattr()
|
||||
for proto in protocols:
|
||||
for proto in range(2):
|
||||
with support.infinite_recursion():
|
||||
self.assertRaises(RuntimeError, self.dumps, x, proto)
|
||||
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
|
||||
s = self.dumps(x, proto)
|
||||
|
||||
def test_reduce_bad_iterator(self):
|
||||
# Issue4176: crash when 4th and 5th items of __reduce__()
|
||||
|
|
|
@ -1940,28 +1940,30 @@ class SubclassTest:
|
|||
def test_pickle(self):
|
||||
a = self.type2test(b"abcd")
|
||||
a.x = 10
|
||||
a.y = self.type2test(b"efgh")
|
||||
a.z = self.type2test(b"efgh")
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
b = pickle.loads(pickle.dumps(a, proto))
|
||||
self.assertNotEqual(id(a), id(b))
|
||||
self.assertEqual(a, b)
|
||||
self.assertEqual(a.x, b.x)
|
||||
self.assertEqual(a.y, b.y)
|
||||
self.assertEqual(a.z, b.z)
|
||||
self.assertEqual(type(a), type(b))
|
||||
self.assertEqual(type(a.y), type(b.y))
|
||||
self.assertEqual(type(a.z), type(b.z))
|
||||
self.assertFalse(hasattr(b, 'y'))
|
||||
|
||||
def test_copy(self):
|
||||
a = self.type2test(b"abcd")
|
||||
a.x = 10
|
||||
a.y = self.type2test(b"efgh")
|
||||
a.z = self.type2test(b"efgh")
|
||||
for copy_method in (copy.copy, copy.deepcopy):
|
||||
b = copy_method(a)
|
||||
self.assertNotEqual(id(a), id(b))
|
||||
self.assertEqual(a, b)
|
||||
self.assertEqual(a.x, b.x)
|
||||
self.assertEqual(a.y, b.y)
|
||||
self.assertEqual(a.z, b.z)
|
||||
self.assertEqual(type(a), type(b))
|
||||
self.assertEqual(type(a.y), type(b.y))
|
||||
self.assertEqual(type(a.z), type(b.z))
|
||||
self.assertFalse(hasattr(b, 'y'))
|
||||
|
||||
def test_fromhex(self):
|
||||
b = self.type2test.fromhex('1a2B30')
|
||||
|
@ -1994,6 +1996,9 @@ class SubclassTest:
|
|||
class ByteArraySubclass(bytearray):
|
||||
pass
|
||||
|
||||
class ByteArraySubclassWithSlots(bytearray):
|
||||
__slots__ = ('x', 'y', '__dict__')
|
||||
|
||||
class BytesSubclass(bytes):
|
||||
pass
|
||||
|
||||
|
@ -2014,6 +2019,9 @@ class ByteArraySubclassTest(SubclassTest, unittest.TestCase):
|
|||
x = subclass(newarg=4, source=b"abcd")
|
||||
self.assertEqual(x, b"abcd")
|
||||
|
||||
class ByteArraySubclassWithSlotsTest(SubclassTest, unittest.TestCase):
|
||||
basetype = bytearray
|
||||
type2test = ByteArraySubclassWithSlots
|
||||
|
||||
class BytesSubclassTest(SubclassTest, unittest.TestCase):
|
||||
basetype = bytes
|
||||
|
|
|
@ -781,6 +781,9 @@ class TestVariousIteratorArgs(unittest.TestCase):
|
|||
class Deque(deque):
|
||||
pass
|
||||
|
||||
class DequeWithSlots(deque):
|
||||
__slots__ = ('x', 'y', '__dict__')
|
||||
|
||||
class DequeWithBadIter(deque):
|
||||
def __iter__(self):
|
||||
raise TypeError
|
||||
|
@ -810,40 +813,28 @@ class TestSubclass(unittest.TestCase):
|
|||
self.assertEqual(len(d), 0)
|
||||
|
||||
def test_copy_pickle(self):
|
||||
for cls in Deque, DequeWithSlots:
|
||||
for d in cls('abc'), cls('abcde', maxlen=4):
|
||||
d.x = ['x']
|
||||
d.z = ['z']
|
||||
|
||||
d = Deque('abc')
|
||||
e = d.__copy__()
|
||||
self.assertEqual(type(d), type(e))
|
||||
self.assertEqual(list(d), list(e))
|
||||
|
||||
e = d.__copy__()
|
||||
self.assertEqual(type(d), type(e))
|
||||
self.assertEqual(list(d), list(e))
|
||||
e = cls(d)
|
||||
self.assertEqual(type(d), type(e))
|
||||
self.assertEqual(list(d), list(e))
|
||||
|
||||
e = Deque(d)
|
||||
self.assertEqual(type(d), type(e))
|
||||
self.assertEqual(list(d), list(e))
|
||||
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
s = pickle.dumps(d, proto)
|
||||
e = pickle.loads(s)
|
||||
self.assertNotEqual(id(d), id(e))
|
||||
self.assertEqual(type(d), type(e))
|
||||
self.assertEqual(list(d), list(e))
|
||||
|
||||
d = Deque('abcde', maxlen=4)
|
||||
|
||||
e = d.__copy__()
|
||||
self.assertEqual(type(d), type(e))
|
||||
self.assertEqual(list(d), list(e))
|
||||
|
||||
e = Deque(d)
|
||||
self.assertEqual(type(d), type(e))
|
||||
self.assertEqual(list(d), list(e))
|
||||
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
s = pickle.dumps(d, proto)
|
||||
e = pickle.loads(s)
|
||||
self.assertNotEqual(id(d), id(e))
|
||||
self.assertEqual(type(d), type(e))
|
||||
self.assertEqual(list(d), list(e))
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
s = pickle.dumps(d, proto)
|
||||
e = pickle.loads(s)
|
||||
self.assertNotEqual(id(d), id(e))
|
||||
self.assertEqual(type(d), type(e))
|
||||
self.assertEqual(list(d), list(e))
|
||||
self.assertEqual(e.x, d.x)
|
||||
self.assertEqual(e.z, d.z)
|
||||
self.assertFalse(hasattr(e, 'y'))
|
||||
|
||||
def test_pickle_recursive(self):
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
|
|
|
@ -181,6 +181,7 @@ You can get the information from the list type:
|
|||
'__ge__',
|
||||
'__getattribute__',
|
||||
'__getitem__',
|
||||
'__getstate__',
|
||||
'__gt__',
|
||||
'__hash__',
|
||||
'__iadd__',
|
||||
|
|
|
@ -287,6 +287,8 @@ class OrderedDictTests:
|
|||
# and have a repr/eval round-trip
|
||||
pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
|
||||
od = OrderedDict(pairs)
|
||||
od.x = ['x']
|
||||
od.z = ['z']
|
||||
def check(dup):
|
||||
msg = "\ncopy: %s\nod: %s" % (dup, od)
|
||||
self.assertIsNot(dup, od, msg)
|
||||
|
@ -295,13 +297,27 @@ class OrderedDictTests:
|
|||
self.assertEqual(len(dup), len(od))
|
||||
self.assertEqual(type(dup), type(od))
|
||||
check(od.copy())
|
||||
check(copy.copy(od))
|
||||
check(copy.deepcopy(od))
|
||||
dup = copy.copy(od)
|
||||
check(dup)
|
||||
self.assertIs(dup.x, od.x)
|
||||
self.assertIs(dup.z, od.z)
|
||||
self.assertFalse(hasattr(dup, 'y'))
|
||||
dup = copy.deepcopy(od)
|
||||
check(dup)
|
||||
self.assertEqual(dup.x, od.x)
|
||||
self.assertIsNot(dup.x, od.x)
|
||||
self.assertEqual(dup.z, od.z)
|
||||
self.assertIsNot(dup.z, od.z)
|
||||
self.assertFalse(hasattr(dup, 'y'))
|
||||
# pickle directly pulls the module, so we have to fake it
|
||||
with replaced_module('collections', self.module):
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
with self.subTest(proto=proto):
|
||||
check(pickle.loads(pickle.dumps(od, proto)))
|
||||
dup = pickle.loads(pickle.dumps(od, proto))
|
||||
check(dup)
|
||||
self.assertEqual(dup.x, od.x)
|
||||
self.assertEqual(dup.z, od.z)
|
||||
self.assertFalse(hasattr(dup, 'y'))
|
||||
check(eval(repr(od)))
|
||||
update_test = OrderedDict()
|
||||
update_test.update(od)
|
||||
|
@ -846,6 +862,23 @@ class CPythonOrderedDictSubclassTests(CPythonOrderedDictTests):
|
|||
pass
|
||||
|
||||
|
||||
class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
|
||||
|
||||
module = py_coll
|
||||
class OrderedDict(py_coll.OrderedDict):
|
||||
__slots__ = ('x', 'y')
|
||||
test_copying = OrderedDictTests.test_copying
|
||||
|
||||
|
||||
@unittest.skipUnless(c_coll, 'requires the C version of the collections module')
|
||||
class CPythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
|
||||
|
||||
module = c_coll
|
||||
class OrderedDict(c_coll.OrderedDict):
|
||||
__slots__ = ('x', 'y')
|
||||
test_copying = OrderedDictTests.test_copying
|
||||
|
||||
|
||||
class PurePythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -227,14 +227,17 @@ class TestJointOps:
|
|||
|
||||
def test_pickling(self):
|
||||
for i in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
if type(self.s) not in (set, frozenset):
|
||||
self.s.x = ['x']
|
||||
self.s.z = ['z']
|
||||
p = pickle.dumps(self.s, i)
|
||||
dup = pickle.loads(p)
|
||||
self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup))
|
||||
if type(self.s) not in (set, frozenset):
|
||||
self.s.x = 10
|
||||
p = pickle.dumps(self.s, i)
|
||||
dup = pickle.loads(p)
|
||||
self.assertEqual(self.s.x, dup.x)
|
||||
self.assertEqual(self.s.z, dup.z)
|
||||
self.assertFalse(hasattr(self.s, 'y'))
|
||||
del self.s.x, self.s.z
|
||||
|
||||
def test_iterator_pickling(self):
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
|
@ -808,6 +811,21 @@ class TestFrozenSetSubclass(TestFrozenSet):
|
|||
# All empty frozenset subclass instances should have different ids
|
||||
self.assertEqual(len(set(map(id, efs))), len(efs))
|
||||
|
||||
|
||||
class SetSubclassWithSlots(set):
|
||||
__slots__ = ('x', 'y', '__dict__')
|
||||
|
||||
class TestSetSubclassWithSlots(unittest.TestCase):
|
||||
thetype = SetSubclassWithSlots
|
||||
setUp = TestJointOps.setUp
|
||||
test_pickling = TestJointOps.test_pickling
|
||||
|
||||
class FrozenSetSubclassWithSlots(frozenset):
|
||||
__slots__ = ('x', 'y', '__dict__')
|
||||
|
||||
class TestFrozenSetSubclassWithSlots(TestSetSubclassWithSlots):
|
||||
thetype = FrozenSetSubclassWithSlots
|
||||
|
||||
# Tests taken from test_sets.py =============================================
|
||||
|
||||
empty_set = set()
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import unittest
|
||||
from weakref import WeakSet
|
||||
import copy
|
||||
import string
|
||||
from collections import UserString as ustr
|
||||
from collections.abc import Set, MutableSet
|
||||
|
@ -15,6 +16,12 @@ class RefCycle:
|
|||
def __init__(self):
|
||||
self.cycle = self
|
||||
|
||||
class WeakSetSubclass(WeakSet):
|
||||
pass
|
||||
|
||||
class WeakSetWithSlots(WeakSet):
|
||||
__slots__ = ('x', 'y')
|
||||
|
||||
|
||||
class TestWeakSet(unittest.TestCase):
|
||||
|
||||
|
@ -447,6 +454,30 @@ class TestWeakSet(unittest.TestCase):
|
|||
self.assertIsInstance(self.s, Set)
|
||||
self.assertIsInstance(self.s, MutableSet)
|
||||
|
||||
def test_copying(self):
|
||||
for cls in WeakSet, WeakSetWithSlots:
|
||||
s = cls(self.items)
|
||||
s.x = ['x']
|
||||
s.z = ['z']
|
||||
|
||||
dup = copy.copy(s)
|
||||
self.assertIsInstance(dup, cls)
|
||||
self.assertEqual(dup, s)
|
||||
self.assertIsNot(dup, s)
|
||||
self.assertIs(dup.x, s.x)
|
||||
self.assertIs(dup.z, s.z)
|
||||
self.assertFalse(hasattr(dup, 'y'))
|
||||
|
||||
dup = copy.deepcopy(s)
|
||||
self.assertIsInstance(dup, cls)
|
||||
self.assertEqual(dup, s)
|
||||
self.assertIsNot(dup, s)
|
||||
self.assertEqual(dup.x, s.x)
|
||||
self.assertIsNot(dup.x, s.x)
|
||||
self.assertEqual(dup.z, s.z)
|
||||
self.assertIsNot(dup.z, s.z)
|
||||
self.assertFalse(hasattr(dup, 'y'))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -2524,8 +2524,7 @@ class BasicElementTest(ElementTestCase, unittest.TestCase):
|
|||
<group><dogs>4</dogs>
|
||||
</group>"""
|
||||
e1 = dumper.fromstring(XMLTEXT)
|
||||
if hasattr(e1, '__getstate__'):
|
||||
self.assertEqual(e1.__getstate__()['tag'], 'group')
|
||||
self.assertEqual(e1.__getstate__()['tag'], 'group')
|
||||
e2 = self.pickleRoundTrip(e1, 'xml.etree.ElementTree',
|
||||
dumper, loader, proto)
|
||||
self.assertEqual(e2.tag, 'group')
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue