Add support for websockets subprotocols (#246)

This commit is contained in:
Giovanni Barillari 2024-03-18 23:19:59 +01:00 committed by GitHub
parent 32f3915f12
commit bf79551e1c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 134 additions and 120 deletions

View file

@ -7,7 +7,7 @@ use tokio::sync::oneshot;
use super::{
io::{ASGIHTTPProtocol as HTTPProtocol, ASGIWebsocketProtocol as WebsocketProtocol, WebsocketDetachedTransport},
utils::{build_scope, scope_native_parts},
utils::{build_scope_http, build_scope_ws, scope_native_parts},
};
use crate::{
callbacks::{
@ -312,18 +312,7 @@ macro_rules! call_impl_rtb_http {
client
);
Python::with_gil(|py| {
let scope = build_scope(
py,
&req,
"http",
version,
server,
client,
scheme,
&path,
query_string,
)
.unwrap();
let scope = build_scope_http(py, &req, version, server, client, scheme, &path, query_string).unwrap();
let _ = $runner::new(py, cb, protocol, scope).run(py);
});
@ -360,18 +349,8 @@ macro_rules! call_impl_rtt_http {
client
);
Python::with_gil(|py| {
let scope = build_scope(
py,
&req,
"http",
version,
server,
client,
&scheme,
&path,
query_string,
)
.unwrap();
let scope =
build_scope_http(py, &req, version, server, client, &scheme, &path, query_string).unwrap();
let _ = $runner::new(py, cb, protocol, scope).run(py);
});
});
@ -408,18 +387,7 @@ macro_rules! call_impl_rtb_ws {
client
);
Python::with_gil(|py| {
let scope = build_scope(
py,
&req,
"websocket",
version,
server,
client,
scheme,
&path,
query_string,
)
.unwrap();
let scope = build_scope_ws(py, &req, version, server, client, scheme, &path, query_string).unwrap();
let _ = $runner::new(py, cb, protocol, scope).run(py);
});
@ -457,18 +425,8 @@ macro_rules! call_impl_rtt_ws {
client
);
Python::with_gil(|py| {
let scope = build_scope(
py,
&req,
"websocket",
version,
server,
client,
&scheme,
&path,
query_string,
)
.unwrap();
let scope =
build_scope_ws(py, &req, version, server, client, &scheme, &path, query_string).unwrap();
let _ = $runner::new(py, cb, protocol, scope).run(py);
});
});

View file

