gh-91576: Speed up iteration of strings (#91574)

This commit is contained in:
Kumar Aditya 2022-04-18 19:48:27 +05:30 committed by GitHub
parent a29f858124
commit 8c54c3dacc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 79 additions and 6 deletions

View file

@ -20,6 +20,7 @@ extern void _PyUnicode_Fini(PyInterpreterState *);
extern void _PyUnicode_FiniTypes(PyInterpreterState *); extern void _PyUnicode_FiniTypes(PyInterpreterState *);
extern void _PyStaticUnicode_Dealloc(PyObject *); extern void _PyStaticUnicode_Dealloc(PyObject *);
extern PyTypeObject _PyUnicodeASCIIIter_Type;
/* other API */ /* other API */

View file

@ -9,6 +9,7 @@ import _string
import codecs import codecs
import itertools import itertools
import operator import operator
import pickle
import struct import struct
import sys import sys
import textwrap import textwrap
@ -185,6 +186,36 @@ class UnicodeTest(string_tests.CommonTest,
self.assertEqual(next(it), "\u3333") self.assertEqual(next(it), "\u3333")
self.assertRaises(StopIteration, next, it) self.assertRaises(StopIteration, next, it)
def test_iterators_invocation(self):
cases = [type(iter('abc')), type(iter('🚀'))]
for cls in cases:
with self.subTest(cls=cls):
self.assertRaises(TypeError, cls)
def test_iteration(self):
cases = ['abc', '🚀🚀🚀', "\u1111\u2222\u3333"]
for case in cases:
with self.subTest(string=case):
self.assertEqual(case, "".join(iter(case)))
def test_exhausted_iterator(self):
cases = ['abc', '🚀🚀🚀', "\u1111\u2222\u3333"]
for case in cases:
with self.subTest(case=case):
iterator = iter(case)
tuple(iterator)
self.assertRaises(StopIteration, next, iterator)
def test_pickle_iterator(self):
cases = ['abc', '🚀🚀🚀', "\u1111\u2222\u3333"]
for case in cases:
with self.subTest(case=case):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
it = iter(case)
with self.subTest(proto=proto):
pickled = "".join(pickle.loads(pickle.dumps(it, proto)))
self.assertEqual(case, pickled)
def test_count(self): def test_count(self):
string_tests.CommonTest.test_count(self) string_tests.CommonTest.test_count(self)
# check mixed argument types # check mixed argument types

View file

@ -0,0 +1 @@
Speed up iteration of ascii strings by 50%. Patch by Kumar Aditya.

View file

@ -1936,6 +1936,7 @@ static PyTypeObject* static_types[] = {
&_PyNamespace_Type, &_PyNamespace_Type,
&_PyNone_Type, &_PyNone_Type,
&_PyNotImplemented_Type, &_PyNotImplemented_Type,
&_PyUnicodeASCIIIter_Type,
&_PyUnion_Type, &_PyUnion_Type,
&_PyWeakref_CallableProxyType, &_PyWeakref_CallableProxyType,
&_PyWeakref_ProxyType, &_PyWeakref_ProxyType,

View file

@ -15697,7 +15697,7 @@ unicodeiter_traverse(unicodeiterobject *it, visitproc visit, void *arg)
static PyObject * static PyObject *
unicodeiter_next(unicodeiterobject *it) unicodeiter_next(unicodeiterobject *it)
{ {
PyObject *seq, *item; PyObject *seq;
assert(it != NULL); assert(it != NULL);
seq = it->it_seq; seq = it->it_seq;
@ -15709,10 +15709,8 @@ unicodeiter_next(unicodeiterobject *it)
int kind = PyUnicode_KIND(seq); int kind = PyUnicode_KIND(seq);
const void *data = PyUnicode_DATA(seq); const void *data = PyUnicode_DATA(seq);
Py_UCS4 chr = PyUnicode_READ(kind, data, it->it_index); Py_UCS4 chr = PyUnicode_READ(kind, data, it->it_index);
item = PyUnicode_FromOrdinal(chr); it->it_index++;
if (item != NULL) return unicode_char(chr);
++it->it_index;
return item;
} }
it->it_seq = NULL; it->it_seq = NULL;
@ -15720,6 +15718,29 @@ unicodeiter_next(unicodeiterobject *it)
return NULL; return NULL;
} }
static PyObject *
unicode_ascii_iter_next(unicodeiterobject *it)
{
assert(it != NULL);
PyObject *seq = it->it_seq;
if (seq == NULL) {
return NULL;
}
assert(_PyUnicode_CHECK(seq));
assert(PyUnicode_IS_COMPACT_ASCII(seq));
if (it->it_index < PyUnicode_GET_LENGTH(seq)) {
const void *data = ((void*)(_PyASCIIObject_CAST(seq) + 1));
Py_UCS1 chr = (Py_UCS1)PyUnicode_READ(PyUnicode_1BYTE_KIND,
data, it->it_index);
it->it_index++;
PyObject *item = (PyObject*)&_Py_SINGLETON(strings).ascii[chr];
return Py_NewRef(item);
}
it->it_seq = NULL;
Py_DECREF(seq);
return NULL;
}
static PyObject * static PyObject *
unicodeiter_len(unicodeiterobject *it, PyObject *Py_UNUSED(ignored)) unicodeiter_len(unicodeiterobject *it, PyObject *Py_UNUSED(ignored))
{ {
@ -15808,6 +15829,19 @@ PyTypeObject PyUnicodeIter_Type = {
0, 0,
}; };
PyTypeObject _PyUnicodeASCIIIter_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
.tp_name = "str_ascii_iterator",
.tp_basicsize = sizeof(unicodeiterobject),
.tp_dealloc = (destructor)unicodeiter_dealloc,
.tp_getattro = PyObject_GenericGetAttr,
.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
.tp_traverse = (traverseproc)unicodeiter_traverse,
.tp_iter = PyObject_SelfIter,
.tp_iternext = (iternextfunc)unicode_ascii_iter_next,
.tp_methods = unicodeiter_methods,
};
static PyObject * static PyObject *
unicode_iter(PyObject *seq) unicode_iter(PyObject *seq)
{ {
@ -15819,7 +15853,12 @@ unicode_iter(PyObject *seq)
} }
if (PyUnicode_READY(seq) == -1) if (PyUnicode_READY(seq) == -1)
return NULL; return NULL;
it = PyObject_GC_New(unicodeiterobject, &PyUnicodeIter_Type); if (PyUnicode_IS_COMPACT_ASCII(seq)) {
it = PyObject_GC_New(unicodeiterobject, &_PyUnicodeASCIIIter_Type);
}
else {
it = PyObject_GC_New(unicodeiterobject, &PyUnicodeIter_Type);
}
if (it == NULL) if (it == NULL)
return NULL; return NULL;
it->it_index = 0; it->it_index = 0;