Optimize heapq.nsmallest/nlargest for cases where n==1 or n>=size.

This commit is contained in:
Raymond Hettinger 2009-01-12 10:37:32 +00:00
parent c22ab18e91
commit b5bc33cdab

View file

@ -129,7 +129,7 @@ From all times, sorting has always been a Great Art! :-)
__all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'merge', __all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'merge',
'nlargest', 'nsmallest', 'heappushpop'] 'nlargest', 'nsmallest', 'heappushpop']
from itertools import islice, repeat, count, imap, izip, tee from itertools import islice, repeat, count, imap, izip, tee, chain
from operator import itemgetter, neg from operator import itemgetter, neg
import bisect import bisect
@ -354,10 +354,32 @@ def nsmallest(n, iterable, key=None):
Equivalent to: sorted(iterable, key=key)[:n] Equivalent to: sorted(iterable, key=key)[:n]
""" """
# Short-cut for n==1 is to use min() when len(iterable)>0
if n == 1:
it = iter(iterable)
head = list(islice(it, 1))
if not head:
return []
if key is None:
return [min(chain(head, it))]
return [min(chain(head, it), key=key)]
# When n>=size, it's faster to use sort()
try:
size = len(iterable)
except (TypeError, AttributeError):
pass
else:
if n >= size:
return sorted(iterable, key=key)[:n]
# When key is none, use simpler decoration
if key is None: if key is None:
it = izip(iterable, count()) # decorate it = izip(iterable, count()) # decorate
result = _nsmallest(n, it) result = _nsmallest(n, it)
return map(itemgetter(0), result) # undecorate return map(itemgetter(0), result) # undecorate
# General case, slowest method
in1, in2 = tee(iterable) in1, in2 = tee(iterable)
it = izip(imap(key, in1), count(), in2) # decorate it = izip(imap(key, in1), count(), in2) # decorate
result = _nsmallest(n, it) result = _nsmallest(n, it)
@ -369,10 +391,33 @@ def nlargest(n, iterable, key=None):
Equivalent to: sorted(iterable, key=key, reverse=True)[:n] Equivalent to: sorted(iterable, key=key, reverse=True)[:n]
""" """
# Short-cut for n==1 is to use max() when len(iterable)>0
if n == 1:
it = iter(iterable)
head = list(islice(it, 1))
if not head:
return []
if key is None:
return [max(chain(head, it))]
return [max(chain(head, it), key=key)]
# When n>=size, it's faster to use sort()
try:
size = len(iterable)
except (TypeError, AttributeError):
pass
else:
if n >= size:
return sorted(iterable, key=key, reverse=True)[:n]
# When key is none, use simpler decoration
if key is None: if key is None:
it = izip(iterable, imap(neg, count())) # decorate it = izip(iterable, imap(neg, count())) # decorate
result = _nlargest(n, it) result = _nlargest(n, it)
return map(itemgetter(0), result) # undecorate return map(itemgetter(0), result) # undecorate
# General case, slowest method
in1, in2 = tee(iterable) in1, in2 = tee(iterable)
it = izip(imap(key, in1), imap(neg, count()), in2) # decorate it = izip(imap(key, in1), imap(neg, count()), in2) # decorate
result = _nlargest(n, it) result = _nlargest(n, it)