bpo-36785: PEP 574 implementation (GH-7076)

This commit is contained in:
Antoine Pitrou 2019-05-26 17:10:09 +02:00 committed by GitHub
parent 22ccb0b490
commit 91f4380ced
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 1888 additions and 242 deletions

View file

@ -16,6 +16,16 @@ import weakref
from textwrap import dedent
from http.cookies import SimpleCookie
try:
import _testbuffer
except ImportError:
_testbuffer = None
try:
import numpy as np
except ImportError:
np = None
from test import support
from test.support import (
TestFailed, TESTFN, run_with_locale, no_tracing,
@ -162,6 +172,139 @@ def create_dynamic_class(name, bases):
result.reduce_args = (name, bases)
return result
class ZeroCopyBytes(bytes):
readonly = True
c_contiguous = True
f_contiguous = True
zero_copy_reconstruct = True
def __reduce_ex__(self, protocol):
if protocol >= 5:
return type(self)._reconstruct, (pickle.PickleBuffer(self),), None
else:
return type(self)._reconstruct, (bytes(self),)
def __repr__(self):
return "{}({!r})".format(self.__class__.__name__, bytes(self))
__str__ = __repr__
@classmethod
def _reconstruct(cls, obj):
with memoryview(obj) as m:
obj = m.obj
if type(obj) is cls:
# Zero-copy
return obj
else:
return cls(obj)
class ZeroCopyBytearray(bytearray):
readonly = False
c_contiguous = True
f_contiguous = True
zero_copy_reconstruct = True
def __reduce_ex__(self, protocol):
if protocol >= 5:
return type(self)._reconstruct, (pickle.PickleBuffer(self),), None
else:
return type(self)._reconstruct, (bytes(self),)
def __repr__(self):
return "{}({!r})".format(self.__class__.__name__, bytes(self))
__str__ = __repr__
@classmethod
def _reconstruct(cls, obj):
with memoryview(obj) as m:
obj = m.obj
if type(obj) is cls:
# Zero-copy
return obj
else:
return cls(obj)
if _testbuffer is not None:
class PicklableNDArray:
# A not-really-zero-copy picklable ndarray, as the ndarray()
# constructor doesn't allow for it
zero_copy_reconstruct = False
def __init__(self, *args, **kwargs):
self.array = _testbuffer.ndarray(*args, **kwargs)
def __getitem__(self, idx):
cls = type(self)
new = cls.__new__(cls)
new.array = self.array[idx]
return new
@property
def readonly(self):
return self.array.readonly
@property
def c_contiguous(self):
return self.array.c_contiguous
@property
def f_contiguous(self):
return self.array.f_contiguous
def __eq__(self, other):
if not isinstance(other, PicklableNDArray):
return NotImplemented
return (other.array.format == self.array.format and
other.array.shape == self.array.shape and
other.array.strides == self.array.strides and
other.array.readonly == self.array.readonly and
other.array.tobytes() == self.array.tobytes())
def __ne__(self, other):
if not isinstance(other, PicklableNDArray):
return NotImplemented
return not (self == other)
def __repr__(self):
return (f"{type(self)}(shape={self.array.shape},"
f"strides={self.array.strides}, "
f"bytes={self.array.tobytes()})")
def __reduce_ex__(self, protocol):
if not self.array.contiguous:
raise NotImplementedError("Reconstructing a non-contiguous "
"ndarray does not seem possible")
ndarray_kwargs = {"shape": self.array.shape,
"strides": self.array.strides,
"format": self.array.format,
"flags": (0 if self.readonly
else _testbuffer.ND_WRITABLE)}
pb = pickle.PickleBuffer(self.array)
if protocol >= 5:
return (type(self)._reconstruct,
(pb, ndarray_kwargs))
else:
# Need to serialize the bytes in physical order
with pb.raw() as m:
return (type(self)._reconstruct,
(m.tobytes(), ndarray_kwargs))
@classmethod
def _reconstruct(cls, obj, kwargs):
with memoryview(obj) as m:
# For some reason, ndarray() wants a list of integers...
# XXX This only works if format == 'B'
items = list(m.tobytes())
return cls(items, **kwargs)
# DATA0 .. DATA4 are the pickles we expect under the various protocols, for
# the object returned by create_data().
@ -888,12 +1031,22 @@ class AbstractUnpickleTests(unittest.TestCase):
dumped = b'\x80\x04\x8d\4\0\0\0\0\0\0\0\xe2\x82\xac\x00.'
self.assertEqual(self.loads(dumped), '\u20ac\x00')
def test_bytearray8(self):
dumped = b'\x80\x05\x96\x03\x00\x00\x00\x00\x00\x00\x00xxx.'
self.assertEqual(self.loads(dumped), bytearray(b'xxx'))
@requires_32b
def test_large_32b_binbytes8(self):
dumped = b'\x80\x04\x8e\4\0\0\0\1\0\0\0\xe2\x82\xac\x00.'
self.check_unpickling_error((pickle.UnpicklingError, OverflowError),
dumped)
@requires_32b
def test_large_32b_bytearray8(self):
dumped = b'\x80\x05\x96\4\0\0\0\1\0\0\0\xe2\x82\xac\x00.'
self.check_unpickling_error((pickle.UnpicklingError, OverflowError),
dumped)
@requires_32b
def test_large_32b_binunicode8(self):
dumped = b'\x80\x04\x8d\4\0\0\0\1\0\0\0\xe2\x82\xac\x00.'
@ -1171,6 +1324,10 @@ class AbstractUnpickleTests(unittest.TestCase):
b'\x8e\x03\x00\x00\x00\x00\x00\x00',
b'\x8e\x03\x00\x00\x00\x00\x00\x00\x00',
b'\x8e\x03\x00\x00\x00\x00\x00\x00\x00ab',
b'\x96', # BYTEARRAY8
b'\x96\x03\x00\x00\x00\x00\x00\x00',
b'\x96\x03\x00\x00\x00\x00\x00\x00\x00',
b'\x96\x03\x00\x00\x00\x00\x00\x00\x00ab',
b'\x95', # FRAME
b'\x95\x02\x00\x00\x00\x00\x00\x00',
b'\x95\x02\x00\x00\x00\x00\x00\x00\x00',
@ -1482,6 +1639,25 @@ class AbstractPickleTests(unittest.TestCase):
p = self.dumps(s, proto)
self.assert_is_copy(s, self.loads(p))
def test_bytearray(self):
for proto in protocols:
for s in b'', b'xyz', b'xyz'*100:
b = bytearray(s)
p = self.dumps(b, proto)
bb = self.loads(p)
self.assertIsNot(bb, b)
self.assert_is_copy(b, bb)
if proto <= 3:
# bytearray is serialized using a global reference
self.assertIn(b'bytearray', p)
self.assertTrue(opcode_in_pickle(pickle.GLOBAL, p))
elif proto == 4:
self.assertIn(b'bytearray', p)
self.assertTrue(opcode_in_pickle(pickle.STACK_GLOBAL, p))
elif proto == 5:
self.assertNotIn(b'bytearray', p)
self.assertTrue(opcode_in_pickle(pickle.BYTEARRAY8, p))
def test_ints(self):
for proto in protocols:
n = sys.maxsize
@ -2114,7 +2290,8 @@ class AbstractPickleTests(unittest.TestCase):
the following consistency check.
"""
frame_end = frameless_start = None
frameless_opcodes = {'BINBYTES', 'BINUNICODE', 'BINBYTES8', 'BINUNICODE8'}
frameless_opcodes = {'BINBYTES', 'BINUNICODE', 'BINBYTES8',
'BINUNICODE8', 'BYTEARRAY8'}
for op, arg, pos in pickletools.genops(pickled):
if frame_end is not None:
self.assertLessEqual(pos, frame_end)
@ -2225,19 +2402,20 @@ class AbstractPickleTests(unittest.TestCase):
num_frames = 20
# Large byte objects (dict values) intermittent with small objects
# (dict keys)
obj = {i: bytes([i]) * frame_size for i in range(num_frames)}
for bytes_type in (bytes, bytearray):
obj = {i: bytes_type([i]) * frame_size for i in range(num_frames)}
for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
pickled = self.dumps(obj, proto)
for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
pickled = self.dumps(obj, proto)
frameless_pickle = remove_frames(pickled)
self.assertEqual(count_opcode(pickle.FRAME, frameless_pickle), 0)
self.assertEqual(obj, self.loads(frameless_pickle))
frameless_pickle = remove_frames(pickled)
self.assertEqual(count_opcode(pickle.FRAME, frameless_pickle), 0)
self.assertEqual(obj, self.loads(frameless_pickle))
some_frames_pickle = remove_frames(pickled, lambda i: i % 2)
self.assertLess(count_opcode(pickle.FRAME, some_frames_pickle),
count_opcode(pickle.FRAME, pickled))
self.assertEqual(obj, self.loads(some_frames_pickle))
some_frames_pickle = remove_frames(pickled, lambda i: i % 2)
self.assertLess(count_opcode(pickle.FRAME, some_frames_pickle),
count_opcode(pickle.FRAME, pickled))
self.assertEqual(obj, self.loads(some_frames_pickle))
def test_framed_write_sizes_with_delayed_writer(self):
class ChunkAccumulator:
@ -2452,6 +2630,186 @@ class AbstractPickleTests(unittest.TestCase):
with self.assertRaises((AttributeError, pickle.PicklingError)):
pickletools.dis(self.dumps(f, proto))
#
# PEP 574 tests below
#
def buffer_like_objects(self):
# Yield buffer-like objects with the bytestring "abcdef" in them
bytestring = b"abcdefgh"
yield ZeroCopyBytes(bytestring)
yield ZeroCopyBytearray(bytestring)
if _testbuffer is not None:
items = list(bytestring)
value = int.from_bytes(bytestring, byteorder='little')
for flags in (0, _testbuffer.ND_WRITABLE):
# 1-D, contiguous
yield PicklableNDArray(items, format='B', shape=(8,),
flags=flags)
# 2-D, C-contiguous
yield PicklableNDArray(items, format='B', shape=(4, 2),
strides=(2, 1), flags=flags)
# 2-D, Fortran-contiguous
yield PicklableNDArray(items, format='B',
shape=(4, 2), strides=(1, 4),
flags=flags)
def test_in_band_buffers(self):
# Test in-band buffers (PEP 574)
for obj in self.buffer_like_objects():
for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
data = self.dumps(obj, proto)
if obj.c_contiguous and proto >= 5:
# The raw memory bytes are serialized in physical order
self.assertIn(b"abcdefgh", data)
self.assertEqual(count_opcode(pickle.NEXT_BUFFER, data), 0)
if proto >= 5:
self.assertEqual(count_opcode(pickle.SHORT_BINBYTES, data),
1 if obj.readonly else 0)
self.assertEqual(count_opcode(pickle.BYTEARRAY8, data),
0 if obj.readonly else 1)
# Return a true value from buffer_callback should have
# the same effect
def buffer_callback(obj):
return True
data2 = self.dumps(obj, proto,
buffer_callback=buffer_callback)
self.assertEqual(data2, data)
new = self.loads(data)
# It's a copy
self.assertIsNot(new, obj)
self.assertIs(type(new), type(obj))
self.assertEqual(new, obj)
# XXX Unfortunately cannot test non-contiguous array
# (see comment in PicklableNDArray.__reduce_ex__)
def test_oob_buffers(self):
# Test out-of-band buffers (PEP 574)
for obj in self.buffer_like_objects():
for proto in range(0, 5):
# Need protocol >= 5 for buffer_callback
with self.assertRaises(ValueError):
self.dumps(obj, proto,
buffer_callback=[].append)
for proto in range(5, pickle.HIGHEST_PROTOCOL + 1):
buffers = []
buffer_callback = lambda pb: buffers.append(pb.raw())
data = self.dumps(obj, proto,
buffer_callback=buffer_callback)
self.assertNotIn(b"abcdefgh", data)
self.assertEqual(count_opcode(pickle.SHORT_BINBYTES, data), 0)
self.assertEqual(count_opcode(pickle.BYTEARRAY8, data), 0)
self.assertEqual(count_opcode(pickle.NEXT_BUFFER, data), 1)
self.assertEqual(count_opcode(pickle.READONLY_BUFFER, data),
1 if obj.readonly else 0)
if obj.c_contiguous:
self.assertEqual(bytes(buffers[0]), b"abcdefgh")
# Need buffers argument to unpickle properly
with self.assertRaises(pickle.UnpicklingError):
self.loads(data)
new = self.loads(data, buffers=buffers)
if obj.zero_copy_reconstruct:
# Zero-copy achieved
self.assertIs(new, obj)
else:
self.assertIs(type(new), type(obj))
self.assertEqual(new, obj)
# Non-sequence buffers accepted too
new = self.loads(data, buffers=iter(buffers))
if obj.zero_copy_reconstruct:
# Zero-copy achieved
self.assertIs(new, obj)
else:
self.assertIs(type(new), type(obj))
self.assertEqual(new, obj)
def test_oob_buffers_writable_to_readonly(self):
# Test reconstructing readonly object from writable buffer
obj = ZeroCopyBytes(b"foobar")
for proto in range(5, pickle.HIGHEST_PROTOCOL + 1):
buffers = []
buffer_callback = buffers.append
data = self.dumps(obj, proto, buffer_callback=buffer_callback)
buffers = map(bytearray, buffers)
new = self.loads(data, buffers=buffers)
self.assertIs(type(new), type(obj))
self.assertEqual(new, obj)
def test_picklebuffer_error(self):
# PickleBuffer forbidden with protocol < 5
pb = pickle.PickleBuffer(b"foobar")
for proto in range(0, 5):
with self.assertRaises(pickle.PickleError):
self.dumps(pb, proto)
def test_buffer_callback_error(self):
def buffer_callback(buffers):
1/0
pb = pickle.PickleBuffer(b"foobar")
with self.assertRaises(ZeroDivisionError):
self.dumps(pb, 5, buffer_callback=buffer_callback)
def test_buffers_error(self):
pb = pickle.PickleBuffer(b"foobar")
for proto in range(5, pickle.HIGHEST_PROTOCOL + 1):
data = self.dumps(pb, proto, buffer_callback=[].append)
# Non iterable buffers
with self.assertRaises(TypeError):
self.loads(data, buffers=object())
# Buffer iterable exhausts too early
with self.assertRaises(pickle.UnpicklingError):
self.loads(data, buffers=[])
@unittest.skipIf(np is None, "Test needs Numpy")
def test_buffers_numpy(self):
def check_no_copy(x, y):
np.testing.assert_equal(x, y)
self.assertEqual(x.ctypes.data, y.ctypes.data)
def check_copy(x, y):
np.testing.assert_equal(x, y)
self.assertNotEqual(x.ctypes.data, y.ctypes.data)
def check_array(arr):
# In-band
for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
data = self.dumps(arr, proto)
new = self.loads(data)
check_copy(arr, new)
for proto in range(5, pickle.HIGHEST_PROTOCOL + 1):
buffer_callback = lambda _: True
data = self.dumps(arr, proto, buffer_callback=buffer_callback)
new = self.loads(data)
check_copy(arr, new)
# Out-of-band
for proto in range(5, pickle.HIGHEST_PROTOCOL + 1):
buffers = []
buffer_callback = buffers.append
data = self.dumps(arr, proto, buffer_callback=buffer_callback)
new = self.loads(data, buffers=buffers)
if arr.flags.c_contiguous or arr.flags.f_contiguous:
check_no_copy(arr, new)
else:
check_copy(arr, new)
# 1-D
arr = np.arange(6)
check_array(arr)
# 1-D, non-contiguous
check_array(arr[::2])
# 2-D, C-contiguous
arr = np.arange(12).reshape((3, 4))
check_array(arr)
# 2-D, F-contiguous
check_array(arr.T)
# 2-D, non-contiguous
check_array(arr[::2])
class BigmemPickleTests(unittest.TestCase):
@ -2736,7 +3094,7 @@ class AbstractPickleModuleTests(unittest.TestCase):
def test_highest_protocol(self):
# Of course this needs to be changed when HIGHEST_PROTOCOL changes.
self.assertEqual(pickle.HIGHEST_PROTOCOL, 4)
self.assertEqual(pickle.HIGHEST_PROTOCOL, 5)
def test_callapi(self):
f = io.BytesIO()
@ -2760,6 +3118,47 @@ class AbstractPickleModuleTests(unittest.TestCase):
self.assertRaises(pickle.PicklingError, BadPickler().dump, 0)
self.assertRaises(pickle.UnpicklingError, BadUnpickler().load)
def check_dumps_loads_oob_buffers(self, dumps, loads):
# No need to do the full gamut of tests here, just enough to
# check that dumps() and loads() redirect their arguments
# to the underlying Pickler and Unpickler, respectively.
obj = ZeroCopyBytes(b"foo")
for proto in range(0, 5):
# Need protocol >= 5 for buffer_callback
with self.assertRaises(ValueError):
dumps(obj, protocol=proto,
buffer_callback=[].append)
for proto in range(5, pickle.HIGHEST_PROTOCOL + 1):
buffers = []
buffer_callback = buffers.append
data = dumps(obj, protocol=proto,
buffer_callback=buffer_callback)
self.assertNotIn(b"foo", data)
self.assertEqual(bytes(buffers[0]), b"foo")
# Need buffers argument to unpickle properly
with self.assertRaises(pickle.UnpicklingError):
loads(data)
new = loads(data, buffers=buffers)
self.assertIs(new, obj)
def test_dumps_loads_oob_buffers(self):
# Test out-of-band buffers (PEP 574) with top-level dumps() and loads()
self.check_dumps_loads_oob_buffers(self.dumps, self.loads)
def test_dump_load_oob_buffers(self):
# Test out-of-band buffers (PEP 574) with top-level dump() and load()
def dumps(obj, **kwargs):
f = io.BytesIO()
self.dump(obj, f, **kwargs)
return f.getvalue()
def loads(data, **kwargs):
f = io.BytesIO(data)
return self.load(f, **kwargs)
self.check_dumps_loads_oob_buffers(dumps, loads)
class AbstractPersistentPicklerTests(unittest.TestCase):