Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
#include "infini_train/include/lr_scheduler.h"
#include "infini_train/include/nn/lora/lora_utils.h"
#include "infini_train/include/nn/modules/loss.h"
#include "infini_train/include/nn/modules/module.h"
Expand Down Expand Up @@ -56,8 +57,14 @@ DEFINE_uint32(num_iteration, 10, "number of iterations to run");
DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation");
DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations");
DEFINE_double(learning_rate, 1e-4, "Peak learning rate.");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
// lr scheduler
DEFINE_double(min_lr, 0.0, "Minimum learning rate.");
DEFINE_string(lr_decay_style, "constant", "LR decay style: none|constant|linear|cosine|inverse-square-root");
DEFINE_int64(lr_warmup_iters, 0, "Number of linear warmup iterations.");
DEFINE_double(lr_warmup_init, 0.0, "Initial learning rate at the start of warmup.");
DEFINE_int64(lr_decay_iters, 0, "Number of iterations to decay LR over (0 = num_iteration).");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -100,6 +107,8 @@ constexpr char kDeviceCPU[] = "cpu";
constexpr char kDeviceCUDA[] = "cuda";
constexpr char kDtypeFP32[] = "float32";
constexpr char kDtypeBF16[] = "bfloat16";
const std::unordered_set<std::string> kSupportedLRDecayStyles
= {"none", "constant", "linear", "cosine", "inverse-square-root"};

