Skip to content

Commit 3ff9ee7

Browse files
committed
Rework AArch64 alignment
1 parent 3a6c20d commit 3ff9ee7

File tree

2 files changed

+7
-10
lines changed

2 files changed

+7
-10
lines changed

Python/jit.c

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,10 @@ void patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *s
431431

432432
#if defined(__aarch64__) || defined(_M_ARM64)
433433
#define TRAMPOLINE_SIZE 16
434+
#define DATA_ALIGN 8
434435
#else
435436
#define TRAMPOLINE_SIZE 0
437+
#define DATA_ALIGN 0
436438
#endif
437439

438440
// Generate and patch AArch64 trampolines. The symbols to jump to are stored
@@ -522,7 +524,8 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz
522524
// Round up to the nearest page:
523525
size_t page_size = get_page_size();
524526
assert((page_size & (page_size - 1)) == 0);
525-
size_t padding = page_size - ((code_size + state.trampolines.size + data_size) & (page_size - 1));
527+
size_t code_padding = DATA_ALIGN - ((code_size + state.trampolines.size) & (DATA_ALIGN - 1))
528+
size_t padding = page_size - ((code_size + state.trampolines.size + code_padding + data_size) & (page_size - 1));
526529
size_t total_size = code_size + state.trampolines.size + data_size + padding;
527530
unsigned char *memory = jit_alloc(total_size);
528531
if (memory == NULL) {
@@ -545,7 +548,7 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz
545548
// Loop again to emit the code:
546549
unsigned char *code = memory;
547550
state.trampolines.mem = memory + code_size;
548-
unsigned char *data = memory + code_size + state.trampolines.size;
551+
unsigned char *data = memory + code_size + state.trampolines.size + code_padding;
549552
// Compile the shim, which handles converting between the native
550553
// calling convention and the calling convention used by jitted code
551554
// (which may be different for efficiency reasons).
@@ -567,7 +570,7 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz
567570
code += group->code_size;
568571
data += group->data_size;
569572
assert(code == memory + code_size);
570-
assert(data == memory + code_size + state.trampolines.size + data_size);
573+
assert(data == memory + code_size + state.trampolines.size + code_padding + data_size);
571574
#ifdef MAP_JIT
572575
pthread_jit_write_protect_np(1);
573576
#endif

Tools/jit/_optimizers.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ class Optimizer:
8484
r'\s*(?P<label>[\w."$?@]+):'
8585
)
8686
# Override everything that follows in subclasses:
87-
_alignment: typing.ClassVar[int] = 1
8887
_branches: typing.ClassVar[dict[str, str | None]] = {}
8988
# Two groups (instruction and target):
9089
_re_branch: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH
@@ -197,15 +196,12 @@ def _insert_continue_label(self) -> None:
197196
# jmp FOO
198197
# After:
199198
# jmp FOO
200-
# .balign 8
201199
# _JIT_CONTINUE:
202200
# This lets the assembler encode _JIT_CONTINUE jumps at build time!
203-
align = _Block()
204-
align.noninstructions.append(f"\t.balign\t{self._alignment}")
205201
continuation = self._lookup_label(f"{self.prefix}_JIT_CONTINUE")
206202
assert continuation.label
207203
continuation.noninstructions.append(f"{continuation.label}:")
208-
end.link, align.link, continuation.link = align, continuation, end.link
204+
end.link, continuation.link = continuation, end.link
209205

210206
def _mark_hot_blocks(self) -> None:
211207
# Start with the last block, and perform a DFS to find all blocks that
@@ -285,8 +281,6 @@ def run(self) -> None:
285281
class OptimizerAArch64(Optimizer): # pylint: disable = too-few-public-methods
286282
"""aarch64-apple-darwin/aarch64-pc-windows-msvc/aarch64-unknown-linux-gnu"""
287283

288-
# TODO: @diegorusso
289-
_alignment = 8
290284
# https://developer.arm.com/documentation/ddi0602/2025-03/Base-Instructions/B--Branch-
291285
_re_jump = re.compile(r"\s*b\s+(?P<target>[\w.]+)")
292286

0 commit comments

Comments
 (0)