Skip to content

Commit 9f913f1

Browse files
committed
fix
1 parent 5d39ae1 commit 9f913f1

2 files changed

Lines changed: 1 addition & 78 deletions

File tree

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -102,83 +102,6 @@ def test_sdpa_mask_patched(self):
102102
got = patched_sdpa_mask(**kwargs)
103103
self.assertEqualArray(expected, got)
104104

105-
@requires_transformers("4.99")
106-
def test_sdpa_mask_recent_torch_is_running(self):
107-
def _copy_vmap_for_bhqkv(mask_function, bh_indices=True):
108-
dimensions = [(None, None, None, 0), (None, None, 0, None)]
109-
if bh_indices:
110-
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
111-
for dims in dimensions:
112-
mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
113-
return mask_function
114-
115-
def copy_of_sdpa_mask_recent_torch(
116-
batch_size,
117-
cache_position,
118-
kv_length,
119-
kv_offset=0,
120-
mask_function=transformers.masking_utils.causal_mask_function,
121-
attention_mask=None,
122-
local_size=None,
123-
allow_is_causal_skip=True,
124-
**kwargs,
125-
):
126-
q_length = cache_position.shape[0]
127-
padding_mask = transformers.masking_utils.prepare_padding_mask(
128-
attention_mask, kv_length, kv_offset
129-
)
130-
if allow_is_causal_skip and transformers.masking_utils._ignore_causal_mask_sdpa(
131-
padding_mask, q_length, kv_length, kv_offset, local_size
132-
):
133-
return None
134-
kv_arange = torch.arange(kv_length, device=cache_position.device)
135-
kv_arange += kv_offset
136-
if padding_mask is not None:
137-
mask_function = transformers.masking_utils.and_masks(
138-
mask_function,
139-
transformers.masking_utils.padding_mask_function(padding_mask),
140-
)
141-
142-
batch_arange = torch.arange(batch_size, device=cache_position.device)
143-
head_arange = torch.arange(1, device=cache_position.device)
144-
with transformers.masking_utils.TransformGetItemToIndex():
145-
causal_mask = _copy_vmap_for_bhqkv(mask_function)(
146-
batch_arange, head_arange, cache_position, kv_arange
147-
)
148-
return causal_mask
149-
150-
sdpa_mask_recent_torch = copy_of_sdpa_mask_recent_torch
151-
patched_sdpa_mask_recent_torch = patch_transformers.patched_sdpa_mask_recent_torch
152-
kwargs = {
153-
"batch_size": 1,
154-
"cache_position": torch.tensor([3], dtype=torch.int64),
155-
"kv_length": 4,
156-
"kv_offset": 0,
157-
"mask_function": transformers.masking_utils.causal_mask_function,
158-
"attention_mask": torch.tensor([[True, True, True, True]]),
159-
"local_size": None,
160-
"allow_is_causal_skip": True,
161-
"allow_is_bidirectional_skip": False,
162-
}
163-
expected = sdpa_mask_recent_torch(**kwargs)
164-
got = patched_sdpa_mask_recent_torch(**kwargs)
165-
self.assertEqual(expected, got)
166-
167-
kwargs = {
168-
"batch_size": 1,
169-
"cache_position": torch.tensor([3], dtype=torch.int64),
170-
"kv_length": 4,
171-
"kv_offset": 0,
172-
"mask_function": transformers.masking_utils.causal_mask_function,
173-
"attention_mask": torch.tensor([[True, True, True, True]]),
174-
"local_size": None,
175-
"allow_is_causal_skip": False,
176-
"allow_is_bidirectional_skip": False,
177-
}
178-
expected = sdpa_mask_recent_torch(**kwargs)
179-
got = patched_sdpa_mask_recent_torch(**kwargs)
180-
self.assertEqualArray(expected, got)
181-
182105
def test_sdpa_attention_forward_not_causal(self):
183106
sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
184107
patched_sdpa_attention_forward = patch_transformers.patched_sdpa_attention_forward

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ def patched_sdpa_mask_recent_torch(
165165
# `cache_position` is deprecated as an arg,
166166
# and will be removed in Transformers v5.6. Please use `q_length` and "
167167
# `q_offset` instead, similarly to `kv_length` and `kv_offset`"
168-
q_length, q_offset = q_length.shape[0], q_length[0].to(device)
169168
device = q_length.device
169+
q_length, q_offset = q_length.shape[0], q_length[0].to(device)
170170

171171
padding_mask = prepare_padding_mask(
172172
attention_mask, kv_length, kv_offset, **_prepare_padding_mask_kwargs

0 commit comments

Comments
 (0)