granian/tests/conftest.py
2025-05-07 20:26:20 +02:00

111 lines
3 KiB
Python

import asyncio
import multiprocessing as mp
import socket
from contextlib import asynccontextmanager, closing
from functools import partial
from pathlib import Path
import pytest
from granian import Granian
def _serve(**kwargs):
server = Granian(f'tests.apps.{kwargs["interface"]}:app', **kwargs)
server.serve()
@asynccontextmanager
async def _server(interface, port, runtime_mode, tls=False, task_impl='asyncio', static_mount=False):
certs_path = Path.cwd() / 'tests' / 'fixtures' / 'tls'
kwargs = {
'interface': interface,
'port': port,
'loop': 'asyncio',
'blocking_threads': 1,
'runtime_mode': runtime_mode,
'task_impl': task_impl,
}
if tls:
if tls == 'private':
kwargs['ssl_cert'] = certs_path / 'pcert.pem'
kwargs['ssl_key'] = certs_path / 'pkey.pem'
kwargs['ssl_key_password'] = 'foobar' # noqa: S105
else:
kwargs['ssl_cert'] = certs_path / 'cert.pem'
kwargs['ssl_key'] = certs_path / 'key.pem'
if static_mount:
kwargs['static_path_mount'] = Path.cwd() / 'tests' / 'fixtures'
succeeded, spawn_failures = False, 0
while spawn_failures < 3:
proc = mp.get_context('spawn').Process(target=_serve, kwargs=kwargs)
proc.start()
conn_failures = 0
while conn_failures < 3:
try:
await asyncio.sleep(1.5)
sock = socket.create_connection(('127.0.0.1', port), timeout=1)
sock.close()
succeeded = True
break
except Exception:
conn_failures += 1
if succeeded:
break
proc.terminate()
proc.join(timeout=2)
if proc.is_alive():
proc.kill()
spawn_failures += 1
if not succeeded:
raise RuntimeError('Cannot bind server')
try:
yield port
finally:
proc.terminate()
proc.join(timeout=2)
if proc.is_alive():
proc.kill()
@pytest.fixture(scope='function')
def server_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.bind(('localhost', 0))
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return sock.getsockname()[1]
@pytest.fixture(scope='function')
def asgi_server(server_port, **extras):
return partial(_server, 'asgi', server_port, **extras)
@pytest.fixture(scope='function')
def rsgi_server(server_port):
return partial(_server, 'rsgi', server_port)
@pytest.fixture(scope='function')
def wsgi_server(server_port):
return partial(_server, 'wsgi', server_port)
@pytest.fixture(scope='function')
def server(server_port, request):
return partial(_server, request.param, server_port)
@pytest.fixture(scope='function')
def server_tls(server_port, request):
return partial(_server, request.param, server_port, tls=True)
@pytest.fixture(scope='function')
def server_static_files(server_port, request):
return partial(_server, request.param, server_port, static_mount=True)