Convert test_heapq.py to unittests.

This commit is contained in:
Raymond Hettinger 2004-06-10 05:07:18 +00:00
parent 33ecffb65a
commit bce036b49e

View file

@ -1,76 +1,81 @@
"""Unittests for heapq.""" """Unittests for heapq."""
from test.test_support import verify, vereq, verbose, TestFailed
from heapq import heappush, heappop, heapify, heapreplace, nlargest, nsmallest from heapq import heappush, heappop, heapify, heapreplace, nlargest, nsmallest
import random import random
import unittest
from test import test_support
def check_invariant(heap):
# Check the heap invariant.
for pos, item in enumerate(heap):
if pos: # pos 0 has no parent
parentpos = (pos-1) >> 1
verify(heap[parentpos] <= item)
def heapiter(heap):
# An iterator returning a heap's elements, smallest-first. # An iterator returning a heap's elements, smallest-first.
class heapiter(object):
def __init__(self, heap):
self.heap = heap
def next(self):
try: try:
return heappop(self.heap) while 1:
yield heappop(heap)
except IndexError: except IndexError:
raise StopIteration pass
def __iter__(self): class TestHeap(unittest.TestCase):
return self
def test_main(): def test_push_pop(self):
# 1) Push 100 random numbers and pop them off, verifying all's OK. # 1) Push 256 random numbers and pop them off, verifying all's OK.
heap = [] heap = []
data = [] data = []
check_invariant(heap) self.check_invariant(heap)
for i in range(256): for i in range(256):
item = random.random() item = random.random()
data.append(item) data.append(item)
heappush(heap, item) heappush(heap, item)
check_invariant(heap) self.check_invariant(heap)
results = [] results = []
while heap: while heap:
item = heappop(heap) item = heappop(heap)
check_invariant(heap) self.check_invariant(heap)
results.append(item) results.append(item)
data_sorted = data[:] data_sorted = data[:]
data_sorted.sort() data_sorted.sort()
vereq(data_sorted, results) self.assertEqual(data_sorted, results)
# 2) Check that the invariant holds for a sorted array # 2) Check that the invariant holds for a sorted array
check_invariant(results) self.check_invariant(results)
# 3) Naive "N-best" algorithm
def check_invariant(self, heap):
# Check the heap invariant.
for pos, item in enumerate(heap):
if pos: # pos 0 has no parent
parentpos = (pos-1) >> 1
self.assert_(heap[parentpos] <= item)
def test_heapify(self):
for size in range(30):
heap = [random.random() for dummy in range(size)]
heapify(heap)
self.check_invariant(heap)
def test_naive_nbest(self):
data = [random.randrange(2000) for i in range(1000)]
heap = [] heap = []
for item in data: for item in data:
heappush(heap, item) heappush(heap, item)
if len(heap) > 10: if len(heap) > 10:
heappop(heap) heappop(heap)
heap.sort() heap.sort()
vereq(heap, data_sorted[-10:]) self.assertEqual(heap, sorted(data)[-10:])
# 4) Test heapify.
for size in range(30): def test_nbest(self):
heap = [random.random() for dummy in range(size)] # Less-naive "N-best" algorithm, much faster (if len(data) is big
heapify(heap)
check_invariant(heap)
# 5) Less-naive "N-best" algorithm, much faster (if len(data) is big
# enough <wink>) than sorting all of data. However, if we had a max # enough <wink>) than sorting all of data. However, if we had a max
# heap instead of a min heap, it could go faster still via # heap instead of a min heap, it could go faster still via
# heapify'ing all of data (linear time), then doing 10 heappops # heapify'ing all of data (linear time), then doing 10 heappops
# (10 log-time steps). # (10 log-time steps).
data = [random.randrange(2000) for i in range(1000)]
heap = data[:10] heap = data[:10]
heapify(heap) heapify(heap)
for item in data[10:]: for item in data[10:]:
if item > heap[0]: # this gets rarer the longer we run if item > heap[0]: # this gets rarer the longer we run
heapreplace(heap, item) heapreplace(heap, item)
vereq(list(heapiter(heap)), data_sorted[-10:]) self.assertEqual(list(heapiter(heap)), sorted(data)[-10:])
# 6) Exercise everything with repeated heapsort checks
def test_heapsort(self):
# Exercise everything with repeated heapsort checks
for trial in xrange(100): for trial in xrange(100):
size = random.randrange(50) size = random.randrange(50)
data = [random.randrange(25) for i in range(size)] data = [random.randrange(25) for i in range(size)]
@ -81,21 +86,20 @@ def test_main():
heap = [] heap = []
for item in data: for item in data:
heappush(heap, item) heappush(heap, item)
data.sort() heap_sorted = [heappop(heap) for i in range(size)]
sorted = [heappop(heap) for i in range(size)] self.assertEqual(heap_sorted, sorted(data))
vereq(data, sorted)
# 7) Check nlargest() and nsmallest() def test_nsmallest(self):
data = [random.randrange(2000) for i in range(1000)] data = [random.randrange(2000) for i in range(1000)]
copy = data[:] self.assertEqual(nsmallest(data, 400), sorted(data)[:400])
copy.sort(reverse=True)
vereq(nlargest(data, 400), copy[:400])
copy.sort()
vereq(nsmallest(data, 400), copy[:400])
# Make user happy def test_largest(self):
if verbose: data = [random.randrange(2000) for i in range(1000)]
print "All OK" self.assertEqual(nlargest(data, 400), sorted(data, reverse=True)[:400])
def test_main():
test_support.run_unittest(TestHeap)
if __name__ == "__main__": if __name__ == "__main__":
test_main() test_main()