Skip to content

Commit 8a67cda

Browse files
authored
Add files via upload
1 parent 38aecda commit 8a67cda

4 files changed

Lines changed: 111 additions & 30 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ classifiers = [
1919
"Topic :: Scientific/Engineering :: Artificial Intelligence",
2020
]
2121
dependencies = [
22-
"torch>=2.1.0",
22+
"torch>=2.4.0",
2323
"transformers>=4.40.0",
2424
"datasets>=2.18.0",
2525
"safetensors>=0.4.0",

requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
torch>=2.1.0
1+
torch>=2.4.0
22
transformers>=4.40.0
33
datasets>=2.18.0
44
tokenizers>=0.19.0
55
safetensors>=0.4.0
66
optuna>=3.6.0
77
numpy>=1.24.0
88
tqdm>=4.66.0
9-
gradio>=6.0.0

ui/tabs/config_tab.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@
1212
"""
1313
from ui.gradio_compat import gr
1414

15-
from chronos.model.config import ChronosConfig
1615
from ui.i18n import t, register_translatable
1716
from ui.estimator import (
1817
ArchConfig, total_params, active_params, memory_footprint,
1918
estimated_decode_tps, fmt_bytes, fmt_params,
2019
)
2120
from ui.presets import (
22-
MINIMIND_MOE_DEFAULTS, PRESETS, preset_names, get_preset,
23-
values_in_input_order, save_config, load_config, CONFIG_INPUT_ORDER,
21+
preset_names, get_preset, values_in_input_order, save_config, load_config,
2422
)
2523

2624

@@ -88,7 +86,9 @@ def build_config_tab():
8886
data_path has been moved to the Train tab, so the tuple now has 3
8987
elements. Callers must be updated.
9088
"""
91-
config_state = gr.State(ChronosConfig().__dict__.copy())
89+
initial_preset = "Recommended-CN (≈120M)"
90+
initial_cfg = get_preset(initial_preset)
91+
config_state = gr.State(dict(initial_cfg))
9292

9393
with gr.Tab(t("tab.config")) as tab:
9494
register_translatable(tab, "tab.config")
@@ -97,7 +97,7 @@ def build_config_tab():
9797
with gr.Row():
9898
preset_dd = gr.Dropdown(
9999
choices=preset_names(),
100-
value="Recommended-CN (≈120M)",
100+
value=initial_preset,
101101
label=t("config.preset"),
102102
scale=2,
103103
)
@@ -123,7 +123,7 @@ def build_config_tab():
123123
with gr.Column(scale=3):
124124
gr.Markdown(f"### {t('config.arch')}")
125125

126-
D = MINIMIND_MOE_DEFAULTS # MiniMind-MoE starting values
126+
D = initial_cfg
127127

128128
with gr.Row():
129129
hidden_size = gr.Slider(*RANGES["hidden_size"], value=D["hidden_size"], label=t("config.hidden_size"))
@@ -211,7 +211,7 @@ def build_config_tab():
211211
register_translatable(save_interval, "config.save_interval")
212212
register_translatable(save_dir, "config.save_dir")
213213

214-
config_display = gr.JSON(label="Current Config (saved to config_state)", value={})
214+
config_display = gr.JSON(label="Current Config (saved to config_state)", value=initial_cfg)
215215

216216
with gr.Column(scale=1, min_width=260):
217217
gr.Markdown(f"### 🧬 {t('designer.title')}")
@@ -282,11 +282,11 @@ def update_config(*vals):
282282
# ── Preset / Save / Load wiring ───────────────────────
283283
def apply_preset(name):
284284
cfg = get_preset(name)
285-
return [gr.update(value=v) for v in values_in_input_order(cfg)]
285+
return [dict(cfg), dict(cfg)] + [gr.update(value=v) for v in values_in_input_order(cfg)]
286286

287287
def reset_minimind():
288288
cfg = get_preset("MiniMind-MoE (default)")
289-
return [gr.update(value=v) for v in values_in_input_order(cfg)]
289+
return [dict(cfg), dict(cfg)] + [gr.update(value=v) for v in values_in_input_order(cfg)]
290290

291291
def do_save(cfg, path):
292292
try:
@@ -299,21 +299,20 @@ def do_load(path):
299299
try:
300300
cfg = load_config(path)
301301
vals = values_in_input_order(cfg)
302-
return [f"✅ Loaded from `{path}`"] + [gr.update(value=v) for v in vals]
302+
return [f"✅ Loaded from `{path}`", dict(cfg), dict(cfg)] + [gr.update(value=v) for v in vals]
303303
except FileNotFoundError:
304-
return [f"❌ Not found: `{path}`"] + [gr.update() for _ in all_inputs]
304+
return [f"❌ Not found: `{path}`", gr.update(), gr.update()] + [gr.update() for _ in all_inputs]
305305
except Exception as e:
306-
return [f"❌ Load failed: {e}"] + [gr.update() for _ in all_inputs]
306+
return [f"❌ Load failed: {e}", gr.update(), gr.update()] + [gr.update() for _ in all_inputs]
307307

308-
apply_preset_btn.click(fn=apply_preset, inputs=[preset_dd], outputs=all_inputs)
308+
apply_preset_btn.click(fn=apply_preset, inputs=[preset_dd], outputs=[config_state, config_display] + all_inputs)
309309
# Selecting a preset in the dropdown should sync immediately;
310310
# the explicit "Load Preset" button stays as a re-apply affordance.
311-
preset_dd.change(fn=apply_preset, inputs=[preset_dd], outputs=all_inputs)
312-
reset_btn.click(fn=reset_minimind, outputs=all_inputs)
311+
preset_dd.change(fn=apply_preset, inputs=[preset_dd], outputs=[config_state, config_display] + all_inputs)
312+
reset_btn.click(fn=reset_minimind, outputs=[config_state, config_display] + all_inputs)
313313
save_btn.click(fn=do_save, inputs=[config_state, cfg_path], outputs=[save_status])
314-
load_btn.click(fn=do_load, inputs=[cfg_path], outputs=[save_status] + all_inputs)
314+
load_btn.click(fn=do_load, inputs=[cfg_path], outputs=[save_status, config_state, config_display] + all_inputs)
315315

316-
_initial_cfg = ChronosConfig().__dict__
317-
total_box.value, active_box.value, vram_box.value, ssd_box.value, kv_box.value, tps_box.value = _estimate(_initial_cfg)
316+
total_box.value, active_box.value, vram_box.value, ssd_box.value, kv_box.value, tps_box.value = _estimate(initial_cfg)
318317

319318
return config_state, all_inputs, save_dir

ui/tabs/train_tab.py

Lines changed: 92 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
ui/tabs/train_tab.py — Full training loop: Pretrain / SFT / DPO / ORPO / GRPO / Distill
33
"""
44
import os
5+
import glob
56
import time
67
import threading
78
import queue
@@ -47,6 +48,18 @@
4748
"grpo": "./tests/fixtures/tiny_grpo.jsonl",
4849
"distill": "./tests/fixtures/tiny_sft.jsonl",
4950
}
51+
CHECKPOINT_TOPOLOGY_KEYS = (
52+
"hidden_size",
53+
"num_hidden_layers",
54+
"num_experts",
55+
"moe_intermediate_size",
56+
"vocab_size",
57+
"lookahead_steps",
58+
"kv_latent_dim",
59+
"rope_dim",
60+
"num_attention_heads",
61+
"num_key_value_heads",
62+
)
5063
STAGE_HELP_TEXT = {
5164
"pretrain": {
5265
"zh-Hans": "从通用语料继续预训练。`init_weight` 可留空;若存在同名 checkpoint,将按当前拓扑尝试恢复。",
@@ -394,7 +407,36 @@ def _sniff_checkpoint(path: str) -> dict:
394407
total = int(lookahead_proj.shape[0])
395408
n_exp = int(out["num_experts"])
396409
if n_exp > 0 and total % n_exp == 0:
397-
out["lookahead_steps"] = total // n_exp
410+
# The router predicts current-token routing plus K future
411+
# steps, so the saved output rows are (lookahead_steps + 1)
412+
# * num_experts. Do not report the +1 as user topology.
413+
out["lookahead_steps"] = max(0, total // n_exp - 1)
414+
415+
qnope = sd.get("model.layers.0.self_attn.q_nope_proj.weight")
416+
qrope = sd.get("model.layers.0.self_attn.q_rope_proj.weight")
417+
kvdown = sd.get("model.layers.0.self_attn.kv_down_proj.weight")
418+
vproj = sd.get("model.layers.0.self_attn.v_proj.weight")
419+
if qnope is not None and qrope is not None and out.get("hidden_size"):
420+
candidates = [
421+
n for n in range(1, out["hidden_size"] + 1)
422+
if out["hidden_size"] % n == 0
423+
and int(qrope.shape[0]) % n == 0
424+
and int(qnope.shape[0]) % n == 0
425+
]
426+
if candidates:
427+
preferred = 8 if 8 in candidates else max(candidates)
428+
head_dim = out["hidden_size"] // preferred
429+
rope_dim = int(qrope.shape[0]) // preferred
430+
nope_dim = int(qnope.shape[0]) // preferred
431+
if rope_dim + nope_dim == head_dim:
432+
out["num_attention_heads"] = preferred
433+
out["rope_dim"] = rope_dim
434+
if kvdown is not None:
435+
out["kv_latent_dim"] = int(kvdown.shape[0])
436+
if vproj is not None and out.get("kv_latent_dim"):
437+
head_dim = out.get("hidden_size", 0) // max(out.get("num_attention_heads", 8), 1)
438+
if head_dim > 0:
439+
out["num_key_value_heads"] = max(1, int(vproj.shape[0]) // head_dim)
398440

399441
return out
400442
except Exception:
@@ -477,16 +519,42 @@ def _default_init_path(self, save_dir: str, mode: str, hidden_size: int) -> str:
477519
return ""
478520
return os.path.join(save_dir, f"{upstream}_{hidden_size}_moe.pth")
479521

522+
def _resolve_default_init_path(self, save_dir: str, mode: str, hidden_size: int) -> str:
523+
exact = self._default_init_path(save_dir, mode, hidden_size)
524+
if os.path.exists(exact):
525+
return exact
526+
upstream = STAGE_DEFAULT_INIT.get(mode)
527+
if not upstream:
528+
return exact
529+
candidates = [
530+
p for p in glob.glob(os.path.join(save_dir, f"{upstream}_*_moe.pth"))
531+
if os.path.isfile(p)
532+
]
533+
if not candidates:
534+
return exact
535+
return max(candidates, key=os.path.getmtime)
536+
480537
def _topology_mismatches(self, sniffed: dict, model_cfg_kwargs: dict) -> list[str]:
481538
mismatches = []
482-
for k in [
483-
"hidden_size", "num_hidden_layers", "num_experts",
484-
"moe_intermediate_size", "vocab_size", "lookahead_steps",
485-
]:
539+
for k in CHECKPOINT_TOPOLOGY_KEYS:
486540
if k in sniffed and k in model_cfg_kwargs and int(sniffed[k]) != int(model_cfg_kwargs[k]):
487541
mismatches.append(f"{k}: ckpt={sniffed[k]} != ui={model_cfg_kwargs[k]}")
488542
return mismatches
489543

544+
def _adopt_checkpoint_topology(self, model_cfg_kwargs: dict, sniffed: dict) -> dict:
545+
adopted = {}
546+
for key in CHECKPOINT_TOPOLOGY_KEYS:
547+
if key not in sniffed:
548+
continue
549+
current = model_cfg_kwargs.get(key)
550+
value = sniffed[key]
551+
if current is None or int(current) != int(value):
552+
adopted[key] = value
553+
model_cfg_kwargs[key] = value
554+
if "moe_intermediate_size" in sniffed:
555+
model_cfg_kwargs["intermediate_size"] = sniffed["moe_intermediate_size"]
556+
return adopted
557+
490558
def _build_stage_args(self, cfg: dict, mode: str, save_dir: str, hidden_size: int):
491559
reward_spec = (cfg.get("reward_spec") or "toy").strip() or "toy"
492560
teacher_path = (cfg.get("teacher_path") or "").strip()
@@ -681,6 +749,24 @@ def _run(self, cfg: dict, mode: str):
681749
if opt_key in AUTO_SENTINEL_KEYS and val == 0:
682750
continue
683751
model_cfg_kwargs[opt_key] = val
752+
save_dir = cfg.get("save_dir", "./out")
753+
init_weight = (cfg.get("init_weight") or "").strip()
754+
load_path = ""
755+
init_sniffed = {}
756+
if mode != "pretrain":
757+
load_path = init_weight or self._resolve_default_init_path(
758+
save_dir, mode, int(model_cfg_kwargs["hidden_size"])
759+
)
760+
if os.path.exists(load_path):
761+
init_sniffed = _sniff_checkpoint(load_path)
762+
adopted = self._adopt_checkpoint_topology(model_cfg_kwargs, init_sniffed)
763+
if adopted:
764+
summary = ", ".join(f"{k}={v}" for k, v in adopted.items())
765+
self._put(
766+
f"[{mode.upper()}] Adopted topology from init checkpoint "
767+
f"{load_path}: {summary}"
768+
)
769+
684770
model_cfg = ChronosConfig(**model_cfg_kwargs)
685771
model = ChronosForCausalLM(model_cfg)
686772
params_m = sum(p.numel() for p in model.parameters()) / 1e6
@@ -689,14 +775,12 @@ def _run(self, cfg: dict, mode: str):
689775
f"E={model_cfg.num_experts}, ffn={model_cfg.intermediate_size}, "
690776
f"vocab={model_cfg.vocab_size})")
691777

692-
save_dir = cfg.get("save_dir", "./out")
693778
ckp_path = self._stage_checkpoint_path(save_dir, mode, model_cfg.hidden_size)
694779

695780
# Init-weight resolution:
696781
# - pretrain: optionally resume from its own checkpoint if present
697782
# - other stages: require an upstream weight (explicit init_weight
698783
# or the stage's default predecessor checkpoint).
699-
init_weight = (cfg.get("init_weight") or "").strip()
700784
if mode == "pretrain":
701785
resume_path = init_weight or ckp_path
702786
if os.path.exists(resume_path):
@@ -714,7 +798,6 @@ def _run(self, cfg: dict, mode: str):
714798
else:
715799
self._put("Pretraining from random init")
716800
else:
717-
load_path = init_weight or self._default_init_path(save_dir, mode, model_cfg.hidden_size)
718801
if not os.path.exists(load_path):
719802
raise FileNotFoundError(
720803
f"[{mode.upper()}] requires an upstream checkpoint to initialize from. "
@@ -723,7 +806,7 @@ def _run(self, cfg: dict, mode: str):
723806
f"or run the prior stage first."
724807
)
725808

726-
sniffed = _sniff_checkpoint(load_path)
809+
sniffed = init_sniffed or _sniff_checkpoint(load_path)
727810
mismatch_hints = self._topology_mismatches(sniffed, model_cfg_kwargs)
728811
if mismatch_hints:
729812
raise RuntimeError(

0 commit comments

Comments
 (0)