mirror of
				https://github.com/python/cpython.git
				synced 2025-10-26 08:19:20 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			281 lines
		
	
	
	
		
			7.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			281 lines
		
	
	
	
		
			7.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import asyncio
 | |
| import asyncio.events
 | |
| import contextlib
 | |
| import os
 | |
| import pprint
 | |
| import select
 | |
| import socket
 | |
| import tempfile
 | |
| import threading
 | |
| 
 | |
| 
 | |
| class FunctionalTestCaseMixin:
 | |
| 
 | |
|     def new_loop(self):
 | |
|         return asyncio.new_event_loop()
 | |
| 
 | |
|     def run_loop_briefly(self, *, delay=0.01):
 | |
|         self.loop.run_until_complete(asyncio.sleep(delay))
 | |
| 
 | |
|     def loop_exception_handler(self, loop, context):
 | |
|         self.__unhandled_exceptions.append(context)
 | |
|         self.loop.default_exception_handler(context)
 | |
| 
 | |
|     def setUp(self):
 | |
|         self.loop = self.new_loop()
 | |
|         asyncio.set_event_loop(None)
 | |
| 
 | |
|         self.loop.set_exception_handler(self.loop_exception_handler)
 | |
|         self.__unhandled_exceptions = []
 | |
| 
 | |
|         # Disable `_get_running_loop`.
 | |
|         self._old_get_running_loop = asyncio.events._get_running_loop
 | |
|         asyncio.events._get_running_loop = lambda: None
 | |
| 
 | |
|     def tearDown(self):
 | |
|         try:
 | |
|             self.loop.close()
 | |
| 
 | |
|             if self.__unhandled_exceptions:
 | |
|                 print('Unexpected calls to loop.call_exception_handler():')
 | |
|                 pprint.pprint(self.__unhandled_exceptions)
 | |
|                 self.fail('unexpected calls to loop.call_exception_handler()')
 | |
| 
 | |
|         finally:
 | |
|             asyncio.events._get_running_loop = self._old_get_running_loop
 | |
|             asyncio.set_event_loop(None)
 | |
|             self.loop = None
 | |
| 
 | |
|     def tcp_server(self, server_prog, *,
 | |
|                    family=socket.AF_INET,
 | |
|                    addr=None,
 | |
|                    timeout=5,
 | |
|                    backlog=1,
 | |
|                    max_clients=10):
 | |
| 
 | |
|         if addr is None:
 | |
|             if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX:
 | |
|                 with tempfile.NamedTemporaryFile() as tmp:
 | |
|                     addr = tmp.name
 | |
|             else:
 | |
|                 addr = ('127.0.0.1', 0)
 | |
| 
 | |
|         sock = socket.socket(family, socket.SOCK_STREAM)
 | |
| 
 | |
|         if timeout is None:
 | |
|             raise RuntimeError('timeout is required')
 | |
|         if timeout <= 0:
 | |
|             raise RuntimeError('only blocking sockets are supported')
 | |
|         sock.settimeout(timeout)
 | |
| 
 | |
|         try:
 | |
|             sock.bind(addr)
 | |
|             sock.listen(backlog)
 | |
|         except OSError as ex:
 | |
|             sock.close()
 | |
|             raise ex
 | |
| 
 | |
|         return TestThreadedServer(
 | |
|             self, sock, server_prog, timeout, max_clients)
 | |
| 
 | |
|     def tcp_client(self, client_prog,
 | |
|                    family=socket.AF_INET,
 | |
|                    timeout=10):
 | |
| 
 | |
|         sock = socket.socket(family, socket.SOCK_STREAM)
 | |
| 
 | |
|         if timeout is None:
 | |
|             raise RuntimeError('timeout is required')
 | |
|         if timeout <= 0:
 | |
|             raise RuntimeError('only blocking sockets are supported')
 | |
|         sock.settimeout(timeout)
 | |
| 
 | |
|         return TestThreadedClient(
 | |
|             self, sock, client_prog, timeout)
 | |
| 
 | |
|     def unix_server(self, *args, **kwargs):
 | |
|         if not hasattr(socket, 'AF_UNIX'):
 | |
|             raise NotImplementedError
 | |
|         return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
 | |
| 
 | |
|     def unix_client(self, *args, **kwargs):
 | |
|         if not hasattr(socket, 'AF_UNIX'):
 | |
|             raise NotImplementedError
 | |
|         return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
 | |
| 
 | |
|     @contextlib.contextmanager
 | |
|     def unix_sock_name(self):
 | |
|         with tempfile.TemporaryDirectory() as td:
 | |
|             fn = os.path.join(td, 'sock')
 | |
|             try:
 | |
|                 yield fn
 | |
|             finally:
 | |
|                 try:
 | |
|                     os.unlink(fn)
 | |
|                 except OSError:
 | |
|                     pass
 | |
| 
 | |
|     def _abort_socket_test(self, ex):
 | |
|         try:
 | |
|             self.loop.stop()
 | |
|         finally:
 | |
|             self.fail(ex)
 | |
| 
 | |
| 
 | |
| ##############################################################################
 | |
| # Socket Testing Utilities
 | |
| ##############################################################################
 | |
| 
 | |
| 
 | |
