GH-119726: Deduplicate AArch64 trampolines within a trace (GH-123872)

This commit is contained in:
Diego Russo 2024-10-02 20:07:20 +01:00 committed by GitHub
parent 7a178b7605
commit b85923a0fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 146 additions and 58 deletions

View file

@ -2,7 +2,6 @@
import dataclasses
import enum
import sys
import typing
import _schema
@ -103,8 +102,8 @@ _HOLE_EXPRS = {
HoleValue.OPERAND_HI: "(instruction->operand >> 32)",
HoleValue.OPERAND_LO: "(instruction->operand & UINT32_MAX)",
HoleValue.TARGET: "instruction->target",
HoleValue.JUMP_TARGET: "instruction_starts[instruction->jump_target]",
HoleValue.ERROR_TARGET: "instruction_starts[instruction->error_target]",
HoleValue.JUMP_TARGET: "state->instruction_starts[instruction->jump_target]",
HoleValue.ERROR_TARGET: "state->instruction_starts[instruction->error_target]",
HoleValue.ZERO: "",
}
@ -125,6 +124,7 @@ class Hole:
symbol: str | None
# ...plus this addend:
addend: int
need_state: bool = False
func: str = dataclasses.field(init=False)
# Convenience method:
replace = dataclasses.replace
@ -157,10 +157,12 @@ class Hole:
if value:
value += " + "
value += f"(uintptr_t)&{self.symbol}"
if _signed(self.addend):
if _signed(self.addend) or not value:
if value:
value += " + "
value += f"{_signed(self.addend):#x}"
if self.need_state:
return f"{self.func}({location}, {value}, state);"
return f"{self.func}({location}, {value});"
@ -175,7 +177,6 @@ class Stencil:
body: bytearray = dataclasses.field(default_factory=bytearray, init=False)
holes: list[Hole] = dataclasses.field(default_factory=list, init=False)
disassembly: list[str] = dataclasses.field(default_factory=list, init=False)
trampolines: dict[str, int] = dataclasses.field(default_factory=dict, init=False)
def pad(self, alignment: int) -> None:
"""Pad the stencil to the given alignment."""
@ -184,39 +185,6 @@ class Stencil:
self.disassembly.append(f"{offset:x}: {' '.join(['00'] * padding)}")
self.body.extend([0] * padding)
def emit_aarch64_trampoline(self, hole: Hole, alignment: int) -> Hole:
"""Even with the large code model, AArch64 Linux insists on 28-bit jumps."""
assert hole.symbol is not None
reuse_trampoline = hole.symbol in self.trampolines
if reuse_trampoline:
# Re-use the base address of the previously created trampoline
base = self.trampolines[hole.symbol]
else:
self.pad(alignment)
base = len(self.body)
new_hole = hole.replace(addend=base, symbol=None, value=HoleValue.DATA)
if reuse_trampoline:
return new_hole
self.disassembly += [
f"{base + 4 * 0:x}: 58000048 ldr x8, 8",
f"{base + 4 * 1:x}: d61f0100 br x8",
f"{base + 4 * 2:x}: 00000000",
f"{base + 4 * 2:016x}: R_AARCH64_ABS64 {hole.symbol}",
f"{base + 4 * 3:x}: 00000000",
]
for code in [
0x58000048.to_bytes(4, sys.byteorder),
0xD61F0100.to_bytes(4, sys.byteorder),
0x00000000.to_bytes(4, sys.byteorder),
0x00000000.to_bytes(4, sys.byteorder),
]:
self.body.extend(code)
self.holes.append(hole.replace(offset=base + 8, kind="R_AARCH64_ABS64"))
self.trampolines[hole.symbol] = base
return new_hole
def remove_jump(self, *, alignment: int = 1) -> None:
"""Remove a zero-length continuation jump, if it exists."""
hole = max(self.holes, key=lambda hole: hole.offset)
@ -282,8 +250,14 @@ class StencilGroup:
default_factory=dict, init=False
)
_got: dict[str, int] = dataclasses.field(default_factory=dict, init=False)
_trampolines: set[int] = dataclasses.field(default_factory=set, init=False)
def process_relocations(self, *, alignment: int = 1) -> None:
def process_relocations(
self,
known_symbols: dict[str, int],
*,
alignment: int = 1,
) -> None:
"""Fix up all GOT and internal relocations for this stencil group."""
for hole in self.code.holes.copy():
if (
@ -291,9 +265,17 @@ class StencilGroup:
in {"R_AARCH64_CALL26", "R_AARCH64_JUMP26", "ARM64_RELOC_BRANCH26"}
and hole.value is HoleValue.ZERO
):
new_hole = self.data.emit_aarch64_trampoline(hole, alignment)
self.code.holes.remove(hole)
self.code.holes.append(new_hole)
hole.func = "patch_aarch64_trampoline"
hole.need_state = True
assert hole.symbol is not None
if hole.symbol in known_symbols:
ordinal = known_symbols[hole.symbol]
else:
ordinal = len(known_symbols)
known_symbols[hole.symbol] = ordinal
self._trampolines.add(ordinal)
hole.addend = ordinal
hole.symbol = None
self.code.remove_jump(alignment=alignment)
self.code.pad(alignment)
self.data.pad(8)
@ -348,9 +330,20 @@ class StencilGroup:
)
self.data.body.extend([0] * 8)
def _get_trampoline_mask(self) -> str:
bitmask: int = 0
trampoline_mask: list[str] = []
for ordinal in self._trampolines:
bitmask |= 1 << ordinal
while bitmask:
word = bitmask & ((1 << 32) - 1)
trampoline_mask.append(f"{word:#04x}")
bitmask >>= 32
return "{" + ", ".join(trampoline_mask) + "}"
def as_c(self, opname: str) -> str:
"""Dump this hole as a StencilGroup initializer."""
return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}}}"
return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}, {self._get_trampoline_mask()}}}"
def symbol_to_value(symbol: str) -> tuple[HoleValue, str | None]:

View file

@ -44,6 +44,7 @@ class _Target(typing.Generic[_S, _R]):
stable: bool = False
debug: bool = False
verbose: bool = False
known_symbols: dict[str, int] = dataclasses.field(default_factory=dict)
def _compute_digest(self, out: pathlib.Path) -> str:
hasher = hashlib.sha256()
@ -95,7 +96,9 @@ class _Target(typing.Generic[_S, _R]):
if group.data.body:
line = f"0: {str(bytes(group.data.body)).removeprefix('b')}"
group.data.disassembly.append(line)
group.process_relocations(alignment=self.alignment)
group.process_relocations(
known_symbols=self.known_symbols, alignment=self.alignment
)
return group
def _handle_section(self, section: _S, group: _stencils.StencilGroup) -> None:
@ -231,7 +234,7 @@ class _Target(typing.Generic[_S, _R]):
if comment:
file.write(f"// {comment}\n")
file.write("\n")
for line in _writer.dump(stencil_groups):
for line in _writer.dump(stencil_groups, self.known_symbols):
file.write(f"{line}\n")
try:
jit_stencils_new.replace(jit_stencils)

View file

@ -2,17 +2,24 @@
import itertools
import typing
import math
import _stencils
def _dump_footer(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[str]:
def _dump_footer(
groups: dict[str, _stencils.StencilGroup], symbols: dict[str, int]
) -> typing.Iterator[str]:
symbol_mask_size = max(math.ceil(len(symbols) / 32), 1)
yield f'static_assert(SYMBOL_MASK_WORDS >= {symbol_mask_size}, "SYMBOL_MASK_WORDS too small");'
yield ""
yield "typedef struct {"
yield " void (*emit)("
yield " unsigned char *code, unsigned char *data, _PyExecutorObject *executor,"
yield " const _PyUOpInstruction *instruction, uintptr_t instruction_starts[]);"
yield " const _PyUOpInstruction *instruction, jit_state *state);"
yield " size_t code_size;"
yield " size_t data_size;"
yield " symbol_mask trampoline_mask;"
yield "} StencilGroup;"
yield ""
yield f"static const StencilGroup trampoline = {groups['trampoline'].as_c('trampoline')};"
@ -23,13 +30,18 @@ def _dump_footer(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[s
continue
yield f" [{opname}] = {group.as_c(opname)},"
yield "};"
yield ""
yield f"static const void * const symbols_map[{max(len(symbols), 1)}] = {{"
for symbol, ordinal in symbols.items():
yield f" [{ordinal}] = &{symbol},"
yield "};"
def _dump_stencil(opname: str, group: _stencils.StencilGroup) -> typing.Iterator[str]:
yield "void"
yield f"emit_{opname}("
yield " unsigned char *code, unsigned char *data, _PyExecutorObject *executor,"
yield " const _PyUOpInstruction *instruction, uintptr_t instruction_starts[])"
yield " const _PyUOpInstruction *instruction, jit_state *state)"
yield "{"
for part, stencil in [("code", group.code), ("data", group.data)]:
for line in stencil.disassembly:
@ -58,8 +70,10 @@ def _dump_stencil(opname: str, group: _stencils.StencilGroup) -> typing.Iterator
yield ""
def dump(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[str]:
def dump(
groups: dict[str, _stencils.StencilGroup], symbols: dict[str, int]
) -> typing.Iterator[str]:
"""Yield a JIT compiler line-by-line as a C header file."""
for opname, group in sorted(groups.items()):
yield from _dump_stencil(opname, group)
yield from _dump_footer(groups)
yield from _dump_footer(groups, symbols)