mirror of
https://github.com/emmett-framework/granian.git
synced 2025-07-07 11:25:36 +00:00
Align codestyle (#124)
* Add lint and format tools and config * Format Python code * Format Rust code * Add lint CI workflow
This commit is contained in:
parent
e28eb81db6
commit
c95fd0cba7
53 changed files with 1388 additions and 1991 deletions
32
.github/workflows/lint.yml
vendored
Normal file
32
.github/workflows/lint.yml
vendored
Normal file
|
@ -0,0 +1,32 @@
|
|||
name: lint
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
branches:
|
||||
- master
|
||||
|
||||
env:
|
||||
MATURIN_VERSION: 1.2.3
|
||||
PYTHON_VERSION: 3.11
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ env.PYTHON_VERSION }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
- name: Install
|
||||
run: |
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install maturin==${{ env.MATURIN_VERSION }}
|
||||
maturin develop --extras=lint
|
||||
- name: Lint
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
make lint
|
26
.pre-commit-config.yaml
Normal file
26
.pre-commit-config.yaml
Normal file
|
@ -0,0 +1,26 @@
|
|||
fail_fast: true
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.0.1
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: check-toml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- id: check-added-large-files
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: lint-python
|
||||
name: Lint Python
|
||||
entry: make lint-python
|
||||
types: [python]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
- id: lint-rust
|
||||
name: Lint Rust
|
||||
entry: make lint-rust
|
||||
types: [rust]
|
||||
language: system
|
||||
pass_filenames: false
|
1
.rustfmt.toml
Normal file
1
.rustfmt.toml
Normal file
|
@ -0,0 +1 @@
|
|||
max_width = 120
|
67
Makefile
Normal file
67
Makefile
Normal file
|
@ -0,0 +1,67 @@
|
|||
.DEFAULT_GOAL := all
|
||||
black = black granian tests
|
||||
ruff = ruff granian tests
|
||||
|
||||
.PHONY: build-dev
|
||||
build-dev:
|
||||
@rm -f granian/*.so
|
||||
maturin develop --extras test
|
||||
|
||||
.PHONY: format
|
||||
format:
|
||||
$(black)
|
||||
$(ruff) --fix --exit-zero
|
||||
cargo fmt
|
||||
|
||||
.PHONY: lint-python
|
||||
lint-python:
|
||||
$(ruff)
|
||||
$(black) --check --diff
|
||||
|
||||
.PHONY: lint-rust
|
||||
lint-rust:
|
||||
cargo fmt --version
|
||||
cargo fmt --all -- --check
|
||||
cargo clippy --version
|
||||
cargo clippy --tests -- \
|
||||
-D warnings \
|
||||
-W clippy::pedantic \
|
||||
-W clippy::dbg_macro \
|
||||
-W clippy::print_stdout \
|
||||
-A clippy::cast-possible-truncation \
|
||||
-A clippy::cast-possible-wrap \
|
||||
-A clippy::cast-precision-loss \
|
||||
-A clippy::cast-sign-loss \
|
||||
-A clippy::declare-interior-mutable-const \
|
||||
-A clippy::float-cmp \
|
||||
-A clippy::fn-params-excessive-bools \
|
||||
-A clippy::if-not-else \
|
||||
-A clippy::inline-always \
|
||||
-A clippy::manual-let-else \
|
||||
-A clippy::match-bool \
|
||||
-A clippy::match-same-arms \
|
||||
-A clippy::missing-errors-doc \
|
||||
-A clippy::missing-panics-doc \
|
||||
-A clippy::module-name-repetitions \
|
||||
-A clippy::must-use-candidate \
|
||||
-A clippy::needless-pass-by-value \
|
||||
-A clippy::similar-names \
|
||||
-A clippy::single-match-else \
|
||||
-A clippy::struct-excessive-bools \
|
||||
-A clippy::too-many-arguments \
|
||||
-A clippy::too-many-lines \
|
||||
-A clippy::type-complexity \
|
||||
-A clippy::unnecessary-wraps \
|
||||
-A clippy::unused-self \
|
||||
-A clippy::used-underscore-binding \
|
||||
-A clippy::wrong-self-convention
|
||||
|
||||
.PHONY: lint
|
||||
lint: lint-python lint-rust
|
||||
|
||||
.PHONY: test
|
||||
test:
|
||||
pytest -v test
|
||||
|
||||
.PHONY: all
|
||||
all: format build-dev lint test
|
|
@ -1 +1 @@
|
|||
from .server import Granian
|
||||
from .server import Granian # noqa
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from granian.cli import cli
|
||||
|
||||
|
||||
cli()
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = "0.6.0"
|
||||
__version__ = '0.6.0'
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
from typing import Any, Dict, List, Tuple, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ._types import WebsocketMessage
|
||||
|
||||
|
||||
class ASGIScope:
|
||||
def as_dict(self, root_path: str) -> Dict[str, Any]: ...
|
||||
|
||||
|
||||
class RSGIHeaders:
|
||||
def __contains__(self, key: str) -> bool: ...
|
||||
def keys(self) -> List[str]: ...
|
||||
|
@ -14,7 +12,6 @@ class RSGIHeaders:
|
|||
def items(self) -> List[Tuple[str]]: ...
|
||||
def get(self, key: str, default: Any = None) -> Any: ...
|
||||
|
||||
|
||||
class RSGIScope:
|
||||
proto: str
|
||||
http_version: str
|
||||
|
@ -29,12 +26,10 @@ class RSGIScope:
|
|||
@property
|
||||
def headers(self) -> RSGIHeaders: ...
|
||||
|
||||
|
||||
class RSGIHTTPStreamTransport:
|
||||
async def send_bytes(self, data: bytes): ...
|
||||
async def send_str(self, data: str): ...
|
||||
|
||||
|
||||
class RSGIHTTPProtocol:
|
||||
async def __call__(self) -> bytes: ...
|
||||
def response_empty(self, status: int, headers: List[Tuple[str, str]]): ...
|
||||
|
@ -43,25 +38,17 @@ class RSGIHTTPProtocol:
|
|||
def response_file(self, status: int, headers: List[Tuple[str, str]], file: str): ...
|
||||
def response_stream(self, status: int, headers: List[Tuple[str, str]]) -> RSGIHTTPStreamTransport: ...
|
||||
|
||||
|
||||
class RSGIWebsocketTransport:
|
||||
async def receive(self) -> WebsocketMessage: ...
|
||||
async def send_bytes(self, data: bytes): ...
|
||||
async def send_str(self, data: str): ...
|
||||
|
||||
|
||||
class RSGIWebsocketProtocol:
|
||||
async def accept(self) -> RSGIWebsocketTransport: ...
|
||||
def close(self, status: Optional[int]) -> Tuple[int, bool]: ...
|
||||
|
||||
|
||||
class RSGIProtocolError(RuntimeError):
|
||||
...
|
||||
|
||||
|
||||
class RSGIProtocolClosed(RuntimeError):
|
||||
...
|
||||
|
||||
class RSGIProtocolError(RuntimeError): ...
|
||||
class RSGIProtocolClosed(RuntimeError): ...
|
||||
|
||||
class WSGIScope:
|
||||
def to_environ(self, environ: Dict[str, Any]) -> Dict[str, Any]: ...
|
||||
|
|
|
@ -2,22 +2,21 @@ import os
|
|||
import re
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from types import ModuleType
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
|
||||
def get_import_components(path: str) -> List[Optional[str]]:
|
||||
return (re.split(r":(?![\\/])", path, 1) + [None])[:2]
|
||||
return (re.split(r':(?![\\/])', path, 1) + [None])[:2]
|
||||
|
||||
|
||||
def prepare_import(path: str) -> str:
|
||||
path = os.path.realpath(path)
|
||||
|
||||
fname, ext = os.path.splitext(path)
|
||||
if ext == ".py":
|
||||
if ext == '.py':
|
||||
path = fname
|
||||
if os.path.basename(path) == "__init__":
|
||||
if os.path.basename(path) == '__init__':
|
||||
path = os.path.dirname(path)
|
||||
|
||||
module_name = []
|
||||
|
@ -27,26 +26,22 @@ def prepare_import(path: str) -> str:
|
|||
path, name = os.path.split(path)
|
||||
module_name.append(name)
|
||||
|
||||
if not os.path.exists(os.path.join(path, "__init__.py")):
|
||||
if not os.path.exists(os.path.join(path, '__init__.py')):
|
||||
break
|
||||
|
||||
if sys.path[0] != path:
|
||||
sys.path.insert(0, path)
|
||||
|
||||
return ".".join(module_name[::-1])
|
||||
return '.'.join(module_name[::-1])
|
||||
|
||||
|
||||
def load_module(
|
||||
module_name: str,
|
||||
raise_on_failure: bool = True
|
||||
) -> Optional[ModuleType]:
|
||||
def load_module(module_name: str, raise_on_failure: bool = True) -> Optional[ModuleType]:
|
||||
try:
|
||||
__import__(module_name)
|
||||
except ImportError:
|
||||
if sys.exc_info()[-1].tb_next:
|
||||
raise RuntimeError(
|
||||
f"While importing '{module_name}', an ImportError was raised:"
|
||||
f"\n\n{traceback.format_exc()}"
|
||||
f"While importing '{module_name}', an ImportError was raised:" f"\n\n{traceback.format_exc()}"
|
||||
)
|
||||
elif raise_on_failure:
|
||||
raise RuntimeError(f"Could not import '{module_name}'.")
|
||||
|
@ -58,9 +53,9 @@ def load_module(
|
|||
def load_target(target: str) -> Callable[..., None]:
|
||||
path, name = get_import_components(target)
|
||||
path = prepare_import(path) if path else None
|
||||
name = name or "app"
|
||||
name = name or 'app'
|
||||
module = load_module(path)
|
||||
rv = module
|
||||
for element in name.split("."):
|
||||
for element in name.split('.'):
|
||||
rv = getattr(rv, element)
|
||||
return rv
|
||||
|
|
|
@ -2,12 +2,11 @@ import asyncio
|
|||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
|
||||
class Registry:
|
||||
__slots__ = ["_data"]
|
||||
__slots__ = ['_data']
|
||||
|
||||
def __init__(self):
|
||||
self._data: Dict[str, Callable[..., Any]] = {}
|
||||
|
@ -22,6 +21,7 @@ class Registry:
|
|||
def wrap(builder: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self._data[key] = builder
|
||||
return builder
|
||||
|
||||
return wrap
|
||||
|
||||
def get(self, key: str) -> Callable[..., Any]:
|
||||
|
@ -31,18 +31,13 @@ class Registry:
|
|||
raise RuntimeError(f"'{key}' implementation not available.")
|
||||
|
||||
|
||||
|
||||
class BuilderRegistry(Registry):
|
||||
__slots__ = []
|
||||
|
||||
def __init__(self):
|
||||
self._data: Dict[str, Tuple[Callable[..., Any], List[str]]] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
key: str,
|
||||
packages: Optional[List[str]] = None
|
||||
) -> Callable[[], Callable[..., Any]]:
|
||||
def register(self, key: str, packages: Optional[List[str]] = None) -> Callable[[], Callable[..., Any]]:
|
||||
packages = packages or []
|
||||
|
||||
def wrap(builder: Callable[..., Any]) -> Callable[..., Any]:
|
||||
|
@ -56,6 +51,7 @@ class BuilderRegistry(Registry):
|
|||
if implemented:
|
||||
self._data[key] = (builder, loaded_packages)
|
||||
return builder
|
||||
|
||||
return wrap
|
||||
|
||||
def get(self, key: str) -> Callable[..., Any]:
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import asyncio
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from ._granian import ASGIScope as Scope
|
||||
|
@ -22,12 +21,7 @@ class LifespanProtocol:
|
|||
async def handle(self):
|
||||
try:
|
||||
await self.callable(
|
||||
{
|
||||
"type": "lifespan",
|
||||
"asgi": {"version": "3.0", "spec_version": "2.3"}
|
||||
},
|
||||
self.receive,
|
||||
self.send
|
||||
{'type': 'lifespan', 'asgi': {'version': '3.0', 'spec_version': '2.3'}}, self.receive, self.send
|
||||
)
|
||||
except Exception:
|
||||
self.errored = True
|
||||
|
@ -43,7 +37,7 @@ class LifespanProtocol:
|
|||
loop = asyncio.get_event_loop()
|
||||
_handler_task = loop.create_task(self.handle())
|
||||
|
||||
await self.event_queue.put({"type": "lifespan.startup"})
|
||||
await self.event_queue.put({'type': 'lifespan.startup'})
|
||||
await self.event_startup.wait()
|
||||
|
||||
if self.failure_startup or (self.errored and not self.unsupported):
|
||||
|
@ -53,7 +47,7 @@ class LifespanProtocol:
|
|||
if self.errored:
|
||||
return
|
||||
|
||||
await self.event_queue.put({"type": "lifespan.shutdown"})
|
||||
await self.event_queue.put({'type': 'lifespan.shutdown'})
|
||||
await self.event_shutdown.wait()
|
||||
|
||||
if self.failure_shutdown or (self.errored and not self.unsupported):
|
||||
|
@ -89,14 +83,14 @@ class LifespanProtocol:
|
|||
# self.logger.error(message["message"])
|
||||
|
||||
_event_handlers = {
|
||||
"lifespan.startup.complete": _handle_startup_complete,
|
||||
"lifespan.startup.failed": _handle_startup_failed,
|
||||
"lifespan.shutdown.complete": _handle_shutdown_complete,
|
||||
"lifespan.shutdown.failed": _handle_shutdown_failed
|
||||
'lifespan.startup.complete': _handle_startup_complete,
|
||||
'lifespan.startup.failed': _handle_startup_failed,
|
||||
'lifespan.shutdown.complete': _handle_shutdown_complete,
|
||||
'lifespan.shutdown.failed': _handle_shutdown_failed,
|
||||
}
|
||||
|
||||
async def send(self, message):
|
||||
handler = self._event_handlers[message["type"]]
|
||||
handler = self._event_handlers[message['type']]
|
||||
handler(self, message)
|
||||
|
||||
|
||||
|
@ -108,6 +102,7 @@ def _send_wrapper(proto):
|
|||
@wraps(proto)
|
||||
def send(data):
|
||||
return proto(_noop_coro, data)
|
||||
|
||||
return send
|
||||
|
||||
|
||||
|
@ -116,9 +111,6 @@ def _callback_wrapper(callback, scope_opts):
|
|||
|
||||
@wraps(callback)
|
||||
def wrapper(scope: Scope, proto):
|
||||
return callback(
|
||||
scope.as_dict(root_url_path),
|
||||
proto.receive,
|
||||
_send_wrapper(proto.send)
|
||||
)
|
||||
return callback(scope.as_dict(root_url_path), proto.receive, _send_wrapper(proto.send))
|
||||
|
||||
return wrapper
|
||||
|
|
102
granian/cli.py
102
granian/cli.py
|
@ -1,107 +1,55 @@
|
|||
import json
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
|
||||
from .__version__ import __version__
|
||||
from .constants import Interfaces, HTTPModes, Loops, ThreadModes
|
||||
from .constants import HTTPModes, Interfaces, Loops, ThreadModes
|
||||
from .log import LogLevels
|
||||
from .server import Granian
|
||||
|
||||
|
||||
cli = typer.Typer(name="granian", context_settings={"ignore_unknown_options": True})
|
||||
cli = typer.Typer(name='granian', context_settings={'ignore_unknown_options': True})
|
||||
|
||||
|
||||
def version_callback(value: bool):
|
||||
if value:
|
||||
typer.echo(f"{cli.info.name} {__version__}")
|
||||
typer.echo(f'{cli.info.name} {__version__}')
|
||||
raise typer.Exit()
|
||||
|
||||
|
||||
@cli.command()
|
||||
def main(
|
||||
app: str = typer.Argument(..., help="Application target to serve."),
|
||||
host: str = typer.Option("127.0.0.1", help="Host address to bind to."),
|
||||
port: int = typer.Option(8000, help="Port to bind to."),
|
||||
interface: Interfaces = typer.Option(
|
||||
Interfaces.RSGI.value,
|
||||
help="Application interface type."
|
||||
),
|
||||
http: HTTPModes = typer.Option(
|
||||
HTTPModes.auto.value,
|
||||
help="HTTP version."
|
||||
),
|
||||
websockets: bool = typer.Option(
|
||||
True,
|
||||
"--ws/--no-ws",
|
||||
help="Enable websockets handling",
|
||||
show_default="enabled"
|
||||
),
|
||||
workers: int = typer.Option(1, min=1, help="Number of worker processes."),
|
||||
threads: int = typer.Option(1, min=1, help="Number of threads."),
|
||||
threading_mode: ThreadModes = typer.Option(
|
||||
ThreadModes.workers.value,
|
||||
help="Threading mode to use."
|
||||
),
|
||||
loop: Loops = typer.Option(Loops.auto.value, help="Event loop implementation"),
|
||||
loop_opt: bool = typer.Option(
|
||||
False,
|
||||
"--opt/--no-opt",
|
||||
help="Enable loop optimizations",
|
||||
show_default="disabled"
|
||||
),
|
||||
backlog: int = typer.Option(
|
||||
1024,
|
||||
min=128,
|
||||
help="Maximum number of connections to hold in backlog."
|
||||
),
|
||||
log_level: LogLevels = typer.Option(
|
||||
LogLevels.info.value,
|
||||
help="Log level",
|
||||
case_sensitive=False
|
||||
),
|
||||
app: str = typer.Argument(..., help='Application target to serve.'),
|
||||
host: str = typer.Option('127.0.0.1', help='Host address to bind to.'),
|
||||
port: int = typer.Option(8000, help='Port to bind to.'),
|
||||
interface: Interfaces = typer.Option(Interfaces.RSGI.value, help='Application interface type.'),
|
||||
http: HTTPModes = typer.Option(HTTPModes.auto.value, help='HTTP version.'),
|
||||
websockets: bool = typer.Option(True, '--ws/--no-ws', help='Enable websockets handling', show_default='enabled'),
|
||||
workers: int = typer.Option(1, min=1, help='Number of worker processes.'),
|
||||
threads: int = typer.Option(1, min=1, help='Number of threads.'),
|
||||
threading_mode: ThreadModes = typer.Option(ThreadModes.workers.value, help='Threading mode to use.'),
|
||||
loop: Loops = typer.Option(Loops.auto.value, help='Event loop implementation'),
|
||||
loop_opt: bool = typer.Option(False, '--opt/--no-opt', help='Enable loop optimizations', show_default='disabled'),
|
||||
backlog: int = typer.Option(1024, min=128, help='Maximum number of connections to hold in backlog.'),
|
||||
log_level: LogLevels = typer.Option(LogLevels.info.value, help='Log level', case_sensitive=False),
|
||||
log_config: Optional[Path] = typer.Option(
|
||||
None,
|
||||
help="Logging configuration file (json)",
|
||||
exists=True,
|
||||
file_okay=True,
|
||||
dir_okay=False,
|
||||
readable=True
|
||||
None, help='Logging configuration file (json)', exists=True, file_okay=True, dir_okay=False, readable=True
|
||||
),
|
||||
ssl_keyfile: Optional[Path] = typer.Option(
|
||||
None,
|
||||
help="SSL key file",
|
||||
exists=True,
|
||||
file_okay=True,
|
||||
dir_okay=False,
|
||||
readable=True
|
||||
None, help='SSL key file', exists=True, file_okay=True, dir_okay=False, readable=True
|
||||
),
|
||||
ssl_certificate: Optional[Path] = typer.Option(
|
||||
None,
|
||||
help="SSL certificate file",
|
||||
exists=True,
|
||||
file_okay=True,
|
||||
dir_okay=False,
|
||||
readable=True
|
||||
),
|
||||
url_path_prefix: Optional[str] = typer.Option(
|
||||
None,
|
||||
help="URL path prefix the app is mounted on"
|
||||
None, help='SSL certificate file', exists=True, file_okay=True, dir_okay=False, readable=True
|
||||
),
|
||||
url_path_prefix: Optional[str] = typer.Option(None, help='URL path prefix the app is mounted on'),
|
||||
reload: bool = typer.Option(
|
||||
False,
|
||||
"--reload/--no-reload",
|
||||
help="Enable auto reload on application's files changes"
|
||||
False, '--reload/--no-reload', help="Enable auto reload on application's files changes"
|
||||
),
|
||||
_: Optional[bool] = typer.Option(
|
||||
None,
|
||||
"--version",
|
||||
callback=version_callback,
|
||||
is_eager=True,
|
||||
help="Shows the version and exit."
|
||||
)
|
||||
None, '--version', callback=version_callback, is_eager=True, help='Shows the version and exit.'
|
||||
),
|
||||
):
|
||||
log_dictconfig = None
|
||||
if log_config:
|
||||
|
@ -109,7 +57,7 @@ def main(
|
|||
try:
|
||||
log_dictconfig = json.loads(log_config_file.read())
|
||||
except Exception:
|
||||
print("Unable to parse provided logging config.")
|
||||
print('Unable to parse provided logging config.')
|
||||
raise typer.Exit(1)
|
||||
|
||||
Granian(
|
||||
|
@ -131,5 +79,5 @@ def main(
|
|||
ssl_cert=ssl_certificate,
|
||||
ssl_key=ssl_keyfile,
|
||||
url_path_prefix=url_path_prefix,
|
||||
reload=reload
|
||||
reload=reload,
|
||||
).serve()
|
||||
|
|
|
@ -2,23 +2,23 @@ from enum import Enum
|
|||
|
||||
|
||||
class Interfaces(str, Enum):
|
||||
ASGI = "asgi"
|
||||
RSGI = "rsgi"
|
||||
WSGI = "wsgi"
|
||||
ASGI = 'asgi'
|
||||
RSGI = 'rsgi'
|
||||
WSGI = 'wsgi'
|
||||
|
||||
|
||||
class HTTPModes(str, Enum):
|
||||
auto = "auto"
|
||||
http1 = "1"
|
||||
http2 = "2"
|
||||
auto = 'auto'
|
||||
http1 = '1'
|
||||
http2 = '2'
|
||||
|
||||
|
||||
class ThreadModes(str, Enum):
|
||||
runtime = "runtime"
|
||||
workers = "workers"
|
||||
runtime = 'runtime'
|
||||
workers = 'workers'
|
||||
|
||||
|
||||
class Loops(str, Enum):
|
||||
auto = "auto"
|
||||
asyncio = "asyncio"
|
||||
uvloop = "uvloop"
|
||||
auto = 'auto'
|
||||
asyncio = 'asyncio'
|
||||
uvloop = 'uvloop'
|
||||
|
|
|
@ -1,18 +1,17 @@
|
|||
import copy
|
||||
import logging
|
||||
import logging.config
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class LogLevels(str, Enum):
|
||||
critical = "critical"
|
||||
error = "error"
|
||||
warning = "warning"
|
||||
warn = "warn"
|
||||
info = "info"
|
||||
debug = "debug"
|
||||
critical = 'critical'
|
||||
error = 'error'
|
||||
warning = 'warning'
|
||||
warn = 'warn'
|
||||
info = 'info'
|
||||
debug = 'debug'
|
||||
|
||||
|
||||
log_levels_map = {
|
||||
|
@ -21,27 +20,21 @@ log_levels_map = {
|
|||
LogLevels.warning: logging.WARNING,
|
||||
LogLevels.warn: logging.WARN,
|
||||
LogLevels.info: logging.INFO,
|
||||
LogLevels.debug: logging.DEBUG
|
||||
LogLevels.debug: logging.DEBUG,
|
||||
}
|
||||
|
||||
LOGGING_CONFIG = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"root": {"level": "INFO", "handlers": ["console"]},
|
||||
"formatters": {
|
||||
"generic": {
|
||||
"()": "logging.Formatter",
|
||||
"fmt": "[%(levelname)s] %(message)s",
|
||||
"datefmt": "[%Y-%m-%d %H:%M:%S %z]"
|
||||
'version': 1,
|
||||
'disable_existing_loggers': False,
|
||||
'root': {'level': 'INFO', 'handlers': ['console']},
|
||||
'formatters': {
|
||||
'generic': {
|
||||
'()': 'logging.Formatter',
|
||||
'fmt': '[%(levelname)s] %(message)s',
|
||||
'datefmt': '[%Y-%m-%d %H:%M:%S %z]',
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"formatter": "generic",
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": "ext://sys.stdout",
|
||||
}
|
||||
}
|
||||
'handlers': {'console': {'formatter': 'generic', 'class': 'logging.StreamHandler', 'stream': 'ext://sys.stdout'}},
|
||||
}
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
@ -51,5 +44,5 @@ def configure_logging(level: LogLevels, config: Optional[Dict[str, Any]] = None)
|
|||
log_config = copy.deepcopy(LOGGING_CONFIG)
|
||||
if config:
|
||||
log_config.update(config)
|
||||
log_config["root"]["level"] = log_levels_map[level]
|
||||
log_config['root']['level'] = log_levels_map[level]
|
||||
logging.config.dictConfig(log_config)
|
||||
|
|
|
@ -2,7 +2,5 @@ import copyreg
|
|||
|
||||
from ._granian import ListenerHolder as SocketHolder
|
||||
|
||||
copyreg.pickle(
|
||||
SocketHolder,
|
||||
lambda v: (SocketHolder, v.__getstate__())
|
||||
)
|
||||
|
||||
copyreg.pickle(SocketHolder, lambda v: (SocketHolder, v.__getstate__()))
|
||||
|
|
|
@ -2,12 +2,12 @@ from enum import Enum
|
|||
from typing import Union
|
||||
|
||||
from ._granian import (
|
||||
RSGIHTTPProtocol as HTTPProtocol,
|
||||
RSGIWebsocketProtocol as WebsocketProtocol,
|
||||
RSGIHeaders as Headers,
|
||||
RSGIScope as Scope,
|
||||
RSGIProtocolError as ProtocolError,
|
||||
RSGIProtocolClosed as ProtocolClosed
|
||||
RSGIHeaders as Headers, # noqa
|
||||
RSGIHTTPProtocol as HTTPProtocol, # noqa
|
||||
RSGIProtocolClosed as ProtocolClosed, # noqa
|
||||
RSGIProtocolError as ProtocolError, # noqa
|
||||
RSGIScope as Scope, # noqa
|
||||
RSGIWebsocketProtocol as WebsocketProtocol, # noqa
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ import socket
|
|||
import ssl
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
@ -16,11 +15,12 @@ from ._futures import future_watcher_wrapper
|
|||
from ._granian import ASGIWorker, RSGIWorker, WSGIWorker
|
||||
from ._internal import load_target
|
||||
from .asgi import LifespanProtocol, _callback_wrapper as _asgi_call_wrap
|
||||
from .constants import Interfaces, HTTPModes, Loops, ThreadModes
|
||||
from .constants import HTTPModes, Interfaces, Loops, ThreadModes
|
||||
from .log import LogLevels, configure_logging, logger
|
||||
from .net import SocketHolder
|
||||
from .wsgi import _callback_wrapper as _wsgi_call_wrap
|
||||
|
||||
|
||||
multiprocessing.allow_connection_pickling()
|
||||
|
||||
|
||||
|
@ -30,7 +30,7 @@ class Granian:
|
|||
def __init__(
|
||||
self,
|
||||
target: str,
|
||||
address: str = "127.0.0.1",
|
||||
address: str = '127.0.0.1',
|
||||
port: int = 8000,
|
||||
interface: Interfaces = Interfaces.RSGI,
|
||||
workers: int = 1,
|
||||
|
@ -48,7 +48,7 @@ class Granian:
|
|||
ssl_cert: Optional[Path] = None,
|
||||
ssl_key: Optional[Path] = None,
|
||||
url_path_prefix: Optional[str] = None,
|
||||
reload: bool = False
|
||||
reload: bool = False,
|
||||
):
|
||||
self.target = target
|
||||
self.bind_addr = address
|
||||
|
@ -75,11 +75,7 @@ class Granian:
|
|||
self.procs: List[multiprocessing.Process] = []
|
||||
self.exit_event = threading.Event()
|
||||
|
||||
def build_ssl_context(
|
||||
self,
|
||||
cert: Optional[Path],
|
||||
key: Optional[Path]
|
||||
):
|
||||
def build_ssl_context(self, cert: Optional[Path], key: Optional[Path]):
|
||||
if not (cert and key):
|
||||
self.ssl_ctx = (False, None, None)
|
||||
return
|
||||
|
@ -108,7 +104,7 @@ class Granian:
|
|||
log_level,
|
||||
log_config,
|
||||
ssl_ctx,
|
||||
scope_opts
|
||||
scope_opts,
|
||||
):
|
||||
from granian._loops import loops, set_loop_signals
|
||||
|
||||
|
@ -129,29 +125,12 @@ class Granian:
|
|||
wcallback = future_watcher_wrapper(wcallback)
|
||||
|
||||
worker = ASGIWorker(
|
||||
worker_id,
|
||||
sfd,
|
||||
threads,
|
||||
pthreads,
|
||||
http_mode,
|
||||
http1_buffer_size,
|
||||
websockets,
|
||||
loop_opt,
|
||||
*ssl_ctx
|
||||
)
|
||||
serve = getattr(worker, {
|
||||
ThreadModes.runtime: "serve_rth",
|
||||
ThreadModes.workers: "serve_wth"
|
||||
}[threading_mode])
|
||||
serve(
|
||||
wcallback,
|
||||
loop,
|
||||
contextvars.copy_context(),
|
||||
shutdown_event.wait()
|
||||
worker_id, sfd, threads, pthreads, http_mode, http1_buffer_size, websockets, loop_opt, *ssl_ctx
|
||||
)
|
||||
serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode])
|
||||
serve(wcallback, loop, contextvars.copy_context(), shutdown_event.wait())
|
||||
loop.run_until_complete(lifespan_handler.shutdown())
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _spawn_rsgi_worker(
|
||||
worker_id,
|
||||
|
@ -168,7 +147,7 @@ class Granian:
|
|||
log_level,
|
||||
log_config,
|
||||
ssl_ctx,
|
||||
scope_opts
|
||||
scope_opts,
|
||||
):
|
||||
from granian._loops import loops, set_loop_signals
|
||||
|
||||
|
@ -176,38 +155,23 @@ class Granian:
|
|||
loop = loops.get(loop_impl)
|
||||
sfd = socket.fileno()
|
||||
target = callback_loader()
|
||||
callback = (
|
||||
getattr(target, '__rsgi__') if hasattr(target, '__rsgi__') else
|
||||
target
|
||||
)
|
||||
callback = getattr(target, '__rsgi__') if hasattr(target, '__rsgi__') else target
|
||||
callback_init = (
|
||||
getattr(target, '__rsgi_init__') if hasattr(target, '__rsgi_init__') else
|
||||
lambda *args, **kwargs: None
|
||||
getattr(target, '__rsgi_init__') if hasattr(target, '__rsgi_init__') else lambda *args, **kwargs: None
|
||||
)
|
||||
|
||||
shutdown_event = set_loop_signals(loop, [signal.SIGTERM, signal.SIGINT])
|
||||
callback_init(loop)
|
||||
|
||||
worker = RSGIWorker(
|
||||
worker_id,
|
||||
sfd,
|
||||
threads,
|
||||
pthreads,
|
||||
http_mode,
|
||||
http1_buffer_size,
|
||||
websockets,
|
||||
loop_opt,
|
||||
*ssl_ctx
|
||||
worker_id, sfd, threads, pthreads, http_mode, http1_buffer_size, websockets, loop_opt, *ssl_ctx
|
||||
)
|
||||
serve = getattr(worker, {
|
||||
ThreadModes.runtime: "serve_rth",
|
||||
ThreadModes.workers: "serve_wth"
|
||||
}[threading_mode])
|
||||
serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode])
|
||||
serve(
|
||||
future_watcher_wrapper(callback) if not loop_opt else callback,
|
||||
loop,
|
||||
contextvars.copy_context(),
|
||||
shutdown_event.wait()
|
||||
shutdown_event.wait(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -226,7 +190,7 @@ class Granian:
|
|||
log_level,
|
||||
log_config,
|
||||
ssl_ctx,
|
||||
scope_opts
|
||||
scope_opts,
|
||||
):
|
||||
from granian._loops import loops, set_loop_signals
|
||||
|
||||
|
@ -237,46 +201,20 @@ class Granian:
|
|||
|
||||
shutdown_event = set_loop_signals(loop, [signal.SIGTERM, signal.SIGINT])
|
||||
|
||||
worker = WSGIWorker(
|
||||
worker_id,
|
||||
sfd,
|
||||
threads,
|
||||
pthreads,
|
||||
http_mode,
|
||||
http1_buffer_size,
|
||||
*ssl_ctx
|
||||
)
|
||||
serve = getattr(worker, {
|
||||
ThreadModes.runtime: "serve_rth",
|
||||
ThreadModes.workers: "serve_wth"
|
||||
}[threading_mode])
|
||||
serve(
|
||||
_wsgi_call_wrap(callback, scope_opts),
|
||||
loop,
|
||||
contextvars.copy_context(),
|
||||
shutdown_event.wait()
|
||||
)
|
||||
worker = WSGIWorker(worker_id, sfd, threads, pthreads, http_mode, http1_buffer_size, *ssl_ctx)
|
||||
serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode])
|
||||
serve(_wsgi_call_wrap(callback, scope_opts), loop, contextvars.copy_context(), shutdown_event.wait())
|
||||
|
||||
def _init_shared_socket(self):
|
||||
self._shd = SocketHolder.from_address(
|
||||
self.bind_addr,
|
||||
self.bind_port,
|
||||
self.backlog
|
||||
)
|
||||
self._shd = SocketHolder.from_address(self.bind_addr, self.bind_port, self.backlog)
|
||||
self._sfd = self._shd.get_fd()
|
||||
|
||||
def signal_handler(self, *args, **kwargs):
|
||||
self.exit_event.set()
|
||||
|
||||
def _spawn_proc(
|
||||
self,
|
||||
id,
|
||||
target,
|
||||
callback_loader,
|
||||
socket_loader
|
||||
) -> multiprocessing.Process:
|
||||
def _spawn_proc(self, id, target, callback_loader, socket_loader) -> multiprocessing.Process:
|
||||
return multiprocessing.get_context().Process(
|
||||
name="granian-worker",
|
||||
name='granian-worker',
|
||||
target=target,
|
||||
args=(
|
||||
id,
|
||||
|
@ -293,10 +231,8 @@ class Granian:
|
|||
self.log_level,
|
||||
self.log_config,
|
||||
self.ssl_ctx,
|
||||
{
|
||||
"url_path_prefix": self.url_path_prefix
|
||||
}
|
||||
)
|
||||
{'url_path_prefix': self.url_path_prefix},
|
||||
),
|
||||
)
|
||||
|
||||
def _spawn_workers(self, sock, spawn_target, target_loader):
|
||||
|
@ -305,14 +241,11 @@ class Granian:
|
|||
|
||||
for idx in range(self.workers):
|
||||
proc = self._spawn_proc(
|
||||
id=idx + 1,
|
||||
target=spawn_target,
|
||||
callback_loader=target_loader,
|
||||
socket_loader=socket_loader
|
||||
id=idx + 1, target=spawn_target, callback_loader=target_loader, socket_loader=socket_loader
|
||||
)
|
||||
proc.start()
|
||||
self.procs.append(proc)
|
||||
logger.info(f"Spawning worker-{idx + 1} with pid: {proc.pid}")
|
||||
logger.info(f'Spawning worker-{idx + 1} with pid: {proc.pid}')
|
||||
|
||||
def _stop_workers(self):
|
||||
for proc in self.procs:
|
||||
|
@ -321,7 +254,7 @@ class Granian:
|
|||
proc.join()
|
||||
|
||||
def startup(self, spawn_target, target_loader):
|
||||
logger.info("Starting granian")
|
||||
logger.info('Starting granian')
|
||||
|
||||
for sig in self.SIGNALS:
|
||||
signal.signal(sig, self.signal_handler)
|
||||
|
@ -329,13 +262,13 @@ class Granian:
|
|||
self._init_shared_socket()
|
||||
sock = socket.socket(fileno=self._sfd)
|
||||
sock.set_inheritable(True)
|
||||
logger.info(f"Listening at: {self.bind_addr}:{self.bind_port}")
|
||||
logger.info(f'Listening at: {self.bind_addr}:{self.bind_port}')
|
||||
|
||||
self._spawn_workers(sock, spawn_target, target_loader)
|
||||
return sock
|
||||
|
||||
def shutdown(self):
|
||||
logger.info("Shutting down granian")
|
||||
logger.info('Shutting down granian')
|
||||
self._stop_workers()
|
||||
|
||||
def _serve(self, spawn_target, target_loader):
|
||||
|
@ -360,12 +293,12 @@ class Granian:
|
|||
self,
|
||||
spawn_target: Optional[Callable[..., None]] = None,
|
||||
target_loader: Optional[Callable[..., Callable[..., Any]]] = None,
|
||||
wrap_loader: bool = True
|
||||
wrap_loader: bool = True,
|
||||
):
|
||||
default_spawners = {
|
||||
Interfaces.ASGI: self._spawn_asgi_worker,
|
||||
Interfaces.RSGI: self._spawn_rsgi_worker,
|
||||
Interfaces.WSGI: self._spawn_wsgi_worker
|
||||
Interfaces.WSGI: self._spawn_wsgi_worker,
|
||||
}
|
||||
if target_loader:
|
||||
if wrap_loader:
|
||||
|
@ -383,8 +316,5 @@ class Granian:
|
|||
"Number of workers will now fallback to 1."
|
||||
)
|
||||
|
||||
serve_method = (
|
||||
self._serve_with_reloader if self.reload_on_changes else
|
||||
self._serve
|
||||
)
|
||||
serve_method = self._serve_with_reloader if self.reload_on_changes else self._serve
|
||||
serve_method(spawn_target, target_loader)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
|
@ -14,30 +13,27 @@ class Response:
|
|||
self.status = 200
|
||||
self.headers = []
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
status: str,
|
||||
headers: List[Tuple[str, str]],
|
||||
exc_info: Any = None
|
||||
):
|
||||
def __call__(self, status: str, headers: List[Tuple[str, str]], exc_info: Any = None):
|
||||
self.status = int(status.split(' ', 1)[0])
|
||||
self.headers = headers
|
||||
|
||||
|
||||
def _callback_wrapper(callback, scope_opts):
|
||||
basic_env = dict(os.environ)
|
||||
basic_env.update({
|
||||
'GATEWAY_INTERFACE': 'CGI/1.1',
|
||||
'SCRIPT_NAME': scope_opts.get('url_path_prefix') or '',
|
||||
'SERVER_SOFTWARE': 'Granian',
|
||||
'wsgi.errors': sys.stderr,
|
||||
'wsgi.input_terminated': True,
|
||||
'wsgi.input': None,
|
||||
'wsgi.multiprocess': False,
|
||||
'wsgi.multithread': False,
|
||||
'wsgi.run_once': False,
|
||||
'wsgi.version': (1, 0)
|
||||
})
|
||||
basic_env.update(
|
||||
{
|
||||
'GATEWAY_INTERFACE': 'CGI/1.1',
|
||||
'SCRIPT_NAME': scope_opts.get('url_path_prefix') or '',
|
||||
'SERVER_SOFTWARE': 'Granian',
|
||||
'wsgi.errors': sys.stderr,
|
||||
'wsgi.input_terminated': True,
|
||||
'wsgi.input': None,
|
||||
'wsgi.multiprocess': False,
|
||||
'wsgi.multithread': False,
|
||||
'wsgi.run_once': False,
|
||||
'wsgi.version': (1, 0),
|
||||
}
|
||||
)
|
||||
|
||||
@wraps(callback)
|
||||
def wrapper(scope: Scope) -> Tuple[int, List[Tuple[str, str]], bytes]:
|
||||
|
@ -46,7 +42,7 @@ def _callback_wrapper(callback, scope_opts):
|
|||
|
||||
if isinstance(rv, list):
|
||||
resp_type = 0
|
||||
rv = b"".join(rv)
|
||||
rv = b''.join(rv)
|
||||
else:
|
||||
resp_type = 1
|
||||
rv = iter(rv)
|
||||
|
|
126
pyproject.toml
126
pyproject.toml
|
@ -1,65 +1,111 @@
|
|||
[project]
|
||||
name = "granian"
|
||||
name = 'granian'
|
||||
authors = [
|
||||
{name = "Giovanni Barillari", email = "g@baro.dev"}
|
||||
{name = 'Giovanni Barillari', email = 'g@baro.dev'}
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: BSD License",
|
||||
"Operating System :: MacOS",
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: Implementation :: CPython",
|
||||
"Programming Language :: Python :: Implementation :: PyPy",
|
||||
"Programming Language :: Python",
|
||||
"Programming Language :: Rust",
|
||||
"Topic :: Internet :: WWW/HTTP"
|
||||
'Development Status :: 5 - Production/Stable',
|
||||
'Intended Audience :: Developers',
|
||||
'License :: OSI Approved :: BSD License',
|
||||
'Operating System :: MacOS',
|
||||
'Operating System :: Microsoft :: Windows',
|
||||
'Operating System :: POSIX :: Linux',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Programming Language :: Python :: 3.9',
|
||||
'Programming Language :: Python :: 3.10',
|
||||
'Programming Language :: Python :: 3.11',
|
||||
'Programming Language :: Python :: Implementation :: CPython',
|
||||
'Programming Language :: Python :: Implementation :: PyPy',
|
||||
'Programming Language :: Python',
|
||||
'Programming Language :: Rust',
|
||||
'Topic :: Internet :: WWW/HTTP'
|
||||
]
|
||||
|
||||
dynamic = [
|
||||
"description",
|
||||
"keywords",
|
||||
"license",
|
||||
"readme",
|
||||
"version"
|
||||
'description',
|
||||
'keywords',
|
||||
'license',
|
||||
'readme',
|
||||
'version'
|
||||
]
|
||||
|
||||
requires-python = ">=3.8"
|
||||
requires-python = '>=3.8'
|
||||
dependencies = [
|
||||
"watchfiles~=0.18",
|
||||
"typer~=0.4",
|
||||
"uvloop~=0.17.0; sys_platform != 'win32' and platform_python_implementation == 'CPython'"
|
||||
'watchfiles~=0.18',
|
||||
'typer~=0.4',
|
||||
'uvloop~=0.17.0; sys_platform != "win32" and platform_python_implementation == "CPython"'
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
lint = [
|
||||
'black~=23.7.0',
|
||||
'ruff~=0.0.287'
|
||||
]
|
||||
test = [
|
||||
"httpx~=0.23.0",
|
||||
"pytest~=7.1.2",
|
||||
"pytest-asyncio~=0.18.3",
|
||||
"websockets~=10.3"
|
||||
'httpx~=0.23.0',
|
||||
'pytest~=7.1.2',
|
||||
'pytest-asyncio~=0.18.3',
|
||||
'websockets~=10.3'
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/emmett-framework/granian"
|
||||
Funding = "https://github.com/sponsors/gi0baro"
|
||||
Source = "https://github.com/emmett-framework/granian"
|
||||
Homepage = 'https://github.com/emmett-framework/granian'
|
||||
Funding = 'https://github.com/sponsors/gi0baro'
|
||||
Source = 'https://github.com/emmett-framework/granian'
|
||||
|
||||
[project.scripts]
|
||||
granian = "granian:cli.cli"
|
||||
granian = 'granian:cli.cli'
|
||||
|
||||
[build-system]
|
||||
requires = ["maturin>=1.1.0,<1.3.0"]
|
||||
build-backend = "maturin"
|
||||
requires = ['maturin>=1.1.0,<1.3.0']
|
||||
build-backend = 'maturin'
|
||||
|
||||
[tool.maturin]
|
||||
module-name = "granian._granian"
|
||||
bindings = "pyo3"
|
||||
module-name = 'granian._granian'
|
||||
bindings = 'pyo3'
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
extend-select = [
|
||||
# E and F are enabled by default
|
||||
'B', # flake8-bugbear
|
||||
'C4', # flake8-comprehensions
|
||||
'C90', # mccabe
|
||||
'I', # isort
|
||||
'N', # pep8-naming
|
||||
'Q', # flake8-quotes
|
||||
'RUF100', # ruff (unused noqa)
|
||||
'S', # flake8-bandit
|
||||
'W' # pycodestyle
|
||||
]
|
||||
extend-ignore = [
|
||||
'B008', # function calls in args defaults are fine
|
||||
'B009', # getattr with constants is fine
|
||||
'B034', # re.split won't confuse us
|
||||
'B904', # rising without from is fine
|
||||
'E501', # leave line length to black
|
||||
'N818', # leave to us exceptions naming
|
||||
'S101' # assert is fine
|
||||
]
|
||||
flake8-quotes = { inline-quotes = 'single', multiline-quotes = 'double' }
|
||||
mccabe = { max-complexity = 13 }
|
||||
|
||||
[tool.ruff.isort]
|
||||
combine-as-imports = true
|
||||
lines-after-imports = 2
|
||||
known-first-party = ['granian', 'tests']
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
'granian/_granian.pyi' = ['I001']
|
||||
'tests/**' = ['B018', 'S110', 'S501']
|
||||
|
||||
[tool.black]
|
||||
color = true
|
||||
line-length = 120
|
||||
target-version = ['py38', 'py39', 'py310', 'py311']
|
||||
skip-string-normalization = true # leave this to ruff
|
||||
skip-magic-trailing-comma = true # leave this to ruff
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_mode = 'auto'
|
||||
|
|
|
@ -3,45 +3,33 @@ use pyo3::prelude::*;
|
|||
use pyo3_asyncio::TaskLocals;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use crate::{
|
||||
callbacks::{
|
||||
CallbackWrapper,
|
||||
callback_impl_run,
|
||||
callback_impl_run_pytask,
|
||||
callback_impl_loop_run,
|
||||
callback_impl_loop_pytask,
|
||||
callback_impl_loop_step,
|
||||
callback_impl_loop_wake,
|
||||
callback_impl_loop_err
|
||||
},
|
||||
runtime::RuntimeRef,
|
||||
ws::{HyperWebsocket, UpgradeData}
|
||||
};
|
||||
use super::{
|
||||
io::{ASGIHTTPProtocol as HTTPProtocol, ASGIWebsocketProtocol as WebsocketProtocol},
|
||||
types::ASGIScope as Scope
|
||||
types::ASGIScope as Scope,
|
||||
};
|
||||
use crate::{
|
||||
callbacks::{
|
||||
callback_impl_loop_err, callback_impl_loop_pytask, callback_impl_loop_run, callback_impl_loop_step,
|
||||
callback_impl_loop_wake, callback_impl_run, callback_impl_run_pytask, CallbackWrapper,
|
||||
},
|
||||
runtime::RuntimeRef,
|
||||
ws::{HyperWebsocket, UpgradeData},
|
||||
};
|
||||
|
||||
|
||||
#[pyclass]
|
||||
pub(crate) struct CallbackRunnerHTTP {
|
||||
proto: Py<HTTPProtocol>,
|
||||
context: TaskLocals,
|
||||
cb: PyObject
|
||||
cb: PyObject,
|
||||
}
|
||||
|
||||
impl CallbackRunnerHTTP {
|
||||
pub fn new(
|
||||
py: Python,
|
||||
cb: CallbackWrapper,
|
||||
proto: HTTPProtocol,
|
||||
scope: Scope
|
||||
) -> Self {
|
||||
pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self {
|
||||
let pyproto = Py::new(py, proto).unwrap();
|
||||
Self {
|
||||
proto: pyproto.clone(),
|
||||
context: cb.context,
|
||||
cb: cb.callback.call1(py, (scope, pyproto)).unwrap()
|
||||
cb: cb.callback.call1(py, (scope, pyproto)).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -64,14 +52,14 @@ macro_rules! callback_impl_done_http {
|
|||
let _ = tx.send(res);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! callback_impl_done_err {
|
||||
($self:expr, $py:expr) => {
|
||||
log::warn!("Application callable raised an exception");
|
||||
$self.done($py)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
|
@ -79,22 +67,17 @@ pub(crate) struct CallbackTaskHTTP {
|
|||
proto: Py<HTTPProtocol>,
|
||||
context: TaskLocals,
|
||||
pycontext: PyObject,
|
||||
cb: PyObject
|
||||
cb: PyObject,
|
||||
}
|
||||
|
||||
impl CallbackTaskHTTP {
|
||||
pub fn new(
|
||||
py: Python,
|
||||
cb: PyObject,
|
||||
proto: Py<HTTPProtocol>,
|
||||
context: TaskLocals
|
||||
) -> PyResult<Self> {
|
||||
pub fn new(py: Python, cb: PyObject, proto: Py<HTTPProtocol>, context: TaskLocals) -> PyResult<Self> {
|
||||
let pyctx = context.context(py);
|
||||
Ok(Self {
|
||||
proto,
|
||||
context,
|
||||
pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(),
|
||||
cb
|
||||
cb,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -128,21 +111,16 @@ pub(crate) struct CallbackWrappedRunnerHTTP {
|
|||
context: TaskLocals,
|
||||
cb: PyObject,
|
||||
#[pyo3(get)]
|
||||
scope: PyObject
|
||||
scope: PyObject,
|
||||
}
|
||||
|
||||
impl CallbackWrappedRunnerHTTP {
|
||||
pub fn new(
|
||||
py: Python,
|
||||
cb: CallbackWrapper,
|
||||
proto: HTTPProtocol,
|
||||
scope: Scope
|
||||
) -> Self {
|
||||
pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self {
|
||||
Self {
|
||||
proto: Py::new(py, proto).unwrap(),
|
||||
context: cb.context,
|
||||
cb: cb.callback,
|
||||
scope: scope.into_py(py)
|
||||
scope: scope.into_py(py),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -168,21 +146,16 @@ impl CallbackWrappedRunnerHTTP {
|
|||
pub(crate) struct CallbackRunnerWebsocket {
|
||||
proto: Py<WebsocketProtocol>,
|
||||
context: TaskLocals,
|
||||
cb: PyObject
|
||||
cb: PyObject,
|
||||
}
|
||||
|
||||
impl CallbackRunnerWebsocket {
|
||||
pub fn new(
|
||||
py: Python,
|
||||
cb: CallbackWrapper,
|
||||
proto: WebsocketProtocol,
|
||||
scope: Scope
|
||||
) -> Self {
|
||||
pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self {
|
||||
let pyproto = Py::new(py, proto).unwrap();
|
||||
Self {
|
||||
proto: pyproto.clone(),
|
||||
context: cb.context,
|
||||
cb: cb.callback.call1(py, (scope, pyproto)).unwrap()
|
||||
cb: cb.callback.call1(py, (scope, pyproto)).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -203,7 +176,7 @@ macro_rules! callback_impl_done_ws {
|
|||
let _ = tx.send(res);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
|
@ -211,22 +184,17 @@ pub(crate) struct CallbackTaskWebsocket {
|
|||
proto: Py<WebsocketProtocol>,
|
||||
context: TaskLocals,
|
||||
pycontext: PyObject,
|
||||
cb: PyObject
|
||||
cb: PyObject,
|
||||
}
|
||||
|
||||
impl CallbackTaskWebsocket {
|
||||
pub fn new(
|
||||
py: Python,
|
||||
cb: PyObject,
|
||||
proto: Py<WebsocketProtocol>,
|
||||
context: TaskLocals
|
||||
) -> PyResult<Self> {
|
||||
pub fn new(py: Python, cb: PyObject, proto: Py<WebsocketProtocol>, context: TaskLocals) -> PyResult<Self> {
|
||||
let pyctx = context.context(py);
|
||||
Ok(Self {
|
||||
proto,
|
||||
context,
|
||||
pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(),
|
||||
cb
|
||||
cb,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -260,21 +228,16 @@ pub(crate) struct CallbackWrappedRunnerWebsocket {
|
|||
context: TaskLocals,
|
||||
cb: PyObject,
|
||||
#[pyo3(get)]
|
||||
scope: PyObject
|
||||
scope: PyObject,
|
||||
}
|
||||
|
||||
impl CallbackWrappedRunnerWebsocket {
|
||||
pub fn new(
|
||||
py: Python,
|
||||
cb: CallbackWrapper,
|
||||
proto: WebsocketProtocol,
|
||||
scope: Scope
|
||||
) -> Self {
|
||||
pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self {
|
||||
Self {
|
||||
proto: Py::new(py, proto).unwrap(),
|
||||
context: cb.context,
|
||||
cb: cb.callback,
|
||||
scope: scope.into_py(py)
|
||||
scope: scope.into_py(py),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -325,7 +288,7 @@ macro_rules! call_impl_rtb_http {
|
|||
cb: CallbackWrapper,
|
||||
rt: RuntimeRef,
|
||||
req: Request<Body>,
|
||||
scope: Scope
|
||||
scope: Scope,
|
||||
) -> oneshot::Receiver<Response<Body>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let protocol = HTTPProtocol::new(rt, req, tx);
|
||||
|
@ -345,7 +308,7 @@ macro_rules! call_impl_rtt_http {
|
|||
cb: CallbackWrapper,
|
||||
rt: RuntimeRef,
|
||||
req: Request<Body>,
|
||||
scope: Scope
|
||||
scope: Scope,
|
||||
) -> oneshot::Receiver<Response<Body>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let protocol = HTTPProtocol::new(rt, req, tx);
|
||||
|
@ -368,7 +331,7 @@ macro_rules! call_impl_rtb_ws {
|
|||
rt: RuntimeRef,
|
||||
ws: HyperWebsocket,
|
||||
upgrade: UpgradeData,
|
||||
scope: Scope
|
||||
scope: Scope,
|
||||
) -> oneshot::Receiver<bool> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade);
|
||||
|
@ -389,7 +352,7 @@ macro_rules! call_impl_rtt_ws {
|
|||
rt: RuntimeRef,
|
||||
ws: HyperWebsocket,
|
||||
upgrade: UpgradeData,
|
||||
scope: Scope
|
||||
scope: Scope,
|
||||
) -> oneshot::Receiver<bool> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade);
|
||||
|
|
|
@ -88,5 +88,5 @@ macro_rules! error_message {
|
|||
}
|
||||
|
||||
pub(crate) use error_flow;
|
||||
pub(crate) use error_transport;
|
||||
pub(crate) use error_message;
|
||||
pub(crate) use error_transport;
|
||||
|
|
|
@ -1,34 +1,22 @@
|
|||
use hyper::{
|
||||
Body,
|
||||
Request,
|
||||
Response,
|
||||
StatusCode,
|
||||
header::SERVER as HK_SERVER,
|
||||
http::response::Builder as ResponseBuilder
|
||||
header::SERVER as HK_SERVER, http::response::Builder as ResponseBuilder, Body, Request, Response, StatusCode,
|
||||
};
|
||||
use std::net::SocketAddr;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::{
|
||||
callbacks::CallbackWrapper,
|
||||
http::{HV_SERVER, response_500},
|
||||
runtime::RuntimeRef,
|
||||
ws::{UpgradeData, is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade}
|
||||
};
|
||||
use super::{
|
||||
callbacks::{
|
||||
call_rtb_http,
|
||||
call_rtb_http_pyw,
|
||||
call_rtb_ws,
|
||||
call_rtb_ws_pyw,
|
||||
call_rtt_http,
|
||||
call_rtt_http_pyw,
|
||||
call_rtt_ws,
|
||||
call_rtt_ws_pyw
|
||||
call_rtb_http, call_rtb_http_pyw, call_rtb_ws, call_rtb_ws_pyw, call_rtt_http, call_rtt_http_pyw, call_rtt_ws,
|
||||
call_rtt_ws_pyw,
|
||||
},
|
||||
types::ASGIScope as Scope
|
||||
types::ASGIScope as Scope,
|
||||
};
|
||||
use crate::{
|
||||
callbacks::CallbackWrapper,
|
||||
http::{response_500, HV_SERVER},
|
||||
runtime::RuntimeRef,
|
||||
ws::{is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade, UpgradeData},
|
||||
};
|
||||
|
||||
|
||||
macro_rules! default_scope {
|
||||
($server_addr:expr, $client_addr:expr, $req:expr, $scheme:expr) => {
|
||||
|
@ -39,7 +27,7 @@ macro_rules! default_scope {
|
|||
$req.method().as_ref(),
|
||||
$server_addr,
|
||||
$client_addr,
|
||||
$req.headers()
|
||||
$req.headers(),
|
||||
)
|
||||
};
|
||||
}
|
||||
|
@ -53,7 +41,7 @@ macro_rules! handle_http_response {
|
|||
response_500()
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! handle_request {
|
||||
|
@ -64,7 +52,7 @@ macro_rules! handle_request {
|
|||
server_addr: SocketAddr,
|
||||
client_addr: SocketAddr,
|
||||
req: Request<Body>,
|
||||
scheme: &str
|
||||
scheme: &str,
|
||||
) -> Response<Body> {
|
||||
let scope = default_scope!(server_addr, client_addr, &req, scheme);
|
||||
handle_http_response!($handler, rt, callback, req, scope)
|
||||
|
@ -80,7 +68,7 @@ macro_rules! handle_request_with_ws {
|
|||
server_addr: SocketAddr,
|
||||
client_addr: SocketAddr,
|
||||
req: Request<Body>,
|
||||
scheme: &str
|
||||
scheme: &str,
|
||||
) -> Response<Body> {
|
||||
let mut scope = default_scope!(server_addr, client_addr, &req, scheme);
|
||||
|
||||
|
@ -95,24 +83,20 @@ macro_rules! handle_request_with_ws {
|
|||
rt.inner.spawn(async move {
|
||||
let tx_ref = restx.clone();
|
||||
|
||||
match $handler_ws(
|
||||
callback,
|
||||
rth,
|
||||
ws,
|
||||
UpgradeData::new(res, restx),
|
||||
scope
|
||||
).await {
|
||||
match $handler_ws(callback, rth, ws, UpgradeData::new(res, restx), scope).await {
|
||||
Ok(consumed) => {
|
||||
if !consumed {
|
||||
let _ = tx_ref.send(
|
||||
ResponseBuilder::new()
|
||||
.status(StatusCode::FORBIDDEN)
|
||||
.header(HK_SERVER, HV_SERVER)
|
||||
.body(Body::from(""))
|
||||
.unwrap()
|
||||
).await;
|
||||
let _ = tx_ref
|
||||
.send(
|
||||
ResponseBuilder::new()
|
||||
.status(StatusCode::FORBIDDEN)
|
||||
.header(HK_SERVER, HV_SERVER)
|
||||
.body(Body::from(""))
|
||||
.unwrap(),
|
||||
)
|
||||
.await;
|
||||
};
|
||||
},
|
||||
}
|
||||
_ => {
|
||||
log::error!("ASGI protocol failure");
|
||||
let _ = tx_ref.send(response_500()).await;
|
||||
|
@ -124,10 +108,10 @@ macro_rules! handle_request_with_ws {
|
|||
Some(res) => {
|
||||
resrx.close();
|
||||
res
|
||||
},
|
||||
_ => response_500()
|
||||
}
|
||||
_ => response_500(),
|
||||
}
|
||||
},
|
||||
}
|
||||
Err(err) => {
|
||||
return ResponseBuilder::new()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
|
|
315
src/asgi/io.rs
315
src/asgi/io.rs
|
@ -1,33 +1,34 @@
|
|||
use bytes::Bytes;
|
||||
use futures::{sink::SinkExt, stream::{SplitSink, SplitStream, StreamExt}};
|
||||
use futures::{
|
||||
sink::SinkExt,
|
||||
stream::{SplitSink, SplitStream, StreamExt},
|
||||
};
|
||||
use hyper::{
|
||||
Request,
|
||||
Response,
|
||||
body::{Body, HttpBody, Sender as BodySender},
|
||||
header::{HeaderName, HeaderValue, HeaderMap, SERVER as HK_SERVER}
|
||||
header::{HeaderMap, HeaderName, HeaderValue, SERVER as HK_SERVER},
|
||||
Request, Response,
|
||||
};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyBytes, PyDict};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{oneshot, Mutex};
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio::sync::{Mutex, oneshot};
|
||||
use tungstenite::Message;
|
||||
|
||||
use super::{
|
||||
errors::{error_flow, error_message, error_transport, UnsupportedASGIMessage},
|
||||
types::ASGIMessageType,
|
||||
};
|
||||
use crate::{
|
||||
http::HV_SERVER,
|
||||
runtime::{RuntimeRef, future_into_py_iter, future_into_py_futlike},
|
||||
ws::{HyperWebsocket, UpgradeData}
|
||||
runtime::{future_into_py_futlike, future_into_py_iter, RuntimeRef},
|
||||
ws::{HyperWebsocket, UpgradeData},
|
||||
};
|
||||
use super::{
|
||||
errors::{UnsupportedASGIMessage, error_flow, error_transport, error_message},
|
||||
types::ASGIMessageType
|
||||
};
|
||||
|
||||
|
||||
const EMPTY_BYTES: Vec<u8> = Vec::new();
|
||||
const EMPTY_STRING: String = String::new();
|
||||
|
||||
#[pyclass(module="granian._granian")]
|
||||
#[pyclass(module = "granian._granian")]
|
||||
pub(crate) struct ASGIHTTPProtocol {
|
||||
rt: RuntimeRef,
|
||||
tx: Option<oneshot::Sender<Response<Body>>>,
|
||||
|
@ -36,15 +37,11 @@ pub(crate) struct ASGIHTTPProtocol {
|
|||
response_chunked: bool,
|
||||
response_status: Option<i16>,
|
||||
response_headers: Option<HeaderMap>,
|
||||
body_tx: Option<Arc<Mutex<BodySender>>>
|
||||
body_tx: Option<Arc<Mutex<BodySender>>>,
|
||||
}
|
||||
|
||||
impl ASGIHTTPProtocol {
|
||||
pub fn new(
|
||||
rt: RuntimeRef,
|
||||
request: Request<Body>,
|
||||
tx: oneshot::Sender<Response<Body>>
|
||||
) -> Self {
|
||||
pub fn new(rt: RuntimeRef, request: Request<Body>, tx: oneshot::Sender<Response<Body>>) -> Self {
|
||||
Self {
|
||||
rt,
|
||||
tx: Some(tx),
|
||||
|
@ -53,7 +50,7 @@ impl ASGIHTTPProtocol {
|
|||
response_chunked: false,
|
||||
response_status: None,
|
||||
response_headers: None,
|
||||
body_tx: None
|
||||
body_tx: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -71,7 +68,7 @@ impl ASGIHTTPProtocol {
|
|||
fn send_body<'p>(&self, py: Python<'p>, tx: Arc<Mutex<BodySender>>, body: Vec<u8>) -> PyResult<&'p PyAny> {
|
||||
future_into_py_futlike(self.rt.clone(), py, async move {
|
||||
let mut tx = tx.lock().await;
|
||||
match (&mut *tx).send_data(body.into()).await {
|
||||
match (*tx).send_data(body.into()).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
log::warn!("ASGI transport tx error: {:?}", err);
|
||||
|
@ -94,18 +91,18 @@ impl ASGIHTTPProtocol {
|
|||
let mut bodym = body_ref.lock().await;
|
||||
let body = &mut *bodym;
|
||||
let mut more_body = false;
|
||||
let chunk = body.data().await.map_or_else(|| Bytes::new(), |buf| {
|
||||
buf.map_or_else(|_| Bytes::new(), |buf| {
|
||||
more_body = !body.is_end_stream();
|
||||
buf
|
||||
})
|
||||
let chunk = body.data().await.map_or_else(Bytes::new, |buf| {
|
||||
buf.map_or_else(
|
||||
|_| Bytes::new(),
|
||||
|buf| {
|
||||
more_body = !body.is_end_stream();
|
||||
buf
|
||||
},
|
||||
)
|
||||
});
|
||||
Python::with_gil(|py| {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item(
|
||||
pyo3::intern!(py, "type"),
|
||||
pyo3::intern!(py, "http.request")
|
||||
)?;
|
||||
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "http.request"))?;
|
||||
dict.set_item(pyo3::intern!(py, "body"), PyBytes::new(py, &chunk[..]))?;
|
||||
dict.set_item(pyo3::intern!(py, "more_body"), more_body)?;
|
||||
Ok(dict.to_object(py))
|
||||
|
@ -115,16 +112,14 @@ impl ASGIHTTPProtocol {
|
|||
|
||||
fn send<'p>(&mut self, py: Python<'p>, asyncw: &'p PyAny, data: &'p PyDict) -> PyResult<&'p PyAny> {
|
||||
match adapt_message_type(data) {
|
||||
Ok(ASGIMessageType::HTTPStart) => {
|
||||
match self.response_started {
|
||||
false => {
|
||||
self.response_status = Some(adapt_status_code(data)?);
|
||||
self.response_headers = Some(adapt_headers(data));
|
||||
self.response_started = true;
|
||||
asyncw.call0()
|
||||
},
|
||||
true => error_flow!()
|
||||
Ok(ASGIMessageType::HTTPStart) => match self.response_started {
|
||||
false => {
|
||||
self.response_status = Some(adapt_status_code(data)?);
|
||||
self.response_headers = Some(adapt_headers(data));
|
||||
self.response_started = true;
|
||||
asyncw.call0()
|
||||
}
|
||||
true => error_flow!(),
|
||||
},
|
||||
Ok(ASGIMessageType::HTTPBody) => {
|
||||
let (body, more) = adapt_body(data);
|
||||
|
@ -133,7 +128,7 @@ impl ASGIHTTPProtocol {
|
|||
let headers = self.response_headers.take().unwrap();
|
||||
self.send_response(self.response_status.unwrap(), headers, body.into());
|
||||
asyncw.call0()
|
||||
},
|
||||
}
|
||||
(true, true, false) => {
|
||||
self.response_chunked = true;
|
||||
let headers = self.response_headers.take().unwrap();
|
||||
|
@ -142,37 +137,31 @@ impl ASGIHTTPProtocol {
|
|||
self.body_tx = Some(tx.clone());
|
||||
self.send_response(self.response_status.unwrap(), headers, body_stream);
|
||||
self.send_body(py, tx, body)
|
||||
},
|
||||
(true, true, true) => {
|
||||
match self.body_tx.as_mut() {
|
||||
Some(tx) => {
|
||||
let tx = tx.clone();
|
||||
self.send_body(py, tx, body)
|
||||
},
|
||||
_ => error_flow!()
|
||||
}
|
||||
(true, true, true) => match self.body_tx.as_mut() {
|
||||
Some(tx) => {
|
||||
let tx = tx.clone();
|
||||
self.send_body(py, tx, body)
|
||||
}
|
||||
_ => error_flow!(),
|
||||
},
|
||||
(true, false, true) => {
|
||||
match self.body_tx.take() {
|
||||
Some(tx) => {
|
||||
match body.is_empty() {
|
||||
false => self.send_body(py, tx, body),
|
||||
true => asyncw.call0()
|
||||
}
|
||||
},
|
||||
_ => error_flow!()
|
||||
}
|
||||
(true, false, true) => match self.body_tx.take() {
|
||||
Some(tx) => match body.is_empty() {
|
||||
false => self.send_body(py, tx, body),
|
||||
true => asyncw.call0(),
|
||||
},
|
||||
_ => error_flow!(),
|
||||
},
|
||||
_ => error_flow!()
|
||||
_ => error_flow!(),
|
||||
}
|
||||
},
|
||||
}
|
||||
Err(err) => Err(err.into()),
|
||||
_ => error_message!()
|
||||
_ => error_message!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(module="granian._granian")]
|
||||
#[pyclass(module = "granian._granian")]
|
||||
pub(crate) struct ASGIWebsocketProtocol {
|
||||
rt: RuntimeRef,
|
||||
tx: Option<oneshot::Sender<bool>>,
|
||||
|
@ -181,16 +170,11 @@ pub(crate) struct ASGIWebsocketProtocol {
|
|||
ws_tx: Arc<Mutex<Option<SplitSink<WebSocketStream<hyper::upgrade::Upgraded>, Message>>>>,
|
||||
ws_rx: Arc<Mutex<Option<SplitStream<WebSocketStream<hyper::upgrade::Upgraded>>>>>,
|
||||
accepted: Arc<Mutex<bool>>,
|
||||
closed: bool
|
||||
closed: bool,
|
||||
}
|
||||
|
||||
impl ASGIWebsocketProtocol {
|
||||
pub fn new(
|
||||
rt: RuntimeRef,
|
||||
tx: oneshot::Sender<bool>,
|
||||
websocket: HyperWebsocket,
|
||||
upgrade: UpgradeData
|
||||
) -> Self {
|
||||
pub fn new(rt: RuntimeRef, tx: oneshot::Sender<bool>, websocket: HyperWebsocket, upgrade: UpgradeData) -> Self {
|
||||
Self {
|
||||
rt,
|
||||
tx: Some(tx),
|
||||
|
@ -199,7 +183,7 @@ impl ASGIWebsocketProtocol {
|
|||
ws_tx: Arc::new(Mutex::new(None)),
|
||||
ws_rx: Arc::new(Mutex::new(None)),
|
||||
accepted: Arc::new(Mutex::new(false)),
|
||||
closed: false
|
||||
closed: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -211,7 +195,7 @@ impl ASGIWebsocketProtocol {
|
|||
let tx = self.ws_tx.clone();
|
||||
let rx = self.ws_rx.clone();
|
||||
future_into_py_iter(self.rt.clone(), py, async move {
|
||||
if let Ok(_) = upgrade.send().await {
|
||||
if (upgrade.send().await).is_ok() {
|
||||
if let Ok(stream) = websocket.await {
|
||||
let mut wtx = tx.lock().await;
|
||||
let mut wrx = rx.lock().await;
|
||||
|
@ -220,7 +204,7 @@ impl ASGIWebsocketProtocol {
|
|||
*wtx = Some(tx);
|
||||
*wrx = Some(rx);
|
||||
*accepted = true;
|
||||
return Ok(())
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
error_flow!()
|
||||
|
@ -228,18 +212,14 @@ impl ASGIWebsocketProtocol {
|
|||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn send_message<'p>(
|
||||
&self,
|
||||
py: Python<'p>,
|
||||
data: &'p PyDict
|
||||
) -> PyResult<&'p PyAny> {
|
||||
fn send_message<'p>(&self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> {
|
||||
let transport = self.ws_tx.clone();
|
||||
let message = ws_message_into_rs(data);
|
||||
future_into_py_iter(self.rt.clone(), py, async move {
|
||||
if let Ok(message) = message {
|
||||
if let Some(ws) = &mut *(transport.lock().await) {
|
||||
if let Ok(_) = ws.send(message).await {
|
||||
return Ok(())
|
||||
if (ws.send(message).await).is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
};
|
||||
|
@ -253,8 +233,8 @@ impl ASGIWebsocketProtocol {
|
|||
let transport = self.ws_tx.clone();
|
||||
future_into_py_iter(self.rt.clone(), py, async move {
|
||||
if let Some(ws) = &mut *(transport.lock().await) {
|
||||
if let Ok(_) = ws.close().await {
|
||||
return Ok(())
|
||||
if (ws.close().await).is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
error_flow!()
|
||||
|
@ -262,10 +242,7 @@ impl ASGIWebsocketProtocol {
|
|||
}
|
||||
|
||||
fn consumed(&self) -> bool {
|
||||
match &self.upgrade {
|
||||
Some(_) => false,
|
||||
_ => true
|
||||
}
|
||||
self.upgrade.is_none()
|
||||
}
|
||||
|
||||
pub fn tx(&mut self) -> (Option<oneshot::Sender<bool>>, bool) {
|
||||
|
@ -278,44 +255,26 @@ impl ASGIWebsocketProtocol {
|
|||
fn receive<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> {
|
||||
let transport = self.ws_rx.clone();
|
||||
let accepted = self.accepted.clone();
|
||||
let closed = self.closed.clone();
|
||||
let closed = self.closed;
|
||||
future_into_py_futlike(self.rt.clone(), py, async move {
|
||||
let accepted = accepted.lock().await;
|
||||
match (*accepted, closed) {
|
||||
(false, false) => {
|
||||
return Python::with_gil(|py| {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item(
|
||||
pyo3::intern!(py, "type"),
|
||||
pyo3::intern!(py, "websocket.connect")
|
||||
)?;
|
||||
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.connect"))?;
|
||||
Ok(dict.to_object(py))
|
||||
})
|
||||
},
|
||||
(true, false) => {},
|
||||
_ => {
|
||||
return error_flow!()
|
||||
}
|
||||
(true, false) => {}
|
||||
_ => return error_flow!(),
|
||||
}
|
||||
if let Some(ws) = &mut *(transport.lock().await) {
|
||||
loop {
|
||||
match ws.next().await {
|
||||
Some(recv) => {
|
||||
match recv {
|
||||
Ok(Message::Ping(_)) => {
|
||||
continue
|
||||
},
|
||||
Ok(message) => {
|
||||
return ws_message_into_py(message)
|
||||
},
|
||||
_ => {
|
||||
break
|
||||
}
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
break
|
||||
}
|
||||
while let Some(recv) = ws.next().await {
|
||||
match recv {
|
||||
Ok(Message::Ping(_)) => continue,
|
||||
Ok(message) => return ws_message_into_py(message),
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -325,25 +284,17 @@ impl ASGIWebsocketProtocol {
|
|||
|
||||
fn send<'p>(&mut self, py: Python<'p>, _asyncw: &'p PyAny, data: &'p PyDict) -> PyResult<&'p PyAny> {
|
||||
match (adapt_message_type(data), self.closed) {
|
||||
(Ok(ASGIMessageType::WSAccept), _) => {
|
||||
self.accept(py)
|
||||
},
|
||||
(Ok(ASGIMessageType::WSClose), false) => {
|
||||
self.close(py)
|
||||
},
|
||||
(Ok(ASGIMessageType::WSMessage), false) => {
|
||||
self.send_message(py, data)
|
||||
},
|
||||
(Ok(ASGIMessageType::WSAccept), _) => self.accept(py),
|
||||
(Ok(ASGIMessageType::WSClose), false) => self.close(py),
|
||||
(Ok(ASGIMessageType::WSMessage), false) => self.send_message(py, data),
|
||||
(Err(err), _) => Err(err.into()),
|
||||
_ => error_message!()
|
||||
_ => error_message!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
fn adapt_message_type(
|
||||
message: &PyDict
|
||||
) -> Result<ASGIMessageType, UnsupportedASGIMessage> {
|
||||
fn adapt_message_type(message: &PyDict) -> Result<ASGIMessageType, UnsupportedASGIMessage> {
|
||||
match message.get_item("type") {
|
||||
Some(item) => {
|
||||
let message_type: &str = item.extract()?;
|
||||
|
@ -353,20 +304,18 @@ fn adapt_message_type(
|
|||
"websocket.accept" => Ok(ASGIMessageType::WSAccept),
|
||||
"websocket.close" => Ok(ASGIMessageType::WSClose),
|
||||
"websocket.send" => Ok(ASGIMessageType::WSMessage),
|
||||
_ => error_message!()
|
||||
_ => error_message!(),
|
||||
}
|
||||
},
|
||||
_ => error_message!()
|
||||
}
|
||||
_ => error_message!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn adapt_status_code(message: &PyDict) -> Result<i16, UnsupportedASGIMessage> {
|
||||
match message.get_item("status") {
|
||||
Some(item) => {
|
||||
Ok(item.extract()?)
|
||||
},
|
||||
_ => error_message!()
|
||||
Some(item) => Ok(item.extract()?),
|
||||
_ => error_message!(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -377,34 +326,26 @@ fn adapt_headers(message: &PyDict) -> HeaderMap {
|
|||
match message.get_item("headers") {
|
||||
Some(item) => {
|
||||
let accum: Vec<Vec<&[u8]>> = item.extract().unwrap_or(Vec::new());
|
||||
for tup in accum.iter() {
|
||||
match (
|
||||
HeaderName::from_bytes(tup[0]),
|
||||
HeaderValue::from_bytes(tup[1])
|
||||
) {
|
||||
(Ok(key), Ok(val)) => { ret.append(key, val); },
|
||||
_ => {}
|
||||
for tup in &accum {
|
||||
if let (Ok(key), Ok(val)) = (HeaderName::from_bytes(tup[0]), HeaderValue::from_bytes(tup[1])) {
|
||||
ret.append(key, val);
|
||||
}
|
||||
};
|
||||
}
|
||||
ret
|
||||
},
|
||||
_ => ret
|
||||
}
|
||||
_ => ret,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn adapt_body(message: &PyDict) -> (Vec<u8>, bool) {
|
||||
let body = match message.get_item("body") {
|
||||
Some(item) => {
|
||||
item.extract().unwrap_or(EMPTY_BYTES)
|
||||
},
|
||||
_ => EMPTY_BYTES
|
||||
Some(item) => item.extract().unwrap_or(EMPTY_BYTES),
|
||||
_ => EMPTY_BYTES,
|
||||
};
|
||||
let more = match message.get_item("more_body") {
|
||||
Some(item) => {
|
||||
item.extract().unwrap_or(false)
|
||||
},
|
||||
_ => false
|
||||
Some(item) => item.extract().unwrap_or(false),
|
||||
_ => false,
|
||||
};
|
||||
(body, more)
|
||||
}
|
||||
|
@ -412,22 +353,12 @@ fn adapt_body(message: &PyDict) -> (Vec<u8>, bool) {
|
|||
#[inline(always)]
|
||||
fn ws_message_into_rs(message: &PyDict) -> PyResult<Message> {
|
||||
match (message.get_item("bytes"), message.get_item("text")) {
|
||||
(Some(item), None) => {
|
||||
Ok(Message::Binary(item.extract().unwrap_or(EMPTY_BYTES)))
|
||||
},
|
||||
(None, Some(item)) => {
|
||||
Ok(Message::Text(item.extract().unwrap_or(EMPTY_STRING)))
|
||||
},
|
||||
(Some(itemb), Some(itemt)) => {
|
||||
match (itemb.extract().unwrap_or(None), itemt.extract().unwrap_or(None)) {
|
||||
(Some(msgb), None) => {
|
||||
Ok(Message::Binary(msgb))
|
||||
},
|
||||
(None, Some(msgt)) => {
|
||||
Ok(Message::Text(msgt))
|
||||
},
|
||||
_ => error_flow!()
|
||||
}
|
||||
(Some(item), None) => Ok(Message::Binary(item.extract().unwrap_or(EMPTY_BYTES))),
|
||||
(None, Some(item)) => Ok(Message::Text(item.extract().unwrap_or(EMPTY_STRING))),
|
||||
(Some(itemb), Some(itemt)) => match (itemb.extract().unwrap_or(None), itemt.extract().unwrap_or(None)) {
|
||||
(Some(msgb), None) => Ok(Message::Binary(msgb)),
|
||||
(None, Some(msgt)) => Ok(Message::Text(msgt)),
|
||||
_ => error_flow!(),
|
||||
},
|
||||
_ => {
|
||||
error_flow!()
|
||||
|
@ -438,41 +369,23 @@ fn ws_message_into_rs(message: &PyDict) -> PyResult<Message> {
|
|||
#[inline(always)]
|
||||
fn ws_message_into_py(message: Message) -> PyResult<PyObject> {
|
||||
match message {
|
||||
Message::Binary(message) => {
|
||||
Python::with_gil(|py| {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item(
|
||||
pyo3::intern!(py, "type"),
|
||||
pyo3::intern!(py, "websocket.receive")
|
||||
)?;
|
||||
dict.set_item(
|
||||
pyo3::intern!(py, "bytes"),
|
||||
PyBytes::new(py, &message[..])
|
||||
)?;
|
||||
Ok(dict.to_object(py))
|
||||
})
|
||||
},
|
||||
Message::Text(message) => {
|
||||
Python::with_gil(|py| {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item(
|
||||
pyo3::intern!(py, "type"),
|
||||
pyo3::intern!(py, "websocket.receive")
|
||||
)?;
|
||||
dict.set_item(pyo3::intern!(py, "text"), message)?;
|
||||
Ok(dict.to_object(py))
|
||||
})
|
||||
},
|
||||
Message::Close(_) => {
|
||||
Python::with_gil(|py| {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item(
|
||||
pyo3::intern!(py, "type"),
|
||||
pyo3::intern!(py, "websocket.disconnect")
|
||||
)?;
|
||||
Ok(dict.to_object(py))
|
||||
})
|
||||
},
|
||||
Message::Binary(message) => Python::with_gil(|py| {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.receive"))?;
|
||||
dict.set_item(pyo3::intern!(py, "bytes"), PyBytes::new(py, &message[..]))?;
|
||||
Ok(dict.to_object(py))
|
||||
}),
|
||||
Message::Text(message) => Python::with_gil(|py| {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.receive"))?;
|
||||
dict.set_item(pyo3::intern!(py, "text"), message)?;
|
||||
Ok(dict.to_object(py))
|
||||
}),
|
||||
Message::Close(_) => Python::with_gil(|py| {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.disconnect"))?;
|
||||
Ok(dict.to_object(py))
|
||||
}),
|
||||
v => {
|
||||
log::warn!("Unsupported websocket message received {:?}", v);
|
||||
error_flow!()
|
||||
|
|
|
@ -1,29 +1,14 @@
|
|||
use pyo3::prelude::*;
|
||||
|
||||
use crate::{
|
||||
workers::{
|
||||
WorkerConfig,
|
||||
serve_rth,
|
||||
serve_wth,
|
||||
serve_rth_ssl,
|
||||
serve_wth_ssl
|
||||
}
|
||||
};
|
||||
use super::http::{
|
||||
handle_rtb,
|
||||
handle_rtb_pyw,
|
||||
handle_rtt,
|
||||
handle_rtt_pyw,
|
||||
handle_rtb_ws,
|
||||
handle_rtb_ws_pyw,
|
||||
handle_rtt_ws,
|
||||
handle_rtt_ws_pyw
|
||||
handle_rtb, handle_rtb_pyw, handle_rtb_ws, handle_rtb_ws_pyw, handle_rtt, handle_rtt_pyw, handle_rtt_ws,
|
||||
handle_rtt_ws_pyw,
|
||||
};
|
||||
use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig};
|
||||
|
||||
|
||||
#[pyclass(module="granian._granian")]
|
||||
#[pyclass(module = "granian._granian")]
|
||||
pub struct ASGIWorker {
|
||||
config: WorkerConfig
|
||||
config: WorkerConfig,
|
||||
}
|
||||
|
||||
impl ASGIWorker {
|
||||
|
@ -74,7 +59,7 @@ impl ASGIWorker {
|
|||
opt_enabled: bool,
|
||||
ssl_enabled: bool,
|
||||
ssl_cert: Option<&str>,
|
||||
ssl_key: Option<&str>
|
||||
ssl_key: Option<&str>,
|
||||
) -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
config: WorkerConfig::new(
|
||||
|
@ -88,22 +73,16 @@ impl ASGIWorker {
|
|||
opt_enabled,
|
||||
ssl_enabled,
|
||||
ssl_cert,
|
||||
ssl_key
|
||||
)
|
||||
ssl_key,
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
fn serve_rth(
|
||||
&self,
|
||||
callback: PyObject,
|
||||
event_loop: &PyAny,
|
||||
context: &PyAny,
|
||||
signal_rx: PyObject
|
||||
) {
|
||||
fn serve_rth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
|
||||
match (
|
||||
self.config.websockets_enabled,
|
||||
self.config.ssl_enabled,
|
||||
self.config.opt_enabled
|
||||
self.config.opt_enabled,
|
||||
) {
|
||||
(false, false, true) => self._serve_rth(callback, event_loop, context, signal_rx),
|
||||
(false, false, false) => self._serve_rth_pyw(callback, event_loop, context, signal_rx),
|
||||
|
@ -112,21 +91,15 @@ impl ASGIWorker {
|
|||
(false, true, true) => self._serve_rth_ssl(callback, event_loop, context, signal_rx),
|
||||
(false, true, false) => self._serve_rth_ssl_pyw(callback, event_loop, context, signal_rx),
|
||||
(true, true, true) => self._serve_rth_ssl_ws(callback, event_loop, context, signal_rx),
|
||||
(true, true, false) => self._serve_rth_ssl_ws_pyw(callback, event_loop, context, signal_rx)
|
||||
(true, true, false) => self._serve_rth_ssl_ws_pyw(callback, event_loop, context, signal_rx),
|
||||
}
|
||||
}
|
||||
|
||||
fn serve_wth(
|
||||
&self,
|
||||
callback: PyObject,
|
||||
event_loop: &PyAny,
|
||||
context: &PyAny,
|
||||
signal_rx: PyObject
|
||||
) {
|
||||
fn serve_wth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
|
||||
match (
|
||||
self.config.websockets_enabled,
|
||||
self.config.ssl_enabled,
|
||||
self.config.opt_enabled
|
||||
self.config.opt_enabled,
|
||||
) {
|
||||
(false, false, true) => self._serve_wth(callback, event_loop, context, signal_rx),
|
||||
(false, false, false) => self._serve_wth_pyw(callback, event_loop, context, signal_rx),
|
||||
|
@ -135,7 +108,7 @@ impl ASGIWorker {
|
|||
(false, true, true) => self._serve_wth_ssl(callback, event_loop, context, signal_rx),
|
||||
(false, true, false) => self._serve_wth_ssl_pyw(callback, event_loop, context, signal_rx),
|
||||
(true, true, true) => self._serve_wth_ssl_ws(callback, event_loop, context, signal_rx),
|
||||
(true, true, false) => self._serve_wth_ssl_ws_pyw(callback, event_loop, context, signal_rx)
|
||||
(true, true, false) => self._serve_wth_ssl_ws_pyw(callback, event_loop, context, signal_rx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
use hyper::{Uri, Version, header::HeaderMap};
|
||||
use hyper::{header::HeaderMap, Uri, Version};
|
||||
use once_cell::sync::OnceCell;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyBytes, PyDict, PyList, PyString};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
|
||||
|
||||
const SCHEME_HTTPS: &str = "https";
|
||||
const SCHEME_WS: &str = "ws";
|
||||
const SCHEME_WSS: &str = "wss";
|
||||
|
@ -17,10 +16,10 @@ pub(crate) enum ASGIMessageType {
|
|||
HTTPBody,
|
||||
WSAccept,
|
||||
WSClose,
|
||||
WSMessage
|
||||
WSMessage,
|
||||
}
|
||||
|
||||
#[pyclass(module="granian._granian")]
|
||||
#[pyclass(module = "granian._granian")]
|
||||
pub(crate) struct ASGIScope {
|
||||
http_version: Version,
|
||||
scheme: String,
|
||||
|
@ -31,7 +30,7 @@ pub(crate) struct ASGIScope {
|
|||
client_ip: IpAddr,
|
||||
client_port: u16,
|
||||
headers: HeaderMap,
|
||||
is_websocket: bool
|
||||
is_websocket: bool,
|
||||
}
|
||||
|
||||
impl ASGIScope {
|
||||
|
@ -42,31 +41,31 @@ impl ASGIScope {
|
|||
method: &str,
|
||||
server: SocketAddr,
|
||||
client: SocketAddr,
|
||||
headers: &HeaderMap
|
||||
headers: &HeaderMap,
|
||||
) -> Self {
|
||||
Self {
|
||||
http_version: http_version,
|
||||
http_version,
|
||||
scheme: scheme.to_string(),
|
||||
method: method.to_string(),
|
||||
uri: uri,
|
||||
uri,
|
||||
server_ip: server.ip(),
|
||||
server_port: server.port(),
|
||||
client_ip: client.ip(),
|
||||
client_port: client.port(),
|
||||
headers: headers.to_owned(),
|
||||
is_websocket: false
|
||||
headers: headers.clone(),
|
||||
is_websocket: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_websocket(&mut self) {
|
||||
self.is_websocket = true
|
||||
self.is_websocket = true;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn py_proto(&self) -> &str {
|
||||
match self.is_websocket {
|
||||
false => "http",
|
||||
true => "websocket"
|
||||
true => "websocket",
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -76,7 +75,7 @@ impl ASGIScope {
|
|||
Version::HTTP_10 => "1",
|
||||
Version::HTTP_11 => "1.1",
|
||||
Version::HTTP_2 => "2",
|
||||
_ => "1"
|
||||
_ => "1",
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -85,22 +84,20 @@ impl ASGIScope {
|
|||
let scheme = &self.scheme[..];
|
||||
match self.is_websocket {
|
||||
false => scheme,
|
||||
true => {
|
||||
match scheme {
|
||||
SCHEME_HTTPS => SCHEME_WSS,
|
||||
_ => SCHEME_WS
|
||||
}
|
||||
}
|
||||
true => match scheme {
|
||||
SCHEME_HTTPS => SCHEME_WSS,
|
||||
_ => SCHEME_WS,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn py_headers<'p>(&self, py: Python<'p>) -> PyResult<&'p PyList> {
|
||||
let rv = PyList::empty(py);
|
||||
for (key, value) in self.headers.iter() {
|
||||
for (key, value) in &self.headers {
|
||||
rv.append((
|
||||
PyBytes::new(py, key.as_str().as_bytes()),
|
||||
PyBytes::new(py, value.as_bytes())
|
||||
PyBytes::new(py, value.as_bytes()),
|
||||
))?;
|
||||
}
|
||||
Ok(rv)
|
||||
|
@ -110,17 +107,10 @@ impl ASGIScope {
|
|||
#[pymethods]
|
||||
impl ASGIScope {
|
||||
fn as_dict<'p>(&self, py: Python<'p>, url_path_prefix: &'p str) -> PyResult<&'p PyAny> {
|
||||
let (
|
||||
path,
|
||||
query_string,
|
||||
proto,
|
||||
http_version,
|
||||
server,
|
||||
client,
|
||||
scheme,
|
||||
method
|
||||
) = py.allow_threads(|| {
|
||||
let (path, query_string) = self.uri.path_and_query()
|
||||
let (path, query_string, proto, http_version, server, client, scheme, method) = py.allow_threads(|| {
|
||||
let (path, query_string) = self
|
||||
.uri
|
||||
.path_and_query()
|
||||
.map_or_else(|| ("", ""), |pq| (pq.path(), pq.query().unwrap_or("")));
|
||||
(
|
||||
path,
|
||||
|
@ -136,19 +126,23 @@ impl ASGIScope {
|
|||
let dict: &PyDict = PyDict::new(py);
|
||||
dict.set_item(
|
||||
pyo3::intern!(py, "asgi"),
|
||||
ASGI_VERSION.get_or_try_init(|| {
|
||||
let rv = PyDict::new(py);
|
||||
rv.set_item("version", "3.0")?;
|
||||
rv.set_item("spec_version", "2.3")?;
|
||||
Ok::<PyObject, PyErr>(rv.into())
|
||||
})?.as_ref(py)
|
||||
ASGI_VERSION
|
||||
.get_or_try_init(|| {
|
||||
let rv = PyDict::new(py);
|
||||
rv.set_item("version", "3.0")?;
|
||||
rv.set_item("spec_version", "2.3")?;
|
||||
Ok::<PyObject, PyErr>(rv.into())
|
||||
})?
|
||||
.as_ref(py),
|
||||
)?;
|
||||
dict.set_item(
|
||||
pyo3::intern!(py, "extensions"),
|
||||
ASGI_EXTENSIONS.get_or_try_init(|| {
|
||||
let rv = PyDict::new(py);
|
||||
Ok::<PyObject, PyErr>(rv.into())
|
||||
})?.as_ref(py)
|
||||
ASGI_EXTENSIONS
|
||||
.get_or_try_init(|| {
|
||||
let rv = PyDict::new(py);
|
||||
Ok::<PyObject, PyErr>(rv.into())
|
||||
})?
|
||||
.as_ref(py),
|
||||
)?;
|
||||
dict.set_item(pyo3::intern!(py, "type"), proto)?;
|
||||
dict.set_item(pyo3::intern!(py, "http_version"), http_version)?;
|
||||
|
@ -160,17 +154,12 @@ impl ASGIScope {
|
|||
dict.set_item(pyo3::intern!(py, "path"), path)?;
|
||||
dict.set_item(
|
||||
pyo3::intern!(py, "raw_path"),
|
||||
PyString::new(py, path)
|
||||
.call_method1(
|
||||
pyo3::intern!(py, "encode"), (pyo3::intern!(py, "ascii"),)
|
||||
)?
|
||||
PyString::new(py, path).call_method1(pyo3::intern!(py, "encode"), (pyo3::intern!(py, "ascii"),))?,
|
||||
)?;
|
||||
dict.set_item(
|
||||
pyo3::intern!(py, "query_string"),
|
||||
PyString::new(py, query_string)
|
||||
.call_method1(
|
||||
pyo3::intern!(py, "encode"), (pyo3::intern!(py, "latin-1"),)
|
||||
)?
|
||||
.call_method1(pyo3::intern!(py, "encode"), (pyo3::intern!(py, "latin-1"),))?,
|
||||
)?;
|
||||
dict.set_item(pyo3::intern!(py, "headers"), self.py_headers(py)?)?;
|
||||
Ok(dict)
|
||||
|
|
161
src/callbacks.rs
161
src/callbacks.rs
|
@ -2,32 +2,27 @@ use once_cell::sync::OnceCell;
|
|||
use pyo3::prelude::*;
|
||||
use pyo3::pyclass::IterNextOutput;
|
||||
|
||||
|
||||
static CONTEXTVARS: OnceCell<PyObject> = OnceCell::new();
|
||||
static CONTEXT: OnceCell<PyObject> = OnceCell::new();
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct CallbackWrapper {
|
||||
pub callback: PyObject,
|
||||
pub context: pyo3_asyncio::TaskLocals
|
||||
pub context: pyo3_asyncio::TaskLocals,
|
||||
}
|
||||
|
||||
impl CallbackWrapper {
|
||||
pub(crate) fn new(
|
||||
callback: PyObject,
|
||||
event_loop: &PyAny,
|
||||
context: &PyAny
|
||||
) -> Self {
|
||||
pub(crate) fn new(callback: PyObject, event_loop: &PyAny, context: &PyAny) -> Self {
|
||||
Self {
|
||||
callback,
|
||||
context: pyo3_asyncio::TaskLocals::new(event_loop).with_context(context)
|
||||
context: pyo3_asyncio::TaskLocals::new(event_loop).with_context(context),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
pub(crate) struct PyIterAwaitable {
|
||||
result: Option<PyResult<PyObject>>
|
||||
result: Option<PyResult<PyObject>>,
|
||||
}
|
||||
|
||||
impl PyIterAwaitable {
|
||||
|
@ -36,7 +31,7 @@ impl PyIterAwaitable {
|
|||
}
|
||||
|
||||
pub(crate) fn set_result(&mut self, result: PyResult<PyObject>) {
|
||||
self.result = Some(result)
|
||||
self.result = Some(result);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -52,13 +47,11 @@ impl PyIterAwaitable {
|
|||
|
||||
fn __next__(&mut self, py: Python) -> PyResult<IterNextOutput<PyObject, PyObject>> {
|
||||
match self.result.take() {
|
||||
Some(res) => {
|
||||
match res {
|
||||
Ok(v) => Ok(IterNextOutput::Return(v)),
|
||||
Err(err) => Err(err)
|
||||
}
|
||||
Some(res) => match res {
|
||||
Ok(v) => Ok(IterNextOutput::Return(v)),
|
||||
Err(err) => Err(err),
|
||||
},
|
||||
_ => Ok(IterNextOutput::Yield(py.None()))
|
||||
_ => Ok(IterNextOutput::Yield(py.None())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -68,7 +61,7 @@ pub(crate) struct PyFutureAwaitable {
|
|||
py_block: bool,
|
||||
event_loop: PyObject,
|
||||
result: Option<PyResult<PyObject>>,
|
||||
cb: Option<(PyObject, Py<pyo3::types::PyDict>)>
|
||||
cb: Option<(PyObject, Py<pyo3::types::PyDict>)>,
|
||||
}
|
||||
|
||||
impl PyFutureAwaitable {
|
||||
|
@ -77,7 +70,7 @@ impl PyFutureAwaitable {
|
|||
event_loop,
|
||||
py_block: true,
|
||||
result: None,
|
||||
cb: None
|
||||
cb: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -85,12 +78,9 @@ impl PyFutureAwaitable {
|
|||
pyself.result = Some(result);
|
||||
if let Some((cb, ctx)) = pyself.cb.take() {
|
||||
let py = pyself.py();
|
||||
let _ = pyself.event_loop.call_method(
|
||||
py,
|
||||
"call_soon_threadsafe",
|
||||
(cb, &pyself),
|
||||
Some(ctx.as_ref(py))
|
||||
);
|
||||
let _ = pyself
|
||||
.event_loop
|
||||
.call_method(py, "call_soon_threadsafe", (cb, &pyself), Some(ctx.as_ref(py)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -108,25 +98,22 @@ impl PyFutureAwaitable {
|
|||
|
||||
#[setter(_asyncio_future_blocking)]
|
||||
fn set_block(&mut self, val: bool) {
|
||||
self.py_block = val
|
||||
self.py_block = val;
|
||||
}
|
||||
|
||||
fn get_loop(&mut self) -> PyObject {
|
||||
self.event_loop.clone()
|
||||
}
|
||||
|
||||
fn add_done_callback(
|
||||
mut pyself: PyRefMut<'_, Self>,
|
||||
py: Python,
|
||||
cb: PyObject,
|
||||
context: PyObject
|
||||
) -> PyResult<()> {
|
||||
fn add_done_callback(mut pyself: PyRefMut<'_, Self>, py: Python, cb: PyObject, context: PyObject) -> PyResult<()> {
|
||||
let kwctx = pyo3::types::PyDict::new(py);
|
||||
kwctx.set_item("context", context)?;
|
||||
match pyself.result {
|
||||
Some(_) => {
|
||||
pyself.event_loop.call_method(py, "call_soon", (cb, &pyself), Some(kwctx))?;
|
||||
},
|
||||
pyself
|
||||
.event_loop
|
||||
.call_method(py, "call_soon", (cb, &pyself), Some(kwctx))?;
|
||||
}
|
||||
_ => {
|
||||
pyself.cb = Some((cb, kwctx.into_py(py)));
|
||||
}
|
||||
|
@ -136,9 +123,9 @@ impl PyFutureAwaitable {
|
|||
|
||||
fn cancel(mut pyself: PyRefMut<'_, Self>, py: Python) -> bool {
|
||||
if let Some((cb, kwctx)) = pyself.cb.take() {
|
||||
let _ = pyself.event_loop.call_method(
|
||||
py, "call_soon", (cb, &pyself), Some(kwctx.as_ref(py))
|
||||
);
|
||||
let _ = pyself
|
||||
.event_loop
|
||||
.call_method(py, "call_soon", (cb, &pyself), Some(kwctx.as_ref(py)));
|
||||
}
|
||||
false
|
||||
}
|
||||
|
@ -150,79 +137,69 @@ impl PyFutureAwaitable {
|
|||
pyself
|
||||
}
|
||||
|
||||
fn __next__(
|
||||
mut pyself: PyRefMut<'_, Self>
|
||||
) -> PyResult<IterNextOutput<PyRefMut<'_, Self>, PyObject>> {
|
||||
fn __next__(mut pyself: PyRefMut<'_, Self>) -> PyResult<IterNextOutput<PyRefMut<'_, Self>, PyObject>> {
|
||||
match pyself.result {
|
||||
Some(_) => {
|
||||
match pyself.result.take().unwrap() {
|
||||
Ok(v) => Ok(IterNextOutput::Return(v)),
|
||||
Err(err) => Err(err)
|
||||
}
|
||||
Some(_) => match pyself.result.take().unwrap() {
|
||||
Ok(v) => Ok(IterNextOutput::Return(v)),
|
||||
Err(err) => Err(err),
|
||||
},
|
||||
_ => Ok(IterNextOutput::Yield(pyself))
|
||||
_ => Ok(IterNextOutput::Yield(pyself)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn contextvars(py: Python) -> PyResult<&PyAny> {
|
||||
Ok(CONTEXTVARS
|
||||
.get_or_try_init(|| py.import("contextvars").map(|m| m.into()))?
|
||||
.get_or_try_init(|| py.import("contextvars").map(std::convert::Into::into))?
|
||||
.as_ref(py))
|
||||
}
|
||||
|
||||
pub fn empty_pycontext(py: Python) -> PyResult<&PyAny> {
|
||||
Ok(CONTEXT
|
||||
.get_or_try_init(|| contextvars(py)?.getattr("Context")?.call0().map(|c| c.into()))?
|
||||
.get_or_try_init(|| {
|
||||
contextvars(py)?
|
||||
.getattr("Context")?
|
||||
.call0()
|
||||
.map(std::convert::Into::into)
|
||||
})?
|
||||
.as_ref(py))
|
||||
}
|
||||
|
||||
macro_rules! callback_impl_run {
|
||||
() => {
|
||||
pub fn run<'p>(self, py: Python<'p>) -> PyResult<&'p PyAny> {
|
||||
pub fn run(self, py: Python<'_>) -> PyResult<&PyAny> {
|
||||
let event_loop = self.context.event_loop(py);
|
||||
let target = self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_task"))?;
|
||||
let kwctx = pyo3::types::PyDict::new(py);
|
||||
kwctx.set_item(
|
||||
pyo3::intern!(py, "context"),
|
||||
crate::callbacks::empty_pycontext(py)?
|
||||
crate::callbacks::empty_pycontext(py)?,
|
||||
)?;
|
||||
event_loop.call_method(
|
||||
pyo3::intern!(py, "call_soon_threadsafe"),
|
||||
(target,),
|
||||
Some(kwctx)
|
||||
)
|
||||
event_loop.call_method(pyo3::intern!(py, "call_soon_threadsafe"), (target,), Some(kwctx))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! callback_impl_run_pytask {
|
||||
() => {
|
||||
pub fn run<'p>(self, py: Python<'p>) -> PyResult<&'p PyAny> {
|
||||
pub fn run(self, py: Python<'_>) -> PyResult<&PyAny> {
|
||||
let event_loop = self.context.event_loop(py);
|
||||
let context = self.context.context(py);
|
||||
let target = self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_task"))?;
|
||||
let kwctx = pyo3::types::PyDict::new(py);
|
||||
kwctx.set_item(
|
||||
pyo3::intern!(py, "context"),
|
||||
context
|
||||
)?;
|
||||
event_loop.call_method(
|
||||
pyo3::intern!(py, "call_soon_threadsafe"),
|
||||
(target,),
|
||||
Some(kwctx)
|
||||
)
|
||||
kwctx.set_item(pyo3::intern!(py, "context"), context)?;
|
||||
event_loop.call_method(pyo3::intern!(py, "call_soon_threadsafe"), (target,), Some(kwctx))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! callback_impl_loop_run {
|
||||
() => {
|
||||
pub fn run<'p>(self, py: Python<'p>) -> PyResult<&'p PyAny> {
|
||||
pub fn run(self, py: Python<'_>) -> PyResult<&PyAny> {
|
||||
let context = self.pycontext.clone().into_ref(py);
|
||||
context.call_method1(
|
||||
pyo3::intern!(py, "run"),
|
||||
(self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_step"))?,)
|
||||
(self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_step"))?,),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
@ -232,7 +209,7 @@ macro_rules! callback_impl_loop_pytask {
|
|||
($pyself:expr, $py:expr) => {
|
||||
$pyself.context.event_loop($py).call_method1(
|
||||
pyo3::intern!($py, "create_task"),
|
||||
($pyself.cb.clone().into_ref($py).call1(($pyself.into_py($py),))?,)
|
||||
($pyself.cb.clone().into_ref($py).call1(($pyself.into_py($py),))?,),
|
||||
)
|
||||
};
|
||||
}
|
||||
|
@ -241,12 +218,9 @@ macro_rules! callback_impl_loop_step {
|
|||
($pyself:expr, $py:expr) => {
|
||||
match $pyself.cb.call_method1($py, pyo3::intern!($py, "send"), ($py.None(),)) {
|
||||
Ok(res) => {
|
||||
let blocking: bool = match res.getattr(
|
||||
$py,
|
||||
pyo3::intern!($py, "_asyncio_future_blocking")
|
||||
) {
|
||||
let blocking: bool = match res.getattr($py, pyo3::intern!($py, "_asyncio_future_blocking")) {
|
||||
Ok(v) => v.extract($py)?,
|
||||
_ => false
|
||||
_ => false,
|
||||
};
|
||||
|
||||
let ctx = $pyself.pycontext.clone();
|
||||
|
@ -255,43 +229,30 @@ macro_rules! callback_impl_loop_step {
|
|||
|
||||
match blocking {
|
||||
true => {
|
||||
res.setattr(
|
||||
$py,
|
||||
pyo3::intern!($py, "_asyncio_future_blocking"),
|
||||
false
|
||||
)?;
|
||||
res.setattr($py, pyo3::intern!($py, "_asyncio_future_blocking"), false)?;
|
||||
res.call_method(
|
||||
$py,
|
||||
pyo3::intern!($py, "add_done_callback"),
|
||||
(
|
||||
$pyself
|
||||
.into_py($py)
|
||||
.getattr($py, pyo3::intern!($py, "_loop_wake"))?,
|
||||
),
|
||||
Some(kwctx)
|
||||
($pyself.into_py($py).getattr($py, pyo3::intern!($py, "_loop_wake"))?,),
|
||||
Some(kwctx),
|
||||
)?;
|
||||
Ok(())
|
||||
},
|
||||
}
|
||||
false => {
|
||||
let event_loop = $pyself.context.event_loop($py);
|
||||
event_loop.call_method(
|
||||
pyo3::intern!($py, "call_soon"),
|
||||
(
|
||||
$pyself
|
||||
.into_py($py)
|
||||
.getattr($py, pyo3::intern!($py, "_loop_step"))?,
|
||||
),
|
||||
Some(kwctx)
|
||||
($pyself.into_py($py).getattr($py, pyo3::intern!($py, "_loop_step"))?,),
|
||||
Some(kwctx),
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
Err(err) => {
|
||||
if (
|
||||
err.is_instance_of::<pyo3::exceptions::PyStopIteration>($py) ||
|
||||
err.is_instance_of::<pyo3::exceptions::asyncio::CancelledError>($py)
|
||||
) {
|
||||
if (err.is_instance_of::<pyo3::exceptions::PyStopIteration>($py)
|
||||
|| err.is_instance_of::<pyo3::exceptions::asyncio::CancelledError>($py))
|
||||
{
|
||||
$pyself.done($py);
|
||||
Ok(())
|
||||
} else {
|
||||
|
@ -307,7 +268,7 @@ macro_rules! callback_impl_loop_wake {
|
|||
($pyself:expr, $py:expr, $fut:expr) => {
|
||||
match $fut.call_method0($py, pyo3::intern!($py, "result")) {
|
||||
Ok(_) => $pyself.into_py($py).call_method0($py, pyo3::intern!($py, "_loop_step")),
|
||||
Err(err) => $pyself._loop_err($py, err)
|
||||
Err(err) => $pyself._loop_err($py, err),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -322,10 +283,10 @@ macro_rules! callback_impl_loop_err {
|
|||
};
|
||||
}
|
||||
|
||||
pub(crate) use callback_impl_run;
|
||||
pub(crate) use callback_impl_run_pytask;
|
||||
pub(crate) use callback_impl_loop_run;
|
||||
pub(crate) use callback_impl_loop_err;
|
||||
pub(crate) use callback_impl_loop_pytask;
|
||||
pub(crate) use callback_impl_loop_run;
|
||||
pub(crate) use callback_impl_loop_step;
|
||||
pub(crate) use callback_impl_loop_wake;
|
||||
pub(crate) use callback_impl_loop_err;
|
||||
pub(crate) use callback_impl_run;
|
||||
pub(crate) use callback_impl_run_pytask;
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
use hyper::{Body, Response, header::{HeaderValue, SERVER as HK_SERVER}};
|
||||
use hyper::{
|
||||
header::{HeaderValue, SERVER as HK_SERVER},
|
||||
Body, Response,
|
||||
};
|
||||
|
||||
pub(crate) const HV_SERVER: HeaderValue = HeaderValue::from_static("granian");
|
||||
|
||||
|
|
|
@ -8,8 +8,8 @@ mod callbacks;
|
|||
mod http;
|
||||
mod rsgi;
|
||||
mod runtime;
|
||||
mod tls;
|
||||
mod tcp;
|
||||
mod tls;
|
||||
mod utils;
|
||||
mod workers;
|
||||
mod ws;
|
||||
|
|
|
@ -2,45 +2,33 @@ use pyo3::prelude::*;
|
|||
use pyo3_asyncio::TaskLocals;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use crate::{
|
||||
callbacks::{
|
||||
CallbackWrapper,
|
||||
callback_impl_run,
|
||||
callback_impl_run_pytask,
|
||||
callback_impl_loop_run,
|
||||
callback_impl_loop_pytask,
|
||||
callback_impl_loop_step,
|
||||
callback_impl_loop_wake,
|
||||
callback_impl_loop_err
|
||||
},
|
||||
runtime::RuntimeRef,
|
||||
ws::{HyperWebsocket, UpgradeData}
|
||||
};
|
||||
use super::{
|
||||
io::{RSGIHTTPProtocol as HTTPProtocol, RSGIWebsocketProtocol as WebsocketProtocol},
|
||||
types::{RSGIScope as Scope, PyResponse, PyResponseBody}
|
||||
types::{PyResponse, PyResponseBody, RSGIScope as Scope},
|
||||
};
|
||||
use crate::{
|
||||
callbacks::{
|
||||
callback_impl_loop_err, callback_impl_loop_pytask, callback_impl_loop_run, callback_impl_loop_step,
|
||||
callback_impl_loop_wake, callback_impl_run, callback_impl_run_pytask, CallbackWrapper,
|
||||
},
|
||||
runtime::RuntimeRef,
|
||||
ws::{HyperWebsocket, UpgradeData},
|
||||
};
|
||||
|
||||
|
||||
#[pyclass]
|
||||
pub(crate) struct CallbackRunnerHTTP {
|
||||
proto: Py<HTTPProtocol>,
|
||||
context: TaskLocals,
|
||||
cb: PyObject
|
||||
cb: PyObject,
|
||||
}
|
||||
|
||||
impl CallbackRunnerHTTP {
|
||||
pub fn new(
|
||||
py: Python,
|
||||
cb: CallbackWrapper,
|
||||
proto: HTTPProtocol,
|
||||
scope: Scope
|
||||
) -> Self {
|
||||
pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self {
|
||||
let pyproto = Py::new(py, proto).unwrap();
|
||||
Self {
|
||||
proto: pyproto.clone(),
|
||||
context: cb.context,
|
||||
cb: cb.callback.call1(py, (scope, pyproto)).unwrap()
|
||||
cb: cb.callback.call1(py, (scope, pyproto)).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -58,19 +46,17 @@ macro_rules! callback_impl_done_http {
|
|||
($self:expr, $py:expr) => {
|
||||
if let Ok(mut proto) = $self.proto.as_ref($py).try_borrow_mut() {
|
||||
if let Some(tx) = proto.tx() {
|
||||
let _ = tx.send(
|
||||
PyResponse::Body(PyResponseBody::empty(500, Vec::new()))
|
||||
);
|
||||
let _ = tx.send(PyResponse::Body(PyResponseBody::empty(500, Vec::new())));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! callback_impl_done_err {
|
||||
($self:expr, $py:expr) => {
|
||||
log::warn!("Application callable raised an exception");
|
||||
$self.done($py)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
|
@ -78,18 +64,18 @@ pub(crate) struct CallbackTaskHTTP {
|
|||
proto: Py<HTTPProtocol>,
|
||||
context: TaskLocals,
|
||||
pycontext: PyObject,
|
||||
cb: PyObject
|
||||
cb: PyObject,
|
||||
}
|
||||
|
||||
impl CallbackTaskHTTP {
|
||||
pub fn new(
|
||||
py: Python,
|
||||
cb: PyObject,
|
||||
proto: Py<HTTPProtocol>,
|
||||
context: TaskLocals
|
||||
) -> PyResult<Self> {
|
||||
pub fn new(py: Python, cb: PyObject, proto: Py<HTTPProtocol>, context: TaskLocals) -> PyResult<Self> {
|
||||
let pyctx = context.context(py);
|
||||
Ok(Self { proto, context, pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), cb })
|
||||
Ok(Self {
|
||||
proto,
|
||||
context,
|
||||
pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(),
|
||||
cb,
|
||||
})
|
||||
}
|
||||
|
||||
fn done(&self, py: Python) {
|
||||
|
@ -122,21 +108,16 @@ pub(crate) struct CallbackWrappedRunnerHTTP {
|
|||
context: TaskLocals,
|
||||
cb: PyObject,
|
||||
#[pyo3(get)]
|
||||
scope: PyObject
|
||||
scope: PyObject,
|
||||
}
|
||||
|
||||
impl CallbackWrappedRunnerHTTP {
|
||||
pub fn new(
|
||||
py: Python,
|
||||
cb: CallbackWrapper,
|
||||
proto: HTTPProtocol,
|
||||
scope: Scope
|
||||
) -> Self {
|
||||
pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self {
|
||||
Self {
|
||||
proto: Py::new(py, proto).unwrap(),
|
||||
context: cb.context,
|
||||
cb: cb.callback,
|
||||
scope: scope.into_py(py)
|
||||
scope: scope.into_py(py),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -162,21 +143,16 @@ impl CallbackWrappedRunnerHTTP {
|
|||
pub(crate) struct CallbackRunnerWebsocket {
|
||||
proto: Py<WebsocketProtocol>,
|
||||
context: TaskLocals,
|
||||
cb: PyObject
|
||||
cb: PyObject,
|
||||
}
|
||||
|
||||
impl CallbackRunnerWebsocket {
|
||||
pub fn new(
|
||||
py: Python,
|
||||
cb: CallbackWrapper,
|
||||
proto: WebsocketProtocol,
|
||||
scope: Scope
|
||||
) -> Self {
|
||||
pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self {
|
||||
let pyproto = Py::new(py, proto).unwrap();
|
||||
Self {
|
||||
proto: pyproto.clone(),
|
||||
context: cb.context,
|
||||
cb: cb.callback.call1(py, (scope, pyproto)).unwrap()
|
||||
cb: cb.callback.call1(py, (scope, pyproto)).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -197,7 +173,7 @@ macro_rules! callback_impl_done_ws {
|
|||
let _ = tx.send(res);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
|
@ -205,18 +181,18 @@ pub(crate) struct CallbackTaskWebsocket {
|
|||
proto: Py<WebsocketProtocol>,
|
||||
context: TaskLocals,
|
||||
pycontext: PyObject,
|
||||
cb: PyObject
|
||||
cb: PyObject,
|
||||
}
|
||||
|
||||
impl CallbackTaskWebsocket {
|
||||
pub fn new(
|
||||
py: Python,
|
||||
cb: PyObject,
|
||||
proto: Py<WebsocketProtocol>,
|
||||
context: TaskLocals
|
||||
) -> PyResult<Self> {
|
||||
pub fn new(py: Python, cb: PyObject, proto: Py<WebsocketProtocol>, context: TaskLocals) -> PyResult<Self> {
|
||||
let pyctx = context.context(py);
|
||||
Ok(Self { proto, context, pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), cb })
|
||||
Ok(Self {
|
||||
proto,
|
||||
context,
|
||||
pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(),
|
||||
cb,
|
||||
})
|
||||
}
|
||||
|
||||
fn done(&self, py: Python) {
|
||||
|
@ -249,21 +225,16 @@ pub(crate) struct CallbackWrappedRunnerWebsocket {
|
|||
context: TaskLocals,
|
||||
cb: PyObject,
|
||||
#[pyo3(get)]
|
||||
scope: PyObject
|
||||
scope: PyObject,
|
||||
}
|
||||
|
||||
impl CallbackWrappedRunnerWebsocket {
|
||||
pub fn new(
|
||||
py: Python,
|
||||
cb: CallbackWrapper,
|
||||
proto: WebsocketProtocol,
|
||||
scope: Scope
|
||||
) -> Self {
|
||||
pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self {
|
||||
Self {
|
||||
proto: Py::new(py, proto).unwrap(),
|
||||
context: cb.context,
|
||||
cb: cb.callback,
|
||||
scope: scope.into_py(py)
|
||||
scope: scope.into_py(py),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -291,7 +262,7 @@ macro_rules! call_impl_rtb_http {
|
|||
cb: CallbackWrapper,
|
||||
rt: RuntimeRef,
|
||||
req: hyper::Request<hyper::Body>,
|
||||
scope: Scope
|
||||
scope: Scope,
|
||||
) -> oneshot::Receiver<PyResponse> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let protocol = HTTPProtocol::new(rt, tx, req);
|
||||
|
@ -311,7 +282,7 @@ macro_rules! call_impl_rtt_http {
|
|||
cb: CallbackWrapper,
|
||||
rt: RuntimeRef,
|
||||
req: hyper::Request<hyper::Body>,
|
||||
scope: Scope
|
||||
scope: Scope,
|
||||
) -> oneshot::Receiver<PyResponse> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let protocol = HTTPProtocol::new(rt, tx, req);
|
||||
|
@ -334,7 +305,7 @@ macro_rules! call_impl_rtb_ws {
|
|||
rt: RuntimeRef,
|
||||
ws: HyperWebsocket,
|
||||
upgrade: UpgradeData,
|
||||
scope: Scope
|
||||
scope: Scope,
|
||||
) -> oneshot::Receiver<(i32, bool)> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade);
|
||||
|
@ -355,7 +326,7 @@ macro_rules! call_impl_rtt_ws {
|
|||
rt: RuntimeRef,
|
||||
ws: HyperWebsocket,
|
||||
upgrade: UpgradeData,
|
||||
scope: Scope
|
||||
scope: Scope,
|
||||
) -> oneshot::Receiver<(i32, bool)> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade);
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use pyo3::{create_exception, exceptions::PyRuntimeError};
|
||||
|
||||
|
||||
create_exception!(_granian, RSGIProtocolError, PyRuntimeError, "RSGIProtocolError");
|
||||
create_exception!(_granian, RSGIProtocolClosed, PyRuntimeError, "RSGIProtocolClosed");
|
||||
|
||||
|
|
|
@ -1,34 +1,22 @@
|
|||
use hyper::{
|
||||
Body,
|
||||
Request,
|
||||
Response,
|
||||
StatusCode,
|
||||
header::SERVER as HK_SERVER,
|
||||
http::response::Builder as ResponseBuilder
|
||||
header::SERVER as HK_SERVER, http::response::Builder as ResponseBuilder, Body, Request, Response, StatusCode,
|
||||
};
|
||||
use std::net::SocketAddr;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::{
|
||||
callbacks::CallbackWrapper,
|
||||
http::{HV_SERVER, response_500},
|
||||
runtime::RuntimeRef,
|
||||
ws::{UpgradeData, is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade}
|
||||
};
|
||||
use super::{
|
||||
callbacks::{
|
||||
call_rtb_http,
|
||||
call_rtb_http_pyw,
|
||||
call_rtb_ws,
|
||||
call_rtb_ws_pyw,
|
||||
call_rtt_http,
|
||||
call_rtt_http_pyw,
|
||||
call_rtt_ws,
|
||||
call_rtt_ws_pyw
|
||||
call_rtb_http, call_rtb_http_pyw, call_rtb_ws, call_rtb_ws_pyw, call_rtt_http, call_rtt_http_pyw, call_rtt_ws,
|
||||
call_rtt_ws_pyw,
|
||||
},
|
||||
types::{RSGIScope as Scope, PyResponse}
|
||||
types::{PyResponse, RSGIScope as Scope},
|
||||
};
|
||||
use crate::{
|
||||
callbacks::CallbackWrapper,
|
||||
http::{response_500, HV_SERVER},
|
||||
runtime::RuntimeRef,
|
||||
ws::{is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade, UpgradeData},
|
||||
};
|
||||
|
||||
|
||||
macro_rules! default_scope {
|
||||
($server_addr:expr, $client_addr:expr, $req:expr, $scheme:expr) => {
|
||||
|
@ -40,7 +28,7 @@ macro_rules! default_scope {
|
|||
$req.method().as_ref(),
|
||||
$server_addr,
|
||||
$client_addr,
|
||||
$req.headers()
|
||||
$req.headers(),
|
||||
)
|
||||
};
|
||||
}
|
||||
|
@ -48,12 +36,8 @@ macro_rules! default_scope {
|
|||
macro_rules! handle_http_response {
|
||||
($handler:expr, $rt:expr, $callback:expr, $req:expr, $scope:expr) => {
|
||||
match $handler($callback, $rt, $req, $scope).await {
|
||||
Ok(PyResponse::Body(pyres)) => {
|
||||
pyres.to_response()
|
||||
},
|
||||
Ok(PyResponse::File(pyres)) => {
|
||||
pyres.to_response().await
|
||||
},
|
||||
Ok(PyResponse::Body(pyres)) => pyres.to_response(),
|
||||
Ok(PyResponse::File(pyres)) => pyres.to_response().await,
|
||||
_ => {
|
||||
log::error!("RSGI protocol failure");
|
||||
response_500()
|
||||
|
@ -70,7 +54,7 @@ macro_rules! handle_request {
|
|||
server_addr: SocketAddr,
|
||||
client_addr: SocketAddr,
|
||||
req: Request<Body>,
|
||||
scheme: &str
|
||||
scheme: &str,
|
||||
) -> Response<Body> {
|
||||
let scope = default_scope!(server_addr, client_addr, &req, scheme);
|
||||
handle_http_response!($handler, rt, callback, req, scope)
|
||||
|
@ -86,7 +70,7 @@ macro_rules! handle_request_with_ws {
|
|||
server_addr: SocketAddr,
|
||||
client_addr: SocketAddr,
|
||||
req: Request<Body>,
|
||||
scheme: &str
|
||||
scheme: &str,
|
||||
) -> Response<Body> {
|
||||
let mut scope = default_scope!(server_addr, client_addr, &req, scheme);
|
||||
|
||||
|
@ -101,27 +85,23 @@ macro_rules! handle_request_with_ws {
|
|||
rt.inner.spawn(async move {
|
||||
let tx_ref = restx.clone();
|
||||
|
||||
match $handler_ws(
|
||||
callback,
|
||||
rth,
|
||||
ws,
|
||||
UpgradeData::new(res, restx),
|
||||
scope
|
||||
).await {
|
||||
match $handler_ws(callback, rth, ws, UpgradeData::new(res, restx), scope).await {
|
||||
Ok((status, consumed)) => {
|
||||
if !consumed {
|
||||
let _ = tx_ref.send(
|
||||
ResponseBuilder::new()
|
||||
.status(
|
||||
StatusCode::from_u16(status as u16)
|
||||
.unwrap_or(StatusCode::FORBIDDEN)
|
||||
)
|
||||
.header(HK_SERVER, HV_SERVER)
|
||||
.body(Body::from(""))
|
||||
.unwrap()
|
||||
).await;
|
||||
let _ = tx_ref
|
||||
.send(
|
||||
ResponseBuilder::new()
|
||||
.status(
|
||||
StatusCode::from_u16(status as u16)
|
||||
.unwrap_or(StatusCode::FORBIDDEN),
|
||||
)
|
||||
.header(HK_SERVER, HV_SERVER)
|
||||
.body(Body::from(""))
|
||||
.unwrap(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
},
|
||||
}
|
||||
_ => {
|
||||
log::error!("RSGI protocol failure");
|
||||
let _ = tx_ref.send(response_500()).await;
|
||||
|
@ -133,10 +113,10 @@ macro_rules! handle_request_with_ws {
|
|||
Some(res) => {
|
||||
resrx.close();
|
||||
res
|
||||
},
|
||||
_ => response_500()
|
||||
}
|
||||
},
|
||||
}
|
||||
_ => response_500(),
|
||||
};
|
||||
}
|
||||
Err(err) => {
|
||||
return ResponseBuilder::new()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
|
@ -149,7 +129,6 @@ macro_rules! handle_request_with_ws {
|
|||
|
||||
handle_http_response!($handler_req, rt, callback, req, scope)
|
||||
}
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
|
|
232
src/rsgi/io.rs
232
src/rsgi/io.rs
|
@ -1,35 +1,40 @@
|
|||
use bytes::Bytes;
|
||||
use futures::{sink::SinkExt, stream::{SplitSink, SplitStream, StreamExt}};
|
||||
use hyper::{body::{Body, Sender as BodySender, HttpBody}, Request};
|
||||
use futures::{
|
||||
sink::SinkExt,
|
||||
stream::{SplitSink, SplitStream, StreamExt},
|
||||
};
|
||||
use hyper::{
|
||||
body::{Body, HttpBody, Sender as BodySender},
|
||||
Request,
|
||||
};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyBytes, PyString};
|
||||
use std::sync::Arc;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio::sync::{oneshot, Mutex};
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tungstenite::Message;
|
||||
|
||||
use crate::{
|
||||
runtime::{Runtime, RuntimeRef, future_into_py_iter, future_into_py_futlike},
|
||||
ws::{HyperWebsocket, UpgradeData}
|
||||
};
|
||||
use super::{
|
||||
errors::{error_proto, error_stream},
|
||||
types::{PyResponse, PyResponseBody, PyResponseFile}
|
||||
types::{PyResponse, PyResponseBody, PyResponseFile},
|
||||
};
|
||||
use crate::{
|
||||
runtime::{future_into_py_futlike, future_into_py_iter, Runtime, RuntimeRef},
|
||||
ws::{HyperWebsocket, UpgradeData},
|
||||
};
|
||||
|
||||
|
||||
#[pyclass(module="granian._granian")]
|
||||
#[pyclass(module = "granian._granian")]
|
||||
pub(crate) struct RSGIHTTPStreamTransport {
|
||||
rt: RuntimeRef,
|
||||
tx: Arc<Mutex<BodySender>>
|
||||
tx: Arc<Mutex<BodySender>>,
|
||||
}
|
||||
|
||||
impl RSGIHTTPStreamTransport {
|
||||
pub fn new(
|
||||
rt: RuntimeRef,
|
||||
transport: BodySender
|
||||
) -> Self {
|
||||
Self { rt: rt, tx: Arc::new(Mutex::new(transport)) }
|
||||
pub fn new(rt: RuntimeRef, transport: BodySender) -> Self {
|
||||
Self {
|
||||
rt,
|
||||
tx: Arc::new(Mutex::new(transport)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -41,8 +46,8 @@ impl RSGIHTTPStreamTransport {
|
|||
if let Ok(mut stream) = transport.try_lock() {
|
||||
return match stream.send_data(data.into()).await {
|
||||
Ok(_) => Ok(()),
|
||||
_ => error_stream!()
|
||||
}
|
||||
_ => error_stream!(),
|
||||
};
|
||||
}
|
||||
error_proto!()
|
||||
})
|
||||
|
@ -54,31 +59,27 @@ impl RSGIHTTPStreamTransport {
|
|||
if let Ok(mut stream) = transport.try_lock() {
|
||||
return match stream.send_data(data.into()).await {
|
||||
Ok(_) => Ok(()),
|
||||
_ => error_stream!()
|
||||
}
|
||||
_ => error_stream!(),
|
||||
};
|
||||
}
|
||||
error_proto!()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(module="granian._granian")]
|
||||
#[pyclass(module = "granian._granian")]
|
||||
pub(crate) struct RSGIHTTPProtocol {
|
||||
rt: RuntimeRef,
|
||||
tx: Option<oneshot::Sender<super::types::PyResponse>>,
|
||||
body: Arc<Mutex<Body>>
|
||||
body: Arc<Mutex<Body>>,
|
||||
}
|
||||
|
||||
impl RSGIHTTPProtocol {
|
||||
pub fn new(
|
||||
rt: RuntimeRef,
|
||||
tx: oneshot::Sender<super::types::PyResponse>,
|
||||
request: Request<Body>
|
||||
) -> Self {
|
||||
pub fn new(rt: RuntimeRef, tx: oneshot::Sender<super::types::PyResponse>, request: Request<Body>) -> Self {
|
||||
Self {
|
||||
rt,
|
||||
tx: Some(tx),
|
||||
body: Arc::new(Mutex::new(request.into_body()))
|
||||
body: Arc::new(Mutex::new(request.into_body())),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,14 +111,13 @@ impl RSGIHTTPProtocol {
|
|||
let mut bodym = body_ref.lock().await;
|
||||
let body = &mut *bodym;
|
||||
if body.is_end_stream() {
|
||||
return Err(pyo3::exceptions::PyStopAsyncIteration::new_err("stream exhausted"))
|
||||
return Err(pyo3::exceptions::PyStopAsyncIteration::new_err("stream exhausted"));
|
||||
}
|
||||
let chunk = body.data().await.map_or_else(|| Bytes::new(), |buf| {
|
||||
buf.unwrap_or_else(|_| Bytes::new())
|
||||
});
|
||||
Ok(Python::with_gil(|py| {
|
||||
PyBytes::new(py, &chunk[..]).to_object(py)
|
||||
}))
|
||||
let chunk = body
|
||||
.data()
|
||||
.await
|
||||
.map_or_else(Bytes::new, |buf| buf.unwrap_or_else(|_| Bytes::new()));
|
||||
Ok(Python::with_gil(|py| PyBytes::new(py, &chunk[..]).to_object(py)))
|
||||
})?;
|
||||
Ok(Some(fut))
|
||||
}
|
||||
|
@ -125,36 +125,28 @@ impl RSGIHTTPProtocol {
|
|||
#[pyo3(signature = (status=200, headers=vec![]))]
|
||||
fn response_empty(&mut self, status: u16, headers: Vec<(String, String)>) {
|
||||
if let Some(tx) = self.tx.take() {
|
||||
let _ = tx.send(
|
||||
PyResponse::Body(PyResponseBody::empty(status, headers))
|
||||
);
|
||||
let _ = tx.send(PyResponse::Body(PyResponseBody::empty(status, headers)));
|
||||
}
|
||||
}
|
||||
|
||||
#[pyo3(signature = (status=200, headers=vec![], body=vec![]))]
|
||||
fn response_bytes(&mut self, status: u16, headers: Vec<(String, String)>, body: Vec<u8>) {
|
||||
if let Some(tx) = self.tx.take() {
|
||||
let _ = tx.send(
|
||||
PyResponse::Body(PyResponseBody::from_bytes(status, headers, body))
|
||||
);
|
||||
let _ = tx.send(PyResponse::Body(PyResponseBody::from_bytes(status, headers, body)));
|
||||
}
|
||||
}
|
||||
|
||||
#[pyo3(signature = (status=200, headers=vec![], body="".to_string()))]
|
||||
#[pyo3(signature = (status=200, headers=vec![], body=String::new()))]
|
||||
fn response_str(&mut self, status: u16, headers: Vec<(String, String)>, body: String) {
|
||||
if let Some(tx) = self.tx.take() {
|
||||
let _ = tx.send(
|
||||
PyResponse::Body(PyResponseBody::from_string(status, headers, body))
|
||||
);
|
||||
let _ = tx.send(PyResponse::Body(PyResponseBody::from_string(status, headers, body)));
|
||||
}
|
||||
}
|
||||
|
||||
#[pyo3(signature = (status, headers, file))]
|
||||
fn response_file(&mut self, status: u16, headers: Vec<(String, String)>, file: String) {
|
||||
if let Some(tx) = self.tx.take() {
|
||||
let _ = tx.send(
|
||||
PyResponse::File(PyResponseFile::new(status, headers, file))
|
||||
);
|
||||
let _ = tx.send(PyResponse::File(PyResponseFile::new(status, headers, file)));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -163,34 +155,33 @@ impl RSGIHTTPProtocol {
|
|||
&mut self,
|
||||
py: Python<'p>,
|
||||
status: u16,
|
||||
headers: Vec<(String, String)>
|
||||
headers: Vec<(String, String)>,
|
||||
) -> PyResult<&'p PyAny> {
|
||||
if let Some(tx) = self.tx.take() {
|
||||
let (body_tx, body_stream) = Body::channel();
|
||||
let _ = tx.send(
|
||||
PyResponse::Body(PyResponseBody::new(status, headers, body_stream))
|
||||
);
|
||||
let _ = tx.send(PyResponse::Body(PyResponseBody::new(status, headers, body_stream)));
|
||||
let trx = Py::new(py, RSGIHTTPStreamTransport::new(self.rt.clone(), body_tx))?;
|
||||
return Ok(trx.into_ref(py))
|
||||
return Ok(trx.into_ref(py));
|
||||
}
|
||||
error_proto!()
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(module="granian._granian")]
|
||||
#[pyclass(module = "granian._granian")]
|
||||
pub(crate) struct RSGIWebsocketTransport {
|
||||
rt: RuntimeRef,
|
||||
tx: Arc<Mutex<SplitSink<WebSocketStream<hyper::upgrade::Upgraded>, Message>>>,
|
||||
rx: Arc<Mutex<SplitStream<WebSocketStream<hyper::upgrade::Upgraded>>>>
|
||||
rx: Arc<Mutex<SplitStream<WebSocketStream<hyper::upgrade::Upgraded>>>>,
|
||||
}
|
||||
|
||||
impl RSGIWebsocketTransport {
|
||||
pub fn new(
|
||||
rt: RuntimeRef,
|
||||
transport: WebSocketStream<hyper::upgrade::Upgraded>
|
||||
) -> Self {
|
||||
pub fn new(rt: RuntimeRef, transport: WebSocketStream<hyper::upgrade::Upgraded>) -> Self {
|
||||
let (tx, rx) = transport.split();
|
||||
Self { rt: rt, tx: Arc::new(Mutex::new(tx)), rx: Arc::new(Mutex::new(rx)) }
|
||||
Self {
|
||||
rt,
|
||||
tx: Arc::new(Mutex::new(tx)),
|
||||
rx: Arc::new(Mutex::new(rx)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn close(&self) {
|
||||
|
@ -209,27 +200,14 @@ impl RSGIWebsocketTransport {
|
|||
let transport = self.rx.clone();
|
||||
future_into_py_futlike(self.rt.clone(), py, async move {
|
||||
if let Ok(mut stream) = transport.try_lock() {
|
||||
loop {
|
||||
match stream.next().await {
|
||||
Some(recv) => {
|
||||
match recv {
|
||||
Ok(Message::Ping(_)) => {
|
||||
continue
|
||||
},
|
||||
Ok(message) => {
|
||||
return message_into_py(message)
|
||||
},
|
||||
_ => {
|
||||
break
|
||||
}
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
break
|
||||
}
|
||||
while let Some(recv) = stream.next().await {
|
||||
match recv {
|
||||
Ok(Message::Ping(_)) => continue,
|
||||
Ok(message) => return message_into_py(message),
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
return error_stream!()
|
||||
return error_stream!();
|
||||
}
|
||||
error_proto!()
|
||||
})
|
||||
|
@ -241,8 +219,8 @@ impl RSGIWebsocketTransport {
|
|||
if let Ok(mut stream) = transport.try_lock() {
|
||||
return match stream.send(Message::Binary(data)).await {
|
||||
Ok(_) => Ok(()),
|
||||
_ => error_stream!()
|
||||
}
|
||||
_ => error_stream!(),
|
||||
};
|
||||
}
|
||||
error_proto!()
|
||||
})
|
||||
|
@ -254,22 +232,22 @@ impl RSGIWebsocketTransport {
|
|||
if let Ok(mut stream) = transport.try_lock() {
|
||||
return match stream.send(Message::Text(data)).await {
|
||||
Ok(_) => Ok(()),
|
||||
_ => error_stream!()
|
||||
}
|
||||
_ => error_stream!(),
|
||||
};
|
||||
}
|
||||
error_proto!()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(module="granian._granian")]
|
||||
#[pyclass(module = "granian._granian")]
|
||||
pub(crate) struct RSGIWebsocketProtocol {
|
||||
rt: RuntimeRef,
|
||||
tx: Option<oneshot::Sender<(i32, bool)>>,
|
||||
websocket: Arc<Mutex<HyperWebsocket>>,
|
||||
upgrade: Option<UpgradeData>,
|
||||
transport: Arc<Mutex<Option<Py<RSGIWebsocketTransport>>>>,
|
||||
status: i32
|
||||
status: i32,
|
||||
}
|
||||
|
||||
impl RSGIWebsocketProtocol {
|
||||
|
@ -277,7 +255,7 @@ impl RSGIWebsocketProtocol {
|
|||
rt: RuntimeRef,
|
||||
tx: oneshot::Sender<(i32, bool)>,
|
||||
websocket: HyperWebsocket,
|
||||
upgrade: UpgradeData
|
||||
upgrade: UpgradeData,
|
||||
) -> Self {
|
||||
Self {
|
||||
rt,
|
||||
|
@ -285,15 +263,12 @@ impl RSGIWebsocketProtocol {
|
|||
websocket: Arc::new(Mutex::new(websocket)),
|
||||
upgrade: Some(upgrade),
|
||||
transport: Arc::new(Mutex::new(None)),
|
||||
status: 0
|
||||
status: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn consumed(&self) -> bool {
|
||||
match &self.upgrade {
|
||||
Some(_) => false,
|
||||
_ => true
|
||||
}
|
||||
self.upgrade.is_none()
|
||||
}
|
||||
|
||||
pub fn tx(&mut self) -> (Option<oneshot::Sender<(i32, bool)>>, (i32, bool)) {
|
||||
|
@ -304,18 +279,20 @@ impl RSGIWebsocketProtocol {
|
|||
enum WebsocketMessageType {
|
||||
Close = 0,
|
||||
Bytes = 1,
|
||||
Text = 2
|
||||
Text = 2,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct WebsocketInboundCloseMessage {
|
||||
#[pyo3(get)]
|
||||
kind: usize
|
||||
kind: usize,
|
||||
}
|
||||
|
||||
impl WebsocketInboundCloseMessage {
|
||||
pub fn new() -> Self {
|
||||
Self { kind: WebsocketMessageType::Close as usize }
|
||||
Self {
|
||||
kind: WebsocketMessageType::Close as usize,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -324,12 +301,15 @@ struct WebsocketInboundBytesMessage {
|
|||
#[pyo3(get)]
|
||||
kind: usize,
|
||||
#[pyo3(get)]
|
||||
data: Py<PyBytes>
|
||||
data: Py<PyBytes>,
|
||||
}
|
||||
|
||||
impl WebsocketInboundBytesMessage {
|
||||
pub fn new(data:Py<PyBytes>) -> Self {
|
||||
Self { kind: WebsocketMessageType::Bytes as usize, data: data }
|
||||
pub fn new(data: Py<PyBytes>) -> Self {
|
||||
Self {
|
||||
kind: WebsocketMessageType::Bytes as usize,
|
||||
data,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -338,12 +318,15 @@ struct WebsocketInboundTextMessage {
|
|||
#[pyo3(get)]
|
||||
kind: usize,
|
||||
#[pyo3(get)]
|
||||
data: Py<PyString>
|
||||
data: Py<PyString>,
|
||||
}
|
||||
|
||||
impl WebsocketInboundTextMessage {
|
||||
pub fn new(data: Py<PyString>) -> Self {
|
||||
Self { kind: WebsocketMessageType::Text as usize, data: data }
|
||||
Self {
|
||||
kind: WebsocketMessageType::Text as usize,
|
||||
data,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -374,23 +357,18 @@ impl RSGIWebsocketProtocol {
|
|||
future_into_py_iter(self.rt.clone(), py, async move {
|
||||
let mut ws = transport.lock().await;
|
||||
match upgrade.send().await {
|
||||
Ok(_) => {
|
||||
match (&mut *ws).await {
|
||||
Ok(stream) => {
|
||||
let mut trx = itransport.lock().await;
|
||||
Ok(Python::with_gil(|py| {
|
||||
let pytransport = Py::new(
|
||||
py,
|
||||
RSGIWebsocketTransport::new(rth, stream)
|
||||
).unwrap();
|
||||
*trx = Some(pytransport.clone());
|
||||
pytransport
|
||||
}))
|
||||
},
|
||||
_ => error_proto!()
|
||||
Ok(_) => match (&mut *ws).await {
|
||||
Ok(stream) => {
|
||||
let mut trx = itransport.lock().await;
|
||||
Ok(Python::with_gil(|py| {
|
||||
let pytransport = Py::new(py, RSGIWebsocketTransport::new(rth, stream)).unwrap();
|
||||
*trx = Some(pytransport.clone());
|
||||
pytransport
|
||||
}))
|
||||
}
|
||||
_ => error_proto!(),
|
||||
},
|
||||
_ => error_proto!()
|
||||
_ => error_proto!(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -399,25 +377,13 @@ impl RSGIWebsocketProtocol {
|
|||
#[inline(always)]
|
||||
fn message_into_py(message: Message) -> PyResult<PyObject> {
|
||||
match message {
|
||||
Message::Binary(message) => {
|
||||
Ok(Python::with_gil(|py| {
|
||||
WebsocketInboundBytesMessage::new(
|
||||
PyBytes::new(py, &message).into()
|
||||
).into_py(py)
|
||||
}))
|
||||
},
|
||||
Message::Text(message) => {
|
||||
Ok(Python::with_gil(|py| {
|
||||
WebsocketInboundTextMessage::new(
|
||||
PyString::new(py, &message).into()
|
||||
).into_py(py)
|
||||
}))
|
||||
},
|
||||
Message::Close(_) => {
|
||||
Ok(Python::with_gil(|py| {
|
||||
WebsocketInboundCloseMessage::new().into_py(py)
|
||||
}))
|
||||
}
|
||||
Message::Binary(message) => Ok(Python::with_gil(|py| {
|
||||
WebsocketInboundBytesMessage::new(PyBytes::new(py, &message).into()).into_py(py)
|
||||
})),
|
||||
Message::Text(message) => Ok(Python::with_gil(|py| {
|
||||
WebsocketInboundTextMessage::new(PyString::new(py, &message).into()).into_py(py)
|
||||
})),
|
||||
Message::Close(_) => Ok(Python::with_gil(|py| WebsocketInboundCloseMessage::new().into_py(py))),
|
||||
v => {
|
||||
log::warn!("Unsupported websocket message received {:?}", v);
|
||||
error_proto!()
|
||||
|
|
|
@ -1,28 +1,14 @@
|
|||
use pyo3::prelude::*;
|
||||
|
||||
use crate::{
|
||||
workers::{
|
||||
WorkerConfig,
|
||||
serve_rth,
|
||||
serve_wth,
|
||||
serve_rth_ssl,
|
||||
serve_wth_ssl
|
||||
}
|
||||
};
|
||||
use super::http::{
|
||||
handle_rtb,
|
||||
handle_rtb_pyw,
|
||||
handle_rtt,
|
||||
handle_rtt_pyw,
|
||||
handle_rtb_ws,
|
||||
handle_rtb_ws_pyw,
|
||||
handle_rtt_ws,
|
||||
handle_rtt_ws_pyw
|
||||
handle_rtb, handle_rtb_pyw, handle_rtb_ws, handle_rtb_ws_pyw, handle_rtt, handle_rtt_pyw, handle_rtt_ws,
|
||||
handle_rtt_ws_pyw,
|
||||
};
|
||||
use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig};
|
||||
|
||||
#[pyclass(module="granian._granian")]
|
||||
#[pyclass(module = "granian._granian")]
|
||||
pub struct RSGIWorker {
|
||||
config: WorkerConfig
|
||||
config: WorkerConfig,
|
||||
}
|
||||
|
||||
impl RSGIWorker {
|
||||
|
@ -73,7 +59,7 @@ impl RSGIWorker {
|
|||
opt_enabled: bool,
|
||||
ssl_enabled: bool,
|
||||
ssl_cert: Option<&str>,
|
||||
ssl_key: Option<&str>
|
||||
ssl_key: Option<&str>,
|
||||
) -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
config: WorkerConfig::new(
|
||||
|
@ -87,22 +73,16 @@ impl RSGIWorker {
|
|||
opt_enabled,
|
||||
ssl_enabled,
|
||||
ssl_cert,
|
||||
ssl_key
|
||||
)
|
||||
ssl_key,
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
fn serve_rth(
|
||||
&self,
|
||||
callback: PyObject,
|
||||
event_loop: &PyAny,
|
||||
context: &PyAny,
|
||||
signal_rx: PyObject
|
||||
) {
|
||||
fn serve_rth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
|
||||
match (
|
||||
self.config.websockets_enabled,
|
||||
self.config.ssl_enabled,
|
||||
self.config.opt_enabled
|
||||
self.config.opt_enabled,
|
||||
) {
|
||||
(false, false, true) => self._serve_rth(callback, event_loop, context, signal_rx),
|
||||
(false, false, false) => self._serve_rth_pyw(callback, event_loop, context, signal_rx),
|
||||
|
@ -111,21 +91,15 @@ impl RSGIWorker {
|
|||
(false, true, true) => self._serve_rth_ssl(callback, event_loop, context, signal_rx),
|
||||
(false, true, false) => self._serve_rth_ssl_pyw(callback, event_loop, context, signal_rx),
|
||||
(true, true, true) => self._serve_rth_ssl_ws(callback, event_loop, context, signal_rx),
|
||||
(true, true, false) => self._serve_rth_ssl_ws_pyw(callback, event_loop, context, signal_rx)
|
||||
(true, true, false) => self._serve_rth_ssl_ws_pyw(callback, event_loop, context, signal_rx),
|
||||
}
|
||||
}
|
||||
|
||||
fn serve_wth(
|
||||
&self,
|
||||
callback: PyObject,
|
||||
event_loop: &PyAny,
|
||||
context: &PyAny,
|
||||
signal_rx: PyObject
|
||||
) {
|
||||
fn serve_wth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
|
||||
match (
|
||||
self.config.websockets_enabled,
|
||||
self.config.ssl_enabled,
|
||||
self.config.opt_enabled
|
||||
self.config.opt_enabled,
|
||||
) {
|
||||
(false, false, true) => self._serve_wth(callback, event_loop, context, signal_rx),
|
||||
(false, false, false) => self._serve_wth_pyw(callback, event_loop, context, signal_rx),
|
||||
|
@ -134,7 +108,7 @@ impl RSGIWorker {
|
|||
(false, true, true) => self._serve_wth_ssl(callback, event_loop, context, signal_rx),
|
||||
(false, true, false) => self._serve_wth_ssl_pyw(callback, event_loop, context, signal_rx),
|
||||
(true, true, true) => self._serve_wth_ssl_ws(callback, event_loop, context, signal_rx),
|
||||
(true, true, false) => self._serve_wth_ssl_ws_pyw(callback, event_loop, context, signal_rx)
|
||||
(true, true, false) => self._serve_wth_ssl_ws_pyw(callback, event_loop, context, signal_rx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use hyper::{
|
||||
header::{HeaderMap, HeaderName, HeaderValue, SERVER as HK_SERVER},
|
||||
Body, Uri, Version
|
||||
Body, Uri, Version,
|
||||
};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyString;
|
||||
|
@ -10,11 +10,10 @@ use tokio_util::codec::{BytesCodec, FramedRead};
|
|||
|
||||
use crate::http::HV_SERVER;
|
||||
|
||||
|
||||
#[pyclass(module="granian._granian")]
|
||||
#[pyclass(module = "granian._granian")]
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RSGIHeaders {
|
||||
inner: HeaderMap
|
||||
inner: HeaderMap,
|
||||
}
|
||||
|
||||
impl RSGIHeaders {
|
||||
|
@ -29,7 +28,7 @@ impl RSGIHeaders {
|
|||
let mut ret = Vec::with_capacity(self.inner.keys_len());
|
||||
for key in self.inner.keys() {
|
||||
ret.push(key.as_str());
|
||||
};
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
|
@ -37,15 +36,15 @@ impl RSGIHeaders {
|
|||
let mut ret = Vec::with_capacity(self.inner.keys_len());
|
||||
for val in self.inner.values() {
|
||||
ret.push(val.to_str().unwrap());
|
||||
};
|
||||
}
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
fn items(&self) -> PyResult<Vec<(&str, &str)>> {
|
||||
let mut ret = Vec::with_capacity(self.inner.keys_len());
|
||||
for (key, val) in self.inner.iter() {
|
||||
for (key, val) in &self.inner {
|
||||
ret.push((key.as_str(), val.to_str().unwrap()));
|
||||
};
|
||||
}
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
|
@ -56,18 +55,16 @@ impl RSGIHeaders {
|
|||
#[pyo3(signature = (key, default=None))]
|
||||
fn get(&self, py: Python, key: &str, default: Option<PyObject>) -> Option<PyObject> {
|
||||
match self.inner.get(key) {
|
||||
Some(val) => {
|
||||
match val.to_str() {
|
||||
Ok(string) => Some(PyString::new(py, string).into()),
|
||||
_ => default
|
||||
}
|
||||
Some(val) => match val.to_str() {
|
||||
Ok(string) => Some(PyString::new(py, string).into()),
|
||||
_ => default,
|
||||
},
|
||||
_ => default
|
||||
_ => default,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(module="granian._granian")]
|
||||
#[pyclass(module = "granian._granian")]
|
||||
pub(crate) struct RSGIScope {
|
||||
#[pyo3(get)]
|
||||
proto: String,
|
||||
|
@ -84,7 +81,7 @@ pub(crate) struct RSGIScope {
|
|||
#[pyo3(get)]
|
||||
client: String,
|
||||
#[pyo3(get)]
|
||||
headers: RSGIHeaders
|
||||
headers: RSGIHeaders,
|
||||
}
|
||||
|
||||
impl RSGIScope {
|
||||
|
@ -96,23 +93,23 @@ impl RSGIScope {
|
|||
method: &str,
|
||||
server: SocketAddr,
|
||||
client: SocketAddr,
|
||||
headers: &HeaderMap
|
||||
headers: &HeaderMap,
|
||||
) -> Self {
|
||||
Self {
|
||||
proto: proto.to_string(),
|
||||
http_version: http_version,
|
||||
http_version,
|
||||
rsgi_version: "1.2".to_string(),
|
||||
scheme: scheme.to_string(),
|
||||
method: method.to_string(),
|
||||
uri: uri,
|
||||
uri,
|
||||
server: server.to_string(),
|
||||
client: client.to_string(),
|
||||
headers: RSGIHeaders::new(headers)
|
||||
headers: RSGIHeaders::new(headers),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_proto(&mut self, value: &str) {
|
||||
self.proto = value.to_string()
|
||||
self.proto = value.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -125,7 +122,7 @@ impl RSGIScope {
|
|||
Version::HTTP_11 => "1.1",
|
||||
Version::HTTP_2 => "2",
|
||||
Version::HTTP_3 => "3",
|
||||
_ => "1"
|
||||
_ => "1",
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -142,38 +139,36 @@ impl RSGIScope {
|
|||
|
||||
pub(crate) enum PyResponse {
|
||||
Body(PyResponseBody),
|
||||
File(PyResponseFile)
|
||||
File(PyResponseFile),
|
||||
}
|
||||
|
||||
pub(crate) struct PyResponseBody {
|
||||
status: u16,
|
||||
headers: Vec<(String, String)>,
|
||||
body: Body
|
||||
body: Body,
|
||||
}
|
||||
|
||||
pub(crate) struct PyResponseFile {
|
||||
status: u16,
|
||||
headers: Vec<(String, String)>,
|
||||
file_path: String
|
||||
file_path: String,
|
||||
}
|
||||
|
||||
macro_rules! response_head_from_py {
|
||||
($status:expr, $headers:expr, $res:expr) => {
|
||||
{
|
||||
let mut rh = hyper::http::HeaderMap::new();
|
||||
($status:expr, $headers:expr, $res:expr) => {{
|
||||
let mut rh = hyper::http::HeaderMap::new();
|
||||
|
||||
rh.insert(HK_SERVER, HV_SERVER);
|
||||
for (key, value) in $headers {
|
||||
rh.append(
|
||||
HeaderName::from_bytes(key.as_bytes()).unwrap(),
|
||||
HeaderValue::from_str(&value).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
*$res.status_mut() = $status.try_into().unwrap();
|
||||
*$res.headers_mut() = rh;
|
||||
rh.insert(HK_SERVER, HV_SERVER);
|
||||
for (key, value) in $headers {
|
||||
rh.append(
|
||||
HeaderName::from_bytes(key.as_bytes()).unwrap(),
|
||||
HeaderValue::from_str(&value).unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
*$res.status_mut() = $status.try_into().unwrap();
|
||||
*$res.headers_mut() = rh;
|
||||
}};
|
||||
}
|
||||
|
||||
impl PyResponseBody {
|
||||
|
@ -182,18 +177,30 @@ impl PyResponseBody {
|
|||
}
|
||||
|
||||
pub fn empty(status: u16, headers: Vec<(String, String)>) -> Self {
|
||||
Self { status, headers, body: Body::empty() }
|
||||
Self {
|
||||
status,
|
||||
headers,
|
||||
body: Body::empty(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bytes(status: u16, headers: Vec<(String, String)>, body: Vec<u8>) -> Self {
|
||||
Self { status, headers, body: Body::from(body) }
|
||||
Self {
|
||||
status,
|
||||
headers,
|
||||
body: Body::from(body),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_string(status: u16, headers: Vec<(String, String)>, body: String) -> Self {
|
||||
Self { status, headers, body: Body::from(body) }
|
||||
Self {
|
||||
status,
|
||||
headers,
|
||||
body: Body::from(body),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_response(self) -> hyper::Response::<Body> {
|
||||
pub fn to_response(self) -> hyper::Response<Body> {
|
||||
let mut res = hyper::Response::<Body>::new(self.body);
|
||||
response_head_from_py!(self.status, &self.headers, res);
|
||||
res
|
||||
|
@ -202,10 +209,14 @@ impl PyResponseBody {
|
|||
|
||||
impl PyResponseFile {
|
||||
pub fn new(status: u16, headers: Vec<(String, String)>, file_path: String) -> Self {
|
||||
Self { status, headers, file_path }
|
||||
Self {
|
||||
status,
|
||||
headers,
|
||||
file_path,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn to_response(&self) -> hyper::Response::<Body> {
|
||||
pub async fn to_response(&self) -> hyper::Response<Body> {
|
||||
let file = File::open(&self.file_path).await.unwrap();
|
||||
let stream = FramedRead::new(file, BytesCodec::new());
|
||||
let mut res = hyper::Response::<Body>::new(Body::wrap_stream(stream));
|
||||
|
|
|
@ -1,12 +1,19 @@
|
|||
use once_cell::unsync::OnceCell as UnsyncOnceCell;
|
||||
use pyo3_asyncio::TaskLocals;
|
||||
use pyo3::prelude::*;
|
||||
use std::{future::Future, io, pin::Pin, sync::{Arc, Mutex}};
|
||||
use tokio::{runtime::Builder, task::{JoinHandle, LocalSet}};
|
||||
use pyo3_asyncio::TaskLocals;
|
||||
use std::{
|
||||
future::Future,
|
||||
io,
|
||||
pin::Pin,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
use tokio::{
|
||||
runtime::Builder,
|
||||
task::{JoinHandle, LocalSet},
|
||||
};
|
||||
|
||||
use super::callbacks::{PyFutureAwaitable, PyIterAwaitable};
|
||||
|
||||
|
||||
tokio::task_local! {
|
||||
static TASK_LOCALS: UnsyncOnceCell<TaskLocals>;
|
||||
}
|
||||
|
@ -27,11 +34,7 @@ pub trait Runtime: Send + 'static {
|
|||
}
|
||||
|
||||
pub trait ContextExt: Runtime {
|
||||
fn scope<F, R>(
|
||||
&self,
|
||||
locals: TaskLocals,
|
||||
fut: F
|
||||
) -> Pin<Box<dyn Future<Output = R> + Send>>
|
||||
fn scope<F, R>(&self, locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
|
||||
where
|
||||
F: Future<Output = R> + Send + 'static;
|
||||
|
||||
|
@ -45,36 +48,34 @@ pub trait SpawnLocalExt: Runtime {
|
|||
}
|
||||
|
||||
pub trait LocalContextExt: Runtime {
|
||||
fn scope_local<F, R>(
|
||||
&self,
|
||||
locals: TaskLocals,
|
||||
fut: F
|
||||
) -> Pin<Box<dyn Future<Output = R>>>
|
||||
fn scope_local<F, R>(&self, locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R>>>
|
||||
where
|
||||
F: Future<Output = R> + 'static;
|
||||
}
|
||||
|
||||
pub(crate) struct RuntimeWrapper {
|
||||
rt: tokio::runtime::Runtime
|
||||
rt: tokio::runtime::Runtime,
|
||||
}
|
||||
|
||||
impl RuntimeWrapper {
|
||||
pub fn new(blocking_threads: usize) -> Self {
|
||||
Self { rt: default_runtime(blocking_threads).unwrap() }
|
||||
Self {
|
||||
rt: default_runtime(blocking_threads).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_runtime(rt: tokio::runtime::Runtime) -> Self {
|
||||
Self { rt: rt }
|
||||
Self { rt }
|
||||
}
|
||||
|
||||
pub fn handler(&self) -> RuntimeRef {
|
||||
RuntimeRef::new(self.rt.handle().to_owned())
|
||||
RuntimeRef::new(self.rt.handle().clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RuntimeRef {
|
||||
pub inner: tokio::runtime::Handle
|
||||
pub inner: tokio::runtime::Handle,
|
||||
}
|
||||
|
||||
impl RuntimeRef {
|
||||
|
@ -108,11 +109,7 @@ impl Runtime for RuntimeRef {
|
|||
}
|
||||
|
||||
impl ContextExt for RuntimeRef {
|
||||
fn scope<F, R>(
|
||||
&self,
|
||||
locals: TaskLocals,
|
||||
fut: F
|
||||
) -> Pin<Box<dyn Future<Output = R> + Send>>
|
||||
fn scope<F, R>(&self, locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
|
||||
where
|
||||
F: Future<Output = R> + Send + 'static,
|
||||
{
|
||||
|
@ -123,7 +120,7 @@ impl ContextExt for RuntimeRef {
|
|||
}
|
||||
|
||||
fn get_task_locals() -> Option<TaskLocals> {
|
||||
match TASK_LOCALS.try_with(|c| c.get().map(|locals| locals.clone())) {
|
||||
match TASK_LOCALS.try_with(|c| c.get().cloned()) {
|
||||
Ok(locals) => locals,
|
||||
Err(_) => None,
|
||||
}
|
||||
|
@ -140,11 +137,7 @@ impl SpawnLocalExt for RuntimeRef {
|
|||
}
|
||||
|
||||
impl LocalContextExt for RuntimeRef {
|
||||
fn scope_local<F, R>(
|
||||
&self,
|
||||
locals: TaskLocals,
|
||||
fut: F
|
||||
) -> Pin<Box<dyn Future<Output = R>>>
|
||||
fn scope_local<F, R>(&self, locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R>>>
|
||||
where
|
||||
F: Future<Output = R> + 'static,
|
||||
{
|
||||
|
@ -169,7 +162,7 @@ pub(crate) fn init_runtime_mt(threads: usize, blocking_threads: usize) -> Runtim
|
|||
.max_blocking_threads(blocking_threads)
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap()
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -177,12 +170,8 @@ pub(crate) fn init_runtime_st(blocking_threads: usize) -> RuntimeWrapper {
|
|||
RuntimeWrapper::new(blocking_threads)
|
||||
}
|
||||
|
||||
pub(crate) fn into_future(
|
||||
awaitable: &PyAny,
|
||||
) -> PyResult<impl Future<Output = PyResult<PyObject>> + Send> {
|
||||
pyo3_asyncio::into_future_with_locals(
|
||||
&get_current_locals::<RuntimeRef>(awaitable.py())?, awaitable
|
||||
)
|
||||
pub(crate) fn into_future(awaitable: &PyAny) -> PyResult<impl Future<Output = PyResult<PyObject>> + Send> {
|
||||
pyo3_asyncio::into_future_with_locals(&get_current_locals::<RuntimeRef>(awaitable.py())?, awaitable)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -241,10 +230,7 @@ where
|
|||
rt.spawn(async move {
|
||||
let result = fut.await;
|
||||
Python::with_gil(move |py| {
|
||||
PyFutureAwaitable::set_result(
|
||||
py_aw.as_ref(py).borrow_mut(),
|
||||
result.map(|v| v.into_py(py))
|
||||
);
|
||||
PyFutureAwaitable::set_result(py_aw.as_ref(py).borrow_mut(), result.map(|v| v.into_py(py)));
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -269,11 +255,7 @@ where
|
|||
let rth = rt.handler();
|
||||
|
||||
rt.spawn(async move {
|
||||
let val = rth.scope(
|
||||
task_locals.clone(),
|
||||
fut
|
||||
)
|
||||
.await;
|
||||
let val = rth.scope(task_locals.clone(), fut).await;
|
||||
if let Ok(mut result) = result_tx.lock() {
|
||||
*result = Some(val.unwrap());
|
||||
}
|
||||
|
@ -292,7 +274,7 @@ where
|
|||
|
||||
pub(crate) fn block_on_local<F>(rt: RuntimeWrapper, local: LocalSet, fut: F)
|
||||
where
|
||||
F: Future + 'static
|
||||
F: Future + 'static,
|
||||
{
|
||||
local.block_on(&rt.rt, fut);
|
||||
}
|
||||
|
|
33
src/tcp.rs
33
src/tcp.rs
|
@ -9,10 +9,9 @@ use std::os::windows::io::{AsRawSocket, FromRawSocket};
|
|||
|
||||
use socket2::{Domain, Protocol, Socket, Type};
|
||||
|
||||
|
||||
#[pyclass(module="granian._granian")]
|
||||
#[pyclass(module = "granian._granian")]
|
||||
pub struct ListenerHolder {
|
||||
socket: TcpListener
|
||||
socket: TcpListener,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
|
@ -20,28 +19,19 @@ impl ListenerHolder {
|
|||
#[cfg(unix)]
|
||||
#[new]
|
||||
pub fn new(fd: i32) -> PyResult<Self> {
|
||||
let socket = unsafe {
|
||||
TcpListener::from_raw_fd(fd)
|
||||
};
|
||||
Ok(Self { socket: socket })
|
||||
let socket = unsafe { TcpListener::from_raw_fd(fd) };
|
||||
Ok(Self { socket })
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
#[new]
|
||||
pub fn new(fd: u64) -> PyResult<Self> {
|
||||
let socket = unsafe {
|
||||
TcpListener::from_raw_socket(fd)
|
||||
};
|
||||
Ok(Self { socket: socket })
|
||||
let socket = unsafe { TcpListener::from_raw_socket(fd) };
|
||||
Ok(Self { socket })
|
||||
}
|
||||
|
||||
#[classmethod]
|
||||
pub fn from_address(
|
||||
_cls: &PyType,
|
||||
address: &str,
|
||||
port: u16,
|
||||
backlog: i32
|
||||
) -> PyResult<Self> {
|
||||
pub fn from_address(_cls: &PyType, address: &str, port: u16, backlog: i32) -> PyResult<Self> {
|
||||
let address: SocketAddr = (address.parse::<IpAddr>()?, port).into();
|
||||
let socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?;
|
||||
socket.set_reuse_address(true)?;
|
||||
|
@ -54,17 +44,13 @@ impl ListenerHolder {
|
|||
#[cfg(unix)]
|
||||
pub fn __getstate__(&self, py: Python) -> PyObject {
|
||||
let fd = self.socket.as_raw_fd();
|
||||
(
|
||||
fd.into_py(py),
|
||||
).to_object(py)
|
||||
(fd.into_py(py),).to_object(py)
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
pub fn __getstate__(&self, py: Python) -> PyObject {
|
||||
let fd = self.socket.as_raw_socket();
|
||||
(
|
||||
fd.into_py(py),
|
||||
).to_object(py)
|
||||
(fd.into_py(py),).to_object(py)
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
|
@ -84,7 +70,6 @@ impl ListenerHolder {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
pub(crate) fn init_pymodule(module: &PyModule) -> PyResult<()> {
|
||||
module.add_class::<ListenerHolder>()?;
|
||||
|
||||
|
|
32
src/tls.rs
32
src/tls.rs
|
@ -1,20 +1,22 @@
|
|||
use futures::stream::StreamExt;
|
||||
use hyper::server::{accept, conn::{AddrIncoming, AddrStream}};
|
||||
use hyper::server::{
|
||||
accept,
|
||||
conn::{AddrIncoming, AddrStream},
|
||||
};
|
||||
use std::{fs, future, io, iter::Iterator, net::TcpListener, sync::Arc};
|
||||
use tls_listener::{Error as TlsError, TlsListener};
|
||||
use tokio_rustls::{
|
||||
TlsAcceptor,
|
||||
rustls::{Certificate, PrivateKey, ServerConfig},
|
||||
server::TlsStream
|
||||
server::TlsStream,
|
||||
TlsAcceptor,
|
||||
};
|
||||
|
||||
|
||||
pub(crate) type TlsAddrStream = TlsStream<AddrStream>;
|
||||
|
||||
pub(crate) fn tls_listen(
|
||||
config: Arc<ServerConfig>,
|
||||
tcp: TcpListener
|
||||
) -> impl accept::Accept<Conn=TlsAddrStream, Error=TlsError<io::Error, io::Error>> {
|
||||
tcp: TcpListener,
|
||||
) -> impl accept::Accept<Conn = TlsAddrStream, Error = TlsError<io::Error, io::Error>> {
|
||||
tcp.set_nonblocking(true).unwrap();
|
||||
let tcp_listener = tokio::net::TcpListener::from_std(tcp).unwrap();
|
||||
let incoming = AddrIncoming::from_listener(tcp_listener).unwrap();
|
||||
|
@ -34,30 +36,26 @@ fn tls_error(err: String) -> io::Error {
|
|||
}
|
||||
|
||||
pub(crate) fn load_certs(filename: &str) -> io::Result<Vec<Certificate>> {
|
||||
let certfile = fs::File::open(filename)
|
||||
.map_err(|e| tls_error(format!("failed to open {}: {}", filename, e)))?;
|
||||
let certfile = fs::File::open(filename).map_err(|e| tls_error(format!("failed to open {filename}: {e}")))?;
|
||||
let mut reader = io::BufReader::new(certfile);
|
||||
|
||||
let certs = rustls_pemfile::certs(&mut reader)
|
||||
.map_err(|_| tls_error("failed to load certificate".into()))?;
|
||||
let certs = rustls_pemfile::certs(&mut reader).map_err(|_| tls_error("failed to load certificate".into()))?;
|
||||
Ok(certs.into_iter().map(Certificate).collect())
|
||||
}
|
||||
|
||||
pub(crate) fn load_private_key(filename: &str) -> io::Result<PrivateKey> {
|
||||
let keyfile = fs::File::open(filename)
|
||||
.map_err(|e| tls_error(format!("failed to open {}: {}", filename, e)))?;
|
||||
let keyfile = fs::File::open(filename).map_err(|e| tls_error(format!("failed to open {filename}: {e}")))?;
|
||||
let mut reader = io::BufReader::new(keyfile);
|
||||
|
||||
let keys = rustls_pemfile::read_all(&mut reader)
|
||||
.map_err(|_| tls_error("failed to load private key".into()))?;
|
||||
let keys = rustls_pemfile::read_all(&mut reader).map_err(|_| tls_error("failed to load private key".into()))?;
|
||||
if keys.len() != 1 {
|
||||
return Err(tls_error("expected a single private key".into()));
|
||||
}
|
||||
|
||||
let key = match &keys[0] {
|
||||
rustls_pemfile::Item::RSAKey(key) => PrivateKey(key.to_vec()),
|
||||
rustls_pemfile::Item::PKCS8Key(key) => PrivateKey(key.to_vec()),
|
||||
rustls_pemfile::Item::ECKey(key) => PrivateKey(key.to_vec()),
|
||||
rustls_pemfile::Item::RSAKey(key) => PrivateKey(key.clone()),
|
||||
rustls_pemfile::Item::PKCS8Key(key) => PrivateKey(key.clone()),
|
||||
rustls_pemfile::Item::ECKey(key) => PrivateKey(key.clone()),
|
||||
_ => {
|
||||
return Err(tls_error("failed to load private key".into()));
|
||||
}
|
||||
|
|
12
src/utils.rs
12
src/utils.rs
|
@ -1,13 +1,15 @@
|
|||
pub(crate) fn header_contains_value(
|
||||
headers: &hyper::HeaderMap,
|
||||
header: impl hyper::header::AsHeaderName,
|
||||
value: impl AsRef<[u8]>
|
||||
value: impl AsRef<[u8]>,
|
||||
) -> bool {
|
||||
let value = value.as_ref();
|
||||
for header in headers.get_all(header) {
|
||||
if header.as_bytes().split(|&c| c == b',').any(
|
||||
|x| trim(x).eq_ignore_ascii_case(value)
|
||||
) {
|
||||
if header
|
||||
.as_bytes()
|
||||
.split(|&c| c == b',')
|
||||
.any(|x| trim(x).eq_ignore_ascii_case(value))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -30,7 +32,7 @@ fn trim_start(data: &[u8]) -> &[u8] {
|
|||
#[inline]
|
||||
fn trim_end(data: &[u8]) -> &[u8] {
|
||||
if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) {
|
||||
&data[..last + 1]
|
||||
&data[..=last]
|
||||
} else {
|
||||
b""
|
||||
}
|
||||
|
|
277
src/workers.rs
277
src/workers.rs
|
@ -8,8 +8,8 @@ use std::os::windows::io::FromRawSocket;
|
|||
|
||||
use super::asgi::serve::ASGIWorker;
|
||||
use super::rsgi::serve::RSGIWorker;
|
||||
use super::wsgi::serve::WSGIWorker;
|
||||
use super::tls::{load_certs as tls_load_certs, load_private_key as tls_load_pkey};
|
||||
use super::wsgi::serve::WSGIWorker;
|
||||
|
||||
pub(crate) struct WorkerConfig {
|
||||
pub id: i32,
|
||||
|
@ -22,7 +22,7 @@ pub(crate) struct WorkerConfig {
|
|||
pub opt_enabled: bool,
|
||||
pub ssl_enabled: bool,
|
||||
ssl_cert: Option<String>,
|
||||
ssl_key: Option<String>
|
||||
ssl_key: Option<String>,
|
||||
}
|
||||
|
||||
impl WorkerConfig {
|
||||
|
@ -37,7 +37,7 @@ impl WorkerConfig {
|
|||
opt_enabled: bool,
|
||||
ssl_enabled: bool,
|
||||
ssl_cert: Option<&str>,
|
||||
ssl_key: Option<&str>
|
||||
ssl_key: Option<&str>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
|
@ -49,23 +49,19 @@ impl WorkerConfig {
|
|||
websockets_enabled,
|
||||
opt_enabled,
|
||||
ssl_enabled,
|
||||
ssl_cert: ssl_cert.map_or(None, |v| Some(v.into())),
|
||||
ssl_key: ssl_key.map_or(None, |v| Some(v.into()))
|
||||
ssl_cert: ssl_cert.map(std::convert::Into::into),
|
||||
ssl_key: ssl_key.map(std::convert::Into::into),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub fn tcp_listener(&self) -> TcpListener {
|
||||
unsafe {
|
||||
TcpListener::from_raw_fd(self.socket_fd)
|
||||
}
|
||||
unsafe { TcpListener::from_raw_fd(self.socket_fd) }
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
pub fn tcp_listener(&self) -> TcpListener {
|
||||
unsafe {
|
||||
TcpListener::from_raw_socket(self.socket_fd as u64)
|
||||
}
|
||||
unsafe { TcpListener::from_raw_socket(self.socket_fd as u64) }
|
||||
}
|
||||
|
||||
pub fn tls_cfg(&self) -> tokio_rustls::rustls::ServerConfig {
|
||||
|
@ -74,13 +70,13 @@ impl WorkerConfig {
|
|||
.with_no_client_auth()
|
||||
.with_single_cert(
|
||||
tls_load_certs(&self.ssl_cert.clone().unwrap()[..]).unwrap(),
|
||||
tls_load_pkey(&self.ssl_key.clone().unwrap()[..]).unwrap()
|
||||
tls_load_pkey(&self.ssl_key.clone().unwrap()[..]).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
cfg.alpn_protocols = match &self.http_mode[..] {
|
||||
"1" => vec![b"http/1.1".to_vec()],
|
||||
"2" => vec![b"h2".to_vec()],
|
||||
_ => vec![b"h2".to_vec(), b"http/1.1".to_vec()]
|
||||
_ => vec![b"h2".to_vec(), b"http/1.1".to_vec()],
|
||||
};
|
||||
cfg
|
||||
}
|
||||
|
@ -102,7 +98,7 @@ pub(crate) struct WorkerExecutor;
|
|||
|
||||
impl<F> hyper::rt::Executor<F> for WorkerExecutor
|
||||
where
|
||||
F: std::future::Future + 'static
|
||||
F: std::future::Future + 'static,
|
||||
{
|
||||
fn execute(&self, fut: F) {
|
||||
tokio::task::spawn_local(fut);
|
||||
|
@ -123,14 +119,9 @@ macro_rules! build_service {
|
|||
let rth = rth.clone();
|
||||
|
||||
async move {
|
||||
Ok::<_, std::convert::Infallible>($target(
|
||||
rth,
|
||||
callback_wrapper,
|
||||
local_addr,
|
||||
remote_addr,
|
||||
req,
|
||||
"http"
|
||||
).await)
|
||||
Ok::<_, std::convert::Infallible>(
|
||||
$target(rth, callback_wrapper, local_addr, remote_addr, req, "http").await,
|
||||
)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
@ -153,14 +144,9 @@ macro_rules! build_service_ssl {
|
|||
let rth = rth.clone();
|
||||
|
||||
async move {
|
||||
Ok::<_, std::convert::Infallible>($target(
|
||||
rth,
|
||||
callback_wrapper,
|
||||
local_addr,
|
||||
remote_addr,
|
||||
req,
|
||||
"https"
|
||||
).await)
|
||||
Ok::<_, std::convert::Infallible>(
|
||||
$target(rth, callback_wrapper, local_addr, remote_addr, req, "https").await,
|
||||
)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
@ -170,13 +156,7 @@ macro_rules! build_service_ssl {
|
|||
|
||||
macro_rules! serve_rth {
|
||||
($func_name:ident, $target:expr) => {
|
||||
fn $func_name(
|
||||
&self,
|
||||
callback: PyObject,
|
||||
event_loop: &PyAny,
|
||||
context: &PyAny,
|
||||
signal_rx: PyObject
|
||||
) {
|
||||
fn $func_name(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
|
||||
pyo3_log::init();
|
||||
let rt = crate::runtime::init_runtime_mt(self.config.threads, self.config.pthreads);
|
||||
let rth = rt.handler();
|
||||
|
@ -184,34 +164,30 @@ macro_rules! serve_rth {
|
|||
let http1_only = self.config.http_mode == "1";
|
||||
let http2_only = self.config.http_mode == "2";
|
||||
let http1_buffer_max = self.config.http1_buffer_max.clone();
|
||||
let callback_wrapper = crate::callbacks::CallbackWrapper::new(
|
||||
callback, event_loop, context
|
||||
);
|
||||
let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context);
|
||||
|
||||
let worker_id = self.config.id;
|
||||
log::info!("Started worker-{}", worker_id);
|
||||
|
||||
let svc_loop = crate::runtime::run_until_complete(
|
||||
rt.handler(),
|
||||
event_loop,
|
||||
async move {
|
||||
let service = crate::workers::build_service!(
|
||||
callback_wrapper, rth, $target
|
||||
);
|
||||
let server = hyper::Server::from_tcp(tcp_listener).unwrap()
|
||||
.http1_only(http1_only)
|
||||
.http2_only(http2_only)
|
||||
.http1_max_buf_size(http1_buffer_max)
|
||||
.serve(service);
|
||||
server.with_graceful_shutdown(async move {
|
||||
Python::with_gil(|py| {
|
||||
crate::runtime::into_future(signal_rx.as_ref(py)).unwrap()
|
||||
}).await.unwrap();
|
||||
}).await.unwrap();
|
||||
log::info!("Stopping worker-{}", worker_id);
|
||||
Ok(())
|
||||
}
|
||||
);
|
||||
let svc_loop = crate::runtime::run_until_complete(rt.handler(), event_loop, async move {
|
||||
let service = crate::workers::build_service!(callback_wrapper, rth, $target);
|
||||
let server = hyper::Server::from_tcp(tcp_listener)
|
||||
.unwrap()
|
||||
.http1_only(http1_only)
|
||||
.http2_only(http2_only)
|
||||
.http1_max_buf_size(http1_buffer_max)
|
||||
.serve(service);
|
||||
server
|
||||
.with_graceful_shutdown(async move {
|
||||
Python::with_gil(|py| crate::runtime::into_future(signal_rx.as_ref(py)).unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
log::info!("Stopping worker-{}", worker_id);
|
||||
Ok(())
|
||||
});
|
||||
|
||||
match svc_loop {
|
||||
Ok(_) => {}
|
||||
|
@ -226,13 +202,7 @@ macro_rules! serve_rth {
|
|||
|
||||
macro_rules! serve_rth_ssl {
|
||||
($func_name:ident, $target:expr) => {
|
||||
fn $func_name(
|
||||
&self,
|
||||
callback: PyObject,
|
||||
event_loop: &PyAny,
|
||||
context: &PyAny,
|
||||
signal_rx: PyObject
|
||||
) {
|
||||
fn $func_name(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
|
||||
pyo3_log::init();
|
||||
let rt = crate::runtime::init_runtime_mt(self.config.threads, self.config.pthreads);
|
||||
let rth = rt.handler();
|
||||
|
@ -241,38 +211,29 @@ macro_rules! serve_rth_ssl {
|
|||
let http2_only = self.config.http_mode == "2";
|
||||
let http1_buffer_max = self.config.http1_buffer_max.clone();
|
||||
let tls_cfg = self.config.tls_cfg();
|
||||
let callback_wrapper = crate::callbacks::CallbackWrapper::new(
|
||||
callback, event_loop, context
|
||||
);
|
||||
let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context);
|
||||
|
||||
let worker_id = self.config.id;
|
||||
log::info!("Started worker-{}", worker_id);
|
||||
|
||||
let svc_loop = crate::runtime::run_until_complete(
|
||||
rt.handler(),
|
||||
event_loop,
|
||||
async move {
|
||||
let service = crate::workers::build_service_ssl!(
|
||||
callback_wrapper, rth, $target
|
||||
);
|
||||
let server = hyper::Server::builder(
|
||||
crate::tls::tls_listen(
|
||||
std::sync::Arc::new(tls_cfg), tcp_listener
|
||||
)
|
||||
)
|
||||
.http1_only(http1_only)
|
||||
.http2_only(http2_only)
|
||||
.http1_max_buf_size(http1_buffer_max)
|
||||
.serve(service);
|
||||
server.with_graceful_shutdown(async move {
|
||||
Python::with_gil(|py| {
|
||||
crate::runtime::into_future(signal_rx.as_ref(py)).unwrap()
|
||||
}).await.unwrap();
|
||||
}).await.unwrap();
|
||||
log::info!("Stopping worker-{}", worker_id);
|
||||
Ok(())
|
||||
}
|
||||
);
|
||||
let svc_loop = crate::runtime::run_until_complete(rt.handler(), event_loop, async move {
|
||||
let service = crate::workers::build_service_ssl!(callback_wrapper, rth, $target);
|
||||
let server = hyper::Server::builder(crate::tls::tls_listen(std::sync::Arc::new(tls_cfg), tcp_listener))
|
||||
.http1_only(http1_only)
|
||||
.http2_only(http2_only)
|
||||
.http1_max_buf_size(http1_buffer_max)
|
||||
.serve(service);
|
||||
server
|
||||
.with_graceful_shutdown(async move {
|
||||
Python::with_gil(|py| crate::runtime::into_future(signal_rx.as_ref(py)).unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
log::info!("Stopping worker-{}", worker_id);
|
||||
Ok(())
|
||||
});
|
||||
|
||||
match svc_loop {
|
||||
Ok(_) => {}
|
||||
|
@ -287,22 +248,14 @@ macro_rules! serve_rth_ssl {
|
|||
|
||||
macro_rules! serve_wth {
|
||||
($func_name: ident, $target:expr) => {
|
||||
fn $func_name(
|
||||
&self,
|
||||
callback: PyObject,
|
||||
event_loop: &PyAny,
|
||||
context: &PyAny,
|
||||
signal_rx: PyObject
|
||||
) {
|
||||
fn $func_name(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
|
||||
pyo3_log::init();
|
||||
let rtm = crate::runtime::init_runtime_mt(1, 1);
|
||||
|
||||
let worker_id = self.config.id;
|
||||
log::info!("Started worker-{}", worker_id);
|
||||
|
||||
let callback_wrapper = crate::callbacks::CallbackWrapper::new(
|
||||
callback, event_loop, context
|
||||
);
|
||||
let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context);
|
||||
let mut workers = vec![];
|
||||
let (stx, srx) = tokio::sync::watch::channel(false);
|
||||
|
||||
|
@ -323,38 +276,36 @@ macro_rules! serve_wth {
|
|||
let local = tokio::task::LocalSet::new();
|
||||
|
||||
crate::runtime::block_on_local(rt, local, async move {
|
||||
let service = crate::workers::build_service!(
|
||||
callback_wrapper, rth, $target
|
||||
);
|
||||
let server = hyper::Server::from_tcp(tcp_listener).unwrap()
|
||||
let service = crate::workers::build_service!(callback_wrapper, rth, $target);
|
||||
let server = hyper::Server::from_tcp(tcp_listener)
|
||||
.unwrap()
|
||||
.executor(crate::workers::WorkerExecutor)
|
||||
.http1_only(http1_only)
|
||||
.http2_only(http2_only)
|
||||
.http1_max_buf_size(http1_buffer_max)
|
||||
.serve(service);
|
||||
server.with_graceful_shutdown(async move {
|
||||
srx.changed().await.unwrap();
|
||||
}).await.unwrap();
|
||||
server
|
||||
.with_graceful_shutdown(async move {
|
||||
srx.changed().await.unwrap();
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
log::info!("Stopping worker-{} runtime-{}", worker_id, thread_id + 1);
|
||||
});
|
||||
}));
|
||||
};
|
||||
}
|
||||
|
||||
let main_loop = crate::runtime::run_until_complete(
|
||||
rtm.handler(),
|
||||
event_loop,
|
||||
async move {
|
||||
Python::with_gil(|py| {
|
||||
crate::runtime::into_future(signal_rx.as_ref(py)).unwrap()
|
||||
}).await.unwrap();
|
||||
stx.send(true).unwrap();
|
||||
log::info!("Stopping worker-{}", worker_id);
|
||||
while let Some(worker) = workers.pop() {
|
||||
worker.join().unwrap();
|
||||
}
|
||||
Ok(())
|
||||
let main_loop = crate::runtime::run_until_complete(rtm.handler(), event_loop, async move {
|
||||
Python::with_gil(|py| crate::runtime::into_future(signal_rx.as_ref(py)).unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
stx.send(true).unwrap();
|
||||
log::info!("Stopping worker-{}", worker_id);
|
||||
while let Some(worker) = workers.pop() {
|
||||
worker.join().unwrap();
|
||||
}
|
||||
);
|
||||
Ok(())
|
||||
});
|
||||
|
||||
match main_loop {
|
||||
Ok(_) => {}
|
||||
|
@ -369,22 +320,14 @@ macro_rules! serve_wth {
|
|||
|
||||
macro_rules! serve_wth_ssl {
|
||||
($func_name: ident, $target:expr) => {
|
||||
fn $func_name(
|
||||
&self,
|
||||
callback: PyObject,
|
||||
event_loop: &PyAny,
|
||||
context: &PyAny,
|
||||
signal_rx: PyObject
|
||||
) {
|
||||
fn $func_name(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
|
||||
pyo3_log::init();
|
||||
let rtm = crate::runtime::init_runtime_mt(1, 1);
|
||||
|
||||
let worker_id = self.config.id;
|
||||
log::info!("Started worker-{}", worker_id);
|
||||
|
||||
let callback_wrapper = crate::callbacks::CallbackWrapper::new(
|
||||
callback, event_loop, context
|
||||
);
|
||||
let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context);
|
||||
let mut workers = vec![];
|
||||
let (stx, srx) = tokio::sync::watch::channel(false);
|
||||
|
||||
|
@ -406,42 +349,36 @@ macro_rules! serve_wth_ssl {
|
|||
let local = tokio::task::LocalSet::new();
|
||||
|
||||
crate::runtime::block_on_local(rt, local, async move {
|
||||
let service = crate::workers::build_service_ssl!(
|
||||
callback_wrapper, rth, $target
|
||||
);
|
||||
let server = hyper::Server::builder(
|
||||
crate::tls::tls_listen(
|
||||
std::sync::Arc::new(tls_cfg), tcp_listener
|
||||
)
|
||||
)
|
||||
.executor(crate::workers::WorkerExecutor)
|
||||
.http1_only(http1_only)
|
||||
.http2_only(http2_only)
|
||||
.http1_max_buf_size(http1_buffer_max)
|
||||
.serve(service);
|
||||
server.with_graceful_shutdown(async move {
|
||||
srx.changed().await.unwrap();
|
||||
}).await.unwrap();
|
||||
let service = crate::workers::build_service_ssl!(callback_wrapper, rth, $target);
|
||||
let server =
|
||||
hyper::Server::builder(crate::tls::tls_listen(std::sync::Arc::new(tls_cfg), tcp_listener))
|
||||
.executor(crate::workers::WorkerExecutor)
|
||||
.http1_only(http1_only)
|
||||
.http2_only(http2_only)
|
||||
.http1_max_buf_size(http1_buffer_max)
|
||||
.serve(service);
|
||||
server
|
||||
.with_graceful_shutdown(async move {
|
||||
srx.changed().await.unwrap();
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
log::info!("Stopping worker-{} runtime-{}", worker_id, thread_id + 1);
|
||||
});
|
||||
}));
|
||||
};
|
||||
}
|
||||
|
||||
let main_loop = crate::runtime::run_until_complete(
|
||||
rtm.handler(),
|
||||
event_loop,
|
||||
async move {
|
||||
Python::with_gil(|py| {
|
||||
crate::runtime::into_future(signal_rx.as_ref(py)).unwrap()
|
||||
}).await.unwrap();
|
||||
stx.send(true).unwrap();
|
||||
log::info!("Stopping worker-{}", worker_id);
|
||||
while let Some(worker) = workers.pop() {
|
||||
worker.join().unwrap();
|
||||
}
|
||||
Ok(())
|
||||
let main_loop = crate::runtime::run_until_complete(rtm.handler(), event_loop, async move {
|
||||
Python::with_gil(|py| crate::runtime::into_future(signal_rx.as_ref(py)).unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
stx.send(true).unwrap();
|
||||
log::info!("Stopping worker-{}", worker_id);
|
||||
while let Some(worker) = workers.pop() {
|
||||
worker.join().unwrap();
|
||||
}
|
||||
);
|
||||
Ok(())
|
||||
});
|
||||
|
||||
match main_loop {
|
||||
Ok(_) => {}
|
||||
|
@ -457,8 +394,8 @@ macro_rules! serve_wth_ssl {
|
|||
pub(crate) use build_service;
|
||||
pub(crate) use build_service_ssl;
|
||||
pub(crate) use serve_rth;
|
||||
pub(crate) use serve_wth;
|
||||
pub(crate) use serve_rth_ssl;
|
||||
pub(crate) use serve_wth;
|
||||
pub(crate) use serve_wth_ssl;
|
||||
|
||||
pub(crate) fn init_pymodule(module: &PyModule) -> PyResult<()> {
|
||||
|
|
64
src/ws.rs
64
src/ws.rs
|
@ -1,24 +1,24 @@
|
|||
use hyper::{
|
||||
Body,
|
||||
Request,
|
||||
Response,
|
||||
StatusCode,
|
||||
header::{CONNECTION, UPGRADE},
|
||||
http::response::Builder
|
||||
http::response::Builder,
|
||||
Body, Request, Response, StatusCode,
|
||||
};
|
||||
use pin_project::pin_project;
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tungstenite::{
|
||||
error::ProtocolError,
|
||||
handshake::derive_accept_key,
|
||||
protocol::{Role, WebSocketConfig}
|
||||
protocol::{Role, WebSocketConfig},
|
||||
};
|
||||
use pin_project::pin_project;
|
||||
use std::{future::Future, pin::Pin, task::{Context, Poll}};
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use super::utils::header_contains_value;
|
||||
|
||||
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct HyperWebsocket {
|
||||
|
@ -37,15 +37,9 @@ impl Future for HyperWebsocket {
|
|||
Poll::Ready(x) => x,
|
||||
};
|
||||
|
||||
let upgraded = upgraded.map_err(|_|
|
||||
tungstenite::Error::Protocol(ProtocolError::HandshakeIncomplete)
|
||||
)?;
|
||||
let upgraded = upgraded.map_err(|_| tungstenite::Error::Protocol(ProtocolError::HandshakeIncomplete))?;
|
||||
|
||||
let stream = WebSocketStream::from_raw_socket(
|
||||
upgraded,
|
||||
Role::Server,
|
||||
this.config.take(),
|
||||
);
|
||||
let stream = WebSocketStream::from_raw_socket(upgraded, Role::Server, this.config.take());
|
||||
tokio::pin!(stream);
|
||||
|
||||
match stream.as_mut().poll(cx) {
|
||||
|
@ -58,18 +52,15 @@ impl Future for HyperWebsocket {
|
|||
pub(crate) struct UpgradeData {
|
||||
response_builder: Option<Builder>,
|
||||
response_tx: Option<mpsc::Sender<Response<Body>>>,
|
||||
pub consumed: bool
|
||||
pub consumed: bool,
|
||||
}
|
||||
|
||||
impl UpgradeData {
|
||||
pub fn new(
|
||||
response_builder: Builder,
|
||||
response_tx: mpsc::Sender<Response<Body>>)
|
||||
-> Self {
|
||||
pub fn new(response_builder: Builder, response_tx: mpsc::Sender<Response<Body>>) -> Self {
|
||||
Self {
|
||||
response_builder: Some(response_builder),
|
||||
response_tx: Some(response_tx),
|
||||
consumed: false
|
||||
consumed: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,19 +70,16 @@ impl UpgradeData {
|
|||
Ok(_) => {
|
||||
self.consumed = true;
|
||||
Ok(())
|
||||
},
|
||||
err => err
|
||||
}
|
||||
err => err,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn is_upgrade_request<B>(request: &Request<B>) -> bool {
|
||||
header_contains_value(
|
||||
request.headers(), CONNECTION, "Upgrade"
|
||||
) && header_contains_value(
|
||||
request.headers(), UPGRADE, "websocket"
|
||||
)
|
||||
header_contains_value(request.headers(), CONNECTION, "Upgrade")
|
||||
&& header_contains_value(request.headers(), UPGRADE, "websocket")
|
||||
}
|
||||
|
||||
pub(crate) fn upgrade_intent<B>(
|
||||
|
@ -100,13 +88,17 @@ pub(crate) fn upgrade_intent<B>(
|
|||
) -> Result<(Builder, HyperWebsocket), ProtocolError> {
|
||||
let request = request.borrow_mut();
|
||||
|
||||
let key = request.headers()
|
||||
let key = request
|
||||
.headers()
|
||||
.get("Sec-WebSocket-Key")
|
||||
.ok_or(ProtocolError::MissingSecWebSocketKey)?;
|
||||
|
||||
if request.headers().get("Sec-WebSocket-Version").map(
|
||||
|v| v.as_bytes()
|
||||
) != Some(b"13") {
|
||||
if request
|
||||
.headers()
|
||||
.get("Sec-WebSocket-Version")
|
||||
.map(hyper::http::HeaderValue::as_bytes)
|
||||
!= Some(b"13")
|
||||
{
|
||||
return Err(ProtocolError::MissingSecWebSocketVersionHeader);
|
||||
}
|
||||
|
||||
|
|
|
@ -2,44 +2,35 @@ use hyper::Body;
|
|||
use pyo3::prelude::*;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use super::types::{WSGIResponseBodyIter, WSGIScope as Scope};
|
||||
use crate::callbacks::CallbackWrapper;
|
||||
use super::types::{WSGIScope as Scope, WSGIResponseBodyIter};
|
||||
|
||||
const WSGI_LIST_RESPONSE_BODY: i32 = 0;
|
||||
const WSGI_ITER_RESPONSE_BODY: i32 = 1;
|
||||
|
||||
|
||||
#[inline(always)]
|
||||
fn run_callback(
|
||||
callback: PyObject,
|
||||
scope: Scope
|
||||
) -> PyResult<(i32, Vec<(String, String)>, Body)> {
|
||||
fn run_callback(callback: PyObject, scope: Scope) -> PyResult<(i32, Vec<(String, String)>, Body)> {
|
||||
Python::with_gil(|py| {
|
||||
let (status, headers, body_type, pybody) = callback.call1(py, (scope,))?
|
||||
.extract::<(i32, Vec<(String, String)>, i32, PyObject)>(py)?;
|
||||
let (status, headers, body_type, pybody) =
|
||||
callback
|
||||
.call1(py, (scope,))?
|
||||
.extract::<(i32, Vec<(String, String)>, i32, PyObject)>(py)?;
|
||||
let body = match body_type {
|
||||
WSGI_LIST_RESPONSE_BODY => Body::from(pybody.extract::<Vec<u8>>(py)?),
|
||||
WSGI_ITER_RESPONSE_BODY => Body::wrap_stream(WSGIResponseBodyIter::new(pybody)),
|
||||
_ => Body::empty()
|
||||
_ => Body::empty(),
|
||||
};
|
||||
Ok((status, headers, body))
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn call_rtb_http(
|
||||
cb: CallbackWrapper,
|
||||
scope: Scope
|
||||
) -> PyResult<(i32, Vec<(String, String)>, Body)> {
|
||||
run_callback(cb.callback.clone(), scope)
|
||||
pub(crate) fn call_rtb_http(cb: CallbackWrapper, scope: Scope) -> PyResult<(i32, Vec<(String, String)>, Body)> {
|
||||
run_callback(cb.callback, scope)
|
||||
}
|
||||
|
||||
pub(crate) fn call_rtt_http(
|
||||
cb: CallbackWrapper,
|
||||
scope: Scope
|
||||
scope: Scope,
|
||||
) -> JoinHandle<PyResult<(i32, Vec<(String, String)>, Body)>> {
|
||||
let callback = cb.callback.clone();
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
run_callback(callback, scope)
|
||||
})
|
||||
tokio::task::spawn_blocking(move || run_callback(cb.callback, scope))
|
||||
}
|
||||
|
|
|
@ -1,21 +1,18 @@
|
|||
use hyper::{
|
||||
Body,
|
||||
Request,
|
||||
Response,
|
||||
header::{SERVER as HK_SERVER, HeaderName, HeaderValue}
|
||||
header::{HeaderName, HeaderValue, SERVER as HK_SERVER},
|
||||
Body, Request, Response,
|
||||
};
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use crate::{
|
||||
callbacks::CallbackWrapper,
|
||||
http::{HV_SERVER, response_500},
|
||||
runtime::RuntimeRef,
|
||||
};
|
||||
use super::{
|
||||
callbacks::{call_rtb_http, call_rtt_http},
|
||||
types::WSGIScope as Scope
|
||||
types::WSGIScope as Scope,
|
||||
};
|
||||
use crate::{
|
||||
callbacks::CallbackWrapper,
|
||||
http::{response_500, HV_SERVER},
|
||||
runtime::RuntimeRef,
|
||||
};
|
||||
|
||||
|
||||
#[inline(always)]
|
||||
fn build_response(status: i32, pyheaders: Vec<(String, String)>, body: Body) -> Response<Body> {
|
||||
|
@ -26,7 +23,7 @@ fn build_response(status: i32, pyheaders: Vec<(String, String)>, body: Body) ->
|
|||
for (key, val) in pyheaders {
|
||||
headers.append(
|
||||
HeaderName::from_bytes(key.as_bytes()).unwrap(),
|
||||
HeaderValue::from_str(&val).unwrap()
|
||||
HeaderValue::from_str(&val).unwrap(),
|
||||
);
|
||||
}
|
||||
res
|
||||
|
@ -38,14 +35,11 @@ pub(crate) async fn handle_rtt(
|
|||
server_addr: SocketAddr,
|
||||
client_addr: SocketAddr,
|
||||
req: Request<Body>,
|
||||
scheme: &str
|
||||
scheme: &str,
|
||||
) -> Response<Body> {
|
||||
if let Ok(res) = call_rtt_http(
|
||||
callback,
|
||||
Scope::new(scheme, server_addr, client_addr, req).await
|
||||
).await {
|
||||
if let Ok(res) = call_rtt_http(callback, Scope::new(scheme, server_addr, client_addr, req).await).await {
|
||||
if let Ok((status, headers, body)) = res {
|
||||
return build_response(status, headers, body)
|
||||
return build_response(status, headers, body);
|
||||
}
|
||||
log::warn!("Application callable raised an exception");
|
||||
} else {
|
||||
|
@ -60,12 +54,9 @@ pub(crate) async fn handle_rtb(
|
|||
server_addr: SocketAddr,
|
||||
client_addr: SocketAddr,
|
||||
req: Request<Body>,
|
||||
scheme: &str
|
||||
scheme: &str,
|
||||
) -> Response<Body> {
|
||||
match call_rtb_http(
|
||||
callback,
|
||||
Scope::new(scheme, server_addr, client_addr, req).await
|
||||
) {
|
||||
match call_rtb_http(callback, Scope::new(scheme, server_addr, client_addr, req).await) {
|
||||
Ok((status, headers, body)) => build_response(status, headers, body),
|
||||
_ => {
|
||||
log::warn!("Application callable raised an exception");
|
||||
|
|
|
@ -1,17 +1,11 @@
|
|||
use pyo3::prelude::*;
|
||||
|
||||
use crate::workers::{
|
||||
WorkerConfig,
|
||||
serve_rth,
|
||||
serve_wth,
|
||||
serve_rth_ssl,
|
||||
serve_wth_ssl
|
||||
};
|
||||
use super::http::{handle_rtb, handle_rtt};
|
||||
use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig};
|
||||
|
||||
#[pyclass(module="granian._granian")]
|
||||
#[pyclass(module = "granian._granian")]
|
||||
pub struct WSGIWorker {
|
||||
config: WorkerConfig
|
||||
config: WorkerConfig,
|
||||
}
|
||||
|
||||
impl WSGIWorker {
|
||||
|
@ -46,7 +40,7 @@ impl WSGIWorker {
|
|||
http1_buffer_max: usize,
|
||||
ssl_enabled: bool,
|
||||
ssl_cert: Option<&str>,
|
||||
ssl_key: Option<&str>
|
||||
ssl_key: Option<&str>,
|
||||
) -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
config: WorkerConfig::new(
|
||||
|
@ -60,34 +54,22 @@ impl WSGIWorker {
|
|||
true,
|
||||
ssl_enabled,
|
||||
ssl_cert,
|
||||
ssl_key
|
||||
)
|
||||
ssl_key,
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
fn serve_rth(
|
||||
&self,
|
||||
callback: PyObject,
|
||||
event_loop: &PyAny,
|
||||
context: &PyAny,
|
||||
signal_rx: PyObject
|
||||
) {
|
||||
fn serve_rth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
|
||||
match self.config.ssl_enabled {
|
||||
false => self._serve_rth(callback, event_loop, context, signal_rx),
|
||||
true => self._serve_rth_ssl(callback, event_loop, context, signal_rx)
|
||||
true => self._serve_rth_ssl(callback, event_loop, context, signal_rx),
|
||||
}
|
||||
}
|
||||
|
||||
fn serve_wth(
|
||||
&self,
|
||||
callback: PyObject,
|
||||
event_loop: &PyAny,
|
||||
context: &PyAny,
|
||||
signal_rx: PyObject
|
||||
) {
|
||||
fn serve_wth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) {
|
||||
match self.config.ssl_enabled {
|
||||
false => self._serve_wth(callback, event_loop, context, signal_rx),
|
||||
true => self._serve_wth_ssl(callback, event_loop, context, signal_rx)
|
||||
true => self._serve_wth_ssl(callback, event_loop, context, signal_rx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,23 +1,21 @@
|
|||
use futures::Stream;
|
||||
use hyper::{
|
||||
body::Bytes,
|
||||
header::{CONTENT_TYPE, CONTENT_LENGTH, HeaderMap},
|
||||
Body,
|
||||
Method,
|
||||
Request,
|
||||
Uri,
|
||||
Version
|
||||
header::{HeaderMap, CONTENT_LENGTH, CONTENT_TYPE},
|
||||
Body, Method, Request, Uri, Version,
|
||||
};
|
||||
use pyo3::{prelude::*, types::IntoPyDict};
|
||||
use pyo3::types::{PyBytes, PyDict, PyList};
|
||||
use std::{net::{IpAddr, SocketAddr}, task::{Context, Poll}};
|
||||
use pyo3::{prelude::*, types::IntoPyDict};
|
||||
use std::{
|
||||
net::{IpAddr, SocketAddr},
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
const LINE_SPLIT: u8 = u8::from_be_bytes(*b"\n");
|
||||
|
||||
|
||||
#[pyclass(module = "granian._granian")]
|
||||
pub(crate) struct WSGIBody {
|
||||
inner: Bytes
|
||||
inner: Bytes,
|
||||
}
|
||||
|
||||
impl WSGIBody {
|
||||
|
@ -36,9 +34,9 @@ impl WSGIBody {
|
|||
match self.inner.iter().position(|&c| c == LINE_SPLIT) {
|
||||
Some(next_split) => {
|
||||
let bytes = self.inner.split_to(next_split);
|
||||
Some(PyBytes::new(py, &bytes[..]))
|
||||
},
|
||||
_ => None
|
||||
Some(PyBytes::new(py, &bytes))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -48,18 +46,16 @@ impl WSGIBody {
|
|||
None => {
|
||||
let bytes = self.inner.split_to(self.inner.len());
|
||||
PyBytes::new(py, &bytes[..])
|
||||
},
|
||||
Some(size) => {
|
||||
match size {
|
||||
0 => PyBytes::new(py, b""),
|
||||
size => {
|
||||
let limit = self.inner.len();
|
||||
let rsize = if size > limit { limit } else { size };
|
||||
let bytes = self.inner.split_to(rsize);
|
||||
PyBytes::new(py, &bytes[..])
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(size) => match size {
|
||||
0 => PyBytes::new(py, b""),
|
||||
size => {
|
||||
let limit = self.inner.len();
|
||||
let rsize = if size > limit { limit } else { size };
|
||||
let bytes = self.inner.split_to(rsize);
|
||||
PyBytes::new(py, &bytes[..])
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -69,16 +65,17 @@ impl WSGIBody {
|
|||
let bytes = self.inner.split_to(next_split);
|
||||
self.inner = self.inner.slice(1..);
|
||||
PyBytes::new(py, &bytes[..])
|
||||
},
|
||||
_ => PyBytes::new(py, b"")
|
||||
}
|
||||
_ => PyBytes::new(py, b""),
|
||||
}
|
||||
}
|
||||
|
||||
#[pyo3(signature = (_hint=None))]
|
||||
fn readlines<'p>(&mut self, py: Python<'p>, _hint: Option<PyObject>) -> &'p PyList {
|
||||
let lines: Vec<&PyBytes> = self.inner
|
||||
let lines: Vec<&PyBytes> = self
|
||||
.inner
|
||||
.split(|&c| c == LINE_SPLIT)
|
||||
.map(|item| PyBytes::new(py, &item[..]))
|
||||
.map(|item| PyBytes::new(py, item))
|
||||
.collect();
|
||||
self.inner.clear();
|
||||
PyList::new(py, lines)
|
||||
|
@ -95,28 +92,19 @@ pub(crate) struct WSGIScope {
|
|||
server_port: u16,
|
||||
client: String,
|
||||
headers: HeaderMap,
|
||||
body: Bytes
|
||||
body: Bytes,
|
||||
}
|
||||
|
||||
impl WSGIScope {
|
||||
pub async fn new(
|
||||
scheme: &str,
|
||||
server: SocketAddr,
|
||||
client: SocketAddr,
|
||||
request: Request<Body>,
|
||||
) -> Self {
|
||||
pub async fn new(scheme: &str, server: SocketAddr, client: SocketAddr, request: Request<Body>) -> Self {
|
||||
let http_version = request.version();
|
||||
let method = request.method().to_owned();
|
||||
let uri = request.uri().to_owned();
|
||||
let headers = request.headers().to_owned();
|
||||
let method = request.method().clone();
|
||||
let uri = request.uri().clone();
|
||||
let headers = request.headers().clone();
|
||||
|
||||
let body = match method {
|
||||
Method::HEAD | Method::GET | Method::OPTIONS => { Bytes::new() },
|
||||
_ => {
|
||||
hyper::body::to_bytes(request)
|
||||
.await
|
||||
.unwrap_or(Bytes::new())
|
||||
}
|
||||
Method::HEAD | Method::GET | Method::OPTIONS => Bytes::new(),
|
||||
_ => hyper::body::to_bytes(request).await.unwrap_or(Bytes::new()),
|
||||
};
|
||||
|
||||
Self {
|
||||
|
@ -128,7 +116,7 @@ impl WSGIScope {
|
|||
server_port: server.port(),
|
||||
client: client.to_string(),
|
||||
headers,
|
||||
body
|
||||
body,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -138,7 +126,7 @@ impl WSGIScope {
|
|||
Version::HTTP_10 => "HTTP/1",
|
||||
Version::HTTP_11 => "HTTP/1.1",
|
||||
Version::HTTP_2 => "HTTP/2",
|
||||
_ => "HTTP/1"
|
||||
_ => "HTTP/1",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -157,21 +145,21 @@ impl WSGIScope {
|
|||
content_type,
|
||||
content_len,
|
||||
headers,
|
||||
body
|
||||
body,
|
||||
) = py.allow_threads(|| {
|
||||
let (path, query_string) = self.uri.path_and_query()
|
||||
let (path, query_string) = self
|
||||
.uri
|
||||
.path_and_query()
|
||||
.map_or_else(|| ("", ""), |pq| (pq.path(), pq.query().unwrap_or("")));
|
||||
let content_type = self.headers.remove(CONTENT_TYPE);
|
||||
let content_len = self.headers.remove(CONTENT_LENGTH);
|
||||
let mut headers = Vec::with_capacity(self.headers.len());
|
||||
|
||||
for (key, val) in self.headers.iter() {
|
||||
headers.push(
|
||||
(
|
||||
format!("HTTP_{}", key.as_str().replace("-", "_").to_uppercase()),
|
||||
val.to_str().unwrap_or_default()
|
||||
)
|
||||
);
|
||||
for (key, val) in &self.headers {
|
||||
headers.push((
|
||||
format!("HTTP_{}", key.as_str().replace('-', "_").to_uppercase()),
|
||||
val.to_str().unwrap_or_default(),
|
||||
));
|
||||
}
|
||||
|
||||
(
|
||||
|
@ -185,7 +173,7 @@ impl WSGIScope {
|
|||
content_type,
|
||||
content_len,
|
||||
headers,
|
||||
WSGIBody::new(self.body.to_owned())
|
||||
WSGIBody::new(self.body.clone()),
|
||||
)
|
||||
});
|
||||
|
||||
|
@ -202,13 +190,13 @@ impl WSGIScope {
|
|||
if let Some(content_type) = content_type {
|
||||
ret.set_item(
|
||||
pyo3::intern!(py, "CONTENT_TYPE"),
|
||||
content_type.to_str().unwrap_or_default()
|
||||
content_type.to_str().unwrap_or_default(),
|
||||
)?;
|
||||
}
|
||||
if let Some(content_len) = content_len {
|
||||
ret.set_item(
|
||||
pyo3::intern!(py, "CONTENT_LENGTH"),
|
||||
content_len.to_str().unwrap_or_default()
|
||||
content_len.to_str().unwrap_or_default(),
|
||||
)?;
|
||||
}
|
||||
|
||||
|
@ -219,7 +207,7 @@ impl WSGIScope {
|
|||
}
|
||||
|
||||
pub(crate) struct WSGIResponseBodyIter {
|
||||
inner: PyObject
|
||||
inner: PyObject,
|
||||
}
|
||||
|
||||
impl WSGIResponseBodyIter {
|
||||
|
@ -235,27 +223,20 @@ impl WSGIResponseBodyIter {
|
|||
impl Stream for WSGIResponseBodyIter {
|
||||
type Item = PyResult<Vec<u8>>;
|
||||
|
||||
fn poll_next(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>
|
||||
) -> Poll<Option<Self::Item>> {
|
||||
Python::with_gil(|py| {
|
||||
match self.inner.call_method0(py, pyo3::intern!(py, "__next__")) {
|
||||
Ok(chunk_obj) => {
|
||||
match chunk_obj.extract::<Vec<u8>>(py) {
|
||||
Ok(chunk) => Poll::Ready(Some(Ok(chunk))),
|
||||
_ => {
|
||||
self.close_inner(py);
|
||||
Poll::Ready(None)
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
if err.is_instance_of::<pyo3::exceptions::PyStopIteration>(py) {
|
||||
self.close_inner(py);
|
||||
}
|
||||
fn poll_next(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
Python::with_gil(|py| match self.inner.call_method0(py, pyo3::intern!(py, "__next__")) {
|
||||
Ok(chunk_obj) => match chunk_obj.extract::<Vec<u8>>(py) {
|
||||
Ok(chunk) => Poll::Ready(Some(Ok(chunk))),
|
||||
_ => {
|
||||
self.close_inner(py);
|
||||
Poll::Ready(None)
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
if err.is_instance_of::<pyo3::exceptions::PyStopIteration>(py) {
|
||||
self.close_inner(py);
|
||||
}
|
||||
Poll::Ready(None)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,55 +1,45 @@
|
|||
import json
|
||||
|
||||
|
||||
PLAINTEXT_RESPONSE = {
|
||||
'type': 'http.response.start',
|
||||
'status': 200,
|
||||
'headers': [
|
||||
[b'content-type', b'text/plain; charset=utf-8'],
|
||||
]
|
||||
}
|
||||
JSON_RESPONSE = {
|
||||
'type': 'http.response.start',
|
||||
'status': 200,
|
||||
'headers': [
|
||||
[b'content-type', b'application/json'],
|
||||
]
|
||||
'headers': [[b'content-type', b'text/plain; charset=utf-8']],
|
||||
}
|
||||
JSON_RESPONSE = {'type': 'http.response.start', 'status': 200, 'headers': [[b'content-type', b'application/json']]}
|
||||
|
||||
|
||||
async def info(scope, receive, send):
|
||||
await send(JSON_RESPONSE)
|
||||
await send({
|
||||
'type': 'http.response.body',
|
||||
'body': json.dumps({
|
||||
'type': scope['type'],
|
||||
'asgi': scope['asgi'],
|
||||
'http_version': scope['http_version'],
|
||||
'scheme': scope['scheme'],
|
||||
'method': scope['method'],
|
||||
'path': scope['path'],
|
||||
'query_string': scope['query_string'].decode("latin-1"),
|
||||
'headers': {
|
||||
k.decode("utf8"): v.decode("utf8")
|
||||
for k, v in scope['headers']
|
||||
}
|
||||
}).encode("utf8"),
|
||||
'more_body': False
|
||||
})
|
||||
await send(
|
||||
{
|
||||
'type': 'http.response.body',
|
||||
'body': json.dumps(
|
||||
{
|
||||
'type': scope['type'],
|
||||
'asgi': scope['asgi'],
|
||||
'http_version': scope['http_version'],
|
||||
'scheme': scope['scheme'],
|
||||
'method': scope['method'],
|
||||
'path': scope['path'],
|
||||
'query_string': scope['query_string'].decode('latin-1'),
|
||||
'headers': {k.decode('utf8'): v.decode('utf8') for k, v in scope['headers']},
|
||||
}
|
||||
).encode('utf8'),
|
||||
'more_body': False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def echo(scope, receive, send):
|
||||
await send(PLAINTEXT_RESPONSE)
|
||||
more_body = True
|
||||
body = b""
|
||||
body = b''
|
||||
while more_body:
|
||||
msg = await receive()
|
||||
more_body = msg['more_body']
|
||||
body += msg['body']
|
||||
await send({
|
||||
'type': 'http.response.body',
|
||||
'body': body,
|
||||
'more_body': False
|
||||
})
|
||||
await send({'type': 'http.response.body', 'body': body, 'more_body': False})
|
||||
|
||||
|
||||
async def ws_reject(scope, receive, send):
|
||||
|
@ -58,21 +48,22 @@ async def ws_reject(scope, receive, send):
|
|||
|
||||
async def ws_info(scope, receive, send):
|
||||
await send({'type': 'websocket.accept'})
|
||||
await send({
|
||||
'type': 'websocket.send',
|
||||
'text': json.dumps({
|
||||
'type': scope['type'],
|
||||
'asgi': scope['asgi'],
|
||||
'http_version': scope['http_version'],
|
||||
'scheme': scope['scheme'],
|
||||
'path': scope['path'],
|
||||
'query_string': scope['query_string'].decode("latin-1"),
|
||||
'headers': {
|
||||
k.decode("utf8"): v.decode("utf8")
|
||||
for k, v in scope['headers']
|
||||
}
|
||||
})
|
||||
})
|
||||
await send(
|
||||
{
|
||||
'type': 'websocket.send',
|
||||
'text': json.dumps(
|
||||
{
|
||||
'type': scope['type'],
|
||||
'asgi': scope['asgi'],
|
||||
'http_version': scope['http_version'],
|
||||
'scheme': scope['scheme'],
|
||||
'path': scope['path'],
|
||||
'query_string': scope['query_string'].decode('latin-1'),
|
||||
'headers': {k.decode('utf8'): v.decode('utf8') for k, v in scope['headers']},
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
await send({'type': 'websocket.close'})
|
||||
|
||||
|
||||
|
@ -98,10 +89,7 @@ async def ws_push(scope, receive, send):
|
|||
|
||||
try:
|
||||
while True:
|
||||
await send({
|
||||
'type': 'websocket.send',
|
||||
'text': 'ping'
|
||||
})
|
||||
await send({'type': 'websocket.send', 'text': 'ping'})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
@ -116,12 +104,12 @@ async def err_proto(scope, receive, send):
|
|||
|
||||
def app(scope, receive, send):
|
||||
return {
|
||||
"/info": info,
|
||||
"/echo": echo,
|
||||
"/ws_reject": ws_reject,
|
||||
"/ws_info": ws_info,
|
||||
"/ws_echo": ws_echo,
|
||||
"/ws_push": ws_push,
|
||||
"/err_app": err_app,
|
||||
"/err_proto": err_proto
|
||||
'/info': info,
|
||||
'/echo': echo,
|
||||
'/ws_reject': ws_reject,
|
||||
'/ws_info': ws_info,
|
||||
'/ws_echo': ws_echo,
|
||||
'/ws_push': ws_push,
|
||||
'/err_app': err_app,
|
||||
'/err_proto': err_proto,
|
||||
}[scope['path']](scope, receive, send)
|
||||
|
|
|
@ -1,55 +1,42 @@
|
|||
import json
|
||||
|
||||
from granian.rsgi import (
|
||||
HTTPProtocol,
|
||||
Scope,
|
||||
WebsocketMessageType,
|
||||
WebsocketProtocol
|
||||
)
|
||||
from granian.rsgi import HTTPProtocol, Scope, WebsocketMessageType, WebsocketProtocol
|
||||
|
||||
|
||||
async def info(scope: Scope, protocol: HTTPProtocol):
|
||||
protocol.response_bytes(
|
||||
200,
|
||||
[('content-type', 'application/json')],
|
||||
json.dumps({
|
||||
'proto': scope.proto,
|
||||
'http_version': scope.http_version,
|
||||
'rsgi_version': scope.rsgi_version,
|
||||
'scheme': scope.scheme,
|
||||
'method': scope.method,
|
||||
'path': scope.path,
|
||||
'query_string': scope.query_string,
|
||||
'headers': {k: v for k, v in scope.headers.items()}
|
||||
}).encode("utf8")
|
||||
json.dumps(
|
||||
{
|
||||
'proto': scope.proto,
|
||||
'http_version': scope.http_version,
|
||||
'rsgi_version': scope.rsgi_version,
|
||||
'scheme': scope.scheme,
|
||||
'method': scope.method,
|
||||
'path': scope.path,
|
||||
'query_string': scope.query_string,
|
||||
'headers': dict(scope.headers.items()),
|
||||
}
|
||||
).encode('utf8'),
|
||||
)
|
||||
|
||||
|
||||
async def echo(_, protocol: HTTPProtocol):
|
||||
msg = await protocol()
|
||||
protocol.response_bytes(
|
||||
200,
|
||||
[('content-type', 'text/plain; charset=utf-8')],
|
||||
msg
|
||||
)
|
||||
protocol.response_bytes(200, [('content-type', 'text/plain; charset=utf-8')], msg)
|
||||
|
||||
|
||||
async def echo_stream(_, protocol: HTTPProtocol):
|
||||
trx = protocol.response_stream(
|
||||
200,
|
||||
[('content-type', 'text/plain; charset=utf-8')]
|
||||
)
|
||||
trx = protocol.response_stream(200, [('content-type', 'text/plain; charset=utf-8')])
|
||||
async for msg in protocol:
|
||||
await trx.send_bytes(msg)
|
||||
|
||||
|
||||
async def stream(_, protocol: HTTPProtocol):
|
||||
trx = protocol.response_stream(
|
||||
200,
|
||||
[('content-type', 'text/plain; charset=utf-8')]
|
||||
)
|
||||
trx = protocol.response_stream(200, [('content-type', 'text/plain; charset=utf-8')])
|
||||
for _ in range(0, 3):
|
||||
await trx.send_bytes(b"test")
|
||||
await trx.send_bytes(b'test')
|
||||
|
||||
|
||||
async def ws_reject(_, protocol: WebsocketProtocol):
|
||||
|
@ -59,16 +46,20 @@ async def ws_reject(_, protocol: WebsocketProtocol):
|
|||
async def ws_info(scope: Scope, protocol: WebsocketProtocol):
|
||||
trx = await protocol.accept()
|
||||
|
||||
await trx.send_str(json.dumps({
|
||||
'proto': scope.proto,
|
||||
'http_version': scope.http_version,
|
||||
'rsgi_version': scope.rsgi_version,
|
||||
'scheme': scope.scheme,
|
||||
'method': scope.method,
|
||||
'path': scope.path,
|
||||
'query_string': scope.query_string,
|
||||
'headers': {k: v for k, v in scope.headers.items()}
|
||||
}))
|
||||
await trx.send_str(
|
||||
json.dumps(
|
||||
{
|
||||
'proto': scope.proto,
|
||||
'http_version': scope.http_version,
|
||||
'rsgi_version': scope.rsgi_version,
|
||||
'scheme': scope.scheme,
|
||||
'method': scope.method,
|
||||
'path': scope.path,
|
||||
'query_string': scope.query_string,
|
||||
'headers': dict(scope.headers.items()),
|
||||
}
|
||||
)
|
||||
)
|
||||
while True:
|
||||
message = await trx.receive()
|
||||
if message.kind == WebsocketMessageType.close:
|
||||
|
@ -97,7 +88,7 @@ async def ws_push(_, protocol: WebsocketProtocol):
|
|||
|
||||
try:
|
||||
while True:
|
||||
await trx.send_str("ping")
|
||||
await trx.send_str('ping')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
@ -110,13 +101,13 @@ async def err_app(scope: Scope, protocol: HTTPProtocol):
|
|||
|
||||
def app(scope, protocol):
|
||||
return {
|
||||
"/info": info,
|
||||
"/echo": echo,
|
||||
"/echos": echo_stream,
|
||||
"/stream": stream,
|
||||
"/ws_reject": ws_reject,
|
||||
"/ws_info": ws_info,
|
||||
"/ws_echo": ws_echo,
|
||||
"/ws_push": ws_push,
|
||||
"/err_app": err_app
|
||||
'/info': info,
|
||||
'/echo': echo,
|
||||
'/echos': echo_stream,
|
||||
'/stream': stream,
|
||||
'/ws_reject': ws_reject,
|
||||
'/ws_info': ws_info,
|
||||
'/ws_echo': ws_echo,
|
||||
'/ws_push': ws_push,
|
||||
'/err_app': err_app,
|
||||
}[scope.path](scope, protocol)
|
||||
|
|
|
@ -2,37 +2,32 @@ import json
|
|||
|
||||
|
||||
def info(environ, protocol):
|
||||
protocol(
|
||||
"200 OK",
|
||||
[('content-type', 'application/json')]
|
||||
)
|
||||
return [json.dumps({
|
||||
'scheme': environ['wsgi.url_scheme'],
|
||||
'method': environ['REQUEST_METHOD'],
|
||||
'path': environ["PATH_INFO"],
|
||||
'query_string': environ["QUERY_STRING"],
|
||||
'content_length': environ['CONTENT_LENGTH'],
|
||||
'headers': {k: v for k, v in environ.items() if k.startswith("HTTP_")}
|
||||
}).encode("utf8")]
|
||||
protocol('200 OK', [('content-type', 'application/json')])
|
||||
return [
|
||||
json.dumps(
|
||||
{
|
||||
'scheme': environ['wsgi.url_scheme'],
|
||||
'method': environ['REQUEST_METHOD'],
|
||||
'path': environ['PATH_INFO'],
|
||||
'query_string': environ['QUERY_STRING'],
|
||||
'content_length': environ['CONTENT_LENGTH'],
|
||||
'headers': {k: v for k, v in environ.items() if k.startswith('HTTP_')},
|
||||
}
|
||||
).encode('utf8')
|
||||
]
|
||||
|
||||
|
||||
def echo(environ, protocol):
|
||||
protocol(
|
||||
'200 OK',
|
||||
[('content-type', 'text/plain; charset=utf-8')]
|
||||
)
|
||||
protocol('200 OK', [('content-type', 'text/plain; charset=utf-8')])
|
||||
return [environ['wsgi.input'].read()]
|
||||
|
||||
|
||||
def iterbody(environ, protocol):
|
||||
def response():
|
||||
for _ in range(0, 3):
|
||||
yield b"test"
|
||||
yield b'test'
|
||||
|
||||
protocol(
|
||||
'200 OK',
|
||||
[('content-type', 'text/plain; charset=utf-8')]
|
||||
)
|
||||
protocol('200 OK', [('content-type', 'text/plain; charset=utf-8')])
|
||||
return response()
|
||||
|
||||
|
||||
|
@ -41,9 +36,6 @@ def err_app(environ, protocol):
|
|||
|
||||
|
||||
def app(environ, protocol):
|
||||
return {
|
||||
"/info": info,
|
||||
"/echo": echo,
|
||||
"/iterbody": iterbody,
|
||||
"/err_app": err_app
|
||||
}[environ["PATH_INFO"]](environ, protocol)
|
||||
return {'/info': info, '/echo': echo, '/iterbody': iterbody, '/err_app': err_app}[environ['PATH_INFO']](
|
||||
environ, protocol
|
||||
)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import asyncio
|
||||
import os
|
||||
import socket
|
||||
|
||||
from contextlib import asynccontextmanager, closing
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
@ -11,19 +10,20 @@ import pytest
|
|||
|
||||
@asynccontextmanager
|
||||
async def _server(interface, port, threading_mode, tls=False):
|
||||
certs_path = Path.cwd() / "tests" / "fixtures" / "tls"
|
||||
certs_path = Path.cwd() / 'tests' / 'fixtures' / 'tls'
|
||||
tls_opts = (
|
||||
f"--ssl-certificate {certs_path / 'cert.pem'} "
|
||||
f"--ssl-keyfile {certs_path / 'key.pem'} "
|
||||
) if tls else ""
|
||||
(f"--ssl-certificate {certs_path / 'cert.pem'} " f"--ssl-keyfile {certs_path / 'key.pem'} ") if tls else ''
|
||||
)
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
"".join([
|
||||
f"granian --interface {interface} --port {port} ",
|
||||
f"--threads 1 --threading-mode {threading_mode} ",
|
||||
tls_opts,
|
||||
f"tests.apps.{interface}:app"
|
||||
]),
|
||||
env=dict(os.environ)
|
||||
''.join(
|
||||
[
|
||||
f'granian --interface {interface} --port {port} ',
|
||||
f'--threads 1 --threading-mode {threading_mode} ',
|
||||
tls_opts,
|
||||
f'tests.apps.{interface}:app',
|
||||
]
|
||||
),
|
||||
env=dict(os.environ),
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
try:
|
||||
|
@ -33,7 +33,7 @@ async def _server(interface, port, threading_mode, tls=False):
|
|||
await proc.wait()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest.fixture(scope='function')
|
||||
def server_port():
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
||||
sock.bind(('localhost', 0))
|
||||
|
@ -41,26 +41,26 @@ def server_port():
|
|||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest.fixture(scope='function')
|
||||
def asgi_server(server_port):
|
||||
return partial(_server, "asgi", server_port)
|
||||
return partial(_server, 'asgi', server_port)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest.fixture(scope='function')
|
||||
def rsgi_server(server_port):
|
||||
return partial(_server, "rsgi", server_port)
|
||||
return partial(_server, 'rsgi', server_port)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest.fixture(scope='function')
|
||||
def wsgi_server(server_port):
|
||||
return partial(_server, "wsgi", server_port)
|
||||
return partial(_server, 'wsgi', server_port)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest.fixture(scope='function')
|
||||
def server(server_port, request):
|
||||
return partial(_server, request.param, server_port)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest.fixture(scope='function')
|
||||
def server_tls(server_port, request):
|
||||
return partial(_server, request.param, server_port, tls=True)
|
||||
|
|
|
@ -3,92 +3,59 @@ import pytest
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"threading_mode",
|
||||
[
|
||||
"runtime",
|
||||
"workers"
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_scope(asgi_server, threading_mode):
|
||||
async with asgi_server(threading_mode) as port:
|
||||
res = httpx.get(f"http://localhost:{port}/info?test=true")
|
||||
res = httpx.get(f'http://localhost:{port}/info?test=true')
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.headers["content-type"] == "application/json"
|
||||
assert res.headers['content-type'] == 'application/json'
|
||||
|
||||
data = res.json()
|
||||
assert data['asgi'] == {
|
||||
'version': '3.0',
|
||||
'spec_version': '2.3'
|
||||
}
|
||||
assert data['type'] == "http"
|
||||
assert data['asgi'] == {'version': '3.0', 'spec_version': '2.3'}
|
||||
assert data['type'] == 'http'
|
||||
assert data['http_version'] == '1.1'
|
||||
assert data['scheme'] == 'http'
|
||||
assert data['method'] == "GET"
|
||||
assert data['method'] == 'GET'
|
||||
assert data['path'] == '/info'
|
||||
assert data['query_string'] == 'test=true'
|
||||
assert data['headers']['host'] == f'localhost:{port}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"threading_mode",
|
||||
[
|
||||
"runtime",
|
||||
"workers"
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_body(asgi_server, threading_mode):
|
||||
async with asgi_server(threading_mode) as port:
|
||||
res = httpx.post(f"http://localhost:{port}/echo", content="test")
|
||||
res = httpx.post(f'http://localhost:{port}/echo', content='test')
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.text == "test"
|
||||
assert res.text == 'test'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"threading_mode",
|
||||
[
|
||||
"runtime",
|
||||
"workers"
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_body_large(asgi_server, threading_mode):
|
||||
data = "".join([f"{idx}test".zfill(8) for idx in range(0, 5000)])
|
||||
data = ''.join([f'{idx}test'.zfill(8) for idx in range(0, 5000)])
|
||||
async with asgi_server(threading_mode) as port:
|
||||
res = httpx.post(f"http://localhost:{port}/echo", content=data)
|
||||
res = httpx.post(f'http://localhost:{port}/echo', content=data)
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.text == data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"threading_mode",
|
||||
[
|
||||
"runtime",
|
||||
"workers"
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_app_error(asgi_server, threading_mode):
|
||||
async with asgi_server(threading_mode) as port:
|
||||
res = httpx.get(f"http://localhost:{port}/err_app")
|
||||
res = httpx.get(f'http://localhost:{port}/err_app')
|
||||
|
||||
assert res.status_code == 500
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"threading_mode",
|
||||
[
|
||||
"runtime",
|
||||
"workers"
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_protocol_error(asgi_server, threading_mode):
|
||||
async with asgi_server(threading_mode) as port:
|
||||
res = httpx.get(f"http://localhost:{port}/err_proto")
|
||||
res = httpx.get(f'http://localhost:{port}/err_proto')
|
||||
|
||||
assert res.status_code == 500
|
||||
|
|
|
@ -1,20 +1,18 @@
|
|||
import httpx
|
||||
import json
|
||||
import pathlib
|
||||
import pytest
|
||||
import ssl
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import websockets
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("server_tls", ["asgi", "rsgi"], indirect=True)
|
||||
@pytest.mark.parametrize("threading_mode", ["runtime", "workers"])
|
||||
@pytest.mark.parametrize('server_tls', ['asgi', 'rsgi'], indirect=True)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_http_scope(server_tls, threading_mode):
|
||||
async with server_tls(threading_mode) as port:
|
||||
res = httpx.get(
|
||||
f"https://localhost:{port}/info?test=true",
|
||||
verify=False
|
||||
)
|
||||
res = httpx.get(f'https://localhost:{port}/info?test=true', verify=False)
|
||||
|
||||
assert res.status_code == 200
|
||||
data = res.json()
|
||||
|
@ -22,17 +20,14 @@ async def test_http_scope(server_tls, threading_mode):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("threading_mode", ["runtime", "workers"])
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_asgi_ws_scope(asgi_server, threading_mode):
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
localhost_pem = pathlib.Path.cwd() / "tests" / "fixtures" / "tls" / "cert.pem"
|
||||
localhost_pem = pathlib.Path.cwd() / 'tests' / 'fixtures' / 'tls' / 'cert.pem'
|
||||
ssl_context.load_verify_locations(localhost_pem)
|
||||
|
||||
async with asgi_server(threading_mode, tls=True) as port:
|
||||
async with websockets.connect(
|
||||
f"wss://localhost:{port}/ws_info?test=true",
|
||||
ssl=ssl_context
|
||||
) as ws:
|
||||
async with websockets.connect(f'wss://localhost:{port}/ws_info?test=true', ssl=ssl_context) as ws:
|
||||
res = await ws.recv()
|
||||
|
||||
data = json.loads(res)
|
||||
|
@ -40,17 +35,14 @@ async def test_asgi_ws_scope(asgi_server, threading_mode):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("threading_mode", ["runtime", "workers"])
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_rsgi_ws_scope(rsgi_server, threading_mode):
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
localhost_pem = pathlib.Path.cwd() / "tests" / "fixtures" / "tls" / "cert.pem"
|
||||
localhost_pem = pathlib.Path.cwd() / 'tests' / 'fixtures' / 'tls' / 'cert.pem'
|
||||
ssl_context.load_verify_locations(localhost_pem)
|
||||
|
||||
async with rsgi_server(threading_mode, tls=True) as port:
|
||||
async with websockets.connect(
|
||||
f"wss://localhost:{port}/ws_info?test=true",
|
||||
ssl=ssl_context
|
||||
) as ws:
|
||||
async with websockets.connect(f'wss://localhost:{port}/ws_info?test=true', ssl=ssl_context) as ws:
|
||||
res = await ws.recv()
|
||||
|
||||
data = json.loads(res)
|
||||
|
|
|
@ -3,90 +3,60 @@ import pytest
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"threading_mode",
|
||||
[
|
||||
"runtime",
|
||||
"workers"
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_scope(rsgi_server, threading_mode):
|
||||
async with rsgi_server(threading_mode) as port:
|
||||
res = httpx.get(f"http://localhost:{port}/info?test=true")
|
||||
res = httpx.get(f'http://localhost:{port}/info?test=true')
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.headers["content-type"] == "application/json"
|
||||
assert res.headers['content-type'] == 'application/json'
|
||||
|
||||
data = res.json()
|
||||
assert data['proto'] == "http"
|
||||
assert data['proto'] == 'http'
|
||||
assert data['http_version'] == '1.1'
|
||||
assert data['rsgi_version'] == '1.2'
|
||||
assert data['scheme'] == 'http'
|
||||
assert data['method'] == "GET"
|
||||
assert data['method'] == 'GET'
|
||||
assert data['path'] == '/info'
|
||||
assert data['query_string'] == 'test=true'
|
||||
assert data['headers']['host'] == f'localhost:{port}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"threading_mode",
|
||||
[
|
||||
"runtime",
|
||||
"workers"
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_body(rsgi_server, threading_mode):
|
||||
async with rsgi_server(threading_mode) as port:
|
||||
res = httpx.post(f"http://localhost:{port}/echo", content="test")
|
||||
res = httpx.post(f'http://localhost:{port}/echo', content='test')
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.text == "test"
|
||||
assert res.text == 'test'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"threading_mode",
|
||||
[
|
||||
"runtime",
|
||||
"workers"
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_body_stream_req(rsgi_server, threading_mode):
|
||||
data = "".join([f"{idx}test".zfill(8) for idx in range(0, 5000)])
|
||||
data = ''.join([f'{idx}test'.zfill(8) for idx in range(0, 5000)])
|
||||
async with rsgi_server(threading_mode) as port:
|
||||
res = httpx.post(f"http://localhost:{port}/echos", content=data)
|
||||
res = httpx.post(f'http://localhost:{port}/echos', content=data)
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.text == data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"threading_mode",
|
||||
[
|
||||
"runtime",
|
||||
"workers"
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_body_stream_res(rsgi_server, threading_mode):
|
||||
async with rsgi_server(threading_mode) as port:
|
||||
res = httpx.get(f"http://localhost:{port}/stream")
|
||||
res = httpx.get(f'http://localhost:{port}/stream')
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.text == "test" * 3
|
||||
assert res.text == 'test' * 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"threading_mode",
|
||||
[
|
||||
"runtime",
|
||||
"workers"
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_app_error(rsgi_server, threading_mode):
|
||||
async with rsgi_server(threading_mode) as port:
|
||||
res = httpx.get(f"http://localhost:{port}/err_app")
|
||||
res = httpx.get(f'http://localhost:{port}/err_app')
|
||||
|
||||
assert res.status_code == 500
|
||||
|
|
|
@ -1,56 +1,48 @@
|
|||
import json
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import websockets
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="skip on windows")
|
||||
@pytest.mark.skipif(sys.platform == 'win32', reason='skip on windows')
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("server", ["asgi", "rsgi"], indirect=True)
|
||||
@pytest.mark.parametrize("threading_mode", ["runtime", "workers"])
|
||||
@pytest.mark.parametrize('server', ['asgi', 'rsgi'], indirect=True)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_messages(server, threading_mode):
|
||||
async with server(threading_mode) as port:
|
||||
async with websockets.connect(f"ws://localhost:{port}/ws_echo") as ws:
|
||||
await ws.send("foo")
|
||||
async with websockets.connect(f'ws://localhost:{port}/ws_echo') as ws:
|
||||
await ws.send('foo')
|
||||
res_text = await ws.recv()
|
||||
await ws.send(b"foo")
|
||||
await ws.send(b'foo')
|
||||
res_bytes = await ws.recv()
|
||||
|
||||
assert res_text == "foo"
|
||||
assert res_bytes == b"foo"
|
||||
assert res_text == 'foo'
|
||||
assert res_bytes == b'foo'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("server", ["asgi", "rsgi"], indirect=True)
|
||||
@pytest.mark.parametrize("threading_mode", ["runtime", "workers"])
|
||||
@pytest.mark.parametrize('server', ['asgi', 'rsgi'], indirect=True)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_reject(server, threading_mode):
|
||||
async with server(threading_mode) as port:
|
||||
with pytest.raises(websockets.InvalidStatusCode) as exc:
|
||||
async with websockets.connect(f"ws://localhost:{port}/ws_reject") as ws:
|
||||
async with websockets.connect(f'ws://localhost:{port}/ws_reject'):
|
||||
pass
|
||||
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"threading_mode",
|
||||
[
|
||||
"runtime",
|
||||
"workers"
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_asgi_scope(asgi_server, threading_mode):
|
||||
async with asgi_server(threading_mode) as port:
|
||||
async with websockets.connect(f"ws://localhost:{port}/ws_info?test=true") as ws:
|
||||
async with websockets.connect(f'ws://localhost:{port}/ws_info?test=true') as ws:
|
||||
res = await ws.recv()
|
||||
|
||||
data = json.loads(res)
|
||||
assert data['asgi'] == {
|
||||
'version': '3.0',
|
||||
'spec_version': '2.3'
|
||||
}
|
||||
assert data['type'] == "websocket"
|
||||
assert data['asgi'] == {'version': '3.0', 'spec_version': '2.3'}
|
||||
assert data['type'] == 'websocket'
|
||||
assert data['http_version'] == '1.1'
|
||||
assert data['scheme'] == 'ws'
|
||||
assert data['path'] == '/ws_info'
|
||||
|
@ -59,16 +51,10 @@ async def test_asgi_scope(asgi_server, threading_mode):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"threading_mode",
|
||||
[
|
||||
"runtime",
|
||||
"workers"
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_rsgi_scope(rsgi_server, threading_mode):
|
||||
async with rsgi_server(threading_mode) as port:
|
||||
async with websockets.connect(f"ws://localhost:{port}/ws_info?test=true") as ws:
|
||||
async with websockets.connect(f'ws://localhost:{port}/ws_info?test=true') as ws:
|
||||
res = await ws.recv()
|
||||
|
||||
data = json.loads(res)
|
||||
|
@ -76,7 +62,7 @@ async def test_rsgi_scope(rsgi_server, threading_mode):
|
|||
assert data['http_version'] == '1.1'
|
||||
assert data['rsgi_version'] == '1.2'
|
||||
assert data['scheme'] == 'http'
|
||||
assert data['method'] == "GET"
|
||||
assert data['method'] == 'GET'
|
||||
assert data['path'] == '/ws_info'
|
||||
assert data['query_string'] == 'test=true'
|
||||
assert data['headers']['host'] == f'localhost:{port}'
|
||||
|
|
|
@ -3,24 +3,18 @@ import pytest
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"threading_mode",
|
||||
[
|
||||
"runtime",
|
||||
"workers"
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_scope(wsgi_server, threading_mode):
|
||||
payload = "body_payload"
|
||||
payload = 'body_payload'
|
||||
async with wsgi_server(threading_mode) as port:
|
||||
res = httpx.post(f"http://localhost:{port}/info?test=true", content=payload)
|
||||
res = httpx.post(f'http://localhost:{port}/info?test=true', content=payload)
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.headers["content-type"] == "application/json"
|
||||
assert res.headers['content-type'] == 'application/json'
|
||||
|
||||
data = res.json()
|
||||
assert data['scheme'] == 'http'
|
||||
assert data['method'] == "POST"
|
||||
assert data['method'] == 'POST'
|
||||
assert data['path'] == '/info'
|
||||
assert data['query_string'] == 'test=true'
|
||||
assert data['headers']['HTTP_HOST'] == f'localhost:{port}'
|
||||
|
@ -28,47 +22,29 @@ async def test_scope(wsgi_server, threading_mode):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"threading_mode",
|
||||
[
|
||||
"runtime",
|
||||
"workers"
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_body(wsgi_server, threading_mode):
|
||||
async with wsgi_server(threading_mode) as port:
|
||||
res = httpx.post(f"http://localhost:{port}/echo", content="test")
|
||||
res = httpx.post(f'http://localhost:{port}/echo', content='test')
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.text == "test"
|
||||
assert res.text == 'test'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"threading_mode",
|
||||
[
|
||||
"runtime",
|
||||
"workers"
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_iterbody(wsgi_server, threading_mode):
|
||||
async with wsgi_server(threading_mode) as port:
|
||||
res = httpx.get(f"http://localhost:{port}/iterbody")
|
||||
res = httpx.get(f'http://localhost:{port}/iterbody')
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.text == "test" * 3
|
||||
assert res.text == 'test' * 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"threading_mode",
|
||||
[
|
||||
"runtime",
|
||||
"workers"
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_app_error(wsgi_server, threading_mode):
|
||||
async with wsgi_server(threading_mode) as port:
|
||||
res = httpx.get(f"http://localhost:{port}/err_app")
|
||||
res = httpx.get(f'http://localhost:{port}/err_app')
|
||||
|
||||
assert res.status_code == 500
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue