Description
tl;dr: The numba codegen for scan currently always emits a while loop. This can lead to up to 4x slowdown on long sequences vs a native for-loop. Here's graph:
numba_funcify_Scan unconditionally creates code of the form:
i = 0
cond = np.array(False)
while i < n_steps and not cond.item():
inner_fn()
i += 1
while scans are extremely rare relative to for scans. In the for case, we never update the cond variable, so the code is logically equivalent to a for loop. It turns out, however, that LLVM can't prove that. The LLM tells me that this is because np.array(False) creates a heap-resident numpy array via numba's memory allocator. LLVM cannot prove that: 1) the cond array doesn't alias the buffers being written to inside the loop, nor 2) the .item() call reads from a location that is loop-invariant.
LLVM conservatively assumes that the memory could possibly be modified by the stores inside the loop body, so it executes the following IR:
B140.endif: ; the loop header
%lsr.iv10 = phi ptr [ %scevgep, %B140.endif.preheader ], [ %scevgep11, %B172 ]
%lsr.iv = phi i64 [ %0, %B140.endif.preheader ], [ %lsr.iv.next, %B172 ]
;; ↓ MUST load cond from heap memory every iteration (can't prove it's invariant)
%.193 = load i8, ptr %.6.i.i, align 1
%.196 = icmp eq i8 %.193, 0 ; cond == False?
br i1 %.196, label %B172, label %B174 ; if cond, exit loop
B172: ; the loop body
%scevgep12 = getelementptr i8, ptr %lsr.iv10, i64 -8
%.131 = load double, ptr %scevgep12, align 8 ; ← MUST load buf[i] from memory
%.132 = fadd double %.131, 1.000000e+00 ; buf[i+1] = buf[i] + 1.0
store double %.132, ptr %lsr.iv10, align 8 ; write to buf
%lsr.iv.next = add i64 %lsr.iv, -1 ; decrement
%scevgep11 = getelementptr i8, ptr %lsr.iv10, i64 8 ; advance
%exitcond.not = icmp eq i64 %lsr.iv.next, 0 ; counter == 0?
br i1 %exitcond.not, label %B174, label %B140.endif ; loop back to cond check
Ideally, we should be getting this:
B30: ; the loop body
%lsr.iv1 = phi i64 [ %arg.n_steps, %B30.preheader ], [ %lsr.iv.next, %B30 ]
%lsr.iv = phi ptr [ %invariant.gep, %B30.preheader ], [ %scevgep, %B30 ]
%store_forwarded = phi double [ %load_initial, %B30.preheader ], [ %.168, %B30 ]
;; ↑ LLVM forwards the stored value directly via a phi node — no memory load needed!
%.168 = fadd double %store_forwarded, 1.000000e+00 ; buf[i+1] = buf[i] + 1.0
store double %.168, ptr %lsr.iv, align 8 ; write to buf
%scevgep = getelementptr i8, ptr %lsr.iv, i64 8 ; advance pointer
%lsr.iv.next = add i64 %lsr.iv1, -1 ; decrement counter
%exitcond.not = icmp eq i64 %lsr.iv.next, 0 ; counter == 0?
br i1 %exitcond.not, label %B62, label %B30 ; loop or exit
The fix is pretty straight-forward. We currently always generate loops that look like this:
def scan(n_steps, outer_in_1):
outer_in_1_len = outer_in_1.shape[0]
outer_in_1_sitsot_storage = outer_in_1
outer_in_1_sitsot_storage_temp_scalar_0 = np.empty((), dtype=np.float64)
i = 0
cond = np.array(False) # ← heap allocation
while i < n_steps and not cond.item(): # ← checked every iter
outer_in_1_sitsot_storage_temp_scalar_0[()] = outer_in_1_sitsot_storage[(i) % outer_in_1_len]
(inner_out_0,) = scan_inner_func(
outer_in_1_sitsot_storage_temp_scalar_0
)
outer_in_1_sitsot_storage[(i + 1) % outer_in_1_len] = inner_out_0
i += 1 # ← cond never updated
...
return outer_in_1_sitsot_storage
Instead, we can check the as_while flag at compile time and, if False, specialize to something like this:
def scan(n_steps, outer_in_1):
outer_in_1_len = outer_in_1.shape[0]
outer_in_1_sitsot_storage = outer_in_1
outer_in_1_sitsot_storage_temp_scalar_0 = np.empty((), dtype=np.float64)
for i in range(n_steps): # ← simple counted loop
outer_in_1_sitsot_storage_temp_scalar_0[()] = outer_in_1_sitsot_storage[(i) % outer_in_1_len]
(inner_out_0,) = scan_inner_func(
outer_in_1_sitsot_storage_temp_scalar_0
)
outer_in_1_sitsot_storage[(i + 1) % outer_in_1_len] = inner_out_0
...
return outer_in_1_sitsot_storage
Basically something that looks like this (untested LLM code):
if op.info.as_while:
loop_src = f"""
i = 0
while i < n_steps:
{indent(inner_scalar_in_args_to_temp_storage, " " * 8)}
{inner_outputs} = scan_inner_func(
{inner_in_args}
)
{indent(inner_out_post_processing_block, " " * 8)}
{indent(inner_out_to_outer_out_stmts, " " * 8)}
if cond:
i += 1
break
i += 1
"""
else:
loop_src = f"""
for i in range(n_steps):
{indent(inner_scalar_in_args_to_temp_storage, " " * 8)}
{inner_outputs} = scan_inner_func(
{inner_in_args}
)
{indent(inner_out_post_processing_block, " " * 8)}
{indent(inner_out_to_outer_out_stmts, " " * 8)}
"""
Description
tl;dr: The numba codegen for scan currently always emits a
whileloop. This can lead to up to 4x slowdown on long sequences vs a native for-loop. Here's graph:numba_funcify_Scanunconditionally creates code of the form:whilescans are extremely rare relative toforscans. In theforcase, we never update thecondvariable, so the code is logically equivalent to aforloop. It turns out, however, that LLVM can't prove that. The LLM tells me that this is becausenp.array(False)creates a heap-resident numpy array via numba's memory allocator. LLVM cannot prove that: 1) thecondarray doesn't alias the buffers being written to inside the loop, nor 2) the.item()call reads from a location that is loop-invariant.LLVM conservatively assumes that the memory could possibly be modified by the stores inside the loop body, so it executes the following IR:
Ideally, we should be getting this:
The fix is pretty straight-forward. We currently always generate loops that look like this:
Instead, we can check the
as_whileflag at compile time and, if False, specialize to something like this:Basically something that looks like this (untested LLM code):