Fix asyncio timeouts support in rust task impl (#476)

This commit is contained in:
Giovanni Barillari 2024-12-28 17:38:19 +01:00 committed by GitHub
parent 3ab82a76b6
commit 9aff0caaed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 188 additions and 2 deletions

View file

@ -24,6 +24,15 @@ class _CBScheduler(_BaseCBScheduler):
super().__init__()
self._schedule_fn = _cbsched_schedule(loop, ctx, self._run, cb)
def cancel(self):
return False
def cancelling(self):
return 0
def uncancel(self):
return 0
class _CBSchedulerAIO(_BaseCBScheduler):
__slots__ = []

View file

@ -1,6 +1,5 @@
use pyo3::{exceptions::PyStopIteration, prelude::*, types::PyDict, IntoPyObjectExt};
use std::sync::{atomic, Arc, OnceLock, RwLock};
use std::sync::{atomic, Arc, Mutex, OnceLock, RwLock};
use tokio::sync::Notify;
pub(crate) type ArcCBScheduler = Arc<Py<CallbackScheduler>>;
@ -255,9 +254,142 @@ impl CallbackScheduler {
self.schedule_fn.set(val).unwrap();
}
#[cfg(not(any(Py_3_12, Py_3_13)))]
fn _run(pyself: Py<Self>, py: Python, coro: PyObject) {
CallbackScheduler::send(pyself, py, coro);
}
#[cfg(any(Py_3_12, Py_3_13))]
fn _run(pyself: Py<Self>, py: Python, coro: PyObject) {
let stepper = Py::new(py, CallbackSchedulerStep::new(py, pyself, coro)).unwrap();
CallbackSchedulerStep::send(stepper, py);
}
}
#[pyclass(frozen, module = "granian._granian")]
pub(crate) struct CallbackSchedulerStep {
sched: Py<CallbackScheduler>,
coro: PyObject,
futw: Mutex<Option<PyObject>>,
pyname_wake: PyObject,
}
impl CallbackSchedulerStep {
#[cfg(any(Py_3_12, Py_3_13))]
pub(crate) fn new(py: Python, sched: Py<CallbackScheduler>, coro: PyObject) -> Self {
Self {
sched,
coro,
futw: Mutex::new(None),
pyname_wake: pyo3::intern!(py, "wake").into_py_any(py).unwrap(),
}
}
#[inline]
pub(crate) fn send(pyself: Py<Self>, py: Python) {
let rself = pyself.get();
let rsched = rself.sched.get();
let ptr = pyself.as_ptr();
{
let mut guard = rself.futw.lock().unwrap();
*guard = None;
}
unsafe {
pyo3::ffi::PyObject_CallOneArg(rsched.aio_tenter.as_ptr(), ptr);
}
if let Ok(res) = unsafe {
let res = pyo3::ffi::PyObject_CallMethodOneArg(
rself.coro.as_ptr(),
rsched.pyname_aiosend.as_ptr(),
rsched.pynone.as_ptr(),
);
Bound::from_owned_ptr_or_err(py, res)
} {
if unsafe {
let vptr = pyo3::ffi::PyObject_GetAttr(res.as_ptr(), rsched.pyname_aioblock.as_ptr());
Bound::from_owned_ptr_or_err(py, vptr)
.map(|v| v.extract::<bool>().unwrap_or(false))
.unwrap_or(false)
} {
let resp = res.as_ptr();
{
let mut guard = rself.futw.lock().unwrap();
*guard = Some(res.unbind().clone_ref(py));
}
unsafe {
pyo3::ffi::PyObject_SetAttr(resp, rsched.pyname_aioblock.as_ptr(), rsched.pyfalse.as_ptr());
pyo3::ffi::PyObject_Call(
pyo3::ffi::PyObject_GetAttr(resp, rsched.pyname_donecb.as_ptr()),
(pyself.clone_ref(py),).into_py_any(py).unwrap().into_ptr(),
rsched.pykw_ctx.as_ptr(),
);
}
} else {
unsafe {
let mptr = pyo3::ffi::PyObject_GetAttr(ptr, rself.pyname_wake.as_ptr());
pyo3::ffi::PyObject_CallMethodOneArg(
#[allow(clippy::used_underscore_binding)]
rsched._loop.as_ptr(),
rsched.pyname_loopcs.as_ptr(),
mptr,
);
}
}
}
unsafe {
pyo3::ffi::PyObject_CallOneArg(rsched.aio_texit.as_ptr(), ptr);
}
}
#[inline]
pub(crate) fn throw(pyself: Py<Self>, _py: Python, err: PyObject) {
let rself = pyself.get();
let rsched = rself.sched.get();
let ptr = pyself.as_ptr();
unsafe {
let corom = pyo3::ffi::PyObject_GetAttr(rself.coro.as_ptr(), rsched.pyname_aiothrow.as_ptr());
pyo3::ffi::PyObject_CallOneArg(rsched.aio_tenter.as_ptr(), ptr);
pyo3::ffi::PyObject_CallOneArg(corom, err.as_ptr());
pyo3::ffi::PyErr_Clear();
pyo3::ffi::PyObject_CallOneArg(rsched.aio_texit.as_ptr(), ptr);
}
}
}
#[pymethods]
impl CallbackSchedulerStep {
fn _step(pyself: Py<Self>, py: Python) {
CallbackSchedulerStep::send(pyself, py);
}
fn __call__(pyself: Py<Self>, py: Python, fut: PyObject) {
match fut.call_method0(py, pyo3::intern!(py, "result")) {
Ok(_) => CallbackSchedulerStep::send(pyself, py),
Err(err) => CallbackSchedulerStep::throw(pyself, py, err.into_py_any(py).unwrap()),
}
}
fn cancel(&self, py: Python) -> PyResult<PyObject> {
let guard = self.futw.lock().unwrap();
if let Some(v) = guard.as_ref() {
return v.call_method0(py, pyo3::intern!(py, "cancel"));
}
Ok(self.sched.get().pyfalse.clone_ref(py))
}
fn cancelling(&self) -> i32 {
0
}
fn uncancel(&self) -> i32 {
0
}
}
#[pyclass(frozen, module = "granian._granian")]

