mirror of
https://github.com/python/cpython.git
synced 2025-08-31 05:58:33 +00:00
GH-98363: Add itertools.batched() (GH-98364)
This commit is contained in:
parent
70732d8a4c
commit
de3ece769a
5 changed files with 370 additions and 39 deletions
|
@ -159,6 +159,44 @@ class TestBasicOps(unittest.TestCase):
|
|||
with self.assertRaises(TypeError):
|
||||
list(accumulate([10, 20], 100))
|
||||
|
||||
def test_batched(self):
|
||||
self.assertEqual(list(batched('ABCDEFG', 3)),
|
||||
[['A', 'B', 'C'], ['D', 'E', 'F'], ['G']])
|
||||
self.assertEqual(list(batched('ABCDEFG', 2)),
|
||||
[['A', 'B'], ['C', 'D'], ['E', 'F'], ['G']])
|
||||
self.assertEqual(list(batched('ABCDEFG', 1)),
|
||||
[['A'], ['B'], ['C'], ['D'], ['E'], ['F'], ['G']])
|
||||
|
||||
with self.assertRaises(TypeError): # Too few arguments
|
||||
list(batched('ABCDEFG'))
|
||||
with self.assertRaises(TypeError):
|
||||
list(batched('ABCDEFG', 3, None)) # Too many arguments
|
||||
with self.assertRaises(TypeError):
|
||||
list(batched(None, 3)) # Non-iterable input
|
||||
with self.assertRaises(TypeError):
|
||||
list(batched('ABCDEFG', 'hello')) # n is a string
|
||||
with self.assertRaises(ValueError):
|
||||
list(batched('ABCDEFG', 0)) # n is zero
|
||||
with self.assertRaises(ValueError):
|
||||
list(batched('ABCDEFG', -1)) # n is negative
|
||||
|
||||
data = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
||||
for n in range(1, 6):
|
||||
for i in range(len(data)):
|
||||
s = data[:i]
|
||||
batches = list(batched(s, n))
|
||||
with self.subTest(s=s, n=n, batches=batches):
|
||||
# Order is preserved and no data is lost
|
||||
self.assertEqual(''.join(chain(*batches)), s)
|
||||
# Each batch is an exact list
|
||||
self.assertTrue(all(type(batch) is list for batch in batches))
|
||||
# All but the last batch is of size n
|
||||
if batches:
|
||||
last_batch = batches.pop()
|
||||
self.assertTrue(all(len(batch) == n for batch in batches))
|
||||
self.assertTrue(len(last_batch) <= n)
|
||||
batches.append(last_batch)
|
||||
|
||||
def test_chain(self):
|
||||
|
||||
def chain2(*iterables):
|
||||
|
@ -1737,6 +1775,31 @@ class TestExamples(unittest.TestCase):
|
|||
|
||||
class TestPurePythonRoughEquivalents(unittest.TestCase):
|
||||
|
||||
def test_batched_recipe(self):
|
||||
def batched_recipe(iterable, n):
|
||||
"Batch data into lists of length n. The last batch may be shorter."
|
||||
# batched('ABCDEFG', 3) --> ABC DEF G
|
||||
if n < 1:
|
||||
raise ValueError('n must be at least one')
|
||||
it = iter(iterable)
|
||||
while (batch := list(islice(it, n))):
|
||||
yield batch
|
||||
|
||||
for iterable, n in product(
|
||||
['', 'a', 'ab', 'abc', 'abcd', 'abcde', 'abcdef', 'abcdefg', None],
|
||||
[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, None]):
|
||||
with self.subTest(iterable=iterable, n=n):
|
||||
try:
|
||||
e1, r1 = None, list(batched(iterable, n))
|
||||
except Exception as e:
|
||||
e1, r1 = type(e), None
|
||||
try:
|
||||
e2, r2 = None, list(batched_recipe(iterable, n))
|
||||
except Exception as e:
|
||||
e2, r2 = type(e), None
|
||||
self.assertEqual(r1, r2)
|
||||
self.assertEqual(e1, e2)
|
||||
|
||||
@staticmethod
|
||||
def islice(iterable, *args):
|
||||
s = slice(*args)
|
||||
|
@ -1788,6 +1851,10 @@ class TestGC(unittest.TestCase):
|
|||
a = []
|
||||
self.makecycle(accumulate([1,2,a,3]), a)
|
||||
|
||||
def test_batched(self):
|
||||
a = []
|
||||
self.makecycle(batched([1,2,a,3], 2), a)
|
||||
|
||||
def test_chain(self):
|
||||
a = []
|
||||
self.makecycle(chain(a), a)
|
||||
|
@ -1972,6 +2039,18 @@ class TestVariousIteratorArgs(unittest.TestCase):
|
|||
self.assertRaises(TypeError, accumulate, N(s))
|
||||
self.assertRaises(ZeroDivisionError, list, accumulate(E(s)))
|
||||
|
||||
def test_batched(self):
|
||||
s = 'abcde'
|
||||
r = [['a', 'b'], ['c', 'd'], ['e']]
|
||||
n = 2
|
||||
for g in (G, I, Ig, L, R):
|
||||
with self.subTest(g=g):
|
||||
self.assertEqual(list(batched(g(s), n)), r)
|
||||
self.assertEqual(list(batched(S(s), 2)), [])
|
||||
self.assertRaises(TypeError, batched, X(s), 2)
|
||||
self.assertRaises(TypeError, batched, N(s), 2)
|
||||
self.assertRaises(ZeroDivisionError, list, batched(E(s), 2))
|
||||
|
||||
def test_chain(self):
|
||||
for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
|
||||
for g in (G, I, Ig, S, L, R):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue