GH-128682: Account for escapes in DECREF_INPUTS (GH-129953)

* Handle escapes in DECREF_INPUTS

* Mark a few more functions as escaping

* Replace DECREF_INPUTS with PyStackRef_CLOSE where possible
This commit is contained in:
Mark Shannon 2025-02-12 17:44:59 +00:00 committed by GitHub
parent 3e222e3a15
commit 72f56654d0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 2228 additions and 921 deletions

View file

@ -224,12 +224,13 @@ def array_or_scalar(var: StackItem | Local) -> str:
return "array" if var.is_array() else "scalar"
class Stack:
def __init__(self, extract_bits: bool=True) -> None:
def __init__(self, extract_bits: bool=True, cast_type: str = "uintptr_t") -> None:
self.top_offset = StackOffset.empty()
self.base_offset = StackOffset.empty()
self.variables: list[Local] = []
self.defined: set[str] = set()
self.extract_bits = extract_bits
self.cast_type = cast_type
def pop(self, var: StackItem) -> tuple[str, Local]:
self.top_offset.pop(var)
@ -298,8 +299,8 @@ class Stack:
out: CWriter,
var: StackItem,
base_offset: StackOffset,
cast_type: str = "uintptr_t",
extract_bits: bool = True,
cast_type: str,
extract_bits: bool,
) -> None:
cast = f"({cast_type})" if var.type else ""
bits = ".bits" if cast and extract_bits else ""
@ -315,9 +316,7 @@ class Stack:
out.emit(f"stack_pointer += {number};\n")
out.emit("assert(WITHIN_STACK_BOUNDS());\n")
def flush(
self, out: CWriter, cast_type: str = "uintptr_t"
) -> None:
def flush(self, out: CWriter) -> None:
out.start_line()
var_offset = self.base_offset.copy()
for var in self.variables:
@ -325,7 +324,7 @@ class Stack:
var.defined and
not var.in_memory
):
Stack._do_emit(out, var.item, var_offset, cast_type, self.extract_bits)
Stack._do_emit(out, var.item, var_offset, self.cast_type, self.extract_bits)
var.in_memory = True
var_offset.push(var.item)
number = self.top_offset.to_c()
@ -347,7 +346,7 @@ class Stack:
)
def copy(self) -> "Stack":
other = Stack(self.extract_bits)
other = Stack(self.extract_bits, self.cast_type)
other.top_offset = self.top_offset.copy()
other.base_offset = self.base_offset.copy()
other.variables = [var.copy() for var in self.variables]
@ -508,17 +507,26 @@ class Storage:
return True
return False
def flush(self, out: CWriter, cast_type: str = "uintptr_t") -> None:
def flush(self, out: CWriter) -> None:
self.clear_dead_inputs()
self._push_defined_outputs()
self.stack.flush(out, cast_type)
self.stack.flush(out)
def save(self, out: CWriter) -> None:
assert self.spilled >= 0
if self.spilled == 0:
self.flush(out)
out.start_line()
out.emit("_PyFrame_SetStackPointer(frame, stack_pointer);\n")
out.emit_spill()
self.spilled += 1
def save_inputs(self, out: CWriter) -> None:
assert self.spilled >= 0
if self.spilled == 0:
self.clear_dead_inputs()
self.stack.flush(out)
out.start_line()
out.emit_spill()
self.spilled += 1
def reload(self, out: CWriter) -> None:
@ -528,7 +536,7 @@ class Storage:
self.spilled -= 1
if self.spilled == 0:
out.start_line()
out.emit("stack_pointer = _PyFrame_GetStackPointer(frame);\n")
out.emit_reload()
@staticmethod
def for_uop(stack: Stack, uop: Uop) -> tuple[list[str], "Storage"]:
@ -637,3 +645,91 @@ class Storage:
outputs = ", ".join([var.compact_str() for var in self.outputs])
peeks = ", ".join([var.name for var in self.peeks])
return f"{stack_comment[:-2]}{next_line}inputs: {inputs}{next_line}outputs: {outputs}{next_line}peeks: {peeks} */"
def close_inputs(self, out: CWriter) -> None:
tmp_defined = False
def close_named(close: str, name: str, overwrite: str) -> None:
nonlocal tmp_defined
if overwrite:
if not tmp_defined:
out.emit("_PyStackRef ")
tmp_defined = True
out.emit(f"tmp = {name};\n")
out.emit(f"{name} = {overwrite};\n")
if not var.is_array():
var.in_memory = False
self.flush(out)
out.emit(f"{close}(tmp);\n")
else:
out.emit(f"{close}({name});\n")
def close_variable(var: Local, overwrite: str) -> None:
nonlocal tmp_defined
close = "PyStackRef_CLOSE"
if "null" in var.name or var.condition and var.condition != "1":
close = "PyStackRef_XCLOSE"
if var.size:
if var.size == "1":
close_named(close, f"{var.name}[0]", overwrite)
else:
if overwrite and not tmp_defined:
out.emit("_PyStackRef tmp;\n")
tmp_defined = True
out.emit(f"for (int _i = {var.size}; --_i >= 0;) {{\n")
close_named(close, f"{var.name}[_i]", overwrite)
out.emit("}\n")
else:
if var.condition and var.condition == "0":
return
close_named(close, var.name, overwrite)
self.clear_dead_inputs()
if not self.inputs:
return
output: Local | None = None
for var in self.outputs:
if var.is_array():
if len(self.inputs) > 1:
raise StackError("Cannot call DECREF_INPUTS with multiple live input(s) and array output")
elif var.defined:
if output is not None:
raise StackError("Cannot call DECREF_INPUTS with more than one live output")
output = var
self.save_inputs(out)
if output is not None:
lowest = self.inputs[0]
if lowest.is_array():
try:
size = int(lowest.size)
except:
size = -1
if size <= 0:
raise StackError("Cannot call DECREF_INPUTS with non fixed size array as lowest input on stack")
if size > 1:
raise StackError("Cannot call DECREF_INPUTS with array size > 1 as lowest input on stack")
output.defined = False
close_variable(lowest, output.name)
else:
lowest.in_memory = False
output.defined = False
close_variable(lowest, output.name)
to_close = self.inputs[: 0 if output is not None else None: -1]
if len(to_close) == 1 and not to_close[0].is_array():
self.reload(out)
to_close[0].defined = False
self.flush(out)
self.save_inputs(out)
close_variable(to_close[0], "")
self.reload(out)
else:
for var in to_close:
assert var.defined or var.is_array()
close_variable(var, "PyStackRef_NULL")
self.reload(out)
for var in self.inputs:
var.defined = False
if output is not None:
output.defined = True
# MyPy false positive
lowest.defined = False # type: ignore[possibly-undefined]
self.flush(out)