gh-119793: Add optional length-checking to map() (GH-120471)

Co-authored-by: Bénédikt Tran <10796600+picnixz@users.noreply.github.com>
Co-authored-by: Pieter Eendebak <pieter.eendebak@gmail.com>
Co-authored-by: Erlend E. Aasland <erlend.aasland@protonmail.com>
Co-authored-by: Raymond Hettinger <rhettinger@users.noreply.github.com>
This commit is contained in:
Nice Zombies 2024-11-04 15:00:19 +01:00 committed by GitHub
parent bfc1d2504c
commit 3032fcd90e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 210 additions and 17 deletions

View file

@ -1205,14 +1205,19 @@ are always available. They are listed here in alphabetical order.
unchanged from previous versions. unchanged from previous versions.
.. function:: map(function, iterable, *iterables) .. function:: map(function, iterable, /, *iterables, strict=False)
Return an iterator that applies *function* to every item of *iterable*, Return an iterator that applies *function* to every item of *iterable*,
yielding the results. If additional *iterables* arguments are passed, yielding the results. If additional *iterables* arguments are passed,
*function* must take that many arguments and is applied to the items from all *function* must take that many arguments and is applied to the items from all
iterables in parallel. With multiple iterables, the iterator stops when the iterables in parallel. With multiple iterables, the iterator stops when the
shortest iterable is exhausted. For cases where the function inputs are shortest iterable is exhausted. If *strict* is ``True`` and one of the
already arranged into argument tuples, see :func:`itertools.starmap`\. iterables is exhausted before the others, a :exc:`ValueError` is raised. For
cases where the function inputs are already arranged into argument tuples,
see :func:`itertools.starmap`.
.. versionchanged:: 3.14
Added the *strict* parameter.
.. function:: max(iterable, *, key=None) .. function:: max(iterable, *, key=None)

View file

@ -175,6 +175,10 @@ Improved error messages
Other language changes Other language changes
====================== ======================
* The :func:`map` built-in now has an optional keyword-only *strict* flag
like :func:`zip` to check that all the iterables are of equal length.
(Contributed by Wannes Boeykens in :gh:`119793`.)
* Incorrect usage of :keyword:`await` and asynchronous comprehensions * Incorrect usage of :keyword:`await` and asynchronous comprehensions
is now detected even if the code is optimized away by the :option:`-O` is now detected even if the code is optimized away by the :option:`-O`
command-line option. For example, ``python -O -c 'assert await 1'`` command-line option. For example, ``python -O -c 'assert await 1'``

View file

