diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e2ece5cb3685..94c6d1387843 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -676,9 +676,16 @@ def get_attention_scores( key = key.float() if attention_mask is None: - baddbmm_input = torch.empty( - query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device - ) + if query.device.type == "mps": + # MPS' baddbmm does not short-circuit on beta=0, so an + # uninitialized input from torch.empty() can propagate NaN. + baddbmm_input = torch.zeros( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + else: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) beta = 0 else: baddbmm_input = attention_mask diff --git a/tests/models/test_attention_processor.py b/tests/models/test_attention_processor.py index 8b45c2148504..e674a5d9ea01 100644 --- a/tests/models/test_attention_processor.py +++ b/tests/models/test_attention_processor.py @@ -133,3 +133,32 @@ def test_conversion_when_using_device_map(self): self.assertTrue(np.allclose(pre_conversion, conversion, atol=1e-3)) self.assertTrue(np.allclose(conversion, after_conversion, atol=1e-3)) + + +class GetAttentionScoresMPSTests(unittest.TestCase): + @pytest.mark.skipif(torch_device != "mps", reason="test exercises an MPS-specific code path") + def test_no_nan_when_attention_mask_is_none_on_mps(self): + # Regression test: torch.empty() on MPS can return non-finite values, + # and MPS' baddbmm does not short-circuit on beta=0, so an unmasked + # call to get_attention_scores used to propagate NaN into the output. + torch.manual_seed(0) + heads, dim_head, seq_len = 4, 32, 256 + attn = Attention( + query_dim=heads * dim_head, + heads=heads, + dim_head=dim_head, + bias=False, + ).to(torch_device, torch.float16) + + for _ in range(20): + # Pollute the MPS allocator pool with non-finite values so that a + # subsequent torch.empty() is likely to return NaN-filled memory. + polluter = torch.full((heads, seq_len, seq_len), float("nan"), device=torch_device, dtype=torch.float16) + del polluter + + query = torch.randn(1, seq_len, heads * dim_head, device=torch_device, dtype=torch.float16) + key = torch.randn(1, seq_len, heads * dim_head, device=torch_device, dtype=torch.float16) + scores = attn.get_attention_scores( + attn.head_to_batch_dim(query), attn.head_to_batch_dim(key), attention_mask=None + ) + self.assertFalse(torch.isnan(scores).any().item(), "attention scores contain NaN on MPS")