gh-76785: Fix interpreters.Queue.get_nowait() (gh-116166)

I missed this change in gh-115566.
This commit is contained in:
Eric Snow 2024-03-01 09:36:35 -07:00 committed by GitHub
parent a7549b03ce
commit 936d4611d6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 73 additions and 46 deletions

View file

@ -215,10 +215,15 @@ class Queue:
is the same as get(). is the same as get().
""" """
try: try:
return _queues.get(self._id) obj, fmt = _queues.get(self._id)
except _queues.QueueEmpty as exc: except _queues.QueueEmpty as exc:
exc.__class__ = QueueEmpty exc.__class__ = QueueEmpty
raise # re-raise raise # re-raise
if fmt == _PICKLED:
obj = pickle.loads(obj)
else:
assert fmt == _SHARED_ONLY
return obj
_queues._register_queue_type(Queue) _queues._register_queue_type(Queue)

View file

@ -210,10 +210,15 @@ class TestQueueOps(TestBase):
]: ]:
with self.subTest(repr(obj)): with self.subTest(repr(obj)):
queue = queues.create() queue = queues.create()
queue.put(obj, syncobj=True) queue.put(obj, syncobj=True)
obj2 = queue.get() obj2 = queue.get()
self.assertEqual(obj2, obj) self.assertEqual(obj2, obj)
queue.put(obj, syncobj=True)
obj2 = queue.get_nowait()
self.assertEqual(obj2, obj)
for obj in [ for obj in [
[1, 2, 3], [1, 2, 3],
{'a': 13, 'b': 17}, {'a': 13, 'b': 17},
@ -237,10 +242,15 @@ class TestQueueOps(TestBase):
]: ]:
with self.subTest(repr(obj)): with self.subTest(repr(obj)):
queue = queues.create() queue = queues.create()
queue.put(obj, syncobj=False) queue.put(obj, syncobj=False)
obj2 = queue.get() obj2 = queue.get()
self.assertEqual(obj2, obj) self.assertEqual(obj2, obj)
queue.put(obj, syncobj=False)
obj2 = queue.get_nowait()
self.assertEqual(obj2, obj)
def test_get_timeout(self): def test_get_timeout(self):
queue = queues.create() queue = queues.create()
with self.assertRaises(queues.QueueEmpty): with self.assertRaises(queues.QueueEmpty):
@ -254,11 +264,13 @@ class TestQueueOps(TestBase):
def test_put_get_default_syncobj(self): def test_put_get_default_syncobj(self):
expected = list(range(20)) expected = list(range(20))
queue = queues.create(syncobj=True) queue = queues.create(syncobj=True)
for i in range(20): for methname in ('get', 'get_nowait'):
queue.put(i) with self.subTest(f'{methname}()'):
actual = [queue.get() for _ in range(20)] get = getattr(queue, methname)
for i in range(20):
self.assertEqual(actual, expected) queue.put(i)
actual = [get() for _ in range(20)]
self.assertEqual(actual, expected)
obj = [1, 2, 3] # lists are not shareable obj = [1, 2, 3] # lists are not shareable
with self.assertRaises(interpreters.NotShareableError): with self.assertRaises(interpreters.NotShareableError):
@ -267,29 +279,36 @@ class TestQueueOps(TestBase):
def test_put_get_default_not_syncobj(self): def test_put_get_default_not_syncobj(self):
expected = list(range(20)) expected = list(range(20))
queue = queues.create(syncobj=False) queue = queues.create(syncobj=False)
for i in range(20): for methname in ('get', 'get_nowait'):
queue.put(i) with self.subTest(f'{methname}()'):
actual = [queue.get() for _ in range(20)] get = getattr(queue, methname)
self.assertEqual(actual, expected) for i in range(20):
queue.put(i)
actual = [get() for _ in range(20)]
self.assertEqual(actual, expected)
obj = [1, 2, 3] # lists are not shareable obj = [1, 2, 3] # lists are not shareable
queue.put(obj) queue.put(obj)
obj2 = queue.get() obj2 = get()
self.assertEqual(obj, obj2) self.assertEqual(obj, obj2)
self.assertIsNot(obj, obj2) self.assertIsNot(obj, obj2)
def test_put_get_same_interpreter(self): def test_put_get_same_interpreter(self):
interp = interpreters.create() interp = interpreters.create()
interp.exec(dedent(""" interp.exec(dedent("""
from test.support.interpreters import queues from test.support.interpreters import queues
queue = queues.create() queue = queues.create()
orig = b'spam'
queue.put(orig, syncobj=True)
obj = queue.get()
assert obj == orig, 'expected: obj == orig'
assert obj is not orig, 'expected: obj is not orig'
""")) """))
for methname in ('get', 'get_nowait'):
with self.subTest(f'{methname}()'):
interp.exec(dedent(f"""
orig = b'spam'
queue.put(orig, syncobj=True)
obj = queue.{methname}()
assert obj == orig, 'expected: obj == orig'
assert obj is not orig, 'expected: obj is not orig'
"""))
def test_put_get_different_interpreters(self): def test_put_get_different_interpreters(self):
interp = interpreters.create() interp = interpreters.create()
@ -297,34 +316,37 @@ class TestQueueOps(TestBase):
queue2 = queues.create() queue2 = queues.create()
self.assertEqual(len(queues.list_all()), 2) self.assertEqual(len(queues.list_all()), 2)
obj1 = b'spam' for methname in ('get', 'get_nowait'):
queue1.put(obj1, syncobj=True) with self.subTest(f'{methname}()'):
obj1 = b'spam'
queue1.put(obj1, syncobj=True)
out = _run_output( out = _run_output(
interp, interp,
dedent(f""" dedent(f"""
from test.support.interpreters import queues from test.support.interpreters import queues
queue1 = queues.Queue({queue1.id}) queue1 = queues.Queue({queue1.id})
queue2 = queues.Queue({queue2.id}) queue2 = queues.Queue({queue2.id})
assert queue1.qsize() == 1, 'expected: queue1.qsize() == 1' assert queue1.qsize() == 1, 'expected: queue1.qsize() == 1'
obj = queue1.get() obj = queue1.{methname}()
assert queue1.qsize() == 0, 'expected: queue1.qsize() == 0' assert queue1.qsize() == 0, 'expected: queue1.qsize() == 0'
assert obj == b'spam', 'expected: obj == obj1' assert obj == b'spam', 'expected: obj == obj1'
# When going to another interpreter we get a copy. # When going to another interpreter we get a copy.
assert id(obj) != {id(obj1)}, 'expected: obj is not obj1' assert id(obj) != {id(obj1)}, 'expected: obj is not obj1'
obj2 = b'eggs' obj2 = b'eggs'
print(id(obj2)) print(id(obj2))
assert queue2.qsize() == 0, 'expected: queue2.qsize() == 0' assert queue2.qsize() == 0, 'expected: queue2.qsize() == 0'
queue2.put(obj2, syncobj=True) queue2.put(obj2, syncobj=True)
assert queue2.qsize() == 1, 'expected: queue2.qsize() == 1' assert queue2.qsize() == 1, 'expected: queue2.qsize() == 1'
""")) """))
self.assertEqual(len(queues.list_all()), 2) self.assertEqual(len(queues.list_all()), 2)
self.assertEqual(queue1.qsize(), 0) self.assertEqual(queue1.qsize(), 0)
self.assertEqual(queue2.qsize(), 1) self.assertEqual(queue2.qsize(), 1)
obj2 = queue2.get() get = getattr(queue2, methname)
self.assertEqual(obj2, b'eggs') obj2 = get()
self.assertNotEqual(id(obj2), int(out)) self.assertEqual(obj2, b'eggs')
self.assertNotEqual(id(obj2), int(out))
def test_put_cleared_with_subinterpreter(self): def test_put_cleared_with_subinterpreter(self):
interp = interpreters.create() interp = interpreters.create()