diff --git a/granian/asgi.py b/granian/asgi.py index 8dbdbaa..1ef6366 100644 --- a/granian/asgi.py +++ b/granian/asgi.py @@ -141,7 +141,8 @@ def future_handler(watcher): def handler(task): try: task.result() - watcher.done(True) except Exception: watcher.done(False) + raise + watcher.done(True) return handler diff --git a/granian/rsgi.py b/granian/rsgi.py index 0797aee..11fb54d 100644 --- a/granian/rsgi.py +++ b/granian/rsgi.py @@ -95,6 +95,7 @@ def future_handler(watcher): try: res = task.result() except Exception: - res = None + watcher.err() + raise watcher.done(res) return handler diff --git a/src/asgi/callbacks.rs b/src/asgi/callbacks.rs index 19245df..b09fbc1 100644 --- a/src/asgi/callbacks.rs +++ b/src/asgi/callbacks.rs @@ -55,6 +55,10 @@ pub(crate) async fn call( match rx.await { Ok(true) => Ok(()), + Ok(false) => { + log::warn!("Application callable raised an exception"); + error_flow!() + }, _ => error_flow!() } } diff --git a/src/rsgi/callbacks.rs b/src/rsgi/callbacks.rs index b882ee9..2d68e98 100644 --- a/src/rsgi/callbacks.rs +++ b/src/rsgi/callbacks.rs @@ -3,13 +3,13 @@ use tokio::sync::oneshot; use crate::callbacks::CallbackWrapper; use super::{ - errors::ApplicationError, + errors::{error_proto, error_app}, io::{RSGIHTTPProtocol as HTTPProtocol, RSGIWebsocketProtocol as WebsocketProtocol}, types::RSGIScope as Scope }; -#[derive(FromPyObject)] +#[derive(FromPyObject, Debug)] pub(crate) struct CallbackResponse { pub mode: u32, pub status: i32, @@ -21,7 +21,7 @@ pub(crate) struct CallbackResponse { #[pyclass] pub(crate) struct CallbackResponseWatcher { - tx: Option>, + tx: Option>>, #[pyo3(get)] event_loop: PyObject, #[pyo3(get)] @@ -32,7 +32,7 @@ impl CallbackResponseWatcher { pub fn new( py: Python, cb: CallbackWrapper, - tx: oneshot::Sender + tx: oneshot::Sender> ) -> Self { Self { tx: Some(tx), @@ -46,8 +46,22 @@ impl CallbackResponseWatcher { impl CallbackResponseWatcher { fn done(&mut self, py: Python, result: PyObject) -> PyResult<()> { if let Some(tx) = self.tx.take() { - // FIXME: handle failure - let _ = tx.send(result.extract(py)?); + match result.extract(py) { + Ok(res) => { + let _ = tx.send(res); + return Ok(()) + }, + _ => { + let _ = tx.send(None); + } + } + }; + error_proto!() + } + + fn err(&mut self) -> PyResult<()> { + if let Some(tx) = self.tx.take() { + let _ = tx.send(None); }; Ok(()) } @@ -55,7 +69,7 @@ impl CallbackResponseWatcher { #[pyclass] pub(crate) struct CallbackProtocolWatcher { - tx: Option>, + tx: Option>>, #[pyo3(get)] event_loop: PyObject, #[pyo3(get)] @@ -66,7 +80,7 @@ impl CallbackProtocolWatcher { pub fn new( py: Python, cb: CallbackWrapper, - tx: oneshot::Sender<(i32, bool)> + tx: oneshot::Sender> ) -> Self { Self { tx: Some(tx), @@ -80,8 +94,22 @@ impl CallbackProtocolWatcher { impl CallbackProtocolWatcher { fn done(&mut self, py: Python, result: PyObject) -> PyResult<()> { if let Some(tx) = self.tx.take() { - // FIXME: handle failure - let _ = tx.send(result.extract(py)?); + match result.extract(py) { + Ok(res) => { + let _ = tx.send(res); + return Ok(()) + }, + _ => { + let _ = tx.send(None); + } + } + }; + error_proto!() + } + + fn err(&mut self) -> PyResult<()> { + if let Some(tx) = self.tx.take() { + let _ = tx.send(None); }; Ok(()) } @@ -99,8 +127,19 @@ pub(crate) async fn call_response( })?; match rx.await { - Ok(v) => Ok(v), - _ => Err(ApplicationError.into()) + Ok(res) => { + match res { + Some(res) => Ok(res), + _ => { + log::warn!("Application failed to return a response"); + error_app!() + } + } + }, + _ => { + log::error!("RSGI protocol failure"); + error_proto!() + } } } @@ -116,7 +155,18 @@ pub(crate) async fn call_protocol( })?; match rx.await { - Ok(v) => Ok(v), - _ => Err(ApplicationError.into()) + Ok(res) => { + match res { + Some(res) => Ok(res), + _ => { + log::warn!("Application failed to close protocol"); + error_app!() + } + } + }, + _ => { + log::error!("RSGI protocol failure"); + error_proto!() + } } } diff --git a/src/rsgi/errors.rs b/src/rsgi/errors.rs index 8472978..8ad9dc4 100644 --- a/src/rsgi/errors.rs +++ b/src/rsgi/errors.rs @@ -62,8 +62,15 @@ impl std::convert::From for PyErr { macro_rules! error_proto { () => { - Err(RSGIProtocolError.into()) + Err(super::errors::RSGIProtocolError.into()) + }; +} + +macro_rules! error_app { + () => { + Err(super::errors::ApplicationError.into()) }; } pub(crate) use error_proto; +pub(crate) use error_app; diff --git a/src/rsgi/io.rs b/src/rsgi/io.rs index ecefaac..6d057e7 100644 --- a/src/rsgi/io.rs +++ b/src/rsgi/io.rs @@ -12,7 +12,7 @@ use crate::{ runtime::{RuntimeRef, future_into_py}, ws::{HyperWebsocket, UpgradeData} }; -use super::errors::{RSGIProtocolError, error_proto}; +use super::errors::error_proto; #[pyclass(module="granian._granian")] diff --git a/tests/apps/asgi.py b/tests/apps/asgi.py index 6273447..3e41dea 100644 --- a/tests/apps/asgi.py +++ b/tests/apps/asgi.py @@ -85,11 +85,21 @@ async def ws_echo(scope, receive, send): }) +async def err_app(scope, receive, send): + 1 / 0 + + +async def err_proto(scope, receive, send): + await send({'type': 'wrong.msg'}) + + def app(scope, receive, send): return { "/info": info, "/echo": echo, "/ws_reject": ws_reject, "/ws_info": ws_info, - "/ws_echo": ws_echo + "/ws_echo": ws_echo, + "/err_app": err_app, + "/err_proto": err_proto }[scope['path']](scope, receive, send) diff --git a/tests/apps/rsgi.py b/tests/apps/rsgi.py index 5fdddf3..3028665 100644 --- a/tests/apps/rsgi.py +++ b/tests/apps/rsgi.py @@ -73,11 +73,21 @@ async def ws_echo(_, protocol: WebsocketProtocol): return protocol.close() +async def err_app(scope: Scope, protocol: HTTPProtocol): + 1 / 0 + + +async def err_proto(scope: Scope, protocol: HTTPProtocol): + return "bad" + + def app(scope, protocol): return { "/info": info, "/echo": echo, "/ws_reject": ws_reject, "/ws_info": ws_info, - "/ws_echo": ws_echo + "/ws_echo": ws_echo, + "/err_app": err_app, + "/err_proto": err_proto }[scope.path](scope, protocol) diff --git a/tests/test_asgi.py b/tests/test_asgi.py index e52eb18..447dc66 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -45,3 +45,33 @@ async def test_body(asgi_server, threading_mode): assert res.status_code == 200 assert res.text == "test" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "threading_mode", + [ + "runtime", + "workers" + ] +) +async def test_app_error(asgi_server, threading_mode): + async with asgi_server(threading_mode) as port: + res = httpx.get(f"http://localhost:{port}/err_app") + + assert res.status_code == 500 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "threading_mode", + [ + "runtime", + "workers" + ] +) +async def test_protocol_error(asgi_server, threading_mode): + async with asgi_server(threading_mode) as port: + res = httpx.get(f"http://localhost:{port}/err_proto") + + assert res.status_code == 500 diff --git a/tests/test_rsgi.py b/tests/test_rsgi.py index 69008e8..0eab101 100644 --- a/tests/test_rsgi.py +++ b/tests/test_rsgi.py @@ -42,3 +42,33 @@ async def test_body(rsgi_server, threading_mode): assert res.status_code == 200 assert res.text == "test" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "threading_mode", + [ + "runtime", + "workers" + ] +) +async def test_app_error(rsgi_server, threading_mode): + async with rsgi_server(threading_mode) as port: + res = httpx.get(f"http://localhost:{port}/err_app") + + assert res.status_code == 500 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "threading_mode", + [ + "runtime", + "workers" + ] +) +async def test_protocol_error(rsgi_server, threading_mode): + async with rsgi_server(threading_mode) as port: + res = httpx.get(f"http://localhost:{port}/err_proto") + + assert res.status_code == 500