From dd8bf369f8e7d7ec474e5cc07da82e01d16bd7f2 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 20 May 2026 07:07:42 -0700 Subject: [PATCH] Fix Voxtral Metal streaming mask --- examples/models/voxtral_realtime/model.py | 3 +- .../tests/test_ring_kv_cache.py | 42 +++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 examples/models/voxtral_realtime/tests/test_ring_kv_cache.py diff --git a/examples/models/voxtral_realtime/model.py b/examples/models/voxtral_realtime/model.py index e591445cc56..3ff110c161e 100644 --- a/examples/models/voxtral_realtime/model.py +++ b/examples/models/voxtral_realtime/model.py @@ -1129,7 +1129,8 @@ def create_causal_mask( return torch.where( valid, torch.zeros(1, dtype=dtype, device=start_pos.device), - torch.tensor(float("-inf"), dtype=dtype, device=start_pos.device), + # MPS SDPA can propagate NaNs from -inf additive masks in AOTI. + torch.tensor(-1e9, dtype=dtype, device=start_pos.device), ) diff --git a/examples/models/voxtral_realtime/tests/test_ring_kv_cache.py b/examples/models/voxtral_realtime/tests/test_ring_kv_cache.py new file mode 100644 index 00000000000..274f0e177ba --- /dev/null +++ b/examples/models/voxtral_realtime/tests/test_ring_kv_cache.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from types import ModuleType +from unittest.mock import patch + +import torch + +with patch.dict( + "sys.modules", + {"executorch.extension.llm.custom_ops.custom_ops": ModuleType("custom_ops")}, +): + from executorch.examples.models.voxtral_realtime.model import StandardRingKVCache + + +class StandardRingKVCacheTest(unittest.TestCase): + def test_additive_mask_uses_finite_negative_values(self): + cache = StandardRingKVCache(window_size=4, n_heads=1, head_dim=2) + + mask = cache.create_causal_mask( + torch.tensor(0), seq_len=1, dtype=torch.bfloat16 + ) + + self.assertEqual(mask.dtype, torch.bfloat16) + self.assertTrue(torch.isfinite(mask).all()) + self.assertEqual(mask[0, 0].item(), 0) + self.assertLess(mask[0, 1].float().item(), -1e8) + + def test_bool_mask_keeps_bool_dtype(self): + cache = StandardRingKVCache(window_size=4, n_heads=1, head_dim=2) + + mask = cache.create_causal_mask(torch.tensor(3), seq_len=2, bool_mask=True) + + self.assertEqual(mask.dtype, torch.bool) + + +if __name__ == "__main__": + unittest.main()