bpo-46615: Don't crash when set operations mutate the sets (GH-31120) (GH-31312)

Ensure strong references are acquired whenever using `set_next()`. Added randomized test cases for `__eq__` methods that sometimes mutate sets when called.

(cherry picked from commit 4a66615ba7)
This commit is contained in:
Dennis Sweeney 2022-02-13 05:29:42 -05:00 committed by GitHub
parent ebe73e6095
commit c31b8a97a8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 226 additions and 8 deletions

View file

@ -1799,6 +1799,192 @@ class TestWeirdBugs(unittest.TestCase):
s = {0}
s.update(other)
class TestOperationsMutating:
"""Regression test for bpo-46615"""
constructor1 = None
constructor2 = None
def make_sets_of_bad_objects(self):
class Bad:
def __eq__(self, other):
if not enabled:
return False
if randrange(20) == 0:
set1.clear()
if randrange(20) == 0:
set2.clear()
return bool(randrange(2))
def __hash__(self):
return randrange(2)
# Don't behave poorly during construction.
enabled = False
set1 = self.constructor1(Bad() for _ in range(randrange(50)))
set2 = self.constructor2(Bad() for _ in range(randrange(50)))
# Now start behaving poorly
enabled = True
return set1, set2
def check_set_op_does_not_crash(self, function):
for _ in range(100):
set1, set2 = self.make_sets_of_bad_objects()
try:
function(set1, set2)
except RuntimeError as e:
# Just make sure we don't crash here.
self.assertIn("changed size during iteration", str(e))
class TestBinaryOpsMutating(TestOperationsMutating):
def test_eq_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a == b)
def test_ne_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a != b)
def test_lt_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a < b)
def test_le_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a <= b)
def test_gt_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a > b)
def test_ge_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a >= b)
def test_and_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a & b)
def test_or_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a | b)
def test_sub_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a - b)
def test_xor_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a ^ b)
def test_iadd_with_mutation(self):
def f(a, b):
a &= b
self.check_set_op_does_not_crash(f)
def test_ior_with_mutation(self):
def f(a, b):
a |= b
self.check_set_op_does_not_crash(f)
def test_isub_with_mutation(self):
def f(a, b):
a -= b
self.check_set_op_does_not_crash(f)
def test_ixor_with_mutation(self):
def f(a, b):
a ^= b
self.check_set_op_does_not_crash(f)
def test_iteration_with_mutation(self):
def f1(a, b):
for x in a:
pass
for y in b:
pass
def f2(a, b):
for y in b:
pass
for x in a:
pass
def f3(a, b):
for x, y in zip(a, b):
pass
self.check_set_op_does_not_crash(f1)
self.check_set_op_does_not_crash(f2)
self.check_set_op_does_not_crash(f3)
class TestBinaryOpsMutating_Set_Set(TestBinaryOpsMutating, unittest.TestCase):
constructor1 = set
constructor2 = set
class TestBinaryOpsMutating_Subclass_Subclass(TestBinaryOpsMutating, unittest.TestCase):
constructor1 = SetSubclass
constructor2 = SetSubclass
class TestBinaryOpsMutating_Set_Subclass(TestBinaryOpsMutating, unittest.TestCase):
constructor1 = set
constructor2 = SetSubclass
class TestBinaryOpsMutating_Subclass_Set(TestBinaryOpsMutating, unittest.TestCase):
constructor1 = SetSubclass
constructor2 = set
class TestMethodsMutating(TestOperationsMutating):
def test_issubset_with_mutation(self):
self.check_set_op_does_not_crash(set.issubset)
def test_issuperset_with_mutation(self):
self.check_set_op_does_not_crash(set.issuperset)
def test_intersection_with_mutation(self):
self.check_set_op_does_not_crash(set.intersection)
def test_union_with_mutation(self):
self.check_set_op_does_not_crash(set.union)
def test_difference_with_mutation(self):
self.check_set_op_does_not_crash(set.difference)
def test_symmetric_difference_with_mutation(self):
self.check_set_op_does_not_crash(set.symmetric_difference)
def test_isdisjoint_with_mutation(self):
self.check_set_op_does_not_crash(set.isdisjoint)
def test_difference_update_with_mutation(self):
self.check_set_op_does_not_crash(set.difference_update)
def test_intersection_update_with_mutation(self):
self.check_set_op_does_not_crash(set.intersection_update)
def test_symmetric_difference_update_with_mutation(self):
self.check_set_op_does_not_crash(set.symmetric_difference_update)
def test_update_with_mutation(self):
self.check_set_op_does_not_crash(set.update)
class TestMethodsMutating_Set_Set(TestMethodsMutating, unittest.TestCase):
constructor1 = set
constructor2 = set
class TestMethodsMutating_Subclass_Subclass(TestMethodsMutating, unittest.TestCase):
constructor1 = SetSubclass
constructor2 = SetSubclass
class TestMethodsMutating_Set_Subclass(TestMethodsMutating, unittest.TestCase):
constructor1 = set
constructor2 = SetSubclass
class TestMethodsMutating_Subclass_Set(TestMethodsMutating, unittest.TestCase):
constructor1 = SetSubclass
constructor2 = set
class TestMethodsMutating_Set_Dict(TestMethodsMutating, unittest.TestCase):
constructor1 = set
constructor2 = dict.fromkeys
class TestMethodsMutating_Set_List(TestMethodsMutating, unittest.TestCase):
constructor1 = set
constructor2 = list
# Application tests (based on David Eppstein's graph recipes ====================================
def powerset(U):

