We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 859e4b3 commit 64e1bc5Copy full SHA for 64e1bc5
1 file changed
src/maxdiffusion/models/attention_flax.py
@@ -225,7 +225,6 @@ def _tpu_flash_attention(
225
attention_mask: jax.Array = None,
226
) -> jax.Array:
227
"""TPU Flash Attention"""
228
- jax.debug.print("USing FLASH ATTENTION")
229
230
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
231
# This is the case for cross-attn.
@@ -445,7 +444,6 @@ def _apply_attention_dot(
445
444
float32_qk_product: bool,
446
use_memory_efficient_attention: bool,
447
):
448
- jax.debug.print("Using DOT PRODUCT ATTENTION")
449
"""Apply Attention."""
450
if split_head_dim:
451
b = key.shape[0]
0 commit comments