Issue #26494: Fixed crash on iterating exhausting iterators.

Affected classes are generic sequence iterators, iterators of str, bytes,
bytearray, list, tuple, set, frozenset, dict, OrderedDict, corresponding
views and os.scandir() iterator.
This commit is contained in:
Serhiy Storchaka 2016-03-30 20:41:15 +03:00
commit ab479c49d3
19 changed files with 94 additions and 24 deletions

View file

@ -5,6 +5,7 @@ Tests common to tuple, list and UserList.UserList
import unittest import unittest
import sys import sys
import pickle import pickle
from test import support
# Various iterables # Various iterables
# This is used for checking the constructor (here and in test_deque.py) # This is used for checking the constructor (here and in test_deque.py)
@ -408,3 +409,7 @@ class CommonTest(unittest.TestCase):
lst2 = pickle.loads(pickle.dumps(lst, proto)) lst2 = pickle.loads(pickle.dumps(lst, proto))
self.assertEqual(lst2, lst) self.assertEqual(lst2, lst)
self.assertNotEqual(id(lst2), id(lst)) self.assertNotEqual(id(lst2), id(lst))
def test_free_after_iterating(self):
support.check_free_after_iterating(self, iter, self.type2test)
support.check_free_after_iterating(self, reversed, self.type2test)

View file

@ -2432,3 +2432,22 @@ def run_in_subinterp(code):
"memory allocations") "memory allocations")
import _testcapi import _testcapi
return _testcapi.run_in_subinterp(code) return _testcapi.run_in_subinterp(code)
def check_free_after_iterating(test, iter, cls, args=()):
class A(cls):
def __del__(self):
nonlocal done
done = True
try:
next(it)
except StopIteration:
pass
done = False
it = iter(A(*args))
# Issue 26494: Shouldn't crash
test.assertRaises(StopIteration, next, it)
# The sequence should be deallocated just after the end of iterating
gc_collect()
test.assertTrue(done)

View file

@ -761,6 +761,10 @@ class BaseBytesTest:
self.assertRaisesRegex(TypeError, r'\bendswith\b', b.endswith, self.assertRaisesRegex(TypeError, r'\bendswith\b', b.endswith,
x, None, None, None) x, None, None, None)
def test_free_after_iterating(self):
test.support.check_free_after_iterating(self, iter, self.type2test)
test.support.check_free_after_iterating(self, reversed, self.type2test)
class BytesTest(BaseBytesTest, unittest.TestCase): class BytesTest(BaseBytesTest, unittest.TestCase):
type2test = bytes type2test = bytes

View file

@ -918,6 +918,10 @@ class TestSequence(seq_tests.CommonTest):
# For now, bypass tests that require slicing # For now, bypass tests that require slicing
pass pass
def test_free_after_iterating(self):
# For now, bypass tests that require slicing
self.skipTest("Exhausted deque iterator doesn't free a deque")
#============================================================================== #==============================================================================
libreftest = """ libreftest = """

View file

@ -954,6 +954,12 @@ class DictTest(unittest.TestCase):
d = {X(): 0, 1: 1} d = {X(): 0, 1: 1}
self.assertRaises(RuntimeError, d.update, other) self.assertRaises(RuntimeError, d.update, other)
def test_free_after_iterating(self):
support.check_free_after_iterating(self, iter, dict)
support.check_free_after_iterating(self, lambda d: iter(d.keys()), dict)
support.check_free_after_iterating(self, lambda d: iter(d.values()), dict)
support.check_free_after_iterating(self, lambda d: iter(d.items()), dict)
from test import mapping_tests from test import mapping_tests
class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol): class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol):

View file

@ -3,6 +3,7 @@
import sys import sys
import unittest import unittest
from test.support import run_unittest, TESTFN, unlink, cpython_only from test.support import run_unittest, TESTFN, unlink, cpython_only
from test.support import check_free_after_iterating
import pickle import pickle
import collections.abc import collections.abc
@ -980,6 +981,9 @@ class TestCase(unittest.TestCase):
self.assertEqual(next(it), 0) self.assertEqual(next(it), 0)
self.assertEqual(next(it), 1) self.assertEqual(next(it), 1)
def test_free_after_iterating(self):
check_free_after_iterating(self, iter, SequenceClass, (0,))
def test_main(): def test_main():
run_unittest(TestCase) run_unittest(TestCase)

View file

@ -608,6 +608,12 @@ class OrderedDictTests:
gc.collect() gc.collect()
self.assertIsNone(r()) self.assertIsNone(r())
def test_free_after_iterating(self):
support.check_free_after_iterating(self, iter, self.OrderedDict)
support.check_free_after_iterating(self, lambda d: iter(d.keys()), self.OrderedDict)
support.check_free_after_iterating(self, lambda d: iter(d.values()), self.OrderedDict)
support.check_free_after_iterating(self, lambda d: iter(d.items()), self.OrderedDict)
class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase): class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase):

View file

@ -364,6 +364,9 @@ class TestJointOps:
gc.collect() gc.collect()
self.assertTrue(ref() is None, "Cycle was not collected") self.assertTrue(ref() is None, "Cycle was not collected")
def test_free_after_iterating(self):
support.check_free_after_iterating(self, iter, self.thetype)
class TestSet(TestJointOps, unittest.TestCase): class TestSet(TestJointOps, unittest.TestCase):
thetype = set thetype = set
basetype = set basetype = set

View file

@ -2729,6 +2729,10 @@ class UnicodeTest(string_tests.CommonTest,
# Check that the second call returns the same result # Check that the second call returns the same result
self.assertEqual(getargs_s_hash(s), chr(k).encode() * (i + 1)) self.assertEqual(getargs_s_hash(s), chr(k).encode() * (i + 1))
def test_free_after_iterating(self):
support.check_free_after_iterating(self, iter, str)
support.check_free_after_iterating(self, reversed, str)
class StringModuleTest(unittest.TestCase): class StringModuleTest(unittest.TestCase):
def test_formatter_parser(self): def test_formatter_parser(self):

View file

@ -10,6 +10,11 @@ Release date: tba
Core and Builtins Core and Builtins
----------------- -----------------
- Issue #26494: Fixed crash on iterating exhausting iterators.
Affected classes are generic sequence iterators, iterators of str, bytes,
bytearray, list, tuple, set, frozenset, dict, OrderedDict, corresponding
views and os.scandir() iterator.
- Issue #26574: Optimize ``bytes.replace(b'', b'.')`` and - Issue #26574: Optimize ``bytes.replace(b'', b'.')`` and
``bytearray.replace(b'', b'.')``. Patch written by Josh Snider. ``bytearray.replace(b'', b'.')``. Patch written by Josh Snider.

View file

@ -11956,13 +11956,15 @@ ScandirIterator_is_closed(ScandirIterator *iterator)
static void static void
ScandirIterator_closedir(ScandirIterator *iterator) ScandirIterator_closedir(ScandirIterator *iterator)
{ {
if (iterator->handle == INVALID_HANDLE_VALUE) HANDLE handle = iterator->handle;
if (handle == INVALID_HANDLE_VALUE)
return; return;
Py_BEGIN_ALLOW_THREADS
FindClose(iterator->handle);
Py_END_ALLOW_THREADS
iterator->handle = INVALID_HANDLE_VALUE; iterator->handle = INVALID_HANDLE_VALUE;
Py_BEGIN_ALLOW_THREADS
FindClose(handle);
Py_END_ALLOW_THREADS
} }
static PyObject * static PyObject *
@ -12018,13 +12020,15 @@ ScandirIterator_is_closed(ScandirIterator *iterator)
static void static void
ScandirIterator_closedir(ScandirIterator *iterator) ScandirIterator_closedir(ScandirIterator *iterator)
{ {
if (!iterator->dirp) DIR *dirp = iterator->dirp;
if (!dirp)
return; return;
Py_BEGIN_ALLOW_THREADS
closedir(iterator->dirp);
Py_END_ALLOW_THREADS
iterator->dirp = NULL; iterator->dirp = NULL;
Py_BEGIN_ALLOW_THREADS
closedir(dirp);
Py_END_ALLOW_THREADS
return; return;
} }

View file

@ -3126,8 +3126,8 @@ bytearrayiter_next(bytesiterobject *it)
return item; return item;
} }
Py_DECREF(seq);
it->it_seq = NULL; it->it_seq = NULL;
Py_DECREF(seq);
return NULL; return NULL;
} }

