Skip to content

[CB] Fix capture of max_seqlen#45323

Merged
remi-or merged 17 commits intomainfrom
cb-fix-cg-prefill
Apr 17, 2026
Merged

[CB] Fix capture of max_seqlen#45323
remi-or merged 17 commits intomainfrom
cb-fix-cg-prefill

Conversation

@remi-or
Copy link
Copy Markdown
Collaborator

@remi-or remi-or commented Apr 8, 2026

Summary

This PR fixes the issue raised in #45274 .
CUDA graph reuse in continuous batching used (num_q_tokens, max_kv_read) as the graph cache key. However, FlashAttention varlen kernels also depend on max_seqlen_q and max_seqlen_k, which are Python ints baked into the graph as kernel launch constants. When two batches shared the same padded tensor sizes but had different max_seqlen_k, a stale graph was replayed with incorrect kernel parameters, causing accuracy collapse (e.g., GSM8K dropping from ~0.45 to ~0.23).

Additionally, use_cuda_graph was a single boolean controlling both the varlen (prefill) and decode paths, with no way to enable graphs for one path but not the other.

Fix

  1. Per-path CUDA graph control

use_cuda_graph is now bool | tuple[bool, bool] | None, allowing independent control over the varlen and decode paths. This enables power-user to short circuit the issue altogether and test out different configs for different workloads.

  1. Extended graph key with max_seqlen_k

The varlen graph key is now (num_q_tokens, max_kv_read, *max_seqlen_k.values()), capturing all integer values that affect FA kernel behavior. The variable max_seqlen_k is bucketed using power of 2 padded sizes and a minimum size of 1024. max_seqlen_q is padded to num_q_tokens (already in the key). The decode key remains (num_q_tokens,) since max_seqlen_k is unused by flash_attn_with_kvcache.

  1. Padding fix for cumulative_seqlens_q

The old code set padding entries in cu_seqlens_q to q_size to work around NaN issues, but this inflated max_seqlen_q and complicated graph keying. Padding entries are now set to 0 (zero-length sequences), which FA handles correctly and removes the NaN workaround.

Performances

Arguments Main (tok/s) Current (tok/s) Diff (%)
--samples 10 870.35 866.77 -0.4%
--samples 20 --num-blocks 20 528.24 517.7 -2.0%
--samples 50 3606.35 3599.36 -0.2%
--samples 100 5359.15 5344.67 -0.3%
--samples 100 --attn flash_attention_2 3687.72 3682.75 -0.1%
--samples 100 --attn sdpa 1030.96 1031.0 +0.0%
--samples 500 --no-use-async 6673.33 6635.82 -0.6%
--samples 500 --use-async 8045.98 7992.98 -0.7%
--samples 32 --max-new-tokens 2048 --use-async 2088.15 2043.99 -2.1%
--samples 32 --max-new-tokens 2048 --use-async --block-table 32 2715.6 2709.64 -0.2%
--samples 500 --add-prefix --compile 7511.76 7598.71 +1.2%
--samples 50 --num-return-sequences 8 --do-sample 870.29 867.97 -0.3%
--samples 100 --num-return-sequences 4 --do-sample 1712.62 1716.06 +0.2%

Perf is gloabally stable, a bit of regression in some case in normal because max_seqlen_q/k are now always bigger than the actual value, hence a bit more work is created for the FA kernel.

Tests

The following tests pass:

tests/generation/test_continuous_batching.py
tests/cli/test_serve.py
tests/generation/test_paged_attention.py

Eval on GMS8K

max_batch_tokens Accuracy Time (s) Tok/s
128 0.8205 231.35 1337.8
256 0.8246 166.64 1857.3
384 0.8246 148.91 2078.5
512 0.8246 140.21 2207.4
640 0.8246 136.86 2261.5
768 0.8246 133.70 2314.8
896 0.8238 135.36 2286.5
1024 0.8246 130.90 2364.5

with meta-llama/Llama-3.1-8B-Instruct and attn is kernels-community/flash-attn3

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Qubitium
Copy link
Copy Markdown
Contributor

Qubitium commented Apr 9, 2026