@ -32,6 +32,7 @@ use crate::{
const EMPTY_BYTES: Cow<[u8]> = Cow::Borrowed(b"");
const EMPTY_STRING: String = String::new();
static WS_SUBPROTO_HNAME: &str = "Sec-WebSocket-Protocol";
#[pyclass(frozen, module = "granian._granian")]
pub(crate) struct ASGIHTTPProtocol {
@ -140,18 +141,17 @@ impl ASGIHTTPProtocol {
}
fn send<'p>(&self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> {
match adapt_message_type(data) {
Ok(ASGIMessageType::HTTPStart) => match self.response_started.load(atomic::Ordering::Relaxed) {
match adapt_message_type(py, data) {
Ok(ASGIMessageType::HTTPStart(intent)) => match self.response_started.load(atomic::Ordering::Relaxed) {
false => {
let mut response_intent = self.response_intent.lock().unwrap();
*response_intent = Some((adapt_status_code(py, data)?, adapt_headers(py, data)));
*response_intent = Some(intent);
self.response_started.store(true, atomic::Ordering::Relaxed);
empty_future_into_py(py)
}
true => error_flow!(),
},
Ok(ASGIMessageType::HTTPBody) => {
let (body, more) = adapt_body(py, data);
Ok(ASGIMessageType::HTTPBody((body, more))) => {
match (
self.response_started.load(atomic::Ordering::Relaxed),
more,
@ -200,12 +200,11 @@ impl ASGIHTTPProtocol {
_ => error_flow!(),
}
}
Ok(ASGIMessageType::HTTPFile) => match (
Ok(ASGIMessageType::HTTPFile(file_path)) => match (
self.response_started.load(atomic::Ordering::Relaxed),
adapt_file(py, data),
self.tx.lock().unwrap().take(),
) {
(true, Ok(file_path), Some(tx)) => {
(true, Some(tx)) => {
let (status, headers) = self.response_intent.lock().unwrap().take().unwrap();
future_into_py_iter(self.rt.clone(), py, async move {
let res = match File::open(&file_path).await {
@ -288,7 +287,7 @@ impl ASGIWebsocketProtocol {
}
#[inline(always)]
fn accept<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> {
fn accept<'p>(&self, py: Python<'p>, subproto: Option<String>) -> PyResult<&'p PyAny> {
let upgrade = self.upgrade.lock().unwrap().take();
let websocket = self.websocket.lock().unwrap().take();
let accepted = self.accepted.clone();
@ -297,7 +296,11 @@ impl ASGIWebsocketProtocol {
future_into_py_iter(self.rt.clone(), py, async move {
if let Some(mut upgrade) = upgrade {
if (upgrade.send().await).is_ok() {
let upgrade_headers = match subproto {
Some(v) => vec![(WS_SUBPROTO_HNAME.to_string(), v)],
_ => vec![],
};
if (upgrade.send(Some(upgrade_headers)).await).is_ok() {
if let Some(websocket) = websocket {
if let Ok(stream) = websocket.await {
let mut wtx = tx.lock().await;
@ -316,29 +319,23 @@ impl ASGIWebsocketProtocol {
}
#[inline(always)]
fn send_message<'p>(&self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> {
fn send_message<'p>(&self, py: Python<'p>, data: Message) -> PyResult<&'p PyAny> {
let transport = self.ws_tx.clone();
let message = ws_message_into_rs(py, data);
let closed = self.closed.clone();
future_into_py_futlike(self.rt.clone(), py, async move {
match message {
Ok(message) => {
if let Some(ws) = &mut *(transport.lock().await) {
match ws.send(message).await {
Ok(()) => return Ok(()),
_ => {
if closed.load(atomic::Ordering::Relaxed) {
log::info!("Attempted to write to a closed websocket");
return Ok(());
}
}
};
};
error_flow!()
}
Err(err) => Err(err),
}
if let Some(ws) = &mut *(transport.lock().await) {
match ws.send(data).await {
Ok(()) => return Ok(()),
_ => {
if closed.load(atomic::Ordering::Relaxed) {
log::info!("Attempted to write to a closed websocket");
return Ok(());
}
}
};
};
error_flow!()
})
}
@ -416,27 +413,36 @@ impl ASGIWebsocketProtocol {
}
fn send<'p>(&self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> {
match adapt_message_type(data) {
Ok(ASGIMessageType::WSAccept) => self.accept(py),
match adapt_message_type(py, data) {
Ok(ASGIMessageType::WSAccept(subproto)) => self.accept(py, subproto),
Ok(ASGIMessageType::WSClose) => self.close(py),
Ok(ASGIMessageType::WSMessage) => self.send_message(py, data),
Ok(ASGIMessageType::WSMessage(message)) => self.send_message(py, message),
_ => future_into_py_iter::<_, _, PyErr>(self.rt.clone(), py, async { error_message!() }),
}
}
}
#[inline(never)]
fn adapt_message_type(message: &PyDict) -> Result<ASGIMessageType, UnsupportedASGIMessage> {
match message.get_item("type") {
fn adapt_message_type(py: Python, message: &PyDict) -> Result<ASGIMessageType, UnsupportedASGIMessage> {
match message.get_item(pyo3::intern!(py, "type")) {
Ok(Some(item)) => {
let message_type: &str = item.extract()?;
match message_type {
"http.response.start" => Ok(ASGIMessageType::HTTPStart),
"http.response.body" => Ok(ASGIMessageType::HTTPBody),
"http.response.pathsend" => Ok(ASGIMessageType::HTTPFile),
"websocket.accept" => Ok(ASGIMessageType::WSAccept),
"http.response.start" => Ok(ASGIMessageType::HTTPStart((
adapt_status_code(py, message)?,
adapt_headers(py, message),
))),
"http.response.body" => Ok(ASGIMessageType::HTTPBody(adapt_body(py, message))),
"http.response.pathsend" => Ok(ASGIMessageType::HTTPFile(adapt_file(py, message)?)),
"websocket.accept" => {
let subproto: Option<String> = match message.get_item(pyo3::intern!(py, "subprotocol")) {
Ok(Some(item)) => item.extract::<String>().map(Some)?,
_ => None,
};
Ok(ASGIMessageType::WSAccept(subproto))
}
"websocket.close" => Ok(ASGIMessageType::WSClose),
"websocket.send" => Ok(ASGIMessageType::WSMessage),
"websocket.send" => Ok(ASGIMessageType::WSMessage(ws_message_into_rs(py, message)?)),
_ => error_message!(),
}
}
@ -505,11 +511,9 @@ fn ws_message_into_rs(py: Python, message: &PyDict) -> PyResult<Message> {
(Some(itemb), Some(itemt)) => match (itemb.extract().unwrap_or(None), itemt.extract().unwrap_or(None)) {
(Some(msgb), None) => Ok(Message::Binary(msgb)),
(None, Some(msgt)) => Ok(Message::Text(msgt)),
_ => error_flow!(),
_ => error_message!(),
},
_ => {
error_flow!()
}
_ => error_message!(),
}
}

View file

@ -1,8 +1,11 @@
use hyper::HeaderMap;
use tokio_tungstenite::tungstenite::Message;
pub(crate) enum ASGIMessageType {
HTTPStart,
HTTPBody,
HTTPFile,
WSAccept,
HTTPStart((i16, HeaderMap)),
HTTPBody((Box<[u8]>, bool)),
HTTPFile(String),
WSAccept(Option<String>),
WSClose,
WSMessage,
WSMessage(Message),
}

View file

@ -5,7 +5,7 @@ use hyper::{
use pyo3::{
prelude::*,
sync::GILOnceCell,
types::{PyBytes, PyDict, PyList},
types::{PyBytes, PyDict, PyList, PyString},
};
static ASGI_VERSION: GILOnceCell<PyObject> = GILOnceCell::new();
@ -94,4 +94,54 @@ pub(super) fn build_scope<'p>(
Ok(scope)
}
#[inline]
pub(super) fn build_scope_http<'p>(
py: Python<'p>,
req: &'p request::Parts,
version: &'p str,
server: (String, String),
client: (String, String),
scheme: &'p str,
path: &'p str,
query_string: &'p str,
) -> PyResult<&'p PyDict> {
build_scope(py, req, "http", version, server, client, scheme, path, query_string)
}
#[inline]
pub(super) fn build_scope_ws<'p>(
py: Python<'p>,
req: &'p request::Parts,
version: &'p str,
server: (String, String),
client: (String, String),
scheme: &'p str,
path: &'p str,
query_string: &'p str,
) -> PyResult<&'p PyDict> {
let scope = build_scope(
py,
req,
"websocket",
version,
server,
client,
scheme,
path,
query_string,
)?;
scope.set_item(
pyo3::intern!(py, "subprotocols"),
PyList::new(
py,
req.headers
.get_all("Sec-WebSocket-Protocol")
.iter()
.map(|v| PyString::new(py, v.to_str().unwrap()))
.collect::<Vec<&PyString>>(),
),
)?;
Ok(scope)
}
pub(super) use scope_native_parts;

View file

@ -360,7 +360,7 @@ impl RSGIWebsocketProtocol {
let itransport = self.transport.clone();
future_into_py_iter(self.rt.clone(), py, async move {
let mut ws = transport.lock().await;
match upgrade.send().await {
match upgrade.send(None).await {
Ok(()) => match (&mut *ws).await {
Ok(stream) => {
let mut trx = itransport.lock().unwrap();

View file

@ -1,6 +1,6 @@
use http_body_util::BodyExt;
use hyper::{
header::{CONNECTION, UPGRADE},
header::{HeaderName, HeaderValue, CONNECTION, UPGRADE},
http::response::Builder,
Request, Response, StatusCode,
};
@ -60,34 +60,33 @@ impl Future for HyperWebsocket {
}
pub(crate) struct UpgradeData {
response_builder: Option<Builder>,
response_tx: Option<mpsc::Sender<HTTPResponse>>,
pub consumed: bool,
response: Option<(Builder, mpsc::Sender<HTTPResponse>)>,
}
impl UpgradeData {
pub fn new(response_builder: Builder, response_tx: mpsc::Sender<HTTPResponse>) -> Self {
Self {
response_builder: Some(response_builder),
response_tx: Some(response_tx),
consumed: false,
response: Some((response_builder, response_tx)),
}
}
pub async fn send(&mut self) -> Result<(), mpsc::error::SendError<HTTPResponse>> {
let res = self
.response_builder
.take()
.unwrap()
.body(http_body_util::Empty::new().map_err(|e| match e {}).boxed())
.unwrap();
match self.response_tx.take().unwrap().send(res).await {
Ok(()) => {
self.consumed = true;
Ok(())
pub async fn send(&mut self, headers: Option<Vec<(String, String)>>) -> anyhow::Result<()> {
if let Some((mut builder, tx)) = self.response.take() {
if let Some(headers) = headers {
let rheaders = builder.headers_mut().unwrap();
for (key, val) in &headers {
rheaders.append(
HeaderName::from_bytes(key.as_bytes()).unwrap(),
HeaderValue::from_str(val).unwrap(),
);
}
}
err => err,
let res = builder
.body(http_body_util::Empty::new().map_err(|e| match e {}).boxed())
.unwrap();
return Ok(tx.send(res).await?);
}
Err(anyhow::Error::msg("Already consumed"))
}
}

View file

@ -74,6 +74,7 @@ async def ws_info(scope, receive, send):
'path': scope['path'],
'query_string': scope['query_string'].decode('latin-1'),
'headers': {k.decode('utf8'): v.decode('utf8') for k, v in scope['headers']},
'subprotocols': scope['subprotocols'],
}
),
}

View file

@ -1,13 +1,11 @@
import json
import os
import sys
import pytest
import websockets
@pytest.mark.asyncio
@pytest.mark.skipif(sys.platform == 'win32', reason='skip on windows')
@pytest.mark.parametrize('server', ['asgi', 'rsgi'], indirect=True)
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_messages(server, threading_mode):
@ -50,6 +48,7 @@ async def test_asgi_scope(asgi_server, threading_mode):
assert data['path'] == '/ws_info'
assert data['query_string'] == 'test=true'
assert data['headers']['host'] == f'localhost:{port}'
assert not data['subprotocols']
@pytest.mark.asyncio