bpo-33654: Support BufferedProtocol in set_protocol() and start_tls() (GH-7130)

In this commit:

* Support BufferedProtocol in set_protocol() and start_tls()
* Fix proactor to cancel readers reliably
* Update tests to be compatible with OpenSSL 1.1.1
* Clarify BufferedProtocol docs
* Bump TLS tests timeouts to 60 seconds; eliminate possible race from start_serving
* Rewrite test_start_tls_server_1
This commit is contained in:
Yury Selivanov 2018-05-28 14:31:28 -04:00 committed by GitHub
parent e549c4be5f
commit dbf102271f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 382 additions and 69 deletions

View file

@ -1,8 +1,7 @@
"""Tests for asyncio/sslproto.py."""
import os
import logging
import time
import socket
import unittest
from unittest import mock
try:
@ -185,17 +184,67 @@ class SslProtoHandshakeTests(test_utils.TestCase):
class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
PAYLOAD_SIZE = 1024 * 100
TIMEOUT = 60
def new_loop(self):
raise NotImplementedError
def test_start_tls_client_1(self):
HELLO_MSG = b'1' * 1024 * 1024
def test_buf_feed_data(self):
class Proto(asyncio.BufferedProtocol):
def __init__(self, bufsize, usemv):
self.buf = bytearray(bufsize)
self.mv = memoryview(self.buf)
self.data = b''
self.usemv = usemv
def get_buffer(self, sizehint):
if self.usemv:
return self.mv
else:
return self.buf
def buffer_updated(self, nsize):
if self.usemv:
self.data += self.mv[:nsize]
else:
self.data += self.buf[:nsize]
for usemv in [False, True]:
proto = Proto(1, usemv)
sslproto._feed_data_to_bufferred_proto(proto, b'12345')
self.assertEqual(proto.data, b'12345')
proto = Proto(2, usemv)
sslproto._feed_data_to_bufferred_proto(proto, b'12345')
self.assertEqual(proto.data, b'12345')
proto = Proto(2, usemv)
sslproto._feed_data_to_bufferred_proto(proto, b'1234')
self.assertEqual(proto.data, b'1234')
proto = Proto(4, usemv)
sslproto._feed_data_to_bufferred_proto(proto, b'1234')
self.assertEqual(proto.data, b'1234')
proto = Proto(100, usemv)
sslproto._feed_data_to_bufferred_proto(proto, b'12345')
self.assertEqual(proto.data, b'12345')
proto = Proto(0, usemv)
with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
sslproto._feed_data_to_bufferred_proto(proto, b'12345')
def test_start_tls_client_reg_proto_1(self):
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
def serve(sock):
sock.settimeout(5)
sock.settimeout(self.TIMEOUT)
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
@ -205,6 +254,8 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
sock.sendall(b'O')
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
sock.shutdown(socket.SHUT_RDWR)
sock.close()
class ClientProto(asyncio.Protocol):
@ -246,17 +297,80 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
def test_start_tls_server_1(self):
HELLO_MSG = b'1' * 1024 * 1024
def test_start_tls_client_buf_proto_1(self):
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
def serve(sock):
sock.settimeout(self.TIMEOUT)
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
sock.start_tls(server_context, server_side=True)
sock.sendall(b'O')
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
sock.shutdown(socket.SHUT_RDWR)
sock.close()
class ClientProto(asyncio.BufferedProtocol):
def __init__(self, on_data, on_eof):
self.on_data = on_data
self.on_eof = on_eof
self.con_made_cnt = 0
self.buf = bytearray(1)
def connection_made(proto, tr):
proto.con_made_cnt += 1
# Ensure connection_made gets called only once.
self.assertEqual(proto.con_made_cnt, 1)
def get_buffer(self, sizehint):
return self.buf
def buffer_updated(self, nsize):
assert nsize == 1
self.on_data.set_result(bytes(self.buf[:nsize]))
def eof_received(self):
self.on_eof.set_result(True)
async def client(addr):
await asyncio.sleep(0.5, loop=self.loop)
on_data = self.loop.create_future()
on_eof = self.loop.create_future()
tr, proto = await self.loop.create_connection(
lambda: ClientProto(on_data, on_eof), *addr)
tr.write(HELLO_MSG)
new_tr = await self.loop.start_tls(tr, proto, client_context)
self.assertEqual(await on_data, b'O')
new_tr.write(HELLO_MSG)
await on_eof
new_tr.close()
with self.tcp_server(serve) as srv:
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr),
loop=self.loop, timeout=self.TIMEOUT))
def test_start_tls_server_1(self):
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
# TODO: fix TLSv1.3 support
client_context.options |= ssl.OP_NO_TLSv1_3
def client(sock, addr):
time.sleep(0.5)
sock.settimeout(5)
sock.settimeout(self.TIMEOUT)
sock.connect(addr)
data = sock.recv_all(len(HELLO_MSG))
@ -264,12 +378,15 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
sock.start_tls(client_context)
sock.sendall(HELLO_MSG)
sock.shutdown(socket.SHUT_RDWR)
sock.close()
class ServerProto(asyncio.Protocol):
def __init__(self, on_con, on_eof):
def __init__(self, on_con, on_eof, on_con_lost):
self.on_con = on_con
self.on_eof = on_eof
self.on_con_lost = on_con_lost
self.data = b''
def connection_made(self, tr):
@ -281,7 +398,13 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
def eof_received(self):
self.on_eof.set_result(1)
async def main():
def connection_lost(self, exc):
if exc is None:
self.on_con_lost.set_result(None)
else:
self.on_con_lost.set_exception(exc)
async def main(proto, on_con, on_eof, on_con_lost):
tr = await on_con
tr.write(HELLO_MSG)
@ -292,24 +415,29 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
server_side=True)
await on_eof
await on_con_lost
self.assertEqual(proto.data, HELLO_MSG)
new_tr.close()
async def run_main():
on_con = self.loop.create_future()
on_eof = self.loop.create_future()
on_con_lost = self.loop.create_future()
proto = ServerProto(on_con, on_eof, on_con_lost)
server = await self.loop.create_server(
lambda: proto, '127.0.0.1', 0)
addr = server.sockets[0].getsockname()
with self.tcp_client(lambda sock: client(sock, addr)):
await asyncio.wait_for(
main(proto, on_con, on_eof, on_con_lost),
loop=self.loop, timeout=self.TIMEOUT)
server.close()
await server.wait_closed()
on_con = self.loop.create_future()
on_eof = self.loop.create_future()
proto = ServerProto(on_con, on_eof)
server = self.loop.run_until_complete(
self.loop.create_server(
lambda: proto, '127.0.0.1', 0))
addr = server.sockets[0].getsockname()
with self.tcp_client(lambda sock: client(sock, addr)):
self.loop.run_until_complete(
asyncio.wait_for(main(), loop=self.loop, timeout=10))
self.loop.run_until_complete(run_main())
def test_start_tls_wrong_args(self):
async def main():
@ -332,7 +460,6 @@ class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
@unittest.skipIf(os.environ.get('APPVEYOR'), 'XXX: issue 32458')
class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase):
def new_loop(self):