mirror of
https://github.com/python/cpython.git
synced 2025-12-04 00:30:19 +00:00
asyncio: Change write buffer use to avoid O(N**2). Make write()/sendto() accept bytearray/memoryview too. Change some asserts with proper exceptions.
This commit is contained in:
parent
f28ce60441
commit
a5062c5d81
2 changed files with 207 additions and 78 deletions
|
|
@ -340,6 +340,8 @@ class _SelectorTransport(transports.Transport):
|
|||
|
||||
max_size = 256 * 1024 # Buffer size passed to recv().
|
||||
|
||||
_buffer_factory = bytearray # Constructs initial value for self._buffer.
|
||||
|
||||
def __init__(self, loop, sock, protocol, extra, server=None):
|
||||
super().__init__(extra)
|
||||
self._extra['socket'] = sock
|
||||
|
|
@ -354,7 +356,7 @@ class _SelectorTransport(transports.Transport):
|
|||
self._sock_fd = sock.fileno()
|
||||
self._protocol = protocol
|
||||
self._server = server
|
||||
self._buffer = collections.deque()
|
||||
self._buffer = self._buffer_factory()
|
||||
self._conn_lost = 0 # Set when call to connection_lost scheduled.
|
||||
self._closing = False # Set when close() called.
|
||||
self._protocol_paused = False
|
||||
|
|
@ -433,12 +435,14 @@ class _SelectorTransport(transports.Transport):
|
|||
high = 4*low
|
||||
if low is None:
|
||||
low = high // 4
|
||||
assert 0 <= low <= high, repr((low, high))
|
||||
if not high >= low >= 0:
|
||||
raise ValueError('high (%r) must be >= low (%r) must be >= 0' %
|
||||
(high, low))
|
||||
self._high_water = high
|
||||
self._low_water = low
|
||||
|
||||
def get_write_buffer_size(self):
|
||||
return sum(len(data) for data in self._buffer)
|
||||
return len(self._buffer)
|
||||
|
||||
|
||||
class _SelectorSocketTransport(_SelectorTransport):
|
||||
|
|
@ -455,13 +459,16 @@ class _SelectorSocketTransport(_SelectorTransport):
|
|||
self._loop.call_soon(waiter.set_result, None)
|
||||
|
||||
def pause_reading(self):
|
||||
assert not self._closing, 'Cannot pause_reading() when closing'
|
||||
assert not self._paused, 'Already paused'
|
||||
if self._closing:
|
||||
raise RuntimeError('Cannot pause_reading() when closing')
|
||||
if self._paused:
|
||||
raise RuntimeError('Already paused')
|
||||
self._paused = True
|
||||
self._loop.remove_reader(self._sock_fd)
|
||||
|
||||
def resume_reading(self):
|
||||
assert self._paused, 'Not paused'
|
||||
if not self._paused:
|
||||
raise RuntimeError('Not paused')
|
||||
self._paused = False
|
||||
if self._closing:
|
||||
return
|
||||
|
|
@ -488,8 +495,11 @@ class _SelectorSocketTransport(_SelectorTransport):
|
|||
self.close()
|
||||
|
||||
def write(self, data):
|
||||
assert isinstance(data, bytes), repr(type(data))
|
||||
assert not self._eof, 'Cannot call write() after write_eof()'
|
||||
if not isinstance(data, (bytes, bytearray, memoryview)):
|
||||
raise TypeError('data argument must be byte-ish (%r)',
|
||||
type(data))
|
||||
if self._eof:
|
||||
raise RuntimeError('Cannot call write() after write_eof()')
|
||||
if not data:
|
||||
return
|
||||
|
||||
|
|
@ -516,25 +526,23 @@ class _SelectorSocketTransport(_SelectorTransport):
|
|||
self._loop.add_writer(self._sock_fd, self._write_ready)
|
||||
|
||||
# Add it to the buffer.
|
||||
self._buffer.append(data)
|
||||
self._buffer.extend(data)
|
||||
self._maybe_pause_protocol()
|
||||
|
||||
def _write_ready(self):
|
||||
data = b''.join(self._buffer)
|
||||
assert data, 'Data should not be empty'
|
||||
assert self._buffer, 'Data should not be empty'
|
||||
|
||||
self._buffer.clear() # Optimistically; may have to put it back later.
|
||||
try:
|
||||
n = self._sock.send(data)
|
||||
n = self._sock.send(self._buffer)
|
||||
except (BlockingIOError, InterruptedError):
|
||||
self._buffer.append(data) # Still need to write this.
|
||||
pass
|
||||
except Exception as exc:
|
||||
self._loop.remove_writer(self._sock_fd)
|
||||
self._buffer.clear()
|
||||
self._fatal_error(exc)
|
||||
else:
|
||||
data = data[n:]
|
||||
if data:
|
||||
self._buffer.append(data) # Still need to write this.
|
||||
if n:
|
||||
del self._buffer[:n]
|
||||
self._maybe_resume_protocol() # May append to buffer.
|
||||
if not self._buffer:
|
||||
self._loop.remove_writer(self._sock_fd)
|
||||
|
|
@ -556,6 +564,8 @@ class _SelectorSocketTransport(_SelectorTransport):
|
|||
|
||||
class _SelectorSslTransport(_SelectorTransport):
|
||||
|
||||
_buffer_factory = bytearray
|
||||
|
||||
def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None,
|
||||
server_side=False, server_hostname=None,
|
||||
extra=None, server=None):
|
||||
|
|
@ -661,13 +671,16 @@ class _SelectorSslTransport(_SelectorTransport):
|
|||
# accept more data for the buffer and eventually the app will
|
||||
# call resume_reading() again, and things will flow again.
|
||||
|
||||
assert not self._closing, 'Cannot pause_reading() when closing'
|
||||
assert not self._paused, 'Already paused'
|
||||
if self._closing:
|
||||
raise RuntimeError('Cannot pause_reading() when closing')
|
||||
if self._paused:
|
||||
raise RuntimeError('Already paused')
|
||||
self._paused = True
|
||||
self._loop.remove_reader(self._sock_fd)
|
||||
|
||||
def resume_reading(self):
|
||||
assert self._paused, 'Not paused'
|
||||
if not self._paused:
|
||||
raise ('Not paused')
|
||||
self._paused = False
|
||||
if self._closing:
|
||||
return
|
||||
|
|
@ -712,10 +725,8 @@ class _SelectorSslTransport(_SelectorTransport):
|
|||
self._loop.add_reader(self._sock_fd, self._read_ready)
|
||||
|
||||
if self._buffer:
|
||||
data = b''.join(self._buffer)
|
||||
self._buffer.clear()
|
||||
try:
|
||||
n = self._sock.send(data)
|
||||
n = self._sock.send(self._buffer)
|
||||
except (BlockingIOError, InterruptedError,
|
||||
ssl.SSLWantWriteError):
|
||||
n = 0
|
||||
|
|
@ -725,11 +736,12 @@ class _SelectorSslTransport(_SelectorTransport):
|
|||
self._write_wants_read = True
|
||||
except Exception as exc:
|
||||
self._loop.remove_writer(self._sock_fd)
|
||||
self._buffer.clear()
|
||||
self._fatal_error(exc)
|
||||
return
|
||||
|
||||
if n < len(data):
|
||||
self._buffer.append(data[n:])
|
||||
if n:
|
||||
del self._buffer[:n]
|
||||
|
||||
self._maybe_resume_protocol() # May append to buffer.
|
||||
|
||||
|
|
@ -739,7 +751,9 @@ class _SelectorSslTransport(_SelectorTransport):
|
|||
self._call_connection_lost(None)
|
||||
|
||||
def write(self, data):
|
||||
assert isinstance(data, bytes), repr(type(data))
|
||||
if not isinstance(data, (bytes, bytearray, memoryview)):
|
||||
raise TypeError('data argument must be byte-ish (%r)',
|
||||
type(data))
|
||||
if not data:
|
||||
return
|
||||
|
||||
|
|
@ -753,7 +767,7 @@ class _SelectorSslTransport(_SelectorTransport):
|
|||
self._loop.add_writer(self._sock_fd, self._write_ready)
|
||||
|
||||
# Add it to the buffer.
|
||||
self._buffer.append(data)
|
||||
self._buffer.extend(data)
|
||||
self._maybe_pause_protocol()
|
||||
|
||||
def can_write_eof(self):
|
||||
|
|
@ -762,6 +776,8 @@ class _SelectorSslTransport(_SelectorTransport):
|
|||
|
||||
class _SelectorDatagramTransport(_SelectorTransport):
|
||||
|
||||
_buffer_factory = collections.deque
|
||||
|
||||
def __init__(self, loop, sock, protocol, address=None, extra=None):
|
||||
super().__init__(loop, sock, protocol, extra)
|
||||
self._address = address
|
||||
|
|
@ -784,12 +800,15 @@ class _SelectorDatagramTransport(_SelectorTransport):
|
|||
self._protocol.datagram_received(data, addr)
|
||||
|
||||
def sendto(self, data, addr=None):
|
||||
assert isinstance(data, bytes), repr(type(data))
|
||||
if not isinstance(data, (bytes, bytearray, memoryview)):
|
||||
raise TypeError('data argument must be byte-ish (%r)',
|
||||
type(data))
|
||||
if not data:
|
||||
return
|
||||
|
||||
if self._address:
|
||||
assert addr in (None, self._address)
|
||||
if self._address and addr not in (None, self._address):
|
||||
raise ValueError('Invalid address: must be None or %s' %
|
||||
(self._address,))
|
||||
|
||||
if self._conn_lost and self._address:
|
||||
if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
|
||||
|
|
@ -814,7 +833,8 @@ class _SelectorDatagramTransport(_SelectorTransport):
|
|||
self._fatal_error(exc)
|
||||
return
|
||||
|
||||
self._buffer.append((data, addr))
|
||||
# Ensure that what we buffer is immutable.
|
||||
self._buffer.append((bytes(data), addr))
|
||||
self._maybe_pause_protocol()
|
||||
|
||||
def _sendto_ready(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue