From 560607812df24abd2e2a48a1911bed2c61834686 Mon Sep 17 00:00:00 2001 From: Brandt Bucher Date: Wed, 11 Jun 2025 18:48:45 -0700 Subject: [PATCH 01/20] Do some textual assembly magic --- Tools/jit/_stencils.py | 79 +++----------- Tools/jit/_targets.py | 237 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 250 insertions(+), 66 deletions(-) diff --git a/Tools/jit/_stencils.py b/Tools/jit/_stencils.py index 03b0ba647b0db7..4cab3fea6ec694 100644 --- a/Tools/jit/_stencils.py +++ b/Tools/jit/_stencils.py @@ -17,8 +17,6 @@ class HoleValue(enum.Enum): # The base address of the machine code for the current uop (exposed as _JIT_ENTRY): CODE = enum.auto() - # The base address of the machine code for the next uop (exposed as _JIT_CONTINUE): - CONTINUE = enum.auto() # The base address of the read-only data for this uop: DATA = enum.auto() # The address of the current executor (exposed as _JIT_EXECUTOR): @@ -97,7 +95,6 @@ class HoleValue(enum.Enum): # Translate HoleValues to C expressions: _HOLE_EXPRS = { HoleValue.CODE: "(uintptr_t)code", - HoleValue.CONTINUE: "(uintptr_t)code + sizeof(code_body)", HoleValue.DATA: "(uintptr_t)data", HoleValue.EXECUTOR: "(uintptr_t)executor", # These should all have been turned into DATA values by process_relocations: @@ -209,63 +206,22 @@ def pad(self, alignment: int) -> None: self.disassembly.append(f"{offset:x}: {' '.join(['00'] * padding)}") self.body.extend([0] * padding) - def add_nops(self, nop: bytes, alignment: int) -> None: - """Add NOPs until there is alignment. Fail if it is not possible.""" - offset = len(self.body) - nop_size = len(nop) - - # Calculate the gap to the next multiple of alignment. - gap = -offset % alignment - if gap: - if gap % nop_size == 0: - count = gap // nop_size - self.body.extend(nop * count) - else: - raise ValueError( - f"Cannot add nops of size '{nop_size}' to a body with " - f"offset '{offset}' to align with '{alignment}'" - ) - - def remove_jump(self) -> None: - """Remove a zero-length continuation jump, if it exists.""" - hole = max(self.holes, key=lambda hole: hole.offset) - match hole: - case Hole( - offset=offset, - kind="IMAGE_REL_AMD64_REL32", - value=HoleValue.GOT, - symbol="_JIT_CONTINUE", - addend=-4, - ) as hole: - # jmp qword ptr [rip] - jump = b"\x48\xff\x25\x00\x00\x00\x00" - offset -= 3 - case Hole( - offset=offset, - kind="IMAGE_REL_I386_REL32" | "R_X86_64_PLT32" | "X86_64_RELOC_BRANCH", - value=HoleValue.CONTINUE, - symbol=None, - addend=addend, - ) as hole if ( - _signed(addend) == -4 - ): - # jmp 5 - jump = b"\xe9\x00\x00\x00\x00" - offset -= 1 - case Hole( - offset=offset, - kind="R_AARCH64_JUMP26", - value=HoleValue.CONTINUE, - symbol=None, - addend=0, - ) as hole: - # b #4 - jump = b"\x00\x00\x00\x14" - case _: - return - if self.body[offset:] == jump: - self.body = self.body[:offset] - self.holes.remove(hole) + # def add_nops(self, nop: bytes, alignment: int) -> None: + # """Add NOPs until there is alignment. Fail if it is not possible.""" + # offset = len(self.body) + # nop_size = len(nop) + + # # Calculate the gap to the next multiple of alignment. + # gap = -offset % alignment + # if gap: + # if gap % nop_size == 0: + # count = gap // nop_size + # self.body.extend(nop * count) + # else: + # raise ValueError( + # f"Cannot add nops of size '{nop_size}' to a body with " + # f"offset '{offset}' to align with '{alignment}'" + # ) @dataclasses.dataclass @@ -306,8 +262,7 @@ def process_relocations( self._trampolines.add(ordinal) hole.addend = ordinal hole.symbol = None - self.code.remove_jump() - self.code.add_nops(nop=nop, alignment=alignment) + # self.code.add_nops(nop=nop, alignment=alignment) self.data.pad(8) for stencil in [self.code, self.data]: for hole in stencil.holes: diff --git a/Tools/jit/_targets.py b/Tools/jit/_targets.py index d0a1c081ffecc2..1fa6579e7a58ac 100644 --- a/Tools/jit/_targets.py +++ b/Tools/jit/_targets.py @@ -3,6 +3,7 @@ import asyncio import dataclasses import hashlib +import itertools import json import os import pathlib @@ -34,6 +35,223 @@ "_R", _schema.COFFRelocation, _schema.ELFRelocation, _schema.MachORelocation ) +inverted_branches = {} +for op, nop in [ + ("ja", "jna"), + ("jae", "jnae"), + ("jb", "jnb"), + ("jbe", "jnbe"), + ("jc", "jnc"), + ("je", "jne"), + ("jg", "jng"), + ("jge", "jnge"), + ("jl", "jnl"), + ("jle", "jnle"), + ("jo", "jno"), + ("jp", "jnp"), + ("js", "jns"), + ("jz", "jnz"), + ("jpe", "jpo"), + ("jcxz", None), + ("jecxz", None), + ("jrxz", None), + ("loop", None), + ("loope", None), + ("loopne", None), + ("loopnz", None), + ("loopz", None), +]: + inverted_branches[op] = nop + if nop is not None: + inverted_branches[nop] = op + + +@dataclasses.dataclass +class _Line: + text: str + hot: bool = dataclasses.field(init=False, default=False) + predecessors: list["_Line"] = dataclasses.field( + init=False, repr=False, default_factory=list + ) + + +@dataclasses.dataclass +class _Label(_Line): + label: str + + +@dataclasses.dataclass +class _Jump(_Line): + target: _Label | None + + def remove(self) -> None: + self.text = "" + if self.target is not None: + self.target.predecessors.remove(self) + if not self.target.predecessors: + self.target.text = "" + self.target = None + + def update(self, target: _Label) -> None: + assert self.target is not None + self.target.predecessors.remove(self) + assert self.target.label in self.text + self.text = self.text.replace(self.target.label, target.label) + self.target = target + self.target.predecessors.append(self) + + +@dataclasses.dataclass +class _Branch(_Line): + op: str + target: _Label + fallthrough: _Line | None = None + + def update(self, target: _Label) -> None: + assert self.target is not None + self.target.predecessors.remove(self) + assert self.target.label in self.text + self.text = self.text.replace(self.target.label, target.label) + self.target = target + self.target.predecessors.append(self) + + def invert(self, jump: _Jump) -> bool: + inverted = inverted_branches[self.op] + if inverted is None or jump.target is None: + return False + assert self.op in self.text + self.text = self.text.replace(self.op, inverted) + old_target = self.target + self.update(jump.target) + jump.update(old_target) + return True + + +@dataclasses.dataclass +class _Return(_Line): + pass + + +@dataclasses.dataclass +class _Noise(_Line): + pass + + +def _branch( + line: str, use_label: typing.Callable[[str], _Label] +) -> tuple[str, _Label] | None: + branch = re.match(rf"\s*({'|'.join(inverted_branches)})\s+([\w\.]+)", line) + return branch and (branch.group(1), use_label(branch.group(2))) + + +def _jump(line: str, use_label: typing.Callable[[str], _Label]) -> _Label | None: + if jump := re.match(r"\s*jmp\s+([\w\.]+)", line): + return use_label(jump.group(1)) + return None + + +def _label(line: str, use_label: typing.Callable[[str], _Label]) -> _Label | None: + label = re.match(r"\s*([\w\.]+):", line) + return label and use_label(label.group(1)) + + +def _return(line: str) -> bool: + return re.match(r"\s*ret\s+", line) is not None + + +def _noise(line: str) -> bool: + return re.match(r"\s*[#\.]|\s*$", line) is not None + + +def _apply_asm_transformations(path: pathlib.Path) -> None: + labels = {} + + def use_label(label: str) -> _Label: + if label not in labels: + labels[label] = _Label("", label) + return labels[label] + + def new_line(text: str) -> _Line: + if branch := _branch(text, use_label): + op, label = branch + line = _Branch(text, op, label) + label.predecessors.append(line) + return line + if label := _jump(text, use_label): + line = _Jump(text, label) + label.predecessors.append(line) + return line + if line := _label(text, use_label): + assert line.text == "" + line.text = text + return line + if _return(text): + return _Return(text) + if _noise(text): + return _Noise(text) + return _Line(text) + + # Build graph: + lines = [] + line = _Noise("") # Dummy. + with path.open() as file: + for i, text in enumerate(file): + new = new_line(text) + if not isinstance(line, (_Jump, _Return)): + new.predecessors.append(line) + lines.append(new) + line = new + for i, line in enumerate(reversed(lines)): + if not isinstance(line, (_Label, _Noise)): + break + new = new_line("_JIT_CONTINUE:\n") + lines.insert(len(lines) - i, new) + line = new + # Mark hot lines: + todo = labels["_JIT_CONTINUE"].predecessors.copy() + while todo: + line = todo.pop() + line.hot = True + for predecessor in line.predecessors: + if not predecessor.hot: + todo.append(predecessor) + for pair in itertools.pairwise( + filter(lambda line: line.text and not isinstance(line, _Noise), lines) + ): + match pair: + case (_Branch(hot=True) as branch, _Jump(hot=False) as jump): + branch.invert(jump) + jump.hot = True + for pair in itertools.pairwise(lines): + match pair: + case (_Jump() | _Return(), _): + pass + case (_Line(hot=True), _Line(hot=False) as cold): + cold.hot = True + # Reorder blocks: + hot = [] + cold = [] + for line in lines: + if line.hot: + hot.append(line) + else: + cold.append(line) + lines = hot + cold + # Remove zero-length jumps: + again = True + while again: + again = False + for pair in itertools.pairwise( + filter(lambda line: line.text and not isinstance(line, _Noise), lines) + ): + match pair: + case (_Jump(target=target) as jump, label) if target is label: + jump.remove() + again = True + # Write new assembly: + with path.open("w") as file: + file.writelines(line.text for line in lines) + @dataclasses.dataclass class _Target(typing.Generic[_S, _R]): @@ -118,8 +336,9 @@ def _handle_relocation( async def _compile( self, opname: str, c: pathlib.Path, tempdir: pathlib.Path ) -> _stencils.StencilGroup: + s = tempdir / f"{opname}.s" o = tempdir / f"{opname}.o" - args = [ + args_s = [ f"--target={self.triple}", "-DPy_BUILD_CORE_MODULE", "-D_DEBUG" if self.debug else "-DNDEBUG", @@ -133,7 +352,8 @@ async def _compile( f"-I{CPYTHON / 'Python'}", f"-I{CPYTHON / 'Tools' / 'jit'}", "-O3", - "-c", + "-S", + # "-c", # Shorten full absolute file paths in the generated code (like the # __FILE__ macro and assert failure messages) for reproducibility: f"-ffile-prefix-map={CPYTHON}=.", @@ -152,11 +372,20 @@ async def _compile( "-fno-stack-protector", "-std=c11", "-o", - f"{o}", + f"{s}", f"{c}", *self.args, ] - await _llvm.run("clang", args, echo=self.verbose) + await _llvm.run("clang", args_s, echo=self.verbose) + _apply_asm_transformations(s) + args_o = [ + f"--target={self.triple}", + "-c", + "-o", + f"{o}", + f"{s}", + ] + await _llvm.run("clang", args_o, echo=self.verbose) return await self._parse(o) async def _build_stencils(self) -> dict[str, _stencils.StencilGroup]: From 858624af55e42b7c7eceabf21877f744b00d0ee8 Mon Sep 17 00:00:00 2001 From: Brandt Bucher Date: Thu, 12 Jun 2025 13:37:31 -0700 Subject: [PATCH 02/20] Switch to a linked list --- Tools/jit/_targets.py | 215 ++++++++++++++++++++++++------------------ 1 file changed, 123 insertions(+), 92 deletions(-) diff --git a/Tools/jit/_targets.py b/Tools/jit/_targets.py index 1fa6579e7a58ac..103dac7a15664d 100644 --- a/Tools/jit/_targets.py +++ b/Tools/jit/_targets.py @@ -35,44 +35,59 @@ "_R", _schema.COFFRelocation, _schema.ELFRelocation, _schema.MachORelocation ) -inverted_branches = {} +branches = {} for op, nop in [ ("ja", "jna"), ("jae", "jnae"), ("jb", "jnb"), ("jbe", "jnbe"), ("jc", "jnc"), + ("jcxz", None), ("je", "jne"), + ("jecxz", None), ("jg", "jng"), ("jge", "jnge"), ("jl", "jnl"), ("jle", "jnle"), ("jo", "jno"), ("jp", "jnp"), - ("js", "jns"), - ("jz", "jnz"), ("jpe", "jpo"), - ("jcxz", None), - ("jecxz", None), ("jrxz", None), + ("js", "jns"), + ("jz", "jnz"), ("loop", None), ("loope", None), ("loopne", None), ("loopnz", None), ("loopz", None), ]: - inverted_branches[op] = nop + branches[op] = nop if nop is not None: - inverted_branches[nop] = op + branches[nop] = op @dataclasses.dataclass class _Line: + fallthrough: typing.ClassVar[bool] = True text: str hot: bool = dataclasses.field(init=False, default=False) predecessors: list["_Line"] = dataclasses.field( init=False, repr=False, default_factory=list ) + link: "_Line | None" = dataclasses.field(init=False, repr=False, default=None) + + def heat(self) -> None: + if self.hot: + return + self.hot = True + for predecessor in self.predecessors: + predecessor.heat() + if self.fallthrough and self.link is not None: + self.link.heat() + + def optimize(self) -> None: + if self.link is not None: + self.link.optimize() @dataclasses.dataclass @@ -82,18 +97,23 @@ class _Label(_Line): @dataclasses.dataclass class _Jump(_Line): - target: _Label | None + fallthrough = False + target: _Label + + def optimize(self) -> None: + super().optimize() + target_aliases = _aliases(self.target) + if any(alias in target_aliases for alias in _aliases(self.link)): + self.remove() def remove(self) -> None: - self.text = "" - if self.target is not None: - self.target.predecessors.remove(self) - if not self.target.predecessors: - self.target.text = "" - self.target = None + [predecessor] = self.predecessors + assert predecessor.link is self + self.target.predecessors.remove(self) + predecessor.link = self.link + self.target.predecessors.append(predecessor) def update(self, target: _Label) -> None: - assert self.target is not None self.target.predecessors.remove(self) assert self.target.label in self.text self.text = self.text.replace(self.target.label, target.label) @@ -105,10 +125,15 @@ def update(self, target: _Label) -> None: class _Branch(_Line): op: str target: _Label - fallthrough: _Line | None = None + + def optimize(self) -> None: + super().optimize() + if self.target.hot: + for jump in _aliases(self.link): + if isinstance(jump, _Jump) and self.invert(jump): + jump.optimize() def update(self, target: _Label) -> None: - assert self.target is not None self.target.predecessors.remove(self) assert self.target.label in self.text self.text = self.text.replace(self.target.label, target.label) @@ -116,11 +141,12 @@ def update(self, target: _Label) -> None: self.target.predecessors.append(self) def invert(self, jump: _Jump) -> bool: - inverted = inverted_branches[self.op] - if inverted is None or jump.target is None: + inverted = branches[self.op] + if inverted is None: return False assert self.op in self.text self.text = self.text.replace(self.op, inverted) + self.op = inverted old_target = self.target self.update(jump.target) jump.update(old_target) @@ -129,7 +155,7 @@ def invert(self, jump: _Jump) -> bool: @dataclasses.dataclass class _Return(_Line): - pass + fallthrough = False @dataclasses.dataclass @@ -140,7 +166,7 @@ class _Noise(_Line): def _branch( line: str, use_label: typing.Callable[[str], _Label] ) -> tuple[str, _Label] | None: - branch = re.match(rf"\s*({'|'.join(inverted_branches)})\s+([\w\.]+)", line) + branch = re.match(rf"\s*({'|'.join(branches)})\s+([\w\.]+)", line) return branch and (branch.group(1), use_label(branch.group(2))) @@ -163,25 +189,57 @@ def _noise(line: str) -> bool: return re.match(r"\s*[#\.]|\s*$", line) is not None -def _apply_asm_transformations(path: pathlib.Path) -> None: - labels = {} +def _aliases(line: _Line | None) -> list[_Line]: + aliases = [] + while line is not None and isinstance(line, (_Label, _Noise)): + aliases.append(line) + line = line.link + if line is not None: + aliases.append(line) + return aliases - def use_label(label: str) -> _Label: - if label not in labels: - labels[label] = _Label("", label) - return labels[label] - def new_line(text: str) -> _Line: - if branch := _branch(text, use_label): +@dataclasses.dataclass +class _AssemblyTransformer: + _path: pathlib.Path + _alignment: int = 1 + _lines: _Line = dataclasses.field(init=False) + _labels: dict[str, _Label] = dataclasses.field(init=False, default_factory=dict) + _ran: bool = dataclasses.field(init=False, default=False) + + def __post_init__(self) -> None: + dummy = current = _Noise("") + for line in self._path.read_text().splitlines(True): + new = self._new_line(line) + if current.fallthrough: + new.predecessors.append(current) + current.link = new + current = new + assert dummy.link is not None + self._lines = dummy.link + + def __iter__(self) -> typing.Iterator[_Line]: + line = self._lines + while line is not None: + yield line + line = line.link + + def _use_label(self, label: str) -> _Label: + if label not in self._labels: + self._labels[label] = _Label("", label) + return self._labels[label] + + def _new_line(self, text: str) -> _Line: + if branch := _branch(text, self._use_label): op, label = branch line = _Branch(text, op, label) label.predecessors.append(line) return line - if label := _jump(text, use_label): + if label := _jump(text, self._use_label): line = _Jump(text, label) label.predecessors.append(line) return line - if line := _label(text, use_label): + if line := _label(text, self._use_label): assert line.text == "" line.text = text return line @@ -191,66 +249,39 @@ def new_line(text: str) -> _Line: return _Noise(text) return _Line(text) - # Build graph: - lines = [] - line = _Noise("") # Dummy. - with path.open() as file: - for i, text in enumerate(file): - new = new_line(text) - if not isinstance(line, (_Jump, _Return)): - new.predecessors.append(line) - lines.append(new) - line = new - for i, line in enumerate(reversed(lines)): - if not isinstance(line, (_Label, _Noise)): - break - new = new_line("_JIT_CONTINUE:\n") - lines.insert(len(lines) - i, new) - line = new - # Mark hot lines: - todo = labels["_JIT_CONTINUE"].predecessors.copy() - while todo: - line = todo.pop() - line.hot = True - for predecessor in line.predecessors: - if not predecessor.hot: - todo.append(predecessor) - for pair in itertools.pairwise( - filter(lambda line: line.text and not isinstance(line, _Noise), lines) - ): - match pair: - case (_Branch(hot=True) as branch, _Jump(hot=False) as jump): - branch.invert(jump) - jump.hot = True - for pair in itertools.pairwise(lines): - match pair: - case (_Jump() | _Return(), _): - pass - case (_Line(hot=True), _Line(hot=False) as cold): - cold.hot = True - # Reorder blocks: - hot = [] - cold = [] - for line in lines: - if line.hot: - hot.append(line) - else: - cold.append(line) - lines = hot + cold - # Remove zero-length jumps: - again = True - while again: - again = False - for pair in itertools.pairwise( - filter(lambda line: line.text and not isinstance(line, _Noise), lines) - ): - match pair: - case (_Jump(target=target) as jump, label) if target is label: - jump.remove() - again = True - # Write new assembly: - with path.open("w") as file: - file.writelines(line.text for line in lines) + def _dump(self) -> str: + return "".join(line.text for line in self) + + def _break_on(self, name: str) -> None: + if self._path.stem == name: + print(self._dump()) + breakpoint() + + def run(self) -> None: + assert not self._ran + self._ran = True + last_line = None + for line in self: + if not isinstance(line, (_Label, _Noise)): + last_line = line + assert last_line is not None + new = self._new_line(f".balign {self._alignment}\n") + new.link = last_line.link + last_line.link = new + new = self._new_line("_JIT_CONTINUE:\n") + new.link = last_line.link + last_line.link = new + # Mark hot lines and optimize: + recursion_limit = sys.getrecursionlimit() + sys.setrecursionlimit(10_000) + try: + self._labels["_JIT_CONTINUE"].heat() + # self._break_on("_BUILD_TUPLE") + self._lines.optimize() + finally: + sys.setrecursionlimit(recursion_limit) + # Write new assembly: + self._path.write_text(self._dump()) @dataclasses.dataclass @@ -377,7 +408,7 @@ async def _compile( *self.args, ] await _llvm.run("clang", args_s, echo=self.verbose) - _apply_asm_transformations(s) + _AssemblyTransformer(s, self.alignment).run() args_o = [ f"--target={self.triple}", "-c", From 77886d09ef1019f5c250135efad6cc06147e1064 Mon Sep 17 00:00:00 2001 From: Brandt Bucher Date: Thu, 12 Jun 2025 14:35:26 -0700 Subject: [PATCH 03/20] Handle prefixes --- Tools/jit/_targets.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/Tools/jit/_targets.py b/Tools/jit/_targets.py index 103dac7a15664d..73f30fbc4de856 100644 --- a/Tools/jit/_targets.py +++ b/Tools/jit/_targets.py @@ -201,15 +201,16 @@ def _aliases(line: _Line | None) -> list[_Line]: @dataclasses.dataclass class _AssemblyTransformer: - _path: pathlib.Path - _alignment: int = 1 + path: pathlib.Path + alignment: int = 1 + prefix: str = "" _lines: _Line = dataclasses.field(init=False) _labels: dict[str, _Label] = dataclasses.field(init=False, default_factory=dict) _ran: bool = dataclasses.field(init=False, default=False) def __post_init__(self) -> None: dummy = current = _Noise("") - for line in self._path.read_text().splitlines(True): + for line in self.path.read_text().splitlines(True): new = self._new_line(line) if current.fallthrough: new.predecessors.append(current) @@ -253,7 +254,7 @@ def _dump(self) -> str: return "".join(line.text for line in self) def _break_on(self, name: str) -> None: - if self._path.stem == name: + if self.path.stem == name: print(self._dump()) breakpoint() @@ -265,23 +266,23 @@ def run(self) -> None: if not isinstance(line, (_Label, _Noise)): last_line = line assert last_line is not None - new = self._new_line(f".balign {self._alignment}\n") + new = self._new_line(f".balign {self.alignment}\n") new.link = last_line.link last_line.link = new - new = self._new_line("_JIT_CONTINUE:\n") + new = self._new_line(f"{self.prefix}_JIT_CONTINUE:\n") new.link = last_line.link last_line.link = new # Mark hot lines and optimize: recursion_limit = sys.getrecursionlimit() sys.setrecursionlimit(10_000) try: - self._labels["_JIT_CONTINUE"].heat() + self._labels[f"{self.prefix}_JIT_CONTINUE"].heat() # self._break_on("_BUILD_TUPLE") self._lines.optimize() finally: sys.setrecursionlimit(recursion_limit) # Write new assembly: - self._path.write_text(self._dump()) + self.path.write_text(self._dump()) @dataclasses.dataclass @@ -408,7 +409,7 @@ async def _compile( *self.args, ] await _llvm.run("clang", args_s, echo=self.verbose) - _AssemblyTransformer(s, self.alignment).run() + _AssemblyTransformer(s, alignment=self.alignment, prefix=self.prefix).run() args_o = [ f"--target={self.triple}", "-c", From da7207929c6a7cd7315624d3d530dc7cc32b4271 Mon Sep 17 00:00:00 2001 From: Brandt Bucher Date: Thu, 19 Jun 2025 15:44:24 -0700 Subject: [PATCH 04/20] Rework optimizer and break it out into its own module --- Tools/jit/_optimizers.py | 292 +++++++++++++++++++++++++++++++++++++++ Tools/jit/_targets.py | 272 +++--------------------------------- 2 files changed, 309 insertions(+), 255 deletions(-) create mode 100644 Tools/jit/_optimizers.py diff --git a/Tools/jit/_optimizers.py b/Tools/jit/_optimizers.py new file mode 100644 index 00000000000000..cdccce42384974 --- /dev/null +++ b/Tools/jit/_optimizers.py @@ -0,0 +1,292 @@ +import collections +import dataclasses +import pathlib +import re +import typing + +branches = {} +for op, nop in [ + ("ja", "jna"), + ("jae", "jnae"), + ("jb", "jnb"), + ("jbe", "jnbe"), + ("jc", "jnc"), + ("jcxz", None), + ("je", "jne"), + ("jecxz", None), + ("jg", "jng"), + ("jge", "jnge"), + ("jl", "jnl"), + ("jle", "jnle"), + ("jo", "jno"), + ("jp", "jnp"), + ("jpe", "jpo"), + ("jrxz", None), + ("js", "jns"), + ("jz", "jnz"), + ("loop", None), + ("loope", None), + ("loopne", None), + ("loopnz", None), + ("loopz", None), +]: + branches[op] = nop + if nop: + branches[nop] = op + + +def _get_branch(line: str) -> str | None: + branch = re.match(rf"\s*({'|'.join(branches)})\s+([\w\.]+)", line) + return branch and branch[2] + + +def _invert_branch(line: str, label: str) -> str | None: + branch = re.match(rf"\s*({'|'.join(branches)})\s+([\w\.]+)", line) + assert branch + inverted = branches.get(branch[1]) + if not inverted: + return None + line = line.replace(branch[1], inverted, 1) # XXX + line = line.replace(branch[2], label, 1) # XXX + return line + + +def _get_jump(line: str) -> str | None: + jump = re.match(r"\s*(?:rex64\s+)?jmpq?\s+\*?([\w\.]+)", line) + return jump and jump[1] + + +def _get_label(line: str) -> str | None: + label = re.match(r"\s*([\w\.]+):", line) + return label and label[1] + + +def _is_return(line: str) -> bool: + return re.match(r"\s*ret\s+", line) is not None + + +def _is_noise(line: str) -> bool: + return re.match(r"\s*([#\.]|$)", line) is not None + + +@dataclasses.dataclass +class _Block: + label: str + noise: list[str] = dataclasses.field(default_factory=list) + instructions: list[str] = dataclasses.field(default_factory=list) + target: typing.Self | None = None + link: typing.Self | None = None + fallthrough: bool = True + hot: bool = False + + def __eq__(self, other: object) -> bool: + return self is other + + def __hash__(self) -> int: + return super().__hash__() + + def resolve(self) -> typing.Self: + while self.link and not self.instructions: + self = self.link + return self + + +class _Labels(dict): + def __missing__(self, key: str) -> _Block: + self[key] = _Block(key) + return self[key] + + +@dataclasses.dataclass +class Optimizer: + + path: pathlib.Path + _: dataclasses.KW_ONLY + prefix: str = "" + _graph: _Block = dataclasses.field(init=False) + _labels: _Labels = dataclasses.field(init=False, default_factory=_Labels) + + _re_branch: typing.ClassVar[re.Pattern[str]] # Two groups: instruction and target. + _re_jump: typing.ClassVar[re.Pattern[str]] # One group: target. + _re_return: typing.ClassVar[re.Pattern[str]] # No groups. + + def __post_init__(self) -> None: + text = self._preprocess(self.path.read_text()) + self._graph = block = self._new_block() + for line in text.splitlines(): + if label := _get_label(line): + block.link = block = self._labels[label] + elif block.target or not block.fallthrough: + block.link = block = self._new_block() + if _is_noise(line) or _get_label(line): + if block.instructions: + block.link = block = self._new_block() + block.noise.append(line) + continue + block.instructions.append(line) + if target := _get_branch(line): + block.target = self._labels[target] + assert block.fallthrough + elif target := _get_jump(line): + block.target = self._labels[target] + block.fallthrough = False + elif _is_return(line): + assert not block.target + block.fallthrough = False + + def _new_block(self, label: str | None = None) -> _Block: + if not label: + label = f"{self.prefix}_JIT_LABEL_{len(self._labels)}" + assert label not in self._labels, label + block = self._labels[label] = _Block(label, [f"{label}:"]) + return block + + def _preprocess(self, text: str) -> str: + return text + + def _blocks(self) -> typing.Generator[_Block, None, None]: + block = self._graph + while block: + yield block + block = block.link + + def _lines(self) -> typing.Generator[str, None, None]: + for block in self._blocks(): + yield from block.noise + yield from block.instructions + + def _insert_continue_label(self) -> None: + for end in reversed(list(self._blocks())): + if end.instructions: + break + continuation = self._labels[f"{self.prefix}_JIT_CONTINUE"] + continuation.noise.append(f"{continuation.label}:") + end.link, continuation.link = continuation, end.link + + def _mark_hot_blocks(self) -> None: + predecessors = collections.defaultdict(set) + for block in self._blocks(): + if block.target: + predecessors[block.target].add(block) + if block.fallthrough and block.link: + predecessors[block.link].add(block) + todo = [self._labels[f"{self.prefix}_JIT_CONTINUE"]] + while todo: + block = todo.pop() + block.hot = True + todo.extend( + predecessor + for predecessor in predecessors[block] + if not predecessor.hot + ) + + def _invert_hot_branches(self) -> None: + for block in self._blocks(): + if ( + block.fallthrough + and block.target + and block.link + and block.target.hot + and not block.link.hot + ): + # Turn... + # branch hot + # ...into.. + # opposite-branch ._JIT_LABEL_N + # jmp hot + # ._JIT_LABEL_N: + label_block = self._new_block() + inverted = _invert_branch(block.instructions[-1], label_block.label) + if inverted is None: + continue + jump_block = self._new_block() + jump_block.instructions.append(f"\tjmp\t{block.target.label}") + jump_block.target = block.target + jump_block.fallthrough = False + block.instructions[-1] = inverted + block.target = label_block + label_block.link = block.link + jump_block.link = label_block + block.link = jump_block + + def _thread_jumps(self) -> None: + for block in self._blocks(): + while block.target: + label = block.target.label + target = block.target.resolve() + if ( + not target.fallthrough + and target.target + and len(target.instructions) == 1 + ): + block.instructions[-1] = block.instructions[-1].replace( + label, target.target.label + ) # XXX + block.target = target.target + else: + break + + def _remove_dead_code(self) -> None: + reachable = set() + todo = [self._graph] + while todo: + block = todo.pop() + reachable.add(block) + if block.target and block.target not in reachable: + todo.append(block.target) + if block.fallthrough and block.link and block.link not in reachable: + todo.append(block.link) + for block in self._blocks(): + if block not in reachable: + block.instructions.clear() + + def _remove_redundant_jumps(self) -> None: + for block in self._blocks(): + if ( + block.target + and block.link + and block.target.resolve() is block.link.resolve() + ): + block.target = None + block.fallthrough = True + block.instructions.pop() + + def _remove_unused_labels(self) -> None: + used = set() + for block in self._blocks(): + if block.target: + used.add(block.target) + for block in self._blocks(): + if block not in used and block.label.startswith( + f"{self.prefix}_JIT_LABEL_" + ): + del block.noise[0] + + def run(self) -> None: + self._insert_continue_label() + self._mark_hot_blocks() + self._invert_hot_branches() + self._thread_jumps() + self._remove_dead_code() + self._remove_redundant_jumps() + self._remove_unused_labels() + self.path.write_text("\n".join(self._lines())) + + +class OptimizerX86(Optimizer): + + _re_branch = re.compile( + rf"\s*(?P{'|'.join(branches)})\s+(?P[\w\.]+)" + ) + _re_jump = re.compile(r"\s*jmp\s+(?P[\w\.]+)") + _re_return = re.compile(r"\s*ret\b") + + +class OptimizerX86Windows(OptimizerX86): + + def _preprocess(self, text: str) -> str: + text = super()._preprocess(text) + far_indirect_jump = ( + rf"rex64\s+jmpq\s+\*__imp_(?P{self.prefix}_JIT_\w+)\(%rip\)" + ) + return re.sub(far_indirect_jump, r"jmp\t\g", text) diff --git a/Tools/jit/_targets.py b/Tools/jit/_targets.py index 73f30fbc4de856..1972596b2fd5de 100644 --- a/Tools/jit/_targets.py +++ b/Tools/jit/_targets.py @@ -3,7 +3,6 @@ import asyncio import dataclasses import hashlib -import itertools import json import os import pathlib @@ -13,6 +12,7 @@ import typing import _llvm +import _optimizers import _schema import _stencils import _writer @@ -35,258 +35,10 @@ "_R", _schema.COFFRelocation, _schema.ELFRelocation, _schema.MachORelocation ) -branches = {} -for op, nop in [ - ("ja", "jna"), - ("jae", "jnae"), - ("jb", "jnb"), - ("jbe", "jnbe"), - ("jc", "jnc"), - ("jcxz", None), - ("je", "jne"), - ("jecxz", None), - ("jg", "jng"), - ("jge", "jnge"), - ("jl", "jnl"), - ("jle", "jnle"), - ("jo", "jno"), - ("jp", "jnp"), - ("jpe", "jpo"), - ("jrxz", None), - ("js", "jns"), - ("jz", "jnz"), - ("loop", None), - ("loope", None), - ("loopne", None), - ("loopnz", None), - ("loopz", None), -]: - branches[op] = nop - if nop is not None: - branches[nop] = op - - -@dataclasses.dataclass -class _Line: - fallthrough: typing.ClassVar[bool] = True - text: str - hot: bool = dataclasses.field(init=False, default=False) - predecessors: list["_Line"] = dataclasses.field( - init=False, repr=False, default_factory=list - ) - link: "_Line | None" = dataclasses.field(init=False, repr=False, default=None) - - def heat(self) -> None: - if self.hot: - return - self.hot = True - for predecessor in self.predecessors: - predecessor.heat() - if self.fallthrough and self.link is not None: - self.link.heat() - - def optimize(self) -> None: - if self.link is not None: - self.link.optimize() - - -@dataclasses.dataclass -class _Label(_Line): - label: str - - -@dataclasses.dataclass -class _Jump(_Line): - fallthrough = False - target: _Label - - def optimize(self) -> None: - super().optimize() - target_aliases = _aliases(self.target) - if any(alias in target_aliases for alias in _aliases(self.link)): - self.remove() - - def remove(self) -> None: - [predecessor] = self.predecessors - assert predecessor.link is self - self.target.predecessors.remove(self) - predecessor.link = self.link - self.target.predecessors.append(predecessor) - - def update(self, target: _Label) -> None: - self.target.predecessors.remove(self) - assert self.target.label in self.text - self.text = self.text.replace(self.target.label, target.label) - self.target = target - self.target.predecessors.append(self) - - -@dataclasses.dataclass -class _Branch(_Line): - op: str - target: _Label - - def optimize(self) -> None: - super().optimize() - if self.target.hot: - for jump in _aliases(self.link): - if isinstance(jump, _Jump) and self.invert(jump): - jump.optimize() - - def update(self, target: _Label) -> None: - self.target.predecessors.remove(self) - assert self.target.label in self.text - self.text = self.text.replace(self.target.label, target.label) - self.target = target - self.target.predecessors.append(self) - - def invert(self, jump: _Jump) -> bool: - inverted = branches[self.op] - if inverted is None: - return False - assert self.op in self.text - self.text = self.text.replace(self.op, inverted) - self.op = inverted - old_target = self.target - self.update(jump.target) - jump.update(old_target) - return True - - -@dataclasses.dataclass -class _Return(_Line): - fallthrough = False - - -@dataclasses.dataclass -class _Noise(_Line): - pass - - -def _branch( - line: str, use_label: typing.Callable[[str], _Label] -) -> tuple[str, _Label] | None: - branch = re.match(rf"\s*({'|'.join(branches)})\s+([\w\.]+)", line) - return branch and (branch.group(1), use_label(branch.group(2))) - - -def _jump(line: str, use_label: typing.Callable[[str], _Label]) -> _Label | None: - if jump := re.match(r"\s*jmp\s+([\w\.]+)", line): - return use_label(jump.group(1)) - return None - - -def _label(line: str, use_label: typing.Callable[[str], _Label]) -> _Label | None: - label = re.match(r"\s*([\w\.]+):", line) - return label and use_label(label.group(1)) - - -def _return(line: str) -> bool: - return re.match(r"\s*ret\s+", line) is not None - - -def _noise(line: str) -> bool: - return re.match(r"\s*[#\.]|\s*$", line) is not None - - -def _aliases(line: _Line | None) -> list[_Line]: - aliases = [] - while line is not None and isinstance(line, (_Label, _Noise)): - aliases.append(line) - line = line.link - if line is not None: - aliases.append(line) - return aliases - - -@dataclasses.dataclass -class _AssemblyTransformer: - path: pathlib.Path - alignment: int = 1 - prefix: str = "" - _lines: _Line = dataclasses.field(init=False) - _labels: dict[str, _Label] = dataclasses.field(init=False, default_factory=dict) - _ran: bool = dataclasses.field(init=False, default=False) - - def __post_init__(self) -> None: - dummy = current = _Noise("") - for line in self.path.read_text().splitlines(True): - new = self._new_line(line) - if current.fallthrough: - new.predecessors.append(current) - current.link = new - current = new - assert dummy.link is not None - self._lines = dummy.link - - def __iter__(self) -> typing.Iterator[_Line]: - line = self._lines - while line is not None: - yield line - line = line.link - - def _use_label(self, label: str) -> _Label: - if label not in self._labels: - self._labels[label] = _Label("", label) - return self._labels[label] - - def _new_line(self, text: str) -> _Line: - if branch := _branch(text, self._use_label): - op, label = branch - line = _Branch(text, op, label) - label.predecessors.append(line) - return line - if label := _jump(text, self._use_label): - line = _Jump(text, label) - label.predecessors.append(line) - return line - if line := _label(text, self._use_label): - assert line.text == "" - line.text = text - return line - if _return(text): - return _Return(text) - if _noise(text): - return _Noise(text) - return _Line(text) - - def _dump(self) -> str: - return "".join(line.text for line in self) - - def _break_on(self, name: str) -> None: - if self.path.stem == name: - print(self._dump()) - breakpoint() - - def run(self) -> None: - assert not self._ran - self._ran = True - last_line = None - for line in self: - if not isinstance(line, (_Label, _Noise)): - last_line = line - assert last_line is not None - new = self._new_line(f".balign {self.alignment}\n") - new.link = last_line.link - last_line.link = new - new = self._new_line(f"{self.prefix}_JIT_CONTINUE:\n") - new.link = last_line.link - last_line.link = new - # Mark hot lines and optimize: - recursion_limit = sys.getrecursionlimit() - sys.setrecursionlimit(10_000) - try: - self._labels[f"{self.prefix}_JIT_CONTINUE"].heat() - # self._break_on("_BUILD_TUPLE") - self._lines.optimize() - finally: - sys.setrecursionlimit(recursion_limit) - # Write new assembly: - self.path.write_text(self._dump()) - @dataclasses.dataclass class _Target(typing.Generic[_S, _R]): + triple: str condition: str _: dataclasses.KW_ONLY @@ -298,6 +50,7 @@ class _Target(typing.Generic[_S, _R]): verbose: bool = False known_symbols: dict[str, int] = dataclasses.field(default_factory=dict) pyconfig_dir: pathlib.Path = pathlib.Path.cwd().resolve() + optimizer: type[_optimizers.Optimizer] | None = None def _get_nop(self) -> bytes: if re.fullmatch(r"aarch64-.*", self.triple): @@ -409,7 +162,8 @@ async def _compile( *self.args, ] await _llvm.run("clang", args_s, echo=self.verbose) - _AssemblyTransformer(s, alignment=self.alignment, prefix=self.prefix).run() + if self.optimizer: + self.optimizer(s, prefix=self.prefix).run() args_o = [ f"--target={self.triple}", "-c", @@ -804,18 +558,26 @@ def get_target(host: str) -> _COFF | _ELF | _MachO: "-Wno-ignored-attributes", ] condition = "defined(_M_IX86)" - target = _COFF(host, condition, args=args, prefix="_") + target = _COFF( + host, + condition, + args=args, + optimizer=_optimizers.OptimizerX86Windows, + prefix="_", + ) elif re.fullmatch(r"x86_64-apple-darwin.*", host): condition = "defined(__x86_64__) && defined(__APPLE__)" - target = _MachO(host, condition, prefix="_") + target = _MachO(host, condition, optimizer=_optimizers.OptimizerX86, prefix="_") elif re.fullmatch(r"x86_64-pc-windows-msvc", host): args = ["-fms-runtime-lib=dll"] condition = "defined(_M_X64)" - target = _COFF(host, condition, args=args) + target = _COFF( + host, condition, args=args, optimizer=_optimizers.OptimizerX86Windows + ) elif re.fullmatch(r"x86_64-.*-linux-gnu", host): args = ["-fno-pic", "-mcmodel=medium", "-mlarge-data-threshold=0"] condition = "defined(__x86_64__) && defined(__linux__)" - target = _ELF(host, condition, args=args) + target = _ELF(host, condition, args=args, optimizer=_optimizers.OptimizerX86) else: raise ValueError(host) return target From 9ebb32c7bb090f5c74d465c577760deb1eb62d74 Mon Sep 17 00:00:00 2001 From: Brandt Bucher Date: Thu, 19 Jun 2025 18:09:49 -0700 Subject: [PATCH 05/20] Add AArch64 stub --- Tools/jit/_optimizers.py | 247 ++++++++++++++++++++++----------------- Tools/jit/_stencils.py | 22 +--- Tools/jit/_targets.py | 24 ++-- 3 files changed, 151 insertions(+), 142 deletions(-) diff --git a/Tools/jit/_optimizers.py b/Tools/jit/_optimizers.py index cdccce42384974..a54edd34838be3 100644 --- a/Tools/jit/_optimizers.py +++ b/Tools/jit/_optimizers.py @@ -4,69 +4,34 @@ import re import typing -branches = {} -for op, nop in [ - ("ja", "jna"), - ("jae", "jnae"), - ("jb", "jnb"), - ("jbe", "jnbe"), - ("jc", "jnc"), - ("jcxz", None), - ("je", "jne"), - ("jecxz", None), - ("jg", "jng"), - ("jge", "jnge"), - ("jl", "jnl"), - ("jle", "jnle"), - ("jo", "jno"), - ("jp", "jnp"), - ("jpe", "jpo"), - ("jrxz", None), - ("js", "jns"), - ("jz", "jnz"), - ("loop", None), - ("loope", None), - ("loopne", None), - ("loopnz", None), - ("loopz", None), -]: - branches[op] = nop - if nop: - branches[nop] = op - - -def _get_branch(line: str) -> str | None: - branch = re.match(rf"\s*({'|'.join(branches)})\s+([\w\.]+)", line) - return branch and branch[2] - - -def _invert_branch(line: str, label: str) -> str | None: - branch = re.match(rf"\s*({'|'.join(branches)})\s+([\w\.]+)", line) - assert branch - inverted = branches.get(branch[1]) - if not inverted: - return None - line = line.replace(branch[1], inverted, 1) # XXX - line = line.replace(branch[2], label, 1) # XXX - return line - - -def _get_jump(line: str) -> str | None: - jump = re.match(r"\s*(?:rex64\s+)?jmpq?\s+\*?([\w\.]+)", line) - return jump and jump[1] - - -def _get_label(line: str) -> str | None: - label = re.match(r"\s*([\w\.]+):", line) - return label and label[1] - - -def _is_return(line: str) -> bool: - return re.match(r"\s*ret\s+", line) is not None - - -def _is_noise(line: str) -> bool: - return re.match(r"\s*([#\.]|$)", line) is not None +_RE_NEVER_MATCH = re.compile(r"(?!)") + +_X86_BRANCHES = { + "ja": "jna", + "jae": "jnae", + "jb": "jnb", + "jbe": "jnbe", + "jc": "jnc", + "jcxz": None, + "je": "jne", + "jecxz": None, + "jg": "jng", + "jge": "jnge", + "jl": "jnl", + "jle": "jnle", + "jo": "jno", + "jp": "jnp", + "jpe": "jpo", + "jrxz": None, + "js": "jns", + "jz": "jnz", + "loop": None, + "loope": None, + "loopne": None, + "loopnz": None, + "loopz": None, +} +_X86_BRANCHES |= {v: k for k, v in _X86_BRANCHES.items() if v} @dataclasses.dataclass @@ -91,12 +56,6 @@ def resolve(self) -> typing.Self: return self -class _Labels(dict): - def __missing__(self, key: str) -> _Block: - self[key] = _Block(key) - return self[key] - - @dataclasses.dataclass class Optimizer: @@ -104,44 +63,104 @@ class Optimizer: _: dataclasses.KW_ONLY prefix: str = "" _graph: _Block = dataclasses.field(init=False) - _labels: _Labels = dataclasses.field(init=False, default_factory=_Labels) - - _re_branch: typing.ClassVar[re.Pattern[str]] # Two groups: instruction and target. - _re_jump: typing.ClassVar[re.Pattern[str]] # One group: target. - _re_return: typing.ClassVar[re.Pattern[str]] # No groups. + _labels: dict = dataclasses.field(init=False, default_factory=dict) + _alignment: typing.ClassVar[int] = 1 + _branches: typing.ClassVar[dict[str, str | None]] = {} + _re_branch: typing.ClassVar[re.Pattern[str]] = ( + _RE_NEVER_MATCH # Two groups: instruction and target. + ) + _re_jump: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH # One group: target. + _re_label: typing.ClassVar[re.Pattern[str]] = re.compile( + r"\s*(?P