From 39e8192e4d58f5891d8177e9bc3b1d4020d7e643 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Tue, 18 Feb 2025 09:46:48 +0200 Subject: [PATCH 1/2] gh-130230: Fix crash in pow() with only Decimal third argument (GH-130237) (cherry picked from commit b93b7e566e5a4efe7f077af2083140e50bd2b08f) Co-authored-by: Serhiy Storchaka --- Lib/test/test_decimal.py | 9 +++++++++ ...-02-17-21-16-51.gh-issue-130230.9ta9P9.rst | 1 + Modules/_decimal/_decimal.c | 20 ++++++++++++++++++- 3 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 Misc/NEWS.d/next/Library/2025-02-17-21-16-51.gh-issue-130230.9ta9P9.rst diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index d1e7e69e7e951b..c4d05f87a8df22 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -4458,6 +4458,15 @@ def test_implicit_context(self): self.assertIs(Decimal("NaN").fma(7, 1).is_nan(), True) # three arg power self.assertEqual(pow(Decimal(10), 2, 7), 2) + if self.decimal == C: + self.assertEqual(pow(10, Decimal(2), 7), 2) + self.assertEqual(pow(10, 2, Decimal(7)), 2) + else: + # XXX: Three-arg power doesn't use __rpow__. + self.assertRaises(TypeError, pow, 10, Decimal(2), 7) + # XXX: There is no special method to dispatch on the + # third arg of three-arg power. + self.assertRaises(TypeError, pow, 10, 2, Decimal(7)) # exp self.assertEqual(Decimal("1.01").exp(), 3) # is_normal diff --git a/Misc/NEWS.d/next/Library/2025-02-17-21-16-51.gh-issue-130230.9ta9P9.rst b/Misc/NEWS.d/next/Library/2025-02-17-21-16-51.gh-issue-130230.9ta9P9.rst new file mode 100644 index 00000000000000..20327fd5f25b43 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-02-17-21-16-51.gh-issue-130230.9ta9P9.rst @@ -0,0 +1 @@ +Fix crash in :func:`pow` with only :class:`~decimal.Decimal` third argument. diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c index 29ae5f402a06f9..a1c5221f78fa10 100644 --- a/Modules/_decimal/_decimal.c +++ b/Modules/_decimal/_decimal.c @@ -140,6 +140,24 @@ find_state_left_or_right(PyObject *left, PyObject *right) return get_module_state(mod); } +static inline decimal_state * +find_state_ternary(PyObject *left, PyObject *right, PyObject *modulus) +{ + PyTypeObject *base; + if (PyType_GetBaseByToken(Py_TYPE(left), &dec_spec, &base) != 1) { + assert(!PyErr_Occurred()); + if (PyType_GetBaseByToken(Py_TYPE(right), &dec_spec, &base) != 1) { + assert(!PyErr_Occurred()); + PyType_GetBaseByToken(Py_TYPE(modulus), &dec_spec, &base); + } + } + assert(base != NULL); + void *state = _PyType_GetModuleState(base); + assert(state != NULL); + Py_DECREF(base); + return (decimal_state *)state; +} + #if !defined(MPD_VERSION_HEX) || MPD_VERSION_HEX < 0x02050000 #error "libmpdec version >= 2.5.0 required" @@ -4305,7 +4323,7 @@ nm_mpd_qpow(PyObject *base, PyObject *exp, PyObject *mod) PyObject *context; uint32_t status = 0; - decimal_state *state = find_state_left_or_right(base, exp); + decimal_state *state = find_state_ternary(base, exp, mod); CURRENT_CONTEXT(state, context); CONVERT_BINOP(&a, &b, base, exp, context); From 7b9cf13a77c09a81a3386a99064225e59072bd34 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Tue, 18 Feb 2025 11:26:14 +0200 Subject: [PATCH 2/2] Rewrite code for 3.13. --- Include/internal/pycore_typeobject.h | 1 + Modules/_decimal/_decimal.c | 17 ++++------------- Objects/typeobject.c | 20 ++++++++++++++++++++ 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/Include/internal/pycore_typeobject.h b/Include/internal/pycore_typeobject.h index a6562f7f9ba74e..164b243dae7806 100644 --- a/Include/internal/pycore_typeobject.h +++ b/Include/internal/pycore_typeobject.h @@ -199,6 +199,7 @@ extern PyObject * _PyType_GetMRO(PyTypeObject *type); extern PyObject* _PyType_GetSubclasses(PyTypeObject *); extern int _PyType_HasSubclasses(PyTypeObject *); PyAPI_FUNC(PyObject *) _PyType_GetModuleByDef2(PyTypeObject *, PyTypeObject *, PyModuleDef *); +PyAPI_FUNC(PyObject *) _PyType_GetModuleByDef3(PyTypeObject *, PyTypeObject *, PyTypeObject *, PyModuleDef *); // PyType_Ready() must be called if _PyType_IsReady() is false. // See also the Py_TPFLAGS_READY flag. diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c index a1c5221f78fa10..d1fbfd7d30317b 100644 --- a/Modules/_decimal/_decimal.c +++ b/Modules/_decimal/_decimal.c @@ -143,19 +143,10 @@ find_state_left_or_right(PyObject *left, PyObject *right) static inline decimal_state * find_state_ternary(PyObject *left, PyObject *right, PyObject *modulus) { - PyTypeObject *base; - if (PyType_GetBaseByToken(Py_TYPE(left), &dec_spec, &base) != 1) { - assert(!PyErr_Occurred()); - if (PyType_GetBaseByToken(Py_TYPE(right), &dec_spec, &base) != 1) { - assert(!PyErr_Occurred()); - PyType_GetBaseByToken(Py_TYPE(modulus), &dec_spec, &base); - } - } - assert(base != NULL); - void *state = _PyType_GetModuleState(base); - assert(state != NULL); - Py_DECREF(base); - return (decimal_state *)state; + PyObject *mod = _PyType_GetModuleByDef3(Py_TYPE(left), Py_TYPE(right), Py_TYPE(modulus), + &_decimal_module); + assert(mod != NULL); + return get_module_state(mod); } diff --git a/Objects/typeobject.c b/Objects/typeobject.c index 57e03c669d9141..c84841070555d1 100644 --- a/Objects/typeobject.c +++ b/Objects/typeobject.c @@ -5038,6 +5038,26 @@ _PyType_GetModuleByDef2(PyTypeObject *left, PyTypeObject *right, return module; } +PyObject * +_PyType_GetModuleByDef3(PyTypeObject *left, PyTypeObject *right, PyTypeObject *third, + PyModuleDef *def) +{ + PyObject *module = get_module_by_def(left, def); + if (module == NULL) { + module = get_module_by_def(right, def); + if (module == NULL) { + module = get_module_by_def(third, def); + if (module == NULL) { + PyErr_Format( + PyExc_TypeError, + "PyType_GetModuleByDef: No superclass of '%s', '%s' nor '%s' has " + "the given module", left->tp_name, right->tp_name, third->tp_name); + } + } + } + return module; +} + void * PyObject_GetTypeData(PyObject *obj, PyTypeObject *cls) {