-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbenchmark_input_reduction.py
More file actions
125 lines (102 loc) · 3.83 KB
/
benchmark_input_reduction.py
File metadata and controls
125 lines (102 loc) · 3.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import time
import gc
import statistics
from transformers import AutoModelForCausalLM, AutoTokenizer
# --- Configuration ---
MODEL_ID = "swiss-ai/Apertus-8B-Instruct-2509"
PERCENTAGES = [0.01, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64, 1.0]
NUM_TOKENS_TO_GENERATE = 100
WARMUP_TOKENS = 10
NUM_RUNS = 10
def get_total_memory_gb():
total_mem = 0
for i in range(torch.cuda.device_count()):
total_mem += torch.cuda.max_memory_allocated(i)
return total_mem / (1024 ** 3)
def reset_all_memory_stats():
for i in range(torch.cuda.device_count()):
torch.cuda.reset_peak_memory_stats(i)
torch.cuda.empty_cache()
def benchmark_vocab_scaling():
print(f"Loading model: {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
input_text = "The future of artificial intelligence in Switzerland is"
base_inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
original_vocab_size = model.config.vocab_size
hidden_size = model.model.embed_tokens.embedding_dim
print(f"\nOriginal Vocab Size: {original_vocab_size}")
print(f"Testing full vocab (Input+Output) scaling.")
print("-" * 90)
print(f"{'Percent':<8} | {'Vocab':<8} | {'Total Mem (GB)':<15}")
print("-" * 90)
for p in PERCENTAGES:
# Capture devices before deleting
input_device = model.model.embed_tokens.weight.device
head_device = model.lm_head.weight.device
new_vocab_size = int(original_vocab_size * p)
# 1. DELETE OLD LAYERS
del model.model.embed_tokens
del model.lm_head
gc.collect()
torch.cuda.empty_cache()
# 2. CREATE NEW LAYERS (Both must match new_vocab_size)
model.model.embed_tokens = torch.nn.Embedding(
new_vocab_size,
hidden_size,
device=input_device,
dtype=torch.bfloat16
)
model.lm_head = torch.nn.Linear(
hidden_size,
new_vocab_size,
bias=False,
device=head_device,
dtype=torch.bfloat16
)
# Update config so generation knows the limit
model.config.vocab_size = new_vocab_size
# 3. Clamping the inputs
safe_ids = base_inputs.input_ids % new_vocab_size
current_inputs = {
"input_ids": safe_ids,
"attention_mask": base_inputs.attention_mask
}
# 4. Reset Stats
reset_all_memory_stats()
# 5. Warmup
try:
with torch.no_grad():
_ = model.generate(
**current_inputs,
max_new_tokens=WARMUP_TOKENS,
min_new_tokens=WARMUP_TOKENS,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
except Exception as e:
print(f"Error during warmup: {e}")
continue
# 6. Benchmark Loop
run_throughputs = []
for _ in range(NUM_RUNS):
torch.cuda.synchronize()
with torch.no_grad():
_ = model.generate(
**current_inputs,
max_new_tokens=NUM_TOKENS_TO_GENERATE,
min_new_tokens=NUM_TOKENS_TO_GENERATE,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
torch.cuda.synchronize()
total_mem = get_total_memory_gb()
print(f"{p*100:>6.0f}% | {new_vocab_size:<8} | {total_mem:<15.4f}")
if __name__ == "__main__":
benchmark_vocab_scaling()