gh-113202: Add a strict option to itertools.batched() (gh-113203)

This commit is contained in:
Raymond Hettinger 2023-12-16 09:13:50 -06:00 committed by GitHub
parent fe479fb8a9
commit 1583c40be9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 60 additions and 24 deletions

View file

@ -105,20 +105,11 @@ class itertools.pairwise "pairwiseobject *" "clinic_state()->pairwise_type"
/* batched object ************************************************************/
/* Note: The built-in zip() function includes a "strict" argument
that was needed because that function would silently truncate data,
and there was no easy way for a user to detect the data loss.
The same reasoning does not apply to batched() which never drops data.
Instead, batched() produces a shorter tuple which can be handled
as the user sees fit. If requested, it would be reasonable to add
"fillvalue" support which had demonstrated value in zip_longest().
For now, the API is kept simple and clean.
*/
typedef struct {
PyObject_HEAD
PyObject *it;
Py_ssize_t batch_size;
bool strict;
} batchedobject;
/*[clinic input]
@ -126,6 +117,9 @@ typedef struct {
itertools.batched.__new__ as batched_new
iterable: object
n: Py_ssize_t
*
strict: bool = False
Batch data into tuples of length n. The last batch may be shorter than n.
Loops over the input iterable and accumulates data into tuples
@ -140,11 +134,15 @@ or when the input iterable is exhausted.
('D', 'E', 'F')
('G',)
If "strict" is True, raises a ValueError if the final batch is shorter
than n.
[clinic start generated code]*/
static PyObject *
batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n)
/*[clinic end generated code: output=7ebc954d655371b6 input=ffd70726927c5129]*/
batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n,
int strict)
/*[clinic end generated code: output=c6de11b061529d3e input=7814b47e222f5467]*/
{
PyObject *it;
batchedobject *bo;
@ -170,6 +168,7 @@ batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n)
}
bo->batch_size = n;
bo->it = it;
bo->strict = (bool) strict;
return (PyObject *)bo;
}
@ -233,6 +232,12 @@ batched_next(batchedobject *bo)
Py_DECREF(result);
return NULL;
}
if (bo->strict) {
Py_CLEAR(bo->it);
Py_DECREF(result);
PyErr_SetString(PyExc_ValueError, "batched(): incomplete batch");
return NULL;
}
_PyTuple_Resize(&result, i);
return result;
}