Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
#include "infini_train/include/lr_scheduler.h"
#include "infini_train/include/nn/lora/lora_utils.h"
#include "infini_train/include/nn/modules/loss.h"
#include "infini_train/include/nn/modules/module.h"
Expand Down Expand Up @@ -60,8 +61,14 @@ DEFINE_uint32(num_iteration, 10, "number of iterations to run");
DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation");
DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations");
DEFINE_double(learning_rate, 1e-4, "Peak learning rate.");
DEFINE_int32(zero_stage, 0, "ZeRO stage (0/1/2/3); 0 disables DistributedOptimizer");
// lr scheduler
DEFINE_double(min_lr, 0.0, "Minimum learning rate.");
DEFINE_string(lr_decay_style, "constant", "LR decay style: none|constant|linear|cosine|inverse-square-root");
DEFINE_int64(lr_warmup_iters, 0, "Number of linear warmup iterations.");
DEFINE_double(lr_warmup_init, 0.0, "Initial learning rate at the start of warmup.");
DEFINE_int64(lr_decay_iters, 0, "Number of iterations to decay LR over (0 = num_iteration).");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -109,6 +116,8 @@ constexpr char kDeviceCPU[] = "cpu";
constexpr char kDeviceCUDA[] = "cuda";
constexpr char kDtypeFP32[] = "float32";
constexpr char kDtypeBF16[] = "bfloat16";
const std::unordered_set<std::string> kSupportedLRDecayStyles
= {"none", "constant", "linear", "cosine", "inverse-square-root"};

//
const std::unordered_map<std::string, nn::TransformerConfig> kModelToConfigs = {
Expand All @@ -124,6 +133,8 @@ DEFINE_validator(model, [](const char *, const std::string &value) { return kSup
DEFINE_validator(device,
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 0 && value <= 3; });
DEFINE_validator(lr_decay_style,
[](const char *, const std::string &value) { return kSupportedLRDecayStyles.contains(value); });

void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;
Expand Down Expand Up @@ -314,6 +325,16 @@ void Train(const nn::parallel::Rank &rank) {
optimizer = optimizer_creator(params_to_optimize);
}

const int64_t lr_decay_iters = FLAGS_lr_decay_iters > 0 ? FLAGS_lr_decay_iters : FLAGS_num_iteration;
TrainingLRSchedulerConfig sched_config;
sched_config.lr = static_cast<float>(FLAGS_learning_rate);
sched_config.min_lr = static_cast<float>(FLAGS_min_lr);
sched_config.lr_decay_style = FLAGS_lr_decay_style;
sched_config.lr_decay_iters = lr_decay_iters;
sched_config.lr_warmup_iters = FLAGS_lr_warmup_iters;
sched_config.lr_warmup_init = static_cast<float>(FLAGS_lr_warmup_init);
auto scheduler = CreateLRScheduler(optimizer, sched_config);

auto train_iter = train_loader.begin();
std::shared_ptr<nn::Module> loss_fn
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(
Expand All @@ -332,7 +353,8 @@ void Train(const nn::parallel::Rank &rank) {
.optimizer = optimizer,
.model_config = model_config,
.state = state,
.load_optimizer_state = false});
.load_optimizer_state = false,
.lr_scheduler = scheduler});
start_step = resume_result.global_step;
size_t consumed_batches = resume_result.consumed_batches;

Expand All @@ -351,7 +373,7 @@ void Train(const nn::parallel::Rank &rank) {
.save_dir = save_dir,
.global_step = global_step,
.consumed_batches = consumed_batches,
.last_lr = FLAGS_learning_rate,
.last_lr = optimizer->learning_rate(),
.n_layer = model_config.n_layer,
.n_head = model_config.n_head,
.n_kv_head = model_config.n_kv_head,
Expand All @@ -367,6 +389,7 @@ void Train(const nn::parallel::Rank &rank) {
.rank = rank,
.model = *model,
.optimizer = *optimizer,
.lr_scheduler = scheduler.get(),
});
};

