mirror of
				https://github.com/python/cpython.git
				synced 2025-11-04 11:49:12 +00:00 
			
		
		
		
	To make sure there is no issue with code that is both Python 2 and 3 compatible, there are no plans to remove the module any sooner than Python 4 (unless the community moves to Python 3 solidly before then).
		
			
				
	
	
		
			307 lines
		
	
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			307 lines
		
	
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""
 | 
						|
Test suite for socketserver.
 | 
						|
"""
 | 
						|
 | 
						|
import _imp as imp
 | 
						|
import contextlib
 | 
						|
import os
 | 
						|
import select
 | 
						|
import signal
 | 
						|
import socket
 | 
						|
import select
 | 
						|
import errno
 | 
						|
import tempfile
 | 
						|
import unittest
 | 
						|
import socketserver
 | 
						|
 | 
						|
import test.support
 | 
						|
from test.support import reap_children, reap_threads, verbose
 | 
						|
try:
 | 
						|
    import threading
 | 
						|
except ImportError:
 | 
						|
    threading = None
 | 
						|
 | 
						|
test.support.requires("network")
 | 
						|
 | 
						|
TEST_STR = b"hello world\n"
 | 
						|
HOST = test.support.HOST
 | 
						|
 | 
						|
HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
 | 
						|
HAVE_FORKING = hasattr(os, "fork")
 | 
						|
 | 
						|
def signal_alarm(n):
 | 
						|
    """Call signal.alarm when it exists (i.e. not on Windows)."""
 | 
						|
    if hasattr(signal, 'alarm'):
 | 
						|
        signal.alarm(n)
 | 
						|
 | 
						|
# Remember real select() to avoid interferences with mocking
 | 
						|
_real_select = select.select
 | 
						|
 | 
						|
def receive(sock, n, timeout=20):
 | 
						|
    r, w, x = _real_select([sock], [], [], timeout)
 | 
						|
    if sock in r:
 | 
						|
        return sock.recv(n)
 | 
						|
    else:
 | 
						|
        raise RuntimeError("timed out on %r" % (sock,))
 | 
						|
 | 
						|
if HAVE_UNIX_SOCKETS:
 | 
						|
    class ForkingUnixStreamServer(socketserver.ForkingMixIn,
 | 
						|
                                  socketserver.UnixStreamServer):
 | 
						|
        pass
 | 
						|
 | 
						|
    class ForkingUnixDatagramServer(socketserver.ForkingMixIn,
 | 
						|
                                    socketserver.UnixDatagramServer):
 | 
						|
        pass
 | 
						|
 | 
						|
 | 
						|
@contextlib.contextmanager
 | 
						|
def simple_subprocess(testcase):
 | 
						|
    pid = os.fork()
 | 
						|
    if pid == 0:
 | 
						|
        # Don't raise an exception; it would be caught by the test harness.
 | 
						|
        os._exit(72)
 | 
						|
    yield None
 | 
						|
    pid2, status = os.waitpid(pid, 0)
 | 
						|
    testcase.assertEqual(pid2, pid)
 | 
						|
    testcase.assertEqual(72 << 8, status)
 | 
						|
 | 
						|
 | 
						|
@unittest.skipUnless(threading, 'Threading required for this test.')
 | 
						|
class SocketServerTest(unittest.TestCase):
 | 
						|
    """Test all socket servers."""
 | 
						|
 | 
						|
    def setUp(self):
 | 
						|
        signal_alarm(60)  # Kill deadlocks after 60 seconds.
 | 
						|
        self.port_seed = 0
 | 
						|
        self.test_files = []
 | 
						|
 | 
						|
    def tearDown(self):
 | 
						|
        signal_alarm(0)  # Didn't deadlock.
 | 
						|
        reap_children()
 | 
						|
 | 
						|
        for fn in self.test_files:
 | 
						|
            try:
 | 
						|
                os.remove(fn)
 | 
						|
            except OSError:
 | 
						|
                pass
 | 
						|
        self.test_files[:] = []
 | 
						|
 | 
						|
    def pickaddr(self, proto):
 | 
						|
        if proto == socket.AF_INET:
 | 
						|
            return (HOST, 0)
 | 
						|
        else:
 | 
						|
            # XXX: We need a way to tell AF_UNIX to pick its own name
 | 
						|
            # like AF_INET provides port==0.
 | 
						|
            dir = None
 | 
						|
            fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
 | 
						|
            self.test_files.append(fn)
 | 
						|
            return fn
 | 
						|
 | 
						|
    def make_server(self, addr, svrcls, hdlrbase):
 | 
						|
        class MyServer(svrcls):
 | 
						|
            def handle_error(self, request, client_address):
 | 
						|
                self.close_request(request)
 | 
						|
                self.server_close()
 | 
						|
                raise
 | 
						|
 | 
						|
        class MyHandler(hdlrbase):
 | 
						|
            def handle(self):
 | 
						|
                line = self.rfile.readline()
 | 
						|
                self.wfile.write(line)
 | 
						|
 | 
						|
        if verbose: print("creating server")
 | 
						|
        server = MyServer(addr, MyHandler)
 | 
						|
        self.assertEqual(server.server_address, server.socket.getsockname())
 | 
						|
        return server
 | 
						|
 | 
						|
    @reap_threads
 | 
						|
    def run_server(self, svrcls, hdlrbase, testfunc):
 | 
						|
        server = self.make_server(self.pickaddr(svrcls.address_family),
 | 
						|
                                  svrcls, hdlrbase)
 | 
						|
        # We had the OS pick a port, so pull the real address out of
 | 
						|
        # the server.
 | 
						|
        addr = server.server_address
 | 
						|
        if verbose:
 | 
						|
            print("ADDR =", addr)
 | 
						|
            print("CLASS =", svrcls)
 | 
						|
 | 
						|
        t = threading.Thread(
 | 
						|
            name='%s serving' % svrcls,
 | 
						|
            target=server.serve_forever,
 | 
						|
            # Short poll interval to make the test finish quickly.
 | 
						|
            # Time between requests is short enough that we won't wake
 | 
						|
            # up spuriously too many times.
 | 
						|
            kwargs={'poll_interval':0.01})
 | 
						|
        t.daemon = True  # In case this function raises.
 | 
						|
        t.start()
 | 
						|
        if verbose: print("server running")
 | 
						|
        for i in range(3):
 | 
						|
            if verbose: print("test client", i)
 | 
						|
            testfunc(svrcls.address_family, addr)
 | 
						|
        if verbose: print("waiting for server")
 | 
						|
        server.shutdown()
 | 
						|
        t.join()
 | 
						|
        server.server_close()
 | 
						|
        if verbose: print("done")
 | 
						|
 | 
						|
    def stream_examine(self, proto, addr):
 | 
						|
        s = socket.socket(proto, socket.SOCK_STREAM)
 | 
						|
        s.connect(addr)
 | 
						|
        s.sendall(TEST_STR)
 | 
						|
        buf = data = receive(s, 100)
 | 
						|
        while data and b'\n' not in buf:
 | 
						|
            data = receive(s, 100)
 | 
						|
            buf += data
 | 
						|
        self.assertEqual(buf, TEST_STR)
 | 
						|
        s.close()
 | 
						|
 | 
						|
    def dgram_examine(self, proto, addr):
 | 
						|
        s = socket.socket(proto, socket.SOCK_DGRAM)
 | 
						|
        s.sendto(TEST_STR, addr)
 | 
						|
        buf = data = receive(s, 100)
 | 
						|
        while data and b'\n' not in buf:
 | 
						|
            data = receive(s, 100)
 | 
						|
            buf += data
 | 
						|
        self.assertEqual(buf, TEST_STR)
 | 
						|
        s.close()
 | 
						|
 | 
						|
    def test_TCPServer(self):
 | 
						|
        self.run_server(socketserver.TCPServer,
 | 
						|
                        socketserver.StreamRequestHandler,
 | 
						|
                        self.stream_examine)
 | 
						|
 | 
						|
    def test_ThreadingTCPServer(self):
 | 
						|
        self.run_server(socketserver.ThreadingTCPServer,
 | 
						|
                        socketserver.StreamRequestHandler,
 | 
						|
                        self.stream_examine)
 | 
						|
 | 
						|
    if HAVE_FORKING:
 | 
						|
        def test_ForkingTCPServer(self):
 | 
						|
            with simple_subprocess(self):
 | 
						|
                self.run_server(socketserver.ForkingTCPServer,
 | 
						|
                                socketserver.StreamRequestHandler,
 | 
						|
                                self.stream_examine)
 | 
						|
 | 
						|
    if HAVE_UNIX_SOCKETS:
 | 
						|
        def test_UnixStreamServer(self):
 | 
						|
            self.run_server(socketserver.UnixStreamServer,
 | 
						|
                            socketserver.StreamRequestHandler,
 | 
						|
                            self.stream_examine)
 | 
						|
 | 
						|
        def test_ThreadingUnixStreamServer(self):
 | 
						|
            self.run_server(socketserver.ThreadingUnixStreamServer,
 | 
						|
                            socketserver.StreamRequestHandler,
 | 
						|
                            self.stream_examine)
 | 
						|
 | 
						|
        if HAVE_FORKING:
 | 
						|
            def test_ForkingUnixStreamServer(self):
 | 
						|
                with simple_subprocess(self):
 | 
						|
                    self.run_server(ForkingUnixStreamServer,
 | 
						|
                                    socketserver.StreamRequestHandler,
 | 
						|
                                    self.stream_examine)
 | 
						|
 | 
						|
    def test_UDPServer(self):
 | 
						|
        self.run_server(socketserver.UDPServer,
 | 
						|
                        socketserver.DatagramRequestHandler,
 | 
						|
                        self.dgram_examine)
 | 
						|
 | 
						|
    def test_ThreadingUDPServer(self):
 | 
						|
        self.run_server(socketserver.ThreadingUDPServer,
 | 
						|
                        socketserver.DatagramRequestHandler,
 | 
						|
                        self.dgram_examine)
 | 
						|
 | 
						|
    if HAVE_FORKING:
 | 
						|
        def test_ForkingUDPServer(self):
 | 
						|
            with simple_subprocess(self):
 | 
						|
                self.run_server(socketserver.ForkingUDPServer,
 | 
						|
                                socketserver.DatagramRequestHandler,
 | 
						|
                                self.dgram_examine)
 | 
						|
 | 
						|
    @contextlib.contextmanager
 | 
						|
    def mocked_select_module(self):
 | 
						|
        """Mocks the select.select() call to raise EINTR for first call"""
 | 
						|
        old_select = select.select
 | 
						|
 | 
						|
        class MockSelect:
 | 
						|
            def __init__(self):
 | 
						|
                self.called = 0
 | 
						|
 | 
						|
            def __call__(self, *args):
 | 
						|
                self.called += 1
 | 
						|
                if self.called == 1:
 | 
						|
                    # raise the exception on first call
 | 
						|
                    raise OSError(errno.EINTR, os.strerror(errno.EINTR))
 | 
						|
                else:
 | 
						|
                    # Return real select value for consecutive calls
 | 
						|
                    return old_select(*args)
 | 
						|
 | 
						|
        select.select = MockSelect()
 | 
						|
        try:
 | 
						|
            yield select.select
 | 
						|
        finally:
 | 
						|
            select.select = old_select
 | 
						|
 | 
						|
    def test_InterruptServerSelectCall(self):
 | 
						|
        with self.mocked_select_module() as mock_select:
 | 
						|
            pid = self.run_server(socketserver.TCPServer,
 | 
						|
                                  socketserver.StreamRequestHandler,
 | 
						|
                                  self.stream_examine)
 | 
						|
            # Make sure select was called again:
 | 
						|
            self.assertGreater(mock_select.called, 1)
 | 
						|
 | 
						|
    # Alas, on Linux (at least) recvfrom() doesn't return a meaningful
 | 
						|
    # client address so this cannot work:
 | 
						|
 | 
						|
    # if HAVE_UNIX_SOCKETS:
 | 
						|
    #     def test_UnixDatagramServer(self):
 | 
						|
    #         self.run_server(socketserver.UnixDatagramServer,
 | 
						|
    #                         socketserver.DatagramRequestHandler,
 | 
						|
    #                         self.dgram_examine)
 | 
						|
    #
 | 
						|
    #     def test_ThreadingUnixDatagramServer(self):
 | 
						|
    #         self.run_server(socketserver.ThreadingUnixDatagramServer,
 | 
						|
    #                         socketserver.DatagramRequestHandler,
 | 
						|
    #                         self.dgram_examine)
 | 
						|
    #
 | 
						|
    #     if HAVE_FORKING:
 | 
						|
    #         def test_ForkingUnixDatagramServer(self):
 | 
						|
    #             self.run_server(socketserver.ForkingUnixDatagramServer,
 | 
						|
    #                             socketserver.DatagramRequestHandler,
 | 
						|
    #                             self.dgram_examine)
 | 
						|
 | 
						|
    @reap_threads
 | 
						|
    def test_shutdown(self):
 | 
						|
        # Issue #2302: shutdown() should always succeed in making an
 | 
						|
        # other thread leave serve_forever().
 | 
						|
        class MyServer(socketserver.TCPServer):
 | 
						|
            pass
 | 
						|
 | 
						|
        class MyHandler(socketserver.StreamRequestHandler):
 | 
						|
            pass
 | 
						|
 | 
						|
        threads = []
 | 
						|
        for i in range(20):
 | 
						|
            s = MyServer((HOST, 0), MyHandler)
 | 
						|
            t = threading.Thread(
 | 
						|
                name='MyServer serving',
 | 
						|
                target=s.serve_forever,
 | 
						|
                kwargs={'poll_interval':0.01})
 | 
						|
            t.daemon = True  # In case this function raises.
 | 
						|
            threads.append((t, s))
 | 
						|
        for t, s in threads:
 | 
						|
            t.start()
 | 
						|
            s.shutdown()
 | 
						|
        for t, s in threads:
 | 
						|
            t.join()
 | 
						|
            s.server_close()
 | 
						|
 | 
						|
 | 
						|
def test_main():
 | 
						|
    if imp.lock_held():
 | 
						|
        # If the import lock is held, the threads will hang
 | 
						|
        raise unittest.SkipTest("can't run when import lock is held")
 | 
						|
 | 
						|
    test.support.run_unittest(SocketServerTest)
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    test_main()
 |