diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index c12b5a28..74901d7f 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -13,6 +13,7 @@ #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" +#include "infini_train/include/lr_scheduler.h" #include "infini_train/include/nn/lora/lora_utils.h" #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" @@ -56,8 +57,14 @@ DEFINE_uint32(num_iteration, 10, "number of iterations to run"); DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation"); DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization -DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations"); +DEFINE_double(learning_rate, 1e-4, "Peak learning rate."); DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); +// lr scheduler +DEFINE_double(min_lr, 0.0, "Minimum learning rate."); +DEFINE_string(lr_decay_style, "constant", "LR decay style: none|constant|linear|cosine|inverse-square-root"); +DEFINE_int64(lr_warmup_iters, 0, "Number of linear warmup iterations."); +DEFINE_double(lr_warmup_init, 0.0, "Initial learning rate at the start of warmup."); +DEFINE_int64(lr_decay_iters, 0, "Number of iterations to decay LR over (0 = num_iteration)."); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -100,6 +107,8 @@ constexpr char kDeviceCPU[] = "cpu"; constexpr char kDeviceCUDA[] = "cuda"; constexpr char kDtypeFP32[] = "float32"; constexpr char kDtypeBF16[] = "bfloat16"; +const std::unordered_set kSupportedLRDecayStyles + = {"none", "constant", "linear", "cosine", "inverse-square-root"}; // const std::unordered_map kModelToConfigs = { @@ -114,6 +123,8 @@ const std::unordered_map kModelToConfigs = { DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); }); DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); +DEFINE_validator(lr_decay_style, + [](const char *, const std::string &value) { return kSupportedLRDecayStyles.contains(value); }); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -305,6 +316,16 @@ void Train(const nn::parallel::Rank &rank) { optimizer = optimizer_creator(params_to_optimize); } + const int64_t lr_decay_iters = FLAGS_lr_decay_iters > 0 ? FLAGS_lr_decay_iters : FLAGS_num_iteration; + TrainingLRSchedulerConfig sched_config; + sched_config.lr = static_cast(FLAGS_learning_rate); + sched_config.min_lr = static_cast(FLAGS_min_lr); + sched_config.lr_decay_style = FLAGS_lr_decay_style; + sched_config.lr_decay_iters = lr_decay_iters; + sched_config.lr_warmup_iters = FLAGS_lr_warmup_iters; + sched_config.lr_warmup_init = static_cast(FLAGS_lr_warmup_init); + auto scheduler = CreateLRScheduler(optimizer, sched_config); + auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast( @@ -348,6 +369,7 @@ void Train(const nn::parallel::Rank &rank) { Profiler::Instance().SetTag("Step_" + std::to_string(step)); #endif + const float current_lr = scheduler ? scheduler->GetLR() : static_cast(FLAGS_learning_rate); float lossf = 0.0f; // model->Train(); if (pp_world_size == 1) { @@ -392,6 +414,9 @@ void Train(const nn::parallel::Rank &rank) { } optimizer->Step(); + if (scheduler) { + scheduler->Step(); + } } else { auto [x, y] = *train_iter; // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below @@ -401,6 +426,9 @@ void Train(const nn::parallel::Rank &rank) { y = std::make_shared(y->To(device)); lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); + if (scheduler) { + scheduler->Step(); + } } if (ddp_world_size > 1) { @@ -416,11 +444,10 @@ void Train(const nn::parallel::Rank &rank) { if (rank.IsLastRank()) { size_t used_mb = 0, reserved_mb = 0; std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); - LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", - step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, - tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, + step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps, + used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 117551d5..5c841946 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -11,6 +11,7 @@ #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" +#include "infini_train/include/lr_scheduler.h" #include "infini_train/include/nn/lora/lora_utils.h" #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" @@ -55,8 +56,14 @@ DEFINE_uint32(num_iteration, 10, "number of iterations to run"); DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation"); DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization -DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations"); +DEFINE_double(learning_rate, 1e-5, "Peak learning rate."); DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); +// lr scheduler +DEFINE_double(min_lr, 0.0, "Minimum learning rate."); +DEFINE_string(lr_decay_style, "constant", "LR decay style: none|constant|linear|cosine|inverse-square-root"); +DEFINE_int64(lr_warmup_iters, 0, "Number of linear warmup iterations."); +DEFINE_double(lr_warmup_init, 0.0, "Initial learning rate at the start of warmup."); +DEFINE_int64(lr_decay_iters, 0, "Number of iterations to decay LR over (0 = num_iteration)."); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -95,11 +102,15 @@ constexpr char kDeviceCPU[] = "cpu"; constexpr char kDeviceCUDA[] = "cuda"; constexpr char kDtypeFP32[] = "float32"; constexpr char kDtypeBF16[] = "bfloat16"; +const std::unordered_set kSupportedLRDecayStyles + = {"none", "constant", "linear", "cosine", "inverse-square-root"}; } // namespace DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); }); DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); +DEFINE_validator(lr_decay_style, + [](const char *, const std::string &value) { return kSupportedLRDecayStyles.contains(value); }); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -284,6 +295,16 @@ void Train(const nn::parallel::Rank &rank) { optimizer = optimizer_creator(params_to_optimize); } + const int64_t lr_decay_iters = FLAGS_lr_decay_iters > 0 ? FLAGS_lr_decay_iters : FLAGS_num_iteration; + TrainingLRSchedulerConfig sched_config; + sched_config.lr = static_cast(FLAGS_learning_rate); + sched_config.min_lr = static_cast(FLAGS_min_lr); + sched_config.lr_decay_style = FLAGS_lr_decay_style; + sched_config.lr_decay_iters = lr_decay_iters; + sched_config.lr_warmup_iters = FLAGS_lr_warmup_iters; + sched_config.lr_warmup_init = static_cast(FLAGS_lr_warmup_init); + auto scheduler = CreateLRScheduler(optimizer, sched_config); + auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast(std::make_shared()) @@ -324,6 +345,7 @@ void Train(const nn::parallel::Rank &rank) { Profiler::Instance().SetTag("Step_" + std::to_string(step)); #endif + const float current_lr = scheduler ? scheduler->GetLR() : static_cast(FLAGS_learning_rate); float lossf = 0.0f; if (pp_world_size == 1) { // model->Train(); @@ -367,6 +389,9 @@ void Train(const nn::parallel::Rank &rank) { } optimizer->Step(); + if (scheduler) { + scheduler->Step(); + } } else { auto [x, y] = *train_iter; // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below @@ -376,6 +401,9 @@ void Train(const nn::parallel::Rank &rank) { y = std::make_shared(y->To(device)); lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); + if (scheduler) { + scheduler->Step(); + } } if (ddp_world_size > 1) { @@ -391,11 +419,10 @@ void Train(const nn::parallel::Rank &rank) { if (rank.IsLastRank()) { size_t used_mb = 0, reserved_mb = 0; std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); - LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", - step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, - tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, + step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps, + used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { diff --git a/infini_train/include/lr_scheduler.h b/infini_train/include/lr_scheduler.h new file mode 100644 index 00000000..13c8e79a --- /dev/null +++ b/infini_train/include/lr_scheduler.h @@ -0,0 +1,173 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace infini_train { + +class Optimizer; + +using StateValue = std::variant>; +using StateDict = std::unordered_map; + +struct TrainingLRSchedulerConfig { + std::string lr_decay_style = "constant"; + float lr = 0.0f; + float min_lr = 0.0f; + int64_t lr_decay_iters = 1; + int64_t lr_warmup_iters = 0; + float lr_warmup_init = 0.0f; +}; + +class LRScheduler { +public: + template static std::shared_ptr Create(Args &&...args) { + auto scheduler = std::make_shared(std::forward(args)...); + scheduler->InitialStep(); + return scheduler; + } + + explicit LRScheduler(std::shared_ptr optimizer, int64_t last_step = -1); + virtual ~LRScheduler() = default; + + LRScheduler(const LRScheduler &) = delete; + LRScheduler &operator=(const LRScheduler &) = delete; + + virtual void Step(); + virtual void Step(int64_t epoch); + virtual void InitialStep(); + + float GetLR() const; + float BaseLR() const; + int64_t LastStep() const; + + void ResetStep(int64_t step = -1); + virtual StateDict State() const; + virtual void LoadState(const StateDict &state); + + bool SharesOptimizerWith(const std::shared_ptr &opt) const; + +protected: + virtual float GetClosedFormLR() const = 0; + virtual float GetChainedFormLR() const; + void ApplyLR(float lr); + + std::shared_ptr optimizer_; + int64_t last_step_; + float recover_lr_; + float base_lr_; + bool is_initial_ = false; +}; + +std::shared_ptr CreateLRScheduler(std::shared_ptr optimizer, + const TrainingLRSchedulerConfig &config); + +namespace lr_schedulers { + +class ConstantLR : public LRScheduler { +public: + ConstantLR(std::shared_ptr optimizer, float factor = 1.0f / 3.0f, int total_iters = 5, + int64_t last_step = -1); + ~ConstantLR() override = default; + +protected: + float GetChainedFormLR() const override; + float GetClosedFormLR() const override; + +private: + const float factor_; + const int64_t total_iters_; +}; + +class StepLR : public LRScheduler { +public: + StepLR(std::shared_ptr optimizer, int64_t step_size, float gamma = 0.1f, int64_t last_step = -1); + ~StepLR() override = default; + +protected: + float GetChainedFormLR() const override; + float GetClosedFormLR() const override; + +private: + const int64_t step_size_; + const float gamma_; +}; + +class LinearLR : public LRScheduler { +public: + LinearLR(std::shared_ptr optimizer, float start_factor = 1.0f / 3.0f, float end_factor = 1.0f, + int64_t total_iters = 5, int64_t last_step = -1); + ~LinearLR() override = default; + +protected: + float GetChainedFormLR() const override; + float GetClosedFormLR() const override; + +private: + const float start_factor_; + const float end_factor_; + const int64_t total_iters_; +}; + +class LambdaLR : public LRScheduler { +public: + using LambdaFunc = std::function; + + LambdaLR(std::shared_ptr optimizer, LambdaFunc lr_lambda, int64_t last_step = -1); + ~LambdaLR() override = default; + +protected: + float GetClosedFormLR() const override; + +private: + const LambdaFunc lr_lambda_; +}; + +class SequentialLR : public LRScheduler { +public: + SequentialLR(std::shared_ptr optimizer, std::vector> schedulers, + std::vector milestones, int64_t last_step = -1); + ~SequentialLR() override = default; + + void Step() override; + void InitialStep() override; + + StateDict State() const override; + void LoadState(const StateDict &state) override; + +protected: + float GetClosedFormLR() const override; + void UndoChildInitialSteps(); + +private: + std::vector> schedulers_; + std::vector milestones_; +}; + +class ChainedScheduler : public LRScheduler { +public: + ChainedScheduler(std::shared_ptr optimizer, std::vector> schedulers, + int64_t last_step = -1); + ~ChainedScheduler() override = default; + + void Step() override; + void InitialStep() override; + + StateDict State() const override; + void LoadState(const StateDict &state) override; + +protected: + float GetClosedFormLR() const override; + +private: + std::vector> schedulers_; +}; + +} // namespace lr_schedulers +} // namespace infini_train diff --git a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h index bc31442e..18947ec7 100644 --- a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h +++ b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h @@ -34,6 +34,9 @@ class DistributedOptimizer final : public infini_train::Optimizer { void StartParamSync(bool force_sync = false); void FinishParamSync(bool skip_next_bucket_dispatch = false); + virtual void set_learning_rate(float lr) override; + virtual float learning_rate() const override; + private: void BuildShardParamsAndBindGrads(); diff --git a/infini_train/include/optimizer.h b/infini_train/include/optimizer.h index fb0ae2d5..942fd92a 100644 --- a/infini_train/include/optimizer.h +++ b/infini_train/include/optimizer.h @@ -15,14 +15,25 @@ using OptimizerCreator = std::function(const std::vec class Optimizer { public: - explicit Optimizer(const std::vector> ¶ms); + explicit Optimizer(const std::vector> ¶ms, float learning_rate = 0.0f); virtual void ZeroGrad(bool set_to_none = true); virtual void Step() = 0; + virtual void set_learning_rate(float lr); + + virtual float learning_rate() const; + + float initial_learning_rate() const; + + void set_initial_learning_rate(float lr); + protected: std::vector> params_; + float learning_rate_ = 0.0f; + float initial_learning_rate_ = 0.0f; + bool initial_lr_set_ = false; }; namespace optimizers { @@ -37,9 +48,6 @@ class SGD : public Optimizer { return std::make_shared(params, learning_rate); }; } - -private: - const float learning_rate_ = 0.0; }; class Adam : public Optimizer { @@ -58,7 +66,6 @@ class Adam : public Optimizer { private: int64_t t_; - const float learning_rate_; const float beta1_; const float beta2_; const float eps_; diff --git a/infini_train/src/lr_scheduler.cc b/infini_train/src/lr_scheduler.cc new file mode 100644 index 00000000..42afb165 --- /dev/null +++ b/infini_train/src/lr_scheduler.cc @@ -0,0 +1,372 @@ +#include "infini_train/include/lr_scheduler.h" + +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/optimizer.h" + +namespace infini_train { + +std::shared_ptr CreateLRScheduler(std::shared_ptr optimizer, + const TrainingLRSchedulerConfig &config) { + if (config.lr_decay_style == "none") { + return nullptr; + } + + CHECK(optimizer) << "CreateLRScheduler: optimizer must not be null."; + const float max_lr = config.lr != 0.0f ? config.lr : optimizer->learning_rate(); + CHECK_GT(max_lr, 0.0f) << "CreateLRScheduler: max_lr must be > 0."; + CHECK_GE(config.lr_warmup_init, 0.0f) << "CreateLRScheduler: lr_warmup_init must be >= 0."; + CHECK_GE(config.min_lr, 0.0f) << "CreateLRScheduler: min_lr must be >= 0."; + CHECK_GE(max_lr, config.min_lr) << "CreateLRScheduler: max_lr must be >= min_lr."; + CHECK_LE(config.lr_warmup_init, max_lr) << "CreateLRScheduler: lr_warmup_init must be <= max_lr."; + CHECK_GE(config.lr_warmup_iters, 0) << "CreateLRScheduler: lr_warmup_iters must be >= 0."; + CHECK_GT(config.lr_decay_iters, 0) << "CreateLRScheduler: lr_decay_iters must be > 0."; + CHECK_LT(config.lr_warmup_iters, config.lr_decay_iters) + << "CreateLRScheduler: lr_warmup_iters must be < lr_decay_iters."; + CHECK(config.lr_decay_style == "constant" || config.lr_decay_style == "linear" || config.lr_decay_style == "cosine" + || config.lr_decay_style == "inverse-square-root") + << "CreateLRScheduler: unsupported lr_decay_style: " << config.lr_decay_style; + + std::shared_ptr main_scheduler; + const int64_t decay_iters_after_warmup = config.lr_decay_iters - config.lr_warmup_iters; + if (config.lr_decay_style == "constant") { + main_scheduler = LRScheduler::Create(optimizer, [](int64_t) { return 1.0f; }); + } else if (config.lr_decay_style == "linear") { + main_scheduler = LRScheduler::Create(optimizer, 1.0f, config.min_lr / max_lr, + decay_iters_after_warmup); + } else if (config.lr_decay_style == "cosine") { + main_scheduler = LRScheduler::Create( + optimizer, [max_lr, min_lr = config.min_lr, decay_iters_after_warmup](int64_t step) { + if (step > decay_iters_after_warmup) { + return min_lr / max_lr; + } + const float decay_ratio = static_cast(step) / static_cast(decay_iters_after_warmup); + CHECK_GE(decay_ratio, 0.0f) << "CreateLRScheduler: decay " + "ratio must be >= 0."; + CHECK_LE(decay_ratio, 1.0f) << "CreateLRScheduler: decay " + "ratio must be <= 1."; + const float coeff = 0.5f * (std::cos(std::numbers::pi_v * decay_ratio) + 1.0f); + return (min_lr + coeff * (max_lr - min_lr)) / max_lr; + }); + } else if (config.lr_decay_style == "inverse-square-root") { + main_scheduler = LRScheduler::Create( + optimizer, [max_lr, min_lr = config.min_lr, lr_warmup_iters = config.lr_warmup_iters, + lr_decay_iters = config.lr_decay_iters](int64_t step) { + const int64_t global_step = step + lr_warmup_iters; + if (global_step > lr_decay_iters) { + return min_lr / max_lr; + } + const auto warmup = static_cast(std::max(lr_warmup_iters, 1)); + const auto current = static_cast(std::max(global_step, 1)); + return std::max(min_lr, max_lr * std::sqrt(warmup) / std::sqrt(current)) / max_lr; + }); + } + + CHECK(main_scheduler) << "CreateLRScheduler: failed to create scheduler."; + if (config.lr_warmup_iters == 0) { + return main_scheduler; + } + + auto warmup_scheduler = LRScheduler::Create( + optimizer, + [lr_warmup_init = config.lr_warmup_init, max_lr, lr_warmup_iters = config.lr_warmup_iters](int64_t step) { + const float warmup_ratio = static_cast(step) / static_cast(lr_warmup_iters); + return (lr_warmup_init + (max_lr - lr_warmup_init) * warmup_ratio) / max_lr; + }); + return LRScheduler::Create( + std::move(optimizer), std::vector>{warmup_scheduler, main_scheduler}, + std::vector{config.lr_warmup_iters}); +} + +LRScheduler::LRScheduler(std::shared_ptr optimizer, int64_t last_step) + : optimizer_(std::move(optimizer)), last_step_(last_step), base_lr_(0.0f) { + CHECK(optimizer_) << "LRScheduler: optimizer must not be null."; + optimizer_->set_initial_learning_rate(optimizer_->learning_rate()); + base_lr_ = optimizer_->initial_learning_rate(); +} + +void LRScheduler::Step() { + ++last_step_; + ApplyLR(GetChainedFormLR()); +} + +void LRScheduler::Step(int64_t epoch) { + last_step_ = epoch; + ApplyLR(GetClosedFormLR()); +} + +void LRScheduler::InitialStep() { + is_initial_ = true; + Step(); + is_initial_ = false; +} + +void LRScheduler::ApplyLR(float lr) { optimizer_->set_learning_rate(lr); } + +float LRScheduler::GetChainedFormLR() const { return GetClosedFormLR(); } + +float LRScheduler::GetLR() const { return optimizer_->learning_rate(); } + +float LRScheduler::BaseLR() const { return base_lr_; } + +int64_t LRScheduler::LastStep() const { return last_step_; } + +bool LRScheduler::SharesOptimizerWith(const std::shared_ptr &opt) const { return optimizer_ == opt; } + +void LRScheduler::ResetStep(int64_t step) { last_step_ = step; } + +StateDict LRScheduler::State() const { + return { + {"last_step", last_step_}, + {"recover_lr", optimizer_->learning_rate()}, + {"base_lr", base_lr_}, + }; +} + +void LRScheduler::LoadState(const StateDict &state) { + last_step_ = std::get(state.at("last_step")); + recover_lr_ = std::get(state.at("recover_lr")); + base_lr_ = std::get(state.at("base_lr")); + optimizer_->set_learning_rate(recover_lr_); +} + +// Concrete LR Schedulers + +namespace lr_schedulers { + +// --- ConstantLR --- + +ConstantLR::ConstantLR(std::shared_ptr optimizer, float factor, int total_iters, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), factor_(factor), total_iters_(total_iters) { + CHECK_GT(factor_, 0.0f) << "ConstantLR: factor must be > 0."; + CHECK_LE(factor_, 1.0f) << "ConstantLR: factor must be <= 1."; +} + +float ConstantLR::GetClosedFormLR() const { return last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_; } + +float ConstantLR::GetChainedFormLR() const { + const float lr = optimizer_->learning_rate(); + if (last_step_ == 0) { + return lr * factor_; + } else if (last_step_ < total_iters_) { + return lr; + } else if (last_step_ == total_iters_) { + return lr / factor_; + } + return lr; +} + +// --- StepLR --- + +StepLR::StepLR(std::shared_ptr optimizer, int64_t step_size, float gamma, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), step_size_(step_size), gamma_(gamma) { + CHECK_GT(step_size_, 0) << "StepLR: step_size must be > 0."; + CHECK_GT(gamma_, 0.0f) << "StepLR: gamma must be > 0."; +} + +float StepLR::GetClosedFormLR() const { + return base_lr_ + * static_cast(std::pow(static_cast(gamma_), static_cast(last_step_ / step_size_))); +} + +float StepLR::GetChainedFormLR() const { + const float lr = optimizer_->learning_rate(); + if (last_step_ == 0 || (last_step_ % step_size_) != 0) { + return lr; + } + return lr * gamma_; +} + +LinearLR::LinearLR(std::shared_ptr optimizer, float start_factor, float end_factor, int64_t total_iters, + int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), start_factor_(start_factor), end_factor_(end_factor), + total_iters_(total_iters) { + CHECK_GT(start_factor_, 0.0f) << "LinearLR: start_factor must be > 0."; + CHECK_LE(start_factor_, 1.0f) << "LinearLR: start_factor must be <= 1."; + CHECK_GE(end_factor_, 0.0f) << "LinearLR: end_factor must be >= 0."; + CHECK_LE(end_factor_, 1.0f) << "LinearLR: end_factor must be <= 1."; + CHECK_GT(total_iters_, 0) << "LinearLR: total_iters must be > 0."; +} + +float LinearLR::GetClosedFormLR() const { + if (last_step_ >= total_iters_) { + return base_lr_ * end_factor_; + } + return base_lr_ + * (start_factor_ + + (end_factor_ - start_factor_) * static_cast(last_step_) / static_cast(total_iters_)); +} + +float LinearLR::GetChainedFormLR() const { + const float lr = optimizer_->learning_rate(); + if (last_step_ == 0) { + return lr * start_factor_; + } + if (last_step_ > total_iters_ || is_initial_) { + return lr; + } + if (last_step_ == total_iters_) { + const float prev_factor + = start_factor_ + + (end_factor_ - start_factor_) * static_cast(total_iters_ - 1) / static_cast(total_iters_); + return lr * (end_factor_ / prev_factor); + } + + const float numerator = end_factor_ - start_factor_; + const float denominator + = start_factor_ * static_cast(total_iters_) + static_cast(last_step_ - 1) * numerator; + return lr * (1.0f + numerator / denominator); +} + +LambdaLR::LambdaLR(std::shared_ptr optimizer, std::function lr_lambda, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), lr_lambda_(std::move(lr_lambda)) { + CHECK(lr_lambda_) << "LambdaLR: lr_lambda must not be null."; +} + +float LambdaLR::GetClosedFormLR() const { return base_lr_ * lr_lambda_(last_step_); } + +float SequentialLR::GetClosedFormLR() const { + LOG(FATAL) << "SequentialLR does not support closed-form LR. Use Step() without an explicit epoch."; + return base_lr_; +} + +SequentialLR::SequentialLR(std::shared_ptr optimizer, std::vector> schedulers, + std::vector milestones, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)), + milestones_(std::move(milestones)) { + CHECK(!schedulers_.empty()) << "SequentialLR requires at least one scheduler."; + + for (size_t i = 0; i < schedulers_.size(); ++i) { + CHECK(schedulers_[i]) << "SequentialLR: scheduler at index " << i << " must not be null."; + CHECK(schedulers_[i]->SharesOptimizerWith(optimizer_)) + << "SequentialLR: scheduler at index " << i << " must share the same optimizer."; + } + + CHECK_EQ(milestones_.size(), schedulers_.size() - 1) + << "SequentialLR: milestones count must be schedulers count - 1."; + + for (size_t i = 1; i < milestones_.size(); ++i) { + CHECK_GT(milestones_[i], milestones_[i - 1]) << "Milestones must be strictly increasing."; + } +} + +void SequentialLR::InitialStep() { + + optimizer_->set_learning_rate(schedulers_[0]->BaseLR()); + + UndoChildInitialSteps(); + + ++last_step_; + schedulers_[0]->InitialStep(); +} + +void SequentialLR::UndoChildInitialSteps() { + for (auto &sched : schedulers_) { + if (auto nested = std::dynamic_pointer_cast(sched)) { + nested->UndoChildInitialSteps(); + } + sched->ResetStep(sched->LastStep() - 1); + } +} + +void SequentialLR::Step() { + ++last_step_; + size_t idx = std::upper_bound(milestones_.begin(), milestones_.end(), last_step_) - milestones_.begin(); + + auto &scheduler = schedulers_[idx]; + + if (idx > 0 && milestones_[idx - 1] == last_step_) { + scheduler->Step(0); + } else { + scheduler->Step(); + } +} + +StateDict SequentialLR::State() const { + StateDict state; + state["last_step"] = last_step_; + state["recover_lr"] = optimizer_->learning_rate(); + state["base_lr"] = base_lr_; + for (size_t i = 0; i < schedulers_.size(); ++i) { + auto sub_state = schedulers_[i]->State(); + for (const auto &[key, value] : sub_state) { state["scheduler_" + std::to_string(i) + "." + key] = value; } + } + return state; +} + +void SequentialLR::LoadState(const StateDict &state) { + last_step_ = std::get(state.at("last_step")); + recover_lr_ = std::get(state.at("recover_lr")); + base_lr_ = std::get(state.at("base_lr")); + + for (size_t i = 0; i < schedulers_.size(); ++i) { + StateDict sub_state; + std::string prefix = "scheduler_" + std::to_string(i) + "."; + for (const auto &[key, value] : state) { + if (key.substr(0, prefix.size()) == prefix) { + sub_state[key.substr(prefix.size())] = value; + } + } + if (!sub_state.empty()) { + schedulers_[i]->LoadState(sub_state); + } + } + optimizer_->set_learning_rate(recover_lr_); +} + +ChainedScheduler::ChainedScheduler(std::shared_ptr optimizer, + std::vector> schedulers, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)) { + CHECK(!schedulers_.empty()) << "ChainedScheduler requires at least one scheduler."; + + for (size_t i = 0; i < schedulers_.size(); ++i) { + CHECK(schedulers_[i]) << "ChainedScheduler: scheduler at index " << i << " must not be null."; + CHECK(schedulers_[i]->SharesOptimizerWith(optimizer_)) + << "ChainedScheduler: scheduler at index " << i << " must share the same optimizer."; + } +} + +float ChainedScheduler::GetClosedFormLR() const { + LOG(FATAL) << "ChainedScheduler does not support closed-form LR. Use Step() without an explicit epoch."; + return base_lr_; +} + +void ChainedScheduler::InitialStep() { last_step_ = 0; } + +void ChainedScheduler::Step() { + ++last_step_; + for (auto &sched : schedulers_) { sched->Step(); } +} + +StateDict ChainedScheduler::State() const { + StateDict state = LRScheduler::State(); + for (size_t i = 0; i < schedulers_.size(); ++i) { + auto sub_state = schedulers_[i]->State(); + for (const auto &[key, value] : sub_state) { state["scheduler_" + std::to_string(i) + "." + key] = value; } + } + return state; +} + +void ChainedScheduler::LoadState(const StateDict &state) { + LRScheduler::LoadState(state); + for (size_t i = 0; i < schedulers_.size(); ++i) { + StateDict sub_state; + std::string prefix = "scheduler_" + std::to_string(i) + "."; + for (const auto &[key, value] : state) { + if (key.substr(0, prefix.size()) == prefix) { + sub_state[key.substr(prefix.size())] = value; + } + } + if (!sub_state.empty()) { + schedulers_[i]->LoadState(sub_state); + } + } +} + +} // namespace lr_schedulers +} // namespace infini_train diff --git a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc index 55e5800b..2531ca60 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc @@ -114,6 +114,20 @@ void DistributedOptimizer::ZeroGrad(bool set_to_none) { } } +void DistributedOptimizer::set_learning_rate(float lr) { + Optimizer::set_learning_rate(lr); + if (base_optimizer_) { + base_optimizer_->set_learning_rate(lr); + } +} + +float DistributedOptimizer::learning_rate() const { + if (base_optimizer_) { + return base_optimizer_->learning_rate(); + } + return Optimizer::learning_rate(); +} + void DistributedOptimizer::Step() { // 1. Ensure grads are synced FinishGradSync(); diff --git a/infini_train/src/optimizer.cc b/infini_train/src/optimizer.cc index d5589b01..1712f50f 100644 --- a/infini_train/src/optimizer.cc +++ b/infini_train/src/optimizer.cc @@ -8,16 +8,32 @@ #include "infini_train/include/tensor.h" namespace infini_train { -Optimizer::Optimizer(const std::vector> ¶ms) : params_(params) {} +Optimizer::Optimizer(const std::vector> ¶ms, float learning_rate) + : params_(params), learning_rate_(learning_rate) {} void Optimizer::ZeroGrad(bool set_to_none) { for (auto param : params_) { param->ZeroGrad(set_to_none); } } +void Optimizer::set_learning_rate(float lr) { learning_rate_ = lr; } + +float Optimizer::learning_rate() const { return learning_rate_; } + +float Optimizer::initial_learning_rate() const { + CHECK(initial_lr_set_) << "Optimizer: initial_learning_rate not set. " + "Use with an LRScheduler first."; + return initial_learning_rate_; +} + +void Optimizer::set_initial_learning_rate(float lr) { + if (!initial_lr_set_) { + initial_learning_rate_ = lr; + initial_lr_set_ = true; + } +} namespace optimizers { -SGD::SGD(const std::vector> ¶ms, float learning_rate) - : Optimizer(params), learning_rate_(learning_rate) {} +SGD::SGD(const std::vector> ¶ms, float learning_rate) : Optimizer(params, learning_rate) {} void SGD::Step() { for (auto param : params_) { @@ -33,7 +49,7 @@ void SGD::Step() { } Adam::Adam(const std::vector> ¶ms, float learning_rate, float beta1, float beta2, float eps) - : Optimizer(params), t_(0), learning_rate_(learning_rate), beta1_(beta1), beta2_(beta2), eps_(eps) { + : Optimizer(params, learning_rate), t_(0), beta1_(beta1), beta2_(beta2), eps_(eps) { for (const auto ¶m : params_) { m_.emplace_back(std::make_shared(param->Dims(), param->Dtype(), param->GetDevice())); diff --git a/scripts/test_config.json b/scripts/test_config.json index 54332f70..81282289 100644 --- a/scripts/test_config.json +++ b/scripts/test_config.json @@ -304,6 +304,182 @@ } ] }, + { + "tag": "lr_scheduler", + "tests": [ + { + "id": "3_none_distopt", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "use_distributed_optimizer": true, + "learning_rate": 0.00001, + "lr_decay_style": "none" + } + }, + { + "id": "4_constant_tp4", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "constant", + "lr_warmup_iters": 0, + "lr_decay_iters": 0 + } + }, + { + "id": "5_linear_tp4_sp_distopt", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "sequence_parallel": true, + "use_distributed_optimizer": true, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "linear", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 10 + } + }, + { + "id": "6_cosine_pp8", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "pipeline_parallel": 8, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "cosine", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 10 + } + }, + { + "id": "7_inverse_sqrt_pp4_vpp2", + "args": { + "dtype": "float32", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "pipeline_parallel": 4, + "virtual_pipeline_parallel": 2, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "inverse-square-root", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 10 + } + }, + { + "id": "8_cosine_all_parallel_distopt", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 2, + "sequence_parallel": true, + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2, + "use_distributed_optimizer": true, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "cosine", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 10 + } + }, + { + "id": "3_bfloat16_linear", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "linear", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 0 + } + }, + { + "id": "4_bfloat16_inverse_sqrt_tp4_distopt", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "use_distributed_optimizer": true, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "inverse-square-root", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 10 + } + }, + { + "id": "5_bfloat16_constant_tp4_sp", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "sequence_parallel": true, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "constant", + "lr_warmup_iters": 0, + "lr_decay_iters": 10 + } + }, + { + "id": "8_bfloat16_none_all_parallel", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 2, + "sequence_parallel": true, + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2, + "learning_rate": 0.00001, + "lr_decay_style": "none" + } + } + ] + }, { "tag": "lora", "tests": [ diff --git a/tests/optimizer/test_lr_scheduler.cc b/tests/optimizer/test_lr_scheduler.cc new file mode 100644 index 00000000..6d346bce --- /dev/null +++ b/tests/optimizer/test_lr_scheduler.cc @@ -0,0 +1,336 @@ +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "infini_train/include/lr_scheduler.h" +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +#include "tests/common/test_utils.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace { + +constexpr float kBaseLR = 0.1f; +constexpr float kEps = 1e-6f; + +class LRSchedulerTest : public infini_train::test::InfiniTrainTest {}; + +class LinearDecayScheduler : public LRScheduler { +public: + LinearDecayScheduler(std::shared_ptr optimizer, int64_t total_steps, int64_t last_step = -1) + : LRScheduler(std::move(optimizer), last_step), total_steps_(total_steps) {} + +protected: + float GetClosedFormLR() const override { + if (last_step_ >= total_steps_) { + return 0.0f; + } + return base_lr_ * (1.0f - static_cast(last_step_) / static_cast(total_steps_)); + } + +private: + int64_t total_steps_; +}; + +std::shared_ptr MakeDummyOptimizer(float lr) { + std::vector> empty_params; + return std::make_shared(empty_params, lr); +} + +void ExpectStepSequence(const std::shared_ptr &scheduler, std::initializer_list expected, + float eps = kEps) { + for (float expected_lr : expected) { + scheduler->Step(); + EXPECT_NEAR(scheduler->GetLR(), expected_lr, eps); + } +} + +std::shared_ptr MakeSequentialScheduler(std::shared_ptr opt) { + auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, + /*total_iters=*/3); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); + return LRScheduler::Create(opt, + /*schedulers=*/std::vector>{linear, step_lr}, + /*milestones=*/std::vector{3}); +} + +std::shared_ptr MakeChainedScheduler(std::shared_ptr opt) { + auto step_lr = LRScheduler::Create(opt, /*step_size=*/2, /*gamma=*/0.5f); + auto lambda_lr = LRScheduler::Create(opt, /*lr_lambda=*/[](int64_t step) { return 1.0f - 0.05f * step; }); + return LRScheduler::Create( + opt, /*schedulers=*/std::vector>{step_lr, lambda_lr}); +} + +} // namespace + +TEST_P(LRSchedulerTest, BaseSchedulerStateRoundTripAndResume) { + constexpr int64_t kTotalSteps = 20; + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = LRScheduler::Create(opt_ref, /*total_steps=*/kTotalSteps); + ExpectStepSequence(sched_ref, {0.095f, 0.09f, 0.085f, 0.08f, 0.075f, 0.07f, 0.065f}); + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = LRScheduler::Create(opt_a, /*total_steps=*/kTotalSteps); + ExpectStepSequence(sched_a, {0.095f, 0.09f, 0.085f}); + + StateDict state = sched_a->State(); + EXPECT_EQ(state.count("last_step"), 1); + EXPECT_EQ(state.count("recover_lr"), 1); + EXPECT_EQ(state.count("base_lr"), 1); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = LRScheduler::Create(opt_b, /*total_steps=*/kTotalSteps); + sched_b->LoadState(state); + ExpectStepSequence(sched_b, {0.08f, 0.075f, 0.07f, 0.065f}); + + EXPECT_EQ(sched_b->LastStep(), sched_ref->LastStep()); + EXPECT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); + EXPECT_NEAR(opt_b->learning_rate(), sched_ref->GetLR(), kEps); +} + +TEST_P(LRSchedulerTest, ConstantLRMatchesExpectedSchedule) { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/3); + + EXPECT_EQ(sched->LastStep(), 0); + EXPECT_NEAR(sched->GetLR(), 0.05f, kEps); + EXPECT_NEAR(opt->learning_rate(), 0.05f, kEps); + + ExpectStepSequence(sched, {0.05f, 0.05f, 0.1f, 0.1f, 0.1f}); +} + +TEST_P(LRSchedulerTest, LinearLRMatchesExpectedScheduleAndClosedForm) { + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto chainable = LRScheduler::Create(opt_a, /*start_factor=*/0.2f, /*end_factor=*/1.0f, + /*total_iters=*/5); + + EXPECT_NEAR(chainable->GetLR(), 0.02f, kEps); + ExpectStepSequence(chainable, {0.036f, 0.052f, 0.068f, 0.084f, 0.1f, 0.1f, 0.1f}, 1e-7f); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto closed_form = LRScheduler::Create(opt_b, /*start_factor=*/0.2f, /*end_factor=*/1.0f, + /*total_iters=*/5); + auto opt_c = MakeDummyOptimizer(kBaseLR); + auto chainable_again = LRScheduler::Create(opt_c, /*start_factor=*/0.2f, /*end_factor=*/1.0f, + /*total_iters=*/5); + + for (int epoch = 1; epoch <= 10; ++epoch) { + chainable_again->Step(); + closed_form->Step(epoch); + EXPECT_NEAR(chainable_again->GetLR(), closed_form->GetLR(), kEps); + } +} + +TEST_P(LRSchedulerTest, StepLRMatchesExpectedScheduleAndClosedForm) { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.1f); + + EXPECT_NEAR(sched->GetLR(), kBaseLR, kEps); + ExpectStepSequence(sched, {0.1f, 0.1f, 0.01f, 0.01f, 0.01f, 0.001f, 0.001f}, 1e-7f); + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto chainable = LRScheduler::Create(opt_a, /*step_size=*/3, /*gamma=*/0.1f); + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto closed_form = LRScheduler::Create(opt_b, /*step_size=*/3, /*gamma=*/0.1f); + + for (int epoch = 1; epoch <= 12; ++epoch) { + chainable->Step(); + closed_form->Step(epoch); + EXPECT_NEAR(chainable->GetLR(), closed_form->GetLR(), 1e-7f); + } +} + +TEST_P(LRSchedulerTest, LambdaLRMatchesExpectedScheduleAndRestoresState) { + auto lambda_fn = [](int64_t step) { return static_cast(std::pow(0.95, step)); }; + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = LRScheduler::Create(opt_ref, /*lr_lambda=*/lambda_fn); + ExpectStepSequence(sched_ref, {0.095f, 0.09025f, 0.0857375f, 0.08145062f}, 1e-6f); + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = LRScheduler::Create(opt_a, /*lr_lambda=*/lambda_fn); + ExpectStepSequence(sched_a, {0.095f, 0.09025f}, 1e-6f); + StateDict state = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = LRScheduler::Create(opt_b, /*lr_lambda=*/lambda_fn); + sched_b->LoadState(state); + ExpectStepSequence(sched_b, {0.0857375f, 0.08145062f}, 1e-6f); + + EXPECT_EQ(sched_b->LastStep(), sched_ref->LastStep()); + EXPECT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), 1e-6f); +} + +TEST_P(LRSchedulerTest, SequentialLRSwitchesAtMilestonesAndRestoresState) { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = MakeSequentialScheduler(opt); + + EXPECT_NEAR(sched->GetLR(), 0.0f, kEps); + ExpectStepSequence(sched, {0.1f / 3.0f, 0.2f / 3.0f, 0.1f, 0.1f, 0.1f, 0.05f}, 1e-5f); + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = MakeSequentialScheduler(opt_ref); + ExpectStepSequence(sched_ref, {0.1f / 3.0f, 0.2f / 3.0f, 0.1f, 0.1f, 0.1f, 0.05f, 0.05f, 0.05f}, 1e-5f); + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = MakeSequentialScheduler(opt_a); + ExpectStepSequence(sched_a, {0.1f / 3.0f, 0.2f / 3.0f, 0.1f, 0.1f}, 1e-5f); + StateDict state = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = MakeSequentialScheduler(opt_b); + sched_b->LoadState(state); + ExpectStepSequence(sched_b, {0.1f, 0.05f, 0.05f, 0.05f}, 1e-5f); + + EXPECT_EQ(sched_b->LastStep(), sched_ref->LastStep()); + EXPECT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); +} + +TEST_P(LRSchedulerTest, ChainedSchedulerComposesChildrenAndRestoresState) { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = MakeChainedScheduler(opt); + + EXPECT_NEAR(sched->GetLR(), 0.1f, kEps); + ExpectStepSequence(sched, {0.095f, 0.09f, 0.085f, 0.08f}, kEps); + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = MakeChainedScheduler(opt_ref); + ExpectStepSequence(sched_ref, {0.095f, 0.09f, 0.085f, 0.08f, 0.075f, 0.07f}, kEps); + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = MakeChainedScheduler(opt_a); + ExpectStepSequence(sched_a, {0.095f, 0.09f, 0.085f}, kEps); + StateDict state = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = MakeChainedScheduler(opt_b); + sched_b->LoadState(state); + ExpectStepSequence(sched_b, {0.08f, 0.075f, 0.07f}, kEps); + + EXPECT_EQ(sched_b->LastStep(), sched_ref->LastStep()); + EXPECT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); +} + +TEST_P(LRSchedulerTest, TrainingSchedulerFactoryBuildsCommonDecayStyles) { + { + auto opt = MakeDummyOptimizer(0.1f); + auto sched = CreateLRScheduler(opt, { + .lr_decay_style = "constant", + .lr = 0.1f, + .min_lr = 0.0f, + .lr_decay_iters = 10, + .lr_warmup_iters = 0, + .lr_warmup_init = 0.0f, + }); + EXPECT_NEAR(sched->GetLR(), 0.1f, kEps); + ExpectStepSequence(sched, {0.1f, 0.1f, 0.1f}, kEps); + } + + { + auto opt = MakeDummyOptimizer(1.0f); + auto sched = CreateLRScheduler(opt, { + .lr_decay_style = "linear", + .lr = 1.0f, + .min_lr = 0.1f, + .lr_decay_iters = 6, + .lr_warmup_iters = 2, + .lr_warmup_init = 0.0f, + }); + EXPECT_NEAR(sched->GetLR(), 0.0f, kEps); + ExpectStepSequence(sched, {0.5f, 1.0f, 0.775f, 0.55f, 0.325f, 0.1f}, kEps); + } + + { + auto opt = MakeDummyOptimizer(1.0f); + auto sched = CreateLRScheduler(opt, { + .lr_decay_style = "cosine", + .lr = 1.0f, + .min_lr = 0.0f, + .lr_decay_iters = 4, + .lr_warmup_iters = 0, + .lr_warmup_init = 0.0f, + }); + ExpectStepSequence(sched, {0.853553f, 0.5f, 0.146447f, 0.0f}, 1e-5f); + } + + { + auto opt = MakeDummyOptimizer(1.0f); + auto sched = CreateLRScheduler(opt, { + .lr_decay_style = "inverse-square-root", + .lr = 1.0f, + .min_lr = 0.1f, + .lr_decay_iters = 10, + .lr_warmup_iters = 2, + .lr_warmup_init = 0.0f, + }); + ExpectStepSequence(sched, {0.5f, 1.0f, 0.8164966f, 0.7071068f, 0.6324555f}, 1e-5f); + } +} + +TEST_P(LRSchedulerTest, TrainingSchedulerFactoryReturnsNullForNoneStyle) { + auto opt = MakeDummyOptimizer(0.1f); + auto sched = CreateLRScheduler(opt, { + .lr_decay_style = "none", + .lr = 0.1f, + .min_lr = 0.0f, + .lr_decay_iters = 10, + .lr_warmup_iters = 0, + .lr_warmup_init = 0.0f, + }); + + EXPECT_EQ(sched, nullptr); +} + +TEST_P(LRSchedulerTest, RejectsInvalidSchedulerConfigurations) { + EXPECT_DEATH( + { + auto opt = MakeDummyOptimizer(0.1f); + auto sched = LRScheduler::Create(opt, /*step_size=*/0, /*gamma=*/0.1f); + (void)sched; + }, + ""); + + EXPECT_DEATH( + { + auto opt = MakeDummyOptimizer(0.1f); + auto sched = LRScheduler::Create(opt, /*lr_lambda=*/LambdaLR::LambdaFunc{}); + (void)sched; + }, + ""); + + EXPECT_DEATH( + { + auto opt1 = MakeDummyOptimizer(0.1f); + auto opt2 = MakeDummyOptimizer(0.1f); + auto linear = LRScheduler::Create(opt1, /*start_factor=*/0.5f, /*end_factor=*/1.0f, + /*total_iters=*/2); + auto step_lr = LRScheduler::Create(opt2, /*step_size=*/2, /*gamma=*/0.5f); + auto sched = LRScheduler::Create( + opt1, /*schedulers=*/std::vector>{linear, step_lr}, + /*milestones=*/std::vector{1}); + (void)sched; + }, + ""); + + EXPECT_DEATH( + { + auto opt = MakeDummyOptimizer(0.1f); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/2, /*gamma=*/0.5f); + std::shared_ptr sched = LRScheduler::Create( + opt, + /*schedulers=*/std::vector>{step_lr}); + sched->Step(0); + }, + ""); +} + +INFINI_TRAIN_REGISTER_TEST(LRSchedulerTest);