diff --git a/src/asgi/callbacks.rs b/src/asgi/callbacks.rs index a1742f7..81c0a29 100644 --- a/src/asgi/callbacks.rs +++ b/src/asgi/callbacks.rs @@ -7,7 +7,7 @@ use tokio::sync::oneshot; use super::{ io::{ASGIHTTPProtocol as HTTPProtocol, ASGIWebsocketProtocol as WebsocketProtocol, WebsocketDetachedTransport}, - utils::{build_scope, scope_native_parts}, + utils::{build_scope_http, build_scope_ws, scope_native_parts}, }; use crate::{ callbacks::{ @@ -312,18 +312,7 @@ macro_rules! call_impl_rtb_http { client ); Python::with_gil(|py| { - let scope = build_scope( - py, - &req, - "http", - version, - server, - client, - scheme, - &path, - query_string, - ) - .unwrap(); + let scope = build_scope_http(py, &req, version, server, client, scheme, &path, query_string).unwrap(); let _ = $runner::new(py, cb, protocol, scope).run(py); }); @@ -360,18 +349,8 @@ macro_rules! call_impl_rtt_http { client ); Python::with_gil(|py| { - let scope = build_scope( - py, - &req, - "http", - version, - server, - client, - &scheme, - &path, - query_string, - ) - .unwrap(); + let scope = + build_scope_http(py, &req, version, server, client, &scheme, &path, query_string).unwrap(); let _ = $runner::new(py, cb, protocol, scope).run(py); }); }); @@ -408,18 +387,7 @@ macro_rules! call_impl_rtb_ws { client ); Python::with_gil(|py| { - let scope = build_scope( - py, - &req, - "websocket", - version, - server, - client, - scheme, - &path, - query_string, - ) - .unwrap(); + let scope = build_scope_ws(py, &req, version, server, client, scheme, &path, query_string).unwrap(); let _ = $runner::new(py, cb, protocol, scope).run(py); }); @@ -457,18 +425,8 @@ macro_rules! call_impl_rtt_ws { client ); Python::with_gil(|py| { - let scope = build_scope( - py, - &req, - "websocket", - version, - server, - client, - &scheme, - &path, - query_string, - ) - .unwrap(); + let scope = + build_scope_ws(py, &req, version, server, client, &scheme, &path, query_string).unwrap(); let _ = $runner::new(py, cb, protocol, scope).run(py); }); }); diff --git a/src/asgi/io.rs b/src/asgi/io.rs index 12fc0b0..f46f1d2 100644 --- a/src/asgi/io.rs +++ b/src/asgi/io.rs @@ -32,6 +32,7 @@ use crate::{ const EMPTY_BYTES: Cow<[u8]> = Cow::Borrowed(b""); const EMPTY_STRING: String = String::new(); +static WS_SUBPROTO_HNAME: &str = "Sec-WebSocket-Protocol"; #[pyclass(frozen, module = "granian._granian")] pub(crate) struct ASGIHTTPProtocol { @@ -140,18 +141,17 @@ impl ASGIHTTPProtocol { } fn send<'p>(&self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> { - match adapt_message_type(data) { - Ok(ASGIMessageType::HTTPStart) => match self.response_started.load(atomic::Ordering::Relaxed) { + match adapt_message_type(py, data) { + Ok(ASGIMessageType::HTTPStart(intent)) => match self.response_started.load(atomic::Ordering::Relaxed) { false => { let mut response_intent = self.response_intent.lock().unwrap(); - *response_intent = Some((adapt_status_code(py, data)?, adapt_headers(py, data))); + *response_intent = Some(intent); self.response_started.store(true, atomic::Ordering::Relaxed); empty_future_into_py(py) } true => error_flow!(), }, - Ok(ASGIMessageType::HTTPBody) => { - let (body, more) = adapt_body(py, data); + Ok(ASGIMessageType::HTTPBody((body, more))) => { match ( self.response_started.load(atomic::Ordering::Relaxed), more, @@ -200,12 +200,11 @@ impl ASGIHTTPProtocol { _ => error_flow!(), } } - Ok(ASGIMessageType::HTTPFile) => match ( + Ok(ASGIMessageType::HTTPFile(file_path)) => match ( self.response_started.load(atomic::Ordering::Relaxed), - adapt_file(py, data), self.tx.lock().unwrap().take(), ) { - (true, Ok(file_path), Some(tx)) => { + (true, Some(tx)) => { let (status, headers) = self.response_intent.lock().unwrap().take().unwrap(); future_into_py_iter(self.rt.clone(), py, async move { let res = match File::open(&file_path).await { @@ -288,7 +287,7 @@ impl ASGIWebsocketProtocol { } #[inline(always)] - fn accept<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + fn accept<'p>(&self, py: Python<'p>, subproto: Option) -> PyResult<&'p PyAny> { let upgrade = self.upgrade.lock().unwrap().take(); let websocket = self.websocket.lock().unwrap().take(); let accepted = self.accepted.clone(); @@ -297,7 +296,11 @@ impl ASGIWebsocketProtocol { future_into_py_iter(self.rt.clone(), py, async move { if let Some(mut upgrade) = upgrade { - if (upgrade.send().await).is_ok() { + let upgrade_headers = match subproto { + Some(v) => vec![(WS_SUBPROTO_HNAME.to_string(), v)], + _ => vec![], + }; + if (upgrade.send(Some(upgrade_headers)).await).is_ok() { if let Some(websocket) = websocket { if let Ok(stream) = websocket.await { let mut wtx = tx.lock().await; @@ -316,29 +319,23 @@ impl ASGIWebsocketProtocol { } #[inline(always)] - fn send_message<'p>(&self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> { + fn send_message<'p>(&self, py: Python<'p>, data: Message) -> PyResult<&'p PyAny> { let transport = self.ws_tx.clone(); - let message = ws_message_into_rs(py, data); let closed = self.closed.clone(); future_into_py_futlike(self.rt.clone(), py, async move { - match message { - Ok(message) => { - if let Some(ws) = &mut *(transport.lock().await) { - match ws.send(message).await { - Ok(()) => return Ok(()), - _ => { - if closed.load(atomic::Ordering::Relaxed) { - log::info!("Attempted to write to a closed websocket"); - return Ok(()); - } - } - }; - }; - error_flow!() - } - Err(err) => Err(err), - } + if let Some(ws) = &mut *(transport.lock().await) { + match ws.send(data).await { + Ok(()) => return Ok(()), + _ => { + if closed.load(atomic::Ordering::Relaxed) { + log::info!("Attempted to write to a closed websocket"); + return Ok(()); + } + } + }; + }; + error_flow!() }) } @@ -416,27 +413,36 @@ impl ASGIWebsocketProtocol { } fn send<'p>(&self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> { - match adapt_message_type(data) { - Ok(ASGIMessageType::WSAccept) => self.accept(py), + match adapt_message_type(py, data) { + Ok(ASGIMessageType::WSAccept(subproto)) => self.accept(py, subproto), Ok(ASGIMessageType::WSClose) => self.close(py), - Ok(ASGIMessageType::WSMessage) => self.send_message(py, data), + Ok(ASGIMessageType::WSMessage(message)) => self.send_message(py, message), _ => future_into_py_iter::<_, _, PyErr>(self.rt.clone(), py, async { error_message!() }), } } } #[inline(never)] -fn adapt_message_type(message: &PyDict) -> Result { - match message.get_item("type") { +fn adapt_message_type(py: Python, message: &PyDict) -> Result { + match message.get_item(pyo3::intern!(py, "type")) { Ok(Some(item)) => { let message_type: &str = item.extract()?; match message_type { - "http.response.start" => Ok(ASGIMessageType::HTTPStart), - "http.response.body" => Ok(ASGIMessageType::HTTPBody), - "http.response.pathsend" => Ok(ASGIMessageType::HTTPFile), - "websocket.accept" => Ok(ASGIMessageType::WSAccept), + "http.response.start" => Ok(ASGIMessageType::HTTPStart(( + adapt_status_code(py, message)?, + adapt_headers(py, message), + ))), + "http.response.body" => Ok(ASGIMessageType::HTTPBody(adapt_body(py, message))), + "http.response.pathsend" => Ok(ASGIMessageType::HTTPFile(adapt_file(py, message)?)), + "websocket.accept" => { + let subproto: Option = match message.get_item(pyo3::intern!(py, "subprotocol")) { + Ok(Some(item)) => item.extract::().map(Some)?, + _ => None, + }; + Ok(ASGIMessageType::WSAccept(subproto)) + } "websocket.close" => Ok(ASGIMessageType::WSClose), - "websocket.send" => Ok(ASGIMessageType::WSMessage), + "websocket.send" => Ok(ASGIMessageType::WSMessage(ws_message_into_rs(py, message)?)), _ => error_message!(), } } @@ -505,11 +511,9 @@ fn ws_message_into_rs(py: Python, message: &PyDict) -> PyResult { (Some(itemb), Some(itemt)) => match (itemb.extract().unwrap_or(None), itemt.extract().unwrap_or(None)) { (Some(msgb), None) => Ok(Message::Binary(msgb)), (None, Some(msgt)) => Ok(Message::Text(msgt)), - _ => error_flow!(), + _ => error_message!(), }, - _ => { - error_flow!() - } + _ => error_message!(), } } diff --git a/src/asgi/types.rs b/src/asgi/types.rs index f6cd9d7..4e34dba 100644 --- a/src/asgi/types.rs +++ b/src/asgi/types.rs @@ -1,8 +1,11 @@ +use hyper::HeaderMap; +use tokio_tungstenite::tungstenite::Message; + pub(crate) enum ASGIMessageType { - HTTPStart, - HTTPBody, - HTTPFile, - WSAccept, + HTTPStart((i16, HeaderMap)), + HTTPBody((Box<[u8]>, bool)), + HTTPFile(String), + WSAccept(Option), WSClose, - WSMessage, + WSMessage(Message), } diff --git a/src/asgi/utils.rs b/src/asgi/utils.rs index 27872b9..f79c7ec 100644 --- a/src/asgi/utils.rs +++ b/src/asgi/utils.rs @@ -5,7 +5,7 @@ use hyper::{ use pyo3::{ prelude::*, sync::GILOnceCell, - types::{PyBytes, PyDict, PyList}, + types::{PyBytes, PyDict, PyList, PyString}, }; static ASGI_VERSION: GILOnceCell = GILOnceCell::new(); @@ -94,4 +94,54 @@ pub(super) fn build_scope<'p>( Ok(scope) } +#[inline] +pub(super) fn build_scope_http<'p>( + py: Python<'p>, + req: &'p request::Parts, + version: &'p str, + server: (String, String), + client: (String, String), + scheme: &'p str, + path: &'p str, + query_string: &'p str, +) -> PyResult<&'p PyDict> { + build_scope(py, req, "http", version, server, client, scheme, path, query_string) +} + +#[inline] +pub(super) fn build_scope_ws<'p>( + py: Python<'p>, + req: &'p request::Parts, + version: &'p str, + server: (String, String), + client: (String, String), + scheme: &'p str, + path: &'p str, + query_string: &'p str, +) -> PyResult<&'p PyDict> { + let scope = build_scope( + py, + req, + "websocket", + version, + server, + client, + scheme, + path, + query_string, + )?; + scope.set_item( + pyo3::intern!(py, "subprotocols"), + PyList::new( + py, + req.headers + .get_all("Sec-WebSocket-Protocol") + .iter() + .map(|v| PyString::new(py, v.to_str().unwrap())) + .collect::>(), + ), + )?; + Ok(scope) +} + pub(super) use scope_native_parts; diff --git a/src/rsgi/io.rs b/src/rsgi/io.rs index f3244d1..f99d406 100644 --- a/src/rsgi/io.rs +++ b/src/rsgi/io.rs @@ -360,7 +360,7 @@ impl RSGIWebsocketProtocol { let itransport = self.transport.clone(); future_into_py_iter(self.rt.clone(), py, async move { let mut ws = transport.lock().await; - match upgrade.send().await { + match upgrade.send(None).await { Ok(()) => match (&mut *ws).await { Ok(stream) => { let mut trx = itransport.lock().unwrap(); diff --git a/src/ws.rs b/src/ws.rs index a6624c4..b291b45 100644 --- a/src/ws.rs +++ b/src/ws.rs @@ -1,6 +1,6 @@ use http_body_util::BodyExt; use hyper::{ - header::{CONNECTION, UPGRADE}, + header::{HeaderName, HeaderValue, CONNECTION, UPGRADE}, http::response::Builder, Request, Response, StatusCode, }; @@ -60,34 +60,33 @@ impl Future for HyperWebsocket { } pub(crate) struct UpgradeData { - response_builder: Option, - response_tx: Option>, - pub consumed: bool, + response: Option<(Builder, mpsc::Sender)>, } impl UpgradeData { pub fn new(response_builder: Builder, response_tx: mpsc::Sender) -> Self { Self { - response_builder: Some(response_builder), - response_tx: Some(response_tx), - consumed: false, + response: Some((response_builder, response_tx)), } } - pub async fn send(&mut self) -> Result<(), mpsc::error::SendError> { - let res = self - .response_builder - .take() - .unwrap() - .body(http_body_util::Empty::new().map_err(|e| match e {}).boxed()) - .unwrap(); - match self.response_tx.take().unwrap().send(res).await { - Ok(()) => { - self.consumed = true; - Ok(()) + pub async fn send(&mut self, headers: Option>) -> anyhow::Result<()> { + if let Some((mut builder, tx)) = self.response.take() { + if let Some(headers) = headers { + let rheaders = builder.headers_mut().unwrap(); + for (key, val) in &headers { + rheaders.append( + HeaderName::from_bytes(key.as_bytes()).unwrap(), + HeaderValue::from_str(val).unwrap(), + ); + } } - err => err, + let res = builder + .body(http_body_util::Empty::new().map_err(|e| match e {}).boxed()) + .unwrap(); + return Ok(tx.send(res).await?); } + Err(anyhow::Error::msg("Already consumed")) } } diff --git a/tests/apps/asgi.py b/tests/apps/asgi.py index 8711aeb..02efb66 100644 --- a/tests/apps/asgi.py +++ b/tests/apps/asgi.py @@ -74,6 +74,7 @@ async def ws_info(scope, receive, send): 'path': scope['path'], 'query_string': scope['query_string'].decode('latin-1'), 'headers': {k.decode('utf8'): v.decode('utf8') for k, v in scope['headers']}, + 'subprotocols': scope['subprotocols'], } ), } diff --git a/tests/test_ws.py b/tests/test_ws.py index 6ce3b22..0045b8a 100644 --- a/tests/test_ws.py +++ b/tests/test_ws.py @@ -1,13 +1,11 @@ import json import os -import sys import pytest import websockets @pytest.mark.asyncio -@pytest.mark.skipif(sys.platform == 'win32', reason='skip on windows') @pytest.mark.parametrize('server', ['asgi', 'rsgi'], indirect=True) @pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_messages(server, threading_mode): @@ -50,6 +48,7 @@ async def test_asgi_scope(asgi_server, threading_mode): assert data['path'] == '/ws_info' assert data['query_string'] == 'test=true' assert data['headers']['host'] == f'localhost:{port}' + assert not data['subprotocols'] @pytest.mark.asyncio