@remi-or Please also check and fix, if related for my sweep crashes: The crash points at high max_batched_tokens of >= 896, 1024. Threshold likely model/shape specific but if you sweep the max_bt you should hit the same crashes at some threshold of max_batched_tokens.

  MAIN = /tmp/transformers-main-head @ d3c7a19176
  PR = /root/transformers @ 33e9cbb08e
  Workload: Evalution GSM8K Platinum, paged|flash_attention_2, bf16, batch_size=24, max_rows=128, use_cuda_graph=auto, use_async_batching=auto.
  

  Sweep Results
  +--------+--------------+------------+-----------+-------------+-----------+----------+--------------+------------+
  | max_bt | MAIN acc,num | PR acc,num | Δ acc     | MAIN samp/s | PR samp/s | Δ samp/s | MAIN peak GB | PR peak GB | 
  +========+==============+============+===========+=============+===========+==========+==============+============+
    | 128    | 0.4140625    | 0.4140625  | +0.0000   | 9.7604      | 10.1033   | +3.51%   | 34.7090      | 34.7090    |
    | 256    | 0.4140625    | 0.4453125  | +0.0312   | 9.7349      | 8.7137    | -10.49%  | 34.9004      | 34.9004    |
--> | 384    | 0.2265625    | 0.4531250  | +0.2266   | 10.0055     | 9.9615    | -0.44%   | 35.0859      | 35.0898    |
    | 512    | 0.4218750    | 0.4062500  | -0.0156   | 10.0384     | 9.7630    | -2.74%   | 35.2949      | 35.2949    | 
    | 640    | 0.3984375    | 0.4140625  | +0.0156   | 10.2056     | 9.5572    | -6.35%   | 35.4883      | 35.4902    |
    | 768    | 0.3984375    | 0.4062500  | +0.0078   | 10.0023     | 9.7447    | -2.58%   | 35.6836      | 35.6836    | 
    | 896    | FAIL         | FAIL       | n/a       | n/a         | n/a       | n/a      | n/a          | n/a        | 
    | 1024   | FAIL         | FAIL       | n/a       | n/a         | n/a       | n/a      | n/a          | n/a        |
  +--------+--------------+------------+-----------+-------------+-----------+----------+--------------+------------+

  Failure Details
  +--------+-------------------------------+
  | max_bt | MAIN                          | PR                             |
  +========+===============================+================================+
  | 896    | illegal memory access         | illegal memory access          |
  | 1024   | CUBLAS_STATUS_EXECUTION_FAILED| illegal memory access          |
  +--------+-------------------------------+--------------------------------+

@remi-or
Copy link
Copy Markdown
Collaborator Author

remi-or commented Apr 14, 2026

@Qubitium I am almost done with the draft status, I was just wondering what you were using to run your benches please? So I can run the same workload. Thank you!

@Qubitium
Copy link
Copy Markdown
Contributor

@Qubitium I am almost done with the draft status, I was just wondering what you were using to run your benches please? So I can run the same workload. Thank you!

I will recreate the sample script that reproduced this.

@Qubitium
Copy link
Copy Markdown
Contributor

Qubitium commented Apr 14, 2026

@Qubitium I am almost done with the draft status, I was just wondering what you were using to run your benches please? So I can run the same workload. Thank you!

I have re-created the benchmark script but I cannot find/remember the model I used to reproduce the crashes or the score degration without the fix! It was a simple sweep of the max_batch_token size to stress cb | fa2 | page-attention on A100. Ah. I thought I was just using Llma 3.2 1b instruct but its not crashing under same config and the vram usage is not matching up so it is not the same/right model. I tried some other models and th vram usge is also not matching up and crashing at the aforementioned max_batch_token size boundries. Let's ignore that for now then.

@remi-or remi-or requested a review from ArthurZucker April 15, 2026 01:05
@remi-or remi-or marked this pull request as ready for review April 15, 2026 01:07
@remi-or
Copy link
Copy Markdown
Collaborator Author

remi-or commented Apr 15, 2026

Ok, dropping the GMSK benchmark. Ready for review then.
EDIT: ran the evals on my end to check everything was fine -- it is on the model I tried.

@remi-or
Copy link
Copy Markdown
Collaborator Author

remi-or commented Apr 16, 2026

BTW @Qubitium I have not activated FA2 because I have a bug when running it. I am using the kernels implem of FA2 so that might be the issue. Here is a MRE of the issue. I wonder what implementation you were using and if you encountered the same issue.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !

# Validate and optionally filter processors based on their CB support
self._validate_processors(drop_unsupported_processors)
self._retrieve_processors_kwargs()
# Static boolean to know if there is any processing to do
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logit processing i guess :)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe property since self.logits_processor can change

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it can only be cleared right now! Which is already planned for. I changed the comment as asked.

@remi-or remi-or added this pull request to the merge queue Apr 17, 2026
Merged via the queue into main with commit 71560f4 Apr 17, 2026
29 checks passed
@remi-or remi-or deleted the cb-fix-cg-prefill branch April 17, 2026 03:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants