Revise ASGI ws accept flow to wait for send

This commit is contained in:
Giovanni Barillari 2025-06-02 19:45:15 +02:00
parent 02f683950c
commit 4db737084a
No known key found for this signature in database
2 changed files with 29 additions and 9 deletions

View file

@ -46,7 +46,7 @@ pub(crate) struct ASGIHTTPProtocol {
body_tx: Mutex<Option<mpsc::UnboundedSender<body::Bytes>>>,
flow_rx_exhausted: Arc<atomic::AtomicBool>,
flow_rx_closed: Arc<atomic::AtomicBool>,
flow_tx_waiter: Arc<tokio::sync::Notify>,
flow_tx_waiter: Arc<Notify>,
sent_response_code: Arc<atomic::AtomicU16>,
}
@ -307,7 +307,9 @@ pub(crate) struct ASGIWebsocketProtocol {
upgrade: Mutex<Option<UpgradeData>>,
ws_rx: Arc<AsyncMutex<Option<WSRxStream>>>,
ws_tx: Arc<AsyncMutex<Option<WSTxStream>>>,
accepted: Arc<atomic::AtomicBool>,
init_rx: atomic::AtomicBool,
init_tx: Arc<atomic::AtomicBool>,
init_event: Arc<Notify>,
closed: Arc<atomic::AtomicBool>,
}
@ -325,7 +327,9 @@ impl ASGIWebsocketProtocol {
upgrade: Mutex::new(Some(upgrade)),
ws_rx: Arc::new(AsyncMutex::new(None)),
ws_tx: Arc::new(AsyncMutex::new(None)),
accepted: Arc::new(false.into()),
init_rx: false.into(),
init_tx: Arc::new(false.into()),
init_event: Arc::new(Notify::new()),
closed: Arc::new(false.into()),
}
}
@ -334,7 +338,8 @@ impl ASGIWebsocketProtocol {
fn accept<'p>(&self, py: Python<'p>, subproto: Option<String>) -> PyResult<Bound<'p, PyAny>> {
let upgrade = self.upgrade.lock().unwrap().take();
let websocket = self.websocket.lock().unwrap().take();
let accepted = self.accepted.clone();
let accepted = self.init_tx.clone();
let accept_notify = self.init_event.clone();
let rx = self.ws_rx.clone();
let tx = self.ws_tx.clone();
@ -352,7 +357,9 @@ impl ASGIWebsocketProtocol {
let (tx, rx) = stream.split();
*wtx = Some(tx);
*wrx = Some(rx);
drop(wrx);
accepted.store(true, atomic::Ordering::Release);
accept_notify.notify_one();
return FutureResultToPy::None;
}
}
@ -422,14 +429,27 @@ impl ASGIWebsocketProtocol {
#[pymethods]
impl ASGIWebsocketProtocol {
fn receive<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let accepted = self.accepted.clone();
// if it's the first `receive` call, return the connect message
if self
.init_rx
.compare_exchange(false, true, atomic::Ordering::Relaxed, atomic::Ordering::Relaxed)
.is_ok()
{
return done_future_into_py(
py,
super::conversion::message_into_py(py, ASGIMessageType::WSConnect).map(Bound::unbind),
);
}
let accepted = self.init_tx.clone();
let accepted_ev = self.init_event.clone();
let closed = self.closed.clone();
let transport = self.ws_rx.clone();
future_into_py_futlike(self.rt.clone(), py, async move {
let accepted = accepted.load(atomic::Ordering::Acquire);
if !accepted {
return FutureResultToPy::ASGIMessage(ASGIMessageType::WSConnect);
if !accepted.load(atomic::Ordering::Acquire) {
// need to wait for the protocol to send the accept message and init transport
accepted_ev.notified().await;
}
if let Some(ws) = &mut *(transport.lock().await) {

View file

@ -359,8 +359,8 @@ impl PyDoneAwaitable {
fn __next__(&self, py: Python) -> PyResult<PyObject> {
self.result
.as_ref()
.map(|v| v.clone_ref(py))
.map_err(|v| v.clone_ref(py))
.map(|v| Err(PyStopIteration::new_err(v.clone_ref(py))))?
}
}