mirror of
https://github.com/python/cpython.git
synced 2025-10-10 00:43:41 +00:00
bpo-29970: Add timeout for SSL handshake in asyncio
10 seconds by default.
This commit is contained in:
parent
4b965930e8
commit
f7686c1f55
12 changed files with 207 additions and 83 deletions
|
@ -29,6 +29,7 @@ import sys
|
|||
import warnings
|
||||
import weakref
|
||||
|
||||
from . import constants
|
||||
from . import coroutines
|
||||
from . import events
|
||||
from . import futures
|
||||
|
@ -275,9 +276,11 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
"""Create socket transport."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
|
||||
*, server_side=False, server_hostname=None,
|
||||
extra=None, server=None):
|
||||
def _make_ssl_transport(
|
||||
self, rawsock, protocol, sslcontext, waiter=None,
|
||||
*, server_side=False, server_hostname=None,
|
||||
extra=None, server=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
"""Create SSL transport."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -635,10 +638,12 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
return await self.run_in_executor(
|
||||
None, socket.getnameinfo, sockaddr, flags)
|
||||
|
||||
async def create_connection(self, protocol_factory, host=None, port=None,
|
||||
*, ssl=None, family=0,
|
||||
proto=0, flags=0, sock=None,
|
||||
local_addr=None, server_hostname=None):
|
||||
async def create_connection(
|
||||
self, protocol_factory, host=None, port=None,
|
||||
*, ssl=None, family=0,
|
||||
proto=0, flags=0, sock=None,
|
||||
local_addr=None, server_hostname=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
"""Connect to a TCP server.
|
||||
|
||||
Create a streaming transport connection to a given Internet host and
|
||||
|
@ -751,7 +756,8 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
f'A Stream Socket was expected, got {sock!r}')
|
||||
|
||||
transport, protocol = await self._create_connection_transport(
|
||||
sock, protocol_factory, ssl, server_hostname)
|
||||
sock, protocol_factory, ssl, server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
if self._debug:
|
||||
# Get the socket from the transport because SSL transport closes
|
||||
# the old socket and creates a new SSL socket
|
||||
|
@ -760,8 +766,10 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
sock, host, port, transport, protocol)
|
||||
return transport, protocol
|
||||
|
||||
async def _create_connection_transport(self, sock, protocol_factory, ssl,
|
||||
server_hostname, server_side=False):
|
||||
async def _create_connection_transport(
|
||||
self, sock, protocol_factory, ssl,
|
||||
server_hostname, server_side=False,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
|
||||
sock.setblocking(False)
|
||||
|
||||
|
@ -771,7 +779,8 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
sslcontext = None if isinstance(ssl, bool) else ssl
|
||||
transport = self._make_ssl_transport(
|
||||
sock, protocol, sslcontext, waiter,
|
||||
server_side=server_side, server_hostname=server_hostname)
|
||||
server_side=server_side, server_hostname=server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
else:
|
||||
transport = self._make_socket_transport(sock, protocol, waiter)
|
||||
|
||||
|
@ -929,15 +938,17 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
raise OSError(f'getaddrinfo({host!r}) returned empty list')
|
||||
return infos
|
||||
|
||||
async def create_server(self, protocol_factory, host=None, port=None,
|
||||
*,
|
||||
family=socket.AF_UNSPEC,
|
||||
flags=socket.AI_PASSIVE,
|
||||
sock=None,
|
||||
backlog=100,
|
||||
ssl=None,
|
||||
reuse_address=None,
|
||||
reuse_port=None):
|
||||
async def create_server(
|
||||
self, protocol_factory, host=None, port=None,
|
||||
*,
|
||||
family=socket.AF_UNSPEC,
|
||||
flags=socket.AI_PASSIVE,
|
||||
sock=None,
|
||||
backlog=100,
|
||||
ssl=None,
|
||||
reuse_address=None,
|
||||
reuse_port=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
"""Create a TCP server.
|
||||
|
||||
The host parameter can be a string, in that case the TCP server is
|
||||
|
@ -1026,13 +1037,16 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
for sock in sockets:
|
||||
sock.listen(backlog)
|
||||
sock.setblocking(False)
|
||||
self._start_serving(protocol_factory, sock, ssl, server, backlog)
|
||||
self._start_serving(protocol_factory, sock, ssl, server, backlog,
|
||||
ssl_handshake_timeout)
|
||||
if self._debug:
|
||||
logger.info("%r is serving", server)
|
||||
return server
|
||||
|
||||
async def connect_accepted_socket(self, protocol_factory, sock,
|
||||
*, ssl=None):
|
||||
async def connect_accepted_socket(
|
||||
self, protocol_factory, sock,
|
||||
*, ssl=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
"""Handle an accepted connection.
|
||||
|
||||
This is used by servers that accept connections outside of
|
||||
|
@ -1045,7 +1059,8 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
raise ValueError(f'A Stream Socket was expected, got {sock!r}')
|
||||
|
||||
transport, protocol = await self._create_connection_transport(
|
||||
sock, protocol_factory, ssl, '', server_side=True)
|
||||
sock, protocol_factory, ssl, '', server_side=True,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
if self._debug:
|
||||
# Get the socket from the transport because SSL transport closes
|
||||
# the old socket and creates a new SSL socket
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue