mirror of
https://github.com/python/cpython.git
synced 2025-10-09 16:34:44 +00:00
bpo-23749: Implement loop.start_tls() (#5039)
This commit is contained in:
parent
bbdb17d19b
commit
f111b3dcb4
10 changed files with 580 additions and 54 deletions
|
@ -13,6 +13,7 @@ from asyncio import log
|
|||
from asyncio import sslproto
|
||||
from asyncio import tasks
|
||||
from test.test_asyncio import utils as test_utils
|
||||
from test.test_asyncio import functional as func_tests
|
||||
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
|
@ -158,5 +159,156 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
|||
self.assertIs(ssl_proto._app_protocol, new_app_proto)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Start TLS Tests
|
||||
##############################################################################
|
||||
|
||||
|
||||
class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
|
||||
|
||||
def new_loop(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_start_tls_client_1(self):
|
||||
HELLO_MSG = b'1' * 1024 * 1024 * 5
|
||||
|
||||
server_context = test_utils.simple_server_sslcontext()
|
||||
client_context = test_utils.simple_client_sslcontext()
|
||||
|
||||
def serve(sock):
|
||||
data = sock.recv_all(len(HELLO_MSG))
|
||||
self.assertEqual(len(data), len(HELLO_MSG))
|
||||
|
||||
sock.start_tls(server_context, server_side=True)
|
||||
|
||||
sock.sendall(b'O')
|
||||
data = sock.recv_all(len(HELLO_MSG))
|
||||
self.assertEqual(len(data), len(HELLO_MSG))
|
||||
sock.close()
|
||||
|
||||
class ClientProto(asyncio.Protocol):
|
||||
def __init__(self, on_data, on_eof):
|
||||
self.on_data = on_data
|
||||
self.on_eof = on_eof
|
||||
self.con_made_cnt = 0
|
||||
|
||||
def connection_made(proto, tr):
|
||||
proto.con_made_cnt += 1
|
||||
# Ensure connection_made gets called only once.
|
||||
self.assertEqual(proto.con_made_cnt, 1)
|
||||
|
||||
def data_received(self, data):
|
||||
self.on_data.set_result(data)
|
||||
|
||||
def eof_received(self):
|
||||
self.on_eof.set_result(True)
|
||||
|
||||
async def client(addr):
|
||||
on_data = self.loop.create_future()
|
||||
on_eof = self.loop.create_future()
|
||||
|
||||
tr, proto = await self.loop.create_connection(
|
||||
lambda: ClientProto(on_data, on_eof), *addr)
|
||||
|
||||
tr.write(HELLO_MSG)
|
||||
new_tr = await self.loop.start_tls(tr, proto, client_context)
|
||||
|
||||
self.assertEqual(await on_data, b'O')
|
||||
new_tr.write(HELLO_MSG)
|
||||
await on_eof
|
||||
|
||||
new_tr.close()
|
||||
|
||||
with self.tcp_server(serve) as srv:
|
||||
self.loop.run_until_complete(
|
||||
asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
|
||||
|
||||
def test_start_tls_server_1(self):
|
||||
HELLO_MSG = b'1' * 1024 * 1024 * 5
|
||||
|
||||
server_context = test_utils.simple_server_sslcontext()
|
||||
client_context = test_utils.simple_client_sslcontext()
|
||||
|
||||
def client(sock, addr):
|
||||
sock.connect(addr)
|
||||
data = sock.recv_all(len(HELLO_MSG))
|
||||
self.assertEqual(len(data), len(HELLO_MSG))
|
||||
|
||||
sock.start_tls(client_context)
|
||||
sock.sendall(HELLO_MSG)
|
||||
sock.close()
|
||||
|
||||
class ServerProto(asyncio.Protocol):
|
||||
def __init__(self, on_con, on_eof):
|
||||
self.on_con = on_con
|
||||
self.on_eof = on_eof
|
||||
self.data = b''
|
||||
|
||||
def connection_made(self, tr):
|
||||
self.on_con.set_result(tr)
|
||||
|
||||
def data_received(self, data):
|
||||
self.data += data
|
||||
|
||||
def eof_received(self):
|
||||
self.on_eof.set_result(1)
|
||||
|
||||
async def main():
|
||||
tr = await on_con
|
||||
tr.write(HELLO_MSG)
|
||||
|
||||
self.assertEqual(proto.data, b'')
|
||||
|
||||
new_tr = await self.loop.start_tls(
|
||||
tr, proto, server_context,
|
||||
server_side=True)
|
||||
|
||||
await on_eof
|
||||
self.assertEqual(proto.data, HELLO_MSG)
|
||||
new_tr.close()
|
||||
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
|
||||
on_con = self.loop.create_future()
|
||||
on_eof = self.loop.create_future()
|
||||
proto = ServerProto(on_con, on_eof)
|
||||
|
||||
server = self.loop.run_until_complete(
|
||||
self.loop.create_server(
|
||||
lambda: proto, '127.0.0.1', 0))
|
||||
addr = server.sockets[0].getsockname()
|
||||
|
||||
with self.tcp_client(lambda sock: client(sock, addr)):
|
||||
self.loop.run_until_complete(
|
||||
asyncio.wait_for(main(), loop=self.loop, timeout=10))
|
||||
|
||||
def test_start_tls_wrong_args(self):
|
||||
async def main():
|
||||
with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
|
||||
await self.loop.start_tls(None, None, None)
|
||||
|
||||
sslctx = test_utils.simple_server_sslcontext()
|
||||
with self.assertRaisesRegex(TypeError, 'is not supported'):
|
||||
await self.loop.start_tls(None, None, sslctx)
|
||||
|
||||
self.loop.run_until_complete(main())
|
||||
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
class SelectorStartTLS(BaseStartTLS, unittest.TestCase):
|
||||
|
||||
def new_loop(self):
|
||||
return asyncio.SelectorEventLoop()
|
||||
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
|
||||
class ProactorStartTLS(BaseStartTLS, unittest.TestCase):
|
||||
|
||||
def new_loop(self):
|
||||
return asyncio.ProactorEventLoop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue