gh-97982: Factorize PyUnicode_Count() and unicode_count() code (#98025)

Add unicode_count_impl() to factorize PyUnicode_Count()
and unicode_count() code.
This commit is contained in:
Nikita Sobolev 2022-10-12 19:27:53 +03:00 committed by GitHub
parent e9569ec43e
commit ccab67ba79
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 60 deletions

View file

@ -241,6 +241,10 @@ class UnicodeTest(string_tests.CommonTest,
self.checkequal(0, 'a' * 10, 'count', 'a\u0102') self.checkequal(0, 'a' * 10, 'count', 'a\u0102')
self.checkequal(0, 'a' * 10, 'count', 'a\U00100304') self.checkequal(0, 'a' * 10, 'count', 'a\U00100304')
self.checkequal(0, '\u0102' * 10, 'count', '\u0102\U00100304') self.checkequal(0, '\u0102' * 10, 'count', '\u0102\U00100304')
# test subclass
class MyStr(str):
pass
self.checkequal(3, MyStr('aaa'), 'count', 'a')
def test_find(self): def test_find(self):
string_tests.CommonTest.test_find(self) string_tests.CommonTest.test_find(self)
@ -3002,6 +3006,12 @@ class CAPITest(unittest.TestCase):
self.assertEqual(unicode_count(uni, ch, 0, len(uni)), 1) self.assertEqual(unicode_count(uni, ch, 0, len(uni)), 1)
self.assertEqual(unicode_count(st, ch, 0, len(st)), 0) self.assertEqual(unicode_count(st, ch, 0, len(st)), 0)
# subclasses should still work
class MyStr(str):
pass
self.assertEqual(unicode_count(MyStr('aab'), 'a', 0, 3), 2)
# Test PyUnicode_FindChar() # Test PyUnicode_FindChar()
@support.cpython_only @support.cpython_only
@unittest.skipIf(_testcapi is None, 'need _testcapi module') @unittest.skipIf(_testcapi is None, 'need _testcapi module')

View file

@ -8964,21 +8964,20 @@ _PyUnicode_InsertThousandsGrouping(
return count; return count;
} }
static Py_ssize_t
Py_ssize_t unicode_count_impl(PyObject *str,
PyUnicode_Count(PyObject *str, PyObject *substr,
PyObject *substr, Py_ssize_t start,
Py_ssize_t start, Py_ssize_t end)
Py_ssize_t end)
{ {
assert(PyUnicode_Check(str));
assert(PyUnicode_Check(substr));
Py_ssize_t result; Py_ssize_t result;
int kind1, kind2; int kind1, kind2;
const void *buf1 = NULL, *buf2 = NULL; const void *buf1 = NULL, *buf2 = NULL;
Py_ssize_t len1, len2; Py_ssize_t len1, len2;
if (ensure_unicode(str) < 0 || ensure_unicode(substr) < 0)
return -1;
kind1 = PyUnicode_KIND(str); kind1 = PyUnicode_KIND(str);
kind2 = PyUnicode_KIND(substr); kind2 = PyUnicode_KIND(substr);
if (kind1 < kind2) if (kind1 < kind2)
@ -8998,6 +8997,7 @@ PyUnicode_Count(PyObject *str,
goto onError; goto onError;
} }
// We don't reuse `anylib_count` here because of the explicit casts.
switch (kind1) { switch (kind1) {
case PyUnicode_1BYTE_KIND: case PyUnicode_1BYTE_KIND:
result = ucs1lib_count( result = ucs1lib_count(
@ -9033,6 +9033,18 @@ PyUnicode_Count(PyObject *str,
return -1; return -1;
} }
Py_ssize_t
PyUnicode_Count(PyObject *str,
PyObject *substr,
Py_ssize_t start,
Py_ssize_t end)
{
if (ensure_unicode(str) < 0 || ensure_unicode(substr) < 0)
return -1;
return unicode_count_impl(str, substr, start, end);
}
Py_ssize_t Py_ssize_t
PyUnicode_Find(PyObject *str, PyUnicode_Find(PyObject *str,
PyObject *substr, PyObject *substr,
@ -10848,62 +10860,16 @@ unicode_count(PyObject *self, PyObject *args)
PyObject *substring = NULL; /* initialize to fix a compiler warning */ PyObject *substring = NULL; /* initialize to fix a compiler warning */
Py_ssize_t start = 0; Py_ssize_t start = 0;
Py_ssize_t end = PY_SSIZE_T_MAX; Py_ssize_t end = PY_SSIZE_T_MAX;
PyObject *result; Py_ssize_t result;
int kind1, kind2;
const void *buf1, *buf2;
Py_ssize_t len1, len2, iresult;
if (!parse_args_finds_unicode("count", args, &substring, &start, &end)) if (!parse_args_finds_unicode("count", args, &substring, &start, &end))
return NULL; return NULL;
kind1 = PyUnicode_KIND(self); result = unicode_count_impl(self, substring, start, end);
kind2 = PyUnicode_KIND(substring); if (result == -1)
if (kind1 < kind2) return NULL;
return PyLong_FromLong(0);
len1 = PyUnicode_GET_LENGTH(self); return PyLong_FromSsize_t(result);
len2 = PyUnicode_GET_LENGTH(substring);
ADJUST_INDICES(start, end, len1);
if (end - start < len2)
return PyLong_FromLong(0);
buf1 = PyUnicode_DATA(self);
buf2 = PyUnicode_DATA(substring);
if (kind2 != kind1) {
buf2 = unicode_askind(kind2, buf2, len2, kind1);
if (!buf2)
return NULL;
}
switch (kind1) {
case PyUnicode_1BYTE_KIND:
iresult = ucs1lib_count(
((const Py_UCS1*)buf1) + start, end - start,
buf2, len2, PY_SSIZE_T_MAX
);
break;
case PyUnicode_2BYTE_KIND:
iresult = ucs2lib_count(
((const Py_UCS2*)buf1) + start, end - start,
buf2, len2, PY_SSIZE_T_MAX
);
break;
case PyUnicode_4BYTE_KIND:
iresult = ucs4lib_count(
((const Py_UCS4*)buf1) + start, end - start,
buf2, len2, PY_SSIZE_T_MAX
);
break;
default:
Py_UNREACHABLE();
}
result = PyLong_FromSsize_t(iresult);
assert((kind2 == kind1) == (buf2 == PyUnicode_DATA(substring)));
if (kind2 != kind1)
PyMem_Free((void *)buf2);
return result;
} }
/*[clinic input] /*[clinic input]