mirror of
https://github.com/python/cpython.git
synced 2025-10-09 08:31:26 +00:00
bpo-32622: Native sendfile on windows (GH-5565)
* Support sendfile on Windows Proactor event loop naively.
(cherry picked from commit a19fb3c6aa
)
Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
This commit is contained in:
parent
b6b6669cfd
commit
632c1cb571
7 changed files with 431 additions and 93 deletions
|
@ -15,6 +15,7 @@ except ImportError:
|
|||
ssl = None
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import errno
|
||||
|
@ -2092,22 +2093,7 @@ class SubprocessTestsMixin:
|
|||
self.loop.run_until_complete(connect(shell=False))
|
||||
|
||||
|
||||
class MySendfileProto(MyBaseProto):
|
||||
|
||||
def __init__(self, loop=None, close_after=0):
|
||||
super().__init__(loop)
|
||||
self.data = bytearray()
|
||||
self.close_after = close_after
|
||||
|
||||
def data_received(self, data):
|
||||
self.data.extend(data)
|
||||
super().data_received(data)
|
||||
if self.close_after and self.nbytes >= self.close_after:
|
||||
self.transport.close()
|
||||
|
||||
|
||||
class SendfileMixin:
|
||||
# Note: sendfile via SSL transport is equal to sendfile fallback
|
||||
class SendfileBase:
|
||||
|
||||
DATA = b"12345abcde" * 160 * 1024 # 160 KiB
|
||||
|
||||
|
@ -2130,9 +2116,134 @@ class SendfileMixin:
|
|||
def run_loop(self, coro):
|
||||
return self.loop.run_until_complete(coro)
|
||||
|
||||
def prepare(self, *, is_ssl=False, close_after=0):
|
||||
|
||||
class SockSendfileMixin(SendfileBase):
|
||||
|
||||
class MyProto(asyncio.Protocol):
|
||||
|
||||
def __init__(self, loop):
|
||||
self.started = False
|
||||
self.closed = False
|
||||
self.data = bytearray()
|
||||
self.fut = loop.create_future()
|
||||
self.transport = None
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.started = True
|
||||
self.transport = transport
|
||||
|
||||
def data_received(self, data):
|
||||
self.data.extend(data)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
self.closed = True
|
||||
self.fut.set_result(None)
|
||||
|
||||
async def wait_closed(self):
|
||||
await self.fut
|
||||
|
||||
def make_socket(self, cleanup=True):
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.setblocking(False)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024)
|
||||
if cleanup:
|
||||
self.addCleanup(sock.close)
|
||||
return sock
|
||||
|
||||
def prepare_socksendfile(self):
|
||||
sock = self.make_socket()
|
||||
proto = self.MyProto(self.loop)
|
||||
port = support.find_unused_port()
|
||||
srv_proto = MySendfileProto(loop=self.loop, close_after=close_after)
|
||||
srv_sock = self.make_socket(cleanup=False)
|
||||
srv_sock.bind((support.HOST, port))
|
||||
server = self.run_loop(self.loop.create_server(
|
||||
lambda: proto, sock=srv_sock))
|
||||
self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port)))
|
||||
|
||||
def cleanup():
|
||||
if proto.transport is not None:
|
||||
# can be None if the task was cancelled before
|
||||
# connection_made callback
|
||||
proto.transport.close()
|
||||
self.run_loop(proto.wait_closed())
|
||||
|
||||
server.close()
|
||||
self.run_loop(server.wait_closed())
|
||||
|
||||
self.addCleanup(cleanup)
|
||||
|
||||
return sock, proto
|
||||
|
||||
def test_sock_sendfile_success(self):
|
||||
sock, proto = self.prepare_socksendfile()
|
||||
ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
|
||||
sock.close()
|
||||
self.run_loop(proto.wait_closed())
|
||||
|
||||
self.assertEqual(ret, len(self.DATA))
|
||||
self.assertEqual(proto.data, self.DATA)
|
||||
self.assertEqual(self.file.tell(), len(self.DATA))
|
||||
|
||||
def test_sock_sendfile_with_offset_and_count(self):
|
||||
sock, proto = self.prepare_socksendfile()
|
||||
ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
|
||||
1000, 2000))
|
||||
sock.close()
|
||||
self.run_loop(proto.wait_closed())
|
||||
|
||||
self.assertEqual(proto.data, self.DATA[1000:3000])
|
||||
self.assertEqual(self.file.tell(), 3000)
|
||||
self.assertEqual(ret, 2000)
|
||||
|
||||
def test_sock_sendfile_zero_size(self):
|
||||
sock, proto = self.prepare_socksendfile()
|
||||
with tempfile.TemporaryFile() as f:
|
||||
ret = self.run_loop(self.loop.sock_sendfile(sock, f,
|
||||
0, None))
|
||||
sock.close()
|
||||
self.run_loop(proto.wait_closed())
|
||||
|
||||
self.assertEqual(ret, 0)
|
||||
self.assertEqual(self.file.tell(), 0)
|
||||
|
||||
def test_sock_sendfile_mix_with_regular_send(self):
|
||||
buf = b'1234567890' * 1024 * 1024 # 10 MB
|
||||
sock, proto = self.prepare_socksendfile()
|
||||
self.run_loop(self.loop.sock_sendall(sock, buf))
|
||||
ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
|
||||
self.run_loop(self.loop.sock_sendall(sock, buf))
|
||||
sock.close()
|
||||
self.run_loop(proto.wait_closed())
|
||||
|
||||
self.assertEqual(ret, len(self.DATA))
|
||||
expected = buf + self.DATA + buf
|
||||
self.assertEqual(proto.data, expected)
|
||||
self.assertEqual(self.file.tell(), len(self.DATA))
|
||||
|
||||
|
||||
class SendfileMixin(SendfileBase):
|
||||
|
||||
class MySendfileProto(MyBaseProto):
|
||||
|
||||
def __init__(self, loop=None, close_after=0):
|
||||
super().__init__(loop)
|
||||
self.data = bytearray()
|
||||
self.close_after = close_after
|
||||
|
||||
def data_received(self, data):
|
||||
self.data.extend(data)
|
||||
super().data_received(data)
|
||||
if self.close_after and self.nbytes >= self.close_after:
|
||||
self.transport.close()
|
||||
|
||||
|
||||
# Note: sendfile via SSL transport is equal to sendfile fallback
|
||||
|
||||
def prepare_sendfile(self, *, is_ssl=False, close_after=0):
|
||||
port = support.find_unused_port()
|
||||
srv_proto = self.MySendfileProto(loop=self.loop,
|
||||
close_after=close_after)
|
||||
if is_ssl:
|
||||
if not ssl:
|
||||
self.skipTest("No ssl module")
|
||||
|
@ -2156,7 +2267,7 @@ class SendfileMixin:
|
|||
# reduce send socket buffer size to test on relative small data sets
|
||||
cli_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
|
||||
cli_sock.connect((support.HOST, port))
|
||||
cli_proto = MySendfileProto(loop=self.loop)
|
||||
cli_proto = self.MySendfileProto(loop=self.loop)
|
||||
tr, pr = self.run_loop(self.loop.create_connection(
|
||||
lambda: cli_proto, sock=cli_sock,
|
||||
ssl=cli_ctx, server_hostname=server_hostname))
|
||||
|
@ -2189,7 +2300,7 @@ class SendfileMixin:
|
|||
tr.close()
|
||||
|
||||
def test_sendfile(self):
|
||||
srv_proto, cli_proto = self.prepare()
|
||||
srv_proto, cli_proto = self.prepare_sendfile()
|
||||
ret = self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file))
|
||||
cli_proto.transport.close()
|
||||
|
@ -2200,7 +2311,7 @@ class SendfileMixin:
|
|||
self.assertEqual(self.file.tell(), len(self.DATA))
|
||||
|
||||
def test_sendfile_force_fallback(self):
|
||||
srv_proto, cli_proto = self.prepare()
|
||||
srv_proto, cli_proto = self.prepare_sendfile()
|
||||
|
||||
def sendfile_native(transp, file, offset, count):
|
||||
# to raise SendfileNotAvailableError
|
||||
|
@ -2222,7 +2333,7 @@ class SendfileMixin:
|
|||
if sys.platform == 'win32':
|
||||
if isinstance(self.loop, asyncio.ProactorEventLoop):
|
||||
self.skipTest("Fails on proactor event loop")
|
||||
srv_proto, cli_proto = self.prepare()
|
||||
srv_proto, cli_proto = self.prepare_sendfile()
|
||||
|
||||
def sendfile_native(transp, file, offset, count):
|
||||
# to raise SendfileNotAvailableError
|
||||
|
@ -2243,7 +2354,7 @@ class SendfileMixin:
|
|||
self.assertEqual(self.file.tell(), 0)
|
||||
|
||||
def test_sendfile_ssl(self):
|
||||
srv_proto, cli_proto = self.prepare(is_ssl=True)
|
||||
srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
|
||||
ret = self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file))
|
||||
cli_proto.transport.close()
|
||||
|
@ -2254,7 +2365,7 @@ class SendfileMixin:
|
|||
self.assertEqual(self.file.tell(), len(self.DATA))
|
||||
|
||||
def test_sendfile_for_closing_transp(self):
|
||||
srv_proto, cli_proto = self.prepare()
|
||||
srv_proto, cli_proto = self.prepare_sendfile()
|
||||
cli_proto.transport.close()
|
||||
with self.assertRaisesRegex(RuntimeError, "is closing"):
|
||||
self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
|
||||
|
@ -2263,7 +2374,7 @@ class SendfileMixin:
|
|||
self.assertEqual(self.file.tell(), 0)
|
||||
|
||||
def test_sendfile_pre_and_post_data(self):
|
||||
srv_proto, cli_proto = self.prepare()
|
||||
srv_proto, cli_proto = self.prepare_sendfile()
|
||||
PREFIX = b'zxcvbnm' * 1024
|
||||
SUFFIX = b'0987654321' * 1024
|
||||
cli_proto.transport.write(PREFIX)
|
||||
|
@ -2277,7 +2388,7 @@ class SendfileMixin:
|
|||
self.assertEqual(self.file.tell(), len(self.DATA))
|
||||
|
||||
def test_sendfile_ssl_pre_and_post_data(self):
|
||||
srv_proto, cli_proto = self.prepare(is_ssl=True)
|
||||
srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
|
||||
PREFIX = b'zxcvbnm' * 1024
|
||||
SUFFIX = b'0987654321' * 1024
|
||||
cli_proto.transport.write(PREFIX)
|
||||
|
@ -2291,7 +2402,7 @@ class SendfileMixin:
|
|||
self.assertEqual(self.file.tell(), len(self.DATA))
|
||||
|
||||
def test_sendfile_partial(self):
|
||||
srv_proto, cli_proto = self.prepare()
|
||||
srv_proto, cli_proto = self.prepare_sendfile()
|
||||
ret = self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
|
||||
cli_proto.transport.close()
|
||||
|
@ -2302,7 +2413,7 @@ class SendfileMixin:
|
|||
self.assertEqual(self.file.tell(), 1100)
|
||||
|
||||
def test_sendfile_ssl_partial(self):
|
||||
srv_proto, cli_proto = self.prepare(is_ssl=True)
|
||||
srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
|
||||
ret = self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
|
||||
cli_proto.transport.close()
|
||||
|
@ -2313,7 +2424,8 @@ class SendfileMixin:
|
|||
self.assertEqual(self.file.tell(), 1100)
|
||||
|
||||
def test_sendfile_close_peer_after_receiving(self):
|
||||
srv_proto, cli_proto = self.prepare(close_after=len(self.DATA))
|
||||
srv_proto, cli_proto = self.prepare_sendfile(
|
||||
close_after=len(self.DATA))
|
||||
ret = self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file))
|
||||
cli_proto.transport.close()
|
||||
|
@ -2324,8 +2436,8 @@ class SendfileMixin:
|
|||
self.assertEqual(self.file.tell(), len(self.DATA))
|
||||
|
||||
def test_sendfile_ssl_close_peer_after_receiving(self):
|
||||
srv_proto, cli_proto = self.prepare(is_ssl=True,
|
||||
close_after=len(self.DATA))
|
||||
srv_proto, cli_proto = self.prepare_sendfile(
|
||||
is_ssl=True, close_after=len(self.DATA))
|
||||
ret = self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file))
|
||||
self.run_loop(srv_proto.done)
|
||||
|
@ -2335,7 +2447,7 @@ class SendfileMixin:
|
|||
self.assertEqual(self.file.tell(), len(self.DATA))
|
||||
|
||||
def test_sendfile_close_peer_in_middle_of_receiving(self):
|
||||
srv_proto, cli_proto = self.prepare(close_after=1024)
|
||||
srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
|
||||
with self.assertRaises(ConnectionError):
|
||||
self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file))
|
||||
|
@ -2345,6 +2457,7 @@ class SendfileMixin:
|
|||
srv_proto.nbytes)
|
||||
self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
|
||||
self.file.tell())
|
||||
self.assertTrue(cli_proto.transport.is_closing())
|
||||
|
||||
def test_sendfile_fallback_close_peer_in_middle_of_receiving(self):
|
||||
|
||||
|
@ -2355,7 +2468,7 @@ class SendfileMixin:
|
|||
|
||||
self.loop._sendfile_native = sendfile_native
|
||||
|
||||
srv_proto, cli_proto = self.prepare(close_after=1024)
|
||||
srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
|
||||
with self.assertRaises(ConnectionError):
|
||||
self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file))
|
||||
|
@ -2369,7 +2482,7 @@ class SendfileMixin:
|
|||
@unittest.skipIf(not hasattr(os, 'sendfile'),
|
||||
"Don't have native sendfile support")
|
||||
def test_sendfile_prevents_bare_write(self):
|
||||
srv_proto, cli_proto = self.prepare()
|
||||
srv_proto, cli_proto = self.prepare_sendfile()
|
||||
fut = self.loop.create_future()
|
||||
|
||||
async def coro():
|
||||
|
@ -2397,6 +2510,7 @@ if sys.platform == 'win32':
|
|||
|
||||
class SelectEventLoopTests(EventLoopTestsMixin,
|
||||
SendfileMixin,
|
||||
SockSendfileMixin,
|
||||
test_utils.TestCase):
|
||||
|
||||
def create_event_loop(self):
|
||||
|
@ -2404,6 +2518,7 @@ if sys.platform == 'win32':
|
|||
|
||||
class ProactorEventLoopTests(EventLoopTestsMixin,
|
||||
SendfileMixin,
|
||||
SockSendfileMixin,
|
||||
SubprocessTestsMixin,
|
||||
test_utils.TestCase):
|
||||
|
||||
|
@ -2431,7 +2546,9 @@ if sys.platform == 'win32':
|
|||
else:
|
||||
import selectors
|
||||
|
||||
class UnixEventLoopTestsMixin(EventLoopTestsMixin, SendfileMixin):
|
||||
class UnixEventLoopTestsMixin(EventLoopTestsMixin,
|
||||
SendfileMixin,
|
||||
SockSendfileMixin):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
watcher = asyncio.SafeChildWatcher()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue