mirror of
https://github.com/python/cpython.git
synced 2025-09-26 18:29:57 +00:00
Simplify the signature for itertools.accumulate() to match numpy. Handle one item iterable the same way as min()/max().
This commit is contained in:
parent
a7a0e1a0f4
commit
d8ff4658fb
3 changed files with 31 additions and 37 deletions
|
@ -90,11 +90,13 @@ loops that truncate the stream.
|
||||||
parameter (which defaults to :const:`0`). Elements may be any addable type
|
parameter (which defaults to :const:`0`). Elements may be any addable type
|
||||||
including :class:`Decimal` or :class:`Fraction`. Equivalent to::
|
including :class:`Decimal` or :class:`Fraction`. Equivalent to::
|
||||||
|
|
||||||
def accumulate(iterable, start=0):
|
def accumulate(iterable):
|
||||||
'Return running totals'
|
'Return running totals'
|
||||||
# accumulate([1,2,3,4,5]) --> 1 3 6 10 15
|
# accumulate([1,2,3,4,5]) --> 1 3 6 10 15
|
||||||
total = start
|
it = iter(iterable)
|
||||||
for element in iterable:
|
total = next(it)
|
||||||
|
yield total
|
||||||
|
for element in it:
|
||||||
total += element
|
total += element
|
||||||
yield total
|
yield total
|
||||||
|
|
||||||
|
|
|
@ -60,17 +60,17 @@ class TestBasicOps(unittest.TestCase):
|
||||||
def test_accumulate(self):
|
def test_accumulate(self):
|
||||||
self.assertEqual(list(accumulate(range(10))), # one positional arg
|
self.assertEqual(list(accumulate(range(10))), # one positional arg
|
||||||
[0, 1, 3, 6, 10, 15, 21, 28, 36, 45])
|
[0, 1, 3, 6, 10, 15, 21, 28, 36, 45])
|
||||||
self.assertEqual(list(accumulate(range(10), 100)), # two positional args
|
self.assertEqual(list(accumulate(iterable=range(10))), # kw arg
|
||||||
[100, 101, 103, 106, 110, 115, 121, 128, 136, 145])
|
[0, 1, 3, 6, 10, 15, 21, 28, 36, 45])
|
||||||
self.assertEqual(list(accumulate(iterable=range(10), start=100)), # kw args
|
|
||||||
[100, 101, 103, 106, 110, 115, 121, 128, 136, 145])
|
|
||||||
for typ in int, complex, Decimal, Fraction: # multiple types
|
for typ in int, complex, Decimal, Fraction: # multiple types
|
||||||
self.assertEqual(list(accumulate(range(10), typ(0))),
|
self.assertEqual(
|
||||||
|
list(accumulate(map(typ, range(10)))),
|
||||||
list(map(typ, [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])))
|
list(map(typ, [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])))
|
||||||
self.assertEqual(list(accumulate([])), []) # empty iterable
|
self.assertEqual(list(accumulate([])), []) # empty iterable
|
||||||
self.assertRaises(TypeError, accumulate, range(10), 0, 5) # too many args
|
self.assertEqual(list(accumulate([7])), [7]) # iterable of length one
|
||||||
|
self.assertRaises(TypeError, accumulate, range(10), 5) # too many args
|
||||||
self.assertRaises(TypeError, accumulate) # too few args
|
self.assertRaises(TypeError, accumulate) # too few args
|
||||||
self.assertRaises(TypeError, accumulate, range(10), x=7) # unexpected kwd args
|
self.assertRaises(TypeError, accumulate, x=range(10)) # unexpected kwd arg
|
||||||
self.assertRaises(TypeError, list, accumulate([1, []])) # args that don't add
|
self.assertRaises(TypeError, list, accumulate([1, []])) # args that don't add
|
||||||
|
|
||||||
def test_chain(self):
|
def test_chain(self):
|
||||||
|
|
|
@ -2597,14 +2597,12 @@ static PyTypeObject accumulate_type;
|
||||||
static PyObject *
|
static PyObject *
|
||||||
accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
|
accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
|
||||||
{
|
{
|
||||||
static char *kwargs[] = {"iterable", "start", NULL};
|
static char *kwargs[] = {"iterable", NULL};
|
||||||
PyObject *iterable;
|
PyObject *iterable;
|
||||||
PyObject *it;
|
PyObject *it;
|
||||||
PyObject *start = NULL;
|
|
||||||
accumulateobject *lz;
|
accumulateobject *lz;
|
||||||
|
|
||||||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:accumulate",
|
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:accumulate", kwargs, &iterable))
|
||||||
kwargs, &iterable, &start))
|
|
||||||
return NULL;
|
return NULL;
|
||||||
|
|
||||||
/* Get iterator. */
|
/* Get iterator. */
|
||||||
|
@ -2612,26 +2610,14 @@ accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
|
||||||
if (it == NULL)
|
if (it == NULL)
|
||||||
return NULL;
|
return NULL;
|
||||||
|
|
||||||
/* Default start value */
|
|
||||||
if (start == NULL) {
|
|
||||||
start = PyLong_FromLong(0);
|
|
||||||
if (start == NULL) {
|
|
||||||
Py_DECREF(it);
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Py_INCREF(start);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* create accumulateobject structure */
|
/* create accumulateobject structure */
|
||||||
lz = (accumulateobject *)type->tp_alloc(type, 0);
|
lz = (accumulateobject *)type->tp_alloc(type, 0);
|
||||||
if (lz == NULL) {
|
if (lz == NULL) {
|
||||||
Py_DECREF(it);
|
Py_DECREF(it);
|
||||||
Py_DECREF(start);
|
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
lz->total = start;
|
lz->total = NULL;
|
||||||
lz->it = it;
|
lz->it = it;
|
||||||
return (PyObject *)lz;
|
return (PyObject *)lz;
|
||||||
}
|
}
|
||||||
|
@ -2662,6 +2648,12 @@ accumulate_next(accumulateobject *lz)
|
||||||
if (val == NULL)
|
if (val == NULL)
|
||||||
return NULL;
|
return NULL;
|
||||||
|
|
||||||
|
if (lz->total == NULL) {
|
||||||
|
Py_INCREF(val);
|
||||||
|
lz->total = val;
|
||||||
|
return lz->total;
|
||||||
|
}
|
||||||
|
|
||||||
newtotal = PyNumber_Add(lz->total, val);
|
newtotal = PyNumber_Add(lz->total, val);
|
||||||
Py_DECREF(val);
|
Py_DECREF(val);
|
||||||
if (newtotal == NULL)
|
if (newtotal == NULL)
|
||||||
|
@ -2676,7 +2668,7 @@ accumulate_next(accumulateobject *lz)
|
||||||
}
|
}
|
||||||
|
|
||||||
PyDoc_STRVAR(accumulate_doc,
|
PyDoc_STRVAR(accumulate_doc,
|
||||||
"accumulate(iterable, start=0) --> accumulate object\n\
|
"accumulate(iterable) --> accumulate object\n\
|
||||||
\n\
|
\n\
|
||||||
Return series of accumulated sums.");
|
Return series of accumulated sums.");
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue