gh-96471: Add asyncio queue shutdown (#104228)

Co-authored-by: Duprat <yduprat@gmail.com>
This commit is contained in:
Laurie O 2024-04-07 00:27:13 +10:00 committed by GitHub
parent 1d3225ae05
commit df4d84c3cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 301 additions and 3 deletions

View file

@ -522,5 +522,204 @@ class PriorityQueueJoinTests(_QueueJoinTestMixin, unittest.IsolatedAsyncioTestCa
q_class = asyncio.PriorityQueue
class _QueueShutdownTestMixin:
q_class = None
def assertRaisesShutdown(self, msg="Didn't appear to shut-down queue"):
return self.assertRaises(asyncio.QueueShutDown, msg=msg)
async def test_format(self):
q = self.q_class()
q.shutdown()
self.assertEqual(q._format(), 'maxsize=0 shutdown')
async def test_shutdown_empty(self):
# Test shutting down an empty queue
# Setup empty queue, and join() and get() tasks
q = self.q_class()
loop = asyncio.get_running_loop()
get_task = loop.create_task(q.get())
await asyncio.sleep(0) # want get task pending before shutdown
# Perform shut-down
q.shutdown(immediate=False) # unfinished tasks: 0 -> 0
self.assertEqual(q.qsize(), 0)
# Ensure join() task successfully finishes
await q.join()
# Ensure get() task is finished, and raised ShutDown
await asyncio.sleep(0)
self.assertTrue(get_task.done())
with self.assertRaisesShutdown():
await get_task
# Ensure put() and get() raise ShutDown
with self.assertRaisesShutdown():
await q.put("data")
with self.assertRaisesShutdown():
q.put_nowait("data")
with self.assertRaisesShutdown():
await q.get()
with self.assertRaisesShutdown():
q.get_nowait()
async def test_shutdown_nonempty(self):
# Test shutting down a non-empty queue
# Setup full queue with 1 item, and join() and put() tasks
q = self.q_class(maxsize=1)
loop = asyncio.get_running_loop()
q.put_nowait("data")
join_task = loop.create_task(q.join())
put_task = loop.create_task(q.put("data2"))
# Ensure put() task is not finished
await asyncio.sleep(0)
self.assertFalse(put_task.done())
# Perform shut-down
q.shutdown(immediate=False) # unfinished tasks: 1 -> 1
self.assertEqual(q.qsize(), 1)
# Ensure put() task is finished, and raised ShutDown
await asyncio.sleep(0)
self.assertTrue(put_task.done())
with self.assertRaisesShutdown():
await put_task
# Ensure get() succeeds on enqueued item
self.assertEqual(await q.get(), "data")
# Ensure join() task is not finished
await asyncio.sleep(0)
self.assertFalse(join_task.done())
# Ensure put() and get() raise ShutDown
with self.assertRaisesShutdown():
await q.put("data")
with self.assertRaisesShutdown():
q.put_nowait("data")
with self.assertRaisesShutdown():
await q.get()
with self.assertRaisesShutdown():
q.get_nowait()
# Ensure there is 1 unfinished task, and join() task succeeds
q.task_done()
await asyncio.sleep(0)
self.assertTrue(join_task.done())
await join_task
with self.assertRaises(
ValueError, msg="Didn't appear to mark all tasks done"
):
q.task_done()
async def test_shutdown_immediate(self):
# Test immediately shutting down a queue
# Setup queue with 1 item, and a join() task
q = self.q_class()
loop = asyncio.get_running_loop()
q.put_nowait("data")
join_task = loop.create_task(q.join())
# Perform shut-down
q.shutdown(immediate=True) # unfinished tasks: 1 -> 0
self.assertEqual(q.qsize(), 0)
# Ensure join() task has successfully finished
await asyncio.sleep(0)
self.assertTrue(join_task.done())
await join_task
# Ensure put() and get() raise ShutDown
with self.assertRaisesShutdown():
await q.put("data")
with self.assertRaisesShutdown():
q.put_nowait("data")
with self.assertRaisesShutdown():
await q.get()
with self.assertRaisesShutdown():
q.get_nowait()
# Ensure there are no unfinished tasks
with self.assertRaises(
ValueError, msg="Didn't appear to mark all tasks done"
):
q.task_done()
async def test_shutdown_immediate_with_unfinished(self):
# Test immediately shutting down a queue with unfinished tasks
# Setup queue with 2 items (1 retrieved), and a join() task
q = self.q_class()
loop = asyncio.get_running_loop()
q.put_nowait("data")
q.put_nowait("data")
join_task = loop.create_task(q.join())
self.assertEqual(await q.get(), "data")
# Perform shut-down
q.shutdown(immediate=True) # unfinished tasks: 2 -> 1
self.assertEqual(q.qsize(), 0)
# Ensure join() task is not finished
await asyncio.sleep(0)
self.assertFalse(join_task.done())
# Ensure put() and get() raise ShutDown
with self.assertRaisesShutdown():
await q.put("data")
with self.assertRaisesShutdown():
q.put_nowait("data")
with self.assertRaisesShutdown():
await q.get()
with self.assertRaisesShutdown():
q.get_nowait()
# Ensure there is 1 unfinished task
q.task_done()
with self.assertRaises(
ValueError, msg="Didn't appear to mark all tasks done"
):
q.task_done()
# Ensure join() task has successfully finished
await asyncio.sleep(0)
self.assertTrue(join_task.done())
await join_task
class QueueShutdownTests(
_QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
):
q_class = asyncio.Queue
class LifoQueueShutdownTests(
_QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
):
q_class = asyncio.LifoQueue
class PriorityQueueShutdownTests(
_QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
):
q_class = asyncio.PriorityQueue
if __name__ == '__main__':
unittest.main()