gh-84570: Add Timeouts to SendChannel.send() and RecvChannel.recv() (gh-110567)

This commit is contained in:
Eric Snow 2023-10-17 17:05:49 -06:00 committed by GitHub
parent 7029c1a1c5
commit c58c63fdf6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 205 additions and 50 deletions

View file

@ -89,6 +89,12 @@ extern int _PyThread_at_fork_reinit(PyThread_type_lock *lock);
// unset: -1 seconds, in nanoseconds // unset: -1 seconds, in nanoseconds
#define PyThread_UNSET_TIMEOUT ((_PyTime_t)(-1 * 1000 * 1000 * 1000)) #define PyThread_UNSET_TIMEOUT ((_PyTime_t)(-1 * 1000 * 1000 * 1000))
// Exported for the _xxinterpchannels module.
PyAPI_FUNC(int) PyThread_ParseTimeoutArg(
PyObject *arg,
int blocking,
PY_TIMEOUT_T *timeout);
/* Helper to acquire an interruptible lock with a timeout. If the lock acquire /* Helper to acquire an interruptible lock with a timeout. If the lock acquire
* is interrupted, signal handlers are run, and if they raise an exception, * is interrupted, signal handlers are run, and if they raise an exception,
* PY_LOCK_INTR is returned. Otherwise, PY_LOCK_ACQUIRED or PY_LOCK_FAILURE * PY_LOCK_INTR is returned. Otherwise, PY_LOCK_ACQUIRED or PY_LOCK_FAILURE

View file

@ -170,15 +170,25 @@ class RecvChannel(_ChannelEnd):
_end = 'recv' _end = 'recv'
def recv(self, *, _sentinel=object(), _delay=10 / 1000): # 10 milliseconds def recv(self, timeout=None, *,
_sentinel=object(),
_delay=10 / 1000, # 10 milliseconds
):
"""Return the next object from the channel. """Return the next object from the channel.
This blocks until an object has been sent, if none have been This blocks until an object has been sent, if none have been
sent already. sent already.
""" """
if timeout is not None:
timeout = int(timeout)
if timeout < 0:
raise ValueError(f'timeout value must be non-negative')
end = time.time() + timeout
obj = _channels.recv(self._id, _sentinel) obj = _channels.recv(self._id, _sentinel)
while obj is _sentinel: while obj is _sentinel:
time.sleep(_delay) time.sleep(_delay)
if timeout is not None and time.time() >= end:
raise TimeoutError
obj = _channels.recv(self._id, _sentinel) obj = _channels.recv(self._id, _sentinel)
return obj return obj
@ -203,12 +213,12 @@ class SendChannel(_ChannelEnd):
_end = 'send' _end = 'send'
def send(self, obj): def send(self, obj, timeout=None):
"""Send the object (i.e. its data) to the channel's receiving end. """Send the object (i.e. its data) to the channel's receiving end.
This blocks until the object is received. This blocks until the object is received.
""" """
_channels.send(self._id, obj, blocking=True) _channels.send(self._id, obj, timeout=timeout, blocking=True)
def send_nowait(self, obj): def send_nowait(self, obj):
"""Send the object to the channel's receiving end. """Send the object to the channel's receiving end.
@ -221,12 +231,12 @@ class SendChannel(_ChannelEnd):
# See bpo-32604 and gh-19829. # See bpo-32604 and gh-19829.
return _channels.send(self._id, obj, blocking=False) return _channels.send(self._id, obj, blocking=False)
def send_buffer(self, obj): def send_buffer(self, obj, timeout=None):
"""Send the object's buffer to the channel's receiving end. """Send the object's buffer to the channel's receiving end.
This blocks until the object is received. This blocks until the object is received.
""" """
_channels.send_buffer(self._id, obj, blocking=True) _channels.send_buffer(self._id, obj, timeout=timeout, blocking=True)
def send_buffer_nowait(self, obj): def send_buffer_nowait(self, obj):
"""Send the object's buffer to the channel's receiving end. """Send the object's buffer to the channel's receiving end.

View file

@ -864,22 +864,34 @@ class ChannelTests(TestBase):
self.assertEqual(received, obj) self.assertEqual(received, obj)
def test_send_closed_while_waiting(self): def test_send_timeout(self):
obj = b'spam' obj = b'spam'
wait = self.build_send_waiter(obj)
cid = channels.create()
def f():
wait()
channels.close(cid, force=True)
t = threading.Thread(target=f)
t.start()
with self.assertRaises(channels.ChannelClosedError):
channels.send(cid, obj, blocking=True)
t.join()
def test_send_buffer_closed_while_waiting(self): with self.subTest('non-blocking with timeout'):
cid = channels.create()
with self.assertRaises(ValueError):
channels.send(cid, obj, blocking=False, timeout=0.1)
with self.subTest('timeout hit'):
cid = channels.create()
with self.assertRaises(TimeoutError):
channels.send(cid, obj, blocking=True, timeout=0.1)
with self.assertRaises(channels.ChannelEmptyError):
received = channels.recv(cid)
print(repr(received))
with self.subTest('timeout not hit'):
cid = channels.create()
def f():
recv_wait(cid)
t = threading.Thread(target=f)
t.start()
channels.send(cid, obj, blocking=True, timeout=10)
t.join()
def test_send_buffer_timeout(self):
try: try:
self._has_run_once self._has_run_once_timeout
except AttributeError: except AttributeError:
# At the moment, this test leaks a few references. # At the moment, this test leaks a few references.
# It looks like the leak originates with the addition # It looks like the leak originates with the addition
@ -888,19 +900,95 @@ class ChannelTests(TestBase):
# if the refleak isn't fixed yet, so we skip here. # if the refleak isn't fixed yet, so we skip here.
raise unittest.SkipTest('temporarily skipped due to refleaks') raise unittest.SkipTest('temporarily skipped due to refleaks')
else: else:
self._has_run_once = True self._has_run_once_timeout = True
obj = bytearray(b'spam')
with self.subTest('non-blocking with timeout'):
cid = channels.create()
with self.assertRaises(ValueError):
channels.send_buffer(cid, obj, blocking=False, timeout=0.1)
with self.subTest('timeout hit'):
cid = channels.create()
with self.assertRaises(TimeoutError):
channels.send_buffer(cid, obj, blocking=True, timeout=0.1)
with self.assertRaises(channels.ChannelEmptyError):
received = channels.recv(cid)
print(repr(received))
with self.subTest('timeout not hit'):
cid = channels.create()
def f():
recv_wait(cid)
t = threading.Thread(target=f)
t.start()
channels.send_buffer(cid, obj, blocking=True, timeout=10)
t.join()
def test_send_closed_while_waiting(self):
obj = b'spam'
wait = self.build_send_waiter(obj)
with self.subTest('without timeout'):
cid = channels.create()
def f():
wait()
channels.close(cid, force=True)
t = threading.Thread(target=f)
t.start()
with self.assertRaises(channels.ChannelClosedError):
channels.send(cid, obj, blocking=True)
t.join()
with self.subTest('with timeout'):
cid = channels.create()
def f():
wait()
channels.close(cid, force=True)
t = threading.Thread(target=f)
t.start()
with self.assertRaises(channels.ChannelClosedError):
channels.send(cid, obj, blocking=True, timeout=30)
t.join()
def test_send_buffer_closed_while_waiting(self):
try:
self._has_run_once_closed
except AttributeError:
# At the moment, this test leaks a few references.
# It looks like the leak originates with the addition
# of _channels.send_buffer() (gh-110246), whereas the
# tests were added afterward. We want this test even
# if the refleak isn't fixed yet, so we skip here.
raise unittest.SkipTest('temporarily skipped due to refleaks')
else:
self._has_run_once_closed = True
obj = bytearray(b'spam') obj = bytearray(b'spam')
wait = self.build_send_waiter(obj, buffer=True) wait = self.build_send_waiter(obj, buffer=True)
cid = channels.create()
def f(): with self.subTest('without timeout'):
wait() cid = channels.create()
channels.close(cid, force=True) def f():
t = threading.Thread(target=f) wait()
t.start() channels.close(cid, force=True)
with self.assertRaises(channels.ChannelClosedError): t = threading.Thread(target=f)
channels.send_buffer(cid, obj, blocking=True) t.start()
t.join() with self.assertRaises(channels.ChannelClosedError):
channels.send_buffer(cid, obj, blocking=True)
t.join()
with self.subTest('with timeout'):
cid = channels.create()
def f():
wait()
channels.close(cid, force=True)
t = threading.Thread(target=f)
t.start()
with self.assertRaises(channels.ChannelClosedError):
channels.send_buffer(cid, obj, blocking=True, timeout=30)
t.join()
#------------------- #-------------------
# close # close

View file

@ -1022,6 +1022,11 @@ class TestSendRecv(TestBase):
self.assertEqual(obj2, b'eggs') self.assertEqual(obj2, b'eggs')
self.assertNotEqual(id(obj2), int(out)) self.assertNotEqual(id(obj2), int(out))
def test_recv_timeout(self):
r, _ = interpreters.create_channel()
with self.assertRaises(TimeoutError):
r.recv(timeout=1)
def test_recv_channel_does_not_exist(self): def test_recv_channel_does_not_exist(self):
ch = interpreters.RecvChannel(1_000_000) ch = interpreters.RecvChannel(1_000_000)
with self.assertRaises(interpreters.ChannelNotFoundError): with self.assertRaises(interpreters.ChannelNotFoundError):

View file

@ -214,6 +214,8 @@ _queue_SimpleQueue_get_impl(simplequeueobject *self, PyTypeObject *cls,
PY_TIMEOUT_T microseconds; PY_TIMEOUT_T microseconds;
PyThreadState *tstate = PyThreadState_Get(); PyThreadState *tstate = PyThreadState_Get();
// XXX Use PyThread_ParseTimeoutArg().
if (block == 0) { if (block == 0) {
/* Non-blocking */ /* Non-blocking */
microseconds = 0; microseconds = 0;

