Skip to content

Commit eddd12e

Browse files
committed
add comments
1 parent 196d378 commit eddd12e

File tree

4 files changed

+46
-42
lines changed

4 files changed

+46
-42
lines changed

models/modeling_llama.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
318318
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
319319
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
320320
self._init_rope()
321-
self.kv_cluster = KVCluster(window_size = 100, max_capacity_prompt = 500) # [YL] add kv_cluster
321+
self.kv_cluster = KVCluster(window_size = 100, max_capacity_prompt = 500) # [SnapKV] add kv_cluster
322322

323323
def _init_rope(self):
324324
if self.config.rope_scaling is None:
@@ -402,7 +402,7 @@ def forward(
402402
"with a layer index."
403403
)
404404
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
405-
if hasattr(self, "kv_seq_len"): #[YL] add kv_seq_len
405+
if hasattr(self, "kv_seq_len"): #[SnapKV] add kv_seq_len
406406
# print('self.kv_seq_len', self.kv_seq_len)
407407
if self.kv_seq_len != 0:
408408
kv_seq_len += self.kv_seq_len
@@ -414,7 +414,7 @@ def forward(
414414
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
415415
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
416416

417-
# [YL] move to ahead
417+
# [SnapKV] move to ahead
418418
key_states = repeat_kv(key_states, self.num_key_value_groups)
419419
value_states = repeat_kv(value_states, self.num_key_value_groups)
420420

@@ -425,7 +425,7 @@ def forward(
425425
# key_states = repeat_kv(key_states, self.num_key_value_groups)
426426
# value_states = repeat_kv(value_states, self.num_key_value_groups)
427427

428-
kv_seq_len = key_states.shape[-2] # [YL] adjust kv_seq_len
428+
kv_seq_len = key_states.shape[-2] # [SnapKV] adjust kv_seq_len
429429

430430
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
431431

@@ -436,7 +436,7 @@ def forward(
436436
)
437437

438438
if attention_mask is not None:
439-
attention_mask = attention_mask[...,-kv_seq_len:] # [YL]
439+
attention_mask = attention_mask[...,-kv_seq_len:] # [SnapKV]
440440
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
441441
raise ValueError(
442442
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
@@ -530,7 +530,7 @@ def forward(
530530
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
531531
"with a layer index."
532532
)
533-
if hasattr(self, "kv_seq_len"): #[YL] add kv_seq_len
533+
if hasattr(self, "kv_seq_len"): #[SnapKV] add kv_seq_len
534534
# print('self.kv_seq_len', self.kv_seq_len)
535535
if self.kv_seq_len != 0:
536536
kv_seq_len += self.kv_seq_len
@@ -541,14 +541,14 @@ def forward(
541541

542542
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
543543
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
544-
# [YL] move to ahead
544+
# [SnapKV] move to ahead
545545
key_states = repeat_kv(key_states, self.num_key_value_groups)
546546
value_states = repeat_kv(value_states, self.num_key_value_groups)
547547

548548
if past_key_value is not None:
549549
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
550550
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
551-
if key_states.shape[-2] == kv_seq_len: # [YL] add kv_cluster
551+
if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster
552552
self.kv_seq_len = kv_seq_len
553553
key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups)
554554
past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
@@ -1273,7 +1273,7 @@ def forward(
12731273
def prepare_inputs_for_generation(
12741274
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
12751275
):
1276-
if past_key_values is None: # [YL]
1276+
if past_key_values is None: # [SnapKV]
12771277
for layer in self.model.layers:
12781278
layer.self_attn.kv_seq_len = 0
12791279
if past_key_values is not None:

models/modeling_mistral.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
PyTorch Mistral baseline model.
2323
https://github.com/huggingface/transformers/blob/v4.36-release/src/transformers/models/mistral/modeling_mistral.py
2424
Please write change log here:
25+
[SnapKV] save attention weights
2526
"""
2627

2728
import inspect
@@ -49,7 +50,7 @@
4950
replace_return_docstrings,
5051
)
5152
from transformers.models.mistral.configuration_mistral import MistralConfig
52-
from snapkv_utils import KVCluster
53+
from utils_yl_ratio_avgpool_v2 import KVCluster # [SnapKV]
5354

5455

5556
if is_flash_attn_2_available():
@@ -239,7 +240,7 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
239240
max_position_embeddings=self.max_position_embeddings,
240241
base=self.rope_theta,
241242
)
242-
self.kv_cluster = KVCluster(window_size = 100, max_capacity_prompt = 500) # add kv_cluster
243+
self.kv_cluster = KVCluster(window_size = 100, max_capacity_prompt = 500) # [SnapKV] add kv_cluster
243244
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
244245
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
245246

@@ -276,7 +277,7 @@ def forward(
276277
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
277278
"with a layer index."
278279
)
279-
if hasattr(self, "kv_seq_len"): # add kv_seq_len
280+
if hasattr(self, "kv_seq_len"): #[SnapKV] add kv_seq_len
280281
# print('self.kv_seq_len', self.kv_seq_len)
281282
if self.kv_seq_len != 0:
282283
kv_seq_len += self.kv_seq_len
@@ -289,21 +290,21 @@ def forward(
289290
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
290291

291292
# repeat k/v heads if n_kv_heads < n_heads
292-
# move to ahead
293+
# [SnapKV] move to ahead
293294
key_states = repeat_kv(key_states, self.num_key_value_groups)
294295
value_states = repeat_kv(value_states, self.num_key_value_groups)
295296

296297
if past_key_value is not None:
297298
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
298-
if key_states.shape[-2] == kv_seq_len: # add kv_cluster
299+
if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster
299300
self.kv_seq_len = kv_seq_len
300301
key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups)
301302
past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
302303
else:
303304
self.kv_seq_len += q_len
304305
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
305306

306-
kv_seq_len = key_states.shape[-2] # adjust kv_seq_len
307+
kv_seq_len = key_states.shape[-2] # [SnapKV] adjust kv_seq_len
307308

308309
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
309310

@@ -359,7 +360,7 @@ def __init__(self, *args, **kwargs):
359360
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
360361
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
361362
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
362-
# self.kv_cluster = KVCluster(window_size = 100, max_capacity_prompt = 500) # add kv_cluster
363+
# self.kv_cluster = KVCluster(window_size = 100, max_capacity_prompt = 500) # [SnapKV] add kv_cluster
363364

364365

365366
def forward(
@@ -405,7 +406,7 @@ def forward(
405406
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
406407
"with a layer index."
407408
)
408-
if hasattr(self, "kv_seq_len"): # add kv_seq_len
409+
if hasattr(self, "kv_seq_len"): #[SnapKV] add kv_seq_len
409410
# print('self.kv_seq_len', self.kv_seq_len)
410411
if self.kv_seq_len != 0:
411412
kv_seq_len += self.kv_seq_len
@@ -432,7 +433,7 @@ def forward(
432433
" make sure to upgrade flash-attn library."
433434
)
434435
# repeat k/v heads if n_kv_heads < n_heads
435-
# move to ahead
436+
# [SnapKV] move to ahead
436437
key_states = repeat_kv(key_states, self.num_key_value_groups)
437438
value_states = repeat_kv(value_states, self.num_key_value_groups)
438439

@@ -463,7 +464,7 @@ def forward(
463464
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
464465

465466
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
466-
if key_states.shape[-2] == kv_seq_len: # add kv_cluster
467+
if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster
467468
self.kv_seq_len = kv_seq_len
468469
key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups)
469470
past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
@@ -502,7 +503,7 @@ def forward(
502503
key_states = key_states.transpose(1, 2)
503504
value_states = value_states.transpose(1, 2)
504505
# print('layer id', self.layer_idx, 'query_states', query_states.shape, 'key_states', key_states.shape, 'value_states', value_states.shape, 'kv_seq_len', kv_seq_len, 'dropout_rate', dropout_rate, 'use_sliding_windows', use_sliding_windows)
505-
# change attention_mask to None
506+
# [SnapKV] change attention_mask to None
506507
# print('layer id', self.layer_idx, 'query_states', query_states.shape, 'key_states', key_states.shape, 'value_states', value_states.shape, 'attention_mask', attention_mask.shape, 'kv_seq_len', kv_seq_len, 'dropout_rate', dropout_rate, 'use_sliding_windows', use_sliding_windows)
507508
attn_output = self._flash_attention_forward(
508509
query_states,
@@ -956,7 +957,7 @@ def forward(
956957
)
957958

958959
if self._use_flash_attention_2:
959-
# if False: # attention_mask is used for compression
960+
# if False: # [SnapKV] attention_mask is used for compression
960961
# 2d mask is passed through the layers
961962
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
962963
else:
@@ -1161,7 +1162,7 @@ def prepare_inputs_for_generation(
11611162
max_cache_length = past_key_values.get_max_length()
11621163
else:
11631164
# # cache_length = past_length = past_key_values[0][0].shape[2]
1164-
# if len(past_key_values) == 0: # for the first time, past_key_values is empty
1165+
# if len(past_key_values) == 0: # [SnapKV] for the first time, past_key_values is empty
11651166
# print('fuck')
11661167
# for layer in self.model.layers:
11671168
# if hasattr(layer, "self_attn"):

models/modeling_mixtral.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
)
5353
from transformers.utils.import_utils import is_torch_fx_available
5454
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
55-
from snapkv_utils import KVCluster
55+
from utils_yl_ratio_avgpool_v2 import KVCluster # [SnapKV]
5656

5757
if is_flash_attn_2_available():
5858
from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -276,7 +276,7 @@ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
276276
max_position_embeddings=self.max_position_embeddings,
277277
base=self.rope_theta,
278278
)
279-
self.kv_cluster = KVCluster(window_size = 100, max_capacity_prompt = 500) # add kv_cluster
279+
self.kv_cluster = KVCluster(window_size = 100, max_capacity_prompt = 500) # [SnapKV] add kv_cluster
280280

281281
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
282282
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@@ -314,7 +314,7 @@ def forward(
314314
"with a layer index."
315315
)
316316
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
317-
if hasattr(self, "kv_seq_len"): # add kv_seq_len
317+
if hasattr(self, "kv_seq_len"): # [SnapKV] add kv_seq_len
318318
# print('self.kv_seq_len', self.kv_seq_len)
319319
if self.kv_seq_len != 0:
320320
kv_seq_len += self.kv_seq_len
@@ -326,12 +326,13 @@ def forward(
326326
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
327327
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
328328

329+
# [SnapKV] move to ahead
329330
key_states = repeat_kv(key_states, self.num_key_value_groups)
330331
value_states = repeat_kv(value_states, self.num_key_value_groups)
331332

332333
if past_key_value is not None:
333334
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
334-
if key_states.shape[-2] == kv_seq_len: # add kv_cluster
335+
if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster
335336
self.kv_seq_len = kv_seq_len
336337
key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups)
337338
past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
@@ -343,7 +344,7 @@ def forward(
343344
# key_states = repeat_kv(key_states, self.num_key_value_groups)
344345
# value_states = repeat_kv(value_states, self.num_key_value_groups)
345346

346-
kv_seq_len = key_states.shape[-2] # adjust kv_seq_len
347+
kv_seq_len = key_states.shape[-2] # [SnapKV] adjust kv_seq_len
347348

348349
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
349350

@@ -354,7 +355,7 @@ def forward(
354355
)
355356

356357
if attention_mask is not None:
357-
attention_mask = attention_mask[...,-kv_seq_len:]
358+
attention_mask = attention_mask[...,-kv_seq_len:] # [SnapKV]
358359
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
359360
raise ValueError(
360361
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
@@ -437,7 +438,7 @@ def forward(
437438
"with a layer index."
438439
)
439440
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
440-
if hasattr(self, "kv_seq_len"): # add kv_seq_len
441+
if hasattr(self, "kv_seq_len"): #[SnapKV] add kv_seq_len
441442
# print('self.kv_seq_len', self.kv_seq_len)
442443
if self.kv_seq_len != 0:
443444
kv_seq_len += self.kv_seq_len
@@ -465,6 +466,7 @@ def forward(
465466
" make sure to upgrade flash-attn library."
466467
)
467468

469+
# [SnapKV] move to ahead
468470
key_states = repeat_kv(key_states, self.num_key_value_groups)
469471
value_states = repeat_kv(value_states, self.num_key_value_groups)
470472

@@ -497,7 +499,7 @@ def forward(
497499

498500
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
499501
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
500-
if key_states.shape[-2] == kv_seq_len: # add kv_cluster
502+
if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster
501503
self.kv_seq_len = kv_seq_len
502504
key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups)
503505
past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
@@ -1413,7 +1415,7 @@ def forward(
14131415
def prepare_inputs_for_generation(
14141416
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
14151417
):
1416-
if past_key_values is None:
1418+
if past_key_values is None: # [SnapKV]
14171419
for layer in self.model.layers:
14181420
layer.self_attn.kv_seq_len = 0
14191421
# Omit tokens covered by past_key_values

0 commit comments

Comments
 (0)