//
const std::unordered_map<std::string, nn::TransformerConfig> kModelToConfigs = {
Expand All @@ -114,6 +123,8 @@ const std::unordered_map<std::string, nn::TransformerConfig> kModelToConfigs = {
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
DEFINE_validator(device,
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
DEFINE_validator(lr_decay_style,
[](const char *, const std::string &value) { return kSupportedLRDecayStyles.contains(value); });

void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;
Expand Down Expand Up @@ -305,6 +316,16 @@ void Train(const nn::parallel::Rank &rank) {
optimizer = optimizer_creator(params_to_optimize);
}

const int64_t lr_decay_iters = FLAGS_lr_decay_iters > 0 ? FLAGS_lr_decay_iters : FLAGS_num_iteration;
TrainingLRSchedulerConfig sched_config;
sched_config.lr = static_cast<float>(FLAGS_learning_rate);
sched_config.min_lr = static_cast<float>(FLAGS_min_lr);
sched_config.lr_decay_style = FLAGS_lr_decay_style;
sched_config.lr_decay_iters = lr_decay_iters;
sched_config.lr_warmup_iters = FLAGS_lr_warmup_iters;
sched_config.lr_warmup_init = static_cast<float>(FLAGS_lr_warmup_init);
auto scheduler = CreateLRScheduler(optimizer, sched_config);

auto train_iter = train_loader.begin();
std::shared_ptr<nn::Module> loss_fn
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(
Expand Down Expand Up @@ -348,6 +369,7 @@ void Train(const nn::parallel::Rank &rank) {
Profiler::Instance().SetTag("Step_" + std::to_string(step));
#endif

const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
float lossf = 0.0f;
// model->Train();
if (pp_world_size == 1) {
Expand Down Expand Up @@ -392,6 +414,9 @@ void Train(const nn::parallel::Rank &rank) {
}

optimizer->Step();
if (scheduler) {
scheduler->Step();
}
} else {
auto [x, y] = *train_iter;
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
Expand All @@ -401,6 +426,9 @@ void Train(const nn::parallel::Rank &rank) {
y = std::make_shared<Tensor>(y->To(device));

lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype);
if (scheduler) {
scheduler->Step();
}
}

if (ddp_world_size > 1) {
Expand All @@ -416,11 +444,10 @@ void Train(const nn::parallel::Rank &rank) {
if (rank.IsLastRank()) {
size_t used_mb = 0, reserved_mb = 0;
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);

LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps,
used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
pp_world_size);

if ((step + 1) % FLAGS_freq_generate_txt == 0) {
Expand Down
35 changes: 31 additions & 4 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
#include "infini_train/include/lr_scheduler.h"
#include "infini_train/include/nn/lora/lora_utils.h"
#include "infini_train/include/nn/modules/loss.h"
#include "infini_train/include/nn/modules/module.h"
Expand Down Expand Up @@ -55,8 +56,14 @@ DEFINE_uint32(num_iteration, 10, "number of iterations to run");
DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation");
DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations");
DEFINE_double(learning_rate, 1e-5, "Peak learning rate.");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
// lr scheduler
DEFINE_double(min_lr, 0.0, "Minimum learning rate.");
DEFINE_string(lr_decay_style, "constant", "LR decay style: none|constant|linear|cosine|inverse-square-root");
DEFINE_int64(lr_warmup_iters, 0, "Number of linear warmup iterations.");
DEFINE_double(lr_warmup_init, 0.0, "Initial learning rate at the start of warmup.");
DEFINE_int64(lr_decay_iters, 0, "Number of iterations to decay LR over (0 = num_iteration).");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -95,11 +102,15 @@ constexpr char kDeviceCPU[] = "cpu";
constexpr char kDeviceCUDA[] = "cuda";
constexpr char kDtypeFP32[] = "float32";
constexpr char kDtypeBF16[] = "bfloat16";
const std::unordered_set<std::string> kSupportedLRDecayStyles
= {"none", "constant", "linear", "cosine", "inverse-square-root"};
} // namespace

DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
DEFINE_validator(device,
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
DEFINE_validator(lr_decay_style,
[](const char *, const std::string &value) { return kSupportedLRDecayStyles.contains(value); });

void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;
Expand Down Expand Up @@ -284,6 +295,16 @@ void Train(const nn::parallel::Rank &rank) {
optimizer = optimizer_creator(params_to_optimize);
}

const int64_t lr_decay_iters = FLAGS_lr_decay_iters > 0 ? FLAGS_lr_decay_iters : FLAGS_num_iteration;
TrainingLRSchedulerConfig sched_config;
sched_config.lr = static_cast<float>(FLAGS_learning_rate);
sched_config.min_lr = static_cast<float>(FLAGS_min_lr);
sched_config.lr_decay_style = FLAGS_lr_decay_style;
sched_config.lr_decay_iters = lr_decay_iters;
sched_config.lr_warmup_iters = FLAGS_lr_warmup_iters;
sched_config.lr_warmup_init = static_cast<float>(FLAGS_lr_warmup_init);
auto scheduler = CreateLRScheduler(optimizer, sched_config);

auto train_iter = train_loader.begin();
std::shared_ptr<nn::Module> loss_fn
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(std::make_shared<VocabParallelCrossEntropyLoss>())
Expand Down Expand Up @@ -324,6 +345,7 @@ void Train(const nn::parallel::Rank &rank) {
Profiler::Instance().SetTag("Step_" + std::to_string(step));
#endif

const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
float lossf = 0.0f;
if (pp_world_size == 1) {
// model->Train();
Expand Down Expand Up @@ -367,6 +389,9 @@ void Train(const nn::parallel::Rank &rank) {
}

optimizer->Step();
if (scheduler) {
scheduler->Step();
}
} else {
auto [x, y] = *train_iter;
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
Expand All @@ -376,6 +401,9 @@ void Train(const nn::parallel::Rank &rank) {
y = std::make_shared<Tensor>(y->To(device));

lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype);
if (scheduler) {
scheduler->Step();
}
}

if (ddp_world_size > 1) {
Expand All @@ -391,11 +419,10 @@ void Train(const nn::parallel::Rank &rank) {
if (rank.IsLastRank()) {
size_t used_mb = 0, reserved_mb = 0;
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);

LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps,
used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
pp_world_size);

if ((step + 1) % FLAGS_freq_generate_txt == 0) {
Expand Down
173 changes: 173 additions & 0 deletions infini_train/include/lr_scheduler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
#pragma once

#include <cmath>
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <variant>
#include <vector>

namespace infini_train {

class Optimizer;

using StateValue = std::variant<int64_t, float, double, std::string, std::vector<float>>;
using StateDict = std::unordered_map<std::string, StateValue>;

struct TrainingLRSchedulerConfig {
std::string lr_decay_style = "constant";
float lr = 0.0f;
float min_lr = 0.0f;
int64_t lr_decay_iters = 1;
int64_t lr_warmup_iters = 0;
float lr_warmup_init = 0.0f;
};

class LRScheduler {
public:
template <typename T, typename... Args> static std::shared_ptr<T> Create(Args &&...args) {
auto scheduler = std::make_shared<T>(std::forward<Args>(args)...);
scheduler->InitialStep();
return scheduler;
}

explicit LRScheduler(std::shared_ptr<Optimizer> optimizer, int64_t last_step = -1);
virtual ~LRScheduler() = default;

LRScheduler(const LRScheduler &) = delete;
LRScheduler &operator=(const LRScheduler &) = delete;

virtual void Step();
virtual void Step(int64_t epoch);
virtual void InitialStep();

float GetLR() const;
float BaseLR() const;
int64_t LastStep() const;

void ResetStep(int64_t step = -1);
virtual StateDict State() const;
virtual void LoadState(const StateDict &state);

bool SharesOptimizerWith(const std::shared_ptr<Optimizer> &opt) const;

protected:
virtual float GetClosedFormLR() const = 0;
virtual float GetChainedFormLR() const;
void ApplyLR(float lr);

std::shared_ptr<Optimizer> optimizer_;
int64_t last_step_;
float recover_lr_;
float base_lr_;
bool is_initial_ = false;
};

std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimizer,
const TrainingLRSchedulerConfig &config);

namespace lr_schedulers {

class ConstantLR : public LRScheduler {
public:
ConstantLR(std::shared_ptr<Optimizer> optimizer, float factor = 1.0f / 3.0f, int total_iters = 5,
int64_t last_step = -1);
~ConstantLR() override = default;

protected:
float GetChainedFormLR() const override;
float GetClosedFormLR() const override;

private:
const float factor_;
const int64_t total_iters_;
};

class StepLR : public LRScheduler {
public:
StepLR(std::shared_ptr<Optimizer> optimizer, int64_t step_size, float gamma = 0.1f, int64_t last_step = -1);
~StepLR() override = default;

protected:
float GetChainedFormLR() const override;
float GetClosedFormLR() const override;

private:
const int64_t step_size_;
const float gamma_;
};

class LinearLR : public LRScheduler {
public:
LinearLR(std::shared_ptr<Optimizer> optimizer, float start_factor = 1.0f / 3.0f, float end_factor = 1.0f,
int64_t total_iters = 5, int64_t last_step = -1);
~LinearLR() override = default;

protected:
float GetChainedFormLR() const override;
float GetClosedFormLR() const override;

private:
const float start_factor_;
const float end_factor_;
const int64_t total_iters_;
};

class LambdaLR : public LRScheduler {
public:
using LambdaFunc = std::function<float(int64_t)>;

LambdaLR(std::shared_ptr<Optimizer> optimizer, LambdaFunc lr_lambda, int64_t last_step = -1);
~LambdaLR() override = default;

protected:
float GetClosedFormLR() const override;

private:
const LambdaFunc lr_lambda_;
};

class SequentialLR : public LRScheduler {
public:
SequentialLR(std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers,
std::vector<int64_t> milestones, int64_t last_step = -1);
~SequentialLR() override = default;

void Step() override;
void InitialStep() override;

StateDict State() const override;
void LoadState(const StateDict &state) override;

protected:
float GetClosedFormLR() const override;
void UndoChildInitialSteps();

private:
std::vector<std::shared_ptr<LRScheduler>> schedulers_;
std::vector<int64_t> milestones_;
};

class ChainedScheduler : public LRScheduler {
public:
ChainedScheduler(std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers,
int64_t last_step = -1);
~ChainedScheduler() override = default;

void Step() override;
void InitialStep() override;

StateDict State() const override;
void LoadState(const StateDict &state) override;

protected:
float GetClosedFormLR() const override;

private:
std::vector<std::shared_ptr<LRScheduler>> schedulers_;
};

} // namespace lr_schedulers
} // namespace infini_train
3 changes: 3 additions & 0 deletions infini_train/include/nn/parallel/ddp/distributed_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class DistributedOptimizer final : public infini_train::Optimizer {
void StartParamSync(bool force_sync = false);
void FinishParamSync(bool skip_next_bucket_dispatch = false);

virtual void set_learning_rate(float lr) override;
virtual float learning_rate() const override;

private:
void BuildShardParamsAndBindGrads();

Expand Down
Loading
Loading