Expand Down Expand Up @@ -403,6 +426,7 @@ void Train(const nn::parallel::Rank &rank) {
Profiler::Instance().SetTag("Step_" + std::to_string(step));
#endif

const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
float lossf = 0.0f;
// model->Train();
if (pp_world_size == 1) {
Expand Down Expand Up @@ -448,6 +472,9 @@ void Train(const nn::parallel::Rank &rank) {
}

optimizer->Step();
if (scheduler) {
scheduler->Step();
}
} else {
auto [x, y] = *train_iter;
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
Expand All @@ -458,6 +485,9 @@ void Train(const nn::parallel::Rank &rank) {
y = std::make_shared<Tensor>(y->To(device));

lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype);
if (scheduler) {
scheduler->Step();
}
}

if (ddp_world_size > 1) {
Expand All @@ -473,11 +503,10 @@ void Train(const nn::parallel::Rank &rank) {
if (rank.IsLastRank()) {
size_t used_mb = 0, reserved_mb = 0;
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);

LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps,
used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
pp_world_size);

if ((step + 1) % FLAGS_freq_generate_txt == 0) {
Expand Down
41 changes: 35 additions & 6 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
#include "infini_train/include/lr_scheduler.h"
#include "infini_train/include/nn/lora/lora_utils.h"
#include "infini_train/include/nn/modules/loss.h"
#include "infini_train/include/nn/modules/module.h"
Expand Down Expand Up @@ -59,8 +60,14 @@ DEFINE_uint32(num_iteration, 10, "number of iterations to run");
DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation");
DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations");
DEFINE_double(learning_rate, 1e-5, "Peak learning rate.");
DEFINE_int32(zero_stage, 0, "ZeRO stage (0/1/2/3); 0 disables DistributedOptimizer");
// lr scheduler
DEFINE_double(min_lr, 0.0, "Minimum learning rate.");
DEFINE_string(lr_decay_style, "constant", "LR decay style: none|constant|linear|cosine|inverse-square-root");
DEFINE_int64(lr_warmup_iters, 0, "Number of linear warmup iterations.");
DEFINE_double(lr_warmup_init, 0.0, "Initial learning rate at the start of warmup.");
DEFINE_int64(lr_decay_iters, 0, "Number of iterations to decay LR over (0 = num_iteration).");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -105,12 +112,16 @@ constexpr char kDeviceCPU[] = "cpu";
constexpr char kDeviceCUDA[] = "cuda";
constexpr char kDtypeFP32[] = "float32";
constexpr char kDtypeBF16[] = "bfloat16";
const std::unordered_set<std::string> kSupportedLRDecayStyles
= {"none", "constant", "linear", "cosine", "inverse-square-root"};
} // namespace

DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
DEFINE_validator(device,
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 0 && value <= 3; });
DEFINE_validator(lr_decay_style,
[](const char *, const std::string &value) { return kSupportedLRDecayStyles.contains(value); });

void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;
Expand Down Expand Up @@ -294,6 +305,16 @@ void Train(const nn::parallel::Rank &rank) {
optimizer = optimizer_creator(params_to_optimize);
}

const int64_t lr_decay_iters = FLAGS_lr_decay_iters > 0 ? FLAGS_lr_decay_iters : FLAGS_num_iteration;
TrainingLRSchedulerConfig sched_config;
sched_config.lr = static_cast<float>(FLAGS_learning_rate);
sched_config.min_lr = static_cast<float>(FLAGS_min_lr);
sched_config.lr_decay_style = FLAGS_lr_decay_style;
sched_config.lr_decay_iters = lr_decay_iters;
sched_config.lr_warmup_iters = FLAGS_lr_warmup_iters;
sched_config.lr_warmup_init = static_cast<float>(FLAGS_lr_warmup_init);
auto scheduler = CreateLRScheduler(optimizer, sched_config);

auto train_iter = train_loader.begin();
std::shared_ptr<nn::Module> loss_fn
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(std::make_shared<VocabParallelCrossEntropyLoss>())
Expand All @@ -311,7 +332,8 @@ void Train(const nn::parallel::Rank &rank) {
.optimizer = optimizer,
.model_config = model_config,
.state = state,
.load_optimizer_state = true});
.load_optimizer_state = true,
.lr_scheduler = scheduler});

