mirror of
https://github.com/python/cpython.git
synced 2025-08-04 00:48:58 +00:00
asyncio: Skip getaddrinfo if host is already resolved.
getaddrinfo takes an exclusive lock on some platforms, causing clients to queue up waiting for the lock if many names are being resolved concurrently. Users may want to handle name resolution in their own code, for the sake of caching, using an alternate resolver, or to measure DNS duration separately from connection duration. Skip getaddrinfo if the "host" passed into create_connection is already resolved. See https://github.com/python/asyncio/pull/302 for details. Patch by A. Jesse Jiryu Davis.
This commit is contained in:
parent
8c084eb77d
commit
d5c2a62100
7 changed files with 283 additions and 67 deletions
|
@ -32,6 +32,120 @@ MOCK_ANY = mock.ANY
|
|||
PY34 = sys.version_info >= (3, 4)
|
||||
|
||||
|
||||
def mock_socket_module():
|
||||
m_socket = mock.MagicMock(spec=socket)
|
||||
for name in (
|
||||
'AF_INET', 'AF_INET6', 'AF_UNSPEC', 'IPPROTO_TCP', 'IPPROTO_UDP',
|
||||
'SOCK_STREAM', 'SOCK_DGRAM', 'SOL_SOCKET', 'SO_REUSEADDR', 'inet_pton'
|
||||
):
|
||||
if hasattr(socket, name):
|
||||
setattr(m_socket, name, getattr(socket, name))
|
||||
else:
|
||||
delattr(m_socket, name)
|
||||
|
||||
m_socket.socket = mock.MagicMock()
|
||||
m_socket.socket.return_value = test_utils.mock_nonblocking_socket()
|
||||
|
||||
return m_socket
|
||||
|
||||
|
||||
def patch_socket(f):
|
||||
return mock.patch('asyncio.base_events.socket',
|
||||
new_callable=mock_socket_module)(f)
|
||||
|
||||
|
||||
class BaseEventTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
base_events._ipaddr_info.cache_clear()
|
||||
|
||||
def tearDown(self):
|
||||
base_events._ipaddr_info.cache_clear()
|
||||
super().tearDown()
|
||||
|
||||
def test_ipaddr_info(self):
|
||||
UNSPEC = socket.AF_UNSPEC
|
||||
INET = socket.AF_INET
|
||||
INET6 = socket.AF_INET6
|
||||
STREAM = socket.SOCK_STREAM
|
||||
DGRAM = socket.SOCK_DGRAM
|
||||
TCP = socket.IPPROTO_TCP
|
||||
UDP = socket.IPPROTO_UDP
|
||||
|
||||
self.assertEqual(
|
||||
(INET, STREAM, TCP, '', ('1.2.3.4', 1)),
|
||||
base_events._ipaddr_info('1.2.3.4', 1, INET, STREAM, TCP))
|
||||
|
||||
self.assertEqual(
|
||||
(INET, STREAM, TCP, '', ('1.2.3.4', 1)),
|
||||
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, STREAM, TCP))
|
||||
|
||||
self.assertEqual(
|
||||
(INET, DGRAM, UDP, '', ('1.2.3.4', 1)),
|
||||
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, DGRAM, UDP))
|
||||
|
||||
# Socket type STREAM implies TCP protocol.
|
||||
self.assertEqual(
|
||||
(INET, STREAM, TCP, '', ('1.2.3.4', 1)),
|
||||
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, STREAM, 0))
|
||||
|
||||
# Socket type DGRAM implies UDP protocol.
|
||||
self.assertEqual(
|
||||
(INET, DGRAM, UDP, '', ('1.2.3.4', 1)),
|
||||
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, DGRAM, 0))
|
||||
|
||||
# No socket type.
|
||||
self.assertIsNone(
|
||||
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, 0, 0))
|
||||
|
||||
# IPv4 address with family IPv6.
|
||||
self.assertIsNone(
|
||||
base_events._ipaddr_info('1.2.3.4', 1, INET6, STREAM, TCP))
|
||||
|
||||
self.assertEqual(
|
||||
(INET6, STREAM, TCP, '', ('::3', 1)),
|
||||
base_events._ipaddr_info('::3', 1, INET6, STREAM, TCP))
|
||||
|
||||
self.assertEqual(
|
||||
(INET6, STREAM, TCP, '', ('::3', 1)),
|
||||
base_events._ipaddr_info('::3', 1, UNSPEC, STREAM, TCP))
|
||||
|
||||
# IPv6 address with family IPv4.
|
||||
self.assertIsNone(
|
||||
base_events._ipaddr_info('::3', 1, INET, STREAM, TCP))
|
||||
|
||||
# IPv6 address with zone index.
|
||||
self.assertEqual(
|
||||
(INET6, STREAM, TCP, '', ('::3%lo0', 1)),
|
||||
base_events._ipaddr_info('::3%lo0', 1, INET6, STREAM, TCP))
|
||||
|
||||
@patch_socket
|
||||
def test_ipaddr_info_no_inet_pton(self, m_socket):
|
||||
del m_socket.inet_pton
|
||||
self.test_ipaddr_info()
|
||||
|
||||
def test_check_resolved_address(self):
|
||||
sock = socket.socket(socket.AF_INET)
|
||||
base_events._check_resolved_address(sock, ('1.2.3.4', 1))
|
||||
|
||||
sock = socket.socket(socket.AF_INET6)
|
||||
base_events._check_resolved_address(sock, ('::3', 1))
|
||||
base_events._check_resolved_address(sock, ('::3%lo0', 1))
|
||||
self.assertRaises(ValueError,
|
||||
base_events._check_resolved_address, sock, ('foo', 1))
|
||||
|
||||
def test_check_resolved_sock_type(self):
|
||||
# Ensure we ignore extra flags in sock.type.
|
||||
if hasattr(socket, 'SOCK_NONBLOCK'):
|
||||
sock = socket.socket(type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
|
||||
base_events._check_resolved_address(sock, ('1.2.3.4', 1))
|
||||
|
||||
if hasattr(socket, 'SOCK_CLOEXEC'):
|
||||
sock = socket.socket(type=socket.SOCK_STREAM | socket.SOCK_CLOEXEC)
|
||||
base_events._check_resolved_address(sock, ('1.2.3.4', 1))
|
||||
|
||||
|
||||
class BaseEventLoopTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -875,7 +989,12 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
|
|||
self.loop = asyncio.new_event_loop()
|
||||
self.set_event_loop(self.loop)
|
||||
|
||||
@mock.patch('asyncio.base_events.socket')
|
||||
def tearDown(self):
|
||||
# Clear mocked constants like AF_INET from the cache.
|
||||
base_events._ipaddr_info.cache_clear()
|
||||
super().tearDown()
|
||||
|
||||
@patch_socket
|
||||
def test_create_connection_multiple_errors(self, m_socket):
|
||||
|
||||
class MyProto(asyncio.Protocol):
|
||||
|
@ -908,7 +1027,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
|
|||
|
||||
self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2')
|
||||
|
||||
@mock.patch('asyncio.base_events.socket')
|
||||
@patch_socket
|
||||
def test_create_connection_timeout(self, m_socket):
|
||||
# Ensure that the socket is closed on timeout
|
||||
sock = mock.Mock()
|
||||
|
@ -986,7 +1105,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
|
|||
with self.assertRaises(OSError):
|
||||
self.loop.run_until_complete(coro)
|
||||
|
||||
@mock.patch('asyncio.base_events.socket')
|
||||
@patch_socket
|
||||
def test_create_connection_multiple_errors_local_addr(self, m_socket):
|
||||
|
||||
def bind(addr):
|
||||
|
@ -1018,6 +1137,46 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
|
|||
self.assertTrue(str(cm.exception).startswith('Multiple exceptions: '))
|
||||
self.assertTrue(m_socket.socket.return_value.close.called)
|
||||
|
||||
def _test_create_connection_ip_addr(self, m_socket, allow_inet_pton):
|
||||
# Test the fallback code, even if this system has inet_pton.
|
||||
if not allow_inet_pton:
|
||||
del m_socket.inet_pton
|
||||
|
||||
def getaddrinfo(*args, **kw):
|
||||
self.fail('should not have called getaddrinfo')
|
||||
|
||||
m_socket.getaddrinfo = getaddrinfo
|
||||
sock = m_socket.socket.return_value
|
||||
|
||||
self.loop.add_reader = mock.Mock()
|
||||
self.loop.add_reader._is_coroutine = False
|
||||
self.loop.add_writer = mock.Mock()
|
||||
self.loop.add_writer._is_coroutine = False
|
||||
|
||||
coro = self.loop.create_connection(MyProto, '1.2.3.4', 80)
|
||||
self.loop.run_until_complete(coro)
|
||||
sock.connect.assert_called_with(('1.2.3.4', 80))
|
||||
m_socket.socket.assert_called_with(family=m_socket.AF_INET,
|
||||
proto=m_socket.IPPROTO_TCP,
|
||||
type=m_socket.SOCK_STREAM)
|
||||
|
||||
sock.family = socket.AF_INET6
|
||||
coro = self.loop.create_connection(MyProto, '::2', 80)
|
||||
|
||||
self.loop.run_until_complete(coro)
|
||||
sock.connect.assert_called_with(('::2', 80))
|
||||
m_socket.socket.assert_called_with(family=m_socket.AF_INET6,
|
||||
proto=m_socket.IPPROTO_TCP,
|
||||
type=m_socket.SOCK_STREAM)
|
||||
|
||||
@patch_socket
|
||||
def test_create_connection_ip_addr(self, m_socket):
|
||||
self._test_create_connection_ip_addr(m_socket, True)
|
||||
|
||||
@patch_socket
|
||||
def test_create_connection_no_inet_pton(self, m_socket):
|
||||
self._test_create_connection_ip_addr(m_socket, False)
|
||||
|
||||
def test_create_connection_no_local_addr(self):
|
||||
@asyncio.coroutine
|
||||
def getaddrinfo(host, *args, **kw):
|
||||
|
@ -1153,11 +1312,9 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
|
|||
f = self.loop.create_server(MyProto, '0.0.0.0', 0)
|
||||
self.assertRaises(OSError, self.loop.run_until_complete, f)
|
||||
|
||||
@mock.patch('asyncio.base_events.socket')
|
||||
@patch_socket
|
||||
def test_create_server_nosoreuseport(self, m_socket):
|
||||
m_socket.getaddrinfo = socket.getaddrinfo
|
||||
m_socket.SOCK_STREAM = socket.SOCK_STREAM
|
||||
m_socket.SOL_SOCKET = socket.SOL_SOCKET
|
||||
del m_socket.SO_REUSEPORT
|
||||
m_socket.socket.return_value = mock.Mock()
|
||||
|
||||
|
@ -1166,7 +1323,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
|
|||
|
||||
self.assertRaises(ValueError, self.loop.run_until_complete, f)
|
||||
|
||||
@mock.patch('asyncio.base_events.socket')
|
||||
@patch_socket
|
||||
def test_create_server_cant_bind(self, m_socket):
|
||||
|
||||
class Err(OSError):
|
||||
|
@ -1182,7 +1339,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
|
|||
self.assertRaises(OSError, self.loop.run_until_complete, fut)
|
||||
self.assertTrue(m_sock.close.called)
|
||||
|
||||
@mock.patch('asyncio.base_events.socket')
|
||||
@patch_socket
|
||||
def test_create_datagram_endpoint_no_addrinfo(self, m_socket):
|
||||
m_socket.getaddrinfo.return_value = []
|
||||
m_socket.getaddrinfo._is_coroutine = False
|
||||
|
@ -1211,7 +1368,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
|
|||
self.assertRaises(
|
||||
OSError, self.loop.run_until_complete, coro)
|
||||
|
||||
@mock.patch('asyncio.base_events.socket')
|
||||
@patch_socket
|
||||
def test_create_datagram_endpoint_socket_err(self, m_socket):
|
||||
m_socket.getaddrinfo = socket.getaddrinfo
|
||||
m_socket.socket.side_effect = OSError
|
||||
|
@ -1234,7 +1391,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
|
|||
self.assertRaises(
|
||||
ValueError, self.loop.run_until_complete, coro)
|
||||
|
||||
@mock.patch('asyncio.base_events.socket')
|
||||
@patch_socket
|
||||
def test_create_datagram_endpoint_setblk_err(self, m_socket):
|
||||
m_socket.socket.return_value.setblocking.side_effect = OSError
|
||||
|
||||
|
@ -1250,12 +1407,11 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
|
|||
asyncio.DatagramProtocol)
|
||||
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
|
||||
|
||||
@mock.patch('asyncio.base_events.socket')
|
||||
@patch_socket
|
||||
def test_create_datagram_endpoint_cant_bind(self, m_socket):
|
||||
class Err(OSError):
|
||||
pass
|
||||
|
||||
m_socket.AF_INET6 = socket.AF_INET6
|
||||
m_socket.getaddrinfo = socket.getaddrinfo
|
||||
m_sock = m_socket.socket.return_value = mock.Mock()
|
||||
m_sock.bind.side_effect = Err
|
||||
|
@ -1369,11 +1525,8 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
|
|||
self.loop.run_until_complete(protocol.done)
|
||||
self.assertEqual('CLOSED', protocol.state)
|
||||
|
||||
@mock.patch('asyncio.base_events.socket')
|
||||
@patch_socket
|
||||
def test_create_datagram_endpoint_nosoreuseport(self, m_socket):
|
||||
m_socket.getaddrinfo = socket.getaddrinfo
|
||||
m_socket.SOCK_DGRAM = socket.SOCK_DGRAM
|
||||
m_socket.SOL_SOCKET = socket.SOL_SOCKET
|
||||
del m_socket.SO_REUSEPORT
|
||||
m_socket.socket.return_value = mock.Mock()
|
||||
|
||||
|
@ -1385,6 +1538,29 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
|
|||
|
||||
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
|
||||
|
||||
@patch_socket
|
||||
def test_create_datagram_endpoint_ip_addr(self, m_socket):
|
||||
def getaddrinfo(*args, **kw):
|
||||
self.fail('should not have called getaddrinfo')
|
||||
|
||||
m_socket.getaddrinfo = getaddrinfo
|
||||
m_socket.socket.return_value.bind = bind = mock.Mock()
|
||||
self.loop.add_reader = mock.Mock()
|
||||
self.loop.add_reader._is_coroutine = False
|
||||
|
||||
reuseport_supported = hasattr(socket, 'SO_REUSEPORT')
|
||||
coro = self.loop.create_datagram_endpoint(
|
||||
lambda: MyDatagramProto(loop=self.loop),
|
||||
local_addr=('1.2.3.4', 0),
|
||||
reuse_address=False,
|
||||
reuse_port=reuseport_supported)
|
||||
|
||||
self.loop.run_until_complete(coro)
|
||||
bind.assert_called_with(('1.2.3.4', 0))
|
||||
m_socket.socket.assert_called_with(family=m_socket.AF_INET,
|
||||
proto=m_socket.IPPROTO_UDP,
|
||||
type=m_socket.SOCK_DGRAM)
|
||||
|
||||
def test_accept_connection_retry(self):
|
||||
sock = mock.Mock()
|
||||
sock.accept.side_effect = BlockingIOError()
|
||||
|
|
|
@ -1573,10 +1573,6 @@ class EventLoopTestsMixin:
|
|||
'selector': self.loop._selector.__class__.__name__})
|
||||
|
||||
def test_sock_connect_address(self):
|
||||
# In debug mode, sock_connect() must ensure that the address is already
|
||||
# resolved (call _check_resolved_address())
|
||||
self.loop.set_debug(True)
|
||||
|
||||
addresses = [(socket.AF_INET, ('www.python.org', 80))]
|
||||
if support.IPV6_ENABLED:
|
||||
addresses.extend((
|
||||
|
|
|
@ -436,7 +436,7 @@ class ProactorSocketTransportTests(test_utils.TestCase):
|
|||
class BaseProactorEventLoopTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.sock = mock.Mock(socket.socket)
|
||||
self.sock = test_utils.mock_nonblocking_socket()
|
||||
self.proactor = mock.Mock()
|
||||
|
||||
self.ssock, self.csock = mock.Mock(), mock.Mock()
|
||||
|
@ -491,8 +491,8 @@ class BaseProactorEventLoopTests(test_utils.TestCase):
|
|||
self.proactor.send.assert_called_with(self.sock, b'data')
|
||||
|
||||
def test_sock_connect(self):
|
||||
self.loop.sock_connect(self.sock, 123)
|
||||
self.proactor.connect.assert_called_with(self.sock, 123)
|
||||
self.loop.sock_connect(self.sock, ('1.2.3.4', 123))
|
||||
self.proactor.connect.assert_called_with(self.sock, ('1.2.3.4', 123))
|
||||
|
||||
def test_sock_accept(self):
|
||||
self.loop.sock_accept(self.sock)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue