Align codestyle (#124)

* Add lint and format tools and config

* Format Python code

* Format Rust code

* Add lint CI workflow
This commit is contained in:
Giovanni Barillari 2023-09-25 18:00:09 +02:00 committed by GitHub
parent e28eb81db6
commit c95fd0cba7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
53 changed files with 1388 additions and 1991 deletions

32
.github/workflows/lint.yml vendored Normal file
View file

@ -0,0 +1,32 @@
name: lint
on:
pull_request:
types: [opened, synchronize]
branches:
- master
env:
MATURIN_VERSION: 1.2.3
PYTHON_VERSION: 3.11
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ env.PYTHON_VERSION }}
uses: actions/setup-python@v4
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install
run: |
python -m venv .venv
source .venv/bin/activate
pip install maturin==${{ env.MATURIN_VERSION }}
maturin develop --extras=lint
- name: Lint
run: |
source .venv/bin/activate
make lint

26
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,26 @@
fail_fast: true
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: check-yaml
- id: check-toml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-added-large-files
- repo: local
hooks:
- id: lint-python
name: Lint Python
entry: make lint-python
types: [python]
language: system
pass_filenames: false
- id: lint-rust
name: Lint Rust
entry: make lint-rust
types: [rust]
language: system
pass_filenames: false

1
.rustfmt.toml Normal file
View file

@ -0,0 +1 @@
max_width = 120

67
Makefile Normal file
View file