start_step = resume_result.global_step;
size_t consumed_batches = resume_result.consumed_batches;
Expand All @@ -331,7 +353,7 @@ void Train(const nn::parallel::Rank &rank) {
.save_dir = save_dir,
.global_step = global_step,
.consumed_batches = consumed_batches,
.last_lr = FLAGS_learning_rate,
.last_lr = optimizer->learning_rate(),
.n_layer = model_config.n_layer,
.n_head = model_config.n_head,
.n_kv_head = model_config.n_kv_head,
Expand All @@ -347,6 +369,7 @@ void Train(const nn::parallel::Rank &rank) {
.rank = rank,
.model = *model,
.optimizer = *optimizer,
.lr_scheduler = scheduler.get(),
});
};

Expand Down Expand Up @@ -381,6 +404,7 @@ void Train(const nn::parallel::Rank &rank) {
Profiler::Instance().SetTag("Step_" + std::to_string(step));
#endif

const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
float lossf = 0.0f;
if (pp_world_size == 1) {
// model->Train();
Expand Down Expand Up @@ -425,6 +449,9 @@ void Train(const nn::parallel::Rank &rank) {
}

optimizer->Step();
if (scheduler) {
scheduler->Step();
}
} else {
auto [x, y] = *train_iter;
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
Expand All @@ -435,6 +462,9 @@ void Train(const nn::parallel::Rank &rank) {
y = std::make_shared<Tensor>(y->To(device));

lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype);
if (scheduler) {
scheduler->Step();
}
}

if (ddp_world_size > 1) {
Expand All @@ -450,11 +480,10 @@ void Train(const nn::parallel::Rank &rank) {
if (rank.IsLastRank()) {
size_t used_mb = 0, reserved_mb = 0;
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);

LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps,
used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
pp_world_size);

if ((step + 1) % FLAGS_freq_generate_txt == 0) {
Expand Down
7 changes: 5 additions & 2 deletions infini_train/include/checkpoint/checkpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

namespace infini_train {
class Optimizer;
class LRScheduler;
class Tensor;
namespace nn {
class Module;
Expand All @@ -34,10 +35,12 @@ struct TrainerState {
class Checkpoint {
public:
static void Save(const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer *optimizer,
const TrainerState &state, bool save_optimizer_state);
const TrainerState &state, bool save_optimizer_state, const LRScheduler *lr_scheduler,
bool save_lr_scheduler_state);

static void Load(const std::filesystem::path &checkpoint_dir, nn::Module &model, Optimizer *optimizer,
TrainerState &state, bool load_optimizer_state);
TrainerState &state, bool load_optimizer_state, LRScheduler *lr_scheduler,
bool load_lr_scheduler_state);

private:
static void SaveStateDict(const std::filesystem::path &path,
Expand Down
9 changes: 9 additions & 0 deletions infini_train/include/checkpoint/checkpoint_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <cstdint>
#include <cstring>
#include <filesystem>
#include <memory>

#include "infini_train/include/checkpoint/checkpoint.h"
#include "infini_train/include/dataloader.h"
Expand All @@ -13,6 +14,10 @@
using namespace infini_train;
namespace nn = infini_train::nn;

namespace infini_train {
class LRScheduler;
}

namespace infini_train::nn {
class TransformerConfig;
}
Expand All @@ -25,6 +30,8 @@ struct ResumeFromCheckpointArgs {
const nn::TransformerConfig &model_config;
TrainerState &state;
bool load_optimizer_state;
std::shared_ptr<LRScheduler> lr_scheduler = nullptr;
bool load_lr_scheduler_state = true;
};

struct ResumeFromCheckpointResult {
Expand Down Expand Up @@ -52,6 +59,8 @@ struct SaveCheckpointArgs {
const nn::parallel::Rank &rank;
const nn::Module &model;
const Optimizer &optimizer;
const LRScheduler *lr_scheduler = nullptr;
bool save_lr_scheduler_state = true;
};

ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args);
Expand Down
Loading
Loading