Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions csrc/engine/infer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@ InferEngine::InferEngine(
const cache::CacheConfig *cache_config,
bool enable_graph_compiling,
backends::AttentionBackend attention_backend,
std::optional<infinicore::DataType> kv_cache_dtype) // Changed parameter
: communication_group_(distributed_config, device_type), attention_backend_(attention_backend) {
std::optional<infinicore::DataType> kv_cache_dtype,
bool use_mla)
: communication_group_(distributed_config, device_type), attention_backend_(attention_backend), use_mla_(use_mla) {
if (cache_config != nullptr) {
cache_config_ = cache_config->unique_copy();
}

// Load model config if model_path is provided, model_path must be valid, and config.json exists
this->model_config_ = infinilm::config::ConfigFactory::createConfig(config_str);
auto infinilm_config = std::make_shared<infinilm::global_state::InfinilmConfig>(attention_backend, this->model_config_);
auto infinilm_config = std::make_shared<infinilm::global_state::InfinilmConfig>(attention_backend, this->model_config_, use_mla);

// Only support offline int8 kv cache quantization in this version
if (kv_cache_dtype.has_value()) {
Expand Down
4 changes: 3 additions & 1 deletion csrc/engine/infer_engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class InferEngine {
const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default,
std::optional<infinicore::DataType> kv_cache_dtype = std::nullopt);
std::optional<infinicore::DataType> kv_cache_dtype = std::nullopt,
bool use_mla = false);

// Load a parameter to all workers (each can extract its shard inside RankWorker)
void load_param(const std::string &name, const infinicore::Tensor &param);
Expand Down Expand Up @@ -68,6 +69,7 @@ class InferEngine {
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
backends::AttentionBackend attention_backend_ = backends::AttentionBackend::Default;
bool weights_finalized_ = false;
bool use_mla_{false};
};

} // namespace infinilm::engine
5 changes: 4 additions & 1 deletion csrc/global_state/infinilm_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@ struct InfinilmConfig {
public:
InfinilmConfig() = default;
InfinilmConfig(const infinilm::backends::AttentionBackend &backend,
const std::shared_ptr<infinilm::config::ModelConfig> &model_config)
const std::shared_ptr<infinilm::config::ModelConfig> &model_config,
bool use_mla = false)
: attention_backend(backend),
use_mla(use_mla),
model_config(model_config) {}

public:
infinilm::backends::AttentionBackend attention_backend;
bool use_mla{false};
std::shared_ptr<infinilm::config::ModelConfig> model_config;
};