View file

@ -88,14 +88,15 @@ lock_acquire_parse_args(PyObject *args, PyObject *kwds,
char *kwlist[] = {"blocking", "timeout", NULL}; char *kwlist[] = {"blocking", "timeout", NULL};
int blocking = 1; int blocking = 1;
PyObject *timeout_obj = NULL; PyObject *timeout_obj = NULL;
const _PyTime_t unset_timeout = _PyTime_FromSeconds(-1);
*timeout = unset_timeout ;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|pO:acquire", kwlist, if (!PyArg_ParseTupleAndKeywords(args, kwds, "|pO:acquire", kwlist,
&blocking, &timeout_obj)) &blocking, &timeout_obj))
return -1; return -1;
// XXX Use PyThread_ParseTimeoutArg().
const _PyTime_t unset_timeout = _PyTime_FromSeconds(-1);
*timeout = unset_timeout;
if (timeout_obj if (timeout_obj
&& _PyTime_FromSecondsObject(timeout, && _PyTime_FromSecondsObject(timeout,
timeout_obj, _PyTime_ROUND_TIMEOUT) < 0) timeout_obj, _PyTime_ROUND_TIMEOUT) < 0)
@ -108,7 +109,7 @@ lock_acquire_parse_args(PyObject *args, PyObject *kwds,
} }
if (*timeout < 0 && *timeout != unset_timeout) { if (*timeout < 0 && *timeout != unset_timeout) {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"timeout value must be positive"); "timeout value must be a non-negative number");
return -1; return -1;
} }
if (!blocking) if (!blocking)

View file

@ -242,9 +242,8 @@ add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared,
} }
static int static int
wait_for_lock(PyThread_type_lock mutex) wait_for_lock(PyThread_type_lock mutex, PY_TIMEOUT_T timeout)
{ {
PY_TIMEOUT_T timeout = PyThread_UNSET_TIMEOUT;
PyLockStatus res = PyThread_acquire_lock_timed_with_retries(mutex, timeout); PyLockStatus res = PyThread_acquire_lock_timed_with_retries(mutex, timeout);
if (res == PY_LOCK_INTR) { if (res == PY_LOCK_INTR) {
/* KeyboardInterrupt, etc. */ /* KeyboardInterrupt, etc. */
@ -1883,7 +1882,8 @@ _channel_clear_sent(_channels *channels, int64_t cid, _waiting_t *waiting)
} }
static int static int
_channel_send_wait(_channels *channels, int64_t cid, PyObject *obj) _channel_send_wait(_channels *channels, int64_t cid, PyObject *obj,
PY_TIMEOUT_T timeout)
{ {
// We use a stack variable here, so we must ensure that &waiting // We use a stack variable here, so we must ensure that &waiting
// is not held by any channel item at the point this function exits. // is not held by any channel item at the point this function exits.
@ -1901,7 +1901,7 @@ _channel_send_wait(_channels *channels, int64_t cid, PyObject *obj)
} }
/* Wait until the object is received. */ /* Wait until the object is received. */
if (wait_for_lock(waiting.mutex) < 0) { if (wait_for_lock(waiting.mutex, timeout) < 0) {
assert(PyErr_Occurred()); assert(PyErr_Occurred());
_waiting_finish_releasing(&waiting); _waiting_finish_releasing(&waiting);
/* The send() call is failing now, so make sure the item /* The send() call is failing now, so make sure the item
@ -2816,25 +2816,29 @@ receive end.");
static PyObject * static PyObject *
channel_send(PyObject *self, PyObject *args, PyObject *kwds) channel_send(PyObject *self, PyObject *args, PyObject *kwds)
{ {
// XXX Add a timeout arg. static char *kwlist[] = {"cid", "obj", "blocking", "timeout", NULL};
static char *kwlist[] = {"cid", "obj", "blocking", NULL};
int64_t cid;
struct channel_id_converter_data cid_data = { struct channel_id_converter_data cid_data = {
.module = self, .module = self,
}; };
PyObject *obj; PyObject *obj;
int blocking = 1; int blocking = 1;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|$p:channel_send", kwlist, PyObject *timeout_obj = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|$pO:channel_send", kwlist,
channel_id_converter, &cid_data, &obj, channel_id_converter, &cid_data, &obj,
&blocking)) { &blocking, &timeout_obj)) {
return NULL;
}
int64_t cid = cid_data.cid;
PY_TIMEOUT_T timeout;
if (PyThread_ParseTimeoutArg(timeout_obj, blocking, &timeout) < 0) {
return NULL; return NULL;
} }
cid = cid_data.cid;
/* Queue up the object. */ /* Queue up the object. */
int err = 0; int err = 0;
if (blocking) { if (blocking) {
err = _channel_send_wait(&_globals.channels, cid, obj); err = _channel_send_wait(&_globals.channels, cid, obj, timeout);
} }
else { else {
err = _channel_send(&_globals.channels, cid, obj, NULL); err = _channel_send(&_globals.channels, cid, obj, NULL);
@ -2855,20 +2859,25 @@ By default this waits for the object to be received.");
static PyObject * static PyObject *
channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds) channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds)
{ {
static char *kwlist[] = {"cid", "obj", "blocking", NULL}; static char *kwlist[] = {"cid", "obj", "blocking", "timeout", NULL};
int64_t cid;
struct channel_id_converter_data cid_data = { struct channel_id_converter_data cid_data = {
.module = self, .module = self,
}; };
PyObject *obj; PyObject *obj;
int blocking = 1; int blocking = 1;
PyObject *timeout_obj = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, if (!PyArg_ParseTupleAndKeywords(args, kwds,
"O&O|$p:channel_send_buffer", kwlist, "O&O|$pO:channel_send_buffer", kwlist,
channel_id_converter, &cid_data, &obj, channel_id_converter, &cid_data, &obj,
&blocking)) { &blocking, &timeout_obj)) {
return NULL;
}
int64_t cid = cid_data.cid;
PY_TIMEOUT_T timeout;
if (PyThread_ParseTimeoutArg(timeout_obj, blocking, &timeout) < 0) {
return NULL; return NULL;
} }
cid = cid_data.cid;
PyObject *tempobj = PyMemoryView_FromObject(obj); PyObject *tempobj = PyMemoryView_FromObject(obj);
if (tempobj == NULL) { if (tempobj == NULL) {
@ -2878,7 +2887,7 @@ channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds)
/* Queue up the object. */ /* Queue up the object. */
int err = 0; int err = 0;
if (blocking) { if (blocking) {
err = _channel_send_wait(&_globals.channels, cid, tempobj); err = _channel_send_wait(&_globals.channels, cid, tempobj, timeout);
} }
else { else {
err = _channel_send(&_globals.channels, cid, tempobj, NULL); err = _channel_send(&_globals.channels, cid, tempobj, NULL);

View file

@ -93,6 +93,40 @@ PyThread_set_stacksize(size_t size)
} }
int
PyThread_ParseTimeoutArg(PyObject *arg, int blocking, PY_TIMEOUT_T *timeout_p)
{
assert(_PyTime_FromSeconds(-1) == PyThread_UNSET_TIMEOUT);
if (arg == NULL || arg == Py_None) {
*timeout_p = blocking ? PyThread_UNSET_TIMEOUT : 0;
return 0;
}
if (!blocking) {
PyErr_SetString(PyExc_ValueError,
"can't specify a timeout for a non-blocking call");
return -1;
}
_PyTime_t timeout;
if (_PyTime_FromSecondsObject(&timeout, arg, _PyTime_ROUND_TIMEOUT) < 0) {
return -1;
}
if (timeout < 0) {
PyErr_SetString(PyExc_ValueError,
"timeout value must be a non-negative number");
return -1;
}
if (_PyTime_AsMicroseconds(timeout,
_PyTime_ROUND_TIMEOUT) > PY_TIMEOUT_MAX) {
PyErr_SetString(PyExc_OverflowError,
"timeout value is too large");
return -1;
}
*timeout_p = timeout;
return 0;
}
PyLockStatus PyLockStatus
PyThread_acquire_lock_timed_with_retries(PyThread_type_lock lock, PyThread_acquire_lock_timed_with_retries(PyThread_type_lock lock,
PY_TIMEOUT_T timeout) PY_TIMEOUT_T timeout)