View file

@ -0,0 +1 @@
When iterating over sets internally in ``setobject.c``, acquire strong references to the resulting items from the set. This prevents crashes in corner-cases of various set operations where the set gets mutated.

View file

@ -1207,17 +1207,21 @@ set_intersection(PySetObject *so, PyObject *other)
while (set_next((PySetObject *)other, &pos, &entry)) {
key = entry->key;
hash = entry->hash;
Py_INCREF(key);
rv = set_contains_entry(so, key, hash);
if (rv < 0) {
Py_DECREF(result);
Py_DECREF(key);
return NULL;
}
if (rv) {
if (set_add_entry(result, key, hash)) {
Py_DECREF(result);
Py_DECREF(key);
return NULL;
}
}
Py_DECREF(key);
}
return (PyObject *)result;
}
@ -1357,12 +1361,17 @@ set_isdisjoint(PySetObject *so, PyObject *other)
other = tmp;
}
while (set_next((PySetObject *)other, &pos, &entry)) {
rv = set_contains_entry(so, entry->key, entry->hash);
if (rv < 0)
PyObject *key = entry->key;
Py_INCREF(key);
rv = set_contains_entry(so, key, entry->hash);
Py_DECREF(key);
if (rv < 0) {
return NULL;
if (rv)
}
if (rv) {
Py_RETURN_FALSE;
}
}
Py_RETURN_TRUE;
}
@ -1420,11 +1429,16 @@ set_difference_update_internal(PySetObject *so, PyObject *other)
Py_INCREF(other);
}
while (set_next((PySetObject *)other, &pos, &entry))
if (set_discard_entry(so, entry->key, entry->hash) < 0) {
while (set_next((PySetObject *)other, &pos, &entry)) {
PyObject *key = entry->key;
Py_INCREF(key);
if (set_discard_entry(so, key, entry->hash) < 0) {
Py_DECREF(other);
Py_DECREF(key);
return -1;
}
Py_DECREF(key);
}
Py_DECREF(other);
} else {
@ -1515,17 +1529,21 @@ set_difference(PySetObject *so, PyObject *other)
while (set_next(so, &pos, &entry)) {
key = entry->key;
hash = entry->hash;
Py_INCREF(key);
rv = _PyDict_Contains(other, key, hash);
if (rv < 0) {
Py_DECREF(result);
Py_DECREF(key);
return NULL;
}
if (!rv) {
if (set_add_entry((PySetObject *)result, key, hash)) {
Py_DECREF(result);
Py_DECREF(key);
return NULL;
}
}
Py_DECREF(key);
}
return result;
}
@ -1534,17 +1552,21 @@ set_difference(PySetObject *so, PyObject *other)
while (set_next(so, &pos, &entry)) {
key = entry->key;
hash = entry->hash;
Py_INCREF(key);
rv = set_contains_entry((PySetObject *)other, key, hash);
if (rv < 0) {
Py_DECREF(result);
Py_DECREF(key);
return NULL;
}
if (!rv) {
if (set_add_entry((PySetObject *)result, key, hash)) {
Py_DECREF(result);
Py_DECREF(key);
return NULL;
}
}
Py_DECREF(key);
}
return result;
}
@ -1641,17 +1663,21 @@ set_symmetric_difference_update(PySetObject *so, PyObject *other)
while (set_next(otherset, &pos, &entry)) {
key = entry->key;
hash = entry->hash;
Py_INCREF(key);
rv = set_discard_entry(so, key, hash);
if (rv < 0) {
Py_DECREF(otherset);
Py_DECREF(key);
return NULL;
}
if (rv == DISCARD_NOTFOUND) {
if (set_add_entry(so, key, hash)) {
Py_DECREF(otherset);
Py_DECREF(key);
return NULL;
}
}
Py_DECREF(key);
}
Py_DECREF(otherset);
Py_RETURN_NONE;
@ -1726,12 +1752,17 @@ set_issubset(PySetObject *so, PyObject *other)
Py_RETURN_FALSE;
while (set_next(so, &pos, &entry)) {
rv = set_contains_entry((PySetObject *)other, entry->key, entry->hash);
if (rv < 0)
PyObject *key = entry->key;
Py_INCREF(key);
rv = set_contains_entry((PySetObject *)other, key, entry->hash);
Py_DECREF(key);
if (rv < 0) {
return NULL;
if (!rv)
}
if (!rv) {
Py_RETURN_FALSE;
}
}
Py_RETURN_TRUE;
}