mirror of
https://github.com/python/cpython.git
synced 2025-09-27 02:39:58 +00:00
GH-91166: Implement zero copy writes for SelectorSocketTransport
in asyncio (#31871)
Co-authored-by: Guido van Rossum <gvanrossum@gmail.com>
This commit is contained in:
parent
0f6420640c
commit
c122390a55
3 changed files with 176 additions and 30 deletions
|
@ -9,6 +9,8 @@ __all__ = 'BaseSelectorEventLoop',
|
||||||
import collections
|
import collections
|
||||||
import errno
|
import errno
|
||||||
import functools
|
import functools
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
import selectors
|
import selectors
|
||||||
import socket
|
import socket
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -28,6 +30,14 @@ from . import transports
|
||||||
from . import trsock
|
from . import trsock
|
||||||
from .log import logger
|
from .log import logger
|
||||||
|
|
||||||
|
_HAS_SENDMSG = hasattr(socket.socket, 'sendmsg')
|
||||||
|
|
||||||
|
if _HAS_SENDMSG:
|
||||||
|
try:
|
||||||
|
SC_IOV_MAX = os.sysconf('SC_IOV_MAX')
|
||||||
|
except OSError:
|
||||||
|
# Fallback to send
|
||||||
|
_HAS_SENDMSG = False
|
||||||
|
|
||||||
def _test_selector_event(selector, fd, event):
|
def _test_selector_event(selector, fd, event):
|
||||||
# Test if the selector is monitoring 'event' events
|
# Test if the selector is monitoring 'event' events
|
||||||
|
@ -757,8 +767,6 @@ class _SelectorTransport(transports._FlowControlMixin,
|
||||||
|
|
||||||
max_size = 256 * 1024 # Buffer size passed to recv().
|
max_size = 256 * 1024 # Buffer size passed to recv().
|
||||||
|
|
||||||
_buffer_factory = bytearray # Constructs initial value for self._buffer.
|
|
||||||
|
|
||||||
# Attribute used in the destructor: it must be set even if the constructor
|
# Attribute used in the destructor: it must be set even if the constructor
|
||||||
# is not called (see _SelectorSslTransport which may start by raising an
|
# is not called (see _SelectorSslTransport which may start by raising an
|
||||||
# exception)
|
# exception)
|
||||||
|
@ -783,7 +791,7 @@ class _SelectorTransport(transports._FlowControlMixin,
|
||||||
self.set_protocol(protocol)
|
self.set_protocol(protocol)
|
||||||
|
|
||||||
self._server = server
|
self._server = server
|
||||||
self._buffer = self._buffer_factory()
|
self._buffer = collections.deque()
|
||||||
self._conn_lost = 0 # Set when call to connection_lost scheduled.
|
self._conn_lost = 0 # Set when call to connection_lost scheduled.
|
||||||
self._closing = False # Set when close() called.
|
self._closing = False # Set when close() called.
|
||||||
if self._server is not None:
|
if self._server is not None:
|
||||||
|
@ -887,7 +895,7 @@ class _SelectorTransport(transports._FlowControlMixin,
|
||||||
self._server = None
|
self._server = None
|
||||||
|
|
||||||
def get_write_buffer_size(self):
|
def get_write_buffer_size(self):
|
||||||
return len(self._buffer)
|
return sum(map(len, self._buffer))
|
||||||
|
|
||||||
def _add_reader(self, fd, callback, *args):
|
def _add_reader(self, fd, callback, *args):
|
||||||
if self._closing:
|
if self._closing:
|
||||||
|
@ -909,7 +917,10 @@ class _SelectorSocketTransport(_SelectorTransport):
|
||||||
self._eof = False
|
self._eof = False
|
||||||
self._paused = False
|
self._paused = False
|
||||||
self._empty_waiter = None
|
self._empty_waiter = None
|
||||||
|
if _HAS_SENDMSG:
|
||||||
|
self._write_ready = self._write_sendmsg
|
||||||
|
else:
|
||||||
|
self._write_ready = self._write_send
|
||||||
# Disable the Nagle algorithm -- small writes will be
|
# Disable the Nagle algorithm -- small writes will be
|
||||||
# sent without waiting for the TCP ACK. This generally
|
# sent without waiting for the TCP ACK. This generally
|
||||||
# decreases the latency (in some cases significantly.)
|
# decreases the latency (in some cases significantly.)
|
||||||
|
@ -1066,23 +1077,68 @@ class _SelectorSocketTransport(_SelectorTransport):
|
||||||
self._fatal_error(exc, 'Fatal write error on socket transport')
|
self._fatal_error(exc, 'Fatal write error on socket transport')
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
data = data[n:]
|
data = memoryview(data)[n:]
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
# Not all was written; register write handler.
|
# Not all was written; register write handler.
|
||||||
self._loop._add_writer(self._sock_fd, self._write_ready)
|
self._loop._add_writer(self._sock_fd, self._write_ready)
|
||||||
|
|
||||||
# Add it to the buffer.
|
# Add it to the buffer.
|
||||||
self._buffer.extend(data)
|
self._buffer.append(data)
|
||||||
self._maybe_pause_protocol()
|
self._maybe_pause_protocol()
|
||||||
|
|
||||||
def _write_ready(self):
|
def _get_sendmsg_buffer(self):
|
||||||
assert self._buffer, 'Data should not be empty'
|
return itertools.islice(self._buffer, SC_IOV_MAX)
|
||||||
|
|
||||||
|
def _write_sendmsg(self):
|
||||||
|
assert self._buffer, 'Data should not be empty'
|
||||||
if self._conn_lost:
|
if self._conn_lost:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
n = self._sock.send(self._buffer)
|
nbytes = self._sock.sendmsg(self._get_sendmsg_buffer())
|
||||||
|
self._adjust_leftover_buffer(nbytes)
|
||||||
|
except (BlockingIOError, InterruptedError):
|
||||||
|
pass
|
||||||
|
except (SystemExit, KeyboardInterrupt):
|
||||||
|
raise
|
||||||
|
except BaseException as exc:
|
||||||
|
self._loop._remove_writer(self._sock_fd)
|
||||||
|
self._buffer.clear()
|
||||||
|
self._fatal_error(exc, 'Fatal write error on socket transport')
|
||||||
|
if self._empty_waiter is not None:
|
||||||
|
self._empty_waiter.set_exception(exc)
|
||||||
|
else:
|
||||||
|
self._maybe_resume_protocol() # May append to buffer.
|
||||||
|
if not self._buffer:
|
||||||
|
self._loop._remove_writer(self._sock_fd)
|
||||||
|
if self._empty_waiter is not None:
|
||||||
|
self._empty_waiter.set_result(None)
|
||||||
|
if self._closing:
|
||||||
|
self._call_connection_lost(None)
|
||||||
|
elif self._eof:
|
||||||
|
self._sock.shutdown(socket.SHUT_WR)
|
||||||
|
|
||||||
|
def _adjust_leftover_buffer(self, nbytes: int) -> None:
|
||||||
|
buffer = self._buffer
|
||||||
|
while nbytes:
|
||||||
|
b = buffer.popleft()
|
||||||
|
b_len = len(b)
|
||||||
|
if b_len <= nbytes:
|
||||||
|
nbytes -= b_len
|
||||||
|
else:
|
||||||
|
buffer.appendleft(b[nbytes:])
|
||||||
|
break
|
||||||
|
|
||||||
|
def _write_send(self):
|
||||||
|
assert self._buffer, 'Data should not be empty'
|
||||||
|
if self._conn_lost:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
buffer = self._buffer.popleft()
|
||||||
|
n = self._sock.send(buffer)
|
||||||
|
if n != len(buffer):
|
||||||
|
# Not all data was written
|
||||||
|
self._buffer.appendleft(buffer[n:])
|
||||||
except (BlockingIOError, InterruptedError):
|
except (BlockingIOError, InterruptedError):
|
||||||
pass
|
pass
|
||||||
except (SystemExit, KeyboardInterrupt):
|
except (SystemExit, KeyboardInterrupt):
|
||||||
|
@ -1094,8 +1150,6 @@ class _SelectorSocketTransport(_SelectorTransport):
|
||||||
if self._empty_waiter is not None:
|
if self._empty_waiter is not None:
|
||||||
self._empty_waiter.set_exception(exc)
|
self._empty_waiter.set_exception(exc)
|
||||||
else:
|
else:
|
||||||
if n:
|
|
||||||
del self._buffer[:n]
|
|
||||||
self._maybe_resume_protocol() # May append to buffer.
|
self._maybe_resume_protocol() # May append to buffer.
|
||||||
if not self._buffer:
|
if not self._buffer:
|
||||||
self._loop._remove_writer(self._sock_fd)
|
self._loop._remove_writer(self._sock_fd)
|
||||||
|
@ -1113,6 +1167,16 @@ class _SelectorSocketTransport(_SelectorTransport):
|
||||||
if not self._buffer:
|
if not self._buffer:
|
||||||
self._sock.shutdown(socket.SHUT_WR)
|
self._sock.shutdown(socket.SHUT_WR)
|
||||||
|
|
||||||
|
def writelines(self, list_of_data):
|
||||||
|
if self._eof:
|
||||||
|
raise RuntimeError('Cannot call writelines() after write_eof()')
|
||||||
|
if self._empty_waiter is not None:
|
||||||
|
raise RuntimeError('unable to writelines; sendfile is in progress')
|
||||||
|
if not list_of_data:
|
||||||
|
return
|
||||||
|
self._buffer.extend([memoryview(data) for data in list_of_data])
|
||||||
|
self._write_ready()
|
||||||
|
|
||||||
def can_write_eof(self):
|
def can_write_eof(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -1,23 +1,25 @@
|
||||||
"""Tests for selector_events.py"""
|
"""Tests for selector_events.py"""
|
||||||
|
|
||||||
import sys
|
import collections
|
||||||
import selectors
|
import selectors
|
||||||
import socket
|
import socket
|
||||||
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
from asyncio import selector_events
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ssl
|
import ssl
|
||||||
except ImportError:
|
except ImportError:
|
||||||
ssl = None
|
ssl = None
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio.selector_events import BaseSelectorEventLoop
|
from asyncio.selector_events import (BaseSelectorEventLoop,
|
||||||
from asyncio.selector_events import _SelectorTransport
|
_SelectorDatagramTransport,
|
||||||
from asyncio.selector_events import _SelectorSocketTransport
|
_SelectorSocketTransport,
|
||||||
from asyncio.selector_events import _SelectorDatagramTransport
|
_SelectorTransport)
|
||||||
from test.test_asyncio import utils as test_utils
|
from test.test_asyncio import utils as test_utils
|
||||||
|
|
||||||
|
|
||||||
MOCK_ANY = mock.ANY
|
MOCK_ANY = mock.ANY
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,7 +39,10 @@ class TestBaseSelectorEventLoop(BaseSelectorEventLoop):
|
||||||
|
|
||||||
|
|
||||||
def list_to_buffer(l=()):
|
def list_to_buffer(l=()):
|
||||||
return bytearray().join(l)
|
buffer = collections.deque()
|
||||||
|
buffer.extend((memoryview(i) for i in l))
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def close_transport(transport):
|
def close_transport(transport):
|
||||||
|
@ -493,9 +498,13 @@ class SelectorSocketTransportTests(test_utils.TestCase):
|
||||||
self.sock = mock.Mock(socket.socket)
|
self.sock = mock.Mock(socket.socket)
|
||||||
self.sock_fd = self.sock.fileno.return_value = 7
|
self.sock_fd = self.sock.fileno.return_value = 7
|
||||||
|
|
||||||
def socket_transport(self, waiter=None):
|
def socket_transport(self, waiter=None, sendmsg=False):
|
||||||
transport = _SelectorSocketTransport(self.loop, self.sock,
|
transport = _SelectorSocketTransport(self.loop, self.sock,
|
||||||
self.protocol, waiter=waiter)
|
self.protocol, waiter=waiter)
|
||||||
|
if sendmsg:
|
||||||
|
transport._write_ready = transport._write_sendmsg
|
||||||
|
else:
|
||||||
|
transport._write_ready = transport._write_send
|
||||||
self.addCleanup(close_transport, transport)
|
self.addCleanup(close_transport, transport)
|
||||||
return transport
|
return transport
|
||||||
|
|
||||||
|
@ -664,14 +673,14 @@ class SelectorSocketTransportTests(test_utils.TestCase):
|
||||||
|
|
||||||
def test_write_no_data(self):
|
def test_write_no_data(self):
|
||||||
transport = self.socket_transport()
|
transport = self.socket_transport()
|
||||||
transport._buffer.extend(b'data')
|
transport._buffer.append(memoryview(b'data'))
|
||||||
transport.write(b'')
|
transport.write(b'')
|
||||||
self.assertFalse(self.sock.send.called)
|
self.assertFalse(self.sock.send.called)
|
||||||
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
|
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
|
||||||
|
|
||||||
def test_write_buffer(self):
|
def test_write_buffer(self):
|
||||||
transport = self.socket_transport()
|
transport = self.socket_transport()
|
||||||
transport._buffer.extend(b'data1')
|
transport._buffer.append(b'data1')
|
||||||
transport.write(b'data2')
|
transport.write(b'data2')
|
||||||
self.assertFalse(self.sock.send.called)
|
self.assertFalse(self.sock.send.called)
|
||||||
self.assertEqual(list_to_buffer([b'data1', b'data2']),
|
self.assertEqual(list_to_buffer([b'data1', b'data2']),
|
||||||
|
@ -729,6 +738,77 @@ class SelectorSocketTransportTests(test_utils.TestCase):
|
||||||
self.loop.assert_writer(7, transport._write_ready)
|
self.loop.assert_writer(7, transport._write_ready)
|
||||||
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
|
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
|
||||||
|
|
||||||
|
def test_write_sendmsg_no_data(self):
|
||||||
|
self.sock.sendmsg = mock.Mock()
|
||||||
|
self.sock.sendmsg.return_value = 0
|
||||||
|
transport = self.socket_transport(sendmsg=True)
|
||||||
|
transport._buffer.append(memoryview(b'data'))
|
||||||
|
transport.write(b'')
|
||||||
|
self.assertFalse(self.sock.sendmsg.called)
|
||||||
|
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
|
||||||
|
|
||||||
|
@unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg')
|
||||||
|
def test_write_sendmsg_full(self):
|
||||||
|
data = memoryview(b'data')
|
||||||
|
self.sock.sendmsg = mock.Mock()
|
||||||
|
self.sock.sendmsg.return_value = len(data)
|
||||||
|
|
||||||
|
transport = self.socket_transport(sendmsg=True)
|
||||||
|
transport._buffer.append(data)
|
||||||
|
self.loop._add_writer(7, transport._write_ready)
|
||||||
|
transport._write_ready()
|
||||||
|
self.assertTrue(self.sock.sendmsg.called)
|
||||||
|
self.assertFalse(self.loop.writers)
|
||||||
|
|
||||||
|
@unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg')
|
||||||
|
def test_write_sendmsg_partial(self):
|
||||||
|
|
||||||
|
data = memoryview(b'data')
|
||||||
|
self.sock.sendmsg = mock.Mock()
|
||||||
|
# Sent partial data
|
||||||
|
self.sock.sendmsg.return_value = 2
|
||||||
|
|
||||||
|
transport = self.socket_transport(sendmsg=True)
|
||||||
|
transport._buffer.append(data)
|
||||||
|
self.loop._add_writer(7, transport._write_ready)
|
||||||
|
transport._write_ready()
|
||||||
|
self.assertTrue(self.sock.sendmsg.called)
|
||||||
|
self.assertTrue(self.loop.writers)
|
||||||
|
self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
|
||||||
|
|
||||||
|
@unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg')
|
||||||
|
def test_write_sendmsg_half_buffer(self):
|
||||||
|
data = [memoryview(b'data1'), memoryview(b'data2')]
|
||||||
|
self.sock.sendmsg = mock.Mock()
|
||||||
|
# Sent partial data
|
||||||
|
self.sock.sendmsg.return_value = 2
|
||||||
|
|
||||||
|
transport = self.socket_transport(sendmsg=True)
|
||||||
|
transport._buffer.extend(data)
|
||||||
|
self.loop._add_writer(7, transport._write_ready)
|
||||||
|
transport._write_ready()
|
||||||
|
self.assertTrue(self.sock.sendmsg.called)
|
||||||
|
self.assertTrue(self.loop.writers)
|
||||||
|
self.assertEqual(list_to_buffer([b'ta1', b'data2']), transport._buffer)
|
||||||
|
|
||||||
|
@unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg')
|
||||||
|
def test_write_sendmsg_OSError(self):
|
||||||
|
data = memoryview(b'data')
|
||||||
|
self.sock.sendmsg = mock.Mock()
|
||||||
|
err = self.sock.sendmsg.side_effect = OSError()
|
||||||
|
|
||||||
|
transport = self.socket_transport(sendmsg=True)
|
||||||
|
transport._fatal_error = mock.Mock()
|
||||||
|
transport._buffer.extend(data)
|
||||||
|
# Calls _fatal_error and clears the buffer
|
||||||
|
transport._write_ready()
|
||||||
|
self.assertTrue(self.sock.sendmsg.called)
|
||||||
|
self.assertFalse(self.loop.writers)
|
||||||
|
self.assertEqual(list_to_buffer([]), transport._buffer)
|
||||||
|
transport._fatal_error.assert_called_with(
|
||||||
|
err,
|
||||||
|
'Fatal write error on socket transport')
|
||||||
|
|
||||||
@mock.patch('asyncio.selector_events.logger')
|
@mock.patch('asyncio.selector_events.logger')
|
||||||
def test_write_exception(self, m_log):
|
def test_write_exception(self, m_log):
|
||||||
err = self.sock.send.side_effect = OSError()
|
err = self.sock.send.side_effect = OSError()
|
||||||
|
@ -768,19 +848,19 @@ class SelectorSocketTransportTests(test_utils.TestCase):
|
||||||
self.sock.send.return_value = len(data)
|
self.sock.send.return_value = len(data)
|
||||||
|
|
||||||
transport = self.socket_transport()
|
transport = self.socket_transport()
|
||||||
transport._buffer.extend(data)
|
transport._buffer.append(data)
|
||||||
self.loop._add_writer(7, transport._write_ready)
|
self.loop._add_writer(7, transport._write_ready)
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
self.assertTrue(self.sock.send.called)
|
self.assertTrue(self.sock.send.called)
|
||||||
self.assertFalse(self.loop.writers)
|
self.assertFalse(self.loop.writers)
|
||||||
|
|
||||||
def test_write_ready_closing(self):
|
def test_write_ready_closing(self):
|
||||||
data = b'data'
|
data = memoryview(b'data')
|
||||||
self.sock.send.return_value = len(data)
|
self.sock.send.return_value = len(data)
|
||||||
|
|
||||||
transport = self.socket_transport()
|
transport = self.socket_transport()
|
||||||
transport._closing = True
|
transport._closing = True
|
||||||
transport._buffer.extend(data)
|
transport._buffer.append(data)
|
||||||
self.loop._add_writer(7, transport._write_ready)
|
self.loop._add_writer(7, transport._write_ready)
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
self.assertTrue(self.sock.send.called)
|
self.assertTrue(self.sock.send.called)
|
||||||
|
@ -795,11 +875,11 @@ class SelectorSocketTransportTests(test_utils.TestCase):
|
||||||
self.assertRaises(AssertionError, transport._write_ready)
|
self.assertRaises(AssertionError, transport._write_ready)
|
||||||
|
|
||||||
def test_write_ready_partial(self):
|
def test_write_ready_partial(self):
|
||||||
data = b'data'
|
data = memoryview(b'data')
|
||||||
self.sock.send.return_value = 2
|
self.sock.send.return_value = 2
|
||||||
|
|
||||||
transport = self.socket_transport()
|
transport = self.socket_transport()
|
||||||
transport._buffer.extend(data)
|
transport._buffer.append(data)
|
||||||
self.loop._add_writer(7, transport._write_ready)
|
self.loop._add_writer(7, transport._write_ready)
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
self.loop.assert_writer(7, transport._write_ready)
|
self.loop.assert_writer(7, transport._write_ready)
|
||||||
|
@ -810,7 +890,7 @@ class SelectorSocketTransportTests(test_utils.TestCase):
|
||||||
self.sock.send.return_value = 0
|
self.sock.send.return_value = 0
|
||||||
|
|
||||||
transport = self.socket_transport()
|
transport = self.socket_transport()
|
||||||
transport._buffer.extend(data)
|
transport._buffer.append(data)
|
||||||
self.loop._add_writer(7, transport._write_ready)
|
self.loop._add_writer(7, transport._write_ready)
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
self.loop.assert_writer(7, transport._write_ready)
|
self.loop.assert_writer(7, transport._write_ready)
|
||||||
|
@ -820,12 +900,13 @@ class SelectorSocketTransportTests(test_utils.TestCase):
|
||||||
self.sock.send.side_effect = BlockingIOError
|
self.sock.send.side_effect = BlockingIOError
|
||||||
|
|
||||||
transport = self.socket_transport()
|
transport = self.socket_transport()
|
||||||
transport._buffer = list_to_buffer([b'data1', b'data2'])
|
buffer = list_to_buffer([b'data1', b'data2'])
|
||||||
|
transport._buffer = buffer
|
||||||
self.loop._add_writer(7, transport._write_ready)
|
self.loop._add_writer(7, transport._write_ready)
|
||||||
transport._write_ready()
|
transport._write_ready()
|
||||||
|
|
||||||
self.loop.assert_writer(7, transport._write_ready)
|
self.loop.assert_writer(7, transport._write_ready)
|
||||||
self.assertEqual(list_to_buffer([b'data1data2']), transport._buffer)
|
self.assertEqual(buffer, transport._buffer)
|
||||||
|
|
||||||
def test_write_ready_exception(self):
|
def test_write_ready_exception(self):
|
||||||
err = self.sock.send.side_effect = OSError()
|
err = self.sock.send.side_effect = OSError()
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
:mod:`asyncio` is optimized to avoid excessive copying when writing to socket and use :meth:`~socket.socket.sendmsg` if the platform supports it. Patch by Kumar Aditya.
|
Loading…
Add table
Add a link
Reference in a new issue