You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/examples/attention/attention.ipynb
+2-1Lines changed: 2 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -151,6 +151,7 @@
151
151
"- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n",
152
152
"- flash-attention supports KV-caching and paged attention, and cuDNN attention does not.\n",
153
153
"- flash-attention uses bottom right diagonal for `causal` mask in cross attention (see [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)), and cuDNN attention supports both top left and bottom right.\n",
154
+
"- **Sliding window attention (SWA):** flash-attention has SWA(left, right) support for all mask types except top-left causal masks, with or without dropout, and without bias. cuDNN attention supports SWA(left, 0) starting from 9.2 and SWA(left, right) starting from 9.6, without dropout, and with `bias_type=\"no_bias\"`.\n",
154
155
"- flash-attention outperforms cuDNN attention on Ampere architectures, and cuDNN attention has 20-50% advantages on Hopper architectures, based on our benchmarks for a number of commonly-used model configurations.\n",
155
156
"\n",
156
157
"To compare cuDNN attention and flash-attention, users can modify the `model_configs` dictionary in [benchmarks/attention/benchmark_attention.py](https://github.com/NVIDIA/TransformerEngine/blob/main/benchmarks/attention/benchmark_attention.py) to collect performance numbers. The script runs each entry in `model_configs` for `num_iters` times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0."
Copy file name to clipboardExpand all lines: docs/examples/attention/cp_ag_thd_dpa_jax_deep_dive.ipynb
+1Lines changed: 1 addition & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -28,6 +28,7 @@
28
28
"source": [
29
29
"### Question 1: Why choose Striped>1 ?\n",
30
30
"\n",
31
+
"\n",
31
32
"Prior to the addition of this feature, Transformer Engine JAX attention already supported load balancing via a striping pattern, i.e., `stripe_size=1` for `CP + THD + P2P(Ring) + Striped + SWA`. However, this reordering technique does not lend itself well to an all-gathered (post-AG) pattern. The following example illustrates this distinction. For this example, `cp_size=4`, `num_segments=4`, `window_size=(8,0)`, and the pattern is for a single rank after striped reordering has been performed: \n",
0 commit comments