gh-110067: Make max heap methods public and add missing ones (GH-130725)

Co-authored-by: Bénédikt Tran <10796600+picnixz@users.noreply.github.com>
Co-authored-by: Petr Viktorin <encukou@gmail.com>
This commit is contained in:
Stan Ulbrych 2025-05-05 16:52:49 +01:00 committed by GitHub
parent bb5ec6ea6e
commit f5b784741d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 524 additions and 123 deletions

View file

@ -13,8 +13,9 @@ c_heapq = import_helper.import_fresh_module('heapq', fresh=['_heapq'])
# _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when
# _heapq is imported, so check them there
func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace',
'_heappop_max', '_heapreplace_max', '_heapify_max']
func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace']
# Add max-heap variants
func_names += [func + '_max' for func in func_names]
class TestModules(TestCase):
def test_py_functions(self):
@ -24,7 +25,7 @@ class TestModules(TestCase):
@skipUnless(c_heapq, 'requires _heapq')
def test_c_functions(self):
for fname in func_names:
self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq', fname)
def load_tests(loader, tests, ignore):
@ -74,6 +75,34 @@ class TestHeap:
except AttributeError:
pass
def test_max_push_pop(self):
# 1) Push 256 random numbers and pop them off, verifying all's OK.
heap = []
data = []
self.check_max_invariant(heap)
for i in range(256):
item = random.random()
data.append(item)
self.module.heappush_max(heap, item)
self.check_max_invariant(heap)
results = []
while heap:
item = self.module.heappop_max(heap)
self.check_max_invariant(heap)
results.append(item)
data_sorted = data[:]
data_sorted.sort(reverse=True)
self.assertEqual(data_sorted, results)
# 2) Check that the invariant holds for a sorted array
self.check_max_invariant(results)
self.assertRaises(TypeError, self.module.heappush_max, [])
exc_types = (AttributeError, TypeError)
self.assertRaises(exc_types, self.module.heappush_max, None, None)
self.assertRaises(exc_types, self.module.heappop_max, None)
def check_invariant(self, heap):
# Check the heap invariant.
for pos, item in enumerate(heap):
@ -81,6 +110,11 @@ class TestHeap:
parentpos = (pos-1) >> 1
self.assertTrue(heap[parentpos] <= item)
def check_max_invariant(self, heap):
for pos, item in enumerate(heap[1:], start=1):
parentpos = (pos - 1) >> 1
self.assertGreaterEqual(heap[parentpos], item)
def test_heapify(self):
for size in list(range(30)) + [20000]:
heap = [random.random() for dummy in range(size)]
@ -89,6 +123,14 @@ class TestHeap:
self.assertRaises(TypeError, self.module.heapify, None)
def test_heapify_max(self):
for size in list(range(30)) + [20000]:
heap = [random.random() for dummy in range(size)]
self.module.heapify_max(heap)
self.check_max_invariant(heap)
self.assertRaises(TypeError, self.module.heapify_max, None)
def test_naive_nbest(self):
data = [random.randrange(2000) for i in range(1000)]
heap = []
@ -109,10 +151,7 @@ class TestHeap:
def test_nbest(self):
# Less-naive "N-best" algorithm, much faster (if len(data) is big
# enough <wink>) than sorting all of data. However, if we had a max
# heap instead of a min heap, it could go faster still via
# heapify'ing all of data (linear time), then doing 10 heappops
# (10 log-time steps).
# enough <wink>) than sorting all of data.
data = [random.randrange(2000) for i in range(1000)]
heap = data[:10]
self.module.heapify(heap)
@ -125,6 +164,17 @@ class TestHeap:
self.assertRaises(TypeError, self.module.heapreplace, None, None)
self.assertRaises(IndexError, self.module.heapreplace, [], None)
def test_nbest_maxheap(self):
# With a max heap instead of a min heap, the "N-best" algorithm can
# go even faster still via heapify'ing all of data (linear time), then
# doing 10 heappops (10 log-time steps).
data = [random.randrange(2000) for i in range(1000)]
heap = data[:]
self.module.heapify_max(heap)
result = [self.module.heappop_max(heap) for _ in range(10)]
result.reverse()
self.assertEqual(result, sorted(data)[-10:])
def test_nbest_with_pushpop(self):
data = [random.randrange(2000) for i in range(1000)]
heap = data[:10]
@ -134,6 +184,62 @@ class TestHeap:
self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
self.assertEqual(self.module.heappushpop([], 'x'), 'x')
def test_naive_nworst(self):
# Max-heap variant of "test_naive_nbest"
data = [random.randrange(2000) for i in range(1000)]
heap = []
for item in data:
self.module.heappush_max(heap, item)
if len(heap) > 10:
self.module.heappop_max(heap)
heap.sort()
expected = sorted(data)[:10]
self.assertEqual(heap, expected)
def heapiter_max(self, heap):
# An iterator returning a max-heap's elements, largest-first.
try:
while 1:
yield self.module.heappop_max(heap)
except IndexError:
pass
def test_nworst(self):
# Max-heap variant of "test_nbest"
data = [random.randrange(2000) for i in range(1000)]
heap = data[:10]
self.module.heapify_max(heap)
for item in data[10:]:
if item < heap[0]: # this gets rarer the longer we run
self.module.heapreplace_max(heap, item)
expected = sorted(data, reverse=True)[-10:]
self.assertEqual(list(self.heapiter_max(heap)), expected)
self.assertRaises(TypeError, self.module.heapreplace_max, None)
self.assertRaises(TypeError, self.module.heapreplace_max, None, None)
self.assertRaises(IndexError, self.module.heapreplace_max, [], None)
def test_nworst_minheap(self):
# Min-heap variant of "test_nbest_maxheap"
data = [random.randrange(2000) for i in range(1000)]
heap = data[:]
self.module.heapify(heap)
result = [self.module.heappop(heap) for _ in range(10)]
result.reverse()
expected = sorted(data, reverse=True)[-10:]
self.assertEqual(result, expected)
def test_nworst_with_pushpop(self):
# Max-heap variant of "test_nbest_with_pushpop"
data = [random.randrange(2000) for i in range(1000)]
heap = data[:10]
self.module.heapify_max(heap)
for item in data[10:]:
self.module.heappushpop_max(heap, item)
expected = sorted(data, reverse=True)[-10:]
self.assertEqual(list(self.heapiter_max(heap)), expected)
self.assertEqual(self.module.heappushpop_max([], 'x'), 'x')
def test_heappushpop(self):
h = []
x = self.module.heappushpop(h, 10)
@ -153,12 +259,31 @@ class TestHeap:
x = self.module.heappushpop(h, 11)
self.assertEqual((h, x), ([11], 10))
def test_heappushpop_max(self):
h = []
x = self.module.heappushpop_max(h, 10)
self.assertTupleEqual((h, x), ([], 10))
h = [10]
x = self.module.heappushpop_max(h, 10.0)
self.assertTupleEqual((h, x), ([10], 10.0))
self.assertIsInstance(h[0], int)
self.assertIsInstance(x, float)
h = [10]
x = self.module.heappushpop_max(h, 11)
self.assertTupleEqual((h, x), ([10], 11))
h = [10]
x = self.module.heappushpop_max(h, 9)
self.assertTupleEqual((h, x), ([9], 10))
def test_heappop_max(self):
# _heapop_max has an optimization for one-item lists which isn't
# heapop_max has an optimization for one-item lists which isn't
# covered in other tests, so test that case explicitly here
h = [3, 2]
self.assertEqual(self.module._heappop_max(h), 3)
self.assertEqual(self.module._heappop_max(h), 2)
self.assertEqual(self.module.heappop_max(h), 3)
self.assertEqual(self.module.heappop_max(h), 2)
def test_heapsort(self):
# Exercise everything with repeated heapsort checks
@ -175,6 +300,20 @@ class TestHeap:
heap_sorted = [self.module.heappop(heap) for i in range(size)]
self.assertEqual(heap_sorted, sorted(data))
def test_heapsort_max(self):
for trial in range(100):
size = random.randrange(50)
data = [random.randrange(25) for i in range(size)]
if trial & 1: # Half of the time, use heapify_max
heap = data[:]
self.module.heapify_max(heap)
else: # The rest of the time, use heappush_max
heap = []
for item in data:
self.module.heappush_max(heap, item)
heap_sorted = [self.module.heappop_max(heap) for i in range(size)]
self.assertEqual(heap_sorted, sorted(data, reverse=True))
def test_merge(self):
inputs = []
for i in range(random.randrange(25)):
@ -377,16 +516,20 @@ class SideEffectLT:
class TestErrorHandling:
def test_non_sequence(self):
for f in (self.module.heapify, self.module.heappop):
for f in (self.module.heapify, self.module.heappop,
self.module.heapify_max, self.module.heappop_max):
self.assertRaises((TypeError, AttributeError), f, 10)
for f in (self.module.heappush, self.module.heapreplace,
self.module.heappush_max, self.module.heapreplace_max,
self.module.nlargest, self.module.nsmallest):
self.assertRaises((TypeError, AttributeError), f, 10, 10)
def test_len_only(self):
for f in (self.module.heapify, self.module.heappop):
for f in (self.module.heapify, self.module.heappop,
self.module.heapify_max, self.module.heappop_max):
self.assertRaises((TypeError, AttributeError), f, LenOnly())
for f in (self.module.heappush, self.module.heapreplace):
for f in (self.module.heappush, self.module.heapreplace,
self.module.heappush_max, self.module.heapreplace_max):
self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10)
for f in (self.module.nlargest, self.module.nsmallest):
self.assertRaises(TypeError, f, 2, LenOnly())
@ -395,7 +538,8 @@ class TestErrorHandling:
seq = [CmpErr(), CmpErr(), CmpErr()]
for f in (self.module.heapify, self.module.heappop):
self.assertRaises(ZeroDivisionError, f, seq)
for f in (self.module.heappush, self.module.heapreplace):
for f in (self.module.heappush, self.module.heapreplace,
self.module.heappush_max, self.module.heapreplace_max):
self.assertRaises(ZeroDivisionError, f, seq, 10)
for f in (self.module.nlargest, self.module.nsmallest):
self.assertRaises(ZeroDivisionError, f, 2, seq)
@ -403,6 +547,8 @@ class TestErrorHandling:
def test_arg_parsing(self):
for f in (self.module.heapify, self.module.heappop,
self.module.heappush, self.module.heapreplace,
self.module.heapify_max, self.module.heappop_max,
self.module.heappush_max, self.module.heapreplace_max,
self.module.nlargest, self.module.nsmallest):
self.assertRaises((TypeError, AttributeError), f, 10)
@ -424,6 +570,10 @@ class TestErrorHandling:
# Python version raises IndexError, C version RuntimeError
with self.assertRaises((IndexError, RuntimeError)):
self.module.heappush(heap, SideEffectLT(5, heap))
heap = []
heap.extend(SideEffectLT(i, heap) for i in range(200))
with self.assertRaises((IndexError, RuntimeError)):
self.module.heappush_max(heap, SideEffectLT(5, heap))
def test_heappop_mutating_heap(self):
heap = []
@ -431,8 +581,12 @@ class TestErrorHandling:
# Python version raises IndexError, C version RuntimeError
with self.assertRaises((IndexError, RuntimeError)):
self.module.heappop(heap)
heap = []
heap.extend(SideEffectLT(i, heap) for i in range(200))
with self.assertRaises((IndexError, RuntimeError)):
self.module.heappop_max(heap)
def test_comparison_operator_modifiying_heap(self):
def test_comparison_operator_modifying_heap(self):
# See bpo-39421: Strong references need to be taken
# when comparing objects as they can alter the heap
class EvilClass(int):
@ -444,7 +598,7 @@ class TestErrorHandling:
self.module.heappush(heap, EvilClass(0))
self.assertRaises(IndexError, self.module.heappushpop, heap, 1)
def test_comparison_operator_modifiying_heap_two_heaps(self):
def test_comparison_operator_modifying_heap_two_heaps(self):
class h(int):
def __lt__(self, o):
@ -464,6 +618,17 @@ class TestErrorHandling:
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))
list1, list2 = [], []
self.module.heappush_max(list1, h(0))
self.module.heappush_max(list2, g(0))
self.module.heappush_max(list1, g(1))
self.module.heappush_max(list2, h(1))
self.assertRaises((IndexError, RuntimeError), self.module.heappush_max, list1, g(1))
self.assertRaises((IndexError, RuntimeError), self.module.heappush_max, list2, h(1))
class TestErrorHandlingPython(TestErrorHandling, TestCase):
module = py_heapq