-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Summary
Write a naive attention kernel in WarpForth that computes single-head scaled dot-product attention:
O = softmax(Q @ K^T / sqrt(d_k)) @ V
Design
Approach: one kernel, one thread block per query row
Each thread block computes one row of the output. Within a block, threads cooperate on:
- Dot products: Q[row] · K[j] for all j (parallelized across threads)
- Scaling: divide by sqrt(d_k)
- Causal mask: set scores[j] = -inf for j > row (GPT-2 is autoregressive)
- Softmax: max-reduction → subtract max → exp → sum-reduction → divide
- V accumulation: weighted sum of V rows
Parameters
\! kernel attention
\! param Q f64[SEQ*DIM] \ Query matrix (seq_len × head_dim), row-major
\! param K f64[SEQ*DIM] \ Key matrix
\! param V f64[SEQ*DIM] \ Value matrix
\! param O f64[SEQ*DIM] \ Output matrix
\! param SEQ_LEN i64 \ Sequence length
\! param HEAD_DIM i64 \ Head dimension (e.g., 64)
\! shared SCORES f64[MAX_SEQ] \ Scratch for scores (one row)
\! shared SCRATCH f64[MAX_SEQ] \ Scratch for reductionsGrid/block configuration
- Grid:
(seq_len, 1, 1)— one block per query row - Block:
(block_size, 1, 1)— threads cooperate within a row
Causal masking
GPT-2 uses autoregressive (causal) attention. Scores where key position > query position must be masked to -inf before softmax. This is essential for correctness.
Softmax implementation
Numerically stable softmax (3 passes over scores):
- Max reduction via shared memory + BARRIER (needs
FMAX) - Exp + sum: compute
exp(score - max), store back, reduce sum (needsFEXP) - Normalize: divide each score by sum
Dependencies
- Float math intrinsics: FEXP, FSQRT, FLOG, FABS, FNEG, FMAX, FMIN #42 — Math intrinsics (FEXP, FSQRT, FMAX, FMIN) — blocks this issue
Files to create
demo/attention.forth— The kernel sourcetest/Pipeline/attention.forthor GPU test — Correctness test against NumPy reference
Acceptance criteria
- Kernel compiles to PTX via
warpforthc - Output matches
torch.nn.functional.scaled_dot_product_attentionfor small inputs - Causal masking works correctly
- Tested with seq_len=16, head_dim=64 at minimum
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request