Skip to content

Fix non-canonical cu_seqlens_k from preprocessor#514

Open
jlamypoirier wants to merge 2 commits into
jlp_sdpa-attentionfrom
jlp_varlen-cu-seqlens-fix
Open

Fix non-canonical cu_seqlens_k from preprocessor#514
jlamypoirier wants to merge 2 commits into
jlp_sdpa-attentionfrom
jlp_varlen-cu-seqlens-fix

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

Summary

  • Preprocessor emitted cu_seqlens_k[0] = first_document_begin instead of 0, violating the canonical varlen prefix-sum layout. SDPA EFFICIENT backward writes corrupt dK/dV rows when fed this, propagating wrong K/V projection grads through the reduce-scatter under sequence-data-parallel + micro-batch splits.
  • Preprocessor now produces canonical cu_seqlens_k starting at 0; narrows document_index_k / position_index to the active K extent; exposes the dropped leading-prefix length as a new first_document_begin int kwarg.
  • Pre-allocate one K/V buffer per attention layer across all micro-sequences of a sequence. Forward writes the SDP-gather result into the next slice via gather_op(out=); backward accumulates per-micro-seq K/V grad into a shared grad buffer slice. The leading + trailing narrows and the per-step torch.cat / AttachGrad cross-micro-seq splice are absorbed into _query_key_value's custom autograd region.

Test plan

  • tests/data/test_preprocessing.py (843/843 pass) — updated to expect canonical layout
  • tests/layers/test_attention.py (56/56 pass on CPU and CUDA)
  • tests/models/test_model.pygpt_2-ms4, gpt_2-sdp2, gpt_2-sdp2_stp2, gpt_2-sdp2_stp2_bf4, gpt_2-stp2_pp2s1_bf4 still fail with ~0.5–1% relative gradient drift vs the simple baseline. The cu_seqlens_k bug fix is correct, but the inherent sdpa_nested EFFICIENT-backend numerical character exceeds the 0.3% tolerance that was calibrated against the deterministic backup path. Pre-PR-Add SDPA attention implementation #512 main used backup as the fp32+CUDA default; Add SDPA attention implementation #512 switched to sdpa_nested. To be addressed separately — either revert that auto-default for fp32 or loosen the comparison tolerance for the sdpa path.

The data preprocessor emitted `cu_seqlens_k[0] = first_document_begin` rather
than 0, violating the canonical varlen prefix-sum layout required by every
public varlen attention API. SDPA's EFFICIENT backward writes corrupted dK/dV
rows when fed this layout, propagating wrong gradients through the K/V
projection's reduce-scatter under sequence-data-parallel + micro-batch splits.

Three changes that compose:
- `LengthModelInputPreprocessor` now produces `cu_seqlens_k` starting at 0 and
  narrows `document_index_k` / `position_index` to the active K extent. The
  dropped leading-prefix length is exposed as a new `first_document_begin` int
  kwarg.
- Pre-allocate one K/V buffer per attention layer across all micro-sequences
  of a sequence. Each forward writes the SDP-gather result into the next slice
  via `gather_op(out=)`; backward accumulates each micro-seq's K/V grad into
  a shared grad buffer slice. The leading + trailing narrows and the per-step
  `torch.cat` / `AttachGrad` workaround for the cross-micro-seq splice are all
  absorbed into the `_query_key_value` custom autograd region.
- `_preprocess_for_backup_attention` builds the attention mask against the
  narrowed K cols so `sdpa_dense` and `backup` consume the same K extent as
  flash and `sdpa_nested`.

Update `tests/data/test_preprocessing.py` to expect the canonical layout.
`_test_first_document_begin` injects a fake past K/V slot with arbitrary leading
data, drives attention through a manually-built kwargs with `sequence_k_past`
and `first_document_begin` both set to a non-zero `past_length`, and verifies:
- forward output matches a per-doc reference computed on the active documents
  alone (the dropped prefix has no observable effect),
- parameter gradients match the reference,
- the K/V grad buffer at `[:past_length]` is exactly zero — the specific
  guarantee of the cu_seqlens_k canonicalization fix.

Runs backup + sdpa_dense on fp32, flash + sdpa_nested on bf16 (flash rejects
fp32). Plugged into the existing `test_attention` parametrization as a new
case with `name="first_document_begin"`, dispatched via name check.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant