bpo-44011: Revert "New asyncio ssl implementation (GH-17975)" (GH-25848)

This reverts commit 5fb06edbbb and all
subsequent dependent commits.
This commit is contained in:
Pablo Galindo 2021-05-03 16:21:59 +01:00 committed by GitHub
parent 39494285e1
commit 7719953b30
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 524 additions and 2477 deletions

View file

@ -273,7 +273,7 @@ class _SendfileFallbackProtocol(protocols.Protocol):
class Server(events.AbstractServer): class Server(events.AbstractServer):
def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
ssl_handshake_timeout, ssl_shutdown_timeout=None): ssl_handshake_timeout):
self._loop = loop self._loop = loop
self._sockets = sockets self._sockets = sockets
self._active_count = 0 self._active_count = 0
@ -282,7 +282,6 @@ class Server(events.AbstractServer):
self._backlog = backlog self._backlog = backlog
self._ssl_context = ssl_context self._ssl_context = ssl_context
self._ssl_handshake_timeout = ssl_handshake_timeout self._ssl_handshake_timeout = ssl_handshake_timeout
self._ssl_shutdown_timeout = ssl_shutdown_timeout
self._serving = False self._serving = False
self._serving_forever_fut = None self._serving_forever_fut = None
@ -314,8 +313,7 @@ class Server(events.AbstractServer):
sock.listen(self._backlog) sock.listen(self._backlog)
self._loop._start_serving( self._loop._start_serving(
self._protocol_factory, sock, self._ssl_context, self._protocol_factory, sock, self._ssl_context,
self, self._backlog, self._ssl_handshake_timeout, self, self._backlog, self._ssl_handshake_timeout)
self._ssl_shutdown_timeout)
def get_loop(self): def get_loop(self):
return self._loop return self._loop
@ -469,7 +467,6 @@ class BaseEventLoop(events.AbstractEventLoop):
*, server_side=False, server_hostname=None, *, server_side=False, server_hostname=None,
extra=None, server=None, extra=None, server=None,
ssl_handshake_timeout=None, ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
call_connection_made=True): call_connection_made=True):
"""Create SSL transport.""" """Create SSL transport."""
raise NotImplementedError raise NotImplementedError
@ -972,7 +969,6 @@ class BaseEventLoop(events.AbstractEventLoop):
proto=0, flags=0, sock=None, proto=0, flags=0, sock=None,
local_addr=None, server_hostname=None, local_addr=None, server_hostname=None,
ssl_handshake_timeout=None, ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
happy_eyeballs_delay=None, interleave=None): happy_eyeballs_delay=None, interleave=None):
"""Connect to a TCP server. """Connect to a TCP server.
@ -1008,10 +1004,6 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError( raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl') 'ssl_handshake_timeout is only meaningful with ssl')
if ssl_shutdown_timeout is not None and not ssl:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')
if happy_eyeballs_delay is not None and interleave is None: if happy_eyeballs_delay is not None and interleave is None:
# If using happy eyeballs, default to interleave addresses by family # If using happy eyeballs, default to interleave addresses by family
interleave = 1 interleave = 1
@ -1087,8 +1079,7 @@ class BaseEventLoop(events.AbstractEventLoop):
transport, protocol = await self._create_connection_transport( transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname, sock, protocol_factory, ssl, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout, ssl_handshake_timeout=ssl_handshake_timeout)
ssl_shutdown_timeout=ssl_shutdown_timeout)
if self._debug: if self._debug:
# Get the socket from the transport because SSL transport closes # Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket # the old socket and creates a new SSL socket
@ -1100,8 +1091,7 @@ class BaseEventLoop(events.AbstractEventLoop):
async def _create_connection_transport( async def _create_connection_transport(
self, sock, protocol_factory, ssl, self, sock, protocol_factory, ssl,
server_hostname, server_side=False, server_hostname, server_side=False,
ssl_handshake_timeout=None, ssl_handshake_timeout=None):
ssl_shutdown_timeout=None):
sock.setblocking(False) sock.setblocking(False)
@ -1112,8 +1102,7 @@ class BaseEventLoop(events.AbstractEventLoop):
transport = self._make_ssl_transport( transport = self._make_ssl_transport(
sock, protocol, sslcontext, waiter, sock, protocol, sslcontext, waiter,
server_side=server_side, server_hostname=server_hostname, server_side=server_side, server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout, ssl_handshake_timeout=ssl_handshake_timeout)
ssl_shutdown_timeout=ssl_shutdown_timeout)
else: else:
transport = self._make_socket_transport(sock, protocol, waiter) transport = self._make_socket_transport(sock, protocol, waiter)
@ -1204,8 +1193,7 @@ class BaseEventLoop(events.AbstractEventLoop):
async def start_tls(self, transport, protocol, sslcontext, *, async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False, server_side=False,
server_hostname=None, server_hostname=None,
ssl_handshake_timeout=None, ssl_handshake_timeout=None):
ssl_shutdown_timeout=None):
"""Upgrade transport to TLS. """Upgrade transport to TLS.
Return a new transport that *protocol* should start using Return a new transport that *protocol* should start using
@ -1228,7 +1216,6 @@ class BaseEventLoop(events.AbstractEventLoop):
self, protocol, sslcontext, waiter, self, protocol, sslcontext, waiter,
server_side, server_hostname, server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout, ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout,
call_connection_made=False) call_connection_made=False)
# Pause early so that "ssl_protocol.data_received()" doesn't # Pause early so that "ssl_protocol.data_received()" doesn't
@ -1427,7 +1414,6 @@ class BaseEventLoop(events.AbstractEventLoop):
reuse_address=None, reuse_address=None,
reuse_port=None, reuse_port=None,
ssl_handshake_timeout=None, ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True): start_serving=True):
"""Create a TCP server. """Create a TCP server.
@ -1451,10 +1437,6 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError( raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl') 'ssl_handshake_timeout is only meaningful with ssl')
if ssl_shutdown_timeout is not None and ssl is None:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')
if host is not None or port is not None: if host is not None or port is not None:
if sock is not None: if sock is not None:
raise ValueError( raise ValueError(
@ -1527,8 +1509,7 @@ class BaseEventLoop(events.AbstractEventLoop):
sock.setblocking(False) sock.setblocking(False)
server = Server(self, sockets, protocol_factory, server = Server(self, sockets, protocol_factory,
ssl, backlog, ssl_handshake_timeout, ssl, backlog, ssl_handshake_timeout)
ssl_shutdown_timeout)
if start_serving: if start_serving:
server._start_serving() server._start_serving()
# Skip one loop iteration so that all 'loop.add_reader' # Skip one loop iteration so that all 'loop.add_reader'
@ -1542,8 +1523,7 @@ class BaseEventLoop(events.AbstractEventLoop):
async def connect_accepted_socket( async def connect_accepted_socket(
self, protocol_factory, sock, self, protocol_factory, sock,
*, ssl=None, *, ssl=None,
ssl_handshake_timeout=None, ssl_handshake_timeout=None):
ssl_shutdown_timeout=None):
if sock.type != socket.SOCK_STREAM: if sock.type != socket.SOCK_STREAM:
raise ValueError(f'A Stream Socket was expected, got {sock!r}') raise ValueError(f'A Stream Socket was expected, got {sock!r}')
@ -1551,14 +1531,9 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError( raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl') 'ssl_handshake_timeout is only meaningful with ssl')
if ssl_shutdown_timeout is not None and not ssl:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')
transport, protocol = await self._create_connection_transport( transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, '', server_side=True, sock, protocol_factory, ssl, '', server_side=True,
ssl_handshake_timeout=ssl_handshake_timeout, ssl_handshake_timeout=ssl_handshake_timeout)
ssl_shutdown_timeout=ssl_shutdown_timeout)
if self._debug: if self._debug:
# Get the socket from the transport because SSL transport closes # Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket # the old socket and creates a new SSL socket

View file

@ -15,17 +15,10 @@ DEBUG_STACK_DEPTH = 10
# The default timeout matches that of Nginx. # The default timeout matches that of Nginx.
SSL_HANDSHAKE_TIMEOUT = 60.0 SSL_HANDSHAKE_TIMEOUT = 60.0
# Number of seconds to wait for SSL shutdown to complete
# The default timeout mimics lingering_time
SSL_SHUTDOWN_TIMEOUT = 30.0
# Used in sendfile fallback code. We use fallback for platforms # Used in sendfile fallback code. We use fallback for platforms
# that don't support sendfile, or for TLS connections. # that don't support sendfile, or for TLS connections.
SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 256 SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 256
FLOW_CONTROL_HIGH_WATER_SSL_READ = 256 # KiB
FLOW_CONTROL_HIGH_WATER_SSL_WRITE = 512 # KiB
# The enum should be here to break circular dependencies between # The enum should be here to break circular dependencies between
# base_events and sslproto # base_events and sslproto
class _SendfileMode(enum.Enum): class _SendfileMode(enum.Enum):

View file

@ -304,7 +304,6 @@ class AbstractEventLoop:
flags=0, sock=None, local_addr=None, flags=0, sock=None, local_addr=None,
server_hostname=None, server_hostname=None,
ssl_handshake_timeout=None, ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
happy_eyeballs_delay=None, interleave=None): happy_eyeballs_delay=None, interleave=None):
raise NotImplementedError raise NotImplementedError
@ -314,7 +313,6 @@ class AbstractEventLoop:
flags=socket.AI_PASSIVE, sock=None, backlog=100, flags=socket.AI_PASSIVE, sock=None, backlog=100,
ssl=None, reuse_address=None, reuse_port=None, ssl=None, reuse_address=None, reuse_port=None,
ssl_handshake_timeout=None, ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True): start_serving=True):
"""A coroutine which creates a TCP server bound to host and port. """A coroutine which creates a TCP server bound to host and port.
@ -355,10 +353,6 @@ class AbstractEventLoop:
will wait for completion of the SSL handshake before aborting the will wait for completion of the SSL handshake before aborting the
connection. Default is 60s. connection. Default is 60s.
ssl_shutdown_timeout is the time in seconds that an SSL server
will wait for completion of the SSL shutdown procedure
before aborting the connection. Default is 30s.
start_serving set to True (default) causes the created server start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False, to start accepting connections immediately. When set to False,
the user should await Server.start_serving() or Server.serve_forever() the user should await Server.start_serving() or Server.serve_forever()
@ -377,8 +371,7 @@ class AbstractEventLoop:
async def start_tls(self, transport, protocol, sslcontext, *, async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False, server_side=False,
server_hostname=None, server_hostname=None,
ssl_handshake_timeout=None, ssl_handshake_timeout=None):
ssl_shutdown_timeout=None):
"""Upgrade a transport to TLS. """Upgrade a transport to TLS.
Return a new transport that *protocol* should start using Return a new transport that *protocol* should start using
@ -390,15 +383,13 @@ class AbstractEventLoop:
self, protocol_factory, path=None, *, self, protocol_factory, path=None, *,
ssl=None, sock=None, ssl=None, sock=None,
server_hostname=None, server_hostname=None,
ssl_handshake_timeout=None, ssl_handshake_timeout=None):
ssl_shutdown_timeout=None):
raise NotImplementedError raise NotImplementedError
async def create_unix_server( async def create_unix_server(
self, protocol_factory, path=None, *, self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None, sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=None, ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True): start_serving=True):
"""A coroutine which creates a UNIX Domain Socket server. """A coroutine which creates a UNIX Domain Socket server.
@ -420,9 +411,6 @@ class AbstractEventLoop:
ssl_handshake_timeout is the time in seconds that an SSL server ssl_handshake_timeout is the time in seconds that an SSL server
will wait for the SSL handshake to complete (defaults to 60s). will wait for the SSL handshake to complete (defaults to 60s).
ssl_shutdown_timeout is the time in seconds that an SSL server
will wait for the SSL shutdown to finish (defaults to 30s).
start_serving set to True (default) causes the created server start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False, to start accepting connections immediately. When set to False,
the user should await Server.start_serving() or Server.serve_forever() the user should await Server.start_serving() or Server.serve_forever()
@ -433,8 +421,7 @@ class AbstractEventLoop:
async def connect_accepted_socket( async def connect_accepted_socket(
self, protocol_factory, sock, self, protocol_factory, sock,
*, ssl=None, *, ssl=None,
ssl_handshake_timeout=None, ssl_handshake_timeout=None):
ssl_shutdown_timeout=None):
"""Handle an accepted connection. """Handle an accepted connection.
This is used by servers that accept connections outside of This is used by servers that accept connections outside of