| class TestSocketWrapper:
 | |
| 
 | |
|     def __init__(self, sock):
 | |
|         self.__sock = sock
 | |
| 
 | |
|     def recv_all(self, n):
 | |
|         buf = b''
 | |
|         while len(buf) < n:
 | |
|             data = self.recv(n - len(buf))
 | |
|             if data == b'':
 | |
|                 raise ConnectionAbortedError
 | |
|             buf += data
 | |
|         return buf
 | |
| 
 | |
|     def start_tls(self, ssl_context, *,
 | |
|                   server_side=False,
 | |
|                   server_hostname=None):
 | |
| 
 | |
|         ssl_sock = ssl_context.wrap_socket(
 | |
|             self.__sock, server_side=server_side,
 | |
|             server_hostname=server_hostname,
 | |
|             do_handshake_on_connect=False)
 | |
| 
 | |
|         try:
 | |
|             ssl_sock.do_handshake()
 | |
|         except:
 | |
|             ssl_sock.close()
 | |
|             raise
 | |
|         finally:
 | |
|             self.__sock.close()
 | |
| 
 | |
|         self.__sock = ssl_sock
 | |
| 
 | |
|     def __getattr__(self, name):
 | |
|         return getattr(self.__sock, name)
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return '<{} {!r}>'.format(type(self).__name__, self.__sock)
 | |
| 
 | |
| 
 | |
| class SocketThread(threading.Thread):
 | |
| 
 | |
|     def stop(self):
 | |
|         self._active = False
 | |
|         self.join()
 | |
| 
 | |
|     def __enter__(self):
 | |
|         self.start()
 | |
|         return self
 | |
| 
 | |
|     def __exit__(self, *exc):
 | |
|         self.stop()
 | |
| 
 | |
| 
 | |
| class TestThreadedClient(SocketThread):
 | |
| 
 | |
|     def __init__(self, test, sock, prog, timeout):
 | |
|         threading.Thread.__init__(self, None, None, 'test-client')
 | |
|         self.daemon = True
 | |
| 
 | |
|         self._timeout = timeout
 | |
|         self._sock = sock
 | |
|         self._active = True
 | |
|         self._prog = prog
 | |
|         self._test = test
 | |
| 
 | |
|     def run(self):
 | |
|         try:
 | |
|             self._prog(TestSocketWrapper(self._sock))
 | |
|         except Exception as ex:
 | |
|             self._test._abort_socket_test(ex)
 | |
| 
 | |
| 
 | |
| class TestThreadedServer(SocketThread):
 | |
| 
 | |
|     def __init__(self, test, sock, prog, timeout, max_clients):
 | |
|         threading.Thread.__init__(self, None, None, 'test-server')
 | |
|         self.daemon = True
 | |
| 
 | |
|         self._clients = 0
 | |
|         self._finished_clients = 0
 | |
|         self._max_clients = max_clients
 | |
|         self._timeout = timeout
 | |
|         self._sock = sock
 | |
|         self._active = True
 | |
| 
 | |
|         self._prog = prog
 | |
| 
 | |
|         self._s1, self._s2 = socket.socketpair()
 | |
|         self._s1.setblocking(False)
 | |
| 
 | |
|         self._test = test
 | |
| 
 | |
|     def stop(self):
 | |
|         try:
 | |
|             if self._s2 and self._s2.fileno() != -1:
 | |
|                 try:
 | |
|                     self._s2.send(b'stop')
 | |
|                 except OSError:
 | |
|                     pass
 | |
|         finally:
 | |
|             super().stop()
 | |
| 
 | |
|     def run(self):
 | |
|         try:
 | |
|             with self._sock:
 | |
|                 self._sock.setblocking(0)
 | |
|                 self._run()
 | |
|         finally:
 | |
|             self._s1.close()
 | |
|             self._s2.close()
 | |
| 
 | |
|     def _run(self):
 | |
|         while self._active:
 | |
|             if self._clients >= self._max_clients:
 | |
|                 return
 | |
| 
 | |
|             r, w, x = select.select(
 | |
|                 [self._sock, self._s1], [], [], self._timeout)
 | |
| 
 | |
|             if self._s1 in r:
 | |
|                 return
 | |
| 
 | |
|             if self._sock in r:
 | |
|                 try:
 | |
|                     conn, addr = self._sock.accept()
 | |
|                 except BlockingIOError:
 | |
|                     continue
 | |
|                 except socket.timeout:
 | |
|                     if not self._active:
 | |
|                         return
 | |
|                     else:
 | |
|                         raise
 | |
|                 else:
 | |
|                     self._clients += 1
 | |
|                     conn.settimeout(self._timeout)
 | |
|                     try:
 | |
|                         with conn:
 | |
|                             self._handle_client(conn)
 | |
|                     except Exception as ex:
 | |
|                         self._active = False
 | |
|                         try:
 | |
|                             raise
 | |
|                         finally:
 | |
|                             self._test._abort_socket_test(ex)
 | |
| 
 | |
|     def _handle_client(self, sock):
 | |
|         self._prog(TestSocketWrapper(sock))
 | |
| 
 | |
|     @property
 | |
|     def addr(self):
 | |
|         return self._sock.getsockname()
 | 
