mirror of
https://github.com/python/cpython.git
synced 2025-12-04 08:34:25 +00:00
Rework integer overflow path in math.prod and add more tests (GH-11809)
The overflow check was relying on undefined behaviour as it was using the result of the multiplication to do the check, and once the overflow has already happened, any operation on the result is undefined behaviour. Some extra checks that exercise code paths related to this are also added.
This commit is contained in:
parent
62fa51f121
commit
0411411c6b
2 changed files with 137 additions and 40 deletions
|
|
@ -1595,6 +1595,92 @@ class MathTests(unittest.TestCase):
|
|||
self.fail('Failures in test_mtestfile:\n ' +
|
||||
'\n '.join(failures))
|
||||
|
||||
def test_prod(self):
|
||||
prod = math.prod
|
||||
self.assertEqual(prod([]), 1)
|
||||
self.assertEqual(prod([], start=5), 5)
|
||||
self.assertEqual(prod(list(range(2,8))), 5040)
|
||||
self.assertEqual(prod(iter(list(range(2,8)))), 5040)
|
||||
self.assertEqual(prod(range(1, 10), start=10), 3628800)
|
||||
|
||||
self.assertEqual(prod([1, 2, 3, 4, 5]), 120)
|
||||
self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0)
|
||||
self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0)
|
||||
self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0)
|
||||
|
||||
# Test overflow in fast-path for integers
|
||||
self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32)
|
||||
# Test overflow in fast-path for floats
|
||||
self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32))
|
||||
|
||||
self.assertRaises(TypeError, prod)
|
||||
self.assertRaises(TypeError, prod, 42)
|
||||
self.assertRaises(TypeError, prod, ['a', 'b', 'c'])
|
||||
self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '')
|
||||
self.assertRaises(TypeError, prod, [b'a', b'c'], b'')
|
||||
values = [bytearray(b'a'), bytearray(b'b')]
|
||||
self.assertRaises(TypeError, prod, values, bytearray(b''))
|
||||
self.assertRaises(TypeError, prod, [[1], [2], [3]])
|
||||
self.assertRaises(TypeError, prod, [{2:3}])
|
||||
self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3})
|
||||
self.assertRaises(TypeError, prod, [[1], [2], [3]], [])
|
||||
with self.assertRaises(TypeError):
|
||||
prod([10, 20], [30, 40]) # start is a keyword-only argument
|
||||
|
||||
self.assertEqual(prod([0, 1, 2, 3]), 0)
|
||||
self.assertEqual(prod([1, 0, 2, 3]), 0)
|
||||
self.assertEqual(prod([1, 2, 3, 0]), 0)
|
||||
|
||||
def _naive_prod(iterable, start=1):
|
||||
for elem in iterable:
|
||||
start *= elem
|
||||
return start
|
||||
|
||||
# Big integers
|
||||
|
||||
iterable = range(1, 10000)
|
||||
self.assertEqual(prod(iterable), _naive_prod(iterable))
|
||||
iterable = range(-10000, -1)
|
||||
self.assertEqual(prod(iterable), _naive_prod(iterable))
|
||||
iterable = range(-1000, 1000)
|
||||
self.assertEqual(prod(iterable), 0)
|
||||
|
||||
# Big floats
|
||||
|
||||
iterable = [float(x) for x in range(1, 1000)]
|
||||
self.assertEqual(prod(iterable), _naive_prod(iterable))
|
||||
iterable = [float(x) for x in range(-1000, -1)]
|
||||
self.assertEqual(prod(iterable), _naive_prod(iterable))
|
||||
iterable = [float(x) for x in range(-1000, 1000)]
|
||||
self.assertIsNaN(prod(iterable))
|
||||
|
||||
# Float tests
|
||||
|
||||
self.assertIsNaN(prod([1, 2, 3, float("nan"), 2, 3]))
|
||||
self.assertIsNaN(prod([1, 0, float("nan"), 2, 3]))
|
||||
self.assertIsNaN(prod([1, float("nan"), 0, 3]))
|
||||
self.assertIsNaN(prod([1, float("inf"), float("nan"),3]))
|
||||
self.assertIsNaN(prod([1, float("-inf"), float("nan"),3]))
|
||||
self.assertIsNaN(prod([1, float("nan"), float("inf"),3]))
|
||||
self.assertIsNaN(prod([1, float("nan"), float("-inf"),3]))
|
||||
|
||||
self.assertEqual(prod([1, 2, 3, float('inf'),-3,4]), float('-inf'))
|
||||
self.assertEqual(prod([1, 2, 3, float('-inf'),-3,4]), float('inf'))
|
||||
|
||||
self.assertIsNaN(prod([1,2,0,float('inf'), -3, 4]))
|
||||
self.assertIsNaN(prod([1,2,0,float('-inf'), -3, 4]))
|
||||
self.assertIsNaN(prod([1, 2, 3, float('inf'), -3, 0, 3]))
|
||||
self.assertIsNaN(prod([1, 2, 3, float('-inf'), -3, 0, 2]))
|
||||
|
||||
# Type preservation
|
||||
|
||||
self.assertEqual(type(prod([1, 2, 3, 4, 5, 6])), int)
|
||||
self.assertEqual(type(prod([1, 2.0, 3, 4, 5, 6])), float)
|
||||
self.assertEqual(type(prod(range(1, 10000))), int)
|
||||
self.assertEqual(type(prod(range(1, 10000), start=1.0)), float)
|
||||
self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])),
|
||||
decimal.Decimal)
|
||||
|
||||
# Custom assertions.
|
||||
|
||||
def assertIsNaN(self, value):
|
||||
|
|
@ -1724,41 +1810,6 @@ class IsCloseTests(unittest.TestCase):
|
|||
self.assertAllClose(fraction_examples, rel_tol=1e-8)
|
||||
self.assertAllNotClose(fraction_examples, rel_tol=1e-9)
|
||||
|
||||
def test_prod(self):
|
||||
prod = math.prod
|
||||
self.assertEqual(prod([]), 1)
|
||||
self.assertEqual(prod([], start=5), 5)
|
||||
self.assertEqual(prod(list(range(2,8))), 5040)
|
||||
self.assertEqual(prod(iter(list(range(2,8)))), 5040)
|
||||
self.assertEqual(prod(range(1, 10), start=10), 3628800)
|
||||
|
||||
self.assertEqual(prod([1, 2, 3, 4, 5]), 120)
|
||||
self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0)
|
||||
self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0)
|
||||
self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0)
|
||||
|
||||
# Test overflow in fast-path for integers
|
||||
self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32)
|
||||
# Test overflow in fast-path for floats
|
||||
self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32))
|
||||
|
||||
self.assertRaises(TypeError, prod)
|
||||
self.assertRaises(TypeError, prod, 42)
|
||||
self.assertRaises(TypeError, prod, ['a', 'b', 'c'])
|
||||
self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '')
|
||||
self.assertRaises(TypeError, prod, [b'a', b'c'], b'')
|
||||
values = [bytearray(b'a'), bytearray(b'b')]
|
||||
self.assertRaises(TypeError, prod, values, bytearray(b''))
|
||||
self.assertRaises(TypeError, prod, [[1], [2], [3]])
|
||||
self.assertRaises(TypeError, prod, [{2:3}])
|
||||
self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3})
|
||||
self.assertRaises(TypeError, prod, [[1], [2], [3]], [])
|
||||
with self.assertRaises(TypeError):
|
||||
prod([10, 20], [30, 40]) # start is a keyword-only argument
|
||||
|
||||
self.assertEqual(prod([0, 1, 2, 3]), 0)
|
||||
self.assertEqual(prod([1, 0, 2, 3]), 0)
|
||||
self.assertEqual(prod(range(10)), 0)
|
||||
|
||||
def test_main():
|
||||
from doctest import DocFileSuite
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue