Skip to content

Commit 129b456

Browse files
committed
refactor(optimizer): finalize optimizer schema and backend handling
1 parent 4f182bc commit 129b456

14 files changed

Lines changed: 542 additions & 235 deletions

File tree

deepmd/pd/train/training.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
self.restart_training = restart_model is not None
116116
model_params = config["model"]
117117
training_params = config["training"]
118+
optimizer_params = config.get("optimizer", {})
118119
self.multi_task = "model_dict" in model_params
119120
self.finetune_links = finetune_links
120121
self.finetune_update_stat = False
@@ -149,14 +150,17 @@ def __init__(
149150
self.lcurve_should_print_header = True
150151

151152
def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
152-
opt_type = params.get("opt_type", "Adam")
153-
opt_param = {
154-
"kf_blocksize": params.get("kf_blocksize", 5120),
155-
"kf_start_pref_e": params.get("kf_start_pref_e", 1),
156-
"kf_limit_pref_e": params.get("kf_limit_pref_e", 1),
157-
"kf_start_pref_f": params.get("kf_start_pref_f", 1),
158-
"kf_limit_pref_f": params.get("kf_limit_pref_f", 1),
159-
}
153+
"""
154+
Extract optimizer parameters.
155+
156+
Note: Default values are already filled by argcheck.normalize()
157+
before this function is called.
158+
"""
159+
opt_type = params.get("type", "Adam")
160+
if opt_type != "Adam":
161+
raise ValueError(f"Not supported optimizer type '{opt_type}'")
162+
opt_param = dict(params)
163+
opt_param.pop("type", None)
160164
return opt_type, opt_param
161165

162166
def get_data_loader(
@@ -243,7 +247,7 @@ def get_sample() -> dict[str, Any]:
243247
return get_sample
244248

245249
def get_lr(lr_params: dict[str, Any]) -> BaseLR:
246-
lr_params["stop_steps"] = self.num_steps - self.warmup_steps
250+
lr_params["stop_steps"] = self.num_steps
247251
lr_schedule = BaseLR(**lr_params)
248252
return lr_schedule
249253

@@ -263,7 +267,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
263267
self.optim_dict[model_key]
264268
)
265269
else:
266-
self.opt_type, self.opt_param = get_opt_param(training_params)
270+
self.opt_type, self.opt_param = get_opt_param(optimizer_params)
267271

268272
# loss_param_tmp for Hessian activation
269273
loss_param_tmp = None
@@ -598,7 +602,11 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
598602
lr_lambda=lambda step: warm_up_linear(step, self.warmup_steps),
599603
)
600604
self.optimizer = paddle.optimizer.Adam(
601-
learning_rate=self.scheduler, parameters=self.wrapper.parameters()
605+
learning_rate=self.scheduler,
606+
parameters=self.wrapper.parameters(),
607+
beta1=float(self.opt_param["adam_beta1"]),
608+
beta2=float(self.opt_param["adam_beta2"]),
609+
weight_decay=float(self.opt_param["weight_decay"]),
602610
)
603611
if optimizer_state_dict is not None and self.restart_training:
604612
self.optimizer.set_state_dict(optimizer_state_dict)

deepmd/pt/train/training.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(
125125
self.restart_training = restart_model is not None
126126
model_params = config["model"]
127127
training_params = config["training"]
128+
optimizer_params = config.get("optimizer", {})
128129
self.multi_task = "model_dict" in model_params
129130
self.finetune_links = finetune_links
130131
self.finetune_update_stat = False
@@ -157,25 +158,17 @@ def __init__(
157158
self.lcurve_should_print_header = True
158159

159160
def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
160-
opt_type = params.get("opt_type", "Adam")
161-
opt_param = {
162-
# LKF parameters
163-
"kf_blocksize": params.get("kf_blocksize", 5120),
164-
"kf_start_pref_e": params.get("kf_start_pref_e", 1),
165-
"kf_limit_pref_e": params.get("kf_limit_pref_e", 1),
166-
"kf_start_pref_f": params.get("kf_start_pref_f", 1),
167-
"kf_limit_pref_f": params.get("kf_limit_pref_f", 1),
168-
# Common parameters
169-
"weight_decay": params.get("weight_decay", 0.001),
170-
# Muon/AdaMuon parameters
171-
"momentum": params.get("momentum", 0.95),
172-
"adam_beta1": params.get("adam_beta1", 0.9),
173-
"adam_beta2": params.get("adam_beta2", 0.95),
174-
"lr_adjust": params.get("lr_adjust", 10.0),
175-
"lr_adjust_coeff": params.get("lr_adjust_coeff", 0.2),
176-
"muon_2d_only": params.get("muon_2d_only", True),
177-
"min_2d_dim": params.get("min_2d_dim", 1),
178-
}
161+
"""
162+
Extract optimizer parameters.
163+
164+
Note: Default values are already filled by argcheck.normalize()
165+
before this function is called.
166+
"""
167+
opt_type = params.get("type", "Adam")
168+
if opt_type not in ("Adam", "AdamW", "LKF", "AdaMuon", "HybridMuon"):
169+
raise ValueError(f"Not supported optimizer type '{opt_type}'")
170+
opt_param = dict(params)
171+
opt_param.pop("type", None)
179172
return opt_type, opt_param
180173

181174
def cycle_iterator(iterable: Iterable) -> Generator[Any, None, None]:
@@ -279,7 +272,7 @@ def get_sample() -> Any:
279272
return get_sample
280273

281274
def get_lr(lr_params: dict[str, Any]) -> BaseLR:
282-
lr_params["stop_steps"] = self.num_steps - self.warmup_steps
275+
lr_params["stop_steps"] = self.num_steps
283276
lr_schedule = BaseLR(**lr_params)
284277
return lr_schedule
285278

@@ -299,7 +292,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
299292
self.optim_dict[model_key]
300293
)
301294
else:
302-
self.opt_type, self.opt_param = get_opt_param(training_params)
295+
self.opt_type, self.opt_param = get_opt_param(optimizer_params)
303296

304297
# loss_param_tmp for Hessian activation
305298
loss_param_tmp = None
@@ -712,20 +705,38 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
712705

713706
# TODO add optimizers for multitask
714707
# author: iProzd
715-
if self.opt_type in ["Adam", "AdamW"]:
716-
if self.opt_type == "Adam":
717-
self.optimizer = torch.optim.Adam(
718-
self.wrapper.parameters(),
719-
lr=self.lr_exp.start_lr,
720-
fused=False if DEVICE.type == "cpu" else True,
721-
)
722-
else:
723-
self.optimizer = torch.optim.AdamW(
724-
self.wrapper.parameters(),
725-
lr=self.lr_exp.start_lr,
726-
weight_decay=float(self.opt_param["weight_decay"]),
727-
fused=False if DEVICE.type == "cpu" else True,
728-
)
708+
if self.opt_type == "Adam":
709+
adam_betas = (
710+
float(self.opt_param["adam_beta1"]),
711+
float(self.opt_param["adam_beta2"]),
712+
)
713+
weight_decay = float(self.opt_param["weight_decay"])
714+
self.optimizer = torch.optim.Adam(
715+
self.wrapper.parameters(),
716+
lr=self.lr_exp.start_lr,
717+
betas=adam_betas,
718+
weight_decay=weight_decay,
719+
fused=False if DEVICE.type == "cpu" else True,
720+
)
721+
if optimizer_state_dict is not None and self.restart_training:
722+
self.optimizer.load_state_dict(optimizer_state_dict)
723+
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
724+
self.optimizer,
725+
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
726+
)
727+
elif self.opt_type == "AdamW":
728+
adam_betas = (
729+
float(self.opt_param["adam_beta1"]),
730+
float(self.opt_param["adam_beta2"]),
731+
)
732+
weight_decay = float(self.opt_param["weight_decay"])
733+
self.optimizer = torch.optim.AdamW(
734+
self.wrapper.parameters(),
735+
lr=self.lr_exp.start_lr,
736+
betas=adam_betas,
737+
weight_decay=weight_decay,
738+
fused=False if DEVICE.type == "cpu" else True,
739+
)
729740
if optimizer_state_dict is not None and self.restart_training:
730741
self.optimizer.load_state_dict(optimizer_state_dict)
731742
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
@@ -749,6 +760,12 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
749760
lr_adjust=float(self.opt_param["lr_adjust"]),
750761
lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
751762
)
763+
if optimizer_state_dict is not None and self.restart_training:
764+
self.optimizer.load_state_dict(optimizer_state_dict)
765+
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
766+
self.optimizer,
767+
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
768+
)
752769
elif self.opt_type == "HybridMuon":
753770
self.optimizer = HybridMuonOptimizer(
754771
self.wrapper.parameters(),

deepmd/tf/entrypoints/train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ def train(
162162
jdata["model"] = json.loads(t_training_script)["model"]
163163

164164
jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")
165-
166165
jdata = normalize(jdata)
167166

168167
if not is_compress and not skip_neighbor_stat:

deepmd/tf/train/trainer.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,22 @@ def get_lr_and_coef(lr_param: dict) -> Any:
139139
# learning rate
140140
lr_param = jdata["learning_rate"]
141141
self.lr, self.scale_lr_coef = get_lr_and_coef(lr_param)
142+
# optimizer
143+
# Note: Default values are already filled by argcheck.normalize()
144+
optimizer_param = jdata.get("optimizer", {})
145+
self.optimizer_type = optimizer_param.get("type", "Adam")
146+
self.optimizer_beta1 = float(optimizer_param.get("adam_beta1"))
147+
self.optimizer_beta2 = float(optimizer_param.get("adam_beta2"))
148+
self.optimizer_weight_decay = float(optimizer_param.get("weight_decay"))
149+
if self.optimizer_type != "Adam":
150+
raise RuntimeError(
151+
f"Unsupported optimizer type {self.optimizer_type} for TensorFlow backend."
152+
)
153+
if self.optimizer_weight_decay != 0.0:
154+
raise RuntimeError(
155+
"TensorFlow Adam optimizer does not support weight_decay. "
156+
"Set optimizer/weight_decay to 0."
157+
)
142158
# loss
143159
# infer loss type by fitting_type
144160
loss_param = jdata.get("loss", {})
@@ -331,17 +347,31 @@ def _build_network(self, data: DeepmdDataSystem, suffix: str = "") -> None:
331347
log.info("built network")
332348

333349
def _build_optimizer(self) -> Any:
350+
if self.optimizer_type != "Adam":
351+
raise RuntimeError(
352+
f"Unsupported optimizer type {self.optimizer_type} for TensorFlow backend."
353+
)
334354
if self.run_opt.is_distrib:
335355
if self.scale_lr_coef > 1.0:
336356
log.info("Scale learning rate by coef: %f", self.scale_lr_coef)
337357
optimizer = tf.train.AdamOptimizer(
338-
self.learning_rate * self.scale_lr_coef
358+
self.learning_rate * self.scale_lr_coef,
359+
beta1=self.optimizer_beta1,
360+
beta2=self.optimizer_beta2,
339361
)
340362
else:
341-
optimizer = tf.train.AdamOptimizer(self.learning_rate)
363+
optimizer = tf.train.AdamOptimizer(
364+
self.learning_rate,
365+
beta1=self.optimizer_beta1,
366+
beta2=self.optimizer_beta2,
367+
)
342368
optimizer = self.run_opt._HVD.DistributedOptimizer(optimizer)
343369
else:
344-
optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
370+
optimizer = tf.train.AdamOptimizer(
371+
learning_rate=self.learning_rate,
372+
beta1=self.optimizer_beta1,
373+
beta2=self.optimizer_beta2,
374+
)
345375

346376
if self.mixed_prec is not None:
347377
_TF_VERSION = Version(TF_VERSION)

deepmd/tf/utils/compat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
from deepmd.utils.compat import (
55
convert_input_v0_v1,
66
convert_input_v1_v2,
7+
convert_optimizer_to_new_format,
78
deprecate_numb_test,
89
update_deepmd_input,
910
)
1011

1112
__all__ = [
1213
"convert_input_v0_v1",
1314
"convert_input_v1_v2",
15+
"convert_optimizer_to_new_format",
1416
"deprecate_numb_test",
1517
"update_deepmd_input",
1618
]

0 commit comments

Comments
 (0)