mirror of
https://github.com/python/cpython.git
synced 2025-09-27 10:50:04 +00:00
asyncio: Add server_hostname as create_connection() argument, with secure default.
This commit is contained in:
parent
2b430b8720
commit
21c85a7124
4 changed files with 78 additions and 5 deletions
|
@ -275,8 +275,27 @@ class BaseEventLoop(events.AbstractEventLoop):
|
||||||
@tasks.coroutine
|
@tasks.coroutine
|
||||||
def create_connection(self, protocol_factory, host=None, port=None, *,
|
def create_connection(self, protocol_factory, host=None, port=None, *,
|
||||||
ssl=None, family=0, proto=0, flags=0, sock=None,
|
ssl=None, family=0, proto=0, flags=0, sock=None,
|
||||||
local_addr=None):
|
local_addr=None, server_hostname=None):
|
||||||
"""XXX"""
|
"""XXX"""
|
||||||
|
if server_hostname is not None and not ssl:
|
||||||
|
raise ValueError('server_hostname is only meaningful with ssl')
|
||||||
|
|
||||||
|
if server_hostname is None and ssl:
|
||||||
|
# Use host as default for server_hostname. It is an error
|
||||||
|
# if host is empty or not set, e.g. when an
|
||||||
|
# already-connected socket was passed or when only a port
|
||||||
|
# is given. To avoid this error, you can pass
|
||||||
|
# server_hostname='' -- this will bypass the hostname
|
||||||
|
# check. (This also means that if host is a numeric
|
||||||
|
# IP/IPv6 address, we will attempt to verify that exact
|
||||||
|
# address; this will probably fail, but it is possible to
|
||||||
|
# create a certificate for a specific IP address, so we
|
||||||
|
# don't judge it here.)
|
||||||
|
if not host:
|
||||||
|
raise ValueError('You must set server_hostname '
|
||||||
|
'when using ssl without a host')
|
||||||
|
server_hostname = host
|
||||||
|
|
||||||
if host is not None or port is not None:
|
if host is not None or port is not None:
|
||||||
if sock is not None:
|
if sock is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -357,7 +376,7 @@ class BaseEventLoop(events.AbstractEventLoop):
|
||||||
sslcontext = None if isinstance(ssl, bool) else ssl
|
sslcontext = None if isinstance(ssl, bool) else ssl
|
||||||
transport = self._make_ssl_transport(
|
transport = self._make_ssl_transport(
|
||||||
sock, protocol, sslcontext, waiter,
|
sock, protocol, sslcontext, waiter,
|
||||||
server_side=False, server_hostname=host)
|
server_side=False, server_hostname=server_hostname)
|
||||||
else:
|
else:
|
||||||
transport = self._make_socket_transport(sock, protocol, waiter)
|
transport = self._make_socket_transport(sock, protocol, waiter)
|
||||||
|
|
||||||
|
|
|
@ -172,7 +172,7 @@ class AbstractEventLoop:
|
||||||
|
|
||||||
def create_connection(self, protocol_factory, host=None, port=None, *,
|
def create_connection(self, protocol_factory, host=None, port=None, *,
|
||||||
ssl=None, family=0, proto=0, flags=0, sock=None,
|
ssl=None, family=0, proto=0, flags=0, sock=None,
|
||||||
local_addr=None):
|
local_addr=None, server_hostname=None):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def create_server(self, protocol_factory, host=None, port=None, *,
|
def create_server(self, protocol_factory, host=None, port=None, *,
|
||||||
|
|
|
@ -573,7 +573,7 @@ class _SelectorSslTransport(_SelectorTransport):
|
||||||
'server_side': server_side,
|
'server_side': server_side,
|
||||||
'do_handshake_on_connect': False,
|
'do_handshake_on_connect': False,
|
||||||
}
|
}
|
||||||
if server_hostname is not None and not server_side and ssl.HAS_SNI:
|
if server_hostname and not server_side and ssl.HAS_SNI:
|
||||||
wrap_kwargs['server_hostname'] = server_hostname
|
wrap_kwargs['server_hostname'] = server_hostname
|
||||||
sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs)
|
sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs)
|
||||||
|
|
||||||
|
@ -619,7 +619,7 @@ class _SelectorSslTransport(_SelectorTransport):
|
||||||
|
|
||||||
# Verify hostname if requested.
|
# Verify hostname if requested.
|
||||||
peercert = self._sock.getpeercert()
|
peercert = self._sock.getpeercert()
|
||||||
if (self._server_hostname is not None and
|
if (self._server_hostname and
|
||||||
self._sslcontext.verify_mode != ssl.CERT_NONE):
|
self._sslcontext.verify_mode != ssl.CERT_NONE):
|
||||||
try:
|
try:
|
||||||
ssl.match_hostname(peercert, self._server_hostname)
|
ssl.match_hostname(peercert, self._server_hostname)
|
||||||
|
|
|
@ -444,6 +444,60 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
OSError, self.loop.run_until_complete, coro)
|
OSError, self.loop.run_until_complete, coro)
|
||||||
|
|
||||||
|
def test_create_connection_server_hostname_default(self):
|
||||||
|
self.loop.getaddrinfo = unittest.mock.Mock()
|
||||||
|
def mock_getaddrinfo(*args, **kwds):
|
||||||
|
f = futures.Future(loop=self.loop)
|
||||||
|
f.set_result([(socket.AF_INET, socket.SOCK_STREAM,
|
||||||
|
socket.SOL_TCP, '', ('1.2.3.4', 80))])
|
||||||
|
return f
|
||||||
|
self.loop.getaddrinfo.side_effect = mock_getaddrinfo
|
||||||
|
self.loop.sock_connect = unittest.mock.Mock()
|
||||||
|
self.loop.sock_connect.return_value = ()
|
||||||
|
self.loop._make_ssl_transport = unittest.mock.Mock()
|
||||||
|
def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, **kwds):
|
||||||
|
waiter.set_result(None)
|
||||||
|
self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport
|
||||||
|
ANY = unittest.mock.ANY
|
||||||
|
# First try the default server_hostname.
|
||||||
|
self.loop._make_ssl_transport.reset_mock()
|
||||||
|
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True)
|
||||||
|
self.loop.run_until_complete(coro)
|
||||||
|
self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY,
|
||||||
|
server_side=False,
|
||||||
|
server_hostname='python.org')
|
||||||
|
# Next try an explicit server_hostname.
|
||||||
|
self.loop._make_ssl_transport.reset_mock()
|
||||||
|
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True,
|
||||||
|
server_hostname='perl.com')
|
||||||
|
self.loop.run_until_complete(coro)
|
||||||
|
self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY,
|
||||||
|
server_side=False,
|
||||||
|
server_hostname='perl.com')
|
||||||
|
# Finally try an explicit empty server_hostname.
|
||||||
|
self.loop._make_ssl_transport.reset_mock()
|
||||||
|
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True,
|
||||||
|
server_hostname='')
|
||||||
|
self.loop.run_until_complete(coro)
|
||||||
|
self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY,
|
||||||
|
server_side=False,
|
||||||
|
server_hostname='')
|
||||||
|
|
||||||
|
def test_create_connection_server_hostname_errors(self):
|
||||||
|
# When not using ssl, server_hostname must be None (but '' is OK).
|
||||||
|
coro = self.loop.create_connection(MyProto, 'python.org', 80, server_hostname='')
|
||||||
|
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
|
||||||
|
coro = self.loop.create_connection(MyProto, 'python.org', 80, server_hostname='python.org')
|
||||||
|
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
|
||||||
|
|
||||||
|
# When using ssl, server_hostname may be None if host is non-empty.
|
||||||
|
coro = self.loop.create_connection(MyProto, '', 80, ssl=True)
|
||||||
|
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
|
||||||
|
coro = self.loop.create_connection(MyProto, None, 80, ssl=True)
|
||||||
|
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
|
||||||
|
coro = self.loop.create_connection(MyProto, None, None, ssl=True, sock=socket.socket())
|
||||||
|
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
|
||||||
|
|
||||||
def test_create_server_empty_host(self):
|
def test_create_server_empty_host(self):
|
||||||
# if host is empty string use None instead
|
# if host is empty string use None instead
|
||||||
host = object()
|
host = object()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue