Issue #11743: Rewrite multiprocessing connection classes in pure Python.

This commit is contained in:
Antoine Pitrou 2011-05-09 17:04:27 +02:00
parent df77e3d4a0
commit 87cf220972
15 changed files with 490 additions and 983 deletions

View file

@ -34,19 +34,27 @@
__all__ = [ 'Client', 'Listener', 'Pipe' ]
import io
import os
import sys
import pickle
import select
import socket
import struct
import errno
import time
import tempfile
import itertools
import _multiprocessing
from multiprocessing import current_process, AuthenticationError
from multiprocessing import current_process, AuthenticationError, BufferTooShort
from multiprocessing.util import get_temp_dir, Finalize, sub_debug, debug
from multiprocessing.forking import duplicate, close
try:
from _multiprocessing import win32
except ImportError:
if sys.platform == 'win32':
raise
win32 = None
#
#
@ -110,6 +118,281 @@ def address_type(address):
else:
raise ValueError('address type of %r unrecognized' % address)
#
# Connection classes
#
class _ConnectionBase:
_handle = None
def __init__(self, handle, readable=True, writable=True):
handle = handle.__index__()
if handle < 0:
raise ValueError("invalid handle")
if not readable and not writable:
raise ValueError(
"at least one of `readable` and `writable` must be True")
self._handle = handle
self._readable = readable
self._writable = writable
def __del__(self):
if self._handle is not None:
self._close()
def _check_closed(self):
if self._handle is None:
raise IOError("handle is closed")
def _check_readable(self):
if not self._readable:
raise IOError("connection is write-only")
def _check_writable(self):
if not self._writable:
raise IOError("connection is read-only")
def _bad_message_length(self):
if self._writable:
self._readable = False
else:
self.close()
raise IOError("bad message length")
@property
def closed(self):
"""True if the connection is closed"""
return self._handle is None
@property
def readable(self):
"""True if the connection is readable"""
return self._readable
@property
def writable(self):
"""True if the connection is writable"""
return self._writable
def fileno(self):
"""File descriptor or handle of the connection"""
self._check_closed()
return self._handle
def close(self):
"""Close the connection"""
if self._handle is not None:
try:
self._close()
finally:
self._handle = None
def send_bytes(self, buf, offset=0, size=None):
"""Send the bytes data from a bytes-like object"""
self._check_closed()
self._check_writable()
m = memoryview(buf)
# HACK for byte-indexing of non-bytewise buffers (e.g. array.array)
if m.itemsize > 1:
m = memoryview(bytes(m))
n = len(m)
if offset < 0:
raise ValueError("offset is negative")
if n < offset:
raise ValueError("buffer length < offset")
if size is None:
size = n - offset
elif size < 0:
raise ValueError("size is negative")
elif offset + size > n:
raise ValueError("buffer length < offset + size")
self._send_bytes(m[offset:offset + size])
def send(self, obj):
"""Send a (picklable) object"""
self._check_closed()
self._check_writable()
buf = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
self._send_bytes(memoryview(buf))
def recv_bytes(self, maxlength=None):
"""
Receive bytes data as a bytes object.
"""
self._check_closed()
self._check_readable()
if maxlength is not None and maxlength < 0:
raise ValueError("negative maxlength")
buf = self._recv_bytes(maxlength)
if buf is None:
self._bad_message_length()
return buf.getvalue()
def recv_bytes_into(self, buf, offset=0):
"""
Receive bytes data into a writeable buffer-like object.
Return the number of bytes read.
"""
self._check_closed()
self._check_readable()
with memoryview(buf) as m:
# Get bytesize of arbitrary buffer
itemsize = m.itemsize
bytesize = itemsize * len(m)
if offset < 0:
raise ValueError("negative offset")
elif offset > bytesize:
raise ValueError("offset too large")
result = self._recv_bytes()
size = result.tell()
if bytesize < offset + size:
raise BufferTooShort(result.getvalue())
# Message can fit in dest
result.seek(0)
result.readinto(m[offset // itemsize :
(offset + size) // itemsize])
return size
def recv(self):
"""Receive a (picklable) object"""
self._check_closed()
self._check_readable()
buf = self._recv_bytes()
return pickle.loads(buf.getbuffer())
def poll(self, timeout=0.0):
"""Whether there is any input available to be read"""
self._check_closed()
self._check_readable()
if timeout < 0.0:
timeout = None
return self._poll(timeout)
if win32:
class PipeConnection(_ConnectionBase):
"""
Connection class based on a Windows named pipe.
"""
def _close(self):
win32.CloseHandle(self._handle)
def _send_bytes(self, buf):
nwritten = win32.WriteFile(self._handle, buf)
assert nwritten == len(buf)
def _recv_bytes(self, maxsize=None):
buf = io.BytesIO()
bufsize = 512
if maxsize is not None:
bufsize = min(bufsize, maxsize)
try:
firstchunk, complete = win32.ReadFile(self._handle, bufsize)
except IOError as e:
if e.errno == win32.ERROR_BROKEN_PIPE:
raise EOFError
raise
lenfirstchunk = len(firstchunk)
buf.write(firstchunk)
if complete:
return buf
navail, nleft = win32.PeekNamedPipe(self._handle)
if maxsize is not None and lenfirstchunk + nleft > maxsize:
return None
lastchunk, complete = win32.ReadFile(self._handle, nleft)
assert complete
buf.write(lastchunk)
return buf
def _poll(self, timeout):
navail, nleft = win32.PeekNamedPipe(self._handle)
if navail > 0:
return True
elif timeout == 0.0:
return False
# Setup a polling loop (translated straight from old
# pipe_connection.c)
if timeout < 0.0:
deadline = None
else:
deadline = time.time() + timeout
delay = 0.001
max_delay = 0.02
while True:
time.sleep(delay)
navail, nleft = win32.PeekNamedPipe(self._handle)
if navail > 0:
return True
if deadline and time.time() > deadline:
return False
if delay < max_delay:
delay += 0.001
class Connection(_ConnectionBase):
"""
Connection class based on an arbitrary file descriptor (Unix only), or
a socket handle (Windows).
"""
if win32:
def _close(self):
win32.closesocket(self._handle)
_write = win32.send
_read = win32.recv
else:
def _close(self):
os.close(self._handle)
_write = os.write
_read = os.read
def _send(self, buf, write=_write):
remaining = len(buf)
while True:
n = write(self._handle, buf)
remaining -= n
if remaining == 0:
break
buf = buf[n:]
def _recv(self, size, read=_read):
buf = io.BytesIO()
remaining = size
while remaining > 0:
chunk = read(self._handle, remaining)
n = len(chunk)
if n == 0:
if remaining == size:
raise EOFError
else:
raise IOError("got end of file during message")
buf.write(chunk)
remaining -= n
return buf
def _send_bytes(self, buf):
# For wire compatibility with 3.2 and lower
n = len(buf)
self._send(struct.pack("=i", len(buf)))
# The condition is necessary to avoid "broken pipe" errors
# when sending a 0-length buffer if the other end closed the pipe.
if n > 0:
self._send(buf)
def _recv_bytes(self, maxsize=None):
buf = self._recv(4)
size, = struct.unpack("=i", buf.getvalue())
if maxsize is not None and size > maxsize:
return None
return self._recv(size)
def _poll(self, timeout):
r = select.select([self._handle], [], [], timeout)[0]
return bool(r)
#
# Public functions
#
@ -186,21 +469,19 @@ if sys.platform != 'win32':
'''
if duplex:
s1, s2 = socket.socketpair()
c1 = _multiprocessing.Connection(os.dup(s1.fileno()))
c2 = _multiprocessing.Connection(os.dup(s2.fileno()))
c1 = Connection(os.dup(s1.fileno()))
c2 = Connection(os.dup(s2.fileno()))
s1.close()
s2.close()
else:
fd1, fd2 = os.pipe()
c1 = _multiprocessing.Connection(fd1, writable=False)
c2 = _multiprocessing.Connection(fd2, readable=False)
c1 = Connection(fd1, writable=False)
c2 = Connection(fd2, readable=False)
return c1, c2
else:
from _multiprocessing import win32
def Pipe(duplex=True):
'''
Returns pair of connection objects at either end of a pipe
@ -234,8 +515,8 @@ else:
if e.args[0] != win32.ERROR_PIPE_CONNECTED:
raise
c1 = _multiprocessing.PipeConnection(h1, writable=duplex)
c2 = _multiprocessing.PipeConnection(h2, readable=duplex)
c1 = PipeConnection(h1, writable=duplex)
c2 = PipeConnection(h2, readable=duplex)
return c1, c2
@ -266,7 +547,7 @@ class SocketListener(object):
def accept(self):
s, self._last_accepted = self._socket.accept()
fd = duplicate(s.fileno())
conn = _multiprocessing.Connection(fd)
conn = Connection(fd)
s.close()
return conn
@ -298,7 +579,7 @@ def SocketClient(address):
raise
fd = duplicate(s.fileno())
conn = _multiprocessing.Connection(fd)
conn = Connection(fd)
return conn
#
@ -345,7 +626,7 @@ if sys.platform == 'win32':
except WindowsError as e:
if e.args[0] != win32.ERROR_PIPE_CONNECTED:
raise
return _multiprocessing.PipeConnection(handle)
return PipeConnection(handle)
@staticmethod
def _finalize_pipe_listener(queue, address):
@ -377,7 +658,7 @@ if sys.platform == 'win32':
win32.SetNamedPipeHandleState(
h, win32.PIPE_READMODE_MESSAGE, None, None
)
return _multiprocessing.PipeConnection(h)
return PipeConnection(h)
#
# Authentication stuff
@ -451,3 +732,7 @@ def XmlClient(*args, **kwds):
global xmlrpclib
import xmlrpc.client as xmlrpclib
return ConnectionWrapper(Client(*args, **kwds), _xml_dumps, _xml_loads)
# Late import because of circular import
from multiprocessing.forking import duplicate, close