Skip to content

Commit 64e1bc5

Browse files
committed
renoved extra debug
1 parent 859e4b3 commit 64e1bc5

1 file changed

Lines changed: 0 additions & 2 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@ def _tpu_flash_attention(
225225
attention_mask: jax.Array = None,
226226
) -> jax.Array:
227227
"""TPU Flash Attention"""
228-
jax.debug.print("USing FLASH ATTENTION")
229228

230229
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
231230
# This is the case for cross-attn.
@@ -445,7 +444,6 @@ def _apply_attention_dot(
445444
float32_qk_product: bool,
446445
use_memory_efficient_attention: bool,
447446
):
448-
jax.debug.print("Using DOT PRODUCT ATTENTION")
449447
"""Apply Attention."""
450448
if split_head_dim:
451449
b = key.shape[0]

0 commit comments

Comments
 (0)