mirror of
https://github.com/python/cpython.git
synced 2025-12-04 08:34:25 +00:00
Issue #8692: Improve performance of math.factorial:
(1) use a different algorithm that roughly halves the total number of
multiplications required and results in more balanced multiplications
(2) use a lookup table for small arguments
(3) fast accumulation of products in C integer arithmetic rather than
PyLong arithmetic when possible.
Typical speedup, from unscientific testing on a 64-bit laptop, is 4.5x
to 6.5x for arguments in the range 100 - 10000.
Patch by Daniel Stutzbach; extensive reviews by Alexander Belopolsky.
This commit is contained in:
parent
ae6265f8d0
commit
4c8a9a2df3
3 changed files with 307 additions and 30 deletions
|
|
@ -60,6 +60,56 @@ def ulps_check(expected, got, ulps=20):
|
||||||
return "error = {} ulps; permitted error = {} ulps".format(ulps_error,
|
return "error = {} ulps; permitted error = {} ulps".format(ulps_error,
|
||||||
ulps)
|
ulps)
|
||||||
|
|
||||||
|
# Here's a pure Python version of the math.factorial algorithm, for
|
||||||
|
# documentation and comparison purposes.
|
||||||
|
#
|
||||||
|
# Formula:
|
||||||
|
#
|
||||||
|
# factorial(n) = factorial_odd_part(n) << (n - count_set_bits(n))
|
||||||
|
#
|
||||||
|
# where
|
||||||
|
#
|
||||||
|
# factorial_odd_part(n) = product_{i >= 0} product_{0 < j <= n >> i; j odd} j
|
||||||
|
#
|
||||||
|
# The outer product above is an infinite product, but once i >= n.bit_length,
|
||||||
|
# (n >> i) < 1 and the corresponding term of the product is empty. So only the
|
||||||
|
# finitely many terms for 0 <= i < n.bit_length() contribute anything.
|
||||||
|
#
|
||||||
|
# We iterate downwards from i == n.bit_length() - 1 to i == 0. The inner
|
||||||
|
# product in the formula above starts at 1 for i == n.bit_length(); for each i
|
||||||
|
# < n.bit_length() we get the inner product for i from that for i + 1 by
|
||||||
|
# multiplying by all j in {n >> i+1 < j <= n >> i; j odd}. In Python terms,
|
||||||
|
# this set is range((n >> i+1) + 1 | 1, (n >> i) + 1 | 1, 2).
|
||||||
|
|
||||||
|
def count_set_bits(n):
|
||||||
|
"""Number of '1' bits in binary expansion of a nonnnegative integer."""
|
||||||
|
return 1 + count_set_bits(n & n - 1) if n else 0
|
||||||
|
|
||||||
|
def partial_product(start, stop):
|
||||||
|
"""Product of integers in range(start, stop, 2), computed recursively.
|
||||||
|
start and stop should both be odd, with start <= stop.
|
||||||
|
|
||||||
|
"""
|
||||||
|
numfactors = (stop - start) >> 1
|
||||||
|
if not numfactors:
|
||||||
|
return 1
|
||||||
|
elif numfactors == 1:
|
||||||
|
return start
|
||||||
|
else:
|
||||||
|
mid = (start + numfactors) | 1
|
||||||
|
return partial_product(start, mid) * partial_product(mid, stop)
|
||||||
|
|
||||||
|
def py_factorial(n):
|
||||||
|
"""Factorial of nonnegative integer n, via "Binary Split Factorial Formula"
|
||||||
|
described at http://www.luschny.de/math/factorial/binarysplitfact.html
|
||||||
|
|
||||||
|
"""
|
||||||
|
inner = outer = 1
|
||||||
|
for i in reversed(range(n.bit_length())):
|
||||||
|
inner *= partial_product((n >> i + 1) + 1 | 1, (n >> i) + 1 | 1)
|
||||||
|
outer *= inner
|
||||||
|
return outer << (n - count_set_bits(n))
|
||||||
|
|
||||||
def acc_check(expected, got, rel_err=2e-15, abs_err = 5e-323):
|
def acc_check(expected, got, rel_err=2e-15, abs_err = 5e-323):
|
||||||
"""Determine whether non-NaN floats a and b are equal to within a
|
"""Determine whether non-NaN floats a and b are equal to within a
|
||||||
(small) rounding error. The default values for rel_err and
|
(small) rounding error. The default values for rel_err and
|
||||||
|
|
@ -365,18 +415,19 @@ class MathTests(unittest.TestCase):
|
||||||
self.ftest('fabs(1)', math.fabs(1), 1)
|
self.ftest('fabs(1)', math.fabs(1), 1)
|
||||||
|
|
||||||
def testFactorial(self):
|
def testFactorial(self):
|
||||||
def fact(n):
|
self.assertEqual(math.factorial(0), 1)
|
||||||
result = 1
|
self.assertEqual(math.factorial(0.0), 1)
|
||||||
for i in range(1, int(n)+1):
|
total = 1
|
||||||
result *= i
|
for i in range(1, 1000):
|
||||||
return result
|
total *= i
|
||||||
values = list(range(10)) + [50, 100, 500]
|
self.assertEqual(math.factorial(i), total)
|
||||||
random.shuffle(values)
|
self.assertEqual(math.factorial(float(i)), total)
|
||||||
for x in values:
|
self.assertEqual(math.factorial(i), py_factorial(i))
|
||||||
for cast in (int, float):
|
|
||||||
self.assertEqual(math.factorial(cast(x)), fact(x), (x, fact(x), math.factorial(x)))
|
|
||||||
self.assertRaises(ValueError, math.factorial, -1)
|
self.assertRaises(ValueError, math.factorial, -1)
|
||||||
|
self.assertRaises(ValueError, math.factorial, -1.0)
|
||||||
self.assertRaises(ValueError, math.factorial, math.pi)
|
self.assertRaises(ValueError, math.factorial, math.pi)
|
||||||
|
self.assertRaises(OverflowError, math.factorial, sys.maxsize+1)
|
||||||
|
self.assertRaises(OverflowError, math.factorial, 10e100)
|
||||||
|
|
||||||
def testFloor(self):
|
def testFloor(self):
|
||||||
self.assertRaises(TypeError, math.floor)
|
self.assertRaises(TypeError, math.floor)
|
||||||
|
|
|
||||||
|
|
@ -1132,6 +1132,12 @@ Library
|
||||||
Extension Modules
|
Extension Modules
|
||||||
-----------------
|
-----------------
|
||||||
|
|
||||||
|
- Issue #8692: Optimize math.factorial: replace the previous naive
|
||||||
|
algorithm with an improved 'binary-split' algorithm that uses fewer
|
||||||
|
multiplications and allows many of the multiplications to be
|
||||||
|
performed using plain C integer arithmetic instead of PyLong
|
||||||
|
arithmetic. Also uses a lookup table for small arguments.
|
||||||
|
|
||||||
- Issue #8674: Fixed a number of incorrect or undefined-behaviour-inducing
|
- Issue #8674: Fixed a number of incorrect or undefined-behaviour-inducing
|
||||||
overflow checks in the audioop module.
|
overflow checks in the audioop module.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1129,11 +1129,232 @@ PyDoc_STRVAR(math_fsum_doc,
|
||||||
Return an accurate floating point sum of values in the iterable.\n\
|
Return an accurate floating point sum of values in the iterable.\n\
|
||||||
Assumes IEEE-754 floating point arithmetic.");
|
Assumes IEEE-754 floating point arithmetic.");
|
||||||
|
|
||||||
|
/* Return the smallest integer k such that n < 2**k, or 0 if n == 0.
|
||||||
|
* Equivalent to floor(lg(x))+1. Also equivalent to: bitwidth_of_type -
|
||||||
|
* count_leading_zero_bits(x)
|
||||||
|
*/
|
||||||
|
|
||||||
|
/* XXX: This routine does more or less the same thing as
|
||||||
|
* bits_in_digit() in Objects/longobject.c. Someday it would be nice to
|
||||||
|
* consolidate them. On BSD, there's a library function called fls()
|
||||||
|
* that we could use, and GCC provides __builtin_clz().
|
||||||
|
*/
|
||||||
|
|
||||||
|
static unsigned long
|
||||||
|
bit_length(unsigned long n)
|
||||||
|
{
|
||||||
|
unsigned long len = 0;
|
||||||
|
while (n != 0) {
|
||||||
|
++len;
|
||||||
|
n >>= 1;
|
||||||
|
}
|
||||||
|
return len;
|
||||||
|
}
|
||||||
|
|
||||||
|
static unsigned long
|
||||||
|
count_set_bits(unsigned long n)
|
||||||
|
{
|
||||||
|
unsigned long count = 0;
|
||||||
|
while (n != 0) {
|
||||||
|
++count;
|
||||||
|
n &= n - 1; /* clear least significant bit */
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Divide-and-conquer factorial algorithm
|
||||||
|
*
|
||||||
|
* Based on the formula and psuedo-code provided at:
|
||||||
|
* http://www.luschny.de/math/factorial/binarysplitfact.html
|
||||||
|
*
|
||||||
|
* Faster algorithms exist, but they're more complicated and depend on
|
||||||
|
* a fast prime factoriazation algorithm.
|
||||||
|
*
|
||||||
|
* Notes on the algorithm
|
||||||
|
* ----------------------
|
||||||
|
*
|
||||||
|
* factorial(n) is written in the form 2**k * m, with m odd. k and m are
|
||||||
|
* computed separately, and then combined using a left shift.
|
||||||
|
*
|
||||||
|
* The function factorial_odd_part computes the odd part m (i.e., the greatest
|
||||||
|
* odd divisor) of factorial(n), using the formula:
|
||||||
|
*
|
||||||
|
* factorial_odd_part(n) =
|
||||||
|
*
|
||||||
|
* product_{i >= 0} product_{0 < j <= n / 2**i, j odd} j
|
||||||
|
*
|
||||||
|
* Example: factorial_odd_part(20) =
|
||||||
|
*
|
||||||
|
* (1) *
|
||||||
|
* (1) *
|
||||||
|
* (1 * 3 * 5) *
|
||||||
|
* (1 * 3 * 5 * 7 * 9)
|
||||||
|
* (1 * 3 * 5 * 7 * 9 * 11 * 13 * 15 * 17 * 19)
|
||||||
|
*
|
||||||
|
* Here i goes from large to small: the first term corresponds to i=4 (any
|
||||||
|
* larger i gives an empty product), and the last term corresponds to i=0.
|
||||||
|
* Each term can be computed from the last by multiplying by the extra odd
|
||||||
|
* numbers required: e.g., to get from the penultimate term to the last one,
|
||||||
|
* we multiply by (11 * 13 * 15 * 17 * 19).
|
||||||
|
*
|
||||||
|
* To see a hint of why this formula works, here are the same numbers as above
|
||||||
|
* but with the even parts (i.e., the appropriate powers of 2) included. For
|
||||||
|
* each subterm in the product for i, we multiply that subterm by 2**i:
|
||||||
|
*
|
||||||
|
* factorial(20) =
|
||||||
|
*
|
||||||
|
* (16) *
|
||||||
|
* (8) *
|
||||||
|
* (4 * 12 * 20) *
|
||||||
|
* (2 * 6 * 10 * 14 * 18) *
|
||||||
|
* (1 * 3 * 5 * 7 * 9 * 11 * 13 * 15 * 17 * 19)
|
||||||
|
*
|
||||||
|
* The factorial_partial_product function computes the product of all odd j in
|
||||||
|
* range(start, stop) for given start and stop. It's used to compute the
|
||||||
|
* partial products like (11 * 13 * 15 * 17 * 19) in the example above. It
|
||||||
|
* operates recursively, repeatedly splitting the range into two roughly equal
|
||||||
|
* pieces until the subranges are small enough to be computed using only C
|
||||||
|
* integer arithmetic.
|
||||||
|
*
|
||||||
|
* The two-valuation k (i.e., the exponent of the largest power of 2 dividing
|
||||||
|
* the factorial) is computed independently in the main math_factorial
|
||||||
|
* function. By standard results, its value is:
|
||||||
|
*
|
||||||
|
* two_valuation = n//2 + n//4 + n//8 + ....
|
||||||
|
*
|
||||||
|
* It can be shown (e.g., by complete induction on n) that two_valuation is
|
||||||
|
* equal to n - count_set_bits(n), where count_set_bits(n) gives the number of
|
||||||
|
* '1'-bits in the binary expansion of n.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/* factorial_partial_product: Compute product(range(start, stop, 2)) using
|
||||||
|
* divide and conquer. Assumes start and stop are odd and stop > start.
|
||||||
|
* max_bits must be >= bit_length(stop - 2). */
|
||||||
|
|
||||||
|
static PyObject *
|
||||||
|
factorial_partial_product(unsigned long start, unsigned long stop,
|
||||||
|
unsigned long max_bits)
|
||||||
|
{
|
||||||
|
unsigned long midpoint, num_operands;
|
||||||
|
PyObject *left = NULL, *right = NULL, *result = NULL;
|
||||||
|
|
||||||
|
/* If the return value will fit an unsigned long, then we can
|
||||||
|
* multiply in a tight, fast loop where each multiply is O(1).
|
||||||
|
* Compute an upper bound on the number of bits required to store
|
||||||
|
* the answer.
|
||||||
|
*
|
||||||
|
* Storing some integer z requires floor(lg(z))+1 bits, which is
|
||||||
|
* conveniently the value returned by bit_length(z). The
|
||||||
|
* product x*y will require at most
|
||||||
|
* bit_length(x) + bit_length(y) bits to store, based
|
||||||
|
* on the idea that lg product = lg x + lg y.
|
||||||
|
*
|
||||||
|
* We know that stop - 2 is the largest number to be multiplied. From
|
||||||
|
* there, we have: bit_length(answer) <= num_operands *
|
||||||
|
* bit_length(stop - 2)
|
||||||
|
*/
|
||||||
|
|
||||||
|
num_operands = (stop - start) / 2;
|
||||||
|
/* The "num_operands <= 8 * SIZEOF_LONG" check guards against the
|
||||||
|
* unlikely case of an overflow in num_operands * max_bits. */
|
||||||
|
if (num_operands <= 8 * SIZEOF_LONG &&
|
||||||
|
num_operands * max_bits <= 8 * SIZEOF_LONG) {
|
||||||
|
unsigned long j, total;
|
||||||
|
for (total = start, j = start + 2; j < stop; j += 2)
|
||||||
|
total *= j;
|
||||||
|
return PyLong_FromUnsignedLong(total);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* find midpoint of range(start, stop), rounded up to next odd number. */
|
||||||
|
midpoint = (start + num_operands) | 1;
|
||||||
|
left = factorial_partial_product(start, midpoint,
|
||||||
|
bit_length(midpoint - 2));
|
||||||
|
if (left == NULL)
|
||||||
|
goto error;
|
||||||
|
right = factorial_partial_product(midpoint, stop, max_bits);
|
||||||
|
if (right == NULL)
|
||||||
|
goto error;
|
||||||
|
result = PyNumber_Multiply(left, right);
|
||||||
|
|
||||||
|
error:
|
||||||
|
Py_XDECREF(left);
|
||||||
|
Py_XDECREF(right);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* factorial_odd_part: compute the odd part of factorial(n). */
|
||||||
|
|
||||||
|
static PyObject *
|
||||||
|
factorial_odd_part(unsigned long n)
|
||||||
|
{
|
||||||
|
long i;
|
||||||
|
unsigned long v, lower, upper;
|
||||||
|
PyObject *partial, *tmp, *inner, *outer;
|
||||||
|
|
||||||
|
inner = PyLong_FromLong(1);
|
||||||
|
if (inner == NULL)
|
||||||
|
return NULL;
|
||||||
|
outer = inner;
|
||||||
|
Py_INCREF(outer);
|
||||||
|
|
||||||
|
upper = 3;
|
||||||
|
for (i = bit_length(n) - 2; i >= 0; i--) {
|
||||||
|
v = n >> i;
|
||||||
|
if (v <= 2)
|
||||||
|
continue;
|
||||||
|
lower = upper;
|
||||||
|
/* (v + 1) | 1 = least odd integer strictly larger than n / 2**i */
|
||||||
|
upper = (v + 1) | 1;
|
||||||
|
/* Here inner is the product of all odd integers j in the range (0,
|
||||||
|
n/2**(i+1)]. The factorial_partial_product call below gives the
|
||||||
|
product of all odd integers j in the range (n/2**(i+1), n/2**i]. */
|
||||||
|
partial = factorial_partial_product(lower, upper, bit_length(upper-2));
|
||||||
|
/* inner *= partial */
|
||||||
|
if (partial == NULL)
|
||||||
|
goto error;
|
||||||
|
tmp = PyNumber_Multiply(inner, partial);
|
||||||
|
Py_DECREF(partial);
|
||||||
|
if (tmp == NULL)
|
||||||
|
goto error;
|
||||||
|
Py_DECREF(inner);
|
||||||
|
inner = tmp;
|
||||||
|
/* Now inner is the product of all odd integers j in the range (0,
|
||||||
|
n/2**i], giving the inner product in the formula above. */
|
||||||
|
|
||||||
|
/* outer *= inner; */
|
||||||
|
tmp = PyNumber_Multiply(outer, inner);
|
||||||
|
if (tmp == NULL)
|
||||||
|
goto error;
|
||||||
|
Py_DECREF(outer);
|
||||||
|
outer = tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
goto done;
|
||||||
|
|
||||||
|
error:
|
||||||
|
Py_DECREF(outer);
|
||||||
|
done:
|
||||||
|
Py_DECREF(inner);
|
||||||
|
return outer;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Lookup table for small factorial values */
|
||||||
|
|
||||||
|
static const unsigned long SmallFactorials[] = {
|
||||||
|
1, 1, 2, 6, 24, 120, 720, 5040, 40320,
|
||||||
|
362880, 3628800, 39916800, 479001600,
|
||||||
|
#if SIZEOF_LONG >= 8
|
||||||
|
6227020800, 87178291200, 1307674368000,
|
||||||
|
20922789888000, 355687428096000, 6402373705728000,
|
||||||
|
121645100408832000, 2432902008176640000
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
static PyObject *
|
static PyObject *
|
||||||
math_factorial(PyObject *self, PyObject *arg)
|
math_factorial(PyObject *self, PyObject *arg)
|
||||||
{
|
{
|
||||||
long i, x;
|
long x;
|
||||||
PyObject *result, *iobj, *newresult;
|
PyObject *result, *odd_part, *two_valuation;
|
||||||
|
|
||||||
if (PyFloat_Check(arg)) {
|
if (PyFloat_Check(arg)) {
|
||||||
PyObject *lx;
|
PyObject *lx;
|
||||||
|
|
@ -1160,25 +1381,24 @@ math_factorial(PyObject *self, PyObject *arg)
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
result = (PyObject *)PyLong_FromLong(1);
|
/* use lookup table if x is small */
|
||||||
if (result == NULL)
|
if (x < (long)(sizeof(SmallFactorials)/sizeof(SmallFactorials[0])))
|
||||||
return NULL;
|
return PyLong_FromUnsignedLong(SmallFactorials[x]);
|
||||||
for (i=1 ; i<=x ; i++) {
|
|
||||||
iobj = (PyObject *)PyLong_FromLong(i);
|
|
||||||
if (iobj == NULL)
|
|
||||||
goto error;
|
|
||||||
newresult = PyNumber_Multiply(result, iobj);
|
|
||||||
Py_DECREF(iobj);
|
|
||||||
if (newresult == NULL)
|
|
||||||
goto error;
|
|
||||||
Py_DECREF(result);
|
|
||||||
result = newresult;
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
|
|
||||||
error:
|
/* else express in the form odd_part * 2**two_valuation, and compute as
|
||||||
Py_DECREF(result);
|
odd_part << two_valuation. */
|
||||||
|
odd_part = factorial_odd_part(x);
|
||||||
|
if (odd_part == NULL)
|
||||||
return NULL;
|
return NULL;
|
||||||
|
two_valuation = PyLong_FromLong(x - count_set_bits(x));
|
||||||
|
if (two_valuation == NULL) {
|
||||||
|
Py_DECREF(odd_part);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
result = PyNumber_Lshift(odd_part, two_valuation);
|
||||||
|
Py_DECREF(two_valuation);
|
||||||
|
Py_DECREF(odd_part);
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
PyDoc_STRVAR(math_factorial_doc,
|
PyDoc_STRVAR(math_factorial_doc,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue