mirror of
https://github.com/python/cpython.git
synced 2025-08-04 00:48:58 +00:00
gh-129107: make bytearray
thread safe (#129108)
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
This commit is contained in:
parent
d2e60d8e59
commit
a05433f24a
5 changed files with 904 additions and 100 deletions
|
@ -11,12 +11,16 @@ import sys
|
|||
import copy
|
||||
import functools
|
||||
import pickle
|
||||
import sysconfig
|
||||
import tempfile
|
||||
import textwrap
|
||||
import threading
|
||||
import unittest
|
||||
|
||||
import test.support
|
||||
from test import support
|
||||
from test.support import import_helper
|
||||
from test.support import threading_helper
|
||||
from test.support import warnings_helper
|
||||
import test.string_tests
|
||||
import test.list_tests
|
||||
|
@ -2185,5 +2189,336 @@ class BytesSubclassTest(SubclassTest, unittest.TestCase):
|
|||
type2test = BytesSubclass
|
||||
|
||||
|
||||
class FreeThreadingTest(unittest.TestCase):
|
||||
@unittest.skipUnless(support.Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
|
||||
@threading_helper.reap_threads
|
||||
@threading_helper.requires_working_threading()
|
||||
def test_free_threading_bytearray(self):
|
||||
# Test pretty much everything that can break under free-threading.
|
||||
# Non-deterministic, but at least one of these things will fail if
|
||||
# bytearray module is not free-thread safe.
|
||||
|
||||
def clear(b, a, *args): # MODIFIES!
|
||||
b.wait()
|
||||
try: a.clear()
|
||||
except BufferError: pass
|
||||
|
||||
def clear2(b, a, c): # MODIFIES c!
|
||||
b.wait()
|
||||
try: c.clear()
|
||||
except BufferError: pass
|
||||
|
||||
def pop1(b, a): # MODIFIES!
|
||||
b.wait()
|
||||
try: a.pop()
|
||||
except IndexError: pass
|
||||
|
||||
def append1(b, a): # MODIFIES!
|
||||
b.wait()
|
||||
a.append(0)
|
||||
|
||||
def insert1(b, a): # MODIFIES!
|
||||
b.wait()
|
||||
a.insert(0, 0)
|
||||
|
||||
def extend(b, a): # MODIFIES!
|
||||
c = bytearray(b'0' * 0x400000)
|
||||
b.wait()
|
||||
a.extend(c)
|
||||
|
||||
def remove(b, a): # MODIFIES!
|
||||
c = ord('0')
|
||||
b.wait()
|
||||
try: a.remove(c)
|
||||
except ValueError: pass
|
||||
|
||||
def reverse(b, a): # modifies inplace
|
||||
b.wait()
|
||||
a.reverse()
|
||||
|
||||
def reduce(b, a):
|
||||
b.wait()
|
||||
a.__reduce__()
|
||||
|
||||
def reduceex2(b, a):
|
||||
b.wait()
|
||||
a.__reduce_ex__(2)
|
||||
|
||||
def reduceex3(b, a):
|
||||
b.wait()
|
||||
c = a.__reduce_ex__(3)
|
||||
assert not c[1] or 0xdd not in c[1][0]
|
||||
|
||||
def count0(b, a):
|
||||
b.wait()
|
||||
a.count(0)
|
||||
|
||||
def decode(b, a):
|
||||
b.wait()
|
||||
a.decode()
|
||||
|
||||
def find(b, a):
|
||||
c = bytearray(b'0' * 0x40000)
|
||||
b.wait()
|
||||
a.find(c)
|
||||
|
||||
def hex(b, a):
|
||||
b.wait()
|
||||
a.hex('_')
|
||||
|
||||
def join(b, a):
|
||||
b.wait()
|
||||
a.join([b'1', b'2', b'3'])
|
||||
|
||||
def replace(b, a):
|
||||
b.wait()
|
||||
a.replace(b'0', b'')
|
||||
|
||||
def maketrans(b, a, c):
|
||||
b.wait()
|
||||
try: a.maketrans(a, c)
|
||||
except ValueError: pass
|
||||
|
||||
def translate(b, a, c):
|
||||
b.wait()
|
||||
a.translate(c)
|
||||
|
||||
def copy(b, a):
|
||||
b.wait()
|
||||
c = a.copy()
|
||||
if c: assert c[0] == 48 # '0'
|
||||
|
||||
def endswith(b, a):
|
||||
b.wait()
|
||||
assert not a.endswith(b'\xdd')
|
||||
|
||||
def index(b, a):
|
||||
b.wait()
|
||||
try: a.index(b'\xdd')
|
||||
except ValueError: return
|
||||
assert False
|
||||
|
||||
def lstrip(b, a):
|
||||
b.wait()
|
||||
assert not a.lstrip(b'0')
|
||||
|
||||
def partition(b, a):
|
||||
b.wait()
|
||||
assert not a.partition(b'\xdd')[2]
|
||||
|
||||
def removeprefix(b, a):
|
||||
b.wait()
|
||||
assert not a.removeprefix(b'0')
|
||||
|
||||
def removesuffix(b, a):
|
||||
b.wait()
|
||||
assert not a.removesuffix(b'0')
|
||||
|
||||
def rfind(b, a):
|
||||
b.wait()
|
||||
assert a.rfind(b'\xdd') == -1
|
||||
|
||||
def rindex(b, a):
|
||||
b.wait()
|
||||
try: a.rindex(b'\xdd')
|
||||
except ValueError: return
|
||||
assert False
|
||||
|
||||
def rpartition(b, a):
|
||||
b.wait()
|
||||
assert not a.rpartition(b'\xdd')[0]
|
||||
|
||||
def rsplit(b, a):
|
||||
b.wait()
|
||||
assert len(a.rsplit(b'\xdd')) == 1
|
||||
|
||||
def rstrip(b, a):
|
||||
b.wait()
|
||||
assert not a.rstrip(b'0')
|
||||
|
||||
def split(b, a):
|
||||
b.wait()
|
||||
assert len(a.split(b'\xdd')) == 1
|
||||
|
||||
def splitlines(b, a):
|
||||
b.wait()
|
||||
l = len(a.splitlines())
|
||||
assert l > 1 or l == 0
|
||||
|
||||
def startswith(b, a):
|
||||
b.wait()
|
||||
assert not a.startswith(b'\xdd')
|
||||
|
||||
def strip(b, a):
|
||||
b.wait()
|
||||
assert not a.strip(b'0')
|
||||
|
||||
def repeat(b, a):
|
||||
b.wait()
|
||||
a * 2
|
||||
|
||||
def contains(b, a):
|
||||
b.wait()
|
||||
assert 0xdd not in a
|
||||
|
||||
def iconcat(b, a): # MODIFIES!
|
||||
c = bytearray(b'0' * 0x400000)
|
||||
b.wait()
|
||||
a += c
|
||||
|
||||
def irepeat(b, a): # MODIFIES!
|
||||
b.wait()
|
||||
a *= 2
|
||||
|
||||
def subscript(b, a):
|
||||
b.wait()
|
||||
try: assert a[0] != 0xdd
|
||||
except IndexError: pass
|
||||
|
||||
def ass_subscript(b, a): # MODIFIES!
|
||||
c = bytearray(b'0' * 0x400000)
|
||||
b.wait()
|
||||
a[:] = c
|
||||
|
||||
def mod(b, a):
|
||||
c = tuple(range(4096))
|
||||
b.wait()
|
||||
try: a % c
|
||||
except TypeError: pass
|
||||
|
||||
def repr_(b, a):
|
||||
b.wait()
|
||||
repr(a)
|
||||
|
||||
def capitalize(b, a):
|
||||
b.wait()
|
||||
c = a.capitalize()
|
||||
assert not c or c[0] not in (0xdd, 0xcd)
|
||||
|
||||
def center(b, a):
|
||||
b.wait()
|
||||
c = a.center(0x60000)
|
||||
assert not c or c[0x20000] not in (0xdd, 0xcd)
|
||||
|
||||
def expandtabs(b, a):
|
||||
b.wait()
|
||||
c = a.expandtabs()
|
||||
assert not c or c[0] not in (0xdd, 0xcd)
|
||||
|
||||
def ljust(b, a):
|
||||
b.wait()
|
||||
c = a.ljust(0x600000)
|
||||
assert not c or c[0] not in (0xdd, 0xcd)
|
||||
|
||||
def lower(b, a):
|
||||
b.wait()
|
||||
c = a.lower()
|
||||
assert not c or c[0] not in (0xdd, 0xcd)
|
||||
|
||||
def rjust(b, a):
|
||||
b.wait()
|
||||
c = a.rjust(0x600000)
|
||||
assert not c or c[-1] not in (0xdd, 0xcd)
|
||||
|
||||
def swapcase(b, a):
|
||||
b.wait()
|
||||
c = a.swapcase()
|
||||
assert not c or c[-1] not in (0xdd, 0xcd)
|
||||
|
||||
def title(b, a):
|
||||
b.wait()
|
||||
c = a.title()
|
||||
assert not c or c[-1] not in (0xdd, 0xcd)
|
||||
|
||||
def upper(b, a):
|
||||
b.wait()
|
||||
c = a.upper()
|
||||
assert not c or c[-1] not in (0xdd, 0xcd)
|
||||
|
||||
def zfill(b, a):
|
||||
b.wait()
|
||||
c = a.zfill(0x400000)
|
||||
assert not c or c[-1] not in (0xdd, 0xcd)
|
||||
|
||||
def check(funcs, a=None, *args):
|
||||
if a is None:
|
||||
a = bytearray(b'0' * 0x400000)
|
||||
|
||||
barrier = threading.Barrier(len(funcs))
|
||||
threads = []
|
||||
|
||||
for func in funcs:
|
||||
thread = threading.Thread(target=func, args=(barrier, a, *args))
|
||||
|
||||
threads.append(thread)
|
||||
|
||||
with threading_helper.start_threads(threads):
|
||||
pass
|
||||
|
||||
for thread in threads:
|
||||
threading_helper.join_thread(thread)
|
||||
|
||||
# hard errors
|
||||
|
||||
check([clear] + [reduce] * 10)
|
||||
check([clear] + [reduceex2] * 10)
|
||||
check([clear] + [append1] * 10)
|
||||
check([clear] * 10)
|
||||
check([clear] + [count0] * 10)
|
||||
check([clear] + [decode] * 10)
|
||||
check([clear] + [extend] * 10)
|
||||
check([clear] + [find] * 10)
|
||||
check([clear] + [hex] * 10)
|
||||
check([clear] + [insert1] * 10)
|
||||
check([clear] + [join] * 10)
|
||||
check([clear] + [pop1] * 10)
|
||||
check([clear] + [remove] * 10)
|
||||
check([clear] + [replace] * 10)
|
||||
check([clear] + [reverse] * 10)
|
||||
check([clear, clear2] + [maketrans] * 10, bytearray(range(128)), bytearray(range(128)))
|
||||
check([clear] + [translate] * 10, None, bytearray.maketrans(bytearray(range(128)), bytearray(range(128))))
|
||||
|
||||
check([clear] + [repeat] * 10)
|
||||
check([clear] + [iconcat] * 10)
|
||||
check([clear] + [irepeat] * 10)
|
||||
check([clear] + [ass_subscript] * 10)
|
||||
check([clear] + [repr_] * 10)
|
||||
|
||||
# value errors
|
||||
|
||||
check([clear] + [reduceex3] * 10, bytearray(b'a' * 0x40000))
|
||||
check([clear] + [copy] * 10)
|
||||
check([clear] + [endswith] * 10)
|
||||
check([clear] + [index] * 10)
|
||||
check([clear] + [lstrip] * 10)
|
||||
check([clear] + [partition] * 10)
|
||||
check([clear] + [removeprefix] * 10, bytearray(b'0'))
|
||||
check([clear] + [removesuffix] * 10, bytearray(b'0'))
|
||||
check([clear] + [rfind] * 10)
|
||||
check([clear] + [rindex] * 10)
|
||||
check([clear] + [rpartition] * 10)
|
||||
check([clear] + [rsplit] * 10, bytearray(b'0' * 0x4000))
|
||||
check([clear] + [rstrip] * 10)
|
||||
check([clear] + [split] * 10, bytearray(b'0' * 0x4000))
|
||||
check([clear] + [splitlines] * 10, bytearray(b'\n' * 0x400))
|
||||
check([clear] + [startswith] * 10)
|
||||
check([clear] + [strip] * 10)
|
||||
|
||||
check([clear] + [contains] * 10)
|
||||
check([clear] + [subscript] * 10)
|
||||
check([clear] + [mod] * 10, bytearray(b'%d' * 4096))
|
||||
|
||||
check([clear] + [capitalize] * 10, bytearray(b'a' * 0x40000))
|
||||
check([clear] + [center] * 10, bytearray(b'a' * 0x40000))
|
||||
check([clear] + [expandtabs] * 10, bytearray(b'0\t' * 4096))
|
||||
check([clear] + [ljust] * 10, bytearray(b'0' * 0x400000))
|
||||
check([clear] + [lower] * 10, bytearray(b'A' * 0x400000))
|
||||
check([clear] + [rjust] * 10, bytearray(b'0' * 0x400000))
|
||||
check([clear] + [swapcase] * 10, bytearray(b'aA' * 0x200000))
|
||||
check([clear] + [title] * 10, bytearray(b'aA' * 0x200000))
|
||||
check([clear] + [upper] * 10, bytearray(b'a' * 0x400000))
|
||||
check([clear] + [zfill] * 10, bytearray(b'1' * 0x200000))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue