Skip to content

Naive attention kernel in Forth #44

@tetsuo-cpp

Description

@tetsuo-cpp

Summary

Write a naive attention kernel in WarpForth that computes single-head scaled dot-product attention:

O = softmax(Q @ K^T / sqrt(d_k)) @ V

Design

Approach: one kernel, one thread block per query row

Each thread block computes one row of the output. Within a block, threads cooperate on:

  1. Dot products: Q[row] · K[j] for all j (parallelized across threads)
  2. Scaling: divide by sqrt(d_k)
  3. Causal mask: set scores[j] = -inf for j > row (GPT-2 is autoregressive)
  4. Softmax: max-reduction → subtract max → exp → sum-reduction → divide
  5. V accumulation: weighted sum of V rows

Parameters

\! kernel attention
\! param Q f64[SEQ*DIM]       \ Query matrix (seq_len × head_dim), row-major
\! param K f64[SEQ*DIM]       \ Key matrix
\! param V f64[SEQ*DIM]       \ Value matrix
\! param O f64[SEQ*DIM]       \ Output matrix
\! param SEQ_LEN i64          \ Sequence length
\! param HEAD_DIM i64         \ Head dimension (e.g., 64)
\! shared SCORES f64[MAX_SEQ] \ Scratch for scores (one row)
\! shared SCRATCH f64[MAX_SEQ] \ Scratch for reductions

Grid/block configuration

  • Grid: (seq_len, 1, 1) — one block per query row
  • Block: (block_size, 1, 1) — threads cooperate within a row

Causal masking

GPT-2 uses autoregressive (causal) attention. Scores where key position > query position must be masked to -inf before softmax. This is essential for correctness.

Softmax implementation

Numerically stable softmax (3 passes over scores):

  1. Max reduction via shared memory + BARRIER (needs FMAX)
  2. Exp + sum: compute exp(score - max), store back, reduce sum (needs FEXP)
  3. Normalize: divide each score by sum

Dependencies

Files to create

  • demo/attention.forth — The kernel source
  • test/Pipeline/attention.forth or GPU test — Correctness test against NumPy reference

Acceptance criteria

  • Kernel compiles to PTX via warpforthc
  • Output matches torch.nn.functional.scaled_dot_product_attention for small inputs
  • Causal masking works correctly
  • Tested with seq_len=16, head_dim=64 at minimum

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions