gh-113202: Add a strict option to itertools.batched() (gh-113203)

This commit is contained in:
Raymond Hettinger 2023-12-16 09:13:50 -06:00 committed by GitHub
parent fe479fb8a9
commit 1583c40be9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 60 additions and 24 deletions

View file

@ -164,11 +164,14 @@ loops that truncate the stream.
Added the optional *initial* parameter. Added the optional *initial* parameter.
.. function:: batched(iterable, n) .. function:: batched(iterable, n, *, strict=False)
Batch data from the *iterable* into tuples of length *n*. The last Batch data from the *iterable* into tuples of length *n*. The last
batch may be shorter than *n*. batch may be shorter than *n*.
If *strict* is true, will raise a :exc:`ValueError` if the final
batch is shorter than *n*.
Loops over the input iterable and accumulates data into tuples up to Loops over the input iterable and accumulates data into tuples up to
size *n*. The input is consumed lazily, just enough to fill a batch. size *n*. The input is consumed lazily, just enough to fill a batch.
The result is yielded as soon as the batch is full or when the input The result is yielded as soon as the batch is full or when the input
@ -190,16 +193,21 @@ loops that truncate the stream.
Roughly equivalent to:: Roughly equivalent to::
def batched(iterable, n): def batched(iterable, n, *, strict=False):
# batched('ABCDEFG', 3) --> ABC DEF G # batched('ABCDEFG', 3) --> ABC DEF G
if n < 1: if n < 1:
raise ValueError('n must be at least one') raise ValueError('n must be at least one')
it = iter(iterable) it = iter(iterable)
while batch := tuple(islice(it, n)): while batch := tuple(islice(it, n)):
if strict and len(batch) != n:
raise ValueError('batched(): incomplete batch')
yield batch yield batch
.. versionadded:: 3.12 .. versionadded:: 3.12
.. versionchanged:: 3.13
Added the *strict* option.
.. function:: chain(*iterables) .. function:: chain(*iterables)
@ -1039,7 +1047,7 @@ The following recipes have a more mathematical flavor:
def reshape(matrix, cols): def reshape(matrix, cols):
"Reshape a 2-D matrix to have a given number of columns." "Reshape a 2-D matrix to have a given number of columns."
# reshape([(0, 1), (2, 3), (4, 5)], 3) --> (0, 1, 2), (3, 4, 5) # reshape([(0, 1), (2, 3), (4, 5)], 3) --> (0, 1, 2), (3, 4, 5)
return batched(chain.from_iterable(matrix), cols) return batched(chain.from_iterable(matrix), cols, strict=True)
def transpose(matrix): def transpose(matrix):
"Swap the rows and columns of a 2-D matrix." "Swap the rows and columns of a 2-D matrix."
@ -1270,6 +1278,10 @@ The following recipes have a more mathematical flavor:
[(0, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11)] [(0, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11)]
>>> list(reshape(M, 4)) >>> list(reshape(M, 4))
[(0, 1, 2, 3), (4, 5, 6, 7), (8, 9, 10, 11)] [(0, 1, 2, 3), (4, 5, 6, 7), (8, 9, 10, 11)]
>>> list(reshape(M, 5))
Traceback (most recent call last):
...
ValueError: batched(): incomplete batch
>>> list(reshape(M, 6)) >>> list(reshape(M, 6))
[(0, 1, 2, 3, 4, 5), (6, 7, 8, 9, 10, 11)] [(0, 1, 2, 3, 4, 5), (6, 7, 8, 9, 10, 11)]
>>> list(reshape(M, 12)) >>> list(reshape(M, 12))

View file

