gh-129107: make bytearray thread safe (#129108)

Co-authored-by: Kumar Aditya <kumaraditya@python.org>
This commit is contained in:
Tomasz Pytel 2025-02-15 02:19:42 -05:00 committed by GitHub
parent d2e60d8e59
commit a05433f24a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 904 additions and 100 deletions

View file

@ -11,12 +11,16 @@ import sys
import copy
import functools
import pickle
import sysconfig
import tempfile
import textwrap
import threading
import unittest
import test.support
from test import support
from test.support import import_helper
from test.support import threading_helper
from test.support import warnings_helper
import test.string_tests
import test.list_tests
@ -2185,5 +2189,336 @@ class BytesSubclassTest(SubclassTest, unittest.TestCase):
type2test = BytesSubclass
class FreeThreadingTest(unittest.TestCase):
@unittest.skipUnless(support.Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
@threading_helper.reap_threads
@threading_helper.requires_working_threading()
def test_free_threading_bytearray(self):
# Test pretty much everything that can break under free-threading.
# Non-deterministic, but at least one of these things will fail if
# bytearray module is not free-thread safe.
def clear(b, a, *args): # MODIFIES!
b.wait()
try: a.clear()
except BufferError: pass
def clear2(b, a, c): # MODIFIES c!
b.wait()
try: c.clear()
except BufferError: pass
def pop1(b, a): # MODIFIES!
b.wait()
try: a.pop()
except IndexError: pass
def append1(b, a): # MODIFIES!
b.wait()
a.append(0)
def insert1(b, a): # MODIFIES!
b.wait()
a.insert(0, 0)
def extend(b, a): # MODIFIES!
c = bytearray(b'0' * 0x400000)
b.wait()
a.extend(c)
def remove(b, a): # MODIFIES!
c = ord('0')
b.wait()
try: a.remove(c)
except ValueError: pass
def reverse(b, a): # modifies inplace
b.wait()
a.reverse()
def reduce(b, a):
b.wait()
a.__reduce__()
def reduceex2(b, a):
b.wait()
a.__reduce_ex__(2)
def reduceex3(b, a):
b.wait()
c = a.__reduce_ex__(3)
assert not c[1] or 0xdd not in c[1][0]
def count0(b, a):
b.wait()
a.count(0)
def decode(b, a):
b.wait()
a.decode()
def find(b, a):
c = bytearray(b'0' * 0x40000)
b.wait()
a.find(c)
def hex(b, a):
b.wait()
a.hex('_')
def join(b, a):
b.wait()
a.join([b'1', b'2', b'3'])
def replace(b, a):
b.wait()
a.replace(b'0', b'')
def maketrans(b, a, c):
b.wait()
try: a.maketrans(a, c)
except ValueError: pass
def translate(b, a, c):
b.wait()
a.translate(c)
def copy(b, a):
b.wait()
c = a.copy()
if c: assert c[0] == 48 # '0'
def endswith(b, a):
b.wait()
assert not a.endswith(b'\xdd')
def index(b, a):
b.wait()
try: a.index(b'\xdd')
except ValueError: return
assert False
def lstrip(b, a):
b.wait()
assert not a.lstrip(b'0')
def partition(b, a):
b.wait()
assert not a.partition(b'\xdd')[2]
def removeprefix(b, a):
b.wait()
assert not a.removeprefix(b'0')
def removesuffix(b, a):
b.wait()
assert not a.removesuffix(b'0')
def rfind(b, a):
b.wait()
assert a.rfind(b'\xdd') == -1
def rindex(b, a):
b.wait()
try: a.rindex(b'\xdd')
except ValueError: return
assert False
def rpartition(b, a):
b.wait()
assert not a.rpartition(b'\xdd')[0]
def rsplit(b, a):
b.wait()
assert len(a.rsplit(b'\xdd')) == 1
def rstrip(b, a):
b.wait()
assert not a.rstrip(b'0')
def split(b, a):
b.wait()
assert len(a.split(b'\xdd')) == 1
def splitlines(b, a):
b.wait()
l = len(a.splitlines())
assert l > 1 or l == 0
def startswith(b, a):
b.wait()
assert not a.startswith(b'\xdd')
def strip(b, a):
b.wait()
assert not a.strip(b'0')
def repeat(b, a):
b.wait()
a * 2
def contains(b, a):
b.wait()
assert 0xdd not in a
def iconcat(b, a): # MODIFIES!
c = bytearray(b'0' * 0x400000)
b.wait()
a += c
def irepeat(b, a): # MODIFIES!
b.wait()
a *= 2
def subscript(b, a):
b.wait()
try: assert a[0] != 0xdd
except IndexError: pass
def ass_subscript(b, a): # MODIFIES!
c = bytearray(b'0' * 0x400000)
b.wait()
a[:] = c
def mod(b, a):
c = tuple(range(4096))
b.wait()
try: a % c
except TypeError: pass
def repr_(b, a):
b.wait()
repr(a)
def capitalize(b, a):
b.wait()
c = a.capitalize()
assert not c or c[0] not in (0xdd, 0xcd)
def center(b, a):
b.wait()
c = a.center(0x60000)
assert not c or c[0x20000] not in (0xdd, 0xcd)
def expandtabs(b, a):
b.wait()
c = a.expandtabs()
assert not c or c[0] not in (0xdd, 0xcd)
def ljust(b, a):
b.wait()
c = a.ljust(0x600000)
assert not c or c[0] not in (0xdd, 0xcd)
def lower(b, a):
b.wait()
c = a.lower()
assert not c or c[0] not in (0xdd, 0xcd)
def rjust(b, a):
b.wait()
c = a.rjust(0x600000)
assert not c or c[-1] not in (0xdd, 0xcd)
def swapcase(b, a):
b.wait()
c = a.swapcase()
assert not c or c[-1] not in (0xdd, 0xcd)
def title(b, a):
b.wait()
c = a.title()
assert not c or c[-1] not in (0xdd, 0xcd)
def upper(b, a):
b.wait()
c = a.upper()
assert not c or c[-1] not in (0xdd, 0xcd)
def zfill(b, a):
b.wait()
c = a.zfill(0x400000)
assert not c or c[-1] not in (0xdd, 0xcd)
def check(funcs, a=None, *args):
if a is None:
a = bytearray(b'0' * 0x400000)
barrier = threading.Barrier(len(funcs))
threads = []
for func in funcs:
thread = threading.Thread(target=func, args=(barrier, a, *args))
threads.append(thread)
with threading_helper.start_threads(threads):
pass
for thread in threads:
threading_helper.join_thread(thread)
# hard errors
check([clear] + [reduce] * 10)
check([clear] + [reduceex2] * 10)
check([clear] + [append1] * 10)
check([clear] * 10)
check([clear] + [count0] * 10)
check([clear] + [decode] * 10)
check([clear] + [extend] * 10)
check([clear] + [find] * 10)
check([clear] + [hex] * 10)
check([clear] + [insert1] * 10)
check([clear] + [join] * 10)
check([clear] + [pop1] * 10)
check([clear] + [remove] * 10)
check([clear] + [replace] * 10)
check([clear] + [reverse] * 10)
check([clear, clear2] + [maketrans] * 10, bytearray(range(128)), bytearray(range(128)))
check([clear] + [translate] * 10, None, bytearray.maketrans(bytearray(range(128)), bytearray(range(128))))
check([clear] + [repeat] * 10)
check([clear] + [iconcat] * 10)
check([clear] + [irepeat] * 10)
check([clear] + [ass_subscript] * 10)
check([clear] + [repr_] * 10)
# value errors
check([clear] + [reduceex3] * 10, bytearray(b'a' * 0x40000))
check([clear] + [copy] * 10)
check([clear] + [endswith] * 10)
check([clear] + [index] * 10)
check([clear] + [lstrip] * 10)
check([clear] + [partition] * 10)
check([clear] + [removeprefix] * 10, bytearray(b'0'))
check([clear] + [removesuffix] * 10, bytearray(b'0'))
check([clear] + [rfind] * 10)
check([clear] + [rindex] * 10)
check([clear] + [rpartition] * 10)
check([clear] + [rsplit] * 10, bytearray(b'0' * 0x4000))
check([clear] + [rstrip] * 10)
check([clear] + [split] * 10, bytearray(b'0' * 0x4000))
check([clear] + [splitlines] * 10, bytearray(b'\n' * 0x400))
check([clear] + [startswith] * 10)
check([clear] + [strip] * 10)
check([clear] + [contains] * 10)
check([clear] + [subscript] * 10)
check([clear] + [mod] * 10, bytearray(b'%d' * 4096))
check([clear] + [capitalize] * 10, bytearray(b'a' * 0x40000))
check([clear] + [center] * 10, bytearray(b'a' * 0x40000))
check([clear] + [expandtabs] * 10, bytearray(b'0\t' * 4096))
check([clear] + [ljust] * 10, bytearray(b'0' * 0x400000))
check([clear] + [lower] * 10, bytearray(b'A' * 0x400000))
check([clear] + [rjust] * 10, bytearray(b'0' * 0x400000))
check([clear] + [swapcase] * 10, bytearray(b'aA' * 0x200000))
check([clear] + [title] * 10, bytearray(b'aA' * 0x200000))
check([clear] + [upper] * 10, bytearray(b'a' * 0x400000))
check([clear] + [zfill] * 10, bytearray(b'1' * 0x200000))
if __name__ == "__main__":
unittest.main()