gh-93143: Avoid NULL check in LOAD_FAST based on analysis in the compiler (GH-93144)

This commit is contained in:
Dennis Sweeney 2022-05-31 16:32:30 -04:00 committed by GitHub
parent 8a5e3c2ec6
commit f425f3bb27
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 371 additions and 52 deletions

View file

@ -1,5 +1,6 @@
import dis
from itertools import combinations, product
import sys
import textwrap
import unittest
@ -682,5 +683,184 @@ class TestBuglets(unittest.TestCase):
compile("while True or not spam: pass", "<test>", "exec")
class TestMarkingVariablesAsUnKnown(BytecodeTestCase):
def setUp(self):
self.addCleanup(sys.settrace, sys.gettrace())
sys.settrace(None)
def test_load_fast_known_simple(self):
def f():
x = 1
y = x + x
self.assertInBytecode(f, 'LOAD_FAST')
def test_load_fast_unknown_simple(self):
def f():
if condition():
x = 1
print(x)
self.assertInBytecode(f, 'LOAD_FAST_CHECK')
self.assertNotInBytecode(f, 'LOAD_FAST')
def test_load_fast_unknown_because_del(self):
def f():
x = 1
del x
print(x)
self.assertInBytecode(f, 'LOAD_FAST_CHECK')
self.assertNotInBytecode(f, 'LOAD_FAST')
def test_load_fast_known_because_parameter(self):
def f1(x):
print(x)
self.assertInBytecode(f1, 'LOAD_FAST')
self.assertNotInBytecode(f1, 'LOAD_FAST_CHECK')
def f2(*, x):
print(x)
self.assertInBytecode(f2, 'LOAD_FAST')
self.assertNotInBytecode(f2, 'LOAD_FAST_CHECK')
def f3(*args):
print(args)
self.assertInBytecode(f3, 'LOAD_FAST')
self.assertNotInBytecode(f3, 'LOAD_FAST_CHECK')
def f4(**kwargs):
print(kwargs)
self.assertInBytecode(f4, 'LOAD_FAST')
self.assertNotInBytecode(f4, 'LOAD_FAST_CHECK')
def f5(x=0):
print(x)
self.assertInBytecode(f5, 'LOAD_FAST')
self.assertNotInBytecode(f5, 'LOAD_FAST_CHECK')
def test_load_fast_known_because_already_loaded(self):
def f():
if condition():
x = 1
print(x)
print(x)
self.assertInBytecode(f, 'LOAD_FAST_CHECK')
self.assertInBytecode(f, 'LOAD_FAST')
def test_load_fast_known_multiple_branches(self):
def f():
if condition():
x = 1
else:
x = 2
print(x)
self.assertInBytecode(f, 'LOAD_FAST')
self.assertNotInBytecode(f, 'LOAD_FAST_CHECK')
def test_load_fast_unknown_after_error(self):
def f():
try:
res = 1 / 0
except ZeroDivisionError:
pass
return res
# LOAD_FAST (known) still occurs in the no-exception branch.
# Assert that it doesn't occur in the LOAD_FAST_CHECK branch.
self.assertInBytecode(f, 'LOAD_FAST_CHECK')
def test_load_fast_unknown_after_error_2(self):
def f():
try:
1 / 0
except:
print(a, b, c, d, e, f, g)
a = b = c = d = e = f = g = 1
self.assertInBytecode(f, 'LOAD_FAST_CHECK')
self.assertNotInBytecode(f, 'LOAD_FAST')
def test_setting_lineno_adds_check(self):
code = textwrap.dedent("""\
def f():
x = 2
L = 3
L = 4
for i in range(55):
x + 6
del x
L = 8
L = 9
L = 10
""")
ns = {}
exec(code, ns)
f = ns['f']
self.assertInBytecode(f, "LOAD_FAST")
def trace(frame, event, arg):
if event == 'line' and frame.f_lineno == 9:
frame.f_lineno = 2
sys.settrace(None)
return None
return trace
sys.settrace(trace)
f()
self.assertNotInBytecode(f, "LOAD_FAST")
def make_function_with_no_checks(self):
code = textwrap.dedent("""\
def f():
x = 2
L = 3
L = 4
L = 5
if not L:
x + 7
y = 2
""")
ns = {}
exec(code, ns)
f = ns['f']
self.assertInBytecode(f, "LOAD_FAST")
self.assertNotInBytecode(f, "LOAD_FAST_CHECK")
return f
def test_deleting_local_adds_check(self):
f = self.make_function_with_no_checks()
def trace(frame, event, arg):
if event == 'line' and frame.f_lineno == 4:
del frame.f_locals["x"]
sys.settrace(None)
return None
return trace
sys.settrace(trace)
f()
self.assertNotInBytecode(f, "LOAD_FAST")
self.assertInBytecode(f, "LOAD_FAST_CHECK")
def test_modifying_local_does_not_add_check(self):
f = self.make_function_with_no_checks()
def trace(frame, event, arg):
if event == 'line' and frame.f_lineno == 4:
frame.f_locals["x"] = 42
sys.settrace(None)
return None
return trace
sys.settrace(trace)
f()
self.assertInBytecode(f, "LOAD_FAST")
self.assertNotInBytecode(f, "LOAD_FAST_CHECK")
def test_initializing_local_does_not_add_check(self):
f = self.make_function_with_no_checks()
def trace(frame, event, arg):
if event == 'line' and frame.f_lineno == 4:
frame.f_locals["y"] = 42
sys.settrace(None)
return None
return trace
sys.settrace(trace)
f()
self.assertInBytecode(f, "LOAD_FAST")
self.assertNotInBytecode(f, "LOAD_FAST_CHECK")
if __name__ == "__main__":
unittest.main()