gh-119793: Add optional length-checking to map() (GH-120471)

Co-authored-by: Bénédikt Tran <10796600+picnixz@users.noreply.github.com>
Co-authored-by: Pieter Eendebak <pieter.eendebak@gmail.com>
Co-authored-by: Erlend E. Aasland <erlend.aasland@protonmail.com>
Co-authored-by: Raymond Hettinger <rhettinger@users.noreply.github.com>
This commit is contained in:
Nice Zombies 2024-11-04 15:00:19 +01:00 committed by GitHub
parent bfc1d2504c
commit 3032fcd90e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 210 additions and 17 deletions

View file

@ -1311,6 +1311,7 @@ typedef struct {
PyObject_HEAD
PyObject *iters;
PyObject *func;
int strict;
} mapobject;
static PyObject *
@ -1319,10 +1320,21 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
PyObject *it, *iters, *func;
mapobject *lz;
Py_ssize_t numargs, i;
int strict = 0;
if ((type == &PyMap_Type || type->tp_init == PyMap_Type.tp_init) &&
!_PyArg_NoKeywords("map", kwds))
return NULL;
if (kwds) {
PyObject *empty = PyTuple_New(0);
if (empty == NULL) {
return NULL;
}
static char *kwlist[] = {"strict", NULL};
int parsed = PyArg_ParseTupleAndKeywords(
empty, kwds, "|$p:map", kwlist, &strict);
Py_DECREF(empty);
if (!parsed) {
return NULL;
}
}
numargs = PyTuple_Size(args);
if (numargs < 2) {
@ -1354,6 +1366,7 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
lz->iters = iters;
func = PyTuple_GET_ITEM(args, 0);
lz->func = Py_NewRef(func);
lz->strict = strict;
return (PyObject *)lz;
}
@ -1363,11 +1376,14 @@ map_vectorcall(PyObject *type, PyObject * const*args,
size_t nargsf, PyObject *kwnames)
{
PyTypeObject *tp = _PyType_CAST(type);
if (tp == &PyMap_Type && !_PyArg_NoKwnames("map", kwnames)) {
return NULL;
}
Py_ssize_t nargs = PyVectorcall_NARGS(nargsf);
if (kwnames != NULL && PyTuple_GET_SIZE(kwnames) != 0) {
// Fallback to map_new()
PyThreadState *tstate = _PyThreadState_GET();
return _PyObject_MakeTpCall(tstate, type, args, nargs, kwnames);
}
if (nargs < 2) {
PyErr_SetString(PyExc_TypeError,
"map() must have at least two arguments.");
@ -1395,6 +1411,7 @@ map_vectorcall(PyObject *type, PyObject * const*args,
}
lz->iters = iters;
lz->func = Py_NewRef(args[0]);
lz->strict = 0;
return (PyObject *)lz;
}
@ -1419,6 +1436,7 @@ map_traverse(mapobject *lz, visitproc visit, void *arg)
static PyObject *
map_next(mapobject *lz)
{
Py_ssize_t i;
PyObject *small_stack[_PY_FASTCALL_SMALL_STACK];
PyObject **stack;
PyObject *result = NULL;
@ -1437,10 +1455,13 @@ map_next(mapobject *lz)
}
Py_ssize_t nargs = 0;
for (Py_ssize_t i=0; i < niters; i++) {
for (i=0; i < niters; i++) {
PyObject *it = PyTuple_GET_ITEM(lz->iters, i);
PyObject *val = Py_TYPE(it)->tp_iternext(it);
if (val == NULL) {
if (lz->strict) {
goto check;
}
goto exit;
}
stack[i] = val;
@ -1450,13 +1471,50 @@ map_next(mapobject *lz)
result = _PyObject_VectorcallTstate(tstate, lz->func, stack, nargs, NULL);
exit:
for (Py_ssize_t i=0; i < nargs; i++) {
for (i=0; i < nargs; i++) {
Py_DECREF(stack[i]);
}
if (stack != small_stack) {
PyMem_Free(stack);
}
return result;
check:
if (PyErr_Occurred()) {
if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
// next() on argument i raised an exception (not StopIteration)
return NULL;
}
PyErr_Clear();
}
if (i) {
// ValueError: map() argument 2 is shorter than argument 1
// ValueError: map() argument 3 is shorter than arguments 1-2
const char* plural = i == 1 ? " " : "s 1-";
return PyErr_Format(PyExc_ValueError,
"map() argument %d is shorter than argument%s%d",
i + 1, plural, i);
}
for (i = 1; i < niters; i++) {
PyObject *it = PyTuple_GET_ITEM(lz->iters, i);
PyObject *val = (*Py_TYPE(it)->tp_iternext)(it);
if (val) {
Py_DECREF(val);
const char* plural = i == 1 ? " " : "s 1-";
return PyErr_Format(PyExc_ValueError,
"map() argument %d is longer than argument%s%d",
i + 1, plural, i);
}
if (PyErr_Occurred()) {
if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
// next() on argument i raised an exception (not StopIteration)
return NULL;
}
PyErr_Clear();
}
// Argument i is exhausted. So far so good...
}
// All arguments are exhausted. Success!
goto exit;
}
static PyObject *
@ -1473,21 +1531,41 @@ map_reduce(mapobject *lz, PyObject *Py_UNUSED(ignored))
PyTuple_SET_ITEM(args, i+1, Py_NewRef(it));
}
if (lz->strict) {
return Py_BuildValue("ONO", Py_TYPE(lz), args, Py_True);
}
return Py_BuildValue("ON", Py_TYPE(lz), args);
}
PyDoc_STRVAR(setstate_doc, "Set state information for unpickling.");
static PyObject *
map_setstate(mapobject *lz, PyObject *state)
{
int strict = PyObject_IsTrue(state);
if (strict < 0) {
return NULL;
}
lz->strict = strict;
Py_RETURN_NONE;
}
static PyMethodDef map_methods[] = {
{"__reduce__", _PyCFunction_CAST(map_reduce), METH_NOARGS, reduce_doc},
{"__setstate__", _PyCFunction_CAST(map_setstate), METH_O, setstate_doc},
{NULL, NULL} /* sentinel */
};
PyDoc_STRVAR(map_doc,
"map(function, iterable, /, *iterables)\n\
"map(function, iterable, /, *iterables, strict=False)\n\
--\n\
\n\
Make an iterator that computes the function using arguments from\n\
each of the iterables. Stops when the shortest iterable is exhausted.");
each of the iterables. Stops when the shortest iterable is exhausted.\n\
\n\
If strict is true and one of the arguments is exhausted before the others,\n\
raise a ValueError.");
PyTypeObject PyMap_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
@ -3068,8 +3146,6 @@ zip_reduce(zipobject *lz, PyObject *Py_UNUSED(ignored))
return PyTuple_Pack(2, Py_TYPE(lz), lz->ittuple);
}
PyDoc_STRVAR(setstate_doc, "Set state information for unpickling.");
static PyObject *
zip_setstate(zipobject *lz, PyObject *state)
{