Skip to content

Commit cade567

Browse files
committed
add motivations
1 parent eddd12e commit cade567

12 files changed

+4552
-0
lines changed

.DS_Store

6 KB
Binary file not shown.

observations/categorize_prompts.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import transformers
2+
import torch
3+
import json
4+
from fastchat.model import load_model, get_conversation_template
5+
import argparse
6+
from tqdm import tqdm
7+
# load jsonl dataset
8+
def load_jsonl(file_path):
9+
data = []
10+
with open(file_path, 'r') as f:
11+
for line in f:
12+
try:
13+
data.append(json.loads(line))
14+
except:
15+
pass
16+
return data
17+
18+
def classify_per_data(dataset, idx, tokenizer, model_name):
19+
data = dataset[idx]['data']
20+
max_turns = len(data) // 2
21+
conv = get_conversation_template(model_name)
22+
info = {}
23+
info['max_turns'] = max_turns
24+
info['idx'] = idx
25+
info['id'] = dataset[idx]['id']
26+
info['turns'] = []
27+
for i in range(max_turns):
28+
conv.append_message(conv.roles[0], data[i*2])
29+
# get start stamp
30+
start = tokenizer.encode(conv.get_prompt(), return_tensors='pt').shape[1]
31+
conv.append_message(conv.roles[1], data[i*2+1])
32+
# get end stamp
33+
end = tokenizer.encode(conv.get_prompt(), return_tensors='pt').shape[1]
34+
info['turns'].append((start, end))
35+
return info
36+
37+
def main(args):
38+
model_name = args.model_name
39+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
40+
dataset = load_jsonl(args.dataset_path)
41+
for idx in tqdm(range(len(dataset))):
42+
info = classify_per_data(dataset, idx, tokenizer, model_name)
43+
# save info to jsonl by line
44+
with open(args.output_path, 'a') as f:
45+
f.write(json.dumps(info) + '\n')
46+
47+
if __name__ == "__main__":
48+
parser = argparse.ArgumentParser(description="Categorize prompts in a dataset")
49+
# add arguments
50+
parser.add_argument('--dataset_path', type=str, help='path to the dataset')
51+
parser.add_argument('--output_path', type=str, help='path to the output')
52+
parser.add_argument('--model_name', type=str, help='model name')
53+
# get args
54+
args = parser.parse_args()
55+
main(args)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import json
2+
import argparse
3+
# load jsonl dataset
4+
def load_jsonl(file_path):
5+
data = []
6+
with open(file_path, 'r') as f:
7+
for line in f:
8+
try:
9+
data.append(json.loads(line))
10+
except:
11+
pass
12+
return data
13+
14+
def collect_prompt(dataset_id, data_info, prev_range = [2000, 2500], min_length = 64 * 8):
15+
prompt_list = []
16+
for info in data_info:
17+
for i, turn in enumerate(info['turns']):
18+
start, end = turn
19+
if start > prev_range[0] and start < prev_range[1] and end - start > min_length:
20+
prompt_list.append({'idx': info['idx'], 'turn': i, 'start': start, 'end': end, 'dataset_id':dataset_id})
21+
return prompt_list
22+
23+
# main
24+
def main(args):
25+
data_info_path = args.data_info_path
26+
dataset_id_begin = args.dataset_id_begin
27+
dataset_id_end = args.dataset_id_end
28+
min_length = args.min_length
29+
length_start = args.length_start
30+
length_end = args.length_end
31+
length_step = args.length_step
32+
total_step = (length_end - length_start) // length_step + 1
33+
for dataset_id in range(dataset_id_begin, dataset_id_end + 1):
34+
data_info = load_jsonl(data_info_path.replace('DATASET_ID', str(dataset_id)))
35+
for step in range(total_step):
36+
cur_start = length_start + step * length_step
37+
cur_end = cur_start + length_step if step < total_step - 1 else 100000
38+
prompt_list = collect_prompt(dataset_id, data_info, [cur_start, cur_end], min_length)
39+
with open(f'./data/filtered_info/prompt_{dataset_id}_len_{cur_start}_{cur_end}.json', 'w') as f:
40+
json.dump(prompt_list, f, indent=2)
41+
42+
if __name__ == '__main__':
43+
parser = argparse.ArgumentParser()
44+
parser.add_argument('--data_info_path', type=str, default='./data/ultrachat_DATASET_ID_categorized.jsonl')
45+
parser.add_argument('--dataset_id_begin', type=int, default=0)
46+
parser.add_argument('--dataset_id_end', type=int, default=9)
47+
parser.add_argument('--min_length', type=int, default=64 * 8)
48+
parser.add_argument('--length_start', type=int, default=1000)
49+
parser.add_argument('--length_end', type=int, default=3000)
50+
parser.add_argument('--length_step', type=int, default=500)
51+
args = parser.parse_args()
52+
main(args)

