diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index d27421c8..8d3b7efa 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -15,6 +15,7 @@ #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" +#include "infini_train/include/lr_scheduler.h" #include "infini_train/include/nn/lora/lora_utils.h" #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" @@ -60,8 +61,14 @@ DEFINE_uint32(num_iteration, 10, "number of iterations to run"); DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation"); DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization -DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations"); +DEFINE_double(learning_rate, 1e-4, "Peak learning rate."); DEFINE_int32(zero_stage, 0, "ZeRO stage (0/1/2/3); 0 disables DistributedOptimizer"); +// lr scheduler +DEFINE_double(min_lr, 0.0, "Minimum learning rate."); +DEFINE_string(lr_decay_style, "constant", "LR decay style: none|constant|linear|cosine|inverse-square-root"); +DEFINE_int64(lr_warmup_iters, 0, "Number of linear warmup iterations."); +DEFINE_double(lr_warmup_init, 0.0, "Initial learning rate at the start of warmup."); +DEFINE_int64(lr_decay_iters, 0, "Number of iterations to decay LR over (0 = num_iteration)."); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -109,6 +116,8 @@ constexpr char kDeviceCPU[] = "cpu"; constexpr char kDeviceCUDA[] = "cuda"; constexpr char kDtypeFP32[] = "float32"; constexpr char kDtypeBF16[] = "bfloat16"; +const std::unordered_set kSupportedLRDecayStyles + = {"none", "constant", "linear", "cosine", "inverse-square-root"}; // const std::unordered_map kModelToConfigs = { @@ -124,6 +133,8 @@ DEFINE_validator(model, [](const char *, const std::string &value) { return kSup DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 0 && value <= 3; }); +DEFINE_validator(lr_decay_style, + [](const char *, const std::string &value) { return kSupportedLRDecayStyles.contains(value); }); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -314,6 +325,16 @@ void Train(const nn::parallel::Rank &rank) { optimizer = optimizer_creator(params_to_optimize); } + const int64_t lr_decay_iters = FLAGS_lr_decay_iters > 0 ? FLAGS_lr_decay_iters : FLAGS_num_iteration; + TrainingLRSchedulerConfig sched_config; + sched_config.lr = static_cast(FLAGS_learning_rate); + sched_config.min_lr = static_cast(FLAGS_min_lr); + sched_config.lr_decay_style = FLAGS_lr_decay_style; + sched_config.lr_decay_iters = lr_decay_iters; + sched_config.lr_warmup_iters = FLAGS_lr_warmup_iters; + sched_config.lr_warmup_init = static_cast(FLAGS_lr_warmup_init); + auto scheduler = CreateLRScheduler(optimizer, sched_config); + auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast( @@ -332,7 +353,8 @@ void Train(const nn::parallel::Rank &rank) { .optimizer = optimizer, .model_config = model_config, .state = state, - .load_optimizer_state = false}); + .load_optimizer_state = false, + .lr_scheduler = scheduler}); start_step = resume_result.global_step; size_t consumed_batches = resume_result.consumed_batches; @@ -351,7 +373,7 @@ void Train(const nn::parallel::Rank &rank) { .save_dir = save_dir, .global_step = global_step, .consumed_batches = consumed_batches, - .last_lr = FLAGS_learning_rate, + .last_lr = optimizer->learning_rate(), .n_layer = model_config.n_layer, .n_head = model_config.n_head, .n_kv_head = model_config.n_kv_head, @@ -367,6 +389,7 @@ void Train(const nn::parallel::Rank &rank) { .rank = rank, .model = *model, .optimizer = *optimizer, + .lr_scheduler = scheduler.get(), }); }; @@ -403,6 +426,7 @@ void Train(const nn::parallel::Rank &rank) { Profiler::Instance().SetTag("Step_" + std::to_string(step)); #endif + const float current_lr = scheduler ? scheduler->GetLR() : static_cast(FLAGS_learning_rate); float lossf = 0.0f; // model->Train(); if (pp_world_size == 1) { @@ -448,6 +472,9 @@ void Train(const nn::parallel::Rank &rank) { } optimizer->Step(); + if (scheduler) { + scheduler->Step(); + } } else { auto [x, y] = *train_iter; // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below @@ -458,6 +485,9 @@ void Train(const nn::parallel::Rank &rank) { y = std::make_shared(y->To(device)); lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); + if (scheduler) { + scheduler->Step(); + } } if (ddp_world_size > 1) { @@ -473,11 +503,10 @@ void Train(const nn::parallel::Rank &rank) { if (rank.IsLastRank()) { size_t used_mb = 0, reserved_mb = 0; std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); - LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", - step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, - tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, + step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps, + used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 72538a73..9050c2b4 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -14,6 +14,7 @@ #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" +#include "infini_train/include/lr_scheduler.h" #include "infini_train/include/nn/lora/lora_utils.h" #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" @@ -59,8 +60,14 @@ DEFINE_uint32(num_iteration, 10, "number of iterations to run"); DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation"); DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization -DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations"); +DEFINE_double(learning_rate, 1e-5, "Peak learning rate."); DEFINE_int32(zero_stage, 0, "ZeRO stage (0/1/2/3); 0 disables DistributedOptimizer"); +// lr scheduler +DEFINE_double(min_lr, 0.0, "Minimum learning rate."); +DEFINE_string(lr_decay_style, "constant", "LR decay style: none|constant|linear|cosine|inverse-square-root"); +DEFINE_int64(lr_warmup_iters, 0, "Number of linear warmup iterations."); +DEFINE_double(lr_warmup_init, 0.0, "Initial learning rate at the start of warmup."); +DEFINE_int64(lr_decay_iters, 0, "Number of iterations to decay LR over (0 = num_iteration)."); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -105,12 +112,16 @@ constexpr char kDeviceCPU[] = "cpu"; constexpr char kDeviceCUDA[] = "cuda"; constexpr char kDtypeFP32[] = "float32"; constexpr char kDtypeBF16[] = "bfloat16"; +const std::unordered_set kSupportedLRDecayStyles + = {"none", "constant", "linear", "cosine", "inverse-square-root"}; } // namespace DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); }); DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 0 && value <= 3; }); +DEFINE_validator(lr_decay_style, + [](const char *, const std::string &value) { return kSupportedLRDecayStyles.contains(value); }); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -294,6 +305,16 @@ void Train(const nn::parallel::Rank &rank) { optimizer = optimizer_creator(params_to_optimize); } + const int64_t lr_decay_iters = FLAGS_lr_decay_iters > 0 ? FLAGS_lr_decay_iters : FLAGS_num_iteration; + TrainingLRSchedulerConfig sched_config; + sched_config.lr = static_cast(FLAGS_learning_rate); + sched_config.min_lr = static_cast(FLAGS_min_lr); + sched_config.lr_decay_style = FLAGS_lr_decay_style; + sched_config.lr_decay_iters = lr_decay_iters; + sched_config.lr_warmup_iters = FLAGS_lr_warmup_iters; + sched_config.lr_warmup_init = static_cast(FLAGS_lr_warmup_init); + auto scheduler = CreateLRScheduler(optimizer, sched_config); + auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast(std::make_shared()) @@ -311,7 +332,8 @@ void Train(const nn::parallel::Rank &rank) { .optimizer = optimizer, .model_config = model_config, .state = state, - .load_optimizer_state = true}); + .load_optimizer_state = true, + .lr_scheduler = scheduler}); start_step = resume_result.global_step; size_t consumed_batches = resume_result.consumed_batches; @@ -331,7 +353,7 @@ void Train(const nn::parallel::Rank &rank) { .save_dir = save_dir, .global_step = global_step, .consumed_batches = consumed_batches, - .last_lr = FLAGS_learning_rate, + .last_lr = optimizer->learning_rate(), .n_layer = model_config.n_layer, .n_head = model_config.n_head, .n_kv_head = model_config.n_kv_head, @@ -347,6 +369,7 @@ void Train(const nn::parallel::Rank &rank) { .rank = rank, .model = *model, .optimizer = *optimizer, + .lr_scheduler = scheduler.get(), }); }; @@ -381,6 +404,7 @@ void Train(const nn::parallel::Rank &rank) { Profiler::Instance().SetTag("Step_" + std::to_string(step)); #endif + const float current_lr = scheduler ? scheduler->GetLR() : static_cast(FLAGS_learning_rate); float lossf = 0.0f; if (pp_world_size == 1) { // model->Train(); @@ -425,6 +449,9 @@ void Train(const nn::parallel::Rank &rank) { } optimizer->Step(); + if (scheduler) { + scheduler->Step(); + } } else { auto [x, y] = *train_iter; // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below @@ -435,6 +462,9 @@ void Train(const nn::parallel::Rank &rank) { y = std::make_shared(y->To(device)); lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); + if (scheduler) { + scheduler->Step(); + } } if (ddp_world_size > 1) { @@ -450,11 +480,10 @@ void Train(const nn::parallel::Rank &rank) { if (rank.IsLastRank()) { size_t used_mb = 0, reserved_mb = 0; std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); - LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", - step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, - tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, + step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps, + used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { diff --git a/infini_train/include/checkpoint/checkpoint.h b/infini_train/include/checkpoint/checkpoint.h index b122a17a..c71f0d7b 100644 --- a/infini_train/include/checkpoint/checkpoint.h +++ b/infini_train/include/checkpoint/checkpoint.h @@ -9,6 +9,7 @@ namespace infini_train { class Optimizer; +class LRScheduler; class Tensor; namespace nn { class Module; @@ -34,10 +35,12 @@ struct TrainerState { class Checkpoint { public: static void Save(const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer *optimizer, - const TrainerState &state, bool save_optimizer_state); + const TrainerState &state, bool save_optimizer_state, const LRScheduler *lr_scheduler, + bool save_lr_scheduler_state); static void Load(const std::filesystem::path &checkpoint_dir, nn::Module &model, Optimizer *optimizer, - TrainerState &state, bool load_optimizer_state); + TrainerState &state, bool load_optimizer_state, LRScheduler *lr_scheduler, + bool load_lr_scheduler_state); private: static void SaveStateDict(const std::filesystem::path &path, diff --git a/infini_train/include/checkpoint/checkpoint_manager.h b/infini_train/include/checkpoint/checkpoint_manager.h index cce14107..58434faa 100644 --- a/infini_train/include/checkpoint/checkpoint_manager.h +++ b/infini_train/include/checkpoint/checkpoint_manager.h @@ -3,6 +3,7 @@ #include #include #include +#include #include "infini_train/include/checkpoint/checkpoint.h" #include "infini_train/include/dataloader.h" @@ -13,6 +14,10 @@ using namespace infini_train; namespace nn = infini_train::nn; +namespace infini_train { +class LRScheduler; +} + namespace infini_train::nn { class TransformerConfig; } @@ -25,6 +30,8 @@ struct ResumeFromCheckpointArgs { const nn::TransformerConfig &model_config; TrainerState &state; bool load_optimizer_state; + std::shared_ptr lr_scheduler = nullptr; + bool load_lr_scheduler_state = true; }; struct ResumeFromCheckpointResult { @@ -52,6 +59,8 @@ struct SaveCheckpointArgs { const nn::parallel::Rank &rank; const nn::Module &model; const Optimizer &optimizer; + const LRScheduler *lr_scheduler = nullptr; + bool save_lr_scheduler_state = true; }; ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args); diff --git a/infini_train/include/lr_scheduler.h b/infini_train/include/lr_scheduler.h new file mode 100644 index 00000000..13c8e79a --- /dev/null +++ b/infini_train/include/lr_scheduler.h @@ -0,0 +1,173 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace infini_train { + +class Optimizer; + +using StateValue = std::variant>; +using StateDict = std::unordered_map; + +struct TrainingLRSchedulerConfig { + std::string lr_decay_style = "constant"; + float lr = 0.0f; + float min_lr = 0.0f; + int64_t lr_decay_iters = 1; + int64_t lr_warmup_iters = 0; + float lr_warmup_init = 0.0f; +}; + +class LRScheduler { +public: + template static std::shared_ptr Create(Args &&...args) { + auto scheduler = std::make_shared(std::forward(args)...); + scheduler->InitialStep(); + return scheduler; + } + + explicit LRScheduler(std::shared_ptr optimizer, int64_t last_step = -1); + virtual ~LRScheduler() = default; + + LRScheduler(const LRScheduler &) = delete; + LRScheduler &operator=(const LRScheduler &) = delete; + + virtual void Step(); + virtual void Step(int64_t epoch); + virtual void InitialStep(); + + float GetLR() const; + float BaseLR() const; + int64_t LastStep() const; + + void ResetStep(int64_t step = -1); + virtual StateDict State() const; + virtual void LoadState(const StateDict &state); + + bool SharesOptimizerWith(const std::shared_ptr &opt) const; + +protected: + virtual float GetClosedFormLR() const = 0; + virtual float GetChainedFormLR() const; + void ApplyLR(float lr); + + std::shared_ptr optimizer_; + int64_t last_step_; + float recover_lr_; + float base_lr_; + bool is_initial_ = false; +}; + +std::shared_ptr CreateLRScheduler(std::shared_ptr optimizer, + const TrainingLRSchedulerConfig &config); + +namespace lr_schedulers { + +class ConstantLR : public LRScheduler { +public: + ConstantLR(std::shared_ptr optimizer, float factor = 1.0f / 3.0f, int total_iters = 5, + int64_t last_step = -1); + ~ConstantLR() override = default; + +protected: + float GetChainedFormLR() const override; + float GetClosedFormLR() const override; + +private: + const float factor_; + const int64_t total_iters_; +}; + +class StepLR : public LRScheduler { +public: + StepLR(std::shared_ptr optimizer, int64_t step_size, float gamma = 0.1f, int64_t last_step = -1); + ~StepLR() override = default; + +protected: + float GetChainedFormLR() const override; + float GetClosedFormLR() const override; + +private: + const int64_t step_size_; + const float gamma_; +}; + +class LinearLR : public LRScheduler { +public: + LinearLR(std::shared_ptr optimizer, float start_factor = 1.0f / 3.0f, float end_factor = 1.0f, + int64_t total_iters = 5, int64_t last_step = -1); + ~LinearLR() override = default; + +protected: + float GetChainedFormLR() const override; + float GetClosedFormLR() const override; + +private: + const float start_factor_; + const float end_factor_; + const int64_t total_iters_; +}; + +class LambdaLR : public LRScheduler { +public: + using LambdaFunc = std::function; + + LambdaLR(std::shared_ptr optimizer, LambdaFunc lr_lambda, int64_t last_step = -1); + ~LambdaLR() override = default; + +protected: + float GetClosedFormLR() const override; + +private: + const LambdaFunc lr_lambda_; +}; + +class SequentialLR : public LRScheduler { +public: + SequentialLR(std::shared_ptr optimizer, std::vector> schedulers, + std::vector milestones, int64_t last_step = -1); + ~SequentialLR() override = default; + + void Step() override; + void InitialStep() override; + + StateDict State() const override; + void LoadState(const StateDict &state) override; + +protected: + float GetClosedFormLR() const override; + void UndoChildInitialSteps(); + +private: + std::vector> schedulers_; + std::vector milestones_; +}; + +class ChainedScheduler : public LRScheduler { +public: + ChainedScheduler(std::shared_ptr optimizer, std::vector> schedulers, + int64_t last_step = -1); + ~ChainedScheduler() override = default; + + void Step() override; + void InitialStep() override; + + StateDict State() const override; + void LoadState(const StateDict &state) override; + +protected: + float GetClosedFormLR() const override; + +private: + std::vector> schedulers_; +}; + +} // namespace lr_schedulers +} // namespace infini_train diff --git a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h index 559c4312..d694ab2a 100644 --- a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h +++ b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h @@ -38,6 +38,9 @@ class DistributedOptimizer final : public infini_train::Optimizer { void StartParamSync(bool force_sync = false); void FinishParamSync(bool skip_next_bucket_dispatch = false); + virtual void set_learning_rate(float lr) override; + virtual float learning_rate() const override; + private: void BuildShardParamsAndBindGrads(); diff --git a/infini_train/include/optimizer.h b/infini_train/include/optimizer.h index 01051053..4cf8e6b7 100644 --- a/infini_train/include/optimizer.h +++ b/infini_train/include/optimizer.h @@ -17,7 +17,7 @@ using OptimizerCreator = std::function(const std::vec class Optimizer { public: - explicit Optimizer(const std::vector> ¶ms); + explicit Optimizer(const std::vector> ¶ms, float learning_rate = 0.0f); virtual void ZeroGrad(bool set_to_none = true); @@ -27,8 +27,19 @@ class Optimizer { virtual void LoadStateDict(const std::unordered_map> &state_dict) {} + virtual void set_learning_rate(float lr); + + virtual float learning_rate() const; + + float initial_learning_rate() const; + + void set_initial_learning_rate(float lr); + protected: std::vector> params_; + float learning_rate_ = 0.0f; + float initial_learning_rate_ = 0.0f; + bool initial_lr_set_ = false; }; namespace optimizers { @@ -39,9 +50,6 @@ class SGD : public Optimizer { void Step() override; static OptimizerCreator Create(float learning_rate); - -private: - const float learning_rate_ = 0.0; }; class Adam : public Optimizer { @@ -59,7 +67,6 @@ class Adam : public Optimizer { private: int64_t t_; - const float learning_rate_; const float beta1_; const float beta2_; const float eps_; diff --git a/infini_train/src/checkpoint/checkpoint.cc b/infini_train/src/checkpoint/checkpoint.cc index 892ec497..427d9323 100644 --- a/infini_train/src/checkpoint/checkpoint.cc +++ b/infini_train/src/checkpoint/checkpoint.cc @@ -5,9 +5,13 @@ #include #include #include +#include +#include +#include #include "glog/logging.h" +#include "infini_train/include/lr_scheduler.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/optimizer.h" #include "infini_train/include/tensor.h" @@ -16,6 +20,16 @@ namespace infini_train { namespace { constexpr uint32_t kCkptMagic = 0x54504B43; // CKPT constexpr uint32_t kCkptVersion = 1; +constexpr uint32_t kLRSchedulerMagic = 0x53524C53; // SLRS +constexpr uint32_t kLRSchedulerVersion = 1; + +enum class LRSchedulerStateValueType : uint8_t { + kInt64 = 1, + kFloat = 2, + kDouble = 3, + kString = 4, + kFloatVector = 5, +}; void WriteString(std::ofstream *ofs, const std::string &value) { uint32_t len = static_cast(value.size()); @@ -31,6 +45,111 @@ std::string ReadString(std::ifstream *ifs) { return s; } +void WriteLRSchedulerStateValue(std::ofstream *ofs, const StateValue &value) { + if (std::holds_alternative(value)) { + const auto type = LRSchedulerStateValueType::kInt64; + const auto data = std::get(value); + ofs->write(reinterpret_cast(&type), sizeof(type)); + ofs->write(reinterpret_cast(&data), sizeof(data)); + } else if (std::holds_alternative(value)) { + const auto type = LRSchedulerStateValueType::kFloat; + const auto data = std::get(value); + ofs->write(reinterpret_cast(&type), sizeof(type)); + ofs->write(reinterpret_cast(&data), sizeof(data)); + } else if (std::holds_alternative(value)) { + const auto type = LRSchedulerStateValueType::kDouble; + const auto data = std::get(value); + ofs->write(reinterpret_cast(&type), sizeof(type)); + ofs->write(reinterpret_cast(&data), sizeof(data)); + } else if (std::holds_alternative(value)) { + const auto type = LRSchedulerStateValueType::kString; + ofs->write(reinterpret_cast(&type), sizeof(type)); + WriteString(ofs, std::get(value)); + } else if (std::holds_alternative>(value)) { + const auto type = LRSchedulerStateValueType::kFloatVector; + const auto &data = std::get>(value); + const auto size = static_cast(data.size()); + ofs->write(reinterpret_cast(&type), sizeof(type)); + ofs->write(reinterpret_cast(&size), sizeof(size)); + ofs->write(reinterpret_cast(data.data()), static_cast(size * sizeof(float))); + } else { + LOG(FATAL) << "Unsupported LR scheduler state value type."; + } +} + +StateValue ReadLRSchedulerStateValue(std::ifstream *ifs) { + LRSchedulerStateValueType type{}; + ifs->read(reinterpret_cast(&type), sizeof(type)); + switch (type) { + case LRSchedulerStateValueType::kInt64: { + int64_t data = 0; + ifs->read(reinterpret_cast(&data), sizeof(data)); + return data; + } + case LRSchedulerStateValueType::kFloat: { + float data = 0.0f; + ifs->read(reinterpret_cast(&data), sizeof(data)); + return data; + } + case LRSchedulerStateValueType::kDouble: { + double data = 0.0; + ifs->read(reinterpret_cast(&data), sizeof(data)); + return data; + } + case LRSchedulerStateValueType::kString: + return ReadString(ifs); + case LRSchedulerStateValueType::kFloatVector: { + uint64_t size = 0; + ifs->read(reinterpret_cast(&size), sizeof(size)); + std::vector data(size); + ifs->read(reinterpret_cast(data.data()), static_cast(size * sizeof(float))); + return data; + } + default: + LOG(FATAL) << "Unsupported LR scheduler state value type: " << static_cast(type); + } + return int64_t{0}; +} + +void SaveLRSchedulerState(const std::filesystem::path &path, const StateDict &state_dict) { + std::ofstream ofs(path, std::ios::binary); + CHECK(ofs.is_open()) << "Failed to open LR scheduler checkpoint file: " << path; + + const uint32_t magic = kLRSchedulerMagic; + const uint32_t version = kLRSchedulerVersion; + const uint32_t count = static_cast(state_dict.size()); + ofs.write(reinterpret_cast(&magic), sizeof(magic)); + ofs.write(reinterpret_cast(&version), sizeof(version)); + ofs.write(reinterpret_cast(&count), sizeof(count)); + + for (const auto &[name, value] : state_dict) { + WriteString(&ofs, name); + WriteLRSchedulerStateValue(&ofs, value); + } +} + +StateDict LoadLRSchedulerState(const std::filesystem::path &path) { + std::ifstream ifs(path, std::ios::binary); + CHECK(ifs.is_open()) << "Failed to open LR scheduler checkpoint file: " << path; + + uint32_t magic = 0; + uint32_t version = 0; + uint32_t count = 0; + ifs.read(reinterpret_cast(&magic), sizeof(magic)); + ifs.read(reinterpret_cast(&version), sizeof(version)); + ifs.read(reinterpret_cast(&count), sizeof(count)); + + CHECK_EQ(magic, kLRSchedulerMagic) << "Invalid LR scheduler checkpoint magic: " << path; + CHECK_EQ(version, kLRSchedulerVersion) << "Unsupported LR scheduler checkpoint version: " << path; + + StateDict state; + for (uint32_t i = 0; i < count; ++i) { + auto name = ReadString(&ifs); + state.emplace(std::move(name), ReadLRSchedulerStateValue(&ifs)); + } + return state; +} + // TODO: This is a hand-rolled JSON field extractor. Replace with a proper JSON library (e.g., nlohmann/json) once // available in the project dependencies. template T ExtractNumberField(const std::string &content, const std::string &key, T fallback) { @@ -63,7 +182,8 @@ template T ExtractNumberField(const std::string &content, const std } // namespace void Checkpoint::Save(const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer *optimizer, - const TrainerState &state, bool save_optimizer_state) { + const TrainerState &state, bool save_optimizer_state, const LRScheduler *lr_scheduler, + bool save_lr_scheduler_state) { std::filesystem::create_directories(checkpoint_dir); LOG(INFO) << "[CKPT] Save begin: dir=" << checkpoint_dir << ", global_step=" << state.global_step; @@ -80,12 +200,17 @@ void Checkpoint::Save(const std::filesystem::path &checkpoint_dir, const nn::Mod } } + if (save_lr_scheduler_state && lr_scheduler != nullptr) { + SaveLRSchedulerState(checkpoint_dir / "lr_scheduler.ckpt", lr_scheduler->State()); + } + SaveTrainerState(checkpoint_dir / "trainer_state.json", state); LOG(ERROR) << "[CKPT] Save done: dir=" << checkpoint_dir; } void Checkpoint::Load(const std::filesystem::path &checkpoint_dir, nn::Module &model, Optimizer *optimizer, - TrainerState &state, bool load_optimizer_state) { + TrainerState &state, bool load_optimizer_state, LRScheduler *lr_scheduler, + bool load_lr_scheduler_state) { const auto model_path = checkpoint_dir / "model.ckpt"; LOG(INFO) << "[CKPT] Loading model: " << model_path; @@ -103,6 +228,18 @@ void Checkpoint::Load(const std::filesystem::path &checkpoint_dir, nn::Module &m } state = LoadTrainerState(checkpoint_dir / "trainer_state.json"); + + if (load_lr_scheduler_state && lr_scheduler != nullptr) { + const auto lr_scheduler_path = checkpoint_dir / "lr_scheduler.ckpt"; + if (std::filesystem::exists(lr_scheduler_path)) { + LOG(INFO) << "[CKPT] Loading LR scheduler: " << lr_scheduler_path; + lr_scheduler->LoadState(LoadLRSchedulerState(lr_scheduler_path)); + } else { + LOG(WARNING) << "[CKPT] LR scheduler checkpoint not found at: " << lr_scheduler_path + << ". Keeping the initialized scheduler state."; + } + } + LOG(ERROR) << "[CKPT] Load done: global_step=" << state.global_step << ", consumed_batches =" << state.consumed_batches << ", last_lr=" << state.last_lr << ", topology(ddp,tp,sp,pp)=(" << state.ddp_size << "," << state.tp_size << "," << state.sp_size << "," diff --git a/infini_train/src/checkpoint/checkpoint_manager.cc b/infini_train/src/checkpoint/checkpoint_manager.cc index 71e15c08..887446c9 100644 --- a/infini_train/src/checkpoint/checkpoint_manager.cc +++ b/infini_train/src/checkpoint/checkpoint_manager.cc @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -38,7 +39,8 @@ ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs & } } - Checkpoint::Load(resume_dir, *args.model, args.optimizer.get(), args.state, args.load_optimizer_state); + Checkpoint::Load(resume_dir, *args.model, args.optimizer.get(), args.state, args.load_optimizer_state, + args.lr_scheduler.get(), args.load_lr_scheduler_state); result.global_step = static_cast(args.state.global_step); @@ -88,7 +90,8 @@ void SaveCheckpoint(const SaveCheckpointArgs &args) { state.sp_size = args.sp_size; state.pp_size = args.pp_size; - Checkpoint::Save(args.save_dir, args.model, &args.optimizer, state, args.save_optimizer_state); + Checkpoint::Save(args.save_dir, args.model, &args.optimizer, state, args.save_optimizer_state, args.lr_scheduler, + args.save_lr_scheduler_state); const auto ckpt_end = std::chrono::high_resolution_clock::now(); const double ckpt_ms = std::chrono::duration(ckpt_end - ckpt_start).count(); diff --git a/infini_train/src/lr_scheduler.cc b/infini_train/src/lr_scheduler.cc new file mode 100644 index 00000000..42afb165 --- /dev/null +++ b/infini_train/src/lr_scheduler.cc @@ -0,0 +1,372 @@ +#include "infini_train/include/lr_scheduler.h" + +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/optimizer.h" + +namespace infini_train { + +std::shared_ptr CreateLRScheduler(std::shared_ptr optimizer, + const TrainingLRSchedulerConfig &config) { + if (config.lr_decay_style == "none") { + return nullptr; + } + + CHECK(optimizer) << "CreateLRScheduler: optimizer must not be null."; + const float max_lr = config.lr != 0.0f ? config.lr : optimizer->learning_rate(); + CHECK_GT(max_lr, 0.0f) << "CreateLRScheduler: max_lr must be > 0."; + CHECK_GE(config.lr_warmup_init, 0.0f) << "CreateLRScheduler: lr_warmup_init must be >= 0."; + CHECK_GE(config.min_lr, 0.0f) << "CreateLRScheduler: min_lr must be >= 0."; + CHECK_GE(max_lr, config.min_lr) << "CreateLRScheduler: max_lr must be >= min_lr."; + CHECK_LE(config.lr_warmup_init, max_lr) << "CreateLRScheduler: lr_warmup_init must be <= max_lr."; + CHECK_GE(config.lr_warmup_iters, 0) << "CreateLRScheduler: lr_warmup_iters must be >= 0."; + CHECK_GT(config.lr_decay_iters, 0) << "CreateLRScheduler: lr_decay_iters must be > 0."; + CHECK_LT(config.lr_warmup_iters, config.lr_decay_iters) + << "CreateLRScheduler: lr_warmup_iters must be < lr_decay_iters."; + CHECK(config.lr_decay_style == "constant" || config.lr_decay_style == "linear" || config.lr_decay_style == "cosine" + || config.lr_decay_style == "inverse-square-root") + << "CreateLRScheduler: unsupported lr_decay_style: " << config.lr_decay_style; + + std::shared_ptr main_scheduler; + const int64_t decay_iters_after_warmup = config.lr_decay_iters - config.lr_warmup_iters; + if (config.lr_decay_style == "constant") { + main_scheduler = LRScheduler::Create(optimizer, [](int64_t) { return 1.0f; }); + } else if (config.lr_decay_style == "linear") { + main_scheduler = LRScheduler::Create(optimizer, 1.0f, config.min_lr / max_lr, + decay_iters_after_warmup); + } else if (config.lr_decay_style == "cosine") { + main_scheduler = LRScheduler::Create( + optimizer, [max_lr, min_lr = config.min_lr, decay_iters_after_warmup](int64_t step) { + if (step > decay_iters_after_warmup) { + return min_lr / max_lr; + } + const float decay_ratio = static_cast(step) / static_cast(decay_iters_after_warmup); + CHECK_GE(decay_ratio, 0.0f) << "CreateLRScheduler: decay " + "ratio must be >= 0."; + CHECK_LE(decay_ratio, 1.0f) << "CreateLRScheduler: decay " + "ratio must be <= 1."; + const float coeff = 0.5f * (std::cos(std::numbers::pi_v * decay_ratio) + 1.0f); + return (min_lr + coeff * (max_lr - min_lr)) / max_lr; + }); + } else if (config.lr_decay_style == "inverse-square-root") { + main_scheduler = LRScheduler::Create( + optimizer, [max_lr, min_lr = config.min_lr, lr_warmup_iters = config.lr_warmup_iters, + lr_decay_iters = config.lr_decay_iters](int64_t step) { + const int64_t global_step = step + lr_warmup_iters; + if (global_step > lr_decay_iters) { + return min_lr / max_lr; + } + const auto warmup = static_cast(std::max(lr_warmup_iters, 1)); + const auto current = static_cast(std::max(global_step, 1)); + return std::max(min_lr, max_lr * std::sqrt(warmup) / std::sqrt(current)) / max_lr; + }); + } + + CHECK(main_scheduler) << "CreateLRScheduler: failed to create scheduler."; + if (config.lr_warmup_iters == 0) { + return main_scheduler; + } + + auto warmup_scheduler = LRScheduler::Create( + optimizer, + [lr_warmup_init = config.lr_warmup_init, max_lr, lr_warmup_iters = config.lr_warmup_iters](int64_t step) { + const float warmup_ratio = static_cast(step) / static_cast(lr_warmup_iters); + return (lr_warmup_init + (max_lr - lr_warmup_init) * warmup_ratio) / max_lr; + }); + return LRScheduler::Create( + std::move(optimizer), std::vector>{warmup_scheduler, main_scheduler}, + std::vector{config.lr_warmup_iters}); +} + +LRScheduler::LRScheduler(std::shared_ptr optimizer, int64_t last_step) + : optimizer_(std::move(optimizer)), last_step_(last_step), base_lr_(0.0f) { + CHECK(optimizer_) << "LRScheduler: optimizer must not be null."; + optimizer_->set_initial_learning_rate(optimizer_->learning_rate()); + base_lr_ = optimizer_->initial_learning_rate(); +} + +void LRScheduler::Step() { + ++last_step_; + ApplyLR(GetChainedFormLR()); +} + +void LRScheduler::Step(int64_t epoch) { + last_step_ = epoch; + ApplyLR(GetClosedFormLR()); +} + +void LRScheduler::InitialStep() { + is_initial_ = true; + Step(); + is_initial_ = false; +} + +void LRScheduler::ApplyLR(float lr) { optimizer_->set_learning_rate(lr); } + +float LRScheduler::GetChainedFormLR() const { return GetClosedFormLR(); } + +float LRScheduler::GetLR() const { return optimizer_->learning_rate(); } + +float LRScheduler::BaseLR() const { return base_lr_; } + +int64_t LRScheduler::LastStep() const { return last_step_; } + +bool LRScheduler::SharesOptimizerWith(const std::shared_ptr &opt) const { return optimizer_ == opt; } + +void LRScheduler::ResetStep(int64_t step) { last_step_ = step; } + +StateDict LRScheduler::State() const { + return { + {"last_step", last_step_}, + {"recover_lr", optimizer_->learning_rate()}, + {"base_lr", base_lr_}, + }; +} + +void LRScheduler::LoadState(const StateDict &state) { + last_step_ = std::get(state.at("last_step")); + recover_lr_ = std::get(state.at("recover_lr")); + base_lr_ = std::get(state.at("base_lr")); + optimizer_->set_learning_rate(recover_lr_); +} + +// Concrete LR Schedulers + +namespace lr_schedulers { + +// --- ConstantLR --- + +ConstantLR::ConstantLR(std::shared_ptr optimizer, float factor, int total_iters, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), factor_(factor), total_iters_(total_iters) { + CHECK_GT(factor_, 0.0f) << "ConstantLR: factor must be > 0."; + CHECK_LE(factor_, 1.0f) << "ConstantLR: factor must be <= 1."; +} + +float ConstantLR::GetClosedFormLR() const { return last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_; } + +float ConstantLR::GetChainedFormLR() const { + const float lr = optimizer_->learning_rate(); + if (last_step_ == 0) { + return lr * factor_; + } else if (last_step_ < total_iters_) { + return lr; + } else if (last_step_ == total_iters_) { + return lr / factor_; + } + return lr; +} + +// --- StepLR --- + +StepLR::StepLR(std::shared_ptr optimizer, int64_t step_size, float gamma, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), step_size_(step_size), gamma_(gamma) { + CHECK_GT(step_size_, 0) << "StepLR: step_size must be > 0."; + CHECK_GT(gamma_, 0.0f) << "StepLR: gamma must be > 0."; +} + +float StepLR::GetClosedFormLR() const { + return base_lr_ + * static_cast(std::pow(static_cast(gamma_), static_cast(last_step_ / step_size_))); +} + +float StepLR::GetChainedFormLR() const { + const float lr = optimizer_->learning_rate(); + if (last_step_ == 0 || (last_step_ % step_size_) != 0) { + return lr; + } + return lr * gamma_; +} + +LinearLR::LinearLR(std::shared_ptr optimizer, float start_factor, float end_factor, int64_t total_iters, + int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), start_factor_(start_factor), end_factor_(end_factor), + total_iters_(total_iters) { + CHECK_GT(start_factor_, 0.0f) << "LinearLR: start_factor must be > 0."; + CHECK_LE(start_factor_, 1.0f) << "LinearLR: start_factor must be <= 1."; + CHECK_GE(end_factor_, 0.0f) << "LinearLR: end_factor must be >= 0."; + CHECK_LE(end_factor_, 1.0f) << "LinearLR: end_factor must be <= 1."; + CHECK_GT(total_iters_, 0) << "LinearLR: total_iters must be > 0."; +} + +float LinearLR::GetClosedFormLR() const { + if (last_step_ >= total_iters_) { + return base_lr_ * end_factor_; + } + return base_lr_ + * (start_factor_ + + (end_factor_ - start_factor_) * static_cast(last_step_) / static_cast(total_iters_)); +} + +float LinearLR::GetChainedFormLR() const { + const float lr = optimizer_->learning_rate(); + if (last_step_ == 0) { + return lr * start_factor_; + } + if (last_step_ > total_iters_ || is_initial_) { + return lr; + } + if (last_step_ == total_iters_) { + const float prev_factor + = start_factor_ + + (end_factor_ - start_factor_) * static_cast(total_iters_ - 1) / static_cast(total_iters_); + return lr * (end_factor_ / prev_factor); + } + + const float numerator = end_factor_ - start_factor_; + const float denominator + = start_factor_ * static_cast(total_iters_) + static_cast(last_step_ - 1) * numerator; + return lr * (1.0f + numerator / denominator); +} + +LambdaLR::LambdaLR(std::shared_ptr optimizer, std::function lr_lambda, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), lr_lambda_(std::move(lr_lambda)) { + CHECK(lr_lambda_) << "LambdaLR: lr_lambda must not be null."; +} + +float LambdaLR::GetClosedFormLR() const { return base_lr_ * lr_lambda_(last_step_); } + +float SequentialLR::GetClosedFormLR() const { + LOG(FATAL) << "SequentialLR does not support closed-form LR. Use Step() without an explicit epoch."; + return base_lr_; +} + +SequentialLR::SequentialLR(std::shared_ptr optimizer, std::vector> schedulers, + std::vector milestones, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)), + milestones_(std::move(milestones)) { + CHECK(!schedulers_.empty()) << "SequentialLR requires at least one scheduler."; + + for (size_t i = 0; i < schedulers_.size(); ++i) { + CHECK(schedulers_[i]) << "SequentialLR: scheduler at index " << i << " must not be null."; + CHECK(schedulers_[i]->SharesOptimizerWith(optimizer_)) + << "SequentialLR: scheduler at index " << i << " must share the same optimizer."; + } + + CHECK_EQ(milestones_.size(), schedulers_.size() - 1) + << "SequentialLR: milestones count must be schedulers count - 1."; + + for (size_t i = 1; i < milestones_.size(); ++i) { + CHECK_GT(milestones_[i], milestones_[i - 1]) << "Milestones must be strictly increasing."; + } +} + +void SequentialLR::InitialStep() { + + optimizer_->set_learning_rate(schedulers_[0]->BaseLR()); + + UndoChildInitialSteps(); + + ++last_step_; + schedulers_[0]->InitialStep(); +} + +void SequentialLR::UndoChildInitialSteps() { + for (auto &sched : schedulers_) { + if (auto nested = std::dynamic_pointer_cast(sched)) { + nested->UndoChildInitialSteps(); + } + sched->ResetStep(sched->LastStep() - 1); + } +} + +void SequentialLR::Step() { + ++last_step_; + size_t idx = std::upper_bound(milestones_.begin(), milestones_.end(), last_step_) - milestones_.begin(); + + auto &scheduler = schedulers_[idx]; + + if (idx > 0 && milestones_[idx - 1] == last_step_) { + scheduler->Step(0); + } else { + scheduler->Step(); + } +} + +StateDict SequentialLR::State() const { + StateDict state; + state["last_step"] = last_step_; + state["recover_lr"] = optimizer_->learning_rate(); + state["base_lr"] = base_lr_; + for (size_t i = 0; i < schedulers_.size(); ++i) { + auto sub_state = schedulers_[i]->State(); + for (const auto &[key, value] : sub_state) { state["scheduler_" + std::to_string(i) + "." + key] = value; } + } + return state; +} + +void SequentialLR::LoadState(const StateDict &state) { + last_step_ = std::get(state.at("last_step")); + recover_lr_ = std::get(state.at("recover_lr")); + base_lr_ = std::get(state.at("base_lr")); + + for (size_t i = 0; i < schedulers_.size(); ++i) { + StateDict sub_state; + std::string prefix = "scheduler_" + std::to_string(i) + "."; + for (const auto &[key, value] : state) { + if (key.substr(0, prefix.size()) == prefix) { + sub_state[key.substr(prefix.size())] = value; + } + } + if (!sub_state.empty()) { + schedulers_[i]->LoadState(sub_state); + } + } + optimizer_->set_learning_rate(recover_lr_); +} + +ChainedScheduler::ChainedScheduler(std::shared_ptr optimizer, + std::vector> schedulers, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)) { + CHECK(!schedulers_.empty()) << "ChainedScheduler requires at least one scheduler."; + + for (size_t i = 0; i < schedulers_.size(); ++i) { + CHECK(schedulers_[i]) << "ChainedScheduler: scheduler at index " << i << " must not be null."; + CHECK(schedulers_[i]->SharesOptimizerWith(optimizer_)) + << "ChainedScheduler: scheduler at index " << i << " must share the same optimizer."; + } +} + +float ChainedScheduler::GetClosedFormLR() const { + LOG(FATAL) << "ChainedScheduler does not support closed-form LR. Use Step() without an explicit epoch."; + return base_lr_; +} + +void ChainedScheduler::InitialStep() { last_step_ = 0; } + +void ChainedScheduler::Step() { + ++last_step_; + for (auto &sched : schedulers_) { sched->Step(); } +} + +StateDict ChainedScheduler::State() const { + StateDict state = LRScheduler::State(); + for (size_t i = 0; i < schedulers_.size(); ++i) { + auto sub_state = schedulers_[i]->State(); + for (const auto &[key, value] : sub_state) { state["scheduler_" + std::to_string(i) + "." + key] = value; } + } + return state; +} + +void ChainedScheduler::LoadState(const StateDict &state) { + LRScheduler::LoadState(state); + for (size_t i = 0; i < schedulers_.size(); ++i) { + StateDict sub_state; + std::string prefix = "scheduler_" + std::to_string(i) + "."; + for (const auto &[key, value] : state) { + if (key.substr(0, prefix.size()) == prefix) { + sub_state[key.substr(prefix.size())] = value; + } + } + if (!sub_state.empty()) { + schedulers_[i]->LoadState(sub_state); + } + } +} + +} // namespace lr_schedulers +} // namespace infini_train diff --git a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc index d1831c0f..022a4758 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc @@ -122,6 +122,20 @@ void DistributedOptimizer::ZeroGrad(bool set_to_none) { } } +void DistributedOptimizer::set_learning_rate(float lr) { + Optimizer::set_learning_rate(lr); + if (base_optimizer_) { + base_optimizer_->set_learning_rate(lr); + } +} + +float DistributedOptimizer::learning_rate() const { + if (base_optimizer_) { + return base_optimizer_->learning_rate(); + } + return Optimizer::learning_rate(); +} + void DistributedOptimizer::Step() { // 1. Ensure grads are synced FinishGradSync(); diff --git a/infini_train/src/optimizer.cc b/infini_train/src/optimizer.cc index 15925b2f..00b3ca3b 100644 --- a/infini_train/src/optimizer.cc +++ b/infini_train/src/optimizer.cc @@ -9,16 +9,32 @@ #include "infini_train/include/tensor.h" namespace infini_train { -Optimizer::Optimizer(const std::vector> ¶ms) : params_(params) {} +Optimizer::Optimizer(const std::vector> ¶ms, float learning_rate) + : params_(params), learning_rate_(learning_rate) {} void Optimizer::ZeroGrad(bool set_to_none) { for (auto param : params_) { param->ZeroGrad(set_to_none); } } +void Optimizer::set_learning_rate(float lr) { learning_rate_ = lr; } + +float Optimizer::learning_rate() const { return learning_rate_; } + +float Optimizer::initial_learning_rate() const { + CHECK(initial_lr_set_) << "Optimizer: initial_learning_rate not set. " + "Use with an LRScheduler first."; + return initial_learning_rate_; +} + +void Optimizer::set_initial_learning_rate(float lr) { + if (!initial_lr_set_) { + initial_learning_rate_ = lr; + initial_lr_set_ = true; + } +} namespace optimizers { -SGD::SGD(const std::vector> ¶ms, float learning_rate) - : Optimizer(params), learning_rate_(learning_rate) {} +SGD::SGD(const std::vector> ¶ms, float learning_rate) : Optimizer(params, learning_rate) {} void SGD::Step() { for (auto param : params_) { @@ -40,7 +56,7 @@ OptimizerCreator SGD::Create(float learning_rate) { } Adam::Adam(const std::vector> ¶ms, float learning_rate, float beta1, float beta2, float eps) - : Optimizer(params), t_(0), learning_rate_(learning_rate), beta1_(beta1), beta2_(beta2), eps_(eps) { + : Optimizer(params, learning_rate), t_(0), beta1_(beta1), beta2_(beta2), eps_(eps) { for (const auto ¶m : params_) { m_.emplace_back(std::make_shared(param->Dims(), param->Dtype(), param->GetDevice())); diff --git a/scripts/test_config.json b/scripts/test_config.json index eca3d26f..fc9f6d2c 100644 --- a/scripts/test_config.json +++ b/scripts/test_config.json @@ -378,6 +378,182 @@ } ] }, + { + "tag": "lr_scheduler", + "tests": [ + { + "id": "3_none_distopt", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "use_distributed_optimizer": true, + "learning_rate": 0.00001, + "lr_decay_style": "none" + } + }, + { + "id": "4_constant_tp4", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "constant", + "lr_warmup_iters": 0, + "lr_decay_iters": 0 + } + }, + { + "id": "5_linear_tp4_sp_distopt", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "sequence_parallel": true, + "use_distributed_optimizer": true, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "linear", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 10 + } + }, + { + "id": "6_cosine_pp8", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "pipeline_parallel": 8, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "cosine", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 10 + } + }, + { + "id": "7_inverse_sqrt_pp4_vpp2", + "args": { + "dtype": "float32", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "pipeline_parallel": 4, + "virtual_pipeline_parallel": 2, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "inverse-square-root", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 10 + } + }, + { + "id": "8_cosine_all_parallel_distopt", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 2, + "sequence_parallel": true, + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2, + "use_distributed_optimizer": true, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "cosine", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 10 + } + }, + { + "id": "3_bfloat16_linear", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "linear", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 0 + } + }, + { + "id": "4_bfloat16_inverse_sqrt_tp4_distopt", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "use_distributed_optimizer": true, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "inverse-square-root", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 10 + } + }, + { + "id": "5_bfloat16_constant_tp4_sp", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "sequence_parallel": true, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "constant", + "lr_warmup_iters": 0, + "lr_decay_iters": 10 + } + }, + { + "id": "8_bfloat16_none_all_parallel", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 2, + "sequence_parallel": true, + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2, + "learning_rate": 0.00001, + "lr_decay_style": "none" + } + } + ] + }, { "tag": "lora", "tests": [ diff --git a/tests/autograd/test_autograd.cc b/tests/autograd/test_autograd.cc index 1d6d129a..b766b6c0 100644 --- a/tests/autograd/test_autograd.cc +++ b/tests/autograd/test_autograd.cc @@ -8,8 +8,8 @@ #include "infini_train/include/autograd/function.h" #include "infini_train/include/autograd/linear.h" #include "infini_train/include/autograd/matmul.h" -#include "infini_train/include/autograd/normalization.h" #include "infini_train/include/autograd/no_op.h" +#include "infini_train/include/autograd/normalization.h" #include "infini_train/include/autograd/outer.h" #include "infini_train/include/autograd/reduction.h" #include "infini_train/include/autograd/softmax.h" diff --git a/tests/checkpoint/test_checkpoint_serialization.cc b/tests/checkpoint/test_checkpoint_serialization.cc index b6834c95..dd7250df 100644 --- a/tests/checkpoint/test_checkpoint_serialization.cc +++ b/tests/checkpoint/test_checkpoint_serialization.cc @@ -30,7 +30,8 @@ TEST_P(CheckpointSerializationTest, SaveAndLoadModelFP32) { auto opt1 = std::make_shared(model1->Parameters(), 0.01); TrainerState saved{.global_step = 42, .consumed_batches = 100}; - Checkpoint::Save(dir, *model1, opt1.get(), saved, true); + Checkpoint::Save(dir, *model1, opt1.get(), saved, /*save_optimizer_state=*/true, + /*lr_scheduler=*/nullptr, /*save_lr_scheduler_state=*/false); auto model2 = std::make_shared(3, 2, true, GetDevice()); auto q1 = std::make_shared(std::vector{2, 3}, DataType::kFLOAT32, GetDevice()); @@ -42,7 +43,8 @@ TEST_P(CheckpointSerializationTest, SaveAndLoadModelFP32) { auto opt2 = std::make_shared(model2->Parameters(), 0.01); TrainerState loaded; - Checkpoint::Load(dir, *model2, opt2.get(), loaded, true); + Checkpoint::Load(dir, *model2, opt2.get(), loaded, /*load_optimizer_state=*/true, + /*lr_scheduler=*/nullptr, /*load_lr_scheduler_state=*/false); EXPECT_EQ(loaded.global_step, 42); EXPECT_EQ(loaded.consumed_batches, 100); diff --git a/tests/checkpoint/test_lr_scheduler_state.cc b/tests/checkpoint/test_lr_scheduler_state.cc new file mode 100644 index 00000000..fad9bd3a --- /dev/null +++ b/tests/checkpoint/test_lr_scheduler_state.cc @@ -0,0 +1,122 @@ +#include +#include +#include + +#include "gtest/gtest.h" + +#include "infini_train/include/checkpoint/checkpoint.h" +#include "infini_train/include/lr_scheduler.h" +#include "infini_train/include/nn/modules/linear.h" +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +#include "tests/common/test_utils.h" + +using namespace infini_train; +namespace nn = infini_train::nn; + +namespace { +constexpr float kBaseLR = 0.1f; +constexpr float kEps = 1e-6f; + +TrainingLRSchedulerConfig MakeSchedulerConfig() { + return { + .lr_decay_style = "linear", + .lr = kBaseLR, + .min_lr = 0.01f, + .lr_decay_iters = 8, + .lr_warmup_iters = 2, + .lr_warmup_init = 0.0f, + }; +} + +std::shared_ptr MakeModel(Device device) { + auto model = std::make_shared(1, 2, true, device); + auto weight = std::make_shared(std::vector{2}, DataType::kFLOAT32, device); + weight->Fill(0.5f); + *model->mutable_parameter("weight") = weight; + return model; +} + +void StepTimes(const std::shared_ptr &scheduler, int times) { + for (int i = 0; i < times; ++i) { scheduler->Step(); } +} +} // namespace + +class LRSchedulerCheckpointTest : public test::InfiniTrainTest {}; + +TEST_P(LRSchedulerCheckpointTest, SaveAndLoadLRSchedulerState) { + auto dir = std::filesystem::temp_directory_path() / "test_lr_scheduler_ckpt"; + std::filesystem::remove_all(dir); + + auto model_ref = MakeModel(GetDevice()); + auto opt_ref = std::make_shared(model_ref->Parameters(), kBaseLR); + auto sched_ref = CreateLRScheduler(opt_ref, MakeSchedulerConfig()); + StepTimes(sched_ref, 6); + + auto model1 = MakeModel(GetDevice()); + auto opt1 = std::make_shared(model1->Parameters(), kBaseLR); + auto sched1 = CreateLRScheduler(opt1, MakeSchedulerConfig()); + StepTimes(sched1, 3); + + TrainerState saved{.global_step = 3, .consumed_batches = 12, .last_lr = sched1->GetLR()}; + Checkpoint::Save(dir, *model1, opt1.get(), saved, /*save_optimizer_state=*/false, sched1.get(), + /*save_lr_scheduler_state=*/true); + EXPECT_TRUE(std::filesystem::exists(dir / "lr_scheduler.ckpt")); + + auto model2 = MakeModel(GetDevice()); + auto opt2 = std::make_shared(model2->Parameters(), kBaseLR); + auto sched2 = CreateLRScheduler(opt2, MakeSchedulerConfig()); + + TrainerState loaded; + Checkpoint::Load(dir, *model2, opt2.get(), loaded, /*load_optimizer_state=*/false, sched2.get(), + /*load_lr_scheduler_state=*/true); + + EXPECT_EQ(loaded.global_step, 3); + EXPECT_EQ(loaded.consumed_batches, 12); + EXPECT_EQ(sched2->LastStep(), sched1->LastStep()); + EXPECT_NEAR(sched2->GetLR(), sched1->GetLR(), kEps); + + StepTimes(sched2, 3); + EXPECT_EQ(sched2->LastStep(), sched_ref->LastStep()); + EXPECT_NEAR(sched2->GetLR(), sched_ref->GetLR(), kEps); + EXPECT_NEAR(opt2->learning_rate(), opt_ref->learning_rate(), kEps); + + std::filesystem::remove_all(dir); +} + +TEST_P(LRSchedulerCheckpointTest, HonorsLRSchedulerStateFlags) { + auto dir = std::filesystem::temp_directory_path() / "test_lr_scheduler_ckpt_flags"; + std::filesystem::remove_all(dir); + + auto model1 = MakeModel(GetDevice()); + auto opt1 = std::make_shared(model1->Parameters(), kBaseLR); + auto sched1 = CreateLRScheduler(opt1, MakeSchedulerConfig()); + StepTimes(sched1, 3); + + TrainerState saved{.global_step = 3, .last_lr = sched1->GetLR()}; + Checkpoint::Save(dir, *model1, opt1.get(), saved, /*save_optimizer_state=*/false, sched1.get(), + /*save_lr_scheduler_state=*/false); + EXPECT_FALSE(std::filesystem::exists(dir / "lr_scheduler.ckpt")); + + Checkpoint::Save(dir, *model1, opt1.get(), saved, /*save_optimizer_state=*/false, sched1.get(), + /*save_lr_scheduler_state=*/true); + ASSERT_TRUE(std::filesystem::exists(dir / "lr_scheduler.ckpt")); + + auto model2 = MakeModel(GetDevice()); + auto opt2 = std::make_shared(model2->Parameters(), kBaseLR); + auto sched2 = CreateLRScheduler(opt2, MakeSchedulerConfig()); + const auto initial_step = sched2->LastStep(); + const auto initial_lr = sched2->GetLR(); + + TrainerState loaded; + Checkpoint::Load(dir, *model2, opt2.get(), loaded, /*load_optimizer_state=*/false, sched2.get(), + /*load_lr_scheduler_state=*/false); + + EXPECT_EQ(sched2->LastStep(), initial_step); + EXPECT_NEAR(sched2->GetLR(), initial_lr, kEps); + + std::filesystem::remove_all(dir); +} + +INFINI_TRAIN_REGISTER_TEST(LRSchedulerCheckpointTest); diff --git a/tests/checkpoint/test_trainer_state.cc b/tests/checkpoint/test_trainer_state.cc index 532f2eff..76fa04a4 100644 --- a/tests/checkpoint/test_trainer_state.cc +++ b/tests/checkpoint/test_trainer_state.cc @@ -45,7 +45,8 @@ TEST_P(TrainerStateTest, TrainerStateFileCreated) { *model->mutable_parameter("weight") = p; auto opt = std::make_shared(model->Parameters(), 0.01); - Checkpoint::Save(dir, *model, opt.get(), saved, true); + Checkpoint::Save(dir, *model, opt.get(), saved, /*save_optimizer_state=*/true, + /*lr_scheduler=*/nullptr, /*save_lr_scheduler_state=*/false); EXPECT_TRUE(std::filesystem::exists(dir / "trainer_state.json")); @@ -82,7 +83,8 @@ TEST_P(TrainerStateTest, RoundTrip) { *model1->mutable_parameter("weight") = p1; auto opt1 = std::make_shared(model1->Parameters(), 0.01); - Checkpoint::Save(dir, *model1, opt1.get(), saved, false); + Checkpoint::Save(dir, *model1, opt1.get(), saved, /*save_optimizer_state=*/false, + /*lr_scheduler=*/nullptr, /*save_lr_scheduler_state=*/false); auto model2 = std::make_shared(1, 3, true, GetDevice()); auto p2 = std::make_shared(std::vector{3}, DataType::kFLOAT32, GetDevice()); @@ -91,7 +93,8 @@ TEST_P(TrainerStateTest, RoundTrip) { auto opt2 = std::make_shared(model2->Parameters(), 0.01); TrainerState loaded; - Checkpoint::Load(dir, *model2, opt2.get(), loaded, false); + Checkpoint::Load(dir, *model2, opt2.get(), loaded, /*load_optimizer_state=*/false, + /*lr_scheduler=*/nullptr, /*load_lr_scheduler_state=*/false); EXPECT_EQ(loaded.global_step, 99); EXPECT_EQ(loaded.consumed_batches, 5000); diff --git a/tests/optimizer/test_lr_scheduler.cc b/tests/optimizer/test_lr_scheduler.cc new file mode 100644 index 00000000..6d346bce --- /dev/null +++ b/tests/optimizer/test_lr_scheduler.cc @@ -0,0 +1,336 @@ +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "infini_train/include/lr_scheduler.h" +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +#include "tests/common/test_utils.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace { + +constexpr float kBaseLR = 0.1f; +constexpr float kEps = 1e-6f; + +class LRSchedulerTest : public infini_train::test::InfiniTrainTest {}; + +class LinearDecayScheduler : public LRScheduler { +public: + LinearDecayScheduler(std::shared_ptr optimizer, int64_t total_steps, int64_t last_step = -1) + : LRScheduler(std::move(optimizer), last_step), total_steps_(total_steps) {} + +protected: + float GetClosedFormLR() const override { + if (last_step_ >= total_steps_) { + return 0.0f; + } + return base_lr_ * (1.0f - static_cast(last_step_) / static_cast(total_steps_)); + } + +private: + int64_t total_steps_; +}; + +std::shared_ptr MakeDummyOptimizer(float lr) { + std::vector> empty_params; + return std::make_shared(empty_params, lr); +} + +void ExpectStepSequence(const std::shared_ptr &scheduler, std::initializer_list expected, + float eps = kEps) { + for (float expected_lr : expected) { + scheduler->Step(); + EXPECT_NEAR(scheduler->GetLR(), expected_lr, eps); + } +} + +std::shared_ptr MakeSequentialScheduler(std::shared_ptr opt) { + auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, + /*total_iters=*/3); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); + return LRScheduler::Create(opt, + /*schedulers=*/std::vector>{linear, step_lr}, + /*milestones=*/std::vector{3}); +} + +std::shared_ptr MakeChainedScheduler(std::shared_ptr opt) { + auto step_lr = LRScheduler::Create(opt, /*step_size=*/2, /*gamma=*/0.5f); + auto lambda_lr = LRScheduler::Create(opt, /*lr_lambda=*/[](int64_t step) { return 1.0f - 0.05f * step; }); + return LRScheduler::Create( + opt, /*schedulers=*/std::vector>{step_lr, lambda_lr}); +} + +} // namespace + +TEST_P(LRSchedulerTest, BaseSchedulerStateRoundTripAndResume) { + constexpr int64_t kTotalSteps = 20; + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = LRScheduler::Create(opt_ref, /*total_steps=*/kTotalSteps); + ExpectStepSequence(sched_ref, {0.095f, 0.09f, 0.085f, 0.08f, 0.075f, 0.07f, 0.065f}); + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = LRScheduler::Create(opt_a, /*total_steps=*/kTotalSteps); + ExpectStepSequence(sched_a, {0.095f, 0.09f, 0.085f}); + + StateDict state = sched_a->State(); + EXPECT_EQ(state.count("last_step"), 1); + EXPECT_EQ(state.count("recover_lr"), 1); + EXPECT_EQ(state.count("base_lr"), 1); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = LRScheduler::Create(opt_b, /*total_steps=*/kTotalSteps); + sched_b->LoadState(state); + ExpectStepSequence(sched_b, {0.08f, 0.075f, 0.07f, 0.065f}); + + EXPECT_EQ(sched_b->LastStep(), sched_ref->LastStep()); + EXPECT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); + EXPECT_NEAR(opt_b->learning_rate(), sched_ref->GetLR(), kEps); +} + +TEST_P(LRSchedulerTest, ConstantLRMatchesExpectedSchedule) { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/3); + + EXPECT_EQ(sched->LastStep(), 0); + EXPECT_NEAR(sched->GetLR(), 0.05f, kEps); + EXPECT_NEAR(opt->learning_rate(), 0.05f, kEps); + + ExpectStepSequence(sched, {0.05f, 0.05f, 0.1f, 0.1f, 0.1f}); +} + +TEST_P(LRSchedulerTest, LinearLRMatchesExpectedScheduleAndClosedForm) { + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto chainable = LRScheduler::Create(opt_a, /*start_factor=*/0.2f, /*end_factor=*/1.0f, + /*total_iters=*/5); + + EXPECT_NEAR(chainable->GetLR(), 0.02f, kEps); + ExpectStepSequence(chainable, {0.036f, 0.052f, 0.068f, 0.084f, 0.1f, 0.1f, 0.1f}, 1e-7f); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto closed_form = LRScheduler::Create(opt_b, /*start_factor=*/0.2f, /*end_factor=*/1.0f, + /*total_iters=*/5); + auto opt_c = MakeDummyOptimizer(kBaseLR); + auto chainable_again = LRScheduler::Create(opt_c, /*start_factor=*/0.2f, /*end_factor=*/1.0f, + /*total_iters=*/5); + + for (int epoch = 1; epoch <= 10; ++epoch) { + chainable_again->Step(); + closed_form->Step(epoch); + EXPECT_NEAR(chainable_again->GetLR(), closed_form->GetLR(), kEps); + } +} + +TEST_P(LRSchedulerTest, StepLRMatchesExpectedScheduleAndClosedForm) { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.1f); + + EXPECT_NEAR(sched->GetLR(), kBaseLR, kEps); + ExpectStepSequence(sched, {0.1f, 0.1f, 0.01f, 0.01f, 0.01f, 0.001f, 0.001f}, 1e-7f); + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto chainable = LRScheduler::Create(opt_a, /*step_size=*/3, /*gamma=*/0.1f); + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto closed_form = LRScheduler::Create(opt_b, /*step_size=*/3, /*gamma=*/0.1f); + + for (int epoch = 1; epoch <= 12; ++epoch) { + chainable->Step(); + closed_form->Step(epoch); + EXPECT_NEAR(chainable->GetLR(), closed_form->GetLR(), 1e-7f); + } +} + +TEST_P(LRSchedulerTest, LambdaLRMatchesExpectedScheduleAndRestoresState) { + auto lambda_fn = [](int64_t step) { return static_cast(std::pow(0.95, step)); }; + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = LRScheduler::Create(opt_ref, /*lr_lambda=*/lambda_fn); + ExpectStepSequence(sched_ref, {0.095f, 0.09025f, 0.0857375f, 0.08145062f}, 1e-6f); + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = LRScheduler::Create(opt_a, /*lr_lambda=*/lambda_fn); + ExpectStepSequence(sched_a, {0.095f, 0.09025f}, 1e-6f); + StateDict state = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = LRScheduler::Create(opt_b, /*lr_lambda=*/lambda_fn); + sched_b->LoadState(state); + ExpectStepSequence(sched_b, {0.0857375f, 0.08145062f}, 1e-6f); + + EXPECT_EQ(sched_b->LastStep(), sched_ref->LastStep()); + EXPECT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), 1e-6f); +} + +TEST_P(LRSchedulerTest, SequentialLRSwitchesAtMilestonesAndRestoresState) { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = MakeSequentialScheduler(opt); + + EXPECT_NEAR(sched->GetLR(), 0.0f, kEps); + ExpectStepSequence(sched, {0.1f / 3.0f, 0.2f / 3.0f, 0.1f, 0.1f, 0.1f, 0.05f}, 1e-5f); + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = MakeSequentialScheduler(opt_ref); + ExpectStepSequence(sched_ref, {0.1f / 3.0f, 0.2f / 3.0f, 0.1f, 0.1f, 0.1f, 0.05f, 0.05f, 0.05f}, 1e-5f); + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = MakeSequentialScheduler(opt_a); + ExpectStepSequence(sched_a, {0.1f / 3.0f, 0.2f / 3.0f, 0.1f, 0.1f}, 1e-5f); + StateDict state = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = MakeSequentialScheduler(opt_b); + sched_b->LoadState(state); + ExpectStepSequence(sched_b, {0.1f, 0.05f, 0.05f, 0.05f}, 1e-5f); + + EXPECT_EQ(sched_b->LastStep(), sched_ref->LastStep()); + EXPECT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); +} + +TEST_P(LRSchedulerTest, ChainedSchedulerComposesChildrenAndRestoresState) { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = MakeChainedScheduler(opt); + + EXPECT_NEAR(sched->GetLR(), 0.1f, kEps); + ExpectStepSequence(sched, {0.095f, 0.09f, 0.085f, 0.08f}, kEps); + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = MakeChainedScheduler(opt_ref); + ExpectStepSequence(sched_ref, {0.095f, 0.09f, 0.085f, 0.08f, 0.075f, 0.07f}, kEps); + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = MakeChainedScheduler(opt_a); + ExpectStepSequence(sched_a, {0.095f, 0.09f, 0.085f}, kEps); + StateDict state = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = MakeChainedScheduler(opt_b); + sched_b->LoadState(state); + ExpectStepSequence(sched_b, {0.08f, 0.075f, 0.07f}, kEps); + + EXPECT_EQ(sched_b->LastStep(), sched_ref->LastStep()); + EXPECT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); +} + +TEST_P(LRSchedulerTest, TrainingSchedulerFactoryBuildsCommonDecayStyles) { + { + auto opt = MakeDummyOptimizer(0.1f); + auto sched = CreateLRScheduler(opt, { + .lr_decay_style = "constant", + .lr = 0.1f, + .min_lr = 0.0f, + .lr_decay_iters = 10, + .lr_warmup_iters = 0, + .lr_warmup_init = 0.0f, + }); + EXPECT_NEAR(sched->GetLR(), 0.1f, kEps); + ExpectStepSequence(sched, {0.1f, 0.1f, 0.1f}, kEps); + } + + { + auto opt = MakeDummyOptimizer(1.0f); + auto sched = CreateLRScheduler(opt, { + .lr_decay_style = "linear", + .lr = 1.0f, + .min_lr = 0.1f, + .lr_decay_iters = 6, + .lr_warmup_iters = 2, + .lr_warmup_init = 0.0f, + }); + EXPECT_NEAR(sched->GetLR(), 0.0f, kEps); + ExpectStepSequence(sched, {0.5f, 1.0f, 0.775f, 0.55f, 0.325f, 0.1f}, kEps); + } + + { + auto opt = MakeDummyOptimizer(1.0f); + auto sched = CreateLRScheduler(opt, { + .lr_decay_style = "cosine", + .lr = 1.0f, + .min_lr = 0.0f, + .lr_decay_iters = 4, + .lr_warmup_iters = 0, + .lr_warmup_init = 0.0f, + }); + ExpectStepSequence(sched, {0.853553f, 0.5f, 0.146447f, 0.0f}, 1e-5f); + } + + { + auto opt = MakeDummyOptimizer(1.0f); + auto sched = CreateLRScheduler(opt, { + .lr_decay_style = "inverse-square-root", + .lr = 1.0f, + .min_lr = 0.1f, + .lr_decay_iters = 10, + .lr_warmup_iters = 2, + .lr_warmup_init = 0.0f, + }); + ExpectStepSequence(sched, {0.5f, 1.0f, 0.8164966f, 0.7071068f, 0.6324555f}, 1e-5f); + } +} + +TEST_P(LRSchedulerTest, TrainingSchedulerFactoryReturnsNullForNoneStyle) { + auto opt = MakeDummyOptimizer(0.1f); + auto sched = CreateLRScheduler(opt, { + .lr_decay_style = "none", + .lr = 0.1f, + .min_lr = 0.0f, + .lr_decay_iters = 10, + .lr_warmup_iters = 0, + .lr_warmup_init = 0.0f, + }); + + EXPECT_EQ(sched, nullptr); +} + +TEST_P(LRSchedulerTest, RejectsInvalidSchedulerConfigurations) { + EXPECT_DEATH( + { + auto opt = MakeDummyOptimizer(0.1f); + auto sched = LRScheduler::Create(opt, /*step_size=*/0, /*gamma=*/0.1f); + (void)sched; + }, + ""); + + EXPECT_DEATH( + { + auto opt = MakeDummyOptimizer(0.1f); + auto sched = LRScheduler::Create(opt, /*lr_lambda=*/LambdaLR::LambdaFunc{}); + (void)sched; + }, + ""); + + EXPECT_DEATH( + { + auto opt1 = MakeDummyOptimizer(0.1f); + auto opt2 = MakeDummyOptimizer(0.1f); + auto linear = LRScheduler::Create(opt1, /*start_factor=*/0.5f, /*end_factor=*/1.0f, + /*total_iters=*/2); + auto step_lr = LRScheduler::Create(opt2, /*step_size=*/2, /*gamma=*/0.5f); + auto sched = LRScheduler::Create( + opt1, /*schedulers=*/std::vector>{linear, step_lr}, + /*milestones=*/std::vector{1}); + (void)sched; + }, + ""); + + EXPECT_DEATH( + { + auto opt = MakeDummyOptimizer(0.1f); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/2, /*gamma=*/0.5f); + std::shared_ptr sched = LRScheduler::Create( + opt, + /*schedulers=*/std::vector>{step_lr}); + sched->Step(0); + }, + ""); +} + +INFINI_TRAIN_REGISTER_TEST(LRSchedulerTest);