[3.13] gh-130230: Fix crash in pow() with only Decimal third argument (GH-130237) (GH-130246)

(cherry picked from commit b93b7e566e)

Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
This commit is contained in:
Miss Islington (bot) 2025-02-18 12:18:37 +01:00 committed by GitHub
parent fc1c9f884e
commit 5d83b6c160
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 41 additions and 1 deletions

View file

@ -199,6 +199,7 @@ extern PyObject * _PyType_GetMRO(PyTypeObject *type);
extern PyObject* _PyType_GetSubclasses(PyTypeObject *);
extern int _PyType_HasSubclasses(PyTypeObject *);
PyAPI_FUNC(PyObject *) _PyType_GetModuleByDef2(PyTypeObject *, PyTypeObject *, PyModuleDef *);
PyAPI_FUNC(PyObject *) _PyType_GetModuleByDef3(PyTypeObject *, PyTypeObject *, PyTypeObject *, PyModuleDef *);
// PyType_Ready() must be called if _PyType_IsReady() is false.
// See also the Py_TPFLAGS_READY flag.

View file

@ -4458,6 +4458,15 @@ class Coverage:
self.assertIs(Decimal("NaN").fma(7, 1).is_nan(), True)
# three arg power
self.assertEqual(pow(Decimal(10), 2, 7), 2)
if self.decimal == C:
self.assertEqual(pow(10, Decimal(2), 7), 2)
self.assertEqual(pow(10, 2, Decimal(7)), 2)
else:
# XXX: Three-arg power doesn't use __rpow__.
self.assertRaises(TypeError, pow, 10, Decimal(2), 7)
# XXX: There is no special method to dispatch on the
# third arg of three-arg power.
self.assertRaises(TypeError, pow, 10, 2, Decimal(7))
# exp
self.assertEqual(Decimal("1.01").exp(), 3)
# is_normal

View file

@ -0,0 +1 @@
Fix crash in :func:`pow` with only :class:`~decimal.Decimal` third argument.

View file

@ -140,6 +140,15 @@ find_state_left_or_right(PyObject *left, PyObject *right)
return get_module_state(mod);
}
static inline decimal_state *
find_state_ternary(PyObject *left, PyObject *right, PyObject *modulus)
{
PyObject *mod = _PyType_GetModuleByDef3(Py_TYPE(left), Py_TYPE(right), Py_TYPE(modulus),
&_decimal_module);
assert(mod != NULL);
return get_module_state(mod);
}
#if !defined(MPD_VERSION_HEX) || MPD_VERSION_HEX < 0x02050000
#error "libmpdec version >= 2.5.0 required"
@ -4305,7 +4314,7 @@ nm_mpd_qpow(PyObject *base, PyObject *exp, PyObject *mod)
PyObject *context;
uint32_t status = 0;
decimal_state *state = find_state_left_or_right(base, exp);
decimal_state *state = find_state_ternary(base, exp, mod);
CURRENT_CONTEXT(state, context);
CONVERT_BINOP(&a, &b, base, exp, context);

View file

@ -5038,6 +5038,26 @@ _PyType_GetModuleByDef2(PyTypeObject *left, PyTypeObject *right,
return module;
}
PyObject *
_PyType_GetModuleByDef3(PyTypeObject *left, PyTypeObject *right, PyTypeObject *third,
PyModuleDef *def)
{
PyObject *module = get_module_by_def(left, def);
if (module == NULL) {
module = get_module_by_def(right, def);
if (module == NULL) {
module = get_module_by_def(third, def);
if (module == NULL) {
PyErr_Format(
PyExc_TypeError,
"PyType_GetModuleByDef: No superclass of '%s', '%s' nor '%s' has "
"the given module", left->tp_name, right->tp_name, third->tp_name);
}
}
}
return module;
}
void *
PyObject_GetTypeData(PyObject *obj, PyTypeObject *cls)
{