asyncio: Skip getaddrinfo if host is already resolved.

getaddrinfo takes an exclusive lock on some platforms, causing clients to queue
up waiting for the lock if many names are being resolved concurrently. Users
may want to handle name resolution in their own code, for the sake of caching,
using an alternate resolver, or to measure DNS duration separately from
connection duration. Skip getaddrinfo if the "host" passed into
create_connection is already resolved.

See https://github.com/python/asyncio/pull/302 for details.

Patch by A. Jesse Jiryu Davis.
This commit is contained in:
Yury Selivanov 2015-12-16 19:31:17 -05:00
parent 8c084eb77d
commit d5c2a62100
7 changed files with 283 additions and 67 deletions

View file

@ -32,6 +32,120 @@ MOCK_ANY = mock.ANY
PY34 = sys.version_info >= (3, 4)
def mock_socket_module():
m_socket = mock.MagicMock(spec=socket)
for name in (
'AF_INET', 'AF_INET6', 'AF_UNSPEC', 'IPPROTO_TCP', 'IPPROTO_UDP',
'SOCK_STREAM', 'SOCK_DGRAM', 'SOL_SOCKET', 'SO_REUSEADDR', 'inet_pton'
):
if hasattr(socket, name):
setattr(m_socket, name, getattr(socket, name))
else:
delattr(m_socket, name)
m_socket.socket = mock.MagicMock()
m_socket.socket.return_value = test_utils.mock_nonblocking_socket()
return m_socket
def patch_socket(f):
return mock.patch('asyncio.base_events.socket',
new_callable=mock_socket_module)(f)
class BaseEventTests(test_utils.TestCase):
def setUp(self):
super().setUp()
base_events._ipaddr_info.cache_clear()
def tearDown(self):
base_events._ipaddr_info.cache_clear()
super().tearDown()
def test_ipaddr_info(self):
UNSPEC = socket.AF_UNSPEC
INET = socket.AF_INET
INET6 = socket.AF_INET6
STREAM = socket.SOCK_STREAM
DGRAM = socket.SOCK_DGRAM
TCP = socket.IPPROTO_TCP
UDP = socket.IPPROTO_UDP
self.assertEqual(
(INET, STREAM, TCP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', 1, INET, STREAM, TCP))
self.assertEqual(
(INET, STREAM, TCP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, STREAM, TCP))
self.assertEqual(
(INET, DGRAM, UDP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, DGRAM, UDP))
# Socket type STREAM implies TCP protocol.
self.assertEqual(
(INET, STREAM, TCP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, STREAM, 0))
# Socket type DGRAM implies UDP protocol.
self.assertEqual(
(INET, DGRAM, UDP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, DGRAM, 0))
# No socket type.
self.assertIsNone(
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, 0, 0))
# IPv4 address with family IPv6.
self.assertIsNone(
base_events._ipaddr_info('1.2.3.4', 1, INET6, STREAM, TCP))
self.assertEqual(
(INET6, STREAM, TCP, '', ('::3', 1)),
base_events._ipaddr_info('::3', 1, INET6, STREAM, TCP))
self.assertEqual(
(INET6, STREAM, TCP, '', ('::3', 1)),
base_events._ipaddr_info('::3', 1, UNSPEC, STREAM, TCP))
# IPv6 address with family IPv4.
self.assertIsNone(
base_events._ipaddr_info('::3', 1, INET, STREAM, TCP))
# IPv6 address with zone index.
self.assertEqual(
(INET6, STREAM, TCP, '', ('::3%lo0', 1)),
base_events._ipaddr_info('::3%lo0', 1, INET6, STREAM, TCP))
@patch_socket
def test_ipaddr_info_no_inet_pton(self, m_socket):
del m_socket.inet_pton
self.test_ipaddr_info()
def test_check_resolved_address(self):
sock = socket.socket(socket.AF_INET)
base_events._check_resolved_address(sock, ('1.2.3.4', 1))
sock = socket.socket(socket.AF_INET6)
base_events._check_resolved_address(sock, ('::3', 1))
base_events._check_resolved_address(sock, ('::3%lo0', 1))
self.assertRaises(ValueError,
base_events._check_resolved_address, sock, ('foo', 1))
def test_check_resolved_sock_type(self):
# Ensure we ignore extra flags in sock.type.
if hasattr(socket, 'SOCK_NONBLOCK'):
sock = socket.socket(type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
base_events._check_resolved_address(sock, ('1.2.3.4', 1))
if hasattr(socket, 'SOCK_CLOEXEC'):
sock = socket.socket(type=socket.SOCK_STREAM | socket.SOCK_CLOEXEC)
base_events._check_resolved_address(sock, ('1.2.3.4', 1))
class BaseEventLoopTests(test_utils.TestCase):
def setUp(self):
@ -875,7 +989,12 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.loop = asyncio.new_event_loop()
self.set_event_loop(self.loop)
@mock.patch('asyncio.base_events.socket')
def tearDown(self):
# Clear mocked constants like AF_INET from the cache.
base_events._ipaddr_info.cache_clear()
super().tearDown()
@patch_socket
def test_create_connection_multiple_errors(self, m_socket):
class MyProto(asyncio.Protocol):
@ -908,7 +1027,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2')
@mock.patch('asyncio.base_events.socket')
@patch_socket
def test_create_connection_timeout(self, m_socket):
# Ensure that the socket is closed on timeout
sock = mock.Mock()
@ -986,7 +1105,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
with self.assertRaises(OSError):
self.loop.run_until_complete(coro)
@mock.patch('asyncio.base_events.socket')
@patch_socket
def test_create_connection_multiple_errors_local_addr(self, m_socket):
def bind(addr):
@ -1018,6 +1137,46 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertTrue(str(cm.exception).startswith('Multiple exceptions: '))
self.assertTrue(m_socket.socket.return_value.close.called)
def _test_create_connection_ip_addr(self, m_socket, allow_inet_pton):
# Test the fallback code, even if this system has inet_pton.
if not allow_inet_pton:
del m_socket.inet_pton
def getaddrinfo(*args, **kw):
self.fail('should not have called getaddrinfo')
m_socket.getaddrinfo = getaddrinfo
sock = m_socket.socket.return_value
self.loop.add_reader = mock.Mock()
self.loop.add_reader._is_coroutine = False
self.loop.add_writer = mock.Mock()
self.loop.add_writer._is_coroutine = False
coro = self.loop.create_connection(MyProto, '1.2.3.4', 80)
self.loop.run_until_complete(coro)
sock.connect.assert_called_with(('1.2.3.4', 80))
m_socket.socket.assert_called_with(family=m_socket.AF_INET,
proto=m_socket.IPPROTO_TCP,
type=m_socket.SOCK_STREAM)
sock.family = socket.AF_INET6
coro = self.loop.create_connection(MyProto, '::2', 80)
self.loop.run_until_complete(coro)
sock.connect.assert_called_with(('::2', 80))
m_socket.socket.assert_called_with(family=m_socket.AF_INET6,
proto=m_socket.IPPROTO_TCP,
type=m_socket.SOCK_STREAM)
@patch_socket
def test_create_connection_ip_addr(self, m_socket):
self._test_create_connection_ip_addr(m_socket, True)
@patch_socket
def test_create_connection_no_inet_pton(self, m_socket):
self._test_create_connection_ip_addr(m_socket, False)
def test_create_connection_no_local_addr(self):
@asyncio.coroutine
def getaddrinfo(host, *args, **kw):
@ -1153,11 +1312,9 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
f = self.loop.create_server(MyProto, '0.0.0.0', 0)
self.assertRaises(OSError, self.loop.run_until_complete, f)
@mock.patch('asyncio.base_events.socket')
@patch_socket
def test_create_server_nosoreuseport(self, m_socket):
m_socket.getaddrinfo = socket.getaddrinfo
m_socket.SOCK_STREAM = socket.SOCK_STREAM
m_socket.SOL_SOCKET = socket.SOL_SOCKET
del m_socket.SO_REUSEPORT
m_socket.socket.return_value = mock.Mock()
@ -1166,7 +1323,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertRaises(ValueError, self.loop.run_until_complete, f)
@mock.patch('asyncio.base_events.socket')
@patch_socket
def test_create_server_cant_bind(self, m_socket):
class Err(OSError):
@ -1182,7 +1339,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertRaises(OSError, self.loop.run_until_complete, fut)
self.assertTrue(m_sock.close.called)
@mock.patch('asyncio.base_events.socket')
@patch_socket
def test_create_datagram_endpoint_no_addrinfo(self, m_socket):
m_socket.getaddrinfo.return_value = []
m_socket.getaddrinfo._is_coroutine = False
@ -1211,7 +1368,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertRaises(
OSError, self.loop.run_until_complete, coro)
@mock.patch('asyncio.base_events.socket')
@patch_socket
def test_create_datagram_endpoint_socket_err(self, m_socket):
m_socket.getaddrinfo = socket.getaddrinfo
m_socket.socket.side_effect = OSError
@ -1234,7 +1391,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertRaises(
ValueError, self.loop.run_until_complete, coro)
@mock.patch('asyncio.base_events.socket')
@patch_socket
def test_create_datagram_endpoint_setblk_err(self, m_socket):
m_socket.socket.return_value.setblocking.side_effect = OSError
@ -1250,12 +1407,11 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
asyncio.DatagramProtocol)
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
@mock.patch('asyncio.base_events.socket')
@patch_socket
def test_create_datagram_endpoint_cant_bind(self, m_socket):
class Err(OSError):
pass
m_socket.AF_INET6 = socket.AF_INET6
m_socket.getaddrinfo = socket.getaddrinfo
m_sock = m_socket.socket.return_value = mock.Mock()
m_sock.bind.side_effect = Err
@ -1369,11 +1525,8 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.loop.run_until_complete(protocol.done)
self.assertEqual('CLOSED', protocol.state)
@mock.patch('asyncio.base_events.socket')
@patch_socket
def test_create_datagram_endpoint_nosoreuseport(self, m_socket):
m_socket.getaddrinfo = socket.getaddrinfo
m_socket.SOCK_DGRAM = socket.SOCK_DGRAM
m_socket.SOL_SOCKET = socket.SOL_SOCKET
del m_socket.SO_REUSEPORT
m_socket.socket.return_value = mock.Mock()
@ -1385,6 +1538,29 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
@patch_socket
def test_create_datagram_endpoint_ip_addr(self, m_socket):
def getaddrinfo(*args, **kw):
self.fail('should not have called getaddrinfo')
m_socket.getaddrinfo = getaddrinfo
m_socket.socket.return_value.bind = bind = mock.Mock()
self.loop.add_reader = mock.Mock()
self.loop.add_reader._is_coroutine = False
reuseport_supported = hasattr(socket, 'SO_REUSEPORT')
coro = self.loop.create_datagram_endpoint(
lambda: MyDatagramProto(loop=self.loop),
local_addr=('1.2.3.4', 0),
reuse_address=False,
reuse_port=reuseport_supported)
self.loop.run_until_complete(coro)
bind.assert_called_with(('1.2.3.4', 0))
m_socket.socket.assert_called_with(family=m_socket.AF_INET,
proto=m_socket.IPPROTO_UDP,
type=m_socket.SOCK_DGRAM)
def test_accept_connection_retry(self):
sock = mock.Mock()
sock.accept.side_effect = BlockingIOError()

View file

@ -1573,10 +1573,6 @@ class EventLoopTestsMixin:
'selector': self.loop._selector.__class__.__name__})
def test_sock_connect_address(self):
# In debug mode, sock_connect() must ensure that the address is already
# resolved (call _check_resolved_address())
self.loop.set_debug(True)
addresses = [(socket.AF_INET, ('www.python.org', 80))]
if support.IPV6_ENABLED:
addresses.extend((

View file

@ -436,7 +436,7 @@ class ProactorSocketTransportTests(test_utils.TestCase):
class BaseProactorEventLoopTests(test_utils.TestCase):
def setUp(self):
self.sock = mock.Mock(socket.socket)
self.sock = test_utils.mock_nonblocking_socket()
self.proactor = mock.Mock()
self.ssock, self.csock = mock.Mock(), mock.Mock()
@ -491,8 +491,8 @@ class BaseProactorEventLoopTests(test_utils.TestCase):
self.proactor.send.assert_called_with(self.sock, b'data')
def test_sock_connect(self):
self.loop.sock_connect(self.sock, 123)
self.proactor.connect.assert_called_with(self.sock, 123)
self.loop.sock_connect(self.sock, ('1.2.3.4', 123))
self.proactor.connect.assert_called_with(self.sock, ('1.2.3.4', 123))
def test_sock_accept(self):
self.loop.sock_accept(self.sock)