observations/collect_draw.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
import json
4+
import matplotlib.lines as mlines
5+
6+
def draw_ablation(features, steps, avg_prompt_len, avg_turn, avg_context_len, total_num, save_path):
7+
plt.figure(figsize=(10, 5))
8+
for i in range(steps):
9+
plt.plot(features[:, i], label=f'window {i}')
10+
11+
plt.ylim(0, 1.1)
12+
plt.grid()
13+
14+
custom_entries = [
15+
mlines.Line2D([], [], color='none', marker='None', linestyle='None', label=f'Avg Prompt Len: {avg_prompt_len}'),
16+
mlines.Line2D([], [], color='none', marker='None', linestyle='None', label=f'Avg Turn: {avg_turn}'),
17+
mlines.Line2D([], [], color='none', marker='None', linestyle='None', label=f'Avg Context Len: {avg_context_len}'),
18+
mlines.Line2D([], [], color='none', marker='None', linestyle='None', label=f'Total Num: {total_num}')
19+
]
20+
21+
handles, labels = plt.gca().get_legend_handles_labels()
22+
23+
# Combine existing handles (if any) with custom ones
24+
handles.extend(custom_entries)
25+
26+
# Create the legend with the combined handles
27+
# plt.legend(handles=handles, loc='upper left', bbox_to_anchor=(1, 1), fontsize='small')
28+
plt.legend(handles=handles)
29+
plt.title('Hit rates for different windows')
30+
plt.xlabel('Layer')
31+
plt.ylabel('Hit rate (%)')
32+
plt.tight_layout() # Adjust layout to make room for the legend
33+
34+
plt.savefig(save_path)
35+
36+
def main():
37+
feature_paths = [
38+
'./data/features_finegrained/features_1000_1500_step_128.jsonl',
39+
'./data/features_finegrained/features_1500_2000_step_128.jsonl',
40+
'./data/features_finegrained/features_2000_2500_step_128.jsonl',
41+
'./data/features_finegrained/features_2500_3000_step_128.jsonl',
42+
'./data/features_finegrained/features_3000_100000_step_128.jsonl',
43+
]
44+
45+
data_paths = [
46+
'./data/random_prompts/random_prompt_1000_1500_summary.json',
47+
'./data/random_prompts/random_prompt_1500_2000_summary.json',
48+
'./data/random_prompts/random_prompt_2000_2500_summary.json',
49+
'./data/random_prompts/random_prompt_2500_3000_summary.json',
50+
'./data/random_prompts/random_prompt_3000_100000_summary.json',
51+
]
52+
53+
layers = 32
54+
steps = 4
55+
56+
for i in range(len(feature_paths)):
57+
with open(feature_paths[i], 'r') as f:
58+
features = np.array([json.loads(line) for line in f])
59+
features = np.array(features).reshape(-1, layers, steps).mean(axis=0)
60+
print(features.shape)
61+
with open(data_paths[i], 'r') as f:
62+
data = json.load(f)
63+
64+
draw_ablation(features, steps, data['avg_prompt_len'], data['avg_turn'], data['avg_context_len'], data['total_num'], f'./data/figures/ablation_finegrained_{i}.png')
65+
66+
if __name__ == '__main__':
67+
main()

