Skip to content

Feat: add checkpoint loading mechanism#146

Open
JYMiracle305 wants to merge 5 commits intomasterfrom
feature/add_checkpoint
Open

Feat: add checkpoint loading mechanism#146
JYMiracle305 wants to merge 5 commits intomasterfrom
feature/add_checkpoint

Conversation

@JYMiracle305
Copy link
Copy Markdown
Contributor

@JYMiracle305 JYMiracle305 commented Apr 21, 2026

From ArcaLunar

Checkpoint 读取工具主要参数:

  • --checkpoint_dir 训练过程中的保存目录
  • --save_steps 每 N 次保存一次,设置为 0 则不保存
  • --max_checkpoint_keep 最多保留 K 个 checkpoint
  • --save_optimizer_state 是否保存优化器的状态
  • --resume_from 从指定 checkpoint 目录恢复训练

Checkpoint 文件可以通过从 /data/shared/....../llmc/gpt2 (or llama3) 的原始模型参数训练而来,例子可见仓库中的 REPORT.md(Experiment 实际上也测试了llama3,但是命令只记录了 GPT2 训练),model.bin, optimizer.bin, trainer_state.json 都可以从训练中获取.因此不在附件中提供

Experiment

CUDA_VISIBLE_DEVICES=5,6,7 ./gpt2 --input_bin ../../data/llmc/gpt2/tinyshakespeare/tiny_shakespeare_train.bin --llmc_filepath ../../data/llmc/gpt2/gpt2_124M.bin --checkpoint_dir ../ckpt2/gpt2-noresume/ --num_iteration 100 --save_steps 20 --save_optimizer_state true --max_checkpoint_keep 10
CUDA_VISIBLE_DEVICES=5,6,7 ./gpt2 --input_bin ../../data/llmc/gpt2/tinyshakespeare/tiny_shakespeare_train.bin --llmc_filepath ../../data/llmc/gpt2/gpt2_124M.bin --checkpoint_dir ../ckpt2/gpt2-resumefrom40/ --num_iteration 100 --save_steps 20 --save_optimizer_state true --max_checkpoint_keep 10 --resume_from ../ckpt2/gpt2-noresume/checkpoint_step_000040/ > ../ckpt2/gpt2-resumefrom40/gpt2-resume.log 2>&1

(以上两条训练命令同样用 llama3 也运行了)

运行 compare_loss.py,对于 llama3 模型,由于从 step 40 恢复训练,所以 step 1~40 数据缺失,而其余 60 步的 loss 在 FP32, BF16 下均吻合

  Summary: 60/100 steps matched

==================================================
Overall Summary:
  fp32:    0/1 test cases passed (threshold: 1e-05)
  bfloat16: 0/0 test cases passed (threshold: 1e-02)
  Total:   0/1 test cases passed
==================================================

==================================================
Overall Summary:
  fp32:    0/0 test cases passed (threshold: 1e-05)
  bfloat16: 0/1 test cases passed (threshold: 1e-02)
  Total:   0/1 test cases passed
==================================================

对于 GPT2,模型保存的逻辑有误:训练中 lm_head 与 wte 并非真共享,而 LLMC 存取又按“共享”假设处理,resume 后 lm_head 很容易和 no resume 不一致。解决方法是把训练用 checkpoint 从 LLMC 回调路径切到原生 StateDict 二进制路径,并在加载后显式重建权重绑定语义 (example/gpt2/main.cc).经过修复后,也可以通过.

@JYMiracle305 JYMiracle305 force-pushed the feature/add_checkpoint branch from e8c5dd5 to 0a3deb2 Compare April 24, 2026 09:22
@JYMiracle305 JYMiracle305 changed the title [WIP] Feat: add checkpoint loading mechanism Feat: add checkpoint loading mechanism Apr 29, 2026
Comment thread example/gpt2/main.cc
DEFINE_string(checkpoint_format, "pth",
"checkpoint format: bin|pth. "
"'bin' generates model.bin/optimizer.bin (bin supports LLMC model format via callbacks); "
"'pth' generates model.pth/optimizer.pth (native StateDict binary).");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 pth 和 bin是不是存的是同一种 native StateDict binary 格式啊,能不能去掉 pth

std::unordered_map<std::string, std::shared_ptr<Tensor>> Adam::StateDict() const {
std::unordered_map<std::string, std::shared_ptr<Tensor>> state;
for (size_t i = 0; i < m_.size(); ++i) {
state.emplace(std::format("adam.m.{}", i), m_[i]);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里m_来自Module::Parameters(),是个vector,不是强保序的,可能有点风险

return state;
}

void Module::LoadStateDict(const std::unordered_map<std::string, std::shared_ptr<Tensor>> &state_dict) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需不需要检查一下 checkpoint 中是否有多余 key

@@ -0,0 +1,1060 @@
#include "example/common/checkpoint_loader.h"
Copy link
Copy Markdown
Contributor

@chen2021673 chen2021673 May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件太重了,有一千多行,checkpoint相关的基建和 llama / gpt 的 save / load 都混在一起了。要不要拆分一个example/common/checkpoint_utils.h/.cc,然后保留 gpt2 和 llama3 各自的特化调用?这个可以再讨论一下


ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args) {
ResumeFromCheckpointResult result;
int ddp_world_size = nn::parallel::global::GetDataParallelSize();
Copy link
Copy Markdown
Contributor

@chen2021673 chen2021673 May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里还需要检查 TP / PP / SP size 吗?

Comment thread example/common/utils.cc
#include "gflags/gflags.h"
#include "gflags/gflags_declare.h"
#include "glog/logging.h"
#include "infini_train/include/nn/parallel/global.h"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个 include 还需要吗,是不是可以删掉

return DataLoaderIterator(*dataset_, batch_size_, max_batch_idx_, max_batch_idx_);
}

DataLoaderIterator DataLoader::IteratorAtBatchIndex(size_t batch_idx) const {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

检查一下batch_idx % ddp_world_size == ddp_rank

int64_t global_step = 0;
int64_t data_batch_idx = 0;
int64_t data_batch_stride = 1;
float best_loss = 0.0f;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

best_loss 不应该是0.0,应该是infinity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants