diff --git a/csrc/layers/mlp/moe_mlp.hpp b/csrc/layers/mlp/moe_mlp.hpp index 970ea3c5..45bbe745 100644 --- a/csrc/layers/mlp/moe_mlp.hpp +++ b/csrc/layers/mlp/moe_mlp.hpp @@ -15,6 +15,9 @@ class MoeMLP : public infinicore::nn::Module { size_t hidden_size() const { return hidden_size_; } size_t moe_intermediate_size() const { return moe_intermediate_size_; } + infinicore::Tensor gate_weight() const { return gate_proj_->weight(); } + infinicore::Tensor up_weight() const { return up_proj_->weight(); } + infinicore::Tensor down_weight() const { return down_proj_->weight(); } void set_alpha(float alpha) { down_proj_->set_alpha(alpha); } protected: diff --git a/csrc/models/deepseek_v2/deepseek_v2_attention.cpp b/csrc/models/deepseek_v2/deepseek_v2_attention.cpp new file mode 100644 index 00000000..274b9a38 --- /dev/null +++ b/csrc/models/deepseek_v2/deepseek_v2_attention.cpp @@ -0,0 +1,183 @@ +#include "deepseek_v2_attention.hpp" + +#include "../../global_state/global_state.hpp" +#include "../../utils.hpp" +#include "infinicore/ops.hpp" +#include "infinicore/ops/broadcast_to.hpp" +#include "infinicore/ops/cat.hpp" +#include "infinicore/ops/pad.hpp" + +#include +#include + +namespace infinilm::models::deepseek_v2 { +namespace { + +float yarn_get_mscale(float scale, float mscale) { + if (scale <= 1.0f) { + return 1.0f; + } + return 0.1f * mscale * std::log(scale) + 1.0f; +} + +} // namespace + +DeepseekV2Attention::DeepseekV2Attention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) { + layer_idx_ = layer_idx; + hidden_size_ = model_config->get("hidden_size"); + qk_nope_head_dim_ = model_config->get("qk_nope_head_dim"); + qk_rope_head_dim_ = model_config->get("qk_rope_head_dim"); + q_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_; + v_head_dim_ = model_config->get("v_head_dim"); + + const auto &dtype{model_config->get_dtype()}; + const size_t total_num_heads = model_config->get("num_attention_heads"); + const size_t kv_lora_rank = model_config->get("kv_lora_rank"); + const bool attention_bias = model_config->get_or("attention_bias", false); + const double rms_norm_eps = model_config->get("rms_norm_eps"); + + const auto &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + const int tp_rank = rank_info.tp_rank; + const int tp_size = rank_info.tp_size; + if ((total_num_heads < static_cast(tp_size)) || (total_num_heads % static_cast(tp_size) != 0)) { + throw std::runtime_error("DeepseekV2Attention: num_attention_heads must be divisible by tp_size"); + } + num_attention_heads_ = total_num_heads / static_cast(tp_size); + attention_backend_ = infinilm::global_state::get_infinilm_config().attention_backend; + + auto quantization_method = model_config->get_quantization_method(); + INFINICORE_NN_MODULE_INIT(q_proj, hidden_size_, total_num_heads * q_head_dim_, quantization_method, false, dtype, device, tp_rank, tp_size); + INFINICORE_NN_MODULE_INIT(kv_a_proj_with_mqa, hidden_size_, kv_lora_rank + qk_rope_head_dim_, attention_bias, dtype, device); + INFINICORE_NN_MODULE_INIT(kv_a_layernorm, kv_lora_rank, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(kv_b_proj, kv_lora_rank, total_num_heads * (qk_nope_head_dim_ + v_head_dim_), quantization_method, false, dtype, device, tp_rank, tp_size); + INFINICORE_NN_MODULE_INIT(o_proj, total_num_heads * v_head_dim_, hidden_size_, quantization_method, attention_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + + const size_t max_position_embeddings = model_config->get("max_position_embeddings"); + const double rope_theta = model_config->get("rope_theta"); + rotary_emb_ = std::make_shared( + qk_rope_head_dim_, qk_rope_head_dim_, max_position_embeddings, rope_theta, + infinicore::nn::RoPE::Algo::GPT_J, dtype, device, nullptr); + + softmax_scale_ = 1.0f / std::sqrt(static_cast(q_head_dim_)); + auto &config_json = model_config->get_config_json(); + if (config_json.contains("rope_scaling") && config_json["rope_scaling"].is_object()) { + const auto &rope_scaling = config_json["rope_scaling"]; + const float mscale_all_dim = rope_scaling.value("mscale_all_dim", 0.0f); + if (mscale_all_dim != 0.0f) { + const float scaling_factor = rope_scaling.value("factor", 1.0f); + const float mscale = yarn_get_mscale(scaling_factor, mscale_all_dim); + softmax_scale_ *= mscale * mscale; + } + } + + attn_ = std::make_shared( + num_attention_heads_, q_head_dim_, softmax_scale_, num_attention_heads_, layer_idx_, + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + infinilm::layers::attention::init_kv_cache_quant_params( + [this](const std::string &n, infinicore::nn::Parameter p) { this->register_parameter(n, std::move(p)); }, + device, kv_cache_k_scale_, kv_cache_v_scale_); +} + +infinicore::Tensor DeepseekV2Attention::position_ids_for_rope_(const infinicore::Tensor &position_ids) const { + auto pos_shape = position_ids->shape(); + if (pos_shape.size() == 2) { + return position_ids->narrow({{0, 0, 1}})->contiguous()->view({pos_shape[1]}); + } + if (pos_shape.size() == 1) { + return position_ids->contiguous(); + } + throw std::runtime_error("DeepseekV2Attention: unexpected position_ids shape"); +} + +infinicore::Tensor DeepseekV2Attention::trim_value_padding_(const infinicore::Tensor &attn_output) const { + const auto shape = attn_output->shape(); + const size_t batch_size = shape[0]; + const size_t seq_len = shape[1]; + return attn_output->view({batch_size, seq_len, num_attention_heads_, q_head_dim_}) + ->narrow({{3, 0, v_head_dim_}}) + ->contiguous() + ->view({batch_size, seq_len, num_attention_heads_ * v_head_dim_}); +} + +infinicore::Tensor DeepseekV2Attention::forward(const infinicore::Tensor &positions, + const infinicore::Tensor &hidden_states) const { + if (::infinilm::backends::AttentionBackend::STATIC_ATTN == attention_backend_) { + return forward_static_(positions, hidden_states); + } + return forward_paged_(positions, hidden_states); +} + +infinicore::Tensor DeepseekV2Attention::forward_static_(const infinicore::Tensor &position_ids, + const infinicore::Tensor &hidden_states) const { + auto shape = hidden_states->shape(); + const size_t batch_size = shape[0]; + const size_t seq_len = shape[1]; + auto hidden_states_mutable = hidden_states; + + auto q = q_proj_->forward(hidden_states_mutable)->view({batch_size, seq_len, num_attention_heads_, q_head_dim_}); + auto q_nope = q->narrow({{3, 0, qk_nope_head_dim_}}); + auto q_pe = q->narrow({{3, qk_nope_head_dim_, qk_rope_head_dim_}})->contiguous(); + + auto compressed = kv_a_proj_with_mqa_->forward(hidden_states_mutable); + auto compressed_kv = compressed->narrow({{2, 0, kv_a_layernorm_->normalized_shape()}})->contiguous(); + auto k_pe = compressed->narrow({{2, kv_a_layernorm_->normalized_shape(), qk_rope_head_dim_}})->contiguous(); + + auto kv_norm = kv_a_layernorm_->forward(compressed_kv); + auto kv = kv_b_proj_->forward(kv_norm)->view({batch_size, seq_len, num_attention_heads_, qk_nope_head_dim_ + v_head_dim_}); + auto k_nope = kv->narrow({{3, 0, qk_nope_head_dim_}}); + auto value_states = kv->narrow({{3, qk_nope_head_dim_, v_head_dim_}})->contiguous(); + + auto pos_ids = position_ids_for_rope_(position_ids); + q_pe = rotary_emb_->forward(q_pe, pos_ids, true); + auto k_pe_broadcast = infinicore::op::broadcast_to(k_pe->view({batch_size, seq_len, 1, qk_rope_head_dim_}), + {static_cast(batch_size), static_cast(seq_len), static_cast(num_attention_heads_), static_cast(qk_rope_head_dim_)}); + k_pe_broadcast = rotary_emb_->forward(k_pe_broadcast, pos_ids, true); + + auto query_states = infinicore::op::cat({q_nope, q_pe}, 3); + auto key_states = infinicore::op::cat({k_nope, k_pe_broadcast}, 3); + auto value_padded = infinicore::op::pad(value_states, {0, static_cast(q_head_dim_ - v_head_dim_)}, "constant", 0.0); + + auto attn_output = attn_->forward(query_states, key_states, value_padded); + auto trimmed_output = trim_value_padding_(attn_output); + return o_proj_->forward(trimmed_output); +} + +infinicore::Tensor DeepseekV2Attention::forward_paged_(const infinicore::Tensor &position_ids, + const infinicore::Tensor &hidden_states) const { + auto shape = hidden_states->shape(); + const size_t batch_size = shape[0]; + const size_t seq_len = shape[1]; + ASSERT_EQ(batch_size, 1); + auto hidden_states_mutable = hidden_states; + + auto q = q_proj_->forward(hidden_states_mutable)->view({seq_len, num_attention_heads_, q_head_dim_}); + auto q_nope = q->narrow({{2, 0, qk_nope_head_dim_}}); + auto q_pe = q->narrow({{2, qk_nope_head_dim_, qk_rope_head_dim_}})->contiguous(); + + auto compressed = kv_a_proj_with_mqa_->forward(hidden_states_mutable)->view({seq_len, kv_a_layernorm_->normalized_shape() + qk_rope_head_dim_}); + auto compressed_kv = compressed->narrow({{1, 0, kv_a_layernorm_->normalized_shape()}})->contiguous(); + auto k_pe = compressed->narrow({{1, kv_a_layernorm_->normalized_shape(), qk_rope_head_dim_}})->contiguous(); + + auto kv_norm = kv_a_layernorm_->forward(compressed_kv); + auto kv = kv_b_proj_->forward(kv_norm)->view({seq_len, num_attention_heads_, qk_nope_head_dim_ + v_head_dim_}); + auto k_nope = kv->narrow({{2, 0, qk_nope_head_dim_}}); + auto value_states = kv->narrow({{2, qk_nope_head_dim_, v_head_dim_}})->contiguous(); + + auto pos_ids = position_ids_for_rope_(position_ids); + q_pe = rotary_emb_->forward(q_pe, pos_ids, true); + auto k_pe_broadcast = infinicore::op::broadcast_to(k_pe->view({seq_len, 1, qk_rope_head_dim_}), + {static_cast(seq_len), static_cast(num_attention_heads_), static_cast(qk_rope_head_dim_)}); + k_pe_broadcast = rotary_emb_->forward(k_pe_broadcast, pos_ids, true); + + auto query_states = infinicore::op::cat({q_nope, q_pe}, 2); + auto key_states = infinicore::op::cat({k_nope, k_pe_broadcast}, 2); + auto value_padded = infinicore::op::pad(value_states, {0, static_cast(q_head_dim_ - v_head_dim_)}, "constant", 0.0); + + auto attn_output = attn_->forward(query_states, key_states, value_padded); + auto trimmed_output = trim_value_padding_(attn_output); + return o_proj_->forward(trimmed_output); +} + +} // namespace infinilm::models::deepseek_v2 diff --git a/csrc/models/deepseek_v2/deepseek_v2_attention.hpp b/csrc/models/deepseek_v2/deepseek_v2_attention.hpp new file mode 100644 index 00000000..afedb071 --- /dev/null +++ b/csrc/models/deepseek_v2/deepseek_v2_attention.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include "../../config/model_config.hpp" +#include "../../layers/attention/attention.hpp" +#include "../../layers/linear/linear.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/nn/rmsnorm.hpp" +#include "infinicore/nn/rope.hpp" +#include "infinicore/tensor.hpp" + +#include + +namespace infinilm::models::deepseek_v2 { + +class DeepseekV2Attention : public infinicore::nn::Module { +public: + DeepseekV2Attention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &positions, + const infinicore::Tensor &hidden_states) const; + +private: + infinicore::Tensor forward_static_(const infinicore::Tensor &positions, + const infinicore::Tensor &hidden_states) const; + infinicore::Tensor forward_paged_(const infinicore::Tensor &positions, + const infinicore::Tensor &hidden_states) const; + infinicore::Tensor trim_value_padding_(const infinicore::Tensor &attn_output) const; + infinicore::Tensor position_ids_for_rope_(const infinicore::Tensor &position_ids) const; + + size_t layer_idx_{0}; + size_t hidden_size_{0}; + size_t num_attention_heads_{0}; + size_t qk_nope_head_dim_{0}; + size_t qk_rope_head_dim_{0}; + size_t q_head_dim_{0}; + size_t v_head_dim_{0}; + float softmax_scale_{1.0f}; + infinilm::backends::AttentionBackend attention_backend_; + + INFINICORE_NN_MODULE(infinilm::layers::linear::ColumnParallelLinear, q_proj); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, kv_a_proj_with_mqa); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, kv_a_layernorm); + INFINICORE_NN_MODULE(infinilm::layers::linear::ColumnParallelLinear, kv_b_proj); + INFINICORE_NN_MODULE(infinilm::layers::linear::RowParallelLinear, o_proj); + + std::shared_ptr rotary_emb_; + std::shared_ptr attn_; + infinicore::nn::Parameter kv_cache_k_scale_; + infinicore::nn::Parameter kv_cache_v_scale_; +}; + +} // namespace infinilm::models::deepseek_v2 diff --git a/csrc/models/deepseek_v2/deepseek_v2_decoder_layer.cpp b/csrc/models/deepseek_v2/deepseek_v2_decoder_layer.cpp new file mode 100644 index 00000000..2843bc86 --- /dev/null +++ b/csrc/models/deepseek_v2/deepseek_v2_decoder_layer.cpp @@ -0,0 +1,38 @@ +#include "deepseek_v2_decoder_layer.hpp" + +namespace infinilm::models::deepseek_v2 { + +DeepseekV2DecoderLayer::DeepseekV2DecoderLayer(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) { + const auto &dtype{model_config->get_dtype()}; + const size_t hidden_size = model_config->get("hidden_size"); + const double rms_norm_eps = model_config->get("rms_norm_eps"); + INFINICORE_NN_MODULE_INIT(input_layernorm, hidden_size, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(post_attention_layernorm, hidden_size, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(self_attn, model_config, layer_idx, device); + + const size_t first_k_dense_replace = model_config->get_or("first_k_dense_replace", 0); + const size_t moe_layer_freq = model_config->get_or("moe_layer_freq", 1); + use_moe_ = model_config->get_or("n_routed_experts", 0) > 0 + && layer_idx >= first_k_dense_replace + && (moe_layer_freq == 0 || layer_idx % moe_layer_freq == 0); + if (use_moe_) { + moe_mlp_ = this->register_module("mlp", model_config, device); + } else { + dense_mlp_ = this->register_module("mlp", model_config, device); + } +} + +std::tuple +DeepseekV2DecoderLayer::forward(const infinicore::Tensor &positions, + infinicore::Tensor &hidden_states, + infinicore::Tensor &residual) const { + input_layernorm_->forward_inplace(hidden_states, residual); + hidden_states = self_attn_->forward(positions, hidden_states); + post_attention_layernorm_->forward_inplace(hidden_states, residual); + hidden_states = use_moe_ ? moe_mlp_->forward(hidden_states) : dense_mlp_->forward(hidden_states); + return {hidden_states, residual}; +} + +} // namespace infinilm::models::deepseek_v2 diff --git a/csrc/models/deepseek_v2/deepseek_v2_decoder_layer.hpp b/csrc/models/deepseek_v2/deepseek_v2_decoder_layer.hpp new file mode 100644 index 00000000..e7a36d85 --- /dev/null +++ b/csrc/models/deepseek_v2/deepseek_v2_decoder_layer.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include "../../config/model_config.hpp" +#include "deepseek_v2_attention.hpp" +#include "deepseek_v2_moe.hpp" +#include "infinicore/device.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/nn/rmsnorm.hpp" +#include "infinicore/tensor.hpp" + +#include +#include + +namespace infinilm::models::deepseek_v2 { + +class DeepseekV2DecoderLayer : public infinicore::nn::Module { +public: + DeepseekV2DecoderLayer(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); + + std::tuple forward(const infinicore::Tensor &positions, + infinicore::Tensor &hidden_states, + infinicore::Tensor &residual) const; + +private: + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_attention_layernorm); + INFINICORE_NN_MODULE(DeepseekV2Attention, self_attn); + INFINICORE_NN_MODULE(DeepseekV2MLP, dense_mlp); + INFINICORE_NN_MODULE(DeepseekV2MoE, moe_mlp); + bool use_moe_{false}; +}; + +} // namespace infinilm::models::deepseek_v2 diff --git a/csrc/models/deepseek_v2/deepseek_v2_for_causal_lm.cpp b/csrc/models/deepseek_v2/deepseek_v2_for_causal_lm.cpp new file mode 100644 index 00000000..59d52f45 --- /dev/null +++ b/csrc/models/deepseek_v2/deepseek_v2_for_causal_lm.cpp @@ -0,0 +1,83 @@ +#include "deepseek_v2_for_causal_lm.hpp" + +#include "../models_registry.hpp" +#include "infinicore/ops.hpp" + +#include +#include + +namespace infinilm::models::deepseek_v2 { + +DeepseekV2Model::DeepseekV2Model(std::shared_ptr model_config, + const infinicore::Device &device) { + const auto &dtype{model_config->get_dtype()}; + const size_t vocab_size = model_config->get("vocab_size"); + const size_t hidden_size = model_config->get("hidden_size"); + const size_t num_hidden_layers = model_config->get("num_hidden_layers"); + const double rms_norm_eps = model_config->get("rms_norm_eps"); + + INFINICORE_NN_MODULE_INIT(embed_tokens, vocab_size, hidden_size, std::nullopt, dtype, device); + layers_.reserve(num_hidden_layers); + for (size_t i = 0; i < num_hidden_layers; ++i) { + layers_.push_back(this->register_module("layers." + std::to_string(i), model_config, i, device)); + } + INFINICORE_NN_MODULE_INIT(norm, hidden_size, rms_norm_eps, dtype, device); +} + +infinicore::Tensor DeepseekV2Model::forward(const infinilm::InfinilmModel::Input &input) const { + auto input_ids = input.input_ids.value(); + auto positions = input.position_ids.value(); + auto hidden_states = embed_tokens_->forward(input_ids); + + infinicore::Tensor residual; + for (const auto &layer : layers_) { + layer->forward(positions, hidden_states, residual); + } + norm_->forward_inplace(hidden_states, residual); + return hidden_states; +} + +DeepseekV2ForCausalLM::DeepseekV2ForCausalLM(std::shared_ptr model_config, + const infinicore::Device &device) { + model_config_ = model_config; + const auto &dtype{model_config->get_dtype()}; + const size_t hidden_size = model_config->get("hidden_size"); + const size_t vocab_size = model_config->get("vocab_size"); + INFINICORE_NN_MODULE_INIT(model, model_config, device); + INFINICORE_NN_MODULE_INIT(lm_head, hidden_size, vocab_size, false, dtype, device); +} + +infinilm::InfinilmModel::Output DeepseekV2ForCausalLM::forward(const infinilm::InfinilmModel::Input &input) const { + auto hidden_states = model_->forward(input); + auto logits = lm_head_->forward(hidden_states); + return {logits}; +} + +std::shared_ptr create_deepseek_v2_model_config(std::shared_ptr model_config) { + const std::string model_type = model_config->get("model_type"); + if ("deepseek_v2" != model_type) { + throw std::runtime_error("create_deepseek_v2_model_config: model_type is not deepseek_v2"); + } + + auto &config_json = model_config->get_config_json(); + const size_t q_head_dim = config_json.at("qk_nope_head_dim").get() + config_json.at("qk_rope_head_dim").get(); + config_json["head_dim"] = q_head_dim; + config_json["num_experts"] = config_json.value("n_routed_experts", 0); + config_json["mlp_bias"] = false; + if (!config_json.contains("attention_output_bias")) { + config_json["attention_output_bias"] = config_json.value("attention_bias", false); + } + if (!config_json.contains("dtype") && config_json.contains("torch_dtype")) { + config_json["dtype"] = config_json["torch_dtype"]; + } + return model_config; +} + +} // namespace infinilm::models::deepseek_v2 + +namespace { +INFINILM_REGISTER_CAUSAL_LM_MODEL( + deepseek_v2, + infinilm::models::deepseek_v2::DeepseekV2ForCausalLM, + infinilm::models::deepseek_v2::create_deepseek_v2_model_config); +} // namespace diff --git a/csrc/models/deepseek_v2/deepseek_v2_for_causal_lm.hpp b/csrc/models/deepseek_v2/deepseek_v2_for_causal_lm.hpp new file mode 100644 index 00000000..b5e22f6a --- /dev/null +++ b/csrc/models/deepseek_v2/deepseek_v2_for_causal_lm.hpp @@ -0,0 +1,41 @@ +#pragma once + +#include "../../layers/linear/linear.hpp" +#include "../infinilm_model.hpp" +#include "deepseek_v2_decoder_layer.hpp" +#include "infinicore/nn/embedding.hpp" +#include "infinicore/nn/rmsnorm.hpp" +#include "infinicore/tensor.hpp" + +#include + +namespace infinilm::models::deepseek_v2 { + +class DeepseekV2Model : public infinicore::nn::Module { +public: + DeepseekV2Model(std::shared_ptr model_config, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinilm::InfinilmModel::Input &input) const; + +private: + INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_tokens); + INFINICORE_NN_MODULE_VEC(DeepseekV2DecoderLayer, layers); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, norm); +}; + +class DeepseekV2ForCausalLM : public infinilm::InfinilmModel { +public: + DeepseekV2ForCausalLM(std::shared_ptr model_config, + const infinicore::Device &device); + + Output forward(const Input &input) const override; + +private: + INFINICORE_NN_MODULE(DeepseekV2Model, model); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, lm_head); +}; + +std::shared_ptr create_deepseek_v2_model_config(std::shared_ptr model_config); + +} // namespace infinilm::models::deepseek_v2 diff --git a/csrc/models/deepseek_v2/deepseek_v2_moe.cpp b/csrc/models/deepseek_v2/deepseek_v2_moe.cpp new file mode 100644 index 00000000..d486fa3b --- /dev/null +++ b/csrc/models/deepseek_v2/deepseek_v2_moe.cpp @@ -0,0 +1,137 @@ +#include "deepseek_v2_moe.hpp" + +#include "../../global_state/global_state.hpp" +#include "../../utils.hpp" +#include "infinicore/ops.hpp" +#include "infinicore/ops/distributed/allreduce.hpp" + +#include + +namespace infinilm::models::deepseek_v2 { + +DeepseekV2TopKRouter::DeepseekV2TopKRouter(std::shared_ptr model_config, + const infinicore::Device &device) { + const auto &dtype{model_config->get_dtype()}; + const size_t hidden_size = model_config->get("hidden_size"); + num_experts_ = model_config->get("num_experts"); + num_experts_per_tok_ = model_config->get("num_experts_per_tok"); + norm_topk_prob_ = model_config->get("norm_topk_prob"); + + ASSERT((num_experts_ > 0) && (num_experts_per_tok_ > 0) && (num_experts_per_tok_ <= num_experts_)); + INFINICORE_NN_PARAMETER_INIT(weight, ({num_experts_, hidden_size}, dtype, device)); +} + +std::tuple +DeepseekV2TopKRouter::forward(const infinicore::Tensor &hidden_states) const { + ASSERT(hidden_states->ndim() == 2); + const size_t ntoken = hidden_states->shape()[0]; + auto router_logits = infinicore::op::linear(hidden_states, weight_, std::nullopt, 1.0f); + auto router_scores = infinicore::Tensor::empty({ntoken, num_experts_per_tok_}, infinicore::DataType::F32, hidden_states->device()); + auto router_indices = infinicore::Tensor::empty({ntoken, num_experts_per_tok_}, infinicore::DataType::I32, hidden_states->device()); + infinicore::op::topksoftmax(router_scores, router_indices, router_logits, num_experts_per_tok_, norm_topk_prob_); + return {router_scores, router_indices}; +} + +DeepseekV2Experts::DeepseekV2Experts(std::shared_ptr model_config, + const infinicore::Device &device) { + hidden_size_ = model_config->get("hidden_size"); + moe_intermediate_size_ = model_config->get("moe_intermediate_size"); + num_experts_ = model_config->get("num_experts"); + num_experts_per_tok_ = model_config->get("num_experts_per_tok"); + ASSERT((num_experts_ > 0) && (num_experts_per_tok_ > 0) && (num_experts_per_tok_ <= num_experts_)); + const auto &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + tp_size_ = static_cast(rank_info.tp_size); + communicator_ = rank_info.comm; + + experts_.reserve(num_experts_); + gate_weights_.reserve(num_experts_); + up_weights_.reserve(num_experts_); + down_weights_.reserve(num_experts_); + for (size_t i = 0; i < num_experts_; ++i) { + auto expert = this->register_module(std::to_string(i), model_config, device); + gate_weights_.push_back(expert->gate_weight()); + up_weights_.push_back(expert->up_weight()); + down_weights_.push_back(expert->down_weight()); + experts_.push_back(std::move(expert)); + } + local_moe_intermediate_size_ = gate_weights_.empty() ? moe_intermediate_size_ : gate_weights_.front()->shape()[0]; +} + +infinicore::Tensor DeepseekV2Experts::forward_cpu_routed_(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &top_k_index, + const infinicore::Tensor &top_k_weights) const { + auto top_k_weights_cpu = top_k_weights->to(infinicore::Device::Type::CPU); + auto top_k_index_cpu = top_k_index->to(infinicore::Device::Type::CPU); + auto *top_k_index_ptr = reinterpret_cast(top_k_index_cpu->data()); + auto *top_k_weights_ptr = reinterpret_cast(top_k_weights_cpu->data()); + + const size_t ntoken = hidden_states->shape()[0]; + auto final_hidden_states = infinicore::Tensor::empty(hidden_states->shape(), hidden_states->dtype(), hidden_states->device()); + for (size_t itok = 0; itok < ntoken; ++itok) { + auto hidden_states_i = hidden_states->narrow({{0, itok, 1}}); + const size_t route_row = itok * num_experts_per_tok_; + + infinicore::Tensor final_hidden_states_i; + for (size_t k = 0; k < num_experts_per_tok_; ++k) { + const int index = top_k_index_ptr[route_row + k]; + const float score = top_k_weights_ptr[route_row + k]; + ASSERT(index >= 0 && static_cast(index) < num_experts_); + experts_[index]->set_alpha(score); + auto expert_out = experts_[index]->forward(hidden_states_i); + if (k == 0) { + final_hidden_states_i = expert_out; + } else { + infinicore::op::add_(final_hidden_states_i, final_hidden_states_i, expert_out); + } + } + final_hidden_states->narrow({{0, itok, 1}})->copy_from(final_hidden_states_i); + } + return final_hidden_states; +} + +infinicore::Tensor DeepseekV2Experts::forward(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &top_k_index, + const infinicore::Tensor &top_k_weights) const { + ASSERT(hidden_states->ndim() == 2); + if (hidden_states->device().getType() == infinicore::Device::Type::NVIDIA && (hidden_states->dtype() == infinicore::DataType::BF16 || hidden_states->dtype() == infinicore::DataType::F16)) { + auto output = infinicore::op::deepseek_moe(hidden_states, top_k_index, top_k_weights, + gate_weights_, up_weights_, down_weights_, + local_moe_intermediate_size_, num_experts_); + if (tp_size_ > 1 && communicator_ != nullptr) { + infinicore::op::distributed::allreduce_(output, output, INFINICCL_SUM, communicator_); + } + return output; + } + return forward_cpu_routed_(hidden_states, top_k_index, top_k_weights); +} + +DeepseekV2MoE::DeepseekV2MoE(std::shared_ptr model_config, + const infinicore::Device &device) { + INFINICORE_NN_MODULE_INIT(gate, model_config, device); + INFINICORE_NN_MODULE_INIT(experts, model_config, device); + + const size_t n_shared_experts = model_config->get_or("n_shared_experts", 0); + has_shared_experts_ = n_shared_experts > 0; + if (has_shared_experts_) { + auto shared_config_json = model_config->get_config_json(); + shared_config_json["intermediate_size"] = model_config->get("moe_intermediate_size") * n_shared_experts; + auto shared_config = std::make_shared(shared_config_json); + INFINICORE_NN_MODULE_INIT(shared_experts, shared_config, device); + } +} + +infinicore::Tensor DeepseekV2MoE::forward(const infinicore::Tensor &hidden_states) const { + ASSERT(hidden_states->ndim() == 3); + const auto shape = hidden_states->shape(); + auto hidden_states_reshaped = hidden_states->view({shape[0] * shape[1], shape[2]}); + + auto [routing_weights, selected_experts] = gate_->forward(hidden_states_reshaped); + auto final_hidden_states = experts_->forward(hidden_states_reshaped, selected_experts, routing_weights)->view(shape); + if (has_shared_experts_) { + auto shared_out = shared_experts_->forward(hidden_states); + final_hidden_states = infinicore::op::add(final_hidden_states, shared_out); + } + return final_hidden_states; +} + +} // namespace infinilm::models::deepseek_v2 diff --git a/csrc/models/deepseek_v2/deepseek_v2_moe.hpp b/csrc/models/deepseek_v2/deepseek_v2_moe.hpp new file mode 100644 index 00000000..9f788fc9 --- /dev/null +++ b/csrc/models/deepseek_v2/deepseek_v2_moe.hpp @@ -0,0 +1,76 @@ +#pragma once + +#include "../../config/model_config.hpp" +#include "../../layers/common_modules.hpp" +#include "../../layers/linear/linear.hpp" +#include "../../layers/mlp/mlp.hpp" +#include "infinicore/device.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/tensor.hpp" +#include + +#include +#include +#include + +namespace infinilm::models::deepseek_v2 { + +using DeepseekV2MLP = infinilm::layers::mlp::MLP; +using DeepseekV2ExpertMLP = infinilm::layers::MoeMLP; + +class DeepseekV2TopKRouter : public infinicore::nn::Module { +public: + DeepseekV2TopKRouter(std::shared_ptr model_config, + const infinicore::Device &device); + + std::tuple forward(const infinicore::Tensor &hidden_states) const; + +protected: + INFINICORE_NN_PARAMETER(weight); + size_t num_experts_per_tok_{0}; + size_t num_experts_{0}; + bool norm_topk_prob_{false}; +}; + +class DeepseekV2Experts : public infinicore::nn::Module { +public: + DeepseekV2Experts(std::shared_ptr model_config, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &top_k_index, + const infinicore::Tensor &top_k_weights) const; + +protected: + infinicore::Tensor forward_cpu_routed_(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &top_k_index, + const infinicore::Tensor &top_k_weights) const; + + INFINICORE_NN_MODULE_VEC(DeepseekV2ExpertMLP, experts); + std::vector gate_weights_; + std::vector up_weights_; + std::vector down_weights_; + size_t hidden_size_{0}; + size_t moe_intermediate_size_{0}; + size_t local_moe_intermediate_size_{0}; + size_t num_experts_per_tok_{0}; + size_t num_experts_{0}; + size_t tp_size_{1}; + infinicclComm_t communicator_{nullptr}; +}; + +class DeepseekV2MoE : public infinicore::nn::Module { +public: + DeepseekV2MoE(std::shared_ptr model_config, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + +protected: + INFINICORE_NN_MODULE(DeepseekV2TopKRouter, gate); + INFINICORE_NN_MODULE(DeepseekV2Experts, experts); + INFINICORE_NN_MODULE(DeepseekV2MLP, shared_experts); + bool has_shared_experts_{false}; +}; + +} // namespace infinilm::models::deepseek_v2 diff --git a/python/infinilm/processors/sentencepiece_processor.py b/python/infinilm/processors/sentencepiece_processor.py index 1741041b..19eefbb3 100644 --- a/python/infinilm/processors/sentencepiece_processor.py +++ b/python/infinilm/processors/sentencepiece_processor.py @@ -73,3 +73,8 @@ class InternLMProcessor(SentencePieceProcessor): @register_processor("mistral") class MistralProcessor(SentencePieceProcessor): pass + + +@register_processor("deepseek_v2") +class DeepSeekV2Processor(BasicLLMProcessor): + pass