observations/collect_features.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import json
2+
from fastchat.model import load_model, get_conversation_template
3+
from models.modeling_mistral_benchmark import MistralForCausalLM as MyMistralForCausalLM
4+
import torch
5+
import transformers
6+
from tqdm import tqdm
7+
import numpy as np
8+
import argparse
9+
10+
def load_jsonl(file_path):
11+
data = []
12+
with open(file_path, 'r') as f:
13+
for line in f:
14+
try:
15+
data.append(json.loads(line))
16+
except:
17+
pass
18+
return data
19+
20+
def load_data(data, model_name, tokenizer):
21+
conv = get_conversation_template(model_name)
22+
max_turns = len(data) // 2
23+
for i in range(max_turns - 1):
24+
conv.append_message(conv.roles[0], data[i*2])
25+
conv.append_message(conv.roles[1], data[i*2+1])
26+
conv.append_message(conv.roles[0], data[(max_turns - 1)*2])
27+
start = tokenizer.encode(conv.get_prompt(), return_tensors='pt').shape[1]
28+
conv.append_message(conv.roles[1], data[(max_turns - 1)*2+1])
29+
end = tokenizer.encode(conv.get_prompt(), return_tensors='pt').shape[1]
30+
return conv.get_prompt(), start - 1, end
31+
32+
def main(args):
33+
dataset_path = args.dataset_path
34+
dataset = load_jsonl(dataset_path)
35+
model_name = args.model_name
36+
model = MyMistralForCausalLM.from_pretrained(
37+
model_name,
38+
torch_dtype=torch.float16,
39+
low_cpu_mem_usage=True,
40+
device_map="auto",
41+
# use_flash_attention_2=True
42+
)
43+
tokenizer = transformers.AutoTokenizer.from_pretrained(
44+
model_name,
45+
padding_side="right",
46+
use_fast=False,
47+
)
48+
49+
features = []
50+
window_size = args.window_size
51+
steps = args.steps
52+
threshold = args.threshold
53+
54+
layer_len = len(model.model.layers)
55+
56+
for data in tqdm(dataset):
57+
features_per_data = []
58+
with torch.inference_mode():
59+
prompt, prev_len, end = load_data(data, model_name, tokenizer)
60+
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
61+
for layer_id in range(layer_len):
62+
model.model.layers[layer_id].self_attn.prev_len = prev_len
63+
model.model.layers[layer_id].self_attn.steps = steps
64+
model.model.layers[layer_id].self_attn.threshold = threshold
65+
model.model.layers[layer_id].self_attn.window_size = window_size
66+
outputs = model(input_ids, output_attentions = False, use_cache = False)
67+
for layer_id in range(layer_len):
68+
features_per_data.append(model.model.layers[layer_id].self_attn.features_per_data)
69+
# attn_weights = model.model.layers[layer_id].self_attn.attn_weights
70+
# total_len = attn_weights.shape[-1]
71+
# for step in range(steps):
72+
# start = prev_len - window_size
73+
# end = prev_len
74+
# shift = window_size * step
75+
# prev_attn_sum = attn_weights[0, :, start:end, :start].sum(1)
76+
# cur_attn_sum = attn_weights[0, :, start + shift + window_size:end + shift + window_size, :start].sum(1)
77+
# prev_attn_sum_threshold = prev_attn_sum > (threshold * window_size)
78+
# cur_attn_sum_threshold = cur_attn_sum > (threshold * window_size)
79+
# activation_overlap = (prev_attn_sum_threshold & cur_attn_sum_threshold).sum(-1)
80+
# activation_sum = cur_attn_sum_threshold.sum(-1)
81+
# hit_rate = activation_overlap / activation_sum
82+
# hit_rate = hit_rate.mean()
83+
84+
# features_per_data.append(hit_rate.item())
85+
# total_len = attn_weights.shape[-1]
86+
# for step in range(steps):
87+
# activation_overlaps = []
88+
# for channel_id in range(attn_weights.shape[1]):
89+
# start = prev_len - window_size
90+
# end = prev_len
91+
# shift = window_size * step
92+
# prev_attn_sum = attn_weights[0, channel_id, start:end, :start].sum(0)
93+
# cur_attn_sum = attn_weights[0, channel_id, start + shift + window_size:end + shift + window_size, :start].sum(0)
94+
# activation_overlap = ((prev_attn_sum > threshold) & (cur_attn_sum > threshold)).sum()/(cur_attn_sum > threshold).sum()
95+
# # check if nan skip
96+
# if not torch.isnan(activation_overlap):
97+
# activation_overlaps.append(activation_overlap.item())
98+
99+
# features_per_data.append(np.mean(activation_overlaps))
100+
# jsonl line by line
101+
with open(args.output_path, 'a') as f:
102+
f.write(json.dumps(features_per_data) + '\n')
103+
# np.save(args.output_path, np.array(features))
104+
105+
if __name__ == "__main__":
106+
parser = argparse.ArgumentParser()
107+
parser.add_argument("--dataset_path", type=str, required=True)
108+
parser.add_argument("--model_name", type=str, default = 'mistralai/Mistral-7B-Instruct-v0.2')
109+
parser.add_argument("--window_size", type=int, default=128)
110+
parser.add_argument("--steps", type=int, default=4)
111+
parser.add_argument("--threshold", type=float, default=0.005)
112+
parser.add_argument("--output_path", type=str, required=True)
113+
args = parser.parse_args()
114+
main(args)

0 commit comments

Comments
 (0)