Fix stability of heapq's nlargest() and nsmallest().

This commit is contained in:
Raymond Hettinger 2007-01-04 17:53:34 +00:00
parent 2dc4db0174
commit 769a40a1d0
2 changed files with 14 additions and 18 deletions

View file

@ -130,7 +130,7 @@ __all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'nlargest',
'nsmallest'] 'nsmallest']
from itertools import islice, repeat, count, imap, izip, tee from itertools import islice, repeat, count, imap, izip, tee
from operator import itemgetter from operator import itemgetter, neg
import bisect import bisect
def heappush(heap, item): def heappush(heap, item):
@ -315,8 +315,6 @@ def nsmallest(n, iterable, key=None):
Equivalent to: sorted(iterable, key=key)[:n] Equivalent to: sorted(iterable, key=key)[:n]
""" """
if key is None:
return _nsmallest(n, iterable)
in1, in2 = tee(iterable) in1, in2 = tee(iterable)
it = izip(imap(key, in1), count(), in2) # decorate it = izip(imap(key, in1), count(), in2) # decorate
result = _nsmallest(n, it) result = _nsmallest(n, it)
@ -328,10 +326,8 @@ def nlargest(n, iterable, key=None):
Equivalent to: sorted(iterable, key=key, reverse=True)[:n] Equivalent to: sorted(iterable, key=key, reverse=True)[:n]
""" """
if key is None:
return _nlargest(n, iterable)
in1, in2 = tee(iterable) in1, in2 = tee(iterable)
it = izip(imap(key, in1), count(), in2) # decorate it = izip(imap(key, in1), imap(neg, count()), in2) # decorate
result = _nlargest(n, it) result = _nlargest(n, it)
return map(itemgetter(2), result) # undecorate return map(itemgetter(2), result) # undecorate

View file

@ -104,20 +104,20 @@ class TestHeap(unittest.TestCase):
self.assertEqual(heap_sorted, sorted(data)) self.assertEqual(heap_sorted, sorted(data))
def test_nsmallest(self): def test_nsmallest(self):
data = [random.randrange(2000) for i in range(1000)] data = [(random.randrange(2000), i) for i in range(1000)]
f = lambda x: x * 547 % 2000 for f in (None, lambda x: x[0] * 547 % 2000):
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
self.assertEqual(nsmallest(n, data), sorted(data)[:n]) self.assertEqual(nsmallest(n, data), sorted(data)[:n])
self.assertEqual(nsmallest(n, data, key=f), self.assertEqual(nsmallest(n, data, key=f),
sorted(data, key=f)[:n]) sorted(data, key=f)[:n])
def test_nlargest(self): def test_nlargest(self):
data = [random.randrange(2000) for i in range(1000)] data = [(random.randrange(2000), i) for i in range(1000)]
f = lambda x: x * 547 % 2000 for f in (None, lambda x: x[0] * 547 % 2000):
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
self.assertEqual(nlargest(n, data), sorted(data, reverse=True)[:n]) self.assertEqual(nlargest(n, data), sorted(data, reverse=True)[:n])
self.assertEqual(nlargest(n, data, key=f), self.assertEqual(nlargest(n, data, key=f),
sorted(data, key=f, reverse=True)[:n]) sorted(data, key=f, reverse=True)[:n])
#============================================================================== #==============================================================================