Issue #17810: Implement PEP 3154, pickle protocol 4.

Most of the work is by Alexandre.
This commit is contained in:
Antoine Pitrou 2013-11-23 18:59:12 +01:00
parent 95401c5f6b
commit c9dc4a2a8a
12 changed files with 3132 additions and 1006 deletions

View file

@ -11,6 +11,7 @@ dis(pickle, out=None, memo=None, indentlevel=4)
'''
import codecs
import io
import pickle
import re
import sys
@ -168,6 +169,7 @@ UP_TO_NEWLINE = -1
TAKEN_FROM_ARGUMENT1 = -2 # num bytes is 1-byte unsigned int
TAKEN_FROM_ARGUMENT4 = -3 # num bytes is 4-byte signed little-endian int
TAKEN_FROM_ARGUMENT4U = -4 # num bytes is 4-byte unsigned little-endian int
TAKEN_FROM_ARGUMENT8U = -5 # num bytes is 8-byte unsigned little-endian int
class ArgumentDescriptor(object):
__slots__ = (
@ -175,7 +177,7 @@ class ArgumentDescriptor(object):
'name',
# length of argument, in bytes; an int; UP_TO_NEWLINE and
# TAKEN_FROM_ARGUMENT{1,4} are negative values for variable-length
# TAKEN_FROM_ARGUMENT{1,4,8} are negative values for variable-length
# cases
'n',
@ -196,7 +198,8 @@ class ArgumentDescriptor(object):
n in (UP_TO_NEWLINE,
TAKEN_FROM_ARGUMENT1,
TAKEN_FROM_ARGUMENT4,
TAKEN_FROM_ARGUMENT4U))
TAKEN_FROM_ARGUMENT4U,
TAKEN_FROM_ARGUMENT8U))
self.n = n
self.reader = reader
@ -288,6 +291,27 @@ uint4 = ArgumentDescriptor(
doc="Four-byte unsigned integer, little-endian.")
def read_uint8(f):
r"""
>>> import io
>>> read_uint8(io.BytesIO(b'\xff\x00\x00\x00\x00\x00\x00\x00'))
255
>>> read_uint8(io.BytesIO(b'\xff' * 8)) == 2**64-1
True
"""
data = f.read(8)
if len(data) == 8:
return _unpack("<Q", data)[0]
raise ValueError("not enough data in stream to read uint8")
uint8 = ArgumentDescriptor(
name='uint8',
n=8,
reader=read_uint8,
doc="Eight-byte unsigned integer, little-endian.")
def read_stringnl(f, decode=True, stripquotes=True):
r"""
>>> import io
@ -381,6 +405,36 @@ stringnl_noescape_pair = ArgumentDescriptor(
a single blank separating the two strings.
""")
def read_string1(f):
r"""
>>> import io
>>> read_string1(io.BytesIO(b"\x00"))
''
>>> read_string1(io.BytesIO(b"\x03abcdef"))
'abc'
"""
n = read_uint1(f)
assert n >= 0
data = f.read(n)
if len(data) == n:
return data.decode("latin-1")
raise ValueError("expected %d bytes in a string1, but only %d remain" %
(n, len(data)))
string1 = ArgumentDescriptor(
name="string1",
n=TAKEN_FROM_ARGUMENT1,
reader=read_string1,
doc="""A counted string.
The first argument is a 1-byte unsigned int giving the number
of bytes in the string, and the second argument is that many
bytes.
""")
def read_string4(f):
r"""
>>> import io
@ -415,28 +469,28 @@ string4 = ArgumentDescriptor(
""")
def read_string1(f):
def read_bytes1(f):
r"""
>>> import io
>>> read_string1(io.BytesIO(b"\x00"))
''
>>> read_string1(io.BytesIO(b"\x03abcdef"))
'abc'
>>> read_bytes1(io.BytesIO(b"\x00"))
b''
>>> read_bytes1(io.BytesIO(b"\x03abcdef"))
b'abc'
"""
n = read_uint1(f)
assert n >= 0
data = f.read(n)
if len(data) == n:
return data.decode("latin-1")
raise ValueError("expected %d bytes in a string1, but only %d remain" %
return data
raise ValueError("expected %d bytes in a bytes1, but only %d remain" %
(n, len(data)))
string1 = ArgumentDescriptor(
name="string1",
bytes1 = ArgumentDescriptor(
name="bytes1",
n=TAKEN_FROM_ARGUMENT1,
reader=read_string1,
doc="""A counted string.
reader=read_bytes1,
doc="""A counted bytes string.
The first argument is a 1-byte unsigned int giving the number
of bytes in the string, and the second argument is that many
@ -486,6 +540,7 @@ def read_bytes4(f):
"""
n = read_uint4(f)
assert n >= 0
if n > sys.maxsize:
raise ValueError("bytes4 byte count > sys.maxsize: %d" % n)
data = f.read(n)
@ -505,6 +560,39 @@ bytes4 = ArgumentDescriptor(
""")
def read_bytes8(f):
r"""
>>> import io
>>> read_bytes8(io.BytesIO(b"\x00\x00\x00\x00\x00\x00\x00\x00abc"))
b''
>>> read_bytes8(io.BytesIO(b"\x03\x00\x00\x00\x00\x00\x00\x00abcdef"))
b'abc'
>>> read_bytes8(io.BytesIO(b"\x00\x00\x00\x00\x00\x00\x03\x00abcdef"))
Traceback (most recent call last):
...
ValueError: expected 844424930131968 bytes in a bytes8, but only 6 remain
"""
n = read_uint8(f)
assert n >= 0
if n > sys.maxsize:
raise ValueError("bytes8 byte count > sys.maxsize: %d" % n)
data = f.read(n)
if len(data) == n:
return data
raise ValueError("expected %d bytes in a bytes8, but only %d remain" %
(n, len(data)))
bytes8 = ArgumentDescriptor(
name="bytes8",
n=TAKEN_FROM_ARGUMENT8U,
reader=read_bytes8,
doc="""A counted bytes string.
The first argument is a 8-byte little-endian unsigned int giving
the number of bytes, and the second argument is that many bytes.
""")
def read_unicodestringnl(f):
r"""
>>> import io
@ -530,6 +618,46 @@ unicodestringnl = ArgumentDescriptor(
escape sequences.
""")
def read_unicodestring1(f):
r"""
>>> import io
>>> s = 'abcd\uabcd'
>>> enc = s.encode('utf-8')
>>> enc
b'abcd\xea\xaf\x8d'
>>> n = bytes([len(enc)]) # little-endian 1-byte length
>>> t = read_unicodestring1(io.BytesIO(n + enc + b'junk'))
>>> s == t
True
>>> read_unicodestring1(io.BytesIO(n + enc[:-1]))
Traceback (most recent call last):
...
ValueError: expected 7 bytes in a unicodestring1, but only 6 remain
"""
n = read_uint1(f)
assert n >= 0
data = f.read(n)
if len(data) == n:
return str(data, 'utf-8', 'surrogatepass')
raise ValueError("expected %d bytes in a unicodestring1, but only %d "
"remain" % (n, len(data)))
unicodestring1 = ArgumentDescriptor(
name="unicodestring1",
n=TAKEN_FROM_ARGUMENT1,
reader=read_unicodestring1,
doc="""A counted Unicode string.
The first argument is a 1-byte little-endian signed int
giving the number of bytes in the string, and the second
argument-- the UTF-8 encoding of the Unicode string --
contains that many bytes.
""")
def read_unicodestring4(f):
r"""
>>> import io
@ -549,6 +677,7 @@ def read_unicodestring4(f):
"""
n = read_uint4(f)
assert n >= 0
if n > sys.maxsize:
raise ValueError("unicodestring4 byte count > sys.maxsize: %d" % n)
data = f.read(n)
@ -570,6 +699,47 @@ unicodestring4 = ArgumentDescriptor(
""")
def read_unicodestring8(f):
r"""
>>> import io
>>> s = 'abcd\uabcd'
>>> enc = s.encode('utf-8')
>>> enc
b'abcd\xea\xaf\x8d'
>>> n = bytes([len(enc)]) + bytes(7) # little-endian 8-byte length
>>> t = read_unicodestring8(io.BytesIO(n + enc + b'junk'))
>>> s == t
True
>>> read_unicodestring8(io.BytesIO(n + enc[:-1]))
Traceback (most recent call last):
...
ValueError: expected 7 bytes in a unicodestring8, but only 6 remain
"""
n = read_uint8(f)
assert n >= 0
if n > sys.maxsize:
raise ValueError("unicodestring8 byte count > sys.maxsize: %d" % n)
data = f.read(n)
if len(data) == n:
return str(data, 'utf-8', 'surrogatepass')
raise ValueError("expected %d bytes in a unicodestring8, but only %d "
"remain" % (n, len(data)))
unicodestring8 = ArgumentDescriptor(
name="unicodestring8",
n=TAKEN_FROM_ARGUMENT8U,
reader=read_unicodestring8,
doc="""A counted Unicode string.
The first argument is a 8-byte little-endian signed int
giving the number of bytes in the string, and the second
argument-- the UTF-8 encoding of the Unicode string --
contains that many bytes.
""")
def read_decimalnl_short(f):
r"""
>>> import io
@ -859,6 +1029,16 @@ pydict = StackObject(
obtype=dict,
doc="A Python dict object.")
pyset = StackObject(
name="set",
obtype=set,
doc="A Python set object.")
pyfrozenset = StackObject(
name="frozenset",
obtype=set,
doc="A Python frozenset object.")
anyobject = StackObject(
name='any',
obtype=object,
@ -1142,6 +1322,19 @@ opcodes = [
literally as the string content.
"""),
I(name='BINBYTES8',
code='\x8e',
arg=bytes8,
stack_before=[],
stack_after=[pybytes],
proto=4,
doc="""Push a Python bytes object.
There are two arguments: the first is a 8-byte unsigned int giving
the number of bytes in the string, and the second is that many bytes,
which are taken literally as the string content.
"""),
# Ways to spell None.
I(name='NONE',
@ -1190,6 +1383,19 @@ opcodes = [
until the next newline character.
"""),
I(name='SHORT_BINUNICODE',
code='\x8c',
arg=unicodestring1,
stack_before=[],
stack_after=[pyunicode],
proto=4,
doc="""Push a Python Unicode string object.
There are two arguments: the first is a 1-byte little-endian signed int
giving the number of bytes in the string. The second is that many
bytes, and is the UTF-8 encoding of the Unicode string.
"""),
I(name='BINUNICODE',
code='X',
arg=unicodestring4,
@ -1203,6 +1409,19 @@ opcodes = [
bytes, and is the UTF-8 encoding of the Unicode string.
"""),
I(name='BINUNICODE8',
code='\x8d',
arg=unicodestring8,
stack_before=[],
stack_after=[pyunicode],
proto=4,
doc="""Push a Python Unicode string object.
There are two arguments: the first is a 8-byte little-endian signed int
giving the number of bytes in the string. The second is that many
bytes, and is the UTF-8 encoding of the Unicode string.
"""),
# Ways to spell floats.
I(name='FLOAT',
@ -1428,6 +1647,54 @@ opcodes = [
1, 2, ..., n, and in that order.
"""),
# Ways to build sets
I(name='EMPTY_SET',
code='\x8f',
arg=None,
stack_before=[],
stack_after=[pyset],
proto=4,
doc="Push an empty set."),
I(name='ADDITEMS',
code='\x90',
arg=None,
stack_before=[pyset, markobject, stackslice],
stack_after=[pyset],
proto=4,
doc="""Add an arbitrary number of items to an existing set.
The slice of the stack following the topmost markobject is taken as
a sequence of items, added to the set immediately under the topmost
markobject. Everything at and after the topmost markobject is popped,
leaving the mutated set at the top of the stack.
Stack before: ... pyset markobject item_1 ... item_n
Stack after: ... pyset
where pyset has been modified via pyset.add(item_i) = item_i for i in
1, 2, ..., n, and in that order.
"""),
# Way to build frozensets
I(name='FROZENSET',
code='\x91',
arg=None,
stack_before=[markobject, stackslice],
stack_after=[pyfrozenset],
proto=4,
doc="""Build a frozenset out of the topmost slice, after markobject.
All the stack entries following the topmost markobject are placed into
a single Python frozenset, which single frozenset object replaces all
of the stack from the topmost markobject onward. For example,
Stack before: ... markobject 1 2 3
Stack after: ... frozenset({1, 2, 3})
"""),
# Stack manipulation.
I(name='POP',
@ -1549,6 +1816,18 @@ opcodes = [
unsigned little-endian integer following.
"""),
I(name='MEMOIZE',
code='\x94',
arg=None,
stack_before=[anyobject],
stack_after=[anyobject],
proto=4,
doc="""Store the stack top into the memo. The stack is not popped.
The index of the memo location to write is the number of
elements currently present in the memo.
"""),
# Access the extension registry (predefined objects). Akin to the GET
# family.
@ -1614,6 +1893,15 @@ opcodes = [
stack, so unpickling subclasses can override this form of lookup.
"""),
I(name='STACK_GLOBAL',
code='\x93',
arg=None,
stack_before=[pyunicode, pyunicode],
stack_after=[anyobject],
proto=0,
doc="""Push a global object (module.attr) on the stack.
"""),
# Ways to build objects of classes pickle doesn't know about directly
# (user-defined classes). I despair of documenting this accurately
# and comprehensibly -- you really have to read the pickle code to
@ -1770,6 +2058,21 @@ opcodes = [
onto the stack.
"""),
I(name='NEWOBJ_EX',
code='\x92',
arg=None,
stack_before=[anyobject, anyobject, anyobject],
stack_after=[anyobject],
proto=4,
doc="""Build an object instance.
The stack before should be thought of as containing a class
object followed by an argument tuple and by a keyword argument dict
(the dict being the stack top). Call these cls and args. They are
popped off the stack, and the value returned by
cls.__new__(cls, *args, *kwargs) is pushed back onto the stack.
"""),
# Machine control.
I(name='PROTO',
@ -1797,6 +2100,20 @@ opcodes = [
empty then.
"""),
# Framing support.
I(name='FRAME',
code='\x95',
arg=uint8,
stack_before=[],
stack_after=[],
proto=4,
doc="""Indicate the beginning of a new frame.
The unpickler may use this opcode to safely prefetch data from its
underlying stream.
"""),
# Ways to deal with persistent IDs.
I(name='PERSID',
@ -1903,6 +2220,38 @@ del assure_pickle_consistency
##############################################################################
# A pickle opcode generator.
def _genops(data, yield_end_pos=False):
if isinstance(data, bytes_types):
data = io.BytesIO(data)
if hasattr(data, "tell"):
getpos = data.tell
else:
getpos = lambda: None
while True:
pos = getpos()
code = data.read(1)
opcode = code2op.get(code.decode("latin-1"))
if opcode is None:
if code == b"":
raise ValueError("pickle exhausted before seeing STOP")
else:
raise ValueError("at position %s, opcode %r unknown" % (
"<unknown>" if pos is None else pos,
code))
if opcode.arg is None:
arg = None
else:
arg = opcode.arg.reader(data)
if yield_end_pos:
yield opcode, arg, pos, getpos()
else:
yield opcode, arg, pos
if code == b'.':
assert opcode.name == 'STOP'
break
def genops(pickle):
"""Generate all the opcodes in a pickle.
@ -1926,62 +2275,47 @@ def genops(pickle):
used. Else (the pickle doesn't have a tell(), and it's not obvious how
to query its current position) pos is None.
"""
if isinstance(pickle, bytes_types):
import io
pickle = io.BytesIO(pickle)
if hasattr(pickle, "tell"):
getpos = pickle.tell
else:
getpos = lambda: None
while True:
pos = getpos()
code = pickle.read(1)
opcode = code2op.get(code.decode("latin-1"))
if opcode is None:
if code == b"":
raise ValueError("pickle exhausted before seeing STOP")
else:
raise ValueError("at position %s, opcode %r unknown" % (
pos is None and "<unknown>" or pos,
code))
if opcode.arg is None:
arg = None
else:
arg = opcode.arg.reader(pickle)
yield opcode, arg, pos
if code == b'.':
assert opcode.name == 'STOP'
break
return _genops(pickle)
##############################################################################
# A pickle optimizer.
def optimize(p):
'Optimize a pickle string by removing unused PUT opcodes'
gets = set() # set of args used by a GET opcode
puts = [] # (arg, startpos, stoppos) for the PUT opcodes
prevpos = None # set to pos if previous opcode was a PUT
for opcode, arg, pos in genops(p):
if prevpos is not None:
puts.append((prevarg, prevpos, pos))
prevpos = None
not_a_put = object()
gets = { not_a_put } # set of args used by a GET opcode
opcodes = [] # (startpos, stoppos, putid)
proto = 0
for opcode, arg, pos, end_pos in _genops(p, yield_end_pos=True):
if 'PUT' in opcode.name:
prevarg, prevpos = arg, pos
elif 'GET' in opcode.name:
gets.add(arg)
opcodes.append((pos, end_pos, arg))
elif 'FRAME' in opcode.name:
pass
else:
if 'GET' in opcode.name:
gets.add(arg)
elif opcode.name == 'PROTO':
assert pos == 0, pos
proto = arg
opcodes.append((pos, end_pos, not_a_put))
prevpos, prevarg = pos, None
# Copy the pickle string except for PUTS without a corresponding GET
s = []
i = 0
for arg, start, stop in puts:
j = stop if (arg in gets) else start
s.append(p[i:j])
i = stop
s.append(p[i:])
return b''.join(s)
# Copy the opcodes except for PUTS without a corresponding GET
out = io.BytesIO()
opcodes = iter(opcodes)
if proto >= 2:
# Write the PROTO header before any framing
start, stop, _ = next(opcodes)
out.write(p[start:stop])
buf = pickle._Framer(out.write)
if proto >= 4:
buf.start_framing()
for start, stop, putid in opcodes:
if putid in gets:
buf.write(p[start:stop])
if proto >= 4:
buf.end_framing()
return out.getvalue()
##############################################################################
# A symbolic pickle disassembler.
@ -2081,17 +2415,20 @@ def dis(pickle, out=None, memo=None, indentlevel=4, annotate=0):
errormsg = markmsg = "no MARK exists on stack"
# Check for correct memo usage.
if opcode.name in ("PUT", "BINPUT", "LONG_BINPUT"):
assert arg is not None
if arg in memo:
if opcode.name in ("PUT", "BINPUT", "LONG_BINPUT", "MEMOIZE"):
if opcode.name == "MEMOIZE":
memo_idx = len(memo)
else:
assert arg is not None
memo_idx = arg
if memo_idx in memo:
errormsg = "memo key %r already defined" % arg
elif not stack:
errormsg = "stack is empty -- can't store into memo"
elif stack[-1] is markobject:
errormsg = "can't store markobject in the memo"
else:
memo[arg] = stack[-1]
memo[memo_idx] = stack[-1]
elif opcode.name in ("GET", "BINGET", "LONG_BINGET"):
if arg in memo:
assert len(after) == 1