gh-123619: Add an unstable C API function for enabling deferred reference counting (GH-123635)

Co-authored-by: Sam Gross <colesbury@gmail.com>
This commit is contained in:
Peter Bierma 2024-11-13 08:27:16 -05:00 committed by GitHub
parent 29b5323c45
commit d00878b06a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 128 additions and 1 deletions

View file

@ -1,10 +1,13 @@
import enum
import unittest
from test import support
from test.support import import_helper
from test.support import os_helper
from test.support import threading_helper
_testlimitedcapi = import_helper.import_module('_testlimitedcapi')
_testcapi = import_helper.import_module('_testcapi')
_testinternalcapi = import_helper.import_module('_testinternalcapi')
class Constant(enum.IntEnum):
@ -131,5 +134,48 @@ class ClearWeakRefsNoCallbacksTest(unittest.TestCase):
_testcapi.pyobject_clear_weakrefs_no_callbacks(obj)
class EnableDeferredRefcountingTest(unittest.TestCase):
"""Test PyUnstable_Object_EnableDeferredRefcount"""
@support.requires_resource("cpu")
def test_enable_deferred_refcount(self):
from threading import Thread
self.assertEqual(_testcapi.pyobject_enable_deferred_refcount("not tracked"), 0)
foo = []
self.assertEqual(_testcapi.pyobject_enable_deferred_refcount(foo), int(support.Py_GIL_DISABLED))
# Make sure reference counting works on foo now
self.assertEqual(foo, [])
if support.Py_GIL_DISABLED:
self.assertTrue(_testinternalcapi.has_deferred_refcount(foo))
# Make sure that PyUnstable_Object_EnableDeferredRefcount is thread safe
def silly_func(obj):
self.assertIn(
_testcapi.pyobject_enable_deferred_refcount(obj),
(0, 1)
)
silly_list = [1, 2, 3]
threads = [
Thread(target=silly_func, args=(silly_list,)) for _ in range(5)
]
with threading_helper.catch_threading_exception() as cm:
for t in threads:
t.start()
for i in range(10):
silly_list.append(i)
for t in threads:
t.join()
self.assertIsNone(cm.exc_value)
if support.Py_GIL_DISABLED:
self.assertTrue(_testinternalcapi.has_deferred_refcount(silly_list))
if __name__ == "__main__":
unittest.main()