@@ -125,6 +125,7 @@ def __init__(
125125 self .restart_training = restart_model is not None
126126 model_params = config ["model" ]
127127 training_params = config ["training" ]
128+ optimizer_params = config .get ("optimizer" , {})
128129 self .multi_task = "model_dict" in model_params
129130 self .finetune_links = finetune_links
130131 self .finetune_update_stat = False
@@ -157,25 +158,17 @@ def __init__(
157158 self .lcurve_should_print_header = True
158159
159160 def get_opt_param (params : dict [str , Any ]) -> tuple [str , dict [str , Any ]]:
160- opt_type = params .get ("opt_type" , "Adam" )
161- opt_param = {
162- # LKF parameters
163- "kf_blocksize" : params .get ("kf_blocksize" , 5120 ),
164- "kf_start_pref_e" : params .get ("kf_start_pref_e" , 1 ),
165- "kf_limit_pref_e" : params .get ("kf_limit_pref_e" , 1 ),
166- "kf_start_pref_f" : params .get ("kf_start_pref_f" , 1 ),
167- "kf_limit_pref_f" : params .get ("kf_limit_pref_f" , 1 ),
168- # Common parameters
169- "weight_decay" : params .get ("weight_decay" , 0.001 ),
170- # Muon/AdaMuon parameters
171- "momentum" : params .get ("momentum" , 0.95 ),
172- "adam_beta1" : params .get ("adam_beta1" , 0.9 ),
173- "adam_beta2" : params .get ("adam_beta2" , 0.95 ),
174- "lr_adjust" : params .get ("lr_adjust" , 10.0 ),
175- "lr_adjust_coeff" : params .get ("lr_adjust_coeff" , 0.2 ),
176- "muon_2d_only" : params .get ("muon_2d_only" , True ),
177- "min_2d_dim" : params .get ("min_2d_dim" , 1 ),
178- }
161+ """
162+ Extract optimizer parameters.
163+
164+ Note: Default values are already filled by argcheck.normalize()
165+ before this function is called.
166+ """
167+ opt_type = params .get ("type" , "Adam" )
168+ if opt_type not in ("Adam" , "AdamW" , "LKF" , "AdaMuon" , "HybridMuon" ):
169+ raise ValueError (f"Not supported optimizer type '{ opt_type } '" )
170+ opt_param = dict (params )
171+ opt_param .pop ("type" , None )
179172 return opt_type , opt_param
180173
181174 def cycle_iterator (iterable : Iterable ) -> Generator [Any , None , None ]:
@@ -279,7 +272,7 @@ def get_sample() -> Any:
279272 return get_sample
280273
281274 def get_lr (lr_params : dict [str , Any ]) -> BaseLR :
282- lr_params ["stop_steps" ] = self .num_steps - self . warmup_steps
275+ lr_params ["stop_steps" ] = self .num_steps
283276 lr_schedule = BaseLR (** lr_params )
284277 return lr_schedule
285278
@@ -299,7 +292,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
299292 self .optim_dict [model_key ]
300293 )
301294 else :
302- self .opt_type , self .opt_param = get_opt_param (training_params )
295+ self .opt_type , self .opt_param = get_opt_param (optimizer_params )
303296
304297 # loss_param_tmp for Hessian activation
305298 loss_param_tmp = None
@@ -712,20 +705,38 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
712705
713706 # TODO add optimizers for multitask
714707 # author: iProzd
715- if self .opt_type in ["Adam" , "AdamW" ]:
716- if self .opt_type == "Adam" :
717- self .optimizer = torch .optim .Adam (
718- self .wrapper .parameters (),
719- lr = self .lr_exp .start_lr ,
720- fused = False if DEVICE .type == "cpu" else True ,
721- )
722- else :
723- self .optimizer = torch .optim .AdamW (
724- self .wrapper .parameters (),
725- lr = self .lr_exp .start_lr ,
726- weight_decay = float (self .opt_param ["weight_decay" ]),
727- fused = False if DEVICE .type == "cpu" else True ,
728- )
708+ if self .opt_type == "Adam" :
709+ adam_betas = (
710+ float (self .opt_param ["adam_beta1" ]),
711+ float (self .opt_param ["adam_beta2" ]),
712+ )
713+ weight_decay = float (self .opt_param ["weight_decay" ])
714+ self .optimizer = torch .optim .Adam (
715+ self .wrapper .parameters (),
716+ lr = self .lr_exp .start_lr ,
717+ betas = adam_betas ,
718+ weight_decay = weight_decay ,
719+ fused = False if DEVICE .type == "cpu" else True ,
720+ )
721+ if optimizer_state_dict is not None and self .restart_training :
722+ self .optimizer .load_state_dict (optimizer_state_dict )
723+ self .scheduler = torch .optim .lr_scheduler .LambdaLR (
724+ self .optimizer ,
725+ lambda step : warm_up_linear (step + self .start_step , self .warmup_steps ),
726+ )
727+ elif self .opt_type == "AdamW" :
728+ adam_betas = (
729+ float (self .opt_param ["adam_beta1" ]),
730+ float (self .opt_param ["adam_beta2" ]),
731+ )
732+ weight_decay = float (self .opt_param ["weight_decay" ])
733+ self .optimizer = torch .optim .AdamW (
734+ self .wrapper .parameters (),
735+ lr = self .lr_exp .start_lr ,
736+ betas = adam_betas ,
737+ weight_decay = weight_decay ,
738+ fused = False if DEVICE .type == "cpu" else True ,
739+ )
729740 if optimizer_state_dict is not None and self .restart_training :
730741 self .optimizer .load_state_dict (optimizer_state_dict )
731742 self .scheduler = torch .optim .lr_scheduler .LambdaLR (
@@ -749,6 +760,12 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
749760 lr_adjust = float (self .opt_param ["lr_adjust" ]),
750761 lr_adjust_coeff = float (self .opt_param ["lr_adjust_coeff" ]),
751762 )
763+ if optimizer_state_dict is not None and self .restart_training :
764+ self .optimizer .load_state_dict (optimizer_state_dict )
765+ self .scheduler = torch .optim .lr_scheduler .LambdaLR (
766+ self .optimizer ,
767+ lambda step : warm_up_linear (step + self .start_step , self .warmup_steps ),
768+ )
752769 elif self .opt_type == "HybridMuon" :
753770 self .optimizer = HybridMuonOptimizer (
754771 self .wrapper .parameters (),
0 commit comments