22ui/tabs/train_tab.py — Full training loop: Pretrain / SFT / DPO / ORPO / GRPO / Distill
33"""
44import os
5+ import glob
56import time
67import threading
78import queue
4748 "grpo" : "./tests/fixtures/tiny_grpo.jsonl" ,
4849 "distill" : "./tests/fixtures/tiny_sft.jsonl" ,
4950}
51+ CHECKPOINT_TOPOLOGY_KEYS = (
52+ "hidden_size" ,
53+ "num_hidden_layers" ,
54+ "num_experts" ,
55+ "moe_intermediate_size" ,
56+ "vocab_size" ,
57+ "lookahead_steps" ,
58+ "kv_latent_dim" ,
59+ "rope_dim" ,
60+ "num_attention_heads" ,
61+ "num_key_value_heads" ,
62+ )
5063STAGE_HELP_TEXT = {
5164 "pretrain" : {
5265 "zh-Hans" : "从通用语料继续预训练。`init_weight` 可留空;若存在同名 checkpoint,将按当前拓扑尝试恢复。" ,
@@ -394,7 +407,36 @@ def _sniff_checkpoint(path: str) -> dict:
394407 total = int (lookahead_proj .shape [0 ])
395408 n_exp = int (out ["num_experts" ])
396409 if n_exp > 0 and total % n_exp == 0 :
397- out ["lookahead_steps" ] = total // n_exp
410+ # The router predicts current-token routing plus K future
411+ # steps, so the saved output rows are (lookahead_steps + 1)
412+ # * num_experts. Do not report the +1 as user topology.
413+ out ["lookahead_steps" ] = max (0 , total // n_exp - 1 )
414+
415+ qnope = sd .get ("model.layers.0.self_attn.q_nope_proj.weight" )
416+ qrope = sd .get ("model.layers.0.self_attn.q_rope_proj.weight" )
417+ kvdown = sd .get ("model.layers.0.self_attn.kv_down_proj.weight" )
418+ vproj = sd .get ("model.layers.0.self_attn.v_proj.weight" )
419+ if qnope is not None and qrope is not None and out .get ("hidden_size" ):
420+ candidates = [
421+ n for n in range (1 , out ["hidden_size" ] + 1 )
422+ if out ["hidden_size" ] % n == 0
423+ and int (qrope .shape [0 ]) % n == 0
424+ and int (qnope .shape [0 ]) % n == 0
425+ ]
426+ if candidates :
427+ preferred = 8 if 8 in candidates else max (candidates )
428+ head_dim = out ["hidden_size" ] // preferred
429+ rope_dim = int (qrope .shape [0 ]) // preferred
430+ nope_dim = int (qnope .shape [0 ]) // preferred
431+ if rope_dim + nope_dim == head_dim :
432+ out ["num_attention_heads" ] = preferred
433+ out ["rope_dim" ] = rope_dim
434+ if kvdown is not None :
435+ out ["kv_latent_dim" ] = int (kvdown .shape [0 ])
436+ if vproj is not None and out .get ("kv_latent_dim" ):
437+ head_dim = out .get ("hidden_size" , 0 ) // max (out .get ("num_attention_heads" , 8 ), 1 )
438+ if head_dim > 0 :
439+ out ["num_key_value_heads" ] = max (1 , int (vproj .shape [0 ]) // head_dim )
398440
399441 return out
400442 except Exception :
@@ -477,16 +519,42 @@ def _default_init_path(self, save_dir: str, mode: str, hidden_size: int) -> str:
477519 return ""
478520 return os .path .join (save_dir , f"{ upstream } _{ hidden_size } _moe.pth" )
479521
522+ def _resolve_default_init_path (self , save_dir : str , mode : str , hidden_size : int ) -> str :
523+ exact = self ._default_init_path (save_dir , mode , hidden_size )
524+ if os .path .exists (exact ):
525+ return exact
526+ upstream = STAGE_DEFAULT_INIT .get (mode )
527+ if not upstream :
528+ return exact
529+ candidates = [
530+ p for p in glob .glob (os .path .join (save_dir , f"{ upstream } _*_moe.pth" ))
531+ if os .path .isfile (p )
532+ ]
533+ if not candidates :
534+ return exact
535+ return max (candidates , key = os .path .getmtime )
536+
480537 def _topology_mismatches (self , sniffed : dict , model_cfg_kwargs : dict ) -> list [str ]:
481538 mismatches = []
482- for k in [
483- "hidden_size" , "num_hidden_layers" , "num_experts" ,
484- "moe_intermediate_size" , "vocab_size" , "lookahead_steps" ,
485- ]:
539+ for k in CHECKPOINT_TOPOLOGY_KEYS :
486540 if k in sniffed and k in model_cfg_kwargs and int (sniffed [k ]) != int (model_cfg_kwargs [k ]):
487541 mismatches .append (f"{ k } : ckpt={ sniffed [k ]} != ui={ model_cfg_kwargs [k ]} " )
488542 return mismatches
489543
544+ def _adopt_checkpoint_topology (self , model_cfg_kwargs : dict , sniffed : dict ) -> dict :
545+ adopted = {}
546+ for key in CHECKPOINT_TOPOLOGY_KEYS :
547+ if key not in sniffed :
548+ continue
549+ current = model_cfg_kwargs .get (key )
550+ value = sniffed [key ]
551+ if current is None or int (current ) != int (value ):
552+ adopted [key ] = value
553+ model_cfg_kwargs [key ] = value
554+ if "moe_intermediate_size" in sniffed :
555+ model_cfg_kwargs ["intermediate_size" ] = sniffed ["moe_intermediate_size" ]
556+ return adopted
557+
490558 def _build_stage_args (self , cfg : dict , mode : str , save_dir : str , hidden_size : int ):
491559 reward_spec = (cfg .get ("reward_spec" ) or "toy" ).strip () or "toy"
492560 teacher_path = (cfg .get ("teacher_path" ) or "" ).strip ()
@@ -681,6 +749,24 @@ def _run(self, cfg: dict, mode: str):
681749 if opt_key in AUTO_SENTINEL_KEYS and val == 0 :
682750 continue
683751 model_cfg_kwargs [opt_key ] = val
752+ save_dir = cfg .get ("save_dir" , "./out" )
753+ init_weight = (cfg .get ("init_weight" ) or "" ).strip ()
754+ load_path = ""
755+ init_sniffed = {}
756+ if mode != "pretrain" :
757+ load_path = init_weight or self ._resolve_default_init_path (
758+ save_dir , mode , int (model_cfg_kwargs ["hidden_size" ])
759+ )
760+ if os .path .exists (load_path ):
761+ init_sniffed = _sniff_checkpoint (load_path )
762+ adopted = self ._adopt_checkpoint_topology (model_cfg_kwargs , init_sniffed )
763+ if adopted :
764+ summary = ", " .join (f"{ k } ={ v } " for k , v in adopted .items ())
765+ self ._put (
766+ f"[{ mode .upper ()} ] Adopted topology from init checkpoint "
767+ f"{ load_path } : { summary } "
768+ )
769+
684770 model_cfg = ChronosConfig (** model_cfg_kwargs )
685771 model = ChronosForCausalLM (model_cfg )
686772 params_m = sum (p .numel () for p in model .parameters ()) / 1e6
@@ -689,14 +775,12 @@ def _run(self, cfg: dict, mode: str):
689775 f"E={ model_cfg .num_experts } , ffn={ model_cfg .intermediate_size } , "
690776 f"vocab={ model_cfg .vocab_size } )" )
691777
692- save_dir = cfg .get ("save_dir" , "./out" )
693778 ckp_path = self ._stage_checkpoint_path (save_dir , mode , model_cfg .hidden_size )
694779
695780 # Init-weight resolution:
696781 # - pretrain: optionally resume from its own checkpoint if present
697782 # - other stages: require an upstream weight (explicit init_weight
698783 # or the stage's default predecessor checkpoint).
699- init_weight = (cfg .get ("init_weight" ) or "" ).strip ()
700784 if mode == "pretrain" :
701785 resume_path = init_weight or ckp_path
702786 if os .path .exists (resume_path ):
@@ -714,7 +798,6 @@ def _run(self, cfg: dict, mode: str):
714798 else :
715799 self ._put ("Pretraining from random init" )
716800 else :
717- load_path = init_weight or self ._default_init_path (save_dir , mode , model_cfg .hidden_size )
718801 if not os .path .exists (load_path ):
719802 raise FileNotFoundError (
720803 f"[{ mode .upper ()} ] requires an upstream checkpoint to initialize from. "
@@ -723,7 +806,7 @@ def _run(self, cfg: dict, mode: str):
723806 f"or run the prior stage first."
724807 )
725808
726- sniffed = _sniff_checkpoint (load_path )
809+ sniffed = init_sniffed or _sniff_checkpoint (load_path )
727810 mismatch_hints = self ._topology_mismatches (sniffed , model_cfg_kwargs )
728811 if mismatch_hints :
729812 raise RuntimeError (
0 commit comments