@@ -741,72 +741,53 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
741741
742742 # TODO add optimizers for multitask
743743 # author: iProzd
744- if self .opt_type in ["Adam" , "AdamW" ]:
744+ if self .opt_type == "LKF" :
745+ self .optimizer = LKFOptimizer (
746+ self .wrapper .parameters (), 0.98 , 0.99870 , self .opt_param ["kf_blocksize" ]
747+ )
748+ else :
749+ # === Common path for gradient-based optimizers ===
745750 adam_betas = (
746751 float (self .opt_param ["adam_beta1" ]),
747752 float (self .opt_param ["adam_beta2" ]),
748753 )
749754 weight_decay = float (self .opt_param ["weight_decay" ])
750- optimizer_class = (
751- torch .optim .Adam if self .opt_type == "Adam" else torch .optim .AdamW
752- )
755+
756+ if self .opt_type in ("Adam" , "AdamW" ):
757+ cls = torch .optim .Adam if self .opt_type == "Adam" else torch .optim .AdamW
758+ extra = {"betas" : adam_betas , "fused" : DEVICE .type != "cpu" }
759+ elif self .opt_type == "AdaMuon" :
760+ cls = AdaMuonOptimizer
761+ extra = {
762+ "adam_betas" : adam_betas ,
763+ "momentum" : float (self .opt_param ["momentum" ]),
764+ "lr_adjust" : float (self .opt_param ["lr_adjust" ]),
765+ "lr_adjust_coeff" : float (self .opt_param ["lr_adjust_coeff" ]),
766+ }
767+ elif self .opt_type == "HybridMuon" :
768+ cls = HybridMuonOptimizer
769+ extra = {
770+ "adam_betas" : adam_betas ,
771+ "momentum" : float (self .opt_param ["momentum" ]),
772+ "lr_adjust" : float (self .opt_param ["lr_adjust" ]),
773+ "lr_adjust_coeff" : float (self .opt_param ["lr_adjust_coeff" ]),
774+ "muon_2d_only" : bool (self .opt_param ["muon_2d_only" ]),
775+ "min_2d_dim" : int (self .opt_param ["min_2d_dim" ]),
776+ }
777+ else :
778+ raise ValueError (f"Not supported optimizer type '{ self .opt_type } '" )
779+
753780 self .optimizer = self ._create_optimizer (
754- optimizer_class ,
781+ cls ,
755782 lr = self .lr_exp .start_lr ,
756- betas = adam_betas ,
757783 weight_decay = weight_decay ,
758- fused = DEVICE .type != "cpu" ,
759- )
760- self ._load_optimizer_state (optimizer_state_dict )
761- self .scheduler = torch .optim .lr_scheduler .LambdaLR (
762- self .optimizer ,
763- lambda step : warm_up_linear (step + self .start_step , self .warmup_steps ),
764- )
765- elif self .opt_type == "LKF" :
766- self .optimizer = LKFOptimizer (
767- self .wrapper .parameters (), 0.98 , 0.99870 , self .opt_param ["kf_blocksize" ]
768- )
769- elif self .opt_type == "AdaMuon" :
770- self .optimizer = self ._create_optimizer (
771- AdaMuonOptimizer ,
772- lr = self .lr_exp .start_lr ,
773- momentum = float (self .opt_param ["momentum" ]),
774- weight_decay = float (self .opt_param ["weight_decay" ]),
775- adam_betas = (
776- float (self .opt_param ["adam_beta1" ]),
777- float (self .opt_param ["adam_beta2" ]),
778- ),
779- lr_adjust = float (self .opt_param ["lr_adjust" ]),
780- lr_adjust_coeff = float (self .opt_param ["lr_adjust_coeff" ]),
781- )
782- if optimizer_state_dict is not None and self .restart_training :
783- self .optimizer .load_state_dict (optimizer_state_dict )
784- self .scheduler = torch .optim .lr_scheduler .LambdaLR (
785- self .optimizer ,
786- lambda step : warm_up_linear (step + self .start_step , self .warmup_steps ),
787- )
788- elif self .opt_type == "HybridMuon" :
789- self .optimizer = self ._create_optimizer (
790- HybridMuonOptimizer ,
791- lr = self .lr_exp .start_lr ,
792- momentum = float (self .opt_param ["momentum" ]),
793- weight_decay = float (self .opt_param ["weight_decay" ]),
794- adam_betas = (
795- float (self .opt_param ["adam_beta1" ]),
796- float (self .opt_param ["adam_beta2" ]),
797- ),
798- lr_adjust = float (self .opt_param ["lr_adjust" ]),
799- lr_adjust_coeff = float (self .opt_param ["lr_adjust_coeff" ]),
800- muon_2d_only = bool (self .opt_param ["muon_2d_only" ]),
801- min_2d_dim = int (self .opt_param ["min_2d_dim" ]),
784+ ** extra ,
802785 )
803786 self ._load_optimizer_state (optimizer_state_dict )
804787 self .scheduler = torch .optim .lr_scheduler .LambdaLR (
805788 self .optimizer ,
806789 lambda step : warm_up_linear (step + self .start_step , self .warmup_steps ),
807790 )
808- else :
809- raise ValueError (f"Not supported optimizer type '{ self .opt_type } '" )
810791
811792 if self .zero_stage > 0 and self .rank == 0 :
812793 if self .zero_stage == 1 :
0 commit comments