@ -0,0 +1,67 @@
.DEFAULT_GOAL := all
black = black granian tests
ruff = ruff granian tests
.PHONY: build-dev
build-dev:
@rm -f granian/*.so
maturin develop --extras test
.PHONY: format
format:
$(black)
$(ruff) --fix --exit-zero
cargo fmt
.PHONY: lint-python
lint-python:
$(ruff)
$(black) --check --diff
.PHONY: lint-rust
lint-rust:
cargo fmt --version
cargo fmt --all -- --check
cargo clippy --version
cargo clippy --tests -- \
-D warnings \
-W clippy::pedantic \
-W clippy::dbg_macro \
-W clippy::print_stdout \
-A clippy::cast-possible-truncation \
-A clippy::cast-possible-wrap \
-A clippy::cast-precision-loss \
-A clippy::cast-sign-loss \
-A clippy::declare-interior-mutable-const \
-A clippy::float-cmp \
-A clippy::fn-params-excessive-bools \
-A clippy::if-not-else \
-A clippy::inline-always \
-A clippy::manual-let-else \
-A clippy::match-bool \
-A clippy::match-same-arms \
-A clippy::missing-errors-doc \
-A clippy::missing-panics-doc \
-A clippy::module-name-repetitions \
-A clippy::must-use-candidate \
-A clippy::needless-pass-by-value \
-A clippy::similar-names \
-A clippy::single-match-else \
-A clippy::struct-excessive-bools \
-A clippy::too-many-arguments \
-A clippy::too-many-lines \
-A clippy::type-complexity \
-A clippy::unnecessary-wraps \
-A clippy::unused-self \
-A clippy::used-underscore-binding \
-A clippy::wrong-self-convention
.PHONY: lint
lint: lint-python lint-rust
.PHONY: test
test:
pytest -v test
.PHONY: all
all: format build-dev lint test

View file

@ -1 +1 @@
from .server import Granian
from .server import Granian # noqa

View file

@ -1,3 +1,4 @@
from granian.cli import cli
cli()

View file

@ -1 +1 @@
__version__ = "0.6.0"
__version__ = '0.6.0'

View file

@ -1,12 +1,10 @@
from typing import Any, Dict, List, Tuple, Optional
from typing import Any, Dict, List, Optional, Tuple
from ._types import WebsocketMessage
class ASGIScope:
def as_dict(self, root_path: str) -> Dict[str, Any]: ...
class RSGIHeaders:
def __contains__(self, key: str) -> bool: ...
def keys(self) -> List[str]: ...
@ -14,7 +12,6 @@ class RSGIHeaders:
def items(self) -> List[Tuple[str]]: ...
def get(self, key: str, default: Any = None) -> Any: ...
class RSGIScope:
proto: str
http_version: str
@ -29,12 +26,10 @@ class RSGIScope:
@property
def headers(self) -> RSGIHeaders: ...
class RSGIHTTPStreamTransport:
async def send_bytes(self, data: bytes): ...
async def send_str(self, data: str): ...
class RSGIHTTPProtocol:
async def __call__(self) -> bytes: ...
def response_empty(self, status: int, headers: List[Tuple[str, str]]): ...
@ -43,25 +38,17 @@ class RSGIHTTPProtocol:
def response_file(self, status: int, headers: List[Tuple[str, str]], file: str): ...
def response_stream(self, status: int, headers: List[Tuple[str, str]]) -> RSGIHTTPStreamTransport: ...
class RSGIWebsocketTransport:
async def receive(self) -> WebsocketMessage: ...
async def send_bytes(self, data: bytes): ...
async def send_str(self, data: str): ...
class RSGIWebsocketProtocol:
async def accept(self) -> RSGIWebsocketTransport: ...
def close(self, status: Optional[int]) -> Tuple[int, bool]: ...
class RSGIProtocolError(RuntimeError):
...
class RSGIProtocolClosed(RuntimeError):
...
class RSGIProtocolError(RuntimeError): ...
class RSGIProtocolClosed(RuntimeError): ...
class WSGIScope:
def to_environ(self, environ: Dict[str, Any]) -> Dict[str, Any]: ...

View file

@ -2,22 +2,21 @@ import os
import re
import sys
import traceback
from types import ModuleType
from typing import Callable, List, Optional
def get_import_components(path: str) -> List[Optional[str]]:
return (re.split(r":(?![\\/])", path, 1) + [None])[:2]
return (re.split(r':(?![\\/])', path, 1) + [None])[:2]
def prepare_import(path: str) -> str:
path = os.path.realpath(path)
fname, ext = os.path.splitext(path)
if ext == ".py":
if ext == '.py':
path = fname
if os.path.basename(path) == "__init__":
if os.path.basename(path) == '__init__':
path = os.path.dirname(path)
module_name = []
@ -27,26 +26,22 @@ def prepare_import(path: str) -> str:
path, name = os.path.split(path)
module_name.append(name)
if not os.path.exists(os.path.join(path, "__init__.py")):
if not os.path.exists(os.path.join(path, '__init__.py')):
break
if sys.path[0] != path:
sys.path.insert(0, path)
return ".".join(module_name[::-1])
return '.'.join(module_name[::-1])
def load_module(
module_name: str,
raise_on_failure: bool = True
) -> Optional[ModuleType]:
def load_module(module_name: str, raise_on_failure: bool = True) -> Optional[ModuleType]:
try:
__import__(module_name)
except ImportError:
if sys.exc_info()[-1].tb_next:
raise RuntimeError(
f"While importing '{module_name}', an ImportError was raised:"
f"\n\n{traceback.format_exc()}"
f"While importing '{module_name}', an ImportError was raised:" f"\n\n{traceback.format_exc()}"
)
elif raise_on_failure:
raise RuntimeError(f"Could not import '{module_name}'.")
@ -58,9 +53,9 @@ def load_module(
def load_target(target: str) -> Callable[..., None]:
path, name = get_import_components(target)
path = prepare_import(path) if path else None
name = name or "app"
name = name or 'app'
module = load_module(path)
rv = module
for element in name.split("."):
for element in name.split('.'):
rv = getattr(rv, element)
return rv

View file

@ -2,12 +2,11 @@ import asyncio
import os
import signal
import sys
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
class Registry:
__slots__ = ["_data"]
__slots__ = ['_data']
def __init__(self):
self._data: Dict[str, Callable[..., Any]] = {}
@ -22,6 +21,7 @@ class Registry:
def wrap(builder: Callable[..., Any]) -> Callable[..., Any]:
self._data[key] = builder
return builder
return wrap
def get(self, key: str) -> Callable[..., Any]:
@ -31,18 +31,13 @@ class Registry:
raise RuntimeError(f"'{key}' implementation not available.")
class BuilderRegistry(Registry):
__slots__ = []
def __init__(self):
self._data: Dict[str, Tuple[Callable[..., Any], List[str]]] = {}
def register(
self,
key: str,
packages: Optional[List[str]] = None
) -> Callable[[], Callable[..., Any]]:
def register(self, key: str, packages: Optional[List[str]] = None) -> Callable[[], Callable[..., Any]]:
packages = packages or []
def wrap(builder: Callable[..., Any]) -> Callable[..., Any]:
@ -56,6 +51,7 @@ class BuilderRegistry(Registry):
if implemented:
self._data[key] = (builder, loaded_packages)
return builder
return wrap
def get(self, key: str) -> Callable[..., Any]:

View file

@ -1,5 +1,4 @@
import asyncio
from functools import wraps
from ._granian import ASGIScope as Scope
@ -22,12 +21,7 @@ class LifespanProtocol:
async def handle(self):
try:
await self.callable(
{
"type": "lifespan",
"asgi": {"version": "3.0", "spec_version": "2.3"}
},
self.receive,
self.send
{'type': 'lifespan', 'asgi': {'version': '3.0', 'spec_version': '2.3'}}, self.receive, self.send
)
except Exception:
self.errored = True
@ -43,7 +37,7 @@ class LifespanProtocol:
loop = asyncio.get_event_loop()
_handler_task = loop.create_task(self.handle())
await self.event_queue.put({"type": "lifespan.startup"})
await self.event_queue.put({'type': 'lifespan.startup'})
await self.event_startup.wait()
if self.failure_startup or (self.errored and not self.unsupported):
@ -53,7 +47,7 @@ class LifespanProtocol:
if self.errored:
return
await self.event_queue.put({"type": "lifespan.shutdown"})
await self.event_queue.put({'type': 'lifespan.shutdown'})
await self.event_shutdown.wait()
if self.failure_shutdown or (self.errored and not self.unsupported):
@ -89,14 +83,14 @@ class LifespanProtocol:
# self.logger.error(message["message"])
_event_handlers = {
"lifespan.startup.complete": _handle_startup_complete,
"lifespan.startup.failed": _handle_startup_failed,
"lifespan.shutdown.complete": _handle_shutdown_complete,
"lifespan.shutdown.failed": _handle_shutdown_failed
'lifespan.startup.complete': _handle_startup_complete,
'lifespan.startup.failed': _handle_startup_failed,
'lifespan.shutdown.complete': _handle_shutdown_complete,
'lifespan.shutdown.failed': _handle_shutdown_failed,
}
async def send(self, message):
handler = self._event_handlers[message["type"]]
handler = self._event_handlers[message['type']]
handler(self, message)
@ -108,6 +102,7 @@ def _send_wrapper(proto):
@wraps(proto)
def send(data):
return proto(_noop_coro, data)
return send
@ -116,9 +111,6 @@ def _callback_wrapper(callback, scope_opts):
@wraps(callback)
def wrapper(scope: Scope, proto):
return callback(
scope.as_dict(root_url_path),
proto.receive,
_send_wrapper(proto.send)
)
return callback(scope.as_dict(root_url_path), proto.receive, _send_wrapper(proto.send))
return wrapper

View file

@ -1,107 +1,55 @@
import json
from pathlib import Path
from typing import Optional
import typer
from .__version__ import __version__
from .constants import Interfaces, HTTPModes, Loops, ThreadModes
from .constants import HTTPModes, Interfaces, Loops, ThreadModes
from .log import LogLevels
from .server import Granian
cli = typer.Typer(name="granian", context_settings={"ignore_unknown_options": True})
cli = typer.Typer(name='granian', context_settings={'ignore_unknown_options': True})
def version_callback(value: bool):
if value:
typer.echo(f"{cli.info.name} {__version__}")
typer.echo(f'{cli.info.name} {__version__}')
raise typer.Exit()
@cli.command()
def main(
app: str = typer.Argument(..., help="Application target to serve."),
host: str = typer.Option("127.0.0.1", help="Host address to bind to."),
port: int = typer.Option(8000, help="Port to bind to."),
interface: Interfaces = typer.Option(
Interfaces.RSGI.value,
help="Application interface type."
),
http: HTTPModes = typer.Option(
HTTPModes.auto.value,
help="HTTP version."
),
websockets: bool = typer.Option(
True,
"--ws/--no-ws",
help="Enable websockets handling",
show_default="enabled"
),
workers: int = typer.Option(1, min=1, help="Number of worker processes."),
threads: int = typer.Option(1, min=1, help="Number of threads."),
threading_mode: ThreadModes = typer.Option(
ThreadModes.workers.value,
help="Threading mode to use."
),
loop: Loops = typer.Option(Loops.auto.value, help="Event loop implementation"),
loop_opt: bool = typer.Option(
False,
"--opt/--no-opt",
help="Enable loop optimizations",
show_default="disabled"
),
backlog: int = typer.Option(
1024,
min=128,
help="Maximum number of connections to hold in backlog."
),
log_level: LogLevels = typer.Option(
LogLevels.info.value,
help="Log level",
case_sensitive=False
),
app: str = typer.Argument(..., help='Application target to serve.'),
host: str = typer.Option('127.0.0.1', help='Host address to bind to.'),
port: int = typer.Option(8000, help='Port to bind to.'),
interface: Interfaces = typer.Option(Interfaces.RSGI.value, help='Application interface type.'),
http: HTTPModes = typer.Option(HTTPModes.auto.value, help='HTTP version.'),
websockets: bool = typer.Option(True, '--ws/--no-ws', help='Enable websockets handling', show_default='enabled'),
workers: int = typer.Option(1, min=1, help='Number of worker processes.'),
threads: int = typer.Option(1, min=1, help='Number of threads.'),
threading_mode: ThreadModes = typer.Option(ThreadModes.workers.value, help='Threading mode to use.'),
loop: Loops = typer.Option(Loops.auto.value, help='Event loop implementation'),
loop_opt: bool = typer.Option(False, '--opt/--no-opt', help='Enable loop optimizations', show_default='disabled'),
backlog: int = typer.Option(1024, min=128, help='Maximum number of connections to hold in backlog.'),
log_level: LogLevels = typer.Option(LogLevels.info.value, help='Log level', case_sensitive=False),
log_config: Optional[Path] = typer.Option(
None,
help="Logging configuration file (json)",
exists=True,
file_okay=True,
dir_okay=False,
readable=True
None, help='Logging configuration file (json)', exists=True, file_okay=True, dir_okay=False, readable=True
),
ssl_keyfile: Optional[Path] = typer.Option(
None,
help="SSL key file",
exists=True,
file_okay=True,
dir_okay=False,
readable=True
None, help='SSL key file', exists=True, file_okay=True, dir_okay=False, readable=True
),
ssl_certificate: Optional[Path] = typer.Option(
None,
help="SSL certificate file",
exists=True,
file_okay=True,
dir_okay=False,
readable=True
),
url_path_prefix: Optional[str] = typer.Option(
None,
help="URL path prefix the app is mounted on"
None, help='SSL certificate file', exists=True, file_okay=True, dir_okay=False, readable=True
),
url_path_prefix: Optional[str] = typer.Option(None, help='URL path prefix the app is mounted on'),
reload: bool = typer.Option(
False,
"--reload/--no-reload",
help="Enable auto reload on application's files changes"
False, '--reload/--no-reload', help="Enable auto reload on application's files changes"
),
_: Optional[bool] = typer.Option(
None,
"--version",
callback=version_callback,
is_eager=True,
help="Shows the version and exit."
)
None, '--version', callback=version_callback, is_eager=True, help='Shows the version and exit.'
),
):
log_dictconfig = None
if log_config:
@ -109,7 +57,7 @@ def main(
try:
log_dictconfig = json.loads(log_config_file.read())
except Exception:
print("Unable to parse provided logging config.")
print('Unable to parse provided logging config.')
raise typer.Exit(1)
Granian(
@ -131,5 +79,5 @@ def main(
ssl_cert=ssl_certificate,
ssl_key=ssl_keyfile,
url_path_prefix=url_path_prefix,
reload=reload
reload=reload,
).serve()

View file

@ -2,23 +2,23 @@ from enum import Enum
class Interfaces(str, Enum):
ASGI = "asgi"
RSGI = "rsgi"
WSGI = "wsgi"
ASGI = 'asgi'
RSGI = 'rsgi'
WSGI = 'wsgi'
class HTTPModes(str, Enum):
auto = "auto"
http1 = "1"
http2 = "2"
auto = 'auto'
http1 = '1'
http2 = '2'
class ThreadModes(str, Enum):
runtime = "runtime"
workers = "workers"
runtime = 'runtime'
workers = 'workers'
class Loops(str, Enum):
auto = "auto"
asyncio = "asyncio"
uvloop = "uvloop"
auto = 'auto'
asyncio = 'asyncio'
uvloop = 'uvloop'

View file

@ -1,18 +1,17 @@
import copy
import logging
import logging.config
from enum import Enum
from typing import Any, Dict, Optional
class LogLevels(str, Enum):
critical = "critical"
error = "error"
warning = "warning"
warn = "warn"
info = "info"
debug = "debug"
critical = 'critical'
error = 'error'
warning = 'warning'
warn = 'warn'
info = 'info'
debug = 'debug'
log_levels_map = {
@ -21,27 +20,21 @@ log_levels_map = {
LogLevels.warning: logging.WARNING,
LogLevels.warn: logging.WARN,
LogLevels.info: logging.INFO,
LogLevels.debug: logging.DEBUG
LogLevels.debug: logging.DEBUG,
}
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"root": {"level": "INFO", "handlers": ["console"]},
"formatters": {
"generic": {
"()": "logging.Formatter",
"fmt": "[%(levelname)s] %(message)s",
"datefmt": "[%Y-%m-%d %H:%M:%S %z]"
'version': 1,
'disable_existing_loggers': False,
'root': {'level': 'INFO', 'handlers': ['console']},
'formatters': {
'generic': {
'()': 'logging.Formatter',
'fmt': '[%(levelname)s] %(message)s',
'datefmt': '[%Y-%m-%d %H:%M:%S %z]',
}
},
"handlers": {
"console": {
"formatter": "generic",
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
}
}
'handlers': {'console': {'formatter': 'generic', 'class': 'logging.StreamHandler', 'stream': 'ext://sys.stdout'}},
}
logger = logging.getLogger()
@ -51,5 +44,5 @@ def configure_logging(level: LogLevels, config: Optional[Dict[str, Any]] = None)
log_config = copy.deepcopy(LOGGING_CONFIG)
if config:
log_config.update(config)
log_config["root"]["level"] = log_levels_map[level]
log_config['root']['level'] = log_levels_map[level]
logging.config.dictConfig(log_config)

View file

@ -2,7 +2,5 @@ import copyreg
from ._granian import ListenerHolder as SocketHolder
copyreg.pickle(
SocketHolder,
lambda v: (SocketHolder, v.__getstate__())
)
copyreg.pickle(SocketHolder, lambda v: (SocketHolder, v.__getstate__()))

View file

@ -2,12 +2,12 @@ from enum import Enum
from typing import Union
from ._granian import (
RSGIHTTPProtocol as HTTPProtocol,
RSGIWebsocketProtocol as WebsocketProtocol,
RSGIHeaders as Headers,
RSGIScope as Scope,
RSGIProtocolError as ProtocolError,
RSGIProtocolClosed as ProtocolClosed
RSGIHeaders as Headers, # noqa
RSGIHTTPProtocol as HTTPProtocol, # noqa
RSGIProtocolClosed as ProtocolClosed, # noqa
RSGIProtocolError as ProtocolError, # noqa
RSGIScope as Scope, # noqa
RSGIWebsocketProtocol as WebsocketProtocol, # noqa
)

View file

@ -5,7 +5,6 @@ import socket
import ssl
import sys
import threading
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
@ -16,11 +15,12 @@ from ._futures import future_watcher_wrapper
from ._granian import ASGIWorker, RSGIWorker, WSGIWorker
from ._internal import load_target
from .asgi import LifespanProtocol, _callback_wrapper as _asgi_call_wrap
from .constants import Interfaces, HTTPModes, Loops, ThreadModes
from .constants import HTTPModes, Interfaces, Loops, ThreadModes
from .log import LogLevels, configure_logging, logger
from .net import SocketHolder
from .wsgi import _callback_wrapper as _wsgi_call_wrap
multiprocessing.allow_connection_pickling()
@ -30,7 +30,7 @@ class Granian:
def __init__(
self,
target: str,
address: str = "127.0.0.1",
address: str = '127.0.0.1',
port: int = 8000,
interface: Interfaces = Interfaces.RSGI,
workers: int = 1,
@ -48,7 +48,7 @@ class Granian:
ssl_cert: Optional[Path] = None,
ssl_key: Optional[Path] = None,
url_path_prefix: Optional[str] = None,
reload: bool = False
reload: bool = False,
):
self.target = target
self.bind_addr = address
@ -75,11 +75,7 @@ class Granian:
self.procs: List[multiprocessing.Process] = []
self.exit_event = threading.Event()
def build_ssl_context(
self,
cert: Optional[Path],
key: Optional[Path]
):
def build_ssl_context(self, cert: Optional[Path], key: Optional[Path]):
if not (cert and key):
self.ssl_ctx = (False, None, None)
return
@ -108,7 +104,7 @@ class Granian:
log_level,
log_config,
ssl_ctx,
scope_opts
scope_opts,
):
from granian._loops import loops, set_loop_signals
@ -129,29 +125,12 @@ class Granian:
wcallback = future_watcher_wrapper(wcallback)
worker = ASGIWorker(
worker_id,
sfd,
threads,
pthreads,
http_mode,
http1_buffer_size,
websockets,
loop_opt,
*ssl_ctx
)
serve = getattr(worker, {
ThreadModes.runtime: "serve_rth",
ThreadModes.workers: "serve_wth"
}[threading_mode])
serve(
wcallback,
loop,
contextvars.copy_context(),
shutdown_event.wait()
worker_id, sfd, threads, pthreads, http_mode, http1_buffer_size, websockets, loop_opt, *ssl_ctx
)
serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode])
serve(wcallback, loop, contextvars.copy_context(), shutdown_event.wait())
loop.run_until_complete(lifespan_handler.shutdown())
@staticmethod
def _spawn_rsgi_worker(
worker_id,
@ -168,7 +147,7 @@ class Granian:
log_level,
log_config,
ssl_ctx,
scope_opts
scope_opts,
):
from granian._loops import loops, set_loop_signals
@ -176,38 +155,23 @@ class Granian:
loop = loops.get(loop_impl)
sfd = socket.fileno()
target = callback_loader()
callback = (
getattr(target, '__rsgi__') if hasattr(target, '__rsgi__') else
target
)
callback = getattr(target, '__rsgi__') if hasattr(target, '__rsgi__') else target
callback_init = (
getattr(target, '__rsgi_init__') if hasattr(target, '__rsgi_init__') else
lambda *args, **kwargs: None
getattr(target, '__rsgi_init__') if hasattr(target, '__rsgi_init__') else lambda *args, **kwargs: None
)
shutdown_event = set_loop_signals(loop, [signal.SIGTERM, signal.SIGINT])
callback_init(loop)
worker = RSGIWorker(
worker_id,
sfd,
threads,
pthreads,
http_mode,
http1_buffer_size,
websockets,
loop_opt,
*ssl_ctx
worker_id, sfd, threads, pthreads, http_mode, http1_buffer_size, websockets, loop_opt, *ssl_ctx
)
serve = getattr(worker, {
ThreadModes.runtime: "serve_rth",
ThreadModes.workers: "serve_wth"
}[threading_mode])
serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode])
serve(
future_watcher_wrapper(callback) if not loop_opt else callback,
loop,
contextvars.copy_context(),
shutdown_event.wait()
shutdown_event.wait(),
)
@staticmethod
@ -226,7 +190,7 @@ class Granian:
log_level,
log_config,
ssl_ctx,
scope_opts
scope_opts,
):
from granian._loops import loops, set_loop_signals
@ -237,46 +201,20 @@ class Granian:
shutdown_event = set_loop_signals(loop, [signal.SIGTERM, signal.SIGINT])
worker = WSGIWorker(
worker_id,
sfd,
threads,
pthreads,
http_mode,
http1_buffer_size,
*ssl_ctx
)
serve = getattr(worker, {
ThreadModes.runtime: "serve_rth",
ThreadModes.workers: "serve_wth"
}[threading_mode])
serve(
_wsgi_call_wrap(callback, scope_opts),
loop,
contextvars.copy_context(),
shutdown_event.wait()
)
worker = WSGIWorker(worker_id, sfd, threads, pthreads, http_mode, http1_buffer_size, *ssl_ctx)
serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode])
serve(_wsgi_call_wrap(callback, scope_opts), loop, contextvars.copy_context(), shutdown_event.wait())
def _init_shared_socket(self):
self._shd = SocketHolder.from_address(
self.bind_addr,
self.bind_port,
self.backlog
)
self._shd = SocketHolder.from_address(self.bind_addr, self.bind_port, self.backlog)
self._sfd = self._shd.get_fd()
def signal_handler(self, *args, **kwargs):
self.exit_event.set()
def _spawn_proc(
self,
id,
target,
callback_loader,
socket_loader
) -> multiprocessing.Process:
def _spawn_proc(self, id, target, callback_loader, socket_loader) -> multiprocessing.Process:
return multiprocessing.get_context().Process(
name="granian-worker",
name='granian-worker',
target=target,
args=(
id,
@ -293,10 +231,8 @@ class Granian:
self.log_level,
self.log_config,
self.ssl_ctx,
{
"url_path_prefix": self.url_path_prefix
}
)
{'url_path_prefix': self.url_path_prefix},
),
)
def _spawn_workers(self, sock, spawn_target, target_loader):
@ -305,14 +241,11 @@ class Granian:
for idx in range(self.workers):
proc = self._spawn_proc(
id=idx + 1,
target=spawn_target,
callback_loader=target_loader,
socket_loader=socket_loader
id=idx + 1, target=spawn_target, callback_loader=target_loader, socket_loader=socket_loader
)
proc.start()
self.procs.append(proc)
logger.info(f"Spawning worker-{idx + 1} with pid: {proc.pid}")
logger.info(f'Spawning worker-{idx + 1} with pid: {proc.pid}')
def _stop_workers(self):
for proc in self.procs:
@ -321,7 +254,7 @@ class Granian:
proc.join()
def startup(self, spawn_target, target_loader):
logger.info("Starting granian")
logger.info('Starting granian')
for sig in self.SIGNALS:
signal.signal(sig, self.signal_handler)
@ -329,13 +262,13 @@ class Granian:
self._init_shared_socket()
sock = socket.socket(fileno=self._sfd)
sock.set_inheritable(True)
logger.info(f"Listening at: {self.bind_addr}:{self.bind_port}")
logger.info(f'Listening at: {self.bind_addr}:{self.bind_port}')
self._spawn_workers(sock, spawn_target, target_loader)
return sock
def shutdown(self):
logger.info("Shutting down granian")
logger.info('Shutting down granian')
self._stop_workers()
def _serve(self, spawn_target, target_loader):
@ -360,12 +293,12 @@ class Granian:
self,
spawn_target: Optional[Callable[..., None]] = None,
target_loader: Optional[Callable[..., Callable[..., Any]]] = None,
wrap_loader: bool = True
wrap_loader: bool = True,
):
default_spawners = {
Interfaces.ASGI: self._spawn_asgi_worker,
Interfaces.RSGI: self._spawn_rsgi_worker,
Interfaces.WSGI: self._spawn_wsgi_worker
Interfaces.WSGI: self._spawn_wsgi_worker,
}
if target_loader:
if wrap_loader:
@ -383,8 +316,5 @@ class Granian:
"Number of workers will now fallback to 1."
)
serve_method = (
self._serve_with_reloader if self.reload_on_changes else
self._serve
)
serve_method = self._serve_with_reloader if self.reload_on_changes else self._serve
serve_method(spawn_target, target_loader)

View file

@ -1,6 +1,5 @@
import os
import sys
from functools import wraps
from typing import Any, List, Tuple
@ -14,30 +13,27 @@ class Response:
self.status = 200
self.headers = []
def __call__(
self,
status: str,
headers: List[Tuple[str, str]],
exc_info: Any = None
):
def __call__(self, status: str, headers: List[Tuple[str, str]], exc_info: Any = None):
self.status = int(status.split(' ', 1)[0])
self.headers = headers
def _callback_wrapper(callback, scope_opts):
basic_env = dict(os.environ)
basic_env.update({
'GATEWAY_INTERFACE': 'CGI/1.1',
'SCRIPT_NAME': scope_opts.get('url_path_prefix') or '',
'SERVER_SOFTWARE': 'Granian',
'wsgi.errors': sys.stderr,
'wsgi.input_terminated': True,
'wsgi.input': None,
'wsgi.multiprocess': False,
'wsgi.multithread': False,
'wsgi.run_once': False,
'wsgi.version': (1, 0)
})
basic_env.update(
{
'GATEWAY_INTERFACE': 'CGI/1.1',
'SCRIPT_NAME': scope_opts.get('url_path_prefix') or '',
'SERVER_SOFTWARE': 'Granian',
'wsgi.errors': sys.stderr,
'wsgi.input_terminated': True,
'wsgi.input': None,
'wsgi.multiprocess': False,
'wsgi.multithread': False,
'wsgi.run_once': False,
'wsgi.version': (1, 0),
}
)
@wraps(callback)
def wrapper(scope: Scope) -> Tuple[int, List[Tuple[str, str]], bytes]:
@ -46,7 +42,7 @@ def _callback_wrapper(callback, scope_opts):
if isinstance(rv, list):
resp_type = 0
rv = b"".join(rv)
rv = b''.join(rv)
else:
resp_type = 1
rv = iter(rv)

View file

@ -1,65 +1,111 @@
[project]
name = "granian"
name = 'granian'
authors = [
{name = "Giovanni Barillari", email = "g@baro.dev"}
{name = 'Giovanni Barillari', email = 'g@baro.dev'}
]
classifiers = [
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"License :: OSI Approved :: BSD License",
"Operating System :: MacOS",
"Operating System :: Microsoft :: Windows",
"Operating System :: POSIX :: Linux",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
"Programming Language :: Python",
"Programming Language :: Rust",
"Topic :: Internet :: WWW/HTTP"
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
'License :: OSI Approved :: BSD License',
'Operating System :: MacOS',
'Operating System :: Microsoft :: Windows',
'Operating System :: POSIX :: Linux',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: Implementation :: CPython',
'Programming Language :: Python :: Implementation :: PyPy',
'Programming Language :: Python',
'Programming Language :: Rust',
'Topic :: Internet :: WWW/HTTP'
]
dynamic = [
"description",
"keywords",
"license",
"readme",
"version"
'description',
'keywords',
'license',
'readme',
'version'
]
requires-python = ">=3.8"
requires-python = '>=3.8'
dependencies = [
"watchfiles~=0.18",
"typer~=0.4",
"uvloop~=0.17.0; sys_platform != 'win32' and platform_python_implementation == 'CPython'"
'watchfiles~=0.18',
'typer~=0.4',
'uvloop~=0.17.0; sys_platform != "win32" and platform_python_implementation == "CPython"'
]
[project.optional-dependencies]
lint = [
'black~=23.7.0',
'ruff~=0.0.287'
]
test = [
"httpx~=0.23.0",
"pytest~=7.1.2",
"pytest-asyncio~=0.18.3",
"websockets~=10.3"
'httpx~=0.23.0',
'pytest~=7.1.2',
'pytest-asyncio~=0.18.3',
'websockets~=10.3'
]
[project.urls]
Homepage = "https://github.com/emmett-framework/granian"
Funding = "https://github.com/sponsors/gi0baro"
Source = "https://github.com/emmett-framework/granian"
Homepage = 'https://github.com/emmett-framework/granian'
Funding = 'https://github.com/sponsors/gi0baro'
Source = 'https://github.com/emmett-framework/granian'
[project.scripts]
granian = "granian:cli.cli"
granian = 'granian:cli.cli'
[build-system]
requires = ["maturin>=1.1.0,<1.3.0"]
build-backend = "maturin"
requires = ['maturin>=1.1.0,<1.3.0']
build-backend = 'maturin'
[tool.maturin]
module-name = "granian._granian"
bindings = "pyo3"
module-name = 'granian._granian'
bindings = 'pyo3'
[tool.ruff]
line-length = 120
extend-select = [
# E and F are enabled by default
'B', # flake8-bugbear
'C4', # flake8-comprehensions
'C90', # mccabe
'I', # isort
'N', # pep8-naming
'Q', # flake8-quotes
'RUF100', # ruff (unused noqa)
'S', # flake8-bandit
'W' # pycodestyle
]
extend-ignore = [
'B008', # function calls in args defaults are fine
'B009', # getattr with constants is fine
'B034', # re.split won't confuse us
'B904', # rising without from is fine
'E501', # leave line length to black
'N818', # leave to us exceptions naming
'S101' # assert is fine
]
flake8-quotes = { inline-quotes = 'single', multiline-quotes = 'double' }
mccabe = { max-complexity = 13 }
[tool.ruff.isort]
combine-as-imports = true
lines-after-imports = 2
known-first-party = ['granian', 'tests']
[tool.ruff.per-file-ignores]
'granian/_granian.pyi' = ['I001']
'tests/**' = ['B018', 'S110', 'S501']
[tool.black]
color = true
line-length = 120
target-version = ['py38', 'py39', 'py310', 'py311']
skip-string-normalization = true # leave this to ruff
skip-magic-trailing-comma = true # leave this to ruff
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_mode = 'auto'

View file

@ -3,45 +3,33 @@ use pyo3::prelude::*;
use pyo3_asyncio::TaskLocals;
use tokio::sync::oneshot;
use crate::{
callbacks::{
CallbackWrapper,
callback_impl_run,
callback_impl_run_pytask,
callback_impl_loop_run,
callback_impl_loop_pytask,
callback_impl_loop_step,
callback_impl_loop_wake,
callback_impl_loop_err
},
runtime::RuntimeRef,
ws::{HyperWebsocket, UpgradeData}
};
use super::{
io::{ASGIHTTPProtocol as HTTPProtocol, ASGIWebsocketProtocol as WebsocketProtocol},
types::ASGIScope as Scope
types::ASGIScope as Scope,
};
use crate::{
callbacks::{
callback_impl_loop_err, callback_impl_loop_pytask, callback_impl_loop_run, callback_impl_loop_step,
callback_impl_loop_wake, callback_impl_run, callback_impl_run_pytask, CallbackWrapper,
},
runtime::RuntimeRef,
ws::{HyperWebsocket, UpgradeData},
};
#[pyclass]
pub(crate) struct CallbackRunnerHTTP {
proto: Py<HTTPProtocol>,
context: TaskLocals,
cb: PyObject
cb: PyObject,
}
impl CallbackRunnerHTTP {
pub fn new(
py: Python,
cb: CallbackWrapper,
proto: HTTPProtocol,
scope: Scope
) -> Self {
pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self {
let pyproto = Py::new(py, proto).unwrap();
Self {
proto: pyproto.clone(),
context: cb.context,
cb: cb.callback.call1(py, (scope, pyproto)).unwrap()
cb: cb.callback.call1(py, (scope, pyproto)).unwrap(),
}
}
@ -64,14 +52,14 @@ macro_rules! callback_impl_done_http {
let _ = tx.send(res);
}
}
}
};
}
macro_rules! callback_impl_done_err {
($self:expr, $py:expr) => {
log::warn!("Application callable raised an exception");
$self.done($py)
}
};
}
#[pyclass]
@ -79,22 +67,17 @@ pub(crate) struct CallbackTaskHTTP {
proto: Py<HTTPProtocol>,
context: TaskLocals,
pycontext: PyObject,
cb: PyObject
cb: PyObject,
}
impl CallbackTaskHTTP {
pub fn new(
py: Python,
cb: PyObject,
proto: Py<HTTPProtocol>,
context: TaskLocals
) -> PyResult<Self> {
pub fn new(py: Python, cb: PyObject, proto: Py<HTTPProtocol>, context: TaskLocals) -> PyResult<Self> {
let pyctx = context.context(py);
Ok(Self {
proto,
context,
pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(),
cb
cb,
})
}
@ -128,21 +111,16 @@ pub(crate) struct CallbackWrappedRunnerHTTP {
context: TaskLocals,
cb: PyObject,
#[pyo3(get)]
scope: PyObject
scope: PyObject,
}
impl CallbackWrappedRunnerHTTP {
pub fn new(
py: Python,
cb: CallbackWrapper,
proto: HTTPProtocol,
scope: Scope
) -> Self {
pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self {
Self {
proto: Py::new(py, proto).unwrap(),
context: cb.context,
cb: cb.callback,
scope: scope.into_py(py)
scope: scope.into_py(py),
}
}
@ -168,21 +146,16 @@ impl CallbackWrappedRunnerHTTP {
pub(crate) struct CallbackRunnerWebsocket {
proto: Py<WebsocketProtocol>,
context: TaskLocals,
cb: PyObject
cb: PyObject,
}
impl CallbackRunnerWebsocket {
pub fn new(
py: Python,
cb: CallbackWrapper,
proto: WebsocketProtocol,
scope: Scope
) -> Self {
pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self {
let pyproto = Py::new(py, proto).unwrap();
Self {
proto: pyproto.clone(),
context: cb.context,
cb: cb.callback.call1(py, (scope, pyproto)).unwrap()
cb: cb.callback.call1(py, (scope, pyproto)).unwrap(),
}
}
@ -203,7 +176,7 @@ macro_rules! callback_impl_done_ws {
let _ = tx.send(res);
}
}
}
};
}
#[pyclass]
@ -211,22 +184,17 @@ pub(crate) struct CallbackTaskWebsocket {
proto: Py<WebsocketProtocol>,
context: TaskLocals,
pycontext: PyObject,
cb: PyObject
cb: PyObject,
}
impl CallbackTaskWebsocket {
pub fn new(
py: Python,
cb: PyObject,
proto: Py<WebsocketProtocol>,
context: TaskLocals
) -> PyResult<Self> {
pub fn new(py: Python, cb: PyObject, proto: Py<WebsocketProtocol>, context: TaskLocals) -> PyResult<Self> {
let pyctx = context.context(py);
Ok(Self {
proto,
context,
pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(),
cb
cb,
})
}
@ -260,21 +228,16 @@ pub(crate) struct CallbackWrappedRunnerWebsocket {
context: TaskLocals,
cb: PyObject,
#[pyo3(get)]
scope: PyObject
scope: PyObject,
}
impl CallbackWrappedRunnerWebsocket {
pub fn new(
py: Python,
cb: CallbackWrapper,
proto: WebsocketProtocol,
scope: Scope
) -> Self {
pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self {
Self {
proto: Py::new(py, proto).unwrap(),
context: cb.context,
cb: cb.callback,
scope: scope.into_py(py)
scope: scope.into_py(py),
}
}
@ -325,7 +288,7 @@ macro_rules! call_impl_rtb_http {
cb: CallbackWrapper,
rt: RuntimeRef,
req: Request<Body>,
scope: Scope
scope: Scope,
) -> oneshot::Receiver<Response<Body>> {
let (tx, rx) = oneshot::channel();
let protocol = HTTPProtocol::new(rt, req, tx);
@ -345,7 +308,7 @@ macro_rules! call_impl_rtt_http {
cb: CallbackWrapper,
rt: RuntimeRef,
req: Request<Body>,
scope: Scope
scope: Scope,
) -> oneshot::Receiver<Response<Body>> {
let (tx, rx) = oneshot::channel();
let protocol = HTTPProtocol::new(rt, req, tx);
@ -368,7 +331,7 @@ macro_rules! call_impl_rtb_ws {
rt: RuntimeRef,
ws: HyperWebsocket,
upgrade: UpgradeData,
scope: Scope
scope: Scope,
) -> oneshot::Receiver<bool> {
let (tx, rx) = oneshot::channel();
let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade);
@ -389,7 +352,7 @@ macro_rules! call_impl_rtt_ws {
rt: RuntimeRef,
ws: HyperWebsocket,
upgrade: UpgradeData,
scope: Scope
scope: Scope,
) -> oneshot::Receiver<bool> {
let (tx, rx) = oneshot::channel();
let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade);

View file

@ -88,5 +88,5 @@ macro_rules! error_message {
}
pub(crate) use error_flow;
pub(crate) use error_transport;
pub(crate) use error_message;
pub(crate) use error_transport;

View file

@ -1,34 +1,22 @@
use hyper::{
Body,
Request,
Response,
StatusCode,
header::SERVER as HK_SERVER,
http::response::Builder as ResponseBuilder
header::SERVER as HK_SERVER, http::response::Builder as ResponseBuilder, Body, Request, Response, StatusCode,
};
use std::net::SocketAddr;
use tokio::sync::mpsc;
use crate::{
callbacks::CallbackWrapper,
http::{HV_SERVER, response_500},
runtime::RuntimeRef,
ws::{UpgradeData, is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade}
};
use super::{
callbacks::{
call_rtb_http,
call_rtb_http_pyw,
call_rtb_ws,
call_rtb_ws_pyw,
call_rtt_http,
call_rtt_http_pyw,
call_rtt_ws,
call_rtt_ws_pyw
call_rtb_http, call_rtb_http_pyw, call_rtb_ws, call_rtb_ws_pyw, call_rtt_http, call_rtt_http_pyw, call_rtt_ws,
call_rtt_ws_pyw,
},
types::ASGIScope as Scope
types::ASGIScope as Scope,
};
use crate::{
callbacks::CallbackWrapper,
http::{response_500, HV_SERVER},
runtime::RuntimeRef,
ws::{is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade, UpgradeData},
};
macro_rules! default_scope {
($server_addr:expr, $client_addr:expr, $req:expr, $scheme:expr) => {
@ -39,7 +27,7 @@ macro_rules! default_scope {
$req.method().as_ref(),
$server_addr,
$client_addr,
$req.headers()
$req.headers(),
)
};
}
@ -53,7 +41,7 @@ macro_rules! handle_http_response {
response_500()
}
}
}
};
}
macro_rules! handle_request {
@ -64,7 +52,7 @@ macro_rules! handle_request {
server_addr: SocketAddr,
client_addr: SocketAddr,
req: Request<Body>,
scheme: &str
scheme: &str,
) -> Response<Body> {
let scope = default_scope!(server_addr, client_addr, &req, scheme);
handle_http_response!($handler, rt, callback, req, scope)
@ -80,7 +68,7 @@ macro_rules! handle_request_with_ws {
server_addr: SocketAddr,
client_addr: SocketAddr,
req: Request<Body>,
scheme: &str
scheme: &str,
) -> Response<Body> {
let mut scope = default_scope!(server_addr, client_addr, &req, scheme);
@ -95,24 +83,20 @@ macro_rules! handle_request_with_ws {
rt.inner.spawn(async move {
let tx_ref = restx.clone();
match $handler_ws(
callback,
rth,
ws,
UpgradeData::new(res, restx),
scope
).await {
match $handler_ws(callback, rth, ws, UpgradeData::new(res, restx), scope).await {
Ok(consumed) => {
if !consumed {
let _ = tx_ref.send(
ResponseBuilder::new()
.status(StatusCode::FORBIDDEN)
.header(HK_SERVER, HV_SERVER)
.body(Body::from(""))
.unwrap()
).await;
let _ = tx_ref
.send(
ResponseBuilder::new()
.status(StatusCode::FORBIDDEN)
.header(HK_SERVER, HV_SERVER)
.body(Body::from(""))
.unwrap(),
)
.await;
};
},
}
_ => {
log::error!("ASGI protocol failure");
let _ = tx_ref.send(response_500()).await;
@ -124,10 +108,10 @@ macro_rules! handle_request_with_ws {
Some(res) => {
resrx.close();
res
},
_ => response_500()
}
_ => response_500(),
}
},
}
Err(err) => {
return ResponseBuilder::new()
.status(StatusCode::BAD_REQUEST)

View file

@ -1,33 +1,34 @@
use bytes::Bytes;
use futures::{sink::SinkExt, stream::{SplitSink, SplitStream, StreamExt}};
use futures::{
sink::SinkExt,
stream::{SplitSink, SplitStream, StreamExt},
};
use hyper::{
Request,
Response,
body::{Body, HttpBody, Sender as BodySender},
header::{HeaderName, HeaderValue, HeaderMap, SERVER as HK_SERVER}
header::{HeaderMap, HeaderName, HeaderValue, SERVER as HK_SERVER},
Request, Response,
};
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict};
use std::sync::Arc;
use tokio::sync::{oneshot, Mutex};
use tokio_tungstenite::WebSocketStream;
use tokio::sync::{Mutex, oneshot};
use tungstenite::Message;
use super::{
errors::{error_flow, error_message, error_transport, UnsupportedASGIMessage},
types::ASGIMessageType,
};
use crate::{
http::HV_SERVER,
runtime::{RuntimeRef, future_into_py_iter, future_into_py_futlike},
ws::{HyperWebsocket, UpgradeData}
runtime::{future_into_py_futlike, future_into_py_iter, RuntimeRef},
ws::{HyperWebsocket, UpgradeData},
};
use super::{
errors::{UnsupportedASGIMessage, error_flow, error_transport, error_message},
types::ASGIMessageType
};
const EMPTY_BYTES: Vec<u8> = Vec::new();
const EMPTY_STRING: String = String::new();
#[pyclass(module="granian._granian")]
#[pyclass(module = "granian._granian")]
pub(crate) struct ASGIHTTPProtocol {
rt: RuntimeRef,
tx: Option<oneshot::Sender<Response<Body>>>,
@ -36,15 +37,11 @@ pub(crate) struct ASGIHTTPProtocol {
response_chunked: bool,
response_status: Option<i16>,
response_headers: Option<HeaderMap>,
body_tx: Option<Arc<Mutex<BodySender>>>
body_tx: Option<Arc<Mutex<BodySender>>>,
}
impl ASGIHTTPProtocol {
pub fn new(
rt: RuntimeRef,
request: Request<Body>,
tx: oneshot::Sender<Response<Body>>
) -> Self {
pub fn new(rt: RuntimeRef, request: Request<Body>, tx: oneshot::Sender<Response<Body>>) -> Self {
Self {
rt,
tx: Some(tx),
@ -53,7 +50,7 @@ impl ASGIHTTPProtocol {
response_chunked: false,
response_status: None,
response_headers: None,
body_tx: None
body_tx: None,
}
}
@ -71,7 +68,7 @@ impl ASGIHTTPProtocol {
fn send_body<'p>(&self, py: Python<'p>, tx: Arc<Mutex<BodySender>>, body: Vec<u8>) -> PyResult<&'p PyAny> {
future_into_py_futlike(self.rt.clone(), py, async move {
let mut tx = tx.lock().await;
match (&mut *tx).send_data(body.into()).await {
match (*tx).send_data(body.into()).await {
Ok(_) => Ok(()),
Err(err) => {
log::warn!("ASGI transport tx error: {:?}", err);
@ -94,18 +91,18 @@ impl ASGIHTTPProtocol {
let mut bodym = body_ref.lock().await;
let body = &mut *bodym;
let mut more_body = false;
let chunk = body.data().await.map_or_else(|| Bytes::new(), |buf| {
buf.map_or_else(|_| Bytes::new(), |buf| {
more_body = !body.is_end_stream();
buf
})
let chunk = body.data().await.map_or_else(Bytes::new, |buf| {
buf.map_or_else(
|_| Bytes::new(),
|buf| {
more_body = !body.is_end_stream();
buf
},
)
});
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item(
pyo3::intern!(py, "type"),
pyo3::intern!(py, "http.request")
)?;
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "http.request"))?;
dict.set_item(pyo3::intern!(py, "body"), PyBytes::new(py, &chunk[..]))?;
dict.set_item(pyo3::intern!(py, "more_body"), more_body)?;
Ok(dict.to_object(py))
@ -115,16 +112,14 @@ impl ASGIHTTPProtocol {
fn send<'p>(&mut self, py: Python<'p>, asyncw: &'p PyAny, data: &'p PyDict) -> PyResult<&'p PyAny> {
match adapt_message_type(data) {
Ok(ASGIMessageType::HTTPStart) => {
match self.response_started {
false => {
self.response_status = Some(adapt_status_code(data)?);
self.response_headers = Some(adapt_headers(data));
self.response_started = true;
asyncw.call0()
},
true => error_flow!()
Ok(ASGIMessageType::HTTPStart) => match self.response_started {
false => {
self.response_status = Some(adapt_status_code(data)?);
self.response_headers = Some(adapt_headers(data));
self.response_started = true;
asyncw.call0()
}
true => error_flow!(),
},
Ok(ASGIMessageType::HTTPBody) => {
let (body, more) = adapt_body(data);
@ -133,7 +128,7 @@ impl ASGIHTTPProtocol {
let headers = self.response_headers.take().unwrap();
self.send_response(self.response_status.unwrap(), headers, body.into());
asyncw.call0()
},
}
(true, true, false) => {
self.response_chunked = true;
let headers = self.response_headers.take().unwrap();
@ -142,37 +137,31 @@ impl ASGIHTTPProtocol {
self.body_tx = Some(tx.clone());
self.send_response(self.response_status.unwrap(), headers, body_stream);
self.send_body(py, tx, body)
},
(true, true, true) => {
match self.body_tx.as_mut() {
Some(tx) => {
let tx = tx.clone();
self.send_body(py, tx, body)
},
_ => error_flow!()
}
(true, true, true) => match self.body_tx.as_mut() {
Some(tx) => {
let tx = tx.clone();
self.send_body(py, tx, body)
}
_ => error_flow!(),
},
(true, false, true) => {
match self.body_tx.take() {
Some(tx) => {
match body.is_empty() {
false => self.send_body(py, tx, body),
true => asyncw.call0()
}
},
_ => error_flow!()
}
(true, false, true) => match self.body_tx.take() {
Some(tx) => match body.is_empty() {
false => self.send_body(py, tx, body),
true => asyncw.call0(),
},
_ => error_flow!(),
},
_ => error_flow!()
_ => error_flow!(),
}
},
}
Err(err) => Err(err.into()),
_ => error_message!()
_ => error_message!(),
}
}
}
#[pyclass(module="granian._granian")]
#[pyclass(module = "granian._granian")]
pub(crate) struct ASGIWebsocketProtocol {
rt: RuntimeRef,
tx: Option<oneshot::Sender<bool>>,
@ -181,16 +170,11 @@ pub(crate) struct ASGIWebsocketProtocol {
ws_tx: Arc<Mutex<Option<SplitSink<WebSocketStream<hyper::upgrade::Upgraded>, Message>>>>,
ws_rx: Arc<Mutex<Option<SplitStream<WebSocketStream<hyper::upgrade::Upgraded>>>>>,
accepted: Arc<Mutex<bool>>,
closed: bool
closed: bool,
}
impl ASGIWebsocketProtocol {
pub fn new(
rt: RuntimeRef,
tx: oneshot::Sender<bool>,
websocket: HyperWebsocket,
upgrade: UpgradeData
) -> Self {
pub fn new(rt: RuntimeRef, tx: oneshot::Sender<bool>, websocket: HyperWebsocket, upgrade: UpgradeData) -> Self {
Self {
rt,
tx: Some(tx),
@ -199,7 +183,7 @@ impl ASGIWebsocketProtocol {
ws_tx: Arc::new(Mutex::new(None)),
ws_rx: Arc::new(Mutex::new(None)),
accepted: Arc::new(Mutex::new(false)),
closed: false
closed: false,
}
}
@ -211,7 +195,7 @@ impl ASGIWebsocketProtocol {
let tx = self.ws_tx.clone();
let rx = self.ws_rx.clone();
future_into_py_iter(self.rt.clone(), py, async move {
if let Ok(_) = upgrade.send().await {
if (upgrade.send().await).is_ok() {
if let Ok(stream) = websocket.await {
let mut wtx = tx.lock().await;
let mut wrx = rx.lock().await;
@ -220,7 +204,7 @@ impl ASGIWebsocketProtocol {
*wtx = Some(tx);
*wrx = Some(rx);
*accepted = true;
return Ok(())
return Ok(());
}
}
error_flow!()
@ -228,18 +212,14 @@ impl ASGIWebsocketProtocol {
}
#[inline(always)]
fn send_message<'p>(
&self,
py: Python<'p>,
data: &'p PyDict
) -> PyResult<&'p PyAny> {
fn send_message<'p>(&self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> {
let transport = self.ws_tx.clone();
let message = ws_message_into_rs(data);
future_into_py_iter(self.rt.clone(), py, async move {
if let Ok(message) = message {
if let Some(ws) = &mut *(transport.lock().await) {
if let Ok(_) = ws.send(message).await {
return Ok(())
if (ws.send(message).await).is_ok() {
return Ok(());
}
};
};
@ -253,8 +233,8 @@ impl ASGIWebsocketProtocol {
let transport = self.ws_tx.clone();
future_into_py_iter(self.rt.clone(), py, async move {
if let Some(ws) = &mut *(transport.lock().await) {
if let Ok(_) = ws.close().await {
return Ok(())
if (ws.close().await).is_ok() {
return Ok(());
}
};
error_flow!()
@ -262,10 +242,7 @@ impl ASGIWebsocketProtocol {
}
fn consumed(&self) -> bool {
match &self.upgrade {
Some(_) => false,
_ => true
}
self.upgrade.is_none()
}
pub fn tx(&mut self) -> (Option<oneshot::Sender<bool>>, bool) {
@ -278,44 +255,26 @@ impl ASGIWebsocketProtocol {
fn receive<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> {
let transport = self.ws_rx.clone();
let accepted = self.accepted.clone();
let closed = self.closed.clone();
let closed = self.closed;
future_into_py_futlike(self.rt.clone(), py, async move {
let accepted = accepted.lock().await;
match (*accepted, closed) {
(false, false) => {
return Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item(
pyo3::intern!(py, "type"),
pyo3::intern!(py, "websocket.connect")
)?;
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.connect"))?;
Ok(dict.to_object(py))
})
},
(true, false) => {},
_ => {
return error_flow!()
}
(true, false) => {}
_ => return error_flow!(),
}
if let Some(ws) = &mut *(transport.lock().await) {
loop {
match ws.next().await {
Some(recv) => {
match recv {
Ok(Message::Ping(_)) => {
continue
},
Ok(message) => {
return ws_message_into_py(message)
},
_ => {
break
}
}
},
_ => {
break
}
while let Some(recv) = ws.next().await {
match recv {
Ok(Message::Ping(_)) => continue,
Ok(message) => return ws_message_into_py(message),
_ => break,
}
}
}
@ -325,25 +284,17 @@ impl ASGIWebsocketProtocol {
fn send<'p>(&mut self, py: Python<'p>, _asyncw: &'p PyAny, data: &'p PyDict) -> PyResult<&'p PyAny> {
match (adapt_message_type(data), self.closed) {
(Ok(ASGIMessageType::WSAccept), _) => {
self.accept(py)
},
(Ok(ASGIMessageType::WSClose), false) => {
self.close(py)
},
(Ok(ASGIMessageType::WSMessage), false) => {
self.send_message(py, data)
},
(Ok(ASGIMessageType::WSAccept), _) => self.accept(py),
(Ok(ASGIMessageType::WSClose), false) => self.close(py),
(Ok(ASGIMessageType::WSMessage), false) => self.send_message(py, data),
(Err(err), _) => Err(err.into()),
_ => error_message!()
_ => error_message!(),
}
}
}
#[inline(never)]
fn adapt_message_type(
message: &PyDict
) -> Result<ASGIMessageType, UnsupportedASGIMessage> {
fn adapt_message_type(message: &PyDict) -> Result<ASGIMessageType, UnsupportedASGIMessage> {
match message.get_item("type") {
Some(item) => {
let message_type: &str = item.extract()?;
@ -353,20 +304,18 @@ fn adapt_message_type(
"websocket.accept" => Ok(ASGIMessageType::WSAccept),
"websocket.close" => Ok(ASGIMessageType::WSClose),
"websocket.send" => Ok(ASGIMessageType::WSMessage),
_ => error_message!()
_ => error_message!(),
}
},
_ => error_message!()
}
_ => error_message!(),
}
}
#[inline(always)]
fn adapt_status_code(message: &PyDict) -> Result<i16, UnsupportedASGIMessage> {
match message.get_item("status") {
Some(item) => {
Ok(item.extract()?)
},
_ => error_message!()
Some(item) => Ok(item.extract()?),
_ => error_message!(),
}
}
@ -377,34 +326,26 @@ fn adapt_headers(message: &PyDict) -> HeaderMap {
match message.get_item("headers") {
Some(item) => {
let accum: Vec<Vec<&[u8]>> = item.extract().unwrap_or(Vec::new());
for tup in accum.iter() {
match (
HeaderName::from_bytes(tup[0]),
HeaderValue::from_bytes(tup[1])
) {
(Ok(key), Ok(val)) => { ret.append(key, val); },
_ => {}
for tup in &accum {
if let (Ok(key), Ok(val)) = (HeaderName::from_bytes(tup[0]), HeaderValue::from_bytes(tup[1])) {
ret.append(key, val);
}
};
}
ret
},
_ => ret
}
_ => ret,
}
}
#[inline(always)]
fn adapt_body(message: &PyDict) -> (Vec<u8>, bool) {
let body = match message.get_item("body") {
Some(item) => {
item.extract().unwrap_or(EMPTY_BYTES)
},
_ => EMPTY_BYTES
Some(item) => item.extract().unwrap_or(EMPTY_BYTES),
_ => EMPTY_BYTES,
};
let more = match message.get_item("more_body") {
Some(item) => {
item.extract().unwrap_or(false)
},
_ => false
Some(item) => item.extract().unwrap_or(false),
_ => false,
};
(body, more)
}
@ -412,22 +353,12 @@ fn adapt_body(message: &PyDict) -> (Vec<u8>, bool) {
#[inline(always)]
fn ws_message_into_rs(message: &PyDict) -> PyResult<Message> {
match (message.get_item("bytes"), message.get_item("text")) {
(Some(item), None) => {
Ok(Message::Binary(item.extract().unwrap_or(EMPTY_BYTES)))
},
(None, Some(item)) => {
Ok(Message::Text(item.extract().unwrap_or(EMPTY_STRING)))
},
(Some(itemb), Some(itemt)) => {
match (itemb.extract().unwrap_or(None), itemt.extract().unwrap_or(None)) {
(Some(msgb), None) => {
Ok(Message::Binary(msgb))
},
(None, Some(msgt)) => {
Ok(Message::Text(msgt))
},
_ => error_flow!()
}
(Some(item), None) => Ok(Message::Binary(item.extract().unwrap_or(EMPTY_BYTES))),
(None, Some(item)) => Ok(Message::Text(item.extract().unwrap_or(EMPTY_STRING))),
(Some(itemb), Some(itemt)) => match (itemb.extract().unwrap_or(None), itemt.extract().unwrap_or(None)) {
(Some(msgb), None) => Ok(Message::Binary(msgb)),
(None, Some(msgt)) => Ok(Message::Text(msgt)),
_ => error_flow!(),
},
_ => {
error_flow!()
@ -438,41 +369,23 @@ fn ws_message_into_rs(message: &PyDict) -> PyResult<Message> {
#[inline(always)]
fn ws_message_into_py(message: Message) -> PyResult<PyObject> {
match message {
Message::Binary(message) => {
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item(
pyo3::intern!(py, "type"),
pyo3::intern!(py, "websocket.receive")
)?;
dict.set_item(
pyo3::intern!(py, "bytes"),
PyBytes::new(py, &message[..])
)?;
Ok(dict.to_object(py))
})
},
Message::Text(message) => {
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item(
pyo3::intern!(py, "type"),
pyo3::intern!(py, "websocket.receive")
)?;
dict.set_item(pyo3::intern!(py, "text"), message)?;
Ok(dict.to_object(py))
})
},
Message::Close(_) => {
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item(
pyo3::intern!(py, "type"),
pyo3::intern!(py, "websocket.disconnect")
)?;
Ok(dict.to_object(py))
})
},
Message::Binary(message) => Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.receive"))?;
dict.set_item(pyo3::intern!(py, "bytes"), PyBytes::new(py, &message[..]))?;
Ok(dict.to_object(py))
}),
Message::Text(message) => Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.receive"))?;
dict.set_item(pyo3::intern!(py, "text"), message)?;
Ok(dict.to_object(py))
}),
Message::Close(_) => Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.disconnect"))?;
Ok(dict.to_object(py))
}),
v => {
log::warn!("Unsupported websocket message received {:?}", v);
error_flow!()

View file

@ -1,29 +1,14 @@
use pyo3::prelude::*;
use crate::{
workers::{
WorkerConfig,
serve_rth,
serve_wth,
serve_rth_ssl,
serve_wth_ssl
}
};
use super::http::{
handle_rtb,
handle_rtb_pyw,
handle_rtt,
handle_rtt_pyw,
handle_rtb_ws,
handle_rtb_ws_pyw,
handle_rtt_ws,
handle_rtt_ws_pyw
handle_rtb, handle_rtb_pyw, handle_rtb_ws, handle_rtb_ws_pyw, handle_rtt, handle_rtt_pyw, handle_rtt_ws,
handle_rtt_ws_pyw,
};
use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig};
#[pyclass(module="granian._granian")]
#[pyclass(module = "granian._granian")]
pub struct ASGIWorker {
config: WorkerConfig
config: WorkerConfig,
}
impl ASGIWorker {
@ -74,7 +59,7 @@ impl ASGIWorker {
opt_enabled: bool,
ssl_enabled: bool,
ssl_cert: Option<&str>,
ssl_key: Option<&str>
ssl_key: Option<&str>,
) -> PyResult<Self> {
Ok(Self {
config: WorkerConfig::new(
@ -88,22 +73,16 @@ impl ASGIWorker {
opt_enabled,
ssl_enabled,
ssl_cert,
ssl_key
)
ssl_key,
),
})
}
fn serve_rth(
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
fn serve_rth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
match (
self.config.websockets_enabled,
self.config.ssl_enabled,
self.config.opt_enabled
self.config.opt_enabled,
) {
(false, false, true) => self._serve_rth(callback, event_loop, context, signal_rx),
(false, false, false) => self._serve_rth_pyw(callback, event_loop, context, signal_rx),
@ -112,21 +91,15 @@ impl ASGIWorker {
(false, true, true) => self._serve_rth_ssl(callback, event_loop, context, signal_rx),
(false, true, false) => self._serve_rth_ssl_pyw(callback, event_loop, context, signal_rx),
(true, true, true) => self._serve_rth_ssl_ws(callback, event_loop, context, signal_rx),
(true, true, false) => self._serve_rth_ssl_ws_pyw(callback, event_loop, context, signal_rx)
(true, true, false) => self._serve_rth_ssl_ws_pyw(callback, event_loop, context, signal_rx),
}
}
fn serve_wth(
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
fn serve_wth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
match (
self.config.websockets_enabled,
self.config.ssl_enabled,
self.config.opt_enabled
self.config.opt_enabled,
) {
(false, false, true) => self._serve_wth(callback, event_loop, context, signal_rx),
(false, false, false) => self._serve_wth_pyw(callback, event_loop, context, signal_rx),
@ -135,7 +108,7 @@ impl ASGIWorker {
(false, true, true) => self._serve_wth_ssl(callback, event_loop, context, signal_rx),
(false, true, false) => self._serve_wth_ssl_pyw(callback, event_loop, context, signal_rx),
(true, true, true) => self._serve_wth_ssl_ws(callback, event_loop, context, signal_rx),
(true, true, false) => self._serve_wth_ssl_ws_pyw(callback, event_loop, context, signal_rx)
(true, true, false) => self._serve_wth_ssl_ws_pyw(callback, event_loop, context, signal_rx),
}
}
}

View file

@ -1,10 +1,9 @@
use hyper::{Uri, Version, header::HeaderMap};
use hyper::{header::HeaderMap, Uri, Version};
use once_cell::sync::OnceCell;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyList, PyString};
use std::net::{IpAddr, SocketAddr};
const SCHEME_HTTPS: &str = "https";
const SCHEME_WS: &str = "ws";
const SCHEME_WSS: &str = "wss";
@ -17,10 +16,10 @@ pub(crate) enum ASGIMessageType {
HTTPBody,
WSAccept,
WSClose,
WSMessage
WSMessage,
}
#[pyclass(module="granian._granian")]
#[pyclass(module = "granian._granian")]
pub(crate) struct ASGIScope {
http_version: Version,
scheme: String,
@ -31,7 +30,7 @@ pub(crate) struct ASGIScope {
client_ip: IpAddr,
client_port: u16,
headers: HeaderMap,
is_websocket: bool
is_websocket: bool,
}
impl ASGIScope {
@ -42,31 +41,31 @@ impl ASGIScope {
method: &str,
server: SocketAddr,
client: SocketAddr,
headers: &HeaderMap
headers: &HeaderMap,
) -> Self {
Self {
http_version: http_version,
http_version,
scheme: scheme.to_string(),
method: method.to_string(),
uri: uri,
uri,
server_ip: server.ip(),
server_port: server.port(),
client_ip: client.ip(),
client_port: client.port(),
headers: headers.to_owned(),
is_websocket: false
headers: headers.clone(),
is_websocket: false,
}
}
pub fn set_websocket(&mut self) {
self.is_websocket = true
self.is_websocket = true;
}
#[inline(always)]
fn py_proto(&self) -> &str {
match self.is_websocket {
false => "http",
true => "websocket"
true => "websocket",
}
}
@ -76,7 +75,7 @@ impl ASGIScope {
Version::HTTP_10 => "1",
Version::HTTP_11 => "1.1",
Version::HTTP_2 => "2",
_ => "1"
_ => "1",
}
}
@ -85,22 +84,20 @@ impl ASGIScope {
let scheme = &self.scheme[..];
match self.is_websocket {
false => scheme,
true => {
match scheme {
SCHEME_HTTPS => SCHEME_WSS,
_ => SCHEME_WS
}
}
true => match scheme {
SCHEME_HTTPS => SCHEME_WSS,
_ => SCHEME_WS,
},
}
}
#[inline(always)]
fn py_headers<'p>(&self, py: Python<'p>) -> PyResult<&'p PyList> {
let rv = PyList::empty(py);
for (key, value) in self.headers.iter() {
for (key, value) in &self.headers {
rv.append((
PyBytes::new(py, key.as_str().as_bytes()),
PyBytes::new(py, value.as_bytes())
PyBytes::new(py, value.as_bytes()),
))?;
}
Ok(rv)
@ -110,17 +107,10 @@ impl ASGIScope {
#[pymethods]
impl ASGIScope {
fn as_dict<'p>(&self, py: Python<'p>, url_path_prefix: &'p str) -> PyResult<&'p PyAny> {
let (
path,
query_string,
proto,
http_version,
server,
client,
scheme,
method
) = py.allow_threads(|| {
let (path, query_string) = self.uri.path_and_query()
let (path, query_string, proto, http_version, server, client, scheme, method) = py.allow_threads(|| {
let (path, query_string) = self
.uri
.path_and_query()
.map_or_else(|| ("", ""), |pq| (pq.path(), pq.query().unwrap_or("")));
(
path,
@ -136,19 +126,23 @@ impl ASGIScope {
let dict: &PyDict = PyDict::new(py);
dict.set_item(
pyo3::intern!(py, "asgi"),
ASGI_VERSION.get_or_try_init(|| {
let rv = PyDict::new(py);
rv.set_item("version", "3.0")?;
rv.set_item("spec_version", "2.3")?;
Ok::<PyObject, PyErr>(rv.into())
})?.as_ref(py)
ASGI_VERSION
.get_or_try_init(|| {
let rv = PyDict::new(py);
rv.set_item("version", "3.0")?;
rv.set_item("spec_version", "2.3")?;
Ok::<PyObject, PyErr>(rv.into())
})?
.as_ref(py),
)?;
dict.set_item(
pyo3::intern!(py, "extensions"),
ASGI_EXTENSIONS.get_or_try_init(|| {
let rv = PyDict::new(py);
Ok::<PyObject, PyErr>(rv.into())
})?.as_ref(py)
ASGI_EXTENSIONS
.get_or_try_init(|| {
let rv = PyDict::new(py);
Ok::<PyObject, PyErr>(rv.into())
})?
.as_ref(py),
)?;
dict.set_item(pyo3::intern!(py, "type"), proto)?;
dict.set_item(pyo3::intern!(py, "http_version"), http_version)?;
@ -160,17 +154,12 @@ impl ASGIScope {
dict.set_item(pyo3::intern!(py, "path"), path)?;
dict.set_item(
pyo3::intern!(py, "raw_path"),
PyString::new(py, path)
.call_method1(
pyo3::intern!(py, "encode"), (pyo3::intern!(py, "ascii"),)
)?
PyString::new(py, path).call_method1(pyo3::intern!(py, "encode"), (pyo3::intern!(py, "ascii"),))?,
)?;
dict.set_item(
pyo3::intern!(py, "query_string"),
PyString::new(py, query_string)
.call_method1(
pyo3::intern!(py, "encode"), (pyo3::intern!(py, "latin-1"),)
)?
.call_method1(pyo3::intern!(py, "encode"), (pyo3::intern!(py, "latin-1"),))?,
)?;
dict.set_item(pyo3::intern!(py, "headers"), self.py_headers(py)?)?;
Ok(dict)

View file

@ -2,32 +2,27 @@ use once_cell::sync::OnceCell;
use pyo3::prelude::*;
use pyo3::pyclass::IterNextOutput;
static CONTEXTVARS: OnceCell<PyObject> = OnceCell::new();
static CONTEXT: OnceCell<PyObject> = OnceCell::new();
#[derive(Clone)]
pub(crate) struct CallbackWrapper {
pub callback: PyObject,
pub context: pyo3_asyncio::TaskLocals
pub context: pyo3_asyncio::TaskLocals,
}
impl CallbackWrapper {
pub(crate) fn new(
callback: PyObject,
event_loop: &PyAny,
context: &PyAny
) -> Self {
pub(crate) fn new(callback: PyObject, event_loop: &PyAny, context: &PyAny) -> Self {
Self {
callback,
context: pyo3_asyncio::TaskLocals::new(event_loop).with_context(context)
context: pyo3_asyncio::TaskLocals::new(event_loop).with_context(context),
}
}
}
#[pyclass]
pub(crate) struct PyIterAwaitable {
result: Option<PyResult<PyObject>>
result: Option<PyResult<PyObject>>,
}
impl PyIterAwaitable {
@ -36,7 +31,7 @@ impl PyIterAwaitable {
}
pub(crate) fn set_result(&mut self, result: PyResult<PyObject>) {
self.result = Some(result)
self.result = Some(result);
}
}
@ -52,13 +47,11 @@ impl PyIterAwaitable {
fn __next__(&mut self, py: Python) -> PyResult<IterNextOutput<PyObject, PyObject>> {
match self.result.take() {
Some(res) => {
match res {
Ok(v) => Ok(IterNextOutput::Return(v)),
Err(err) => Err(err)
}
Some(res) => match res {
Ok(v) => Ok(IterNextOutput::Return(v)),
Err(err) => Err(err),
},
_ => Ok(IterNextOutput::Yield(py.None()))
_ => Ok(IterNextOutput::Yield(py.None())),
}
}
}
@ -68,7 +61,7 @@ pub(crate) struct PyFutureAwaitable {
py_block: bool,
event_loop: PyObject,
result: Option<PyResult<PyObject>>,
cb: Option<(PyObject, Py<pyo3::types::PyDict>)>
cb: Option<(PyObject, Py<pyo3::types::PyDict>)>,
}
impl PyFutureAwaitable {
@ -77,7 +70,7 @@ impl PyFutureAwaitable {
event_loop,
py_block: true,
result: None,
cb: None
cb: None,
}
}
@ -85,12 +78,9 @@ impl PyFutureAwaitable {
pyself.result = Some(result);
if let Some((cb, ctx)) = pyself.cb.take() {
let py = pyself.py();
let _ = pyself.event_loop.call_method(
py,
"call_soon_threadsafe",
(cb, &pyself),
Some(ctx.as_ref(py))
);
let _ = pyself
.event_loop
.call_method(py, "call_soon_threadsafe", (cb, &pyself), Some(ctx.as_ref(py)));
}
}
}
@ -108,25 +98,22 @@ impl PyFutureAwaitable {
#[setter(_asyncio_future_blocking)]
fn set_block(&mut self, val: bool) {
self.py_block = val
self.py_block = val;
}
fn get_loop(&mut self) -> PyObject {
self.event_loop.clone()
}
fn add_done_callback(
mut pyself: PyRefMut<'_, Self>,
py: Python,
cb: PyObject,
context: PyObject
) -> PyResult<()> {
fn add_done_callback(mut pyself: PyRefMut<'_, Self>, py: Python, cb: PyObject, context: PyObject) -> PyResult<()> {
let kwctx = pyo3::types::PyDict::new(py);
kwctx.set_item("context", context)?;
match pyself.result {
Some(_) => {
pyself.event_loop.call_method(py, "call_soon", (cb, &pyself), Some(kwctx))?;
},
pyself
.event_loop
.call_method(py, "call_soon", (cb, &pyself), Some(kwctx))?;
}
_ => {
pyself.cb = Some((cb, kwctx.into_py(py)));
}
@ -136,9 +123,9 @@ impl PyFutureAwaitable {
fn cancel(mut pyself: PyRefMut<'_, Self>, py: Python) -> bool {
if let Some((cb, kwctx)) = pyself.cb.take() {
let _ = pyself.event_loop.call_method(
py, "call_soon", (cb, &pyself), Some(kwctx.as_ref(py))
);
let _ = pyself
.event_loop
.call_method(py, "call_soon", (cb, &pyself), Some(kwctx.as_ref(py)));
}
false
}
@ -150,79 +137,69 @@ impl PyFutureAwaitable {
pyself
}
fn __next__(
mut pyself: PyRefMut<'_, Self>
) -> PyResult<IterNextOutput<PyRefMut<'_, Self>, PyObject>> {
fn __next__(mut pyself: PyRefMut<'_, Self>) -> PyResult<IterNextOutput<PyRefMut<'_, Self>, PyObject>> {
match pyself.result {
Some(_) => {
match pyself.result.take().unwrap() {
Ok(v) => Ok(IterNextOutput::Return(v)),
Err(err) => Err(err)
}
Some(_) => match pyself.result.take().unwrap() {
Ok(v) => Ok(IterNextOutput::Return(v)),
Err(err) => Err(err),
},
_ => Ok(IterNextOutput::Yield(pyself))
_ => Ok(IterNextOutput::Yield(pyself)),
}
}
}
fn contextvars(py: Python) -> PyResult<&PyAny> {
Ok(CONTEXTVARS
.get_or_try_init(|| py.import("contextvars").map(|m| m.into()))?
.get_or_try_init(|| py.import("contextvars").map(std::convert::Into::into))?
.as_ref(py))
}
pub fn empty_pycontext(py: Python) -> PyResult<&PyAny> {
Ok(CONTEXT
.get_or_try_init(|| contextvars(py)?.getattr("Context")?.call0().map(|c| c.into()))?
.get_or_try_init(|| {
contextvars(py)?
.getattr("Context")?
.call0()
.map(std::convert::Into::into)
})?
.as_ref(py))
}
macro_rules! callback_impl_run {
() => {
pub fn run<'p>(self, py: Python<'p>) -> PyResult<&'p PyAny> {
pub fn run(self, py: Python<'_>) -> PyResult<&PyAny> {
let event_loop = self.context.event_loop(py);
let target = self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_task"))?;
let kwctx = pyo3::types::PyDict::new(py);
kwctx.set_item(
pyo3::intern!(py, "context"),
crate::callbacks::empty_pycontext(py)?
crate::callbacks::empty_pycontext(py)?,
)?;
event_loop.call_method(
pyo3::intern!(py, "call_soon_threadsafe"),
(target,),
Some(kwctx)
)
event_loop.call_method(pyo3::intern!(py, "call_soon_threadsafe"), (target,), Some(kwctx))
}
};
}
macro_rules! callback_impl_run_pytask {
() => {
pub fn run<'p>(self, py: Python<'p>) -> PyResult<&'p PyAny> {
pub fn run(self, py: Python<'_>) -> PyResult<&PyAny> {
let event_loop = self.context.event_loop(py);
let context = self.context.context(py);
let target = self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_task"))?;
let kwctx = pyo3::types::PyDict::new(py);
kwctx.set_item(
pyo3::intern!(py, "context"),
context
)?;
event_loop.call_method(
pyo3::intern!(py, "call_soon_threadsafe"),
(target,),
Some(kwctx)
)
kwctx.set_item(pyo3::intern!(py, "context"), context)?;
event_loop.call_method(pyo3::intern!(py, "call_soon_threadsafe"), (target,), Some(kwctx))
}
};
}
macro_rules! callback_impl_loop_run {
() => {
pub fn run<'p>(self, py: Python<'p>) -> PyResult<&'p PyAny> {
pub fn run(self, py: Python<'_>) -> PyResult<&PyAny> {
let context = self.pycontext.clone().into_ref(py);
context.call_method1(
pyo3::intern!(py, "run"),
(self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_step"))?,)
(self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_step"))?,),
)
}
};
@ -232,7 +209,7 @@ macro_rules! callback_impl_loop_pytask {
($pyself:expr, $py:expr) => {
$pyself.context.event_loop($py).call_method1(
pyo3::intern!($py, "create_task"),
($pyself.cb.clone().into_ref($py).call1(($pyself.into_py($py),))?,)
($pyself.cb.clone().into_ref($py).call1(($pyself.into_py($py),))?,),
)
};
}
@ -241,12 +218,9 @@ macro_rules! callback_impl_loop_step {
($pyself:expr, $py:expr) => {
match $pyself.cb.call_method1($py, pyo3::intern!($py, "send"), ($py.None(),)) {
Ok(res) => {
let blocking: bool = match res.getattr(
$py,
pyo3::intern!($py, "_asyncio_future_blocking")
) {
let blocking: bool = match res.getattr($py, pyo3::intern!($py, "_asyncio_future_blocking")) {
Ok(v) => v.extract($py)?,
_ => false
_ => false,
};
let ctx = $pyself.pycontext.clone();
@ -255,43 +229,30 @@ macro_rules! callback_impl_loop_step {
match blocking {
true => {
res.setattr(
$py,
pyo3::intern!($py, "_asyncio_future_blocking"),
false
)?;
res.setattr($py, pyo3::intern!($py, "_asyncio_future_blocking"), false)?;
res.call_method(
$py,
pyo3::intern!($py, "add_done_callback"),
(
$pyself
.into_py($py)
.getattr($py, pyo3::intern!($py, "_loop_wake"))?,
),
Some(kwctx)
($pyself.into_py($py).getattr($py, pyo3::intern!($py, "_loop_wake"))?,),
Some(kwctx),
)?;
Ok(())
},
}
false => {
let event_loop = $pyself.context.event_loop($py);
event_loop.call_method(
pyo3::intern!($py, "call_soon"),
(
$pyself
.into_py($py)
.getattr($py, pyo3::intern!($py, "_loop_step"))?,
),
Some(kwctx)
($pyself.into_py($py).getattr($py, pyo3::intern!($py, "_loop_step"))?,),
Some(kwctx),
)?;
Ok(())
}
}
},
}
Err(err) => {
if (
err.is_instance_of::<pyo3::exceptions::PyStopIteration>($py) ||
err.is_instance_of::<pyo3::exceptions::asyncio::CancelledError>($py)
) {
if (err.is_instance_of::<pyo3::exceptions::PyStopIteration>($py)
|| err.is_instance_of::<pyo3::exceptions::asyncio::CancelledError>($py))
{
$pyself.done($py);
Ok(())
} else {
@ -307,7 +268,7 @@ macro_rules! callback_impl_loop_wake {
($pyself:expr, $py:expr, $fut:expr) => {
match $fut.call_method0($py, pyo3::intern!($py, "result")) {
Ok(_) => $pyself.into_py($py).call_method0($py, pyo3::intern!($py, "_loop_step")),
Err(err) => $pyself._loop_err($py, err)
Err(err) => $pyself._loop_err($py, err),
}
};
}
@ -322,10 +283,10 @@ macro_rules! callback_impl_loop_err {
};
}
pub(crate) use callback_impl_run;
pub(crate) use callback_impl_run_pytask;
pub(crate) use callback_impl_loop_run;
pub(crate) use callback_impl_loop_err;
pub(crate) use callback_impl_loop_pytask;
pub(crate) use callback_impl_loop_run;
pub(crate) use callback_impl_loop_step;
pub(crate) use callback_impl_loop_wake;
pub(crate) use callback_impl_loop_err;
pub(crate) use callback_impl_run;
pub(crate) use callback_impl_run_pytask;

View file

@ -1,4 +1,7 @@
use hyper::{Body, Response, header::{HeaderValue, SERVER as HK_SERVER}};
use hyper::{
header::{HeaderValue, SERVER as HK_SERVER},
Body, Response,
};
pub(crate) const HV_SERVER: HeaderValue = HeaderValue::from_static("granian");

View file

@ -8,8 +8,8 @@ mod callbacks;
mod http;
mod rsgi;
mod runtime;
mod tls;
mod tcp;
mod tls;
mod utils;
mod workers;
mod ws;

View file

@ -2,45 +2,33 @@ use pyo3::prelude::*;
use pyo3_asyncio::TaskLocals;
use tokio::sync::oneshot;
use crate::{
callbacks::{
CallbackWrapper,
callback_impl_run,
callback_impl_run_pytask,
callback_impl_loop_run,
callback_impl_loop_pytask,
callback_impl_loop_step,
callback_impl_loop_wake,
callback_impl_loop_err
},
runtime::RuntimeRef,
ws::{HyperWebsocket, UpgradeData}
};
use super::{
io::{RSGIHTTPProtocol as HTTPProtocol, RSGIWebsocketProtocol as WebsocketProtocol},
types::{RSGIScope as Scope, PyResponse, PyResponseBody}
types::{PyResponse, PyResponseBody, RSGIScope as Scope},
};
use crate::{
callbacks::{
callback_impl_loop_err, callback_impl_loop_pytask, callback_impl_loop_run, callback_impl_loop_step,
callback_impl_loop_wake, callback_impl_run, callback_impl_run_pytask, CallbackWrapper,
},
runtime::RuntimeRef,
ws::{HyperWebsocket, UpgradeData},
};
#[pyclass]
pub(crate) struct CallbackRunnerHTTP {
proto: Py<HTTPProtocol>,
context: TaskLocals,
cb: PyObject
cb: PyObject,
}
impl CallbackRunnerHTTP {
pub fn new(
py: Python,
cb: CallbackWrapper,
proto: HTTPProtocol,
scope: Scope
) -> Self {
pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self {
let pyproto = Py::new(py, proto).unwrap();
Self {
proto: pyproto.clone(),
context: cb.context,
cb: cb.callback.call1(py, (scope, pyproto)).unwrap()
cb: cb.callback.call1(py, (scope, pyproto)).unwrap(),
}
}
@ -58,19 +46,17 @@ macro_rules! callback_impl_done_http {
($self:expr, $py:expr) => {
if let Ok(mut proto) = $self.proto.as_ref($py).try_borrow_mut() {
if let Some(tx) = proto.tx() {
let _ = tx.send(
PyResponse::Body(PyResponseBody::empty(500, Vec::new()))
);
let _ = tx.send(PyResponse::Body(PyResponseBody::empty(500, Vec::new())));
}
}
}
};
}
macro_rules! callback_impl_done_err {
($self:expr, $py:expr) => {
log::warn!("Application callable raised an exception");
$self.done($py)
}
};
}
#[pyclass]
@ -78,18 +64,18 @@ pub(crate) struct CallbackTaskHTTP {
proto: Py<HTTPProtocol>,
context: TaskLocals,
pycontext: PyObject,
cb: PyObject
cb: PyObject,
}
impl CallbackTaskHTTP {
pub fn new(
py: Python,
cb: PyObject,
proto: Py<HTTPProtocol>,
context: TaskLocals
) -> PyResult<Self> {
pub fn new(py: Python, cb: PyObject, proto: Py<HTTPProtocol>, context: TaskLocals) -> PyResult<Self> {
let pyctx = context.context(py);
Ok(Self { proto, context, pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), cb })
Ok(Self {
proto,
context,
pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(),
cb,
})
}
fn done(&self, py: Python) {
@ -122,21 +108,16 @@ pub(crate) struct CallbackWrappedRunnerHTTP {
context: TaskLocals,
cb: PyObject,
#[pyo3(get)]
scope: PyObject
scope: PyObject,
}
impl CallbackWrappedRunnerHTTP {
pub fn new(
py: Python,
cb: CallbackWrapper,
proto: HTTPProtocol,
scope: Scope
) -> Self {
pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self {
Self {
proto: Py::new(py, proto).unwrap(),
context: cb.context,
cb: cb.callback,
scope: scope.into_py(py)
scope: scope.into_py(py),
}
}
@ -162,21 +143,16 @@ impl CallbackWrappedRunnerHTTP {
pub(crate) struct CallbackRunnerWebsocket {
proto: Py<WebsocketProtocol>,
context: TaskLocals,
cb: PyObject
cb: PyObject,
}
impl CallbackRunnerWebsocket {
pub fn new(
py: Python,
cb: CallbackWrapper,
proto: WebsocketProtocol,
scope: Scope
) -> Self {
pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self {
let pyproto = Py::new(py, proto).unwrap();
Self {
proto: pyproto.clone(),
context: cb.context,
cb: cb.callback.call1(py, (scope, pyproto)).unwrap()
cb: cb.callback.call1(py, (scope, pyproto)).unwrap(),
}
}
@ -197,7 +173,7 @@ macro_rules! callback_impl_done_ws {
let _ = tx.send(res);
}
}
}
};
}
#[pyclass]
@ -205,18 +181,18 @@ pub(crate) struct CallbackTaskWebsocket {
proto: Py<WebsocketProtocol>,
context: TaskLocals,
pycontext: PyObject,
cb: PyObject
cb: PyObject,
}
impl CallbackTaskWebsocket {
pub fn new(
py: Python,
cb: PyObject,
proto: Py<WebsocketProtocol>,
context: TaskLocals
) -> PyResult<Self> {
pub fn new(py: Python, cb: PyObject, proto: Py<WebsocketProtocol>, context: TaskLocals) -> PyResult<Self> {
let pyctx = context.context(py);
Ok(Self { proto, context, pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), cb })
Ok(Self {
proto,
context,
pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(),
cb,
})
}
fn done(&self, py: Python) {
@ -249,21 +225,16 @@ pub(crate) struct CallbackWrappedRunnerWebsocket {
context: TaskLocals,
cb: PyObject,
#[pyo3(get)]
scope: PyObject
scope: PyObject,
}
impl CallbackWrappedRunnerWebsocket {
pub fn new(
py: Python,
cb: CallbackWrapper,
proto: WebsocketProtocol,
scope: Scope
) -> Self {
pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self {
Self {
proto: Py::new(py, proto).unwrap(),
context: cb.context,
cb: cb.callback,
scope: scope.into_py(py)
scope: scope.into_py(py),
}
}
@ -291,7 +262,7 @@ macro_rules! call_impl_rtb_http {
cb: CallbackWrapper,
rt: RuntimeRef,
req: hyper::Request<hyper::Body>,
scope: Scope
scope: Scope,
) -> oneshot::Receiver<PyResponse> {
let (tx, rx) = oneshot::channel();
let protocol = HTTPProtocol::new(rt, tx, req);
@ -311,7 +282,7 @@ macro_rules! call_impl_rtt_http {
cb: CallbackWrapper,
rt: RuntimeRef,
req: hyper::Request<hyper::Body>,
scope: Scope
scope: Scope,
) -> oneshot::Receiver<PyResponse> {
let (tx, rx) = oneshot::channel();
let protocol = HTTPProtocol::new(rt, tx, req);
@ -334,7 +305,7 @@ macro_rules! call_impl_rtb_ws {
rt: RuntimeRef,
ws: HyperWebsocket,
upgrade: UpgradeData,
scope: Scope
scope: Scope,
) -> oneshot::Receiver<(i32, bool)> {
let (tx, rx) = oneshot::channel();
let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade);
@ -355,7 +326,7 @@ macro_rules! call_impl_rtt_ws {
rt: RuntimeRef,
ws: HyperWebsocket,
upgrade: UpgradeData,
scope: Scope
scope: Scope,
) -> oneshot::Receiver<(i32, bool)> {
let (tx, rx) = oneshot::channel();
let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade);

View file

@ -1,6 +1,5 @@
use pyo3::{create_exception, exceptions::PyRuntimeError};
create_exception!(_granian, RSGIProtocolError, PyRuntimeError, "RSGIProtocolError");
create_exception!(_granian, RSGIProtocolClosed, PyRuntimeError, "RSGIProtocolClosed");

View file

@ -1,34 +1,22 @@
use hyper::{
Body,
Request,
Response,
StatusCode,
header::SERVER as HK_SERVER,
http::response::Builder as ResponseBuilder
header::SERVER as HK_SERVER, http::response::Builder as ResponseBuilder, Body, Request, Response, StatusCode,
};
use std::net::SocketAddr;
use tokio::sync::mpsc;
use crate::{
callbacks::CallbackWrapper,
http::{HV_SERVER, response_500},
runtime::RuntimeRef,
ws::{UpgradeData, is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade}
};
use super::{
callbacks::{
call_rtb_http,
call_rtb_http_pyw,
call_rtb_ws,
call_rtb_ws_pyw,
call_rtt_http,
call_rtt_http_pyw,
call_rtt_ws,
call_rtt_ws_pyw
call_rtb_http, call_rtb_http_pyw, call_rtb_ws, call_rtb_ws_pyw, call_rtt_http, call_rtt_http_pyw, call_rtt_ws,
call_rtt_ws_pyw,
},
types::{RSGIScope as Scope, PyResponse}
types::{PyResponse, RSGIScope as Scope},
};
use crate::{
callbacks::CallbackWrapper,
http::{response_500, HV_SERVER},
runtime::RuntimeRef,
ws::{is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade, UpgradeData},
};
macro_rules! default_scope {
($server_addr:expr, $client_addr:expr, $req:expr, $scheme:expr) => {
@ -40,7 +28,7 @@ macro_rules! default_scope {
$req.method().as_ref(),
$server_addr,
$client_addr,
$req.headers()
$req.headers(),
)
};
}
@ -48,12 +36,8 @@ macro_rules! default_scope {
macro_rules! handle_http_response {
($handler:expr, $rt:expr, $callback:expr, $req:expr, $scope:expr) => {
match $handler($callback, $rt, $req, $scope).await {
Ok(PyResponse::Body(pyres)) => {
pyres.to_response()
},
Ok(PyResponse::File(pyres)) => {
pyres.to_response().await
},
Ok(PyResponse::Body(pyres)) => pyres.to_response(),
Ok(PyResponse::File(pyres)) => pyres.to_response().await,
_ => {
log::error!("RSGI protocol failure");
response_500()
@ -70,7 +54,7 @@ macro_rules! handle_request {
server_addr: SocketAddr,
client_addr: SocketAddr,
req: Request<Body>,
scheme: &str
scheme: &str,
) -> Response<Body> {
let scope = default_scope!(server_addr, client_addr, &req, scheme);
handle_http_response!($handler, rt, callback, req, scope)
@ -86,7 +70,7 @@ macro_rules! handle_request_with_ws {
server_addr: SocketAddr,
client_addr: SocketAddr,
req: Request<Body>,
scheme: &str
scheme: &str,
) -> Response<Body> {
let mut scope = default_scope!(server_addr, client_addr, &req, scheme);
@ -101,27 +85,23 @@ macro_rules! handle_request_with_ws {
rt.inner.spawn(async move {
let tx_ref = restx.clone();
match $handler_ws(
callback,
rth,
ws,
UpgradeData::new(res, restx),
scope
).await {
match $handler_ws(callback, rth, ws, UpgradeData::new(res, restx), scope).await {
Ok((status, consumed)) => {
if !consumed {
let _ = tx_ref.send(
ResponseBuilder::new()
.status(
StatusCode::from_u16(status as u16)
.unwrap_or(StatusCode::FORBIDDEN)
)
.header(HK_SERVER, HV_SERVER)
.body(Body::from(""))
.unwrap()
).await;
let _ = tx_ref
.send(
ResponseBuilder::new()
.status(
StatusCode::from_u16(status as u16)
.unwrap_or(StatusCode::FORBIDDEN),
)
.header(HK_SERVER, HV_SERVER)
.body(Body::from(""))
.unwrap(),
)
.await;
}
},
}
_ => {
log::error!("RSGI protocol failure");
let _ = tx_ref.send(response_500()).await;
@ -133,10 +113,10 @@ macro_rules! handle_request_with_ws {
Some(res) => {
resrx.close();
res
},
_ => response_500()
}
},
}
_ => response_500(),
};
}
Err(err) => {
return ResponseBuilder::new()
.status(StatusCode::BAD_REQUEST)
@ -149,7 +129,6 @@ macro_rules! handle_request_with_ws {
handle_http_response!($handler_req, rt, callback, req, scope)
}
};
}

View file

@ -1,35 +1,40 @@
use bytes::Bytes;
use futures::{sink::SinkExt, stream::{SplitSink, SplitStream, StreamExt}};
use hyper::{body::{Body, Sender as BodySender, HttpBody}, Request};
use futures::{
sink::SinkExt,
stream::{SplitSink, SplitStream, StreamExt},
};
use hyper::{
body::{Body, HttpBody, Sender as BodySender},
Request,
};
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyString};
use std::sync::Arc;
use tokio_tungstenite::WebSocketStream;
use tokio::sync::{oneshot, Mutex};
use tokio_tungstenite::WebSocketStream;
use tungstenite::Message;
use crate::{
runtime::{Runtime, RuntimeRef, future_into_py_iter, future_into_py_futlike},
ws::{HyperWebsocket, UpgradeData}
};
use super::{
errors::{error_proto, error_stream},
types::{PyResponse, PyResponseBody, PyResponseFile}
types::{PyResponse, PyResponseBody, PyResponseFile},
};
use crate::{
runtime::{future_into_py_futlike, future_into_py_iter, Runtime, RuntimeRef},
ws::{HyperWebsocket, UpgradeData},
};
#[pyclass(module="granian._granian")]
#[pyclass(module = "granian._granian")]
pub(crate) struct RSGIHTTPStreamTransport {
rt: RuntimeRef,
tx: Arc<Mutex<BodySender>>
tx: Arc<Mutex<BodySender>>,
}
impl RSGIHTTPStreamTransport {
pub fn new(
rt: RuntimeRef,
transport: BodySender
) -> Self {
Self { rt: rt, tx: Arc::new(Mutex::new(transport)) }
pub fn new(rt: RuntimeRef, transport: BodySender) -> Self {
Self {
rt,
tx: Arc::new(Mutex::new(transport)),
}
}
}
@ -41,8 +46,8 @@ impl RSGIHTTPStreamTransport {
if let Ok(mut stream) = transport.try_lock() {
return match stream.send_data(data.into()).await {
Ok(_) => Ok(()),
_ => error_stream!()
}
_ => error_stream!(),
};
}
error_proto!()
})
@ -54,31 +59,27 @@ impl RSGIHTTPStreamTransport {
if let Ok(mut stream) = transport.try_lock() {
return match stream.send_data(data.into()).await {
Ok(_) => Ok(()),
_ => error_stream!()
}
_ => error_stream!(),
};
}
error_proto!()
})
}
}
#[pyclass(module="granian._granian")]
#[pyclass(module = "granian._granian")]
pub(crate) struct RSGIHTTPProtocol {
rt: RuntimeRef,
tx: Option<oneshot::Sender<super::types::PyResponse>>,
body: Arc<Mutex<Body>>
body: Arc<Mutex<Body>>,
}
impl RSGIHTTPProtocol {
pub fn new(
rt: RuntimeRef,
tx: oneshot::Sender<super::types::PyResponse>,
request: Request<Body>
) -> Self {
pub fn new(rt: RuntimeRef, tx: oneshot::Sender<super::types::PyResponse>, request: Request<Body>) -> Self {
Self {
rt,
tx: Some(tx),
body: Arc::new(Mutex::new(request.into_body()))
body: Arc::new(Mutex::new(request.into_body())),
}
}
@ -110,14 +111,13 @@ impl RSGIHTTPProtocol {
let mut bodym = body_ref.lock().await;
let body = &mut *bodym;
if body.is_end_stream() {
return Err(pyo3::exceptions::PyStopAsyncIteration::new_err("stream exhausted"))
return Err(pyo3::exceptions::PyStopAsyncIteration::new_err("stream exhausted"));
}
let chunk = body.data().await.map_or_else(|| Bytes::new(), |buf| {
buf.unwrap_or_else(|_| Bytes::new())
});
Ok(Python::with_gil(|py| {
PyBytes::new(py, &chunk[..]).to_object(py)
}))
let chunk = body
.data()
.await
.map_or_else(Bytes::new, |buf| buf.unwrap_or_else(|_| Bytes::new()));
Ok(Python::with_gil(|py| PyBytes::new(py, &chunk[..]).to_object(py)))
})?;
Ok(Some(fut))
}
@ -125,36 +125,28 @@ impl RSGIHTTPProtocol {
#[pyo3(signature = (status=200, headers=vec![]))]
fn response_empty(&mut self, status: u16, headers: Vec<(String, String)>) {
if let Some(tx) = self.tx.take() {
let _ = tx.send(
PyResponse::Body(PyResponseBody::empty(status, headers))
);
let _ = tx.send(PyResponse::Body(PyResponseBody::empty(status, headers)));
}
}
#[pyo3(signature = (status=200, headers=vec![], body=vec![]))]
fn response_bytes(&mut self, status: u16, headers: Vec<(String, String)>, body: Vec<u8>) {
if let Some(tx) = self.tx.take() {
let _ = tx.send(
PyResponse::Body(PyResponseBody::from_bytes(status, headers, body))
);
let _ = tx.send(PyResponse::Body(PyResponseBody::from_bytes(status, headers, body)));
}
}
#[pyo3(signature = (status=200, headers=vec![], body="".to_string()))]
#[pyo3(signature = (status=200, headers=vec![], body=String::new()))]
fn response_str(&mut self, status: u16, headers: Vec<(String, String)>, body: String) {
if let Some(tx) = self.tx.take() {
let _ = tx.send(
PyResponse::Body(PyResponseBody::from_string(status, headers, body))
);
let _ = tx.send(PyResponse::Body(PyResponseBody::from_string(status, headers, body)));
}
}
#[pyo3(signature = (status, headers, file))]
fn response_file(&mut self, status: u16, headers: Vec<(String, String)>, file: String) {
if let Some(tx) = self.tx.take() {
let _ = tx.send(
PyResponse::File(PyResponseFile::new(status, headers, file))
);
let _ = tx.send(PyResponse::File(PyResponseFile::new(status, headers, file)));
}
}
@ -163,34 +155,33 @@ impl RSGIHTTPProtocol {
&mut self,
py: Python<'p>,
status: u16,
headers: Vec<(String, String)>
headers: Vec<(String, String)>,
) -> PyResult<&'p PyAny> {
if let Some(tx) = self.tx.take() {
let (body_tx, body_stream) = Body::channel();
let _ = tx.send(
PyResponse::Body(PyResponseBody::new(status, headers, body_stream))
);
let _ = tx.send(PyResponse::Body(PyResponseBody::new(status, headers, body_stream)));
let trx = Py::new(py, RSGIHTTPStreamTransport::new(self.rt.clone(), body_tx))?;
return Ok(trx.into_ref(py))
return Ok(trx.into_ref(py));
}
error_proto!()
}
}
#[pyclass(module="granian._granian")]
#[pyclass(module = "granian._granian")]
pub(crate) struct RSGIWebsocketTransport {
rt: RuntimeRef,
tx: Arc<Mutex<SplitSink<WebSocketStream<hyper::upgrade::Upgraded>, Message>>>,
rx: Arc<Mutex<SplitStream<WebSocketStream<hyper::upgrade::Upgraded>>>>
rx: Arc<Mutex<SplitStream<WebSocketStream<hyper::upgrade::Upgraded>>>>,
}
impl RSGIWebsocketTransport {
pub fn new(
rt: RuntimeRef,
transport: WebSocketStream<hyper::upgrade::Upgraded>
) -> Self {
pub fn new(rt: RuntimeRef, transport: WebSocketStream<hyper::upgrade::Upgraded>) -> Self {
let (tx, rx) = transport.split();
Self { rt: rt, tx: Arc::new(Mutex::new(tx)), rx: Arc::new(Mutex::new(rx)) }
Self {
rt,
tx: Arc::new(Mutex::new(tx)),
rx: Arc::new(Mutex::new(rx)),
}
}
pub fn close(&self) {
@ -209,27 +200,14 @@ impl RSGIWebsocketTransport {
let transport = self.rx.clone();
future_into_py_futlike(self.rt.clone(), py, async move {
if let Ok(mut stream) = transport.try_lock() {
loop {
match stream.next().await {
Some(recv) => {
match recv {
Ok(Message::Ping(_)) => {
continue
},
Ok(message) => {
return message_into_py(message)
},
_ => {
break
}
}
},
_ => {
break
}
while let Some(recv) = stream.next().await {
match recv {
Ok(Message::Ping(_)) => continue,
Ok(message) => return message_into_py(message),
_ => break,
}
}
return error_stream!()
return error_stream!();
}
error_proto!()
})
@ -241,8 +219,8 @@ impl RSGIWebsocketTransport {
if let Ok(mut stream) = transport.try_lock() {
return match stream.send(Message::Binary(data)).await {
Ok(_) => Ok(()),
_ => error_stream!()
}
_ => error_stream!(),
};
}
error_proto!()
})
@ -254,22 +232,22 @@ impl RSGIWebsocketTransport {
if let Ok(mut stream) = transport.try_lock() {
return match stream.send(Message::Text(data)).await {
Ok(_) => Ok(()),
_ => error_stream!()
}
_ => error_stream!(),
};
}
error_proto!()
})
}
}
#[pyclass(module="granian._granian")]
#[pyclass(module = "granian._granian")]
pub(crate) struct RSGIWebsocketProtocol {
rt: RuntimeRef,
tx: Option<oneshot::Sender<(i32, bool)>>,
websocket: Arc<Mutex<HyperWebsocket>>,
upgrade: Option<UpgradeData>,
transport: Arc<Mutex<Option<Py<RSGIWebsocketTransport>>>>,
status: i32
status: i32,
}
impl RSGIWebsocketProtocol {
@ -277,7 +255,7 @@ impl RSGIWebsocketProtocol {
rt: RuntimeRef,
tx: oneshot::Sender<(i32, bool)>,
websocket: HyperWebsocket,
upgrade: UpgradeData
upgrade: UpgradeData,
) -> Self {
Self {
rt,
@ -285,15 +263,12 @@ impl RSGIWebsocketProtocol {
websocket: Arc::new(Mutex::new(websocket)),
upgrade: Some(upgrade),
transport: Arc::new(Mutex::new(None)),
status: 0
status: 0,
}
}
fn consumed(&self) -> bool {
match &self.upgrade {
Some(_) => false,
_ => true
}
self.upgrade.is_none()
}
pub fn tx(&mut self) -> (Option<oneshot::Sender<(i32, bool)>>, (i32, bool)) {
@ -304,18 +279,20 @@ impl RSGIWebsocketProtocol {
enum WebsocketMessageType {
Close = 0,
Bytes = 1,
Text = 2
Text = 2,
}
#[pyclass]
struct WebsocketInboundCloseMessage {
#[pyo3(get)]
kind: usize
kind: usize,
}
impl WebsocketInboundCloseMessage {
pub fn new() -> Self {
Self { kind: WebsocketMessageType::Close as usize }
Self {
kind: WebsocketMessageType::Close as usize,
}
}
}
@ -324,12 +301,15 @@ struct WebsocketInboundBytesMessage {
#[pyo3(get)]
kind: usize,
#[pyo3(get)]
data: Py<PyBytes>
data: Py<PyBytes>,
}
impl WebsocketInboundBytesMessage {
pub fn new(data:Py<PyBytes>) -> Self {
Self { kind: WebsocketMessageType::Bytes as usize, data: data }
pub fn new(data: Py<PyBytes>) -> Self {
Self {
kind: WebsocketMessageType::Bytes as usize,
data,
}
}
}
@ -338,12 +318,15 @@ struct WebsocketInboundTextMessage {
#[pyo3(get)]
kind: usize,
#[pyo3(get)]
data: Py<PyString>
data: Py<PyString>,
}
impl WebsocketInboundTextMessage {
pub fn new(data: Py<PyString>) -> Self {
Self { kind: WebsocketMessageType::Text as usize, data: data }
Self {
kind: WebsocketMessageType::Text as usize,
data,
}
}
}
@ -374,23 +357,18 @@ impl RSGIWebsocketProtocol {
future_into_py_iter(self.rt.clone(), py, async move {
let mut ws = transport.lock().await;
match upgrade.send().await {
Ok(_) => {
match (&mut *ws).await {
Ok(stream) => {
let mut trx = itransport.lock().await;
Ok(Python::with_gil(|py| {
let pytransport = Py::new(
py,
RSGIWebsocketTransport::new(rth, stream)
).unwrap();
*trx = Some(pytransport.clone());
pytransport
}))
},
_ => error_proto!()
Ok(_) => match (&mut *ws).await {
Ok(stream) => {
let mut trx = itransport.lock().await;
Ok(Python::with_gil(|py| {
let pytransport = Py::new(py, RSGIWebsocketTransport::new(rth, stream)).unwrap();
*trx = Some(pytransport.clone());
pytransport
}))
}
_ => error_proto!(),
},
_ => error_proto!()
_ => error_proto!(),
}
})
}
@ -399,25 +377,13 @@ impl RSGIWebsocketProtocol {
#[inline(always)]
fn message_into_py(message: Message) -> PyResult<PyObject> {
match message {
Message::Binary(message) => {
Ok(Python::with_gil(|py| {
WebsocketInboundBytesMessage::new(
PyBytes::new(py, &message).into()
).into_py(py)
}))
},
Message::Text(message) => {
Ok(Python::with_gil(|py| {
WebsocketInboundTextMessage::new(
PyString::new(py, &message).into()
).into_py(py)
}))
},
Message::Close(_) => {
Ok(Python::with_gil(|py| {
WebsocketInboundCloseMessage::new().into_py(py)
}))
}
Message::Binary(message) => Ok(Python::with_gil(|py| {
WebsocketInboundBytesMessage::new(PyBytes::new(py, &message).into()).into_py(py)
})),
Message::Text(message) => Ok(Python::with_gil(|py| {
WebsocketInboundTextMessage::new(PyString::new(py, &message).into()).into_py(py)
})),
Message::Close(_) => Ok(Python::with_gil(|py| WebsocketInboundCloseMessage::new().into_py(py))),
v => {
log::warn!("Unsupported websocket message received {:?}", v);
error_proto!()

View file

@ -1,28 +1,14 @@
use pyo3::prelude::*;
use crate::{
workers::{
WorkerConfig,
serve_rth,
serve_wth,
serve_rth_ssl,
serve_wth_ssl
}
};
use super::http::{
handle_rtb,
handle_rtb_pyw,
handle_rtt,
handle_rtt_pyw,
handle_rtb_ws,
handle_rtb_ws_pyw,
handle_rtt_ws,
handle_rtt_ws_pyw
handle_rtb, handle_rtb_pyw, handle_rtb_ws, handle_rtb_ws_pyw, handle_rtt, handle_rtt_pyw, handle_rtt_ws,
handle_rtt_ws_pyw,
};
use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig};
#[pyclass(module="granian._granian")]
#[pyclass(module = "granian._granian")]
pub struct RSGIWorker {
config: WorkerConfig
config: WorkerConfig,
}
impl RSGIWorker {
@ -73,7 +59,7 @@ impl RSGIWorker {
opt_enabled: bool,
ssl_enabled: bool,
ssl_cert: Option<&str>,
ssl_key: Option<&str>
ssl_key: Option<&str>,
) -> PyResult<Self> {
Ok(Self {
config: WorkerConfig::new(
@ -87,22 +73,16 @@ impl RSGIWorker {
opt_enabled,
ssl_enabled,
ssl_cert,
ssl_key
)
ssl_key,
),
})
}
fn serve_rth(
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
fn serve_rth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
match (
self.config.websockets_enabled,
self.config.ssl_enabled,
self.config.opt_enabled
self.config.opt_enabled,
) {
(false, false, true) => self._serve_rth(callback, event_loop, context, signal_rx),
(false, false, false) => self._serve_rth_pyw(callback, event_loop, context, signal_rx),
@ -111,21 +91,15 @@ impl RSGIWorker {
(false, true, true) => self._serve_rth_ssl(callback, event_loop, context, signal_rx),
(false, true, false) => self._serve_rth_ssl_pyw(callback, event_loop, context, signal_rx),
(true, true, true) => self._serve_rth_ssl_ws(callback, event_loop, context, signal_rx),
(true, true, false) => self._serve_rth_ssl_ws_pyw(callback, event_loop, context, signal_rx)
(true, true, false) => self._serve_rth_ssl_ws_pyw(callback, event_loop, context, signal_rx),
}
}
fn serve_wth(
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
fn serve_wth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
match (
self.config.websockets_enabled,
self.config.ssl_enabled,
self.config.opt_enabled
self.config.opt_enabled,
) {
(false, false, true) => self._serve_wth(callback, event_loop, context, signal_rx),
(false, false, false) => self._serve_wth_pyw(callback, event_loop, context, signal_rx),
@ -134,7 +108,7 @@ impl RSGIWorker {
(false, true, true) => self._serve_wth_ssl(callback, event_loop, context, signal_rx),
(false, true, false) => self._serve_wth_ssl_pyw(callback, event_loop, context, signal_rx),
(true, true, true) => self._serve_wth_ssl_ws(callback, event_loop, context, signal_rx),
(true, true, false) => self._serve_wth_ssl_ws_pyw(callback, event_loop, context, signal_rx)
(true, true, false) => self._serve_wth_ssl_ws_pyw(callback, event_loop, context, signal_rx),
}
}
}

View file

@ -1,6 +1,6 @@
use hyper::{
header::{HeaderMap, HeaderName, HeaderValue, SERVER as HK_SERVER},
Body, Uri, Version
Body, Uri, Version,
};
use pyo3::prelude::*;
use pyo3::types::PyString;
@ -10,11 +10,10 @@ use tokio_util::codec::{BytesCodec, FramedRead};
use crate::http::HV_SERVER;
#[pyclass(module="granian._granian")]
#[pyclass(module = "granian._granian")]
#[derive(Clone)]
pub(crate) struct RSGIHeaders {
inner: HeaderMap
inner: HeaderMap,
}
impl RSGIHeaders {
@ -29,7 +28,7 @@ impl RSGIHeaders {
let mut ret = Vec::with_capacity(self.inner.keys_len());
for key in self.inner.keys() {
ret.push(key.as_str());
};
}
ret
}
@ -37,15 +36,15 @@ impl RSGIHeaders {
let mut ret = Vec::with_capacity(self.inner.keys_len());
for val in self.inner.values() {
ret.push(val.to_str().unwrap());
};
}
Ok(ret)
}
fn items(&self) -> PyResult<Vec<(&str, &str)>> {
let mut ret = Vec::with_capacity(self.inner.keys_len());
for (key, val) in self.inner.iter() {
for (key, val) in &self.inner {
ret.push((key.as_str(), val.to_str().unwrap()));
};
}
Ok(ret)
}
@ -56,18 +55,16 @@ impl RSGIHeaders {
#[pyo3(signature = (key, default=None))]
fn get(&self, py: Python, key: &str, default: Option<PyObject>) -> Option<PyObject> {
match self.inner.get(key) {
Some(val) => {
match val.to_str() {
Ok(string) => Some(PyString::new(py, string).into()),
_ => default
}
Some(val) => match val.to_str() {
Ok(string) => Some(PyString::new(py, string).into()),
_ => default,
},
_ => default
_ => default,
}
}
}
#[pyclass(module="granian._granian")]
#[pyclass(module = "granian._granian")]
pub(crate) struct RSGIScope {
#[pyo3(get)]
proto: String,
@ -84,7 +81,7 @@ pub(crate) struct RSGIScope {
#[pyo3(get)]
client: String,
#[pyo3(get)]
headers: RSGIHeaders
headers: RSGIHeaders,
}
impl RSGIScope {
@ -96,23 +93,23 @@ impl RSGIScope {
method: &str,
server: SocketAddr,
client: SocketAddr,
headers: &HeaderMap
headers: &HeaderMap,
) -> Self {
Self {
proto: proto.to_string(),
http_version: http_version,
http_version,
rsgi_version: "1.2".to_string(),
scheme: scheme.to_string(),
method: method.to_string(),
uri: uri,
uri,
server: server.to_string(),
client: client.to_string(),
headers: RSGIHeaders::new(headers)
headers: RSGIHeaders::new(headers),
}
}
pub fn set_proto(&mut self, value: &str) {
self.proto = value.to_string()
self.proto = value.to_string();
}
}
@ -125,7 +122,7 @@ impl RSGIScope {
Version::HTTP_11 => "1.1",
Version::HTTP_2 => "2",
Version::HTTP_3 => "3",
_ => "1"
_ => "1",
}
}
@ -142,38 +139,36 @@ impl RSGIScope {
pub(crate) enum PyResponse {
Body(PyResponseBody),
File(PyResponseFile)
File(PyResponseFile),
}
pub(crate) struct PyResponseBody {
status: u16,
headers: Vec<(String, String)>,
body: Body
body: Body,
}
pub(crate) struct PyResponseFile {
status: u16,
headers: Vec<(String, String)>,
file_path: String
file_path: String,
}
macro_rules! response_head_from_py {
($status:expr, $headers:expr, $res:expr) => {
{
let mut rh = hyper::http::HeaderMap::new();
($status:expr, $headers:expr, $res:expr) => {{
let mut rh = hyper::http::HeaderMap::new();
rh.insert(HK_SERVER, HV_SERVER);
for (key, value) in $headers {
rh.append(
HeaderName::from_bytes(key.as_bytes()).unwrap(),
HeaderValue::from_str(&value).unwrap()
);
}
*$res.status_mut() = $status.try_into().unwrap();
*$res.headers_mut() = rh;
rh.insert(HK_SERVER, HV_SERVER);
for (key, value) in $headers {
rh.append(
HeaderName::from_bytes(key.as_bytes()).unwrap(),
HeaderValue::from_str(&value).unwrap(),
);
}
}
*$res.status_mut() = $status.try_into().unwrap();
*$res.headers_mut() = rh;
}};
}
impl PyResponseBody {
@ -182,18 +177,30 @@ impl PyResponseBody {
}
pub fn empty(status: u16, headers: Vec<(String, String)>) -> Self {
Self { status, headers, body: Body::empty() }
Self {
status,
headers,
body: Body::empty(),
}
}
pub fn from_bytes(status: u16, headers: Vec<(String, String)>, body: Vec<u8>) -> Self {
Self { status, headers, body: Body::from(body) }
Self {
status,
headers,
body: Body::from(body),
}
}
pub fn from_string(status: u16, headers: Vec<(String, String)>, body: String) -> Self {
Self { status, headers, body: Body::from(body) }
Self {
status,
headers,
body: Body::from(body),
}
}
pub fn to_response(self) -> hyper::Response::<Body> {
pub fn to_response(self) -> hyper::Response<Body> {
let mut res = hyper::Response::<Body>::new(self.body);
response_head_from_py!(self.status, &self.headers, res);
res
@ -202,10 +209,14 @@ impl PyResponseBody {
impl PyResponseFile {
pub fn new(status: u16, headers: Vec<(String, String)>, file_path: String) -> Self {
Self { status, headers, file_path }
Self {
status,
headers,
file_path,
}
}
pub async fn to_response(&self) -> hyper::Response::<Body> {
pub async fn to_response(&self) -> hyper::Response<Body> {
let file = File::open(&self.file_path).await.unwrap();
let stream = FramedRead::new(file, BytesCodec::new());
let mut res = hyper::Response::<Body>::new(Body::wrap_stream(stream));

View file

@ -1,12 +1,19 @@
use once_cell::unsync::OnceCell as UnsyncOnceCell;
use pyo3_asyncio::TaskLocals;
use pyo3::prelude::*;
use std::{future::Future, io, pin::Pin, sync::{Arc, Mutex}};
use tokio::{runtime::Builder, task::{JoinHandle, LocalSet}};
use pyo3_asyncio::TaskLocals;
use std::{
future::Future,
io,
pin::Pin,
sync::{Arc, Mutex},
};
use tokio::{
runtime::Builder,
task::{JoinHandle, LocalSet},
};
use super::callbacks::{PyFutureAwaitable, PyIterAwaitable};
tokio::task_local! {
static TASK_LOCALS: UnsyncOnceCell<TaskLocals>;
}
@ -27,11 +34,7 @@ pub trait Runtime: Send + 'static {
}
pub trait ContextExt: Runtime {
fn scope<F, R>(
&self,
locals: TaskLocals,
fut: F
) -> Pin<Box<dyn Future<Output = R> + Send>>
fn scope<F, R>(&self, locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
where
F: Future<Output = R> + Send + 'static;
@ -45,36 +48,34 @@ pub trait SpawnLocalExt: Runtime {
}
pub trait LocalContextExt: Runtime {
fn scope_local<F, R>(
&self,
locals: TaskLocals,
fut: F
) -> Pin<Box<dyn Future<Output = R>>>
fn scope_local<F, R>(&self, locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R>>>
where
F: Future<Output = R> + 'static;
}
pub(crate) struct RuntimeWrapper {
rt: tokio::runtime::Runtime
rt: tokio::runtime::Runtime,
}
impl RuntimeWrapper {
pub fn new(blocking_threads: usize) -> Self {
Self { rt: default_runtime(blocking_threads).unwrap() }
Self {
rt: default_runtime(blocking_threads).unwrap(),
}
}
pub fn with_runtime(rt: tokio::runtime::Runtime) -> Self {
Self { rt: rt }
Self { rt }
}
pub fn handler(&self) -> RuntimeRef {
RuntimeRef::new(self.rt.handle().to_owned())
RuntimeRef::new(self.rt.handle().clone())
}
}
#[derive(Clone)]
pub struct RuntimeRef {
pub inner: tokio::runtime::Handle
pub inner: tokio::runtime::Handle,
}
impl RuntimeRef {
@ -108,11 +109,7 @@ impl Runtime for RuntimeRef {
}
impl ContextExt for RuntimeRef {
fn scope<F, R>(
&self,
locals: TaskLocals,
fut: F
) -> Pin<Box<dyn Future<Output = R> + Send>>
fn scope<F, R>(&self, locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
where
F: Future<Output = R> + Send + 'static,
{
@ -123,7 +120,7 @@ impl ContextExt for RuntimeRef {
}
fn get_task_locals() -> Option<TaskLocals> {
match TASK_LOCALS.try_with(|c| c.get().map(|locals| locals.clone())) {
match TASK_LOCALS.try_with(|c| c.get().cloned()) {
Ok(locals) => locals,
Err(_) => None,
}
@ -140,11 +137,7 @@ impl SpawnLocalExt for RuntimeRef {
}
impl LocalContextExt for RuntimeRef {
fn scope_local<F, R>(
&self,
locals: TaskLocals,
fut: F
) -> Pin<Box<dyn Future<Output = R>>>
fn scope_local<F, R>(&self, locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R>>>
where
F: Future<Output = R> + 'static,
{
@ -169,7 +162,7 @@ pub(crate) fn init_runtime_mt(threads: usize, blocking_threads: usize) -> Runtim
.max_blocking_threads(blocking_threads)
.enable_all()
.build()
.unwrap()
.unwrap(),
)
}
@ -177,12 +170,8 @@ pub(crate) fn init_runtime_st(blocking_threads: usize) -> RuntimeWrapper {
RuntimeWrapper::new(blocking_threads)
}
pub(crate) fn into_future(
awaitable: &PyAny,
) -> PyResult<impl Future<Output = PyResult<PyObject>> + Send> {
pyo3_asyncio::into_future_with_locals(
&get_current_locals::<RuntimeRef>(awaitable.py())?, awaitable
)
pub(crate) fn into_future(awaitable: &PyAny) -> PyResult<impl Future<Output = PyResult<PyObject>> + Send> {
pyo3_asyncio::into_future_with_locals(&get_current_locals::<RuntimeRef>(awaitable.py())?, awaitable)
}
#[inline]
@ -241,10 +230,7 @@ where
rt.spawn(async move {
let result = fut.await;
Python::with_gil(move |py| {
PyFutureAwaitable::set_result(
py_aw.as_ref(py).borrow_mut(),
result.map(|v| v.into_py(py))
);
PyFutureAwaitable::set_result(py_aw.as_ref(py).borrow_mut(), result.map(|v| v.into_py(py)));
});
});
@ -269,11 +255,7 @@ where
let rth = rt.handler();
rt.spawn(async move {
let val = rth.scope(
task_locals.clone(),
fut
)
.await;
let val = rth.scope(task_locals.clone(), fut).await;
if let Ok(mut result) = result_tx.lock() {
*result = Some(val.unwrap());
}
@ -292,7 +274,7 @@ where
pub(crate) fn block_on_local<F>(rt: RuntimeWrapper, local: LocalSet, fut: F)
where
F: Future + 'static
F: Future + 'static,
{
local.block_on(&rt.rt, fut);
}

View file

@ -9,10 +9,9 @@ use std::os::windows::io::{AsRawSocket, FromRawSocket};
use socket2::{Domain, Protocol, Socket, Type};
#[pyclass(module="granian._granian")]
#[pyclass(module = "granian._granian")]
pub struct ListenerHolder {
socket: TcpListener
socket: TcpListener,
}
#[pymethods]
@ -20,28 +19,19 @@ impl ListenerHolder {
#[cfg(unix)]
#[new]
pub fn new(fd: i32) -> PyResult<Self> {
let socket = unsafe {
TcpListener::from_raw_fd(fd)
};
Ok(Self { socket: socket })
let socket = unsafe { TcpListener::from_raw_fd(fd) };
Ok(Self { socket })
}
#[cfg(windows)]
#[new]
pub fn new(fd: u64) -> PyResult<Self> {
let socket = unsafe {
TcpListener::from_raw_socket(fd)
};
Ok(Self { socket: socket })
let socket = unsafe { TcpListener::from_raw_socket(fd) };
Ok(Self { socket })
}
#[classmethod]
pub fn from_address(
_cls: &PyType,
address: &str,
port: u16,
backlog: i32
) -> PyResult<Self> {
pub fn from_address(_cls: &PyType, address: &str, port: u16, backlog: i32) -> PyResult<Self> {
let address: SocketAddr = (address.parse::<IpAddr>()?, port).into();
let socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?;
socket.set_reuse_address(true)?;
@ -54,17 +44,13 @@ impl ListenerHolder {
#[cfg(unix)]
pub fn __getstate__(&self, py: Python) -> PyObject {
let fd = self.socket.as_raw_fd();
(
fd.into_py(py),
).to_object(py)
(fd.into_py(py),).to_object(py)
}
#[cfg(windows)]
pub fn __getstate__(&self, py: Python) -> PyObject {
let fd = self.socket.as_raw_socket();
(
fd.into_py(py),
).to_object(py)
(fd.into_py(py),).to_object(py)
}
#[cfg(unix)]
@ -84,7 +70,6 @@ impl ListenerHolder {
}
}
pub(crate) fn init_pymodule(module: &PyModule) -> PyResult<()> {
module.add_class::<ListenerHolder>()?;

View file

@ -1,20 +1,22 @@
use futures::stream::StreamExt;
use hyper::server::{accept, conn::{AddrIncoming, AddrStream}};
use hyper::server::{
accept,
conn::{AddrIncoming, AddrStream},
};
use std::{fs, future, io, iter::Iterator, net::TcpListener, sync::Arc};
use tls_listener::{Error as TlsError, TlsListener};
use tokio_rustls::{
TlsAcceptor,
rustls::{Certificate, PrivateKey, ServerConfig},
server::TlsStream
server::TlsStream,
TlsAcceptor,
};
pub(crate) type TlsAddrStream = TlsStream<AddrStream>;
pub(crate) fn tls_listen(
config: Arc<ServerConfig>,
tcp: TcpListener
) -> impl accept::Accept<Conn=TlsAddrStream, Error=TlsError<io::Error, io::Error>> {
tcp: TcpListener,
) -> impl accept::Accept<Conn = TlsAddrStream, Error = TlsError<io::Error, io::Error>> {
tcp.set_nonblocking(true).unwrap();
let tcp_listener = tokio::net::TcpListener::from_std(tcp).unwrap();
let incoming = AddrIncoming::from_listener(tcp_listener).unwrap();
@ -34,30 +36,26 @@ fn tls_error(err: String) -> io::Error {
}
pub(crate) fn load_certs(filename: &str) -> io::Result<Vec<Certificate>> {
let certfile = fs::File::open(filename)
.map_err(|e| tls_error(format!("failed to open {}: {}", filename, e)))?;
let certfile = fs::File::open(filename).map_err(|e| tls_error(format!("failed to open {filename}: {e}")))?;
let mut reader = io::BufReader::new(certfile);
let certs = rustls_pemfile::certs(&mut reader)
.map_err(|_| tls_error("failed to load certificate".into()))?;
let certs = rustls_pemfile::certs(&mut reader).map_err(|_| tls_error("failed to load certificate".into()))?;
Ok(certs.into_iter().map(Certificate).collect())
}
pub(crate) fn load_private_key(filename: &str) -> io::Result<PrivateKey> {
let keyfile = fs::File::open(filename)
.map_err(|e| tls_error(format!("failed to open {}: {}", filename, e)))?;
let keyfile = fs::File::open(filename).map_err(|e| tls_error(format!("failed to open {filename}: {e}")))?;
let mut reader = io::BufReader::new(keyfile);
let keys = rustls_pemfile::read_all(&mut reader)
.map_err(|_| tls_error("failed to load private key".into()))?;
let keys = rustls_pemfile::read_all(&mut reader).map_err(|_| tls_error("failed to load private key".into()))?;
if keys.len() != 1 {
return Err(tls_error("expected a single private key".into()));
}
let key = match &keys[0] {
rustls_pemfile::Item::RSAKey(key) => PrivateKey(key.to_vec()),
rustls_pemfile::Item::PKCS8Key(key) => PrivateKey(key.to_vec()),
rustls_pemfile::Item::ECKey(key) => PrivateKey(key.to_vec()),
rustls_pemfile::Item::RSAKey(key) => PrivateKey(key.clone()),
rustls_pemfile::Item::PKCS8Key(key) => PrivateKey(key.clone()),
rustls_pemfile::Item::ECKey(key) => PrivateKey(key.clone()),
_ => {
return Err(tls_error("failed to load private key".into()));
}

View file

@ -1,13 +1,15 @@
pub(crate) fn header_contains_value(
headers: &hyper::HeaderMap,
header: impl hyper::header::AsHeaderName,
value: impl AsRef<[u8]>
value: impl AsRef<[u8]>,
) -> bool {
let value = value.as_ref();
for header in headers.get_all(header) {
if header.as_bytes().split(|&c| c == b',').any(
|x| trim(x).eq_ignore_ascii_case(value)
) {
if header
.as_bytes()
.split(|&c| c == b',')
.any(|x| trim(x).eq_ignore_ascii_case(value))
{
return true;
}
}
@ -30,7 +32,7 @@ fn trim_start(data: &[u8]) -> &[u8] {
#[inline]
fn trim_end(data: &[u8]) -> &[u8] {
if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) {
&data[..last + 1]
&data[..=last]
} else {
b""
}

View file

@ -8,8 +8,8 @@ use std::os::windows::io::FromRawSocket;
use super::asgi::serve::ASGIWorker;
use super::rsgi::serve::RSGIWorker;
use super::wsgi::serve::WSGIWorker;
use super::tls::{load_certs as tls_load_certs, load_private_key as tls_load_pkey};
use super::wsgi::serve::WSGIWorker;
pub(crate) struct WorkerConfig {
pub id: i32,
@ -22,7 +22,7 @@ pub(crate) struct WorkerConfig {
pub opt_enabled: bool,
pub ssl_enabled: bool,
ssl_cert: Option<String>,
ssl_key: Option<String>
ssl_key: Option<String>,
}
impl WorkerConfig {
@ -37,7 +37,7 @@ impl WorkerConfig {
opt_enabled: bool,
ssl_enabled: bool,
ssl_cert: Option<&str>,
ssl_key: Option<&str>
ssl_key: Option<&str>,
) -> Self {
Self {
id,
@ -49,23 +49,19 @@ impl WorkerConfig {
websockets_enabled,
opt_enabled,
ssl_enabled,
ssl_cert: ssl_cert.map_or(None, |v| Some(v.into())),
ssl_key: ssl_key.map_or(None, |v| Some(v.into()))
ssl_cert: ssl_cert.map(std::convert::Into::into),
ssl_key: ssl_key.map(std::convert::Into::into),
}
}
#[cfg(unix)]
pub fn tcp_listener(&self) -> TcpListener {
unsafe {
TcpListener::from_raw_fd(self.socket_fd)
}
unsafe { TcpListener::from_raw_fd(self.socket_fd) }
}
#[cfg(windows)]
pub fn tcp_listener(&self) -> TcpListener {
unsafe {
TcpListener::from_raw_socket(self.socket_fd as u64)
}
unsafe { TcpListener::from_raw_socket(self.socket_fd as u64) }
}
pub fn tls_cfg(&self) -> tokio_rustls::rustls::ServerConfig {
@ -74,13 +70,13 @@ impl WorkerConfig {
.with_no_client_auth()
.with_single_cert(
tls_load_certs(&self.ssl_cert.clone().unwrap()[..]).unwrap(),
tls_load_pkey(&self.ssl_key.clone().unwrap()[..]).unwrap()
tls_load_pkey(&self.ssl_key.clone().unwrap()[..]).unwrap(),
)
.unwrap();
cfg.alpn_protocols = match &self.http_mode[..] {
"1" => vec![b"http/1.1".to_vec()],
"2" => vec![b"h2".to_vec()],
_ => vec![b"h2".to_vec(), b"http/1.1".to_vec()]
_ => vec![b"h2".to_vec(), b"http/1.1".to_vec()],
};
cfg
}
@ -102,7 +98,7 @@ pub(crate) struct WorkerExecutor;
impl<F> hyper::rt::Executor<F> for WorkerExecutor
where
F: std::future::Future + 'static
F: std::future::Future + 'static,
{
fn execute(&self, fut: F) {
tokio::task::spawn_local(fut);
@ -123,14 +119,9 @@ macro_rules! build_service {
let rth = rth.clone();
async move {
Ok::<_, std::convert::Infallible>($target(
rth,
callback_wrapper,
local_addr,
remote_addr,
req,
"http"
).await)
Ok::<_, std::convert::Infallible>(
$target(rth, callback_wrapper, local_addr, remote_addr, req, "http").await,
)
}
}))
}
@ -153,14 +144,9 @@ macro_rules! build_service_ssl {
let rth = rth.clone();
async move {
Ok::<_, std::convert::Infallible>($target(
rth,
callback_wrapper,
local_addr,
remote_addr,
req,
"https"
).await)
Ok::<_, std::convert::Infallible>(
$target(rth, callback_wrapper, local_addr, remote_addr, req, "https").await,
)
}
}))
}
@ -170,13 +156,7 @@ macro_rules! build_service_ssl {
macro_rules! serve_rth {
($func_name:ident, $target:expr) => {
fn $func_name(
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
fn $func_name(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
pyo3_log::init();
let rt = crate::runtime::init_runtime_mt(self.config.threads, self.config.pthreads);
let rth = rt.handler();
@ -184,34 +164,30 @@ macro_rules! serve_rth {
let http1_only = self.config.http_mode == "1";
let http2_only = self.config.http_mode == "2";
let http1_buffer_max = self.config.http1_buffer_max.clone();
let callback_wrapper = crate::callbacks::CallbackWrapper::new(
callback, event_loop, context
);
let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context);
let worker_id = self.config.id;
log::info!("Started worker-{}", worker_id);
let svc_loop = crate::runtime::run_until_complete(
rt.handler(),
event_loop,
async move {
let service = crate::workers::build_service!(
callback_wrapper, rth, $target
);
let server = hyper::Server::from_tcp(tcp_listener).unwrap()
.http1_only(http1_only)
.http2_only(http2_only)
.http1_max_buf_size(http1_buffer_max)
.serve(service);
server.with_graceful_shutdown(async move {
Python::with_gil(|py| {
crate::runtime::into_future(signal_rx.as_ref(py)).unwrap()
}).await.unwrap();
}).await.unwrap();
log::info!("Stopping worker-{}", worker_id);
Ok(())
}
);
let svc_loop = crate::runtime::run_until_complete(rt.handler(), event_loop, async move {
let service = crate::workers::build_service!(callback_wrapper, rth, $target);
let server = hyper::Server::from_tcp(tcp_listener)
.unwrap()
.http1_only(http1_only)
.http2_only(http2_only)
.http1_max_buf_size(http1_buffer_max)
.serve(service);
server
.with_graceful_shutdown(async move {
Python::with_gil(|py| crate::runtime::into_future(signal_rx.as_ref(py)).unwrap())
.await
.unwrap();
})
.await
.unwrap();
log::info!("Stopping worker-{}", worker_id);
Ok(())
});
match svc_loop {
Ok(_) => {}
@ -226,13 +202,7 @@ macro_rules! serve_rth {
macro_rules! serve_rth_ssl {
($func_name:ident, $target:expr) => {
fn $func_name(
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
fn $func_name(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
pyo3_log::init();
let rt = crate::runtime::init_runtime_mt(self.config.threads, self.config.pthreads);
let rth = rt.handler();
@ -241,38 +211,29 @@ macro_rules! serve_rth_ssl {
let http2_only = self.config.http_mode == "2";
let http1_buffer_max = self.config.http1_buffer_max.clone();
let tls_cfg = self.config.tls_cfg();
let callback_wrapper = crate::callbacks::CallbackWrapper::new(
callback, event_loop, context
);
let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context);
let worker_id = self.config.id;
log::info!("Started worker-{}", worker_id);
let svc_loop = crate::runtime::run_until_complete(
rt.handler(),
event_loop,
async move {
let service = crate::workers::build_service_ssl!(
callback_wrapper, rth, $target
);
let server = hyper::Server::builder(
crate::tls::tls_listen(
std::sync::Arc::new(tls_cfg), tcp_listener
)
)
.http1_only(http1_only)
.http2_only(http2_only)
.http1_max_buf_size(http1_buffer_max)
.serve(service);
server.with_graceful_shutdown(async move {
Python::with_gil(|py| {
crate::runtime::into_future(signal_rx.as_ref(py)).unwrap()
}).await.unwrap();
}).await.unwrap();
log::info!("Stopping worker-{}", worker_id);
Ok(())
}
);
let svc_loop = crate::runtime::run_until_complete(rt.handler(), event_loop, async move {
let service = crate::workers::build_service_ssl!(callback_wrapper, rth, $target);
let server = hyper::Server::builder(crate::tls::tls_listen(std::sync::Arc::new(tls_cfg), tcp_listener))
.http1_only(http1_only)
.http2_only(http2_only)
.http1_max_buf_size(http1_buffer_max)
.serve(service);
server
.with_graceful_shutdown(async move {
Python::with_gil(|py| crate::runtime::into_future(signal_rx.as_ref(py)).unwrap())
.await
.unwrap();
})
.await
.unwrap();
log::info!("Stopping worker-{}", worker_id);
Ok(())
});
match svc_loop {
Ok(_) => {}
@ -287,22 +248,14 @@ macro_rules! serve_rth_ssl {
macro_rules! serve_wth {
($func_name: ident, $target:expr) => {
fn $func_name(
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
fn $func_name(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
pyo3_log::init();
let rtm = crate::runtime::init_runtime_mt(1, 1);
let worker_id = self.config.id;
log::info!("Started worker-{}", worker_id);
let callback_wrapper = crate::callbacks::CallbackWrapper::new(
callback, event_loop, context
);
let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context);
let mut workers = vec![];
let (stx, srx) = tokio::sync::watch::channel(false);
@ -323,38 +276,36 @@ macro_rules! serve_wth {
let local = tokio::task::LocalSet::new();
crate::runtime::block_on_local(rt, local, async move {
let service = crate::workers::build_service!(
callback_wrapper, rth, $target
);
let server = hyper::Server::from_tcp(tcp_listener).unwrap()
let service = crate::workers::build_service!(callback_wrapper, rth, $target);
let server = hyper::Server::from_tcp(tcp_listener)
.unwrap()
.executor(crate::workers::WorkerExecutor)
.http1_only(http1_only)
.http2_only(http2_only)
.http1_max_buf_size(http1_buffer_max)
.serve(service);
server.with_graceful_shutdown(async move {
srx.changed().await.unwrap();
}).await.unwrap();
server
.with_graceful_shutdown(async move {
srx.changed().await.unwrap();
})
.await
.unwrap();
log::info!("Stopping worker-{} runtime-{}", worker_id, thread_id + 1);
});
}));
};
}
let main_loop = crate::runtime::run_until_complete(
rtm.handler(),
event_loop,
async move {
Python::with_gil(|py| {
crate::runtime::into_future(signal_rx.as_ref(py)).unwrap()
}).await.unwrap();
stx.send(true).unwrap();
log::info!("Stopping worker-{}", worker_id);
while let Some(worker) = workers.pop() {
worker.join().unwrap();
}
Ok(())
let main_loop = crate::runtime::run_until_complete(rtm.handler(), event_loop, async move {
Python::with_gil(|py| crate::runtime::into_future(signal_rx.as_ref(py)).unwrap())
.await
.unwrap();
stx.send(true).unwrap();
log::info!("Stopping worker-{}", worker_id);
while let Some(worker) = workers.pop() {
worker.join().unwrap();
}
);
Ok(())
});
match main_loop {
Ok(_) => {}
@ -369,22 +320,14 @@ macro_rules! serve_wth {
macro_rules! serve_wth_ssl {
($func_name: ident, $target:expr) => {
fn $func_name(
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
fn $func_name(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
pyo3_log::init();
let rtm = crate::runtime::init_runtime_mt(1, 1);
let worker_id = self.config.id;
log::info!("Started worker-{}", worker_id);
let callback_wrapper = crate::callbacks::CallbackWrapper::new(
callback, event_loop, context
);
let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context);
let mut workers = vec![];
let (stx, srx) = tokio::sync::watch::channel(false);
@ -406,42 +349,36 @@ macro_rules! serve_wth_ssl {
let local = tokio::task::LocalSet::new();
crate::runtime::block_on_local(rt, local, async move {
let service = crate::workers::build_service_ssl!(
callback_wrapper, rth, $target
);
let server = hyper::Server::builder(
crate::tls::tls_listen(
std::sync::Arc::new(tls_cfg), tcp_listener
)
)
.executor(crate::workers::WorkerExecutor)
.http1_only(http1_only)
.http2_only(http2_only)
.http1_max_buf_size(http1_buffer_max)
.serve(service);
server.with_graceful_shutdown(async move {
srx.changed().await.unwrap();
}).await.unwrap();
let service = crate::workers::build_service_ssl!(callback_wrapper, rth, $target);
let server =
hyper::Server::builder(crate::tls::tls_listen(std::sync::Arc::new(tls_cfg), tcp_listener))
.executor(crate::workers::WorkerExecutor)
.http1_only(http1_only)
.http2_only(http2_only)
.http1_max_buf_size(http1_buffer_max)
.serve(service);
server
.with_graceful_shutdown(async move {
srx.changed().await.unwrap();
})
.await
.unwrap();
log::info!("Stopping worker-{} runtime-{}", worker_id, thread_id + 1);
});
}));
};
}
let main_loop = crate::runtime::run_until_complete(
rtm.handler(),
event_loop,
async move {
Python::with_gil(|py| {
crate::runtime::into_future(signal_rx.as_ref(py)).unwrap()
}).await.unwrap();
stx.send(true).unwrap();
log::info!("Stopping worker-{}", worker_id);
while let Some(worker) = workers.pop() {
worker.join().unwrap();
}
Ok(())
let main_loop = crate::runtime::run_until_complete(rtm.handler(), event_loop, async move {
Python::with_gil(|py| crate::runtime::into_future(signal_rx.as_ref(py)).unwrap())
.await
.unwrap();
stx.send(true).unwrap();
log::info!("Stopping worker-{}", worker_id);
while let Some(worker) = workers.pop() {
worker.join().unwrap();
}
);
Ok(())
});
match main_loop {
Ok(_) => {}
@ -457,8 +394,8 @@ macro_rules! serve_wth_ssl {
pub(crate) use build_service;
pub(crate) use build_service_ssl;
pub(crate) use serve_rth;
pub(crate) use serve_wth;
pub(crate) use serve_rth_ssl;
pub(crate) use serve_wth;
pub(crate) use serve_wth_ssl;
pub(crate) fn init_pymodule(module: &PyModule) -> PyResult<()> {

View file

@ -1,24 +1,24 @@
use hyper::{
Body,
Request,
Response,
StatusCode,
header::{CONNECTION, UPGRADE},
http::response::Builder
http::response::Builder,
Body, Request, Response, StatusCode,
};
use pin_project::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::sync::mpsc;
use tokio_tungstenite::WebSocketStream;
use tungstenite::{
error::ProtocolError,
handshake::derive_accept_key,
protocol::{Role, WebSocketConfig}
protocol::{Role, WebSocketConfig},
};
use pin_project::pin_project;
use std::{future::Future, pin::Pin, task::{Context, Poll}};
use tokio_tungstenite::WebSocketStream;
use tokio::sync::mpsc;
use super::utils::header_contains_value;
#[pin_project]
#[derive(Debug)]
pub struct HyperWebsocket {
@ -37,15 +37,9 @@ impl Future for HyperWebsocket {
Poll::Ready(x) => x,
};
let upgraded = upgraded.map_err(|_|
tungstenite::Error::Protocol(ProtocolError::HandshakeIncomplete)
)?;
let upgraded = upgraded.map_err(|_| tungstenite::Error::Protocol(ProtocolError::HandshakeIncomplete))?;
let stream = WebSocketStream::from_raw_socket(
upgraded,
Role::Server,
this.config.take(),
);
let stream = WebSocketStream::from_raw_socket(upgraded, Role::Server, this.config.take());
tokio::pin!(stream);
match stream.as_mut().poll(cx) {
@ -58,18 +52,15 @@ impl Future for HyperWebsocket {
pub(crate) struct UpgradeData {
response_builder: Option<Builder>,
response_tx: Option<mpsc::Sender<Response<Body>>>,
pub consumed: bool
pub consumed: bool,
}
impl UpgradeData {
pub fn new(
response_builder: Builder,
response_tx: mpsc::Sender<Response<Body>>)
-> Self {
pub fn new(response_builder: Builder, response_tx: mpsc::Sender<Response<Body>>) -> Self {
Self {
response_builder: Some(response_builder),
response_tx: Some(response_tx),
consumed: false
consumed: false,
}
}
@ -79,19 +70,16 @@ impl UpgradeData {
Ok(_) => {
self.consumed = true;
Ok(())
},
err => err
}
err => err,
}
}
}
#[inline]
pub(crate) fn is_upgrade_request<B>(request: &Request<B>) -> bool {
header_contains_value(
request.headers(), CONNECTION, "Upgrade"
) && header_contains_value(
request.headers(), UPGRADE, "websocket"
)
header_contains_value(request.headers(), CONNECTION, "Upgrade")
&& header_contains_value(request.headers(), UPGRADE, "websocket")
}
pub(crate) fn upgrade_intent<B>(
@ -100,13 +88,17 @@ pub(crate) fn upgrade_intent<B>(
) -> Result<(Builder, HyperWebsocket), ProtocolError> {
let request = request.borrow_mut();
let key = request.headers()
let key = request
.headers()
.get("Sec-WebSocket-Key")
.ok_or(ProtocolError::MissingSecWebSocketKey)?;
if request.headers().get("Sec-WebSocket-Version").map(
|v| v.as_bytes()
) != Some(b"13") {
if request
.headers()
.get("Sec-WebSocket-Version")
.map(hyper::http::HeaderValue::as_bytes)
!= Some(b"13")
{
return Err(ProtocolError::MissingSecWebSocketVersionHeader);
}

View file

@ -2,44 +2,35 @@ use hyper::Body;
use pyo3::prelude::*;
use tokio::task::JoinHandle;
use super::types::{WSGIResponseBodyIter, WSGIScope as Scope};
use crate::callbacks::CallbackWrapper;
use super::types::{WSGIScope as Scope, WSGIResponseBodyIter};
const WSGI_LIST_RESPONSE_BODY: i32 = 0;
const WSGI_ITER_RESPONSE_BODY: i32 = 1;
#[inline(always)]
fn run_callback(
callback: PyObject,
scope: Scope
) -> PyResult<(i32, Vec<(String, String)>, Body)> {
fn run_callback(callback: PyObject, scope: Scope) -> PyResult<(i32, Vec<(String, String)>, Body)> {
Python::with_gil(|py| {
let (status, headers, body_type, pybody) = callback.call1(py, (scope,))?
.extract::<(i32, Vec<(String, String)>, i32, PyObject)>(py)?;
let (status, headers, body_type, pybody) =
callback
.call1(py, (scope,))?
.extract::<(i32, Vec<(String, String)>, i32, PyObject)>(py)?;
let body = match body_type {
WSGI_LIST_RESPONSE_BODY => Body::from(pybody.extract::<Vec<u8>>(py)?),
WSGI_ITER_RESPONSE_BODY => Body::wrap_stream(WSGIResponseBodyIter::new(pybody)),
_ => Body::empty()
_ => Body::empty(),
};
Ok((status, headers, body))
})
}
pub(crate) fn call_rtb_http(
cb: CallbackWrapper,
scope: Scope
) -> PyResult<(i32, Vec<(String, String)>, Body)> {
run_callback(cb.callback.clone(), scope)
pub(crate) fn call_rtb_http(cb: CallbackWrapper, scope: Scope) -> PyResult<(i32, Vec<(String, String)>, Body)> {
run_callback(cb.callback, scope)
}
pub(crate) fn call_rtt_http(
cb: CallbackWrapper,
scope: Scope
scope: Scope,
) -> JoinHandle<PyResult<(i32, Vec<(String, String)>, Body)>> {
let callback = cb.callback.clone();
tokio::task::spawn_blocking(move || {
run_callback(callback, scope)
})
tokio::task::spawn_blocking(move || run_callback(cb.callback, scope))
}

View file

@ -1,21 +1,18 @@
use hyper::{
Body,
Request,
Response,
header::{SERVER as HK_SERVER, HeaderName, HeaderValue}
header::{HeaderName, HeaderValue, SERVER as HK_SERVER},
Body, Request, Response,
};
use std::net::SocketAddr;
use crate::{
callbacks::CallbackWrapper,
http::{HV_SERVER, response_500},
runtime::RuntimeRef,
};
use super::{
callbacks::{call_rtb_http, call_rtt_http},
types::WSGIScope as Scope
types::WSGIScope as Scope,
};
use crate::{
callbacks::CallbackWrapper,
http::{response_500, HV_SERVER},
runtime::RuntimeRef,
};
#[inline(always)]
fn build_response(status: i32, pyheaders: Vec<(String, String)>, body: Body) -> Response<Body> {
@ -26,7 +23,7 @@ fn build_response(status: i32, pyheaders: Vec<(String, String)>, body: Body) ->
for (key, val) in pyheaders {
headers.append(
HeaderName::from_bytes(key.as_bytes()).unwrap(),
HeaderValue::from_str(&val).unwrap()
HeaderValue::from_str(&val).unwrap(),
);
}
res
@ -38,14 +35,11 @@ pub(crate) async fn handle_rtt(
server_addr: SocketAddr,
client_addr: SocketAddr,
req: Request<Body>,
scheme: &str
scheme: &str,
) -> Response<Body> {
if let Ok(res) = call_rtt_http(
callback,
Scope::new(scheme, server_addr, client_addr, req).await
).await {
if let Ok(res) = call_rtt_http(callback, Scope::new(scheme, server_addr, client_addr, req).await).await {
if let Ok((status, headers, body)) = res {
return build_response(status, headers, body)
return build_response(status, headers, body);
}
log::warn!("Application callable raised an exception");
} else {
@ -60,12 +54,9 @@ pub(crate) async fn handle_rtb(
server_addr: SocketAddr,
client_addr: SocketAddr,
req: Request<Body>,
scheme: &str
scheme: &str,
) -> Response<Body> {
match call_rtb_http(
callback,
Scope::new(scheme, server_addr, client_addr, req).await
) {
match call_rtb_http(callback, Scope::new(scheme, server_addr, client_addr, req).await) {
Ok((status, headers, body)) => build_response(status, headers, body),
_ => {
log::warn!("Application callable raised an exception");

View file

@ -1,17 +1,11 @@
use pyo3::prelude::*;
use crate::workers::{
WorkerConfig,
serve_rth,
serve_wth,
serve_rth_ssl,
serve_wth_ssl
};
use super::http::{handle_rtb, handle_rtt};
use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig};
#[pyclass(module="granian._granian")]
#[pyclass(module = "granian._granian")]
pub struct WSGIWorker {
config: WorkerConfig
config: WorkerConfig,
}
impl WSGIWorker {
@ -46,7 +40,7 @@ impl WSGIWorker {
http1_buffer_max: usize,
ssl_enabled: bool,
ssl_cert: Option<&str>,
ssl_key: Option<&str>
ssl_key: Option<&str>,
) -> PyResult<Self> {
Ok(Self {
config: WorkerConfig::new(
@ -60,34 +54,22 @@ impl WSGIWorker {
true,
ssl_enabled,
ssl_cert,
ssl_key
)
ssl_key,
),
})
}
fn serve_rth(
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
fn serve_rth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
match self.config.ssl_enabled {
false => self._serve_rth(callback, event_loop, context, signal_rx),
true => self._serve_rth_ssl(callback, event_loop, context, signal_rx)
true => self._serve_rth_ssl(callback, event_loop, context, signal_rx),
}
}
fn serve_wth(
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
fn serve_wth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
match self.config.ssl_enabled {
false => self._serve_wth(callback, event_loop, context, signal_rx),
true => self._serve_wth_ssl(callback, event_loop, context, signal_rx)
true => self._serve_wth_ssl(callback, event_loop, context, signal_rx),
}
}
}

View file

@ -1,23 +1,21 @@
use futures::Stream;
use hyper::{
body::Bytes,
header::{CONTENT_TYPE, CONTENT_LENGTH, HeaderMap},
Body,
Method,
Request,
Uri,
Version
header::{HeaderMap, CONTENT_LENGTH, CONTENT_TYPE},
Body, Method, Request, Uri, Version,
};
use pyo3::{prelude::*, types::IntoPyDict};
use pyo3::types::{PyBytes, PyDict, PyList};
use std::{net::{IpAddr, SocketAddr}, task::{Context, Poll}};
use pyo3::{prelude::*, types::IntoPyDict};
use std::{
net::{IpAddr, SocketAddr},
task::{Context, Poll},
};
const LINE_SPLIT: u8 = u8::from_be_bytes(*b"\n");
#[pyclass(module = "granian._granian")]
pub(crate) struct WSGIBody {
inner: Bytes
inner: Bytes,
}
impl WSGIBody {
@ -36,9 +34,9 @@ impl WSGIBody {
match self.inner.iter().position(|&c| c == LINE_SPLIT) {
Some(next_split) => {
let bytes = self.inner.split_to(next_split);
Some(PyBytes::new(py, &bytes[..]))
},
_ => None
Some(PyBytes::new(py, &bytes))
}
_ => None,
}
}
@ -48,18 +46,16 @@ impl WSGIBody {
None => {
let bytes = self.inner.split_to(self.inner.len());
PyBytes::new(py, &bytes[..])
},
Some(size) => {
match size {
0 => PyBytes::new(py, b""),
size => {
let limit = self.inner.len();
let rsize = if size > limit { limit } else { size };
let bytes = self.inner.split_to(rsize);
PyBytes::new(py, &bytes[..])
}
}
}
Some(size) => match size {
0 => PyBytes::new(py, b""),
size => {
let limit = self.inner.len();
let rsize = if size > limit { limit } else { size };
let bytes = self.inner.split_to(rsize);
PyBytes::new(py, &bytes[..])
}
},
}
}
@ -69,16 +65,17 @@ impl WSGIBody {
let bytes = self.inner.split_to(next_split);
self.inner = self.inner.slice(1..);
PyBytes::new(py, &bytes[..])
},
_ => PyBytes::new(py, b"")
}
_ => PyBytes::new(py, b""),
}
}
#[pyo3(signature = (_hint=None))]
fn readlines<'p>(&mut self, py: Python<'p>, _hint: Option<PyObject>) -> &'p PyList {
let lines: Vec<&PyBytes> = self.inner
let lines: Vec<&PyBytes> = self
.inner
.split(|&c| c == LINE_SPLIT)
.map(|item| PyBytes::new(py, &item[..]))
.map(|item| PyBytes::new(py, item))
.collect();
self.inner.clear();
PyList::new(py, lines)
@ -95,28 +92,19 @@ pub(crate) struct WSGIScope {
server_port: u16,
client: String,
headers: HeaderMap,
body: Bytes
body: Bytes,
}
impl WSGIScope {
pub async fn new(
scheme: &str,
server: SocketAddr,
client: SocketAddr,
request: Request<Body>,
) -> Self {
pub async fn new(scheme: &str, server: SocketAddr, client: SocketAddr, request: Request<Body>) -> Self {
let http_version = request.version();
let method = request.method().to_owned();
let uri = request.uri().to_owned();
let headers = request.headers().to_owned();
let method = request.method().clone();
let uri = request.uri().clone();
let headers = request.headers().clone();
let body = match method {
Method::HEAD | Method::GET | Method::OPTIONS => { Bytes::new() },
_ => {
hyper::body::to_bytes(request)
.await
.unwrap_or(Bytes::new())
}
Method::HEAD | Method::GET | Method::OPTIONS => Bytes::new(),
_ => hyper::body::to_bytes(request).await.unwrap_or(Bytes::new()),
};
Self {
@ -128,7 +116,7 @@ impl WSGIScope {
server_port: server.port(),
client: client.to_string(),
headers,
body
body,
}
}
@ -138,7 +126,7 @@ impl WSGIScope {
Version::HTTP_10 => "HTTP/1",
Version::HTTP_11 => "HTTP/1.1",
Version::HTTP_2 => "HTTP/2",
_ => "HTTP/1"
_ => "HTTP/1",
}
}
}
@ -157,21 +145,21 @@ impl WSGIScope {
content_type,
content_len,
headers,
body
body,
) = py.allow_threads(|| {
let (path, query_string) = self.uri.path_and_query()
let (path, query_string) = self
.uri
.path_and_query()
.map_or_else(|| ("", ""), |pq| (pq.path(), pq.query().unwrap_or("")));
let content_type = self.headers.remove(CONTENT_TYPE);
let content_len = self.headers.remove(CONTENT_LENGTH);
let mut headers = Vec::with_capacity(self.headers.len());
for (key, val) in self.headers.iter() {
headers.push(
(
format!("HTTP_{}", key.as_str().replace("-", "_").to_uppercase()),
val.to_str().unwrap_or_default()
)
);
for (key, val) in &self.headers {
headers.push((
format!("HTTP_{}", key.as_str().replace('-', "_").to_uppercase()),
val.to_str().unwrap_or_default(),
));
}
(
@ -185,7 +173,7 @@ impl WSGIScope {
content_type,
content_len,
headers,
WSGIBody::new(self.body.to_owned())
WSGIBody::new(self.body.clone()),
)
});
@ -202,13 +190,13 @@ impl WSGIScope {
if let Some(content_type) = content_type {
ret.set_item(
pyo3::intern!(py, "CONTENT_TYPE"),
content_type.to_str().unwrap_or_default()
content_type.to_str().unwrap_or_default(),
)?;
}
if let Some(content_len) = content_len {
ret.set_item(
pyo3::intern!(py, "CONTENT_LENGTH"),
content_len.to_str().unwrap_or_default()
content_len.to_str().unwrap_or_default(),
)?;
}
@ -219,7 +207,7 @@ impl WSGIScope {
}
pub(crate) struct WSGIResponseBodyIter {
inner: PyObject
inner: PyObject,
}
impl WSGIResponseBodyIter {
@ -235,27 +223,20 @@ impl WSGIResponseBodyIter {
impl Stream for WSGIResponseBodyIter {
type Item = PyResult<Vec<u8>>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>
) -> Poll<Option<Self::Item>> {
Python::with_gil(|py| {
match self.inner.call_method0(py, pyo3::intern!(py, "__next__")) {
Ok(chunk_obj) => {
match chunk_obj.extract::<Vec<u8>>(py) {
Ok(chunk) => Poll::Ready(Some(Ok(chunk))),
_ => {
self.close_inner(py);
Poll::Ready(None)
}
}
},
Err(err) => {
if err.is_instance_of::<pyo3::exceptions::PyStopIteration>(py) {
self.close_inner(py);
}
fn poll_next(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Python::with_gil(|py| match self.inner.call_method0(py, pyo3::intern!(py, "__next__")) {
Ok(chunk_obj) => match chunk_obj.extract::<Vec<u8>>(py) {
Ok(chunk) => Poll::Ready(Some(Ok(chunk))),
_ => {
self.close_inner(py);
Poll::Ready(None)
}
},
Err(err) => {
if err.is_instance_of::<pyo3::exceptions::PyStopIteration>(py) {
self.close_inner(py);
}
Poll::Ready(None)
}
})
}

View file

@ -1,55 +1,45 @@
import json
PLAINTEXT_RESPONSE = {
'type': 'http.response.start',
'status': 200,
'headers': [
[b'content-type', b'text/plain; charset=utf-8'],
]
}
JSON_RESPONSE = {
'type': 'http.response.start',
'status': 200,
'headers': [
[b'content-type', b'application/json'],
]
'headers': [[b'content-type', b'text/plain; charset=utf-8']],
}
JSON_RESPONSE = {'type': 'http.response.start', 'status': 200, 'headers': [[b'content-type', b'application/json']]}
async def info(scope, receive, send):
await send(JSON_RESPONSE)
await send({
'type': 'http.response.body',
'body': json.dumps({
'type': scope['type'],
'asgi': scope['asgi'],
'http_version': scope['http_version'],
'scheme': scope['scheme'],
'method': scope['method'],
'path': scope['path'],
'query_string': scope['query_string'].decode("latin-1"),
'headers': {
k.decode("utf8"): v.decode("utf8")
for k, v in scope['headers']
}
}).encode("utf8"),
'more_body': False
})
await send(
{
'type': 'http.response.body',
'body': json.dumps(
{
'type': scope['type'],
'asgi': scope['asgi'],
'http_version': scope['http_version'],
'scheme': scope['scheme'],
'method': scope['method'],
'path': scope['path'],
'query_string': scope['query_string'].decode('latin-1'),
'headers': {k.decode('utf8'): v.decode('utf8') for k, v in scope['headers']},
}
).encode('utf8'),
'more_body': False,
}
)
async def echo(scope, receive, send):
await send(PLAINTEXT_RESPONSE)
more_body = True
body = b""
body = b''
while more_body:
msg = await receive()
more_body = msg['more_body']
body += msg['body']
await send({
'type': 'http.response.body',
'body': body,
'more_body': False
})
await send({'type': 'http.response.body', 'body': body, 'more_body': False})
async def ws_reject(scope, receive, send):
@ -58,21 +48,22 @@ async def ws_reject(scope, receive, send):
async def ws_info(scope, receive, send):
await send({'type': 'websocket.accept'})
await send({
'type': 'websocket.send',
'text': json.dumps({
'type': scope['type'],
'asgi': scope['asgi'],
'http_version': scope['http_version'],
'scheme': scope['scheme'],
'path': scope['path'],
'query_string': scope['query_string'].decode("latin-1"),
'headers': {
k.decode("utf8"): v.decode("utf8")
for k, v in scope['headers']
}
})
})
await send(
{
'type': 'websocket.send',
'text': json.dumps(
{
'type': scope['type'],
'asgi': scope['asgi'],
'http_version': scope['http_version'],
'scheme': scope['scheme'],
'path': scope['path'],
'query_string': scope['query_string'].decode('latin-1'),
'headers': {k.decode('utf8'): v.decode('utf8') for k, v in scope['headers']},
}
),
}
)
await send({'type': 'websocket.close'})
@ -98,10 +89,7 @@ async def ws_push(scope, receive, send):
try:
while True:
await send({
'type': 'websocket.send',
'text': 'ping'
})
await send({'type': 'websocket.send', 'text': 'ping'})
except Exception:
pass
@ -116,12 +104,12 @@ async def err_proto(scope, receive, send):
def app(scope, receive, send):
return {
"/info": info,
"/echo": echo,
"/ws_reject": ws_reject,
"/ws_info": ws_info,
"/ws_echo": ws_echo,
"/ws_push": ws_push,
"/err_app": err_app,
"/err_proto": err_proto
'/info': info,
'/echo': echo,
'/ws_reject': ws_reject,
'/ws_info': ws_info,
'/ws_echo': ws_echo,
'/ws_push': ws_push,
'/err_app': err_app,
'/err_proto': err_proto,
}[scope['path']](scope, receive, send)

View file

@ -1,55 +1,42 @@
import json
from granian.rsgi import (
HTTPProtocol,
Scope,
WebsocketMessageType,
WebsocketProtocol
)
from granian.rsgi import HTTPProtocol, Scope, WebsocketMessageType, WebsocketProtocol
async def info(scope: Scope, protocol: HTTPProtocol):
protocol.response_bytes(
200,
[('content-type', 'application/json')],
json.dumps({
'proto': scope.proto,
'http_version': scope.http_version,
'rsgi_version': scope.rsgi_version,
'scheme': scope.scheme,
'method': scope.method,
'path': scope.path,
'query_string': scope.query_string,
'headers': {k: v for k, v in scope.headers.items()}
}).encode("utf8")
json.dumps(
{
'proto': scope.proto,
'http_version': scope.http_version,
'rsgi_version': scope.rsgi_version,
'scheme': scope.scheme,
'method': scope.method,
'path': scope.path,
'query_string': scope.query_string,
'headers': dict(scope.headers.items()),
}
).encode('utf8'),
)
async def echo(_, protocol: HTTPProtocol):
msg = await protocol()
protocol.response_bytes(
200,
[('content-type', 'text/plain; charset=utf-8')],
msg
)
protocol.response_bytes(200, [('content-type', 'text/plain; charset=utf-8')], msg)
async def echo_stream(_, protocol: HTTPProtocol):
trx = protocol.response_stream(
200,
[('content-type', 'text/plain; charset=utf-8')]
)
trx = protocol.response_stream(200, [('content-type', 'text/plain; charset=utf-8')])
async for msg in protocol:
await trx.send_bytes(msg)
async def stream(_, protocol: HTTPProtocol):
trx = protocol.response_stream(
200,
[('content-type', 'text/plain; charset=utf-8')]
)
trx = protocol.response_stream(200, [('content-type', 'text/plain; charset=utf-8')])
for _ in range(0, 3):
await trx.send_bytes(b"test")
await trx.send_bytes(b'test')
async def ws_reject(_, protocol: WebsocketProtocol):
@ -59,16 +46,20 @@ async def ws_reject(_, protocol: WebsocketProtocol):
async def ws_info(scope: Scope, protocol: WebsocketProtocol):
trx = await protocol.accept()
await trx.send_str(json.dumps({
'proto': scope.proto,
'http_version': scope.http_version,
'rsgi_version': scope.rsgi_version,
'scheme': scope.scheme,
'method': scope.method,
'path': scope.path,
'query_string': scope.query_string,
'headers': {k: v for k, v in scope.headers.items()}
}))
await trx.send_str(
json.dumps(
{
'proto': scope.proto,
'http_version': scope.http_version,
'rsgi_version': scope.rsgi_version,
'scheme': scope.scheme,
'method': scope.method,
'path': scope.path,
'query_string': scope.query_string,
'headers': dict(scope.headers.items()),
}
)
)
while True:
message = await trx.receive()
if message.kind == WebsocketMessageType.close:
@ -97,7 +88,7 @@ async def ws_push(_, protocol: WebsocketProtocol):
try:
while True:
await trx.send_str("ping")
await trx.send_str('ping')
except Exception:
pass
@ -110,13 +101,13 @@ async def err_app(scope: Scope, protocol: HTTPProtocol):
def app(scope, protocol):
return {
"/info": info,
"/echo": echo,
"/echos": echo_stream,
"/stream": stream,
"/ws_reject": ws_reject,
"/ws_info": ws_info,
"/ws_echo": ws_echo,
"/ws_push": ws_push,
"/err_app": err_app
'/info': info,
'/echo': echo,
'/echos': echo_stream,
'/stream': stream,
'/ws_reject': ws_reject,
'/ws_info': ws_info,
'/ws_echo': ws_echo,
'/ws_push': ws_push,
'/err_app': err_app,
}[scope.path](scope, protocol)

View file

@ -2,37 +2,32 @@ import json
def info(environ, protocol):
protocol(
"200 OK",
[('content-type', 'application/json')]
)
return [json.dumps({
'scheme': environ['wsgi.url_scheme'],
'method': environ['REQUEST_METHOD'],
'path': environ["PATH_INFO"],
'query_string': environ["QUERY_STRING"],
'content_length': environ['CONTENT_LENGTH'],
'headers': {k: v for k, v in environ.items() if k.startswith("HTTP_")}
}).encode("utf8")]
protocol('200 OK', [('content-type', 'application/json')])
return [
json.dumps(
{
'scheme': environ['wsgi.url_scheme'],
'method': environ['REQUEST_METHOD'],
'path': environ['PATH_INFO'],
'query_string': environ['QUERY_STRING'],
'content_length': environ['CONTENT_LENGTH'],
'headers': {k: v for k, v in environ.items() if k.startswith('HTTP_')},
}
).encode('utf8')
]
def echo(environ, protocol):
protocol(
'200 OK',
[('content-type', 'text/plain; charset=utf-8')]
)
protocol('200 OK', [('content-type', 'text/plain; charset=utf-8')])
return [environ['wsgi.input'].read()]
def iterbody(environ, protocol):
def response():
for _ in range(0, 3):
yield b"test"
yield b'test'
protocol(
'200 OK',
[('content-type', 'text/plain; charset=utf-8')]
)
protocol('200 OK', [('content-type', 'text/plain; charset=utf-8')])
return response()
@ -41,9 +36,6 @@ def err_app(environ, protocol):
def app(environ, protocol):
return {
"/info": info,
"/echo": echo,
"/iterbody": iterbody,
"/err_app": err_app
}[environ["PATH_INFO"]](environ, protocol)
return {'/info': info, '/echo': echo, '/iterbody': iterbody, '/err_app': err_app}[environ['PATH_INFO']](
environ, protocol
)

View file

@ -1,7 +1,6 @@
import asyncio
import os
import socket
from contextlib import asynccontextmanager, closing
from functools import partial
from pathlib import Path
@ -11,19 +10,20 @@ import pytest
@asynccontextmanager
async def _server(interface, port, threading_mode, tls=False):
certs_path = Path.cwd() / "tests" / "fixtures" / "tls"
certs_path = Path.cwd() / 'tests' / 'fixtures' / 'tls'
tls_opts = (
f"--ssl-certificate {certs_path / 'cert.pem'} "
f"--ssl-keyfile {certs_path / 'key.pem'} "
) if tls else ""
(f"--ssl-certificate {certs_path / 'cert.pem'} " f"--ssl-keyfile {certs_path / 'key.pem'} ") if tls else ''
)
proc = await asyncio.create_subprocess_shell(
"".join([
f"granian --interface {interface} --port {port} ",
f"--threads 1 --threading-mode {threading_mode} ",
tls_opts,
f"tests.apps.{interface}:app"
]),
env=dict(os.environ)
''.join(
[
f'granian --interface {interface} --port {port} ',
f'--threads 1 --threading-mode {threading_mode} ',
tls_opts,
f'tests.apps.{interface}:app',
]
),
env=dict(os.environ),
)
await asyncio.sleep(1)
try:
@ -33,7 +33,7 @@ async def _server(interface, port, threading_mode, tls=False):
await proc.wait()
@pytest.fixture(scope="function")
@pytest.fixture(scope='function')
def server_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.bind(('localhost', 0))
@ -41,26 +41,26 @@ def server_port():
return sock.getsockname()[1]
@pytest.fixture(scope="function")
@pytest.fixture(scope='function')
def asgi_server(server_port):
return partial(_server, "asgi", server_port)
return partial(_server, 'asgi', server_port)
@pytest.fixture(scope="function")
@pytest.fixture(scope='function')
def rsgi_server(server_port):
return partial(_server, "rsgi", server_port)
return partial(_server, 'rsgi', server_port)
@pytest.fixture(scope="function")
@pytest.fixture(scope='function')
def wsgi_server(server_port):
return partial(_server, "wsgi", server_port)
return partial(_server, 'wsgi', server_port)
@pytest.fixture(scope="function")
@pytest.fixture(scope='function')
def server(server_port, request):
return partial(_server, request.param, server_port)
@pytest.fixture(scope="function")
@pytest.fixture(scope='function')
def server_tls(server_port, request):
return partial(_server, request.param, server_port, tls=True)

View file

@ -3,92 +3,59 @@ import pytest
@pytest.mark.asyncio
@pytest.mark.parametrize(
"threading_mode",
[
"runtime",
"workers"
]
)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_scope(asgi_server, threading_mode):
async with asgi_server(threading_mode) as port:
res = httpx.get(f"http://localhost:{port}/info?test=true")
res = httpx.get(f'http://localhost:{port}/info?test=true')
assert res.status_code == 200
assert res.headers["content-type"] == "application/json"
assert res.headers['content-type'] == 'application/json'
data = res.json()
assert data['asgi'] == {
'version': '3.0',
'spec_version': '2.3'
}
assert data['type'] == "http"
assert data['asgi'] == {'version': '3.0', 'spec_version': '2.3'}
assert data['type'] == 'http'
assert data['http_version'] == '1.1'
assert data['scheme'] == 'http'
assert data['method'] == "GET"
assert data['method'] == 'GET'
assert data['path'] == '/info'
assert data['query_string'] == 'test=true'
assert data['headers']['host'] == f'localhost:{port}'
@pytest.mark.asyncio
@pytest.mark.parametrize(
"threading_mode",
[
"runtime",
"workers"
]
)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_body(asgi_server, threading_mode):
async with asgi_server(threading_mode) as port:
res = httpx.post(f"http://localhost:{port}/echo", content="test")
res = httpx.post(f'http://localhost:{port}/echo', content='test')
assert res.status_code == 200
assert res.text == "test"
assert res.text == 'test'
@pytest.mark.asyncio
@pytest.mark.parametrize(
"threading_mode",
[
"runtime",
"workers"
]
)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_body_large(asgi_server, threading_mode):
data = "".join([f"{idx}test".zfill(8) for idx in range(0, 5000)])
data = ''.join([f'{idx}test'.zfill(8) for idx in range(0, 5000)])
async with asgi_server(threading_mode) as port:
res = httpx.post(f"http://localhost:{port}/echo", content=data)
res = httpx.post(f'http://localhost:{port}/echo', content=data)
assert res.status_code == 200
assert res.text == data
@pytest.mark.asyncio
@pytest.mark.parametrize(
"threading_mode",
[
"runtime",
"workers"
]
)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_app_error(asgi_server, threading_mode):
async with asgi_server(threading_mode) as port:
res = httpx.get(f"http://localhost:{port}/err_app")
res = httpx.get(f'http://localhost:{port}/err_app')
assert res.status_code == 500
@pytest.mark.asyncio
@pytest.mark.parametrize(
"threading_mode",
[
"runtime",
"workers"
]
)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_protocol_error(asgi_server, threading_mode):
async with asgi_server(threading_mode) as port:
res = httpx.get(f"http://localhost:{port}/err_proto")
res = httpx.get(f'http://localhost:{port}/err_proto')
assert res.status_code == 500

View file

@ -1,20 +1,18 @@
import httpx
import json
import pathlib
import pytest
import ssl
import httpx
import pytest
import websockets
@pytest.mark.asyncio
@pytest.mark.parametrize("server_tls", ["asgi", "rsgi"], indirect=True)
@pytest.mark.parametrize("threading_mode", ["runtime", "workers"])
@pytest.mark.parametrize('server_tls', ['asgi', 'rsgi'], indirect=True)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_http_scope(server_tls, threading_mode):
async with server_tls(threading_mode) as port:
res = httpx.get(
f"https://localhost:{port}/info?test=true",
verify=False
)
res = httpx.get(f'https://localhost:{port}/info?test=true', verify=False)
assert res.status_code == 200
data = res.json()
@ -22,17 +20,14 @@ async def test_http_scope(server_tls, threading_mode):
@pytest.mark.asyncio
@pytest.mark.parametrize("threading_mode", ["runtime", "workers"])
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_asgi_ws_scope(asgi_server, threading_mode):
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
localhost_pem = pathlib.Path.cwd() / "tests" / "fixtures" / "tls" / "cert.pem"
localhost_pem = pathlib.Path.cwd() / 'tests' / 'fixtures' / 'tls' / 'cert.pem'
ssl_context.load_verify_locations(localhost_pem)
async with asgi_server(threading_mode, tls=True) as port:
async with websockets.connect(
f"wss://localhost:{port}/ws_info?test=true",
ssl=ssl_context
) as ws:
async with websockets.connect(f'wss://localhost:{port}/ws_info?test=true', ssl=ssl_context) as ws:
res = await ws.recv()
data = json.loads(res)
@ -40,17 +35,14 @@ async def test_asgi_ws_scope(asgi_server, threading_mode):
@pytest.mark.asyncio
@pytest.mark.parametrize("threading_mode", ["runtime", "workers"])
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_rsgi_ws_scope(rsgi_server, threading_mode):
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
localhost_pem = pathlib.Path.cwd() / "tests" / "fixtures" / "tls" / "cert.pem"
localhost_pem = pathlib.Path.cwd() / 'tests' / 'fixtures' / 'tls' / 'cert.pem'
ssl_context.load_verify_locations(localhost_pem)
async with rsgi_server(threading_mode, tls=True) as port:
async with websockets.connect(
f"wss://localhost:{port}/ws_info?test=true",
ssl=ssl_context
) as ws:
async with websockets.connect(f'wss://localhost:{port}/ws_info?test=true', ssl=ssl_context) as ws:
res = await ws.recv()
data = json.loads(res)

View file

@ -3,90 +3,60 @@ import pytest
@pytest.mark.asyncio
@pytest.mark.parametrize(
"threading_mode",
[
"runtime",
"workers"
]
)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_scope(rsgi_server, threading_mode):
async with rsgi_server(threading_mode) as port:
res = httpx.get(f"http://localhost:{port}/info?test=true")
res = httpx.get(f'http://localhost:{port}/info?test=true')
assert res.status_code == 200
assert res.headers["content-type"] == "application/json"
assert res.headers['content-type'] == 'application/json'
data = res.json()
assert data['proto'] == "http"
assert data['proto'] == 'http'
assert data['http_version'] == '1.1'
assert data['rsgi_version'] == '1.2'
assert data['scheme'] == 'http'
assert data['method'] == "GET"
assert data['method'] == 'GET'
assert data['path'] == '/info'
assert data['query_string'] == 'test=true'
assert data['headers']['host'] == f'localhost:{port}'
@pytest.mark.asyncio
@pytest.mark.parametrize(
"threading_mode",
[
"runtime",
"workers"
]
)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_body(rsgi_server, threading_mode):
async with rsgi_server(threading_mode) as port:
res = httpx.post(f"http://localhost:{port}/echo", content="test")
res = httpx.post(f'http://localhost:{port}/echo', content='test')
assert res.status_code == 200
assert res.text == "test"
assert res.text == 'test'
@pytest.mark.asyncio
@pytest.mark.parametrize(
"threading_mode",
[
"runtime",
"workers"
]
)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_body_stream_req(rsgi_server, threading_mode):
data = "".join([f"{idx}test".zfill(8) for idx in range(0, 5000)])
data = ''.join([f'{idx}test'.zfill(8) for idx in range(0, 5000)])
async with rsgi_server(threading_mode) as port:
res = httpx.post(f"http://localhost:{port}/echos", content=data)
res = httpx.post(f'http://localhost:{port}/echos', content=data)
assert res.status_code == 200
assert res.text == data
@pytest.mark.asyncio
@pytest.mark.parametrize(
"threading_mode",
[
"runtime",
"workers"
]
)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_body_stream_res(rsgi_server, threading_mode):
async with rsgi_server(threading_mode) as port:
res = httpx.get(f"http://localhost:{port}/stream")
res = httpx.get(f'http://localhost:{port}/stream')
assert res.status_code == 200
assert res.text == "test" * 3
assert res.text == 'test' * 3
@pytest.mark.asyncio
@pytest.mark.parametrize(
"threading_mode",
[
"runtime",
"workers"
]
)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_app_error(rsgi_server, threading_mode):
async with rsgi_server(threading_mode) as port:
res = httpx.get(f"http://localhost:{port}/err_app")
res = httpx.get(f'http://localhost:{port}/err_app')
assert res.status_code == 500

View file

@ -1,56 +1,48 @@
import json
import pytest
import sys
import pytest
import websockets
@pytest.mark.skipif(sys.platform == "win32", reason="skip on windows")
@pytest.mark.skipif(sys.platform == 'win32', reason='skip on windows')
@pytest.mark.asyncio
@pytest.mark.parametrize("server", ["asgi", "rsgi"], indirect=True)
@pytest.mark.parametrize("threading_mode", ["runtime", "workers"])
@pytest.mark.parametrize('server', ['asgi', 'rsgi'], indirect=True)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_messages(server, threading_mode):
async with server(threading_mode) as port:
async with websockets.connect(f"ws://localhost:{port}/ws_echo") as ws:
await ws.send("foo")
async with websockets.connect(f'ws://localhost:{port}/ws_echo') as ws:
await ws.send('foo')
res_text = await ws.recv()
await ws.send(b"foo")
await ws.send(b'foo')
res_bytes = await ws.recv()
assert res_text == "foo"
assert res_bytes == b"foo"
assert res_text == 'foo'
assert res_bytes == b'foo'
@pytest.mark.asyncio
@pytest.mark.parametrize("server", ["asgi", "rsgi"], indirect=True)
@pytest.mark.parametrize("threading_mode", ["runtime", "workers"])
@pytest.mark.parametrize('server', ['asgi', 'rsgi'], indirect=True)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_reject(server, threading_mode):
async with server(threading_mode) as port:
with pytest.raises(websockets.InvalidStatusCode) as exc:
async with websockets.connect(f"ws://localhost:{port}/ws_reject") as ws:
async with websockets.connect(f'ws://localhost:{port}/ws_reject'):
pass
assert exc.value.status_code == 403
@pytest.mark.asyncio
@pytest.mark.parametrize(
"threading_mode",
[
"runtime",
"workers"
]
)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_asgi_scope(asgi_server, threading_mode):
async with asgi_server(threading_mode) as port:
async with websockets.connect(f"ws://localhost:{port}/ws_info?test=true") as ws:
async with websockets.connect(f'ws://localhost:{port}/ws_info?test=true') as ws:
res = await ws.recv()
data = json.loads(res)
assert data['asgi'] == {
'version': '3.0',
'spec_version': '2.3'
}
assert data['type'] == "websocket"
assert data['asgi'] == {'version': '3.0', 'spec_version': '2.3'}
assert data['type'] == 'websocket'
assert data['http_version'] == '1.1'
assert data['scheme'] == 'ws'
assert data['path'] == '/ws_info'
@ -59,16 +51,10 @@ async def test_asgi_scope(asgi_server, threading_mode):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"threading_mode",
[
"runtime",
"workers"
]
)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_rsgi_scope(rsgi_server, threading_mode):
async with rsgi_server(threading_mode) as port:
async with websockets.connect(f"ws://localhost:{port}/ws_info?test=true") as ws:
async with websockets.connect(f'ws://localhost:{port}/ws_info?test=true') as ws:
res = await ws.recv()
data = json.loads(res)
@ -76,7 +62,7 @@ async def test_rsgi_scope(rsgi_server, threading_mode):
assert data['http_version'] == '1.1'
assert data['rsgi_version'] == '1.2'
assert data['scheme'] == 'http'
assert data['method'] == "GET"
assert data['method'] == 'GET'
assert data['path'] == '/ws_info'
assert data['query_string'] == 'test=true'
assert data['headers']['host'] == f'localhost:{port}'

View file

@ -3,24 +3,18 @@ import pytest
@pytest.mark.asyncio
@pytest.mark.parametrize(
"threading_mode",
[
"runtime",
"workers"
]
)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_scope(wsgi_server, threading_mode):
payload = "body_payload"
payload = 'body_payload'
async with wsgi_server(threading_mode) as port:
res = httpx.post(f"http://localhost:{port}/info?test=true", content=payload)
res = httpx.post(f'http://localhost:{port}/info?test=true', content=payload)
assert res.status_code == 200
assert res.headers["content-type"] == "application/json"
assert res.headers['content-type'] == 'application/json'
data = res.json()
assert data['scheme'] == 'http'
assert data['method'] == "POST"
assert data['method'] == 'POST'
assert data['path'] == '/info'
assert data['query_string'] == 'test=true'
assert data['headers']['HTTP_HOST'] == f'localhost:{port}'
@ -28,47 +22,29 @@ async def test_scope(wsgi_server, threading_mode):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"threading_mode",
[
"runtime",
"workers"
]
)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_body(wsgi_server, threading_mode):
async with wsgi_server(threading_mode) as port:
res = httpx.post(f"http://localhost:{port}/echo", content="test")
res = httpx.post(f'http://localhost:{port}/echo', content='test')
assert res.status_code == 200
assert res.text == "test"
assert res.text == 'test'
@pytest.mark.asyncio
@pytest.mark.parametrize(
"threading_mode",
[
"runtime",
"workers"
]
)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_iterbody(wsgi_server, threading_mode):
async with wsgi_server(threading_mode) as port:
res = httpx.get(f"http://localhost:{port}/iterbody")
res = httpx.get(f'http://localhost:{port}/iterbody')
assert res.status_code == 200
assert res.text == "test" * 3
assert res.text == 'test' * 3
@pytest.mark.asyncio
@pytest.mark.parametrize(
"threading_mode",
[
"runtime",
"workers"
]
)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_app_error(wsgi_server, threading_mode):
async with wsgi_server(threading_mode) as port:
res = httpx.get(f"http://localhost:{port}/err_app")
res = httpx.get(f'http://localhost:{port}/err_app')
assert res.status_code == 500