View file

@ -3806,8 +3806,8 @@ striter_next(striterobject *it)
return item; return item;
} }
Py_DECREF(seq);
it->it_seq = NULL; it->it_seq = NULL;
Py_DECREF(seq);
return NULL; return NULL;
} }

View file

@ -2988,8 +2988,8 @@ static PyObject *dictiter_iternextkey(dictiterobject *di)
return key; return key;
fail: fail:
Py_DECREF(d);
di->di_dict = NULL; di->di_dict = NULL;
Py_DECREF(d);
return NULL; return NULL;
} }
@ -3069,8 +3069,8 @@ static PyObject *dictiter_iternextvalue(dictiterobject *di)
return value; return value;
fail: fail:
Py_DECREF(d);
di->di_dict = NULL; di->di_dict = NULL;
Py_DECREF(d);
return NULL; return NULL;
} }
@ -3164,8 +3164,8 @@ static PyObject *dictiter_iternextitem(dictiterobject *di)
return result; return result;
fail: fail:
Py_DECREF(d);
di->di_dict = NULL; di->di_dict = NULL;
Py_DECREF(d);
return NULL; return NULL;
} }

View file

@ -69,8 +69,8 @@ iter_iternext(PyObject *iterator)
PyErr_ExceptionMatches(PyExc_StopIteration)) PyErr_ExceptionMatches(PyExc_StopIteration))
{ {
PyErr_Clear(); PyErr_Clear();
Py_DECREF(seq);
it->it_seq = NULL; it->it_seq = NULL;
Py_DECREF(seq);
} }
return NULL; return NULL;
} }

View file

@ -2776,8 +2776,8 @@ listiter_next(listiterobject *it)
return item; return item;
} }
Py_DECREF(seq);
it->it_seq = NULL; it->it_seq = NULL;
Py_DECREF(seq);
return NULL; return NULL;
} }
@ -2906,9 +2906,17 @@ static PyObject *
listreviter_next(listreviterobject *it) listreviter_next(listreviterobject *it)
{ {
PyObject *item; PyObject *item;
Py_ssize_t index = it->it_index; Py_ssize_t index;
PyListObject *seq = it->it_seq; PyListObject *seq;
assert(it != NULL);
seq = it->it_seq;
if (seq == NULL) {
return NULL;
}
assert(PyList_Check(seq));
index = it->it_index;
if (index>=0 && index < PyList_GET_SIZE(seq)) { if (index>=0 && index < PyList_GET_SIZE(seq)) {
item = PyList_GET_ITEM(seq, index); item = PyList_GET_ITEM(seq, index);
it->it_index--; it->it_index--;
@ -2916,10 +2924,8 @@ listreviter_next(listreviterobject *it)
return item; return item;
} }
it->it_index = -1; it->it_index = -1;
if (seq != NULL) {
it->it_seq = NULL; it->it_seq = NULL;
Py_DECREF(seq); Py_DECREF(seq);
}
return NULL; return NULL;
} }

View file

@ -916,8 +916,8 @@ static PyObject *setiter_iternext(setiterobject *si)
return key; return key;
fail: fail:
Py_DECREF(so);
si->si_set = NULL; si->si_set = NULL;
Py_DECREF(so);
return NULL; return NULL;
} }

View file

@ -961,8 +961,8 @@ tupleiter_next(tupleiterobject *it)
return item; return item;
} }
Py_DECREF(seq);
it->it_seq = NULL; it->it_seq = NULL;
Py_DECREF(seq);
return NULL; return NULL;
} }

View file

@ -15401,8 +15401,8 @@ unicodeiter_next(unicodeiterobject *it)
return item; return item;
} }
Py_DECREF(seq);
it->it_seq = NULL; it->it_seq = NULL;
Py_DECREF(seq);
return NULL; return NULL;
} }