-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_quantized.py
More file actions
50 lines (42 loc) · 1.99 KB
/
train_quantized.py
File metadata and controls
50 lines (42 loc) · 1.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.quantization
import os
import json
from modulus.models.meshgraphnet import MeshGraphNet
from modulus.launch.utils import load_checkpoint, save_checkpoint
from omegaconf import DictConfig
import hydra
@hydra.main(version_base=None, config_path=".", config_name="config")
def main(cfg: DictConfig):
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# Load parameters from JSON file
with open("checkpoints/parameters.json") as f:
params = json.load(f)
# Initialize the MeshGraphNet model
model = MeshGraphNet(
params["infeat_nodes"],
params["infeat_edges"],
2, # Output dimension
processor_size=cfg.architecture.processor_size,
hidden_dim_node_encoder=cfg.architecture.hidden_dim_node_encoder,
hidden_dim_edge_encoder=cfg.architecture.hidden_dim_edge_encoder,
hidden_dim_processor=cfg.architecture.hidden_dim_processor,
hidden_dim_node_decoder=cfg.architecture.hidden_dim_node_decoder,
)
model = model.to(device)
model.eval()
# Load the pre-trained model checkpoint
load_checkpoint(os.path.join(cfg.checkpoints.ckpt_path, cfg.checkpoints.ckpt_name), models=model, device=device)
# Perform Post-Training Quantization (PTQ)
# Dynamic quantization reduces model size and increases inference speed by converting weights to int8.
# This process is suitable for models where most operations are matrix multiplications, such as fully connected layers.
# We specifically target Linear and Convolutional layers for quantization to balance performance and accuracy.
model_int8 = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d}, dtype=torch.qint8
)
# Save the quantized model to a file
quantized_model_path = os.path.join(cfg.checkpoints.ckpt_path, "model_quantized.pt")
torch.save(model_int8.state_dict(), quantized_model_path)
if __name__ == "__main__":
main()