bpo-32622: Native sendfile on windows (GH-5565)

* Support sendfile on Windows Proactor event loop naively.
(cherry picked from commit a19fb3c6aa)

Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
This commit is contained in:
Miss Islington (bot) 2018-02-25 09:10:58 -08:00 committed by GitHub
parent b6b6669cfd
commit 632c1cb571
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 431 additions and 93 deletions

View file

@ -15,6 +15,7 @@ except ImportError:
ssl = None
import subprocess
import sys
import tempfile
import threading
import time
import errno
@ -2092,22 +2093,7 @@ class SubprocessTestsMixin:
self.loop.run_until_complete(connect(shell=False))
class MySendfileProto(MyBaseProto):
def __init__(self, loop=None, close_after=0):
super().__init__(loop)
self.data = bytearray()
self.close_after = close_after
def data_received(self, data):
self.data.extend(data)
super().data_received(data)
if self.close_after and self.nbytes >= self.close_after:
self.transport.close()
class SendfileMixin:
# Note: sendfile via SSL transport is equal to sendfile fallback
class SendfileBase:
DATA = b"12345abcde" * 160 * 1024 # 160 KiB
@ -2130,9 +2116,134 @@ class SendfileMixin:
def run_loop(self, coro):
return self.loop.run_until_complete(coro)
def prepare(self, *, is_ssl=False, close_after=0):
class SockSendfileMixin(SendfileBase):
class MyProto(asyncio.Protocol):
def __init__(self, loop):
self.started = False
self.closed = False
self.data = bytearray()
self.fut = loop.create_future()
self.transport = None
def connection_made(self, transport):
self.started = True
self.transport = transport
def data_received(self, data):
self.data.extend(data)
def connection_lost(self, exc):
self.closed = True
self.fut.set_result(None)
async def wait_closed(self):
await self.fut
def make_socket(self, cleanup=True):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setblocking(False)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024)
if cleanup:
self.addCleanup(sock.close)
return sock
def prepare_socksendfile(self):
sock = self.make_socket()
proto = self.MyProto(self.loop)
port = support.find_unused_port()
srv_proto = MySendfileProto(loop=self.loop, close_after=close_after)
srv_sock = self.make_socket(cleanup=False)
srv_sock.bind((support.HOST, port))
server = self.run_loop(self.loop.create_server(
lambda: proto, sock=srv_sock))
self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port)))
def cleanup():
if proto.transport is not None:
# can be None if the task was cancelled before
# connection_made callback
proto.transport.close()
self.run_loop(proto.wait_closed())
server.close()
self.run_loop(server.wait_closed())
self.addCleanup(cleanup)
return sock, proto
def test_sock_sendfile_success(self):
sock, proto = self.prepare_socksendfile()
ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
sock.close()
self.run_loop(proto.wait_closed())
self.assertEqual(ret, len(self.DATA))
self.assertEqual(proto.data, self.DATA)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sock_sendfile_with_offset_and_count(self):
sock, proto = self.prepare_socksendfile()
ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
1000, 2000))
sock.close()
self.run_loop(proto.wait_closed())
self.assertEqual(proto.data, self.DATA[1000:3000])
self.assertEqual(self.file.tell(), 3000)
self.assertEqual(ret, 2000)
def test_sock_sendfile_zero_size(self):
sock, proto = self.prepare_socksendfile()
with tempfile.TemporaryFile() as f:
ret = self.run_loop(self.loop.sock_sendfile(sock, f,
0, None))
sock.close()
self.run_loop(proto.wait_closed())
self.assertEqual(ret, 0)
self.assertEqual(self.file.tell(), 0)
def test_sock_sendfile_mix_with_regular_send(self):
buf = b'1234567890' * 1024 * 1024 # 10 MB
sock, proto = self.prepare_socksendfile()
self.run_loop(self.loop.sock_sendall(sock, buf))
ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
self.run_loop(self.loop.sock_sendall(sock, buf))
sock.close()
self.run_loop(proto.wait_closed())
self.assertEqual(ret, len(self.DATA))
expected = buf + self.DATA + buf
self.assertEqual(proto.data, expected)
self.assertEqual(self.file.tell(), len(self.DATA))
class SendfileMixin(SendfileBase):
class MySendfileProto(MyBaseProto):
def __init__(self, loop=None, close_after=0):
super().__init__(loop)
self.data = bytearray()
self.close_after = close_after
def data_received(self, data):
self.data.extend(data)
super().data_received(data)
if self.close_after and self.nbytes >= self.close_after:
self.transport.close()
# Note: sendfile via SSL transport is equal to sendfile fallback
def prepare_sendfile(self, *, is_ssl=False, close_after=0):
port = support.find_unused_port()
srv_proto = self.MySendfileProto(loop=self.loop,
close_after=close_after)
if is_ssl:
if not ssl:
self.skipTest("No ssl module")
@ -2156,7 +2267,7 @@ class SendfileMixin:
# reduce send socket buffer size to test on relative small data sets
cli_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
cli_sock.connect((support.HOST, port))
cli_proto = MySendfileProto(loop=self.loop)
cli_proto = self.MySendfileProto(loop=self.loop)
tr, pr = self.run_loop(self.loop.create_connection(
lambda: cli_proto, sock=cli_sock,
ssl=cli_ctx, server_hostname=server_hostname))
@ -2189,7 +2300,7 @@ class SendfileMixin:
tr.close()
def test_sendfile(self):
srv_proto, cli_proto = self.prepare()
srv_proto, cli_proto = self.prepare_sendfile()
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.close()
@ -2200,7 +2311,7 @@ class SendfileMixin:
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_force_fallback(self):
srv_proto, cli_proto = self.prepare()
srv_proto, cli_proto = self.prepare_sendfile()
def sendfile_native(transp, file, offset, count):
# to raise SendfileNotAvailableError
@ -2222,7 +2333,7 @@ class SendfileMixin:
if sys.platform == 'win32':
if isinstance(self.loop, asyncio.ProactorEventLoop):
self.skipTest("Fails on proactor event loop")
srv_proto, cli_proto = self.prepare()
srv_proto, cli_proto = self.prepare_sendfile()
def sendfile_native(transp, file, offset, count):
# to raise SendfileNotAvailableError
@ -2243,7 +2354,7 @@ class SendfileMixin:
self.assertEqual(self.file.tell(), 0)
def test_sendfile_ssl(self):
srv_proto, cli_proto = self.prepare(is_ssl=True)
srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.close()
@ -2254,7 +2365,7 @@ class SendfileMixin:
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_for_closing_transp(self):
srv_proto, cli_proto = self.prepare()
srv_proto, cli_proto = self.prepare_sendfile()
cli_proto.transport.close()
with self.assertRaisesRegex(RuntimeError, "is closing"):
self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
@ -2263,7 +2374,7 @@ class SendfileMixin:
self.assertEqual(self.file.tell(), 0)
def test_sendfile_pre_and_post_data(self):
srv_proto, cli_proto = self.prepare()
srv_proto, cli_proto = self.prepare_sendfile()
PREFIX = b'zxcvbnm' * 1024
SUFFIX = b'0987654321' * 1024
cli_proto.transport.write(PREFIX)
@ -2277,7 +2388,7 @@ class SendfileMixin:
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_ssl_pre_and_post_data(self):
srv_proto, cli_proto = self.prepare(is_ssl=True)
srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
PREFIX = b'zxcvbnm' * 1024
SUFFIX = b'0987654321' * 1024
cli_proto.transport.write(PREFIX)
@ -2291,7 +2402,7 @@ class SendfileMixin:
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_partial(self):
srv_proto, cli_proto = self.prepare()
srv_proto, cli_proto = self.prepare_sendfile()
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
cli_proto.transport.close()
@ -2302,7 +2413,7 @@ class SendfileMixin:
self.assertEqual(self.file.tell(), 1100)
def test_sendfile_ssl_partial(self):
srv_proto, cli_proto = self.prepare(is_ssl=True)
srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
cli_proto.transport.close()
@ -2313,7 +2424,8 @@ class SendfileMixin:
self.assertEqual(self.file.tell(), 1100)
def test_sendfile_close_peer_after_receiving(self):
srv_proto, cli_proto = self.prepare(close_after=len(self.DATA))
srv_proto, cli_proto = self.prepare_sendfile(
close_after=len(self.DATA))
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.close()
@ -2324,8 +2436,8 @@ class SendfileMixin:
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_ssl_close_peer_after_receiving(self):
srv_proto, cli_proto = self.prepare(is_ssl=True,
close_after=len(self.DATA))
srv_proto, cli_proto = self.prepare_sendfile(
is_ssl=True, close_after=len(self.DATA))
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
self.run_loop(srv_proto.done)
@ -2335,7 +2447,7 @@ class SendfileMixin:
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_close_peer_in_middle_of_receiving(self):
srv_proto, cli_proto = self.prepare(close_after=1024)
srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
with self.assertRaises(ConnectionError):
self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
@ -2345,6 +2457,7 @@ class SendfileMixin:
srv_proto.nbytes)
self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
self.file.tell())
self.assertTrue(cli_proto.transport.is_closing())
def test_sendfile_fallback_close_peer_in_middle_of_receiving(self):
@ -2355,7 +2468,7 @@ class SendfileMixin:
self.loop._sendfile_native = sendfile_native
srv_proto, cli_proto = self.prepare(close_after=1024)
srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
with self.assertRaises(ConnectionError):
self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
@ -2369,7 +2482,7 @@ class SendfileMixin:
@unittest.skipIf(not hasattr(os, 'sendfile'),
"Don't have native sendfile support")
def test_sendfile_prevents_bare_write(self):
srv_proto, cli_proto = self.prepare()
srv_proto, cli_proto = self.prepare_sendfile()
fut = self.loop.create_future()
async def coro():
@ -2397,6 +2510,7 @@ if sys.platform == 'win32':
class SelectEventLoopTests(EventLoopTestsMixin,
SendfileMixin,
SockSendfileMixin,
test_utils.TestCase):
def create_event_loop(self):
@ -2404,6 +2518,7 @@ if sys.platform == 'win32':
class ProactorEventLoopTests(EventLoopTestsMixin,
SendfileMixin,
SockSendfileMixin,
SubprocessTestsMixin,
test_utils.TestCase):
@ -2431,7 +2546,9 @@ if sys.platform == 'win32':
else:
import selectors
class UnixEventLoopTestsMixin(EventLoopTestsMixin, SendfileMixin):
class UnixEventLoopTestsMixin(EventLoopTestsMixin,
SendfileMixin,
SockSendfileMixin):
def setUp(self):
super().setUp()
watcher = asyncio.SafeChildWatcher()