Closes #13297: use bytes type to send and receive binary data through XMLRPC.

This commit is contained in:
Florent Xicluna 2011-11-15 20:53:25 +01:00
parent 1d8f3f451c
commit 6166519d2b
4 changed files with 132 additions and 48 deletions

View file

@ -386,8 +386,8 @@ class Binary:
if data is None:
data = b""
else:
if not isinstance(data, bytes):
raise TypeError("expected bytes, not %s" %
if not isinstance(data, (bytes, bytearray)):
raise TypeError("expected bytes or bytearray, not %s" %
data.__class__.__name__)
data = bytes(data) # Make a copy of the bytes!
self.data = data
@ -559,6 +559,14 @@ class Marshaller:
write("</string></value>\n")
dispatch[str] = dump_unicode
def dump_bytes(self, value, write):
write("<value><base64>\n")
encoded = base64.encodebytes(value)
write(encoded.decode('ascii'))
write("</base64></value>\n")
dispatch[bytes] = dump_bytes
dispatch[bytearray] = dump_bytes
def dump_array(self, value, write):
i = id(value)
if i in self.memo:
@ -629,7 +637,7 @@ class Unmarshaller:
# and again, if you don't understand what's going on in here,
# that's perfectly ok.
def __init__(self, use_datetime=False):
def __init__(self, use_datetime=False, use_builtin_types=False):
self._type = None
self._stack = []
self._marks = []
@ -637,7 +645,8 @@ class Unmarshaller:
self._methodname = None
self._encoding = "utf-8"
self.append = self._stack.append
self._use_datetime = use_datetime
self._use_datetime = use_builtin_types or use_datetime
self._use_bytes = use_builtin_types
def close(self):
# return response tuple and target method
@ -749,6 +758,8 @@ class Unmarshaller:
def end_base64(self, data):
value = Binary()
value.decode(data.encode("ascii"))
if self._use_bytes:
value = value.data
self.append(value)
self._value = 0
dispatch["base64"] = end_base64
@ -860,21 +871,26 @@ FastMarshaller = FastParser = FastUnmarshaller = None
#
# return A (parser, unmarshaller) tuple.
def getparser(use_datetime=False):
def getparser(use_datetime=False, use_builtin_types=False):
"""getparser() -> parser, unmarshaller
Create an instance of the fastest available parser, and attach it
to an unmarshalling object. Return both objects.
"""
if FastParser and FastUnmarshaller:
if use_datetime:
if use_builtin_types:
mkdatetime = _datetime_type
mkbytes = base64.decodebytes
elif use_datetime:
mkdatetime = _datetime_type
mkbytes = _binary
else:
mkdatetime = _datetime
target = FastUnmarshaller(True, False, _binary, mkdatetime, Fault)
mkbytes = _binary
target = FastUnmarshaller(True, False, mkbytes, mkdatetime, Fault)
parser = FastParser(target)
else:
target = Unmarshaller(use_datetime=use_datetime)
target = Unmarshaller(use_datetime=use_datetime, use_builtin_types=use_builtin_types)
if FastParser:
parser = FastParser(target)
else:
@ -912,7 +928,7 @@ def dumps(params, methodname=None, methodresponse=None, encoding=None,
encoding: the packet encoding (default is UTF-8)
All 8-bit strings in the data structure are assumed to use the
All byte strings in the data structure are assumed to use the
packet encoding. Unicode strings are automatically converted,
where necessary.
"""
@ -971,7 +987,7 @@ def dumps(params, methodname=None, methodresponse=None, encoding=None,
# (None if not present).
# @see Fault
def loads(data, use_datetime=False):
def loads(data, use_datetime=False, use_builtin_types=False):
"""data -> unmarshalled data, method name
Convert an XML-RPC packet to unmarshalled data plus a method
@ -980,7 +996,7 @@ def loads(data, use_datetime=False):
If the XML-RPC packet represents a fault condition, this function
raises a Fault exception.
"""
p, u = getparser(use_datetime=use_datetime)
p, u = getparser(use_datetime=use_datetime, use_builtin_types=use_builtin_types)
p.feed(data)
p.close()
return u.close(), u.getmethodname()
@ -1092,8 +1108,9 @@ class Transport:
# that they can decode such a request
encode_threshold = None #None = don't encode
def __init__(self, use_datetime=False):
def __init__(self, use_datetime=False, use_builtin_types=False):
self._use_datetime = use_datetime
self._use_builtin_types = use_builtin_types
self._connection = (None, None)
self._extra_headers = []
@ -1154,7 +1171,8 @@ class Transport:
def getparser(self):
# get parser and unmarshaller
return getparser(use_datetime=self._use_datetime)
return getparser(use_datetime=self._use_datetime,
use_builtin_types=self._use_builtin_types)
##
# Get authorization info from host parameter
@ -1361,7 +1379,7 @@ class ServerProxy:
"""
def __init__(self, uri, transport=None, encoding=None, verbose=False,
allow_none=False, use_datetime=False):
allow_none=False, use_datetime=False, use_builtin_types=False):
# establish a "logical" server connection
# get the url
@ -1375,9 +1393,11 @@ class ServerProxy:
if transport is None:
if type == "https":
transport = SafeTransport(use_datetime=use_datetime)
handler = SafeTransport
else:
transport = Transport(use_datetime=use_datetime)
handler = Transport
transport = handler(use_datetime=use_datetime,
use_builtin_types=use_builtin_types)
self.__transport = transport
self.__encoding = encoding or 'utf-8'