mirror of
https://github.com/python/cpython.git
synced 2025-09-19 07:00:59 +00:00
bpo-44011: New asyncio ssl implementation (#31275)
* bpo-44011: New asyncio ssl implementation Co-Authored-By: Andrew Svetlov <andrew.svetlov@gmail.com> * fix warning * fix typo Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
This commit is contained in:
parent
3be1a443ca
commit
13c10bfb77
12 changed files with 2478 additions and 527 deletions
|
@ -269,7 +269,7 @@ class _SendfileFallbackProtocol(protocols.Protocol):
|
|||
class Server(events.AbstractServer):
|
||||
|
||||
def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
|
||||
ssl_handshake_timeout):
|
||||
ssl_handshake_timeout, ssl_shutdown_timeout=None):
|
||||
self._loop = loop
|
||||
self._sockets = sockets
|
||||
self._active_count = 0
|
||||
|
@ -278,6 +278,7 @@ class Server(events.AbstractServer):
|
|||
self._backlog = backlog
|
||||
self._ssl_context = ssl_context
|
||||
self._ssl_handshake_timeout = ssl_handshake_timeout
|
||||
self._ssl_shutdown_timeout = ssl_shutdown_timeout
|
||||
self._serving = False
|
||||
self._serving_forever_fut = None
|
||||
|
||||
|
@ -309,7 +310,8 @@ class Server(events.AbstractServer):
|
|||
sock.listen(self._backlog)
|
||||
self._loop._start_serving(
|
||||
self._protocol_factory, sock, self._ssl_context,
|
||||
self, self._backlog, self._ssl_handshake_timeout)
|
||||
self, self._backlog, self._ssl_handshake_timeout,
|
||||
self._ssl_shutdown_timeout)
|
||||
|
||||
def get_loop(self):
|
||||
return self._loop
|
||||
|
@ -463,6 +465,7 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
*, server_side=False, server_hostname=None,
|
||||
extra=None, server=None,
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None,
|
||||
call_connection_made=True):
|
||||
"""Create SSL transport."""
|
||||
raise NotImplementedError
|
||||
|
@ -965,6 +968,7 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
proto=0, flags=0, sock=None,
|
||||
local_addr=None, server_hostname=None,
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None,
|
||||
happy_eyeballs_delay=None, interleave=None):
|
||||
"""Connect to a TCP server.
|
||||
|
||||
|
@ -1000,6 +1004,10 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
raise ValueError(
|
||||
'ssl_handshake_timeout is only meaningful with ssl')
|
||||
|
||||
if ssl_shutdown_timeout is not None and not ssl:
|
||||
raise ValueError(
|
||||
'ssl_shutdown_timeout is only meaningful with ssl')
|
||||
|
||||
if happy_eyeballs_delay is not None and interleave is None:
|
||||
# If using happy eyeballs, default to interleave addresses by family
|
||||
interleave = 1
|
||||
|
@ -1075,7 +1083,8 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
|
||||
transport, protocol = await self._create_connection_transport(
|
||||
sock, protocol_factory, ssl, server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
ssl_shutdown_timeout=ssl_shutdown_timeout)
|
||||
if self._debug:
|
||||
# Get the socket from the transport because SSL transport closes
|
||||
# the old socket and creates a new SSL socket
|
||||
|
@ -1087,7 +1096,8 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
async def _create_connection_transport(
|
||||
self, sock, protocol_factory, ssl,
|
||||
server_hostname, server_side=False,
|
||||
ssl_handshake_timeout=None):
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None):
|
||||
|
||||
sock.setblocking(False)
|
||||
|
||||
|
@ -1098,7 +1108,8 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
transport = self._make_ssl_transport(
|
||||
sock, protocol, sslcontext, waiter,
|
||||
server_side=server_side, server_hostname=server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
ssl_shutdown_timeout=ssl_shutdown_timeout)
|
||||
else:
|
||||
transport = self._make_socket_transport(sock, protocol, waiter)
|
||||
|
||||
|
@ -1189,7 +1200,8 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
async def start_tls(self, transport, protocol, sslcontext, *,
|
||||
server_side=False,
|
||||
server_hostname=None,
|
||||
ssl_handshake_timeout=None):
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None):
|
||||
"""Upgrade transport to TLS.
|
||||
|
||||
Return a new transport that *protocol* should start using
|
||||
|
@ -1212,6 +1224,7 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
self, protocol, sslcontext, waiter,
|
||||
server_side, server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
ssl_shutdown_timeout=ssl_shutdown_timeout,
|
||||
call_connection_made=False)
|
||||
|
||||
# Pause early so that "ssl_protocol.data_received()" doesn't
|
||||
|
@ -1397,6 +1410,7 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
reuse_address=None,
|
||||
reuse_port=None,
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None,
|
||||
start_serving=True):
|
||||
"""Create a TCP server.
|
||||
|
||||
|
@ -1420,6 +1434,10 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
raise ValueError(
|
||||
'ssl_handshake_timeout is only meaningful with ssl')
|
||||
|
||||
if ssl_shutdown_timeout is not None and ssl is None:
|
||||
raise ValueError(
|
||||
'ssl_shutdown_timeout is only meaningful with ssl')
|
||||
|
||||
if host is not None or port is not None:
|
||||
if sock is not None:
|
||||
raise ValueError(
|
||||
|
@ -1492,7 +1510,8 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
sock.setblocking(False)
|
||||
|
||||
server = Server(self, sockets, protocol_factory,
|
||||
ssl, backlog, ssl_handshake_timeout)
|
||||
ssl, backlog, ssl_handshake_timeout,
|
||||
ssl_shutdown_timeout)
|
||||
if start_serving:
|
||||
server._start_serving()
|
||||
# Skip one loop iteration so that all 'loop.add_reader'
|
||||
|
@ -1506,7 +1525,8 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
async def connect_accepted_socket(
|
||||
self, protocol_factory, sock,
|
||||
*, ssl=None,
|
||||
ssl_handshake_timeout=None):
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None):
|
||||
if sock.type != socket.SOCK_STREAM:
|
||||
raise ValueError(f'A Stream Socket was expected, got {sock!r}')
|
||||
|
||||
|
@ -1514,9 +1534,14 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
raise ValueError(
|
||||
'ssl_handshake_timeout is only meaningful with ssl')
|
||||
|
||||
if ssl_shutdown_timeout is not None and not ssl:
|
||||
raise ValueError(
|
||||
'ssl_shutdown_timeout is only meaningful with ssl')
|
||||
|
||||
transport, protocol = await self._create_connection_transport(
|
||||
sock, protocol_factory, ssl, '', server_side=True,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
ssl_shutdown_timeout=ssl_shutdown_timeout)
|
||||
if self._debug:
|
||||
# Get the socket from the transport because SSL transport closes
|
||||
# the old socket and creates a new SSL socket
|
||||
|
|
|
@ -15,10 +15,17 @@ DEBUG_STACK_DEPTH = 10
|
|||
# The default timeout matches that of Nginx.
|
||||
SSL_HANDSHAKE_TIMEOUT = 60.0
|
||||
|
||||
# Number of seconds to wait for SSL shutdown to complete
|
||||
# The default timeout mimics lingering_time
|
||||
SSL_SHUTDOWN_TIMEOUT = 30.0
|
||||
|
||||
# Used in sendfile fallback code. We use fallback for platforms
|
||||
# that don't support sendfile, or for TLS connections.
|
||||
SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 256
|
||||
|
||||
FLOW_CONTROL_HIGH_WATER_SSL_READ = 256 # KiB
|
||||
FLOW_CONTROL_HIGH_WATER_SSL_WRITE = 512 # KiB
|
||||
|
||||
# The enum should be here to break circular dependencies between
|
||||
# base_events and sslproto
|
||||
class _SendfileMode(enum.Enum):
|
||||
|
|
|
@ -303,6 +303,7 @@ class AbstractEventLoop:
|
|||
flags=0, sock=None, local_addr=None,
|
||||
server_hostname=None,
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None,
|
||||
happy_eyeballs_delay=None, interleave=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -312,6 +313,7 @@ class AbstractEventLoop:
|
|||
flags=socket.AI_PASSIVE, sock=None, backlog=100,
|
||||
ssl=None, reuse_address=None, reuse_port=None,
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None,
|
||||
start_serving=True):
|
||||
"""A coroutine which creates a TCP server bound to host and port.
|
||||
|
||||
|
@ -352,6 +354,10 @@ class AbstractEventLoop:
|
|||
will wait for completion of the SSL handshake before aborting the
|
||||
connection. Default is 60s.
|
||||
|
||||
ssl_shutdown_timeout is the time in seconds that an SSL server
|
||||
will wait for completion of the SSL shutdown procedure
|
||||
before aborting the connection. Default is 30s.
|
||||
|
||||
start_serving set to True (default) causes the created server
|
||||
to start accepting connections immediately. When set to False,
|
||||
the user should await Server.start_serving() or Server.serve_forever()
|
||||
|
@ -370,7 +376,8 @@ class AbstractEventLoop:
|
|||
async def start_tls(self, transport, protocol, sslcontext, *,
|
||||
server_side=False,
|
||||
server_hostname=None,
|
||||
ssl_handshake_timeout=None):
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None):
|
||||
"""Upgrade a transport to TLS.
|
||||
|
||||
Return a new transport that *protocol* should start using
|
||||
|
@ -382,13 +389,15 @@ class AbstractEventLoop:
|
|||
self, protocol_factory, path=None, *,
|
||||
ssl=None, sock=None,
|
||||
server_hostname=None,
|
||||
ssl_handshake_timeout=None):
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None):
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_unix_server(
|
||||
self, protocol_factory, path=None, *,
|
||||
sock=None, backlog=100, ssl=None,
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None,
|
||||
start_serving=True):
|
||||
"""A coroutine which creates a UNIX Domain Socket server.
|
||||
|
||||
|
@ -410,6 +419,9 @@ class AbstractEventLoop:
|
|||
ssl_handshake_timeout is the time in seconds that an SSL server
|
||||
will wait for the SSL handshake to complete (defaults to 60s).
|
||||
|
||||
ssl_shutdown_timeout is the time in seconds that an SSL server
|
||||
will wait for the SSL shutdown to finish (defaults to 30s).
|
||||
|
||||
start_serving set to True (default) causes the created server
|
||||
to start accepting connections immediately. When set to False,
|
||||
the user should await Server.start_serving() or Server.serve_forever()
|
||||
|
@ -420,7 +432,8 @@ class AbstractEventLoop:
|
|||
async def connect_accepted_socket(
|
||||
self, protocol_factory, sock,
|
||||
*, ssl=None,
|
||||
ssl_handshake_timeout=None):
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None):
|
||||
"""Handle an accepted connection.
|
||||
|
||||
This is used by servers that accept connections outside of
|
||||
|
|
|
@ -642,11 +642,13 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
|
|||
self, rawsock, protocol, sslcontext, waiter=None,
|
||||
*, server_side=False, server_hostname=None,
|
||||
extra=None, server=None,
|
||||
ssl_handshake_timeout=None):
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None):
|
||||
ssl_protocol = sslproto.SSLProtocol(
|
||||
self, protocol, sslcontext, waiter,
|
||||
server_side, server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
ssl_shutdown_timeout=ssl_shutdown_timeout)
|
||||
_ProactorSocketTransport(self, rawsock, ssl_protocol,
|
||||
extra=extra, server=server)
|
||||
return ssl_protocol._app_transport
|
||||
|
@ -812,7 +814,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
|
|||
|
||||
def _start_serving(self, protocol_factory, sock,
|
||||
sslcontext=None, server=None, backlog=100,
|
||||
ssl_handshake_timeout=None):
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None):
|
||||
|
||||
def loop(f=None):
|
||||
try:
|
||||
|
@ -826,7 +829,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
|
|||
self._make_ssl_transport(
|
||||
conn, protocol, sslcontext, server_side=True,
|
||||
extra={'peername': addr}, server=server,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
ssl_shutdown_timeout=ssl_shutdown_timeout)
|
||||
else:
|
||||
self._make_socket_transport(
|
||||
conn, protocol,
|
||||
|
|
|
@ -70,11 +70,15 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
|||
self, rawsock, protocol, sslcontext, waiter=None,
|
||||
*, server_side=False, server_hostname=None,
|
||||
extra=None, server=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
|
||||
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT,
|
||||
):
|
||||
ssl_protocol = sslproto.SSLProtocol(
|
||||
self, protocol, sslcontext, waiter,
|
||||
server_side, server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
ssl_shutdown_timeout=ssl_shutdown_timeout
|
||||
)
|
||||
_SelectorSocketTransport(self, rawsock, ssl_protocol,
|
||||
extra=extra, server=server)
|
||||
return ssl_protocol._app_transport
|
||||
|
@ -146,15 +150,17 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
|||
|
||||
def _start_serving(self, protocol_factory, sock,
|
||||
sslcontext=None, server=None, backlog=100,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
|
||||
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
|
||||
self._add_reader(sock.fileno(), self._accept_connection,
|
||||
protocol_factory, sock, sslcontext, server, backlog,
|
||||
ssl_handshake_timeout)
|
||||
ssl_handshake_timeout, ssl_shutdown_timeout)
|
||||
|
||||
def _accept_connection(
|
||||
self, protocol_factory, sock,
|
||||
sslcontext=None, server=None, backlog=100,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
|
||||
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
|
||||
# This method is only called once for each event loop tick where the
|
||||
# listening socket has triggered an EVENT_READ. There may be multiple
|
||||
# connections waiting for an .accept() so it is called in a loop.
|
||||
|
@ -185,20 +191,22 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
|||
self.call_later(constants.ACCEPT_RETRY_DELAY,
|
||||
self._start_serving,
|
||||
protocol_factory, sock, sslcontext, server,
|
||||
backlog, ssl_handshake_timeout)
|
||||
backlog, ssl_handshake_timeout,
|
||||
ssl_shutdown_timeout)
|
||||
else:
|
||||
raise # The event loop will catch, log and ignore it.
|
||||
else:
|
||||
extra = {'peername': addr}
|
||||
accept = self._accept_connection2(
|
||||
protocol_factory, conn, extra, sslcontext, server,
|
||||
ssl_handshake_timeout)
|
||||
ssl_handshake_timeout, ssl_shutdown_timeout)
|
||||
self.create_task(accept)
|
||||
|
||||
async def _accept_connection2(
|
||||
self, protocol_factory, conn, extra,
|
||||
sslcontext=None, server=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
|
||||
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
|
||||
protocol = None
|
||||
transport = None
|
||||
try:
|
||||
|
@ -208,7 +216,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
|||
transport = self._make_ssl_transport(
|
||||
conn, protocol, sslcontext, waiter=waiter,
|
||||
server_side=True, extra=extra, server=server,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
ssl_shutdown_timeout=ssl_shutdown_timeout)
|
||||
else:
|
||||
transport = self._make_socket_transport(
|
||||
conn, protocol, waiter=waiter, extra=extra,
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -229,7 +229,8 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
|
|||
self, protocol_factory, path=None, *,
|
||||
ssl=None, sock=None,
|
||||
server_hostname=None,
|
||||
ssl_handshake_timeout=None):
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None):
|
||||
assert server_hostname is None or isinstance(server_hostname, str)
|
||||
if ssl:
|
||||
if server_hostname is None:
|
||||
|
@ -241,6 +242,9 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
|
|||
if ssl_handshake_timeout is not None:
|
||||
raise ValueError(
|
||||
'ssl_handshake_timeout is only meaningful with ssl')
|
||||
if ssl_shutdown_timeout is not None:
|
||||
raise ValueError(
|
||||
'ssl_shutdown_timeout is only meaningful with ssl')
|
||||
|
||||
if path is not None:
|
||||
if sock is not None:
|
||||
|
@ -267,13 +271,15 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
|
|||
|
||||
transport, protocol = await self._create_connection_transport(
|
||||
sock, protocol_factory, ssl, server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
ssl_shutdown_timeout=ssl_shutdown_timeout)
|
||||
return transport, protocol
|
||||
|
||||
async def create_unix_server(
|
||||
self, protocol_factory, path=None, *,
|
||||
sock=None, backlog=100, ssl=None,
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None,
|
||||
start_serving=True):
|
||||
if isinstance(ssl, bool):
|
||||
raise TypeError('ssl argument must be an SSLContext or None')
|
||||
|
@ -282,6 +288,10 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
|
|||
raise ValueError(
|
||||
'ssl_handshake_timeout is only meaningful with ssl')
|
||||
|
||||
if ssl_shutdown_timeout is not None and not ssl:
|
||||
raise ValueError(
|
||||
'ssl_shutdown_timeout is only meaningful with ssl')
|
||||
|
||||
if path is not None:
|
||||
if sock is not None:
|
||||
raise ValueError(
|
||||
|
@ -328,7 +338,8 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
|
|||
|
||||
sock.setblocking(False)
|
||||
server = base_events.Server(self, [sock], protocol_factory,
|
||||
ssl, backlog, ssl_handshake_timeout)
|
||||
ssl, backlog, ssl_handshake_timeout,
|
||||
ssl_shutdown_timeout)
|
||||
if start_serving:
|
||||
server._start_serving()
|
||||
# Skip one loop iteration so that all 'loop.add_reader'
|
||||
|
|
|
@ -1451,44 +1451,51 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
|
|||
self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport
|
||||
ANY = mock.ANY
|
||||
handshake_timeout = object()
|
||||
shutdown_timeout = object()
|
||||
# First try the default server_hostname.
|
||||
self.loop._make_ssl_transport.reset_mock()
|
||||
coro = self.loop.create_connection(
|
||||
MyProto, 'python.org', 80, ssl=True,
|
||||
ssl_handshake_timeout=handshake_timeout)
|
||||
ssl_handshake_timeout=handshake_timeout,
|
||||
ssl_shutdown_timeout=shutdown_timeout)
|
||||
transport, _ = self.loop.run_until_complete(coro)
|
||||
transport.close()
|
||||
self.loop._make_ssl_transport.assert_called_with(
|
||||
ANY, ANY, ANY, ANY,
|
||||
server_side=False,
|
||||
server_hostname='python.org',
|
||||
ssl_handshake_timeout=handshake_timeout)
|
||||
ssl_handshake_timeout=handshake_timeout,
|
||||
ssl_shutdown_timeout=shutdown_timeout)
|
||||
# Next try an explicit server_hostname.
|
||||
self.loop._make_ssl_transport.reset_mock()
|
||||
coro = self.loop.create_connection(
|
||||
MyProto, 'python.org', 80, ssl=True,
|
||||
server_hostname='perl.com',
|
||||
ssl_handshake_timeout=handshake_timeout)
|
||||
ssl_handshake_timeout=handshake_timeout,
|
||||
ssl_shutdown_timeout=shutdown_timeout)
|
||||
transport, _ = self.loop.run_until_complete(coro)
|
||||
transport.close()
|
||||
self.loop._make_ssl_transport.assert_called_with(
|
||||
ANY, ANY, ANY, ANY,
|
||||
server_side=False,
|
||||
server_hostname='perl.com',
|
||||
ssl_handshake_timeout=handshake_timeout)
|
||||
ssl_handshake_timeout=handshake_timeout,
|
||||
ssl_shutdown_timeout=shutdown_timeout)
|
||||
# Finally try an explicit empty server_hostname.
|
||||
self.loop._make_ssl_transport.reset_mock()
|
||||
coro = self.loop.create_connection(
|
||||
MyProto, 'python.org', 80, ssl=True,
|
||||
server_hostname='',
|
||||
ssl_handshake_timeout=handshake_timeout)
|
||||
ssl_handshake_timeout=handshake_timeout,
|
||||
ssl_shutdown_timeout=shutdown_timeout)
|
||||
transport, _ = self.loop.run_until_complete(coro)
|
||||
transport.close()
|
||||
self.loop._make_ssl_transport.assert_called_with(
|
||||
ANY, ANY, ANY, ANY,
|
||||
server_side=False,
|
||||
server_hostname='',
|
||||
ssl_handshake_timeout=handshake_timeout)
|
||||
ssl_handshake_timeout=handshake_timeout,
|
||||
ssl_shutdown_timeout=shutdown_timeout)
|
||||
|
||||
def test_create_connection_no_ssl_server_hostname_errors(self):
|
||||
# When not using ssl, server_hostname must be None.
|
||||
|
@ -1869,7 +1876,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
|
|||
constants.ACCEPT_RETRY_DELAY,
|
||||
# self.loop._start_serving
|
||||
mock.ANY,
|
||||
MyProto, sock, None, None, mock.ANY, mock.ANY)
|
||||
MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY)
|
||||
|
||||
def test_call_coroutine(self):
|
||||
async def simple_coroutine():
|
||||
|
|
|
@ -71,44 +71,6 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
|
|||
|
||||
close_transport(transport)
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
def test_make_ssl_transport(self):
|
||||
m = mock.Mock()
|
||||
self.loop._add_reader = mock.Mock()
|
||||
self.loop._add_reader._is_coroutine = False
|
||||
self.loop._add_writer = mock.Mock()
|
||||
self.loop._remove_reader = mock.Mock()
|
||||
self.loop._remove_writer = mock.Mock()
|
||||
waiter = self.loop.create_future()
|
||||
with test_utils.disable_logger():
|
||||
transport = self.loop._make_ssl_transport(
|
||||
m, asyncio.Protocol(), m, waiter)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r'SSL transport.*not.*initialized'):
|
||||
transport.is_reading()
|
||||
|
||||
# execute the handshake while the logger is disabled
|
||||
# to ignore SSL handshake failure
|
||||
test_utils.run_briefly(self.loop)
|
||||
|
||||
self.assertTrue(transport.is_reading())
|
||||
transport.pause_reading()
|
||||
transport.pause_reading()
|
||||
self.assertFalse(transport.is_reading())
|
||||
transport.resume_reading()
|
||||
transport.resume_reading()
|
||||
self.assertTrue(transport.is_reading())
|
||||
|
||||
# Sanity check
|
||||
class_name = transport.__class__.__name__
|
||||
self.assertIn("ssl", class_name.lower())
|
||||
self.assertIn("transport", class_name.lower())
|
||||
|
||||
transport.close()
|
||||
# execute pending callbacks to close the socket transport
|
||||
test_utils.run_briefly(self.loop)
|
||||
|
||||
@mock.patch('asyncio.selector_events.ssl', None)
|
||||
@mock.patch('asyncio.sslproto.ssl', None)
|
||||
def test_make_ssl_transport_without_ssl_error(self):
|
||||
|
|
1721
Lib/test/test_asyncio/test_ssl.py
Normal file
1721
Lib/test/test_asyncio/test_ssl.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -15,7 +15,6 @@ import asyncio
|
|||
from asyncio import log
|
||||
from asyncio import protocols
|
||||
from asyncio import sslproto
|
||||
from test import support
|
||||
from test.test_asyncio import utils as test_utils
|
||||
from test.test_asyncio import functional as func_tests
|
||||
|
||||
|
@ -44,15 +43,12 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
|||
|
||||
def connection_made(self, ssl_proto, *, do_handshake=None):
|
||||
transport = mock.Mock()
|
||||
sslpipe = mock.Mock()
|
||||
sslpipe.shutdown.return_value = b''
|
||||
if do_handshake:
|
||||
sslpipe.do_handshake.side_effect = do_handshake
|
||||
else:
|
||||
def mock_handshake(callback):
|
||||
return []
|
||||
sslpipe.do_handshake.side_effect = mock_handshake
|
||||
with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
|
||||
sslobj = mock.Mock()
|
||||
# emulate reading decompressed data
|
||||
sslobj.read.side_effect = ssl.SSLWantReadError
|
||||
if do_handshake is not None:
|
||||
sslobj.do_handshake = do_handshake
|
||||
ssl_proto._sslobj = sslobj
|
||||
ssl_proto.connection_made(transport)
|
||||
return transport
|
||||
|
||||
|
@ -75,7 +71,10 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
|||
def test_eof_received_waiter(self):
|
||||
waiter = self.loop.create_future()
|
||||
ssl_proto = self.ssl_protocol(waiter=waiter)
|
||||
self.connection_made(ssl_proto)
|
||||
self.connection_made(
|
||||
ssl_proto,
|
||||
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
|
||||
)
|
||||
ssl_proto.eof_received()
|
||||
test_utils.run_briefly(self.loop)
|
||||
self.assertIsInstance(waiter.exception(), ConnectionResetError)
|
||||
|
@ -100,7 +99,10 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
|||
# yield from waiter hang if lost_connection was called.
|
||||
waiter = self.loop.create_future()
|
||||
ssl_proto = self.ssl_protocol(waiter=waiter)
|
||||
self.connection_made(ssl_proto)
|
||||
self.connection_made(
|
||||
ssl_proto,
|
||||
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
|
||||
)
|
||||
ssl_proto.connection_lost(ConnectionAbortedError)
|
||||
test_utils.run_briefly(self.loop)
|
||||
self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
|
||||
|
@ -110,7 +112,10 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
|||
waiter = self.loop.create_future()
|
||||
ssl_proto = self.ssl_protocol(waiter=waiter)
|
||||
|
||||
transport = self.connection_made(ssl_proto)
|
||||
transport = self.connection_made(
|
||||
ssl_proto,
|
||||
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
|
||||
)
|
||||
test_utils.run_briefly(self.loop)
|
||||
|
||||
ssl_proto._app_transport.close()
|
||||
|
@ -143,7 +148,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
|||
transp.close()
|
||||
|
||||
# should not raise
|
||||
self.assertIsNone(ssl_proto.data_received(b'data'))
|
||||
self.assertIsNone(ssl_proto.buffer_updated(5))
|
||||
|
||||
def test_write_after_closing(self):
|
||||
ssl_proto = self.ssl_protocol()
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
Reimplement SSL/TLS support in asyncio, borrow the implementation from
|
||||
uvloop library.
|
Loading…
Add table
Add a link
Reference in a new issue