Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@remi-or Please also check and fix, if related for my sweep crashes: The crash points at high max_batched_tokens of >= 896, 1024. Threshold likely model/shape specific but if you sweep the MAIN = /tmp/transformers-main-head @ d3c7a19176
PR = /root/transformers @ 33e9cbb08e
Workload: Evalution GSM8K Platinum, paged|flash_attention_2, bf16, batch_size=24, max_rows=128, use_cuda_graph=auto, use_async_batching=auto.
Sweep Results
+--------+--------------+------------+-----------+-------------+-----------+----------+--------------+------------+
| max_bt | MAIN acc,num | PR acc,num | Δ acc | MAIN samp/s | PR samp/s | Δ samp/s | MAIN peak GB | PR peak GB |
+========+==============+============+===========+=============+===========+==========+==============+============+
| 128 | 0.4140625 | 0.4140625 | +0.0000 | 9.7604 | 10.1033 | +3.51% | 34.7090 | 34.7090 |
| 256 | 0.4140625 | 0.4453125 | +0.0312 | 9.7349 | 8.7137 | -10.49% | 34.9004 | 34.9004 |
--> | 384 | 0.2265625 | 0.4531250 | +0.2266 | 10.0055 | 9.9615 | -0.44% | 35.0859 | 35.0898 |
| 512 | 0.4218750 | 0.4062500 | -0.0156 | 10.0384 | 9.7630 | -2.74% | 35.2949 | 35.2949 |
| 640 | 0.3984375 | 0.4140625 | +0.0156 | 10.2056 | 9.5572 | -6.35% | 35.4883 | 35.4902 |
| 768 | 0.3984375 | 0.4062500 | +0.0078 | 10.0023 | 9.7447 | -2.58% | 35.6836 | 35.6836 |
| 896 | FAIL | FAIL | n/a | n/a | n/a | n/a | n/a | n/a |
| 1024 | FAIL | FAIL | n/a | n/a | n/a | n/a | n/a | n/a |
+--------+--------------+------------+-----------+-------------+-----------+----------+--------------+------------+
Failure Details
+--------+-------------------------------+
| max_bt | MAIN | PR |
+========+===============================+================================+
| 896 | illegal memory access | illegal memory access |
| 1024 | CUBLAS_STATUS_EXECUTION_FAILED| illegal memory access |
+--------+-------------------------------+--------------------------------+ |
|
@Qubitium I am almost done with the draft status, I was just wondering what you were using to run your benches please? So I can run the same workload. Thank you! |
I will recreate the sample script that reproduced this. |
I have re-created the benchmark script but I cannot find/remember the model I used to reproduce the crashes or the score degration without the fix! It was a simple sweep of the max_batch_token size to stress cb | fa2 | page-attention on A100. Ah. I thought I was just using Llma 3.2 1b instruct but its not crashing under same config and the vram usage is not matching up so it is not the same/right model. I tried some other models and th vram usge is also not matching up and crashing at the aforementioned |
|
Ok, dropping the GMSK benchmark. Ready for review then. |
|
BTW @Qubitium I have not activated FA2 because I have a bug when running it. I am using the kernels implem of FA2 so that might be the issue. Here is a MRE of the issue. I wonder what implementation you were using and if you encountered the same issue. |
| # Validate and optionally filter processors based on their CB support | ||
| self._validate_processors(drop_unsupported_processors) | ||
| self._retrieve_processors_kwargs() | ||
| # Static boolean to know if there is any processing to do |
There was a problem hiding this comment.
logit processing i guess :)
There was a problem hiding this comment.
maybe property since self.logits_processor can change
There was a problem hiding this comment.
Actually, it can only be cleared right now! Which is already planned for. I changed the comment as asked.
Summary
This PR fixes the issue raised in #45274 .
CUDA graph reuse in continuous batching used (num_q_tokens, max_kv_read) as the graph cache key. However, FlashAttention varlen kernels also depend on max_seqlen_q and max_seqlen_k, which are Python ints baked into the graph as kernel launch constants. When two batches shared the same padded tensor sizes but had different max_seqlen_k, a stale graph was replayed with incorrect kernel parameters, causing accuracy collapse (e.g., GSM8K dropping from ~0.45 to ~0.23).
Additionally, use_cuda_graph was a single boolean controlling both the varlen (prefill) and decode paths, with no way to enable graphs for one path but not the other.
Fix
use_cuda_graph is now
bool | tuple[bool, bool] | None, allowing independent control over the varlen and decode paths. This enables power-user to short circuit the issue altogether and test out different configs for different workloads.The varlen graph key is now (num_q_tokens, max_kv_read, *max_seqlen_k.values()), capturing all integer values that affect FA kernel behavior. The variable
max_seqlen_kis bucketed using power of 2 padded sizes and a minimum size of 1024. max_seqlen_q is padded to num_q_tokens (already in the key). The decode key remains (num_q_tokens,) since max_seqlen_k is unused by flash_attn_with_kvcache.The old code set padding entries in cu_seqlens_q to q_size to work around NaN issues, but this inflated max_seqlen_q and complicated graph keying. Padding entries are now set to 0 (zero-length sequences), which FA handles correctly and removes the NaN workaround.
Performances
Perf is gloabally stable, a bit of regression in some case in normal because
max_seqlen_q/kare now always bigger than the actual value, hence a bit more work is created for the FA kernel.Tests
The following tests pass:
Eval on GMS8K
max_batch_tokenswith
meta-llama/Llama-3.1-8B-Instructand attn iskernels-community/flash-attn3