Simplify the signature for itertools.accumulate() to match numpy. Handle one item iterable the same way as min()/max().

This commit is contained in:
Raymond Hettinger 2010-12-03 02:09:34 +00:00
parent a7a0e1a0f4
commit d8ff4658fb
3 changed files with 31 additions and 37 deletions

View file

@ -90,13 +90,15 @@ loops that truncate the stream.
parameter (which defaults to :const:`0`). Elements may be any addable type parameter (which defaults to :const:`0`). Elements may be any addable type
including :class:`Decimal` or :class:`Fraction`. Equivalent to:: including :class:`Decimal` or :class:`Fraction`. Equivalent to::
def accumulate(iterable, start=0): def accumulate(iterable):
'Return running totals' 'Return running totals'
# accumulate([1,2,3,4,5]) --> 1 3 6 10 15 # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
total = start it = iter(iterable)
for element in iterable: total = next(it)
total += element yield total
yield total for element in it:
total += element
yield total
.. versionadded:: 3.2 .. versionadded:: 3.2

View file

@ -59,18 +59,18 @@ class TestBasicOps(unittest.TestCase):
def test_accumulate(self): def test_accumulate(self):
self.assertEqual(list(accumulate(range(10))), # one positional arg self.assertEqual(list(accumulate(range(10))), # one positional arg
[0, 1, 3, 6, 10, 15, 21, 28, 36, 45]) [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])
self.assertEqual(list(accumulate(range(10), 100)), # two positional args self.assertEqual(list(accumulate(iterable=range(10))), # kw arg
[100, 101, 103, 106, 110, 115, 121, 128, 136, 145]) [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])
self.assertEqual(list(accumulate(iterable=range(10), start=100)), # kw args
[100, 101, 103, 106, 110, 115, 121, 128, 136, 145])
for typ in int, complex, Decimal, Fraction: # multiple types for typ in int, complex, Decimal, Fraction: # multiple types
self.assertEqual(list(accumulate(range(10), typ(0))), self.assertEqual(
list(accumulate(map(typ, range(10)))),
list(map(typ, [0, 1, 3, 6, 10, 15, 21, 28, 36, 45]))) list(map(typ, [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])))
self.assertEqual(list(accumulate([])), []) # empty iterable self.assertEqual(list(accumulate([])), []) # empty iterable
self.assertRaises(TypeError, accumulate, range(10), 0, 5) # too many args self.assertEqual(list(accumulate([7])), [7]) # iterable of length one
self.assertRaises(TypeError, accumulate, range(10), 5) # too many args
self.assertRaises(TypeError, accumulate) # too few args self.assertRaises(TypeError, accumulate) # too few args
self.assertRaises(TypeError, accumulate, range(10), x=7) # unexpected kwd args self.assertRaises(TypeError, accumulate, x=range(10)) # unexpected kwd arg
self.assertRaises(TypeError, list, accumulate([1, []])) # args that don't add self.assertRaises(TypeError, list, accumulate([1, []])) # args that don't add
def test_chain(self): def test_chain(self):

View file

@ -2597,41 +2597,27 @@ static PyTypeObject accumulate_type;
static PyObject * static PyObject *
accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds) accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{ {
static char *kwargs[] = {"iterable", "start", NULL}; static char *kwargs[] = {"iterable", NULL};
PyObject *iterable; PyObject *iterable;
PyObject *it; PyObject *it;
PyObject *start = NULL;
accumulateobject *lz; accumulateobject *lz;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:accumulate", if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:accumulate", kwargs, &iterable))
kwargs, &iterable, &start)) return NULL;
return NULL;
/* Get iterator. */ /* Get iterator. */
it = PyObject_GetIter(iterable); it = PyObject_GetIter(iterable);
if (it == NULL) if (it == NULL)
return NULL; return NULL;
/* Default start value */
if (start == NULL) {
start = PyLong_FromLong(0);
if (start == NULL) {
Py_DECREF(it);
return NULL;
}
} else {
Py_INCREF(start);
}
/* create accumulateobject structure */ /* create accumulateobject structure */
lz = (accumulateobject *)type->tp_alloc(type, 0); lz = (accumulateobject *)type->tp_alloc(type, 0);
if (lz == NULL) { if (lz == NULL) {
Py_DECREF(it); Py_DECREF(it);
Py_DECREF(start); return NULL;
return NULL;
} }
lz->total = start; lz->total = NULL;
lz->it = it; lz->it = it;
return (PyObject *)lz; return (PyObject *)lz;
} }
@ -2661,11 +2647,17 @@ accumulate_next(accumulateobject *lz)
val = PyIter_Next(lz->it); val = PyIter_Next(lz->it);
if (val == NULL) if (val == NULL)
return NULL; return NULL;
if (lz->total == NULL) {
Py_INCREF(val);
lz->total = val;
return lz->total;
}
newtotal = PyNumber_Add(lz->total, val); newtotal = PyNumber_Add(lz->total, val);
Py_DECREF(val); Py_DECREF(val);
if (newtotal == NULL) if (newtotal == NULL)
return NULL; return NULL;
oldtotal = lz->total; oldtotal = lz->total;
lz->total = newtotal; lz->total = newtotal;
@ -2676,7 +2668,7 @@ accumulate_next(accumulateobject *lz)
} }
PyDoc_STRVAR(accumulate_doc, PyDoc_STRVAR(accumulate_doc,
"accumulate(iterable, start=0) --> accumulate object\n\ "accumulate(iterable) --> accumulate object\n\
\n\ \n\
Return series of accumulated sums."); Return series of accumulated sums.");