View file

@ -642,13 +642,11 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
self, rawsock, protocol, sslcontext, waiter=None, self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None, *, server_side=False, server_hostname=None,
extra=None, server=None, extra=None, server=None,
ssl_handshake_timeout=None, ssl_handshake_timeout=None):
ssl_shutdown_timeout=None):
ssl_protocol = sslproto.SSLProtocol( ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter, self, protocol, sslcontext, waiter,
server_side, server_hostname, server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout, ssl_handshake_timeout=ssl_handshake_timeout)
ssl_shutdown_timeout=ssl_shutdown_timeout)
_ProactorSocketTransport(self, rawsock, ssl_protocol, _ProactorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server) extra=extra, server=server)
return ssl_protocol._app_transport return ssl_protocol._app_transport
@ -814,8 +812,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
def _start_serving(self, protocol_factory, sock, def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100, sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=None, ssl_handshake_timeout=None):
ssl_shutdown_timeout=None):
def loop(f=None): def loop(f=None):
try: try:
@ -829,8 +826,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
self._make_ssl_transport( self._make_ssl_transport(
conn, protocol, sslcontext, server_side=True, conn, protocol, sslcontext, server_side=True,
extra={'peername': addr}, server=server, extra={'peername': addr}, server=server,
ssl_handshake_timeout=ssl_handshake_timeout, ssl_handshake_timeout=ssl_handshake_timeout)
ssl_shutdown_timeout=ssl_shutdown_timeout)
else: else:
self._make_socket_transport( self._make_socket_transport(
conn, protocol, conn, protocol,

View file

@ -70,15 +70,11 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
self, rawsock, protocol, sslcontext, waiter=None, self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None, *, server_side=False, server_hostname=None,
extra=None, server=None, extra=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT,
):
ssl_protocol = sslproto.SSLProtocol( ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter, self, protocol, sslcontext, waiter,
server_side, server_hostname, server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout, ssl_handshake_timeout=ssl_handshake_timeout)
ssl_shutdown_timeout=ssl_shutdown_timeout
)
_SelectorSocketTransport(self, rawsock, ssl_protocol, _SelectorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server) extra=extra, server=server)
return ssl_protocol._app_transport return ssl_protocol._app_transport
@ -150,17 +146,15 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
def _start_serving(self, protocol_factory, sock, def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100, sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
self._add_reader(sock.fileno(), self._accept_connection, self._add_reader(sock.fileno(), self._accept_connection,
protocol_factory, sock, sslcontext, server, backlog, protocol_factory, sock, sslcontext, server, backlog,
ssl_handshake_timeout, ssl_shutdown_timeout) ssl_handshake_timeout)
def _accept_connection( def _accept_connection(
self, protocol_factory, sock, self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100, sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
# This method is only called once for each event loop tick where the # This method is only called once for each event loop tick where the
# listening socket has triggered an EVENT_READ. There may be multiple # listening socket has triggered an EVENT_READ. There may be multiple
# connections waiting for an .accept() so it is called in a loop. # connections waiting for an .accept() so it is called in a loop.
@ -191,22 +185,20 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
self.call_later(constants.ACCEPT_RETRY_DELAY, self.call_later(constants.ACCEPT_RETRY_DELAY,
self._start_serving, self._start_serving,
protocol_factory, sock, sslcontext, server, protocol_factory, sock, sslcontext, server,
backlog, ssl_handshake_timeout, backlog, ssl_handshake_timeout)
ssl_shutdown_timeout)
else: else:
raise # The event loop will catch, log and ignore it. raise # The event loop will catch, log and ignore it.
else: else:
extra = {'peername': addr} extra = {'peername': addr}
accept = self._accept_connection2( accept = self._accept_connection2(
protocol_factory, conn, extra, sslcontext, server, protocol_factory, conn, extra, sslcontext, server,
ssl_handshake_timeout, ssl_shutdown_timeout) ssl_handshake_timeout)
self.create_task(accept) self.create_task(accept)
async def _accept_connection2( async def _accept_connection2(
self, protocol_factory, conn, extra, self, protocol_factory, conn, extra,
sslcontext=None, server=None, sslcontext=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
protocol = None protocol = None
transport = None transport = None
try: try:
@ -216,8 +208,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
transport = self._make_ssl_transport( transport = self._make_ssl_transport(
conn, protocol, sslcontext, waiter=waiter, conn, protocol, sslcontext, waiter=waiter,
server_side=True, extra=extra, server=server, server_side=True, extra=extra, server=server,
ssl_handshake_timeout=ssl_handshake_timeout, ssl_handshake_timeout=ssl_handshake_timeout)
ssl_shutdown_timeout=ssl_shutdown_timeout)
else: else:
transport = self._make_socket_transport( transport = self._make_socket_transport(
conn, protocol, waiter=waiter, extra=extra, conn, protocol, waiter=waiter, extra=extra,

File diff suppressed because it is too large Load diff

View file

@ -229,8 +229,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
self, protocol_factory, path=None, *, self, protocol_factory, path=None, *,
ssl=None, sock=None, ssl=None, sock=None,
server_hostname=None, server_hostname=None,
ssl_handshake_timeout=None, ssl_handshake_timeout=None):
ssl_shutdown_timeout=None):
assert server_hostname is None or isinstance(server_hostname, str) assert server_hostname is None or isinstance(server_hostname, str)
if ssl: if ssl:
if server_hostname is None: if server_hostname is None:
@ -242,9 +241,6 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
if ssl_handshake_timeout is not None: if ssl_handshake_timeout is not None:
raise ValueError( raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl') 'ssl_handshake_timeout is only meaningful with ssl')
if ssl_shutdown_timeout is not None:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')
if path is not None: if path is not None:
if sock is not None: if sock is not None:
@ -271,15 +267,13 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
transport, protocol = await self._create_connection_transport( transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname, sock, protocol_factory, ssl, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout, ssl_handshake_timeout=ssl_handshake_timeout)
ssl_shutdown_timeout=ssl_shutdown_timeout)
return transport, protocol return transport, protocol
async def create_unix_server( async def create_unix_server(
self, protocol_factory, path=None, *, self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None, sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=None, ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True): start_serving=True):
if isinstance(ssl, bool): if isinstance(ssl, bool):
raise TypeError('ssl argument must be an SSLContext or None') raise TypeError('ssl argument must be an SSLContext or None')
@ -288,10 +282,6 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
raise ValueError( raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl') 'ssl_handshake_timeout is only meaningful with ssl')
if ssl_shutdown_timeout is not None and not ssl:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')
if path is not None: if path is not None:
if sock is not None: if sock is not None:
raise ValueError( raise ValueError(
@ -338,8 +328,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
sock.setblocking(False) sock.setblocking(False)
server = base_events.Server(self, [sock], protocol_factory, server = base_events.Server(self, [sock], protocol_factory,
ssl, backlog, ssl_handshake_timeout, ssl, backlog, ssl_handshake_timeout)
ssl_shutdown_timeout)
if start_serving: if start_serving:
server._start_serving() server._start_serving()
# Skip one loop iteration so that all 'loop.add_reader' # Skip one loop iteration so that all 'loop.add_reader'

View file

@ -1437,51 +1437,44 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport
ANY = mock.ANY ANY = mock.ANY
handshake_timeout = object() handshake_timeout = object()
shutdown_timeout = object()
# First try the default server_hostname. # First try the default server_hostname.
self.loop._make_ssl_transport.reset_mock() self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection( coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True, MyProto, 'python.org', 80, ssl=True,
ssl_handshake_timeout=handshake_timeout, ssl_handshake_timeout=handshake_timeout)
ssl_shutdown_timeout=shutdown_timeout)
transport, _ = self.loop.run_until_complete(coro) transport, _ = self.loop.run_until_complete(coro)
transport.close() transport.close()
self.loop._make_ssl_transport.assert_called_with( self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY, ANY, ANY, ANY, ANY,
server_side=False, server_side=False,
server_hostname='python.org', server_hostname='python.org',
ssl_handshake_timeout=handshake_timeout, ssl_handshake_timeout=handshake_timeout)
ssl_shutdown_timeout=shutdown_timeout)
# Next try an explicit server_hostname. # Next try an explicit server_hostname.
self.loop._make_ssl_transport.reset_mock() self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection( coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True, MyProto, 'python.org', 80, ssl=True,
server_hostname='perl.com', server_hostname='perl.com',
ssl_handshake_timeout=handshake_timeout, ssl_handshake_timeout=handshake_timeout)
ssl_shutdown_timeout=shutdown_timeout)
transport, _ = self.loop.run_until_complete(coro) transport, _ = self.loop.run_until_complete(coro)
transport.close() transport.close()
self.loop._make_ssl_transport.assert_called_with( self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY, ANY, ANY, ANY, ANY,
server_side=False, server_side=False,
server_hostname='perl.com', server_hostname='perl.com',
ssl_handshake_timeout=handshake_timeout, ssl_handshake_timeout=handshake_timeout)
ssl_shutdown_timeout=shutdown_timeout)
# Finally try an explicit empty server_hostname. # Finally try an explicit empty server_hostname.
self.loop._make_ssl_transport.reset_mock() self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection( coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True, MyProto, 'python.org', 80, ssl=True,
server_hostname='', server_hostname='',
ssl_handshake_timeout=handshake_timeout, ssl_handshake_timeout=handshake_timeout)
ssl_shutdown_timeout=shutdown_timeout)
transport, _ = self.loop.run_until_complete(coro) transport, _ = self.loop.run_until_complete(coro)
transport.close() transport.close()
self.loop._make_ssl_transport.assert_called_with( self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY, ANY, ANY, ANY, ANY,
server_side=False, server_side=False,
server_hostname='', server_hostname='',
ssl_handshake_timeout=handshake_timeout, ssl_handshake_timeout=handshake_timeout)
ssl_shutdown_timeout=shutdown_timeout)
def test_create_connection_no_ssl_server_hostname_errors(self): def test_create_connection_no_ssl_server_hostname_errors(self):
# When not using ssl, server_hostname must be None. # When not using ssl, server_hostname must be None.
@ -1888,7 +1881,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
constants.ACCEPT_RETRY_DELAY, constants.ACCEPT_RETRY_DELAY,
# self.loop._start_serving # self.loop._start_serving
mock.ANY, mock.ANY,
MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY) MyProto, sock, None, None, mock.ANY, mock.ANY)
def test_call_coroutine(self): def test_call_coroutine(self):
with self.assertWarns(DeprecationWarning): with self.assertWarns(DeprecationWarning):

