bpo-23749: Implement loop.start_tls() (#5039)

This commit is contained in:
Yury Selivanov 2017-12-30 00:35:36 -05:00 committed by GitHub
parent bbdb17d19b
commit f111b3dcb4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 580 additions and 54 deletions

View file

@ -13,6 +13,7 @@ from asyncio import log
from asyncio import sslproto
from asyncio import tasks
from test.test_asyncio import utils as test_utils
from test.test_asyncio import functional as func_tests
@unittest.skipIf(ssl is None, 'No ssl module')
@ -158,5 +159,156 @@ class SslProtoHandshakeTests(test_utils.TestCase):
self.assertIs(ssl_proto._app_protocol, new_app_proto)
##############################################################################
# Start TLS Tests
##############################################################################
class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
def new_loop(self):
raise NotImplementedError
def test_start_tls_client_1(self):
HELLO_MSG = b'1' * 1024 * 1024 * 5
server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
def serve(sock):
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
sock.start_tls(server_context, server_side=True)
sock.sendall(b'O')
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
sock.close()
class ClientProto(asyncio.Protocol):
def __init__(self, on_data, on_eof):
self.on_data = on_data
self.on_eof = on_eof
self.con_made_cnt = 0
def connection_made(proto, tr):
proto.con_made_cnt += 1
# Ensure connection_made gets called only once.
self.assertEqual(proto.con_made_cnt, 1)
def data_received(self, data):
self.on_data.set_result(data)
def eof_received(self):
self.on_eof.set_result(True)
async def client(addr):
on_data = self.loop.create_future()
on_eof = self.loop.create_future()
tr, proto = await self.loop.create_connection(
lambda: ClientProto(on_data, on_eof), *addr)
tr.write(HELLO_MSG)
new_tr = await self.loop.start_tls(tr, proto, client_context)
self.assertEqual(await on_data, b'O')
new_tr.write(HELLO_MSG)
await on_eof
new_tr.close()
with self.tcp_server(serve) as srv:
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
def test_start_tls_server_1(self):
HELLO_MSG = b'1' * 1024 * 1024 * 5
server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
def client(sock, addr):
sock.connect(addr)
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
sock.start_tls(client_context)
sock.sendall(HELLO_MSG)
sock.close()
class ServerProto(asyncio.Protocol):
def __init__(self, on_con, on_eof):
self.on_con = on_con
self.on_eof = on_eof
self.data = b''
def connection_made(self, tr):
self.on_con.set_result(tr)
def data_received(self, data):
self.data += data
def eof_received(self):
self.on_eof.set_result(1)
async def main():
tr = await on_con
tr.write(HELLO_MSG)
self.assertEqual(proto.data, b'')
new_tr = await self.loop.start_tls(
tr, proto, server_context,
server_side=True)
await on_eof
self.assertEqual(proto.data, HELLO_MSG)
new_tr.close()
server.close()
await server.wait_closed()
on_con = self.loop.create_future()
on_eof = self.loop.create_future()
proto = ServerProto(on_con, on_eof)
server = self.loop.run_until_complete(
self.loop.create_server(
lambda: proto, '127.0.0.1', 0))
addr = server.sockets[0].getsockname()
with self.tcp_client(lambda sock: client(sock, addr)):
self.loop.run_until_complete(
asyncio.wait_for(main(), loop=self.loop, timeout=10))
def test_start_tls_wrong_args(self):
async def main():
with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
await self.loop.start_tls(None, None, None)
sslctx = test_utils.simple_server_sslcontext()
with self.assertRaisesRegex(TypeError, 'is not supported'):
await self.loop.start_tls(None, None, sslctx)
self.loop.run_until_complete(main())
@unittest.skipIf(ssl is None, 'No ssl module')
class SelectorStartTLS(BaseStartTLS, unittest.TestCase):
def new_loop(self):
return asyncio.SelectorEventLoop()
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
class ProactorStartTLS(BaseStartTLS, unittest.TestCase):
def new_loop(self):
return asyncio.ProactorEventLoop()
if __name__ == '__main__':
unittest.main()