GH-131798: Narrow the result type of _BINARY_OP_SUBSCR_STR_INT to str in the JIT (GH-132153)

This commit is contained in:
Tomas R. 2025-04-08 17:22:54 +02:00 committed by GitHub
parent 933c6653cb
commit 71009cb835
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 27 additions and 1 deletions

View file

@ -1646,6 +1646,26 @@ class TestUopsOptimization(unittest.TestCase):
self.assertIn("_TO_BOOL_STR", uops) self.assertIn("_TO_BOOL_STR", uops)
self.assertNotIn("_GUARD_TOS_UNICODE", uops) self.assertNotIn("_GUARD_TOS_UNICODE", uops)
def test_binary_subcsr_str_int_narrows_to_str(self):
def testfunc(n):
x = []
s = "foo"
for _ in range(n):
y = s[0] # _BINARY_OP_SUBSCR_STR_INT
z = "bar" + y # (_GUARD_TOS_UNICODE) + _BINARY_OP_ADD_UNICODE
x.append(z)
return x
res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
self.assertEqual(res, ["barf"] * TIER2_THRESHOLD)
self.assertIsNotNone(ex)
uops = get_opnames(ex)
self.assertIn("_BINARY_OP_SUBSCR_STR_INT", uops)
# _BINARY_OP_SUBSCR_STR_INT narrows the result to 'str' so
# the unicode guard before _BINARY_OP_ADD_UNICODE is removed.
self.assertNotIn("_GUARD_TOS_UNICODE", uops)
self.assertIn("_BINARY_OP_ADD_UNICODE", uops)
def global_identity(x): def global_identity(x):
return x return x

View file

@ -0,0 +1,2 @@
Allow the JIT to remove unicode guards after ``_BINARY_OP_SUBSCR_STR_INT``
by setting the return type to string.

View file

@ -366,6 +366,10 @@ dummy_func(void) {
ctx->done = true; ctx->done = true;
} }
op(_BINARY_OP_SUBSCR_STR_INT, (left, right -- res)) {
res = sym_new_type(ctx, &PyUnicode_Type);
}
op(_TO_BOOL, (value -- res)) { op(_TO_BOOL, (value -- res)) {
int already_bool = optimize_to_bool(this_instr, ctx, value, &res); int already_bool = optimize_to_bool(this_instr, ctx, value, &res);
if (!already_bool) { if (!already_bool) {

View file

@ -569,7 +569,7 @@
case _BINARY_OP_SUBSCR_STR_INT: { case _BINARY_OP_SUBSCR_STR_INT: {
JitOptSymbol *res; JitOptSymbol *res;
res = sym_new_not_null(ctx); res = sym_new_type(ctx, &PyUnicode_Type);
stack_pointer[-2] = res; stack_pointer[-2] = res;
stack_pointer += -1; stack_pointer += -1;
assert(WITHIN_STACK_BOUNDS()); assert(WITHIN_STACK_BOUNDS());