2222PyTorch Mistral baseline model.
2323https://github.com/huggingface/transformers/blob/v4.36-release/src/transformers/models/mistral/modeling_mistral.py
2424Please write change log here:
25- [YL ] save attention weights
26- [YL ] for benchmarking
25+ [SnapKV ] save attention weights
26+ [SnapKV ] for benchmarking
2727"""
2828
2929import inspect
@@ -307,7 +307,7 @@ def forward(
307307 # upcast attention to fp32
308308 attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query_states .dtype )
309309
310- # [YL ] get stats=====================
310+ # [SnapKV ] get stats=====================
311311 self .features_per_data = []
312312 threshold = self .threshold
313313 prev_len = self .prev_len
@@ -319,23 +319,8 @@ def forward(
319319 prev_attn_typical = attn_weights [0 , :, start :end , :start ] > threshold
320320 prev_attn_typical = prev_attn_typical .sum (1 ) > 0
321321 self .prev_attn_typical = prev_attn_typical
322-
323- # for step in range(steps):
324- # start = prev_len - window_size
325- # end = prev_len
326- # shift = window_size * step
327- # prev_attn_sum = attn_weights[0, :, start:end, :start].sum(1)
328- # cur_attn_sum = attn_weights[0, :, start + shift + window_size:end + shift + window_size, :start]
329- # values, indices = torch.topk(prev_attn_sum, k=int(top_k * prev_len), dim=1)
330- # mask = torch.zeros_like(prev_attn_sum, dtype=torch.bool, device=prev_attn_sum.device)
331- # batch_indices = torch.arange(prev_attn_sum.size(0)).unsqueeze(1).expand_as(indices)
332- # mask[batch_indices, indices] = 1
333- # mask.unsqueeze_(1)
334- # cur_attn_sum_threshold = cur_attn_sum > threshold
335- # activation_overlap = cur_attn_sum_threshold & mask
336- # hit_rate = activation_overlap.sum().float() / cur_attn_sum_threshold.sum().float()
337- # self.features_per_data.append(hit_rate.item())
338- # [YL] end ==========================
322+
323+ # [SnapKV] end ==========================
339324
340325 attn_weights = nn .functional .dropout (attn_weights , p = self .attention_dropout , training = self .training )
341326 attn_output = torch .matmul (attn_weights , value_states )
0 commit comments