asyncio: Better-looking errors when ssl module cannot be imported. In part by Arnaud Faure.

This commit is contained in:
Guido van Rossum 2013-11-01 14:22:30 -07:00
parent a8d630a6e6
commit 28dff0d823
3 changed files with 41 additions and 12 deletions

View file

@ -466,6 +466,8 @@ class BaseEventLoop(events.AbstractEventLoop):
ssl=None, ssl=None,
reuse_address=None): reuse_address=None):
"""XXX""" """XXX"""
if isinstance(ssl, bool):
raise TypeError('ssl argument must be an SSLContext or None')
if host is not None or port is not None: if host is not None or port is not None:
if sock is not None: if sock is not None:
raise ValueError( raise ValueError(

View file

@ -90,12 +90,13 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
except (BlockingIOError, InterruptedError): except (BlockingIOError, InterruptedError):
pass pass
def _start_serving(self, protocol_factory, sock, ssl=None, server=None): def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None):
self.add_reader(sock.fileno(), self._accept_connection, self.add_reader(sock.fileno(), self._accept_connection,
protocol_factory, sock, ssl, server) protocol_factory, sock, sslcontext, server)
def _accept_connection(self, protocol_factory, sock, ssl=None, def _accept_connection(self, protocol_factory, sock,
server=None): sslcontext=None, server=None):
try: try:
conn, addr = sock.accept() conn, addr = sock.accept()
conn.setblocking(False) conn.setblocking(False)
@ -113,13 +114,13 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
self.remove_reader(sock.fileno()) self.remove_reader(sock.fileno())
self.call_later(constants.ACCEPT_RETRY_DELAY, self.call_later(constants.ACCEPT_RETRY_DELAY,
self._start_serving, self._start_serving,
protocol_factory, sock, ssl, server) protocol_factory, sock, sslcontext, server)
else: else:
raise # The event loop will catch, log and ignore it. raise # The event loop will catch, log and ignore it.
else: else:
if ssl: if sslcontext:
self._make_ssl_transport( self._make_ssl_transport(
conn, protocol_factory(), ssl, None, conn, protocol_factory(), sslcontext, None,
server_side=True, extra={'peername': addr}, server=server) server_side=True, extra={'peername': addr}, server=server)
else: else:
self._make_socket_transport( self._make_socket_transport(
@ -558,17 +559,23 @@ class _SelectorSslTransport(_SelectorTransport):
def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None,
server_side=False, server_hostname=None, server_side=False, server_hostname=None,
extra=None, server=None): extra=None, server=None):
if ssl is None:
raise RuntimeError('stdlib ssl module not available')
if server_side: if server_side:
assert isinstance( if not sslcontext:
sslcontext, ssl.SSLContext), 'Must pass an SSLContext' raise ValueError('Server side ssl needs a valid SSLContext')
else: else:
# Client-side may pass ssl=True to use a default context. if not sslcontext:
# The default is the same as used by urllib. # Client side may pass ssl=True to use a default
if sslcontext is None: # context; in that case the sslcontext passed is None.
# The default is the same as used by urllib with
# cadefault=True.
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext.options |= ssl.OP_NO_SSLv2 sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.set_default_verify_paths() sslcontext.set_default_verify_paths()
sslcontext.verify_mode = ssl.CERT_REQUIRED sslcontext.verify_mode = ssl.CERT_REQUIRED
wrap_kwargs = { wrap_kwargs = {
'server_side': server_side, 'server_side': server_side,
'do_handshake_on_connect': False, 'do_handshake_on_connect': False,

View file

@ -43,6 +43,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.assertIsInstance( self.assertIsInstance(
self.loop._make_socket_transport(m, m), _SelectorSocketTransport) self.loop._make_socket_transport(m, m), _SelectorSocketTransport)
@unittest.skipIf(ssl is None, 'No ssl module')
def test_make_ssl_transport(self): def test_make_ssl_transport(self):
m = unittest.mock.Mock() m = unittest.mock.Mock()
self.loop.add_reader = unittest.mock.Mock() self.loop.add_reader = unittest.mock.Mock()
@ -52,6 +53,16 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.assertIsInstance( self.assertIsInstance(
self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport) self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport)
@unittest.mock.patch('asyncio.selector_events.ssl', None)
def test_make_ssl_transport_without_ssl_error(self):
m = unittest.mock.Mock()
self.loop.add_reader = unittest.mock.Mock()
self.loop.add_writer = unittest.mock.Mock()
self.loop.remove_reader = unittest.mock.Mock()
self.loop.remove_writer = unittest.mock.Mock()
with self.assertRaises(RuntimeError):
self.loop._make_ssl_transport(m, m, m, m)
def test_close(self): def test_close(self):
ssock = self.loop._ssock ssock = self.loop._ssock
ssock.fileno.return_value = 7 ssock.fileno.return_value = 7
@ -1277,6 +1288,15 @@ class SelectorSslTransportTests(unittest.TestCase):
server_hostname='localhost') server_hostname='localhost')
class SelectorSslWithoutSslTransportTests(unittest.TestCase):
@unittest.mock.patch('asyncio.selector_events.ssl', None)
def test_ssl_transport_requires_ssl_module(self):
Mock = unittest.mock.Mock
with self.assertRaises(RuntimeError):
transport = _SelectorSslTransport(Mock(), Mock(), Mock(), Mock())
class SelectorDatagramTransportTests(unittest.TestCase): class SelectorDatagramTransportTests(unittest.TestCase):
def setUp(self): def setUp(self):