@ -187,7 +187,11 @@ class TestBasicOps(unittest.TestCase):
[('A', 'B'), ('C', 'D'), ('E', 'F'), ('G',)]) [('A', 'B'), ('C', 'D'), ('E', 'F'), ('G',)])
self.assertEqual(list(batched('ABCDEFG', 1)), self.assertEqual(list(batched('ABCDEFG', 1)),
[('A',), ('B',), ('C',), ('D',), ('E',), ('F',), ('G',)]) [('A',), ('B',), ('C',), ('D',), ('E',), ('F',), ('G',)])
self.assertEqual(list(batched('ABCDEF', 2, strict=True)),
[('A', 'B'), ('C', 'D'), ('E', 'F')])
with self.assertRaises(ValueError): # Incomplete batch when strict
list(batched('ABCDEFG', 3, strict=True))
with self.assertRaises(TypeError): # Too few arguments with self.assertRaises(TypeError): # Too few arguments
list(batched('ABCDEFG')) list(batched('ABCDEFG'))
with self.assertRaises(TypeError): with self.assertRaises(TypeError):

View file

@ -0,0 +1 @@
Add a ``strict`` option to ``batched()`` in the ``itertools`` module.

View file

@ -10,7 +10,7 @@ preserve
#include "pycore_modsupport.h" // _PyArg_UnpackKeywords() #include "pycore_modsupport.h" // _PyArg_UnpackKeywords()
PyDoc_STRVAR(batched_new__doc__, PyDoc_STRVAR(batched_new__doc__,
"batched(iterable, n)\n" "batched(iterable, n, *, strict=False)\n"
"--\n" "--\n"
"\n" "\n"
"Batch data into tuples of length n. The last batch may be shorter than n.\n" "Batch data into tuples of length n. The last batch may be shorter than n.\n"
@ -25,10 +25,14 @@ PyDoc_STRVAR(batched_new__doc__,
" ...\n" " ...\n"
" (\'A\', \'B\', \'C\')\n" " (\'A\', \'B\', \'C\')\n"
" (\'D\', \'E\', \'F\')\n" " (\'D\', \'E\', \'F\')\n"
" (\'G\',)"); " (\'G\',)\n"
"\n"
"If \"strict\" is True, raises a ValueError if the final batch is shorter\n"
"than n.");
static PyObject * static PyObject *
batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n); batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n,
int strict);
static PyObject * static PyObject *
batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
@ -36,14 +40,14 @@ batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
PyObject *return_value = NULL; PyObject *return_value = NULL;
#if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE)
#define NUM_KEYWORDS 2 #define NUM_KEYWORDS 3
static struct { static struct {
PyGC_Head _this_is_not_used; PyGC_Head _this_is_not_used;
PyObject_VAR_HEAD PyObject_VAR_HEAD
PyObject *ob_item[NUM_KEYWORDS]; PyObject *ob_item[NUM_KEYWORDS];
} _kwtuple = { } _kwtuple = {
.ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS)
.ob_item = { &_Py_ID(iterable), &_Py_ID(n), }, .ob_item = { &_Py_ID(iterable), &_Py_ID(n), &_Py_ID(strict), },
}; };
#undef NUM_KEYWORDS #undef NUM_KEYWORDS
#define KWTUPLE (&_kwtuple.ob_base.ob_base) #define KWTUPLE (&_kwtuple.ob_base.ob_base)
@ -52,18 +56,20 @@ batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
# define KWTUPLE NULL # define KWTUPLE NULL
#endif // !Py_BUILD_CORE #endif // !Py_BUILD_CORE
static const char * const _keywords[] = {"iterable", "n", NULL}; static const char * const _keywords[] = {"iterable", "n", "strict", NULL};
static _PyArg_Parser _parser = { static _PyArg_Parser _parser = {
.keywords = _keywords, .keywords = _keywords,
.fname = "batched", .fname = "batched",
.kwtuple = KWTUPLE, .kwtuple = KWTUPLE,
}; };
#undef KWTUPLE #undef KWTUPLE
PyObject *argsbuf[2]; PyObject *argsbuf[3];
PyObject * const *fastargs; PyObject * const *fastargs;
Py_ssize_t nargs = PyTuple_GET_SIZE(args); Py_ssize_t nargs = PyTuple_GET_SIZE(args);
Py_ssize_t noptargs = nargs + (kwargs ? PyDict_GET_SIZE(kwargs) : 0) - 2;
PyObject *iterable; PyObject *iterable;
Py_ssize_t n; Py_ssize_t n;
int strict = 0;
fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 2, 2, 0, argsbuf); fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 2, 2, 0, argsbuf);
if (!fastargs) { if (!fastargs) {
@ -82,7 +88,15 @@ batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
} }
n = ival; n = ival;
} }
return_value = batched_new_impl(type, iterable, n); if (!noptargs) {
goto skip_optional_kwonly;
}
strict = PyObject_IsTrue(fastargs[2]);
if (strict < 0) {
goto exit;
}
skip_optional_kwonly:
return_value = batched_new_impl(type, iterable, n, strict);
exit: exit:
return return_value; return return_value;
@ -914,4 +928,4 @@ skip_optional_pos:
exit: exit:
return return_value; return return_value;
} }
/*[clinic end generated code: output=782fe7e30733779b input=a9049054013a1b77]*/ /*[clinic end generated code: output=c6a515f765da86b5 input=a9049054013a1b77]*/