View file

@ -1,3 +1,4 @@
import asyncio
import json
import pathlib
@ -126,6 +127,31 @@ async def err_proto(scope, receive, send):
await send({'type': 'wrong.msg'})
async def timeout_n(scope, receive, send):
async def _inner():
return b'ok'
await send(PLAINTEXT_RESPONSE)
try:
ret = await asyncio.wait_for(_inner(), None)
except asyncio.TimeoutError:
ret = b'timeout'
await send({'type': 'http.response.body', 'body': ret, 'more_body': False})
async def timeout_w(scope, receive, send):
async def _inner():
await asyncio.sleep(3)
return b'ok'
await send(PLAINTEXT_RESPONSE)
try:
ret = await asyncio.wait_for(_inner(), 1)
except asyncio.TimeoutError:
ret = b'timeout'
await send({'type': 'http.response.body', 'body': ret, 'more_body': False})
async def lifespan(scope, receive, send):
msg = await receive()
if msg['type'] == 'lifespan.startup':
@ -147,4 +173,6 @@ def app(scope, receive, send):
'/ws_push': ws_push,
'/err_app': err_app,
'/err_proto': err_proto,
'/timeout_n': timeout_n,
'/timeout_w': timeout_w,
}[scope['path']](scope, receive, send)

View file

@ -88,3 +88,20 @@ async def test_sniffio(asgi_server, threading_mode):
assert res.status_code == 200
assert res.text == 'asyncio'
@pytest.mark.asyncio
@pytest.mark.skipif(bool(os.getenv('PGO_RUN')), reason='PGO build')
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_timeout(asgi_server, threading_mode):
async with asgi_server(threading_mode) as port:
res = httpx.get(f'http://localhost:{port}/timeout_n')
assert res.status_code == 200
assert res.text == 'ok'
async with asgi_server(threading_mode) as port:
res = httpx.get(f'http://localhost:{port}/timeout_w')
assert res.status_code == 200
assert res.text == 'timeout'