1+ import json
2+ from fastchat .model import load_model , get_conversation_template
3+ from models .modeling_mistral_benchmark import MistralForCausalLM as MyMistralForCausalLM
4+ import torch
5+ import transformers
6+ from tqdm import tqdm
7+ import numpy as np
8+ import argparse
9+
10+ def load_jsonl (file_path ):
11+ data = []
12+ with open (file_path , 'r' ) as f :
13+ for line in f :
14+ try :
15+ data .append (json .loads (line ))
16+ except :
17+ pass
18+ return data
19+
20+ def load_data (data , model_name , tokenizer ):
21+ conv = get_conversation_template (model_name )
22+ max_turns = len (data ) // 2
23+ for i in range (max_turns - 1 ):
24+ conv .append_message (conv .roles [0 ], data [i * 2 ])
25+ conv .append_message (conv .roles [1 ], data [i * 2 + 1 ])
26+ conv .append_message (conv .roles [0 ], data [(max_turns - 1 )* 2 ])
27+ start = tokenizer .encode (conv .get_prompt (), return_tensors = 'pt' ).shape [1 ]
28+ conv .append_message (conv .roles [1 ], data [(max_turns - 1 )* 2 + 1 ])
29+ end = tokenizer .encode (conv .get_prompt (), return_tensors = 'pt' ).shape [1 ]
30+ return conv .get_prompt (), start - 1 , end
31+
32+ def main (args ):
33+ dataset_path = args .dataset_path
34+ dataset = load_jsonl (dataset_path )
35+ model_name = args .model_name
36+ model = MyMistralForCausalLM .from_pretrained (
37+ model_name ,
38+ torch_dtype = torch .float16 ,
39+ low_cpu_mem_usage = True ,
40+ device_map = "auto" ,
41+ # use_flash_attention_2=True
42+ )
43+ tokenizer = transformers .AutoTokenizer .from_pretrained (
44+ model_name ,
45+ padding_side = "right" ,
46+ use_fast = False ,
47+ )
48+
49+ features = []
50+ window_size = args .window_size
51+ steps = args .steps
52+ threshold = args .threshold
53+
54+ layer_len = len (model .model .layers )
55+
56+ for data in tqdm (dataset ):
57+ features_per_data = []
58+ with torch .inference_mode ():
59+ prompt , prev_len , end = load_data (data , model_name , tokenizer )
60+ input_ids = tokenizer (prompt , return_tensors = 'pt' ).input_ids .to (model .device )
61+ for layer_id in range (layer_len ):
62+ model .model .layers [layer_id ].self_attn .prev_len = prev_len
63+ model .model .layers [layer_id ].self_attn .steps = steps
64+ model .model .layers [layer_id ].self_attn .threshold = threshold
65+ model .model .layers [layer_id ].self_attn .window_size = window_size
66+ outputs = model (input_ids , output_attentions = False , use_cache = False )
67+ for layer_id in range (layer_len ):
68+ features_per_data .append (model .model .layers [layer_id ].self_attn .features_per_data )
69+ # attn_weights = model.model.layers[layer_id].self_attn.attn_weights
70+ # total_len = attn_weights.shape[-1]
71+ # for step in range(steps):
72+ # start = prev_len - window_size
73+ # end = prev_len
74+ # shift = window_size * step
75+ # prev_attn_sum = attn_weights[0, :, start:end, :start].sum(1)
76+ # cur_attn_sum = attn_weights[0, :, start + shift + window_size:end + shift + window_size, :start].sum(1)
77+ # prev_attn_sum_threshold = prev_attn_sum > (threshold * window_size)
78+ # cur_attn_sum_threshold = cur_attn_sum > (threshold * window_size)
79+ # activation_overlap = (prev_attn_sum_threshold & cur_attn_sum_threshold).sum(-1)
80+ # activation_sum = cur_attn_sum_threshold.sum(-1)
81+ # hit_rate = activation_overlap / activation_sum
82+ # hit_rate = hit_rate.mean()
83+
84+ # features_per_data.append(hit_rate.item())
85+ # total_len = attn_weights.shape[-1]
86+ # for step in range(steps):
87+ # activation_overlaps = []
88+ # for channel_id in range(attn_weights.shape[1]):
89+ # start = prev_len - window_size
90+ # end = prev_len
91+ # shift = window_size * step
92+ # prev_attn_sum = attn_weights[0, channel_id, start:end, :start].sum(0)
93+ # cur_attn_sum = attn_weights[0, channel_id, start + shift + window_size:end + shift + window_size, :start].sum(0)
94+ # activation_overlap = ((prev_attn_sum > threshold) & (cur_attn_sum > threshold)).sum()/(cur_attn_sum > threshold).sum()
95+ # # check if nan skip
96+ # if not torch.isnan(activation_overlap):
97+ # activation_overlaps.append(activation_overlap.item())
98+
99+ # features_per_data.append(np.mean(activation_overlaps))
100+ # jsonl line by line
101+ with open (args .output_path , 'a' ) as f :
102+ f .write (json .dumps (features_per_data ) + '\n ' )
103+ # np.save(args.output_path, np.array(features))
104+
105+ if __name__ == "__main__" :
106+ parser = argparse .ArgumentParser ()
107+ parser .add_argument ("--dataset_path" , type = str , required = True )
108+ parser .add_argument ("--model_name" , type = str , default = 'mistralai/Mistral-7B-Instruct-v0.2' )
109+ parser .add_argument ("--window_size" , type = int , default = 128 )
110+ parser .add_argument ("--steps" , type = int , default = 4 )
111+ parser .add_argument ("--threshold" , type = float , default = 0.005 )
112+ parser .add_argument ("--output_path" , type = str , required = True )
113+ args = parser .parse_args ()
114+ main (args )
0 commit comments