Issue #6137: The pickle module now translates module names when loading

or dumping pickles with a 2.x-compatible protocol, in order to make data
sharing and migration easier. This behaviour can be disabled using the
new `fix_imports` optional argument.
This commit is contained in:
Antoine Pitrou 2009-06-04 20:32:06 +00:00
parent 751899a59f
commit d9dfaa9487
8 changed files with 532 additions and 157 deletions

View file

@ -34,6 +34,7 @@ import struct
import re
import io
import codecs
import _compat_pickle
__all__ = ["PickleError", "PicklingError", "UnpicklingError", "Pickler",
"Unpickler", "dump", "dumps", "load", "loads"]
@ -171,12 +172,11 @@ SHORT_BINBYTES = b'C' # " " ; " " " " < 256 bytes
__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$",x)])
# Pickling machinery
class _Pickler:
def __init__(self, file, protocol=None):
def __init__(self, file, protocol=None, *, fix_imports=True):
"""This takes a binary file for writing a pickle data stream.
The optional protocol argument tells the pickler to use the
@ -193,6 +193,10 @@ class _Pickler:
bytes argument. It can thus be a file object opened for binary
writing, a io.BytesIO instance, or any other custom object that
meets this interface.
If fix_imports is True and protocol is less than 3, pickle will try to
map the new Python 3.x names to the old module names used in Python
2.x, so that the pickle data stream is readable with Python 2.x.
"""
if protocol is None:
protocol = DEFAULT_PROTOCOL
@ -208,6 +212,7 @@ class _Pickler:
self.proto = int(protocol)
self.bin = protocol >= 1
self.fast = 0
self.fix_imports = fix_imports and protocol < 3
def clear_memo(self):
"""Clears the pickler's "memo".
@ -698,6 +703,11 @@ class _Pickler:
write(GLOBAL + bytes(module, "utf-8") + b'\n' +
bytes(name, "utf-8") + b'\n')
else:
if self.fix_imports:
if (module, name) in _compat_pickle.REVERSE_NAME_MAPPING:
module, name = _compat_pickle.REVERSE_NAME_MAPPING[(module, name)]
if module in _compat_pickle.REVERSE_IMPORT_MAPPING:
module = _compat_pickle.REVERSE_IMPORT_MAPPING[module]
try:
write(GLOBAL + bytes(module, "ascii") + b'\n' +
bytes(name, "ascii") + b'\n')
@ -766,7 +776,8 @@ def whichmodule(func, funcname):
class _Unpickler:
def __init__(self, file, *, encoding="ASCII", errors="strict"):
def __init__(self, file, *, fix_imports=True,
encoding="ASCII", errors="strict"):
"""This takes a binary file for reading a pickle data stream.
The protocol version of the pickle is detected automatically, so no
@ -779,15 +790,21 @@ class _Unpickler:
reading, a BytesIO object, or any other custom object that
meets this interface.
Optional keyword arguments are encoding and errors, which are
used to decode 8-bit string instances pickled by Python 2.x.
These default to 'ASCII' and 'strict', respectively.
Optional keyword arguments are *fix_imports*, *encoding* and *errors*,
which are used to control compatiblity support for pickle stream
generated by Python 2.x. If *fix_imports* is True, pickle will try to
map the old Python 2.x names to the new names used in Python 3.x. The
*encoding* and *errors* tell pickle how to decode 8-bit string
instances pickled by Python 2.x; these default to 'ASCII' and
'strict', respectively.
"""
self.readline = file.readline
self.read = file.read
self.memo = {}
self.encoding = encoding
self.errors = errors
self.proto = 0
self.fix_imports = fix_imports
def load(self):
"""Read a pickled object representation from the open file.
@ -838,6 +855,7 @@ class _Unpickler:
proto = ord(self.read(1))
if not 0 <= proto <= HIGHEST_PROTOCOL:
raise ValueError("unsupported pickle protocol: %d" % proto)
self.proto = proto
dispatch[PROTO[0]] = load_proto
def load_persid(self):
@ -1088,7 +1106,12 @@ class _Unpickler:
self.append(obj)
def find_class(self, module, name):
# Subclasses may override this
# Subclasses may override this.
if self.proto < 3 and self.fix_imports:
if (module, name) in _compat_pickle.NAME_MAPPING:
module, name = _compat_pickle.NAME_MAPPING[(module, name)]
if module in _compat_pickle.IMPORT_MAPPING:
module = _compat_pickle.IMPORT_MAPPING[module]
__import__(module, level=0)
mod = sys.modules[module]
klass = getattr(mod, name)
@ -1327,27 +1350,28 @@ except ImportError:
# Shorthands
def dump(obj, file, protocol=None):
Pickler(file, protocol).dump(obj)
def dump(obj, file, protocol=None, *, fix_imports=True):
Pickler(file, protocol, fix_imports=fix_imports).dump(obj)
def dumps(obj, protocol=None):
def dumps(obj, protocol=None, *, fix_imports=True):
f = io.BytesIO()
Pickler(f, protocol).dump(obj)
Pickler(f, protocol, fix_imports=fix_imports).dump(obj)
res = f.getvalue()
assert isinstance(res, bytes_types)
return res
def load(file, *, encoding="ASCII", errors="strict"):
return Unpickler(file, encoding=encoding, errors=errors).load()
def load(file, *, fix_imports=True, encoding="ASCII", errors="strict"):
return Unpickler(file, fix_imports=fix_imports,
encoding=encoding, errors=errors).load()
def loads(s, *, encoding="ASCII", errors="strict"):
def loads(s, *, fix_imports=True, encoding="ASCII", errors="strict"):
if isinstance(s, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(s)
return Unpickler(file, encoding=encoding, errors=errors).load()
return Unpickler(file, fix_imports=fix_imports,
encoding=encoding, errors=errors).load()
# Doctest
def _test():
import doctest
return doctest.testmod()