Expand Down
2 changes: 1 addition & 1 deletion csrc/layers/attention/backends/static_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ infinicore::Tensor StaticAttentionImpl::forward(const AttentionLayer &layer,

// Compute attention
size_t ngroup = num_heads_ / num_kv_heads_;
auto Q = q_reshaped->view({batch_size * num_kv_heads_, ngroup * seq_len, head_dim_});
auto Q = q_reshaped->contiguous()->view({batch_size * num_kv_heads_, ngroup * seq_len, head_dim_});
auto K = k_total->view({batch_size * num_kv_heads_, total_seq_len, head_dim_});
auto V = v_total->view({batch_size * num_kv_heads_, total_seq_len, head_dim_});

Expand Down
11 changes: 9 additions & 2 deletions csrc/models/deepseek_v2/deepseek_v2_decoder_layer.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "deepseek_v2_decoder_layer.hpp"

#include "../../global_state/global_state.hpp"

namespace infinilm::models::deepseek_v2 {

DeepseekV2DecoderLayer::DeepseekV2DecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config,
Expand All @@ -10,7 +12,11 @@ DeepseekV2DecoderLayer::DeepseekV2DecoderLayer(std::shared_ptr<infinilm::config:
const double rms_norm_eps = model_config->get<double>("rms_norm_eps");
INFINICORE_NN_MODULE_INIT(input_layernorm, hidden_size, rms_norm_eps, dtype, device);
INFINICORE_NN_MODULE_INIT(post_attention_layernorm, hidden_size, rms_norm_eps, dtype, device);
INFINICORE_NN_MODULE_INIT(self_attn, model_config, layer_idx, device);
if (infinilm::global_state::get_infinilm_config().use_mla) {
self_attn_ = std::make_shared<DeepseekV2SelfAttention>(this->register_module<DeepseekV2MLAAttention>("self_attn", model_config, layer_idx, device));
} else {
self_attn_ = std::make_shared<DeepseekV2SelfAttention>(this->register_module<DeepseekV2Attention>("self_attn", model_config, layer_idx, device));
}

const size_t first_k_dense_replace = model_config->get_or<size_t>("first_k_dense_replace", 0);
const size_t moe_layer_freq = model_config->get_or<size_t>("moe_layer_freq", 1);
Expand All @@ -29,7 +35,8 @@ DeepseekV2DecoderLayer::forward(const infinicore::Tensor &positions,
infinicore::Tensor &hidden_states,
infinicore::Tensor &residual) const {
input_layernorm_->forward_inplace(hidden_states, residual);
hidden_states = self_attn_->forward(positions, hidden_states);
hidden_states = std::visit(
[&](auto &attn_ptr) { return attn_ptr->forward(positions, hidden_states); }, *self_attn_);
post_attention_layernorm_->forward_inplace(hidden_states, residual);
hidden_states = use_moe_ ? moe_mlp_->forward(hidden_states) : dense_mlp_->forward(hidden_states);
return {hidden_states, residual};
Expand Down
6 changes: 5 additions & 1 deletion csrc/models/deepseek_v2/deepseek_v2_decoder_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "../../config/model_config.hpp"
#include "deepseek_v2_attention.hpp"
#include "deepseek_v2_mla_attention.hpp"
#include "deepseek_v2_moe.hpp"
#include "infinicore/device.hpp"
#include "infinicore/nn/module.hpp"
Expand All @@ -10,9 +11,12 @@

#include <memory>
#include <tuple>
#include <variant>

namespace infinilm::models::deepseek_v2 {

using DeepseekV2SelfAttention = std::variant<std::shared_ptr<DeepseekV2Attention>, std::shared_ptr<DeepseekV2MLAAttention>>;

class DeepseekV2DecoderLayer : public infinicore::nn::Module {
public:
DeepseekV2DecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config,
Expand All @@ -26,7 +30,7 @@ class DeepseekV2DecoderLayer : public infinicore::nn::Module {
private:
INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm);
INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_attention_layernorm);
INFINICORE_NN_MODULE(DeepseekV2Attention, self_attn);
INFINICORE_NN_MODULE(DeepseekV2SelfAttention, self_attn);
INFINICORE_NN_MODULE(DeepseekV2MLP, dense_mlp);
INFINICORE_NN_MODULE(DeepseekV2MoE, moe_mlp);
bool use_moe_{false};
Expand Down
200 changes: 200 additions & 0 deletions csrc/models/deepseek_v2/deepseek_v2_mla_attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
#include "deepseek_v2_mla_attention.hpp"

#include "../../global_state/global_state.hpp"
#include "../../utils.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops/cat.hpp"
#include "infinicore/ops/pad.hpp"

#include <cmath>
#include <stdexcept>

namespace infinilm::models::deepseek_v2 {
namespace {

float yarn_get_mscale(float scale, float mscale) {
if (scale <= 1.0f) {
return 1.0f;
}
return 0.1f * mscale * std::log(scale) + 1.0f;
}

} // namespace

DeepseekV2MLAAttention::DeepseekV2MLAAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
size_t layer_idx,
const infinicore::Device &device) {
layer_idx_ = layer_idx;
hidden_size_ = model_config->get<size_t>("hidden_size");
qk_nope_head_dim_ = model_config->get<size_t>("qk_nope_head_dim");
qk_rope_head_dim_ = model_config->get<size_t>("qk_rope_head_dim");
q_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_;
v_head_dim_ = model_config->get<size_t>("v_head_dim");
kv_lora_rank_ = model_config->get<size_t>("kv_lora_rank");
mla_head_dim_ = kv_lora_rank_ + qk_rope_head_dim_;

if (model_config->get_or<size_t>("q_lora_rank", 0) != 0) {
throw std::runtime_error("DeepseekV2MLAAttention: q_lora_rank is not supported yet");
}

const auto &dtype{model_config->get_dtype()};
const size_t total_num_heads = model_config->get<size_t>("num_attention_heads");
const bool attention_bias = model_config->get_or<bool>("attention_bias", false);
const double rms_norm_eps = model_config->get<double>("rms_norm_eps");

const auto &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info();
const int tp_rank = rank_info.tp_rank;
const int tp_size = rank_info.tp_size;
if ((total_num_heads < static_cast<size_t>(tp_size)) || (total_num_heads % static_cast<size_t>(tp_size) != 0)) {
throw std::runtime_error("DeepseekV2MLAAttention: num_attention_heads must be divisible by tp_size");
}
num_attention_heads_ = total_num_heads / static_cast<size_t>(tp_size);
attention_backend_ = infinilm::global_state::get_infinilm_config().attention_backend;

auto quantization_method = model_config->get_quantization_method();
INFINICORE_NN_MODULE_INIT(q_proj, hidden_size_, total_num_heads * q_head_dim_, quantization_method, false, dtype, device, tp_rank, tp_size);
INFINICORE_NN_MODULE_INIT(kv_a_proj_with_mqa, hidden_size_, kv_lora_rank_ + qk_rope_head_dim_, attention_bias, dtype, device);
INFINICORE_NN_MODULE_INIT(kv_a_layernorm, kv_lora_rank_, rms_norm_eps, dtype, device);
INFINICORE_NN_MODULE_INIT(kv_b_proj, kv_lora_rank_, total_num_heads * (qk_nope_head_dim_ + v_head_dim_), quantization_method, false, dtype, device, tp_rank, tp_size);
INFINICORE_NN_MODULE_INIT(o_proj, total_num_heads * v_head_dim_, hidden_size_, quantization_method, attention_bias, dtype, device, tp_rank, tp_size, rank_info.comm);

const size_t max_position_embeddings = model_config->get<size_t>("max_position_embeddings");
const double rope_theta = model_config->get<double>("rope_theta");
rotary_emb_ = std::make_shared<infinicore::nn::RoPE>(
qk_rope_head_dim_, qk_rope_head_dim_, max_position_embeddings, rope_theta,
infinicore::nn::RoPE::Algo::GPT_J, dtype, device, nullptr);

softmax_scale_ = 1.0f / std::sqrt(static_cast<float>(q_head_dim_));
auto &config_json = model_config->get_config_json();
if (config_json.contains("rope_scaling") && config_json["rope_scaling"].is_object()) {
const auto &rope_scaling = config_json["rope_scaling"];
const float mscale_all_dim = rope_scaling.value("mscale_all_dim", 0.0f);
if (mscale_all_dim != 0.0f) {
const float scaling_factor = rope_scaling.value("factor", 1.0f);
const float mscale = yarn_get_mscale(scaling_factor, mscale_all_dim);
softmax_scale_ *= mscale * mscale;
}
}

latent_attn_ = std::make_shared<infinilm::layers::attention::AttentionLayer>(
num_attention_heads_, mla_head_dim_, softmax_scale_, 1, layer_idx_,
kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_);
infinilm::layers::attention::init_kv_cache_quant_params(
[this](const std::string &n, infinicore::nn::Parameter p) { this->register_parameter(n, std::move(p)); },
device, kv_cache_k_scale_, kv_cache_v_scale_);
}

infinicore::Tensor DeepseekV2MLAAttention::position_ids_for_rope_(const infinicore::Tensor &position_ids) const {
auto pos_shape = position_ids->shape();
if (pos_shape.size() == 2) {
return position_ids->narrow({{0, 0, 1}})->contiguous()->view({pos_shape[1]});
}
if (pos_shape.size() == 1) {
return position_ids->contiguous();
}
throw std::runtime_error("DeepseekV2MLAAttention: unexpected position_ids shape");
}

infinicore::Tensor DeepseekV2MLAAttention::kv_b_weight_3d_() const {
return kv_b_proj_->weight()->view({num_attention_heads_, qk_nope_head_dim_ + v_head_dim_, kv_lora_rank_});
}

infinicore::Tensor DeepseekV2MLAAttention::project_q_nope_to_latent_(const infinicore::Tensor &q_nope) const {
const size_t ntokens = q_nope->shape()[0];
auto q_nope_by_head = q_nope->permute({1, 0, 2})->contiguous();
auto w_uk_t = kv_b_weight_3d_()->narrow({{1, 0, qk_nope_head_dim_}})->contiguous();
auto q_latent = infinicore::op::matmul(q_nope_by_head, w_uk_t);
return q_latent->permute({1, 0, 2})->contiguous()->view({ntokens, num_attention_heads_, kv_lora_rank_});
}

infinicore::Tensor DeepseekV2MLAAttention::project_latent_to_value_(const infinicore::Tensor &attn_output,
size_t batch_size,
size_t seq_len) const {
const size_t ntokens = batch_size * seq_len;
auto latent = attn_output->view({ntokens, num_attention_heads_, mla_head_dim_})
->narrow({{2, 0, kv_lora_rank_}})
->contiguous();
auto latent_by_head = latent->permute({1, 0, 2})->contiguous();
auto w_uv = kv_b_weight_3d_()
->narrow({{1, qk_nope_head_dim_, v_head_dim_}})
->permute({0, 2, 1})
->contiguous();
auto value = infinicore::op::matmul(latent_by_head, w_uv)
->permute({1, 0, 2})
->contiguous()
->view({batch_size, seq_len, num_attention_heads_ * v_head_dim_});
return o_proj_->forward(value);
}

infinicore::Tensor DeepseekV2MLAAttention::forward(const infinicore::Tensor &positions,
const infinicore::Tensor &hidden_states) const {
if (::infinilm::backends::AttentionBackend::STATIC_ATTN == attention_backend_) {
return forward_static_(positions, hidden_states);
}
return forward_paged_(positions, hidden_states);
}

infinicore::Tensor DeepseekV2MLAAttention::forward_static_(const infinicore::Tensor &position_ids,
const infinicore::Tensor &hidden_states) const {
auto shape = hidden_states->shape();
const size_t batch_size = shape[0];
const size_t seq_len = shape[1];
const size_t ntokens = batch_size * seq_len;
auto hidden_states_mutable = hidden_states;

auto q = q_proj_->forward(hidden_states_mutable)->view({ntokens, num_attention_heads_, q_head_dim_});
auto q_nope = q->narrow({{2, 0, qk_nope_head_dim_}})->contiguous();
auto q_pe = q->narrow({{2, qk_nope_head_dim_, qk_rope_head_dim_}})->contiguous();

auto compressed = kv_a_proj_with_mqa_->forward(hidden_states_mutable)->view({ntokens, kv_lora_rank_ + qk_rope_head_dim_});
auto compressed_kv = compressed->narrow({{1, 0, kv_lora_rank_}})->contiguous();
auto k_pe = compressed->narrow({{1, kv_lora_rank_, qk_rope_head_dim_}})->contiguous();

auto kv_norm = kv_a_layernorm_->forward(compressed_kv);
auto pos_ids = position_ids_for_rope_(position_ids);
q_pe = rotary_emb_->forward(q_pe, pos_ids, true);
auto k_pe_rope = rotary_emb_->forward(k_pe->view({ntokens, 1, qk_rope_head_dim_}), pos_ids, true);

auto q_latent = project_q_nope_to_latent_(q_nope);
auto query_states = infinicore::op::cat({q_latent, q_pe}, 2)->view({batch_size, seq_len, num_attention_heads_, mla_head_dim_});
auto key_states = infinicore::op::cat({kv_norm->view({ntokens, 1, kv_lora_rank_}), k_pe_rope}, 2)
->view({batch_size, seq_len, 1, mla_head_dim_});
auto value_states = infinicore::op::pad(kv_norm->view({batch_size, seq_len, 1, kv_lora_rank_}),
{0, static_cast<int>(qk_rope_head_dim_)}, "constant", 0.0);

auto attn_output = latent_attn_->forward(query_states, key_states, value_states);
return project_latent_to_value_(attn_output, batch_size, seq_len);
}

infinicore::Tensor DeepseekV2MLAAttention::forward_paged_(const infinicore::Tensor &position_ids,
const infinicore::Tensor &hidden_states) const {
auto shape = hidden_states->shape();
const size_t batch_size = shape[0];
const size_t seq_len = shape[1];
ASSERT_EQ(batch_size, 1);
auto hidden_states_mutable = hidden_states;

auto q = q_proj_->forward(hidden_states_mutable)->view({seq_len, num_attention_heads_, q_head_dim_});
auto q_nope = q->narrow({{2, 0, qk_nope_head_dim_}})->contiguous();
auto q_pe = q->narrow({{2, qk_nope_head_dim_, qk_rope_head_dim_}})->contiguous();

auto compressed = kv_a_proj_with_mqa_->forward(hidden_states_mutable)->view({seq_len, kv_lora_rank_ + qk_rope_head_dim_});
auto compressed_kv = compressed->narrow({{1, 0, kv_lora_rank_}})->contiguous();
auto k_pe = compressed->narrow({{1, kv_lora_rank_, qk_rope_head_dim_}})->contiguous();

auto kv_norm = kv_a_layernorm_->forward(compressed_kv);
auto pos_ids = position_ids_for_rope_(position_ids);
q_pe = rotary_emb_->forward(q_pe, pos_ids, true);
auto k_pe_rope = rotary_emb_->forward(k_pe->view({seq_len, 1, qk_rope_head_dim_}), pos_ids, true);

auto q_latent = project_q_nope_to_latent_(q_nope);
auto query_states = infinicore::op::cat({q_latent, q_pe}, 2);
auto key_states = infinicore::op::cat({kv_norm->view({seq_len, 1, kv_lora_rank_}), k_pe_rope}, 2);
auto value_states = infinicore::op::pad(kv_norm->view({seq_len, 1, kv_lora_rank_}),
{0, static_cast<int>(qk_rope_head_dim_)}, "constant", 0.0);

auto attn_output = latent_attn_->forward(query_states, key_states, value_states);
return project_latent_to_value_(attn_output, batch_size, seq_len);
}

} // namespace infinilm::models::deepseek_v2
61 changes: 61 additions & 0 deletions csrc/models/deepseek_v2/deepseek_v2_mla_attention.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#pragma once

#include "../../backends/attention_backends.hpp"
#include "../../config/model_config.hpp"
#include "../../layers/attention/attention.hpp"
#include "../../layers/linear/linear.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/tensor.hpp"

#include <memory>

namespace infinilm::models::deepseek_v2 {

class DeepseekV2MLAAttention : public infinicore::nn::Module {
public:
DeepseekV2MLAAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
size_t layer_idx,
const infinicore::Device &device);

infinicore::Tensor forward(const infinicore::Tensor &positions,
const infinicore::Tensor &hidden_states) const;

private:
infinicore::Tensor forward_static_(const infinicore::Tensor &positions,
const infinicore::Tensor &hidden_states) const;
infinicore::Tensor forward_paged_(const infinicore::Tensor &positions,
const infinicore::Tensor &hidden_states) const;
infinicore::Tensor position_ids_for_rope_(const infinicore::Tensor &position_ids) const;
infinicore::Tensor kv_b_weight_3d_() const;
infinicore::Tensor project_q_nope_to_latent_(const infinicore::Tensor &q_nope) const;
infinicore::Tensor project_latent_to_value_(const infinicore::Tensor &attn_output,
size_t batch_size,
size_t seq_len) const;

size_t layer_idx_{0};
size_t hidden_size_{0};
size_t num_attention_heads_{0};
size_t qk_nope_head_dim_{0};
size_t qk_rope_head_dim_{0};
size_t q_head_dim_{0};
size_t v_head_dim_{0};
size_t kv_lora_rank_{0};
size_t mla_head_dim_{0};
float softmax_scale_{1.0f};
infinilm::backends::AttentionBackend attention_backend_;

INFINICORE_NN_MODULE(infinilm::layers::linear::ColumnParallelLinear, q_proj);
INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, kv_a_proj_with_mqa);
INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, kv_a_layernorm);
INFINICORE_NN_MODULE(infinilm::layers::linear::ColumnParallelLinear, kv_b_proj);
INFINICORE_NN_MODULE(infinilm::layers::linear::RowParallelLinear, o_proj);

std::shared_ptr<infinicore::nn::RoPE> rotary_emb_;
std::shared_ptr<infinilm::layers::attention::AttentionLayer> latent_attn_;
infinicore::nn::Parameter kv_cache_k_scale_;
infinicore::nn::Parameter kv_cache_v_scale_;
};

} // namespace infinilm::models::deepseek_v2
Loading