mirror of
https://github.com/python/cpython.git
synced 2025-10-09 16:34:44 +00:00
bpo-37295: Optimize math.comb() and math.perm() (GH-29090)
For very large numbers use divide-and-conquer algorithm for getting benefit of Karatsuba multiplication of large numbers. Do calculations completely in C unsigned long long instead of Python integers if possible.
This commit is contained in:
parent
628abe4463
commit
60c320c38e
3 changed files with 200 additions and 95 deletions
|
@ -351,6 +351,11 @@ Optimizations
|
|||
* Pure ASCII strings are now normalized in constant time by :func:`unicodedata.normalize`.
|
||||
(Contributed by Dong-hee Na in :issue:`44987`.)
|
||||
|
||||
* :mod:`math` functions :func:`~math.comb` and :func:`~math.perm` are now up
|
||||
to 10 times or more faster for large arguments (the speed up is larger for
|
||||
larger *k*).
|
||||
(Contributed by Serhiy Storchaka in :issue:`37295`.)
|
||||
|
||||
|
||||
CPython bytecode changes
|
||||
========================
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Optimize :func:`math.comb` and :func:`math.perm`.
|
|
@ -3221,6 +3221,138 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
|
|||
}
|
||||
|
||||
|
||||
/* Number of permutations and combinations.
|
||||
* P(n, k) = n! / (n-k)!
|
||||
* C(n, k) = P(n, k) / k!
|
||||
*/
|
||||
|
||||
/* Calculate C(n, k) for n in the 63-bit range. */
|
||||
static PyObject *
|
||||
perm_comb_small(unsigned long long n, unsigned long long k, int iscomb)
|
||||
{
|
||||
/* long long is at least 64 bit */
|
||||
static const unsigned long long fast_comb_limits[] = {
|
||||
0, ULLONG_MAX, 4294967296ULL, 3329022, 102570, 13467, 3612, 1449, // 0-7
|
||||
746, 453, 308, 227, 178, 147, 125, 110, // 8-15
|
||||
99, 90, 84, 79, 75, 72, 69, 68, // 16-23
|
||||
66, 65, 64, 63, 63, 62, 62, 62, // 24-31
|
||||
};
|
||||
static const unsigned long long fast_perm_limits[] = {
|
||||
0, ULLONG_MAX, 4294967296ULL, 2642246, 65537, 7133, 1627, 568, // 0-7
|
||||
259, 142, 88, 61, 45, 36, 30, // 8-14
|
||||
};
|
||||
|
||||
if (k == 0) {
|
||||
return PyLong_FromLong(1);
|
||||
}
|
||||
|
||||
/* For small enough n and k the result fits in the 64-bit range and can
|
||||
* be calculated without allocating intermediate PyLong objects. */
|
||||
if (iscomb
|
||||
? (k < Py_ARRAY_LENGTH(fast_comb_limits)
|
||||
&& n <= fast_comb_limits[k])
|
||||
: (k < Py_ARRAY_LENGTH(fast_perm_limits)
|
||||
&& n <= fast_perm_limits[k]))
|
||||
{
|
||||
unsigned long long result = n;
|
||||
if (iscomb) {
|
||||
for (unsigned long long i = 1; i < k;) {
|
||||
result *= --n;
|
||||
result /= ++i;
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (unsigned long long i = 1; i < k;) {
|
||||
result *= --n;
|
||||
++i;
|
||||
}
|
||||
}
|
||||
return PyLong_FromUnsignedLongLong(result);
|
||||
}
|
||||
|
||||
/* For larger n use recursive formula. */
|
||||
/* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
|
||||
unsigned long long j = k / 2;
|
||||
PyObject *a, *b;
|
||||
a = perm_comb_small(n, j, iscomb);
|
||||
if (a == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
b = perm_comb_small(n - j, k - j, iscomb);
|
||||
if (b == NULL) {
|
||||
goto error;
|
||||
}
|
||||
Py_SETREF(a, PyNumber_Multiply(a, b));
|
||||
Py_DECREF(b);
|
||||
if (iscomb && a != NULL) {
|
||||
b = perm_comb_small(k, j, 1);
|
||||
if (b == NULL) {
|
||||
goto error;
|
||||
}
|
||||
Py_SETREF(a, PyNumber_FloorDivide(a, b));
|
||||
Py_DECREF(b);
|
||||
}
|
||||
return a;
|
||||
|
||||
error:
|
||||
Py_DECREF(a);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/* Calculate P(n, k) or C(n, k) using recursive formulas.
|
||||
* It is more efficient than sequential multiplication thanks to
|
||||
* Karatsuba multiplication.
|
||||
*/
|
||||
static PyObject *
|
||||
perm_comb(PyObject *n, unsigned long long k, int iscomb)
|
||||
{
|
||||
if (k == 0) {
|
||||
return PyLong_FromLong(1);
|
||||
}
|
||||
if (k == 1) {
|
||||
Py_INCREF(n);
|
||||
return n;
|
||||
}
|
||||
|
||||
/* P(n, k) = P(n, j) * P(n-j, k-j) */
|
||||
/* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
|
||||
unsigned long long j = k / 2;
|
||||
PyObject *a, *b;
|
||||
a = perm_comb(n, j, iscomb);
|
||||
if (a == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyObject *t = PyLong_FromUnsignedLongLong(j);
|
||||
if (t == NULL) {
|
||||
goto error;
|
||||
}
|
||||
n = PyNumber_Subtract(n, t);
|
||||
Py_DECREF(t);
|
||||
if (n == NULL) {
|
||||
goto error;
|
||||
}
|
||||
b = perm_comb(n, k - j, iscomb);
|
||||
Py_DECREF(n);
|
||||
if (b == NULL) {
|
||||
goto error;
|
||||
}
|
||||
Py_SETREF(a, PyNumber_Multiply(a, b));
|
||||
Py_DECREF(b);
|
||||
if (iscomb && a != NULL) {
|
||||
b = perm_comb_small(k, j, 1);
|
||||
if (b == NULL) {
|
||||
goto error;
|
||||
}
|
||||
Py_SETREF(a, PyNumber_FloorDivide(a, b));
|
||||
Py_DECREF(b);
|
||||
}
|
||||
return a;
|
||||
|
||||
error:
|
||||
Py_DECREF(a);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/*[clinic input]
|
||||
math.perm
|
||||
|
||||
|
@ -3244,9 +3376,9 @@ static PyObject *
|
|||
math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
|
||||
/*[clinic end generated code: output=e021a25469653e23 input=5311c5a00f359b53]*/
|
||||
{
|
||||
PyObject *result = NULL, *factor = NULL;
|
||||
PyObject *result = NULL;
|
||||
int overflow, cmp;
|
||||
long long i, factors;
|
||||
long long ki, ni;
|
||||
|
||||
if (k == Py_None) {
|
||||
return math_factorial(module, n);
|
||||
|
@ -3260,6 +3392,7 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
|
|||
Py_DECREF(n);
|
||||
return NULL;
|
||||
}
|
||||
assert(PyLong_CheckExact(n) && PyLong_CheckExact(k));
|
||||
|
||||
if (Py_SIZE(n) < 0) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
|
@ -3281,57 +3414,38 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
|
|||
goto error;
|
||||
}
|
||||
|
||||
factors = PyLong_AsLongLongAndOverflow(k, &overflow);
|
||||
ki = PyLong_AsLongLongAndOverflow(k, &overflow);
|
||||
assert(overflow >= 0 && !PyErr_Occurred());
|
||||
if (overflow > 0) {
|
||||
PyErr_Format(PyExc_OverflowError,
|
||||
"k must not exceed %lld",
|
||||
LLONG_MAX);
|
||||
goto error;
|
||||
}
|
||||
else if (factors == -1) {
|
||||
/* k is nonnegative, so a return value of -1 can only indicate error */
|
||||
goto error;
|
||||
}
|
||||
assert(ki >= 0);
|
||||
|
||||
if (factors == 0) {
|
||||
result = PyLong_FromLong(1);
|
||||
goto done;
|
||||
ni = PyLong_AsLongLongAndOverflow(n, &overflow);
|
||||
assert(overflow >= 0 && !PyErr_Occurred());
|
||||
if (!overflow && ki > 1) {
|
||||
assert(ni >= 0);
|
||||
result = perm_comb_small((unsigned long long)ni,
|
||||
(unsigned long long)ki, 0);
|
||||
}
|
||||
|
||||
result = n;
|
||||
Py_INCREF(result);
|
||||
if (factors == 1) {
|
||||
goto done;
|
||||
else {
|
||||
result = perm_comb(n, (unsigned long long)ki, 0);
|
||||
}
|
||||
|
||||
factor = Py_NewRef(n);
|
||||
PyObject *one = _PyLong_GetOne(); // borrowed ref
|
||||
for (i = 1; i < factors; ++i) {
|
||||
Py_SETREF(factor, PyNumber_Subtract(factor, one));
|
||||
if (factor == NULL) {
|
||||
goto error;
|
||||
}
|
||||
Py_SETREF(result, PyNumber_Multiply(result, factor));
|
||||
if (result == NULL) {
|
||||
goto error;
|
||||
}
|
||||
}
|
||||
Py_DECREF(factor);
|
||||
|
||||
done:
|
||||
Py_DECREF(n);
|
||||
Py_DECREF(k);
|
||||
return result;
|
||||
|
||||
error:
|
||||
Py_XDECREF(factor);
|
||||
Py_XDECREF(result);
|
||||
Py_DECREF(n);
|
||||
Py_DECREF(k);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
/*[clinic input]
|
||||
math.comb
|
||||
|
||||
|
@ -3357,9 +3471,9 @@ static PyObject *
|
|||
math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
|
||||
/*[clinic end generated code: output=bd2cec8d854f3493 input=9a05315af2518709]*/
|
||||
{
|
||||
PyObject *result = NULL, *factor = NULL, *temp;
|
||||
PyObject *result = NULL, *temp;
|
||||
int overflow, cmp;
|
||||
long long i, factors;
|
||||
long long ki, ni;
|
||||
|
||||
n = PyNumber_Index(n);
|
||||
if (n == NULL) {
|
||||
|
@ -3370,6 +3484,7 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
|
|||
Py_DECREF(n);
|
||||
return NULL;
|
||||
}
|
||||
assert(PyLong_CheckExact(n) && PyLong_CheckExact(k));
|
||||
|
||||
if (Py_SIZE(n) < 0) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
|
@ -3382,73 +3497,59 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
|
|||
goto error;
|
||||
}
|
||||
|
||||
/* k = min(k, n - k) */
|
||||
temp = PyNumber_Subtract(n, k);
|
||||
if (temp == NULL) {
|
||||
goto error;
|
||||
}
|
||||
if (Py_SIZE(temp) < 0) {
|
||||
Py_DECREF(temp);
|
||||
result = PyLong_FromLong(0);
|
||||
goto done;
|
||||
}
|
||||
cmp = PyObject_RichCompareBool(temp, k, Py_LT);
|
||||
if (cmp > 0) {
|
||||
Py_SETREF(k, temp);
|
||||
ni = PyLong_AsLongLongAndOverflow(n, &overflow);
|
||||
assert(overflow >= 0 && !PyErr_Occurred());
|
||||
if (!overflow) {
|
||||
assert(ni >= 0);
|
||||
ki = PyLong_AsLongLongAndOverflow(k, &overflow);
|
||||
assert(overflow >= 0 && !PyErr_Occurred());
|
||||
if (overflow || ki > ni) {
|
||||
result = PyLong_FromLong(0);
|
||||
goto done;
|
||||
}
|
||||
assert(ki >= 0);
|
||||
ki = Py_MIN(ki, ni - ki);
|
||||
if (ki > 1) {
|
||||
result = perm_comb_small((unsigned long long)ni,
|
||||
(unsigned long long)ki, 1);
|
||||
goto done;
|
||||
}
|
||||
/* For k == 1 just return the original n in perm_comb(). */
|
||||
}
|
||||
else {
|
||||
Py_DECREF(temp);
|
||||
if (cmp < 0) {
|
||||
goto error;
|
||||
}
|
||||
}
|
||||
|
||||
factors = PyLong_AsLongLongAndOverflow(k, &overflow);
|
||||
if (overflow > 0) {
|
||||
PyErr_Format(PyExc_OverflowError,
|
||||
"min(n - k, k) must not exceed %lld",
|
||||
LLONG_MAX);
|
||||
goto error;
|
||||
}
|
||||
if (factors == -1) {
|
||||
/* k is nonnegative, so a return value of -1 can only indicate error */
|
||||
goto error;
|
||||
}
|
||||
|
||||
if (factors == 0) {
|
||||
result = PyLong_FromLong(1);
|
||||
goto done;
|
||||
}
|
||||
|
||||
result = n;
|
||||
Py_INCREF(result);
|
||||
if (factors == 1) {
|
||||
goto done;
|
||||
}
|
||||
|
||||
factor = Py_NewRef(n);
|
||||
PyObject *one = _PyLong_GetOne(); // borrowed ref
|
||||
for (i = 1; i < factors; ++i) {
|
||||
Py_SETREF(factor, PyNumber_Subtract(factor, one));
|
||||
if (factor == NULL) {
|
||||
goto error;
|
||||
}
|
||||
Py_SETREF(result, PyNumber_Multiply(result, factor));
|
||||
if (result == NULL) {
|
||||
goto error;
|
||||
}
|
||||
|
||||
temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1);
|
||||
/* k = min(k, n - k) */
|
||||
temp = PyNumber_Subtract(n, k);
|
||||
if (temp == NULL) {
|
||||
goto error;
|
||||
}
|
||||
Py_SETREF(result, PyNumber_FloorDivide(result, temp));
|
||||
Py_DECREF(temp);
|
||||
if (result == NULL) {
|
||||
if (Py_SIZE(temp) < 0) {
|
||||
Py_DECREF(temp);
|
||||
result = PyLong_FromLong(0);
|
||||
goto done;
|
||||
}
|
||||
cmp = PyObject_RichCompareBool(temp, k, Py_LT);
|
||||
if (cmp > 0) {
|
||||
Py_SETREF(k, temp);
|
||||
}
|
||||
else {
|
||||
Py_DECREF(temp);
|
||||
if (cmp < 0) {
|
||||
goto error;
|
||||
}
|
||||
}
|
||||
|
||||
ki = PyLong_AsLongLongAndOverflow(k, &overflow);
|
||||
assert(overflow >= 0 && !PyErr_Occurred());
|
||||
if (overflow) {
|
||||
PyErr_Format(PyExc_OverflowError,
|
||||
"min(n - k, k) must not exceed %lld",
|
||||
LLONG_MAX);
|
||||
goto error;
|
||||
}
|
||||
assert(ki >= 0);
|
||||
}
|
||||
Py_DECREF(factor);
|
||||
|
||||
result = perm_comb(n, (unsigned long long)ki, 1);
|
||||
|
||||
done:
|
||||
Py_DECREF(n);
|
||||
|
@ -3456,8 +3557,6 @@ done:
|
|||
return result;
|
||||
|
||||
error:
|
||||
Py_XDECREF(factor);
|
||||
Py_XDECREF(result);
|
||||
Py_DECREF(n);
|
||||
Py_DECREF(k);
|
||||
return NULL;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue