-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapply_language_filter.py
More file actions
163 lines (122 loc) · 5.8 KB
/
apply_language_filter.py
File metadata and controls
163 lines (122 loc) · 5.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import gc
import os
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import update_tokenizer_vocab
MODEL_ID = "swiss-ai/Apertus-8B-Instruct-2509"
LANGUAGE_FILE = "token_analysis/artifacts/final_language_tokens.json"
DEVICE = "cuda:0"
TARGET_LANGUAGE = "eng"
OUTPUT_DIR = f"./models/apertus-8b-pruned-{TARGET_LANGUAGE}"
LOADER_SCRIPT_CONTENT = '''import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
class PrunedModelAdapter(torch.nn.Module):
def __init__(self, model_path, device="cuda:0"):
super().__init__()
print(f"Loading Pruned Adapter from {model_path}...")
self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map=device)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
map_path = os.path.join(model_path, "index_map.pt")
self.index_map = torch.load(map_path).to(device)
print("Building input translation map...")
self.inverse_map = torch.full((self.tokenizer.vocab_size,), -1, dtype=torch.long, device=device)
new_ids = torch.arange(len(self.index_map), device=device)
self.inverse_map[self.index_map] = new_ids
self.device = device
def generate(self, prompt, max_new_tokens=20):
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
old_ids = inputs.input_ids
new_ids = self.inverse_map[old_ids]
if (new_ids == -1).any():
print("Warning: Prompt contains tokens that were pruned from this model!")
new_ids[new_ids == -1] = 0
curr_ids = new_ids
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = self.model(curr_ids)
next_token_logits = outputs.logits[0, -1, :]
new_next_id = torch.argmax(next_token_logits).item()
curr_ids = torch.cat([curr_ids, torch.tensor([[new_next_id]], device=self.device)], dim=1)
old_next_id = self.index_map[new_next_id].item()
if old_next_id == self.tokenizer.eos_token_id:
break
final_new_ids = curr_ids[0]
final_old_ids = self.index_map[final_new_ids]
return self.tokenizer.decode(final_old_ids)
'''
def create_language_model():
print(f"1. Loading Language Map from {LANGUAGE_FILE}...")
with open(LANGUAGE_FILE, "r") as f:
lang_data = json.load(f)
if TARGET_LANGUAGE not in lang_data:
raise ValueError(f"Language '{TARGET_LANGUAGE}' not found in {LANGUAGE_FILE}. Available: {list(lang_data.keys())}")
print(f" Target Language: {TARGET_LANGUAGE}")
keep_indices = lang_data[TARGET_LANGUAGE]
print(" Ensuring Special Tokens (0-999) are present...")
keep_indices.extend(range(0, 1000))
keep_indices = sorted(list(set(keep_indices)))
new_vocab_size = len(keep_indices)
print(f" Reduced Vocab Size: {new_vocab_size}")
index_map = torch.tensor(keep_indices, device=DEVICE)
print(f"\n2. Loading Model: {MODEL_ID}...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16,
device_map=DEVICE,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print("3. Slicing the model...")
original_head = model.lm_head
hidden_size = original_head.in_features
reduced_head_weights = original_head.weight.data[keep_indices, :]
original_embeddings = model.model.embed_tokens
reduced_embed_weights = original_embeddings.weight.data[keep_indices, :]
del original_head
del original_embeddings
del model.lm_head
del model.model.embed_tokens
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(DEVICE)
model.lm_head = torch.nn.Linear(hidden_size, new_vocab_size, bias=False, device=DEVICE, dtype=torch.bfloat16)
model.lm_head.weight.data = reduced_head_weights
model.model.embed_tokens = torch.nn.Embedding(new_vocab_size, hidden_size, device=DEVICE, dtype=torch.bfloat16)
model.model.embed_tokens.weight.data = reduced_embed_weights
model.config.vocab_size = new_vocab_size
original_vocab = tokenizer.get_vocab()
old_id_to_new_id = {old_id: new_id for new_id, old_id in enumerate(keep_indices)}
new_vocab = {}
for token, old_id in original_vocab.items():
if old_id in old_id_to_new_id:
new_id = old_id_to_new_id[old_id]
new_vocab[token] = new_id
print(f" Success! Model pruned to {TARGET_LANGUAGE}.")
print(f"\n4. Saving to {OUTPUT_DIR}...")
os.makedirs(OUTPUT_DIR, exist_ok=True)
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
torch.save(index_map, f"{OUTPUT_DIR}/index_map.pt")
print(f"5. Writing 'loader.py' to {OUTPUT_DIR}...")
with open(f"{OUTPUT_DIR}/loader.py", "w") as f:
f.write(LOADER_SCRIPT_CONTENT)
print("6. Updating Tokenizer Vocabulary...")
update_tokenizer_vocab(OUTPUT_DIR, new_vocab)
test_model = AutoModelForCausalLM.from_pretrained(
OUTPUT_DIR,
dtype=torch.bfloat16,
device_map=DEVICE,
trust_remote_code=True
)
test_tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR)
test_prompt = "Hello, how are you?"
inputs = test_tokenizer(test_prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = test_model.generate(**inputs, max_new_tokens=50)
decoded_output = test_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(" Test successful! Model printed:\n", decoded_output)
print("Done! You can upload the entire folder now.")
if __name__ == "__main__":
create_language_model()