GH-130415: Use boolean guards to narrow types to values in the JIT (GH-130659)

This commit is contained in:
Brandt Bucher 2025-03-02 13:21:34 -08:00 committed by GitHub
parent c6513f7a62
commit 7afa476874
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 348 additions and 166 deletions

View file

@ -140,8 +140,11 @@
}
case _UNARY_NOT: {
JitOptSymbol *value;
JitOptSymbol *res;
res = sym_new_not_null(ctx);
value = stack_pointer[-1];
sym_set_type(value, &PyBool_Type);
res = sym_new_truthiness(ctx, value, false);
stack_pointer[-1] = res;
break;
}
@ -151,7 +154,7 @@
JitOptSymbol *res;
value = stack_pointer[-1];
if (!optimize_to_bool(this_instr, ctx, value, &res)) {
res = sym_new_type(ctx, &PyBool_Type);
res = sym_new_truthiness(ctx, value, true);
}
stack_pointer[-1] = res;
break;
@ -163,7 +166,7 @@
value = stack_pointer[-1];
if (!optimize_to_bool(this_instr, ctx, value, &res)) {
sym_set_type(value, &PyBool_Type);
res = value;
res = sym_new_truthiness(ctx, value, true);
}
stack_pointer[-1] = res;
break;
@ -268,15 +271,15 @@
JitOptSymbol *res;
right = stack_pointer[-1];
left = stack_pointer[-2];
if (sym_is_const(left) && sym_is_const(right) &&
if (sym_is_const(ctx, left) && sym_is_const(ctx, right) &&
sym_matches_type(left, &PyLong_Type) && sym_matches_type(right, &PyLong_Type))
{
assert(PyLong_CheckExact(sym_get_const(left)));
assert(PyLong_CheckExact(sym_get_const(right)));
assert(PyLong_CheckExact(sym_get_const(ctx, left)));
assert(PyLong_CheckExact(sym_get_const(ctx, right)));
stack_pointer += -2;
assert(WITHIN_STACK_BOUNDS());
PyObject *temp = _PyLong_Multiply((PyLongObject *)sym_get_const(left),
(PyLongObject *)sym_get_const(right));
PyObject *temp = _PyLong_Multiply((PyLongObject *)sym_get_const(ctx, left),
(PyLongObject *)sym_get_const(ctx, right));
if (temp == NULL) {
goto error;
}
@ -303,15 +306,15 @@
JitOptSymbol *res;
right = stack_pointer[-1];
left = stack_pointer[-2];
if (sym_is_const(left) && sym_is_const(right) &&
if (sym_is_const(ctx, left) && sym_is_const(ctx, right) &&
sym_matches_type(left, &PyLong_Type) && sym_matches_type(right, &PyLong_Type))
{
assert(PyLong_CheckExact(sym_get_const(left)));
assert(PyLong_CheckExact(sym_get_const(right)));
assert(PyLong_CheckExact(sym_get_const(ctx, left)));
assert(PyLong_CheckExact(sym_get_const(ctx, right)));
stack_pointer += -2;
assert(WITHIN_STACK_BOUNDS());
PyObject *temp = _PyLong_Add((PyLongObject *)sym_get_const(left),
(PyLongObject *)sym_get_const(right));
PyObject *temp = _PyLong_Add((PyLongObject *)sym_get_const(ctx, left),
(PyLongObject *)sym_get_const(ctx, right));
if (temp == NULL) {
goto error;
}
@ -338,15 +341,15 @@
JitOptSymbol *res;
right = stack_pointer[-1];
left = stack_pointer[-2];
if (sym_is_const(left) && sym_is_const(right) &&
if (sym_is_const(ctx, left) && sym_is_const(ctx, right) &&
sym_matches_type(left, &PyLong_Type) && sym_matches_type(right, &PyLong_Type))
{
assert(PyLong_CheckExact(sym_get_const(left)));
assert(PyLong_CheckExact(sym_get_const(right)));
assert(PyLong_CheckExact(sym_get_const(ctx, left)));
assert(PyLong_CheckExact(sym_get_const(ctx, right)));
stack_pointer += -2;
assert(WITHIN_STACK_BOUNDS());
PyObject *temp = _PyLong_Subtract((PyLongObject *)sym_get_const(left),
(PyLongObject *)sym_get_const(right));
PyObject *temp = _PyLong_Subtract((PyLongObject *)sym_get_const(ctx, left),
(PyLongObject *)sym_get_const(ctx, right));
if (temp == NULL) {
goto error;
}
@ -404,14 +407,14 @@
JitOptSymbol *res;
right = stack_pointer[-1];
left = stack_pointer[-2];
if (sym_is_const(left) && sym_is_const(right) &&
if (sym_is_const(ctx, left) && sym_is_const(ctx, right) &&
sym_matches_type(left, &PyFloat_Type) && sym_matches_type(right, &PyFloat_Type))
{
assert(PyFloat_CheckExact(sym_get_const(left)));
assert(PyFloat_CheckExact(sym_get_const(right)));
assert(PyFloat_CheckExact(sym_get_const(ctx, left)));
assert(PyFloat_CheckExact(sym_get_const(ctx, right)));
PyObject *temp = PyFloat_FromDouble(
PyFloat_AS_DOUBLE(sym_get_const(left)) *
PyFloat_AS_DOUBLE(sym_get_const(right)));
PyFloat_AS_DOUBLE(sym_get_const(ctx, left)) *
PyFloat_AS_DOUBLE(sym_get_const(ctx, right)));
if (temp == NULL) {
goto error;
}
@ -438,14 +441,14 @@
JitOptSymbol *res;
right = stack_pointer[-1];
left = stack_pointer[-2];
if (sym_is_const(left) && sym_is_const(right) &&
if (sym_is_const(ctx, left) && sym_is_const(ctx, right) &&
sym_matches_type(left, &PyFloat_Type) && sym_matches_type(right, &PyFloat_Type))
{
assert(PyFloat_CheckExact(sym_get_const(left)));
assert(PyFloat_CheckExact(sym_get_const(right)));
assert(PyFloat_CheckExact(sym_get_const(ctx, left)));
assert(PyFloat_CheckExact(sym_get_const(ctx, right)));
PyObject *temp = PyFloat_FromDouble(
PyFloat_AS_DOUBLE(sym_get_const(left)) +
PyFloat_AS_DOUBLE(sym_get_const(right)));
PyFloat_AS_DOUBLE(sym_get_const(ctx, left)) +
PyFloat_AS_DOUBLE(sym_get_const(ctx, right)));
if (temp == NULL) {
goto error;
}
@ -472,14 +475,14 @@
JitOptSymbol *res;
right = stack_pointer[-1];
left = stack_pointer[-2];
if (sym_is_const(left) && sym_is_const(right) &&
if (sym_is_const(ctx, left) && sym_is_const(ctx, right) &&
sym_matches_type(left, &PyFloat_Type) && sym_matches_type(right, &PyFloat_Type))
{
assert(PyFloat_CheckExact(sym_get_const(left)));
assert(PyFloat_CheckExact(sym_get_const(right)));
assert(PyFloat_CheckExact(sym_get_const(ctx, left)));
assert(PyFloat_CheckExact(sym_get_const(ctx, right)));
PyObject *temp = PyFloat_FromDouble(
PyFloat_AS_DOUBLE(sym_get_const(left)) -
PyFloat_AS_DOUBLE(sym_get_const(right)));
PyFloat_AS_DOUBLE(sym_get_const(ctx, left)) -
PyFloat_AS_DOUBLE(sym_get_const(ctx, right)));
if (temp == NULL) {
goto error;
}
@ -520,9 +523,9 @@
JitOptSymbol *res;
right = stack_pointer[-1];
left = stack_pointer[-2];
if (sym_is_const(left) && sym_is_const(right) &&
if (sym_is_const(ctx, left) && sym_is_const(ctx, right) &&
sym_matches_type(left, &PyUnicode_Type) && sym_matches_type(right, &PyUnicode_Type)) {
PyObject *temp = PyUnicode_Concat(sym_get_const(left), sym_get_const(right));
PyObject *temp = PyUnicode_Concat(sym_get_const(ctx, left), sym_get_const(ctx, right));
if (temp == NULL) {
goto error;
}
@ -547,9 +550,9 @@
right = stack_pointer[-1];
left = stack_pointer[-2];
JitOptSymbol *res;
if (sym_is_const(left) && sym_is_const(right) &&
if (sym_is_const(ctx, left) && sym_is_const(ctx, right) &&
sym_matches_type(left, &PyUnicode_Type) && sym_matches_type(right, &PyUnicode_Type)) {
PyObject *temp = PyUnicode_Concat(sym_get_const(left), sym_get_const(right));
PyObject *temp = PyUnicode_Concat(sym_get_const(ctx, left), sym_get_const(ctx, right));
if (temp == NULL) {
goto error;
}
@ -1159,8 +1162,8 @@
(void)dict_version;
(void)index;
attr = NULL;
if (sym_is_const(owner)) {
PyModuleObject *mod = (PyModuleObject *)sym_get_const(owner);
if (sym_is_const(ctx, owner)) {
PyModuleObject *mod = (PyModuleObject *)sym_get_const(ctx, owner);
if (PyModule_CheckExact(mod)) {
PyObject *dict = mod->md_dict;
stack_pointer[-1] = attr;
@ -1655,10 +1658,10 @@
JitOptSymbol *callable;
callable = stack_pointer[-2 - oparg];
uint32_t func_version = (uint32_t)this_instr->operand0;
if (sym_is_const(callable) && sym_matches_type(callable, &PyFunction_Type)) {
assert(PyFunction_Check(sym_get_const(callable)));
if (sym_is_const(ctx, callable) && sym_matches_type(callable, &PyFunction_Type)) {
assert(PyFunction_Check(sym_get_const(ctx, callable)));
REPLACE_OP(this_instr, _CHECK_FUNCTION_VERSION_INLINE, 0, func_version);
this_instr->operand1 = (uintptr_t)sym_get_const(callable);
this_instr->operand1 = (uintptr_t)sym_get_const(ctx, callable);
}
sym_set_type(callable, &PyFunction_Type);
break;
@ -1724,9 +1727,9 @@
self_or_null = stack_pointer[-1 - oparg];
callable = stack_pointer[-2 - oparg];
assert(sym_matches_type(callable, &PyFunction_Type));
if (sym_is_const(callable)) {
if (sym_is_const(ctx, callable)) {
if (sym_is_null(self_or_null) || sym_is_not_null(self_or_null)) {
PyFunctionObject *func = (PyFunctionObject *)sym_get_const(callable);
PyFunctionObject *func = (PyFunctionObject *)sym_get_const(ctx, callable);
PyCodeObject *co = (PyCodeObject *)func->func_code;
if (co->co_argcount == oparg + !sym_is_null(self_or_null)) {
REPLACE_OP(this_instr, _NOP, 0 ,0);
@ -2160,12 +2163,12 @@
res = sym_new_type(ctx, &PyFloat_Type);
}
else {
if (!sym_is_const(right)) {
if (!sym_is_const(ctx, right)) {
// Case A or B... can't know without the sign of the RHS:
res = sym_new_unknown(ctx);
}
else {
if (_PyLong_IsNegative((PyLongObject *)sym_get_const(right))) {
if (_PyLong_IsNegative((PyLongObject *)sym_get_const(ctx, right))) {
// Case B:
res = sym_new_type(ctx, &PyFloat_Type);
}
@ -2230,8 +2233,8 @@
case _GUARD_IS_TRUE_POP: {
JitOptSymbol *flag;
flag = stack_pointer[-1];
if (sym_is_const(flag)) {
PyObject *value = sym_get_const(flag);
if (sym_is_const(ctx, flag)) {
PyObject *value = sym_get_const(ctx, flag);
assert(value != NULL);
stack_pointer += -1;
assert(WITHIN_STACK_BOUNDS());
@ -2239,6 +2242,7 @@
stack_pointer += 1;
assert(WITHIN_STACK_BOUNDS());
}
sym_set_const(flag, Py_True);
stack_pointer += -1;
assert(WITHIN_STACK_BOUNDS());
break;
@ -2247,8 +2251,8 @@
case _GUARD_IS_FALSE_POP: {
JitOptSymbol *flag;
flag = stack_pointer[-1];
if (sym_is_const(flag)) {
PyObject *value = sym_get_const(flag);
if (sym_is_const(ctx, flag)) {
PyObject *value = sym_get_const(ctx, flag);
assert(value != NULL);
stack_pointer += -1;
assert(WITHIN_STACK_BOUNDS());
@ -2256,6 +2260,7 @@
stack_pointer += 1;
assert(WITHIN_STACK_BOUNDS());
}
sym_set_const(flag, Py_False);
stack_pointer += -1;
assert(WITHIN_STACK_BOUNDS());
break;
@ -2264,8 +2269,8 @@
case _GUARD_IS_NONE_POP: {
JitOptSymbol *flag;
flag = stack_pointer[-1];
if (sym_is_const(flag)) {
PyObject *value = sym_get_const(flag);
if (sym_is_const(ctx, flag)) {
PyObject *value = sym_get_const(ctx, flag);
assert(value != NULL);
stack_pointer += -1;
assert(WITHIN_STACK_BOUNDS());
@ -2283,14 +2288,15 @@
stack_pointer += -1;
assert(WITHIN_STACK_BOUNDS());
}
sym_set_const(flag, Py_None);
break;
}
case _GUARD_IS_NOT_NONE_POP: {
JitOptSymbol *flag;
flag = stack_pointer[-1];
if (sym_is_const(flag)) {
PyObject *value = sym_get_const(flag);
if (sym_is_const(ctx, flag)) {
PyObject *value = sym_get_const(ctx, flag);
assert(value != NULL);
stack_pointer += -1;
assert(WITHIN_STACK_BOUNDS());