Skip to content

Commit 14152c3

Browse files
committed
add figures
1 parent c966dc4 commit 14152c3

File tree

5 files changed

+6
-21
lines changed

5 files changed

+6
-21
lines changed

.DS_Store

0 Bytes
Binary file not shown.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
We introduce an innovative and out-of-box KV cache compression method, SnapKV.
33

44
![Comprehensive Experiment Results on LongBench](./figures/longbench.jpg)
5-
![Pressure Test Result on Needle-in-a-Haystack](./figures/LWM-Text-Chat-1M_SnapKV.pdf)
5+
![Pressure Test Result on Needle-in-a-Haystack](./figures/LWM-Text-Chat-1M_SnapKV.jpg)
66

77
## Quick Start
88
### Use SnapKV-optimized Models
137 KB
Loading
-22.9 KB
Binary file not shown.

observations/models/modeling_mistral_benchmark_layerwise.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
PyTorch Mistral baseline model.
2323
https://github.com/huggingface/transformers/blob/v4.36-release/src/transformers/models/mistral/modeling_mistral.py
2424
Please write change log here:
25-
[YL] save attention weights
26-
[YL] for benchmarking
25+
[SnapKV] save attention weights
26+
[SnapKV] for benchmarking
2727
"""
2828

2929
import inspect
@@ -307,7 +307,7 @@ def forward(
307307
# upcast attention to fp32
308308
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
309309

310-
# [YL] get stats=====================
310+
# [SnapKV] get stats=====================
311311
self.features_per_data = []
312312
threshold = self.threshold
313313
prev_len = self.prev_len
@@ -319,23 +319,8 @@ def forward(
319319
prev_attn_typical = attn_weights[0, :, start:end, :start] > threshold
320320
prev_attn_typical = prev_attn_typical.sum(1) > 0
321321
self.prev_attn_typical = prev_attn_typical
322-
323-
# for step in range(steps):
324-
# start = prev_len - window_size
325-
# end = prev_len
326-
# shift = window_size * step
327-
# prev_attn_sum = attn_weights[0, :, start:end, :start].sum(1)
328-
# cur_attn_sum = attn_weights[0, :, start + shift + window_size:end + shift + window_size, :start]
329-
# values, indices = torch.topk(prev_attn_sum, k=int(top_k * prev_len), dim=1)
330-
# mask = torch.zeros_like(prev_attn_sum, dtype=torch.bool, device=prev_attn_sum.device)
331-
# batch_indices = torch.arange(prev_attn_sum.size(0)).unsqueeze(1).expand_as(indices)
332-
# mask[batch_indices, indices] = 1
333-
# mask.unsqueeze_(1)
334-
# cur_attn_sum_threshold = cur_attn_sum > threshold
335-
# activation_overlap = cur_attn_sum_threshold & mask
336-
# hit_rate = activation_overlap.sum().float() / cur_attn_sum_threshold.sum().float()
337-
# self.features_per_data.append(hit_rate.item())
338-
# [YL] end ==========================
322+
323+
# [SnapKV] end ==========================
339324

340325
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
341326
attn_output = torch.matmul(attn_weights, value_states)

0 commit comments

Comments
 (0)