@@ -102,83 +102,6 @@ def test_sdpa_mask_patched(self):
102102 got = patched_sdpa_mask (** kwargs )
103103 self .assertEqualArray (expected , got )
104104
105- @requires_transformers ("4.99" )
106- def test_sdpa_mask_recent_torch_is_running (self ):
107- def _copy_vmap_for_bhqkv (mask_function , bh_indices = True ):
108- dimensions = [(None , None , None , 0 ), (None , None , 0 , None )]
109- if bh_indices :
110- dimensions .extend ([(None , 0 , None , None ), (0 , None , None , None )])
111- for dims in dimensions :
112- mask_function = torch .vmap (mask_function , in_dims = dims , out_dims = 0 )
113- return mask_function
114-
115- def copy_of_sdpa_mask_recent_torch (
116- batch_size ,
117- cache_position ,
118- kv_length ,
119- kv_offset = 0 ,
120- mask_function = transformers .masking_utils .causal_mask_function ,
121- attention_mask = None ,
122- local_size = None ,
123- allow_is_causal_skip = True ,
124- ** kwargs ,
125- ):
126- q_length = cache_position .shape [0 ]
127- padding_mask = transformers .masking_utils .prepare_padding_mask (
128- attention_mask , kv_length , kv_offset
129- )
130- if allow_is_causal_skip and transformers .masking_utils ._ignore_causal_mask_sdpa (
131- padding_mask , q_length , kv_length , kv_offset , local_size
132- ):
133- return None
134- kv_arange = torch .arange (kv_length , device = cache_position .device )
135- kv_arange += kv_offset
136- if padding_mask is not None :
137- mask_function = transformers .masking_utils .and_masks (
138- mask_function ,
139- transformers .masking_utils .padding_mask_function (padding_mask ),
140- )
141-
142- batch_arange = torch .arange (batch_size , device = cache_position .device )
143- head_arange = torch .arange (1 , device = cache_position .device )
144- with transformers .masking_utils .TransformGetItemToIndex ():
145- causal_mask = _copy_vmap_for_bhqkv (mask_function )(
146- batch_arange , head_arange , cache_position , kv_arange
147- )
148- return causal_mask
149-
150- sdpa_mask_recent_torch = copy_of_sdpa_mask_recent_torch
151- patched_sdpa_mask_recent_torch = patch_transformers .patched_sdpa_mask_recent_torch
152- kwargs = {
153- "batch_size" : 1 ,
154- "cache_position" : torch .tensor ([3 ], dtype = torch .int64 ),
155- "kv_length" : 4 ,
156- "kv_offset" : 0 ,
157- "mask_function" : transformers .masking_utils .causal_mask_function ,
158- "attention_mask" : torch .tensor ([[True , True , True , True ]]),
159- "local_size" : None ,
160- "allow_is_causal_skip" : True ,
161- "allow_is_bidirectional_skip" : False ,
162- }
163- expected = sdpa_mask_recent_torch (** kwargs )
164- got = patched_sdpa_mask_recent_torch (** kwargs )
165- self .assertEqual (expected , got )
166-
167- kwargs = {
168- "batch_size" : 1 ,
169- "cache_position" : torch .tensor ([3 ], dtype = torch .int64 ),
170- "kv_length" : 4 ,
171- "kv_offset" : 0 ,
172- "mask_function" : transformers .masking_utils .causal_mask_function ,
173- "attention_mask" : torch .tensor ([[True , True , True , True ]]),
174- "local_size" : None ,
175- "allow_is_causal_skip" : False ,
176- "allow_is_bidirectional_skip" : False ,
177- }
178- expected = sdpa_mask_recent_torch (** kwargs )
179- got = patched_sdpa_mask_recent_torch (** kwargs )
180- self .assertEqualArray (expected , got )
181-
182105 def test_sdpa_attention_forward_not_causal (self ):
183106 sdpa_attention_forward = sdpa_attention .sdpa_attention_forward
184107 patched_sdpa_attention_forward = patch_transformers .patched_sdpa_attention_forward
0 commit comments