bpo-35900: Enable custom reduction callback registration in _pickle (GH-12499)

Enable custom reduction callback registration for functions and classes in
_pickle.c, using the new Pickler's attribute ``reducer_override``.
This commit is contained in:
Pierre Glaser 2019-05-08 23:08:25 +02:00 committed by Antoine Pitrou
parent 9a4135e939
commit 289f1f80ee
6 changed files with 227 additions and 24 deletions

View file

@ -497,34 +497,42 @@ class _Pickler:
self.write(self.get(x[0]))
return
# Check the type dispatch table
t = type(obj)
f = self.dispatch.get(t)
if f is not None:
f(self, obj) # Call unbound method with explicit self
return
# Check private dispatch table if any, or else copyreg.dispatch_table
reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
rv = NotImplemented
reduce = getattr(self, "reducer_override", None)
if reduce is not None:
rv = reduce(obj)
else:
# Check for a class with a custom metaclass; treat as regular class
if issubclass(t, type):
self.save_global(obj)
if rv is NotImplemented:
# Check the type dispatch table
t = type(obj)
f = self.dispatch.get(t)
if f is not None:
f(self, obj) # Call unbound method with explicit self
return
# Check for a __reduce_ex__ method, fall back to __reduce__
reduce = getattr(obj, "__reduce_ex__", None)
# Check private dispatch table if any, or else
# copyreg.dispatch_table
reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
if reduce is not None:
rv = reduce(self.proto)
rv = reduce(obj)
else:
reduce = getattr(obj, "__reduce__", None)
# Check for a class with a custom metaclass; treat as regular
# class
if issubclass(t, type):
self.save_global(obj)
return
# Check for a __reduce_ex__ method, fall back to __reduce__
reduce = getattr(obj, "__reduce_ex__", None)
if reduce is not None:
rv = reduce()
rv = reduce(self.proto)
else:
raise PicklingError("Can't pickle %r object: %r" %
(t.__name__, obj))
reduce = getattr(obj, "__reduce__", None)
if reduce is not None:
rv = reduce()
else:
raise PicklingError("Can't pickle %r object: %r" %
(t.__name__, obj))
# Check for string returned by reduce(), meaning "save as global"
if isinstance(rv, str):

View file

@ -4,6 +4,7 @@ import dbm
import io
import functools
import os
import math
import pickle
import pickletools
import shutil
@ -3013,6 +3014,73 @@ def setstate_bbb(obj, state):
obj.a = "custom state_setter"
class AbstractCustomPicklerClass:
"""Pickler implementing a reducing hook using reducer_override."""
def reducer_override(self, obj):
obj_name = getattr(obj, "__name__", None)
if obj_name == 'f':
# asking the pickler to save f as 5
return int, (5, )
if obj_name == 'MyClass':
return str, ('some str',)
elif obj_name == 'g':
# in this case, the callback returns an invalid result (not a 2-5
# tuple or a string), the pickler should raise a proper error.
return False
elif obj_name == 'h':
# Simulate a case when the reducer fails. The error should
# be propagated to the original ``dump`` call.
raise ValueError('The reducer just failed')
return NotImplemented
class AbstractHookTests(unittest.TestCase):
def test_pickler_hook(self):
# test the ability of a custom, user-defined CPickler subclass to
# override the default reducing routines of any type using the method
# reducer_override
def f():
pass
def g():
pass
def h():
pass
class MyClass:
pass
for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto):
bio = io.BytesIO()
p = self.pickler_class(bio, proto)
p.dump([f, MyClass, math.log])
new_f, some_str, math_log = pickle.loads(bio.getvalue())
self.assertEqual(new_f, 5)
self.assertEqual(some_str, 'some str')
# math.log does not have its usual reducer overriden, so the
# custom reduction callback should silently direct the pickler
# to the default pickling by attribute, by returning
# NotImplemented
self.assertIs(math_log, math.log)
with self.assertRaises(pickle.PicklingError):
p.dump(g)
with self.assertRaisesRegex(
ValueError, 'The reducer just failed'):
p.dump(h)
class AbstractDispatchTableTests(unittest.TestCase):
def test_default_dispatch_table(self):

View file

@ -11,6 +11,7 @@ import weakref
import unittest
from test import support
from test.pickletester import AbstractHookTests
from test.pickletester import AbstractUnpickleTests
from test.pickletester import AbstractPickleTests
from test.pickletester import AbstractPickleModuleTests
@ -18,6 +19,7 @@ from test.pickletester import AbstractPersistentPicklerTests
from test.pickletester import AbstractIdentityPersistentPicklerTests
from test.pickletester import AbstractPicklerUnpicklerObjectTests
from test.pickletester import AbstractDispatchTableTests
from test.pickletester import AbstractCustomPicklerClass
from test.pickletester import BigmemPickleTests
try:
@ -253,12 +255,23 @@ if has_c_implementation:
def get_dispatch_table(self):
return collections.ChainMap({}, pickle.dispatch_table)
class PyPicklerHookTests(AbstractHookTests):
class CustomPyPicklerClass(pickle._Pickler,
AbstractCustomPicklerClass):
pass
pickler_class = CustomPyPicklerClass
class CPicklerHookTests(AbstractHookTests):
class CustomCPicklerClass(_pickle.Pickler, AbstractCustomPicklerClass):
pass
pickler_class = CustomCPicklerClass
@support.cpython_only
class SizeofTests(unittest.TestCase):
check_sizeof = support.check_sizeof
def test_pickler(self):
basesize = support.calcobjsize('6P2n3i2n3iP')
basesize = support.calcobjsize('6P2n3i2n3i2P')
p = _pickle.Pickler(io.BytesIO())
self.assertEqual(object.__sizeof__(p), basesize)
MT_size = struct.calcsize('3nP0n')
@ -498,7 +511,7 @@ def test_main():
tests = [PyPickleTests, PyUnpicklerTests, PyPicklerTests,
PyPersPicklerTests, PyIdPersPicklerTests,
PyDispatchTableTests, PyChainDispatchTableTests,
CompatPickleTests]
CompatPickleTests, PyPicklerHookTests]
if has_c_implementation:
tests.extend([CPickleTests, CUnpicklerTests, CPicklerTests,
CPersPicklerTests, CIdPersPicklerTests,
@ -506,6 +519,7 @@ def test_main():
PyPicklerUnpicklerObjectTests,
CPicklerUnpicklerObjectTests,
CDispatchTableTests, CChainDispatchTableTests,
CPicklerHookTests,
InMemoryPickleTests, SizeofTests])
support.run_unittest(*tests)
support.run_doctest(pickle)