GH-98363: Add itertools.batched() (GH-98364)

This commit is contained in:
Raymond Hettinger 2022-10-17 18:53:45 -05:00 committed by GitHub
parent 70732d8a4c
commit de3ece769a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 370 additions and 39 deletions

View file

@ -159,6 +159,44 @@ class TestBasicOps(unittest.TestCase):
with self.assertRaises(TypeError):
list(accumulate([10, 20], 100))
def test_batched(self):
self.assertEqual(list(batched('ABCDEFG', 3)),
[['A', 'B', 'C'], ['D', 'E', 'F'], ['G']])
self.assertEqual(list(batched('ABCDEFG', 2)),
[['A', 'B'], ['C', 'D'], ['E', 'F'], ['G']])
self.assertEqual(list(batched('ABCDEFG', 1)),
[['A'], ['B'], ['C'], ['D'], ['E'], ['F'], ['G']])
with self.assertRaises(TypeError): # Too few arguments
list(batched('ABCDEFG'))
with self.assertRaises(TypeError):
list(batched('ABCDEFG', 3, None)) # Too many arguments
with self.assertRaises(TypeError):
list(batched(None, 3)) # Non-iterable input
with self.assertRaises(TypeError):
list(batched('ABCDEFG', 'hello')) # n is a string
with self.assertRaises(ValueError):
list(batched('ABCDEFG', 0)) # n is zero
with self.assertRaises(ValueError):
list(batched('ABCDEFG', -1)) # n is negative
data = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
for n in range(1, 6):
for i in range(len(data)):
s = data[:i]
batches = list(batched(s, n))
with self.subTest(s=s, n=n, batches=batches):
# Order is preserved and no data is lost
self.assertEqual(''.join(chain(*batches)), s)
# Each batch is an exact list
self.assertTrue(all(type(batch) is list for batch in batches))
# All but the last batch is of size n
if batches:
last_batch = batches.pop()
self.assertTrue(all(len(batch) == n for batch in batches))
self.assertTrue(len(last_batch) <= n)
batches.append(last_batch)
def test_chain(self):
def chain2(*iterables):
@ -1737,6 +1775,31 @@ class TestExamples(unittest.TestCase):
class TestPurePythonRoughEquivalents(unittest.TestCase):
def test_batched_recipe(self):
def batched_recipe(iterable, n):
"Batch data into lists of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError('n must be at least one')
it = iter(iterable)
while (batch := list(islice(it, n))):
yield batch
for iterable, n in product(
['', 'a', 'ab', 'abc', 'abcd', 'abcde', 'abcdef', 'abcdefg', None],
[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, None]):
with self.subTest(iterable=iterable, n=n):
try:
e1, r1 = None, list(batched(iterable, n))
except Exception as e:
e1, r1 = type(e), None
try:
e2, r2 = None, list(batched_recipe(iterable, n))
except Exception as e:
e2, r2 = type(e), None
self.assertEqual(r1, r2)
self.assertEqual(e1, e2)
@staticmethod
def islice(iterable, *args):
s = slice(*args)
@ -1788,6 +1851,10 @@ class TestGC(unittest.TestCase):
a = []
self.makecycle(accumulate([1,2,a,3]), a)
def test_batched(self):
a = []
self.makecycle(batched([1,2,a,3], 2), a)
def test_chain(self):
a = []
self.makecycle(chain(a), a)
@ -1972,6 +2039,18 @@ class TestVariousIteratorArgs(unittest.TestCase):
self.assertRaises(TypeError, accumulate, N(s))
self.assertRaises(ZeroDivisionError, list, accumulate(E(s)))
def test_batched(self):
s = 'abcde'
r = [['a', 'b'], ['c', 'd'], ['e']]
n = 2
for g in (G, I, Ig, L, R):
with self.subTest(g=g):
self.assertEqual(list(batched(g(s), n)), r)
self.assertEqual(list(batched(S(s), 2)), [])
self.assertRaises(TypeError, batched, X(s), 2)
self.assertRaises(TypeError, batched, N(s), 2)
self.assertRaises(ZeroDivisionError, list, batched(E(s), 2))
def test_chain(self):
for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
for g in (G, I, Ig, S, L, R):