@ -148,6 +148,9 @@ def filter_char(arg):
def map_char(arg): def map_char(arg):
return chr(ord(arg)+1) return chr(ord(arg)+1)
def pack(*args):
return args
class BuiltinTest(unittest.TestCase): class BuiltinTest(unittest.TestCase):
# Helper to check picklability # Helper to check picklability
def check_iter_pickle(self, it, seq, proto): def check_iter_pickle(self, it, seq, proto):
@ -1269,6 +1272,108 @@ class BuiltinTest(unittest.TestCase):
m2 = map(map_char, "Is this the real life?") m2 = map(map_char, "Is this the real life?")
self.check_iter_pickle(m1, list(m2), proto) self.check_iter_pickle(m1, list(m2), proto)
# strict map tests based on strict zip tests
def test_map_pickle_strict(self):
a = (1, 2, 3)
b = (4, 5, 6)
t = [(1, 4), (2, 5), (3, 6)]
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
m1 = map(pack, a, b, strict=True)
self.check_iter_pickle(m1, t, proto)
def test_map_pickle_strict_fail(self):
a = (1, 2, 3)
b = (4, 5, 6, 7)
t = [(1, 4), (2, 5), (3, 6)]
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
m1 = map(pack, a, b, strict=True)
m2 = pickle.loads(pickle.dumps(m1, proto))
self.assertEqual(self.iter_error(m1, ValueError), t)
self.assertEqual(self.iter_error(m2, ValueError), t)
def test_map_strict(self):
self.assertEqual(tuple(map(pack, (1, 2, 3), 'abc', strict=True)),
((1, 'a'), (2, 'b'), (3, 'c')))
self.assertRaises(ValueError, tuple,
map(pack, (1, 2, 3, 4), 'abc', strict=True))
self.assertRaises(ValueError, tuple,
map(pack, (1, 2), 'abc', strict=True))
self.assertRaises(ValueError, tuple,
map(pack, (1, 2), (1, 2), 'abc', strict=True))
def test_map_strict_iterators(self):
x = iter(range(5))
y = [0]
z = iter(range(5))
self.assertRaises(ValueError, list,
(map(pack, x, y, z, strict=True)))
self.assertEqual(next(x), 2)
self.assertEqual(next(z), 1)
def test_map_strict_error_handling(self):
class Error(Exception):
pass
class Iter:
def __init__(self, size):
self.size = size
def __iter__(self):
return self
def __next__(self):
self.size -= 1
if self.size < 0:
raise Error
return self.size
l1 = self.iter_error(map(pack, "AB", Iter(1), strict=True), Error)
self.assertEqual(l1, [("A", 0)])
l2 = self.iter_error(map(pack, "AB", Iter(2), "A", strict=True), ValueError)
self.assertEqual(l2, [("A", 1, "A")])
l3 = self.iter_error(map(pack, "AB", Iter(2), "ABC", strict=True), Error)
self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
l4 = self.iter_error(map(pack, "AB", Iter(3), strict=True), ValueError)
self.assertEqual(l4, [("A", 2), ("B", 1)])
l5 = self.iter_error(map(pack, Iter(1), "AB", strict=True), Error)
self.assertEqual(l5, [(0, "A")])
l6 = self.iter_error(map(pack, Iter(2), "A", strict=True), ValueError)
self.assertEqual(l6, [(1, "A")])
l7 = self.iter_error(map(pack, Iter(2), "ABC", strict=True), Error)
self.assertEqual(l7, [(1, "A"), (0, "B")])
l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError)
self.assertEqual(l8, [(2, "A"), (1, "B")])
def test_map_strict_error_handling_stopiteration(self):
class Iter:
def __init__(self, size):
self.size = size
def __iter__(self):
return self
def __next__(self):
self.size -= 1
if self.size < 0:
raise StopIteration
return self.size
l1 = self.iter_error(map(pack, "AB", Iter(1), strict=True), ValueError)
self.assertEqual(l1, [("A", 0)])
l2 = self.iter_error(map(pack, "AB", Iter(2), "A", strict=True), ValueError)
self.assertEqual(l2, [("A", 1, "A")])
l3 = self.iter_error(map(pack, "AB", Iter(2), "ABC", strict=True), ValueError)
self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
l4 = self.iter_error(map(pack, "AB", Iter(3), strict=True), ValueError)
self.assertEqual(l4, [("A", 2), ("B", 1)])
l5 = self.iter_error(map(pack, Iter(1), "AB", strict=True), ValueError)
self.assertEqual(l5, [(0, "A")])
l6 = self.iter_error(map(pack, Iter(2), "A", strict=True), ValueError)
self.assertEqual(l6, [(1, "A")])
l7 = self.iter_error(map(pack, Iter(2), "ABC", strict=True), ValueError)
self.assertEqual(l7, [(1, "A"), (0, "B")])
l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError)
self.assertEqual(l8, [(2, "A"), (1, "B")])
def test_max(self): def test_max(self):
self.assertEqual(max('123123'), '3') self.assertEqual(max('123123'), '3')
self.assertEqual(max(1, 2, 3), 3) self.assertEqual(max(1, 2, 3), 3)

View file

@ -2433,10 +2433,10 @@ class SubclassWithKwargsTest(unittest.TestCase):
subclass(*args, newarg=3) subclass(*args, newarg=3)
for cls, args, result in testcases: for cls, args, result in testcases:
# Constructors of repeat, zip, compress accept keyword arguments. # Constructors of repeat, zip, map, compress accept keyword arguments.
# Their subclasses need overriding __new__ to support new # Their subclasses need overriding __new__ to support new
# keyword arguments. # keyword arguments.
if cls in [repeat, zip, compress]: if cls in [repeat, zip, map, compress]:
continue continue
with self.subTest(cls): with self.subTest(cls):
class subclass_with_init(cls): class subclass_with_init(cls):

View file

@ -0,0 +1,3 @@
The :func:`map` built-in now has an optional keyword-only *strict* flag
like :func:`zip` to check that all the iterables are of equal length.
Patch by Wannes Boeykens.