View file

@ -70,6 +70,44 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
close_transport(transport) close_transport(transport)
@unittest.skipIf(ssl is None, 'No ssl module')
def test_make_ssl_transport(self):
m = mock.Mock()
self.loop._add_reader = mock.Mock()
self.loop._add_reader._is_coroutine = False
self.loop._add_writer = mock.Mock()
self.loop._remove_reader = mock.Mock()
self.loop._remove_writer = mock.Mock()
waiter = self.loop.create_future()
with test_utils.disable_logger():
transport = self.loop._make_ssl_transport(
m, asyncio.Protocol(), m, waiter)
with self.assertRaisesRegex(RuntimeError,
r'SSL transport.*not.*initialized'):
transport.is_reading()
# execute the handshake while the logger is disabled
# to ignore SSL handshake failure
test_utils.run_briefly(self.loop)
self.assertTrue(transport.is_reading())
transport.pause_reading()
transport.pause_reading()
self.assertFalse(transport.is_reading())
transport.resume_reading()
transport.resume_reading()
self.assertTrue(transport.is_reading())
# Sanity check
class_name = transport.__class__.__name__
self.assertIn("ssl", class_name.lower())
self.assertIn("transport", class_name.lower())
transport.close()
# execute pending callbacks to close the socket transport
test_utils.run_briefly(self.loop)
@mock.patch('asyncio.selector_events.ssl', None) @mock.patch('asyncio.selector_events.ssl', None)
@mock.patch('asyncio.sslproto.ssl', None) @mock.patch('asyncio.sslproto.ssl', None)
def test_make_ssl_transport_without_ssl_error(self): def test_make_ssl_transport_without_ssl_error(self):