View file

@ -105,20 +105,11 @@ class itertools.pairwise "pairwiseobject *" "clinic_state()->pairwise_type"
/* batched object ************************************************************/ /* batched object ************************************************************/
/* Note: The built-in zip() function includes a "strict" argument
that was needed because that function would silently truncate data,
and there was no easy way for a user to detect the data loss.
The same reasoning does not apply to batched() which never drops data.
Instead, batched() produces a shorter tuple which can be handled
as the user sees fit. If requested, it would be reasonable to add
"fillvalue" support which had demonstrated value in zip_longest().
For now, the API is kept simple and clean.
*/
typedef struct { typedef struct {
PyObject_HEAD PyObject_HEAD
PyObject *it; PyObject *it;
Py_ssize_t batch_size; Py_ssize_t batch_size;
bool strict;
} batchedobject; } batchedobject;
/*[clinic input] /*[clinic input]
@ -126,6 +117,9 @@ typedef struct {
itertools.batched.__new__ as batched_new itertools.batched.__new__ as batched_new
iterable: object iterable: object
n: Py_ssize_t n: Py_ssize_t
*
strict: bool = False
Batch data into tuples of length n. The last batch may be shorter than n. Batch data into tuples of length n. The last batch may be shorter than n.
Loops over the input iterable and accumulates data into tuples Loops over the input iterable and accumulates data into tuples
@ -140,11 +134,15 @@ or when the input iterable is exhausted.
('D', 'E', 'F') ('D', 'E', 'F')
('G',) ('G',)
If "strict" is True, raises a ValueError if the final batch is shorter
than n.
[clinic start generated code]*/ [clinic start generated code]*/
static PyObject * static PyObject *
batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n) batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n,
/*[clinic end generated code: output=7ebc954d655371b6 input=ffd70726927c5129]*/ int strict)
/*[clinic end generated code: output=c6de11b061529d3e input=7814b47e222f5467]*/
{ {
PyObject *it; PyObject *it;
batchedobject *bo; batchedobject *bo;
@ -170,6 +168,7 @@ batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n)
} }
bo->batch_size = n; bo->batch_size = n;
bo->it = it; bo->it = it;
bo->strict = (bool) strict;
return (PyObject *)bo; return (PyObject *)bo;
} }
@ -233,6 +232,12 @@ batched_next(batchedobject *bo)
Py_DECREF(result); Py_DECREF(result);
return NULL; return NULL;
} }
if (bo->strict) {
Py_CLEAR(bo->it);
Py_DECREF(result);
PyErr_SetString(PyExc_ValueError, "batched(): incomplete batch");
return NULL;
}
_PyTuple_Resize(&result, i); _PyTuple_Resize(&result, i);
return result; return result;
} }