GH-127809: Fix the JIT's understanding of ** (GH-127844)

This commit is contained in:
Brandt Bucher 2025-01-07 17:25:48 -08:00 committed by GitHub
parent e08b28235a
commit 65ae3d5a73
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 200 additions and 27 deletions

View file

@ -530,6 +530,8 @@ dummy_func(
pure op(_BINARY_OP_MULTIPLY_INT, (left, right -- res)) {
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyLong_CheckExact(left_o));
assert(PyLong_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
PyObject *res_o = _PyLong_Multiply((PyLongObject *)left_o, (PyLongObject *)right_o);
@ -543,6 +545,8 @@ dummy_func(
pure op(_BINARY_OP_ADD_INT, (left, right -- res)) {
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyLong_CheckExact(left_o));
assert(PyLong_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
PyObject *res_o = _PyLong_Add((PyLongObject *)left_o, (PyLongObject *)right_o);
@ -556,6 +560,8 @@ dummy_func(
pure op(_BINARY_OP_SUBTRACT_INT, (left, right -- res)) {
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyLong_CheckExact(left_o));
assert(PyLong_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
PyObject *res_o = _PyLong_Subtract((PyLongObject *)left_o, (PyLongObject *)right_o);
@ -593,6 +599,8 @@ dummy_func(
pure op(_BINARY_OP_MULTIPLY_FLOAT, (left, right -- res)) {
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyFloat_CheckExact(left_o));
assert(PyFloat_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
double dres =
@ -607,6 +615,8 @@ dummy_func(
pure op(_BINARY_OP_ADD_FLOAT, (left, right -- res)) {
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyFloat_CheckExact(left_o));
assert(PyFloat_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
double dres =
@ -621,6 +631,8 @@ dummy_func(
pure op(_BINARY_OP_SUBTRACT_FLOAT, (left, right -- res)) {
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyFloat_CheckExact(left_o));
assert(PyFloat_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
double dres =
@ -650,6 +662,8 @@ dummy_func(
pure op(_BINARY_OP_ADD_UNICODE, (left, right -- res)) {
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyUnicode_CheckExact(left_o));
assert(PyUnicode_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
PyObject *res_o = PyUnicode_Concat(left_o, right_o);
@ -672,6 +686,8 @@ dummy_func(
op(_BINARY_OP_INPLACE_ADD_UNICODE, (left, right --)) {
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyUnicode_CheckExact(left_o));
assert(PyUnicode_CheckExact(right_o));
int next_oparg;
#if TIER_ONE

View file

@ -638,6 +638,8 @@
left = stack_pointer[-2];
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyLong_CheckExact(left_o));
assert(PyLong_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
PyObject *res_o = _PyLong_Multiply((PyLongObject *)left_o, (PyLongObject *)right_o);
PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc);
@ -658,6 +660,8 @@
left = stack_pointer[-2];
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyLong_CheckExact(left_o));
assert(PyLong_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
PyObject *res_o = _PyLong_Add((PyLongObject *)left_o, (PyLongObject *)right_o);
PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc);
@ -678,6 +682,8 @@
left = stack_pointer[-2];
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyLong_CheckExact(left_o));
assert(PyLong_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
PyObject *res_o = _PyLong_Subtract((PyLongObject *)left_o, (PyLongObject *)right_o);
PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc);
@ -738,6 +744,8 @@
left = stack_pointer[-2];
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyFloat_CheckExact(left_o));
assert(PyFloat_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
double dres =
((PyFloatObject *)left_o)->ob_fval *
@ -759,6 +767,8 @@
left = stack_pointer[-2];
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyFloat_CheckExact(left_o));
assert(PyFloat_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
double dres =
((PyFloatObject *)left_o)->ob_fval +
@ -780,6 +790,8 @@
left = stack_pointer[-2];
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyFloat_CheckExact(left_o));
assert(PyFloat_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
double dres =
((PyFloatObject *)left_o)->ob_fval -
@ -819,6 +831,8 @@
left = stack_pointer[-2];
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyUnicode_CheckExact(left_o));
assert(PyUnicode_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
PyObject *res_o = PyUnicode_Concat(left_o, right_o);
PyStackRef_CLOSE_SPECIALIZED(left, _PyUnicode_ExactDealloc);
@ -838,6 +852,8 @@
left = stack_pointer[-2];
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyUnicode_CheckExact(left_o));
assert(PyUnicode_CheckExact(right_o));
int next_oparg;
#if TIER_ONE
assert(next_instr->op.code == STORE_FAST);

View file

@ -80,6 +80,8 @@
{
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyFloat_CheckExact(left_o));
assert(PyFloat_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
double dres =
((PyFloatObject *)left_o)->ob_fval +
@ -116,6 +118,8 @@
{
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyLong_CheckExact(left_o));
assert(PyLong_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
PyObject *res_o = _PyLong_Add((PyLongObject *)left_o, (PyLongObject *)right_o);
PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc);
@ -151,6 +155,8 @@
{
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyUnicode_CheckExact(left_o));
assert(PyUnicode_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
PyObject *res_o = PyUnicode_Concat(left_o, right_o);
PyStackRef_CLOSE_SPECIALIZED(left, _PyUnicode_ExactDealloc);
@ -185,6 +191,8 @@
{
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyUnicode_CheckExact(left_o));
assert(PyUnicode_CheckExact(right_o));
int next_oparg;
#if TIER_ONE
assert(next_instr->op.code == STORE_FAST);
@ -247,6 +255,8 @@
{
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyFloat_CheckExact(left_o));
assert(PyFloat_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
double dres =
((PyFloatObject *)left_o)->ob_fval *
@ -283,6 +293,8 @@
{
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyLong_CheckExact(left_o));
assert(PyLong_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
PyObject *res_o = _PyLong_Multiply((PyLongObject *)left_o, (PyLongObject *)right_o);
PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc);
@ -318,6 +330,8 @@
{
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyFloat_CheckExact(left_o));
assert(PyFloat_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
double dres =
((PyFloatObject *)left_o)->ob_fval -
@ -354,6 +368,8 @@
{
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyLong_CheckExact(left_o));
assert(PyLong_CheckExact(right_o));
STAT_INC(BINARY_OP, hit);
PyObject *res_o = _PyLong_Subtract((PyLongObject *)left_o, (PyLongObject *)right_o);
PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc);

View file

@ -167,23 +167,56 @@ dummy_func(void) {
}
op(_BINARY_OP, (left, right -- res)) {
PyTypeObject *ltype = sym_get_type(left);
PyTypeObject *rtype = sym_get_type(right);
if (ltype != NULL && (ltype == &PyLong_Type || ltype == &PyFloat_Type) &&
rtype != NULL && (rtype == &PyLong_Type || rtype == &PyFloat_Type))
{
if (oparg != NB_TRUE_DIVIDE && oparg != NB_INPLACE_TRUE_DIVIDE &&
ltype == &PyLong_Type && rtype == &PyLong_Type) {
/* If both inputs are ints and the op is not division the result is an int */
res = sym_new_type(ctx, &PyLong_Type);
bool lhs_int = sym_matches_type(left, &PyLong_Type);
bool rhs_int = sym_matches_type(right, &PyLong_Type);
bool lhs_float = sym_matches_type(left, &PyFloat_Type);
bool rhs_float = sym_matches_type(right, &PyFloat_Type);
if (!((lhs_int || lhs_float) && (rhs_int || rhs_float))) {
// There's something other than an int or float involved:
res = sym_new_unknown(ctx);
}
else if (oparg == NB_POWER || oparg == NB_INPLACE_POWER) {
// This one's fun... the *type* of the result depends on the
// *values* being exponentiated. However, exponents with one
// constant part are reasonably common, so it's probably worth
// trying to infer some simple cases:
// - A: 1 ** 1 -> 1 (int ** int -> int)
// - B: 1 ** -1 -> 1.0 (int ** int -> float)
// - C: 1.0 ** 1 -> 1.0 (float ** int -> float)
// - D: 1 ** 1.0 -> 1.0 (int ** float -> float)
// - E: -1 ** 0.5 ~> 1j (int ** float -> complex)
// - F: 1.0 ** 1.0 -> 1.0 (float ** float -> float)
// - G: -1.0 ** 0.5 ~> 1j (float ** float -> complex)
if (rhs_float) {
// Case D, E, F, or G... can't know without the sign of the LHS
// or whether the RHS is whole, which isn't worth the effort:
res = sym_new_unknown(ctx);
}
else {
/* For any other op combining ints/floats the result is a float */
else if (lhs_float) {
// Case C:
res = sym_new_type(ctx, &PyFloat_Type);
}
else if (!sym_is_const(right)) {
// Case A or B... can't know without the sign of the RHS:
res = sym_new_unknown(ctx);
}
else if (_PyLong_IsNegative((PyLongObject *)sym_get_const(right))) {
// Case B:
res = sym_new_type(ctx, &PyFloat_Type);
}
else {
// Case A:
res = sym_new_type(ctx, &PyLong_Type);
}
}
else if (oparg == NB_TRUE_DIVIDE || oparg == NB_INPLACE_TRUE_DIVIDE) {
res = sym_new_type(ctx, &PyFloat_Type);
}
else if (lhs_int && rhs_int) {
res = sym_new_type(ctx, &PyLong_Type);
}
else {
res = sym_new_unknown(ctx);
res = sym_new_type(ctx, &PyFloat_Type);
}
}

View file

@ -2307,23 +2307,68 @@
_Py_UopsSymbol *res;
right = stack_pointer[-1];
left = stack_pointer[-2];
PyTypeObject *ltype = sym_get_type(left);
PyTypeObject *rtype = sym_get_type(right);
if (ltype != NULL && (ltype == &PyLong_Type || ltype == &PyFloat_Type) &&
rtype != NULL && (rtype == &PyLong_Type || rtype == &PyFloat_Type))
{
if (oparg != NB_TRUE_DIVIDE && oparg != NB_INPLACE_TRUE_DIVIDE &&
ltype == &PyLong_Type && rtype == &PyLong_Type) {
/* If both inputs are ints and the op is not division the result is an int */
res = sym_new_type(ctx, &PyLong_Type);
}
else {
/* For any other op combining ints/floats the result is a float */
res = sym_new_type(ctx, &PyFloat_Type);
}
bool lhs_int = sym_matches_type(left, &PyLong_Type);
bool rhs_int = sym_matches_type(right, &PyLong_Type);
bool lhs_float = sym_matches_type(left, &PyFloat_Type);
bool rhs_float = sym_matches_type(right, &PyFloat_Type);
if (!((lhs_int || lhs_float) && (rhs_int || rhs_float))) {
// There's something other than an int or float involved:
res = sym_new_unknown(ctx);
}
else {
res = sym_new_unknown(ctx);
if (oparg == NB_POWER || oparg == NB_INPLACE_POWER) {
// This one's fun... the *type* of the result depends on the
// *values* being exponentiated. However, exponents with one
// constant part are reasonably common, so it's probably worth
// trying to infer some simple cases:
// - A: 1 ** 1 -> 1 (int ** int -> int)
// - B: 1 ** -1 -> 1.0 (int ** int -> float)
// - C: 1.0 ** 1 -> 1.0 (float ** int -> float)
// - D: 1 ** 1.0 -> 1.0 (int ** float -> float)
// - E: -1 ** 0.5 ~> 1j (int ** float -> complex)
// - F: 1.0 ** 1.0 -> 1.0 (float ** float -> float)
// - G: -1.0 ** 0.5 ~> 1j (float ** float -> complex)
if (rhs_float) {
// Case D, E, F, or G... can't know without the sign of the LHS
// or whether the RHS is whole, which isn't worth the effort:
res = sym_new_unknown(ctx);
}
else {
if (lhs_float) {
// Case C:
res = sym_new_type(ctx, &PyFloat_Type);
}
else {
if (!sym_is_const(right)) {
// Case A or B... can't know without the sign of the RHS:
res = sym_new_unknown(ctx);
}
else {
if (_PyLong_IsNegative((PyLongObject *)sym_get_const(right))) {
// Case B:
res = sym_new_type(ctx, &PyFloat_Type);
}
else {
// Case A:
res = sym_new_type(ctx, &PyLong_Type);
}
}
}
}
}
else {
if (oparg == NB_TRUE_DIVIDE || oparg == NB_INPLACE_TRUE_DIVIDE) {
res = sym_new_type(ctx, &PyFloat_Type);
}
else {
if (lhs_int && rhs_int) {
res = sym_new_type(ctx, &PyLong_Type);
}
else {
res = sym_new_type(ctx, &PyFloat_Type);
}
}
}
}
stack_pointer[-2] = res;
stack_pointer += -1;