File diff suppressed because it is too large Load diff

View file

@ -15,6 +15,7 @@ import asyncio
from asyncio import log from asyncio import log
from asyncio import protocols from asyncio import protocols
from asyncio import sslproto from asyncio import sslproto
from test import support
from test.test_asyncio import utils as test_utils from test.test_asyncio import utils as test_utils
from test.test_asyncio import functional as func_tests from test.test_asyncio import functional as func_tests
@ -43,13 +44,16 @@ class SslProtoHandshakeTests(test_utils.TestCase):
def connection_made(self, ssl_proto, *, do_handshake=None): def connection_made(self, ssl_proto, *, do_handshake=None):
transport = mock.Mock() transport = mock.Mock()
sslobj = mock.Mock() sslpipe = mock.Mock()
# emulate reading decompressed data sslpipe.shutdown.return_value = b''
sslobj.read.side_effect = ssl.SSLWantReadError if do_handshake:
if do_handshake is not None: sslpipe.do_handshake.side_effect = do_handshake
sslobj.do_handshake = do_handshake else:
ssl_proto._sslobj = sslobj def mock_handshake(callback):
ssl_proto.connection_made(transport) return []
sslpipe.do_handshake.side_effect = mock_handshake
with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
ssl_proto.connection_made(transport)
return transport return transport
def test_handshake_timeout_zero(self): def test_handshake_timeout_zero(self):
@ -71,10 +75,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
def test_eof_received_waiter(self): def test_eof_received_waiter(self):
waiter = self.loop.create_future() waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter) ssl_proto = self.ssl_protocol(waiter=waiter)
self.connection_made( self.connection_made(ssl_proto)
ssl_proto,
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
)
ssl_proto.eof_received() ssl_proto.eof_received()
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertIsInstance(waiter.exception(), ConnectionResetError) self.assertIsInstance(waiter.exception(), ConnectionResetError)
@ -99,10 +100,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
# yield from waiter hang if lost_connection was called. # yield from waiter hang if lost_connection was called.
waiter = self.loop.create_future() waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter) ssl_proto = self.ssl_protocol(waiter=waiter)
self.connection_made( self.connection_made(ssl_proto)
ssl_proto,
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
)
ssl_proto.connection_lost(ConnectionAbortedError) ssl_proto.connection_lost(ConnectionAbortedError)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertIsInstance(waiter.exception(), ConnectionAbortedError) self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
@ -112,10 +110,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
waiter = self.loop.create_future() waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter) ssl_proto = self.ssl_protocol(waiter=waiter)
transport = self.connection_made( transport = self.connection_made(ssl_proto)
ssl_proto,
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
ssl_proto._app_transport.close() ssl_proto._app_transport.close()
@ -148,7 +143,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
transp.close() transp.close()
# should not raise # should not raise
self.assertIsNone(ssl_proto.buffer_updated(5)) self.assertIsNone(ssl_proto.data_received(b'data'))
def test_write_after_closing(self): def test_write_after_closing(self):
ssl_proto = self.ssl_protocol() ssl_proto = self.ssl_protocol()

View file

@ -1,2 +0,0 @@
Reimplement SSL/TLS support in asyncio, borrow the impelementation from
uvloop library.