Bug report
attention_mla.MLA scales queries after projection,
|
query = jnp.concatenate([q_nope, q_pe], axis=-1) * self.softmax_scale |
however
cudnn_jax_flash_attention (implementation used when
attention=cudnn_flash_jax) also hardcodes the scale
|
scale=1.0 / math.sqrt(head_dim), |
This leads to incorrect attention results that do not match attention=dot_product and other implementations.
Logs/Output
No response
Environment Information
No response
Additional Context
No response
Bug report
attention_mla.MLAscales queries after projection,maxtext/src/MaxText/layers/attention_mla.py
Line 804 in 3a17530
cudnn_jax_flash_attention(implementation used whenattention=cudnn_flash_jax) also hardcodes the scalemaxtext/src/MaxText/layers/attention_op.py
Line 1509 in 3a17530
This leads to incorrect attention results that do not match
attention=dot_productand other implementations.Logs/Output
No response
Environment Information
No response
Additional Context
No response