bpo-37295: Optimize math.comb() and math.perm() (GH-29090)

For very large numbers use divide-and-conquer algorithm for getting
benefit of Karatsuba multiplication of large numbers.

Do calculations completely in C unsigned long long instead of Python
integers if possible.
This commit is contained in:
Serhiy Storchaka 2021-12-05 22:26:10 +02:00 committed by GitHub
parent 628abe4463
commit 60c320c38e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 200 additions and 95 deletions

View file

@ -3221,6 +3221,138 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
}
/* Number of permutations and combinations.
* P(n, k) = n! / (n-k)!
* C(n, k) = P(n, k) / k!
*/
/* Calculate C(n, k) for n in the 63-bit range. */
static PyObject *
perm_comb_small(unsigned long long n, unsigned long long k, int iscomb)
{
/* long long is at least 64 bit */
static const unsigned long long fast_comb_limits[] = {
0, ULLONG_MAX, 4294967296ULL, 3329022, 102570, 13467, 3612, 1449, // 0-7
746, 453, 308, 227, 178, 147, 125, 110, // 8-15
99, 90, 84, 79, 75, 72, 69, 68, // 16-23
66, 65, 64, 63, 63, 62, 62, 62, // 24-31
};
static const unsigned long long fast_perm_limits[] = {
0, ULLONG_MAX, 4294967296ULL, 2642246, 65537, 7133, 1627, 568, // 0-7
259, 142, 88, 61, 45, 36, 30, // 8-14
};
if (k == 0) {
return PyLong_FromLong(1);
}
/* For small enough n and k the result fits in the 64-bit range and can
* be calculated without allocating intermediate PyLong objects. */
if (iscomb
? (k < Py_ARRAY_LENGTH(fast_comb_limits)
&& n <= fast_comb_limits[k])
: (k < Py_ARRAY_LENGTH(fast_perm_limits)
&& n <= fast_perm_limits[k]))
{
unsigned long long result = n;
if (iscomb) {
for (unsigned long long i = 1; i < k;) {
result *= --n;
result /= ++i;
}
}
else {
for (unsigned long long i = 1; i < k;) {
result *= --n;
++i;
}
}
return PyLong_FromUnsignedLongLong(result);
}
/* For larger n use recursive formula. */
/* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
unsigned long long j = k / 2;
PyObject *a, *b;
a = perm_comb_small(n, j, iscomb);
if (a == NULL) {
return NULL;
}
b = perm_comb_small(n - j, k - j, iscomb);
if (b == NULL) {
goto error;
}
Py_SETREF(a, PyNumber_Multiply(a, b));
Py_DECREF(b);
if (iscomb && a != NULL) {
b = perm_comb_small(k, j, 1);
if (b == NULL) {
goto error;
}
Py_SETREF(a, PyNumber_FloorDivide(a, b));
Py_DECREF(b);
}
return a;
error:
Py_DECREF(a);
return NULL;
}
/* Calculate P(n, k) or C(n, k) using recursive formulas.
* It is more efficient than sequential multiplication thanks to
* Karatsuba multiplication.
*/
static PyObject *
perm_comb(PyObject *n, unsigned long long k, int iscomb)
{
if (k == 0) {
return PyLong_FromLong(1);
}
if (k == 1) {
Py_INCREF(n);
return n;
}
/* P(n, k) = P(n, j) * P(n-j, k-j) */
/* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
unsigned long long j = k / 2;
PyObject *a, *b;
a = perm_comb(n, j, iscomb);
if (a == NULL) {
return NULL;
}
PyObject *t = PyLong_FromUnsignedLongLong(j);
if (t == NULL) {
goto error;
}
n = PyNumber_Subtract(n, t);
Py_DECREF(t);
if (n == NULL) {
goto error;
}
b = perm_comb(n, k - j, iscomb);
Py_DECREF(n);
if (b == NULL) {
goto error;
}
Py_SETREF(a, PyNumber_Multiply(a, b));
Py_DECREF(b);
if (iscomb && a != NULL) {
b = perm_comb_small(k, j, 1);
if (b == NULL) {
goto error;
}
Py_SETREF(a, PyNumber_FloorDivide(a, b));
Py_DECREF(b);
}
return a;
error:
Py_DECREF(a);
return NULL;
}
/*[clinic input]
math.perm
@ -3244,9 +3376,9 @@ static PyObject *
math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
/*[clinic end generated code: output=e021a25469653e23 input=5311c5a00f359b53]*/
{
PyObject *result = NULL, *factor = NULL;
PyObject *result = NULL;
int overflow, cmp;
long long i, factors;
long long ki, ni;
if (k == Py_None) {
return math_factorial(module, n);
@ -3260,6 +3392,7 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
Py_DECREF(n);
return NULL;
}
assert(PyLong_CheckExact(n) && PyLong_CheckExact(k));
if (Py_SIZE(n) < 0) {
PyErr_SetString(PyExc_ValueError,
@ -3281,57 +3414,38 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
goto error;
}
factors = PyLong_AsLongLongAndOverflow(k, &overflow);
ki = PyLong_AsLongLongAndOverflow(k, &overflow);
assert(overflow >= 0 && !PyErr_Occurred());
if (overflow > 0) {
PyErr_Format(PyExc_OverflowError,
"k must not exceed %lld",
LLONG_MAX);
goto error;
}
else if (factors == -1) {
/* k is nonnegative, so a return value of -1 can only indicate error */
goto error;
}
assert(ki >= 0);
if (factors == 0) {
result = PyLong_FromLong(1);
goto done;
ni = PyLong_AsLongLongAndOverflow(n, &overflow);
assert(overflow >= 0 && !PyErr_Occurred());
if (!overflow && ki > 1) {
assert(ni >= 0);
result = perm_comb_small((unsigned long long)ni,
(unsigned long long)ki, 0);
}
result = n;
Py_INCREF(result);
if (factors == 1) {
goto done;
else {
result = perm_comb(n, (unsigned long long)ki, 0);
}
factor = Py_NewRef(n);
PyObject *one = _PyLong_GetOne(); // borrowed ref
for (i = 1; i < factors; ++i) {
Py_SETREF(factor, PyNumber_Subtract(factor, one));
if (factor == NULL) {
goto error;
}
Py_SETREF(result, PyNumber_Multiply(result, factor));
if (result == NULL) {
goto error;
}
}
Py_DECREF(factor);
done:
Py_DECREF(n);
Py_DECREF(k);
return result;
error:
Py_XDECREF(factor);
Py_XDECREF(result);
Py_DECREF(n);
Py_DECREF(k);
return NULL;
}
/*[clinic input]
math.comb
@ -3357,9 +3471,9 @@ static PyObject *
math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
/*[clinic end generated code: output=bd2cec8d854f3493 input=9a05315af2518709]*/
{
PyObject *result = NULL, *factor = NULL, *temp;
PyObject *result = NULL, *temp;
int overflow, cmp;
long long i, factors;
long long ki, ni;
n = PyNumber_Index(n);
if (n == NULL) {
@ -3370,6 +3484,7 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
Py_DECREF(n);
return NULL;
}
assert(PyLong_CheckExact(n) && PyLong_CheckExact(k));
if (Py_SIZE(n) < 0) {
PyErr_SetString(PyExc_ValueError,
@ -3382,73 +3497,59 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
goto error;
}
/* k = min(k, n - k) */
temp = PyNumber_Subtract(n, k);
if (temp == NULL) {
goto error;
}
if (Py_SIZE(temp) < 0) {
Py_DECREF(temp);
result = PyLong_FromLong(0);
goto done;
}
cmp = PyObject_RichCompareBool(temp, k, Py_LT);
if (cmp > 0) {
Py_SETREF(k, temp);
ni = PyLong_AsLongLongAndOverflow(n, &overflow);
assert(overflow >= 0 && !PyErr_Occurred());
if (!overflow) {
assert(ni >= 0);
ki = PyLong_AsLongLongAndOverflow(k, &overflow);
assert(overflow >= 0 && !PyErr_Occurred());
if (overflow || ki > ni) {
result = PyLong_FromLong(0);
goto done;
}
assert(ki >= 0);
ki = Py_MIN(ki, ni - ki);
if (ki > 1) {
result = perm_comb_small((unsigned long long)ni,
(unsigned long long)ki, 1);
goto done;
}
/* For k == 1 just return the original n in perm_comb(). */
}
else {
Py_DECREF(temp);
if (cmp < 0) {
goto error;
}
}
factors = PyLong_AsLongLongAndOverflow(k, &overflow);
if (overflow > 0) {
PyErr_Format(PyExc_OverflowError,
"min(n - k, k) must not exceed %lld",
LLONG_MAX);
goto error;
}
if (factors == -1) {
/* k is nonnegative, so a return value of -1 can only indicate error */
goto error;
}
if (factors == 0) {
result = PyLong_FromLong(1);
goto done;
}
result = n;
Py_INCREF(result);
if (factors == 1) {
goto done;
}
factor = Py_NewRef(n);
PyObject *one = _PyLong_GetOne(); // borrowed ref
for (i = 1; i < factors; ++i) {
Py_SETREF(factor, PyNumber_Subtract(factor, one));
if (factor == NULL) {
goto error;
}
Py_SETREF(result, PyNumber_Multiply(result, factor));
if (result == NULL) {
goto error;
}
temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1);
/* k = min(k, n - k) */
temp = PyNumber_Subtract(n, k);
if (temp == NULL) {
goto error;
}
Py_SETREF(result, PyNumber_FloorDivide(result, temp));
Py_DECREF(temp);
if (result == NULL) {
if (Py_SIZE(temp) < 0) {
Py_DECREF(temp);
result = PyLong_FromLong(0);
goto done;
}
cmp = PyObject_RichCompareBool(temp, k, Py_LT);
if (cmp > 0) {
Py_SETREF(k, temp);
}
else {
Py_DECREF(temp);
if (cmp < 0) {
goto error;
}
}
ki = PyLong_AsLongLongAndOverflow(k, &overflow);
assert(overflow >= 0 && !PyErr_Occurred());
if (overflow) {
PyErr_Format(PyExc_OverflowError,
"min(n - k, k) must not exceed %lld",
LLONG_MAX);
goto error;
}
assert(ki >= 0);
}
Py_DECREF(factor);
result = perm_comb(n, (unsigned long long)ki, 1);
done:
Py_DECREF(n);
@ -3456,8 +3557,6 @@ done:
return result;
error:
Py_XDECREF(factor);
Py_XDECREF(result);
Py_DECREF(n);
Py_DECREF(k);
return NULL;