Skip to content

[mps] Fix NaN in Attention.get_attention_scores when attention_mask is None#13701

Open
Shreyas-jk wants to merge 1 commit intohuggingface:mainfrom
Shreyas-jk:fix/mps-attention-scores-nan-11229
Open

[mps] Fix NaN in Attention.get_attention_scores when attention_mask is None#13701
Shreyas-jk wants to merge 1 commit intohuggingface:mainfrom
Shreyas-jk:fix/mps-attention-scores-nan-11229

Conversation

@Shreyas-jk
Copy link
Copy Markdown

@Shreyas-jk Shreyas-jk commented May 8, 2026

What does this PR do?

Fixes #11229

Attention.get_attention_scores allocates baddbmm_input with torch.empty() and uses beta=0, relying on baddbmm to ignore the uninitialized input. The MPS baddbmm kernel does not short-circuit on beta=0, so any NaN/Inf in the uninitialized memory propagates through 0 * NaN = NaN and poisons the attention output. CUDA happens to mask this because its allocator typically returns zero-initialized memory.

This change uses torch.zeros instead of torch.empty only on MPS, leaving the CUDA / CPU / XPU paths unchanged so they don't pay the extra fill cost.

In real workloads this surfaces as black/NaN images from StableDiffusionXLPipeline with enable_attention_slicing() on Apple Silicon + fp16, which is the standard memory-saving path on Macs with limited unified memory.

Reproduction

Minimal repro on M-series MPS (without the fix): 30/30 trials produce NaN. With the fix: 0/30. CPU baseline: 0/5. Verified on M4 MacBook Pro, torch 2.11.0, fp16 and fp32.

The added test GetAttentionScoresMPSTests.test_no_nan_when_attention_mask_is_none_on_mps reproduces the bug deterministically (fails 20/20 without the fix, passes 20/20 with it) by polluting the MPS allocator pool with NaN-filled tensors before each call.

Before submitting

Who can review?

@pcuenca (MPS / Apple Silicon maintainer per the PR template)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

enable_attention_slicing give NaN results for SDXL on MPS

1 participant