Added itertools.tee()

It works like the pure python verion except:
* it stops storing data after of the iterators gets deallocated
* the data queue is implemented with two stacks instead of one dictionary.
This commit is contained in:
Raymond Hettinger 2003-10-24 08:45:23 +00:00
parent 16b9fa8db3
commit 6a5b027742
4 changed files with 446 additions and 93 deletions

View file

@ -3,6 +3,7 @@ from test import test_support
from itertools import *
import sys
import operator
import random
def onearg(x):
'Test function of one argument'
@ -198,6 +199,50 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(TypeError, dropwhile(10, [(4,5)]).next)
self.assertRaises(ValueError, dropwhile(errfunc, [(4,5)]).next)
def test_tee(self):
n = 100
def irange(n):
for i in xrange(n):
yield i
a, b = tee([]) # test empty iterator
self.assertEqual(list(a), [])
self.assertEqual(list(b), [])
a, b = tee(irange(n)) # test 100% interleaved
self.assertEqual(zip(a,b), zip(range(n),range(n)))
a, b = tee(irange(n)) # test 0% interleaved
self.assertEqual(list(a), range(n))
self.assertEqual(list(b), range(n))
a, b = tee(irange(n)) # test dealloc of leading iterator
self.assertEqual(a.next(), 0)
self.assertEqual(a.next(), 1)
del a
self.assertEqual(list(b), range(n))
a, b = tee(irange(n)) # test dealloc of trailing iterator
self.assertEqual(a.next(), 0)
self.assertEqual(a.next(), 1)
del b
self.assertEqual(list(a), range(2, n))
for j in xrange(5): # test randomly interleaved
order = [0]*n + [1]*n
random.shuffle(order)
lists = ([], [])
its = tee(irange(n))
for i in order:
value = its[i].next()
lists[i].append(value)
self.assertEqual(lists[0], range(n))
self.assertEqual(lists[1], range(n))
self.assertRaises(TypeError, tee)
self.assertRaises(TypeError, tee, 3)
self.assertRaises(TypeError, tee, [1,2], 'x')
def test_StopIteration(self):
self.assertRaises(StopIteration, izip().next)
@ -208,12 +253,65 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(StopIteration, islice([], None).next)
self.assertRaises(StopIteration, islice(StopNow(), None).next)
p, q = tee([])
self.assertRaises(StopIteration, p.next)
self.assertRaises(StopIteration, q.next)
p, q = tee(StopNow())
self.assertRaises(StopIteration, p.next)
self.assertRaises(StopIteration, q.next)
self.assertRaises(StopIteration, repeat(None, 0).next)
for f in (ifilter, ifilterfalse, imap, takewhile, dropwhile, starmap):
self.assertRaises(StopIteration, f(lambda x:x, []).next)
self.assertRaises(StopIteration, f(lambda x:x, StopNow()).next)
class TestGC(unittest.TestCase):
def makecycle(self, iterator, container):
container.append(iterator)
iterator.next()
del container, iterator
def test_chain(self):
a = []
self.makecycle(chain(a), a)
def test_cycle(self):
a = []
self.makecycle(cycle([a]*2), a)
def test_ifilter(self):
a = []
self.makecycle(ifilter(lambda x:True, [a]*2), a)
def test_ifilterfalse(self):
a = []
self.makecycle(ifilterfalse(lambda x:False, a), a)
def test_izip(self):
a = []
self.makecycle(izip([a]*2, [a]*3), a)
def test_imap(self):
a = []
self.makecycle(imap(lambda x:x, [a]*2), a)
def test_islice(self):
a = []
self.makecycle(islice([a]*2, None), a)
def test_starmap(self):
a = []
self.makecycle(starmap(lambda *t: t, [(a,a)]*2), a)
def test_tee(self):
a = []
p, q = t = tee([a]*2)
a += [a, p, q, t]
p.next()
del a, p, q, t
def R(seqn):
'Regular generator'
for i in seqn:
@ -290,45 +388,6 @@ def L(seqn):
'Test multiple tiers of iterators'
return chain(imap(lambda x:x, R(Ig(G(seqn)))))
class TestGC(unittest.TestCase):
def makecycle(self, iterator, container):
container.append(iterator)
iterator.next()
del container, iterator
def test_chain(self):
a = []
self.makecycle(chain(a), a)
def test_cycle(self):
a = []
self.makecycle(cycle([a]*2), a)
def test_ifilter(self):
a = []
self.makecycle(ifilter(lambda x:True, [a]*2), a)
def test_ifilterfalse(self):
a = []
self.makecycle(ifilterfalse(lambda x:False, a), a)
def test_izip(self):
a = []
self.makecycle(izip([a]*2, [a]*3), a)
def test_imap(self):
a = []
self.makecycle(imap(lambda x:x, [a]*2), a)
def test_islice(self):
a = []
self.makecycle(islice([a]*2, None), a)
def test_starmap(self):
a = []
self.makecycle(starmap(lambda *t: t, [(a,a)]*2), a)
class TestVariousIteratorArgs(unittest.TestCase):
@ -427,6 +486,16 @@ class TestVariousIteratorArgs(unittest.TestCase):
self.assertRaises(TypeError, list, dropwhile(isOdd, N(s)))
self.assertRaises(ZeroDivisionError, list, dropwhile(isOdd, E(s)))
def test_tee(self):
for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
for g in (G, I, Ig, S, L, R):
it1, it2 = tee(g(s))
self.assertEqual(list(it1), list(g(s)))
self.assertEqual(list(it2), list(g(s)))
self.assertRaises(TypeError, tee, X(s))
self.assertRaises(TypeError, list, tee(N(s))[0])
self.assertRaises(ZeroDivisionError, list, tee(E(s))[0])
class RegressionTests(unittest.TestCase):
def test_sf_793826(self):
@ -531,6 +600,17 @@ Samuele
>>> def dotproduct(vec1, vec2):
... return sum(imap(operator.mul, vec1, vec2))
>>> def flatten(listOfLists):
... return list(chain(*listOfLists))
>>> def repeatfunc(func, times=None, *args):
... "Repeat calls to func with specified arguments."
... " Example: repeatfunc(random.random)"
... if times is None:
... return starmap(func, repeat(args))
... else:
... return starmap(func, repeat(args, times))
>>> def window(seq, n=2):
... "Returns a sliding window (of width n) over data from the iterable"
... " s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ... "
@ -542,20 +622,6 @@ Samuele
... result = result[1:] + (elem,)
... yield result
>>> def tee(iterable):
... "Return two independent iterators from a single iterable"
... def gen(next, data={}, cnt=[0]):
... dpop = data.pop
... for i in count():
... if i == cnt[0]:
... item = data[i] = next()
... cnt[0] += 1
... else:
... item = dpop(i)
... yield item
... next = iter(iterable).next
... return (gen(next), gen(next))
This is not part of the examples but it tests to make sure the definitions
perform as purported.
@ -592,6 +658,17 @@ False
>>> quantify(xrange(99), lambda x: x%2==0)
50
>>> a = [[1, 2, 3], [4, 5, 6]]
>>> flatten(a)
[1, 2, 3, 4, 5, 6]
>>> list(repeatfunc(pow, 5, 2, 3))
[8, 8, 8, 8, 8]
>>> import random
>>> take(5, imap(int, repeatfunc(random.random)))
[0, 0, 0, 0, 0]
>>> list(window('abc'))
[('a', 'b'), ('b', 'c')]
@ -607,14 +684,6 @@ False
>>> dotproduct([1,2,3], [4,5,6])
32
>>> x, y = tee(chain(xrange(2,10)))
>>> list(x), list(y)
([2, 3, 4, 5, 6, 7, 8, 9], [2, 3, 4, 5, 6, 7, 8, 9])
>>> x, y = tee(chain(xrange(2,10)))
>>> zip(x, y)
[(2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7), (8, 8), (9, 9)]
"""
__test__ = {'libreftest' : libreftest}