mirror of
https://github.com/python/cpython.git
synced 2025-08-04 08:59:19 +00:00
gh-123471: Make itertools.chain thread-safe (#135689)
This commit is contained in:
parent
536a5ff153
commit
0533c1faf2
3 changed files with 43 additions and 4 deletions
|
@ -1,6 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
from threading import Thread, Barrier
|
from threading import Thread, Barrier
|
||||||
from itertools import batched, cycle
|
from itertools import batched, chain, cycle
|
||||||
from test.support import threading_helper
|
from test.support import threading_helper
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ class ItertoolsThreading(unittest.TestCase):
|
||||||
barrier.wait()
|
barrier.wait()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
_ = next(it)
|
next(it)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -62,6 +62,34 @@ class ItertoolsThreading(unittest.TestCase):
|
||||||
|
|
||||||
barrier.reset()
|
barrier.reset()
|
||||||
|
|
||||||
|
@threading_helper.reap_threads
|
||||||
|
def test_chain(self):
|
||||||
|
number_of_threads = 6
|
||||||
|
number_of_iterations = 20
|
||||||
|
|
||||||
|
barrier = Barrier(number_of_threads)
|
||||||
|
def work(it):
|
||||||
|
barrier.wait()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
next(it)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
data = [(1, )] * 200
|
||||||
|
for it in range(number_of_iterations):
|
||||||
|
chain_iterator = chain(*data)
|
||||||
|
worker_threads = []
|
||||||
|
for ii in range(number_of_threads):
|
||||||
|
worker_threads.append(
|
||||||
|
Thread(target=work, args=[chain_iterator]))
|
||||||
|
|
||||||
|
with threading_helper.start_threads(worker_threads):
|
||||||
|
pass
|
||||||
|
|
||||||
|
barrier.reset()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Make concurrent iterations over :class:`itertools.chain` safe under :term:`free threading`.
|
|
@ -1880,8 +1880,8 @@ chain_traverse(PyObject *op, visitproc visit, void *arg)
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject *
|
static inline PyObject *
|
||||||
chain_next(PyObject *op)
|
chain_next_lock_held(PyObject *op)
|
||||||
{
|
{
|
||||||
chainobject *lz = chainobject_CAST(op);
|
chainobject *lz = chainobject_CAST(op);
|
||||||
PyObject *item;
|
PyObject *item;
|
||||||
|
@ -1919,6 +1919,16 @@ chain_next(PyObject *op)
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PyObject *
|
||||||
|
chain_next(PyObject *op)
|
||||||
|
{
|
||||||
|
PyObject *result;
|
||||||
|
Py_BEGIN_CRITICAL_SECTION(op);
|
||||||
|
result = chain_next_lock_held(op);
|
||||||
|
Py_END_CRITICAL_SECTION()
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
PyDoc_STRVAR(chain_doc,
|
PyDoc_STRVAR(chain_doc,
|
||||||
"chain(*iterables)\n\
|
"chain(*iterables)\n\
|
||||||
--\n\
|
--\n\
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue