asyncio: Add server_hostname as create_connection() argument, with secure default.

This commit is contained in:
Guido van Rossum 2013-11-01 14:16:54 -07:00
parent 2b430b8720
commit 21c85a7124
4 changed files with 78 additions and 5 deletions

View file

@ -444,6 +444,60 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
self.assertRaises(
OSError, self.loop.run_until_complete, coro)
def test_create_connection_server_hostname_default(self):
self.loop.getaddrinfo = unittest.mock.Mock()
def mock_getaddrinfo(*args, **kwds):
f = futures.Future(loop=self.loop)
f.set_result([(socket.AF_INET, socket.SOCK_STREAM,
socket.SOL_TCP, '', ('1.2.3.4', 80))])
return f
self.loop.getaddrinfo.side_effect = mock_getaddrinfo
self.loop.sock_connect = unittest.mock.Mock()
self.loop.sock_connect.return_value = ()
self.loop._make_ssl_transport = unittest.mock.Mock()
def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, **kwds):
waiter.set_result(None)
self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport
ANY = unittest.mock.ANY
# First try the default server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True)
self.loop.run_until_complete(coro)
self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='python.org')
# Next try an explicit server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True,
server_hostname='perl.com')
self.loop.run_until_complete(coro)
self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='perl.com')
# Finally try an explicit empty server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True,
server_hostname='')
self.loop.run_until_complete(coro)
self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='')
def test_create_connection_server_hostname_errors(self):
# When not using ssl, server_hostname must be None (but '' is OK).
coro = self.loop.create_connection(MyProto, 'python.org', 80, server_hostname='')
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
coro = self.loop.create_connection(MyProto, 'python.org', 80, server_hostname='python.org')
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
# When using ssl, server_hostname may be None if host is non-empty.
coro = self.loop.create_connection(MyProto, '', 80, ssl=True)
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
coro = self.loop.create_connection(MyProto, None, 80, ssl=True)
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
coro = self.loop.create_connection(MyProto, None, None, ssl=True, sock=socket.socket())
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
def test_create_server_empty_host(self):
# if host is empty string use None instead
host = object()