View file

@ -1311,6 +1311,7 @@ typedef struct {
PyObject_HEAD PyObject_HEAD
PyObject *iters; PyObject *iters;
PyObject *func; PyObject *func;
int strict;
} mapobject; } mapobject;
static PyObject * static PyObject *
@ -1319,10 +1320,21 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
PyObject *it, *iters, *func; PyObject *it, *iters, *func;
mapobject *lz; mapobject *lz;
Py_ssize_t numargs, i; Py_ssize_t numargs, i;
int strict = 0;
if ((type == &PyMap_Type || type->tp_init == PyMap_Type.tp_init) && if (kwds) {
!_PyArg_NoKeywords("map", kwds)) PyObject *empty = PyTuple_New(0);
return NULL; if (empty == NULL) {
return NULL;
}
static char *kwlist[] = {"strict", NULL};
int parsed = PyArg_ParseTupleAndKeywords(
empty, kwds, "|$p:map", kwlist, &strict);
Py_DECREF(empty);
if (!parsed) {
return NULL;
}
}
numargs = PyTuple_Size(args); numargs = PyTuple_Size(args);
if (numargs < 2) { if (numargs < 2) {
@ -1354,6 +1366,7 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
lz->iters = iters; lz->iters = iters;
func = PyTuple_GET_ITEM(args, 0); func = PyTuple_GET_ITEM(args, 0);
lz->func = Py_NewRef(func); lz->func = Py_NewRef(func);
lz->strict = strict;
return (PyObject *)lz; return (PyObject *)lz;
} }
@ -1363,11 +1376,14 @@ map_vectorcall(PyObject *type, PyObject * const*args,
size_t nargsf, PyObject *kwnames) size_t nargsf, PyObject *kwnames)
{ {
PyTypeObject *tp = _PyType_CAST(type); PyTypeObject *tp = _PyType_CAST(type);
if (tp == &PyMap_Type && !_PyArg_NoKwnames("map", kwnames)) {
return NULL;
}
Py_ssize_t nargs = PyVectorcall_NARGS(nargsf); Py_ssize_t nargs = PyVectorcall_NARGS(nargsf);
if (kwnames != NULL && PyTuple_GET_SIZE(kwnames) != 0) {
// Fallback to map_new()
PyThreadState *tstate = _PyThreadState_GET();
return _PyObject_MakeTpCall(tstate, type, args, nargs, kwnames);
}
if (nargs < 2) { if (nargs < 2) {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"map() must have at least two arguments."); "map() must have at least two arguments.");
@ -1395,6 +1411,7 @@ map_vectorcall(PyObject *type, PyObject * const*args,
} }
lz->iters = iters; lz->iters = iters;
lz->func = Py_NewRef(args[0]); lz->func = Py_NewRef(args[0]);
lz->strict = 0;
return (PyObject *)lz; return (PyObject *)lz;
} }
@ -1419,6 +1436,7 @@ map_traverse(mapobject *lz, visitproc visit, void *arg)
static PyObject * static PyObject *
map_next(mapobject *lz) map_next(mapobject *lz)
{ {
Py_ssize_t i;
PyObject *small_stack[_PY_FASTCALL_SMALL_STACK]; PyObject *small_stack[_PY_FASTCALL_SMALL_STACK];
PyObject **stack; PyObject **stack;
PyObject *result = NULL; PyObject *result = NULL;
@ -1437,10 +1455,13 @@ map_next(mapobject *lz)
} }
Py_ssize_t nargs = 0; Py_ssize_t nargs = 0;
for (Py_ssize_t i=0; i < niters; i++) { for (i=0; i < niters; i++) {
PyObject *it = PyTuple_GET_ITEM(lz->iters, i); PyObject *it = PyTuple_GET_ITEM(lz->iters, i);
PyObject *val = Py_TYPE(it)->tp_iternext(it); PyObject *val = Py_TYPE(it)->tp_iternext(it);
if (val == NULL) { if (val == NULL) {
if (lz->strict) {
goto check;
}
goto exit; goto exit;
} }
stack[i] = val; stack[i] = val;
@ -1450,13 +1471,50 @@ map_next(mapobject *lz)
result = _PyObject_VectorcallTstate(tstate, lz->func, stack, nargs, NULL); result = _PyObject_VectorcallTstate(tstate, lz->func, stack, nargs, NULL);
exit: exit:
for (Py_ssize_t i=0; i < nargs; i++) { for (i=0; i < nargs; i++) {
Py_DECREF(stack[i]); Py_DECREF(stack[i]);
} }
if (stack != small_stack) { if (stack != small_stack) {
PyMem_Free(stack); PyMem_Free(stack);
} }
return result; return result;
check:
if (PyErr_Occurred()) {
if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
// next() on argument i raised an exception (not StopIteration)
return NULL;
}
PyErr_Clear();
}
if (i) {
// ValueError: map() argument 2 is shorter than argument 1
// ValueError: map() argument 3 is shorter than arguments 1-2
const char* plural = i == 1 ? " " : "s 1-";
return PyErr_Format(PyExc_ValueError,
"map() argument %d is shorter than argument%s%d",
i + 1, plural, i);
}
for (i = 1; i < niters; i++) {
PyObject *it = PyTuple_GET_ITEM(lz->iters, i);
PyObject *val = (*Py_TYPE(it)->tp_iternext)(it);
if (val) {
Py_DECREF(val);
const char* plural = i == 1 ? " " : "s 1-";
return PyErr_Format(PyExc_ValueError,
"map() argument %d is longer than argument%s%d",
i + 1, plural, i);
}
if (PyErr_Occurred()) {
if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
// next() on argument i raised an exception (not StopIteration)
return NULL;
}
PyErr_Clear();
}
// Argument i is exhausted. So far so good...
}
// All arguments are exhausted. Success!
goto exit;
} }
static PyObject * static PyObject *
@ -1473,21 +1531,41 @@ map_reduce(mapobject *lz, PyObject *Py_UNUSED(ignored))
PyTuple_SET_ITEM(args, i+1, Py_NewRef(it)); PyTuple_SET_ITEM(args, i+1, Py_NewRef(it));
} }
if (lz->strict) {
return Py_BuildValue("ONO", Py_TYPE(lz), args, Py_True);
}
return Py_BuildValue("ON", Py_TYPE(lz), args); return Py_BuildValue("ON", Py_TYPE(lz), args);
} }
PyDoc_STRVAR(setstate_doc, "Set state information for unpickling.");
static PyObject *
map_setstate(mapobject *lz, PyObject *state)
{
int strict = PyObject_IsTrue(state);
if (strict < 0) {
return NULL;
}
lz->strict = strict;
Py_RETURN_NONE;
}
static PyMethodDef map_methods[] = { static PyMethodDef map_methods[] = {
{"__reduce__", _PyCFunction_CAST(map_reduce), METH_NOARGS, reduce_doc}, {"__reduce__", _PyCFunction_CAST(map_reduce), METH_NOARGS, reduce_doc},
{"__setstate__", _PyCFunction_CAST(map_setstate), METH_O, setstate_doc},
{NULL, NULL} /* sentinel */ {NULL, NULL} /* sentinel */
}; };
PyDoc_STRVAR(map_doc, PyDoc_STRVAR(map_doc,
"map(function, iterable, /, *iterables)\n\ "map(function, iterable, /, *iterables, strict=False)\n\
--\n\ --\n\
\n\ \n\
Make an iterator that computes the function using arguments from\n\ Make an iterator that computes the function using arguments from\n\
each of the iterables. Stops when the shortest iterable is exhausted."); each of the iterables. Stops when the shortest iterable is exhausted.\n\
\n\
If strict is true and one of the arguments is exhausted before the others,\n\
raise a ValueError.");
PyTypeObject PyMap_Type = { PyTypeObject PyMap_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) PyVarObject_HEAD_INIT(&PyType_Type, 0)
@ -3068,8 +3146,6 @@ zip_reduce(zipobject *lz, PyObject *Py_UNUSED(ignored))
return PyTuple_Pack(2, Py_TYPE(lz), lz->ittuple); return PyTuple_Pack(2, Py_TYPE(lz), lz->ittuple);
} }
PyDoc_STRVAR(setstate_doc, "Set state information for unpickling.");
static PyObject * static PyObject *
zip_setstate(zipobject *lz, PyObject *state) zip_setstate(zipobject *lz, PyObject *state)
{ {