Skip to content

Commit de34fe0

Browse files
committed
refactor opt initialization
1 parent 765345f commit de34fe0

1 file changed

Lines changed: 33 additions & 52 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 33 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -741,72 +741,53 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
741741

742742
# TODO add optimizers for multitask
743743
# author: iProzd
744-
if self.opt_type in ["Adam", "AdamW"]:
744+
if self.opt_type == "LKF":
745+
self.optimizer = LKFOptimizer(
746+
self.wrapper.parameters(), 0.98, 0.99870, self.opt_param["kf_blocksize"]
747+
)
748+
else:
749+
# === Common path for gradient-based optimizers ===
745750
adam_betas = (
746751
float(self.opt_param["adam_beta1"]),
747752
float(self.opt_param["adam_beta2"]),
748753
)
749754
weight_decay = float(self.opt_param["weight_decay"])
750-
optimizer_class = (
751-
torch.optim.Adam if self.opt_type == "Adam" else torch.optim.AdamW
752-
)
755+
756+
if self.opt_type in ("Adam", "AdamW"):
757+
cls = torch.optim.Adam if self.opt_type == "Adam" else torch.optim.AdamW
758+
extra = {"betas": adam_betas, "fused": DEVICE.type != "cpu"}
759+
elif self.opt_type == "AdaMuon":
760+
cls = AdaMuonOptimizer
761+
extra = {
762+
"adam_betas": adam_betas,
763+
"momentum": float(self.opt_param["momentum"]),
764+
"lr_adjust": float(self.opt_param["lr_adjust"]),
765+
"lr_adjust_coeff": float(self.opt_param["lr_adjust_coeff"]),
766+
}
767+
elif self.opt_type == "HybridMuon":
768+
cls = HybridMuonOptimizer
769+
extra = {
770+
"adam_betas": adam_betas,
771+
"momentum": float(self.opt_param["momentum"]),
772+
"lr_adjust": float(self.opt_param["lr_adjust"]),
773+
"lr_adjust_coeff": float(self.opt_param["lr_adjust_coeff"]),
774+
"muon_2d_only": bool(self.opt_param["muon_2d_only"]),
775+
"min_2d_dim": int(self.opt_param["min_2d_dim"]),
776+
}
777+
else:
778+
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
779+
753780
self.optimizer = self._create_optimizer(
754-
optimizer_class,
781+
cls,
755782
lr=self.lr_exp.start_lr,
756-
betas=adam_betas,
757783
weight_decay=weight_decay,
758-
fused=DEVICE.type != "cpu",
759-
)
760-
self._load_optimizer_state(optimizer_state_dict)
761-
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
762-
self.optimizer,
763-
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
764-
)
765-
elif self.opt_type == "LKF":
766-
self.optimizer = LKFOptimizer(
767-
self.wrapper.parameters(), 0.98, 0.99870, self.opt_param["kf_blocksize"]
768-
)
769-
elif self.opt_type == "AdaMuon":
770-
self.optimizer = self._create_optimizer(
771-
AdaMuonOptimizer,
772-
lr=self.lr_exp.start_lr,
773-
momentum=float(self.opt_param["momentum"]),
774-
weight_decay=float(self.opt_param["weight_decay"]),
775-
adam_betas=(
776-
float(self.opt_param["adam_beta1"]),
777-
float(self.opt_param["adam_beta2"]),
778-
),
779-
lr_adjust=float(self.opt_param["lr_adjust"]),
780-
lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
781-
)
782-
if optimizer_state_dict is not None and self.restart_training:
783-
self.optimizer.load_state_dict(optimizer_state_dict)
784-
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
785-
self.optimizer,
786-
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
787-
)
788-
elif self.opt_type == "HybridMuon":
789-
self.optimizer = self._create_optimizer(
790-
HybridMuonOptimizer,
791-
lr=self.lr_exp.start_lr,
792-
momentum=float(self.opt_param["momentum"]),
793-
weight_decay=float(self.opt_param["weight_decay"]),
794-
adam_betas=(
795-
float(self.opt_param["adam_beta1"]),
796-
float(self.opt_param["adam_beta2"]),
797-
),
798-
lr_adjust=float(self.opt_param["lr_adjust"]),
799-
lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
800-
muon_2d_only=bool(self.opt_param["muon_2d_only"]),
801-
min_2d_dim=int(self.opt_param["min_2d_dim"]),
784+
**extra,
802785
)
803786
self._load_optimizer_state(optimizer_state_dict)
804787
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
805788
self.optimizer,
806789
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
807790
)
808-
else:
809-
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
810791

811792
if self.zero_stage > 0 and self.rank == 0:
812793
if self.zero_stage == 1:

0 commit comments

Comments
 (0)