mirror of
https://github.com/emmett-framework/granian.git
synced 2025-12-23 05:36:49 +00:00
Add support for websockets subprotocols (#246)
This commit is contained in:
parent
32f3915f12
commit
bf79551e1c
8 changed files with 134 additions and 120 deletions
|
|
@ -7,7 +7,7 @@ use tokio::sync::oneshot;
|
|||
|
||||
use super::{
|
||||
io::{ASGIHTTPProtocol as HTTPProtocol, ASGIWebsocketProtocol as WebsocketProtocol, WebsocketDetachedTransport},
|
||||
utils::{build_scope, scope_native_parts},
|
||||
utils::{build_scope_http, build_scope_ws, scope_native_parts},
|
||||
};
|
||||
use crate::{
|
||||
callbacks::{
|
||||
|
|
@ -312,18 +312,7 @@ macro_rules! call_impl_rtb_http {
|
|||
client
|
||||
);
|
||||
Python::with_gil(|py| {
|
||||
let scope = build_scope(
|
||||
py,
|
||||
&req,
|
||||
"http",
|
||||
version,
|
||||
server,
|
||||
client,
|
||||
scheme,
|
||||
&path,
|
||||
query_string,
|
||||
)
|
||||
.unwrap();
|
||||
let scope = build_scope_http(py, &req, version, server, client, scheme, &path, query_string).unwrap();
|
||||
let _ = $runner::new(py, cb, protocol, scope).run(py);
|
||||
});
|
||||
|
||||
|
|
@ -360,18 +349,8 @@ macro_rules! call_impl_rtt_http {
|
|||
client
|
||||
);
|
||||
Python::with_gil(|py| {
|
||||
let scope = build_scope(
|
||||
py,
|
||||
&req,
|
||||
"http",
|
||||
version,
|
||||
server,
|
||||
client,
|
||||
&scheme,
|
||||
&path,
|
||||
query_string,
|
||||
)
|
||||
.unwrap();
|
||||
let scope =
|
||||
build_scope_http(py, &req, version, server, client, &scheme, &path, query_string).unwrap();
|
||||
let _ = $runner::new(py, cb, protocol, scope).run(py);
|
||||
});
|
||||
});
|
||||
|
|
@ -408,18 +387,7 @@ macro_rules! call_impl_rtb_ws {
|
|||
client
|
||||
);
|
||||
Python::with_gil(|py| {
|
||||
let scope = build_scope(
|
||||
py,
|
||||
&req,
|
||||
"websocket",
|
||||
version,
|
||||
server,
|
||||
client,
|
||||
scheme,
|
||||
&path,
|
||||
query_string,
|
||||
)
|
||||
.unwrap();
|
||||
let scope = build_scope_ws(py, &req, version, server, client, scheme, &path, query_string).unwrap();
|
||||
let _ = $runner::new(py, cb, protocol, scope).run(py);
|
||||
});
|
||||
|
||||
|
|
@ -457,18 +425,8 @@ macro_rules! call_impl_rtt_ws {
|
|||
client
|
||||
);
|
||||
Python::with_gil(|py| {
|
||||
let scope = build_scope(
|
||||
py,
|
||||
&req,
|
||||
"websocket",
|
||||
version,
|
||||
server,
|
||||
client,
|
||||
&scheme,
|
||||
&path,
|
||||
query_string,
|
||||
)
|
||||
.unwrap();
|
||||
let scope =
|
||||
build_scope_ws(py, &req, version, server, client, &scheme, &path, query_string).unwrap();
|
||||
let _ = $runner::new(py, cb, protocol, scope).run(py);
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ use crate::{
|
|||
|
||||
const EMPTY_BYTES: Cow<[u8]> = Cow::Borrowed(b"");
|
||||
const EMPTY_STRING: String = String::new();
|
||||
static WS_SUBPROTO_HNAME: &str = "Sec-WebSocket-Protocol";
|
||||
|
||||
#[pyclass(frozen, module = "granian._granian")]
|
||||
pub(crate) struct ASGIHTTPProtocol {
|
||||
|
|
@ -140,18 +141,17 @@ impl ASGIHTTPProtocol {
|
|||
}
|
||||
|
||||
fn send<'p>(&self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> {
|
||||
match adapt_message_type(data) {
|
||||
Ok(ASGIMessageType::HTTPStart) => match self.response_started.load(atomic::Ordering::Relaxed) {
|
||||
match adapt_message_type(py, data) {
|
||||
Ok(ASGIMessageType::HTTPStart(intent)) => match self.response_started.load(atomic::Ordering::Relaxed) {
|
||||
false => {
|
||||
let mut response_intent = self.response_intent.lock().unwrap();
|
||||
*response_intent = Some((adapt_status_code(py, data)?, adapt_headers(py, data)));
|
||||
*response_intent = Some(intent);
|
||||
self.response_started.store(true, atomic::Ordering::Relaxed);
|
||||
empty_future_into_py(py)
|
||||
}
|
||||
true => error_flow!(),
|
||||
},
|
||||
Ok(ASGIMessageType::HTTPBody) => {
|
||||
let (body, more) = adapt_body(py, data);
|
||||
Ok(ASGIMessageType::HTTPBody((body, more))) => {
|
||||
match (
|
||||
self.response_started.load(atomic::Ordering::Relaxed),
|
||||
more,
|
||||
|
|
@ -200,12 +200,11 @@ impl ASGIHTTPProtocol {
|
|||
_ => error_flow!(),
|
||||
}
|
||||
}
|
||||
Ok(ASGIMessageType::HTTPFile) => match (
|
||||
Ok(ASGIMessageType::HTTPFile(file_path)) => match (
|
||||
self.response_started.load(atomic::Ordering::Relaxed),
|
||||
adapt_file(py, data),
|
||||
self.tx.lock().unwrap().take(),
|
||||
) {
|
||||
(true, Ok(file_path), Some(tx)) => {
|
||||
(true, Some(tx)) => {
|
||||
let (status, headers) = self.response_intent.lock().unwrap().take().unwrap();
|
||||
future_into_py_iter(self.rt.clone(), py, async move {
|
||||
let res = match File::open(&file_path).await {
|
||||
|
|
@ -288,7 +287,7 @@ impl ASGIWebsocketProtocol {
|
|||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn accept<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> {
|
||||
fn accept<'p>(&self, py: Python<'p>, subproto: Option<String>) -> PyResult<&'p PyAny> {
|
||||
let upgrade = self.upgrade.lock().unwrap().take();
|
||||
let websocket = self.websocket.lock().unwrap().take();
|
||||
let accepted = self.accepted.clone();
|
||||
|
|
@ -297,7 +296,11 @@ impl ASGIWebsocketProtocol {
|
|||
|
||||
future_into_py_iter(self.rt.clone(), py, async move {
|
||||
if let Some(mut upgrade) = upgrade {
|
||||
if (upgrade.send().await).is_ok() {
|
||||
let upgrade_headers = match subproto {
|
||||
Some(v) => vec![(WS_SUBPROTO_HNAME.to_string(), v)],
|
||||
_ => vec![],
|
||||
};
|
||||
if (upgrade.send(Some(upgrade_headers)).await).is_ok() {
|
||||
if let Some(websocket) = websocket {
|
||||
if let Ok(stream) = websocket.await {
|
||||
let mut wtx = tx.lock().await;
|
||||
|
|
@ -316,29 +319,23 @@ impl ASGIWebsocketProtocol {
|
|||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn send_message<'p>(&self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> {
|
||||
fn send_message<'p>(&self, py: Python<'p>, data: Message) -> PyResult<&'p PyAny> {
|
||||
let transport = self.ws_tx.clone();
|
||||
let message = ws_message_into_rs(py, data);
|
||||
let closed = self.closed.clone();
|
||||
|
||||
future_into_py_futlike(self.rt.clone(), py, async move {
|
||||
match message {
|
||||
Ok(message) => {
|
||||
if let Some(ws) = &mut *(transport.lock().await) {
|
||||
match ws.send(message).await {
|
||||
Ok(()) => return Ok(()),
|
||||
_ => {
|
||||
if closed.load(atomic::Ordering::Relaxed) {
|
||||
log::info!("Attempted to write to a closed websocket");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
error_flow!()
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
if let Some(ws) = &mut *(transport.lock().await) {
|
||||
match ws.send(data).await {
|
||||
Ok(()) => return Ok(()),
|
||||
_ => {
|
||||
if closed.load(atomic::Ordering::Relaxed) {
|
||||
log::info!("Attempted to write to a closed websocket");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
error_flow!()
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -416,27 +413,36 @@ impl ASGIWebsocketProtocol {
|
|||
}
|
||||
|
||||
fn send<'p>(&self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> {
|
||||
match adapt_message_type(data) {
|
||||
Ok(ASGIMessageType::WSAccept) => self.accept(py),
|
||||
match adapt_message_type(py, data) {
|
||||
Ok(ASGIMessageType::WSAccept(subproto)) => self.accept(py, subproto),
|
||||
Ok(ASGIMessageType::WSClose) => self.close(py),
|
||||
Ok(ASGIMessageType::WSMessage) => self.send_message(py, data),
|
||||
Ok(ASGIMessageType::WSMessage(message)) => self.send_message(py, message),
|
||||
_ => future_into_py_iter::<_, _, PyErr>(self.rt.clone(), py, async { error_message!() }),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
fn adapt_message_type(message: &PyDict) -> Result<ASGIMessageType, UnsupportedASGIMessage> {
|
||||
match message.get_item("type") {
|
||||
fn adapt_message_type(py: Python, message: &PyDict) -> Result<ASGIMessageType, UnsupportedASGIMessage> {
|
||||
match message.get_item(pyo3::intern!(py, "type")) {
|
||||
Ok(Some(item)) => {
|
||||
let message_type: &str = item.extract()?;
|
||||
match message_type {
|
||||
"http.response.start" => Ok(ASGIMessageType::HTTPStart),
|
||||
"http.response.body" => Ok(ASGIMessageType::HTTPBody),
|
||||
"http.response.pathsend" => Ok(ASGIMessageType::HTTPFile),
|
||||
"websocket.accept" => Ok(ASGIMessageType::WSAccept),
|
||||
"http.response.start" => Ok(ASGIMessageType::HTTPStart((
|
||||
adapt_status_code(py, message)?,
|
||||
adapt_headers(py, message),
|
||||
))),
|
||||
"http.response.body" => Ok(ASGIMessageType::HTTPBody(adapt_body(py, message))),
|
||||
"http.response.pathsend" => Ok(ASGIMessageType::HTTPFile(adapt_file(py, message)?)),
|
||||
"websocket.accept" => {
|
||||
let subproto: Option<String> = match message.get_item(pyo3::intern!(py, "subprotocol")) {
|
||||
Ok(Some(item)) => item.extract::<String>().map(Some)?,
|
||||
_ => None,
|
||||
};
|
||||
Ok(ASGIMessageType::WSAccept(subproto))
|
||||
}
|
||||
"websocket.close" => Ok(ASGIMessageType::WSClose),
|
||||
"websocket.send" => Ok(ASGIMessageType::WSMessage),
|
||||
"websocket.send" => Ok(ASGIMessageType::WSMessage(ws_message_into_rs(py, message)?)),
|
||||
_ => error_message!(),
|
||||
}
|
||||
}
|
||||
|
|
@ -505,11 +511,9 @@ fn ws_message_into_rs(py: Python, message: &PyDict) -> PyResult<Message> {
|
|||
(Some(itemb), Some(itemt)) => match (itemb.extract().unwrap_or(None), itemt.extract().unwrap_or(None)) {
|
||||
(Some(msgb), None) => Ok(Message::Binary(msgb)),
|
||||
(None, Some(msgt)) => Ok(Message::Text(msgt)),
|
||||
_ => error_flow!(),
|
||||
_ => error_message!(),
|
||||
},
|
||||
_ => {
|
||||
error_flow!()
|
||||
}
|
||||
_ => error_message!(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
use hyper::HeaderMap;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
|
||||
pub(crate) enum ASGIMessageType {
|
||||
HTTPStart,
|
||||
HTTPBody,
|
||||
HTTPFile,
|
||||
WSAccept,
|
||||
HTTPStart((i16, HeaderMap)),
|
||||
HTTPBody((Box<[u8]>, bool)),
|
||||
HTTPFile(String),
|
||||
WSAccept(Option<String>),
|
||||
WSClose,
|
||||
WSMessage,
|
||||
WSMessage(Message),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ use hyper::{
|
|||
use pyo3::{
|
||||
prelude::*,
|
||||
sync::GILOnceCell,
|
||||
types::{PyBytes, PyDict, PyList},
|
||||
types::{PyBytes, PyDict, PyList, PyString},
|
||||
};
|
||||
|
||||
static ASGI_VERSION: GILOnceCell<PyObject> = GILOnceCell::new();
|
||||
|
|
@ -94,4 +94,54 @@ pub(super) fn build_scope<'p>(
|
|||
Ok(scope)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(super) fn build_scope_http<'p>(
|
||||
py: Python<'p>,
|
||||
req: &'p request::Parts,
|
||||
version: &'p str,
|
||||
server: (String, String),
|
||||
client: (String, String),
|
||||
scheme: &'p str,
|
||||
path: &'p str,
|
||||
query_string: &'p str,
|
||||
) -> PyResult<&'p PyDict> {
|
||||
build_scope(py, req, "http", version, server, client, scheme, path, query_string)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(super) fn build_scope_ws<'p>(
|
||||
py: Python<'p>,
|
||||
req: &'p request::Parts,
|
||||
version: &'p str,
|
||||
server: (String, String),
|
||||
client: (String, String),
|
||||
scheme: &'p str,
|
||||
path: &'p str,
|
||||
query_string: &'p str,
|
||||
) -> PyResult<&'p PyDict> {
|
||||
let scope = build_scope(
|
||||
py,
|
||||
req,
|
||||
"websocket",
|
||||
version,
|
||||
server,
|
||||
client,
|
||||
scheme,
|
||||
path,
|
||||
query_string,
|
||||
)?;
|
||||
scope.set_item(
|
||||
pyo3::intern!(py, "subprotocols"),
|
||||
PyList::new(
|
||||
py,
|
||||
req.headers
|
||||
.get_all("Sec-WebSocket-Protocol")
|
||||
.iter()
|
||||
.map(|v| PyString::new(py, v.to_str().unwrap()))
|
||||
.collect::<Vec<&PyString>>(),
|
||||
),
|
||||
)?;
|
||||
Ok(scope)
|
||||
}
|
||||
|
||||
pub(super) use scope_native_parts;
|
||||
|
|
|
|||
|
|
@ -360,7 +360,7 @@ impl RSGIWebsocketProtocol {
|
|||
let itransport = self.transport.clone();
|
||||
future_into_py_iter(self.rt.clone(), py, async move {
|
||||
let mut ws = transport.lock().await;
|
||||
match upgrade.send().await {
|
||||
match upgrade.send(None).await {
|
||||
Ok(()) => match (&mut *ws).await {
|
||||
Ok(stream) => {
|
||||
let mut trx = itransport.lock().unwrap();
|
||||
|
|
|
|||
37
src/ws.rs
37
src/ws.rs
|
|
@ -1,6 +1,6 @@
|
|||
use http_body_util::BodyExt;
|
||||
use hyper::{
|
||||
header::{CONNECTION, UPGRADE},
|
||||
header::{HeaderName, HeaderValue, CONNECTION, UPGRADE},
|
||||
http::response::Builder,
|
||||
Request, Response, StatusCode,
|
||||
};
|
||||
|
|
@ -60,34 +60,33 @@ impl Future for HyperWebsocket {
|
|||
}
|
||||
|
||||
pub(crate) struct UpgradeData {
|
||||
response_builder: Option<Builder>,
|
||||
response_tx: Option<mpsc::Sender<HTTPResponse>>,
|
||||
pub consumed: bool,
|
||||
response: Option<(Builder, mpsc::Sender<HTTPResponse>)>,
|
||||
}
|
||||
|
||||
impl UpgradeData {
|
||||
pub fn new(response_builder: Builder, response_tx: mpsc::Sender<HTTPResponse>) -> Self {
|
||||
Self {
|
||||
response_builder: Some(response_builder),
|
||||
response_tx: Some(response_tx),
|
||||
consumed: false,
|
||||
response: Some((response_builder, response_tx)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send(&mut self) -> Result<(), mpsc::error::SendError<HTTPResponse>> {
|
||||
let res = self
|
||||
.response_builder
|
||||
.take()
|
||||
.unwrap()
|
||||
.body(http_body_util::Empty::new().map_err(|e| match e {}).boxed())
|
||||
.unwrap();
|
||||
match self.response_tx.take().unwrap().send(res).await {
|
||||
Ok(()) => {
|
||||
self.consumed = true;
|
||||
Ok(())
|
||||
pub async fn send(&mut self, headers: Option<Vec<(String, String)>>) -> anyhow::Result<()> {
|
||||
if let Some((mut builder, tx)) = self.response.take() {
|
||||
if let Some(headers) = headers {
|
||||
let rheaders = builder.headers_mut().unwrap();
|
||||
for (key, val) in &headers {
|
||||
rheaders.append(
|
||||
HeaderName::from_bytes(key.as_bytes()).unwrap(),
|
||||
HeaderValue::from_str(val).unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
err => err,
|
||||
let res = builder
|
||||
.body(http_body_util::Empty::new().map_err(|e| match e {}).boxed())
|
||||
.unwrap();
|
||||
return Ok(tx.send(res).await?);
|
||||
}
|
||||
Err(anyhow::Error::msg("Already consumed"))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ async def ws_info(scope, receive, send):
|
|||
'path': scope['path'],
|
||||
'query_string': scope['query_string'].decode('latin-1'),
|
||||
'headers': {k.decode('utf8'): v.decode('utf8') for k, v in scope['headers']},
|
||||
'subprotocols': scope['subprotocols'],
|
||||
}
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import websockets
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(sys.platform == 'win32', reason='skip on windows')
|
||||
@pytest.mark.parametrize('server', ['asgi', 'rsgi'], indirect=True)
|
||||
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
|
||||
async def test_messages(server, threading_mode):
|
||||
|
|
@ -50,6 +48,7 @@ async def test_asgi_scope(asgi_server, threading_mode):
|
|||
assert data['path'] == '/ws_info'
|
||||
assert data['query_string'] == 'test=true'
|
||||
assert data['headers']['host'] == f'localhost:{port}'
|
||||
assert not data['subprotocols']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue