mirror of
https://github.com/python/cpython.git
synced 2025-12-04 08:34:25 +00:00
asyncio: Change write buffer use to avoid O(N**2). Make write()/sendto() accept bytearray/memoryview too. Change some asserts with proper exceptions.
This commit is contained in:
parent
f28ce60441
commit
a5062c5d81
2 changed files with 207 additions and 78 deletions
|
|
@ -340,6 +340,8 @@ class _SelectorTransport(transports.Transport):
|
||||||
|
|
||||||
max_size = 256 * 1024 # Buffer size passed to recv().
|
max_size = 256 * 1024 # Buffer size passed to recv().
|
||||||
|
|
||||||
|
_buffer_factory = bytearray # Constructs initial value for self._buffer.
|
||||||
|
|
||||||
def __init__(self, loop, sock, protocol, extra, server=None):
|
def __init__(self, loop, sock, protocol, extra, server=None):
|
||||||
super().__init__(extra)
|
super().__init__(extra)
|
||||||
self._extra['socket'] = sock
|
self._extra['socket'] = sock
|
||||||
|
|
@ -354,7 +356,7 @@ class _SelectorTransport(transports.Transport):
|
||||||
self._sock_fd = sock.fileno()
|
self._sock_fd = sock.fileno()
|
||||||
self._protocol = protocol
|
self._protocol = protocol
|
||||||
self._server = server
|
self._server = server
|
||||||
self._buffer = collections.deque()
|
self._buffer = self._buffer_factory()
|
||||||
self._conn_lost = 0 # Set when call to connection_lost scheduled.
|
self._conn_lost = 0 # Set when call to connection_lost scheduled.
|
||||||
self._closing = False # Set when close() called.
|
self._closing = False # Set when close() called.
|
||||||
self._protocol_paused = False
|
self._protocol_paused = False
|
||||||
|
|
@ -433,12 +435,14 @@ class _SelectorTransport(transports.Transport):
|
||||||
high = 4*low
|
high = 4*low
|
||||||
if low is None:
|
if low is None:
|
||||||
low = high // 4
|
low = high // 4
|
||||||
assert 0 <= low <= high, repr((low, high))
|
if not high >= low >= 0:
|
||||||
|
raise ValueError('high (%r) must be >= low (%r) must be >= 0' %
|
||||||
|
(high, low))
|
||||||
self._high_water = high
|
self._high_water = high
|
||||||
self._low_water = low
|
self._low_water = low
|
||||||
|
|
||||||
def get_write_buffer_size(self):
|
def get_write_buffer_size(self):
|
||||||
return sum(len(data) for data in self._buffer)
|
return len(self._buffer)
|
||||||
|
|
||||||
|
|
||||||
class _SelectorSocketTransport(_SelectorTransport):
|
class _SelectorSocketTransport(_SelectorTransport):
|
||||||
|
|
@ -455,13 +459,16 @@ class _SelectorSocketTransport(_SelectorTransport):
|
||||||
self._loop.call_soon(waiter.set_result, None)
|
self._loop.call_soon(waiter.set_result, None)
|
||||||
|
|
||||||
def pause_reading(self):
|
def pause_reading(self):
|
||||||
assert not self._closing, 'Cannot pause_reading() when closing'
|
if self._closing:
|
||||||
assert not self._paused, 'Already paused'
|
raise RuntimeError('Cannot pause_reading() when closing')
|
||||||
|
if self._paused:
|
||||||
|
raise RuntimeError('Already paused')
|
||||||
self._paused = True
|
self._paused = True
|
||||||
self._loop.remove_reader(self._sock_fd)
|
self._loop.remove_reader(self._sock_fd)
|
||||||
|
|
||||||
def resume_reading(self):
|
def resume_reading(self):
|
||||||
assert self._paused, 'Not paused'
|
if not self._paused:
|
||||||
|
raise RuntimeError('Not paused')
|
||||||
self._paused = False
|
self._paused = False
|
||||||
if self._closing:
|
if self._closing:
|
||||||
return
|
return
|
||||||
|
|
@ -488,8 +495,11 @@ class _SelectorSocketTransport(_SelectorTransport):
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def write(self, data):
|
def write(self, data):
|
||||||
assert isinstance(data, bytes), repr(type(data))
|
if not isinstance(data, (bytes, bytearray, memoryview)):
|
||||||
assert not self._eof, 'Cannot call write() after write_eof()'
|
raise TypeError('data argument must be byte-ish (%r)',
|
||||||
|
type(data))
|
||||||
|
if self._eof:
|
||||||
|
raise RuntimeError('Cannot call write() after write_eof()')
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -516,25 +526,23 @@ class _SelectorSocketTransport(_SelectorTransport):
|
||||||
self._loop.add_writer(self._sock_fd, self._write_ready)
|
self._loop.add_writer(self._sock_fd, self._write_ready)
|
||||||
|
|
||||||
# Add it to the buffer.
|
# Add it to the buffer.
|
||||||
self._buffer.append(data)
|
self._buffer.extend(data)
|
||||||
self._maybe_pause_protocol()
|
self._maybe_pause_protocol()
|
||||||
|
|
||||||
def _write_ready(self):
|
def _write_ready(self):
|
||||||
data = b''.join(self._buffer)
|
assert self._buffer, 'Data should not be empty'
|
||||||
assert data, 'Data should not be empty'
|
|
||||||
|
|
||||||
self._buffer.clear() # Optimistically; may have to put it back later.
|
|
||||||
try:
|
try:
|
||||||
n = self._sock.send(data)
|
n = self._sock.send(self._buffer)
|
||||||
except (BlockingIOError, InterruptedError):
|
except (BlockingIOError, InterruptedError):
|
||||||
self._buffer.append(data) # Still need to write this.
|
pass
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self._loop.remove_writer(self._sock_fd)
|
self._loop.remove_writer(self._sock_fd)
|
||||||
|
self._buffer.clear()
|
||||||
self._fatal_error(exc)
|
self._fatal_error(exc)
|
||||||
else:
|
else:
|
||||||
data = data[n:]
|
if n:
|
||||||
if data:
|
del self._buffer[:n]
|
||||||
self._buffer.append(data) # Still need to write this.
|
|
||||||
self._maybe_resume_protocol() # May append to buffer.
|
self._maybe_resume_protocol() # May append to buffer.
|
||||||
if not self._buffer:
|
if not self._buffer:
|
||||||
self._loop.remove_writer(self._sock_fd)
|
self._loop.remove_writer(self._sock_fd)
|
||||||
|
|
@ -556,6 +564,8 @@ class _SelectorSocketTransport(_SelectorTransport):
|
||||||
|
|
||||||
class _SelectorSslTransport(_SelectorTransport):
|
class _SelectorSslTransport(_SelectorTransport):
|
||||||
|
|
||||||
|
_buffer_factory = bytearray
|
||||||
|
|
||||||
def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None,
|
def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None,
|
||||||
server_side=False, server_hostname=None,
|
server_side=False, server_hostname=None,
|
||||||
extra=None, server=None):
|
extra=None, server=None):
|
||||||
|
|
@ -661,13 +671,16 @@ class _SelectorSslTransport(_SelectorTransport):
|
||||||
# accept more data for the buffer and eventually the app will
|
# accept more data for the buffer and eventually the app will
|
||||||
# call resume_reading() again, and things will flow again.
|
# call resume_reading() again, and things will flow again.
|
||||||
|
|
||||||
assert not self._closing, 'Cannot pause_reading() when closing'
|
if self._closing:
|
||||||
assert not self._paused, 'Already paused'
|
raise RuntimeError('Cannot pause_reading() when closing')
|
||||||
|
if self._paused:
|
||||||
|
raise RuntimeError('Already paused')
|
||||||
self._paused = True
|
self._paused = True
|
||||||
self._loop.remove_reader(self._sock_fd)
|
self._loop.remove_reader(self._sock_fd)
|
||||||
|
|
||||||
def resume_reading(self):
|
def resume_reading(self):
|
||||||
assert self._paused, 'Not paused'
|
if not self._paused:
|
||||||
|
raise ('Not paused')
|
||||||
self._paused = False
|
self._paused = False
|
||||||
if self._closing:
|
if self._closing:
|
||||||
return
|
return
|
||||||
|
|
@ -712,10 +725,8 @@ class _SelectorSslTransport(_SelectorTransport):
|
||||||
self._loop.add_reader(self._sock_fd, self._read_ready)
|
self._loop.add_reader(self._sock_fd, self._read_ready)
|
||||||
|
|
||||||
if self._buffer:
|
if self._buffer:
|
||||||
data = b''.join(self._buffer)
|
|
||||||
self._buffer.clear()
|
|
||||||
try:
|
try:
|
||||||
n = self._sock.send(data)
|
n = self._sock.send(self._buffer)
|
||||||
except (BlockingIOError, InterruptedError,
|
except (BlockingIOError, InterruptedError,
|
||||||
ssl.SSLWantWriteError):
|
ssl.SSLWantWriteError):
|
||||||
n = 0
|
n = 0
|
||||||
|
|
@ -725,11 +736,12 @@ class _SelectorSslTransport(_SelectorTransport):
|
||||||
self._write_wants_read = True
|
self._write_wants_read = True
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self._loop.remove_writer(self._sock_fd)
|
self._loop.remove_writer(self._sock_fd)
|
||||||
|
self._buffer.clear()
|
||||||
self._fatal_error(exc)
|
self._fatal_error(exc)
|
||||||
return
|
return
|
||||||
|
|
||||||
if n < len(data):
|
if n:
|
||||||
self._buffer.append(data[n:])
|
del self._buffer[:n]
|
||||||
|
|
||||||
self._maybe_resume_protocol() # May append to buffer.
|
self._maybe_resume_protocol() # May append to buffer.
|
||||||
|
|
||||||
|
|
@ -739,7 +751,9 @@ class _SelectorSslTransport(_SelectorTransport):
|
||||||
self._call_connection_lost(None)
|
self._call_connection_lost(None)
|
||||||
|
|
||||||
def write(self, data):
|
def write(self, data):
|
||||||
assert isinstance(data, bytes), repr(type(data))
|
if not isinstance(data, (bytes, bytearray, memoryview)):
|
||||||
|
raise TypeError('data argument must be byte-ish (%r)',
|
||||||
|
type(data))
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -753,7 +767,7 @@ class _SelectorSslTransport(_SelectorTransport):
|
||||||
self._loop.add_writer(self._sock_fd, self._write_ready)
|
self._loop.add_writer(self._sock_fd, self._write_ready)
|
||||||
|
|
||||||
# Add it to the buffer.
|
# Add it to the buffer.
|
||||||
self._buffer.append(data)
|
self._buffer.extend(data)
|
||||||
self._maybe_pause_protocol()
|
self._maybe_pause_protocol()
|
||||||
|
|
||||||
def can_write_eof(self):
|
def can_write_eof(self):
|
||||||
|
|
@ -762,6 +776,8 @@ class _SelectorSslTransport(_SelectorTransport):
|
||||||
|
|
||||||
class _SelectorDatagramTransport(_SelectorTransport):
|
class _SelectorDatagramTransport(_SelectorTransport):
|
||||||
|
|
||||||
|
_buffer_factory = collections.deque
|
||||||
|
|
||||||
def __init__(self, loop, sock, protocol, address=None, extra=None):
|
def __init__(self, loop, sock, protocol, address=None, extra=None):
|
||||||
super().__init__(loop, sock, protocol, extra)
|
super().__init__(loop, sock, protocol, extra)
|
||||||
self._address = address
|
self._address = address
|
||||||
|
|
@ -784,12 +800,15 @@ class _SelectorDatagramTransport(_SelectorTransport):
|
||||||
self._protocol.datagram_received(data, addr)
|
self._protocol.datagram_received(data, addr)
|
||||||
|
|
||||||
def sendto(self, data, addr=None):
|
def sendto(self, data, addr=None):
|
||||||
assert isinstance(data, bytes), repr(type(data))
|
if not isinstance(data, (bytes, bytearray, memoryview)):
|
||||||
|
raise TypeError('data argument must be byte-ish (%r)',
|
||||||
|
type(data))
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self._address:
|
if self._address and addr not in (None, self._address):
|
||||||
assert addr in (None, self._address)
|
raise ValueError('Invalid address: must be None or %s' %
|
||||||
|
(self._address,))
|
||||||
|
|
||||||
if self._conn_lost and self._address:
|
if self._conn_lost and self._address:
|
||||||
if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
|
if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
|
||||||
|
|
@ -814,7 +833,8 @@ class _SelectorDatagramTransport(_SelectorTransport):
|
||||||
self._fatal_error(exc)
|
self._fatal_error(exc)
|
||||||
return
|
return
|
||||||
|
|
||||||
self._buffer.append((data, addr))
|
# Ensure that what we buffer is immutable.
|
||||||
|
self._buffer.append((bytes(data), addr))
|
||||||
self._maybe_pause_protocol()
|
self._maybe_pause_protocol()
|
||||||
|
|
||||||
def _sendto_ready(self):
|
def _sendto_ready(self):
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,10 @@ class TestBaseSelectorEventLoop(BaseSelectorEventLoop):
|
||||||
self._internal_fds += 1
|
self._internal_fds += 1
|
||||||
|
|
||||||
|
|
||||||
|
def list_to_buffer(l=()):
|
||||||
|
return bytearray().join(l)
|
||||||
|
|
||||||
|
|
||||||
class BaseSelectorEventLoopTests(unittest.TestCase):
|
class BaseSelectorEventLoopTests(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|
@ -613,7 +617,7 @@ class SelectorTransportTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_close_write_buffer(self):
|
def test_close_write_buffer(self):
|
||||||
tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
|
tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
|
||||||
tr._buffer.append(b'data')
|
tr._buffer.extend(b'data')
|
||||||
tr.close()
|
tr.close()
|
||||||
|
|
||||||
self.assertFalse(self.loop.readers)
|
self.assertFalse(self.loop.readers)
|
||||||
|
|
@ -622,13 +626,13 @@ class SelectorTransportTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_force_close(self):
|
def test_force_close(self):
|
||||||
tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
|
tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
|
||||||
tr._buffer.append(b'1')
|
tr._buffer.extend(b'1')
|
||||||
self.loop.add_reader(7, unittest.mock.sentinel)
|
self.loop.add_reader(7, unittest.mock.sentinel)
|
||||||
self.loop.add_writer(7, unittest.mock.sentinel)
|
self.loop.add_writer(7, unittest.mock.sentinel)
|
||||||
tr._force_close(None)
|
tr._force_close(None)
|
||||||
|
|
||||||
self.assertTrue(tr._closing)
|
self.assertTrue(tr._closing)
|
||||||
self.assertEqual(tr._buffer, collections.deque())
|
self.assertEqual(tr._buffer, list_to_buffer())
|
||||||
self.assertFalse(self.loop.readers)
|
self.assertFalse(self.loop.readers)
|
||||||
self.assertFalse(self.loop.writers)
|
self.assertFalse(self.loop.writers)
|
||||||
|
|
||||||
|
|
@ -783,21 +787,40 @@ class SelectorSocketTransportTests(unittest.TestCase):
|
||||||
transport.write(data)
|
transport.write(data)
|
||||||
self.sock.send.assert_called_with(data)
|
self.sock.send.assert_called_with(data)
|
||||||
|
|
||||||
|
def test_write_bytearray(self):
|
||||||
|
data = bytearray(b'data')
|
||||||
|
self.sock.send.return_value = len(data)
|
||||||
|
|
||||||
|
transport = _SelectorSocketTransport(
|
||||||
|
self.loop, self.sock, self.protocol)
|
||||||
|
transport.write(data)
|
||||||
|
self.sock.send.assert_called_with(data)
|
||||||
|
self.assertEqual(data, bytearray(b'data')) # Hasn't been mutated.
|
||||||
|
|
||||||
|
def test_write_memoryview(self):
|
||||||
|
data = memoryview(b'data')
|
||||||
|
self.sock.send.return_value = len(data)
|
||||||
|
|
||||||
|
transport = _SelectorSocketTransport(
|
||||||
|
self.loop, self.sock, self.protocol)
|
||||||
|
transport.write(data)
|
||||||
|
self.sock.send.assert_called_with(data)
|
||||||
|
|
||||||
def test_write_no_data(self):
|
def test_write_no_data(self):
|
||||||
transport = _SelectorSocketTransport(
|
transport = _SelectorSocketTransport(
|
||||||
self.loop, self.sock, self.protocol)
|
self.loop, self.sock, self.protocol)
|
||||||
transport._buffer.append(b'data')
|
transport._buffer.extend(b'data')
|
||||||
transport.write(b'')
|
transport.write(b'')
|
||||||
self.assertFalse(self.sock.send.called)
|
self.assertFalse(self.sock.send.called)
|
||||||
self.assertEqual(collections.deque([b'data']), transport._buffer)
|
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
|
||||||
|
|
||||||
def test_write_buffer(self):
|
def test_write_buffer(self):
|
||||||
transport = _SelectorSocketTransport(
|
transport = _SelectorSocketTransport(
|
||||||
self.loop, self.sock, self.protocol)
|
self.loop, self.sock, self.protocol)
|
||||||
transport._buffer.append(b'data1')
|
transport._buffer.extend(b'data1')
|
||||||
transport.write(b'data2')
|
transport.write(b'data2')
|
||||||
self.assertFalse(self.sock.send.called)
|
self.assertFalse(self.sock.send.called)
|
||||||
self.assertEqual(collections.deque([b'data1', b'data2']),
|
self.assertEqual(list_to_buffer([b'data1', b'data2']),
|
||||||
transport._buffer)
|
transport._buffer)
|
||||||
|
|
||||||
def test_write_partial(self):
|
def test_write_partial(self):
|
||||||
|
|
@ -809,7 +832,30 @@ class SelectorSocketTransportTests(unittest.TestCase):
|
||||||
transport.write(data)
|
transport.write(data)
|
||||||
|
|
||||||
self.loop.assert_writer(7, transport._write_ready)
|
self.loop.assert_writer(7, transport._write_ready)
|
||||||
self.assertEqual(collections.deque([b'ta']), transport._buffer)
|
self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
|
||||||
|
|
||||||
|
def test_write_partial_bytearray(self):
|
||||||
|
data = bytearray(b'data')
|
||||||
|
self.sock.send.return_value = 2
|
||||||
|
|
||||||
|
transport = _SelectorSocketTransport(
|
||||||
|
self.loop, self.sock, self.protocol)
|
||||||
|
transport.write(data)
|
||||||
|
|
||||||
|
self.loop.assert_writer(7, transport._write_ready)
|
||||||
|
self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
|
||||||
|
self.assertEqual(data, bytearray(b'data')) # Hasn't been mutated.
|
||||||
|
|
||||||
|
def test_write_partial_memoryview(self):
|
||||||
|
data = memoryview(b'data')
|
||||||
|
self.sock.send.return_value = 2
|
||||||
|
|
||||||
|
transport = _SelectorSocketTransport(
|
||||||
|
self.loop, self.sock, self.protocol)
|
||||||
|
transport.write(data)
|
||||||
|
|
||||||
|
self.loop.assert_writer(7, transport._write_ready)
|
||||||
|
self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
|
||||||
|
|
||||||
def test_write_partial_none(self):
|
def test_write_partial_none(self):
|
||||||
data = b'data'
|
data = b'data'
|
||||||
|
|
@ -821,7 +867,7 @@ class SelectorSocketTransportTests(unittest.TestCase):
|
||||||
transport.write(data)
|
transport.write(data)
|
||||||
|
|
||||||
self.loop.assert_writer(7, transport._write_ready)
|
self.loop.assert_writer(7, transport._write_ready)
|
||||||
self.assertEqual(collections.deque([b'data']), transport._buffer)
|
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
|
||||||
|
|
||||||
def test_write_tryagain(self):
|
def test_write_tryagain(self):
|
||||||
self.sock.send.side_effect = BlockingIOError
|
self.sock.send.side_effect = BlockingIOError
|
||||||
|
|
@ -832,7 +878,7 @@ class SelectorSocketTransportTests(unittest.TestCase):
|
||||||
transport.write(data)
|
transport.write(data)
|
||||||
|
|
||||||
self.loop.assert_writer(7, transport._write_ready)
|
self.loop.assert_writer(7, transport._write_ready)
|
||||||
self.assertEqual(collections.deque([b'data']), transport._buffer)
|
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
|
||||||
|
|
||||||
@unittest.mock.patch('asyncio.selector_events.logger')
|
@unittest.mock.patch('asyncio.selector_events.logger')
|
||||||
def test_write_exception(self, m_log):
|
def test_write_exception(self, m_log):
|
||||||
|
|
@ -859,7 +905,7 @@ class SelectorSocketTransportTests(unittest.TestCase):
|
||||||
def test_write_str(self):
|
def test_write_str(self):
|
||||||
transport = _SelectorSocketTransport(
|
transport = _SelectorSocketTransport(
|
||||||
self.loop, self.sock, self.protocol)
|
self.loop, self.sock, self.protocol)
|
||||||
self.assertRaises(AssertionError, transport.write, 'str')
|
self.assertRaises(TypeError, transport.write, 'str')
|
||||||
|
|
||||||
def test_write_closing(self):
|
def test_write_closing(self):
|
||||||
transport = _SelectorSocketTransport(
|
transport = _SelectorSocketTransport(
|
||||||
|
|
@ -875,11 +921,10 @@ class SelectorSocketTransportTests(unittest.TestCase):
|
||||||
|
|
||||||
transport = _SelectorSocketTransport(
|
transport = _SelectorSocketTransport(
|
||||||
self.loop, self.sock, self.protocol)
|
self.loop, self.sock, self.protocol)
|
||||||
transport._buffer.append(data)
|
transport._buffer.extend(data)
|
||||||
self.loop.add_writer(7, transport._write_ready)
|
self.loop.add_writer(7, transport._write_ready)
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
self.assertTrue(self.sock.send.called)
|
self.assertTrue(self.sock.send.called)
|
||||||
self.assertEqual(self.sock.send.call_args[0], (data,))
|
|
||||||
self.assertFalse(self.loop.writers)
|
self.assertFalse(self.loop.writers)
|
||||||
|
|
||||||
def test_write_ready_closing(self):
|
def test_write_ready_closing(self):
|
||||||
|
|
@ -889,10 +934,10 @@ class SelectorSocketTransportTests(unittest.TestCase):
|
||||||
transport = _SelectorSocketTransport(
|
transport = _SelectorSocketTransport(
|
||||||
self.loop, self.sock, self.protocol)
|
self.loop, self.sock, self.protocol)
|
||||||
transport._closing = True
|
transport._closing = True
|
||||||
transport._buffer.append(data)
|
transport._buffer.extend(data)
|
||||||
self.loop.add_writer(7, transport._write_ready)
|
self.loop.add_writer(7, transport._write_ready)
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
self.sock.send.assert_called_with(data)
|
self.assertTrue(self.sock.send.called)
|
||||||
self.assertFalse(self.loop.writers)
|
self.assertFalse(self.loop.writers)
|
||||||
self.sock.close.assert_called_with()
|
self.sock.close.assert_called_with()
|
||||||
self.protocol.connection_lost.assert_called_with(None)
|
self.protocol.connection_lost.assert_called_with(None)
|
||||||
|
|
@ -900,6 +945,7 @@ class SelectorSocketTransportTests(unittest.TestCase):
|
||||||
def test_write_ready_no_data(self):
|
def test_write_ready_no_data(self):
|
||||||
transport = _SelectorSocketTransport(
|
transport = _SelectorSocketTransport(
|
||||||
self.loop, self.sock, self.protocol)
|
self.loop, self.sock, self.protocol)
|
||||||
|
# This is an internal error.
|
||||||
self.assertRaises(AssertionError, transport._write_ready)
|
self.assertRaises(AssertionError, transport._write_ready)
|
||||||
|
|
||||||
def test_write_ready_partial(self):
|
def test_write_ready_partial(self):
|
||||||
|
|
@ -908,11 +954,11 @@ class SelectorSocketTransportTests(unittest.TestCase):
|
||||||
|
|
||||||
transport = _SelectorSocketTransport(
|
transport = _SelectorSocketTransport(
|
||||||
self.loop, self.sock, self.protocol)
|
self.loop, self.sock, self.protocol)
|
||||||
transport._buffer.append(data)
|
transport._buffer.extend(data)
|
||||||
self.loop.add_writer(7, transport._write_ready)
|
self.loop.add_writer(7, transport._write_ready)
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
self.loop.assert_writer(7, transport._write_ready)
|
self.loop.assert_writer(7, transport._write_ready)
|
||||||
self.assertEqual(collections.deque([b'ta']), transport._buffer)
|
self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
|
||||||
|
|
||||||
def test_write_ready_partial_none(self):
|
def test_write_ready_partial_none(self):
|
||||||
data = b'data'
|
data = b'data'
|
||||||
|
|
@ -920,23 +966,23 @@ class SelectorSocketTransportTests(unittest.TestCase):
|
||||||
|
|
||||||
transport = _SelectorSocketTransport(
|
transport = _SelectorSocketTransport(
|
||||||
self.loop, self.sock, self.protocol)
|
self.loop, self.sock, self.protocol)
|
||||||
transport._buffer.append(data)
|
transport._buffer.extend(data)
|
||||||
self.loop.add_writer(7, transport._write_ready)
|
self.loop.add_writer(7, transport._write_ready)
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
self.loop.assert_writer(7, transport._write_ready)
|
self.loop.assert_writer(7, transport._write_ready)
|
||||||
self.assertEqual(collections.deque([b'data']), transport._buffer)
|
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
|
||||||
|
|
||||||
def test_write_ready_tryagain(self):
|
def test_write_ready_tryagain(self):
|
||||||
self.sock.send.side_effect = BlockingIOError
|
self.sock.send.side_effect = BlockingIOError
|
||||||
|
|
||||||
transport = _SelectorSocketTransport(
|
transport = _SelectorSocketTransport(
|
||||||
self.loop, self.sock, self.protocol)
|
self.loop, self.sock, self.protocol)
|
||||||
transport._buffer = collections.deque([b'data1', b'data2'])
|
transport._buffer = list_to_buffer([b'data1', b'data2'])
|
||||||
self.loop.add_writer(7, transport._write_ready)
|
self.loop.add_writer(7, transport._write_ready)
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
|
|
||||||
self.loop.assert_writer(7, transport._write_ready)
|
self.loop.assert_writer(7, transport._write_ready)
|
||||||
self.assertEqual(collections.deque([b'data1data2']), transport._buffer)
|
self.assertEqual(list_to_buffer([b'data1data2']), transport._buffer)
|
||||||
|
|
||||||
def test_write_ready_exception(self):
|
def test_write_ready_exception(self):
|
||||||
err = self.sock.send.side_effect = OSError()
|
err = self.sock.send.side_effect = OSError()
|
||||||
|
|
@ -944,7 +990,7 @@ class SelectorSocketTransportTests(unittest.TestCase):
|
||||||
transport = _SelectorSocketTransport(
|
transport = _SelectorSocketTransport(
|
||||||
self.loop, self.sock, self.protocol)
|
self.loop, self.sock, self.protocol)
|
||||||
transport._fatal_error = unittest.mock.Mock()
|
transport._fatal_error = unittest.mock.Mock()
|
||||||
transport._buffer.append(b'data')
|
transport._buffer.extend(b'data')
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
transport._fatal_error.assert_called_with(err)
|
transport._fatal_error.assert_called_with(err)
|
||||||
|
|
||||||
|
|
@ -956,7 +1002,7 @@ class SelectorSocketTransportTests(unittest.TestCase):
|
||||||
transport = _SelectorSocketTransport(
|
transport = _SelectorSocketTransport(
|
||||||
self.loop, self.sock, self.protocol)
|
self.loop, self.sock, self.protocol)
|
||||||
transport.close()
|
transport.close()
|
||||||
transport._buffer.append(b'data')
|
transport._buffer.extend(b'data')
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
remove_writer.assert_called_with(self.sock_fd)
|
remove_writer.assert_called_with(self.sock_fd)
|
||||||
|
|
||||||
|
|
@ -976,12 +1022,12 @@ class SelectorSocketTransportTests(unittest.TestCase):
|
||||||
self.sock.send.side_effect = BlockingIOError
|
self.sock.send.side_effect = BlockingIOError
|
||||||
tr.write(b'data')
|
tr.write(b'data')
|
||||||
tr.write_eof()
|
tr.write_eof()
|
||||||
self.assertEqual(tr._buffer, collections.deque([b'data']))
|
self.assertEqual(tr._buffer, list_to_buffer([b'data']))
|
||||||
self.assertTrue(tr._eof)
|
self.assertTrue(tr._eof)
|
||||||
self.assertFalse(self.sock.shutdown.called)
|
self.assertFalse(self.sock.shutdown.called)
|
||||||
self.sock.send.side_effect = lambda _: 4
|
self.sock.send.side_effect = lambda _: 4
|
||||||
tr._write_ready()
|
tr._write_ready()
|
||||||
self.sock.send.assert_called_with(b'data')
|
self.assertTrue(self.sock.send.called)
|
||||||
self.sock.shutdown.assert_called_with(socket.SHUT_WR)
|
self.sock.shutdown.assert_called_with(socket.SHUT_WR)
|
||||||
tr.close()
|
tr.close()
|
||||||
|
|
||||||
|
|
@ -1065,15 +1111,34 @@ class SelectorSslTransportTests(unittest.TestCase):
|
||||||
self.assertFalse(tr._paused)
|
self.assertFalse(tr._paused)
|
||||||
self.loop.assert_reader(1, tr._read_ready)
|
self.loop.assert_reader(1, tr._read_ready)
|
||||||
|
|
||||||
|
def test_write(self):
|
||||||
|
transport = self._make_one()
|
||||||
|
transport.write(b'data')
|
||||||
|
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
|
||||||
|
|
||||||
|
def test_write_bytearray(self):
|
||||||
|
transport = self._make_one()
|
||||||
|
data = bytearray(b'data')
|
||||||
|
transport.write(data)
|
||||||
|
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
|
||||||
|
self.assertEqual(data, bytearray(b'data')) # Hasn't been mutated.
|
||||||
|
self.assertIsNot(data, transport._buffer) # Hasn't been incorporated.
|
||||||
|
|
||||||
|
def test_write_memoryview(self):
|
||||||
|
transport = self._make_one()
|
||||||
|
data = memoryview(b'data')
|
||||||
|
transport.write(data)
|
||||||
|
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
|
||||||
|
|
||||||
def test_write_no_data(self):
|
def test_write_no_data(self):
|
||||||
transport = self._make_one()
|
transport = self._make_one()
|
||||||
transport._buffer.append(b'data')
|
transport._buffer.extend(b'data')
|
||||||
transport.write(b'')
|
transport.write(b'')
|
||||||
self.assertEqual(collections.deque([b'data']), transport._buffer)
|
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
|
||||||
|
|
||||||
def test_write_str(self):
|
def test_write_str(self):
|
||||||
transport = self._make_one()
|
transport = self._make_one()
|
||||||
self.assertRaises(AssertionError, transport.write, 'str')
|
self.assertRaises(TypeError, transport.write, 'str')
|
||||||
|
|
||||||
def test_write_closing(self):
|
def test_write_closing(self):
|
||||||
transport = self._make_one()
|
transport = self._make_one()
|
||||||
|
|
@ -1087,7 +1152,7 @@ class SelectorSslTransportTests(unittest.TestCase):
|
||||||
transport = self._make_one()
|
transport = self._make_one()
|
||||||
transport._conn_lost = 1
|
transport._conn_lost = 1
|
||||||
transport.write(b'data')
|
transport.write(b'data')
|
||||||
self.assertEqual(transport._buffer, collections.deque())
|
self.assertEqual(transport._buffer, list_to_buffer())
|
||||||
transport.write(b'data')
|
transport.write(b'data')
|
||||||
transport.write(b'data')
|
transport.write(b'data')
|
||||||
transport.write(b'data')
|
transport.write(b'data')
|
||||||
|
|
@ -1107,7 +1172,7 @@ class SelectorSslTransportTests(unittest.TestCase):
|
||||||
transport = self._make_one()
|
transport = self._make_one()
|
||||||
transport._write_wants_read = True
|
transport._write_wants_read = True
|
||||||
transport._write_ready = unittest.mock.Mock()
|
transport._write_ready = unittest.mock.Mock()
|
||||||
transport._buffer.append(b'data')
|
transport._buffer.extend(b'data')
|
||||||
transport._read_ready()
|
transport._read_ready()
|
||||||
|
|
||||||
self.assertFalse(transport._write_wants_read)
|
self.assertFalse(transport._write_wants_read)
|
||||||
|
|
@ -1168,31 +1233,31 @@ class SelectorSslTransportTests(unittest.TestCase):
|
||||||
def test_write_ready_send(self):
|
def test_write_ready_send(self):
|
||||||
self.sslsock.send.return_value = 4
|
self.sslsock.send.return_value = 4
|
||||||
transport = self._make_one()
|
transport = self._make_one()
|
||||||
transport._buffer = collections.deque([b'data'])
|
transport._buffer = list_to_buffer([b'data'])
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
self.assertEqual(collections.deque(), transport._buffer)
|
self.assertEqual(list_to_buffer(), transport._buffer)
|
||||||
self.assertTrue(self.sslsock.send.called)
|
self.assertTrue(self.sslsock.send.called)
|
||||||
|
|
||||||
def test_write_ready_send_none(self):
|
def test_write_ready_send_none(self):
|
||||||
self.sslsock.send.return_value = 0
|
self.sslsock.send.return_value = 0
|
||||||
transport = self._make_one()
|
transport = self._make_one()
|
||||||
transport._buffer = collections.deque([b'data1', b'data2'])
|
transport._buffer = list_to_buffer([b'data1', b'data2'])
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
self.assertTrue(self.sslsock.send.called)
|
self.assertTrue(self.sslsock.send.called)
|
||||||
self.assertEqual(collections.deque([b'data1data2']), transport._buffer)
|
self.assertEqual(list_to_buffer([b'data1data2']), transport._buffer)
|
||||||
|
|
||||||
def test_write_ready_send_partial(self):
|
def test_write_ready_send_partial(self):
|
||||||
self.sslsock.send.return_value = 2
|
self.sslsock.send.return_value = 2
|
||||||
transport = self._make_one()
|
transport = self._make_one()
|
||||||
transport._buffer = collections.deque([b'data1', b'data2'])
|
transport._buffer = list_to_buffer([b'data1', b'data2'])
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
self.assertTrue(self.sslsock.send.called)
|
self.assertTrue(self.sslsock.send.called)
|
||||||
self.assertEqual(collections.deque([b'ta1data2']), transport._buffer)
|
self.assertEqual(list_to_buffer([b'ta1data2']), transport._buffer)
|
||||||
|
|
||||||
def test_write_ready_send_closing_partial(self):
|
def test_write_ready_send_closing_partial(self):
|
||||||
self.sslsock.send.return_value = 2
|
self.sslsock.send.return_value = 2
|
||||||
transport = self._make_one()
|
transport = self._make_one()
|
||||||
transport._buffer = collections.deque([b'data1', b'data2'])
|
transport._buffer = list_to_buffer([b'data1', b'data2'])
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
self.assertTrue(self.sslsock.send.called)
|
self.assertTrue(self.sslsock.send.called)
|
||||||
self.assertFalse(self.sslsock.close.called)
|
self.assertFalse(self.sslsock.close.called)
|
||||||
|
|
@ -1201,7 +1266,7 @@ class SelectorSslTransportTests(unittest.TestCase):
|
||||||
self.sslsock.send.return_value = 4
|
self.sslsock.send.return_value = 4
|
||||||
transport = self._make_one()
|
transport = self._make_one()
|
||||||
transport.close()
|
transport.close()
|
||||||
transport._buffer = collections.deque([b'data'])
|
transport._buffer = list_to_buffer([b'data'])
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
self.assertFalse(self.loop.writers)
|
self.assertFalse(self.loop.writers)
|
||||||
self.protocol.connection_lost.assert_called_with(None)
|
self.protocol.connection_lost.assert_called_with(None)
|
||||||
|
|
@ -1210,26 +1275,26 @@ class SelectorSslTransportTests(unittest.TestCase):
|
||||||
self.sslsock.send.return_value = 4
|
self.sslsock.send.return_value = 4
|
||||||
transport = self._make_one()
|
transport = self._make_one()
|
||||||
transport.close()
|
transport.close()
|
||||||
transport._buffer = collections.deque()
|
transport._buffer = list_to_buffer()
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
self.assertFalse(self.loop.writers)
|
self.assertFalse(self.loop.writers)
|
||||||
self.protocol.connection_lost.assert_called_with(None)
|
self.protocol.connection_lost.assert_called_with(None)
|
||||||
|
|
||||||
def test_write_ready_send_retry(self):
|
def test_write_ready_send_retry(self):
|
||||||
transport = self._make_one()
|
transport = self._make_one()
|
||||||
transport._buffer = collections.deque([b'data'])
|
transport._buffer = list_to_buffer([b'data'])
|
||||||
|
|
||||||
self.sslsock.send.side_effect = ssl.SSLWantWriteError
|
self.sslsock.send.side_effect = ssl.SSLWantWriteError
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
self.assertEqual(collections.deque([b'data']), transport._buffer)
|
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
|
||||||
|
|
||||||
self.sslsock.send.side_effect = BlockingIOError()
|
self.sslsock.send.side_effect = BlockingIOError()
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
self.assertEqual(collections.deque([b'data']), transport._buffer)
|
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
|
||||||
|
|
||||||
def test_write_ready_send_read(self):
|
def test_write_ready_send_read(self):
|
||||||
transport = self._make_one()
|
transport = self._make_one()
|
||||||
transport._buffer = collections.deque([b'data'])
|
transport._buffer = list_to_buffer([b'data'])
|
||||||
|
|
||||||
self.loop.remove_writer = unittest.mock.Mock()
|
self.loop.remove_writer = unittest.mock.Mock()
|
||||||
self.sslsock.send.side_effect = ssl.SSLWantReadError
|
self.sslsock.send.side_effect = ssl.SSLWantReadError
|
||||||
|
|
@ -1242,11 +1307,11 @@ class SelectorSslTransportTests(unittest.TestCase):
|
||||||
err = self.sslsock.send.side_effect = OSError()
|
err = self.sslsock.send.side_effect = OSError()
|
||||||
|
|
||||||
transport = self._make_one()
|
transport = self._make_one()
|
||||||
transport._buffer = collections.deque([b'data'])
|
transport._buffer = list_to_buffer([b'data'])
|
||||||
transport._fatal_error = unittest.mock.Mock()
|
transport._fatal_error = unittest.mock.Mock()
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
transport._fatal_error.assert_called_with(err)
|
transport._fatal_error.assert_called_with(err)
|
||||||
self.assertEqual(collections.deque(), transport._buffer)
|
self.assertEqual(list_to_buffer(), transport._buffer)
|
||||||
|
|
||||||
def test_write_ready_read_wants_write(self):
|
def test_write_ready_read_wants_write(self):
|
||||||
self.loop.add_reader = unittest.mock.Mock()
|
self.loop.add_reader = unittest.mock.Mock()
|
||||||
|
|
@ -1355,6 +1420,24 @@ class SelectorDatagramTransportTests(unittest.TestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234)))
|
self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234)))
|
||||||
|
|
||||||
|
def test_sendto_bytearray(self):
|
||||||
|
data = bytearray(b'data')
|
||||||
|
transport = _SelectorDatagramTransport(
|
||||||
|
self.loop, self.sock, self.protocol)
|
||||||
|
transport.sendto(data, ('0.0.0.0', 1234))
|
||||||
|
self.assertTrue(self.sock.sendto.called)
|
||||||
|
self.assertEqual(
|
||||||
|
self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234)))
|
||||||
|
|
||||||
|
def test_sendto_memoryview(self):
|
||||||
|
data = memoryview(b'data')
|
||||||
|
transport = _SelectorDatagramTransport(
|
||||||
|
self.loop, self.sock, self.protocol)
|
||||||
|
transport.sendto(data, ('0.0.0.0', 1234))
|
||||||
|
self.assertTrue(self.sock.sendto.called)
|
||||||
|
self.assertEqual(
|
||||||
|
self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234)))
|
||||||
|
|
||||||
def test_sendto_no_data(self):
|
def test_sendto_no_data(self):
|
||||||
transport = _SelectorDatagramTransport(
|
transport = _SelectorDatagramTransport(
|
||||||
self.loop, self.sock, self.protocol)
|
self.loop, self.sock, self.protocol)
|
||||||
|
|
@ -1375,6 +1458,32 @@ class SelectorDatagramTransportTests(unittest.TestCase):
|
||||||
(b'data2', ('0.0.0.0', 12345))],
|
(b'data2', ('0.0.0.0', 12345))],
|
||||||
list(transport._buffer))
|
list(transport._buffer))
|
||||||
|
|
||||||
|
def test_sendto_buffer_bytearray(self):
|
||||||
|
data2 = bytearray(b'data2')
|
||||||
|
transport = _SelectorDatagramTransport(
|
||||||
|
self.loop, self.sock, self.protocol)
|
||||||
|
transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
|
||||||
|
transport.sendto(data2, ('0.0.0.0', 12345))
|
||||||
|
self.assertFalse(self.sock.sendto.called)
|
||||||
|
self.assertEqual(
|
||||||
|
[(b'data1', ('0.0.0.0', 12345)),
|
||||||
|
(b'data2', ('0.0.0.0', 12345))],
|
||||||
|
list(transport._buffer))
|
||||||
|
self.assertIsInstance(transport._buffer[1][0], bytes)
|
||||||
|
|
||||||
|
def test_sendto_buffer_memoryview(self):
|
||||||
|
data2 = memoryview(b'data2')
|
||||||
|
transport = _SelectorDatagramTransport(
|
||||||
|
self.loop, self.sock, self.protocol)
|
||||||
|
transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
|
||||||
|
transport.sendto(data2, ('0.0.0.0', 12345))
|
||||||
|
self.assertFalse(self.sock.sendto.called)
|
||||||
|
self.assertEqual(
|
||||||
|
[(b'data1', ('0.0.0.0', 12345)),
|
||||||
|
(b'data2', ('0.0.0.0', 12345))],
|
||||||
|
list(transport._buffer))
|
||||||
|
self.assertIsInstance(transport._buffer[1][0], bytes)
|
||||||
|
|
||||||
def test_sendto_tryagain(self):
|
def test_sendto_tryagain(self):
|
||||||
data = b'data'
|
data = b'data'
|
||||||
|
|
||||||
|
|
@ -1439,13 +1548,13 @@ class SelectorDatagramTransportTests(unittest.TestCase):
|
||||||
def test_sendto_str(self):
|
def test_sendto_str(self):
|
||||||
transport = _SelectorDatagramTransport(
|
transport = _SelectorDatagramTransport(
|
||||||
self.loop, self.sock, self.protocol)
|
self.loop, self.sock, self.protocol)
|
||||||
self.assertRaises(AssertionError, transport.sendto, 'str', ())
|
self.assertRaises(TypeError, transport.sendto, 'str', ())
|
||||||
|
|
||||||
def test_sendto_connected_addr(self):
|
def test_sendto_connected_addr(self):
|
||||||
transport = _SelectorDatagramTransport(
|
transport = _SelectorDatagramTransport(
|
||||||
self.loop, self.sock, self.protocol, ('0.0.0.0', 1))
|
self.loop, self.sock, self.protocol, ('0.0.0.0', 1))
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
AssertionError, transport.sendto, b'str', ('0.0.0.0', 2))
|
ValueError, transport.sendto, b'str', ('0.0.0.0', 2))
|
||||||
|
|
||||||
def test_sendto_closing(self):
|
def test_sendto_closing(self):
|
||||||
transport = _SelectorDatagramTransport(
|
transport = _SelectorDatagramTransport(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue