Align codestyle (#124)

* Add lint and format tools and config

* Format Python code

* Format Rust code

* Add lint CI workflow
This commit is contained in:
Giovanni Barillari 2023-09-25 18:00:09 +02:00 committed by GitHub
parent e28eb81db6
commit c95fd0cba7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
53 changed files with 1388 additions and 1991 deletions

32
.github/workflows/lint.yml vendored Normal file
View file

@ -0,0 +1,32 @@
name: lint
on:
pull_request:
types: [opened, synchronize]
branches:
- master
env:
MATURIN_VERSION: 1.2.3
PYTHON_VERSION: 3.11
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ env.PYTHON_VERSION }}
uses: actions/setup-python@v4
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install
run: |
python -m venv .venv
source .venv/bin/activate
pip install maturin==${{ env.MATURIN_VERSION }}
maturin develop --extras=lint
- name: Lint
run: |
source .venv/bin/activate
make lint

26
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,26 @@
fail_fast: true
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: check-yaml
- id: check-toml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-added-large-files
- repo: local
hooks:
- id: lint-python
name: Lint Python
entry: make lint-python
types: [python]
language: system
pass_filenames: false
- id: lint-rust
name: Lint Rust
entry: make lint-rust
types: [rust]
language: system
pass_filenames: false

1
.rustfmt.toml Normal file
View file

@ -0,0 +1 @@
max_width = 120

67
Makefile Normal file
View file

@ -0,0 +1,67 @@
.DEFAULT_GOAL := all
black = black granian tests
ruff = ruff granian tests
.PHONY: build-dev
build-dev:
@rm -f granian/*.so
maturin develop --extras test
.PHONY: format
format:
$(black)
$(ruff) --fix --exit-zero
cargo fmt
.PHONY: lint-python
lint-python:
$(ruff)
$(black) --check --diff
.PHONY: lint-rust
lint-rust:
cargo fmt --version
cargo fmt --all -- --check
cargo clippy --version
cargo clippy --tests -- \
-D warnings \
-W clippy::pedantic \
-W clippy::dbg_macro \
-W clippy::print_stdout \
-A clippy::cast-possible-truncation \
-A clippy::cast-possible-wrap \
-A clippy::cast-precision-loss \
-A clippy::cast-sign-loss \
-A clippy::declare-interior-mutable-const \
-A clippy::float-cmp \
-A clippy::fn-params-excessive-bools \
-A clippy::if-not-else \
-A clippy::inline-always \
-A clippy::manual-let-else \
-A clippy::match-bool \
-A clippy::match-same-arms \
-A clippy::missing-errors-doc \
-A clippy::missing-panics-doc \
-A clippy::module-name-repetitions \
-A clippy::must-use-candidate \
-A clippy::needless-pass-by-value \
-A clippy::similar-names \
-A clippy::single-match-else \
-A clippy::struct-excessive-bools \
-A clippy::too-many-arguments \
-A clippy::too-many-lines \
-A clippy::type-complexity \
-A clippy::unnecessary-wraps \
-A clippy::unused-self \
-A clippy::used-underscore-binding \
-A clippy::wrong-self-convention
.PHONY: lint
lint: lint-python lint-rust
.PHONY: test
test:
pytest -v test
.PHONY: all
all: format build-dev lint test

View file

@ -1 +1 @@
from .server import Granian from .server import Granian # noqa

View file

@ -1,3 +1,4 @@
from granian.cli import cli from granian.cli import cli
cli() cli()

View file

@ -1 +1 @@
__version__ = "0.6.0" __version__ = '0.6.0'

View file

@ -1,12 +1,10 @@
from typing import Any, Dict, List, Tuple, Optional from typing import Any, Dict, List, Optional, Tuple
from ._types import WebsocketMessage from ._types import WebsocketMessage
class ASGIScope: class ASGIScope:
def as_dict(self, root_path: str) -> Dict[str, Any]: ... def as_dict(self, root_path: str) -> Dict[str, Any]: ...
class RSGIHeaders: class RSGIHeaders:
def __contains__(self, key: str) -> bool: ... def __contains__(self, key: str) -> bool: ...
def keys(self) -> List[str]: ... def keys(self) -> List[str]: ...
@ -14,7 +12,6 @@ class RSGIHeaders:
def items(self) -> List[Tuple[str]]: ... def items(self) -> List[Tuple[str]]: ...
def get(self, key: str, default: Any = None) -> Any: ... def get(self, key: str, default: Any = None) -> Any: ...
class RSGIScope: class RSGIScope:
proto: str proto: str
http_version: str http_version: str
@ -29,12 +26,10 @@ class RSGIScope:
@property @property
def headers(self) -> RSGIHeaders: ... def headers(self) -> RSGIHeaders: ...
class RSGIHTTPStreamTransport: class RSGIHTTPStreamTransport:
async def send_bytes(self, data: bytes): ... async def send_bytes(self, data: bytes): ...
async def send_str(self, data: str): ... async def send_str(self, data: str): ...
class RSGIHTTPProtocol: class RSGIHTTPProtocol:
async def __call__(self) -> bytes: ... async def __call__(self) -> bytes: ...
def response_empty(self, status: int, headers: List[Tuple[str, str]]): ... def response_empty(self, status: int, headers: List[Tuple[str, str]]): ...
@ -43,25 +38,17 @@ class RSGIHTTPProtocol:
def response_file(self, status: int, headers: List[Tuple[str, str]], file: str): ... def response_file(self, status: int, headers: List[Tuple[str, str]], file: str): ...
def response_stream(self, status: int, headers: List[Tuple[str, str]]) -> RSGIHTTPStreamTransport: ... def response_stream(self, status: int, headers: List[Tuple[str, str]]) -> RSGIHTTPStreamTransport: ...
class RSGIWebsocketTransport: class RSGIWebsocketTransport:
async def receive(self) -> WebsocketMessage: ... async def receive(self) -> WebsocketMessage: ...
async def send_bytes(self, data: bytes): ... async def send_bytes(self, data: bytes): ...
async def send_str(self, data: str): ... async def send_str(self, data: str): ...
class RSGIWebsocketProtocol: class RSGIWebsocketProtocol:
async def accept(self) -> RSGIWebsocketTransport: ... async def accept(self) -> RSGIWebsocketTransport: ...
def close(self, status: Optional[int]) -> Tuple[int, bool]: ... def close(self, status: Optional[int]) -> Tuple[int, bool]: ...
class RSGIProtocolError(RuntimeError): ...
class RSGIProtocolError(RuntimeError): class RSGIProtocolClosed(RuntimeError): ...
...
class RSGIProtocolClosed(RuntimeError):
...
class WSGIScope: class WSGIScope:
def to_environ(self, environ: Dict[str, Any]) -> Dict[str, Any]: ... def to_environ(self, environ: Dict[str, Any]) -> Dict[str, Any]: ...

View file

@ -2,22 +2,21 @@ import os
import re import re
import sys import sys
import traceback import traceback
from types import ModuleType from types import ModuleType
from typing import Callable, List, Optional from typing import Callable, List, Optional
def get_import_components(path: str) -> List[Optional[str]]: def get_import_components(path: str) -> List[Optional[str]]:
return (re.split(r":(?![\\/])", path, 1) + [None])[:2] return (re.split(r':(?![\\/])', path, 1) + [None])[:2]
def prepare_import(path: str) -> str: def prepare_import(path: str) -> str:
path = os.path.realpath(path) path = os.path.realpath(path)
fname, ext = os.path.splitext(path) fname, ext = os.path.splitext(path)
if ext == ".py": if ext == '.py':
path = fname path = fname
if os.path.basename(path) == "__init__": if os.path.basename(path) == '__init__':
path = os.path.dirname(path) path = os.path.dirname(path)
module_name = [] module_name = []
@ -27,26 +26,22 @@ def prepare_import(path: str) -> str:
path, name = os.path.split(path) path, name = os.path.split(path)
module_name.append(name) module_name.append(name)
if not os.path.exists(os.path.join(path, "__init__.py")): if not os.path.exists(os.path.join(path, '__init__.py')):
break break
if sys.path[0] != path: if sys.path[0] != path:
sys.path.insert(0, path) sys.path.insert(0, path)
return ".".join(module_name[::-1]) return '.'.join(module_name[::-1])
def load_module( def load_module(module_name: str, raise_on_failure: bool = True) -> Optional[ModuleType]:
module_name: str,
raise_on_failure: bool = True
) -> Optional[ModuleType]:
try: try:
__import__(module_name) __import__(module_name)
except ImportError: except ImportError:
if sys.exc_info()[-1].tb_next: if sys.exc_info()[-1].tb_next:
raise RuntimeError( raise RuntimeError(
f"While importing '{module_name}', an ImportError was raised:" f"While importing '{module_name}', an ImportError was raised:" f"\n\n{traceback.format_exc()}"
f"\n\n{traceback.format_exc()}"
) )
elif raise_on_failure: elif raise_on_failure:
raise RuntimeError(f"Could not import '{module_name}'.") raise RuntimeError(f"Could not import '{module_name}'.")
@ -58,9 +53,9 @@ def load_module(
def load_target(target: str) -> Callable[..., None]: def load_target(target: str) -> Callable[..., None]:
path, name = get_import_components(target) path, name = get_import_components(target)
path = prepare_import(path) if path else None path = prepare_import(path) if path else None
name = name or "app" name = name or 'app'
module = load_module(path) module = load_module(path)
rv = module rv = module
for element in name.split("."): for element in name.split('.'):
rv = getattr(rv, element) rv = getattr(rv, element)
return rv return rv

View file

@ -2,12 +2,11 @@ import asyncio
import os import os
import signal import signal
import sys import sys
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
class Registry: class Registry:
__slots__ = ["_data"] __slots__ = ['_data']
def __init__(self): def __init__(self):
self._data: Dict[str, Callable[..., Any]] = {} self._data: Dict[str, Callable[..., Any]] = {}
@ -22,6 +21,7 @@ class Registry:
def wrap(builder: Callable[..., Any]) -> Callable[..., Any]: def wrap(builder: Callable[..., Any]) -> Callable[..., Any]:
self._data[key] = builder self._data[key] = builder
return builder return builder
return wrap return wrap
def get(self, key: str) -> Callable[..., Any]: def get(self, key: str) -> Callable[..., Any]:
@ -31,18 +31,13 @@ class Registry:
raise RuntimeError(f"'{key}' implementation not available.") raise RuntimeError(f"'{key}' implementation not available.")
class BuilderRegistry(Registry): class BuilderRegistry(Registry):
__slots__ = [] __slots__ = []
def __init__(self): def __init__(self):
self._data: Dict[str, Tuple[Callable[..., Any], List[str]]] = {} self._data: Dict[str, Tuple[Callable[..., Any], List[str]]] = {}
def register( def register(self, key: str, packages: Optional[List[str]] = None) -> Callable[[], Callable[..., Any]]:
self,
key: str,
packages: Optional[List[str]] = None
) -> Callable[[], Callable[..., Any]]:
packages = packages or [] packages = packages or []
def wrap(builder: Callable[..., Any]) -> Callable[..., Any]: def wrap(builder: Callable[..., Any]) -> Callable[..., Any]:
@ -56,6 +51,7 @@ class BuilderRegistry(Registry):
if implemented: if implemented:
self._data[key] = (builder, loaded_packages) self._data[key] = (builder, loaded_packages)
return builder return builder
return wrap return wrap
def get(self, key: str) -> Callable[..., Any]: def get(self, key: str) -> Callable[..., Any]:

View file

@ -1,5 +1,4 @@
import asyncio import asyncio
from functools import wraps from functools import wraps
from ._granian import ASGIScope as Scope from ._granian import ASGIScope as Scope
@ -22,12 +21,7 @@ class LifespanProtocol:
async def handle(self): async def handle(self):
try: try:
await self.callable( await self.callable(
{ {'type': 'lifespan', 'asgi': {'version': '3.0', 'spec_version': '2.3'}}, self.receive, self.send
"type": "lifespan",
"asgi": {"version": "3.0", "spec_version": "2.3"}
},
self.receive,
self.send
) )
except Exception: except Exception:
self.errored = True self.errored = True
@ -43,7 +37,7 @@ class LifespanProtocol:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
_handler_task = loop.create_task(self.handle()) _handler_task = loop.create_task(self.handle())
await self.event_queue.put({"type": "lifespan.startup"}) await self.event_queue.put({'type': 'lifespan.startup'})
await self.event_startup.wait() await self.event_startup.wait()
if self.failure_startup or (self.errored and not self.unsupported): if self.failure_startup or (self.errored and not self.unsupported):
@ -53,7 +47,7 @@ class LifespanProtocol:
if self.errored: if self.errored:
return return
await self.event_queue.put({"type": "lifespan.shutdown"}) await self.event_queue.put({'type': 'lifespan.shutdown'})
await self.event_shutdown.wait() await self.event_shutdown.wait()
if self.failure_shutdown or (self.errored and not self.unsupported): if self.failure_shutdown or (self.errored and not self.unsupported):
@ -89,14 +83,14 @@ class LifespanProtocol:
# self.logger.error(message["message"]) # self.logger.error(message["message"])
_event_handlers = { _event_handlers = {
"lifespan.startup.complete": _handle_startup_complete, 'lifespan.startup.complete': _handle_startup_complete,
"lifespan.startup.failed": _handle_startup_failed, 'lifespan.startup.failed': _handle_startup_failed,
"lifespan.shutdown.complete": _handle_shutdown_complete, 'lifespan.shutdown.complete': _handle_shutdown_complete,
"lifespan.shutdown.failed": _handle_shutdown_failed 'lifespan.shutdown.failed': _handle_shutdown_failed,
} }
async def send(self, message): async def send(self, message):
handler = self._event_handlers[message["type"]] handler = self._event_handlers[message['type']]
handler(self, message) handler(self, message)
@ -108,6 +102,7 @@ def _send_wrapper(proto):
@wraps(proto) @wraps(proto)
def send(data): def send(data):
return proto(_noop_coro, data) return proto(_noop_coro, data)
return send return send
@ -116,9 +111,6 @@ def _callback_wrapper(callback, scope_opts):
@wraps(callback) @wraps(callback)
def wrapper(scope: Scope, proto): def wrapper(scope: Scope, proto):
return callback( return callback(scope.as_dict(root_url_path), proto.receive, _send_wrapper(proto.send))
scope.as_dict(root_url_path),
proto.receive,
_send_wrapper(proto.send)
)
return wrapper return wrapper

View file

@ -1,107 +1,55 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import typer import typer
from .__version__ import __version__ from .__version__ import __version__
from .constants import Interfaces, HTTPModes, Loops, ThreadModes from .constants import HTTPModes, Interfaces, Loops, ThreadModes
from .log import LogLevels from .log import LogLevels
from .server import Granian from .server import Granian
cli = typer.Typer(name="granian", context_settings={"ignore_unknown_options": True}) cli = typer.Typer(name='granian', context_settings={'ignore_unknown_options': True})
def version_callback(value: bool): def version_callback(value: bool):
if value: if value:
typer.echo(f"{cli.info.name} {__version__}") typer.echo(f'{cli.info.name} {__version__}')
raise typer.Exit() raise typer.Exit()
@cli.command() @cli.command()
def main( def main(
app: str = typer.Argument(..., help="Application target to serve."), app: str = typer.Argument(..., help='Application target to serve.'),
host: str = typer.Option("127.0.0.1", help="Host address to bind to."), host: str = typer.Option('127.0.0.1', help='Host address to bind to.'),
port: int = typer.Option(8000, help="Port to bind to."), port: int = typer.Option(8000, help='Port to bind to.'),
interface: Interfaces = typer.Option( interface: Interfaces = typer.Option(Interfaces.RSGI.value, help='Application interface type.'),
Interfaces.RSGI.value, http: HTTPModes = typer.Option(HTTPModes.auto.value, help='HTTP version.'),
help="Application interface type." websockets: bool = typer.Option(True, '--ws/--no-ws', help='Enable websockets handling', show_default='enabled'),
), workers: int = typer.Option(1, min=1, help='Number of worker processes.'),
http: HTTPModes = typer.Option( threads: int = typer.Option(1, min=1, help='Number of threads.'),
HTTPModes.auto.value, threading_mode: ThreadModes = typer.Option(ThreadModes.workers.value, help='Threading mode to use.'),
help="HTTP version." loop: Loops = typer.Option(Loops.auto.value, help='Event loop implementation'),
), loop_opt: bool = typer.Option(False, '--opt/--no-opt', help='Enable loop optimizations', show_default='disabled'),
websockets: bool = typer.Option( backlog: int = typer.Option(1024, min=128, help='Maximum number of connections to hold in backlog.'),
True, log_level: LogLevels = typer.Option(LogLevels.info.value, help='Log level', case_sensitive=False),
"--ws/--no-ws",
help="Enable websockets handling",
show_default="enabled"
),
workers: int = typer.Option(1, min=1, help="Number of worker processes."),
threads: int = typer.Option(1, min=1, help="Number of threads."),
threading_mode: ThreadModes = typer.Option(
ThreadModes.workers.value,
help="Threading mode to use."
),
loop: Loops = typer.Option(Loops.auto.value, help="Event loop implementation"),
loop_opt: bool = typer.Option(
False,
"--opt/--no-opt",
help="Enable loop optimizations",
show_default="disabled"
),
backlog: int = typer.Option(
1024,
min=128,
help="Maximum number of connections to hold in backlog."
),
log_level: LogLevels = typer.Option(
LogLevels.info.value,
help="Log level",
case_sensitive=False
),
log_config: Optional[Path] = typer.Option( log_config: Optional[Path] = typer.Option(
None, None, help='Logging configuration file (json)', exists=True, file_okay=True, dir_okay=False, readable=True
help="Logging configuration file (json)",
exists=True,
file_okay=True,
dir_okay=False,
readable=True
), ),
ssl_keyfile: Optional[Path] = typer.Option( ssl_keyfile: Optional[Path] = typer.Option(
None, None, help='SSL key file', exists=True, file_okay=True, dir_okay=False, readable=True
help="SSL key file",
exists=True,
file_okay=True,
dir_okay=False,
readable=True
), ),
ssl_certificate: Optional[Path] = typer.Option( ssl_certificate: Optional[Path] = typer.Option(
None, None, help='SSL certificate file', exists=True, file_okay=True, dir_okay=False, readable=True
help="SSL certificate file",
exists=True,
file_okay=True,
dir_okay=False,
readable=True
),
url_path_prefix: Optional[str] = typer.Option(
None,
help="URL path prefix the app is mounted on"
), ),
url_path_prefix: Optional[str] = typer.Option(None, help='URL path prefix the app is mounted on'),
reload: bool = typer.Option( reload: bool = typer.Option(
False, False, '--reload/--no-reload', help="Enable auto reload on application's files changes"
"--reload/--no-reload",
help="Enable auto reload on application's files changes"
), ),
_: Optional[bool] = typer.Option( _: Optional[bool] = typer.Option(
None, None, '--version', callback=version_callback, is_eager=True, help='Shows the version and exit.'
"--version", ),
callback=version_callback,
is_eager=True,
help="Shows the version and exit."
)
): ):
log_dictconfig = None log_dictconfig = None
if log_config: if log_config:
@ -109,7 +57,7 @@ def main(
try: try:
log_dictconfig = json.loads(log_config_file.read()) log_dictconfig = json.loads(log_config_file.read())
except Exception: except Exception:
print("Unable to parse provided logging config.") print('Unable to parse provided logging config.')
raise typer.Exit(1) raise typer.Exit(1)
Granian( Granian(
@ -131,5 +79,5 @@ def main(
ssl_cert=ssl_certificate, ssl_cert=ssl_certificate,
ssl_key=ssl_keyfile, ssl_key=ssl_keyfile,
url_path_prefix=url_path_prefix, url_path_prefix=url_path_prefix,
reload=reload reload=reload,
).serve() ).serve()

View file

@ -2,23 +2,23 @@ from enum import Enum
class Interfaces(str, Enum): class Interfaces(str, Enum):
ASGI = "asgi" ASGI = 'asgi'
RSGI = "rsgi" RSGI = 'rsgi'
WSGI = "wsgi" WSGI = 'wsgi'
class HTTPModes(str, Enum): class HTTPModes(str, Enum):
auto = "auto" auto = 'auto'
http1 = "1" http1 = '1'
http2 = "2" http2 = '2'
class ThreadModes(str, Enum): class ThreadModes(str, Enum):
runtime = "runtime" runtime = 'runtime'
workers = "workers" workers = 'workers'
class Loops(str, Enum): class Loops(str, Enum):
auto = "auto" auto = 'auto'
asyncio = "asyncio" asyncio = 'asyncio'
uvloop = "uvloop" uvloop = 'uvloop'

View file

@ -1,18 +1,17 @@
import copy import copy
import logging import logging
import logging.config import logging.config
from enum import Enum from enum import Enum
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
class LogLevels(str, Enum): class LogLevels(str, Enum):
critical = "critical" critical = 'critical'
error = "error" error = 'error'
warning = "warning" warning = 'warning'
warn = "warn" warn = 'warn'
info = "info" info = 'info'
debug = "debug" debug = 'debug'
log_levels_map = { log_levels_map = {
@ -21,27 +20,21 @@ log_levels_map = {
LogLevels.warning: logging.WARNING, LogLevels.warning: logging.WARNING,
LogLevels.warn: logging.WARN, LogLevels.warn: logging.WARN,
LogLevels.info: logging.INFO, LogLevels.info: logging.INFO,
LogLevels.debug: logging.DEBUG LogLevels.debug: logging.DEBUG,
} }
LOGGING_CONFIG = { LOGGING_CONFIG = {
"version": 1, 'version': 1,
"disable_existing_loggers": False, 'disable_existing_loggers': False,
"root": {"level": "INFO", "handlers": ["console"]}, 'root': {'level': 'INFO', 'handlers': ['console']},
"formatters": { 'formatters': {
"generic": { 'generic': {
"()": "logging.Formatter", '()': 'logging.Formatter',
"fmt": "[%(levelname)s] %(message)s", 'fmt': '[%(levelname)s] %(message)s',
"datefmt": "[%Y-%m-%d %H:%M:%S %z]" 'datefmt': '[%Y-%m-%d %H:%M:%S %z]',
} }
}, },
"handlers": { 'handlers': {'console': {'formatter': 'generic', 'class': 'logging.StreamHandler', 'stream': 'ext://sys.stdout'}},
"console": {
"formatter": "generic",
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
}
}
} }
logger = logging.getLogger() logger = logging.getLogger()
@ -51,5 +44,5 @@ def configure_logging(level: LogLevels, config: Optional[Dict[str, Any]] = None)
log_config = copy.deepcopy(LOGGING_CONFIG) log_config = copy.deepcopy(LOGGING_CONFIG)
if config: if config:
log_config.update(config) log_config.update(config)
log_config["root"]["level"] = log_levels_map[level] log_config['root']['level'] = log_levels_map[level]
logging.config.dictConfig(log_config) logging.config.dictConfig(log_config)

View file

@ -2,7 +2,5 @@ import copyreg
from ._granian import ListenerHolder as SocketHolder from ._granian import ListenerHolder as SocketHolder
copyreg.pickle(
SocketHolder, copyreg.pickle(SocketHolder, lambda v: (SocketHolder, v.__getstate__()))
lambda v: (SocketHolder, v.__getstate__())
)

View file

@ -2,12 +2,12 @@ from enum import Enum
from typing import Union from typing import Union
from ._granian import ( from ._granian import (
RSGIHTTPProtocol as HTTPProtocol, RSGIHeaders as Headers, # noqa
RSGIWebsocketProtocol as WebsocketProtocol, RSGIHTTPProtocol as HTTPProtocol, # noqa
RSGIHeaders as Headers, RSGIProtocolClosed as ProtocolClosed, # noqa
RSGIScope as Scope, RSGIProtocolError as ProtocolError, # noqa
RSGIProtocolError as ProtocolError, RSGIScope as Scope, # noqa
RSGIProtocolClosed as ProtocolClosed RSGIWebsocketProtocol as WebsocketProtocol, # noqa
) )

View file

@ -5,7 +5,6 @@ import socket
import ssl import ssl
import sys import sys
import threading import threading
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
@ -16,11 +15,12 @@ from ._futures import future_watcher_wrapper
from ._granian import ASGIWorker, RSGIWorker, WSGIWorker from ._granian import ASGIWorker, RSGIWorker, WSGIWorker
from ._internal import load_target from ._internal import load_target
from .asgi import LifespanProtocol, _callback_wrapper as _asgi_call_wrap from .asgi import LifespanProtocol, _callback_wrapper as _asgi_call_wrap
from .constants import Interfaces, HTTPModes, Loops, ThreadModes from .constants import HTTPModes, Interfaces, Loops, ThreadModes
from .log import LogLevels, configure_logging, logger from .log import LogLevels, configure_logging, logger
from .net import SocketHolder from .net import SocketHolder
from .wsgi import _callback_wrapper as _wsgi_call_wrap from .wsgi import _callback_wrapper as _wsgi_call_wrap
multiprocessing.allow_connection_pickling() multiprocessing.allow_connection_pickling()
@ -30,7 +30,7 @@ class Granian:
def __init__( def __init__(
self, self,
target: str, target: str,
address: str = "127.0.0.1", address: str = '127.0.0.1',
port: int = 8000, port: int = 8000,
interface: Interfaces = Interfaces.RSGI, interface: Interfaces = Interfaces.RSGI,
workers: int = 1, workers: int = 1,
@ -48,7 +48,7 @@ class Granian:
ssl_cert: Optional[Path] = None, ssl_cert: Optional[Path] = None,
ssl_key: Optional[Path] = None, ssl_key: Optional[Path] = None,
url_path_prefix: Optional[str] = None, url_path_prefix: Optional[str] = None,
reload: bool = False reload: bool = False,
): ):
self.target = target self.target = target
self.bind_addr = address self.bind_addr = address
@ -75,11 +75,7 @@ class Granian:
self.procs: List[multiprocessing.Process] = [] self.procs: List[multiprocessing.Process] = []
self.exit_event = threading.Event() self.exit_event = threading.Event()
def build_ssl_context( def build_ssl_context(self, cert: Optional[Path], key: Optional[Path]):
self,
cert: Optional[Path],
key: Optional[Path]
):
if not (cert and key): if not (cert and key):
self.ssl_ctx = (False, None, None) self.ssl_ctx = (False, None, None)
return return
@ -108,7 +104,7 @@ class Granian:
log_level, log_level,
log_config, log_config,
ssl_ctx, ssl_ctx,
scope_opts scope_opts,
): ):
from granian._loops import loops, set_loop_signals from granian._loops import loops, set_loop_signals
@ -129,29 +125,12 @@ class Granian:
wcallback = future_watcher_wrapper(wcallback) wcallback = future_watcher_wrapper(wcallback)
worker = ASGIWorker( worker = ASGIWorker(
worker_id, worker_id, sfd, threads, pthreads, http_mode, http1_buffer_size, websockets, loop_opt, *ssl_ctx
sfd,
threads,
pthreads,
http_mode,
http1_buffer_size,
websockets,
loop_opt,
*ssl_ctx
)
serve = getattr(worker, {
ThreadModes.runtime: "serve_rth",
ThreadModes.workers: "serve_wth"
}[threading_mode])
serve(
wcallback,
loop,
contextvars.copy_context(),
shutdown_event.wait()
) )
serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode])
serve(wcallback, loop, contextvars.copy_context(), shutdown_event.wait())
loop.run_until_complete(lifespan_handler.shutdown()) loop.run_until_complete(lifespan_handler.shutdown())
@staticmethod @staticmethod
def _spawn_rsgi_worker( def _spawn_rsgi_worker(
worker_id, worker_id,
@ -168,7 +147,7 @@ class Granian:
log_level, log_level,
log_config, log_config,
ssl_ctx, ssl_ctx,
scope_opts scope_opts,
): ):
from granian._loops import loops, set_loop_signals from granian._loops import loops, set_loop_signals
@ -176,38 +155,23 @@ class Granian:
loop = loops.get(loop_impl) loop = loops.get(loop_impl)
sfd = socket.fileno() sfd = socket.fileno()
target = callback_loader() target = callback_loader()
callback = ( callback = getattr(target, '__rsgi__') if hasattr(target, '__rsgi__') else target
getattr(target, '__rsgi__') if hasattr(target, '__rsgi__') else
target
)
callback_init = ( callback_init = (
getattr(target, '__rsgi_init__') if hasattr(target, '__rsgi_init__') else getattr(target, '__rsgi_init__') if hasattr(target, '__rsgi_init__') else lambda *args, **kwargs: None
lambda *args, **kwargs: None
) )
shutdown_event = set_loop_signals(loop, [signal.SIGTERM, signal.SIGINT]) shutdown_event = set_loop_signals(loop, [signal.SIGTERM, signal.SIGINT])
callback_init(loop) callback_init(loop)
worker = RSGIWorker( worker = RSGIWorker(
worker_id, worker_id, sfd, threads, pthreads, http_mode, http1_buffer_size, websockets, loop_opt, *ssl_ctx
sfd,
threads,
pthreads,
http_mode,
http1_buffer_size,
websockets,
loop_opt,
*ssl_ctx
) )
serve = getattr(worker, { serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode])
ThreadModes.runtime: "serve_rth",
ThreadModes.workers: "serve_wth"
}[threading_mode])
serve( serve(
future_watcher_wrapper(callback) if not loop_opt else callback, future_watcher_wrapper(callback) if not loop_opt else callback,
loop, loop,
contextvars.copy_context(), contextvars.copy_context(),
shutdown_event.wait() shutdown_event.wait(),
) )
@staticmethod @staticmethod
@ -226,7 +190,7 @@ class Granian:
log_level, log_level,
log_config, log_config,
ssl_ctx, ssl_ctx,
scope_opts scope_opts,
): ):
from granian._loops import loops, set_loop_signals from granian._loops import loops, set_loop_signals
@ -237,46 +201,20 @@ class Granian:
shutdown_event = set_loop_signals(loop, [signal.SIGTERM, signal.SIGINT]) shutdown_event = set_loop_signals(loop, [signal.SIGTERM, signal.SIGINT])
worker = WSGIWorker( worker = WSGIWorker(worker_id, sfd, threads, pthreads, http_mode, http1_buffer_size, *ssl_ctx)
worker_id, serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode])
sfd, serve(_wsgi_call_wrap(callback, scope_opts), loop, contextvars.copy_context(), shutdown_event.wait())
threads,
pthreads,
http_mode,
http1_buffer_size,
*ssl_ctx
)
serve = getattr(worker, {
ThreadModes.runtime: "serve_rth",
ThreadModes.workers: "serve_wth"
}[threading_mode])
serve(
_wsgi_call_wrap(callback, scope_opts),
loop,
contextvars.copy_context(),
shutdown_event.wait()
)
def _init_shared_socket(self): def _init_shared_socket(self):
self._shd = SocketHolder.from_address( self._shd = SocketHolder.from_address(self.bind_addr, self.bind_port, self.backlog)
self.bind_addr,
self.bind_port,
self.backlog
)
self._sfd = self._shd.get_fd() self._sfd = self._shd.get_fd()
def signal_handler(self, *args, **kwargs): def signal_handler(self, *args, **kwargs):
self.exit_event.set() self.exit_event.set()
def _spawn_proc( def _spawn_proc(self, id, target, callback_loader, socket_loader) -> multiprocessing.Process:
self,
id,
target,
callback_loader,
socket_loader
) -> multiprocessing.Process:
return multiprocessing.get_context().Process( return multiprocessing.get_context().Process(
name="granian-worker", name='granian-worker',
target=target, target=target,
args=( args=(
id, id,
@ -293,10 +231,8 @@ class Granian:
self.log_level, self.log_level,
self.log_config, self.log_config,
self.ssl_ctx, self.ssl_ctx,
{ {'url_path_prefix': self.url_path_prefix},
"url_path_prefix": self.url_path_prefix ),
}
)
) )
def _spawn_workers(self, sock, spawn_target, target_loader): def _spawn_workers(self, sock, spawn_target, target_loader):
@ -305,14 +241,11 @@ class Granian:
for idx in range(self.workers): for idx in range(self.workers):
proc = self._spawn_proc( proc = self._spawn_proc(
id=idx + 1, id=idx + 1, target=spawn_target, callback_loader=target_loader, socket_loader=socket_loader
target=spawn_target,
callback_loader=target_loader,
socket_loader=socket_loader
) )
proc.start() proc.start()
self.procs.append(proc) self.procs.append(proc)
logger.info(f"Spawning worker-{idx + 1} with pid: {proc.pid}") logger.info(f'Spawning worker-{idx + 1} with pid: {proc.pid}')
def _stop_workers(self): def _stop_workers(self):
for proc in self.procs: for proc in self.procs:
@ -321,7 +254,7 @@ class Granian:
proc.join() proc.join()
def startup(self, spawn_target, target_loader): def startup(self, spawn_target, target_loader):
logger.info("Starting granian") logger.info('Starting granian')
for sig in self.SIGNALS: for sig in self.SIGNALS:
signal.signal(sig, self.signal_handler) signal.signal(sig, self.signal_handler)
@ -329,13 +262,13 @@ class Granian:
self._init_shared_socket() self._init_shared_socket()
sock = socket.socket(fileno=self._sfd) sock = socket.socket(fileno=self._sfd)
sock.set_inheritable(True) sock.set_inheritable(True)
logger.info(f"Listening at: {self.bind_addr}:{self.bind_port}") logger.info(f'Listening at: {self.bind_addr}:{self.bind_port}')
self._spawn_workers(sock, spawn_target, target_loader) self._spawn_workers(sock, spawn_target, target_loader)
return sock return sock
def shutdown(self): def shutdown(self):
logger.info("Shutting down granian") logger.info('Shutting down granian')
self._stop_workers() self._stop_workers()
def _serve(self, spawn_target, target_loader): def _serve(self, spawn_target, target_loader):
@ -360,12 +293,12 @@ class Granian:
self, self,
spawn_target: Optional[Callable[..., None]] = None, spawn_target: Optional[Callable[..., None]] = None,
target_loader: Optional[Callable[..., Callable[..., Any]]] = None, target_loader: Optional[Callable[..., Callable[..., Any]]] = None,
wrap_loader: bool = True wrap_loader: bool = True,
): ):
default_spawners = { default_spawners = {
Interfaces.ASGI: self._spawn_asgi_worker, Interfaces.ASGI: self._spawn_asgi_worker,
Interfaces.RSGI: self._spawn_rsgi_worker, Interfaces.RSGI: self._spawn_rsgi_worker,
Interfaces.WSGI: self._spawn_wsgi_worker Interfaces.WSGI: self._spawn_wsgi_worker,
} }
if target_loader: if target_loader:
if wrap_loader: if wrap_loader:
@ -383,8 +316,5 @@ class Granian:
"Number of workers will now fallback to 1." "Number of workers will now fallback to 1."
) )
serve_method = ( serve_method = self._serve_with_reloader if self.reload_on_changes else self._serve
self._serve_with_reloader if self.reload_on_changes else
self._serve
)
serve_method(spawn_target, target_loader) serve_method(spawn_target, target_loader)

View file

@ -1,6 +1,5 @@
import os import os
import sys import sys
from functools import wraps from functools import wraps
from typing import Any, List, Tuple from typing import Any, List, Tuple
@ -14,30 +13,27 @@ class Response:
self.status = 200 self.status = 200
self.headers = [] self.headers = []
def __call__( def __call__(self, status: str, headers: List[Tuple[str, str]], exc_info: Any = None):
self,
status: str,
headers: List[Tuple[str, str]],
exc_info: Any = None
):
self.status = int(status.split(' ', 1)[0]) self.status = int(status.split(' ', 1)[0])
self.headers = headers self.headers = headers
def _callback_wrapper(callback, scope_opts): def _callback_wrapper(callback, scope_opts):
basic_env = dict(os.environ) basic_env = dict(os.environ)
basic_env.update({ basic_env.update(
'GATEWAY_INTERFACE': 'CGI/1.1', {
'SCRIPT_NAME': scope_opts.get('url_path_prefix') or '', 'GATEWAY_INTERFACE': 'CGI/1.1',
'SERVER_SOFTWARE': 'Granian', 'SCRIPT_NAME': scope_opts.get('url_path_prefix') or '',
'wsgi.errors': sys.stderr, 'SERVER_SOFTWARE': 'Granian',
'wsgi.input_terminated': True, 'wsgi.errors': sys.stderr,
'wsgi.input': None, 'wsgi.input_terminated': True,
'wsgi.multiprocess': False, 'wsgi.input': None,
'wsgi.multithread': False, 'wsgi.multiprocess': False,
'wsgi.run_once': False, 'wsgi.multithread': False,
'wsgi.version': (1, 0) 'wsgi.run_once': False,
}) 'wsgi.version': (1, 0),
}
)
@wraps(callback) @wraps(callback)
def wrapper(scope: Scope) -> Tuple[int, List[Tuple[str, str]], bytes]: def wrapper(scope: Scope) -> Tuple[int, List[Tuple[str, str]], bytes]:
@ -46,7 +42,7 @@ def _callback_wrapper(callback, scope_opts):
if isinstance(rv, list): if isinstance(rv, list):
resp_type = 0 resp_type = 0
rv = b"".join(rv) rv = b''.join(rv)
else: else:
resp_type = 1 resp_type = 1
rv = iter(rv) rv = iter(rv)

View file

@ -1,65 +1,111 @@
[project] [project]
name = "granian" name = 'granian'
authors = [ authors = [
{name = "Giovanni Barillari", email = "g@baro.dev"} {name = 'Giovanni Barillari', email = 'g@baro.dev'}
] ]
classifiers = [ classifiers = [
"Development Status :: 5 - Production/Stable", 'Development Status :: 5 - Production/Stable',
"Intended Audience :: Developers", 'Intended Audience :: Developers',
"License :: OSI Approved :: BSD License", 'License :: OSI Approved :: BSD License',
"Operating System :: MacOS", 'Operating System :: MacOS',
"Operating System :: Microsoft :: Windows", 'Operating System :: Microsoft :: Windows',
"Operating System :: POSIX :: Linux", 'Operating System :: POSIX :: Linux',
"Programming Language :: Python :: 3", 'Programming Language :: Python :: 3',
"Programming Language :: Python :: 3.8", 'Programming Language :: Python :: 3.8',
"Programming Language :: Python :: 3.9", 'Programming Language :: Python :: 3.9',
"Programming Language :: Python :: 3.10", 'Programming Language :: Python :: 3.10',
"Programming Language :: Python :: 3.11", 'Programming Language :: Python :: 3.11',
"Programming Language :: Python :: Implementation :: CPython", 'Programming Language :: Python :: Implementation :: CPython',
"Programming Language :: Python :: Implementation :: PyPy", 'Programming Language :: Python :: Implementation :: PyPy',
"Programming Language :: Python", 'Programming Language :: Python',
"Programming Language :: Rust", 'Programming Language :: Rust',
"Topic :: Internet :: WWW/HTTP" 'Topic :: Internet :: WWW/HTTP'
] ]
dynamic = [ dynamic = [
"description", 'description',
"keywords", 'keywords',
"license", 'license',
"readme", 'readme',
"version" 'version'
] ]
requires-python = ">=3.8" requires-python = '>=3.8'
dependencies = [ dependencies = [
"watchfiles~=0.18", 'watchfiles~=0.18',
"typer~=0.4", 'typer~=0.4',
"uvloop~=0.17.0; sys_platform != 'win32' and platform_python_implementation == 'CPython'" 'uvloop~=0.17.0; sys_platform != "win32" and platform_python_implementation == "CPython"'
] ]
[project.optional-dependencies] [project.optional-dependencies]
lint = [
'black~=23.7.0',
'ruff~=0.0.287'
]
test = [ test = [
"httpx~=0.23.0", 'httpx~=0.23.0',
"pytest~=7.1.2", 'pytest~=7.1.2',
"pytest-asyncio~=0.18.3", 'pytest-asyncio~=0.18.3',
"websockets~=10.3" 'websockets~=10.3'
] ]
[project.urls] [project.urls]
Homepage = "https://github.com/emmett-framework/granian" Homepage = 'https://github.com/emmett-framework/granian'
Funding = "https://github.com/sponsors/gi0baro" Funding = 'https://github.com/sponsors/gi0baro'
Source = "https://github.com/emmett-framework/granian" Source = 'https://github.com/emmett-framework/granian'
[project.scripts] [project.scripts]
granian = "granian:cli.cli" granian = 'granian:cli.cli'
[build-system] [build-system]
requires = ["maturin>=1.1.0,<1.3.0"] requires = ['maturin>=1.1.0,<1.3.0']
build-backend = "maturin" build-backend = 'maturin'
[tool.maturin] [tool.maturin]
module-name = "granian._granian" module-name = 'granian._granian'
bindings = "pyo3" bindings = 'pyo3'
[tool.ruff]
line-length = 120
extend-select = [
# E and F are enabled by default
'B', # flake8-bugbear
'C4', # flake8-comprehensions
'C90', # mccabe
'I', # isort
'N', # pep8-naming
'Q', # flake8-quotes
'RUF100', # ruff (unused noqa)
'S', # flake8-bandit
'W' # pycodestyle
]
extend-ignore = [
'B008', # function calls in args defaults are fine
'B009', # getattr with constants is fine
'B034', # re.split won't confuse us
'B904', # rising without from is fine
'E501', # leave line length to black
'N818', # leave to us exceptions naming
'S101' # assert is fine
]
flake8-quotes = { inline-quotes = 'single', multiline-quotes = 'double' }
mccabe = { max-complexity = 13 }
[tool.ruff.isort]
combine-as-imports = true
lines-after-imports = 2
known-first-party = ['granian', 'tests']
[tool.ruff.per-file-ignores]
'granian/_granian.pyi' = ['I001']
'tests/**' = ['B018', 'S110', 'S501']
[tool.black]
color = true
line-length = 120
target-version = ['py38', 'py39', 'py310', 'py311']
skip-string-normalization = true # leave this to ruff
skip-magic-trailing-comma = true # leave this to ruff
[tool.pytest.ini_options] [tool.pytest.ini_options]
asyncio_mode = "auto" asyncio_mode = 'auto'

View file

@ -3,45 +3,33 @@ use pyo3::prelude::*;
use pyo3_asyncio::TaskLocals; use pyo3_asyncio::TaskLocals;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use crate::{
callbacks::{
CallbackWrapper,
callback_impl_run,
callback_impl_run_pytask,
callback_impl_loop_run,
callback_impl_loop_pytask,
callback_impl_loop_step,
callback_impl_loop_wake,
callback_impl_loop_err
},
runtime::RuntimeRef,
ws::{HyperWebsocket, UpgradeData}
};
use super::{ use super::{
io::{ASGIHTTPProtocol as HTTPProtocol, ASGIWebsocketProtocol as WebsocketProtocol}, io::{ASGIHTTPProtocol as HTTPProtocol, ASGIWebsocketProtocol as WebsocketProtocol},
types::ASGIScope as Scope types::ASGIScope as Scope,
};
use crate::{
callbacks::{
callback_impl_loop_err, callback_impl_loop_pytask, callback_impl_loop_run, callback_impl_loop_step,
callback_impl_loop_wake, callback_impl_run, callback_impl_run_pytask, CallbackWrapper,
},
runtime::RuntimeRef,
ws::{HyperWebsocket, UpgradeData},
}; };
#[pyclass] #[pyclass]
pub(crate) struct CallbackRunnerHTTP { pub(crate) struct CallbackRunnerHTTP {
proto: Py<HTTPProtocol>, proto: Py<HTTPProtocol>,
context: TaskLocals, context: TaskLocals,
cb: PyObject cb: PyObject,
} }
impl CallbackRunnerHTTP { impl CallbackRunnerHTTP {
pub fn new( pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self {
py: Python,
cb: CallbackWrapper,
proto: HTTPProtocol,
scope: Scope
) -> Self {
let pyproto = Py::new(py, proto).unwrap(); let pyproto = Py::new(py, proto).unwrap();
Self { Self {
proto: pyproto.clone(), proto: pyproto.clone(),
context: cb.context, context: cb.context,
cb: cb.callback.call1(py, (scope, pyproto)).unwrap() cb: cb.callback.call1(py, (scope, pyproto)).unwrap(),
} }
} }
@ -64,14 +52,14 @@ macro_rules! callback_impl_done_http {
let _ = tx.send(res); let _ = tx.send(res);
} }
} }
} };
} }
macro_rules! callback_impl_done_err { macro_rules! callback_impl_done_err {
($self:expr, $py:expr) => { ($self:expr, $py:expr) => {
log::warn!("Application callable raised an exception"); log::warn!("Application callable raised an exception");
$self.done($py) $self.done($py)
} };
} }
#[pyclass] #[pyclass]
@ -79,22 +67,17 @@ pub(crate) struct CallbackTaskHTTP {
proto: Py<HTTPProtocol>, proto: Py<HTTPProtocol>,
context: TaskLocals, context: TaskLocals,
pycontext: PyObject, pycontext: PyObject,
cb: PyObject cb: PyObject,
} }
impl CallbackTaskHTTP { impl CallbackTaskHTTP {
pub fn new( pub fn new(py: Python, cb: PyObject, proto: Py<HTTPProtocol>, context: TaskLocals) -> PyResult<Self> {
py: Python,
cb: PyObject,
proto: Py<HTTPProtocol>,
context: TaskLocals
) -> PyResult<Self> {
let pyctx = context.context(py); let pyctx = context.context(py);
Ok(Self { Ok(Self {
proto, proto,
context, context,
pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(),
cb cb,
}) })
} }
@ -128,21 +111,16 @@ pub(crate) struct CallbackWrappedRunnerHTTP {
context: TaskLocals, context: TaskLocals,
cb: PyObject, cb: PyObject,
#[pyo3(get)] #[pyo3(get)]
scope: PyObject scope: PyObject,
} }
impl CallbackWrappedRunnerHTTP { impl CallbackWrappedRunnerHTTP {
pub fn new( pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self {
py: Python,
cb: CallbackWrapper,
proto: HTTPProtocol,
scope: Scope
) -> Self {
Self { Self {
proto: Py::new(py, proto).unwrap(), proto: Py::new(py, proto).unwrap(),
context: cb.context, context: cb.context,
cb: cb.callback, cb: cb.callback,
scope: scope.into_py(py) scope: scope.into_py(py),
} }
} }
@ -168,21 +146,16 @@ impl CallbackWrappedRunnerHTTP {
pub(crate) struct CallbackRunnerWebsocket { pub(crate) struct CallbackRunnerWebsocket {
proto: Py<WebsocketProtocol>, proto: Py<WebsocketProtocol>,
context: TaskLocals, context: TaskLocals,
cb: PyObject cb: PyObject,
} }
impl CallbackRunnerWebsocket { impl CallbackRunnerWebsocket {
pub fn new( pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self {
py: Python,
cb: CallbackWrapper,
proto: WebsocketProtocol,
scope: Scope
) -> Self {
let pyproto = Py::new(py, proto).unwrap(); let pyproto = Py::new(py, proto).unwrap();
Self { Self {
proto: pyproto.clone(), proto: pyproto.clone(),
context: cb.context, context: cb.context,
cb: cb.callback.call1(py, (scope, pyproto)).unwrap() cb: cb.callback.call1(py, (scope, pyproto)).unwrap(),
} }
} }
@ -203,7 +176,7 @@ macro_rules! callback_impl_done_ws {
let _ = tx.send(res); let _ = tx.send(res);
} }
} }
} };
} }
#[pyclass] #[pyclass]
@ -211,22 +184,17 @@ pub(crate) struct CallbackTaskWebsocket {
proto: Py<WebsocketProtocol>, proto: Py<WebsocketProtocol>,
context: TaskLocals, context: TaskLocals,
pycontext: PyObject, pycontext: PyObject,
cb: PyObject cb: PyObject,
} }
impl CallbackTaskWebsocket { impl CallbackTaskWebsocket {
pub fn new( pub fn new(py: Python, cb: PyObject, proto: Py<WebsocketProtocol>, context: TaskLocals) -> PyResult<Self> {
py: Python,
cb: PyObject,
proto: Py<WebsocketProtocol>,
context: TaskLocals
) -> PyResult<Self> {
let pyctx = context.context(py); let pyctx = context.context(py);
Ok(Self { Ok(Self {
proto, proto,
context, context,
pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(),
cb cb,
}) })
} }
@ -260,21 +228,16 @@ pub(crate) struct CallbackWrappedRunnerWebsocket {
context: TaskLocals, context: TaskLocals,
cb: PyObject, cb: PyObject,
#[pyo3(get)] #[pyo3(get)]
scope: PyObject scope: PyObject,
} }
impl CallbackWrappedRunnerWebsocket { impl CallbackWrappedRunnerWebsocket {
pub fn new( pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self {
py: Python,
cb: CallbackWrapper,
proto: WebsocketProtocol,
scope: Scope
) -> Self {
Self { Self {
proto: Py::new(py, proto).unwrap(), proto: Py::new(py, proto).unwrap(),
context: cb.context, context: cb.context,
cb: cb.callback, cb: cb.callback,
scope: scope.into_py(py) scope: scope.into_py(py),
} }
} }
@ -325,7 +288,7 @@ macro_rules! call_impl_rtb_http {
cb: CallbackWrapper, cb: CallbackWrapper,
rt: RuntimeRef, rt: RuntimeRef,
req: Request<Body>, req: Request<Body>,
scope: Scope scope: Scope,
) -> oneshot::Receiver<Response<Body>> { ) -> oneshot::Receiver<Response<Body>> {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
let protocol = HTTPProtocol::new(rt, req, tx); let protocol = HTTPProtocol::new(rt, req, tx);
@ -345,7 +308,7 @@ macro_rules! call_impl_rtt_http {
cb: CallbackWrapper, cb: CallbackWrapper,
rt: RuntimeRef, rt: RuntimeRef,
req: Request<Body>, req: Request<Body>,
scope: Scope scope: Scope,
) -> oneshot::Receiver<Response<Body>> { ) -> oneshot::Receiver<Response<Body>> {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
let protocol = HTTPProtocol::new(rt, req, tx); let protocol = HTTPProtocol::new(rt, req, tx);
@ -368,7 +331,7 @@ macro_rules! call_impl_rtb_ws {
rt: RuntimeRef, rt: RuntimeRef,
ws: HyperWebsocket, ws: HyperWebsocket,
upgrade: UpgradeData, upgrade: UpgradeData,
scope: Scope scope: Scope,
) -> oneshot::Receiver<bool> { ) -> oneshot::Receiver<bool> {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade);
@ -389,7 +352,7 @@ macro_rules! call_impl_rtt_ws {
rt: RuntimeRef, rt: RuntimeRef,
ws: HyperWebsocket, ws: HyperWebsocket,
upgrade: UpgradeData, upgrade: UpgradeData,
scope: Scope scope: Scope,
) -> oneshot::Receiver<bool> { ) -> oneshot::Receiver<bool> {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade);

View file

@ -88,5 +88,5 @@ macro_rules! error_message {
} }
pub(crate) use error_flow; pub(crate) use error_flow;
pub(crate) use error_transport;
pub(crate) use error_message; pub(crate) use error_message;
pub(crate) use error_transport;

View file

@ -1,34 +1,22 @@
use hyper::{ use hyper::{
Body, header::SERVER as HK_SERVER, http::response::Builder as ResponseBuilder, Body, Request, Response, StatusCode,
Request,
Response,
StatusCode,
header::SERVER as HK_SERVER,
http::response::Builder as ResponseBuilder
}; };
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use crate::{
callbacks::CallbackWrapper,
http::{HV_SERVER, response_500},
runtime::RuntimeRef,
ws::{UpgradeData, is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade}
};
use super::{ use super::{
callbacks::{ callbacks::{
call_rtb_http, call_rtb_http, call_rtb_http_pyw, call_rtb_ws, call_rtb_ws_pyw, call_rtt_http, call_rtt_http_pyw, call_rtt_ws,
call_rtb_http_pyw, call_rtt_ws_pyw,
call_rtb_ws,
call_rtb_ws_pyw,
call_rtt_http,
call_rtt_http_pyw,
call_rtt_ws,
call_rtt_ws_pyw
}, },
types::ASGIScope as Scope types::ASGIScope as Scope,
};
use crate::{
callbacks::CallbackWrapper,
http::{response_500, HV_SERVER},
runtime::RuntimeRef,
ws::{is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade, UpgradeData},
}; };
macro_rules! default_scope { macro_rules! default_scope {
($server_addr:expr, $client_addr:expr, $req:expr, $scheme:expr) => { ($server_addr:expr, $client_addr:expr, $req:expr, $scheme:expr) => {
@ -39,7 +27,7 @@ macro_rules! default_scope {
$req.method().as_ref(), $req.method().as_ref(),
$server_addr, $server_addr,
$client_addr, $client_addr,
$req.headers() $req.headers(),
) )
}; };
} }
@ -53,7 +41,7 @@ macro_rules! handle_http_response {
response_500() response_500()
} }
} }
} };
} }
macro_rules! handle_request { macro_rules! handle_request {
@ -64,7 +52,7 @@ macro_rules! handle_request {
server_addr: SocketAddr, server_addr: SocketAddr,
client_addr: SocketAddr, client_addr: SocketAddr,
req: Request<Body>, req: Request<Body>,
scheme: &str scheme: &str,
) -> Response<Body> { ) -> Response<Body> {
let scope = default_scope!(server_addr, client_addr, &req, scheme); let scope = default_scope!(server_addr, client_addr, &req, scheme);
handle_http_response!($handler, rt, callback, req, scope) handle_http_response!($handler, rt, callback, req, scope)
@ -80,7 +68,7 @@ macro_rules! handle_request_with_ws {
server_addr: SocketAddr, server_addr: SocketAddr,
client_addr: SocketAddr, client_addr: SocketAddr,
req: Request<Body>, req: Request<Body>,
scheme: &str scheme: &str,
) -> Response<Body> { ) -> Response<Body> {
let mut scope = default_scope!(server_addr, client_addr, &req, scheme); let mut scope = default_scope!(server_addr, client_addr, &req, scheme);
@ -95,24 +83,20 @@ macro_rules! handle_request_with_ws {
rt.inner.spawn(async move { rt.inner.spawn(async move {
let tx_ref = restx.clone(); let tx_ref = restx.clone();
match $handler_ws( match $handler_ws(callback, rth, ws, UpgradeData::new(res, restx), scope).await {
callback,
rth,
ws,
UpgradeData::new(res, restx),
scope
).await {
Ok(consumed) => { Ok(consumed) => {
if !consumed { if !consumed {
let _ = tx_ref.send( let _ = tx_ref
ResponseBuilder::new() .send(
.status(StatusCode::FORBIDDEN) ResponseBuilder::new()
.header(HK_SERVER, HV_SERVER) .status(StatusCode::FORBIDDEN)
.body(Body::from("")) .header(HK_SERVER, HV_SERVER)
.unwrap() .body(Body::from(""))
).await; .unwrap(),
)
.await;
}; };
}, }
_ => { _ => {
log::error!("ASGI protocol failure"); log::error!("ASGI protocol failure");
let _ = tx_ref.send(response_500()).await; let _ = tx_ref.send(response_500()).await;
@ -124,10 +108,10 @@ macro_rules! handle_request_with_ws {
Some(res) => { Some(res) => {
resrx.close(); resrx.close();
res res
}, }
_ => response_500() _ => response_500(),
} }
}, }
Err(err) => { Err(err) => {
return ResponseBuilder::new() return ResponseBuilder::new()
.status(StatusCode::BAD_REQUEST) .status(StatusCode::BAD_REQUEST)

View file

@ -1,33 +1,34 @@
use bytes::Bytes; use bytes::Bytes;
use futures::{sink::SinkExt, stream::{SplitSink, SplitStream, StreamExt}}; use futures::{
sink::SinkExt,
stream::{SplitSink, SplitStream, StreamExt},
};
use hyper::{ use hyper::{
Request,
Response,
body::{Body, HttpBody, Sender as BodySender}, body::{Body, HttpBody, Sender as BodySender},
header::{HeaderName, HeaderValue, HeaderMap, SERVER as HK_SERVER} header::{HeaderMap, HeaderName, HeaderValue, SERVER as HK_SERVER},
Request, Response,
}; };
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict}; use pyo3::types::{PyBytes, PyDict};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{oneshot, Mutex};
use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::WebSocketStream;
use tokio::sync::{Mutex, oneshot};
use tungstenite::Message; use tungstenite::Message;
use super::{
errors::{error_flow, error_message, error_transport, UnsupportedASGIMessage},
types::ASGIMessageType,
};
use crate::{ use crate::{
http::HV_SERVER, http::HV_SERVER,
runtime::{RuntimeRef, future_into_py_iter, future_into_py_futlike}, runtime::{future_into_py_futlike, future_into_py_iter, RuntimeRef},
ws::{HyperWebsocket, UpgradeData} ws::{HyperWebsocket, UpgradeData},
}; };
use super::{
errors::{UnsupportedASGIMessage, error_flow, error_transport, error_message},
types::ASGIMessageType
};
const EMPTY_BYTES: Vec<u8> = Vec::new(); const EMPTY_BYTES: Vec<u8> = Vec::new();
const EMPTY_STRING: String = String::new(); const EMPTY_STRING: String = String::new();
#[pyclass(module="granian._granian")] #[pyclass(module = "granian._granian")]
pub(crate) struct ASGIHTTPProtocol { pub(crate) struct ASGIHTTPProtocol {
rt: RuntimeRef, rt: RuntimeRef,
tx: Option<oneshot::Sender<Response<Body>>>, tx: Option<oneshot::Sender<Response<Body>>>,
@ -36,15 +37,11 @@ pub(crate) struct ASGIHTTPProtocol {
response_chunked: bool, response_chunked: bool,
response_status: Option<i16>, response_status: Option<i16>,
response_headers: Option<HeaderMap>, response_headers: Option<HeaderMap>,
body_tx: Option<Arc<Mutex<BodySender>>> body_tx: Option<Arc<Mutex<BodySender>>>,
} }
impl ASGIHTTPProtocol { impl ASGIHTTPProtocol {
pub fn new( pub fn new(rt: RuntimeRef, request: Request<Body>, tx: oneshot::Sender<Response<Body>>) -> Self {
rt: RuntimeRef,
request: Request<Body>,
tx: oneshot::Sender<Response<Body>>
) -> Self {
Self { Self {
rt, rt,
tx: Some(tx), tx: Some(tx),
@ -53,7 +50,7 @@ impl ASGIHTTPProtocol {
response_chunked: false, response_chunked: false,
response_status: None, response_status: None,
response_headers: None, response_headers: None,
body_tx: None body_tx: None,
} }
} }
@ -71,7 +68,7 @@ impl ASGIHTTPProtocol {
fn send_body<'p>(&self, py: Python<'p>, tx: Arc<Mutex<BodySender>>, body: Vec<u8>) -> PyResult<&'p PyAny> { fn send_body<'p>(&self, py: Python<'p>, tx: Arc<Mutex<BodySender>>, body: Vec<u8>) -> PyResult<&'p PyAny> {
future_into_py_futlike(self.rt.clone(), py, async move { future_into_py_futlike(self.rt.clone(), py, async move {
let mut tx = tx.lock().await; let mut tx = tx.lock().await;
match (&mut *tx).send_data(body.into()).await { match (*tx).send_data(body.into()).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => {
log::warn!("ASGI transport tx error: {:?}", err); log::warn!("ASGI transport tx error: {:?}", err);
@ -94,18 +91,18 @@ impl ASGIHTTPProtocol {
let mut bodym = body_ref.lock().await; let mut bodym = body_ref.lock().await;
let body = &mut *bodym; let body = &mut *bodym;
let mut more_body = false; let mut more_body = false;
let chunk = body.data().await.map_or_else(|| Bytes::new(), |buf| { let chunk = body.data().await.map_or_else(Bytes::new, |buf| {
buf.map_or_else(|_| Bytes::new(), |buf| { buf.map_or_else(
more_body = !body.is_end_stream(); |_| Bytes::new(),
buf |buf| {
}) more_body = !body.is_end_stream();
buf
},
)
}); });
Python::with_gil(|py| { Python::with_gil(|py| {
let dict = PyDict::new(py); let dict = PyDict::new(py);
dict.set_item( dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "http.request"))?;
pyo3::intern!(py, "type"),
pyo3::intern!(py, "http.request")
)?;
dict.set_item(pyo3::intern!(py, "body"), PyBytes::new(py, &chunk[..]))?; dict.set_item(pyo3::intern!(py, "body"), PyBytes::new(py, &chunk[..]))?;
dict.set_item(pyo3::intern!(py, "more_body"), more_body)?; dict.set_item(pyo3::intern!(py, "more_body"), more_body)?;
Ok(dict.to_object(py)) Ok(dict.to_object(py))
@ -115,16 +112,14 @@ impl ASGIHTTPProtocol {
fn send<'p>(&mut self, py: Python<'p>, asyncw: &'p PyAny, data: &'p PyDict) -> PyResult<&'p PyAny> { fn send<'p>(&mut self, py: Python<'p>, asyncw: &'p PyAny, data: &'p PyDict) -> PyResult<&'p PyAny> {
match adapt_message_type(data) { match adapt_message_type(data) {
Ok(ASGIMessageType::HTTPStart) => { Ok(ASGIMessageType::HTTPStart) => match self.response_started {
match self.response_started { false => {
false => { self.response_status = Some(adapt_status_code(data)?);
self.response_status = Some(adapt_status_code(data)?); self.response_headers = Some(adapt_headers(data));
self.response_headers = Some(adapt_headers(data)); self.response_started = true;
self.response_started = true; asyncw.call0()
asyncw.call0()
},
true => error_flow!()
} }
true => error_flow!(),
}, },
Ok(ASGIMessageType::HTTPBody) => { Ok(ASGIMessageType::HTTPBody) => {
let (body, more) = adapt_body(data); let (body, more) = adapt_body(data);
@ -133,7 +128,7 @@ impl ASGIHTTPProtocol {
let headers = self.response_headers.take().unwrap(); let headers = self.response_headers.take().unwrap();
self.send_response(self.response_status.unwrap(), headers, body.into()); self.send_response(self.response_status.unwrap(), headers, body.into());
asyncw.call0() asyncw.call0()
}, }
(true, true, false) => { (true, true, false) => {
self.response_chunked = true; self.response_chunked = true;
let headers = self.response_headers.take().unwrap(); let headers = self.response_headers.take().unwrap();
@ -142,37 +137,31 @@ impl ASGIHTTPProtocol {
self.body_tx = Some(tx.clone()); self.body_tx = Some(tx.clone());
self.send_response(self.response_status.unwrap(), headers, body_stream); self.send_response(self.response_status.unwrap(), headers, body_stream);
self.send_body(py, tx, body) self.send_body(py, tx, body)
}, }
(true, true, true) => { (true, true, true) => match self.body_tx.as_mut() {
match self.body_tx.as_mut() { Some(tx) => {
Some(tx) => { let tx = tx.clone();
let tx = tx.clone(); self.send_body(py, tx, body)
self.send_body(py, tx, body)
},
_ => error_flow!()
} }
_ => error_flow!(),
}, },
(true, false, true) => { (true, false, true) => match self.body_tx.take() {
match self.body_tx.take() { Some(tx) => match body.is_empty() {
Some(tx) => { false => self.send_body(py, tx, body),
match body.is_empty() { true => asyncw.call0(),
false => self.send_body(py, tx, body), },
true => asyncw.call0() _ => error_flow!(),
}
},
_ => error_flow!()
}
}, },
_ => error_flow!() _ => error_flow!(),
} }
}, }
Err(err) => Err(err.into()), Err(err) => Err(err.into()),
_ => error_message!() _ => error_message!(),
} }
} }
} }
#[pyclass(module="granian._granian")] #[pyclass(module = "granian._granian")]
pub(crate) struct ASGIWebsocketProtocol { pub(crate) struct ASGIWebsocketProtocol {
rt: RuntimeRef, rt: RuntimeRef,
tx: Option<oneshot::Sender<bool>>, tx: Option<oneshot::Sender<bool>>,
@ -181,16 +170,11 @@ pub(crate) struct ASGIWebsocketProtocol {
ws_tx: Arc<Mutex<Option<SplitSink<WebSocketStream<hyper::upgrade::Upgraded>, Message>>>>, ws_tx: Arc<Mutex<Option<SplitSink<WebSocketStream<hyper::upgrade::Upgraded>, Message>>>>,
ws_rx: Arc<Mutex<Option<SplitStream<WebSocketStream<hyper::upgrade::Upgraded>>>>>, ws_rx: Arc<Mutex<Option<SplitStream<WebSocketStream<hyper::upgrade::Upgraded>>>>>,
accepted: Arc<Mutex<bool>>, accepted: Arc<Mutex<bool>>,
closed: bool closed: bool,
} }
impl ASGIWebsocketProtocol { impl ASGIWebsocketProtocol {
pub fn new( pub fn new(rt: RuntimeRef, tx: oneshot::Sender<bool>, websocket: HyperWebsocket, upgrade: UpgradeData) -> Self {
rt: RuntimeRef,
tx: oneshot::Sender<bool>,
websocket: HyperWebsocket,
upgrade: UpgradeData
) -> Self {
Self { Self {
rt, rt,
tx: Some(tx), tx: Some(tx),
@ -199,7 +183,7 @@ impl ASGIWebsocketProtocol {
ws_tx: Arc::new(Mutex::new(None)), ws_tx: Arc::new(Mutex::new(None)),
ws_rx: Arc::new(Mutex::new(None)), ws_rx: Arc::new(Mutex::new(None)),
accepted: Arc::new(Mutex::new(false)), accepted: Arc::new(Mutex::new(false)),
closed: false closed: false,
} }
} }
@ -211,7 +195,7 @@ impl ASGIWebsocketProtocol {
let tx = self.ws_tx.clone(); let tx = self.ws_tx.clone();
let rx = self.ws_rx.clone(); let rx = self.ws_rx.clone();
future_into_py_iter(self.rt.clone(), py, async move { future_into_py_iter(self.rt.clone(), py, async move {
if let Ok(_) = upgrade.send().await { if (upgrade.send().await).is_ok() {
if let Ok(stream) = websocket.await { if let Ok(stream) = websocket.await {
let mut wtx = tx.lock().await; let mut wtx = tx.lock().await;
let mut wrx = rx.lock().await; let mut wrx = rx.lock().await;
@ -220,7 +204,7 @@ impl ASGIWebsocketProtocol {
*wtx = Some(tx); *wtx = Some(tx);
*wrx = Some(rx); *wrx = Some(rx);
*accepted = true; *accepted = true;
return Ok(()) return Ok(());
} }
} }
error_flow!() error_flow!()
@ -228,18 +212,14 @@ impl ASGIWebsocketProtocol {
} }
#[inline(always)] #[inline(always)]
fn send_message<'p>( fn send_message<'p>(&self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> {
&self,
py: Python<'p>,
data: &'p PyDict
) -> PyResult<&'p PyAny> {
let transport = self.ws_tx.clone(); let transport = self.ws_tx.clone();
let message = ws_message_into_rs(data); let message = ws_message_into_rs(data);
future_into_py_iter(self.rt.clone(), py, async move { future_into_py_iter(self.rt.clone(), py, async move {
if let Ok(message) = message { if let Ok(message) = message {
if let Some(ws) = &mut *(transport.lock().await) { if let Some(ws) = &mut *(transport.lock().await) {
if let Ok(_) = ws.send(message).await { if (ws.send(message).await).is_ok() {
return Ok(()) return Ok(());
} }
}; };
}; };
@ -253,8 +233,8 @@ impl ASGIWebsocketProtocol {
let transport = self.ws_tx.clone(); let transport = self.ws_tx.clone();
future_into_py_iter(self.rt.clone(), py, async move { future_into_py_iter(self.rt.clone(), py, async move {
if let Some(ws) = &mut *(transport.lock().await) { if let Some(ws) = &mut *(transport.lock().await) {
if let Ok(_) = ws.close().await { if (ws.close().await).is_ok() {
return Ok(()) return Ok(());
} }
}; };
error_flow!() error_flow!()
@ -262,10 +242,7 @@ impl ASGIWebsocketProtocol {
} }
fn consumed(&self) -> bool { fn consumed(&self) -> bool {
match &self.upgrade { self.upgrade.is_none()
Some(_) => false,
_ => true
}
} }
pub fn tx(&mut self) -> (Option<oneshot::Sender<bool>>, bool) { pub fn tx(&mut self) -> (Option<oneshot::Sender<bool>>, bool) {
@ -278,44 +255,26 @@ impl ASGIWebsocketProtocol {
fn receive<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> { fn receive<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> {
let transport = self.ws_rx.clone(); let transport = self.ws_rx.clone();
let accepted = self.accepted.clone(); let accepted = self.accepted.clone();
let closed = self.closed.clone(); let closed = self.closed;
future_into_py_futlike(self.rt.clone(), py, async move { future_into_py_futlike(self.rt.clone(), py, async move {
let accepted = accepted.lock().await; let accepted = accepted.lock().await;
match (*accepted, closed) { match (*accepted, closed) {
(false, false) => { (false, false) => {
return Python::with_gil(|py| { return Python::with_gil(|py| {
let dict = PyDict::new(py); let dict = PyDict::new(py);
dict.set_item( dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.connect"))?;
pyo3::intern!(py, "type"),
pyo3::intern!(py, "websocket.connect")
)?;
Ok(dict.to_object(py)) Ok(dict.to_object(py))
}) })
},
(true, false) => {},
_ => {
return error_flow!()
} }
(true, false) => {}
_ => return error_flow!(),
} }
if let Some(ws) = &mut *(transport.lock().await) { if let Some(ws) = &mut *(transport.lock().await) {
loop { while let Some(recv) = ws.next().await {
match ws.next().await { match recv {
Some(recv) => { Ok(Message::Ping(_)) => continue,
match recv { Ok(message) => return ws_message_into_py(message),
Ok(Message::Ping(_)) => { _ => break,
continue
},
Ok(message) => {
return ws_message_into_py(message)
},
_ => {
break
}
}
},
_ => {
break
}
} }
} }
} }
@ -325,25 +284,17 @@ impl ASGIWebsocketProtocol {
fn send<'p>(&mut self, py: Python<'p>, _asyncw: &'p PyAny, data: &'p PyDict) -> PyResult<&'p PyAny> { fn send<'p>(&mut self, py: Python<'p>, _asyncw: &'p PyAny, data: &'p PyDict) -> PyResult<&'p PyAny> {
match (adapt_message_type(data), self.closed) { match (adapt_message_type(data), self.closed) {
(Ok(ASGIMessageType::WSAccept), _) => { (Ok(ASGIMessageType::WSAccept), _) => self.accept(py),
self.accept(py) (Ok(ASGIMessageType::WSClose), false) => self.close(py),
}, (Ok(ASGIMessageType::WSMessage), false) => self.send_message(py, data),
(Ok(ASGIMessageType::WSClose), false) => {
self.close(py)
},
(Ok(ASGIMessageType::WSMessage), false) => {
self.send_message(py, data)
},
(Err(err), _) => Err(err.into()), (Err(err), _) => Err(err.into()),
_ => error_message!() _ => error_message!(),
} }
} }
} }
#[inline(never)] #[inline(never)]
fn adapt_message_type( fn adapt_message_type(message: &PyDict) -> Result<ASGIMessageType, UnsupportedASGIMessage> {
message: &PyDict
) -> Result<ASGIMessageType, UnsupportedASGIMessage> {
match message.get_item("type") { match message.get_item("type") {
Some(item) => { Some(item) => {
let message_type: &str = item.extract()?; let message_type: &str = item.extract()?;
@ -353,20 +304,18 @@ fn adapt_message_type(
"websocket.accept" => Ok(ASGIMessageType::WSAccept), "websocket.accept" => Ok(ASGIMessageType::WSAccept),
"websocket.close" => Ok(ASGIMessageType::WSClose), "websocket.close" => Ok(ASGIMessageType::WSClose),
"websocket.send" => Ok(ASGIMessageType::WSMessage), "websocket.send" => Ok(ASGIMessageType::WSMessage),
_ => error_message!() _ => error_message!(),
} }
}, }
_ => error_message!() _ => error_message!(),
} }
} }
#[inline(always)] #[inline(always)]
fn adapt_status_code(message: &PyDict) -> Result<i16, UnsupportedASGIMessage> { fn adapt_status_code(message: &PyDict) -> Result<i16, UnsupportedASGIMessage> {
match message.get_item("status") { match message.get_item("status") {
Some(item) => { Some(item) => Ok(item.extract()?),
Ok(item.extract()?) _ => error_message!(),
},
_ => error_message!()
} }
} }
@ -377,34 +326,26 @@ fn adapt_headers(message: &PyDict) -> HeaderMap {
match message.get_item("headers") { match message.get_item("headers") {
Some(item) => { Some(item) => {
let accum: Vec<Vec<&[u8]>> = item.extract().unwrap_or(Vec::new()); let accum: Vec<Vec<&[u8]>> = item.extract().unwrap_or(Vec::new());
for tup in accum.iter() { for tup in &accum {
match ( if let (Ok(key), Ok(val)) = (HeaderName::from_bytes(tup[0]), HeaderValue::from_bytes(tup[1])) {
HeaderName::from_bytes(tup[0]), ret.append(key, val);
HeaderValue::from_bytes(tup[1])
) {
(Ok(key), Ok(val)) => { ret.append(key, val); },
_ => {}
} }
}; }
ret ret
}, }
_ => ret _ => ret,
} }
} }
#[inline(always)] #[inline(always)]
fn adapt_body(message: &PyDict) -> (Vec<u8>, bool) { fn adapt_body(message: &PyDict) -> (Vec<u8>, bool) {
let body = match message.get_item("body") { let body = match message.get_item("body") {
Some(item) => { Some(item) => item.extract().unwrap_or(EMPTY_BYTES),
item.extract().unwrap_or(EMPTY_BYTES) _ => EMPTY_BYTES,
},
_ => EMPTY_BYTES
}; };
let more = match message.get_item("more_body") { let more = match message.get_item("more_body") {
Some(item) => { Some(item) => item.extract().unwrap_or(false),
item.extract().unwrap_or(false) _ => false,
},
_ => false
}; };
(body, more) (body, more)
} }
@ -412,22 +353,12 @@ fn adapt_body(message: &PyDict) -> (Vec<u8>, bool) {
#[inline(always)] #[inline(always)]
fn ws_message_into_rs(message: &PyDict) -> PyResult<Message> { fn ws_message_into_rs(message: &PyDict) -> PyResult<Message> {
match (message.get_item("bytes"), message.get_item("text")) { match (message.get_item("bytes"), message.get_item("text")) {
(Some(item), None) => { (Some(item), None) => Ok(Message::Binary(item.extract().unwrap_or(EMPTY_BYTES))),
Ok(Message::Binary(item.extract().unwrap_or(EMPTY_BYTES))) (None, Some(item)) => Ok(Message::Text(item.extract().unwrap_or(EMPTY_STRING))),
}, (Some(itemb), Some(itemt)) => match (itemb.extract().unwrap_or(None), itemt.extract().unwrap_or(None)) {
(None, Some(item)) => { (Some(msgb), None) => Ok(Message::Binary(msgb)),
Ok(Message::Text(item.extract().unwrap_or(EMPTY_STRING))) (None, Some(msgt)) => Ok(Message::Text(msgt)),
}, _ => error_flow!(),
(Some(itemb), Some(itemt)) => {
match (itemb.extract().unwrap_or(None), itemt.extract().unwrap_or(None)) {
(Some(msgb), None) => {
Ok(Message::Binary(msgb))
},
(None, Some(msgt)) => {
Ok(Message::Text(msgt))
},
_ => error_flow!()
}
}, },
_ => { _ => {
error_flow!() error_flow!()
@ -438,41 +369,23 @@ fn ws_message_into_rs(message: &PyDict) -> PyResult<Message> {
#[inline(always)] #[inline(always)]
fn ws_message_into_py(message: Message) -> PyResult<PyObject> { fn ws_message_into_py(message: Message) -> PyResult<PyObject> {
match message { match message {
Message::Binary(message) => { Message::Binary(message) => Python::with_gil(|py| {
Python::with_gil(|py| { let dict = PyDict::new(py);
let dict = PyDict::new(py); dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.receive"))?;
dict.set_item( dict.set_item(pyo3::intern!(py, "bytes"), PyBytes::new(py, &message[..]))?;
pyo3::intern!(py, "type"), Ok(dict.to_object(py))
pyo3::intern!(py, "websocket.receive") }),
)?; Message::Text(message) => Python::with_gil(|py| {
dict.set_item( let dict = PyDict::new(py);
pyo3::intern!(py, "bytes"), dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.receive"))?;
PyBytes::new(py, &message[..]) dict.set_item(pyo3::intern!(py, "text"), message)?;
)?; Ok(dict.to_object(py))
Ok(dict.to_object(py)) }),
}) Message::Close(_) => Python::with_gil(|py| {
}, let dict = PyDict::new(py);
Message::Text(message) => { dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.disconnect"))?;
Python::with_gil(|py| { Ok(dict.to_object(py))
let dict = PyDict::new(py); }),
dict.set_item(
pyo3::intern!(py, "type"),
pyo3::intern!(py, "websocket.receive")
)?;
dict.set_item(pyo3::intern!(py, "text"), message)?;
Ok(dict.to_object(py))
})
},
Message::Close(_) => {
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item(
pyo3::intern!(py, "type"),
pyo3::intern!(py, "websocket.disconnect")
)?;
Ok(dict.to_object(py))
})
},
v => { v => {
log::warn!("Unsupported websocket message received {:?}", v); log::warn!("Unsupported websocket message received {:?}", v);
error_flow!() error_flow!()

View file

@ -1,29 +1,14 @@
use pyo3::prelude::*; use pyo3::prelude::*;
use crate::{
workers::{
WorkerConfig,
serve_rth,
serve_wth,
serve_rth_ssl,
serve_wth_ssl
}
};
use super::http::{ use super::http::{
handle_rtb, handle_rtb, handle_rtb_pyw, handle_rtb_ws, handle_rtb_ws_pyw, handle_rtt, handle_rtt_pyw, handle_rtt_ws,
handle_rtb_pyw, handle_rtt_ws_pyw,
handle_rtt,
handle_rtt_pyw,
handle_rtb_ws,
handle_rtb_ws_pyw,
handle_rtt_ws,
handle_rtt_ws_pyw
}; };
use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig};
#[pyclass(module = "granian._granian")]
#[pyclass(module="granian._granian")]
pub struct ASGIWorker { pub struct ASGIWorker {
config: WorkerConfig config: WorkerConfig,
} }
impl ASGIWorker { impl ASGIWorker {
@ -74,7 +59,7 @@ impl ASGIWorker {
opt_enabled: bool, opt_enabled: bool,
ssl_enabled: bool, ssl_enabled: bool,
ssl_cert: Option<&str>, ssl_cert: Option<&str>,
ssl_key: Option<&str> ssl_key: Option<&str>,
) -> PyResult<Self> { ) -> PyResult<Self> {
Ok(Self { Ok(Self {
config: WorkerConfig::new( config: WorkerConfig::new(
@ -88,22 +73,16 @@ impl ASGIWorker {
opt_enabled, opt_enabled,
ssl_enabled, ssl_enabled,
ssl_cert, ssl_cert,
ssl_key ssl_key,
) ),
}) })
} }
fn serve_rth( fn serve_rth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
match ( match (
self.config.websockets_enabled, self.config.websockets_enabled,
self.config.ssl_enabled, self.config.ssl_enabled,
self.config.opt_enabled self.config.opt_enabled,
) { ) {
(false, false, true) => self._serve_rth(callback, event_loop, context, signal_rx), (false, false, true) => self._serve_rth(callback, event_loop, context, signal_rx),
(false, false, false) => self._serve_rth_pyw(callback, event_loop, context, signal_rx), (false, false, false) => self._serve_rth_pyw(callback, event_loop, context, signal_rx),
@ -112,21 +91,15 @@ impl ASGIWorker {
(false, true, true) => self._serve_rth_ssl(callback, event_loop, context, signal_rx), (false, true, true) => self._serve_rth_ssl(callback, event_loop, context, signal_rx),
(false, true, false) => self._serve_rth_ssl_pyw(callback, event_loop, context, signal_rx), (false, true, false) => self._serve_rth_ssl_pyw(callback, event_loop, context, signal_rx),
(true, true, true) => self._serve_rth_ssl_ws(callback, event_loop, context, signal_rx), (true, true, true) => self._serve_rth_ssl_ws(callback, event_loop, context, signal_rx),
(true, true, false) => self._serve_rth_ssl_ws_pyw(callback, event_loop, context, signal_rx) (true, true, false) => self._serve_rth_ssl_ws_pyw(callback, event_loop, context, signal_rx),
} }
} }
fn serve_wth( fn serve_wth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
match ( match (
self.config.websockets_enabled, self.config.websockets_enabled,
self.config.ssl_enabled, self.config.ssl_enabled,
self.config.opt_enabled self.config.opt_enabled,
) { ) {
(false, false, true) => self._serve_wth(callback, event_loop, context, signal_rx), (false, false, true) => self._serve_wth(callback, event_loop, context, signal_rx),
(false, false, false) => self._serve_wth_pyw(callback, event_loop, context, signal_rx), (false, false, false) => self._serve_wth_pyw(callback, event_loop, context, signal_rx),
@ -135,7 +108,7 @@ impl ASGIWorker {
(false, true, true) => self._serve_wth_ssl(callback, event_loop, context, signal_rx), (false, true, true) => self._serve_wth_ssl(callback, event_loop, context, signal_rx),
(false, true, false) => self._serve_wth_ssl_pyw(callback, event_loop, context, signal_rx), (false, true, false) => self._serve_wth_ssl_pyw(callback, event_loop, context, signal_rx),
(true, true, true) => self._serve_wth_ssl_ws(callback, event_loop, context, signal_rx), (true, true, true) => self._serve_wth_ssl_ws(callback, event_loop, context, signal_rx),
(true, true, false) => self._serve_wth_ssl_ws_pyw(callback, event_loop, context, signal_rx) (true, true, false) => self._serve_wth_ssl_ws_pyw(callback, event_loop, context, signal_rx),
} }
} }
} }

View file

@ -1,10 +1,9 @@
use hyper::{Uri, Version, header::HeaderMap}; use hyper::{header::HeaderMap, Uri, Version};
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyList, PyString}; use pyo3::types::{PyBytes, PyDict, PyList, PyString};
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
const SCHEME_HTTPS: &str = "https"; const SCHEME_HTTPS: &str = "https";
const SCHEME_WS: &str = "ws"; const SCHEME_WS: &str = "ws";
const SCHEME_WSS: &str = "wss"; const SCHEME_WSS: &str = "wss";
@ -17,10 +16,10 @@ pub(crate) enum ASGIMessageType {
HTTPBody, HTTPBody,
WSAccept, WSAccept,
WSClose, WSClose,
WSMessage WSMessage,
} }
#[pyclass(module="granian._granian")] #[pyclass(module = "granian._granian")]
pub(crate) struct ASGIScope { pub(crate) struct ASGIScope {
http_version: Version, http_version: Version,
scheme: String, scheme: String,
@ -31,7 +30,7 @@ pub(crate) struct ASGIScope {
client_ip: IpAddr, client_ip: IpAddr,
client_port: u16, client_port: u16,
headers: HeaderMap, headers: HeaderMap,
is_websocket: bool is_websocket: bool,
} }
impl ASGIScope { impl ASGIScope {
@ -42,31 +41,31 @@ impl ASGIScope {
method: &str, method: &str,
server: SocketAddr, server: SocketAddr,
client: SocketAddr, client: SocketAddr,
headers: &HeaderMap headers: &HeaderMap,
) -> Self { ) -> Self {
Self { Self {
http_version: http_version, http_version,
scheme: scheme.to_string(), scheme: scheme.to_string(),
method: method.to_string(), method: method.to_string(),
uri: uri, uri,
server_ip: server.ip(), server_ip: server.ip(),
server_port: server.port(), server_port: server.port(),
client_ip: client.ip(), client_ip: client.ip(),
client_port: client.port(), client_port: client.port(),
headers: headers.to_owned(), headers: headers.clone(),
is_websocket: false is_websocket: false,
} }
} }
pub fn set_websocket(&mut self) { pub fn set_websocket(&mut self) {
self.is_websocket = true self.is_websocket = true;
} }
#[inline(always)] #[inline(always)]
fn py_proto(&self) -> &str { fn py_proto(&self) -> &str {
match self.is_websocket { match self.is_websocket {
false => "http", false => "http",
true => "websocket" true => "websocket",
} }
} }
@ -76,7 +75,7 @@ impl ASGIScope {
Version::HTTP_10 => "1", Version::HTTP_10 => "1",
Version::HTTP_11 => "1.1", Version::HTTP_11 => "1.1",
Version::HTTP_2 => "2", Version::HTTP_2 => "2",
_ => "1" _ => "1",
} }
} }
@ -85,22 +84,20 @@ impl ASGIScope {
let scheme = &self.scheme[..]; let scheme = &self.scheme[..];
match self.is_websocket { match self.is_websocket {
false => scheme, false => scheme,
true => { true => match scheme {
match scheme { SCHEME_HTTPS => SCHEME_WSS,
SCHEME_HTTPS => SCHEME_WSS, _ => SCHEME_WS,
_ => SCHEME_WS },
}
}
} }
} }
#[inline(always)] #[inline(always)]
fn py_headers<'p>(&self, py: Python<'p>) -> PyResult<&'p PyList> { fn py_headers<'p>(&self, py: Python<'p>) -> PyResult<&'p PyList> {
let rv = PyList::empty(py); let rv = PyList::empty(py);
for (key, value) in self.headers.iter() { for (key, value) in &self.headers {
rv.append(( rv.append((
PyBytes::new(py, key.as_str().as_bytes()), PyBytes::new(py, key.as_str().as_bytes()),
PyBytes::new(py, value.as_bytes()) PyBytes::new(py, value.as_bytes()),
))?; ))?;
} }
Ok(rv) Ok(rv)
@ -110,17 +107,10 @@ impl ASGIScope {
#[pymethods] #[pymethods]
impl ASGIScope { impl ASGIScope {
fn as_dict<'p>(&self, py: Python<'p>, url_path_prefix: &'p str) -> PyResult<&'p PyAny> { fn as_dict<'p>(&self, py: Python<'p>, url_path_prefix: &'p str) -> PyResult<&'p PyAny> {
let ( let (path, query_string, proto, http_version, server, client, scheme, method) = py.allow_threads(|| {
path, let (path, query_string) = self
query_string, .uri
proto, .path_and_query()
http_version,
server,
client,
scheme,
method
) = py.allow_threads(|| {
let (path, query_string) = self.uri.path_and_query()
.map_or_else(|| ("", ""), |pq| (pq.path(), pq.query().unwrap_or(""))); .map_or_else(|| ("", ""), |pq| (pq.path(), pq.query().unwrap_or("")));
( (
path, path,
@ -136,19 +126,23 @@ impl ASGIScope {
let dict: &PyDict = PyDict::new(py); let dict: &PyDict = PyDict::new(py);
dict.set_item( dict.set_item(
pyo3::intern!(py, "asgi"), pyo3::intern!(py, "asgi"),
ASGI_VERSION.get_or_try_init(|| { ASGI_VERSION
let rv = PyDict::new(py); .get_or_try_init(|| {
rv.set_item("version", "3.0")?; let rv = PyDict::new(py);
rv.set_item("spec_version", "2.3")?; rv.set_item("version", "3.0")?;
Ok::<PyObject, PyErr>(rv.into()) rv.set_item("spec_version", "2.3")?;
})?.as_ref(py) Ok::<PyObject, PyErr>(rv.into())
})?
.as_ref(py),
)?; )?;
dict.set_item( dict.set_item(
pyo3::intern!(py, "extensions"), pyo3::intern!(py, "extensions"),
ASGI_EXTENSIONS.get_or_try_init(|| { ASGI_EXTENSIONS
let rv = PyDict::new(py); .get_or_try_init(|| {
Ok::<PyObject, PyErr>(rv.into()) let rv = PyDict::new(py);
})?.as_ref(py) Ok::<PyObject, PyErr>(rv.into())
})?
.as_ref(py),
)?; )?;
dict.set_item(pyo3::intern!(py, "type"), proto)?; dict.set_item(pyo3::intern!(py, "type"), proto)?;
dict.set_item(pyo3::intern!(py, "http_version"), http_version)?; dict.set_item(pyo3::intern!(py, "http_version"), http_version)?;
@ -160,17 +154,12 @@ impl ASGIScope {
dict.set_item(pyo3::intern!(py, "path"), path)?; dict.set_item(pyo3::intern!(py, "path"), path)?;
dict.set_item( dict.set_item(
pyo3::intern!(py, "raw_path"), pyo3::intern!(py, "raw_path"),
PyString::new(py, path) PyString::new(py, path).call_method1(pyo3::intern!(py, "encode"), (pyo3::intern!(py, "ascii"),))?,
.call_method1(
pyo3::intern!(py, "encode"), (pyo3::intern!(py, "ascii"),)
)?
)?; )?;
dict.set_item( dict.set_item(
pyo3::intern!(py, "query_string"), pyo3::intern!(py, "query_string"),
PyString::new(py, query_string) PyString::new(py, query_string)
.call_method1( .call_method1(pyo3::intern!(py, "encode"), (pyo3::intern!(py, "latin-1"),))?,
pyo3::intern!(py, "encode"), (pyo3::intern!(py, "latin-1"),)
)?
)?; )?;
dict.set_item(pyo3::intern!(py, "headers"), self.py_headers(py)?)?; dict.set_item(pyo3::intern!(py, "headers"), self.py_headers(py)?)?;
Ok(dict) Ok(dict)

View file

@ -2,32 +2,27 @@ use once_cell::sync::OnceCell;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::pyclass::IterNextOutput; use pyo3::pyclass::IterNextOutput;
static CONTEXTVARS: OnceCell<PyObject> = OnceCell::new(); static CONTEXTVARS: OnceCell<PyObject> = OnceCell::new();
static CONTEXT: OnceCell<PyObject> = OnceCell::new(); static CONTEXT: OnceCell<PyObject> = OnceCell::new();
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct CallbackWrapper { pub(crate) struct CallbackWrapper {
pub callback: PyObject, pub callback: PyObject,
pub context: pyo3_asyncio::TaskLocals pub context: pyo3_asyncio::TaskLocals,
} }
impl CallbackWrapper { impl CallbackWrapper {
pub(crate) fn new( pub(crate) fn new(callback: PyObject, event_loop: &PyAny, context: &PyAny) -> Self {
callback: PyObject,
event_loop: &PyAny,
context: &PyAny
) -> Self {
Self { Self {
callback, callback,
context: pyo3_asyncio::TaskLocals::new(event_loop).with_context(context) context: pyo3_asyncio::TaskLocals::new(event_loop).with_context(context),
} }
} }
} }
#[pyclass] #[pyclass]
pub(crate) struct PyIterAwaitable { pub(crate) struct PyIterAwaitable {
result: Option<PyResult<PyObject>> result: Option<PyResult<PyObject>>,
} }
impl PyIterAwaitable { impl PyIterAwaitable {
@ -36,7 +31,7 @@ impl PyIterAwaitable {
} }
pub(crate) fn set_result(&mut self, result: PyResult<PyObject>) { pub(crate) fn set_result(&mut self, result: PyResult<PyObject>) {
self.result = Some(result) self.result = Some(result);
} }
} }
@ -52,13 +47,11 @@ impl PyIterAwaitable {
fn __next__(&mut self, py: Python) -> PyResult<IterNextOutput<PyObject, PyObject>> { fn __next__(&mut self, py: Python) -> PyResult<IterNextOutput<PyObject, PyObject>> {
match self.result.take() { match self.result.take() {
Some(res) => { Some(res) => match res {
match res { Ok(v) => Ok(IterNextOutput::Return(v)),
Ok(v) => Ok(IterNextOutput::Return(v)), Err(err) => Err(err),
Err(err) => Err(err)
}
}, },
_ => Ok(IterNextOutput::Yield(py.None())) _ => Ok(IterNextOutput::Yield(py.None())),
} }
} }
} }
@ -68,7 +61,7 @@ pub(crate) struct PyFutureAwaitable {
py_block: bool, py_block: bool,
event_loop: PyObject, event_loop: PyObject,
result: Option<PyResult<PyObject>>, result: Option<PyResult<PyObject>>,
cb: Option<(PyObject, Py<pyo3::types::PyDict>)> cb: Option<(PyObject, Py<pyo3::types::PyDict>)>,
} }
impl PyFutureAwaitable { impl PyFutureAwaitable {
@ -77,7 +70,7 @@ impl PyFutureAwaitable {
event_loop, event_loop,
py_block: true, py_block: true,
result: None, result: None,
cb: None cb: None,
} }
} }
@ -85,12 +78,9 @@ impl PyFutureAwaitable {
pyself.result = Some(result); pyself.result = Some(result);
if let Some((cb, ctx)) = pyself.cb.take() { if let Some((cb, ctx)) = pyself.cb.take() {
let py = pyself.py(); let py = pyself.py();
let _ = pyself.event_loop.call_method( let _ = pyself
py, .event_loop
"call_soon_threadsafe", .call_method(py, "call_soon_threadsafe", (cb, &pyself), Some(ctx.as_ref(py)));
(cb, &pyself),
Some(ctx.as_ref(py))
);
} }
} }
} }
@ -108,25 +98,22 @@ impl PyFutureAwaitable {
#[setter(_asyncio_future_blocking)] #[setter(_asyncio_future_blocking)]
fn set_block(&mut self, val: bool) { fn set_block(&mut self, val: bool) {
self.py_block = val self.py_block = val;
} }
fn get_loop(&mut self) -> PyObject { fn get_loop(&mut self) -> PyObject {
self.event_loop.clone() self.event_loop.clone()
} }
fn add_done_callback( fn add_done_callback(mut pyself: PyRefMut<'_, Self>, py: Python, cb: PyObject, context: PyObject) -> PyResult<()> {
mut pyself: PyRefMut<'_, Self>,
py: Python,
cb: PyObject,
context: PyObject
) -> PyResult<()> {
let kwctx = pyo3::types::PyDict::new(py); let kwctx = pyo3::types::PyDict::new(py);
kwctx.set_item("context", context)?; kwctx.set_item("context", context)?;
match pyself.result { match pyself.result {
Some(_) => { Some(_) => {
pyself.event_loop.call_method(py, "call_soon", (cb, &pyself), Some(kwctx))?; pyself
}, .event_loop
.call_method(py, "call_soon", (cb, &pyself), Some(kwctx))?;
}
_ => { _ => {
pyself.cb = Some((cb, kwctx.into_py(py))); pyself.cb = Some((cb, kwctx.into_py(py)));
} }
@ -136,9 +123,9 @@ impl PyFutureAwaitable {
fn cancel(mut pyself: PyRefMut<'_, Self>, py: Python) -> bool { fn cancel(mut pyself: PyRefMut<'_, Self>, py: Python) -> bool {
if let Some((cb, kwctx)) = pyself.cb.take() { if let Some((cb, kwctx)) = pyself.cb.take() {
let _ = pyself.event_loop.call_method( let _ = pyself
py, "call_soon", (cb, &pyself), Some(kwctx.as_ref(py)) .event_loop
); .call_method(py, "call_soon", (cb, &pyself), Some(kwctx.as_ref(py)));
} }
false false
} }
@ -150,79 +137,69 @@ impl PyFutureAwaitable {
pyself pyself
} }
fn __next__( fn __next__(mut pyself: PyRefMut<'_, Self>) -> PyResult<IterNextOutput<PyRefMut<'_, Self>, PyObject>> {
mut pyself: PyRefMut<'_, Self>
) -> PyResult<IterNextOutput<PyRefMut<'_, Self>, PyObject>> {
match pyself.result { match pyself.result {
Some(_) => { Some(_) => match pyself.result.take().unwrap() {
match pyself.result.take().unwrap() { Ok(v) => Ok(IterNextOutput::Return(v)),
Ok(v) => Ok(IterNextOutput::Return(v)), Err(err) => Err(err),
Err(err) => Err(err)
}
}, },
_ => Ok(IterNextOutput::Yield(pyself)) _ => Ok(IterNextOutput::Yield(pyself)),
} }
} }
} }
fn contextvars(py: Python) -> PyResult<&PyAny> { fn contextvars(py: Python) -> PyResult<&PyAny> {
Ok(CONTEXTVARS Ok(CONTEXTVARS
.get_or_try_init(|| py.import("contextvars").map(|m| m.into()))? .get_or_try_init(|| py.import("contextvars").map(std::convert::Into::into))?
.as_ref(py)) .as_ref(py))
} }
pub fn empty_pycontext(py: Python) -> PyResult<&PyAny> { pub fn empty_pycontext(py: Python) -> PyResult<&PyAny> {
Ok(CONTEXT Ok(CONTEXT
.get_or_try_init(|| contextvars(py)?.getattr("Context")?.call0().map(|c| c.into()))? .get_or_try_init(|| {
contextvars(py)?
.getattr("Context")?
.call0()
.map(std::convert::Into::into)
})?
.as_ref(py)) .as_ref(py))
} }
macro_rules! callback_impl_run { macro_rules! callback_impl_run {
() => { () => {
pub fn run<'p>(self, py: Python<'p>) -> PyResult<&'p PyAny> { pub fn run(self, py: Python<'_>) -> PyResult<&PyAny> {
let event_loop = self.context.event_loop(py); let event_loop = self.context.event_loop(py);
let target = self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_task"))?; let target = self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_task"))?;
let kwctx = pyo3::types::PyDict::new(py); let kwctx = pyo3::types::PyDict::new(py);
kwctx.set_item( kwctx.set_item(
pyo3::intern!(py, "context"), pyo3::intern!(py, "context"),
crate::callbacks::empty_pycontext(py)? crate::callbacks::empty_pycontext(py)?,
)?; )?;
event_loop.call_method( event_loop.call_method(pyo3::intern!(py, "call_soon_threadsafe"), (target,), Some(kwctx))
pyo3::intern!(py, "call_soon_threadsafe"),
(target,),
Some(kwctx)
)
} }
}; };
} }
macro_rules! callback_impl_run_pytask { macro_rules! callback_impl_run_pytask {
() => { () => {
pub fn run<'p>(self, py: Python<'p>) -> PyResult<&'p PyAny> { pub fn run(self, py: Python<'_>) -> PyResult<&PyAny> {
let event_loop = self.context.event_loop(py); let event_loop = self.context.event_loop(py);
let context = self.context.context(py); let context = self.context.context(py);
let target = self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_task"))?; let target = self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_task"))?;
let kwctx = pyo3::types::PyDict::new(py); let kwctx = pyo3::types::PyDict::new(py);
kwctx.set_item( kwctx.set_item(pyo3::intern!(py, "context"), context)?;
pyo3::intern!(py, "context"), event_loop.call_method(pyo3::intern!(py, "call_soon_threadsafe"), (target,), Some(kwctx))
context
)?;
event_loop.call_method(
pyo3::intern!(py, "call_soon_threadsafe"),
(target,),
Some(kwctx)
)
} }
}; };
} }
macro_rules! callback_impl_loop_run { macro_rules! callback_impl_loop_run {
() => { () => {
pub fn run<'p>(self, py: Python<'p>) -> PyResult<&'p PyAny> { pub fn run(self, py: Python<'_>) -> PyResult<&PyAny> {
let context = self.pycontext.clone().into_ref(py); let context = self.pycontext.clone().into_ref(py);
context.call_method1( context.call_method1(
pyo3::intern!(py, "run"), pyo3::intern!(py, "run"),
(self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_step"))?,) (self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_step"))?,),
) )
} }
}; };
@ -232,7 +209,7 @@ macro_rules! callback_impl_loop_pytask {
($pyself:expr, $py:expr) => { ($pyself:expr, $py:expr) => {
$pyself.context.event_loop($py).call_method1( $pyself.context.event_loop($py).call_method1(
pyo3::intern!($py, "create_task"), pyo3::intern!($py, "create_task"),
($pyself.cb.clone().into_ref($py).call1(($pyself.into_py($py),))?,) ($pyself.cb.clone().into_ref($py).call1(($pyself.into_py($py),))?,),
) )
}; };
} }
@ -241,12 +218,9 @@ macro_rules! callback_impl_loop_step {
($pyself:expr, $py:expr) => { ($pyself:expr, $py:expr) => {
match $pyself.cb.call_method1($py, pyo3::intern!($py, "send"), ($py.None(),)) { match $pyself.cb.call_method1($py, pyo3::intern!($py, "send"), ($py.None(),)) {
Ok(res) => { Ok(res) => {
let blocking: bool = match res.getattr( let blocking: bool = match res.getattr($py, pyo3::intern!($py, "_asyncio_future_blocking")) {
$py,
pyo3::intern!($py, "_asyncio_future_blocking")
) {
Ok(v) => v.extract($py)?, Ok(v) => v.extract($py)?,
_ => false _ => false,
}; };
let ctx = $pyself.pycontext.clone(); let ctx = $pyself.pycontext.clone();
@ -255,43 +229,30 @@ macro_rules! callback_impl_loop_step {
match blocking { match blocking {
true => { true => {
res.setattr( res.setattr($py, pyo3::intern!($py, "_asyncio_future_blocking"), false)?;
$py,
pyo3::intern!($py, "_asyncio_future_blocking"),
false
)?;
res.call_method( res.call_method(
$py, $py,
pyo3::intern!($py, "add_done_callback"), pyo3::intern!($py, "add_done_callback"),
( ($pyself.into_py($py).getattr($py, pyo3::intern!($py, "_loop_wake"))?,),
$pyself Some(kwctx),
.into_py($py)
.getattr($py, pyo3::intern!($py, "_loop_wake"))?,
),
Some(kwctx)
)?; )?;
Ok(()) Ok(())
}, }
false => { false => {
let event_loop = $pyself.context.event_loop($py); let event_loop = $pyself.context.event_loop($py);
event_loop.call_method( event_loop.call_method(
pyo3::intern!($py, "call_soon"), pyo3::intern!($py, "call_soon"),
( ($pyself.into_py($py).getattr($py, pyo3::intern!($py, "_loop_step"))?,),
$pyself Some(kwctx),
.into_py($py)
.getattr($py, pyo3::intern!($py, "_loop_step"))?,
),
Some(kwctx)
)?; )?;
Ok(()) Ok(())
} }
} }
}, }
Err(err) => { Err(err) => {
if ( if (err.is_instance_of::<pyo3::exceptions::PyStopIteration>($py)
err.is_instance_of::<pyo3::exceptions::PyStopIteration>($py) || || err.is_instance_of::<pyo3::exceptions::asyncio::CancelledError>($py))
err.is_instance_of::<pyo3::exceptions::asyncio::CancelledError>($py) {
) {
$pyself.done($py); $pyself.done($py);
Ok(()) Ok(())
} else { } else {
@ -307,7 +268,7 @@ macro_rules! callback_impl_loop_wake {
($pyself:expr, $py:expr, $fut:expr) => { ($pyself:expr, $py:expr, $fut:expr) => {
match $fut.call_method0($py, pyo3::intern!($py, "result")) { match $fut.call_method0($py, pyo3::intern!($py, "result")) {
Ok(_) => $pyself.into_py($py).call_method0($py, pyo3::intern!($py, "_loop_step")), Ok(_) => $pyself.into_py($py).call_method0($py, pyo3::intern!($py, "_loop_step")),
Err(err) => $pyself._loop_err($py, err) Err(err) => $pyself._loop_err($py, err),
} }
}; };
} }
@ -322,10 +283,10 @@ macro_rules! callback_impl_loop_err {
}; };
} }
pub(crate) use callback_impl_run; pub(crate) use callback_impl_loop_err;
pub(crate) use callback_impl_run_pytask;
pub(crate) use callback_impl_loop_run;
pub(crate) use callback_impl_loop_pytask; pub(crate) use callback_impl_loop_pytask;
pub(crate) use callback_impl_loop_run;
pub(crate) use callback_impl_loop_step; pub(crate) use callback_impl_loop_step;
pub(crate) use callback_impl_loop_wake; pub(crate) use callback_impl_loop_wake;
pub(crate) use callback_impl_loop_err; pub(crate) use callback_impl_run;
pub(crate) use callback_impl_run_pytask;

View file

@ -1,4 +1,7 @@
use hyper::{Body, Response, header::{HeaderValue, SERVER as HK_SERVER}}; use hyper::{
header::{HeaderValue, SERVER as HK_SERVER},
Body, Response,
};
pub(crate) const HV_SERVER: HeaderValue = HeaderValue::from_static("granian"); pub(crate) const HV_SERVER: HeaderValue = HeaderValue::from_static("granian");

View file

@ -8,8 +8,8 @@ mod callbacks;
mod http; mod http;
mod rsgi; mod rsgi;
mod runtime; mod runtime;
mod tls;
mod tcp; mod tcp;
mod tls;
mod utils; mod utils;
mod workers; mod workers;
mod ws; mod ws;

View file

@ -2,45 +2,33 @@ use pyo3::prelude::*;
use pyo3_asyncio::TaskLocals; use pyo3_asyncio::TaskLocals;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use crate::{
callbacks::{
CallbackWrapper,
callback_impl_run,
callback_impl_run_pytask,
callback_impl_loop_run,
callback_impl_loop_pytask,
callback_impl_loop_step,
callback_impl_loop_wake,
callback_impl_loop_err
},
runtime::RuntimeRef,
ws::{HyperWebsocket, UpgradeData}
};
use super::{ use super::{
io::{RSGIHTTPProtocol as HTTPProtocol, RSGIWebsocketProtocol as WebsocketProtocol}, io::{RSGIHTTPProtocol as HTTPProtocol, RSGIWebsocketProtocol as WebsocketProtocol},
types::{RSGIScope as Scope, PyResponse, PyResponseBody} types::{PyResponse, PyResponseBody, RSGIScope as Scope},
};
use crate::{
callbacks::{
callback_impl_loop_err, callback_impl_loop_pytask, callback_impl_loop_run, callback_impl_loop_step,
callback_impl_loop_wake, callback_impl_run, callback_impl_run_pytask, CallbackWrapper,
},
runtime::RuntimeRef,
ws::{HyperWebsocket, UpgradeData},
}; };
#[pyclass] #[pyclass]
pub(crate) struct CallbackRunnerHTTP { pub(crate) struct CallbackRunnerHTTP {
proto: Py<HTTPProtocol>, proto: Py<HTTPProtocol>,
context: TaskLocals, context: TaskLocals,
cb: PyObject cb: PyObject,
} }
impl CallbackRunnerHTTP { impl CallbackRunnerHTTP {
pub fn new( pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self {
py: Python,
cb: CallbackWrapper,
proto: HTTPProtocol,
scope: Scope
) -> Self {
let pyproto = Py::new(py, proto).unwrap(); let pyproto = Py::new(py, proto).unwrap();
Self { Self {
proto: pyproto.clone(), proto: pyproto.clone(),
context: cb.context, context: cb.context,
cb: cb.callback.call1(py, (scope, pyproto)).unwrap() cb: cb.callback.call1(py, (scope, pyproto)).unwrap(),
} }
} }
@ -58,19 +46,17 @@ macro_rules! callback_impl_done_http {
($self:expr, $py:expr) => { ($self:expr, $py:expr) => {
if let Ok(mut proto) = $self.proto.as_ref($py).try_borrow_mut() { if let Ok(mut proto) = $self.proto.as_ref($py).try_borrow_mut() {
if let Some(tx) = proto.tx() { if let Some(tx) = proto.tx() {
let _ = tx.send( let _ = tx.send(PyResponse::Body(PyResponseBody::empty(500, Vec::new())));
PyResponse::Body(PyResponseBody::empty(500, Vec::new()))
);
} }
} }
} };
} }
macro_rules! callback_impl_done_err { macro_rules! callback_impl_done_err {
($self:expr, $py:expr) => { ($self:expr, $py:expr) => {
log::warn!("Application callable raised an exception"); log::warn!("Application callable raised an exception");
$self.done($py) $self.done($py)
} };
} }
#[pyclass] #[pyclass]
@ -78,18 +64,18 @@ pub(crate) struct CallbackTaskHTTP {
proto: Py<HTTPProtocol>, proto: Py<HTTPProtocol>,
context: TaskLocals, context: TaskLocals,
pycontext: PyObject, pycontext: PyObject,
cb: PyObject cb: PyObject,
} }
impl CallbackTaskHTTP { impl CallbackTaskHTTP {
pub fn new( pub fn new(py: Python, cb: PyObject, proto: Py<HTTPProtocol>, context: TaskLocals) -> PyResult<Self> {
py: Python,
cb: PyObject,
proto: Py<HTTPProtocol>,
context: TaskLocals
) -> PyResult<Self> {
let pyctx = context.context(py); let pyctx = context.context(py);
Ok(Self { proto, context, pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), cb }) Ok(Self {
proto,
context,
pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(),
cb,
})
} }
fn done(&self, py: Python) { fn done(&self, py: Python) {
@ -122,21 +108,16 @@ pub(crate) struct CallbackWrappedRunnerHTTP {
context: TaskLocals, context: TaskLocals,
cb: PyObject, cb: PyObject,
#[pyo3(get)] #[pyo3(get)]
scope: PyObject scope: PyObject,
} }
impl CallbackWrappedRunnerHTTP { impl CallbackWrappedRunnerHTTP {
pub fn new( pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self {
py: Python,
cb: CallbackWrapper,
proto: HTTPProtocol,
scope: Scope
) -> Self {
Self { Self {
proto: Py::new(py, proto).unwrap(), proto: Py::new(py, proto).unwrap(),
context: cb.context, context: cb.context,
cb: cb.callback, cb: cb.callback,
scope: scope.into_py(py) scope: scope.into_py(py),
} }
} }
@ -162,21 +143,16 @@ impl CallbackWrappedRunnerHTTP {
pub(crate) struct CallbackRunnerWebsocket { pub(crate) struct CallbackRunnerWebsocket {
proto: Py<WebsocketProtocol>, proto: Py<WebsocketProtocol>,
context: TaskLocals, context: TaskLocals,
cb: PyObject cb: PyObject,
} }
impl CallbackRunnerWebsocket { impl CallbackRunnerWebsocket {
pub fn new( pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self {
py: Python,
cb: CallbackWrapper,
proto: WebsocketProtocol,
scope: Scope
) -> Self {
let pyproto = Py::new(py, proto).unwrap(); let pyproto = Py::new(py, proto).unwrap();
Self { Self {
proto: pyproto.clone(), proto: pyproto.clone(),
context: cb.context, context: cb.context,
cb: cb.callback.call1(py, (scope, pyproto)).unwrap() cb: cb.callback.call1(py, (scope, pyproto)).unwrap(),
} }
} }
@ -197,7 +173,7 @@ macro_rules! callback_impl_done_ws {
let _ = tx.send(res); let _ = tx.send(res);
} }
} }
} };
} }
#[pyclass] #[pyclass]
@ -205,18 +181,18 @@ pub(crate) struct CallbackTaskWebsocket {
proto: Py<WebsocketProtocol>, proto: Py<WebsocketProtocol>,
context: TaskLocals, context: TaskLocals,
pycontext: PyObject, pycontext: PyObject,
cb: PyObject cb: PyObject,
} }
impl CallbackTaskWebsocket { impl CallbackTaskWebsocket {
pub fn new( pub fn new(py: Python, cb: PyObject, proto: Py<WebsocketProtocol>, context: TaskLocals) -> PyResult<Self> {
py: Python,
cb: PyObject,
proto: Py<WebsocketProtocol>,
context: TaskLocals
) -> PyResult<Self> {
let pyctx = context.context(py); let pyctx = context.context(py);
Ok(Self { proto, context, pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), cb }) Ok(Self {
proto,
context,
pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(),
cb,
})
} }
fn done(&self, py: Python) { fn done(&self, py: Python) {
@ -249,21 +225,16 @@ pub(crate) struct CallbackWrappedRunnerWebsocket {
context: TaskLocals, context: TaskLocals,
cb: PyObject, cb: PyObject,
#[pyo3(get)] #[pyo3(get)]
scope: PyObject scope: PyObject,
} }
impl CallbackWrappedRunnerWebsocket { impl CallbackWrappedRunnerWebsocket {
pub fn new( pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self {
py: Python,
cb: CallbackWrapper,
proto: WebsocketProtocol,
scope: Scope
) -> Self {
Self { Self {
proto: Py::new(py, proto).unwrap(), proto: Py::new(py, proto).unwrap(),
context: cb.context, context: cb.context,
cb: cb.callback, cb: cb.callback,
scope: scope.into_py(py) scope: scope.into_py(py),
} }
} }
@ -291,7 +262,7 @@ macro_rules! call_impl_rtb_http {
cb: CallbackWrapper, cb: CallbackWrapper,
rt: RuntimeRef, rt: RuntimeRef,
req: hyper::Request<hyper::Body>, req: hyper::Request<hyper::Body>,
scope: Scope scope: Scope,
) -> oneshot::Receiver<PyResponse> { ) -> oneshot::Receiver<PyResponse> {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
let protocol = HTTPProtocol::new(rt, tx, req); let protocol = HTTPProtocol::new(rt, tx, req);
@ -311,7 +282,7 @@ macro_rules! call_impl_rtt_http {
cb: CallbackWrapper, cb: CallbackWrapper,
rt: RuntimeRef, rt: RuntimeRef,
req: hyper::Request<hyper::Body>, req: hyper::Request<hyper::Body>,
scope: Scope scope: Scope,
) -> oneshot::Receiver<PyResponse> { ) -> oneshot::Receiver<PyResponse> {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
let protocol = HTTPProtocol::new(rt, tx, req); let protocol = HTTPProtocol::new(rt, tx, req);
@ -334,7 +305,7 @@ macro_rules! call_impl_rtb_ws {
rt: RuntimeRef, rt: RuntimeRef,
ws: HyperWebsocket, ws: HyperWebsocket,
upgrade: UpgradeData, upgrade: UpgradeData,
scope: Scope scope: Scope,
) -> oneshot::Receiver<(i32, bool)> { ) -> oneshot::Receiver<(i32, bool)> {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade);
@ -355,7 +326,7 @@ macro_rules! call_impl_rtt_ws {
rt: RuntimeRef, rt: RuntimeRef,
ws: HyperWebsocket, ws: HyperWebsocket,
upgrade: UpgradeData, upgrade: UpgradeData,
scope: Scope scope: Scope,
) -> oneshot::Receiver<(i32, bool)> { ) -> oneshot::Receiver<(i32, bool)> {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade);

View file

@ -1,6 +1,5 @@
use pyo3::{create_exception, exceptions::PyRuntimeError}; use pyo3::{create_exception, exceptions::PyRuntimeError};
create_exception!(_granian, RSGIProtocolError, PyRuntimeError, "RSGIProtocolError"); create_exception!(_granian, RSGIProtocolError, PyRuntimeError, "RSGIProtocolError");
create_exception!(_granian, RSGIProtocolClosed, PyRuntimeError, "RSGIProtocolClosed"); create_exception!(_granian, RSGIProtocolClosed, PyRuntimeError, "RSGIProtocolClosed");

View file

@ -1,34 +1,22 @@
use hyper::{ use hyper::{
Body, header::SERVER as HK_SERVER, http::response::Builder as ResponseBuilder, Body, Request, Response, StatusCode,
Request,
Response,
StatusCode,
header::SERVER as HK_SERVER,
http::response::Builder as ResponseBuilder
}; };
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use crate::{
callbacks::CallbackWrapper,
http::{HV_SERVER, response_500},
runtime::RuntimeRef,
ws::{UpgradeData, is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade}
};
use super::{ use super::{
callbacks::{ callbacks::{
call_rtb_http, call_rtb_http, call_rtb_http_pyw, call_rtb_ws, call_rtb_ws_pyw, call_rtt_http, call_rtt_http_pyw, call_rtt_ws,
call_rtb_http_pyw, call_rtt_ws_pyw,
call_rtb_ws,
call_rtb_ws_pyw,
call_rtt_http,
call_rtt_http_pyw,
call_rtt_ws,
call_rtt_ws_pyw
}, },
types::{RSGIScope as Scope, PyResponse} types::{PyResponse, RSGIScope as Scope},
};
use crate::{
callbacks::CallbackWrapper,
http::{response_500, HV_SERVER},
runtime::RuntimeRef,
ws::{is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade, UpgradeData},
}; };
macro_rules! default_scope { macro_rules! default_scope {
($server_addr:expr, $client_addr:expr, $req:expr, $scheme:expr) => { ($server_addr:expr, $client_addr:expr, $req:expr, $scheme:expr) => {
@ -40,7 +28,7 @@ macro_rules! default_scope {
$req.method().as_ref(), $req.method().as_ref(),
$server_addr, $server_addr,
$client_addr, $client_addr,
$req.headers() $req.headers(),
) )
}; };
} }
@ -48,12 +36,8 @@ macro_rules! default_scope {
macro_rules! handle_http_response { macro_rules! handle_http_response {
($handler:expr, $rt:expr, $callback:expr, $req:expr, $scope:expr) => { ($handler:expr, $rt:expr, $callback:expr, $req:expr, $scope:expr) => {
match $handler($callback, $rt, $req, $scope).await { match $handler($callback, $rt, $req, $scope).await {
Ok(PyResponse::Body(pyres)) => { Ok(PyResponse::Body(pyres)) => pyres.to_response(),
pyres.to_response() Ok(PyResponse::File(pyres)) => pyres.to_response().await,
},
Ok(PyResponse::File(pyres)) => {
pyres.to_response().await
},
_ => { _ => {
log::error!("RSGI protocol failure"); log::error!("RSGI protocol failure");
response_500() response_500()
@ -70,7 +54,7 @@ macro_rules! handle_request {
server_addr: SocketAddr, server_addr: SocketAddr,
client_addr: SocketAddr, client_addr: SocketAddr,
req: Request<Body>, req: Request<Body>,
scheme: &str scheme: &str,
) -> Response<Body> { ) -> Response<Body> {
let scope = default_scope!(server_addr, client_addr, &req, scheme); let scope = default_scope!(server_addr, client_addr, &req, scheme);
handle_http_response!($handler, rt, callback, req, scope) handle_http_response!($handler, rt, callback, req, scope)
@ -86,7 +70,7 @@ macro_rules! handle_request_with_ws {
server_addr: SocketAddr, server_addr: SocketAddr,
client_addr: SocketAddr, client_addr: SocketAddr,
req: Request<Body>, req: Request<Body>,
scheme: &str scheme: &str,
) -> Response<Body> { ) -> Response<Body> {
let mut scope = default_scope!(server_addr, client_addr, &req, scheme); let mut scope = default_scope!(server_addr, client_addr, &req, scheme);
@ -101,27 +85,23 @@ macro_rules! handle_request_with_ws {
rt.inner.spawn(async move { rt.inner.spawn(async move {
let tx_ref = restx.clone(); let tx_ref = restx.clone();
match $handler_ws( match $handler_ws(callback, rth, ws, UpgradeData::new(res, restx), scope).await {
callback,
rth,
ws,
UpgradeData::new(res, restx),
scope
).await {
Ok((status, consumed)) => { Ok((status, consumed)) => {
if !consumed { if !consumed {
let _ = tx_ref.send( let _ = tx_ref
ResponseBuilder::new() .send(
.status( ResponseBuilder::new()
StatusCode::from_u16(status as u16) .status(
.unwrap_or(StatusCode::FORBIDDEN) StatusCode::from_u16(status as u16)
) .unwrap_or(StatusCode::FORBIDDEN),
.header(HK_SERVER, HV_SERVER) )
.body(Body::from("")) .header(HK_SERVER, HV_SERVER)
.unwrap() .body(Body::from(""))
).await; .unwrap(),
)
.await;
} }
}, }
_ => { _ => {
log::error!("RSGI protocol failure"); log::error!("RSGI protocol failure");
let _ = tx_ref.send(response_500()).await; let _ = tx_ref.send(response_500()).await;
@ -133,10 +113,10 @@ macro_rules! handle_request_with_ws {
Some(res) => { Some(res) => {
resrx.close(); resrx.close();
res res
}, }
_ => response_500() _ => response_500(),
} };
}, }
Err(err) => { Err(err) => {
return ResponseBuilder::new() return ResponseBuilder::new()
.status(StatusCode::BAD_REQUEST) .status(StatusCode::BAD_REQUEST)
@ -149,7 +129,6 @@ macro_rules! handle_request_with_ws {
handle_http_response!($handler_req, rt, callback, req, scope) handle_http_response!($handler_req, rt, callback, req, scope)
} }
}; };
} }

View file

@ -1,35 +1,40 @@
use bytes::Bytes; use bytes::Bytes;
use futures::{sink::SinkExt, stream::{SplitSink, SplitStream, StreamExt}}; use futures::{
use hyper::{body::{Body, Sender as BodySender, HttpBody}, Request}; sink::SinkExt,
stream::{SplitSink, SplitStream, StreamExt},
};
use hyper::{
body::{Body, HttpBody, Sender as BodySender},
Request,
};
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyString}; use pyo3::types::{PyBytes, PyString};
use std::sync::Arc; use std::sync::Arc;
use tokio_tungstenite::WebSocketStream;
use tokio::sync::{oneshot, Mutex}; use tokio::sync::{oneshot, Mutex};
use tokio_tungstenite::WebSocketStream;
use tungstenite::Message; use tungstenite::Message;
use crate::{
runtime::{Runtime, RuntimeRef, future_into_py_iter, future_into_py_futlike},
ws::{HyperWebsocket, UpgradeData}
};
use super::{ use super::{
errors::{error_proto, error_stream}, errors::{error_proto, error_stream},
types::{PyResponse, PyResponseBody, PyResponseFile} types::{PyResponse, PyResponseBody, PyResponseFile},
};
use crate::{
runtime::{future_into_py_futlike, future_into_py_iter, Runtime, RuntimeRef},
ws::{HyperWebsocket, UpgradeData},
}; };
#[pyclass(module = "granian._granian")]
#[pyclass(module="granian._granian")]
pub(crate) struct RSGIHTTPStreamTransport { pub(crate) struct RSGIHTTPStreamTransport {
rt: RuntimeRef, rt: RuntimeRef,
tx: Arc<Mutex<BodySender>> tx: Arc<Mutex<BodySender>>,
} }
impl RSGIHTTPStreamTransport { impl RSGIHTTPStreamTransport {
pub fn new( pub fn new(rt: RuntimeRef, transport: BodySender) -> Self {
rt: RuntimeRef, Self {
transport: BodySender rt,
) -> Self { tx: Arc::new(Mutex::new(transport)),
Self { rt: rt, tx: Arc::new(Mutex::new(transport)) } }
} }
} }
@ -41,8 +46,8 @@ impl RSGIHTTPStreamTransport {
if let Ok(mut stream) = transport.try_lock() { if let Ok(mut stream) = transport.try_lock() {
return match stream.send_data(data.into()).await { return match stream.send_data(data.into()).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
_ => error_stream!() _ => error_stream!(),
} };
} }
error_proto!() error_proto!()
}) })
@ -54,31 +59,27 @@ impl RSGIHTTPStreamTransport {
if let Ok(mut stream) = transport.try_lock() { if let Ok(mut stream) = transport.try_lock() {
return match stream.send_data(data.into()).await { return match stream.send_data(data.into()).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
_ => error_stream!() _ => error_stream!(),
} };
} }
error_proto!() error_proto!()
}) })
} }
} }
#[pyclass(module="granian._granian")] #[pyclass(module = "granian._granian")]
pub(crate) struct RSGIHTTPProtocol { pub(crate) struct RSGIHTTPProtocol {
rt: RuntimeRef, rt: RuntimeRef,
tx: Option<oneshot::Sender<super::types::PyResponse>>, tx: Option<oneshot::Sender<super::types::PyResponse>>,
body: Arc<Mutex<Body>> body: Arc<Mutex<Body>>,
} }
impl RSGIHTTPProtocol { impl RSGIHTTPProtocol {
pub fn new( pub fn new(rt: RuntimeRef, tx: oneshot::Sender<super::types::PyResponse>, request: Request<Body>) -> Self {
rt: RuntimeRef,
tx: oneshot::Sender<super::types::PyResponse>,
request: Request<Body>
) -> Self {
Self { Self {
rt, rt,
tx: Some(tx), tx: Some(tx),
body: Arc::new(Mutex::new(request.into_body())) body: Arc::new(Mutex::new(request.into_body())),
} }
} }
@ -110,14 +111,13 @@ impl RSGIHTTPProtocol {
let mut bodym = body_ref.lock().await; let mut bodym = body_ref.lock().await;
let body = &mut *bodym; let body = &mut *bodym;
if body.is_end_stream() { if body.is_end_stream() {
return Err(pyo3::exceptions::PyStopAsyncIteration::new_err("stream exhausted")) return Err(pyo3::exceptions::PyStopAsyncIteration::new_err("stream exhausted"));
} }
let chunk = body.data().await.map_or_else(|| Bytes::new(), |buf| { let chunk = body
buf.unwrap_or_else(|_| Bytes::new()) .data()
}); .await
Ok(Python::with_gil(|py| { .map_or_else(Bytes::new, |buf| buf.unwrap_or_else(|_| Bytes::new()));
PyBytes::new(py, &chunk[..]).to_object(py) Ok(Python::with_gil(|py| PyBytes::new(py, &chunk[..]).to_object(py)))
}))
})?; })?;
Ok(Some(fut)) Ok(Some(fut))
} }
@ -125,36 +125,28 @@ impl RSGIHTTPProtocol {
#[pyo3(signature = (status=200, headers=vec![]))] #[pyo3(signature = (status=200, headers=vec![]))]
fn response_empty(&mut self, status: u16, headers: Vec<(String, String)>) { fn response_empty(&mut self, status: u16, headers: Vec<(String, String)>) {
if let Some(tx) = self.tx.take() { if let Some(tx) = self.tx.take() {
let _ = tx.send( let _ = tx.send(PyResponse::Body(PyResponseBody::empty(status, headers)));
PyResponse::Body(PyResponseBody::empty(status, headers))
);
} }
} }
#[pyo3(signature = (status=200, headers=vec![], body=vec![]))] #[pyo3(signature = (status=200, headers=vec![], body=vec![]))]
fn response_bytes(&mut self, status: u16, headers: Vec<(String, String)>, body: Vec<u8>) { fn response_bytes(&mut self, status: u16, headers: Vec<(String, String)>, body: Vec<u8>) {
if let Some(tx) = self.tx.take() { if let Some(tx) = self.tx.take() {
let _ = tx.send( let _ = tx.send(PyResponse::Body(PyResponseBody::from_bytes(status, headers, body)));
PyResponse::Body(PyResponseBody::from_bytes(status, headers, body))
);
} }
} }
#[pyo3(signature = (status=200, headers=vec![], body="".to_string()))] #[pyo3(signature = (status=200, headers=vec![], body=String::new()))]
fn response_str(&mut self, status: u16, headers: Vec<(String, String)>, body: String) { fn response_str(&mut self, status: u16, headers: Vec<(String, String)>, body: String) {
if let Some(tx) = self.tx.take() { if let Some(tx) = self.tx.take() {
let _ = tx.send( let _ = tx.send(PyResponse::Body(PyResponseBody::from_string(status, headers, body)));
PyResponse::Body(PyResponseBody::from_string(status, headers, body))
);
} }
} }
#[pyo3(signature = (status, headers, file))] #[pyo3(signature = (status, headers, file))]
fn response_file(&mut self, status: u16, headers: Vec<(String, String)>, file: String) { fn response_file(&mut self, status: u16, headers: Vec<(String, String)>, file: String) {
if let Some(tx) = self.tx.take() { if let Some(tx) = self.tx.take() {
let _ = tx.send( let _ = tx.send(PyResponse::File(PyResponseFile::new(status, headers, file)));
PyResponse::File(PyResponseFile::new(status, headers, file))
);
} }
} }
@ -163,34 +155,33 @@ impl RSGIHTTPProtocol {
&mut self, &mut self,
py: Python<'p>, py: Python<'p>,
status: u16, status: u16,
headers: Vec<(String, String)> headers: Vec<(String, String)>,
) -> PyResult<&'p PyAny> { ) -> PyResult<&'p PyAny> {
if let Some(tx) = self.tx.take() { if let Some(tx) = self.tx.take() {
let (body_tx, body_stream) = Body::channel(); let (body_tx, body_stream) = Body::channel();
let _ = tx.send( let _ = tx.send(PyResponse::Body(PyResponseBody::new(status, headers, body_stream)));
PyResponse::Body(PyResponseBody::new(status, headers, body_stream))
);
let trx = Py::new(py, RSGIHTTPStreamTransport::new(self.rt.clone(), body_tx))?; let trx = Py::new(py, RSGIHTTPStreamTransport::new(self.rt.clone(), body_tx))?;
return Ok(trx.into_ref(py)) return Ok(trx.into_ref(py));
} }
error_proto!() error_proto!()
} }
} }
#[pyclass(module="granian._granian")] #[pyclass(module = "granian._granian")]
pub(crate) struct RSGIWebsocketTransport { pub(crate) struct RSGIWebsocketTransport {
rt: RuntimeRef, rt: RuntimeRef,
tx: Arc<Mutex<SplitSink<WebSocketStream<hyper::upgrade::Upgraded>, Message>>>, tx: Arc<Mutex<SplitSink<WebSocketStream<hyper::upgrade::Upgraded>, Message>>>,
rx: Arc<Mutex<SplitStream<WebSocketStream<hyper::upgrade::Upgraded>>>> rx: Arc<Mutex<SplitStream<WebSocketStream<hyper::upgrade::Upgraded>>>>,
} }
impl RSGIWebsocketTransport { impl RSGIWebsocketTransport {
pub fn new( pub fn new(rt: RuntimeRef, transport: WebSocketStream<hyper::upgrade::Upgraded>) -> Self {
rt: RuntimeRef,
transport: WebSocketStream<hyper::upgrade::Upgraded>
) -> Self {
let (tx, rx) = transport.split(); let (tx, rx) = transport.split();
Self { rt: rt, tx: Arc::new(Mutex::new(tx)), rx: Arc::new(Mutex::new(rx)) } Self {
rt,
tx: Arc::new(Mutex::new(tx)),
rx: Arc::new(Mutex::new(rx)),
}
} }
pub fn close(&self) { pub fn close(&self) {
@ -209,27 +200,14 @@ impl RSGIWebsocketTransport {
let transport = self.rx.clone(); let transport = self.rx.clone();
future_into_py_futlike(self.rt.clone(), py, async move { future_into_py_futlike(self.rt.clone(), py, async move {
if let Ok(mut stream) = transport.try_lock() { if let Ok(mut stream) = transport.try_lock() {
loop { while let Some(recv) = stream.next().await {
match stream.next().await { match recv {
Some(recv) => { Ok(Message::Ping(_)) => continue,
match recv { Ok(message) => return message_into_py(message),
Ok(Message::Ping(_)) => { _ => break,
continue
},
Ok(message) => {
return message_into_py(message)
},
_ => {
break
}
}
},
_ => {
break
}
} }
} }
return error_stream!() return error_stream!();
} }
error_proto!() error_proto!()
}) })
@ -241,8 +219,8 @@ impl RSGIWebsocketTransport {
if let Ok(mut stream) = transport.try_lock() { if let Ok(mut stream) = transport.try_lock() {
return match stream.send(Message::Binary(data)).await { return match stream.send(Message::Binary(data)).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
_ => error_stream!() _ => error_stream!(),
} };
} }
error_proto!() error_proto!()
}) })
@ -254,22 +232,22 @@ impl RSGIWebsocketTransport {
if let Ok(mut stream) = transport.try_lock() { if let Ok(mut stream) = transport.try_lock() {
return match stream.send(Message::Text(data)).await { return match stream.send(Message::Text(data)).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
_ => error_stream!() _ => error_stream!(),
} };
} }
error_proto!() error_proto!()
}) })
} }
} }
#[pyclass(module="granian._granian")] #[pyclass(module = "granian._granian")]
pub(crate) struct RSGIWebsocketProtocol { pub(crate) struct RSGIWebsocketProtocol {
rt: RuntimeRef, rt: RuntimeRef,
tx: Option<oneshot::Sender<(i32, bool)>>, tx: Option<oneshot::Sender<(i32, bool)>>,
websocket: Arc<Mutex<HyperWebsocket>>, websocket: Arc<Mutex<HyperWebsocket>>,
upgrade: Option<UpgradeData>, upgrade: Option<UpgradeData>,
transport: Arc<Mutex<Option<Py<RSGIWebsocketTransport>>>>, transport: Arc<Mutex<Option<Py<RSGIWebsocketTransport>>>>,
status: i32 status: i32,
} }
impl RSGIWebsocketProtocol { impl RSGIWebsocketProtocol {
@ -277,7 +255,7 @@ impl RSGIWebsocketProtocol {
rt: RuntimeRef, rt: RuntimeRef,
tx: oneshot::Sender<(i32, bool)>, tx: oneshot::Sender<(i32, bool)>,
websocket: HyperWebsocket, websocket: HyperWebsocket,
upgrade: UpgradeData upgrade: UpgradeData,
) -> Self { ) -> Self {
Self { Self {
rt, rt,
@ -285,15 +263,12 @@ impl RSGIWebsocketProtocol {
websocket: Arc::new(Mutex::new(websocket)), websocket: Arc::new(Mutex::new(websocket)),
upgrade: Some(upgrade), upgrade: Some(upgrade),
transport: Arc::new(Mutex::new(None)), transport: Arc::new(Mutex::new(None)),
status: 0 status: 0,
} }
} }
fn consumed(&self) -> bool { fn consumed(&self) -> bool {
match &self.upgrade { self.upgrade.is_none()
Some(_) => false,
_ => true
}
} }
pub fn tx(&mut self) -> (Option<oneshot::Sender<(i32, bool)>>, (i32, bool)) { pub fn tx(&mut self) -> (Option<oneshot::Sender<(i32, bool)>>, (i32, bool)) {
@ -304,18 +279,20 @@ impl RSGIWebsocketProtocol {
enum WebsocketMessageType { enum WebsocketMessageType {
Close = 0, Close = 0,
Bytes = 1, Bytes = 1,
Text = 2 Text = 2,
} }
#[pyclass] #[pyclass]
struct WebsocketInboundCloseMessage { struct WebsocketInboundCloseMessage {
#[pyo3(get)] #[pyo3(get)]
kind: usize kind: usize,
} }
impl WebsocketInboundCloseMessage { impl WebsocketInboundCloseMessage {
pub fn new() -> Self { pub fn new() -> Self {
Self { kind: WebsocketMessageType::Close as usize } Self {
kind: WebsocketMessageType::Close as usize,
}
} }
} }
@ -324,12 +301,15 @@ struct WebsocketInboundBytesMessage {
#[pyo3(get)] #[pyo3(get)]
kind: usize, kind: usize,
#[pyo3(get)] #[pyo3(get)]
data: Py<PyBytes> data: Py<PyBytes>,
} }
impl WebsocketInboundBytesMessage { impl WebsocketInboundBytesMessage {
pub fn new(data:Py<PyBytes>) -> Self { pub fn new(data: Py<PyBytes>) -> Self {
Self { kind: WebsocketMessageType::Bytes as usize, data: data } Self {
kind: WebsocketMessageType::Bytes as usize,
data,
}
} }
} }
@ -338,12 +318,15 @@ struct WebsocketInboundTextMessage {
#[pyo3(get)] #[pyo3(get)]
kind: usize, kind: usize,
#[pyo3(get)] #[pyo3(get)]
data: Py<PyString> data: Py<PyString>,
} }
impl WebsocketInboundTextMessage { impl WebsocketInboundTextMessage {
pub fn new(data: Py<PyString>) -> Self { pub fn new(data: Py<PyString>) -> Self {
Self { kind: WebsocketMessageType::Text as usize, data: data } Self {
kind: WebsocketMessageType::Text as usize,
data,
}
} }
} }
@ -374,23 +357,18 @@ impl RSGIWebsocketProtocol {
future_into_py_iter(self.rt.clone(), py, async move { future_into_py_iter(self.rt.clone(), py, async move {
let mut ws = transport.lock().await; let mut ws = transport.lock().await;
match upgrade.send().await { match upgrade.send().await {
Ok(_) => { Ok(_) => match (&mut *ws).await {
match (&mut *ws).await { Ok(stream) => {
Ok(stream) => { let mut trx = itransport.lock().await;
let mut trx = itransport.lock().await; Ok(Python::with_gil(|py| {
Ok(Python::with_gil(|py| { let pytransport = Py::new(py, RSGIWebsocketTransport::new(rth, stream)).unwrap();
let pytransport = Py::new( *trx = Some(pytransport.clone());
py, pytransport
RSGIWebsocketTransport::new(rth, stream) }))
).unwrap();
*trx = Some(pytransport.clone());
pytransport
}))
},
_ => error_proto!()
} }
_ => error_proto!(),
}, },
_ => error_proto!() _ => error_proto!(),
} }
}) })
} }
@ -399,25 +377,13 @@ impl RSGIWebsocketProtocol {
#[inline(always)] #[inline(always)]
fn message_into_py(message: Message) -> PyResult<PyObject> { fn message_into_py(message: Message) -> PyResult<PyObject> {
match message { match message {
Message::Binary(message) => { Message::Binary(message) => Ok(Python::with_gil(|py| {
Ok(Python::with_gil(|py| { WebsocketInboundBytesMessage::new(PyBytes::new(py, &message).into()).into_py(py)
WebsocketInboundBytesMessage::new( })),
PyBytes::new(py, &message).into() Message::Text(message) => Ok(Python::with_gil(|py| {
).into_py(py) WebsocketInboundTextMessage::new(PyString::new(py, &message).into()).into_py(py)
})) })),
}, Message::Close(_) => Ok(Python::with_gil(|py| WebsocketInboundCloseMessage::new().into_py(py))),
Message::Text(message) => {
Ok(Python::with_gil(|py| {
WebsocketInboundTextMessage::new(
PyString::new(py, &message).into()
).into_py(py)
}))
},
Message::Close(_) => {
Ok(Python::with_gil(|py| {
WebsocketInboundCloseMessage::new().into_py(py)
}))
}
v => { v => {
log::warn!("Unsupported websocket message received {:?}", v); log::warn!("Unsupported websocket message received {:?}", v);
error_proto!() error_proto!()

View file

@ -1,28 +1,14 @@
use pyo3::prelude::*; use pyo3::prelude::*;
use crate::{
workers::{
WorkerConfig,
serve_rth,
serve_wth,
serve_rth_ssl,
serve_wth_ssl
}
};
use super::http::{ use super::http::{
handle_rtb, handle_rtb, handle_rtb_pyw, handle_rtb_ws, handle_rtb_ws_pyw, handle_rtt, handle_rtt_pyw, handle_rtt_ws,
handle_rtb_pyw, handle_rtt_ws_pyw,
handle_rtt,
handle_rtt_pyw,
handle_rtb_ws,
handle_rtb_ws_pyw,
handle_rtt_ws,
handle_rtt_ws_pyw
}; };
use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig};
#[pyclass(module="granian._granian")] #[pyclass(module = "granian._granian")]
pub struct RSGIWorker { pub struct RSGIWorker {
config: WorkerConfig config: WorkerConfig,
} }
impl RSGIWorker { impl RSGIWorker {
@ -73,7 +59,7 @@ impl RSGIWorker {
opt_enabled: bool, opt_enabled: bool,
ssl_enabled: bool, ssl_enabled: bool,
ssl_cert: Option<&str>, ssl_cert: Option<&str>,
ssl_key: Option<&str> ssl_key: Option<&str>,
) -> PyResult<Self> { ) -> PyResult<Self> {
Ok(Self { Ok(Self {
config: WorkerConfig::new( config: WorkerConfig::new(
@ -87,22 +73,16 @@ impl RSGIWorker {
opt_enabled, opt_enabled,
ssl_enabled, ssl_enabled,
ssl_cert, ssl_cert,
ssl_key ssl_key,
) ),
}) })
} }
fn serve_rth( fn serve_rth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
match ( match (
self.config.websockets_enabled, self.config.websockets_enabled,
self.config.ssl_enabled, self.config.ssl_enabled,
self.config.opt_enabled self.config.opt_enabled,
) { ) {
(false, false, true) => self._serve_rth(callback, event_loop, context, signal_rx), (false, false, true) => self._serve_rth(callback, event_loop, context, signal_rx),
(false, false, false) => self._serve_rth_pyw(callback, event_loop, context, signal_rx), (false, false, false) => self._serve_rth_pyw(callback, event_loop, context, signal_rx),
@ -111,21 +91,15 @@ impl RSGIWorker {
(false, true, true) => self._serve_rth_ssl(callback, event_loop, context, signal_rx), (false, true, true) => self._serve_rth_ssl(callback, event_loop, context, signal_rx),
(false, true, false) => self._serve_rth_ssl_pyw(callback, event_loop, context, signal_rx), (false, true, false) => self._serve_rth_ssl_pyw(callback, event_loop, context, signal_rx),
(true, true, true) => self._serve_rth_ssl_ws(callback, event_loop, context, signal_rx), (true, true, true) => self._serve_rth_ssl_ws(callback, event_loop, context, signal_rx),
(true, true, false) => self._serve_rth_ssl_ws_pyw(callback, event_loop, context, signal_rx) (true, true, false) => self._serve_rth_ssl_ws_pyw(callback, event_loop, context, signal_rx),
} }
} }
fn serve_wth( fn serve_wth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
match ( match (
self.config.websockets_enabled, self.config.websockets_enabled,
self.config.ssl_enabled, self.config.ssl_enabled,
self.config.opt_enabled self.config.opt_enabled,
) { ) {
(false, false, true) => self._serve_wth(callback, event_loop, context, signal_rx), (false, false, true) => self._serve_wth(callback, event_loop, context, signal_rx),
(false, false, false) => self._serve_wth_pyw(callback, event_loop, context, signal_rx), (false, false, false) => self._serve_wth_pyw(callback, event_loop, context, signal_rx),
@ -134,7 +108,7 @@ impl RSGIWorker {
(false, true, true) => self._serve_wth_ssl(callback, event_loop, context, signal_rx), (false, true, true) => self._serve_wth_ssl(callback, event_loop, context, signal_rx),
(false, true, false) => self._serve_wth_ssl_pyw(callback, event_loop, context, signal_rx), (false, true, false) => self._serve_wth_ssl_pyw(callback, event_loop, context, signal_rx),
(true, true, true) => self._serve_wth_ssl_ws(callback, event_loop, context, signal_rx), (true, true, true) => self._serve_wth_ssl_ws(callback, event_loop, context, signal_rx),
(true, true, false) => self._serve_wth_ssl_ws_pyw(callback, event_loop, context, signal_rx) (true, true, false) => self._serve_wth_ssl_ws_pyw(callback, event_loop, context, signal_rx),
} }
} }
} }

View file

@ -1,6 +1,6 @@
use hyper::{ use hyper::{
header::{HeaderMap, HeaderName, HeaderValue, SERVER as HK_SERVER}, header::{HeaderMap, HeaderName, HeaderValue, SERVER as HK_SERVER},
Body, Uri, Version Body, Uri, Version,
}; };
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::PyString; use pyo3::types::PyString;
@ -10,11 +10,10 @@ use tokio_util::codec::{BytesCodec, FramedRead};
use crate::http::HV_SERVER; use crate::http::HV_SERVER;
#[pyclass(module = "granian._granian")]
#[pyclass(module="granian._granian")]
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct RSGIHeaders { pub(crate) struct RSGIHeaders {
inner: HeaderMap inner: HeaderMap,
} }
impl RSGIHeaders { impl RSGIHeaders {
@ -29,7 +28,7 @@ impl RSGIHeaders {
let mut ret = Vec::with_capacity(self.inner.keys_len()); let mut ret = Vec::with_capacity(self.inner.keys_len());
for key in self.inner.keys() { for key in self.inner.keys() {
ret.push(key.as_str()); ret.push(key.as_str());
}; }
ret ret
} }
@ -37,15 +36,15 @@ impl RSGIHeaders {
let mut ret = Vec::with_capacity(self.inner.keys_len()); let mut ret = Vec::with_capacity(self.inner.keys_len());
for val in self.inner.values() { for val in self.inner.values() {
ret.push(val.to_str().unwrap()); ret.push(val.to_str().unwrap());
}; }
Ok(ret) Ok(ret)
} }
fn items(&self) -> PyResult<Vec<(&str, &str)>> { fn items(&self) -> PyResult<Vec<(&str, &str)>> {
let mut ret = Vec::with_capacity(self.inner.keys_len()); let mut ret = Vec::with_capacity(self.inner.keys_len());
for (key, val) in self.inner.iter() { for (key, val) in &self.inner {
ret.push((key.as_str(), val.to_str().unwrap())); ret.push((key.as_str(), val.to_str().unwrap()));
}; }
Ok(ret) Ok(ret)
} }
@ -56,18 +55,16 @@ impl RSGIHeaders {
#[pyo3(signature = (key, default=None))] #[pyo3(signature = (key, default=None))]
fn get(&self, py: Python, key: &str, default: Option<PyObject>) -> Option<PyObject> { fn get(&self, py: Python, key: &str, default: Option<PyObject>) -> Option<PyObject> {
match self.inner.get(key) { match self.inner.get(key) {
Some(val) => { Some(val) => match val.to_str() {
match val.to_str() { Ok(string) => Some(PyString::new(py, string).into()),
Ok(string) => Some(PyString::new(py, string).into()), _ => default,
_ => default
}
}, },
_ => default _ => default,
} }
} }
} }
#[pyclass(module="granian._granian")] #[pyclass(module = "granian._granian")]
pub(crate) struct RSGIScope { pub(crate) struct RSGIScope {
#[pyo3(get)] #[pyo3(get)]
proto: String, proto: String,
@ -84,7 +81,7 @@ pub(crate) struct RSGIScope {
#[pyo3(get)] #[pyo3(get)]
client: String, client: String,
#[pyo3(get)] #[pyo3(get)]
headers: RSGIHeaders headers: RSGIHeaders,
} }
impl RSGIScope { impl RSGIScope {
@ -96,23 +93,23 @@ impl RSGIScope {
method: &str, method: &str,
server: SocketAddr, server: SocketAddr,
client: SocketAddr, client: SocketAddr,
headers: &HeaderMap headers: &HeaderMap,
) -> Self { ) -> Self {
Self { Self {
proto: proto.to_string(), proto: proto.to_string(),
http_version: http_version, http_version,
rsgi_version: "1.2".to_string(), rsgi_version: "1.2".to_string(),
scheme: scheme.to_string(), scheme: scheme.to_string(),
method: method.to_string(), method: method.to_string(),
uri: uri, uri,
server: server.to_string(), server: server.to_string(),
client: client.to_string(), client: client.to_string(),
headers: RSGIHeaders::new(headers) headers: RSGIHeaders::new(headers),
} }
} }
pub fn set_proto(&mut self, value: &str) { pub fn set_proto(&mut self, value: &str) {
self.proto = value.to_string() self.proto = value.to_string();
} }
} }
@ -125,7 +122,7 @@ impl RSGIScope {
Version::HTTP_11 => "1.1", Version::HTTP_11 => "1.1",
Version::HTTP_2 => "2", Version::HTTP_2 => "2",
Version::HTTP_3 => "3", Version::HTTP_3 => "3",
_ => "1" _ => "1",
} }
} }
@ -142,38 +139,36 @@ impl RSGIScope {
pub(crate) enum PyResponse { pub(crate) enum PyResponse {
Body(PyResponseBody), Body(PyResponseBody),
File(PyResponseFile) File(PyResponseFile),
} }
pub(crate) struct PyResponseBody { pub(crate) struct PyResponseBody {
status: u16, status: u16,
headers: Vec<(String, String)>, headers: Vec<(String, String)>,
body: Body body: Body,
} }
pub(crate) struct PyResponseFile { pub(crate) struct PyResponseFile {
status: u16, status: u16,
headers: Vec<(String, String)>, headers: Vec<(String, String)>,
file_path: String file_path: String,
} }
macro_rules! response_head_from_py { macro_rules! response_head_from_py {
($status:expr, $headers:expr, $res:expr) => { ($status:expr, $headers:expr, $res:expr) => {{
{ let mut rh = hyper::http::HeaderMap::new();
let mut rh = hyper::http::HeaderMap::new();
rh.insert(HK_SERVER, HV_SERVER); rh.insert(HK_SERVER, HV_SERVER);
for (key, value) in $headers { for (key, value) in $headers {
rh.append( rh.append(
HeaderName::from_bytes(key.as_bytes()).unwrap(), HeaderName::from_bytes(key.as_bytes()).unwrap(),
HeaderValue::from_str(&value).unwrap() HeaderValue::from_str(&value).unwrap(),
); );
}
*$res.status_mut() = $status.try_into().unwrap();
*$res.headers_mut() = rh;
} }
}
*$res.status_mut() = $status.try_into().unwrap();
*$res.headers_mut() = rh;
}};
} }
impl PyResponseBody { impl PyResponseBody {
@ -182,18 +177,30 @@ impl PyResponseBody {
} }
pub fn empty(status: u16, headers: Vec<(String, String)>) -> Self { pub fn empty(status: u16, headers: Vec<(String, String)>) -> Self {
Self { status, headers, body: Body::empty() } Self {
status,
headers,
body: Body::empty(),
}
} }
pub fn from_bytes(status: u16, headers: Vec<(String, String)>, body: Vec<u8>) -> Self { pub fn from_bytes(status: u16, headers: Vec<(String, String)>, body: Vec<u8>) -> Self {
Self { status, headers, body: Body::from(body) } Self {
status,
headers,
body: Body::from(body),
}
} }
pub fn from_string(status: u16, headers: Vec<(String, String)>, body: String) -> Self { pub fn from_string(status: u16, headers: Vec<(String, String)>, body: String) -> Self {
Self { status, headers, body: Body::from(body) } Self {
status,
headers,
body: Body::from(body),
}
} }
pub fn to_response(self) -> hyper::Response::<Body> { pub fn to_response(self) -> hyper::Response<Body> {
let mut res = hyper::Response::<Body>::new(self.body); let mut res = hyper::Response::<Body>::new(self.body);
response_head_from_py!(self.status, &self.headers, res); response_head_from_py!(self.status, &self.headers, res);
res res
@ -202,10 +209,14 @@ impl PyResponseBody {
impl PyResponseFile { impl PyResponseFile {
pub fn new(status: u16, headers: Vec<(String, String)>, file_path: String) -> Self { pub fn new(status: u16, headers: Vec<(String, String)>, file_path: String) -> Self {
Self { status, headers, file_path } Self {
status,
headers,
file_path,
}
} }
pub async fn to_response(&self) -> hyper::Response::<Body> { pub async fn to_response(&self) -> hyper::Response<Body> {
let file = File::open(&self.file_path).await.unwrap(); let file = File::open(&self.file_path).await.unwrap();
let stream = FramedRead::new(file, BytesCodec::new()); let stream = FramedRead::new(file, BytesCodec::new());
let mut res = hyper::Response::<Body>::new(Body::wrap_stream(stream)); let mut res = hyper::Response::<Body>::new(Body::wrap_stream(stream));

View file

@ -1,12 +1,19 @@
use once_cell::unsync::OnceCell as UnsyncOnceCell; use once_cell::unsync::OnceCell as UnsyncOnceCell;
use pyo3_asyncio::TaskLocals;
use pyo3::prelude::*; use pyo3::prelude::*;
use std::{future::Future, io, pin::Pin, sync::{Arc, Mutex}}; use pyo3_asyncio::TaskLocals;
use tokio::{runtime::Builder, task::{JoinHandle, LocalSet}}; use std::{
future::Future,
io,
pin::Pin,
sync::{Arc, Mutex},
};
use tokio::{
runtime::Builder,
task::{JoinHandle, LocalSet},
};
use super::callbacks::{PyFutureAwaitable, PyIterAwaitable}; use super::callbacks::{PyFutureAwaitable, PyIterAwaitable};
tokio::task_local! { tokio::task_local! {
static TASK_LOCALS: UnsyncOnceCell<TaskLocals>; static TASK_LOCALS: UnsyncOnceCell<TaskLocals>;
} }
@ -27,11 +34,7 @@ pub trait Runtime: Send + 'static {
} }
pub trait ContextExt: Runtime { pub trait ContextExt: Runtime {
fn scope<F, R>( fn scope<F, R>(&self, locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
&self,
locals: TaskLocals,
fut: F
) -> Pin<Box<dyn Future<Output = R> + Send>>
where where
F: Future<Output = R> + Send + 'static; F: Future<Output = R> + Send + 'static;
@ -45,36 +48,34 @@ pub trait SpawnLocalExt: Runtime {
} }
pub trait LocalContextExt: Runtime { pub trait LocalContextExt: Runtime {
fn scope_local<F, R>( fn scope_local<F, R>(&self, locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R>>>
&self,
locals: TaskLocals,
fut: F
) -> Pin<Box<dyn Future<Output = R>>>
where where
F: Future<Output = R> + 'static; F: Future<Output = R> + 'static;
} }
pub(crate) struct RuntimeWrapper { pub(crate) struct RuntimeWrapper {
rt: tokio::runtime::Runtime rt: tokio::runtime::Runtime,
} }
impl RuntimeWrapper { impl RuntimeWrapper {
pub fn new(blocking_threads: usize) -> Self { pub fn new(blocking_threads: usize) -> Self {
Self { rt: default_runtime(blocking_threads).unwrap() } Self {
rt: default_runtime(blocking_threads).unwrap(),
}
} }
pub fn with_runtime(rt: tokio::runtime::Runtime) -> Self { pub fn with_runtime(rt: tokio::runtime::Runtime) -> Self {
Self { rt: rt } Self { rt }
} }
pub fn handler(&self) -> RuntimeRef { pub fn handler(&self) -> RuntimeRef {
RuntimeRef::new(self.rt.handle().to_owned()) RuntimeRef::new(self.rt.handle().clone())
} }
} }
#[derive(Clone)] #[derive(Clone)]
pub struct RuntimeRef { pub struct RuntimeRef {
pub inner: tokio::runtime::Handle pub inner: tokio::runtime::Handle,
} }
impl RuntimeRef { impl RuntimeRef {
@ -108,11 +109,7 @@ impl Runtime for RuntimeRef {
} }
impl ContextExt for RuntimeRef { impl ContextExt for RuntimeRef {
fn scope<F, R>( fn scope<F, R>(&self, locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
&self,
locals: TaskLocals,
fut: F
) -> Pin<Box<dyn Future<Output = R> + Send>>
where where
F: Future<Output = R> + Send + 'static, F: Future<Output = R> + Send + 'static,
{ {
@ -123,7 +120,7 @@ impl ContextExt for RuntimeRef {
} }
fn get_task_locals() -> Option<TaskLocals> { fn get_task_locals() -> Option<TaskLocals> {
match TASK_LOCALS.try_with(|c| c.get().map(|locals| locals.clone())) { match TASK_LOCALS.try_with(|c| c.get().cloned()) {
Ok(locals) => locals, Ok(locals) => locals,
Err(_) => None, Err(_) => None,
} }
@ -140,11 +137,7 @@ impl SpawnLocalExt for RuntimeRef {
} }
impl LocalContextExt for RuntimeRef { impl LocalContextExt for RuntimeRef {
fn scope_local<F, R>( fn scope_local<F, R>(&self, locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R>>>
&self,
locals: TaskLocals,
fut: F
) -> Pin<Box<dyn Future<Output = R>>>
where where
F: Future<Output = R> + 'static, F: Future<Output = R> + 'static,
{ {
@ -169,7 +162,7 @@ pub(crate) fn init_runtime_mt(threads: usize, blocking_threads: usize) -> Runtim
.max_blocking_threads(blocking_threads) .max_blocking_threads(blocking_threads)
.enable_all() .enable_all()
.build() .build()
.unwrap() .unwrap(),
) )
} }
@ -177,12 +170,8 @@ pub(crate) fn init_runtime_st(blocking_threads: usize) -> RuntimeWrapper {
RuntimeWrapper::new(blocking_threads) RuntimeWrapper::new(blocking_threads)
} }
pub(crate) fn into_future( pub(crate) fn into_future(awaitable: &PyAny) -> PyResult<impl Future<Output = PyResult<PyObject>> + Send> {
awaitable: &PyAny, pyo3_asyncio::into_future_with_locals(&get_current_locals::<RuntimeRef>(awaitable.py())?, awaitable)
) -> PyResult<impl Future<Output = PyResult<PyObject>> + Send> {
pyo3_asyncio::into_future_with_locals(
&get_current_locals::<RuntimeRef>(awaitable.py())?, awaitable
)
} }
#[inline] #[inline]
@ -241,10 +230,7 @@ where
rt.spawn(async move { rt.spawn(async move {
let result = fut.await; let result = fut.await;
Python::with_gil(move |py| { Python::with_gil(move |py| {
PyFutureAwaitable::set_result( PyFutureAwaitable::set_result(py_aw.as_ref(py).borrow_mut(), result.map(|v| v.into_py(py)));
py_aw.as_ref(py).borrow_mut(),
result.map(|v| v.into_py(py))
);
}); });
}); });
@ -269,11 +255,7 @@ where
let rth = rt.handler(); let rth = rt.handler();
rt.spawn(async move { rt.spawn(async move {
let val = rth.scope( let val = rth.scope(task_locals.clone(), fut).await;
task_locals.clone(),
fut
)
.await;
if let Ok(mut result) = result_tx.lock() { if let Ok(mut result) = result_tx.lock() {
*result = Some(val.unwrap()); *result = Some(val.unwrap());
} }
@ -292,7 +274,7 @@ where
pub(crate) fn block_on_local<F>(rt: RuntimeWrapper, local: LocalSet, fut: F) pub(crate) fn block_on_local<F>(rt: RuntimeWrapper, local: LocalSet, fut: F)
where where
F: Future + 'static F: Future + 'static,
{ {
local.block_on(&rt.rt, fut); local.block_on(&rt.rt, fut);
} }

View file

@ -9,10 +9,9 @@ use std::os::windows::io::{AsRawSocket, FromRawSocket};
use socket2::{Domain, Protocol, Socket, Type}; use socket2::{Domain, Protocol, Socket, Type};
#[pyclass(module = "granian._granian")]
#[pyclass(module="granian._granian")]
pub struct ListenerHolder { pub struct ListenerHolder {
socket: TcpListener socket: TcpListener,
} }
#[pymethods] #[pymethods]
@ -20,28 +19,19 @@ impl ListenerHolder {
#[cfg(unix)] #[cfg(unix)]
#[new] #[new]
pub fn new(fd: i32) -> PyResult<Self> { pub fn new(fd: i32) -> PyResult<Self> {
let socket = unsafe { let socket = unsafe { TcpListener::from_raw_fd(fd) };
TcpListener::from_raw_fd(fd) Ok(Self { socket })
};
Ok(Self { socket: socket })
} }
#[cfg(windows)] #[cfg(windows)]
#[new] #[new]
pub fn new(fd: u64) -> PyResult<Self> { pub fn new(fd: u64) -> PyResult<Self> {
let socket = unsafe { let socket = unsafe { TcpListener::from_raw_socket(fd) };
TcpListener::from_raw_socket(fd) Ok(Self { socket })
};
Ok(Self { socket: socket })
} }
#[classmethod] #[classmethod]
pub fn from_address( pub fn from_address(_cls: &PyType, address: &str, port: u16, backlog: i32) -> PyResult<Self> {
_cls: &PyType,
address: &str,
port: u16,
backlog: i32
) -> PyResult<Self> {
let address: SocketAddr = (address.parse::<IpAddr>()?, port).into(); let address: SocketAddr = (address.parse::<IpAddr>()?, port).into();
let socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?; let socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?;
socket.set_reuse_address(true)?; socket.set_reuse_address(true)?;
@ -54,17 +44,13 @@ impl ListenerHolder {
#[cfg(unix)] #[cfg(unix)]
pub fn __getstate__(&self, py: Python) -> PyObject { pub fn __getstate__(&self, py: Python) -> PyObject {
let fd = self.socket.as_raw_fd(); let fd = self.socket.as_raw_fd();
( (fd.into_py(py),).to_object(py)
fd.into_py(py),
).to_object(py)
} }
#[cfg(windows)] #[cfg(windows)]
pub fn __getstate__(&self, py: Python) -> PyObject { pub fn __getstate__(&self, py: Python) -> PyObject {
let fd = self.socket.as_raw_socket(); let fd = self.socket.as_raw_socket();
( (fd.into_py(py),).to_object(py)
fd.into_py(py),
).to_object(py)
} }
#[cfg(unix)] #[cfg(unix)]
@ -84,7 +70,6 @@ impl ListenerHolder {
} }
} }
pub(crate) fn init_pymodule(module: &PyModule) -> PyResult<()> { pub(crate) fn init_pymodule(module: &PyModule) -> PyResult<()> {
module.add_class::<ListenerHolder>()?; module.add_class::<ListenerHolder>()?;

View file

@ -1,20 +1,22 @@
use futures::stream::StreamExt; use futures::stream::StreamExt;
use hyper::server::{accept, conn::{AddrIncoming, AddrStream}}; use hyper::server::{
accept,
conn::{AddrIncoming, AddrStream},
};
use std::{fs, future, io, iter::Iterator, net::TcpListener, sync::Arc}; use std::{fs, future, io, iter::Iterator, net::TcpListener, sync::Arc};
use tls_listener::{Error as TlsError, TlsListener}; use tls_listener::{Error as TlsError, TlsListener};
use tokio_rustls::{ use tokio_rustls::{
TlsAcceptor,
rustls::{Certificate, PrivateKey, ServerConfig}, rustls::{Certificate, PrivateKey, ServerConfig},
server::TlsStream server::TlsStream,
TlsAcceptor,
}; };
pub(crate) type TlsAddrStream = TlsStream<AddrStream>; pub(crate) type TlsAddrStream = TlsStream<AddrStream>;
pub(crate) fn tls_listen( pub(crate) fn tls_listen(
config: Arc<ServerConfig>, config: Arc<ServerConfig>,
tcp: TcpListener tcp: TcpListener,
) -> impl accept::Accept<Conn=TlsAddrStream, Error=TlsError<io::Error, io::Error>> { ) -> impl accept::Accept<Conn = TlsAddrStream, Error = TlsError<io::Error, io::Error>> {
tcp.set_nonblocking(true).unwrap(); tcp.set_nonblocking(true).unwrap();
let tcp_listener = tokio::net::TcpListener::from_std(tcp).unwrap(); let tcp_listener = tokio::net::TcpListener::from_std(tcp).unwrap();
let incoming = AddrIncoming::from_listener(tcp_listener).unwrap(); let incoming = AddrIncoming::from_listener(tcp_listener).unwrap();
@ -34,30 +36,26 @@ fn tls_error(err: String) -> io::Error {
} }
pub(crate) fn load_certs(filename: &str) -> io::Result<Vec<Certificate>> { pub(crate) fn load_certs(filename: &str) -> io::Result<Vec<Certificate>> {
let certfile = fs::File::open(filename) let certfile = fs::File::open(filename).map_err(|e| tls_error(format!("failed to open {filename}: {e}")))?;
.map_err(|e| tls_error(format!("failed to open {}: {}", filename, e)))?;
let mut reader = io::BufReader::new(certfile); let mut reader = io::BufReader::new(certfile);
let certs = rustls_pemfile::certs(&mut reader) let certs = rustls_pemfile::certs(&mut reader).map_err(|_| tls_error("failed to load certificate".into()))?;
.map_err(|_| tls_error("failed to load certificate".into()))?;
Ok(certs.into_iter().map(Certificate).collect()) Ok(certs.into_iter().map(Certificate).collect())
} }
pub(crate) fn load_private_key(filename: &str) -> io::Result<PrivateKey> { pub(crate) fn load_private_key(filename: &str) -> io::Result<PrivateKey> {
let keyfile = fs::File::open(filename) let keyfile = fs::File::open(filename).map_err(|e| tls_error(format!("failed to open {filename}: {e}")))?;
.map_err(|e| tls_error(format!("failed to open {}: {}", filename, e)))?;
let mut reader = io::BufReader::new(keyfile); let mut reader = io::BufReader::new(keyfile);
let keys = rustls_pemfile::read_all(&mut reader) let keys = rustls_pemfile::read_all(&mut reader).map_err(|_| tls_error("failed to load private key".into()))?;
.map_err(|_| tls_error("failed to load private key".into()))?;
if keys.len() != 1 { if keys.len() != 1 {
return Err(tls_error("expected a single private key".into())); return Err(tls_error("expected a single private key".into()));
} }
let key = match &keys[0] { let key = match &keys[0] {
rustls_pemfile::Item::RSAKey(key) => PrivateKey(key.to_vec()), rustls_pemfile::Item::RSAKey(key) => PrivateKey(key.clone()),
rustls_pemfile::Item::PKCS8Key(key) => PrivateKey(key.to_vec()), rustls_pemfile::Item::PKCS8Key(key) => PrivateKey(key.clone()),
rustls_pemfile::Item::ECKey(key) => PrivateKey(key.to_vec()), rustls_pemfile::Item::ECKey(key) => PrivateKey(key.clone()),
_ => { _ => {
return Err(tls_error("failed to load private key".into())); return Err(tls_error("failed to load private key".into()));
} }

View file

@ -1,13 +1,15 @@
pub(crate) fn header_contains_value( pub(crate) fn header_contains_value(
headers: &hyper::HeaderMap, headers: &hyper::HeaderMap,
header: impl hyper::header::AsHeaderName, header: impl hyper::header::AsHeaderName,
value: impl AsRef<[u8]> value: impl AsRef<[u8]>,
) -> bool { ) -> bool {
let value = value.as_ref(); let value = value.as_ref();
for header in headers.get_all(header) { for header in headers.get_all(header) {
if header.as_bytes().split(|&c| c == b',').any( if header
|x| trim(x).eq_ignore_ascii_case(value) .as_bytes()
) { .split(|&c| c == b',')
.any(|x| trim(x).eq_ignore_ascii_case(value))
{
return true; return true;
} }
} }
@ -30,7 +32,7 @@ fn trim_start(data: &[u8]) -> &[u8] {
#[inline] #[inline]
fn trim_end(data: &[u8]) -> &[u8] { fn trim_end(data: &[u8]) -> &[u8] {
if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) { if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) {
&data[..last + 1] &data[..=last]
} else { } else {
b"" b""
} }

View file

@ -8,8 +8,8 @@ use std::os::windows::io::FromRawSocket;
use super::asgi::serve::ASGIWorker; use super::asgi::serve::ASGIWorker;
use super::rsgi::serve::RSGIWorker; use super::rsgi::serve::RSGIWorker;
use super::wsgi::serve::WSGIWorker;
use super::tls::{load_certs as tls_load_certs, load_private_key as tls_load_pkey}; use super::tls::{load_certs as tls_load_certs, load_private_key as tls_load_pkey};
use super::wsgi::serve::WSGIWorker;
pub(crate) struct WorkerConfig { pub(crate) struct WorkerConfig {
pub id: i32, pub id: i32,
@ -22,7 +22,7 @@ pub(crate) struct WorkerConfig {
pub opt_enabled: bool, pub opt_enabled: bool,
pub ssl_enabled: bool, pub ssl_enabled: bool,
ssl_cert: Option<String>, ssl_cert: Option<String>,
ssl_key: Option<String> ssl_key: Option<String>,
} }
impl WorkerConfig { impl WorkerConfig {
@ -37,7 +37,7 @@ impl WorkerConfig {
opt_enabled: bool, opt_enabled: bool,
ssl_enabled: bool, ssl_enabled: bool,
ssl_cert: Option<&str>, ssl_cert: Option<&str>,
ssl_key: Option<&str> ssl_key: Option<&str>,
) -> Self { ) -> Self {
Self { Self {
id, id,
@ -49,23 +49,19 @@ impl WorkerConfig {
websockets_enabled, websockets_enabled,
opt_enabled, opt_enabled,
ssl_enabled, ssl_enabled,
ssl_cert: ssl_cert.map_or(None, |v| Some(v.into())), ssl_cert: ssl_cert.map(std::convert::Into::into),
ssl_key: ssl_key.map_or(None, |v| Some(v.into())) ssl_key: ssl_key.map(std::convert::Into::into),
} }
} }
#[cfg(unix)] #[cfg(unix)]
pub fn tcp_listener(&self) -> TcpListener { pub fn tcp_listener(&self) -> TcpListener {
unsafe { unsafe { TcpListener::from_raw_fd(self.socket_fd) }
TcpListener::from_raw_fd(self.socket_fd)
}
} }
#[cfg(windows)] #[cfg(windows)]
pub fn tcp_listener(&self) -> TcpListener { pub fn tcp_listener(&self) -> TcpListener {
unsafe { unsafe { TcpListener::from_raw_socket(self.socket_fd as u64) }
TcpListener::from_raw_socket(self.socket_fd as u64)
}
} }
pub fn tls_cfg(&self) -> tokio_rustls::rustls::ServerConfig { pub fn tls_cfg(&self) -> tokio_rustls::rustls::ServerConfig {
@ -74,13 +70,13 @@ impl WorkerConfig {
.with_no_client_auth() .with_no_client_auth()
.with_single_cert( .with_single_cert(
tls_load_certs(&self.ssl_cert.clone().unwrap()[..]).unwrap(), tls_load_certs(&self.ssl_cert.clone().unwrap()[..]).unwrap(),
tls_load_pkey(&self.ssl_key.clone().unwrap()[..]).unwrap() tls_load_pkey(&self.ssl_key.clone().unwrap()[..]).unwrap(),
) )
.unwrap(); .unwrap();
cfg.alpn_protocols = match &self.http_mode[..] { cfg.alpn_protocols = match &self.http_mode[..] {
"1" => vec![b"http/1.1".to_vec()], "1" => vec![b"http/1.1".to_vec()],
"2" => vec![b"h2".to_vec()], "2" => vec![b"h2".to_vec()],
_ => vec![b"h2".to_vec(), b"http/1.1".to_vec()] _ => vec![b"h2".to_vec(), b"http/1.1".to_vec()],
}; };
cfg cfg
} }
@ -102,7 +98,7 @@ pub(crate) struct WorkerExecutor;
impl<F> hyper::rt::Executor<F> for WorkerExecutor impl<F> hyper::rt::Executor<F> for WorkerExecutor
where where
F: std::future::Future + 'static F: std::future::Future + 'static,
{ {
fn execute(&self, fut: F) { fn execute(&self, fut: F) {
tokio::task::spawn_local(fut); tokio::task::spawn_local(fut);
@ -123,14 +119,9 @@ macro_rules! build_service {
let rth = rth.clone(); let rth = rth.clone();
async move { async move {
Ok::<_, std::convert::Infallible>($target( Ok::<_, std::convert::Infallible>(
rth, $target(rth, callback_wrapper, local_addr, remote_addr, req, "http").await,
callback_wrapper, )
local_addr,
remote_addr,
req,
"http"
).await)
} }
})) }))
} }
@ -153,14 +144,9 @@ macro_rules! build_service_ssl {
let rth = rth.clone(); let rth = rth.clone();
async move { async move {
Ok::<_, std::convert::Infallible>($target( Ok::<_, std::convert::Infallible>(
rth, $target(rth, callback_wrapper, local_addr, remote_addr, req, "https").await,
callback_wrapper, )
local_addr,
remote_addr,
req,
"https"
).await)
} }
})) }))
} }
@ -170,13 +156,7 @@ macro_rules! build_service_ssl {
macro_rules! serve_rth { macro_rules! serve_rth {
($func_name:ident, $target:expr) => { ($func_name:ident, $target:expr) => {
fn $func_name( fn $func_name(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
pyo3_log::init(); pyo3_log::init();
let rt = crate::runtime::init_runtime_mt(self.config.threads, self.config.pthreads); let rt = crate::runtime::init_runtime_mt(self.config.threads, self.config.pthreads);
let rth = rt.handler(); let rth = rt.handler();
@ -184,34 +164,30 @@ macro_rules! serve_rth {
let http1_only = self.config.http_mode == "1"; let http1_only = self.config.http_mode == "1";
let http2_only = self.config.http_mode == "2"; let http2_only = self.config.http_mode == "2";
let http1_buffer_max = self.config.http1_buffer_max.clone(); let http1_buffer_max = self.config.http1_buffer_max.clone();
let callback_wrapper = crate::callbacks::CallbackWrapper::new( let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context);
callback, event_loop, context
);
let worker_id = self.config.id; let worker_id = self.config.id;
log::info!("Started worker-{}", worker_id); log::info!("Started worker-{}", worker_id);
let svc_loop = crate::runtime::run_until_complete( let svc_loop = crate::runtime::run_until_complete(rt.handler(), event_loop, async move {
rt.handler(), let service = crate::workers::build_service!(callback_wrapper, rth, $target);
event_loop, let server = hyper::Server::from_tcp(tcp_listener)
async move { .unwrap()
let service = crate::workers::build_service!( .http1_only(http1_only)
callback_wrapper, rth, $target .http2_only(http2_only)
); .http1_max_buf_size(http1_buffer_max)
let server = hyper::Server::from_tcp(tcp_listener).unwrap() .serve(service);
.http1_only(http1_only) server
.http2_only(http2_only) .with_graceful_shutdown(async move {
.http1_max_buf_size(http1_buffer_max) Python::with_gil(|py| crate::runtime::into_future(signal_rx.as_ref(py)).unwrap())
.serve(service); .await
server.with_graceful_shutdown(async move { .unwrap();
Python::with_gil(|py| { })
crate::runtime::into_future(signal_rx.as_ref(py)).unwrap() .await
}).await.unwrap(); .unwrap();
}).await.unwrap(); log::info!("Stopping worker-{}", worker_id);
log::info!("Stopping worker-{}", worker_id); Ok(())
Ok(()) });
}
);
match svc_loop { match svc_loop {
Ok(_) => {} Ok(_) => {}
@ -226,13 +202,7 @@ macro_rules! serve_rth {
macro_rules! serve_rth_ssl { macro_rules! serve_rth_ssl {
($func_name:ident, $target:expr) => { ($func_name:ident, $target:expr) => {
fn $func_name( fn $func_name(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
pyo3_log::init(); pyo3_log::init();
let rt = crate::runtime::init_runtime_mt(self.config.threads, self.config.pthreads); let rt = crate::runtime::init_runtime_mt(self.config.threads, self.config.pthreads);
let rth = rt.handler(); let rth = rt.handler();
@ -241,38 +211,29 @@ macro_rules! serve_rth_ssl {
let http2_only = self.config.http_mode == "2"; let http2_only = self.config.http_mode == "2";
let http1_buffer_max = self.config.http1_buffer_max.clone(); let http1_buffer_max = self.config.http1_buffer_max.clone();
let tls_cfg = self.config.tls_cfg(); let tls_cfg = self.config.tls_cfg();
let callback_wrapper = crate::callbacks::CallbackWrapper::new( let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context);
callback, event_loop, context
);
let worker_id = self.config.id; let worker_id = self.config.id;
log::info!("Started worker-{}", worker_id); log::info!("Started worker-{}", worker_id);
let svc_loop = crate::runtime::run_until_complete( let svc_loop = crate::runtime::run_until_complete(rt.handler(), event_loop, async move {
rt.handler(), let service = crate::workers::build_service_ssl!(callback_wrapper, rth, $target);
event_loop, let server = hyper::Server::builder(crate::tls::tls_listen(std::sync::Arc::new(tls_cfg), tcp_listener))
async move { .http1_only(http1_only)
let service = crate::workers::build_service_ssl!( .http2_only(http2_only)
callback_wrapper, rth, $target .http1_max_buf_size(http1_buffer_max)
); .serve(service);
let server = hyper::Server::builder( server
crate::tls::tls_listen( .with_graceful_shutdown(async move {
std::sync::Arc::new(tls_cfg), tcp_listener Python::with_gil(|py| crate::runtime::into_future(signal_rx.as_ref(py)).unwrap())
) .await
) .unwrap();
.http1_only(http1_only) })
.http2_only(http2_only) .await
.http1_max_buf_size(http1_buffer_max) .unwrap();
.serve(service); log::info!("Stopping worker-{}", worker_id);
server.with_graceful_shutdown(async move { Ok(())
Python::with_gil(|py| { });
crate::runtime::into_future(signal_rx.as_ref(py)).unwrap()
}).await.unwrap();
}).await.unwrap();
log::info!("Stopping worker-{}", worker_id);
Ok(())
}
);
match svc_loop { match svc_loop {
Ok(_) => {} Ok(_) => {}
@ -287,22 +248,14 @@ macro_rules! serve_rth_ssl {
macro_rules! serve_wth { macro_rules! serve_wth {
($func_name: ident, $target:expr) => { ($func_name: ident, $target:expr) => {
fn $func_name( fn $func_name(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
pyo3_log::init(); pyo3_log::init();
let rtm = crate::runtime::init_runtime_mt(1, 1); let rtm = crate::runtime::init_runtime_mt(1, 1);
let worker_id = self.config.id; let worker_id = self.config.id;
log::info!("Started worker-{}", worker_id); log::info!("Started worker-{}", worker_id);
let callback_wrapper = crate::callbacks::CallbackWrapper::new( let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context);
callback, event_loop, context
);
let mut workers = vec![]; let mut workers = vec![];
let (stx, srx) = tokio::sync::watch::channel(false); let (stx, srx) = tokio::sync::watch::channel(false);
@ -323,38 +276,36 @@ macro_rules! serve_wth {
let local = tokio::task::LocalSet::new(); let local = tokio::task::LocalSet::new();
crate::runtime::block_on_local(rt, local, async move { crate::runtime::block_on_local(rt, local, async move {
let service = crate::workers::build_service!( let service = crate::workers::build_service!(callback_wrapper, rth, $target);
callback_wrapper, rth, $target let server = hyper::Server::from_tcp(tcp_listener)
); .unwrap()
let server = hyper::Server::from_tcp(tcp_listener).unwrap()
.executor(crate::workers::WorkerExecutor) .executor(crate::workers::WorkerExecutor)
.http1_only(http1_only) .http1_only(http1_only)
.http2_only(http2_only) .http2_only(http2_only)
.http1_max_buf_size(http1_buffer_max) .http1_max_buf_size(http1_buffer_max)
.serve(service); .serve(service);
server.with_graceful_shutdown(async move { server
srx.changed().await.unwrap(); .with_graceful_shutdown(async move {
}).await.unwrap(); srx.changed().await.unwrap();
})
.await
.unwrap();
log::info!("Stopping worker-{} runtime-{}", worker_id, thread_id + 1); log::info!("Stopping worker-{} runtime-{}", worker_id, thread_id + 1);
}); });
})); }));
}; }
let main_loop = crate::runtime::run_until_complete( let main_loop = crate::runtime::run_until_complete(rtm.handler(), event_loop, async move {
rtm.handler(), Python::with_gil(|py| crate::runtime::into_future(signal_rx.as_ref(py)).unwrap())
event_loop, .await
async move { .unwrap();
Python::with_gil(|py| { stx.send(true).unwrap();
crate::runtime::into_future(signal_rx.as_ref(py)).unwrap() log::info!("Stopping worker-{}", worker_id);
}).await.unwrap(); while let Some(worker) = workers.pop() {
stx.send(true).unwrap(); worker.join().unwrap();
log::info!("Stopping worker-{}", worker_id);
while let Some(worker) = workers.pop() {
worker.join().unwrap();
}
Ok(())
} }
); Ok(())
});
match main_loop { match main_loop {
Ok(_) => {} Ok(_) => {}
@ -369,22 +320,14 @@ macro_rules! serve_wth {
macro_rules! serve_wth_ssl { macro_rules! serve_wth_ssl {
($func_name: ident, $target:expr) => { ($func_name: ident, $target:expr) => {
fn $func_name( fn $func_name(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
pyo3_log::init(); pyo3_log::init();
let rtm = crate::runtime::init_runtime_mt(1, 1); let rtm = crate::runtime::init_runtime_mt(1, 1);
let worker_id = self.config.id; let worker_id = self.config.id;
log::info!("Started worker-{}", worker_id); log::info!("Started worker-{}", worker_id);
let callback_wrapper = crate::callbacks::CallbackWrapper::new( let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context);
callback, event_loop, context
);
let mut workers = vec![]; let mut workers = vec![];
let (stx, srx) = tokio::sync::watch::channel(false); let (stx, srx) = tokio::sync::watch::channel(false);
@ -406,42 +349,36 @@ macro_rules! serve_wth_ssl {
let local = tokio::task::LocalSet::new(); let local = tokio::task::LocalSet::new();
crate::runtime::block_on_local(rt, local, async move { crate::runtime::block_on_local(rt, local, async move {
let service = crate::workers::build_service_ssl!( let service = crate::workers::build_service_ssl!(callback_wrapper, rth, $target);
callback_wrapper, rth, $target let server =
); hyper::Server::builder(crate::tls::tls_listen(std::sync::Arc::new(tls_cfg), tcp_listener))
let server = hyper::Server::builder( .executor(crate::workers::WorkerExecutor)
crate::tls::tls_listen( .http1_only(http1_only)
std::sync::Arc::new(tls_cfg), tcp_listener .http2_only(http2_only)
) .http1_max_buf_size(http1_buffer_max)
) .serve(service);
.executor(crate::workers::WorkerExecutor) server
.http1_only(http1_only) .with_graceful_shutdown(async move {
.http2_only(http2_only) srx.changed().await.unwrap();
.http1_max_buf_size(http1_buffer_max) })
.serve(service); .await
server.with_graceful_shutdown(async move { .unwrap();
srx.changed().await.unwrap();
}).await.unwrap();
log::info!("Stopping worker-{} runtime-{}", worker_id, thread_id + 1); log::info!("Stopping worker-{} runtime-{}", worker_id, thread_id + 1);
}); });
})); }));
}; }
let main_loop = crate::runtime::run_until_complete( let main_loop = crate::runtime::run_until_complete(rtm.handler(), event_loop, async move {
rtm.handler(), Python::with_gil(|py| crate::runtime::into_future(signal_rx.as_ref(py)).unwrap())
event_loop, .await
async move { .unwrap();
Python::with_gil(|py| { stx.send(true).unwrap();
crate::runtime::into_future(signal_rx.as_ref(py)).unwrap() log::info!("Stopping worker-{}", worker_id);
}).await.unwrap(); while let Some(worker) = workers.pop() {
stx.send(true).unwrap(); worker.join().unwrap();
log::info!("Stopping worker-{}", worker_id);
while let Some(worker) = workers.pop() {
worker.join().unwrap();
}
Ok(())
} }
); Ok(())
});
match main_loop { match main_loop {
Ok(_) => {} Ok(_) => {}
@ -457,8 +394,8 @@ macro_rules! serve_wth_ssl {
pub(crate) use build_service; pub(crate) use build_service;
pub(crate) use build_service_ssl; pub(crate) use build_service_ssl;
pub(crate) use serve_rth; pub(crate) use serve_rth;
pub(crate) use serve_wth;
pub(crate) use serve_rth_ssl; pub(crate) use serve_rth_ssl;
pub(crate) use serve_wth;
pub(crate) use serve_wth_ssl; pub(crate) use serve_wth_ssl;
pub(crate) fn init_pymodule(module: &PyModule) -> PyResult<()> { pub(crate) fn init_pymodule(module: &PyModule) -> PyResult<()> {

View file

@ -1,24 +1,24 @@
use hyper::{ use hyper::{
Body,
Request,
Response,
StatusCode,
header::{CONNECTION, UPGRADE}, header::{CONNECTION, UPGRADE},
http::response::Builder http::response::Builder,
Body, Request, Response, StatusCode,
}; };
use pin_project::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::sync::mpsc;
use tokio_tungstenite::WebSocketStream;
use tungstenite::{ use tungstenite::{
error::ProtocolError, error::ProtocolError,
handshake::derive_accept_key, handshake::derive_accept_key,
protocol::{Role, WebSocketConfig} protocol::{Role, WebSocketConfig},
}; };
use pin_project::pin_project;
use std::{future::Future, pin::Pin, task::{Context, Poll}};
use tokio_tungstenite::WebSocketStream;
use tokio::sync::mpsc;
use super::utils::header_contains_value; use super::utils::header_contains_value;
#[pin_project] #[pin_project]
#[derive(Debug)] #[derive(Debug)]
pub struct HyperWebsocket { pub struct HyperWebsocket {
@ -37,15 +37,9 @@ impl Future for HyperWebsocket {
Poll::Ready(x) => x, Poll::Ready(x) => x,
}; };
let upgraded = upgraded.map_err(|_| let upgraded = upgraded.map_err(|_| tungstenite::Error::Protocol(ProtocolError::HandshakeIncomplete))?;
tungstenite::Error::Protocol(ProtocolError::HandshakeIncomplete)
)?;
let stream = WebSocketStream::from_raw_socket( let stream = WebSocketStream::from_raw_socket(upgraded, Role::Server, this.config.take());
upgraded,
Role::Server,
this.config.take(),
);
tokio::pin!(stream); tokio::pin!(stream);
match stream.as_mut().poll(cx) { match stream.as_mut().poll(cx) {
@ -58,18 +52,15 @@ impl Future for HyperWebsocket {
pub(crate) struct UpgradeData { pub(crate) struct UpgradeData {
response_builder: Option<Builder>, response_builder: Option<Builder>,
response_tx: Option<mpsc::Sender<Response<Body>>>, response_tx: Option<mpsc::Sender<Response<Body>>>,
pub consumed: bool pub consumed: bool,
} }
impl UpgradeData { impl UpgradeData {
pub fn new( pub fn new(response_builder: Builder, response_tx: mpsc::Sender<Response<Body>>) -> Self {
response_builder: Builder,
response_tx: mpsc::Sender<Response<Body>>)
-> Self {
Self { Self {
response_builder: Some(response_builder), response_builder: Some(response_builder),
response_tx: Some(response_tx), response_tx: Some(response_tx),
consumed: false consumed: false,
} }
} }
@ -79,19 +70,16 @@ impl UpgradeData {
Ok(_) => { Ok(_) => {
self.consumed = true; self.consumed = true;
Ok(()) Ok(())
}, }
err => err err => err,
} }
} }
} }
#[inline] #[inline]
pub(crate) fn is_upgrade_request<B>(request: &Request<B>) -> bool { pub(crate) fn is_upgrade_request<B>(request: &Request<B>) -> bool {
header_contains_value( header_contains_value(request.headers(), CONNECTION, "Upgrade")
request.headers(), CONNECTION, "Upgrade" && header_contains_value(request.headers(), UPGRADE, "websocket")
) && header_contains_value(
request.headers(), UPGRADE, "websocket"
)
} }
pub(crate) fn upgrade_intent<B>( pub(crate) fn upgrade_intent<B>(
@ -100,13 +88,17 @@ pub(crate) fn upgrade_intent<B>(
) -> Result<(Builder, HyperWebsocket), ProtocolError> { ) -> Result<(Builder, HyperWebsocket), ProtocolError> {
let request = request.borrow_mut(); let request = request.borrow_mut();
let key = request.headers() let key = request
.headers()
.get("Sec-WebSocket-Key") .get("Sec-WebSocket-Key")
.ok_or(ProtocolError::MissingSecWebSocketKey)?; .ok_or(ProtocolError::MissingSecWebSocketKey)?;
if request.headers().get("Sec-WebSocket-Version").map( if request
|v| v.as_bytes() .headers()
) != Some(b"13") { .get("Sec-WebSocket-Version")
.map(hyper::http::HeaderValue::as_bytes)
!= Some(b"13")
{
return Err(ProtocolError::MissingSecWebSocketVersionHeader); return Err(ProtocolError::MissingSecWebSocketVersionHeader);
} }

View file

@ -2,44 +2,35 @@ use hyper::Body;
use pyo3::prelude::*; use pyo3::prelude::*;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use super::types::{WSGIResponseBodyIter, WSGIScope as Scope};
use crate::callbacks::CallbackWrapper; use crate::callbacks::CallbackWrapper;
use super::types::{WSGIScope as Scope, WSGIResponseBodyIter};
const WSGI_LIST_RESPONSE_BODY: i32 = 0; const WSGI_LIST_RESPONSE_BODY: i32 = 0;
const WSGI_ITER_RESPONSE_BODY: i32 = 1; const WSGI_ITER_RESPONSE_BODY: i32 = 1;
#[inline(always)] #[inline(always)]
fn run_callback( fn run_callback(callback: PyObject, scope: Scope) -> PyResult<(i32, Vec<(String, String)>, Body)> {
callback: PyObject,
scope: Scope
) -> PyResult<(i32, Vec<(String, String)>, Body)> {
Python::with_gil(|py| { Python::with_gil(|py| {
let (status, headers, body_type, pybody) = callback.call1(py, (scope,))? let (status, headers, body_type, pybody) =
.extract::<(i32, Vec<(String, String)>, i32, PyObject)>(py)?; callback
.call1(py, (scope,))?
.extract::<(i32, Vec<(String, String)>, i32, PyObject)>(py)?;
let body = match body_type { let body = match body_type {
WSGI_LIST_RESPONSE_BODY => Body::from(pybody.extract::<Vec<u8>>(py)?), WSGI_LIST_RESPONSE_BODY => Body::from(pybody.extract::<Vec<u8>>(py)?),
WSGI_ITER_RESPONSE_BODY => Body::wrap_stream(WSGIResponseBodyIter::new(pybody)), WSGI_ITER_RESPONSE_BODY => Body::wrap_stream(WSGIResponseBodyIter::new(pybody)),
_ => Body::empty() _ => Body::empty(),
}; };
Ok((status, headers, body)) Ok((status, headers, body))
}) })
} }
pub(crate) fn call_rtb_http( pub(crate) fn call_rtb_http(cb: CallbackWrapper, scope: Scope) -> PyResult<(i32, Vec<(String, String)>, Body)> {
cb: CallbackWrapper, run_callback(cb.callback, scope)
scope: Scope
) -> PyResult<(i32, Vec<(String, String)>, Body)> {
run_callback(cb.callback.clone(), scope)
} }
pub(crate) fn call_rtt_http( pub(crate) fn call_rtt_http(
cb: CallbackWrapper, cb: CallbackWrapper,
scope: Scope scope: Scope,
) -> JoinHandle<PyResult<(i32, Vec<(String, String)>, Body)>> { ) -> JoinHandle<PyResult<(i32, Vec<(String, String)>, Body)>> {
let callback = cb.callback.clone(); tokio::task::spawn_blocking(move || run_callback(cb.callback, scope))
tokio::task::spawn_blocking(move || {
run_callback(callback, scope)
})
} }

View file

@ -1,21 +1,18 @@
use hyper::{ use hyper::{
Body, header::{HeaderName, HeaderValue, SERVER as HK_SERVER},
Request, Body, Request, Response,
Response,
header::{SERVER as HK_SERVER, HeaderName, HeaderValue}
}; };
use std::net::SocketAddr; use std::net::SocketAddr;
use crate::{
callbacks::CallbackWrapper,
http::{HV_SERVER, response_500},
runtime::RuntimeRef,
};
use super::{ use super::{
callbacks::{call_rtb_http, call_rtt_http}, callbacks::{call_rtb_http, call_rtt_http},
types::WSGIScope as Scope types::WSGIScope as Scope,
};
use crate::{
callbacks::CallbackWrapper,
http::{response_500, HV_SERVER},
runtime::RuntimeRef,
}; };
#[inline(always)] #[inline(always)]
fn build_response(status: i32, pyheaders: Vec<(String, String)>, body: Body) -> Response<Body> { fn build_response(status: i32, pyheaders: Vec<(String, String)>, body: Body) -> Response<Body> {
@ -26,7 +23,7 @@ fn build_response(status: i32, pyheaders: Vec<(String, String)>, body: Body) ->
for (key, val) in pyheaders { for (key, val) in pyheaders {
headers.append( headers.append(
HeaderName::from_bytes(key.as_bytes()).unwrap(), HeaderName::from_bytes(key.as_bytes()).unwrap(),
HeaderValue::from_str(&val).unwrap() HeaderValue::from_str(&val).unwrap(),
); );
} }
res res
@ -38,14 +35,11 @@ pub(crate) async fn handle_rtt(
server_addr: SocketAddr, server_addr: SocketAddr,
client_addr: SocketAddr, client_addr: SocketAddr,
req: Request<Body>, req: Request<Body>,
scheme: &str scheme: &str,
) -> Response<Body> { ) -> Response<Body> {
if let Ok(res) = call_rtt_http( if let Ok(res) = call_rtt_http(callback, Scope::new(scheme, server_addr, client_addr, req).await).await {
callback,
Scope::new(scheme, server_addr, client_addr, req).await
).await {
if let Ok((status, headers, body)) = res { if let Ok((status, headers, body)) = res {
return build_response(status, headers, body) return build_response(status, headers, body);
} }
log::warn!("Application callable raised an exception"); log::warn!("Application callable raised an exception");
} else { } else {
@ -60,12 +54,9 @@ pub(crate) async fn handle_rtb(
server_addr: SocketAddr, server_addr: SocketAddr,
client_addr: SocketAddr, client_addr: SocketAddr,
req: Request<Body>, req: Request<Body>,
scheme: &str scheme: &str,
) -> Response<Body> { ) -> Response<Body> {
match call_rtb_http( match call_rtb_http(callback, Scope::new(scheme, server_addr, client_addr, req).await) {
callback,
Scope::new(scheme, server_addr, client_addr, req).await
) {
Ok((status, headers, body)) => build_response(status, headers, body), Ok((status, headers, body)) => build_response(status, headers, body),
_ => { _ => {
log::warn!("Application callable raised an exception"); log::warn!("Application callable raised an exception");

View file

@ -1,17 +1,11 @@
use pyo3::prelude::*; use pyo3::prelude::*;
use crate::workers::{
WorkerConfig,
serve_rth,
serve_wth,
serve_rth_ssl,
serve_wth_ssl
};
use super::http::{handle_rtb, handle_rtt}; use super::http::{handle_rtb, handle_rtt};
use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig};
#[pyclass(module="granian._granian")] #[pyclass(module = "granian._granian")]
pub struct WSGIWorker { pub struct WSGIWorker {
config: WorkerConfig config: WorkerConfig,
} }
impl WSGIWorker { impl WSGIWorker {
@ -46,7 +40,7 @@ impl WSGIWorker {
http1_buffer_max: usize, http1_buffer_max: usize,
ssl_enabled: bool, ssl_enabled: bool,
ssl_cert: Option<&str>, ssl_cert: Option<&str>,
ssl_key: Option<&str> ssl_key: Option<&str>,
) -> PyResult<Self> { ) -> PyResult<Self> {
Ok(Self { Ok(Self {
config: WorkerConfig::new( config: WorkerConfig::new(
@ -60,34 +54,22 @@ impl WSGIWorker {
true, true,
ssl_enabled, ssl_enabled,
ssl_cert, ssl_cert,
ssl_key ssl_key,
) ),
}) })
} }
fn serve_rth( fn serve_rth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
match self.config.ssl_enabled { match self.config.ssl_enabled {
false => self._serve_rth(callback, event_loop, context, signal_rx), false => self._serve_rth(callback, event_loop, context, signal_rx),
true => self._serve_rth_ssl(callback, event_loop, context, signal_rx) true => self._serve_rth_ssl(callback, event_loop, context, signal_rx),
} }
} }
fn serve_wth( fn serve_wth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
&self,
callback: PyObject,
event_loop: &PyAny,
context: &PyAny,
signal_rx: PyObject
) {
match self.config.ssl_enabled { match self.config.ssl_enabled {
false => self._serve_wth(callback, event_loop, context, signal_rx), false => self._serve_wth(callback, event_loop, context, signal_rx),
true => self._serve_wth_ssl(callback, event_loop, context, signal_rx) true => self._serve_wth_ssl(callback, event_loop, context, signal_rx),
} }
} }
} }

View file

@ -1,23 +1,21 @@
use futures::Stream; use futures::Stream;
use hyper::{ use hyper::{
body::Bytes, body::Bytes,
header::{CONTENT_TYPE, CONTENT_LENGTH, HeaderMap}, header::{HeaderMap, CONTENT_LENGTH, CONTENT_TYPE},
Body, Body, Method, Request, Uri, Version,
Method,
Request,
Uri,
Version
}; };
use pyo3::{prelude::*, types::IntoPyDict};
use pyo3::types::{PyBytes, PyDict, PyList}; use pyo3::types::{PyBytes, PyDict, PyList};
use std::{net::{IpAddr, SocketAddr}, task::{Context, Poll}}; use pyo3::{prelude::*, types::IntoPyDict};
use std::{
net::{IpAddr, SocketAddr},
task::{Context, Poll},
};
const LINE_SPLIT: u8 = u8::from_be_bytes(*b"\n"); const LINE_SPLIT: u8 = u8::from_be_bytes(*b"\n");
#[pyclass(module = "granian._granian")] #[pyclass(module = "granian._granian")]
pub(crate) struct WSGIBody { pub(crate) struct WSGIBody {
inner: Bytes inner: Bytes,
} }
impl WSGIBody { impl WSGIBody {
@ -36,9 +34,9 @@ impl WSGIBody {
match self.inner.iter().position(|&c| c == LINE_SPLIT) { match self.inner.iter().position(|&c| c == LINE_SPLIT) {
Some(next_split) => { Some(next_split) => {
let bytes = self.inner.split_to(next_split); let bytes = self.inner.split_to(next_split);
Some(PyBytes::new(py, &bytes[..])) Some(PyBytes::new(py, &bytes))
}, }
_ => None _ => None,
} }
} }
@ -48,18 +46,16 @@ impl WSGIBody {
None => { None => {
let bytes = self.inner.split_to(self.inner.len()); let bytes = self.inner.split_to(self.inner.len());
PyBytes::new(py, &bytes[..]) PyBytes::new(py, &bytes[..])
},
Some(size) => {
match size {
0 => PyBytes::new(py, b""),
size => {
let limit = self.inner.len();
let rsize = if size > limit { limit } else { size };
let bytes = self.inner.split_to(rsize);
PyBytes::new(py, &bytes[..])
}
}
} }
Some(size) => match size {
0 => PyBytes::new(py, b""),
size => {
let limit = self.inner.len();
let rsize = if size > limit { limit } else { size };
let bytes = self.inner.split_to(rsize);
PyBytes::new(py, &bytes[..])
}
},
} }
} }
@ -69,16 +65,17 @@ impl WSGIBody {
let bytes = self.inner.split_to(next_split); let bytes = self.inner.split_to(next_split);
self.inner = self.inner.slice(1..); self.inner = self.inner.slice(1..);
PyBytes::new(py, &bytes[..]) PyBytes::new(py, &bytes[..])
}, }
_ => PyBytes::new(py, b"") _ => PyBytes::new(py, b""),
} }
} }
#[pyo3(signature = (_hint=None))] #[pyo3(signature = (_hint=None))]
fn readlines<'p>(&mut self, py: Python<'p>, _hint: Option<PyObject>) -> &'p PyList { fn readlines<'p>(&mut self, py: Python<'p>, _hint: Option<PyObject>) -> &'p PyList {
let lines: Vec<&PyBytes> = self.inner let lines: Vec<&PyBytes> = self
.inner
.split(|&c| c == LINE_SPLIT) .split(|&c| c == LINE_SPLIT)
.map(|item| PyBytes::new(py, &item[..])) .map(|item| PyBytes::new(py, item))
.collect(); .collect();
self.inner.clear(); self.inner.clear();
PyList::new(py, lines) PyList::new(py, lines)
@ -95,28 +92,19 @@ pub(crate) struct WSGIScope {
server_port: u16, server_port: u16,
client: String, client: String,
headers: HeaderMap, headers: HeaderMap,
body: Bytes body: Bytes,
} }
impl WSGIScope { impl WSGIScope {
pub async fn new( pub async fn new(scheme: &str, server: SocketAddr, client: SocketAddr, request: Request<Body>) -> Self {
scheme: &str,
server: SocketAddr,
client: SocketAddr,
request: Request<Body>,
) -> Self {
let http_version = request.version(); let http_version = request.version();
let method = request.method().to_owned(); let method = request.method().clone();
let uri = request.uri().to_owned(); let uri = request.uri().clone();
let headers = request.headers().to_owned(); let headers = request.headers().clone();
let body = match method { let body = match method {
Method::HEAD | Method::GET | Method::OPTIONS => { Bytes::new() }, Method::HEAD | Method::GET | Method::OPTIONS => Bytes::new(),
_ => { _ => hyper::body::to_bytes(request).await.unwrap_or(Bytes::new()),
hyper::body::to_bytes(request)
.await
.unwrap_or(Bytes::new())
}
}; };
Self { Self {
@ -128,7 +116,7 @@ impl WSGIScope {
server_port: server.port(), server_port: server.port(),
client: client.to_string(), client: client.to_string(),
headers, headers,
body body,
} }
} }
@ -138,7 +126,7 @@ impl WSGIScope {
Version::HTTP_10 => "HTTP/1", Version::HTTP_10 => "HTTP/1",
Version::HTTP_11 => "HTTP/1.1", Version::HTTP_11 => "HTTP/1.1",
Version::HTTP_2 => "HTTP/2", Version::HTTP_2 => "HTTP/2",
_ => "HTTP/1" _ => "HTTP/1",
} }
} }
} }
@ -157,21 +145,21 @@ impl WSGIScope {
content_type, content_type,
content_len, content_len,
headers, headers,
body body,
) = py.allow_threads(|| { ) = py.allow_threads(|| {
let (path, query_string) = self.uri.path_and_query() let (path, query_string) = self
.uri
.path_and_query()
.map_or_else(|| ("", ""), |pq| (pq.path(), pq.query().unwrap_or(""))); .map_or_else(|| ("", ""), |pq| (pq.path(), pq.query().unwrap_or("")));
let content_type = self.headers.remove(CONTENT_TYPE); let content_type = self.headers.remove(CONTENT_TYPE);
let content_len = self.headers.remove(CONTENT_LENGTH); let content_len = self.headers.remove(CONTENT_LENGTH);
let mut headers = Vec::with_capacity(self.headers.len()); let mut headers = Vec::with_capacity(self.headers.len());
for (key, val) in self.headers.iter() { for (key, val) in &self.headers {
headers.push( headers.push((
( format!("HTTP_{}", key.as_str().replace('-', "_").to_uppercase()),
format!("HTTP_{}", key.as_str().replace("-", "_").to_uppercase()), val.to_str().unwrap_or_default(),
val.to_str().unwrap_or_default() ));
)
);
} }
( (
@ -185,7 +173,7 @@ impl WSGIScope {
content_type, content_type,
content_len, content_len,
headers, headers,
WSGIBody::new(self.body.to_owned()) WSGIBody::new(self.body.clone()),
) )
}); });
@ -202,13 +190,13 @@ impl WSGIScope {
if let Some(content_type) = content_type { if let Some(content_type) = content_type {
ret.set_item( ret.set_item(
pyo3::intern!(py, "CONTENT_TYPE"), pyo3::intern!(py, "CONTENT_TYPE"),
content_type.to_str().unwrap_or_default() content_type.to_str().unwrap_or_default(),
)?; )?;
} }
if let Some(content_len) = content_len { if let Some(content_len) = content_len {
ret.set_item( ret.set_item(
pyo3::intern!(py, "CONTENT_LENGTH"), pyo3::intern!(py, "CONTENT_LENGTH"),
content_len.to_str().unwrap_or_default() content_len.to_str().unwrap_or_default(),
)?; )?;
} }
@ -219,7 +207,7 @@ impl WSGIScope {
} }
pub(crate) struct WSGIResponseBodyIter { pub(crate) struct WSGIResponseBodyIter {
inner: PyObject inner: PyObject,
} }
impl WSGIResponseBodyIter { impl WSGIResponseBodyIter {
@ -235,27 +223,20 @@ impl WSGIResponseBodyIter {
impl Stream for WSGIResponseBodyIter { impl Stream for WSGIResponseBodyIter {
type Item = PyResult<Vec<u8>>; type Item = PyResult<Vec<u8>>;
fn poll_next( fn poll_next(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self: std::pin::Pin<&mut Self>, Python::with_gil(|py| match self.inner.call_method0(py, pyo3::intern!(py, "__next__")) {
_cx: &mut Context<'_> Ok(chunk_obj) => match chunk_obj.extract::<Vec<u8>>(py) {
) -> Poll<Option<Self::Item>> { Ok(chunk) => Poll::Ready(Some(Ok(chunk))),
Python::with_gil(|py| { _ => {
match self.inner.call_method0(py, pyo3::intern!(py, "__next__")) { self.close_inner(py);
Ok(chunk_obj) => {
match chunk_obj.extract::<Vec<u8>>(py) {
Ok(chunk) => Poll::Ready(Some(Ok(chunk))),
_ => {
self.close_inner(py);
Poll::Ready(None)
}
}
},
Err(err) => {
if err.is_instance_of::<pyo3::exceptions::PyStopIteration>(py) {
self.close_inner(py);
}
Poll::Ready(None) Poll::Ready(None)
} }
},
Err(err) => {
if err.is_instance_of::<pyo3::exceptions::PyStopIteration>(py) {
self.close_inner(py);
}
Poll::Ready(None)
} }
}) })
} }

View file

@ -1,55 +1,45 @@
import json import json
PLAINTEXT_RESPONSE = { PLAINTEXT_RESPONSE = {
'type': 'http.response.start', 'type': 'http.response.start',
'status': 200, 'status': 200,
'headers': [ 'headers': [[b'content-type', b'text/plain; charset=utf-8']],
[b'content-type', b'text/plain; charset=utf-8'],
]
}
JSON_RESPONSE = {
'type': 'http.response.start',
'status': 200,
'headers': [
[b'content-type', b'application/json'],
]
} }
JSON_RESPONSE = {'type': 'http.response.start', 'status': 200, 'headers': [[b'content-type', b'application/json']]}
async def info(scope, receive, send): async def info(scope, receive, send):
await send(JSON_RESPONSE) await send(JSON_RESPONSE)
await send({ await send(
'type': 'http.response.body', {
'body': json.dumps({ 'type': 'http.response.body',
'type': scope['type'], 'body': json.dumps(
'asgi': scope['asgi'], {
'http_version': scope['http_version'], 'type': scope['type'],
'scheme': scope['scheme'], 'asgi': scope['asgi'],
'method': scope['method'], 'http_version': scope['http_version'],
'path': scope['path'], 'scheme': scope['scheme'],
'query_string': scope['query_string'].decode("latin-1"), 'method': scope['method'],
'headers': { 'path': scope['path'],
k.decode("utf8"): v.decode("utf8") 'query_string': scope['query_string'].decode('latin-1'),
for k, v in scope['headers'] 'headers': {k.decode('utf8'): v.decode('utf8') for k, v in scope['headers']},
} }
}).encode("utf8"), ).encode('utf8'),
'more_body': False 'more_body': False,
}) }
)
async def echo(scope, receive, send): async def echo(scope, receive, send):
await send(PLAINTEXT_RESPONSE) await send(PLAINTEXT_RESPONSE)
more_body = True more_body = True
body = b"" body = b''
while more_body: while more_body:
msg = await receive() msg = await receive()
more_body = msg['more_body'] more_body = msg['more_body']
body += msg['body'] body += msg['body']
await send({ await send({'type': 'http.response.body', 'body': body, 'more_body': False})
'type': 'http.response.body',
'body': body,
'more_body': False
})
async def ws_reject(scope, receive, send): async def ws_reject(scope, receive, send):
@ -58,21 +48,22 @@ async def ws_reject(scope, receive, send):
async def ws_info(scope, receive, send): async def ws_info(scope, receive, send):
await send({'type': 'websocket.accept'}) await send({'type': 'websocket.accept'})
await send({ await send(
'type': 'websocket.send', {
'text': json.dumps({ 'type': 'websocket.send',
'type': scope['type'], 'text': json.dumps(
'asgi': scope['asgi'], {
'http_version': scope['http_version'], 'type': scope['type'],
'scheme': scope['scheme'], 'asgi': scope['asgi'],
'path': scope['path'], 'http_version': scope['http_version'],
'query_string': scope['query_string'].decode("latin-1"), 'scheme': scope['scheme'],
'headers': { 'path': scope['path'],
k.decode("utf8"): v.decode("utf8") 'query_string': scope['query_string'].decode('latin-1'),
for k, v in scope['headers'] 'headers': {k.decode('utf8'): v.decode('utf8') for k, v in scope['headers']},
} }
}) ),
}) }
)
await send({'type': 'websocket.close'}) await send({'type': 'websocket.close'})
@ -98,10 +89,7 @@ async def ws_push(scope, receive, send):
try: try:
while True: while True:
await send({ await send({'type': 'websocket.send', 'text': 'ping'})
'type': 'websocket.send',
'text': 'ping'
})
except Exception: except Exception:
pass pass
@ -116,12 +104,12 @@ async def err_proto(scope, receive, send):
def app(scope, receive, send): def app(scope, receive, send):
return { return {
"/info": info, '/info': info,
"/echo": echo, '/echo': echo,
"/ws_reject": ws_reject, '/ws_reject': ws_reject,
"/ws_info": ws_info, '/ws_info': ws_info,
"/ws_echo": ws_echo, '/ws_echo': ws_echo,
"/ws_push": ws_push, '/ws_push': ws_push,
"/err_app": err_app, '/err_app': err_app,
"/err_proto": err_proto '/err_proto': err_proto,
}[scope['path']](scope, receive, send) }[scope['path']](scope, receive, send)

View file

@ -1,55 +1,42 @@
import json import json
from granian.rsgi import ( from granian.rsgi import HTTPProtocol, Scope, WebsocketMessageType, WebsocketProtocol
HTTPProtocol,
Scope,
WebsocketMessageType,
WebsocketProtocol
)
async def info(scope: Scope, protocol: HTTPProtocol): async def info(scope: Scope, protocol: HTTPProtocol):
protocol.response_bytes( protocol.response_bytes(
200, 200,
[('content-type', 'application/json')], [('content-type', 'application/json')],
json.dumps({ json.dumps(
'proto': scope.proto, {
'http_version': scope.http_version, 'proto': scope.proto,
'rsgi_version': scope.rsgi_version, 'http_version': scope.http_version,
'scheme': scope.scheme, 'rsgi_version': scope.rsgi_version,
'method': scope.method, 'scheme': scope.scheme,
'path': scope.path, 'method': scope.method,
'query_string': scope.query_string, 'path': scope.path,
'headers': {k: v for k, v in scope.headers.items()} 'query_string': scope.query_string,
}).encode("utf8") 'headers': dict(scope.headers.items()),
}
).encode('utf8'),
) )
async def echo(_, protocol: HTTPProtocol): async def echo(_, protocol: HTTPProtocol):
msg = await protocol() msg = await protocol()
protocol.response_bytes( protocol.response_bytes(200, [('content-type', 'text/plain; charset=utf-8')], msg)
200,
[('content-type', 'text/plain; charset=utf-8')],
msg
)
async def echo_stream(_, protocol: HTTPProtocol): async def echo_stream(_, protocol: HTTPProtocol):
trx = protocol.response_stream( trx = protocol.response_stream(200, [('content-type', 'text/plain; charset=utf-8')])
200,
[('content-type', 'text/plain; charset=utf-8')]
)
async for msg in protocol: async for msg in protocol:
await trx.send_bytes(msg) await trx.send_bytes(msg)
async def stream(_, protocol: HTTPProtocol): async def stream(_, protocol: HTTPProtocol):
trx = protocol.response_stream( trx = protocol.response_stream(200, [('content-type', 'text/plain; charset=utf-8')])
200,
[('content-type', 'text/plain; charset=utf-8')]
)
for _ in range(0, 3): for _ in range(0, 3):
await trx.send_bytes(b"test") await trx.send_bytes(b'test')
async def ws_reject(_, protocol: WebsocketProtocol): async def ws_reject(_, protocol: WebsocketProtocol):
@ -59,16 +46,20 @@ async def ws_reject(_, protocol: WebsocketProtocol):
async def ws_info(scope: Scope, protocol: WebsocketProtocol): async def ws_info(scope: Scope, protocol: WebsocketProtocol):
trx = await protocol.accept() trx = await protocol.accept()
await trx.send_str(json.dumps({ await trx.send_str(
'proto': scope.proto, json.dumps(
'http_version': scope.http_version, {
'rsgi_version': scope.rsgi_version, 'proto': scope.proto,
'scheme': scope.scheme, 'http_version': scope.http_version,
'method': scope.method, 'rsgi_version': scope.rsgi_version,
'path': scope.path, 'scheme': scope.scheme,
'query_string': scope.query_string, 'method': scope.method,
'headers': {k: v for k, v in scope.headers.items()} 'path': scope.path,
})) 'query_string': scope.query_string,
'headers': dict(scope.headers.items()),
}
)
)
while True: while True:
message = await trx.receive() message = await trx.receive()
if message.kind == WebsocketMessageType.close: if message.kind == WebsocketMessageType.close:
@ -97,7 +88,7 @@ async def ws_push(_, protocol: WebsocketProtocol):
try: try:
while True: while True:
await trx.send_str("ping") await trx.send_str('ping')
except Exception: except Exception:
pass pass
@ -110,13 +101,13 @@ async def err_app(scope: Scope, protocol: HTTPProtocol):
def app(scope, protocol): def app(scope, protocol):
return { return {
"/info": info, '/info': info,
"/echo": echo, '/echo': echo,
"/echos": echo_stream, '/echos': echo_stream,
"/stream": stream, '/stream': stream,
"/ws_reject": ws_reject, '/ws_reject': ws_reject,
"/ws_info": ws_info, '/ws_info': ws_info,
"/ws_echo": ws_echo, '/ws_echo': ws_echo,
"/ws_push": ws_push, '/ws_push': ws_push,
"/err_app": err_app '/err_app': err_app,
}[scope.path](scope, protocol) }[scope.path](scope, protocol)

View file

@ -2,37 +2,32 @@ import json
def info(environ, protocol): def info(environ, protocol):
protocol( protocol('200 OK', [('content-type', 'application/json')])
"200 OK", return [
[('content-type', 'application/json')] json.dumps(
) {
return [json.dumps({ 'scheme': environ['wsgi.url_scheme'],
'scheme': environ['wsgi.url_scheme'], 'method': environ['REQUEST_METHOD'],
'method': environ['REQUEST_METHOD'], 'path': environ['PATH_INFO'],
'path': environ["PATH_INFO"], 'query_string': environ['QUERY_STRING'],
'query_string': environ["QUERY_STRING"], 'content_length': environ['CONTENT_LENGTH'],
'content_length': environ['CONTENT_LENGTH'], 'headers': {k: v for k, v in environ.items() if k.startswith('HTTP_')},
'headers': {k: v for k, v in environ.items() if k.startswith("HTTP_")} }
}).encode("utf8")] ).encode('utf8')
]
def echo(environ, protocol): def echo(environ, protocol):
protocol( protocol('200 OK', [('content-type', 'text/plain; charset=utf-8')])
'200 OK',
[('content-type', 'text/plain; charset=utf-8')]
)
return [environ['wsgi.input'].read()] return [environ['wsgi.input'].read()]
def iterbody(environ, protocol): def iterbody(environ, protocol):
def response(): def response():
for _ in range(0, 3): for _ in range(0, 3):
yield b"test" yield b'test'
protocol( protocol('200 OK', [('content-type', 'text/plain; charset=utf-8')])
'200 OK',
[('content-type', 'text/plain; charset=utf-8')]
)
return response() return response()
@ -41,9 +36,6 @@ def err_app(environ, protocol):
def app(environ, protocol): def app(environ, protocol):
return { return {'/info': info, '/echo': echo, '/iterbody': iterbody, '/err_app': err_app}[environ['PATH_INFO']](
"/info": info, environ, protocol
"/echo": echo, )
"/iterbody": iterbody,
"/err_app": err_app
}[environ["PATH_INFO"]](environ, protocol)

View file

@ -1,7 +1,6 @@
import asyncio import asyncio
import os import os
import socket import socket
from contextlib import asynccontextmanager, closing from contextlib import asynccontextmanager, closing
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
@ -11,19 +10,20 @@ import pytest
@asynccontextmanager @asynccontextmanager
async def _server(interface, port, threading_mode, tls=False): async def _server(interface, port, threading_mode, tls=False):
certs_path = Path.cwd() / "tests" / "fixtures" / "tls" certs_path = Path.cwd() / 'tests' / 'fixtures' / 'tls'
tls_opts = ( tls_opts = (
f"--ssl-certificate {certs_path / 'cert.pem'} " (f"--ssl-certificate {certs_path / 'cert.pem'} " f"--ssl-keyfile {certs_path / 'key.pem'} ") if tls else ''
f"--ssl-keyfile {certs_path / 'key.pem'} " )
) if tls else ""
proc = await asyncio.create_subprocess_shell( proc = await asyncio.create_subprocess_shell(
"".join([ ''.join(
f"granian --interface {interface} --port {port} ", [
f"--threads 1 --threading-mode {threading_mode} ", f'granian --interface {interface} --port {port} ',
tls_opts, f'--threads 1 --threading-mode {threading_mode} ',
f"tests.apps.{interface}:app" tls_opts,
]), f'tests.apps.{interface}:app',
env=dict(os.environ) ]
),
env=dict(os.environ),
) )
await asyncio.sleep(1) await asyncio.sleep(1)
try: try:
@ -33,7 +33,7 @@ async def _server(interface, port, threading_mode, tls=False):
await proc.wait() await proc.wait()
@pytest.fixture(scope="function") @pytest.fixture(scope='function')
def server_port(): def server_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.bind(('localhost', 0)) sock.bind(('localhost', 0))
@ -41,26 +41,26 @@ def server_port():
return sock.getsockname()[1] return sock.getsockname()[1]
@pytest.fixture(scope="function") @pytest.fixture(scope='function')
def asgi_server(server_port): def asgi_server(server_port):
return partial(_server, "asgi", server_port) return partial(_server, 'asgi', server_port)
@pytest.fixture(scope="function") @pytest.fixture(scope='function')
def rsgi_server(server_port): def rsgi_server(server_port):
return partial(_server, "rsgi", server_port) return partial(_server, 'rsgi', server_port)
@pytest.fixture(scope="function") @pytest.fixture(scope='function')
def wsgi_server(server_port): def wsgi_server(server_port):
return partial(_server, "wsgi", server_port) return partial(_server, 'wsgi', server_port)
@pytest.fixture(scope="function") @pytest.fixture(scope='function')
def server(server_port, request): def server(server_port, request):
return partial(_server, request.param, server_port) return partial(_server, request.param, server_port)
@pytest.fixture(scope="function") @pytest.fixture(scope='function')
def server_tls(server_port, request): def server_tls(server_port, request):
return partial(_server, request.param, server_port, tls=True) return partial(_server, request.param, server_port, tls=True)

View file

@ -3,92 +3,59 @@ import pytest
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
"threading_mode",
[
"runtime",
"workers"
]
)
async def test_scope(asgi_server, threading_mode): async def test_scope(asgi_server, threading_mode):
async with asgi_server(threading_mode) as port: async with asgi_server(threading_mode) as port:
res = httpx.get(f"http://localhost:{port}/info?test=true") res = httpx.get(f'http://localhost:{port}/info?test=true')
assert res.status_code == 200 assert res.status_code == 200
assert res.headers["content-type"] == "application/json" assert res.headers['content-type'] == 'application/json'
data = res.json() data = res.json()
assert data['asgi'] == { assert data['asgi'] == {'version': '3.0', 'spec_version': '2.3'}
'version': '3.0', assert data['type'] == 'http'
'spec_version': '2.3'
}
assert data['type'] == "http"
assert data['http_version'] == '1.1' assert data['http_version'] == '1.1'
assert data['scheme'] == 'http' assert data['scheme'] == 'http'
assert data['method'] == "GET" assert data['method'] == 'GET'
assert data['path'] == '/info' assert data['path'] == '/info'
assert data['query_string'] == 'test=true' assert data['query_string'] == 'test=true'
assert data['headers']['host'] == f'localhost:{port}' assert data['headers']['host'] == f'localhost:{port}'
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
"threading_mode",
[
"runtime",
"workers"
]
)
async def test_body(asgi_server, threading_mode): async def test_body(asgi_server, threading_mode):
async with asgi_server(threading_mode) as port: async with asgi_server(threading_mode) as port:
res = httpx.post(f"http://localhost:{port}/echo", content="test") res = httpx.post(f'http://localhost:{port}/echo', content='test')
assert res.status_code == 200 assert res.status_code == 200
assert res.text == "test" assert res.text == 'test'
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
"threading_mode",
[
"runtime",
"workers"
]
)
async def test_body_large(asgi_server, threading_mode): async def test_body_large(asgi_server, threading_mode):
data = "".join([f"{idx}test".zfill(8) for idx in range(0, 5000)]) data = ''.join([f'{idx}test'.zfill(8) for idx in range(0, 5000)])
async with asgi_server(threading_mode) as port: async with asgi_server(threading_mode) as port:
res = httpx.post(f"http://localhost:{port}/echo", content=data) res = httpx.post(f'http://localhost:{port}/echo', content=data)
assert res.status_code == 200 assert res.status_code == 200
assert res.text == data assert res.text == data
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
"threading_mode",
[
"runtime",
"workers"
]
)
async def test_app_error(asgi_server, threading_mode): async def test_app_error(asgi_server, threading_mode):
async with asgi_server(threading_mode) as port: async with asgi_server(threading_mode) as port:
res = httpx.get(f"http://localhost:{port}/err_app") res = httpx.get(f'http://localhost:{port}/err_app')
assert res.status_code == 500 assert res.status_code == 500
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
"threading_mode",
[
"runtime",
"workers"
]
)
async def test_protocol_error(asgi_server, threading_mode): async def test_protocol_error(asgi_server, threading_mode):
async with asgi_server(threading_mode) as port: async with asgi_server(threading_mode) as port:
res = httpx.get(f"http://localhost:{port}/err_proto") res = httpx.get(f'http://localhost:{port}/err_proto')
assert res.status_code == 500 assert res.status_code == 500

View file

@ -1,20 +1,18 @@
import httpx
import json import json
import pathlib import pathlib
import pytest
import ssl import ssl
import httpx
import pytest
import websockets import websockets
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("server_tls", ["asgi", "rsgi"], indirect=True) @pytest.mark.parametrize('server_tls', ['asgi', 'rsgi'], indirect=True)
@pytest.mark.parametrize("threading_mode", ["runtime", "workers"]) @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_http_scope(server_tls, threading_mode): async def test_http_scope(server_tls, threading_mode):
async with server_tls(threading_mode) as port: async with server_tls(threading_mode) as port:
res = httpx.get( res = httpx.get(f'https://localhost:{port}/info?test=true', verify=False)
f"https://localhost:{port}/info?test=true",
verify=False
)
assert res.status_code == 200 assert res.status_code == 200
data = res.json() data = res.json()
@ -22,17 +20,14 @@ async def test_http_scope(server_tls, threading_mode):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("threading_mode", ["runtime", "workers"]) @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_asgi_ws_scope(asgi_server, threading_mode): async def test_asgi_ws_scope(asgi_server, threading_mode):
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
localhost_pem = pathlib.Path.cwd() / "tests" / "fixtures" / "tls" / "cert.pem" localhost_pem = pathlib.Path.cwd() / 'tests' / 'fixtures' / 'tls' / 'cert.pem'
ssl_context.load_verify_locations(localhost_pem) ssl_context.load_verify_locations(localhost_pem)
async with asgi_server(threading_mode, tls=True) as port: async with asgi_server(threading_mode, tls=True) as port:
async with websockets.connect( async with websockets.connect(f'wss://localhost:{port}/ws_info?test=true', ssl=ssl_context) as ws:
f"wss://localhost:{port}/ws_info?test=true",
ssl=ssl_context
) as ws:
res = await ws.recv() res = await ws.recv()
data = json.loads(res) data = json.loads(res)
@ -40,17 +35,14 @@ async def test_asgi_ws_scope(asgi_server, threading_mode):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("threading_mode", ["runtime", "workers"]) @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_rsgi_ws_scope(rsgi_server, threading_mode): async def test_rsgi_ws_scope(rsgi_server, threading_mode):
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
localhost_pem = pathlib.Path.cwd() / "tests" / "fixtures" / "tls" / "cert.pem" localhost_pem = pathlib.Path.cwd() / 'tests' / 'fixtures' / 'tls' / 'cert.pem'
ssl_context.load_verify_locations(localhost_pem) ssl_context.load_verify_locations(localhost_pem)
async with rsgi_server(threading_mode, tls=True) as port: async with rsgi_server(threading_mode, tls=True) as port:
async with websockets.connect( async with websockets.connect(f'wss://localhost:{port}/ws_info?test=true', ssl=ssl_context) as ws:
f"wss://localhost:{port}/ws_info?test=true",
ssl=ssl_context
) as ws:
res = await ws.recv() res = await ws.recv()
data = json.loads(res) data = json.loads(res)

View file

@ -3,90 +3,60 @@ import pytest
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
"threading_mode",
[
"runtime",
"workers"
]
)
async def test_scope(rsgi_server, threading_mode): async def test_scope(rsgi_server, threading_mode):
async with rsgi_server(threading_mode) as port: async with rsgi_server(threading_mode) as port:
res = httpx.get(f"http://localhost:{port}/info?test=true") res = httpx.get(f'http://localhost:{port}/info?test=true')
assert res.status_code == 200 assert res.status_code == 200
assert res.headers["content-type"] == "application/json" assert res.headers['content-type'] == 'application/json'
data = res.json() data = res.json()
assert data['proto'] == "http" assert data['proto'] == 'http'
assert data['http_version'] == '1.1' assert data['http_version'] == '1.1'
assert data['rsgi_version'] == '1.2' assert data['rsgi_version'] == '1.2'
assert data['scheme'] == 'http' assert data['scheme'] == 'http'
assert data['method'] == "GET" assert data['method'] == 'GET'
assert data['path'] == '/info' assert data['path'] == '/info'
assert data['query_string'] == 'test=true' assert data['query_string'] == 'test=true'
assert data['headers']['host'] == f'localhost:{port}' assert data['headers']['host'] == f'localhost:{port}'
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
"threading_mode",
[
"runtime",
"workers"
]
)
async def test_body(rsgi_server, threading_mode): async def test_body(rsgi_server, threading_mode):
async with rsgi_server(threading_mode) as port: async with rsgi_server(threading_mode) as port:
res = httpx.post(f"http://localhost:{port}/echo", content="test") res = httpx.post(f'http://localhost:{port}/echo', content='test')
assert res.status_code == 200 assert res.status_code == 200
assert res.text == "test" assert res.text == 'test'
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
"threading_mode",
[
"runtime",
"workers"
]
)
async def test_body_stream_req(rsgi_server, threading_mode): async def test_body_stream_req(rsgi_server, threading_mode):
data = "".join([f"{idx}test".zfill(8) for idx in range(0, 5000)]) data = ''.join([f'{idx}test'.zfill(8) for idx in range(0, 5000)])
async with rsgi_server(threading_mode) as port: async with rsgi_server(threading_mode) as port:
res = httpx.post(f"http://localhost:{port}/echos", content=data) res = httpx.post(f'http://localhost:{port}/echos', content=data)
assert res.status_code == 200 assert res.status_code == 200
assert res.text == data assert res.text == data
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
"threading_mode",
[
"runtime",
"workers"
]
)
async def test_body_stream_res(rsgi_server, threading_mode): async def test_body_stream_res(rsgi_server, threading_mode):
async with rsgi_server(threading_mode) as port: async with rsgi_server(threading_mode) as port:
res = httpx.get(f"http://localhost:{port}/stream") res = httpx.get(f'http://localhost:{port}/stream')
assert res.status_code == 200 assert res.status_code == 200
assert res.text == "test" * 3 assert res.text == 'test' * 3
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
"threading_mode",
[
"runtime",
"workers"
]
)
async def test_app_error(rsgi_server, threading_mode): async def test_app_error(rsgi_server, threading_mode):
async with rsgi_server(threading_mode) as port: async with rsgi_server(threading_mode) as port:
res = httpx.get(f"http://localhost:{port}/err_app") res = httpx.get(f'http://localhost:{port}/err_app')
assert res.status_code == 500 assert res.status_code == 500

View file

@ -1,56 +1,48 @@
import json import json
import pytest
import sys import sys
import pytest
import websockets import websockets
@pytest.mark.skipif(sys.platform == "win32", reason="skip on windows") @pytest.mark.skipif(sys.platform == 'win32', reason='skip on windows')
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("server", ["asgi", "rsgi"], indirect=True) @pytest.mark.parametrize('server', ['asgi', 'rsgi'], indirect=True)
@pytest.mark.parametrize("threading_mode", ["runtime", "workers"]) @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_messages(server, threading_mode): async def test_messages(server, threading_mode):
async with server(threading_mode) as port: async with server(threading_mode) as port:
async with websockets.connect(f"ws://localhost:{port}/ws_echo") as ws: async with websockets.connect(f'ws://localhost:{port}/ws_echo') as ws:
await ws.send("foo") await ws.send('foo')
res_text = await ws.recv() res_text = await ws.recv()
await ws.send(b"foo") await ws.send(b'foo')
res_bytes = await ws.recv() res_bytes = await ws.recv()
assert res_text == "foo" assert res_text == 'foo'
assert res_bytes == b"foo" assert res_bytes == b'foo'
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("server", ["asgi", "rsgi"], indirect=True) @pytest.mark.parametrize('server', ['asgi', 'rsgi'], indirect=True)
@pytest.mark.parametrize("threading_mode", ["runtime", "workers"]) @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_reject(server, threading_mode): async def test_reject(server, threading_mode):
async with server(threading_mode) as port: async with server(threading_mode) as port:
with pytest.raises(websockets.InvalidStatusCode) as exc: with pytest.raises(websockets.InvalidStatusCode) as exc:
async with websockets.connect(f"ws://localhost:{port}/ws_reject") as ws: async with websockets.connect(f'ws://localhost:{port}/ws_reject'):
pass pass
assert exc.value.status_code == 403 assert exc.value.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
"threading_mode",
[
"runtime",
"workers"
]
)
async def test_asgi_scope(asgi_server, threading_mode): async def test_asgi_scope(asgi_server, threading_mode):
async with asgi_server(threading_mode) as port: async with asgi_server(threading_mode) as port:
async with websockets.connect(f"ws://localhost:{port}/ws_info?test=true") as ws: async with websockets.connect(f'ws://localhost:{port}/ws_info?test=true') as ws:
res = await ws.recv() res = await ws.recv()
data = json.loads(res) data = json.loads(res)
assert data['asgi'] == { assert data['asgi'] == {'version': '3.0', 'spec_version': '2.3'}
'version': '3.0', assert data['type'] == 'websocket'
'spec_version': '2.3'
}
assert data['type'] == "websocket"
assert data['http_version'] == '1.1' assert data['http_version'] == '1.1'
assert data['scheme'] == 'ws' assert data['scheme'] == 'ws'
assert data['path'] == '/ws_info' assert data['path'] == '/ws_info'
@ -59,16 +51,10 @@ async def test_asgi_scope(asgi_server, threading_mode):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
"threading_mode",
[
"runtime",
"workers"
]
)
async def test_rsgi_scope(rsgi_server, threading_mode): async def test_rsgi_scope(rsgi_server, threading_mode):
async with rsgi_server(threading_mode) as port: async with rsgi_server(threading_mode) as port:
async with websockets.connect(f"ws://localhost:{port}/ws_info?test=true") as ws: async with websockets.connect(f'ws://localhost:{port}/ws_info?test=true') as ws:
res = await ws.recv() res = await ws.recv()
data = json.loads(res) data = json.loads(res)
@ -76,7 +62,7 @@ async def test_rsgi_scope(rsgi_server, threading_mode):
assert data['http_version'] == '1.1' assert data['http_version'] == '1.1'
assert data['rsgi_version'] == '1.2' assert data['rsgi_version'] == '1.2'
assert data['scheme'] == 'http' assert data['scheme'] == 'http'
assert data['method'] == "GET" assert data['method'] == 'GET'
assert data['path'] == '/ws_info' assert data['path'] == '/ws_info'
assert data['query_string'] == 'test=true' assert data['query_string'] == 'test=true'
assert data['headers']['host'] == f'localhost:{port}' assert data['headers']['host'] == f'localhost:{port}'

View file

@ -3,24 +3,18 @@ import pytest
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
"threading_mode",
[
"runtime",
"workers"
]
)
async def test_scope(wsgi_server, threading_mode): async def test_scope(wsgi_server, threading_mode):
payload = "body_payload" payload = 'body_payload'
async with wsgi_server(threading_mode) as port: async with wsgi_server(threading_mode) as port:
res = httpx.post(f"http://localhost:{port}/info?test=true", content=payload) res = httpx.post(f'http://localhost:{port}/info?test=true', content=payload)
assert res.status_code == 200 assert res.status_code == 200
assert res.headers["content-type"] == "application/json" assert res.headers['content-type'] == 'application/json'
data = res.json() data = res.json()
assert data['scheme'] == 'http' assert data['scheme'] == 'http'
assert data['method'] == "POST" assert data['method'] == 'POST'
assert data['path'] == '/info' assert data['path'] == '/info'
assert data['query_string'] == 'test=true' assert data['query_string'] == 'test=true'
assert data['headers']['HTTP_HOST'] == f'localhost:{port}' assert data['headers']['HTTP_HOST'] == f'localhost:{port}'
@ -28,47 +22,29 @@ async def test_scope(wsgi_server, threading_mode):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
"threading_mode",
[
"runtime",
"workers"
]
)
async def test_body(wsgi_server, threading_mode): async def test_body(wsgi_server, threading_mode):
async with wsgi_server(threading_mode) as port: async with wsgi_server(threading_mode) as port:
res = httpx.post(f"http://localhost:{port}/echo", content="test") res = httpx.post(f'http://localhost:{port}/echo', content='test')
assert res.status_code == 200 assert res.status_code == 200
assert res.text == "test" assert res.text == 'test'
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
"threading_mode",
[
"runtime",
"workers"
]
)
async def test_iterbody(wsgi_server, threading_mode): async def test_iterbody(wsgi_server, threading_mode):
async with wsgi_server(threading_mode) as port: async with wsgi_server(threading_mode) as port:
res = httpx.get(f"http://localhost:{port}/iterbody") res = httpx.get(f'http://localhost:{port}/iterbody')
assert res.status_code == 200 assert res.status_code == 200
assert res.text == "test" * 3 assert res.text == 'test' * 3
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
"threading_mode",
[
"runtime",
"workers"
]
)
async def test_app_error(wsgi_server, threading_mode): async def test_app_error(wsgi_server, threading_mode):
async with wsgi_server(threading_mode) as port: async with wsgi_server(threading_mode) as port:
res = httpx.get(f"http://localhost:{port}/err_app") res = httpx.get(f'http://localhost:{port}/err_app')
assert res.status_code == 500 assert res.status_code == 500