2222PyTorch Mistral baseline model.
2323https://github.com/huggingface/transformers/blob/v4.36-release/src/transformers/models/mistral/modeling_mistral.py
2424Please write change log here:
25+ [SnapKV] save attention weights
2526"""
2627
2728import inspect
4950 replace_return_docstrings ,
5051)
5152from transformers .models .mistral .configuration_mistral import MistralConfig
52- from snapkv_utils import KVCluster
53+ from utils_yl_ratio_avgpool_v2 import KVCluster # [SnapKV]
5354
5455
5556if is_flash_attn_2_available ():
@@ -239,7 +240,7 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
239240 max_position_embeddings = self .max_position_embeddings ,
240241 base = self .rope_theta ,
241242 )
242- self .kv_cluster = KVCluster (window_size = 100 , max_capacity_prompt = 500 ) # add kv_cluster
243+ self .kv_cluster = KVCluster (window_size = 100 , max_capacity_prompt = 500 ) # [SnapKV] add kv_cluster
243244 def _shape (self , tensor : torch .Tensor , seq_len : int , bsz : int ):
244245 return tensor .view (bsz , seq_len , self .num_heads , self .head_dim ).transpose (1 , 2 ).contiguous ()
245246
@@ -276,7 +277,7 @@ def forward(
276277 "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
277278 "with a layer index."
278279 )
279- if hasattr (self , "kv_seq_len" ): # add kv_seq_len
280+ if hasattr (self , "kv_seq_len" ): #[SnapKV] add kv_seq_len
280281 # print('self.kv_seq_len', self.kv_seq_len)
281282 if self .kv_seq_len != 0 :
282283 kv_seq_len += self .kv_seq_len
@@ -289,21 +290,21 @@ def forward(
289290 query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
290291
291292 # repeat k/v heads if n_kv_heads < n_heads
292- # move to ahead
293+ # [SnapKV] move to ahead
293294 key_states = repeat_kv (key_states , self .num_key_value_groups )
294295 value_states = repeat_kv (value_states , self .num_key_value_groups )
295296
296297 if past_key_value is not None :
297298 cache_kwargs = {"sin" : sin , "cos" : cos } # Specific to RoPE models
298- if key_states .shape [- 2 ] == kv_seq_len : # add kv_cluster
299+ if key_states .shape [- 2 ] == kv_seq_len : # [SnapKV] add kv_cluster
299300 self .kv_seq_len = kv_seq_len
300301 key_states_compress , value_states_compress = self .kv_cluster .update_kv (key_states , query_states , value_states , attention_mask , self .num_key_value_groups )
301302 past_key_value .update (key_states_compress , value_states_compress , self .layer_idx , cache_kwargs )
302303 else :
303304 self .kv_seq_len += q_len
304305 key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
305306
306- kv_seq_len = key_states .shape [- 2 ] # adjust kv_seq_len
307+ kv_seq_len = key_states .shape [- 2 ] # [SnapKV] adjust kv_seq_len
307308
308309 attn_weights = torch .matmul (query_states , key_states .transpose (2 , 3 )) / math .sqrt (self .head_dim )
309310
@@ -359,7 +360,7 @@ def __init__(self, *args, **kwargs):
359360 # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
360361 # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
361362 self ._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10 ()
362- # self.kv_cluster = KVCluster(window_size = 100, max_capacity_prompt = 500) # add kv_cluster
363+ # self.kv_cluster = KVCluster(window_size = 100, max_capacity_prompt = 500) # [SnapKV] add kv_cluster
363364
364365
365366 def forward (
@@ -405,7 +406,7 @@ def forward(
405406 "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
406407 "with a layer index."
407408 )
408- if hasattr (self , "kv_seq_len" ): # add kv_seq_len
409+ if hasattr (self , "kv_seq_len" ): #[SnapKV] add kv_seq_len
409410 # print('self.kv_seq_len', self.kv_seq_len)
410411 if self .kv_seq_len != 0 :
411412 kv_seq_len += self .kv_seq_len
@@ -432,7 +433,7 @@ def forward(
432433 " make sure to upgrade flash-attn library."
433434 )
434435 # repeat k/v heads if n_kv_heads < n_heads
435- # move to ahead
436+ # [SnapKV] move to ahead
436437 key_states = repeat_kv (key_states , self .num_key_value_groups )
437438 value_states = repeat_kv (value_states , self .num_key_value_groups )
438439
@@ -463,7 +464,7 @@ def forward(
463464 attention_mask = torch .cat ([attention_mask , torch .ones_like (attention_mask [:, - 1 :])], dim = - 1 )
464465
465466 cache_kwargs = {"sin" : sin , "cos" : cos } # Specific to RoPE models
466- if key_states .shape [- 2 ] == kv_seq_len : # add kv_cluster
467+ if key_states .shape [- 2 ] == kv_seq_len : # [SnapKV] add kv_cluster
467468 self .kv_seq_len = kv_seq_len
468469 key_states_compress , value_states_compress = self .kv_cluster .update_kv (key_states , query_states , value_states , attention_mask , self .num_key_value_groups )
469470 past_key_value .update (key_states_compress , value_states_compress , self .layer_idx , cache_kwargs )
@@ -502,7 +503,7 @@ def forward(
502503 key_states = key_states .transpose (1 , 2 )
503504 value_states = value_states .transpose (1 , 2 )
504505 # print('layer id', self.layer_idx, 'query_states', query_states.shape, 'key_states', key_states.shape, 'value_states', value_states.shape, 'kv_seq_len', kv_seq_len, 'dropout_rate', dropout_rate, 'use_sliding_windows', use_sliding_windows)
505- # change attention_mask to None
506+ # [SnapKV] change attention_mask to None
506507 # print('layer id', self.layer_idx, 'query_states', query_states.shape, 'key_states', key_states.shape, 'value_states', value_states.shape, 'attention_mask', attention_mask.shape, 'kv_seq_len', kv_seq_len, 'dropout_rate', dropout_rate, 'use_sliding_windows', use_sliding_windows)
507508 attn_output = self ._flash_attention_forward (
508509 query_states ,
@@ -956,7 +957,7 @@ def forward(
956957 )
957958
958959 if self ._use_flash_attention_2 :
959- # if False: # attention_mask is used for compression
960+ # if False: # [SnapKV] attention_mask is used for compression
960961 # 2d mask is passed through the layers
961962 attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask ) else None
962963 else :
@@ -1161,7 +1162,7 @@ def prepare_inputs_for_generation(
11611162 max_cache_length = past_key_values .get_max_length ()
11621163 else :
11631164 # # cache_length = past_length = past_key_values[0][0].shape[2]
1164- # if len(past_key_values) == 0: # for the first time, past_key_values is empty
1165+ # if len(past_key_values) == 0: # [SnapKV] for the first time, past_key_values is empty
11651166 # print('fuck')
11661167 # for layer in self.model.layers:
11671168 # if hasattr(layer, "self_attn"):
0 commit comments