mirror of
https://github.com/emmett-framework/granian.git
synced 2025-07-07 19:35:33 +00:00
111 lines
3 KiB
Python
111 lines
3 KiB
Python
import asyncio
|
|
import multiprocessing as mp
|
|
import socket
|
|
from contextlib import asynccontextmanager, closing
|
|
from functools import partial
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from granian import Granian
|
|
|
|
|
|
def _serve(**kwargs):
|
|
server = Granian(f'tests.apps.{kwargs["interface"]}:app', **kwargs)
|
|
server.serve()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def _server(interface, port, runtime_mode, tls=False, task_impl='asyncio', static_mount=False):
|
|
certs_path = Path.cwd() / 'tests' / 'fixtures' / 'tls'
|
|
kwargs = {
|
|
'interface': interface,
|
|
'port': port,
|
|
'loop': 'asyncio',
|
|
'blocking_threads': 1,
|
|
'runtime_mode': runtime_mode,
|
|
'task_impl': task_impl,
|
|
}
|
|
if tls:
|
|
if tls == 'private':
|
|
kwargs['ssl_cert'] = certs_path / 'pcert.pem'
|
|
kwargs['ssl_key'] = certs_path / 'pkey.pem'
|
|
kwargs['ssl_key_password'] = 'foobar' # noqa: S105
|
|
else:
|
|
kwargs['ssl_cert'] = certs_path / 'cert.pem'
|
|
kwargs['ssl_key'] = certs_path / 'key.pem'
|
|
if static_mount:
|
|
kwargs['static_path_mount'] = Path.cwd() / 'tests' / 'fixtures'
|
|
|
|
succeeded, spawn_failures = False, 0
|
|
while spawn_failures < 3:
|
|
proc = mp.get_context('spawn').Process(target=_serve, kwargs=kwargs)
|
|
proc.start()
|
|
|
|
conn_failures = 0
|
|
while conn_failures < 3:
|
|
try:
|
|
await asyncio.sleep(1.5)
|
|
sock = socket.create_connection(('127.0.0.1', port), timeout=1)
|
|
sock.close()
|
|
succeeded = True
|
|
break
|
|
except Exception:
|
|
conn_failures += 1
|
|
if succeeded:
|
|
break
|
|
|
|
proc.terminate()
|
|
proc.join(timeout=2)
|
|
if proc.is_alive():
|
|
proc.kill()
|
|
spawn_failures += 1
|
|
|
|
if not succeeded:
|
|
raise RuntimeError('Cannot bind server')
|
|
|
|
try:
|
|
yield port
|
|
finally:
|
|
proc.terminate()
|
|
proc.join(timeout=2)
|
|
if proc.is_alive():
|
|
proc.kill()
|
|
|
|
|
|
@pytest.fixture(scope='function')
|
|
def server_port():
|
|
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
|
sock.bind(('localhost', 0))
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
return sock.getsockname()[1]
|
|
|
|
|
|
@pytest.fixture(scope='function')
|
|
def asgi_server(server_port, **extras):
|
|
return partial(_server, 'asgi', server_port, **extras)
|
|
|
|
|
|
@pytest.fixture(scope='function')
|
|
def rsgi_server(server_port):
|
|
return partial(_server, 'rsgi', server_port)
|
|
|
|
|
|
@pytest.fixture(scope='function')
|
|
def wsgi_server(server_port):
|
|
return partial(_server, 'wsgi', server_port)
|
|
|
|
|
|
@pytest.fixture(scope='function')
|
|
def server(server_port, request):
|
|
return partial(_server, request.param, server_port)
|
|
|
|
|
|
@pytest.fixture(scope='function')
|
|
def server_tls(server_port, request):
|
|
return partial(_server, request.param, server_port, tls=True)
|
|
|
|
|
|
@pytest.fixture(scope='function')
|
|
def server_static_files(server_port, request):
|
|
return partial(_server, request.param, server_port, static_mount=True)
|