diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 6de785a..138e7d3 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,18 +13,18 @@ jobs: lint: runs-on: ubuntu-latest + env: + UV_PYTHON: ${{ matrix.python-version }} steps: - uses: actions/checkout@v4 - - name: Set up Python ${{ env.PYTHON_VERSION }} - uses: actions/setup-python@v5 + - uses: astral-sh/setup-uv@v5 with: - python-version: ${{ env.PYTHON_VERSION }} + enable-cache: false - name: Install run: | - python -m venv .venv - source .venv/bin/activate - pip install maturin - maturin develop --extras=lint + uv python install ${{ env.UV_PYTHON }} + uv venv .venv + uv sync --group lint - name: Lint run: | source .venv/bin/activate diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2d2bb81..7041c9a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,24 +21,25 @@ jobs: - '3.11' - '3.12' - '3.13' + - '3.13t' + env: + UV_PYTHON: ${{ matrix.python-version }} steps: - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + - uses: astral-sh/setup-uv@v5 with: - python-version: ${{ matrix.python-version }} - allow-prereleases: true + enable-cache: false - name: Install run: | - python -m venv .venv - source .venv/bin/activate - pip install maturin - maturin develop --extras=test + uv python install ${{ env.UV_PYTHON }} + uv venv .venv + uv sync --group all + uv run --no-sync maturin develop --uv - name: Test run: | source .venv/bin/activate - py.test -v tests + make test macos: runs-on: macos-latest @@ -51,24 +52,25 @@ jobs: - '3.11' - '3.12' - '3.13' + - '3.13t' + env: + UV_PYTHON: ${{ matrix.python-version }} steps: - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + - uses: astral-sh/setup-uv@v5 with: - python-version: ${{ matrix.python-version }} - allow-prereleases: true + enable-cache: false - name: Install run: | - python -m venv .venv - source .venv/bin/activate - pip install maturin - maturin develop --extras=test + uv python install ${{ env.UV_PYTHON }} + uv venv .venv + uv sync --group all + uv run --no-sync maturin develop --uv - name: Test run: | source .venv/bin/activate - py.test -v tests + make test windows: runs-on: windows-latest @@ -81,21 +83,21 @@ jobs: - '3.11' - '3.12' - '3.13' + - '3.13t' + env: + UV_PYTHON: ${{ matrix.python-version }} steps: - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + - uses: astral-sh/setup-uv@v5 with: - python-version: ${{ matrix.python-version }} - allow-prereleases: true + enable-cache: false - name: Install run: | - python -m venv venv - venv/Scripts/Activate.ps1 - pip install maturin - maturin develop --extras=test + uv python install ${{ env.UV_PYTHON }} + uv venv .venv + uv sync --group all + uv run --no-sync maturin develop --uv - name: Test run: | - venv/Scripts/Activate.ps1 - py.test -v tests + uv run --no-sync pytest -v tests diff --git a/Makefile b/Makefile index 304c0ee..37f342e 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,8 @@ pysources = granian tests .PHONY: build-dev build-dev: @rm -f granian/*.so - maturin develop --extras lint,test + uv sync --group all + maturin develop .PHONY: format format: diff --git a/granian/__init__.py b/granian/__init__.py index be65c04..074353f 100644 --- a/granian/__init__.py +++ b/granian/__init__.py @@ -1,2 +1,2 @@ from ._granian import __version__ # noqa: F401 -from .server import Granian as Granian +from .server import Server as Granian # noqa: F401 diff --git a/granian/_granian.pyi b/granian/_granian.pyi index e095f60..6321a75 100644 --- a/granian/_granian.pyi +++ b/granian/_granian.pyi @@ -5,6 +5,7 @@ from ._types import WebsocketMessage from .http import HTTP1Settings, HTTP2Settings __version__: str +BUILD_GIL: bool class RSGIHeaders: def __contains__(self, key: str) -> bool: ... diff --git a/granian/_internal.py b/granian/_internal.py index 89bee15..ba8f44e 100644 --- a/granian/_internal.py +++ b/granian/_internal.py @@ -41,7 +41,7 @@ def load_module(module_name: str, raise_on_failure: bool = True) -> Optional[Mod except ImportError: if sys.exc_info()[-1].tb_next: raise RuntimeError( - f"While importing '{module_name}', an ImportError was raised:" f'\n\n{traceback.format_exc()}' + f"While importing '{module_name}', an ImportError was raised:\n\n{traceback.format_exc()}" ) elif raise_on_failure: raise RuntimeError(f"Could not import '{module_name}'.") diff --git a/granian/_loops.py b/granian/_loops.py index 4e19cd8..c9be535 100644 --- a/granian/_loops.py +++ b/granian/_loops.py @@ -77,16 +77,14 @@ def build_asyncio_loop(): @loops.register('uvloop', packages=['uvloop']) def build_uv_loop(uvloop): - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - loop = asyncio.new_event_loop() + loop = uvloop.new_event_loop() asyncio.set_event_loop(loop) return loop @loops.register('rloop', packages=['rloop']) def build_rloop(rloop): - asyncio.set_event_loop_policy(rloop.EventLoopPolicy()) - loop = asyncio.new_event_loop() + loop = rloop.new_event_loop() asyncio.set_event_loop(loop) return loop diff --git a/granian/cli.py b/granian/cli.py index 882dda0..8211e0a 100644 --- a/granian/cli.py +++ b/granian/cli.py @@ -9,7 +9,7 @@ from .constants import HTTPModes, Interfaces, Loops, TaskImpl, ThreadModes from .errors import FatalError from .http import HTTP1Settings, HTTP2Settings from .log import LogLevels -from .server import Granian +from .server import Server _AnyCallable = Callable[..., Any] @@ -70,6 +70,11 @@ def option(*param_decls: str, cls: Optional[Type[click.Option]] = None, **attrs: type=click.IntRange(1), help='Number of blocking threads (per worker)', ) +@option( + '--io-blocking-threads', + type=click.IntRange(1), + help='Number of I/O blocking threads (per worker)', +) @option( '--threading-mode', type=EnumType(ThreadModes), @@ -265,6 +270,7 @@ def cli( workers: int, threads: int, blocking_threads: Optional[int], + io_blocking_threads: Optional[int], threading_mode: ThreadModes, loop: Loops, task_impl: TaskImpl, @@ -313,13 +319,14 @@ def cli( print('Unable to parse provided logging config.') raise click.exceptions.Exit(1) - server = Granian( + server = Server( app, address=host, port=port, interface=interface, workers=workers, threads=threads, + io_blocking_threads=io_blocking_threads, blocking_threads=blocking_threads, threading_mode=threading_mode, loop=loop, diff --git a/granian/server/__init__.py b/granian/server/__init__.py new file mode 100644 index 0000000..50cb725 --- /dev/null +++ b/granian/server/__init__.py @@ -0,0 +1,7 @@ +from .._granian import BUILD_GIL + + +if BUILD_GIL: + from .mp import MPServer as Server +else: + from .mt import MTServer as Server # noqa: F401 diff --git a/granian/server.py b/granian/server/common.py similarity index 57% rename from granian/server.py rename to granian/server/common.py index be3db0d..e271571 100644 --- a/granian/server.py +++ b/granian/server/common.py @@ -10,28 +10,25 @@ import threading import time from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type +from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, Type, TypeVar -from ._futures import _future_watcher_wrapper, _new_cbscheduler -from ._granian import ASGIWorker, RSGIWorker, WSGIWorker -from ._imports import setproctitle, watchfiles -from ._internal import load_target -from ._signals import set_main_signals -from .asgi import LifespanProtocol, _callback_wrapper as _asgi_call_wrap -from .constants import HTTPModes, Interfaces, Loops, TaskImpl, ThreadModes -from .errors import ConfigurationError, PidFileError -from .http import HTTP1Settings, HTTP2Settings -from .log import DEFAULT_ACCESSLOG_FMT, LogLevels, configure_logging, logger -from .net import SocketHolder -from .rsgi import _callback_wrapper as _rsgi_call_wrap -from .wsgi import _callback_wrapper as _wsgi_call_wrap +from .._imports import setproctitle, watchfiles +from .._internal import load_target +from .._signals import set_main_signals +from ..constants import HTTPModes, Interfaces, Loops, TaskImpl, ThreadModes +from ..errors import ConfigurationError, PidFileError +from ..http import HTTP1Settings, HTTP2Settings +from ..log import DEFAULT_ACCESSLOG_FMT, LogLevels, configure_logging, logger +from ..net import SocketHolder -multiprocessing.allow_connection_pickling() +WT = TypeVar('WT') -class Worker: - def __init__(self, parent: Granian, idx: int, target: Any, args: Any): +class AbstractWorker: + _idl = 'id' + + def __init__(self, parent: AbstractServer, idx: int, target: Any, args: Any): self.parent = parent self.idx = idx self.interrupt_by_parent = False @@ -39,10 +36,13 @@ class Worker: self._spawn(target, args) def _spawn(self, target, args): - self.proc = multiprocessing.get_context().Process(name='granian-worker', target=target, args=args) + raise NotImplementedError + + def _id(self): + raise NotImplementedError def _watcher(self): - self.proc.join() + self.inner.join() if not self.interrupt_by_parent: logger.error(f'Unexpected exit from worker-{self.idx + 1}') self.parent.interrupt_children.append(self.idx) @@ -53,23 +53,24 @@ class Worker: watcher.start() def start(self): - self.proc.start() - logger.info(f'Spawning worker-{self.idx + 1} with pid: {self.proc.pid}') + self.inner.start() + logger.info(f'Spawning worker-{self.idx + 1} with {self._idl}: {self._id()}') self._watch() + def is_alive(self): + return self.inner.is_alive() + def terminate(self): - self.interrupt_by_parent = True - self.proc.terminate() + raise NotImplementedError def kill(self): - self.interrupt_by_parent = True - self.proc.kill() + raise NotImplementedError def join(self, timeout=None): - self.proc.join(timeout=timeout) + self.inner.join(timeout=timeout) -class Granian: +class AbstractServer(Generic[WT]): def __init__( self, target: str, @@ -78,6 +79,7 @@ class Granian: interface: Interfaces = Interfaces.RSGI, workers: int = 1, threads: int = 1, + io_blocking_threads: Optional[int] = None, blocking_threads: Optional[int] = None, threading_mode: ThreadModes = ThreadModes.workers, loop: Loops = Loops.auto, @@ -117,6 +119,7 @@ class Granian: self.interface = interface self.workers = max(1, workers) self.threads = max(1, threads) + self.io_blocking_threads = 512 if io_blocking_threads is None else max(1, io_blocking_threads) self.threading_mode = threading_mode self.loop = loop self.task_impl = task_impl @@ -127,9 +130,7 @@ class Granian: self.blocking_threads = ( blocking_threads if blocking_threads is not None - else max( - 1, (self.backpressure if self.interface == Interfaces.WSGI else min(2, multiprocessing.cpu_count())) - ) + else max(1, (multiprocessing.cpu_count() * 2 - 1) if self.interface == Interfaces.WSGI else 1) ) self.http1_settings = http1_settings self.http2_settings = http2_settings @@ -158,11 +159,11 @@ class Granian: self.build_ssl_context(ssl_cert, ssl_key, ssl_key_password) self._shd = None self._sfd = None - self.procs: List[Worker] = [] + self.wrks: List[WT] = [] self.main_loop_interrupt = threading.Event() self.interrupt_signal = False self.interrupt_children = [] - self.respawned_procs = {} + self.respawned_wrks = {} self.reload_signal = False self.lifetime_signal = False self.pid = None @@ -179,230 +180,6 @@ class Granian: # key_contents = f.read() self.ssl_ctx = (True, str(cert.resolve()), str(key.resolve()), password) - @staticmethod - def _spawn_asgi_worker( - worker_id: int, - process_name: Optional[str], - callback_loader: Callable[..., Any], - socket: socket.socket, - loop_impl: Loops, - threads: int, - blocking_threads: int, - backpressure: int, - threading_mode: ThreadModes, - task_impl: TaskImpl, - http_mode: HTTPModes, - http1_settings: Optional[HTTP1Settings], - http2_settings: Optional[HTTP2Settings], - websockets: bool, - log_enabled: bool, - log_level: LogLevels, - log_config: Dict[str, Any], - log_access_fmt: Optional[str], - ssl_ctx: Tuple[bool, Optional[str], Optional[str], Optional[str]], - scope_opts: Dict[str, Any], - ): - from granian._loops import loops - from granian._signals import set_loop_signals - - if process_name: - setproctitle.setproctitle(f'{process_name} worker-{worker_id}') - configure_logging(log_level, log_config, log_enabled) - - loop = loops.get(loop_impl) - sfd = socket.fileno() - callback = callback_loader() - shutdown_event = set_loop_signals(loop) - wcallback = _asgi_call_wrap(callback, scope_opts, {}, log_access_fmt) - - worker = ASGIWorker( - worker_id, - sfd, - threads, - blocking_threads, - backpressure, - http_mode, - http1_settings, - http2_settings, - websockets, - *ssl_ctx, - ) - serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) - scheduler = _new_cbscheduler( - loop, _future_watcher_wrapper(wcallback), impl_asyncio=task_impl == TaskImpl.asyncio - ) - serve(scheduler, loop, shutdown_event) - - @staticmethod - def _spawn_asgi_lifespan_worker( - worker_id: int, - process_name: Optional[str], - callback_loader: Callable[..., Any], - socket: socket.socket, - loop_impl: Loops, - threads: int, - blocking_threads: int, - backpressure: int, - threading_mode: ThreadModes, - task_impl: TaskImpl, - http_mode: HTTPModes, - http1_settings: Optional[HTTP1Settings], - http2_settings: Optional[HTTP2Settings], - websockets: bool, - log_enabled: bool, - log_level: LogLevels, - log_config: Dict[str, Any], - log_access_fmt: Optional[str], - ssl_ctx: Tuple[bool, Optional[str], Optional[str], Optional[str]], - scope_opts: Dict[str, Any], - ): - from granian._loops import loops - from granian._signals import set_loop_signals - - if process_name: - setproctitle.setproctitle(f'{process_name} worker-{worker_id}') - configure_logging(log_level, log_config, log_enabled) - - loop = loops.get(loop_impl) - sfd = socket.fileno() - callback = callback_loader() - lifespan_handler = LifespanProtocol(callback) - - loop.run_until_complete(lifespan_handler.startup()) - if lifespan_handler.interrupt: - logger.error('ASGI lifespan startup failed', exc_info=lifespan_handler.exc) - sys.exit(1) - - shutdown_event = set_loop_signals(loop) - wcallback = _asgi_call_wrap(callback, scope_opts, lifespan_handler.state, log_access_fmt) - - worker = ASGIWorker( - worker_id, - sfd, - threads, - blocking_threads, - backpressure, - http_mode, - http1_settings, - http2_settings, - websockets, - *ssl_ctx, - ) - serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) - scheduler = _new_cbscheduler( - loop, _future_watcher_wrapper(wcallback), impl_asyncio=task_impl == TaskImpl.asyncio - ) - serve(scheduler, loop, shutdown_event) - loop.run_until_complete(lifespan_handler.shutdown()) - - @staticmethod - def _spawn_rsgi_worker( - worker_id: int, - process_name: Optional[str], - callback_loader: Callable[..., Any], - socket: socket.socket, - loop_impl: Loops, - threads: int, - blocking_threads: int, - backpressure: int, - threading_mode: ThreadModes, - task_impl: TaskImpl, - http_mode: HTTPModes, - http1_settings: Optional[HTTP1Settings], - http2_settings: Optional[HTTP2Settings], - websockets: bool, - log_enabled: bool, - log_level: LogLevels, - log_config: Dict[str, Any], - log_access_fmt: Optional[str], - ssl_ctx: Tuple[bool, Optional[str], Optional[str], Optional[str]], - scope_opts: Dict[str, Any], - ): - from granian._loops import loops - from granian._signals import set_loop_signals - - if process_name: - setproctitle.setproctitle(f'{process_name} worker-{worker_id}') - configure_logging(log_level, log_config, log_enabled) - - loop = loops.get(loop_impl) - sfd = socket.fileno() - target = callback_loader() - callback = getattr(target, '__rsgi__') if hasattr(target, '__rsgi__') else target - callback_init = ( - getattr(target, '__rsgi_init__') if hasattr(target, '__rsgi_init__') else lambda *args, **kwargs: None - ) - callback_del = ( - getattr(target, '__rsgi_del__') if hasattr(target, '__rsgi_del__') else lambda *args, **kwargs: None - ) - callback = _rsgi_call_wrap(callback, log_access_fmt) - shutdown_event = set_loop_signals(loop) - callback_init(loop) - - worker = RSGIWorker( - worker_id, - sfd, - threads, - blocking_threads, - backpressure, - http_mode, - http1_settings, - http2_settings, - websockets, - *ssl_ctx, - ) - serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) - scheduler = _new_cbscheduler( - loop, _future_watcher_wrapper(callback), impl_asyncio=task_impl == TaskImpl.asyncio - ) - serve(scheduler, loop, shutdown_event) - callback_del(loop) - - @staticmethod - def _spawn_wsgi_worker( - worker_id: int, - process_name: Optional[str], - callback_loader: Callable[..., Any], - socket: socket.socket, - loop_impl: Loops, - threads: int, - blocking_threads: int, - backpressure: int, - threading_mode: ThreadModes, - task_impl: TaskImpl, - http_mode: HTTPModes, - http1_settings: Optional[HTTP1Settings], - http2_settings: Optional[HTTP2Settings], - websockets: bool, - log_enabled: bool, - log_level: LogLevels, - log_config: Dict[str, Any], - log_access_fmt: Optional[str], - ssl_ctx: Tuple[bool, Optional[str], Optional[str], Optional[str]], - scope_opts: Dict[str, Any], - ): - from granian._loops import loops - from granian._signals import set_sync_signals - - if process_name: - setproctitle.setproctitle(f'{process_name} worker-{worker_id}') - configure_logging(log_level, log_config, log_enabled) - - loop = loops.get(loop_impl) - sfd = socket.fileno() - callback = callback_loader() - shutdown_event = set_sync_signals() - - worker = WSGIWorker( - worker_id, sfd, threads, blocking_threads, backpressure, http_mode, http1_settings, http2_settings, *ssl_ctx - ) - serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) - scheduler = _new_cbscheduler( - loop, _wsgi_call_wrap(callback, scope_opts, log_access_fmt), impl_asyncio=task_impl == TaskImpl.asyncio - ) - serve(scheduler, loop, shutdown_event) - shutdown_event.qs.wait() - def _init_shared_socket(self): self._shd = SocketHolder.from_address(self.bind_addr, self.bind_port, self.backlog) self._sfd = self._shd.get_fd() @@ -415,88 +192,62 @@ class Granian: self.reload_signal = True self.main_loop_interrupt.set() - def _spawn_proc(self, idx, target, callback_loader, socket_loader) -> Worker: - return Worker( - parent=self, - idx=idx, - target=target, - args=( - idx + 1, - self.process_name, - callback_loader, - socket_loader(), - self.loop, - self.threads, - self.blocking_threads, - self.backpressure, - self.threading_mode, - self.task_impl, - self.http, - self.http1_settings, - self.http2_settings, - self.websockets, - self.log_enabled, - self.log_level, - self.log_config, - self.log_access_format if self.log_access else None, - self.ssl_ctx, - {'url_path_prefix': self.url_path_prefix}, - ), - ) + def _spawn_worker(self, idx, target, callback_loader, socket_loader) -> WT: + raise NotImplementedError def _spawn_workers(self, sock, spawn_target, target_loader): def socket_loader(): return sock for idx in range(self.workers): - proc = self._spawn_proc( + wrk = self._spawn_worker( idx=idx, target=spawn_target, callback_loader=target_loader, socket_loader=socket_loader ) - proc.start() - self.procs.append(proc) + wrk.start() + self.wrks.append(wrk) def _respawn_workers(self, workers, sock, spawn_target, target_loader, delay: float = 0): def socket_loader(): return sock for idx in workers: - self.respawned_procs[idx] = time.time() + self.respawned_wrks[idx] = time.time() logger.info(f'Respawning worker-{idx + 1}') - old_proc = self.procs.pop(idx) - proc = self._spawn_proc( + old_wrk = self.wrks.pop(idx) + wrk = self._spawn_worker( idx=idx, target=spawn_target, callback_loader=target_loader, socket_loader=socket_loader ) - proc.start() - self.procs.insert(idx, proc) + wrk.start() + self.wrks.insert(idx, wrk) time.sleep(delay) logger.info(f'Stopping old worker-{idx + 1}') - old_proc.terminate() - old_proc.join(self.workers_kill_timeout) + old_wrk.terminate() + old_wrk.join(self.workers_kill_timeout) if self.workers_kill_timeout: - # the process might still be reported alive after `join`, let's context switch - if old_proc.proc.is_alive(): + # the worker might still be reported alive after `join`, let's context switch + if old_wrk.is_alive(): time.sleep(0.001) - if old_proc.proc.is_alive(): + if old_wrk.is_alive(): logger.warning(f'Killing old worker-{idx + 1} after it refused to gracefully stop') - old_proc.kill() - old_proc.join() + old_wrk.kill() + old_wrk.join() def _stop_workers(self): - for proc in self.procs: - proc.terminate() + for wrk in self.wrks: + wrk.terminate() - for proc in self.procs: - proc.join(self.workers_kill_timeout) + for wrk in self.wrks: + wrk.join(self.workers_kill_timeout) if self.workers_kill_timeout: - # the process might still be reported after `join`, let's context switch - if proc.proc.is_alive(): + # the worker might still be reported after `join`, let's context switch + if wrk.is_alive(): time.sleep(0.001) - if proc.proc.is_alive(): - logger.warning(f'Killing worker-{proc.idx} after it refused to gracefully stop') - proc.kill() - proc.join() + if wrk.is_alive(): + logger.warning(f'Killing worker-{wrk.idx} after it refused to gracefully stop') + wrk.kill() + wrk.join() - self.procs.clear() + self.wrks.clear() def _workers_lifetime_watcher(self, ttl): time.sleep(ttl) @@ -584,7 +335,7 @@ class Granian: logger.info('HUP signal received, gracefully respawning workers..') workers = list(range(self.workers)) self.reload_signal = False - self.respawned_procs.clear() + self.respawned_wrks.clear() self.main_loop_interrupt.clear() self._respawn_workers(workers, sock, spawn_target, target_loader, delay=self.respawn_interval) @@ -599,13 +350,13 @@ class Granian: break cycle = time.time() - if any(cycle - self.respawned_procs.get(idx, 0) <= 5.5 for idx in self.interrupt_children): + if any(cycle - self.respawned_wrks.get(idx, 0) <= 5.5 for idx in self.interrupt_children): logger.error('Worker crash loop detected, exiting') break workers = list(self.interrupt_children) self.interrupt_children.clear() - self.respawned_procs.clear() + self.respawned_wrks.clear() self.main_loop_interrupt.clear() self._respawn_workers(workers, sock, spawn_target, target_loader) @@ -618,7 +369,7 @@ class Granian: ttl = self.workers_lifetime * 0.95 now = time.time() etas = [self.workers_lifetime] - for worker in list(self.procs): + for worker in list(self.wrks): if (now - worker.birth) >= ttl: logger.info(f'worker-{worker.idx + 1} lifetime expired, gracefully respawning..') self._respawn_workers( @@ -706,6 +457,10 @@ class Granian: 'Number of workers will now fallback to 1.' ) + if self.interface != Interfaces.WSGI and self.blocking_threads > 1: + logger.error('Blocking threads > 1 is not supported on ASGI and RSGI') + raise ConfigurationError('blocking_threads') + if self.websockets: if self.interface == Interfaces.WSGI: logger.info('Websockets are not supported on WSGI, ignoring') diff --git a/granian/server/mp.py b/granian/server/mp.py new file mode 100644 index 0000000..1299bf6 --- /dev/null +++ b/granian/server/mp.py @@ -0,0 +1,316 @@ +import multiprocessing +import socket +import sys +from typing import Any, Callable, Dict, Optional, Tuple + +from .._futures import _future_watcher_wrapper, _new_cbscheduler +from .._granian import ASGIWorker, RSGIWorker, WSGIWorker +from ..asgi import LifespanProtocol, _callback_wrapper as _asgi_call_wrap +from ..rsgi import _callback_wrapper as _rsgi_call_wrap +from ..wsgi import _callback_wrapper as _wsgi_call_wrap +from .common import ( + AbstractServer, + AbstractWorker, + HTTP1Settings, + HTTP2Settings, + HTTPModes, + LogLevels, + Loops, + TaskImpl, + ThreadModes, + configure_logging, + logger, + setproctitle, +) + + +multiprocessing.allow_connection_pickling() + + +class WorkerProcess(AbstractWorker): + _idl = 'PID' + + def _spawn(self, target, args): + self.inner = multiprocessing.get_context().Process(name='granian-worker', target=target, args=args) + + def _id(self): + return self.inner.pid + + def terminate(self): + self.interrupt_by_parent = True + self.inner.terminate() + + def kill(self): + self.interrupt_by_parent = True + self.inner.kill() + + +class MPServer(AbstractServer[WorkerProcess]): + @staticmethod + def _spawn_asgi_worker( + worker_id: int, + process_name: Optional[str], + callback_loader: Callable[..., Any], + socket: socket.socket, + loop_impl: Loops, + threads: int, + io_blocking_threads: Optional[int], + blocking_threads: int, + backpressure: int, + threading_mode: ThreadModes, + task_impl: TaskImpl, + http_mode: HTTPModes, + http1_settings: Optional[HTTP1Settings], + http2_settings: Optional[HTTP2Settings], + websockets: bool, + log_enabled: bool, + log_level: LogLevels, + log_config: Dict[str, Any], + log_access_fmt: Optional[str], + ssl_ctx: Tuple[bool, Optional[str], Optional[str], Optional[str]], + scope_opts: Dict[str, Any], + ): + from granian._loops import loops + from granian._signals import set_loop_signals + + if process_name: + setproctitle.setproctitle(f'{process_name} worker-{worker_id}') + configure_logging(log_level, log_config, log_enabled) + + loop = loops.get(loop_impl) + sfd = socket.fileno() + callback = callback_loader() + shutdown_event = set_loop_signals(loop) + wcallback = _asgi_call_wrap(callback, scope_opts, {}, log_access_fmt) + + worker = ASGIWorker( + worker_id, + sfd, + threads, + io_blocking_threads, + blocking_threads, + backpressure, + http_mode, + http1_settings, + http2_settings, + websockets, + *ssl_ctx, + ) + serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) + scheduler = _new_cbscheduler( + loop, _future_watcher_wrapper(wcallback), impl_asyncio=task_impl == TaskImpl.asyncio + ) + serve(scheduler, loop, shutdown_event) + + @staticmethod + def _spawn_asgi_lifespan_worker( + worker_id: int, + process_name: Optional[str], + callback_loader: Callable[..., Any], + socket: socket.socket, + loop_impl: Loops, + threads: int, + io_blocking_threads: Optional[int], + blocking_threads: int, + backpressure: int, + threading_mode: ThreadModes, + task_impl: TaskImpl, + http_mode: HTTPModes, + http1_settings: Optional[HTTP1Settings], + http2_settings: Optional[HTTP2Settings], + websockets: bool, + log_enabled: bool, + log_level: LogLevels, + log_config: Dict[str, Any], + log_access_fmt: Optional[str], + ssl_ctx: Tuple[bool, Optional[str], Optional[str], Optional[str]], + scope_opts: Dict[str, Any], + ): + from granian._loops import loops + from granian._signals import set_loop_signals + + if process_name: + setproctitle.setproctitle(f'{process_name} worker-{worker_id}') + configure_logging(log_level, log_config, log_enabled) + + loop = loops.get(loop_impl) + sfd = socket.fileno() + callback = callback_loader() + lifespan_handler = LifespanProtocol(callback) + + loop.run_until_complete(lifespan_handler.startup()) + if lifespan_handler.interrupt: + logger.error('ASGI lifespan startup failed', exc_info=lifespan_handler.exc) + sys.exit(1) + + shutdown_event = set_loop_signals(loop) + wcallback = _asgi_call_wrap(callback, scope_opts, lifespan_handler.state, log_access_fmt) + + worker = ASGIWorker( + worker_id, + sfd, + threads, + io_blocking_threads, + blocking_threads, + backpressure, + http_mode, + http1_settings, + http2_settings, + websockets, + *ssl_ctx, + ) + serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) + scheduler = _new_cbscheduler( + loop, _future_watcher_wrapper(wcallback), impl_asyncio=task_impl == TaskImpl.asyncio + ) + serve(scheduler, loop, shutdown_event) + loop.run_until_complete(lifespan_handler.shutdown()) + + @staticmethod + def _spawn_rsgi_worker( + worker_id: int, + process_name: Optional[str], + callback_loader: Callable[..., Any], + socket: socket.socket, + loop_impl: Loops, + threads: int, + io_blocking_threads: Optional[int], + blocking_threads: int, + backpressure: int, + threading_mode: ThreadModes, + task_impl: TaskImpl, + http_mode: HTTPModes, + http1_settings: Optional[HTTP1Settings], + http2_settings: Optional[HTTP2Settings], + websockets: bool, + log_enabled: bool, + log_level: LogLevels, + log_config: Dict[str, Any], + log_access_fmt: Optional[str], + ssl_ctx: Tuple[bool, Optional[str], Optional[str], Optional[str]], + scope_opts: Dict[str, Any], + ): + from granian._loops import loops + from granian._signals import set_loop_signals + + if process_name: + setproctitle.setproctitle(f'{process_name} worker-{worker_id}') + configure_logging(log_level, log_config, log_enabled) + + loop = loops.get(loop_impl) + sfd = socket.fileno() + target = callback_loader() + callback = getattr(target, '__rsgi__') if hasattr(target, '__rsgi__') else target + callback_init = ( + getattr(target, '__rsgi_init__') if hasattr(target, '__rsgi_init__') else lambda *args, **kwargs: None + ) + callback_del = ( + getattr(target, '__rsgi_del__') if hasattr(target, '__rsgi_del__') else lambda *args, **kwargs: None + ) + callback = _rsgi_call_wrap(callback, log_access_fmt) + shutdown_event = set_loop_signals(loop) + callback_init(loop) + + worker = RSGIWorker( + worker_id, + sfd, + threads, + io_blocking_threads, + blocking_threads, + backpressure, + http_mode, + http1_settings, + http2_settings, + websockets, + *ssl_ctx, + ) + serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) + scheduler = _new_cbscheduler( + loop, _future_watcher_wrapper(callback), impl_asyncio=task_impl == TaskImpl.asyncio + ) + serve(scheduler, loop, shutdown_event) + callback_del(loop) + + @staticmethod + def _spawn_wsgi_worker( + worker_id: int, + process_name: Optional[str], + callback_loader: Callable[..., Any], + socket: socket.socket, + loop_impl: Loops, + threads: int, + io_blocking_threads: Optional[int], + blocking_threads: int, + backpressure: int, + threading_mode: ThreadModes, + task_impl: TaskImpl, + http_mode: HTTPModes, + http1_settings: Optional[HTTP1Settings], + http2_settings: Optional[HTTP2Settings], + websockets: bool, + log_enabled: bool, + log_level: LogLevels, + log_config: Dict[str, Any], + log_access_fmt: Optional[str], + ssl_ctx: Tuple[bool, Optional[str], Optional[str], Optional[str]], + scope_opts: Dict[str, Any], + ): + from granian._loops import loops + from granian._signals import set_sync_signals + + if process_name: + setproctitle.setproctitle(f'{process_name} worker-{worker_id}') + configure_logging(log_level, log_config, log_enabled) + + loop = loops.get(loop_impl) + sfd = socket.fileno() + callback = callback_loader() + shutdown_event = set_sync_signals() + + worker = WSGIWorker( + worker_id, + sfd, + threads, + io_blocking_threads, + blocking_threads, + backpressure, + http_mode, + http1_settings, + http2_settings, + *ssl_ctx, + ) + serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) + scheduler = _new_cbscheduler( + loop, _wsgi_call_wrap(callback, scope_opts, log_access_fmt), impl_asyncio=task_impl == TaskImpl.asyncio + ) + serve(scheduler, loop, shutdown_event) + + def _spawn_worker(self, idx, target, callback_loader, socket_loader) -> WorkerProcess: + return WorkerProcess( + parent=self, + idx=idx, + target=target, + args=( + idx + 1, + self.process_name, + callback_loader, + socket_loader(), + self.loop, + self.threads, + self.io_blocking_threads, + self.blocking_threads, + self.backpressure, + self.threading_mode, + self.task_impl, + self.http, + self.http1_settings, + self.http2_settings, + self.websockets, + self.log_enabled, + self.log_level, + self.log_config, + self.log_access_format if self.log_access else None, + self.ssl_ctx, + {'url_path_prefix': self.url_path_prefix}, + ), + ) diff --git a/granian/server/mt.py b/granian/server/mt.py new file mode 100644 index 0000000..0526ac9 --- /dev/null +++ b/granian/server/mt.py @@ -0,0 +1,311 @@ +import socket +import sys +import threading +from typing import Any, Callable, Dict, Optional, Tuple + +from .._futures import _future_watcher_wrapper, _new_cbscheduler +from .._granian import ASGIWorker, RSGIWorker, WorkerSignal, WorkerSignalSync, WSGIWorker +from .._loops import loops +from ..asgi import LifespanProtocol, _callback_wrapper as _asgi_call_wrap +from ..errors import ConfigurationError, FatalError +from ..rsgi import _callback_wrapper as _rsgi_call_wrap +from ..wsgi import _callback_wrapper as _wsgi_call_wrap +from .common import ( + AbstractServer, + AbstractWorker, + HTTP1Settings, + HTTP2Settings, + HTTPModes, + Interfaces, + Loops, + TaskImpl, + ThreadModes, + logger, +) + + +class WorkerThread(AbstractWorker): + _idl = 'TID' + + def __init__(self, parent, idx, target, args, sig): + self._sig = sig + super().__init__(parent, idx, target, args) + + def _spawn(self, target, args): + self.inner = threading.Thread(name='granian-worker', target=target, args=args) + self._alive = True + + def _id(self): + return self.inner.native_id + + def _watcher(self): + self.inner.join() + self._alive = False + if not self.interrupt_by_parent: + logger.error(f'Unexpected exit from worker-{self.idx + 1}') + self.parent.interrupt_children.append(self.idx) + self.parent.main_loop_interrupt.set() + + def terminate(self): + self._alive = False + self.interrupt_by_parent = True + self._sig.set() + + def is_alive(self): + if not self._alive: + return False + return self.inner.is_alive() + + +class MTServer(AbstractServer[WorkerThread]): + @staticmethod + def _spawn_asgi_worker( + worker_id: int, + shutdown_event: Any, + callback: Any, + socket: socket.socket, + loop_impl: Loops, + threads: int, + io_blocking_threads: Optional[int], + blocking_threads: int, + backpressure: int, + threading_mode: ThreadModes, + task_impl: TaskImpl, + http_mode: HTTPModes, + http1_settings: Optional[HTTP1Settings], + http2_settings: Optional[HTTP2Settings], + websockets: bool, + log_access_fmt: Optional[str], + ssl_ctx: Tuple[bool, Optional[str], Optional[str], Optional[str]], + scope_opts: Dict[str, Any], + ): + loop = loops.get(loop_impl) + sfd = socket.fileno() + wcallback = _asgi_call_wrap(callback, scope_opts, {}, log_access_fmt) + + worker = ASGIWorker( + worker_id, + sfd, + threads, + io_blocking_threads, + blocking_threads, + backpressure, + http_mode, + http1_settings, + http2_settings, + websockets, + *ssl_ctx, + ) + serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) + scheduler = _new_cbscheduler( + loop, _future_watcher_wrapper(wcallback), impl_asyncio=task_impl == TaskImpl.asyncio + ) + serve(scheduler, loop, shutdown_event) + + @staticmethod + def _spawn_asgi_lifespan_worker( + worker_id: int, + shutdown_event: Any, + callback: Any, + socket: socket.socket, + loop_impl: Loops, + threads: int, + io_blocking_threads: Optional[int], + blocking_threads: int, + backpressure: int, + threading_mode: ThreadModes, + task_impl: TaskImpl, + http_mode: HTTPModes, + http1_settings: Optional[HTTP1Settings], + http2_settings: Optional[HTTP2Settings], + websockets: bool, + log_access_fmt: Optional[str], + ssl_ctx: Tuple[bool, Optional[str], Optional[str], Optional[str]], + scope_opts: Dict[str, Any], + ): + loop = loops.get(loop_impl) + sfd = socket.fileno() + + lifespan_handler = LifespanProtocol(callback) + loop.run_until_complete(lifespan_handler.startup()) + if lifespan_handler.interrupt: + logger.error('ASGI lifespan startup failed', exc_info=lifespan_handler.exc) + sys.exit(1) + + wcallback = _asgi_call_wrap(callback, scope_opts, lifespan_handler.state, log_access_fmt) + + worker = ASGIWorker( + worker_id, + sfd, + threads, + io_blocking_threads, + blocking_threads, + backpressure, + http_mode, + http1_settings, + http2_settings, + websockets, + *ssl_ctx, + ) + serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) + scheduler = _new_cbscheduler( + loop, _future_watcher_wrapper(wcallback), impl_asyncio=task_impl == TaskImpl.asyncio + ) + serve(scheduler, loop, shutdown_event) + loop.run_until_complete(lifespan_handler.shutdown()) + + @staticmethod + def _spawn_rsgi_worker( + worker_id: int, + shutdown_event: Any, + callback: Any, + socket: socket.socket, + loop_impl: Loops, + threads: int, + io_blocking_threads: Optional[int], + blocking_threads: int, + backpressure: int, + threading_mode: ThreadModes, + task_impl: TaskImpl, + http_mode: HTTPModes, + http1_settings: Optional[HTTP1Settings], + http2_settings: Optional[HTTP2Settings], + websockets: bool, + log_access_fmt: Optional[str], + ssl_ctx: Tuple[bool, Optional[str], Optional[str], Optional[str]], + scope_opts: Dict[str, Any], + ): + loop = loops.get(loop_impl) + sfd = socket.fileno() + callback_init = ( + getattr(callback, '__rsgi_init__') if hasattr(callback, '__rsgi_init__') else lambda *args, **kwargs: None + ) + callback_del = ( + getattr(callback, '__rsgi_del__') if hasattr(callback, '__rsgi_del__') else lambda *args, **kwargs: None + ) + callback = getattr(callback, '__rsgi__') if hasattr(callback, '__rsgi__') else callback + callback = _rsgi_call_wrap(callback, log_access_fmt) + callback_init(loop) + + worker = RSGIWorker( + worker_id, + sfd, + threads, + io_blocking_threads, + blocking_threads, + backpressure, + http_mode, + http1_settings, + http2_settings, + websockets, + *ssl_ctx, + ) + serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) + scheduler = _new_cbscheduler( + loop, _future_watcher_wrapper(callback), impl_asyncio=task_impl == TaskImpl.asyncio + ) + serve(scheduler, loop, shutdown_event) + callback_del(loop) + + @staticmethod + def _spawn_wsgi_worker( + worker_id: int, + shutdown_event: Any, + callback: Any, + socket: socket.socket, + loop_impl: Loops, + threads: int, + io_blocking_threads: Optional[int], + blocking_threads: int, + backpressure: int, + threading_mode: ThreadModes, + task_impl: TaskImpl, + http_mode: HTTPModes, + http1_settings: Optional[HTTP1Settings], + http2_settings: Optional[HTTP2Settings], + websockets: bool, + log_access_fmt: Optional[str], + ssl_ctx: Tuple[bool, Optional[str], Optional[str], Optional[str]], + scope_opts: Dict[str, Any], + ): + loop = loops.get(loop_impl) + sfd = socket.fileno() + + worker = WSGIWorker( + worker_id, + sfd, + threads, + io_blocking_threads, + blocking_threads, + backpressure, + http_mode, + http1_settings, + http2_settings, + *ssl_ctx, + ) + serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) + scheduler = _new_cbscheduler( + loop, _wsgi_call_wrap(callback, scope_opts, log_access_fmt), impl_asyncio=task_impl == TaskImpl.asyncio + ) + serve(scheduler, loop, shutdown_event) + + def _spawn_worker(self, idx, target, callback_loader, socket_loader) -> WorkerThread: + sig = WorkerSignalSync(threading.Event()) if self.interface == Interfaces.WSGI else WorkerSignal() + + return WorkerThread( + parent=self, + idx=idx, + target=target, + args=( + idx + 1, + sig, + callback_loader, + socket_loader(), + self.loop, + self.threads, + self.io_blocking_threads, + self.blocking_threads, + self.backpressure, + self.threading_mode, + self.task_impl, + self.http, + self.http1_settings, + self.http2_settings, + self.websockets, + self.log_access_format if self.log_access else None, + self.ssl_ctx, + {'url_path_prefix': self.url_path_prefix}, + ), + sig=sig, + ) + + def _check_gil(self): + try: + assert sys._is_gil_enabled() is False + except Exception: + logger.error('Cannot run a free-threaded Granian build with GIL enabled') + raise FatalError('GIL enabled on free-threaded build') + + def _serve(self, spawn_target, target_loader): + target = target_loader() + self._check_gil() + sock = self.startup(spawn_target, target) + self._serve_loop(sock, spawn_target, target) + self.shutdown() + + def _serve_with_reloader(self, spawn_target, target_loader): + raise NotImplementedError + + def serve( + self, + spawn_target: Optional[Callable[..., None]] = None, + target_loader: Optional[Callable[..., Callable[..., Any]]] = None, + wrap_loader: bool = True, + ): + logger.warning('free-threaded Python support is experimental!') + + if self.reload_on_changes: + logger.error('The changes reloader is not supported on the free-threaded build') + raise ConfigurationError('reload') + + super().serve(spawn_target, target_loader, wrap_loader) diff --git a/pyproject.toml b/pyproject.toml index 07543b1..16bfdfe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,6 @@ dynamic = [ requires-python = '>=3.9' dependencies = [ 'click>=8.0.0', - 'uvloop>=0.18.0; sys_platform != "win32" and platform_python_implementation == "CPython"', ] [project.optional-dependencies] @@ -44,6 +43,24 @@ pname = [ reload = [ 'watchfiles>=0.21,<2', ] +rloop = [ + 'rloop; sys_platform != "win32"', +] +uvloop = [ + 'uvloop>=0.18.0; sys_platform != "win32" and platform_python_implementation == "CPython"', +] + +all = ['granian[pname,reload]'] + +[project.urls] +Homepage = 'https://github.com/emmett-framework/granian' +Funding = 'https://github.com/sponsors/gi0baro' +Source = 'https://github.com/emmett-framework/granian' + +[dependency-groups] +build = [ + 'maturin~=1.8', +] lint = [ 'ruff~=0.5.0', ] @@ -52,21 +69,20 @@ test = [ 'pytest~=7.4.2', 'pytest-asyncio~=0.21.1', 'sniffio~=1.3', - 'websockets~=11.0', + 'websockets~=14.2', ] -all = ['granian[pname,reload]'] -dev = ['granian[all,lint,test]'] -[project.urls] -Homepage = 'https://github.com/emmett-framework/granian' -Funding = 'https://github.com/sponsors/gi0baro' -Source = 'https://github.com/emmett-framework/granian' +all = [ + { include-group = 'build' }, + { include-group = 'lint' }, + { include-group = 'test' }, +] [project.scripts] granian = 'granian:cli.entrypoint' [build-system] -requires = ['maturin>=1.1.0,<2'] +requires = ['maturin>=1.8.0,<2'] build-backend = 'maturin' [tool.maturin] @@ -98,7 +114,7 @@ extend-ignore = [ 'S110', # except pass is fine ] flake8-quotes = { inline-quotes = 'single', multiline-quotes = 'double' } -mccabe = { max-complexity = 15 } +mccabe = { max-complexity = 16 } [tool.ruff.format] quote-style = 'single' diff --git a/src/asgi/callbacks.rs b/src/asgi/callbacks.rs index 053fea6..c0c5f4a 100644 --- a/src/asgi/callbacks.rs +++ b/src/asgi/callbacks.rs @@ -13,7 +13,7 @@ use super::{ use crate::{ callbacks::ArcCBScheduler, http::{response_500, HTTPResponse}, - runtime::RuntimeRef, + runtime::{Runtime, RuntimeRef}, utils::log_application_callable_exception, ws::{HyperWebsocket, UpgradeData}, }; @@ -35,9 +35,9 @@ macro_rules! callback_impl_done_ws { } macro_rules! callback_impl_done_err { - ($self:expr, $err:expr) => { + ($self:expr, $py:expr, $err:expr) => { $self.done(); - log_application_callable_exception($err); + log_application_callable_exception($py, $err); }; } @@ -72,8 +72,8 @@ impl CallbackWatcherHTTP { callback_impl_done_http!(self); } - fn err(&self, err: Bound) { - callback_impl_done_err!(self, &PyErr::from_value(err)); + fn err(&self, py: Python, err: Bound) { + callback_impl_done_err!(self, py, &PyErr::from_value(err)); } fn taskref(&self, py: Python, task: PyObject) { @@ -106,8 +106,8 @@ impl CallbackWatcherWebsocket { callback_impl_done_ws!(self); } - fn err(&self, err: Bound) { - callback_impl_done_err!(self, &PyErr::from_value(err)); + fn err(&self, py: Python, err: Bound) { + callback_impl_done_err!(self, py, &PyErr::from_value(err)); } fn taskref(&self, py: Python, task: PyObject) { @@ -148,12 +148,11 @@ pub(crate) fn call_http( req: hyper::http::request::Parts, body: hyper::body::Incoming, ) -> oneshot::Receiver { - let brt = rt.innerb.clone(); let (tx, rx) = oneshot::channel(); - let protocol = HTTPProtocol::new(rt, body, tx); + let protocol = HTTPProtocol::new(rt.clone(), body, tx); let scheme: Arc = scheme.into(); - let _ = brt.run(move || { + rt.spawn_blocking(move |py| { scope_native_parts!( req, server_addr, @@ -164,11 +163,18 @@ pub(crate) fn call_http( server, client ); - Python::with_gil(|py| { - let scope = build_scope_http(py, &req, version, server, client, &scheme, &path, query_string).unwrap(); - let watcher = Py::new(py, CallbackWatcherHTTP::new(py, protocol, scope)).unwrap(); - cb.get().schedule(py, watcher.as_any()); - }); + cb.get().schedule( + py, + Py::new( + py, + CallbackWatcherHTTP::new( + py, + protocol, + build_scope_http(py, &req, version, server, client, &scheme, &path, query_string).unwrap(), + ), + ) + .unwrap(), + ); }); rx @@ -185,12 +191,11 @@ pub(crate) fn call_ws( req: hyper::http::request::Parts, upgrade: UpgradeData, ) -> oneshot::Receiver { - let brt = rt.innerb.clone(); let (tx, rx) = oneshot::channel(); - let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); + let protocol = WebsocketProtocol::new(rt.clone(), tx, ws, upgrade); let scheme: Arc = scheme.into(); - let _ = brt.run(move || { + rt.spawn_blocking(move |py| { scope_native_parts!( req, server_addr, @@ -201,11 +206,18 @@ pub(crate) fn call_ws( server, client ); - Python::with_gil(|py| { - let scope = build_scope_ws(py, &req, version, server, client, &scheme, &path, query_string).unwrap(); - let watcher = Py::new(py, CallbackWatcherWebsocket::new(py, protocol, scope)).unwrap(); - cb.get().schedule(py, watcher.as_any()); - }); + cb.get().schedule( + py, + Py::new( + py, + CallbackWatcherWebsocket::new( + py, + protocol, + build_scope_ws(py, &req, version, server, client, &scheme, &path, query_string).unwrap(), + ), + ) + .unwrap(), + ); }); rx diff --git a/src/asgi/serve.rs b/src/asgi/serve.rs index dd7d138..0bed473 100644 --- a/src/asgi/serve.rs +++ b/src/asgi/serve.rs @@ -4,7 +4,7 @@ use super::http::{handle, handle_ws}; use crate::callbacks::CallbackScheduler; use crate::conversion::{worker_http1_config_from_py, worker_http2_config_from_py}; -use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig, WorkerSignal, WorkerSignals}; +use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig, WorkerSignal}; #[pyclass(frozen, module = "granian._granian")] pub struct ASGIWorker { @@ -30,7 +30,8 @@ impl ASGIWorker { worker_id, socket_fd, threads=1, - blocking_threads=512, + io_blocking_threads=512, + blocking_threads=1, backpressure=256, http_mode="1", http1_opts=None, @@ -47,6 +48,7 @@ impl ASGIWorker { worker_id: i32, socket_fd: i32, threads: usize, + io_blocking_threads: usize, blocking_threads: usize, backpressure: usize, http_mode: &str, @@ -63,6 +65,7 @@ impl ASGIWorker { worker_id, socket_fd, threads, + io_blocking_threads, blocking_threads, backpressure, http_mode, @@ -79,19 +82,19 @@ impl ASGIWorker { fn serve_rth(&self, callback: Py, event_loop: &Bound, signal: Py) { match (self.config.websockets_enabled, self.config.ssl_enabled) { - (false, false) => self._serve_rth(callback, event_loop, WorkerSignals::Tokio(signal)), - (true, false) => self._serve_rth_ws(callback, event_loop, WorkerSignals::Tokio(signal)), - (false, true) => self._serve_rth_ssl(callback, event_loop, WorkerSignals::Tokio(signal)), - (true, true) => self._serve_rth_ssl_ws(callback, event_loop, WorkerSignals::Tokio(signal)), + (false, false) => self._serve_rth(callback, event_loop, signal), + (true, false) => self._serve_rth_ws(callback, event_loop, signal), + (false, true) => self._serve_rth_ssl(callback, event_loop, signal), + (true, true) => self._serve_rth_ssl_ws(callback, event_loop, signal), } } fn serve_wth(&self, callback: Py, event_loop: &Bound, signal: Py) { match (self.config.websockets_enabled, self.config.ssl_enabled) { - (false, false) => self._serve_wth(callback, event_loop, WorkerSignals::Tokio(signal)), - (true, false) => self._serve_wth_ws(callback, event_loop, WorkerSignals::Tokio(signal)), - (false, true) => self._serve_wth_ssl(callback, event_loop, WorkerSignals::Tokio(signal)), - (true, true) => self._serve_wth_ssl_ws(callback, event_loop, WorkerSignals::Tokio(signal)), + (false, false) => self._serve_wth(callback, event_loop, signal), + (true, false) => self._serve_wth_ws(callback, event_loop, signal), + (false, true) => self._serve_wth_ssl(callback, event_loop, signal), + (true, true) => self._serve_wth_ssl_ws(callback, event_loop, signal), } } } diff --git a/src/blocking.rs b/src/blocking.rs index 0fc36e9..358519a 100644 --- a/src/blocking.rs +++ b/src/blocking.rs @@ -1,50 +1,110 @@ use crossbeam_channel as channel; +use pyo3::prelude::*; use std::thread; pub(crate) struct BlockingTask { - inner: Box, + inner: Box, } impl BlockingTask { pub fn new(inner: T) -> BlockingTask where - T: FnOnce() + Send + 'static, + T: FnOnce(Python) + Send + 'static, { Self { inner: Box::new(inner) } } - pub fn run(self) { - (self.inner)(); + pub fn run(self, py: Python) { + (self.inner)(py); } } -#[derive(Clone)] pub(crate) struct BlockingRunner { queue: channel::Sender, + #[cfg(Py_GIL_DISABLED)] + sig: channel::Sender<()>, } impl BlockingRunner { - pub fn new() -> Self { - let queue = blocking_thread(); + #[cfg(not(Py_GIL_DISABLED))] + pub fn new(threads: usize) -> Self { + let queue = blocking_pool(threads); Self { queue } } + #[cfg(Py_GIL_DISABLED)] + pub fn new(threads: usize) -> Self { + let (sigtx, sigrx) = channel::bounded(1); + let queue = blocking_pool(threads, sigrx); + Self { queue, sig: sigtx } + } + pub fn run(&self, task: T) -> Result<(), channel::SendError> where - T: FnOnce() + Send + 'static, + T: FnOnce(Python) + Send + 'static, { self.queue.send(BlockingTask::new(task)) } -} -fn bloking_loop(queue: channel::Receiver) { - while let Ok(task) = queue.recv() { - task.run(); + #[cfg(Py_GIL_DISABLED)] + pub fn shutdown(&self) { + _ = self.sig.send(()); } } -fn blocking_thread() -> channel::Sender { +#[cfg(not(Py_GIL_DISABLED))] +fn blocking_loop(queue: channel::Receiver) { + while let Ok(task) = queue.recv() { + Python::with_gil(|py| task.run(py)); + } +} + +// NOTE: for some reason, on no-gil callback watchers are not GCd until following req. +// It's not clear atm wether this is an issue with pyo3, CPython itself, or smth +// different in terms of pointers due to multi-threaded environment. +// Thus, we need a signal to manually stop the loop and let the server shutdown. +// The following function would be the intended one if we hadn't the issue just described. +// +// #[cfg(Py_GIL_DISABLED)] +// fn blocking_loop(queue: channel::Receiver) { +// Python::with_gil(|py| { +// while let Ok(task) = queue.recv() { +// task.run(py); +// } +// }); +// } +#[cfg(Py_GIL_DISABLED)] +fn blocking_loop(queue: channel::Receiver, sig: channel::Receiver<()>) { + Python::with_gil(|py| loop { + crossbeam_channel::select! { + recv(queue) -> task => match task { + Ok(task) => task.run(py), + _ => break, + }, + recv(sig) -> _ => break + } + }); +} + +#[cfg(not(Py_GIL_DISABLED))] +fn blocking_pool(threads: usize) -> channel::Sender { let (qtx, qrx) = channel::unbounded(); - thread::spawn(|| bloking_loop(qrx)); + for _ in 0..threads { + let tqrx = qrx.clone(); + thread::spawn(|| blocking_loop(tqrx)); + } + + qtx +} + +#[cfg(Py_GIL_DISABLED)] +fn blocking_pool(threads: usize, sig: channel::Receiver<()>) -> channel::Sender { + let (qtx, qrx) = channel::unbounded(); + for _ in 0..threads { + let tqrx = qrx.clone(); + let tsig = sig.clone(); + thread::spawn(|| blocking_loop(tqrx, tsig)); + } + qtx } diff --git a/src/callbacks.rs b/src/callbacks.rs index 59c0524..ba00736 100644 --- a/src/callbacks.rs +++ b/src/callbacks.rs @@ -31,13 +31,15 @@ pub(crate) struct CallbackScheduler { #[cfg(not(PyPy))] impl CallbackScheduler { #[inline] - pub(crate) fn schedule(&self, _py: Python, watcher: &PyObject) { + pub(crate) fn schedule(&self, py: Python, watcher: Py) { let cbarg = watcher.as_ptr(); let sched = self.schedule_fn.get().unwrap().as_ptr(); unsafe { pyo3::ffi::PyObject_CallOneArg(sched, cbarg); } + + watcher.drop_ref(py); } #[inline] @@ -130,13 +132,15 @@ impl CallbackScheduler { #[cfg(PyPy)] impl CallbackScheduler { #[inline] - pub(crate) fn schedule(&self, py: Python, watcher: &PyObject) { + pub(crate) fn schedule(&self, py: Python, watcher: Py) { let cbarg = (watcher,).into_pyobject(py).unwrap().into_ptr(); let sched = self.schedule_fn.get().unwrap().as_ptr(); unsafe { pyo3::ffi::PyObject_CallObject(sched, cbarg); } + + watcher.drop_ref(py); } #[inline] @@ -508,8 +512,9 @@ impl PyIterAwaitable { } #[inline] - pub(crate) fn set_result(&self, py: Python, result: FutureResultToPy) { - let _ = self.result.set(result.into_pyobject(py).map(Bound::unbind)); + pub(crate) fn set_result(pyself: Py, py: Python, result: FutureResultToPy) { + _ = pyself.get().result.set(result.into_pyobject(py).map(Bound::unbind)); + pyself.drop_ref(py); } } @@ -524,7 +529,7 @@ impl PyIterAwaitable { } fn __next__(&self, py: Python) -> PyResult> { - if let Some(res) = py.allow_threads(|| self.result.get()) { + if let Some(res) = self.result.get() { return res .as_ref() .map_err(|err| err.clone_ref(py)) @@ -583,18 +588,22 @@ impl PyFutureAwaitable { ) .is_err() { + pyself.drop_ref(py); return; } - let ack = rself.ack.read().unwrap(); - if let Some((cb, ctx)) = &*ack { - let _ = rself.event_loop.clone_ref(py).call_method( - py, - pyo3::intern!(py, "call_soon_threadsafe"), - (cb, pyself.clone_ref(py)), - Some(ctx.bind(py)), - ); + { + let ack = rself.ack.read().unwrap(); + if let Some((cb, ctx)) = &*ack { + _ = rself.event_loop.clone_ref(py).call_method( + py, + pyo3::intern!(py, "call_soon_threadsafe"), + (cb, pyself.clone_ref(py)), + Some(ctx.bind(py)), + ); + } } + pyself.drop_ref(py); } } diff --git a/src/conversion.rs b/src/conversion.rs index 1f5c32f..aceed55 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -31,8 +31,8 @@ pub(crate) enum FutureResultToPy { Bytes(hyper::body::Bytes), ASGIMessage(crate::asgi::types::ASGIMessageType), ASGIWSMessage(tokio_tungstenite::tungstenite::Message), + RSGIWSAccept(crate::rsgi::io::RSGIWebsocketTransport), RSGIWSMessage(tokio_tungstenite::tungstenite::Message), - Py(PyObject), } impl<'p> IntoPyObject<'p> for FutureResultToPy { @@ -47,8 +47,8 @@ impl<'p> IntoPyObject<'p> for FutureResultToPy { Self::Bytes(inner) => inner.into_pyobject(py), Self::ASGIMessage(message) => crate::asgi::conversion::message_into_py(py, message), Self::ASGIWSMessage(message) => crate::asgi::conversion::ws_message_into_py(py, message), + Self::RSGIWSAccept(obj) => obj.into_bound_py_any(py), Self::RSGIWSMessage(message) => crate::rsgi::conversion::ws_message_into_py(py, message), - Self::Py(obj) => Ok(obj.into_bound(py)), } } } diff --git a/src/lib.rs b/src/lib.rs index 8fc3080..e1c3d2c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,11 @@ mod workers; mod ws; mod wsgi; +#[cfg(not(Py_GIL_DISABLED))] +const BUILD_GIL: bool = true; +#[cfg(Py_GIL_DISABLED)] +const BUILD_GIL: bool = false; + pub fn get_granian_version() -> &'static str { static GRANIAN_VERSION: OnceLock = OnceLock::new(); @@ -26,9 +31,10 @@ pub fn get_granian_version() -> &'static str { }) } -#[pymodule] +#[pymodule(gil_used = false)] fn _granian(py: Python, module: &Bound) -> PyResult<()> { module.add("__version__", get_granian_version())?; + module.add("BUILD_GIL", BUILD_GIL)?; module.add_class::()?; asgi::init_pymodule(module)?; rsgi::init_pymodule(py, module)?; diff --git a/src/rsgi/callbacks.rs b/src/rsgi/callbacks.rs index d00ef51..686de1f 100644 --- a/src/rsgi/callbacks.rs +++ b/src/rsgi/callbacks.rs @@ -8,7 +8,7 @@ use super::{ }; use crate::{ callbacks::ArcCBScheduler, - runtime::RuntimeRef, + runtime::{Runtime, RuntimeRef}, utils::log_application_callable_exception, ws::{HyperWebsocket, UpgradeData}, }; @@ -23,14 +23,14 @@ macro_rules! callback_impl_done_http { macro_rules! callback_impl_done_ws { ($self:expr) => { - let _ = $self.proto.get().close(None); + $self.proto.get().close(None); }; } macro_rules! callback_impl_done_err { - ($self:expr, $err:expr) => { + ($self:expr, $py:expr, $err:expr) => { $self.done(); - log_application_callable_exception($err); + log_application_callable_exception($py, $err); }; } @@ -65,8 +65,8 @@ impl CallbackWatcherHTTP { callback_impl_done_http!(self); } - fn err(&self, err: Bound) { - callback_impl_done_err!(self, &PyErr::from_value(err)); + fn err(&self, py: Python, err: Bound) { + callback_impl_done_err!(self, py, &PyErr::from_value(err)); } fn taskref(&self, py: Python, task: PyObject) { @@ -99,8 +99,8 @@ impl CallbackWatcherWebsocket { callback_impl_done_ws!(self); } - fn err(&self, err: Bound) { - callback_impl_done_err!(self, &PyErr::from_value(err)); + fn err(&self, py: Python, err: Bound) { + callback_impl_done_err!(self, py, &PyErr::from_value(err)); } fn taskref(&self, py: Python, task: PyObject) { @@ -115,15 +115,12 @@ pub(crate) fn call_http( body: hyper::body::Incoming, scope: HTTPScope, ) -> oneshot::Receiver { - let brt = rt.innerb.clone(); let (tx, rx) = oneshot::channel(); - let protocol = HTTPProtocol::new(rt, tx, body); + let protocol = HTTPProtocol::new(rt.clone(), tx, body); - let _ = brt.run(move || { - Python::with_gil(|py| { - let watcher = Py::new(py, CallbackWatcherHTTP::new(py, protocol, scope)).unwrap(); - cb.get().schedule(py, watcher.as_any()); - }); + rt.spawn_blocking(move |py| { + cb.get() + .schedule(py, Py::new(py, CallbackWatcherHTTP::new(py, protocol, scope)).unwrap()); }); rx @@ -137,15 +134,14 @@ pub(crate) fn call_ws( upgrade: UpgradeData, scope: WebsocketScope, ) -> oneshot::Receiver { - let brt = rt.innerb.clone(); let (tx, rx) = oneshot::channel(); - let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); + let protocol = WebsocketProtocol::new(rt.clone(), tx, ws, upgrade); - let _ = brt.run(move || { - Python::with_gil(|py| { - let watcher = Py::new(py, CallbackWatcherWebsocket::new(py, protocol, scope)).unwrap(); - cb.get().schedule(py, watcher.as_any()); - }); + rt.spawn_blocking(move |py| { + cb.get().schedule( + py, + Py::new(py, CallbackWatcherWebsocket::new(py, protocol, scope)).unwrap(), + ); }); rx diff --git a/src/rsgi/http.rs b/src/rsgi/http.rs index ee3278b..400cb75 100644 --- a/src/rsgi/http.rs +++ b/src/rsgi/http.rs @@ -1,3 +1,4 @@ +use futures::sink::SinkExt; use http_body_util::BodyExt; use hyper::{header::SERVER as HK_SERVER, http::response::Builder as ResponseBuilder, StatusCode}; use std::net::SocketAddr; @@ -99,7 +100,7 @@ macro_rules! handle_request_with_ws { let tx_ref = restx.clone(); match $handler_ws(callback, rt, ws, UpgradeData::new(res, restx), scope).await { - Ok((status, consumed, handle)) => match (consumed, handle) { + Ok((status, consumed, stream)) => match (consumed, stream) { (false, _) => { let _ = tx_ref .send( @@ -111,8 +112,8 @@ macro_rules! handle_request_with_ws { ) .await; } - (true, Some(handle)) => { - let _ = handle.await; + (true, Some(mut stream)) => { + let _ = stream.close().await; } _ => {} }, diff --git a/src/rsgi/io.rs b/src/rsgi/io.rs index 3728d02..65321d7 100644 --- a/src/rsgi/io.rs +++ b/src/rsgi/io.rs @@ -4,7 +4,7 @@ use hyper::body; use pyo3::{prelude::*, pybacked::PyBackedStr}; use std::{ borrow::Cow, - sync::{atomic, Arc, Mutex, RwLock}, + sync::{Arc, Mutex, RwLock}, }; use tokio::sync::{mpsc, oneshot, Mutex as AsyncMutex}; use tokio_tungstenite::tungstenite::Message; @@ -15,11 +15,11 @@ use super::{ }; use crate::{ conversion::FutureResultToPy, - runtime::{future_into_py_futlike, Runtime, RuntimeRef}, - ws::{HyperWebsocket, UpgradeData, WSRxStream, WSStream, WSTxStream}, + runtime::{future_into_py_futlike, RuntimeRef}, + ws::{HyperWebsocket, UpgradeData, WSRxStream, WSTxStream}, }; -pub(crate) type WebsocketDetachedTransport = (i32, bool, Option>); +pub(crate) type WebsocketDetachedTransport = (i32, bool, Option); #[pyclass(frozen, module = "granian._granian")] pub(crate) struct RSGIHTTPStreamTransport { @@ -183,38 +183,18 @@ impl RSGIHTTPProtocol { #[pyclass(frozen, module = "granian._granian")] pub(crate) struct RSGIWebsocketTransport { rt: RuntimeRef, - tx: Arc>, + tx: Arc>>, rx: Arc>, - closed: atomic::AtomicBool, } impl RSGIWebsocketTransport { - pub fn new(rt: RuntimeRef, transport: WSStream) -> Self { - let (tx, rx) = transport.split(); + pub fn new(rt: RuntimeRef, tx: Arc>>, rx: WSRxStream) -> Self { Self { rt, - tx: Arc::new(AsyncMutex::new(tx)), + tx, rx: Arc::new(AsyncMutex::new(rx)), - closed: false.into(), } } - - pub fn close(&self) -> Option> { - if self.closed.load(atomic::Ordering::Relaxed) { - return None; - } - self.closed.store(true, atomic::Ordering::Relaxed); - - let tx = self.tx.clone(); - let handle = self.rt.spawn(async move { - if let Ok(mut tx) = tx.try_lock() { - if let Err(err) = tx.close().await { - log::info!("Failed to close websocket with error {:?}", err); - } - } - }); - Some(handle) - } } #[pymethods] @@ -241,11 +221,13 @@ impl RSGIWebsocketTransport { let bdata: Box<[u8]> = data.into(); future_into_py_futlike(self.rt.clone(), py, async move { - if let Ok(mut stream) = transport.try_lock() { - return match stream.send(bdata[..].into()).await { - Ok(()) => FutureResultToPy::None, - _ => FutureResultToPy::Err(error_stream!()), - }; + if let Ok(mut guard) = transport.try_lock() { + if let Some(stream) = &mut *guard { + return match stream.send(bdata[..].into()).await { + Ok(()) => FutureResultToPy::None, + _ => FutureResultToPy::Err(error_stream!()), + }; + } } FutureResultToPy::Err(error_proto!()) }) @@ -255,11 +237,13 @@ impl RSGIWebsocketTransport { let transport = self.tx.clone(); future_into_py_futlike(self.rt.clone(), py, async move { - if let Ok(mut stream) = transport.try_lock() { - return match stream.send(data.into()).await { - Ok(()) => FutureResultToPy::None, - _ => FutureResultToPy::Err(error_stream!()), - }; + if let Ok(mut guard) = transport.try_lock() { + if let Some(stream) = &mut *guard { + return match stream.send(data.into()).await { + Ok(()) => FutureResultToPy::None, + _ => FutureResultToPy::Err(error_stream!()), + }; + } } FutureResultToPy::Err(error_proto!()) }) @@ -272,7 +256,7 @@ pub(crate) struct RSGIWebsocketProtocol { tx: Mutex>>, websocket: Arc>, upgrade: RwLock>, - transport: Arc>>>, + transport: Arc>>, } impl RSGIWebsocketProtocol { @@ -287,7 +271,7 @@ impl RSGIWebsocketProtocol { tx: Mutex::new(Some(tx)), websocket: Arc::new(AsyncMutex::new(websocket)), upgrade: RwLock::new(Some(upgrade)), - transport: Arc::new(Mutex::new(None)), + transport: Arc::new(AsyncMutex::new(None)), } } @@ -304,7 +288,7 @@ impl RSGIWebsocketProtocol { let mut handle = None; if let Ok(mut transport) = self.transport.try_lock() { if let Some(transport) = transport.take() { - handle = transport.get().close(); + handle = Some(transport); } } @@ -322,12 +306,16 @@ impl RSGIWebsocketProtocol { match upgrade.send(None).await { Ok(()) => match (&mut *ws).await { Ok(stream) => { - let mut trx = itransport.lock().unwrap(); - Python::with_gil(|py| { - let pytransport = Py::new(py, RSGIWebsocketTransport::new(rth, stream)).unwrap(); - *trx = Some(pytransport.clone_ref(py)); - FutureResultToPy::Py(pytransport.into_any()) - }) + let (stx, srx) = stream.split(); + { + let mut guard = itransport.lock().await; + *guard = Some(stx); + } + FutureResultToPy::RSGIWSAccept(RSGIWebsocketTransport::new( + rth.clone(), + itransport.clone(), + srx, + )) } _ => FutureResultToPy::Err(error_proto!()), }, diff --git a/src/rsgi/mod.rs b/src/rsgi/mod.rs index 4e8e24a..94a16a8 100644 --- a/src/rsgi/mod.rs +++ b/src/rsgi/mod.rs @@ -4,7 +4,7 @@ mod callbacks; pub(crate) mod conversion; mod errors; mod http; -mod io; +pub(crate) mod io; pub(crate) mod serve; mod types; diff --git a/src/rsgi/serve.rs b/src/rsgi/serve.rs index bc904b9..944d98f 100644 --- a/src/rsgi/serve.rs +++ b/src/rsgi/serve.rs @@ -4,7 +4,7 @@ use super::http::{handle, handle_ws}; use crate::callbacks::CallbackScheduler; use crate::conversion::{worker_http1_config_from_py, worker_http2_config_from_py}; -use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig, WorkerSignal, WorkerSignals}; +use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig, WorkerSignal}; #[pyclass(frozen, module = "granian._granian")] pub struct RSGIWorker { @@ -30,7 +30,8 @@ impl RSGIWorker { worker_id, socket_fd, threads=1, - blocking_threads=512, + io_blocking_threads=512, + blocking_threads=1, backpressure=256, http_mode="1", http1_opts=None, @@ -47,6 +48,7 @@ impl RSGIWorker { worker_id: i32, socket_fd: i32, threads: usize, + io_blocking_threads: usize, blocking_threads: usize, backpressure: usize, http_mode: &str, @@ -63,6 +65,7 @@ impl RSGIWorker { worker_id, socket_fd, threads, + io_blocking_threads, blocking_threads, backpressure, http_mode, @@ -79,19 +82,19 @@ impl RSGIWorker { fn serve_rth(&self, callback: Py, event_loop: &Bound, signal: Py) { match (self.config.websockets_enabled, self.config.ssl_enabled) { - (false, false) => self._serve_rth(callback, event_loop, WorkerSignals::Tokio(signal)), - (true, false) => self._serve_rth_ws(callback, event_loop, WorkerSignals::Tokio(signal)), - (false, true) => self._serve_rth_ssl(callback, event_loop, WorkerSignals::Tokio(signal)), - (true, true) => self._serve_rth_ssl_ws(callback, event_loop, WorkerSignals::Tokio(signal)), + (false, false) => self._serve_rth(callback, event_loop, signal), + (true, false) => self._serve_rth_ws(callback, event_loop, signal), + (false, true) => self._serve_rth_ssl(callback, event_loop, signal), + (true, true) => self._serve_rth_ssl_ws(callback, event_loop, signal), } } fn serve_wth(&self, callback: Py, event_loop: &Bound, signal: Py) { match (self.config.websockets_enabled, self.config.ssl_enabled) { - (false, false) => self._serve_wth(callback, event_loop, WorkerSignals::Tokio(signal)), - (true, false) => self._serve_wth_ws(callback, event_loop, WorkerSignals::Tokio(signal)), - (false, true) => self._serve_wth_ssl(callback, event_loop, WorkerSignals::Tokio(signal)), - (true, true) => self._serve_wth_ssl_ws(callback, event_loop, WorkerSignals::Tokio(signal)), + (false, false) => self._serve_wth(callback, event_loop, signal), + (true, false) => self._serve_wth_ws(callback, event_loop, signal), + (false, true) => self._serve_wth_ssl(callback, event_loop, signal), + (true, true) => self._serve_wth_ssl_ws(callback, event_loop, signal), } } } diff --git a/src/runtime.rs b/src/runtime.rs index 3fae92c..3fc05ca 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -34,7 +34,9 @@ pub trait Runtime: Send + 'static { where F: Future + Send + 'static; - fn blocking(&self) -> BlockingRunner; + fn spawn_blocking(&self, task: F) + where + F: FnOnce(Python) + Send + 'static; } pub trait ContextExt: Runtime { @@ -42,42 +44,49 @@ pub trait ContextExt: Runtime { } pub(crate) struct RuntimeWrapper { - rt: tokio::runtime::Runtime, - br: BlockingRunner, + pub inner: tokio::runtime::Runtime, + br: Arc, pr: Arc, } impl RuntimeWrapper { - pub fn new(blocking_threads: usize, py_loop: Arc) -> Self { + pub fn new(blocking_threads: usize, py_blocking_threads: usize, py_loop: Arc) -> Self { Self { - rt: default_runtime(blocking_threads), - br: BlockingRunner::new(), + inner: default_runtime(blocking_threads), + br: BlockingRunner::new(py_blocking_threads).into(), pr: py_loop, } } - pub fn with_runtime(rt: tokio::runtime::Runtime, py_loop: Arc) -> Self { + pub fn with_runtime(rt: tokio::runtime::Runtime, py_blocking_threads: usize, py_loop: Arc) -> Self { Self { - rt, - br: BlockingRunner::new(), + inner: rt, + br: BlockingRunner::new(py_blocking_threads).into(), pr: py_loop, } } pub fn handler(&self) -> RuntimeRef { - RuntimeRef::new(self.rt.handle().clone(), self.br.clone(), self.pr.clone()) + RuntimeRef::new(self.inner.handle().clone(), self.br.clone(), self.pr.clone()) + } +} + +#[cfg(Py_GIL_DISABLED)] +impl Drop for RuntimeWrapper { + fn drop(&mut self) { + self.br.shutdown(); } } #[derive(Clone)] pub struct RuntimeRef { pub inner: tokio::runtime::Handle, - pub innerb: BlockingRunner, + innerb: Arc, innerp: Arc, } impl RuntimeRef { - pub fn new(rt: tokio::runtime::Handle, br: BlockingRunner, pyloop: Arc) -> Self { + pub fn new(rt: tokio::runtime::Handle, br: Arc, pyloop: Arc) -> Self { Self { inner: rt, innerb: br, @@ -103,8 +112,11 @@ impl Runtime for RuntimeRef { self.inner.spawn(fut) } - fn blocking(&self) -> BlockingRunner { - self.innerb.clone() + fn spawn_blocking(&self, task: F) + where + F: FnOnce(Python) + Send + 'static, + { + _ = self.innerb.run(task); } } @@ -122,7 +134,12 @@ fn default_runtime(blocking_threads: usize) -> tokio::runtime::Runtime { .unwrap() } -pub(crate) fn init_runtime_mt(threads: usize, blocking_threads: usize, py_loop: Arc) -> RuntimeWrapper { +pub(crate) fn init_runtime_mt( + threads: usize, + blocking_threads: usize, + py_blocking_threads: usize, + py_loop: Arc, +) -> RuntimeWrapper { RuntimeWrapper::with_runtime( RuntimeBuilder::new_multi_thread() .worker_threads(threads) @@ -130,12 +147,17 @@ pub(crate) fn init_runtime_mt(threads: usize, blocking_threads: usize, py_loop: .enable_all() .build() .unwrap(), + py_blocking_threads, py_loop, ) } -pub(crate) fn init_runtime_st(blocking_threads: usize, py_loop: Arc) -> RuntimeWrapper { - RuntimeWrapper::new(blocking_threads, py_loop) +pub(crate) fn init_runtime_st( + blocking_threads: usize, + py_blocking_threads: usize, + py_loop: Arc, +) -> RuntimeWrapper { + RuntimeWrapper::new(blocking_threads, py_blocking_threads, py_loop) } // NOTE: @@ -151,16 +173,11 @@ where { let aw = Py::new(py, PyIterAwaitable::new())?; let py_fut = aw.clone_ref(py); - let rb = rt.blocking(); + let rth = rt.clone(); rt.spawn(async move { let result = fut.await; - let _ = rb.run(move || { - Python::with_gil(|py| { - aw.get().set_result(py, result); - drop(aw); - }); - }); + rth.spawn_blocking(move |py| PyIterAwaitable::set_result(aw, py, result)); }); Ok(py_fut.into_any().into_bound(py)) @@ -181,16 +198,12 @@ where let event_loop = rt.py_event_loop(py); let (aw, cancel_tx) = PyFutureAwaitable::new(event_loop).to_spawn(py)?; let py_fut = aw.clone_ref(py); - let rb = rt.blocking(); + let rth = rt.clone(); rt.spawn(async move { tokio::select! { - result = fut => { - let _ = rb.run(move || Python::with_gil(|py| PyFutureAwaitable::set_result(aw, py, result))); - }, - () = cancel_tx.notified() => { - let _ = rb.run(move || Python::with_gil(|_| drop(aw))); - } + result = fut => rth.spawn_blocking(move |py| PyFutureAwaitable::set_result(aw, py, result)), + () = cancel_tx.notified() => rth.spawn_blocking(move |py| aw.drop_ref(py)), } }); @@ -207,7 +220,7 @@ where let event_loop = rt.py_event_loop(py); let event_loop_ref = event_loop.clone_ref(py); let cancel_tx = Arc::new(tokio::sync::Notify::new()); - let rb = rt.blocking(); + let rth = rt.clone(); let py_fut = event_loop.call_method0(py, pyo3::intern!(py, "create_future"))?; py_fut.call_method1( @@ -222,25 +235,21 @@ where rt.spawn(async move { tokio::select! { result = fut => { - let _ = rb.run(move || { - Python::with_gil(|py| { - let pyres = result.into_pyobject(py).map(Bound::unbind); - let (cb, value) = match pyres { - Ok(val) => (fut_ref.getattr(py, pyo3::intern!(py, "set_result")).unwrap(), val), - Err(err) => (fut_ref.getattr(py, pyo3::intern!(py, "set_exception")).unwrap(), err.into_py_any(py).unwrap()) - }; - let _ = event_loop_ref.call_method1(py, pyo3::intern!(py, "call_soon_threadsafe"), (PyFutureResultSetter, cb, value)); - drop(fut_ref); - drop(event_loop_ref); - }); + rth.spawn_blocking(move |py| { + let pyres = result.into_pyobject(py).map(Bound::unbind); + let (cb, value) = match pyres { + Ok(val) => (fut_ref.getattr(py, pyo3::intern!(py, "set_result")).unwrap(), val), + Err(err) => (fut_ref.getattr(py, pyo3::intern!(py, "set_exception")).unwrap(), err.into_py_any(py).unwrap()) + }; + let _ = event_loop_ref.call_method1(py, pyo3::intern!(py, "call_soon_threadsafe"), (PyFutureResultSetter, cb, value)); + fut_ref.drop_ref(py); + event_loop_ref.drop_ref(py); }); }, () = cancel_tx.notified() => { - let _ = rb.run(move || { - Python::with_gil(|_| { - drop(fut_ref); - drop(event_loop_ref); - }); + rth.spawn_blocking(move |py| { + fut_ref.drop_ref(py); + event_loop_ref.drop_ref(py); }); } } @@ -256,9 +265,8 @@ pub(crate) fn empty_future_into_py(py: Python) -> PyResult> { } #[allow(unused_must_use)] -pub(crate) fn run_until_complete(rt: R, event_loop: Bound, fut: F) -> PyResult<()> +pub(crate) fn run_until_complete(rt: RuntimeWrapper, event_loop: Bound, fut: F) -> PyResult<()> where - R: Runtime + ContextExt + Clone, F: Future> + Send + 'static, { let result_tx = Arc::new(Mutex::new(None)); @@ -268,7 +276,7 @@ where let loop_tx = event_loop.clone().unbind(); let future_tx = py_fut.clone().unbind(); - rt.spawn(async move { + rt.inner.spawn(async move { let _ = fut.await; if let Ok(mut result) = result_tx.lock() { *result = Some(()); @@ -279,8 +287,8 @@ where Python::with_gil(move |py| { let res_method = future_tx.getattr(py, "set_result").unwrap(); let _ = loop_tx.call_method(py, "call_soon_threadsafe", (res_method, py.None()), None); - drop(future_tx); - drop(loop_tx); + future_tx.drop_ref(py); + loop_tx.drop_ref(py); }); }); @@ -294,5 +302,5 @@ pub(crate) fn block_on_local(rt: &RuntimeWrapper, local: LocalSet, fut: F) where F: Future + 'static, { - local.block_on(&rt.rt, fut); + local.block_on(&rt.inner, fut); } diff --git a/src/utils.rs b/src/utils.rs index 4ad10e8..0fee64f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,4 +1,4 @@ -use pyo3::types::PyTracebackMethods; +use pyo3::{prelude::*, types::PyTracebackMethods}; pub(crate) fn header_contains_value( headers: &hyper::HeaderMap, @@ -41,13 +41,11 @@ fn trim_end(data: &[u8]) -> &[u8] { } #[inline] -pub(crate) fn log_application_callable_exception(err: &pyo3::PyErr) { - let tb = pyo3::Python::with_gil(|py| { - let tb = match err.traceback(py).map(|t| t.format()) { - Some(Ok(tb)) => tb, - _ => String::new(), - }; - format!("{tb}{err}") - }); - log::error!("Application callable raised an exception\n{tb}"); +pub(crate) fn log_application_callable_exception(py: Python, err: &pyo3::PyErr) { + let tb = match err.traceback(py).map(|t| t.format()) { + Some(Ok(tb)) => tb, + _ => String::new(), + }; + let errs = format!("{tb}{err}"); + log::error!("Application callable raised an exception\n{errs}"); } diff --git a/src/workers.rs b/src/workers.rs index ebb63a5..fcaf6f9 100644 --- a/src/workers.rs +++ b/src/workers.rs @@ -12,11 +12,6 @@ use super::rsgi::serve::RSGIWorker; use super::tls::{load_certs as tls_load_certs, load_private_key as tls_load_pkey}; use super::wsgi::serve::WSGIWorker; -pub(crate) enum WorkerSignals { - Tokio(Py), - Crossbeam(Py), -} - #[pyclass(frozen, module = "granian._granian")] pub(crate) struct WorkerSignal { pub rx: Mutex>>, @@ -95,6 +90,7 @@ pub(crate) struct WorkerConfig { socket_fd: i32, pub threads: usize, pub blocking_threads: usize, + pub io_blocking_threads: usize, pub backpressure: usize, pub http_mode: String, pub http1_opts: HTTP1Config, @@ -111,6 +107,7 @@ impl WorkerConfig { id: i32, socket_fd: i32, threads: usize, + io_blocking_threads: usize, blocking_threads: usize, backpressure: usize, http_mode: &str, @@ -127,6 +124,7 @@ impl WorkerConfig { socket_fd, threads, blocking_threads, + io_blocking_threads, backpressure, http_mode: http_mode.into(), http1_opts, @@ -592,9 +590,9 @@ macro_rules! serve_rth { &self, callback: Py, event_loop: &Bound, - signal: crate::workers::WorkerSignals, + signal: Py, ) { - pyo3_log::init(); + _ = pyo3_log::try_init(); let worker_id = self.config.id; log::info!("Started worker-{}", worker_id); @@ -609,29 +607,14 @@ macro_rules! serve_rth { let rt = crate::runtime::init_runtime_mt( self.config.threads, + self.config.io_blocking_threads, self.config.blocking_threads, std::sync::Arc::new(event_loop.clone().unbind()), ); let rth = rt.handler(); - let mut srx = match signal { - crate::workers::WorkerSignals::Crossbeam(sig) => { - let (stx, srx) = tokio::sync::watch::channel(false); - std::thread::spawn(move || { - let pyrx = sig.get().rx.lock().unwrap().take().unwrap(); - let _ = pyrx.recv(); - stx.send(true).unwrap(); + let mut srx = signal.get().rx.lock().unwrap().take().unwrap(); - Python::with_gil(|py| { - let _ = sig.get().release(py); - drop(sig); - }); - }); - srx - } - crate::workers::WorkerSignals::Tokio(sig) => sig.get().rx.lock().unwrap().take().unwrap(), - }; - - let main_loop = crate::runtime::run_until_complete(rt.handler(), event_loop.clone(), async move { + let main_loop = crate::runtime::run_until_complete(rt, event_loop.clone(), async move { crate::workers::loop_match!( http_mode, http_upgrades, @@ -654,13 +637,10 @@ macro_rules! serve_rth { Ok(()) }); - match main_loop { - Ok(()) => {} - Err(err) => { - log::error!("{}", err); - std::process::exit(1); - } - }; + if let Err(err) = main_loop { + log::error!("{}", err); + std::process::exit(1); + } } }; } @@ -671,10 +651,9 @@ macro_rules! serve_rth_ssl { &self, callback: Py, event_loop: &Bound, - // context: Bound, - signal: crate::workers::WorkerSignals, + signal: Py, ) { - pyo3_log::init(); + _ = pyo3_log::try_init(); let worker_id = self.config.id; log::info!("Started worker-{}", worker_id); @@ -686,34 +665,18 @@ macro_rules! serve_rth_ssl { let http2_opts = self.config.http2_opts.clone(); let backpressure = self.config.backpressure.clone(); let tls_cfg = self.config.tls_cfg(); - // let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop.clone(), context); let callback_wrapper = std::sync::Arc::new(callback); let rt = crate::runtime::init_runtime_mt( self.config.threads, + self.config.io_blocking_threads, self.config.blocking_threads, std::sync::Arc::new(event_loop.clone().unbind()), ); let rth = rt.handler(); - let mut srx = match signal { - crate::workers::WorkerSignals::Crossbeam(sig) => { - let (stx, srx) = tokio::sync::watch::channel(false); - std::thread::spawn(move || { - let pyrx = sig.get().rx.lock().unwrap().take().unwrap(); - let _ = pyrx.recv(); - stx.send(true).unwrap(); + let mut srx = signal.get().rx.lock().unwrap().take().unwrap(); - Python::with_gil(|py| { - let _ = sig.get().release(py); - drop(sig); - }); - }); - srx - } - crate::workers::WorkerSignals::Tokio(sig) => sig.get().rx.lock().unwrap().take().unwrap(), - }; - - let main_loop = crate::runtime::run_until_complete(rt.handler(), event_loop.clone(), async move { + let main_loop = crate::runtime::run_until_complete(rt, event_loop.clone(), async move { crate::workers::loop_match_tls!( http_mode, http_upgrades, @@ -737,20 +700,16 @@ macro_rules! serve_rth_ssl { Ok(()) }); - match main_loop { - Ok(()) => {} - Err(err) => { - log::error!("{}", err); - std::process::exit(1); - } - }; + if let Err(err) = main_loop { + log::error!("{}", err); + std::process::exit(1); + } } }; } macro_rules! serve_wth_inner { ($self:expr, $target:expr, $callback:expr, $event_loop:expr, $wid:expr, $workers:expr, $srx:expr) => { - // let callback_wrapper = crate::callbacks::CallbackWrapper::new($callback, $event_loop.clone(), $context); let callback_wrapper = std::sync::Arc::new($callback); let py_loop = std::sync::Arc::new($event_loop.clone().unbind()); @@ -762,6 +721,7 @@ macro_rules! serve_wth_inner { let http_upgrades = $self.config.websockets_enabled; let http1_opts = $self.config.http1_opts.clone(); let http2_opts = $self.config.http2_opts.clone(); + let io_blocking_threads = $self.config.io_blocking_threads.clone(); let blocking_threads = $self.config.blocking_threads.clone(); let backpressure = $self.config.backpressure.clone(); let callback_wrapper = callback_wrapper.clone(); @@ -769,7 +729,7 @@ macro_rules! serve_wth_inner { let mut srx = $srx.clone(); $workers.push(std::thread::spawn(move || { - let rt = crate::runtime::init_runtime_st(blocking_threads, py_loop); + let rt = crate::runtime::init_runtime_st(io_blocking_threads, blocking_threads, py_loop); let rth = rt.handler(); let local = tokio::task::LocalSet::new(); @@ -791,8 +751,6 @@ macro_rules! serve_wth_inner { ); log::info!("Stopping worker-{} runtime-{}", $wid, thread_id + 1); - - Python::with_gil(|_| drop(callback_wrapper)); }); Python::with_gil(|_| drop(rt)); @@ -802,15 +760,14 @@ macro_rules! serve_wth_inner { } macro_rules! serve_wth { - ($func_name: ident, $target:expr) => { + ($func_name:ident, $target:expr) => { fn $func_name( &self, callback: Py, event_loop: &Bound, - // context: Bound, - signal: crate::workers::WorkerSignals, + signal: Py, ) { - pyo3_log::init(); + _ = pyo3_log::try_init(); let worker_id = self.config.id; log::info!("Started worker-{}", worker_id); @@ -819,44 +776,21 @@ macro_rules! serve_wth { let mut workers = vec![]; crate::workers::serve_wth_inner!(self, $target, callback, event_loop, worker_id, workers, srx); - match signal { - crate::workers::WorkerSignals::Tokio(sig) => { - let rtm = crate::runtime::init_runtime_mt(1, 1, std::sync::Arc::new(event_loop.clone().unbind())); - let mut pyrx = sig.get().rx.lock().unwrap().take().unwrap(); - let main_loop = crate::runtime::run_until_complete(rtm.handler(), event_loop.clone(), async move { - let _ = pyrx.changed().await; - stx.send(true).unwrap(); - log::info!("Stopping worker-{}", worker_id); - while let Some(worker) = workers.pop() { - worker.join().unwrap(); - } - Ok(()) - }); - - match main_loop { - Ok(()) => {} - Err(err) => { - log::error!("{}", err); - std::process::exit(1); - } - }; + let rtm = crate::runtime::init_runtime_mt(1, 1, 1, std::sync::Arc::new(event_loop.clone().unbind())); + let mut pyrx = signal.get().rx.lock().unwrap().take().unwrap(); + let main_loop = crate::runtime::run_until_complete(rtm, event_loop.clone(), async move { + let _ = pyrx.changed().await; + stx.send(true).unwrap(); + log::info!("Stopping worker-{}", worker_id); + while let Some(worker) = workers.pop() { + worker.join().unwrap(); } - crate::workers::WorkerSignals::Crossbeam(sig) => { - std::thread::spawn(move || { - let pyrx = sig.get().rx.lock().unwrap().take().unwrap(); - let _ = pyrx.recv(); - stx.send(true).unwrap(); - log::info!("Stopping worker-{}", worker_id); - while let Some(worker) = workers.pop() { - worker.join().unwrap(); - } + Ok(()) + }); - Python::with_gil(|py| { - let _ = sig.get().release(py); - drop(sig); - }); - }); - } + if let Err(err) = main_loop { + log::error!("{}", err); + std::process::exit(1); } } }; @@ -864,7 +798,6 @@ macro_rules! serve_wth { macro_rules! serve_wth_ssl_inner { ($self:expr, $target:expr, $callback:expr, $event_loop:expr, $wid:expr, $workers:expr, $srx:expr) => { - // let callback_wrapper = crate::callbacks::CallbackWrapper::new($callback, $event_loop.clone(), $context); let callback_wrapper = std::sync::Arc::new($callback); let py_loop = std::sync::Arc::new($event_loop.clone().unbind()); @@ -877,6 +810,7 @@ macro_rules! serve_wth_ssl_inner { let http1_opts = $self.config.http1_opts.clone(); let http2_opts = $self.config.http2_opts.clone(); let tls_cfg = $self.config.tls_cfg(); + let io_blocking_threads = $self.config.io_blocking_threads.clone(); let blocking_threads = $self.config.blocking_threads.clone(); let backpressure = $self.config.backpressure.clone(); let callback_wrapper = callback_wrapper.clone(); @@ -884,7 +818,7 @@ macro_rules! serve_wth_ssl_inner { let mut srx = $srx.clone(); $workers.push(std::thread::spawn(move || { - let rt = crate::runtime::init_runtime_st(blocking_threads, py_loop); + let rt = crate::runtime::init_runtime_st(io_blocking_threads, blocking_threads, py_loop); let rth = rt.handler(); let local = tokio::task::LocalSet::new(); @@ -914,15 +848,14 @@ macro_rules! serve_wth_ssl_inner { } macro_rules! serve_wth_ssl { - ($func_name: ident, $target:expr) => { + ($func_name:ident, $target:expr) => { fn $func_name( &self, callback: Py, event_loop: &Bound, - // context: Bound, - signal: crate::workers::WorkerSignals, + signal: Py, ) { - pyo3_log::init(); + _ = pyo3_log::try_init(); let worker_id = self.config.id; log::info!("Started worker-{}", worker_id); @@ -931,41 +864,21 @@ macro_rules! serve_wth_ssl { let mut workers = vec![]; crate::workers::serve_wth_ssl_inner!(self, $target, callback, event_loop, worker_id, workers, srx); - match signal { - crate::workers::WorkerSignals::Tokio(sig) => { - let rtm = crate::runtime::init_runtime_mt(1, 1, std::sync::Arc::new(event_loop.clone().unbind())); - let mut pyrx = sig.get().rx.lock().unwrap().take().unwrap(); - let main_loop = crate::runtime::run_until_complete(rtm.handler(), event_loop.clone(), async move { - let _ = pyrx.changed().await; - stx.send(true).unwrap(); - log::info!("Stopping worker-{}", worker_id); - while let Some(worker) = workers.pop() { - worker.join().unwrap(); - } - Ok(()) - }); - - match main_loop { - Ok(()) => {} - Err(err) => { - log::error!("{}", err); - std::process::exit(1); - } - }; + let rtm = crate::runtime::init_runtime_mt(1, 1, 1, std::sync::Arc::new(event_loop.clone().unbind())); + let mut pyrx = signal.get().rx.lock().unwrap().take().unwrap(); + let main_loop = crate::runtime::run_until_complete(rtm, event_loop.clone(), async move { + let _ = pyrx.changed().await; + stx.send(true).unwrap(); + log::info!("Stopping worker-{}", worker_id); + while let Some(worker) = workers.pop() { + worker.join().unwrap(); } - crate::workers::WorkerSignals::Crossbeam(sig) => { - let py = event_loop.py(); - let pyrx = sig.get().rx.lock().unwrap().take().unwrap(); + Ok(()) + }); - py.allow_threads(|| { - let _ = pyrx.recv(); - stx.send(true).unwrap(); - log::info!("Stopping worker-{}", worker_id); - while let Some(worker) = workers.pop() { - worker.join().unwrap(); - } - }); - } + if let Err(err) = main_loop { + log::error!("{}", err); + std::process::exit(1); } } }; diff --git a/src/wsgi/callbacks.rs b/src/wsgi/callbacks.rs index 5946292..1c48e0a 100644 --- a/src/wsgi/callbacks.rs +++ b/src/wsgi/callbacks.rs @@ -7,7 +7,7 @@ use itertools::Itertools; use percent_encoding::percent_decode_str; use pyo3::{ prelude::*, - types::{IntoPyDict, PyBytes, PyDict}, + types::{PyBytes, PyDict}, }; use std::net::SocketAddr; use tokio::sync::oneshot; @@ -16,93 +16,21 @@ use super::{io::WSGIProtocol, types::WSGIBody}; use crate::{ callbacks::ArcCBScheduler, http::{empty_body, HTTPResponseBody}, - runtime::RuntimeRef, + runtime::{Runtime, RuntimeRef}, utils::log_application_callable_exception, }; -#[inline] -fn run_callback( - rt: RuntimeRef, - tx: oneshot::Sender<(u16, HeaderMap, HTTPResponseBody)>, - cbs: ArcCBScheduler, - mut parts: request::Parts, - server_addr: SocketAddr, - client_addr: SocketAddr, - scheme: &str, - body: body::Incoming, -) { - let (path_raw, query_string) = parts - .uri - .path_and_query() - .map_or_else(|| ("", ""), |pq| (pq.path(), pq.query().unwrap_or(""))); - let path = percent_decode_str(path_raw).collect_vec(); - let version = match parts.version { - Version::HTTP_10 => "HTTP/1", - Version::HTTP_11 => "HTTP/1.1", - Version::HTTP_2 => "HTTP/2", - Version::HTTP_3 => "HTTP/3", - _ => "HTTP/1", +macro_rules! environ_set { + ($py:expr, $env:expr, $key:expr, $val:expr) => { + $env.set_item(pyo3::intern!($py, $key), $val).unwrap() }; - let server = (server_addr.ip().to_string(), server_addr.port().to_string()); - let client = client_addr.to_string(); - let content_type = parts.headers.remove(header::CONTENT_TYPE); - let content_len = parts.headers.remove(header::CONTENT_LENGTH); - let mut headers = Vec::with_capacity(parts.headers.len()); - for key in parts.headers.keys() { - headers.push(( - format!("HTTP_{}", key.as_str().replace('-', "_").to_uppercase()), - parts - .headers - .get_all(key) - .iter() - .map(|v| v.to_str().unwrap_or_default()) - .join(","), - )); - } - if !parts.headers.contains_key(header::HOST) { - let host = parts.uri.authority().map_or("", Authority::as_str); - headers.push(("HTTP_HOST".to_string(), host.to_string())); - } +} - let _ = Python::with_gil(|py| -> PyResult<()> { - let proto = Py::new(py, WSGIProtocol::new(tx))?; - let callback = cbs.get().cb.clone_ref(py); - let environ = PyDict::new(py); - environ.set_item(pyo3::intern!(py, "SERVER_PROTOCOL"), version)?; - environ.set_item(pyo3::intern!(py, "SERVER_NAME"), server.0)?; - environ.set_item(pyo3::intern!(py, "SERVER_PORT"), server.1)?; - environ.set_item(pyo3::intern!(py, "REMOTE_ADDR"), client)?; - environ.set_item(pyo3::intern!(py, "REQUEST_METHOD"), parts.method.as_str())?; - environ.set_item( - pyo3::intern!(py, "PATH_INFO"), - PyBytes::new(py, &path).call_method1(pyo3::intern!(py, "decode"), (pyo3::intern!(py, "latin1"),))?, - )?; - environ.set_item(pyo3::intern!(py, "QUERY_STRING"), query_string)?; - environ.set_item(pyo3::intern!(py, "wsgi.url_scheme"), scheme)?; - environ.set_item(pyo3::intern!(py, "wsgi.input"), Py::new(py, WSGIBody::new(rt, body))?)?; - if let Some(content_type) = content_type { - environ.set_item( - pyo3::intern!(py, "CONTENT_TYPE"), - content_type.to_str().unwrap_or_default(), - )?; - } - if let Some(content_len) = content_len { - environ.set_item( - pyo3::intern!(py, "CONTENT_LENGTH"), - content_len.to_str().unwrap_or_default(), - )?; - } - environ.update(headers.into_py_dict(py).unwrap().as_mapping())?; - - if let Err(err) = callback.call1(py, (proto.clone_ref(py), environ)) { - log_application_callable_exception(&err); - if let Some(tx) = proto.get().tx() { - let _ = tx.send((500, HeaderMap::new(), empty_body())); - } - } - - Ok(()) - }); +macro_rules! environ_set_header { + ($py:expr, $env:expr, $key:expr, $val:expr) => { + $env.set_item(format!("HTTP_{}", $key.as_str().replace('-', "_").to_uppercase()), $val) + .unwrap() + }; } #[inline(always)] @@ -112,13 +40,93 @@ pub(crate) fn call_http( server_addr: SocketAddr, client_addr: SocketAddr, scheme: &str, - req: request::Parts, + mut req: request::Parts, body: body::Incoming, ) -> oneshot::Receiver<(u16, HeaderMap, HTTPResponseBody)> { - let scheme: std::sync::Arc = scheme.into(); let (tx, rx) = oneshot::channel(); - tokio::task::spawn_blocking(move || { - run_callback(rt, tx, cb, req, server_addr, client_addr, &scheme, body); + let proto = WSGIProtocol::new(tx); + let body_wrapper = WSGIBody::new(rt.clone(), body); + + let scheme: Box = scheme.into(); + let version = match req.version { + Version::HTTP_10 => "HTTP/1", + Version::HTTP_11 => "HTTP/1.1", + Version::HTTP_2 => "HTTP/2", + Version::HTTP_3 => "HTTP/3", + _ => "HTTP/1", + }; + let (path, query_string): (Vec, Box) = req.uri.path_and_query().map_or_else( + || (vec![], "".into()), + |pq| { + ( + percent_decode_str(pq.path()).collect_vec(), + pq.query().unwrap_or("").into(), + ) + }, + ); + let server = (server_addr.ip().to_string(), server_addr.port().to_string()); + + rt.spawn_blocking(move |py| { + let callback = cb.get().cb.clone_ref(py); + let proto = Py::new(py, proto).unwrap(); + let body = Py::new(py, body_wrapper).unwrap(); + + let environ = PyDict::new(py); + environ_set!(py, environ, "SERVER_PROTOCOL", version); + environ_set!(py, environ, "SERVER_NAME", server.0); + environ_set!(py, environ, "SERVER_PORT", server.1); + environ_set!(py, environ, "REMOTE_ADDR", client_addr.to_string()); + environ_set!(py, environ, "REQUEST_METHOD", req.method.as_str()); + environ_set!( + py, + environ, + "PATH_INFO", + PyBytes::new(py, &path) + .call_method1(pyo3::intern!(py, "decode"), (pyo3::intern!(py, "latin1"),)) + .unwrap() + ); + environ_set!(py, environ, "QUERY_STRING", &query_string[..]); + environ_set!(py, environ, "wsgi.url_scheme", &scheme[..]); + environ_set!(py, environ, "wsgi.input", body); + + if let Some(content_type) = req.headers.remove(header::CONTENT_TYPE) { + environ_set!(py, environ, "CONTENT_TYPE", content_type.to_str().unwrap_or_default()); + } + if let Some(content_len) = req.headers.remove(header::CONTENT_LENGTH) { + environ_set!(py, environ, "CONTENT_LENGTH", content_len.to_str().unwrap_or_default()); + } + + for key in req.headers.keys() { + environ_set_header!( + py, + environ, + key, + req.headers + .get_all(key) + .iter() + .map(|v| v.to_str().unwrap_or_default()) + .join(",") + ); + } + if !req.headers.contains_key(header::HOST) { + environ_set!( + py, + environ, + "HTTP_HOST", + req.uri.authority().map_or("", Authority::as_str) + ); + } + + if let Err(err) = callback.call1(py, (proto.clone_ref(py), environ)) { + log_application_callable_exception(py, &err); + if let Some(tx) = proto.get().tx() { + let _ = tx.send((500, HeaderMap::new(), empty_body())); + } + } + + proto.drop_ref(py); + callback.drop_ref(py); }); + rx } diff --git a/src/wsgi/io.rs b/src/wsgi/io.rs index bccc9fa..06ec0ee 100644 --- a/src/wsgi/io.rs +++ b/src/wsgi/io.rs @@ -8,6 +8,7 @@ use pyo3::{prelude::*, pybacked::PyBackedStr}; use std::{borrow::Cow, sync::Mutex}; use tokio::sync::{mpsc, oneshot}; +use super::utils::py_allow_threads; use crate::{ http::{HTTPResponseBody, HV_SERVER}, utils::log_application_callable_exception, @@ -82,14 +83,14 @@ impl WSGIProtocol { }, Err(err) => { if !err.is_instance_of::(py) { - log_application_callable_exception(&err); + log_application_callable_exception(py, &err); } let _ = body.call_method0(pyo3::intern!(py, "close")); closed = true; None } } { - if py.allow_threads(|| body_tx.blocking_send(Ok(frame))).is_ok() { + if py_allow_threads!(py, { body_tx.blocking_send(Ok(frame)) }).is_ok() { continue; } } diff --git a/src/wsgi/mod.rs b/src/wsgi/mod.rs index 66dcd7b..ea3a6c4 100644 --- a/src/wsgi/mod.rs +++ b/src/wsgi/mod.rs @@ -3,3 +3,4 @@ mod http; mod io; pub(crate) mod serve; mod types; +mod utils; diff --git a/src/wsgi/serve.rs b/src/wsgi/serve.rs index f10922e..b31bfa5 100644 --- a/src/wsgi/serve.rs +++ b/src/wsgi/serve.rs @@ -4,9 +4,7 @@ use super::http::handle; use crate::callbacks::CallbackScheduler; use crate::conversion::{worker_http1_config_from_py, worker_http2_config_from_py}; -use crate::workers::{ - serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig, WorkerSignalSync, WorkerSignals, -}; +use crate::workers::{WorkerConfig, WorkerSignalSync}; #[pyclass(frozen, module = "granian._granian")] pub struct WSGIWorker { @@ -14,10 +12,213 @@ pub struct WSGIWorker { } impl WSGIWorker { - serve_rth!(_serve_rth, handle); - serve_wth!(_serve_wth, handle); - serve_rth_ssl!(_serve_rth_ssl, handle); - serve_wth_ssl!(_serve_wth_ssl, handle); + fn _serve_rth( + &self, + py: Python, + callback: Py, + event_loop: &Bound, + signal: Py, + ) { + _ = pyo3_log::try_init(); + + let worker_id = self.config.id; + log::info!("Started worker-{}", worker_id); + + let tcp_listener = self.config.tcp_listener(); + let http_mode = self.config.http_mode.clone(); + let http_upgrades = self.config.websockets_enabled; + let http1_opts = self.config.http1_opts.clone(); + let http2_opts = self.config.http2_opts.clone(); + let backpressure = self.config.backpressure; + let callback_wrapper = std::sync::Arc::new(callback); + + let rt = crate::runtime::init_runtime_mt( + self.config.threads, + self.config.io_blocking_threads, + self.config.blocking_threads, + std::sync::Arc::new(event_loop.clone().unbind()), + ); + let rth = rt.handler(); + + let (stx, mut srx) = tokio::sync::watch::channel(false); + let main_loop = rt.inner.spawn(async move { + crate::workers::loop_match!( + http_mode, + http_upgrades, + tcp_listener, + srx, + backpressure, + rth, + callback_wrapper, + tokio::spawn, + hyper_util::rt::TokioExecutor::new, + http1_opts, + http2_opts, + hyper_util::rt::TokioIo::new, + handle + ); + + log::info!("Stopping worker-{}", worker_id); + + Python::with_gil(|_| drop(callback_wrapper)); + }); + + let pysig = signal.clone_ref(py); + std::thread::spawn(move || { + let pyrx = pysig.get().rx.lock().unwrap().take().unwrap(); + _ = pyrx.recv(); + stx.send(true).unwrap(); + + while !main_loop.is_finished() { + std::thread::sleep(std::time::Duration::from_millis(1)); + } + + Python::with_gil(|py| { + _ = pysig.get().release(py); + drop(pysig); + }); + }); + + _ = signal.get().qs.call_method0(py, pyo3::intern!(py, "wait")); + } + + fn _serve_wth( + &self, + py: Python, + callback: Py, + event_loop: &Bound, + signal: Py, + ) { + _ = pyo3_log::try_init(); + + let worker_id = self.config.id; + log::info!("Started worker-{}", worker_id); + + let (stx, srx) = tokio::sync::watch::channel(false); + let mut workers = vec![]; + crate::workers::serve_wth_inner!(self, handle, callback, event_loop, worker_id, workers, srx); + + let pysig = signal.clone_ref(py); + std::thread::spawn(move || { + let pyrx = pysig.get().rx.lock().unwrap().take().unwrap(); + _ = pyrx.recv(); + stx.send(true).unwrap(); + log::info!("Stopping worker-{}", worker_id); + while let Some(worker) = workers.pop() { + worker.join().unwrap(); + } + + Python::with_gil(|py| { + _ = pysig.get().release(py); + drop(pysig); + }); + }); + + _ = signal.get().qs.call_method0(py, pyo3::intern!(py, "wait")); + } + + fn _serve_rth_ssl( + &self, + py: Python, + callback: Py, + event_loop: &Bound, + signal: Py, + ) { + _ = pyo3_log::try_init(); + + let worker_id = self.config.id; + log::info!("Started worker-{}", worker_id); + + let tcp_listener = self.config.tcp_listener(); + let http_mode = self.config.http_mode.clone(); + let http_upgrades = self.config.websockets_enabled; + let http1_opts = self.config.http1_opts.clone(); + let http2_opts = self.config.http2_opts.clone(); + let backpressure = self.config.backpressure; + let tls_cfg = self.config.tls_cfg(); + let callback_wrapper = std::sync::Arc::new(callback); + + let rt = crate::runtime::init_runtime_mt( + self.config.threads, + self.config.io_blocking_threads, + self.config.blocking_threads, + std::sync::Arc::new(event_loop.clone().unbind()), + ); + let rth = rt.handler(); + + let (stx, mut srx) = tokio::sync::watch::channel(false); + rt.inner.spawn(async move { + crate::workers::loop_match_tls!( + http_mode, + http_upgrades, + tcp_listener, + tls_cfg, + srx, + backpressure, + rth, + callback_wrapper, + tokio::spawn, + hyper_util::rt::TokioExecutor::new, + http1_opts, + http2_opts, + hyper_util::rt::TokioIo::new, + handle + ); + + log::info!("Stopping worker-{}", worker_id); + + Python::with_gil(|_| drop(callback_wrapper)); + }); + + let pysig = signal.clone_ref(py); + std::thread::spawn(move || { + let pyrx = pysig.get().rx.lock().unwrap().take().unwrap(); + _ = pyrx.recv(); + stx.send(true).unwrap(); + + Python::with_gil(|py| { + _ = pysig.get().release(py); + drop(pysig); + }); + }); + + _ = signal.get().qs.call_method0(py, pyo3::intern!(py, "wait")); + } + + fn _serve_wth_ssl( + &self, + py: Python, + callback: Py, + event_loop: &Bound, + signal: Py, + ) { + _ = pyo3_log::try_init(); + + let worker_id = self.config.id; + log::info!("Started worker-{}", worker_id); + + let (stx, srx) = tokio::sync::watch::channel(false); + let mut workers = vec![]; + crate::workers::serve_wth_ssl_inner!(self, handle, callback, event_loop, worker_id, workers, srx); + + let pysig = signal.clone_ref(py); + std::thread::spawn(move || { + let pyrx = pysig.get().rx.lock().unwrap().take().unwrap(); + _ = pyrx.recv(); + stx.send(true).unwrap(); + log::info!("Stopping worker-{}", worker_id); + while let Some(worker) = workers.pop() { + worker.join().unwrap(); + } + + Python::with_gil(|py| { + _ = pysig.get().release(py); + drop(pysig); + }); + }); + + _ = signal.get().qs.call_method0(py, pyo3::intern!(py, "wait")); + } } #[pymethods] @@ -28,7 +229,8 @@ impl WSGIWorker { worker_id, socket_fd, threads=1, - blocking_threads=512, + io_blocking_threads=512, + blocking_threads=1, backpressure=128, http_mode="1", http1_opts=None, @@ -44,6 +246,7 @@ impl WSGIWorker { worker_id: i32, socket_fd: i32, threads: usize, + io_blocking_threads: usize, blocking_threads: usize, backpressure: usize, http_mode: &str, @@ -59,6 +262,7 @@ impl WSGIWorker { worker_id, socket_fd, threads, + io_blocking_threads, blocking_threads, backpressure, http_mode, @@ -73,17 +277,29 @@ impl WSGIWorker { }) } - fn serve_rth(&self, callback: Py, event_loop: &Bound, signal: Py) { + fn serve_rth( + &self, + py: Python, + callback: Py, + event_loop: &Bound, + signal: Py, + ) { match self.config.ssl_enabled { - false => self._serve_rth(callback, event_loop, WorkerSignals::Crossbeam(signal)), - true => self._serve_rth_ssl(callback, event_loop, WorkerSignals::Crossbeam(signal)), + false => self._serve_rth(py, callback, event_loop, signal), + true => self._serve_rth_ssl(py, callback, event_loop, signal), } } - fn serve_wth(&self, callback: Py, event_loop: &Bound, signal: Py) { + fn serve_wth( + &self, + py: Python, + callback: Py, + event_loop: &Bound, + signal: Py, + ) { match self.config.ssl_enabled { - false => self._serve_wth(callback, event_loop, WorkerSignals::Crossbeam(signal)), - true => self._serve_wth_ssl(callback, event_loop, WorkerSignals::Crossbeam(signal)), + false => self._serve_wth(py, callback, event_loop, signal), + true => self._serve_wth_ssl(py, callback, event_loop, signal), } } } diff --git a/src/wsgi/types.rs b/src/wsgi/types.rs index 9a26233..9ad2400 100644 --- a/src/wsgi/types.rs +++ b/src/wsgi/types.rs @@ -9,6 +9,7 @@ use std::sync::{Arc, Mutex}; use tokio::sync::Mutex as AsyncMutex; use tokio_util::bytes::{BufMut, BytesMut}; +use super::utils::py_allow_threads; use crate::{conversion::BytesToPy, runtime::RuntimeRef}; const LINE_SPLIT: u8 = u8::from_be_bytes(*b"\n"); @@ -74,7 +75,7 @@ impl WSGIBody { #[allow(clippy::map_unwrap_or)] fn _readline(&self, py: Python) -> Bytes { let inner = self.inner.clone(); - py.allow_threads(|| { + py_allow_threads!(py, { self.rt.inner.block_on(async move { WSGIBody::fill_buffer(inner, self.buffer.clone(), WSGIBodyBuffering::Line).await; }); @@ -111,7 +112,7 @@ impl WSGIBody { match size { None => { let inner = self.inner.clone(); - let data = py.allow_threads(|| { + let data = py_allow_threads!(py, { self.rt.inner.block_on(async move { let mut inner = inner.lock().await; BodyExt::collect(&mut *inner) @@ -125,7 +126,7 @@ impl WSGIBody { 0 => BytesToPy(Bytes::new()), size => { let inner = self.inner.clone(); - py.allow_threads(|| { + py_allow_threads!(py, { self.rt.inner.block_on(async move { WSGIBody::fill_buffer(inner, self.buffer.clone(), WSGIBodyBuffering::Size(size)).await; }); @@ -149,7 +150,7 @@ impl WSGIBody { #[pyo3(signature = (_hint=None))] fn readlines<'p>(&self, py: Python<'p>, _hint: Option) -> PyResult> { let inner = self.inner.clone(); - let data = py.allow_threads(|| { + let data = py_allow_threads!(py, { self.rt.inner.block_on(async move { let mut inner = inner.lock().await; BodyExt::collect(&mut *inner) diff --git a/src/wsgi/utils.rs b/src/wsgi/utils.rs new file mode 100644 index 0000000..603edb9 --- /dev/null +++ b/src/wsgi/utils.rs @@ -0,0 +1,15 @@ +#[cfg(not(Py_GIL_DISABLED))] +macro_rules! py_allow_threads { + ($py:expr, $func:tt) => { + $py.allow_threads(|| $func) + }; +} + +#[cfg(Py_GIL_DISABLED)] +macro_rules! py_allow_threads { + ($py:expr, $func:tt) => { + $func + }; +} + +pub(super) use py_allow_threads; diff --git a/tests/conftest.py b/tests/conftest.py index 537599c..0c0a970 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,6 +21,7 @@ async def _server(interface, port, threading_mode, tls=False): kwargs = { 'interface': interface, 'port': port, + 'blocking_threads': 1, 'threading_mode': threading_mode, } if tls: diff --git a/tests/test_ws.py b/tests/test_ws.py index 2006cae..080a2e5 100644 --- a/tests/test_ws.py +++ b/tests/test_ws.py @@ -25,11 +25,11 @@ async def test_messages(server, threading_mode): @pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_reject(server, threading_mode): async with server(threading_mode) as port: - with pytest.raises(websockets.InvalidStatusCode) as exc: + with pytest.raises(websockets.exceptions.InvalidStatus) as exc: async with websockets.connect(f'ws://localhost:{port}/ws_reject'): pass - assert exc.value.status_code == 403 + assert exc.value.response.status_code == 403 @pytest.mark.asyncio