GH-135379: Support limited scalar replacement for replicated uops in the JIT code generator. (GH-135563)

* Use it to support efficient specializations of COPY and SWAP in the JIT.
This commit is contained in:
Mark Shannon 2025-06-17 13:43:09 +01:00 committed by GitHub
parent a9e66a7c50
commit 8dd8b5c2f0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 313 additions and 198 deletions

View file

@ -180,7 +180,7 @@ class Uop:
properties: Properties
_size: int = -1
implicitly_created: bool = False
replicated = 0
replicated = range(0)
replicates: "Uop | None" = None
# Size of the instruction(s), only set for uops containing the INSTRUCTION_SIZE macro
instruction_size: int | None = None
@ -868,6 +868,28 @@ def compute_properties(op: parser.CodeDef) -> Properties:
needs_prev=variable_used(op, "prev_instr"),
)
def expand(items: list[StackItem], oparg: int) -> list[StackItem]:
# Only replace array item with scalar if no more than one item is an array
index = -1
for i, item in enumerate(items):
if "oparg" in item.size:
if index >= 0:
return items
index = i
if index < 0:
return items
try:
count = int(eval(items[index].size.replace("oparg", str(oparg))))
except ValueError:
return items
return items[:index] + [
StackItem(items[index].name + f"_{i}", "", items[index].peek, items[index].used) for i in range(count)
] + items[index+1:]
def scalarize_stack(stack: StackEffect, oparg: int) -> StackEffect:
stack.inputs = expand(stack.inputs, oparg)
stack.outputs = expand(stack.outputs, oparg)
return stack
def make_uop(
name: str,
@ -887,20 +909,26 @@ def make_uop(
)
for anno in op.annotations:
if anno.startswith("replicate"):
result.replicated = int(anno[10:-1])
text = anno[10:-1]
start, stop = text.split(":")
result.replicated = range(int(start), int(stop))
break
else:
return result
for oparg in range(result.replicated):
for oparg in result.replicated:
name_x = name + "_" + str(oparg)
properties = compute_properties(op)
properties.oparg = False
properties.const_oparg = oparg
stack = analyze_stack(op)
if not variable_used(op, "oparg"):
stack = scalarize_stack(stack, oparg)
else:
properties.const_oparg = oparg
rep = Uop(
name=name_x,
context=op.context,
annotations=op.annotations,
stack=analyze_stack(op),
stack=stack,
caches=analyze_caches(inputs),
local_stores=find_variable_stores(op),
body=op.block,