asyncio: Add support for UNIX Domain Sockets.

This commit is contained in:
Yury Selivanov 2014-02-18 12:15:06 -05:00
parent c36e504c53
commit 88a5bf0b2e
10 changed files with 750 additions and 205 deletions

View file

@ -4,12 +4,18 @@ import collections
import contextlib
import io
import os
import socket
import socketserver
import sys
import tempfile
import threading
import time
import unittest
import unittest.mock
from http.server import HTTPServer
from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
try:
import ssl
except ImportError: # pragma: no cover
@ -70,42 +76,51 @@ def run_once(loop):
loop.run_forever()
@contextlib.contextmanager
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
class SilentWSGIRequestHandler(WSGIRequestHandler):
class SilentWSGIRequestHandler(WSGIRequestHandler):
def get_stderr(self):
return io.StringIO()
def get_stderr(self):
return io.StringIO()
def log_message(self, format, *args):
def log_message(self, format, *args):
pass
class SilentWSGIServer(WSGIServer):
def handle_error(self, request, client_address):
pass
class SSLWSGIServerMixin:
def finish_request(self, request, client_address):
# The relative location of our test directory (which
# contains the ssl key and certificate files) differs
# between the stdlib and stand-alone asyncio.
# Prefer our own if we can find it.
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
if not os.path.isdir(here):
here = os.path.join(os.path.dirname(os.__file__),
'test', 'test_asyncio')
keyfile = os.path.join(here, 'ssl_key.pem')
certfile = os.path.join(here, 'ssl_cert.pem')
ssock = ssl.wrap_socket(request,
keyfile=keyfile,
certfile=certfile,
server_side=True)
try:
self.RequestHandlerClass(ssock, client_address, self)
ssock.close()
except OSError:
# maybe socket has been closed by peer
pass
class SilentWSGIServer(WSGIServer):
def handle_error(self, request, client_address):
pass
class SSLWSGIServer(SilentWSGIServer):
def finish_request(self, request, client_address):
# The relative location of our test directory (which
# contains the ssl key and certificate files) differs
# between the stdlib and stand-alone asyncio.
# Prefer our own if we can find it.
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
if not os.path.isdir(here):
here = os.path.join(os.path.dirname(os.__file__),
'test', 'test_asyncio')
keyfile = os.path.join(here, 'ssl_key.pem')
certfile = os.path.join(here, 'ssl_cert.pem')
ssock = ssl.wrap_socket(request,
keyfile=keyfile,
certfile=certfile,
server_side=True)
try:
self.RequestHandlerClass(ssock, client_address, self)
ssock.close()
except OSError:
# maybe socket has been closed by peer
pass
class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
pass
def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
def app(environ, start_response):
status = '200 OK'
@ -115,9 +130,9 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
# Run the test WSGI server in a separate thread in order not to
# interfere with event handling in the main thread
server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
httpd = make_server(host, port, app,
server_class, SilentWSGIRequestHandler)
server_class = server_ssl_cls if use_ssl else server_cls
httpd = server_class(address, SilentWSGIRequestHandler)
httpd.set_app(app)
httpd.address = httpd.server_address
server_thread = threading.Thread(target=httpd.serve_forever)
server_thread.start()
@ -129,6 +144,75 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
server_thread.join()
if hasattr(socket, 'AF_UNIX'):
class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
def server_bind(self):
socketserver.UnixStreamServer.server_bind(self)
self.server_name = '127.0.0.1'
self.server_port = 80
class UnixWSGIServer(UnixHTTPServer, WSGIServer):
def server_bind(self):
UnixHTTPServer.server_bind(self)
self.setup_environ()
def get_request(self):
request, client_addr = super().get_request()
# Code in the stdlib expects that get_request
# will return a socket and a tuple (host, port).
# However, this isn't true for UNIX sockets,
# as the second return value will be a path;
# hence we return some fake data sufficient
# to get the tests going
return request, ('127.0.0.1', '')
class SilentUnixWSGIServer(UnixWSGIServer):
def handle_error(self, request, client_address):
pass
class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
pass
def gen_unix_socket_path():
with tempfile.NamedTemporaryFile() as file:
return file.name
@contextlib.contextmanager
def unix_socket_path():
path = gen_unix_socket_path()
try:
yield path
finally:
try:
os.unlink(path)
except OSError:
pass
@contextlib.contextmanager
def run_test_unix_server(*, use_ssl=False):
with unix_socket_path() as path:
yield from _run_test_server(address=path, use_ssl=use_ssl,
server_cls=SilentUnixWSGIServer,
server_ssl_cls=UnixSSLWSGIServer)
@contextlib.contextmanager
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
server_cls=SilentWSGIServer,
server_ssl_cls=SSLWSGIServer)
def make_test_protocol(base):
dct = {}
for name in dir(base):
@ -275,5 +359,6 @@ class TestLoop(base_events.BaseEventLoop):
def _write_to_self(self):
pass
def MockCallback(**kwargs):
return unittest.mock.Mock(spec=['__call__'], **kwargs)