mirror of
https://github.com/python/cpython.git
synced 2025-08-29 13:15:11 +00:00

Some FreeBSD buildbots fail to run this test as the eof was not being received by the server if the size is not big enough. This behaviour only appears if the client is using TLS1.3.
694 lines
23 KiB
Python
694 lines
23 KiB
Python
"""Tests for asyncio/sslproto.py."""
|
|
|
|
import logging
|
|
import socket
|
|
import sys
|
|
import unittest
|
|
from unittest import mock
|
|
try:
|
|
import ssl
|
|
except ImportError:
|
|
ssl = None
|
|
|
|
import asyncio
|
|
from asyncio import log
|
|
from asyncio import protocols
|
|
from asyncio import sslproto
|
|
from asyncio import tasks
|
|
from test.test_asyncio import utils as test_utils
|
|
from test.test_asyncio import functional as func_tests
|
|
|
|
|
|
def tearDownModule():
|
|
asyncio.set_event_loop_policy(None)
|
|
|
|
|
|
@unittest.skipIf(ssl is None, 'No ssl module')
|
|
class SslProtoHandshakeTests(test_utils.TestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.loop = asyncio.new_event_loop()
|
|
self.set_event_loop(self.loop)
|
|
|
|
def ssl_protocol(self, *, waiter=None, proto=None):
|
|
sslcontext = test_utils.dummy_ssl_context()
|
|
if proto is None: # app protocol
|
|
proto = asyncio.Protocol()
|
|
ssl_proto = sslproto.SSLProtocol(self.loop, proto, sslcontext, waiter,
|
|
ssl_handshake_timeout=0.1)
|
|
self.assertIs(ssl_proto._app_transport.get_protocol(), proto)
|
|
self.addCleanup(ssl_proto._app_transport.close)
|
|
return ssl_proto
|
|
|
|
def connection_made(self, ssl_proto, *, do_handshake=None):
|
|
transport = mock.Mock()
|
|
sslpipe = mock.Mock()
|
|
sslpipe.shutdown.return_value = b''
|
|
if do_handshake:
|
|
sslpipe.do_handshake.side_effect = do_handshake
|
|
else:
|
|
def mock_handshake(callback):
|
|
return []
|
|
sslpipe.do_handshake.side_effect = mock_handshake
|
|
with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
|
|
ssl_proto.connection_made(transport)
|
|
return transport
|
|
|
|
def test_handshake_timeout_zero(self):
|
|
sslcontext = test_utils.dummy_ssl_context()
|
|
app_proto = mock.Mock()
|
|
waiter = mock.Mock()
|
|
with self.assertRaisesRegex(ValueError, 'a positive number'):
|
|
sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
|
|
ssl_handshake_timeout=0)
|
|
|
|
def test_handshake_timeout_negative(self):
|
|
sslcontext = test_utils.dummy_ssl_context()
|
|
app_proto = mock.Mock()
|
|
waiter = mock.Mock()
|
|
with self.assertRaisesRegex(ValueError, 'a positive number'):
|
|
sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
|
|
ssl_handshake_timeout=-10)
|
|
|
|
def test_eof_received_waiter(self):
|
|
waiter = asyncio.Future(loop=self.loop)
|
|
ssl_proto = self.ssl_protocol(waiter=waiter)
|
|
self.connection_made(ssl_proto)
|
|
ssl_proto.eof_received()
|
|
test_utils.run_briefly(self.loop)
|
|
self.assertIsInstance(waiter.exception(), ConnectionResetError)
|
|
|
|
def test_fatal_error_no_name_error(self):
|
|
# From issue #363.
|
|
# _fatal_error() generates a NameError if sslproto.py
|
|
# does not import base_events.
|
|
waiter = asyncio.Future(loop=self.loop)
|
|
ssl_proto = self.ssl_protocol(waiter=waiter)
|
|
# Temporarily turn off error logging so as not to spoil test output.
|
|
log_level = log.logger.getEffectiveLevel()
|
|
log.logger.setLevel(logging.FATAL)
|
|
try:
|
|
ssl_proto._fatal_error(None)
|
|
finally:
|
|
# Restore error logging.
|
|
log.logger.setLevel(log_level)
|
|
|
|
def test_connection_lost(self):
|
|
# From issue #472.
|
|
# yield from waiter hang if lost_connection was called.
|
|
waiter = asyncio.Future(loop=self.loop)
|
|
ssl_proto = self.ssl_protocol(waiter=waiter)
|
|
self.connection_made(ssl_proto)
|
|
ssl_proto.connection_lost(ConnectionAbortedError)
|
|
test_utils.run_briefly(self.loop)
|
|
self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
|
|
|
|
def test_close_during_handshake(self):
|
|
# bpo-29743 Closing transport during handshake process leaks socket
|
|
waiter = asyncio.Future(loop=self.loop)
|
|
ssl_proto = self.ssl_protocol(waiter=waiter)
|
|
|
|
transport = self.connection_made(ssl_proto)
|
|
test_utils.run_briefly(self.loop)
|
|
|
|
ssl_proto._app_transport.close()
|
|
self.assertTrue(transport.abort.called)
|
|
|
|
def test_get_extra_info_on_closed_connection(self):
|
|
waiter = asyncio.Future(loop=self.loop)
|
|
ssl_proto = self.ssl_protocol(waiter=waiter)
|
|
self.assertIsNone(ssl_proto._get_extra_info('socket'))
|
|
default = object()
|
|
self.assertIs(ssl_proto._get_extra_info('socket', default), default)
|
|
self.connection_made(ssl_proto)
|
|
self.assertIsNotNone(ssl_proto._get_extra_info('socket'))
|
|
ssl_proto.connection_lost(None)
|
|
self.assertIsNone(ssl_proto._get_extra_info('socket'))
|
|
|
|
def test_set_new_app_protocol(self):
|
|
waiter = asyncio.Future(loop=self.loop)
|
|
ssl_proto = self.ssl_protocol(waiter=waiter)
|
|
new_app_proto = asyncio.Protocol()
|
|
ssl_proto._app_transport.set_protocol(new_app_proto)
|
|
self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto)
|
|
self.assertIs(ssl_proto._app_protocol, new_app_proto)
|
|
|
|
def test_data_received_after_closing(self):
|
|
ssl_proto = self.ssl_protocol()
|
|
self.connection_made(ssl_proto)
|
|
transp = ssl_proto._app_transport
|
|
|
|
transp.close()
|
|
|
|
# should not raise
|
|
self.assertIsNone(ssl_proto.data_received(b'data'))
|
|
|
|
def test_write_after_closing(self):
|
|
ssl_proto = self.ssl_protocol()
|
|
self.connection_made(ssl_proto)
|
|
transp = ssl_proto._app_transport
|
|
transp.close()
|
|
|
|
# should not raise
|
|
self.assertIsNone(transp.write(b'data'))
|
|
|
|
|
|
##############################################################################
|
|
# Start TLS Tests
|
|
##############################################################################
|
|
|
|
|
|
class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
|
|
|
|
PAYLOAD_SIZE = 1024 * 100
|
|
TIMEOUT = 60
|
|
|
|
def new_loop(self):
|
|
raise NotImplementedError
|
|
|
|
def test_buf_feed_data(self):
|
|
|
|
class Proto(asyncio.BufferedProtocol):
|
|
|
|
def __init__(self, bufsize, usemv):
|
|
self.buf = bytearray(bufsize)
|
|
self.mv = memoryview(self.buf)
|
|
self.data = b''
|
|
self.usemv = usemv
|
|
|
|
def get_buffer(self, sizehint):
|
|
if self.usemv:
|
|
return self.mv
|
|
else:
|
|
return self.buf
|
|
|
|
def buffer_updated(self, nsize):
|
|
if self.usemv:
|
|
self.data += self.mv[:nsize]
|
|
else:
|
|
self.data += self.buf[:nsize]
|
|
|
|
for usemv in [False, True]:
|
|
proto = Proto(1, usemv)
|
|
protocols._feed_data_to_buffered_proto(proto, b'12345')
|
|
self.assertEqual(proto.data, b'12345')
|
|
|
|
proto = Proto(2, usemv)
|
|
protocols._feed_data_to_buffered_proto(proto, b'12345')
|
|
self.assertEqual(proto.data, b'12345')
|
|
|
|
proto = Proto(2, usemv)
|
|
protocols._feed_data_to_buffered_proto(proto, b'1234')
|
|
self.assertEqual(proto.data, b'1234')
|
|
|
|
proto = Proto(4, usemv)
|
|
protocols._feed_data_to_buffered_proto(proto, b'1234')
|
|
self.assertEqual(proto.data, b'1234')
|
|
|
|
proto = Proto(100, usemv)
|
|
protocols._feed_data_to_buffered_proto(proto, b'12345')
|
|
self.assertEqual(proto.data, b'12345')
|
|
|
|
proto = Proto(0, usemv)
|
|
with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
|
|
protocols._feed_data_to_buffered_proto(proto, b'12345')
|
|
|
|
def test_start_tls_client_reg_proto_1(self):
|
|
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
|
|
|
|
server_context = test_utils.simple_server_sslcontext()
|
|
client_context = test_utils.simple_client_sslcontext()
|
|
|
|
def serve(sock):
|
|
sock.settimeout(self.TIMEOUT)
|
|
|
|
data = sock.recv_all(len(HELLO_MSG))
|
|
self.assertEqual(len(data), len(HELLO_MSG))
|
|
|
|
sock.start_tls(server_context, server_side=True)
|
|
|
|
sock.sendall(b'O')
|
|
data = sock.recv_all(len(HELLO_MSG))
|
|
self.assertEqual(len(data), len(HELLO_MSG))
|
|
|
|
sock.shutdown(socket.SHUT_RDWR)
|
|
sock.close()
|
|
|
|
class ClientProto(asyncio.Protocol):
|
|
def __init__(self, on_data, on_eof):
|
|
self.on_data = on_data
|
|
self.on_eof = on_eof
|
|
self.con_made_cnt = 0
|
|
|
|
def connection_made(proto, tr):
|
|
proto.con_made_cnt += 1
|
|
# Ensure connection_made gets called only once.
|
|
self.assertEqual(proto.con_made_cnt, 1)
|
|
|
|
def data_received(self, data):
|
|
self.on_data.set_result(data)
|
|
|
|
def eof_received(self):
|
|
self.on_eof.set_result(True)
|
|
|
|
async def client(addr):
|
|
await asyncio.sleep(0.5)
|
|
|
|
on_data = self.loop.create_future()
|
|
on_eof = self.loop.create_future()
|
|
|
|
tr, proto = await self.loop.create_connection(
|
|
lambda: ClientProto(on_data, on_eof), *addr)
|
|
|
|
tr.write(HELLO_MSG)
|
|
new_tr = await self.loop.start_tls(tr, proto, client_context)
|
|
|
|
self.assertEqual(await on_data, b'O')
|
|
new_tr.write(HELLO_MSG)
|
|
await on_eof
|
|
|
|
new_tr.close()
|
|
|
|
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
|
|
self.loop.run_until_complete(
|
|
asyncio.wait_for(client(srv.addr), timeout=10))
|
|
|
|
def test_start_tls_client_buf_proto_1(self):
|
|
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
|
|
|
|
server_context = test_utils.simple_server_sslcontext()
|
|
client_context = test_utils.simple_client_sslcontext()
|
|
client_con_made_calls = 0
|
|
|
|
def serve(sock):
|
|
sock.settimeout(self.TIMEOUT)
|
|
|
|
data = sock.recv_all(len(HELLO_MSG))
|
|
self.assertEqual(len(data), len(HELLO_MSG))
|
|
|
|
sock.start_tls(server_context, server_side=True)
|
|
|
|
sock.sendall(b'O')
|
|
data = sock.recv_all(len(HELLO_MSG))
|
|
self.assertEqual(len(data), len(HELLO_MSG))
|
|
|
|
sock.sendall(b'2')
|
|
data = sock.recv_all(len(HELLO_MSG))
|
|
self.assertEqual(len(data), len(HELLO_MSG))
|
|
|
|
sock.shutdown(socket.SHUT_RDWR)
|
|
sock.close()
|
|
|
|
class ClientProtoFirst(asyncio.BufferedProtocol):
|
|
def __init__(self, on_data):
|
|
self.on_data = on_data
|
|
self.buf = bytearray(1)
|
|
|
|
def connection_made(self, tr):
|
|
nonlocal client_con_made_calls
|
|
client_con_made_calls += 1
|
|
|
|
def get_buffer(self, sizehint):
|
|
return self.buf
|
|
|
|
def buffer_updated(self, nsize):
|
|
assert nsize == 1
|
|
self.on_data.set_result(bytes(self.buf[:nsize]))
|
|
|
|
class ClientProtoSecond(asyncio.Protocol):
|
|
def __init__(self, on_data, on_eof):
|
|
self.on_data = on_data
|
|
self.on_eof = on_eof
|
|
self.con_made_cnt = 0
|
|
|
|
def connection_made(self, tr):
|
|
nonlocal client_con_made_calls
|
|
client_con_made_calls += 1
|
|
|
|
def data_received(self, data):
|
|
self.on_data.set_result(data)
|
|
|
|
def eof_received(self):
|
|
self.on_eof.set_result(True)
|
|
|
|
async def client(addr):
|
|
await asyncio.sleep(0.5)
|
|
|
|
on_data1 = self.loop.create_future()
|
|
on_data2 = self.loop.create_future()
|
|
on_eof = self.loop.create_future()
|
|
|
|
tr, proto = await self.loop.create_connection(
|
|
lambda: ClientProtoFirst(on_data1), *addr)
|
|
|
|
tr.write(HELLO_MSG)
|
|
new_tr = await self.loop.start_tls(tr, proto, client_context)
|
|
|
|
self.assertEqual(await on_data1, b'O')
|
|
new_tr.write(HELLO_MSG)
|
|
|
|
new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
|
|
self.assertEqual(await on_data2, b'2')
|
|
new_tr.write(HELLO_MSG)
|
|
await on_eof
|
|
|
|
new_tr.close()
|
|
|
|
# connection_made() should be called only once -- when
|
|
# we establish connection for the first time. Start TLS
|
|
# doesn't call connection_made() on application protocols.
|
|
self.assertEqual(client_con_made_calls, 1)
|
|
|
|
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
|
|
self.loop.run_until_complete(
|
|
asyncio.wait_for(client(srv.addr),
|
|
timeout=self.TIMEOUT))
|
|
|
|
def test_start_tls_slow_client_cancel(self):
|
|
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
|
|
|
|
client_context = test_utils.simple_client_sslcontext()
|
|
server_waits_on_handshake = self.loop.create_future()
|
|
|
|
def serve(sock):
|
|
sock.settimeout(self.TIMEOUT)
|
|
|
|
data = sock.recv_all(len(HELLO_MSG))
|
|
self.assertEqual(len(data), len(HELLO_MSG))
|
|
|
|
try:
|
|
self.loop.call_soon_threadsafe(
|
|
server_waits_on_handshake.set_result, None)
|
|
data = sock.recv_all(1024 * 1024)
|
|
except ConnectionAbortedError:
|
|
pass
|
|
finally:
|
|
sock.close()
|
|
|
|
class ClientProto(asyncio.Protocol):
|
|
def __init__(self, on_data, on_eof):
|
|
self.on_data = on_data
|
|
self.on_eof = on_eof
|
|
self.con_made_cnt = 0
|
|
|
|
def connection_made(proto, tr):
|
|
proto.con_made_cnt += 1
|
|
# Ensure connection_made gets called only once.
|
|
self.assertEqual(proto.con_made_cnt, 1)
|
|
|
|
def data_received(self, data):
|
|
self.on_data.set_result(data)
|
|
|
|
def eof_received(self):
|
|
self.on_eof.set_result(True)
|
|
|
|
async def client(addr):
|
|
await asyncio.sleep(0.5)
|
|
|
|
on_data = self.loop.create_future()
|
|
on_eof = self.loop.create_future()
|
|
|
|
tr, proto = await self.loop.create_connection(
|
|
lambda: ClientProto(on_data, on_eof), *addr)
|
|
|
|
tr.write(HELLO_MSG)
|
|
|
|
await server_waits_on_handshake
|
|
|
|
with self.assertRaises(asyncio.TimeoutError):
|
|
await asyncio.wait_for(
|
|
self.loop.start_tls(tr, proto, client_context),
|
|
0.5)
|
|
|
|
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
|
|
self.loop.run_until_complete(
|
|
asyncio.wait_for(client(srv.addr), timeout=10))
|
|
|
|
def test_start_tls_server_1(self):
|
|
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
|
|
|
|
server_context = test_utils.simple_server_sslcontext()
|
|
client_context = test_utils.simple_client_sslcontext()
|
|
if sys.platform.startswith('freebsd'):
|
|
# bpo-35031: Some FreeBSD buildbots fail to run this test
|
|
# as the eof was not being received by the server if the payload
|
|
# size is not big enough. This behaviour only appears if the
|
|
# client is using TLS1.3.
|
|
client_context.options |= ssl.OP_NO_TLSv1_3
|
|
|
|
def client(sock, addr):
|
|
sock.settimeout(self.TIMEOUT)
|
|
|
|
sock.connect(addr)
|
|
data = sock.recv_all(len(HELLO_MSG))
|
|
self.assertEqual(len(data), len(HELLO_MSG))
|
|
|
|
sock.start_tls(client_context)
|
|
sock.sendall(HELLO_MSG)
|
|
|
|
sock.shutdown(socket.SHUT_RDWR)
|
|
sock.close()
|
|
|
|
class ServerProto(asyncio.Protocol):
|
|
def __init__(self, on_con, on_eof, on_con_lost):
|
|
self.on_con = on_con
|
|
self.on_eof = on_eof
|
|
self.on_con_lost = on_con_lost
|
|
self.data = b''
|
|
|
|
def connection_made(self, tr):
|
|
self.on_con.set_result(tr)
|
|
|
|
def data_received(self, data):
|
|
self.data += data
|
|
|
|
def eof_received(self):
|
|
self.on_eof.set_result(1)
|
|
|
|
def connection_lost(self, exc):
|
|
if exc is None:
|
|
self.on_con_lost.set_result(None)
|
|
else:
|
|
self.on_con_lost.set_exception(exc)
|
|
|
|
async def main(proto, on_con, on_eof, on_con_lost):
|
|
tr = await on_con
|
|
tr.write(HELLO_MSG)
|
|
|
|
self.assertEqual(proto.data, b'')
|
|
|
|
new_tr = await self.loop.start_tls(
|
|
tr, proto, server_context,
|
|
server_side=True,
|
|
ssl_handshake_timeout=self.TIMEOUT)
|
|
|
|
await on_eof
|
|
await on_con_lost
|
|
self.assertEqual(proto.data, HELLO_MSG)
|
|
new_tr.close()
|
|
|
|
async def run_main():
|
|
on_con = self.loop.create_future()
|
|
on_eof = self.loop.create_future()
|
|
on_con_lost = self.loop.create_future()
|
|
proto = ServerProto(on_con, on_eof, on_con_lost)
|
|
|
|
server = await self.loop.create_server(
|
|
lambda: proto, '127.0.0.1', 0)
|
|
addr = server.sockets[0].getsockname()
|
|
|
|
with self.tcp_client(lambda sock: client(sock, addr),
|
|
timeout=self.TIMEOUT):
|
|
await asyncio.wait_for(
|
|
main(proto, on_con, on_eof, on_con_lost),
|
|
timeout=self.TIMEOUT)
|
|
|
|
server.close()
|
|
await server.wait_closed()
|
|
|
|
self.loop.run_until_complete(run_main())
|
|
|
|
def test_start_tls_wrong_args(self):
|
|
async def main():
|
|
with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
|
|
await self.loop.start_tls(None, None, None)
|
|
|
|
sslctx = test_utils.simple_server_sslcontext()
|
|
with self.assertRaisesRegex(TypeError, 'is not supported'):
|
|
await self.loop.start_tls(None, None, sslctx)
|
|
|
|
self.loop.run_until_complete(main())
|
|
|
|
def test_handshake_timeout(self):
|
|
# bpo-29970: Check that a connection is aborted if handshake is not
|
|
# completed in timeout period, instead of remaining open indefinitely
|
|
client_sslctx = test_utils.simple_client_sslcontext()
|
|
|
|
messages = []
|
|
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
|
|
|
|
server_side_aborted = False
|
|
|
|
def server(sock):
|
|
nonlocal server_side_aborted
|
|
try:
|
|
sock.recv_all(1024 * 1024)
|
|
except ConnectionAbortedError:
|
|
server_side_aborted = True
|
|
finally:
|
|
sock.close()
|
|
|
|
async def client(addr):
|
|
await asyncio.wait_for(
|
|
self.loop.create_connection(
|
|
asyncio.Protocol,
|
|
*addr,
|
|
ssl=client_sslctx,
|
|
server_hostname='',
|
|
ssl_handshake_timeout=10.0),
|
|
0.5)
|
|
|
|
with self.tcp_server(server,
|
|
max_clients=1,
|
|
backlog=1) as srv:
|
|
|
|
with self.assertRaises(asyncio.TimeoutError):
|
|
self.loop.run_until_complete(client(srv.addr))
|
|
|
|
self.assertTrue(server_side_aborted)
|
|
|
|
# Python issue #23197: cancelling a handshake must not raise an
|
|
# exception or log an error, even if the handshake failed
|
|
self.assertEqual(messages, [])
|
|
|
|
def test_create_connection_ssl_slow_handshake(self):
|
|
client_sslctx = test_utils.simple_client_sslcontext()
|
|
|
|
messages = []
|
|
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
|
|
|
|
def server(sock):
|
|
try:
|
|
sock.recv_all(1024 * 1024)
|
|
except ConnectionAbortedError:
|
|
pass
|
|
finally:
|
|
sock.close()
|
|
|
|
async def client(addr):
|
|
reader, writer = await asyncio.open_connection(
|
|
*addr,
|
|
ssl=client_sslctx,
|
|
server_hostname='',
|
|
loop=self.loop,
|
|
ssl_handshake_timeout=1.0)
|
|
|
|
with self.tcp_server(server,
|
|
max_clients=1,
|
|
backlog=1) as srv:
|
|
|
|
with self.assertRaisesRegex(
|
|
ConnectionAbortedError,
|
|
r'SSL handshake.*is taking longer'):
|
|
|
|
self.loop.run_until_complete(client(srv.addr))
|
|
|
|
self.assertEqual(messages, [])
|
|
|
|
def test_create_connection_ssl_failed_certificate(self):
|
|
self.loop.set_exception_handler(lambda loop, ctx: None)
|
|
|
|
sslctx = test_utils.simple_server_sslcontext()
|
|
client_sslctx = test_utils.simple_client_sslcontext(
|
|
disable_verify=False)
|
|
|
|
def server(sock):
|
|
try:
|
|
sock.start_tls(
|
|
sslctx,
|
|
server_side=True)
|
|
except ssl.SSLError:
|
|
pass
|
|
except OSError:
|
|
pass
|
|
finally:
|
|
sock.close()
|
|
|
|
async def client(addr):
|
|
reader, writer = await asyncio.open_connection(
|
|
*addr,
|
|
ssl=client_sslctx,
|
|
server_hostname='',
|
|
loop=self.loop,
|
|
ssl_handshake_timeout=1.0)
|
|
|
|
with self.tcp_server(server,
|
|
max_clients=1,
|
|
backlog=1) as srv:
|
|
|
|
with self.assertRaises(ssl.SSLCertVerificationError):
|
|
self.loop.run_until_complete(client(srv.addr))
|
|
|
|
def test_start_tls_client_corrupted_ssl(self):
|
|
self.loop.set_exception_handler(lambda loop, ctx: None)
|
|
|
|
sslctx = test_utils.simple_server_sslcontext()
|
|
client_sslctx = test_utils.simple_client_sslcontext()
|
|
|
|
def server(sock):
|
|
orig_sock = sock.dup()
|
|
try:
|
|
sock.start_tls(
|
|
sslctx,
|
|
server_side=True)
|
|
sock.sendall(b'A\n')
|
|
sock.recv_all(1)
|
|
orig_sock.send(b'please corrupt the SSL connection')
|
|
except ssl.SSLError:
|
|
pass
|
|
finally:
|
|
orig_sock.close()
|
|
sock.close()
|
|
|
|
async def client(addr):
|
|
reader, writer = await asyncio.open_connection(
|
|
*addr,
|
|
ssl=client_sslctx,
|
|
server_hostname='',
|
|
loop=self.loop)
|
|
|
|
self.assertEqual(await reader.readline(), b'A\n')
|
|
writer.write(b'B')
|
|
with self.assertRaises(ssl.SSLError):
|
|
await reader.readline()
|
|
|
|
writer.close()
|
|
return 'OK'
|
|
|
|
with self.tcp_server(server,
|
|
max_clients=1,
|
|
backlog=1) as srv:
|
|
|
|
res = self.loop.run_until_complete(client(srv.addr))
|
|
|
|
self.assertEqual(res, 'OK')
|
|
|
|
|
|
@unittest.skipIf(ssl is None, 'No ssl module')
|
|
class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):
|
|
|
|
def new_loop(self):
|
|
return asyncio.SelectorEventLoop()
|
|
|
|
|
|
@unittest.skipIf(ssl is None, 'No ssl module')
|
|
@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
|
|
class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase):
|
|
|
|
def new_loop(self):
|
|
return asyncio.ProactorEventLoop()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|