mirror of
https://github.com/python/cpython.git
synced 2025-09-26 10:19:53 +00:00
bpo-23749: Implement loop.start_tls() (#5039)
This commit is contained in:
parent
bbdb17d19b
commit
f111b3dcb4
10 changed files with 580 additions and 54 deletions
|
@ -537,6 +537,38 @@ Creating listening connections
|
||||||
.. versionadded:: 3.5.3
|
.. versionadded:: 3.5.3
|
||||||
|
|
||||||
|
|
||||||
|
TLS Upgrade
|
||||||
|
-----------
|
||||||
|
|
||||||
|
.. coroutinemethod:: AbstractEventLoop.start_tls(transport, protocol, sslcontext, \*, server_side=False, server_hostname=None, ssl_handshake_timeout=None)
|
||||||
|
|
||||||
|
Upgrades an existing connection to TLS.
|
||||||
|
|
||||||
|
Returns a new transport instance, that the *protocol* must start using
|
||||||
|
immediately after the *await*. The *transport* instance passed to
|
||||||
|
the *start_tls* method should never be used again.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
|
||||||
|
* *transport* and *protocol* instances that methods like
|
||||||
|
:meth:`~AbstractEventLoop.create_server` and
|
||||||
|
:meth:`~AbstractEventLoop.create_connection` return.
|
||||||
|
|
||||||
|
* *sslcontext*: a configured instance of :class:`~ssl.SSLContext`.
|
||||||
|
|
||||||
|
* *server_side* pass ``True`` when a server-side connection is being
|
||||||
|
upgraded (like the one created by :meth:`~AbstractEventLoop.create_server`).
|
||||||
|
|
||||||
|
* *server_hostname*: sets or overrides the host name that the target
|
||||||
|
server's certificate will be matched against.
|
||||||
|
|
||||||
|
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to
|
||||||
|
wait for the SSL handshake to complete before aborting the connection.
|
||||||
|
``10.0`` seconds if ``None`` (default).
|
||||||
|
|
||||||
|
.. versionadded:: 3.7
|
||||||
|
|
||||||
|
|
||||||
Watch file descriptors
|
Watch file descriptors
|
||||||
----------------------
|
----------------------
|
||||||
|
|
||||||
|
|
|
@ -29,9 +29,15 @@ import sys
|
||||||
import warnings
|
import warnings
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ssl
|
||||||
|
except ImportError: # pragma: no cover
|
||||||
|
ssl = None
|
||||||
|
|
||||||
from . import coroutines
|
from . import coroutines
|
||||||
from . import events
|
from . import events
|
||||||
from . import futures
|
from . import futures
|
||||||
|
from . import sslproto
|
||||||
from . import tasks
|
from . import tasks
|
||||||
from .log import logger
|
from .log import logger
|
||||||
|
|
||||||
|
@ -279,7 +285,8 @@ class BaseEventLoop(events.AbstractEventLoop):
|
||||||
self, rawsock, protocol, sslcontext, waiter=None,
|
self, rawsock, protocol, sslcontext, waiter=None,
|
||||||
*, server_side=False, server_hostname=None,
|
*, server_side=False, server_hostname=None,
|
||||||
extra=None, server=None,
|
extra=None, server=None,
|
||||||
ssl_handshake_timeout=None):
|
ssl_handshake_timeout=None,
|
||||||
|
call_connection_made=True):
|
||||||
"""Create SSL transport."""
|
"""Create SSL transport."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -795,6 +802,42 @@ class BaseEventLoop(events.AbstractEventLoop):
|
||||||
|
|
||||||
return transport, protocol
|
return transport, protocol
|
||||||
|
|
||||||
|
async def start_tls(self, transport, protocol, sslcontext, *,
|
||||||
|
server_side=False,
|
||||||
|
server_hostname=None,
|
||||||
|
ssl_handshake_timeout=None):
|
||||||
|
"""Upgrade transport to TLS.
|
||||||
|
|
||||||
|
Return a new transport that *protocol* should start using
|
||||||
|
immediately.
|
||||||
|
"""
|
||||||
|
if ssl is None:
|
||||||
|
raise RuntimeError('Python ssl module is not available')
|
||||||
|
|
||||||
|
if not isinstance(sslcontext, ssl.SSLContext):
|
||||||
|
raise TypeError(
|
||||||
|
f'sslcontext is expected to be an instance of ssl.SSLContext, '
|
||||||
|
f'got {sslcontext!r}')
|
||||||
|
|
||||||
|
if not getattr(transport, '_start_tls_compatible', False):
|
||||||
|
raise TypeError(
|
||||||
|
f'transport {self!r} is not supported by start_tls()')
|
||||||
|
|
||||||
|
waiter = self.create_future()
|
||||||
|
ssl_protocol = sslproto.SSLProtocol(
|
||||||
|
self, protocol, sslcontext, waiter,
|
||||||
|
server_side, server_hostname,
|
||||||
|
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||||
|
call_connection_made=False)
|
||||||
|
|
||||||
|
transport.set_protocol(ssl_protocol)
|
||||||
|
self.call_soon(ssl_protocol.connection_made, transport)
|
||||||
|
if not transport.is_reading():
|
||||||
|
self.call_soon(transport.resume_reading)
|
||||||
|
|
||||||
|
await waiter
|
||||||
|
return ssl_protocol._app_transport
|
||||||
|
|
||||||
async def create_datagram_endpoint(self, protocol_factory,
|
async def create_datagram_endpoint(self, protocol_factory,
|
||||||
local_addr=None, remote_addr=None, *,
|
local_addr=None, remote_addr=None, *,
|
||||||
family=0, proto=0, flags=0,
|
family=0, proto=0, flags=0,
|
||||||
|
|
|
@ -305,6 +305,17 @@ class AbstractEventLoop:
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def start_tls(self, transport, protocol, sslcontext, *,
|
||||||
|
server_side=False,
|
||||||
|
server_hostname=None,
|
||||||
|
ssl_handshake_timeout=None):
|
||||||
|
"""Upgrade a transport to TLS.
|
||||||
|
|
||||||
|
Return a new transport that *protocol* should start using
|
||||||
|
immediately.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def create_unix_connection(
|
async def create_unix_connection(
|
||||||
self, protocol_factory, path=None, *,
|
self, protocol_factory, path=None, *,
|
||||||
ssl=None, sock=None,
|
ssl=None, sock=None,
|
||||||
|
|
|
@ -223,6 +223,8 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport,
|
||||||
transports.WriteTransport):
|
transports.WriteTransport):
|
||||||
"""Transport for write pipes."""
|
"""Transport for write pipes."""
|
||||||
|
|
||||||
|
_start_tls_compatible = True
|
||||||
|
|
||||||
def write(self, data):
|
def write(self, data):
|
||||||
if not isinstance(data, (bytes, bytearray, memoryview)):
|
if not isinstance(data, (bytes, bytearray, memoryview)):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
|
|
@ -694,6 +694,8 @@ class _SelectorTransport(transports._FlowControlMixin,
|
||||||
|
|
||||||
class _SelectorSocketTransport(_SelectorTransport):
|
class _SelectorSocketTransport(_SelectorTransport):
|
||||||
|
|
||||||
|
_start_tls_compatible = True
|
||||||
|
|
||||||
def __init__(self, loop, sock, protocol, waiter=None,
|
def __init__(self, loop, sock, protocol, waiter=None,
|
||||||
extra=None, server=None):
|
extra=None, server=None):
|
||||||
super().__init__(loop, sock, protocol, extra, server)
|
super().__init__(loop, sock, protocol, extra, server)
|
||||||
|
|
279
Lib/test/test_asyncio/functional.py
Normal file
279
Lib/test/test_asyncio/functional.py
Normal file
|
@ -0,0 +1,279 @@
|
||||||
|
import asyncio
|
||||||
|
import asyncio.events
|
||||||
|
import contextlib
|
||||||
|
import os
|
||||||
|
import pprint
|
||||||
|
import select
|
||||||
|
import socket
|
||||||
|
import ssl
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionalTestCaseMixin:
|
||||||
|
|
||||||
|
def new_loop(self):
|
||||||
|
return asyncio.new_event_loop()
|
||||||
|
|
||||||
|
def run_loop_briefly(self, *, delay=0.01):
|
||||||
|
self.loop.run_until_complete(asyncio.sleep(delay, loop=self.loop))
|
||||||
|
|
||||||
|
def loop_exception_handler(self, loop, context):
|
||||||
|
self.__unhandled_exceptions.append(context)
|
||||||
|
self.loop.default_exception_handler(context)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.loop = self.new_loop()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
self.loop.set_exception_handler(self.loop_exception_handler)
|
||||||
|
self.__unhandled_exceptions = []
|
||||||
|
|
||||||
|
# Disable `_get_running_loop`.
|
||||||
|
self._old_get_running_loop = asyncio.events._get_running_loop
|
||||||
|
asyncio.events._get_running_loop = lambda: None
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
try:
|
||||||
|
self.loop.close()
|
||||||
|
|
||||||
|
if self.__unhandled_exceptions:
|
||||||
|
print('Unexpected calls to loop.call_exception_handler():')
|
||||||
|
pprint.pprint(self.__unhandled_exceptions)
|
||||||
|
self.fail('unexpected calls to loop.call_exception_handler()')
|
||||||
|
|
||||||
|
finally:
|
||||||
|
asyncio.events._get_running_loop = self._old_get_running_loop
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
self.loop = None
|
||||||
|
|
||||||
|
def tcp_server(self, server_prog, *,
|
||||||
|
family=socket.AF_INET,
|
||||||
|
addr=None,
|
||||||
|
timeout=5,
|
||||||
|
backlog=1,
|
||||||
|
max_clients=10):
|
||||||
|
|
||||||
|
if addr is None:
|
||||||
|
if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX:
|
||||||
|
with tempfile.NamedTemporaryFile() as tmp:
|
||||||
|
addr = tmp.name
|
||||||
|
else:
|
||||||
|
addr = ('127.0.0.1', 0)
|
||||||
|
|
||||||
|
sock = socket.socket(family, socket.SOCK_STREAM)
|
||||||
|
|
||||||
|
if timeout is None:
|
||||||
|
raise RuntimeError('timeout is required')
|
||||||
|
if timeout <= 0:
|
||||||
|
raise RuntimeError('only blocking sockets are supported')
|
||||||
|
sock.settimeout(timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
|
sock.bind(addr)
|
||||||
|
sock.listen(backlog)
|
||||||
|
except OSError as ex:
|
||||||
|
sock.close()
|
||||||
|
raise ex
|
||||||
|
|
||||||
|
return TestThreadedServer(
|
||||||
|
self, sock, server_prog, timeout, max_clients)
|
||||||
|
|
||||||
|
def tcp_client(self, client_prog,
|
||||||
|
family=socket.AF_INET,
|
||||||
|
timeout=10):
|
||||||
|
|
||||||
|
sock = socket.socket(family, socket.SOCK_STREAM)
|
||||||
|
|
||||||
|
if timeout is None:
|
||||||
|
raise RuntimeError('timeout is required')
|
||||||
|
if timeout <= 0:
|
||||||
|
raise RuntimeError('only blocking sockets are supported')
|
||||||
|
sock.settimeout(timeout)
|
||||||
|
|
||||||
|
return TestThreadedClient(
|
||||||
|
self, sock, client_prog, timeout)
|
||||||
|
|
||||||
|
def unix_server(self, *args, **kwargs):
|
||||||
|
if not hasattr(socket, 'AF_UNIX'):
|
||||||
|
raise NotImplementedError
|
||||||
|
return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
|
||||||
|
|
||||||
|
def unix_client(self, *args, **kwargs):
|
||||||
|
if not hasattr(socket, 'AF_UNIX'):
|
||||||
|
raise NotImplementedError
|
||||||
|
return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def unix_sock_name(self):
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
fn = os.path.join(td, 'sock')
|
||||||
|
try:
|
||||||
|
yield fn
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
os.unlink(fn)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _abort_socket_test(self, ex):
|
||||||
|
try:
|
||||||
|
self.loop.stop()
|
||||||
|
finally:
|
||||||
|
self.fail(ex)
|
||||||
|
|
||||||
|
|
||||||
|
##############################################################################
|
||||||
|
# Socket Testing Utilities
|
||||||
|
##############################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class TestSocketWrapper:
|
||||||
|
|
||||||
|
def __init__(self, sock):
|
||||||
|
self.__sock = sock
|
||||||
|
|
||||||
|
def recv_all(self, n):
|
||||||
|
buf = b''
|
||||||
|
while len(buf) < n:
|
||||||
|
data = self.recv(n - len(buf))
|
||||||
|
if data == b'':
|
||||||
|
raise ConnectionAbortedError
|
||||||
|
buf += data
|
||||||
|
return buf
|
||||||
|
|
||||||
|
def start_tls(self, ssl_context, *,
|
||||||
|
server_side=False,
|
||||||
|
server_hostname=None):
|
||||||
|
|
||||||
|
assert isinstance(ssl_context, ssl.SSLContext)
|
||||||
|
|
||||||
|
ssl_sock = ssl_context.wrap_socket(
|
||||||
|
self.__sock, server_side=server_side,
|
||||||
|
server_hostname=server_hostname,
|
||||||
|
do_handshake_on_connect=False)
|
||||||
|
|
||||||
|
ssl_sock.do_handshake()
|
||||||
|
|
||||||
|
self.__sock.close()
|
||||||
|
self.__sock = ssl_sock
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self.__sock, name)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<{} {!r}>'.format(type(self).__name__, self.__sock)
|
||||||
|
|
||||||
|
|
||||||
|
class SocketThread(threading.Thread):
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self._active = False
|
||||||
|
self.join()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *exc):
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
|
||||||
|
class TestThreadedClient(SocketThread):
|
||||||
|
|
||||||
|
def __init__(self, test, sock, prog, timeout):
|
||||||
|
threading.Thread.__init__(self, None, None, 'test-client')
|
||||||
|
self.daemon = True
|
||||||
|
|
||||||
|
self._timeout = timeout
|
||||||
|
self._sock = sock
|
||||||
|
self._active = True
|
||||||
|
self._prog = prog
|
||||||
|
self._test = test
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
try:
|
||||||
|
self._prog(TestSocketWrapper(self._sock))
|
||||||
|
except Exception as ex:
|
||||||
|
self._test._abort_socket_test(ex)
|
||||||
|
|
||||||
|
|
||||||
|
class TestThreadedServer(SocketThread):
|
||||||
|
|
||||||
|
def __init__(self, test, sock, prog, timeout, max_clients):
|
||||||
|
threading.Thread.__init__(self, None, None, 'test-server')
|
||||||
|
self.daemon = True
|
||||||
|
|
||||||
|
self._clients = 0
|
||||||
|
self._finished_clients = 0
|
||||||
|
self._max_clients = max_clients
|
||||||
|
self._timeout = timeout
|
||||||
|
self._sock = sock
|
||||||
|
self._active = True
|
||||||
|
|
||||||
|
self._prog = prog
|
||||||
|
|
||||||
|
self._s1, self._s2 = socket.socketpair()
|
||||||
|
self._s1.setblocking(False)
|
||||||
|
|
||||||
|
self._test = test
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
try:
|
||||||
|
if self._s2 and self._s2.fileno() != -1:
|
||||||
|
try:
|
||||||
|
self._s2.send(b'stop')
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
super().stop()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
try:
|
||||||
|
with self._sock:
|
||||||
|
self._sock.setblocking(0)
|
||||||
|
self._run()
|
||||||
|
finally:
|
||||||
|
self._s1.close()
|
||||||
|
self._s2.close()
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
while self._active:
|
||||||
|
if self._clients >= self._max_clients:
|
||||||
|
return
|
||||||
|
|
||||||
|
r, w, x = select.select(
|
||||||
|
[self._sock, self._s1], [], [], self._timeout)
|
||||||
|
|
||||||
|
if self._s1 in r:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._sock in r:
|
||||||
|
try:
|
||||||
|
conn, addr = self._sock.accept()
|
||||||
|
except BlockingIOError:
|
||||||
|
continue
|
||||||
|
except socket.timeout:
|
||||||
|
if not self._active:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
self._clients += 1
|
||||||
|
conn.settimeout(self._timeout)
|
||||||
|
try:
|
||||||
|
with conn:
|
||||||
|
self._handle_client(conn)
|
||||||
|
except Exception as ex:
|
||||||
|
self._active = False
|
||||||
|
try:
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._test._abort_socket_test(ex)
|
||||||
|
|
||||||
|
def _handle_client(self, sock):
|
||||||
|
self._prog(TestSocketWrapper(sock))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def addr(self):
|
||||||
|
return self._sock.getsockname()
|
|
@ -31,21 +31,7 @@ from asyncio import events
|
||||||
from asyncio import proactor_events
|
from asyncio import proactor_events
|
||||||
from asyncio import selector_events
|
from asyncio import selector_events
|
||||||
from test.test_asyncio import utils as test_utils
|
from test.test_asyncio import utils as test_utils
|
||||||
try:
|
from test import support
|
||||||
from test import support
|
|
||||||
except ImportError:
|
|
||||||
from asyncio import test_support as support
|
|
||||||
|
|
||||||
|
|
||||||
def data_file(filename):
|
|
||||||
if hasattr(support, 'TEST_HOME_DIR'):
|
|
||||||
fullname = os.path.join(support.TEST_HOME_DIR, filename)
|
|
||||||
if os.path.isfile(fullname):
|
|
||||||
return fullname
|
|
||||||
fullname = os.path.join(os.path.dirname(__file__), filename)
|
|
||||||
if os.path.isfile(fullname):
|
|
||||||
return fullname
|
|
||||||
raise FileNotFoundError(filename)
|
|
||||||
|
|
||||||
|
|
||||||
def osx_tiger():
|
def osx_tiger():
|
||||||
|
@ -80,23 +66,6 @@ class CoroLike:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
ONLYCERT = data_file('ssl_cert.pem')
|
|
||||||
ONLYKEY = data_file('ssl_key.pem')
|
|
||||||
SIGNED_CERTFILE = data_file('keycert3.pem')
|
|
||||||
SIGNING_CA = data_file('pycacert.pem')
|
|
||||||
PEERCERT = {'serialNumber': 'B09264B1F2DA21D1',
|
|
||||||
'version': 1,
|
|
||||||
'subject': ((('countryName', 'XY'),),
|
|
||||||
(('localityName', 'Castle Anthrax'),),
|
|
||||||
(('organizationName', 'Python Software Foundation'),),
|
|
||||||
(('commonName', 'localhost'),)),
|
|
||||||
'issuer': ((('countryName', 'XY'),),
|
|
||||||
(('organizationName', 'Python Software Foundation CA'),),
|
|
||||||
(('commonName', 'our-ca-server'),)),
|
|
||||||
'notAfter': 'Nov 13 19:47:07 2022 GMT',
|
|
||||||
'notBefore': 'Jan 4 19:47:07 2013 GMT'}
|
|
||||||
|
|
||||||
|
|
||||||
class MyBaseProto(asyncio.Protocol):
|
class MyBaseProto(asyncio.Protocol):
|
||||||
connected = None
|
connected = None
|
||||||
done = None
|
done = None
|
||||||
|
@ -853,16 +822,8 @@ class EventLoopTestsMixin:
|
||||||
'SSL not supported with proactor event loops before Python 3.5'
|
'SSL not supported with proactor event loops before Python 3.5'
|
||||||
)
|
)
|
||||||
|
|
||||||
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
server_context = test_utils.simple_server_sslcontext()
|
||||||
server_context.load_cert_chain(ONLYCERT, ONLYKEY)
|
client_context = test_utils.simple_client_sslcontext()
|
||||||
if hasattr(server_context, 'check_hostname'):
|
|
||||||
server_context.check_hostname = False
|
|
||||||
server_context.verify_mode = ssl.CERT_NONE
|
|
||||||
|
|
||||||
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
||||||
if hasattr(server_context, 'check_hostname'):
|
|
||||||
client_context.check_hostname = False
|
|
||||||
client_context.verify_mode = ssl.CERT_NONE
|
|
||||||
|
|
||||||
self.test_connect_accepted_socket(server_context, client_context)
|
self.test_connect_accepted_socket(server_context, client_context)
|
||||||
|
|
||||||
|
@ -1048,7 +1009,7 @@ class EventLoopTestsMixin:
|
||||||
def test_create_server_ssl(self):
|
def test_create_server_ssl(self):
|
||||||
proto = MyProto(loop=self.loop)
|
proto = MyProto(loop=self.loop)
|
||||||
server, host, port = self._make_ssl_server(
|
server, host, port = self._make_ssl_server(
|
||||||
lambda: proto, ONLYCERT, ONLYKEY)
|
lambda: proto, test_utils.ONLYCERT, test_utils.ONLYKEY)
|
||||||
|
|
||||||
f_c = self.loop.create_connection(MyBaseProto, host, port,
|
f_c = self.loop.create_connection(MyBaseProto, host, port,
|
||||||
ssl=test_utils.dummy_ssl_context())
|
ssl=test_utils.dummy_ssl_context())
|
||||||
|
@ -1081,7 +1042,7 @@ class EventLoopTestsMixin:
|
||||||
def test_create_unix_server_ssl(self):
|
def test_create_unix_server_ssl(self):
|
||||||
proto = MyProto(loop=self.loop)
|
proto = MyProto(loop=self.loop)
|
||||||
server, path = self._make_ssl_unix_server(
|
server, path = self._make_ssl_unix_server(
|
||||||
lambda: proto, ONLYCERT, ONLYKEY)
|
lambda: proto, test_utils.ONLYCERT, test_utils.ONLYKEY)
|
||||||
|
|
||||||
f_c = self.loop.create_unix_connection(
|
f_c = self.loop.create_unix_connection(
|
||||||
MyBaseProto, path, ssl=test_utils.dummy_ssl_context(),
|
MyBaseProto, path, ssl=test_utils.dummy_ssl_context(),
|
||||||
|
@ -1111,7 +1072,7 @@ class EventLoopTestsMixin:
|
||||||
def test_create_server_ssl_verify_failed(self):
|
def test_create_server_ssl_verify_failed(self):
|
||||||
proto = MyProto(loop=self.loop)
|
proto = MyProto(loop=self.loop)
|
||||||
server, host, port = self._make_ssl_server(
|
server, host, port = self._make_ssl_server(
|
||||||
lambda: proto, SIGNED_CERTFILE)
|
lambda: proto, test_utils.SIGNED_CERTFILE)
|
||||||
|
|
||||||
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||||
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
||||||
|
@ -1141,7 +1102,7 @@ class EventLoopTestsMixin:
|
||||||
def test_create_unix_server_ssl_verify_failed(self):
|
def test_create_unix_server_ssl_verify_failed(self):
|
||||||
proto = MyProto(loop=self.loop)
|
proto = MyProto(loop=self.loop)
|
||||||
server, path = self._make_ssl_unix_server(
|
server, path = self._make_ssl_unix_server(
|
||||||
lambda: proto, SIGNED_CERTFILE)
|
lambda: proto, test_utils.SIGNED_CERTFILE)
|
||||||
|
|
||||||
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||||
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
||||||
|
@ -1170,13 +1131,13 @@ class EventLoopTestsMixin:
|
||||||
def test_create_server_ssl_match_failed(self):
|
def test_create_server_ssl_match_failed(self):
|
||||||
proto = MyProto(loop=self.loop)
|
proto = MyProto(loop=self.loop)
|
||||||
server, host, port = self._make_ssl_server(
|
server, host, port = self._make_ssl_server(
|
||||||
lambda: proto, SIGNED_CERTFILE)
|
lambda: proto, test_utils.SIGNED_CERTFILE)
|
||||||
|
|
||||||
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||||
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
||||||
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
|
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
|
||||||
sslcontext_client.load_verify_locations(
|
sslcontext_client.load_verify_locations(
|
||||||
cafile=SIGNING_CA)
|
cafile=test_utils.SIGNING_CA)
|
||||||
if hasattr(sslcontext_client, 'check_hostname'):
|
if hasattr(sslcontext_client, 'check_hostname'):
|
||||||
sslcontext_client.check_hostname = True
|
sslcontext_client.check_hostname = True
|
||||||
|
|
||||||
|
@ -1199,12 +1160,12 @@ class EventLoopTestsMixin:
|
||||||
def test_create_unix_server_ssl_verified(self):
|
def test_create_unix_server_ssl_verified(self):
|
||||||
proto = MyProto(loop=self.loop)
|
proto = MyProto(loop=self.loop)
|
||||||
server, path = self._make_ssl_unix_server(
|
server, path = self._make_ssl_unix_server(
|
||||||
lambda: proto, SIGNED_CERTFILE)
|
lambda: proto, test_utils.SIGNED_CERTFILE)
|
||||||
|
|
||||||
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||||
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
||||||
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
|
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
|
||||||
sslcontext_client.load_verify_locations(cafile=SIGNING_CA)
|
sslcontext_client.load_verify_locations(cafile=test_utils.SIGNING_CA)
|
||||||
if hasattr(sslcontext_client, 'check_hostname'):
|
if hasattr(sslcontext_client, 'check_hostname'):
|
||||||
sslcontext_client.check_hostname = True
|
sslcontext_client.check_hostname = True
|
||||||
|
|
||||||
|
@ -1224,12 +1185,12 @@ class EventLoopTestsMixin:
|
||||||
def test_create_server_ssl_verified(self):
|
def test_create_server_ssl_verified(self):
|
||||||
proto = MyProto(loop=self.loop)
|
proto = MyProto(loop=self.loop)
|
||||||
server, host, port = self._make_ssl_server(
|
server, host, port = self._make_ssl_server(
|
||||||
lambda: proto, SIGNED_CERTFILE)
|
lambda: proto, test_utils.SIGNED_CERTFILE)
|
||||||
|
|
||||||
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||||
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
||||||
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
|
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
|
||||||
sslcontext_client.load_verify_locations(cafile=SIGNING_CA)
|
sslcontext_client.load_verify_locations(cafile=test_utils.SIGNING_CA)
|
||||||
if hasattr(sslcontext_client, 'check_hostname'):
|
if hasattr(sslcontext_client, 'check_hostname'):
|
||||||
sslcontext_client.check_hostname = True
|
sslcontext_client.check_hostname = True
|
||||||
|
|
||||||
|
@ -1241,7 +1202,7 @@ class EventLoopTestsMixin:
|
||||||
|
|
||||||
# extra info is available
|
# extra info is available
|
||||||
self.check_ssl_extra_info(client,peername=(host, port),
|
self.check_ssl_extra_info(client,peername=(host, port),
|
||||||
peercert=PEERCERT)
|
peercert=test_utils.PEERCERT)
|
||||||
|
|
||||||
# close connection
|
# close connection
|
||||||
proto.transport.close()
|
proto.transport.close()
|
||||||
|
|
|
@ -13,6 +13,7 @@ from asyncio import log
|
||||||
from asyncio import sslproto
|
from asyncio import sslproto
|
||||||
from asyncio import tasks
|
from asyncio import tasks
|
||||||
from test.test_asyncio import utils as test_utils
|
from test.test_asyncio import utils as test_utils
|
||||||
|
from test.test_asyncio import functional as func_tests
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||||
|
@ -158,5 +159,156 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
||||||
self.assertIs(ssl_proto._app_protocol, new_app_proto)
|
self.assertIs(ssl_proto._app_protocol, new_app_proto)
|
||||||
|
|
||||||
|
|
||||||
|
##############################################################################
|
||||||
|
# Start TLS Tests
|
||||||
|
##############################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
|
||||||
|
|
||||||
|
def new_loop(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def test_start_tls_client_1(self):
|
||||||
|
HELLO_MSG = b'1' * 1024 * 1024 * 5
|
||||||
|
|
||||||
|
server_context = test_utils.simple_server_sslcontext()
|
||||||
|
client_context = test_utils.simple_client_sslcontext()
|
||||||
|
|
||||||
|
def serve(sock):
|
||||||
|
data = sock.recv_all(len(HELLO_MSG))
|
||||||
|
self.assertEqual(len(data), len(HELLO_MSG))
|
||||||
|
|
||||||
|
sock.start_tls(server_context, server_side=True)
|
||||||
|
|
||||||
|
sock.sendall(b'O')
|
||||||
|
data = sock.recv_all(len(HELLO_MSG))
|
||||||
|
self.assertEqual(len(data), len(HELLO_MSG))
|
||||||
|
sock.close()
|
||||||
|
|
||||||
|
class ClientProto(asyncio.Protocol):
|
||||||
|
def __init__(self, on_data, on_eof):
|
||||||
|
self.on_data = on_data
|
||||||
|
self.on_eof = on_eof
|
||||||
|
self.con_made_cnt = 0
|
||||||
|
|
||||||
|
def connection_made(proto, tr):
|
||||||
|
proto.con_made_cnt += 1
|
||||||
|
# Ensure connection_made gets called only once.
|
||||||
|
self.assertEqual(proto.con_made_cnt, 1)
|
||||||
|
|
||||||
|
def data_received(self, data):
|
||||||
|
self.on_data.set_result(data)
|
||||||
|
|
||||||
|
def eof_received(self):
|
||||||
|
self.on_eof.set_result(True)
|
||||||
|
|
||||||
|
async def client(addr):
|
||||||
|
on_data = self.loop.create_future()
|
||||||
|
on_eof = self.loop.create_future()
|
||||||
|
|
||||||
|
tr, proto = await self.loop.create_connection(
|
||||||
|
lambda: ClientProto(on_data, on_eof), *addr)
|
||||||
|
|
||||||
|
tr.write(HELLO_MSG)
|
||||||
|
new_tr = await self.loop.start_tls(tr, proto, client_context)
|
||||||
|
|
||||||
|
self.assertEqual(await on_data, b'O')
|
||||||
|
new_tr.write(HELLO_MSG)
|
||||||
|
await on_eof
|
||||||
|
|
||||||
|
new_tr.close()
|
||||||
|
|
||||||
|
with self.tcp_server(serve) as srv:
|
||||||
|
self.loop.run_until_complete(
|
||||||
|
asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
|
||||||
|
|
||||||
|
def test_start_tls_server_1(self):
|
||||||
|
HELLO_MSG = b'1' * 1024 * 1024 * 5
|
||||||
|
|
||||||
|
server_context = test_utils.simple_server_sslcontext()
|
||||||
|
client_context = test_utils.simple_client_sslcontext()
|
||||||
|
|
||||||
|
def client(sock, addr):
|
||||||
|
sock.connect(addr)
|
||||||
|
data = sock.recv_all(len(HELLO_MSG))
|
||||||
|
self.assertEqual(len(data), len(HELLO_MSG))
|
||||||
|
|
||||||
|
sock.start_tls(client_context)
|
||||||
|
sock.sendall(HELLO_MSG)
|
||||||
|
sock.close()
|
||||||
|
|
||||||
|
class ServerProto(asyncio.Protocol):
|
||||||
|
def __init__(self, on_con, on_eof):
|
||||||
|
self.on_con = on_con
|
||||||
|
self.on_eof = on_eof
|
||||||
|
self.data = b''
|
||||||
|
|
||||||
|
def connection_made(self, tr):
|
||||||
|
self.on_con.set_result(tr)
|
||||||
|
|
||||||
|
def data_received(self, data):
|
||||||
|
self.data += data
|
||||||
|
|
||||||
|
def eof_received(self):
|
||||||
|
self.on_eof.set_result(1)
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
tr = await on_con
|
||||||
|
tr.write(HELLO_MSG)
|
||||||
|
|
||||||
|
self.assertEqual(proto.data, b'')
|
||||||
|
|
||||||
|
new_tr = await self.loop.start_tls(
|
||||||
|
tr, proto, server_context,
|
||||||
|
server_side=True)
|
||||||
|
|
||||||
|
await on_eof
|
||||||
|
self.assertEqual(proto.data, HELLO_MSG)
|
||||||
|
new_tr.close()
|
||||||
|
|
||||||
|
server.close()
|
||||||
|
await server.wait_closed()
|
||||||
|
|
||||||
|
on_con = self.loop.create_future()
|
||||||
|
on_eof = self.loop.create_future()
|
||||||
|
proto = ServerProto(on_con, on_eof)
|
||||||
|
|
||||||
|
server = self.loop.run_until_complete(
|
||||||
|
self.loop.create_server(
|
||||||
|
lambda: proto, '127.0.0.1', 0))
|
||||||
|
addr = server.sockets[0].getsockname()
|
||||||
|
|
||||||
|
with self.tcp_client(lambda sock: client(sock, addr)):
|
||||||
|
self.loop.run_until_complete(
|
||||||
|
asyncio.wait_for(main(), loop=self.loop, timeout=10))
|
||||||
|
|
||||||
|
def test_start_tls_wrong_args(self):
|
||||||
|
async def main():
|
||||||
|
with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
|
||||||
|
await self.loop.start_tls(None, None, None)
|
||||||
|
|
||||||
|
sslctx = test_utils.simple_server_sslcontext()
|
||||||
|
with self.assertRaisesRegex(TypeError, 'is not supported'):
|
||||||
|
await self.loop.start_tls(None, None, sslctx)
|
||||||
|
|
||||||
|
self.loop.run_until_complete(main())
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||||
|
class SelectorStartTLS(BaseStartTLS, unittest.TestCase):
|
||||||
|
|
||||||
|
def new_loop(self):
|
||||||
|
return asyncio.SelectorEventLoop()
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||||
|
@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
|
||||||
|
class ProactorStartTLS(BaseStartTLS, unittest.TestCase):
|
||||||
|
|
||||||
|
def new_loop(self):
|
||||||
|
return asyncio.ProactorEventLoop()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -35,6 +35,49 @@ from asyncio.log import logger
|
||||||
from test import support
|
from test import support
|
||||||
|
|
||||||
|
|
||||||
|
def data_file(filename):
|
||||||
|
if hasattr(support, 'TEST_HOME_DIR'):
|
||||||
|
fullname = os.path.join(support.TEST_HOME_DIR, filename)
|
||||||
|
if os.path.isfile(fullname):
|
||||||
|
return fullname
|
||||||
|
fullname = os.path.join(os.path.dirname(__file__), filename)
|
||||||
|
if os.path.isfile(fullname):
|
||||||
|
return fullname
|
||||||
|
raise FileNotFoundError(filename)
|
||||||
|
|
||||||
|
|
||||||
|
ONLYCERT = data_file('ssl_cert.pem')
|
||||||
|
ONLYKEY = data_file('ssl_key.pem')
|
||||||
|
SIGNED_CERTFILE = data_file('keycert3.pem')
|
||||||
|
SIGNING_CA = data_file('pycacert.pem')
|
||||||
|
PEERCERT = {'serialNumber': 'B09264B1F2DA21D1',
|
||||||
|
'version': 1,
|
||||||
|
'subject': ((('countryName', 'XY'),),
|
||||||
|
(('localityName', 'Castle Anthrax'),),
|
||||||
|
(('organizationName', 'Python Software Foundation'),),
|
||||||
|
(('commonName', 'localhost'),)),
|
||||||
|
'issuer': ((('countryName', 'XY'),),
|
||||||
|
(('organizationName', 'Python Software Foundation CA'),),
|
||||||
|
(('commonName', 'our-ca-server'),)),
|
||||||
|
'notAfter': 'Nov 13 19:47:07 2022 GMT',
|
||||||
|
'notBefore': 'Jan 4 19:47:07 2013 GMT'}
|
||||||
|
|
||||||
|
|
||||||
|
def simple_server_sslcontext():
|
||||||
|
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
|
server_context.load_cert_chain(ONLYCERT, ONLYKEY)
|
||||||
|
server_context.check_hostname = False
|
||||||
|
server_context.verify_mode = ssl.CERT_NONE
|
||||||
|
return server_context
|
||||||
|
|
||||||
|
|
||||||
|
def simple_client_sslcontext():
|
||||||
|
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||||
|
client_context.check_hostname = False
|
||||||
|
client_context.verify_mode = ssl.CERT_NONE
|
||||||
|
return client_context
|
||||||
|
|
||||||
|
|
||||||
def dummy_ssl_context():
|
def dummy_ssl_context():
|
||||||
if ssl is None:
|
if ssl is None:
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
asyncio: Implement loop.start_tls()
|
Loading…
Add table
Add a link
Reference in a new issue