diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py index 290ba2cad8e..40df7b606ae 100644 --- a/Lib/test/test_builtin.py +++ b/Lib/test/test_builtin.py @@ -1521,6 +1521,14 @@ class BuiltinTest(unittest.TestCase): self.assertRaises(TypeError, vars, 42) self.assertEqual(vars(self.C_get_vars()), {'a':2}) + def iter_error(self, iterable, error): + """Collect `iterable` into a list, catching an expected `error`.""" + items = [] + with self.assertRaises(error): + for item in iterable: + items.append(item) + return items + def test_zip(self): a = (1, 2, 3) b = (4, 5, 6) @@ -1573,6 +1581,66 @@ class BuiltinTest(unittest.TestCase): z1 = zip(a, b) self.check_iter_pickle(z1, t, proto) + def test_zip_pickle_strict(self): + a = (1, 2, 3) + b = (4, 5, 6) + t = [(1, 4), (2, 5), (3, 6)] + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + z1 = zip(a, b, strict=True) + self.check_iter_pickle(z1, t, proto) + + def test_zip_pickle_strict_fail(self): + a = (1, 2, 3) + b = (4, 5, 6, 7) + t = [(1, 4), (2, 5), (3, 6)] + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + z1 = zip(a, b, strict=True) + z2 = pickle.loads(pickle.dumps(z1, proto)) + self.assertEqual(self.iter_error(z1, ValueError), t) + self.assertEqual(self.iter_error(z2, ValueError), t) + + def test_zip_pickle_stability(self): + # Pickles of zip((1, 2, 3), (4, 5, 6)) dumped from 3.9: + pickles = [ + b'citertools\nizip\np0\n(c__builtin__\niter\np1\n((I1\nI2\nI3\ntp2\ntp3\nRp4\nI0\nbg1\n((I4\nI5\nI6\ntp5\ntp6\nRp7\nI0\nbtp8\nRp9\n.', + b'citertools\nizip\nq\x00(c__builtin__\niter\nq\x01((K\x01K\x02K\x03tq\x02tq\x03Rq\x04K\x00bh\x01((K\x04K\x05K\x06tq\x05tq\x06Rq\x07K\x00btq\x08Rq\t.', + b'\x80\x02citertools\nizip\nq\x00c__builtin__\niter\nq\x01K\x01K\x02K\x03\x87q\x02\x85q\x03Rq\x04K\x00bh\x01K\x04K\x05K\x06\x87q\x05\x85q\x06Rq\x07K\x00b\x86q\x08Rq\t.', + b'\x80\x03cbuiltins\nzip\nq\x00cbuiltins\niter\nq\x01K\x01K\x02K\x03\x87q\x02\x85q\x03Rq\x04K\x00bh\x01K\x04K\x05K\x06\x87q\x05\x85q\x06Rq\x07K\x00b\x86q\x08Rq\t.', + b'\x80\x04\x95L\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x03zip\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94K\x01K\x02K\x03\x87\x94\x85\x94R\x94K\x00bh\x05K\x04K\x05K\x06\x87\x94\x85\x94R\x94K\x00b\x86\x94R\x94.', + b'\x80\x05\x95L\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x03zip\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94K\x01K\x02K\x03\x87\x94\x85\x94R\x94K\x00bh\x05K\x04K\x05K\x06\x87\x94\x85\x94R\x94K\x00b\x86\x94R\x94.', + ] + for protocol, dump in enumerate(pickles): + z1 = zip((1, 2, 3), (4, 5, 6)) + z2 = zip((1, 2, 3), (4, 5, 6), strict=False) + z3 = pickle.loads(dump) + l3 = list(z3) + self.assertEqual(type(z3), zip) + self.assertEqual(pickle.dumps(z1, protocol), dump) + self.assertEqual(pickle.dumps(z2, protocol), dump) + self.assertEqual(list(z1), l3) + self.assertEqual(list(z2), l3) + + def test_zip_pickle_strict_stability(self): + # Pickles of zip((1, 2, 3), (4, 5), strict=True) dumped from 3.10: + pickles = [ + b'citertools\nizip\np0\n(c__builtin__\niter\np1\n((I1\nI2\nI3\ntp2\ntp3\nRp4\nI0\nbg1\n((I4\nI5\ntp5\ntp6\nRp7\nI0\nbtp8\nRp9\nI01\nb.', + b'citertools\nizip\nq\x00(c__builtin__\niter\nq\x01((K\x01K\x02K\x03tq\x02tq\x03Rq\x04K\x00bh\x01((K\x04K\x05tq\x05tq\x06Rq\x07K\x00btq\x08Rq\tI01\nb.', + b'\x80\x02citertools\nizip\nq\x00c__builtin__\niter\nq\x01K\x01K\x02K\x03\x87q\x02\x85q\x03Rq\x04K\x00bh\x01K\x04K\x05\x86q\x05\x85q\x06Rq\x07K\x00b\x86q\x08Rq\t\x88b.', + b'\x80\x03cbuiltins\nzip\nq\x00cbuiltins\niter\nq\x01K\x01K\x02K\x03\x87q\x02\x85q\x03Rq\x04K\x00bh\x01K\x04K\x05\x86q\x05\x85q\x06Rq\x07K\x00b\x86q\x08Rq\t\x88b.', + b'\x80\x04\x95L\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x03zip\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94K\x01K\x02K\x03\x87\x94\x85\x94R\x94K\x00bh\x05K\x04K\x05\x86\x94\x85\x94R\x94K\x00b\x86\x94R\x94\x88b.', + b'\x80\x05\x95L\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x03zip\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94K\x01K\x02K\x03\x87\x94\x85\x94R\x94K\x00bh\x05K\x04K\x05\x86\x94\x85\x94R\x94K\x00b\x86\x94R\x94\x88b.', + ] + a = (1, 2, 3) + b = (4, 5) + t = [(1, 4), (2, 5)] + for protocol, dump in enumerate(pickles): + z1 = zip(a, b, strict=True) + z2 = pickle.loads(dump) + self.assertEqual(pickle.dumps(z1, protocol), dump) + self.assertEqual(type(z2), zip) + self.assertEqual(self.iter_error(z1, ValueError), t) + self.assertEqual(self.iter_error(z2, ValueError), t) + def test_zip_bad_iterable(self): exception = TypeError() @@ -1585,6 +1653,88 @@ class BuiltinTest(unittest.TestCase): self.assertIs(cm.exception, exception) + def test_zip_strict(self): + self.assertEqual(tuple(zip((1, 2, 3), 'abc', strict=True)), + ((1, 'a'), (2, 'b'), (3, 'c'))) + self.assertRaises(ValueError, tuple, + zip((1, 2, 3, 4), 'abc', strict=True)) + self.assertRaises(ValueError, tuple, + zip((1, 2), 'abc', strict=True)) + self.assertRaises(ValueError, tuple, + zip((1, 2), (1, 2), 'abc', strict=True)) + + def test_zip_strict_iterators(self): + x = iter(range(5)) + y = [0] + z = iter(range(5)) + self.assertRaises(ValueError, list, + (zip(x, y, z, strict=True))) + self.assertEqual(next(x), 2) + self.assertEqual(next(z), 1) + + def test_zip_strict_error_handling(self): + + class Error(Exception): + pass + + class Iter: + def __init__(self, size): + self.size = size + def __iter__(self): + return self + def __next__(self): + self.size -= 1 + if self.size < 0: + raise Error + return self.size + + l1 = self.iter_error(zip("AB", Iter(1), strict=True), Error) + self.assertEqual(l1, [("A", 0)]) + l2 = self.iter_error(zip("AB", Iter(2), "A", strict=True), ValueError) + self.assertEqual(l2, [("A", 1, "A")]) + l3 = self.iter_error(zip("AB", Iter(2), "ABC", strict=True), Error) + self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")]) + l4 = self.iter_error(zip("AB", Iter(3), strict=True), ValueError) + self.assertEqual(l4, [("A", 2), ("B", 1)]) + l5 = self.iter_error(zip(Iter(1), "AB", strict=True), Error) + self.assertEqual(l5, [(0, "A")]) + l6 = self.iter_error(zip(Iter(2), "A", strict=True), ValueError) + self.assertEqual(l6, [(1, "A")]) + l7 = self.iter_error(zip(Iter(2), "ABC", strict=True), Error) + self.assertEqual(l7, [(1, "A"), (0, "B")]) + l8 = self.iter_error(zip(Iter(3), "AB", strict=True), ValueError) + self.assertEqual(l8, [(2, "A"), (1, "B")]) + + def test_zip_strict_error_handling_stopiteration(self): + + class Iter: + def __init__(self, size): + self.size = size + def __iter__(self): + return self + def __next__(self): + self.size -= 1 + if self.size < 0: + raise StopIteration + return self.size + + l1 = self.iter_error(zip("AB", Iter(1), strict=True), ValueError) + self.assertEqual(l1, [("A", 0)]) + l2 = self.iter_error(zip("AB", Iter(2), "A", strict=True), ValueError) + self.assertEqual(l2, [("A", 1, "A")]) + l3 = self.iter_error(zip("AB", Iter(2), "ABC", strict=True), ValueError) + self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")]) + l4 = self.iter_error(zip("AB", Iter(3), strict=True), ValueError) + self.assertEqual(l4, [("A", 2), ("B", 1)]) + l5 = self.iter_error(zip(Iter(1), "AB", strict=True), ValueError) + self.assertEqual(l5, [(0, "A")]) + l6 = self.iter_error(zip(Iter(2), "A", strict=True), ValueError) + self.assertEqual(l6, [(1, "A")]) + l7 = self.iter_error(zip(Iter(2), "ABC", strict=True), ValueError) + self.assertEqual(l7, [(1, "A"), (0, "B")]) + l8 = self.iter_error(zip(Iter(3), "AB", strict=True), ValueError) + self.assertEqual(l8, [(2, "A"), (1, "B")]) + def test_format(self): # Test the basic machinery of the format() builtin. Don't test # the specifics of the various formatters diff --git a/Misc/NEWS.d/next/Core and Builtins/2020-06-17-10-27-17.bpo-40636.MYaCIe.rst b/Misc/NEWS.d/next/Core and Builtins/2020-06-17-10-27-17.bpo-40636.MYaCIe.rst new file mode 100644 index 00000000000..ba26ad9373c --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2020-06-17-10-27-17.bpo-40636.MYaCIe.rst @@ -0,0 +1,3 @@ +:func:`zip` now supports :pep:`618`'s ``strict`` parameter, which raises a +:exc:`ValueError` if the arguments are exhausted at different lengths. +Patch by Brandt Bucher. diff --git a/Python/bltinmodule.c b/Python/bltinmodule.c index 65f95280846..c6ede1cd7f6 100644 --- a/Python/bltinmodule.c +++ b/Python/bltinmodule.c @@ -2517,9 +2517,10 @@ builtin_issubclass_impl(PyObject *module, PyObject *cls, typedef struct { PyObject_HEAD - Py_ssize_t tuplesize; - PyObject *ittuple; /* tuple of iterators */ + Py_ssize_t tuplesize; + PyObject *ittuple; /* tuple of iterators */ PyObject *result; + int strict; } zipobject; static PyObject * @@ -2530,9 +2531,21 @@ zip_new(PyTypeObject *type, PyObject *args, PyObject *kwds) PyObject *ittuple; /* tuple of iterators */ PyObject *result; Py_ssize_t tuplesize; + int strict = 0; - if (type == &PyZip_Type && !_PyArg_NoKeywords("zip", kwds)) - return NULL; + if (kwds) { + PyObject *empty = PyTuple_New(0); + if (empty == NULL) { + return NULL; + } + static char *kwlist[] = {"strict", NULL}; + int parsed = PyArg_ParseTupleAndKeywords( + empty, kwds, "|$p:zip", kwlist, &strict); + Py_DECREF(empty); + if (!parsed) { + return NULL; + } + } /* args must be a tuple */ assert(PyTuple_Check(args)); @@ -2573,6 +2586,7 @@ zip_new(PyTypeObject *type, PyObject *args, PyObject *kwds) lz->ittuple = ittuple; lz->tuplesize = tuplesize; lz->result = result; + lz->strict = strict; return (PyObject *)lz; } @@ -2613,6 +2627,9 @@ zip_next(zipobject *lz) item = (*Py_TYPE(it)->tp_iternext)(it); if (item == NULL) { Py_DECREF(result); + if (lz->strict) { + goto check; + } return NULL; } olditem = PyTuple_GET_ITEM(result, i); @@ -2628,28 +2645,85 @@ zip_next(zipobject *lz) item = (*Py_TYPE(it)->tp_iternext)(it); if (item == NULL) { Py_DECREF(result); + if (lz->strict) { + goto check; + } return NULL; } PyTuple_SET_ITEM(result, i, item); } } return result; +check: + if (PyErr_Occurred()) { + if (!PyErr_ExceptionMatches(PyExc_StopIteration)) { + // next() on argument i raised an exception (not StopIteration) + return NULL; + } + PyErr_Clear(); + } + if (i) { + // ValueError: zip() argument 2 is shorter than argument 1 + // ValueError: zip() argument 3 is shorter than arguments 1-2 + const char* plural = i == 1 ? " " : "s 1-"; + return PyErr_Format(PyExc_ValueError, + "zip() argument %d is shorter than argument%s%d", + i + 1, plural, i); + } + for (i = 1; i < tuplesize; i++) { + it = PyTuple_GET_ITEM(lz->ittuple, i); + item = (*Py_TYPE(it)->tp_iternext)(it); + if (item) { + Py_DECREF(item); + const char* plural = i == 1 ? " " : "s 1-"; + return PyErr_Format(PyExc_ValueError, + "zip() argument %d is longer than argument%s%d", + i + 1, plural, i); + } + if (PyErr_Occurred()) { + if (!PyErr_ExceptionMatches(PyExc_StopIteration)) { + // next() on argument i raised an exception (not StopIteration) + return NULL; + } + PyErr_Clear(); + } + // Argument i is exhausted. So far so good... + } + // All arguments are exhausted. Success! + return NULL; } static PyObject * zip_reduce(zipobject *lz, PyObject *Py_UNUSED(ignored)) { /* Just recreate the zip with the internal iterator tuple */ - return Py_BuildValue("OO", Py_TYPE(lz), lz->ittuple); + if (lz->strict) { + return PyTuple_Pack(3, Py_TYPE(lz), lz->ittuple, Py_True); + } + return PyTuple_Pack(2, Py_TYPE(lz), lz->ittuple); +} + +PyDoc_STRVAR(setstate_doc, "Set state information for unpickling."); + +static PyObject * +zip_setstate(zipobject *lz, PyObject *state) +{ + int strict = PyObject_IsTrue(state); + if (strict < 0) { + return NULL; + } + lz->strict = strict; + Py_RETURN_NONE; } static PyMethodDef zip_methods[] = { {"__reduce__", (PyCFunction)zip_reduce, METH_NOARGS, reduce_doc}, - {NULL, NULL} /* sentinel */ + {"__setstate__", (PyCFunction)zip_setstate, METH_O, setstate_doc}, + {NULL} /* sentinel */ }; PyDoc_STRVAR(zip_doc, -"zip(*iterables) --> A zip object yielding tuples until an input is exhausted.\n\ +"zip(*iterables, strict=False) --> Yield tuples until an input is exhausted.\n\ \n\ >>> list(zip('abcdefg', range(3), range(4)))\n\ [('a', 0, 0), ('b', 1, 1), ('c', 2, 2)]\n\ @@ -2657,7 +2731,10 @@ PyDoc_STRVAR(zip_doc, The zip object yields n-length tuples, where n is the number of iterables\n\ passed as positional arguments to zip(). The i-th element in every tuple\n\ comes from the i-th iterable argument to zip(). This continues until the\n\ -shortest argument is exhausted."); +shortest argument is exhausted.\n\ +\n\ +If strict is true and one of the arguments is exhausted before the others,\n\ +raise a ValueError."); PyTypeObject PyZip_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0)