mirror of
https://github.com/python/cpython.git
synced 2025-07-24 11:44:31 +00:00
bpo-46805: Add low level UDP socket functions to asyncio (GH-31455)
This commit is contained in:
parent
7e473e94a5
commit
9f04ee569c
12 changed files with 489 additions and 7 deletions
|
@ -5,11 +5,11 @@ import unittest
|
|||
|
||||
from asyncio import proactor_events
|
||||
from itertools import cycle, islice
|
||||
from unittest.mock import patch, Mock
|
||||
from test.test_asyncio import utils as test_utils
|
||||
from test import support
|
||||
from test.support import socket_helper
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
@ -380,6 +380,79 @@ class BaseSockTestsMixin:
|
|||
self.loop.run_until_complete(
|
||||
self._basetest_huge_content_recvinto(httpd.address))
|
||||
|
||||
async def _basetest_datagram_recvfrom(self, server_address):
|
||||
# Happy path, sock.sendto() returns immediately
|
||||
data = b'\x01' * 4096
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
|
||||
sock.setblocking(False)
|
||||
await self.loop.sock_sendto(sock, data, server_address)
|
||||
received_data, from_addr = await self.loop.sock_recvfrom(
|
||||
sock, 4096)
|
||||
self.assertEqual(received_data, data)
|
||||
self.assertEqual(from_addr, server_address)
|
||||
|
||||
def test_recvfrom(self):
|
||||
with test_utils.run_udp_echo_server() as server_address:
|
||||
self.loop.run_until_complete(
|
||||
self._basetest_datagram_recvfrom(server_address))
|
||||
|
||||
async def _basetest_datagram_recvfrom_into(self, server_address):
|
||||
# Happy path, sock.sendto() returns immediately
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
|
||||
sock.setblocking(False)
|
||||
|
||||
buf = bytearray(4096)
|
||||
data = b'\x01' * 4096
|
||||
await self.loop.sock_sendto(sock, data, server_address)
|
||||
num_bytes, from_addr = await self.loop.sock_recvfrom_into(
|
||||
sock, buf)
|
||||
self.assertEqual(num_bytes, 4096)
|
||||
self.assertEqual(buf, data)
|
||||
self.assertEqual(from_addr, server_address)
|
||||
|
||||
buf = bytearray(8192)
|
||||
await self.loop.sock_sendto(sock, data, server_address)
|
||||
num_bytes, from_addr = await self.loop.sock_recvfrom_into(
|
||||
sock, buf, 4096)
|
||||
self.assertEqual(num_bytes, 4096)
|
||||
self.assertEqual(buf[:4096], data[:4096])
|
||||
self.assertEqual(from_addr, server_address)
|
||||
|
||||
def test_recvfrom_into(self):
|
||||
with test_utils.run_udp_echo_server() as server_address:
|
||||
self.loop.run_until_complete(
|
||||
self._basetest_datagram_recvfrom_into(server_address))
|
||||
|
||||
async def _basetest_datagram_sendto_blocking(self, server_address):
|
||||
# Sad path, sock.sendto() raises BlockingIOError
|
||||
# This involves patching sock.sendto() to raise BlockingIOError but
|
||||
# sendto() is not used by the proactor event loop
|
||||
data = b'\x01' * 4096
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
|
||||
sock.setblocking(False)
|
||||
mock_sock = Mock(sock)
|
||||
mock_sock.gettimeout = sock.gettimeout
|
||||
mock_sock.sendto.configure_mock(side_effect=BlockingIOError)
|
||||
mock_sock.fileno = sock.fileno
|
||||
self.loop.call_soon(
|
||||
lambda: setattr(mock_sock, 'sendto', sock.sendto)
|
||||
)
|
||||
await self.loop.sock_sendto(mock_sock, data, server_address)
|
||||
|
||||
received_data, from_addr = await self.loop.sock_recvfrom(
|
||||
sock, 4096)
|
||||
self.assertEqual(received_data, data)
|
||||
self.assertEqual(from_addr, server_address)
|
||||
|
||||
def test_sendto_blocking(self):
|
||||
if sys.platform == 'win32':
|
||||
if isinstance(self.loop, asyncio.ProactorEventLoop):
|
||||
raise unittest.SkipTest('Not relevant to ProactorEventLoop')
|
||||
|
||||
with test_utils.run_udp_echo_server() as server_address:
|
||||
self.loop.run_until_complete(
|
||||
self._basetest_datagram_sendto_blocking(server_address))
|
||||
|
||||
@socket_helper.skip_unless_bind_unix_socket
|
||||
def test_unix_sock_client_ops(self):
|
||||
with test_utils.run_test_unix_server() as httpd:
|
||||
|
|
|
@ -281,6 +281,31 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
|
|||
server_ssl_cls=SSLWSGIServer)
|
||||
|
||||
|
||||
def echo_datagrams(sock):
|
||||
while True:
|
||||
data, addr = sock.recvfrom(4096)
|
||||
if data == b'STOP':
|
||||
sock.close()
|
||||
break
|
||||
else:
|
||||
sock.sendto(data, addr)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_udp_echo_server(*, host='127.0.0.1', port=0):
|
||||
addr_info = socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
|
||||
family, type, proto, _, sockaddr = addr_info[0]
|
||||
sock = socket.socket(family, type, proto)
|
||||
sock.bind((host, port))
|
||||
thread = threading.Thread(target=lambda: echo_datagrams(sock))
|
||||
thread.start()
|
||||
try:
|
||||
yield sock.getsockname()
|
||||
finally:
|
||||
sock.sendto(b'STOP', sock.getsockname())
|
||||
thread.join()
|
||||
|
||||
|
||||
def make_test_protocol(base):
|
||||
dct = {}
|
||||
for name in dir(base):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue