mirror of
https://github.com/python/cpython.git
synced 2025-08-31 14:07:50 +00:00
asyncio: Add support for UNIX Domain Sockets.
This commit is contained in:
parent
c36e504c53
commit
88a5bf0b2e
10 changed files with 750 additions and 205 deletions
|
@ -4,12 +4,18 @@ import collections
|
|||
import contextlib
|
||||
import io
|
||||
import os
|
||||
import socket
|
||||
import socketserver
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
||||
from http.server import HTTPServer
|
||||
from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except ImportError: # pragma: no cover
|
||||
|
@ -70,42 +76,51 @@ def run_once(loop):
|
|||
loop.run_forever()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
|
||||
class SilentWSGIRequestHandler(WSGIRequestHandler):
|
||||
|
||||
class SilentWSGIRequestHandler(WSGIRequestHandler):
|
||||
def get_stderr(self):
|
||||
return io.StringIO()
|
||||
def get_stderr(self):
|
||||
return io.StringIO()
|
||||
|
||||
def log_message(self, format, *args):
|
||||
def log_message(self, format, *args):
|
||||
pass
|
||||
|
||||
|
||||
class SilentWSGIServer(WSGIServer):
|
||||
|
||||
def handle_error(self, request, client_address):
|
||||
pass
|
||||
|
||||
|
||||
class SSLWSGIServerMixin:
|
||||
|
||||
def finish_request(self, request, client_address):
|
||||
# The relative location of our test directory (which
|
||||
# contains the ssl key and certificate files) differs
|
||||
# between the stdlib and stand-alone asyncio.
|
||||
# Prefer our own if we can find it.
|
||||
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
|
||||
if not os.path.isdir(here):
|
||||
here = os.path.join(os.path.dirname(os.__file__),
|
||||
'test', 'test_asyncio')
|
||||
keyfile = os.path.join(here, 'ssl_key.pem')
|
||||
certfile = os.path.join(here, 'ssl_cert.pem')
|
||||
ssock = ssl.wrap_socket(request,
|
||||
keyfile=keyfile,
|
||||
certfile=certfile,
|
||||
server_side=True)
|
||||
try:
|
||||
self.RequestHandlerClass(ssock, client_address, self)
|
||||
ssock.close()
|
||||
except OSError:
|
||||
# maybe socket has been closed by peer
|
||||
pass
|
||||
|
||||
class SilentWSGIServer(WSGIServer):
|
||||
def handle_error(self, request, client_address):
|
||||
pass
|
||||
|
||||
class SSLWSGIServer(SilentWSGIServer):
|
||||
def finish_request(self, request, client_address):
|
||||
# The relative location of our test directory (which
|
||||
# contains the ssl key and certificate files) differs
|
||||
# between the stdlib and stand-alone asyncio.
|
||||
# Prefer our own if we can find it.
|
||||
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
|
||||
if not os.path.isdir(here):
|
||||
here = os.path.join(os.path.dirname(os.__file__),
|
||||
'test', 'test_asyncio')
|
||||
keyfile = os.path.join(here, 'ssl_key.pem')
|
||||
certfile = os.path.join(here, 'ssl_cert.pem')
|
||||
ssock = ssl.wrap_socket(request,
|
||||
keyfile=keyfile,
|
||||
certfile=certfile,
|
||||
server_side=True)
|
||||
try:
|
||||
self.RequestHandlerClass(ssock, client_address, self)
|
||||
ssock.close()
|
||||
except OSError:
|
||||
# maybe socket has been closed by peer
|
||||
pass
|
||||
class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
|
||||
pass
|
||||
|
||||
|
||||
def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
|
||||
|
||||
def app(environ, start_response):
|
||||
status = '200 OK'
|
||||
|
@ -115,9 +130,9 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
|
|||
|
||||
# Run the test WSGI server in a separate thread in order not to
|
||||
# interfere with event handling in the main thread
|
||||
server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
|
||||
httpd = make_server(host, port, app,
|
||||
server_class, SilentWSGIRequestHandler)
|
||||
server_class = server_ssl_cls if use_ssl else server_cls
|
||||
httpd = server_class(address, SilentWSGIRequestHandler)
|
||||
httpd.set_app(app)
|
||||
httpd.address = httpd.server_address
|
||||
server_thread = threading.Thread(target=httpd.serve_forever)
|
||||
server_thread.start()
|
||||
|
@ -129,6 +144,75 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
|
|||
server_thread.join()
|
||||
|
||||
|
||||
if hasattr(socket, 'AF_UNIX'):
|
||||
|
||||
class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
|
||||
|
||||
def server_bind(self):
|
||||
socketserver.UnixStreamServer.server_bind(self)
|
||||
self.server_name = '127.0.0.1'
|
||||
self.server_port = 80
|
||||
|
||||
|
||||
class UnixWSGIServer(UnixHTTPServer, WSGIServer):
|
||||
|
||||
def server_bind(self):
|
||||
UnixHTTPServer.server_bind(self)
|
||||
self.setup_environ()
|
||||
|
||||
def get_request(self):
|
||||
request, client_addr = super().get_request()
|
||||
# Code in the stdlib expects that get_request
|
||||
# will return a socket and a tuple (host, port).
|
||||
# However, this isn't true for UNIX sockets,
|
||||
# as the second return value will be a path;
|
||||
# hence we return some fake data sufficient
|
||||
# to get the tests going
|
||||
return request, ('127.0.0.1', '')
|
||||
|
||||
|
||||
class SilentUnixWSGIServer(UnixWSGIServer):
|
||||
|
||||
def handle_error(self, request, client_address):
|
||||
pass
|
||||
|
||||
|
||||
class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
|
||||
pass
|
||||
|
||||
|
||||
def gen_unix_socket_path():
|
||||
with tempfile.NamedTemporaryFile() as file:
|
||||
return file.name
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def unix_socket_path():
|
||||
path = gen_unix_socket_path()
|
||||
try:
|
||||
yield path
|
||||
finally:
|
||||
try:
|
||||
os.unlink(path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_test_unix_server(*, use_ssl=False):
|
||||
with unix_socket_path() as path:
|
||||
yield from _run_test_server(address=path, use_ssl=use_ssl,
|
||||
server_cls=SilentUnixWSGIServer,
|
||||
server_ssl_cls=UnixSSLWSGIServer)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
|
||||
yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
|
||||
server_cls=SilentWSGIServer,
|
||||
server_ssl_cls=SSLWSGIServer)
|
||||
|
||||
|
||||
def make_test_protocol(base):
|
||||
dct = {}
|
||||
for name in dir(base):
|
||||
|
@ -275,5 +359,6 @@ class TestLoop(base_events.BaseEventLoop):
|
|||
def _write_to_self(self):
|
||||
pass
|
||||
|
||||
|
||||
def MockCallback(**kwargs):
|
||||
return unittest.mock.Mock(spec=['__call__'], **kwargs)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue