Skip to content

Add specialized numba codegen for as_while=False scans #2024

@jessegrabowski

Description

@jessegrabowski

Description

tl;dr: The numba codegen for scan currently always emits a while loop. This can lead to up to 4x slowdown on long sequences vs a native for-loop. Here's graph:

Image

numba_funcify_Scan unconditionally creates code of the form:

i = 0
cond = np.array(False)
while i < n_steps and not cond.item():
   inner_fn()
   i += 1

while scans are extremely rare relative to for scans. In the for case, we never update the cond variable, so the code is logically equivalent to a for loop. It turns out, however, that LLVM can't prove that. The LLM tells me that this is because np.array(False) creates a heap-resident numpy array via numba's memory allocator. LLVM cannot prove that: 1) the cond array doesn't alias the buffers being written to inside the loop, nor 2) the .item() call reads from a location that is loop-invariant.

LLVM conservatively assumes that the memory could possibly be modified by the stores inside the loop body, so it executes the following IR:

B140.endif:                                       ; the loop header
  %lsr.iv10 = phi ptr [ %scevgep, %B140.endif.preheader ], [ %scevgep11, %B172 ]
  %lsr.iv = phi i64 [ %0, %B140.endif.preheader ], [ %lsr.iv.next, %B172 ]

  ;; ↓ MUST load cond from heap memory every iteration (can't prove it's invariant)
  %.193 = load i8, ptr %.6.i.i, align 1
  %.196 = icmp eq i8 %.193, 0                     ; cond == False?
  br i1 %.196, label %B172, label %B174            ; if cond, exit loop

B172:                                             ; the loop body
  %scevgep12 = getelementptr i8, ptr %lsr.iv10, i64 -8
  %.131 = load double, ptr %scevgep12, align 8    ; ← MUST load buf[i] from memory
  %.132 = fadd double %.131, 1.000000e+00         ; buf[i+1] = buf[i] + 1.0
  store double %.132, ptr %lsr.iv10, align 8       ; write to buf
  %lsr.iv.next = add i64 %lsr.iv, -1              ; decrement
  %scevgep11 = getelementptr i8, ptr %lsr.iv10, i64 8  ; advance
  %exitcond.not = icmp eq i64 %lsr.iv.next, 0     ; counter == 0?
  br i1 %exitcond.not, label %B174, label %B140.endif  ; loop back to cond check

Ideally, we should be getting this:

B30:                                              ; the loop body
  %lsr.iv1 = phi i64 [ %arg.n_steps, %B30.preheader ], [ %lsr.iv.next, %B30 ]
  %lsr.iv = phi ptr [ %invariant.gep, %B30.preheader ], [ %scevgep, %B30 ]
  %store_forwarded = phi double [ %load_initial, %B30.preheader ], [ %.168, %B30 ]
  ;; ↑ LLVM forwards the stored value directly via a phi node — no memory load needed!

  %.168 = fadd double %store_forwarded, 1.000000e+00    ; buf[i+1] = buf[i] + 1.0
  store double %.168, ptr %lsr.iv, align 8               ; write to buf
  %scevgep = getelementptr i8, ptr %lsr.iv, i64 8        ; advance pointer
  %lsr.iv.next = add i64 %lsr.iv1, -1                    ; decrement counter
  %exitcond.not = icmp eq i64 %lsr.iv.next, 0            ; counter == 0?
  br i1 %exitcond.not, label %B62, label %B30             ; loop or exit

The fix is pretty straight-forward. We currently always generate loops that look like this:

def scan(n_steps, outer_in_1):

    outer_in_1_len = outer_in_1.shape[0]
    outer_in_1_sitsot_storage = outer_in_1

    outer_in_1_sitsot_storage_temp_scalar_0 = np.empty((), dtype=np.float64)

    i = 0
    cond = np.array(False)                                         # ← heap allocation
    while i < n_steps and not cond.item():                         # ← checked every iter
        outer_in_1_sitsot_storage_temp_scalar_0[()] = outer_in_1_sitsot_storage[(i) % outer_in_1_len]

        (inner_out_0,) = scan_inner_func(
            outer_in_1_sitsot_storage_temp_scalar_0
        )

        outer_in_1_sitsot_storage[(i + 1) % outer_in_1_len] = inner_out_0
        i += 1                                                     # ← cond never updated

    ...
    return outer_in_1_sitsot_storage

Instead, we can check the as_while flag at compile time and, if False, specialize to something like this:

def scan(n_steps, outer_in_1):

    outer_in_1_len = outer_in_1.shape[0]
    outer_in_1_sitsot_storage = outer_in_1

    outer_in_1_sitsot_storage_temp_scalar_0 = np.empty((), dtype=np.float64)

    for i in range(n_steps):                                       # ← simple counted loop
        outer_in_1_sitsot_storage_temp_scalar_0[()] = outer_in_1_sitsot_storage[(i) % outer_in_1_len]

        (inner_out_0,) = scan_inner_func(
            outer_in_1_sitsot_storage_temp_scalar_0
        )

        outer_in_1_sitsot_storage[(i + 1) % outer_in_1_len] = inner_out_0

    ...
    return outer_in_1_sitsot_storage

Basically something that looks like this (untested LLM code):

if op.info.as_while:
    loop_src = f"""
    i = 0
    while i < n_steps:
{indent(inner_scalar_in_args_to_temp_storage, " " * 8)}

        {inner_outputs} = scan_inner_func(
            {inner_in_args}
        )
{indent(inner_out_post_processing_block, " " * 8)}
{indent(inner_out_to_outer_out_stmts, " " * 8)}
        if cond:
            i += 1
            break
        i += 1
"""
else:
    loop_src = f"""
    for i in range(n_steps):
{indent(inner_scalar_in_args_to_temp_storage, " " * 8)}

        {inner_outputs} = scan_inner_func(
            {inner_in_args}
        )
{indent(inner_out_post_processing_block, " " * 8)}
{indent(inner_out_to_outer_out_stmts, " " * 8)}
"""

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions