bpo-35431: Refactor math.comb() implementation. (GH-13725)

* Fixed some bugs.
* Added support for index-likes objects.
* Improved error messages.
* Cleaned up and optimized the code.
* Added more tests.
This commit is contained in:
Serhiy Storchaka 2019-06-01 22:09:02 +03:00 committed by GitHub
parent 9843bc110d
commit 2b843ac0ae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 112 additions and 102 deletions

View file

@ -3001,10 +3001,11 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
/*[clinic input]
math.comb
n: object(subclass_of='&PyLong_Type')
k: object(subclass_of='&PyLong_Type')
n: object
k: object
/
Number of ways to choose *k* items from *n* items without repetition and without order.
Number of ways to choose k items from n items without repetition and without order.
Also called the binomial coefficient. It is mathematically equal to the expression
n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in
@ -3017,103 +3018,109 @@ Raises ValueError if the arguments are negative or if k > n.
static PyObject *
math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
/*[clinic end generated code: output=bd2cec8d854f3493 input=565f340f98efb5b5]*/
/*[clinic end generated code: output=bd2cec8d854f3493 input=2f336ac9ec8242f9]*/
{
PyObject *val = NULL,
*temp_obj1 = NULL,
*temp_obj2 = NULL,
*dump_var = NULL;
PyObject *result = NULL, *factor = NULL, *temp;
int overflow, cmp;
long long i, terms;
long long i, factors;
cmp = PyObject_RichCompareBool(n, k, Py_LT);
if (cmp < 0) {
goto fail_comb;
n = PyNumber_Index(n);
if (n == NULL) {
return NULL;
}
else if (cmp > 0) {
PyErr_Format(PyExc_ValueError,
"n must be an integer greater than or equal to k");
goto fail_comb;
k = PyNumber_Index(k);
if (k == NULL) {
Py_DECREF(n);
return NULL;
}
/* b = min(b, a - b) */
dump_var = PyNumber_Subtract(n, k);
if (dump_var == NULL) {
goto fail_comb;
if (Py_SIZE(n) < 0) {
PyErr_SetString(PyExc_ValueError,
"n must be a non-negative integer");
goto error;
}
cmp = PyObject_RichCompareBool(k, dump_var, Py_GT);
if (cmp < 0) {
goto fail_comb;
/* k = min(k, n - k) */
temp = PyNumber_Subtract(n, k);
if (temp == NULL) {
goto error;
}
else if (cmp > 0) {
k = dump_var;
dump_var = NULL;
if (Py_SIZE(temp) < 0) {
Py_DECREF(temp);
PyErr_SetString(PyExc_ValueError,
"k must be an integer less than or equal to n");
goto error;
}
cmp = PyObject_RichCompareBool(k, temp, Py_GT);
if (cmp > 0) {
Py_SETREF(k, temp);
}
else {
Py_DECREF(dump_var);
dump_var = NULL;
Py_DECREF(temp);
if (cmp < 0) {
goto error;
}
}
terms = PyLong_AsLongLongAndOverflow(k, &overflow);
if (terms < 0 && PyErr_Occurred()) {
goto fail_comb;
}
else if (overflow > 0) {
factors = PyLong_AsLongLongAndOverflow(k, &overflow);
if (overflow > 0) {
PyErr_Format(PyExc_OverflowError,
"minimum(n - k, k) must not exceed %lld",
"min(n - k, k) must not exceed %lld",
LLONG_MAX);
goto fail_comb;
goto error;
}
else if (overflow < 0 || terms < 0) {
PyErr_Format(PyExc_ValueError,
"k must be a positive integer");
goto fail_comb;
else if (overflow < 0 || factors < 0) {
if (!PyErr_Occurred()) {
PyErr_SetString(PyExc_ValueError,
"k must be a non-negative integer");
}
goto error;
}
if (terms == 0) {
return PyNumber_Long(_PyLong_One);
if (factors == 0) {
result = PyLong_FromLong(1);
goto done;
}
val = PyNumber_Long(n);
for (i = 1; i < terms; ++i) {
temp_obj1 = PyLong_FromSsize_t(i);
if (temp_obj1 == NULL) {
goto fail_comb;
}
temp_obj2 = PyNumber_Subtract(n, temp_obj1);
if (temp_obj2 == NULL) {
goto fail_comb;
}
dump_var = val;
val = PyNumber_Multiply(val, temp_obj2);
if (val == NULL) {
goto fail_comb;
}
Py_DECREF(dump_var);
dump_var = NULL;
Py_DECREF(temp_obj2);
temp_obj2 = PyLong_FromUnsignedLongLong((unsigned long long)(i + 1));
if (temp_obj2 == NULL) {
goto fail_comb;
}
dump_var = val;
val = PyNumber_FloorDivide(val, temp_obj2);
if (val == NULL) {
goto fail_comb;
}
Py_DECREF(dump_var);
Py_DECREF(temp_obj1);
Py_DECREF(temp_obj2);
result = n;
Py_INCREF(result);
if (factors == 1) {
goto done;
}
return val;
factor = n;
Py_INCREF(factor);
for (i = 1; i < factors; ++i) {
Py_SETREF(factor, PyNumber_Subtract(factor, _PyLong_One));
if (factor == NULL) {
goto error;
}
Py_SETREF(result, PyNumber_Multiply(result, factor));
if (result == NULL) {
goto error;
}
fail_comb:
Py_XDECREF(val);
Py_XDECREF(dump_var);
Py_XDECREF(temp_obj1);
Py_XDECREF(temp_obj2);
temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1);
if (temp == NULL) {
goto error;
}
Py_SETREF(result, PyNumber_FloorDivide(result, temp));
Py_DECREF(temp);
if (result == NULL) {
goto error;
}
}
Py_DECREF(factor);
done:
Py_DECREF(n);
Py_DECREF(k);
return result;
error:
Py_XDECREF(factor);
Py_XDECREF(result);
Py_DECREF(n);
Py_DECREF(k);
return NULL;
}