mirror of
https://github.com/python/cpython.git
synced 2025-11-01 18:51:43 +00:00
bpo-32622: Implement loop.sendfile() (#5271)
This commit is contained in:
parent
f13f12d8da
commit
7c684073f9
12 changed files with 560 additions and 8 deletions
|
|
@ -38,8 +38,10 @@ from . import constants
|
|||
from . import coroutines
|
||||
from . import events
|
||||
from . import futures
|
||||
from . import protocols
|
||||
from . import sslproto
|
||||
from . import tasks
|
||||
from . import transports
|
||||
from .log import logger
|
||||
|
||||
|
||||
|
|
@ -155,6 +157,75 @@ def _run_until_complete_cb(fut):
|
|||
futures._get_loop(fut).stop()
|
||||
|
||||
|
||||
|
||||
class _SendfileFallbackProtocol(protocols.Protocol):
|
||||
def __init__(self, transp):
|
||||
if not isinstance(transp, transports._FlowControlMixin):
|
||||
raise TypeError("transport should be _FlowControlMixin instance")
|
||||
self._transport = transp
|
||||
self._proto = transp.get_protocol()
|
||||
self._should_resume_reading = transp.is_reading()
|
||||
self._should_resume_writing = transp._protocol_paused
|
||||
transp.pause_reading()
|
||||
transp.set_protocol(self)
|
||||
if self._should_resume_writing:
|
||||
self._write_ready_fut = self._transport._loop.create_future()
|
||||
else:
|
||||
self._write_ready_fut = None
|
||||
|
||||
async def drain(self):
|
||||
if self._transport.is_closing():
|
||||
raise ConnectionError("Connection closed by peer")
|
||||
fut = self._write_ready_fut
|
||||
if fut is None:
|
||||
return
|
||||
await fut
|
||||
|
||||
def connection_made(self, transport):
|
||||
raise RuntimeError("Invalid state: "
|
||||
"connection should have been established already.")
|
||||
|
||||
def connection_lost(self, exc):
|
||||
if self._write_ready_fut is not None:
|
||||
# Never happens if peer disconnects after sending the whole content
|
||||
# Thus disconnection is always an exception from user perspective
|
||||
if exc is None:
|
||||
self._write_ready_fut.set_exception(
|
||||
ConnectionError("Connection is closed by peer"))
|
||||
else:
|
||||
self._write_ready_fut.set_exception(exc)
|
||||
self._proto.connection_lost(exc)
|
||||
|
||||
def pause_writing(self):
|
||||
if self._write_ready_fut is not None:
|
||||
return
|
||||
self._write_ready_fut = self._transport._loop.create_future()
|
||||
|
||||
def resume_writing(self):
|
||||
if self._write_ready_fut is None:
|
||||
return
|
||||
self._write_ready_fut.set_result(False)
|
||||
self._write_ready_fut = None
|
||||
|
||||
def data_received(self, data):
|
||||
raise RuntimeError("Invalid state: reading should be paused")
|
||||
|
||||
def eof_received(self):
|
||||
raise RuntimeError("Invalid state: reading should be paused")
|
||||
|
||||
async def restore(self):
|
||||
self._transport.set_protocol(self._proto)
|
||||
if self._should_resume_reading:
|
||||
self._transport.resume_reading()
|
||||
if self._write_ready_fut is not None:
|
||||
# Cancel the future.
|
||||
# Basically it has no effect because protocol is switched back,
|
||||
# no code should wait for it anymore.
|
||||
self._write_ready_fut.cancel()
|
||||
if self._should_resume_writing:
|
||||
self._proto.resume_writing()
|
||||
|
||||
|
||||
class Server(events.AbstractServer):
|
||||
|
||||
def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
|
||||
|
|
@ -926,6 +997,77 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
|
||||
return transport, protocol
|
||||
|
||||
async def sendfile(self, transport, file, offset=0, count=None,
|
||||
*, fallback=True):
|
||||
"""Send a file to transport.
|
||||
|
||||
Return the total number of bytes which were sent.
|
||||
|
||||
The method uses high-performance os.sendfile if available.
|
||||
|
||||
file must be a regular file object opened in binary mode.
|
||||
|
||||
offset tells from where to start reading the file. If specified,
|
||||
count is the total number of bytes to transmit as opposed to
|
||||
sending the file until EOF is reached. File position is updated on
|
||||
return or also in case of error in which case file.tell()
|
||||
can be used to figure out the number of bytes
|
||||
which were sent.
|
||||
|
||||
fallback set to True makes asyncio to manually read and send
|
||||
the file when the platform does not support the sendfile syscall
|
||||
(e.g. Windows or SSL socket on Unix).
|
||||
|
||||
Raise SendfileNotAvailableError if the system does not support
|
||||
sendfile syscall and fallback is False.
|
||||
"""
|
||||
if transport.is_closing():
|
||||
raise RuntimeError("Transport is closing")
|
||||
mode = getattr(transport, '_sendfile_compatible',
|
||||
constants._SendfileMode.UNSUPPORTED)
|
||||
if mode is constants._SendfileMode.UNSUPPORTED:
|
||||
raise RuntimeError(
|
||||
f"sendfile is not supported for transport {transport!r}")
|
||||
if mode is constants._SendfileMode.TRY_NATIVE:
|
||||
try:
|
||||
return await self._sendfile_native(transport, file,
|
||||
offset, count)
|
||||
except events.SendfileNotAvailableError as exc:
|
||||
if not fallback:
|
||||
raise
|
||||
# the mode is FALLBACK or fallback is True
|
||||
return await self._sendfile_fallback(transport, file,
|
||||
offset, count)
|
||||
|
||||
async def _sendfile_native(self, transp, file, offset, count):
|
||||
raise events.SendfileNotAvailableError(
|
||||
"sendfile syscall is not supported")
|
||||
|
||||
async def _sendfile_fallback(self, transp, file, offset, count):
|
||||
if offset:
|
||||
file.seek(offset)
|
||||
blocksize = min(count, 16384) if count else 16384
|
||||
buf = bytearray(blocksize)
|
||||
total_sent = 0
|
||||
proto = _SendfileFallbackProtocol(transp)
|
||||
try:
|
||||
while True:
|
||||
if count:
|
||||
blocksize = min(count - total_sent, blocksize)
|
||||
if blocksize <= 0:
|
||||
return total_sent
|
||||
view = memoryview(buf)[:blocksize]
|
||||
read = file.readinto(view)
|
||||
if not read:
|
||||
return total_sent # EOF
|
||||
await proto.drain()
|
||||
transp.write(view)
|
||||
total_sent += read
|
||||
finally:
|
||||
if total_sent > 0 and hasattr(file, 'seek'):
|
||||
file.seek(offset + total_sent)
|
||||
await proto.restore()
|
||||
|
||||
async def start_tls(self, transport, protocol, sslcontext, *,
|
||||
server_side=False,
|
||||
server_hostname=None,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import enum
|
||||
|
||||
# After the connection is lost, log warnings after this many write()s.
|
||||
LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5
|
||||
|
||||
|
|
@ -11,3 +13,10 @@ DEBUG_STACK_DEPTH = 10
|
|||
|
||||
# Number of seconds to wait for SSL handshake to complete
|
||||
SSL_HANDSHAKE_TIMEOUT = 10.0
|
||||
|
||||
# The enum should be here to break circular dependencies between
|
||||
# base_events and sslproto
|
||||
class _SendfileMode(enum.Enum):
|
||||
UNSUPPORTED = enum.auto()
|
||||
TRY_NATIVE = enum.auto()
|
||||
FALLBACK = enum.auto()
|
||||
|
|
|
|||
|
|
@ -354,6 +354,14 @@ class AbstractEventLoop:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def sendfile(self, transport, file, offset=0, count=None,
|
||||
*, fallback=True):
|
||||
"""Send a file through a transport.
|
||||
|
||||
Return an amount of sent bytes.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def start_tls(self, transport, protocol, sslcontext, *,
|
||||
server_side=False,
|
||||
server_hostname=None,
|
||||
|
|
|
|||
|
|
@ -180,7 +180,12 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
|
|||
assert self._read_fut is fut or (self._read_fut is None and
|
||||
self._closing)
|
||||
self._read_fut = None
|
||||
data = fut.result() # deliver data later in "finally" clause
|
||||
if fut.done():
|
||||
# deliver data later in "finally" clause
|
||||
data = fut.result()
|
||||
else:
|
||||
# the future will be replaced by next proactor.recv call
|
||||
fut.cancel()
|
||||
|
||||
if self._closing:
|
||||
# since close() has been called we ignore any read data
|
||||
|
|
@ -345,6 +350,8 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport,
|
|||
transports.Transport):
|
||||
"""Transport for connected sockets."""
|
||||
|
||||
_sendfile_compatible = constants._SendfileMode.FALLBACK
|
||||
|
||||
def _set_extra(self, sock):
|
||||
self._extra['socket'] = sock
|
||||
|
||||
|
|
|
|||
|
|
@ -540,6 +540,20 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
|||
else:
|
||||
fut.set_result((conn, address))
|
||||
|
||||
async def _sendfile_native(self, transp, file, offset, count):
|
||||
del self._transports[transp._sock_fd]
|
||||
resume_reading = transp.is_reading()
|
||||
transp.pause_reading()
|
||||
await transp._make_empty_waiter()
|
||||
try:
|
||||
return await self.sock_sendfile(transp._sock, file, offset, count,
|
||||
fallback=False)
|
||||
finally:
|
||||
transp._reset_empty_waiter()
|
||||
if resume_reading:
|
||||
transp.resume_reading()
|
||||
self._transports[transp._sock_fd] = transp
|
||||
|
||||
def _process_events(self, event_list):
|
||||
for key, mask in event_list:
|
||||
fileobj, (reader, writer) = key.fileobj, key.data
|
||||
|
|
@ -695,12 +709,14 @@ class _SelectorTransport(transports._FlowControlMixin,
|
|||
class _SelectorSocketTransport(_SelectorTransport):
|
||||
|
||||
_start_tls_compatible = True
|
||||
_sendfile_compatible = constants._SendfileMode.TRY_NATIVE
|
||||
|
||||
def __init__(self, loop, sock, protocol, waiter=None,
|
||||
extra=None, server=None):
|
||||
super().__init__(loop, sock, protocol, extra, server)
|
||||
self._eof = False
|
||||
self._paused = False
|
||||
self._empty_waiter = None
|
||||
|
||||
# Disable the Nagle algorithm -- small writes will be
|
||||
# sent without waiting for the TCP ACK. This generally
|
||||
|
|
@ -765,6 +781,8 @@ class _SelectorSocketTransport(_SelectorTransport):
|
|||
f'not {type(data).__name__!r}')
|
||||
if self._eof:
|
||||
raise RuntimeError('Cannot call write() after write_eof()')
|
||||
if self._empty_waiter is not None:
|
||||
raise RuntimeError('unable to write; sendfile is in progress')
|
||||
if not data:
|
||||
return
|
||||
|
||||
|
|
@ -807,12 +825,16 @@ class _SelectorSocketTransport(_SelectorTransport):
|
|||
self._loop._remove_writer(self._sock_fd)
|
||||
self._buffer.clear()
|
||||
self._fatal_error(exc, 'Fatal write error on socket transport')
|
||||
if self._empty_waiter is not None:
|
||||
self._empty_waiter.set_exception(exc)
|
||||
else:
|
||||
if n:
|
||||
del self._buffer[:n]
|
||||
self._maybe_resume_protocol() # May append to buffer.
|
||||
if not self._buffer:
|
||||
self._loop._remove_writer(self._sock_fd)
|
||||
if self._empty_waiter is not None:
|
||||
self._empty_waiter.set_result(None)
|
||||
if self._closing:
|
||||
self._call_connection_lost(None)
|
||||
elif self._eof:
|
||||
|
|
@ -828,6 +850,23 @@ class _SelectorSocketTransport(_SelectorTransport):
|
|||
def can_write_eof(self):
|
||||
return True
|
||||
|
||||
def _call_connection_lost(self, exc):
|
||||
super()._call_connection_lost(exc)
|
||||
if self._empty_waiter is not None:
|
||||
self._empty_waiter.set_exception(
|
||||
ConnectionError("Connection is closed by peer"))
|
||||
|
||||
def _make_empty_waiter(self):
|
||||
if self._empty_waiter is not None:
|
||||
raise RuntimeError("Empty waiter is already set")
|
||||
self._empty_waiter = self._loop.create_future()
|
||||
if not self._buffer:
|
||||
self._empty_waiter.set_result(None)
|
||||
return self._empty_waiter
|
||||
|
||||
def _reset_empty_waiter(self):
|
||||
self._empty_waiter = None
|
||||
|
||||
|
||||
class _SelectorDatagramTransport(_SelectorTransport):
|
||||
|
||||
|
|
|
|||
|
|
@ -282,6 +282,8 @@ class _SSLPipe(object):
|
|||
class _SSLProtocolTransport(transports._FlowControlMixin,
|
||||
transports.Transport):
|
||||
|
||||
_sendfile_compatible = constants._SendfileMode.FALLBACK
|
||||
|
||||
def __init__(self, loop, ssl_protocol):
|
||||
self._loop = loop
|
||||
# SSLProtocol instance
|
||||
|
|
@ -365,6 +367,11 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
|
|||
"""Return the current size of the write buffer."""
|
||||
return self._ssl_protocol._transport.get_write_buffer_size()
|
||||
|
||||
@property
|
||||
def _protocol_paused(self):
|
||||
# Required for sendfile fallback pause_writing/resume_writing logic
|
||||
return self._ssl_protocol._transport._protocol_paused
|
||||
|
||||
def write(self, data):
|
||||
"""Write some data bytes to the transport.
|
||||
|
||||
|
|
|
|||
|
|
@ -425,7 +425,8 @@ class IocpProactor:
|
|||
try:
|
||||
return ov.getresult()
|
||||
except OSError as exc:
|
||||
if exc.winerror == _overlapped.ERROR_NETNAME_DELETED:
|
||||
if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
|
||||
_overlapped.ERROR_OPERATION_ABORTED):
|
||||
raise ConnectionResetError(*exc.args)
|
||||
else:
|
||||
raise
|
||||
|
|
@ -447,7 +448,8 @@ class IocpProactor:
|
|||
try:
|
||||
return ov.getresult()
|
||||
except OSError as exc:
|
||||
if exc.winerror == _overlapped.ERROR_NETNAME_DELETED:
|
||||
if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
|
||||
_overlapped.ERROR_OPERATION_ABORTED):
|
||||
raise ConnectionResetError(*exc.args)
|
||||
else:
|
||||
raise
|
||||
|
|
@ -466,7 +468,8 @@ class IocpProactor:
|
|||
try:
|
||||
return ov.getresult()
|
||||
except OSError as exc:
|
||||
if exc.winerror == _overlapped.ERROR_NETNAME_DELETED:
|
||||
if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
|
||||
_overlapped.ERROR_OPERATION_ABORTED):
|
||||
raise ConnectionResetError(*exc.args)
|
||||
else:
|
||||
raise
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue