From 3e53a124576e1802885d5ba02e9c43ce999c4518 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Wed, 17 Jun 2026 19:13:22 +0800 Subject: [PATCH] feat: concept deepseek v2 mla attn --- csrc/engine/infer_engine.cpp | 7 +- csrc/engine/infer_engine.hpp | 4 +- csrc/global_state/infinilm_config.hpp | 5 +- .../layers/attention/backends/static_attn.cpp | 2 +- .../deepseek_v2/deepseek_v2_decoder_layer.cpp | 11 +- .../deepseek_v2/deepseek_v2_decoder_layer.hpp | 6 +- .../deepseek_v2/deepseek_v2_mla_attention.cpp | 200 ++++++++++++++++++ .../deepseek_v2/deepseek_v2_mla_attention.hpp | 61 ++++++ csrc/models/infinilm_model.cpp | 6 + csrc/pybind11/engine/engine.hpp | 9 +- examples/test_infer.py | 3 + python/infinilm/base_config.py | 6 + python/infinilm/config/engine_config.py | 2 + python/infinilm/infer_engine.py | 4 +- python/infinilm/llm/llm.py | 6 + .../infinilm/llm/model_runner/model_runner.py | 1 + python/infinilm/server/inference_server.py | 6 +- 17 files changed, 325 insertions(+), 14 deletions(-) create mode 100644 csrc/models/deepseek_v2/deepseek_v2_mla_attention.cpp create mode 100644 csrc/models/deepseek_v2/deepseek_v2_mla_attention.hpp diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index e07c676ce..2b3498624 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -15,15 +15,16 @@ InferEngine::InferEngine( const cache::CacheConfig *cache_config, bool enable_graph_compiling, backends::AttentionBackend attention_backend, - std::optional kv_cache_dtype) // Changed parameter - : communication_group_(distributed_config, device_type), attention_backend_(attention_backend) { + std::optional kv_cache_dtype, + bool use_mla) + : communication_group_(distributed_config, device_type), attention_backend_(attention_backend), use_mla_(use_mla) { if (cache_config != nullptr) { cache_config_ = cache_config->unique_copy(); } // Load model config if model_path is provided, model_path must be valid, and config.json exists this->model_config_ = infinilm::config::ConfigFactory::createConfig(config_str); - auto infinilm_config = std::make_shared(attention_backend, this->model_config_); + auto infinilm_config = std::make_shared(attention_backend, this->model_config_, use_mla); // Only support offline int8 kv cache quantization in this version if (kv_cache_dtype.has_value()) { diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index 1b41be21e..632e9d953 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -28,7 +28,8 @@ class InferEngine { const cache::CacheConfig *cache_config = nullptr, bool enable_graph_compiling = false, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default, - std::optional kv_cache_dtype = std::nullopt); + std::optional kv_cache_dtype = std::nullopt, + bool use_mla = false); // Load a parameter to all workers (each can extract its shard inside RankWorker) void load_param(const std::string &name, const infinicore::Tensor ¶m); @@ -68,6 +69,7 @@ class InferEngine { std::shared_ptr model_config_; backends::AttentionBackend attention_backend_ = backends::AttentionBackend::Default; bool weights_finalized_ = false; + bool use_mla_{false}; }; } // namespace infinilm::engine diff --git a/csrc/global_state/infinilm_config.hpp b/csrc/global_state/infinilm_config.hpp index 9b80706ca..dd7b13051 100644 --- a/csrc/global_state/infinilm_config.hpp +++ b/csrc/global_state/infinilm_config.hpp @@ -14,12 +14,15 @@ struct InfinilmConfig { public: InfinilmConfig() = default; InfinilmConfig(const infinilm::backends::AttentionBackend &backend, - const std::shared_ptr &model_config) + const std::shared_ptr &model_config, + bool use_mla = false) : attention_backend(backend), + use_mla(use_mla), model_config(model_config) {} public: infinilm::backends::AttentionBackend attention_backend; + bool use_mla{false}; std::shared_ptr model_config; }; diff --git a/csrc/layers/attention/backends/static_attn.cpp b/csrc/layers/attention/backends/static_attn.cpp index 2d1b7e11a..b71f0b52e 100644 --- a/csrc/layers/attention/backends/static_attn.cpp +++ b/csrc/layers/attention/backends/static_attn.cpp @@ -79,7 +79,7 @@ infinicore::Tensor StaticAttentionImpl::forward(const AttentionLayer &layer, // Compute attention size_t ngroup = num_heads_ / num_kv_heads_; - auto Q = q_reshaped->view({batch_size * num_kv_heads_, ngroup * seq_len, head_dim_}); + auto Q = q_reshaped->contiguous()->view({batch_size * num_kv_heads_, ngroup * seq_len, head_dim_}); auto K = k_total->view({batch_size * num_kv_heads_, total_seq_len, head_dim_}); auto V = v_total->view({batch_size * num_kv_heads_, total_seq_len, head_dim_}); diff --git a/csrc/models/deepseek_v2/deepseek_v2_decoder_layer.cpp b/csrc/models/deepseek_v2/deepseek_v2_decoder_layer.cpp index 2843bc86b..9a1e0c0b1 100644 --- a/csrc/models/deepseek_v2/deepseek_v2_decoder_layer.cpp +++ b/csrc/models/deepseek_v2/deepseek_v2_decoder_layer.cpp @@ -1,5 +1,7 @@ #include "deepseek_v2_decoder_layer.hpp" +#include "../../global_state/global_state.hpp" + namespace infinilm::models::deepseek_v2 { DeepseekV2DecoderLayer::DeepseekV2DecoderLayer(std::shared_ptr model_config, @@ -10,7 +12,11 @@ DeepseekV2DecoderLayer::DeepseekV2DecoderLayer(std::shared_ptrget("rms_norm_eps"); INFINICORE_NN_MODULE_INIT(input_layernorm, hidden_size, rms_norm_eps, dtype, device); INFINICORE_NN_MODULE_INIT(post_attention_layernorm, hidden_size, rms_norm_eps, dtype, device); - INFINICORE_NN_MODULE_INIT(self_attn, model_config, layer_idx, device); + if (infinilm::global_state::get_infinilm_config().use_mla) { + self_attn_ = std::make_shared(this->register_module("self_attn", model_config, layer_idx, device)); + } else { + self_attn_ = std::make_shared(this->register_module("self_attn", model_config, layer_idx, device)); + } const size_t first_k_dense_replace = model_config->get_or("first_k_dense_replace", 0); const size_t moe_layer_freq = model_config->get_or("moe_layer_freq", 1); @@ -29,7 +35,8 @@ DeepseekV2DecoderLayer::forward(const infinicore::Tensor &positions, infinicore::Tensor &hidden_states, infinicore::Tensor &residual) const { input_layernorm_->forward_inplace(hidden_states, residual); - hidden_states = self_attn_->forward(positions, hidden_states); + hidden_states = std::visit( + [&](auto &attn_ptr) { return attn_ptr->forward(positions, hidden_states); }, *self_attn_); post_attention_layernorm_->forward_inplace(hidden_states, residual); hidden_states = use_moe_ ? moe_mlp_->forward(hidden_states) : dense_mlp_->forward(hidden_states); return {hidden_states, residual}; diff --git a/csrc/models/deepseek_v2/deepseek_v2_decoder_layer.hpp b/csrc/models/deepseek_v2/deepseek_v2_decoder_layer.hpp index e7a36d85a..34541fa5e 100644 --- a/csrc/models/deepseek_v2/deepseek_v2_decoder_layer.hpp +++ b/csrc/models/deepseek_v2/deepseek_v2_decoder_layer.hpp @@ -2,6 +2,7 @@ #include "../../config/model_config.hpp" #include "deepseek_v2_attention.hpp" +#include "deepseek_v2_mla_attention.hpp" #include "deepseek_v2_moe.hpp" #include "infinicore/device.hpp" #include "infinicore/nn/module.hpp" @@ -10,9 +11,12 @@ #include #include +#include namespace infinilm::models::deepseek_v2 { +using DeepseekV2SelfAttention = std::variant, std::shared_ptr>; + class DeepseekV2DecoderLayer : public infinicore::nn::Module { public: DeepseekV2DecoderLayer(std::shared_ptr model_config, @@ -26,7 +30,7 @@ class DeepseekV2DecoderLayer : public infinicore::nn::Module { private: INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm); INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_attention_layernorm); - INFINICORE_NN_MODULE(DeepseekV2Attention, self_attn); + INFINICORE_NN_MODULE(DeepseekV2SelfAttention, self_attn); INFINICORE_NN_MODULE(DeepseekV2MLP, dense_mlp); INFINICORE_NN_MODULE(DeepseekV2MoE, moe_mlp); bool use_moe_{false}; diff --git a/csrc/models/deepseek_v2/deepseek_v2_mla_attention.cpp b/csrc/models/deepseek_v2/deepseek_v2_mla_attention.cpp new file mode 100644 index 000000000..07de064c1 --- /dev/null +++ b/csrc/models/deepseek_v2/deepseek_v2_mla_attention.cpp @@ -0,0 +1,200 @@ +#include "deepseek_v2_mla_attention.hpp" + +#include "../../global_state/global_state.hpp" +#include "../../utils.hpp" +#include "infinicore/ops.hpp" +#include "infinicore/ops/cat.hpp" +#include "infinicore/ops/pad.hpp" + +#include +#include + +namespace infinilm::models::deepseek_v2 { +namespace { + +float yarn_get_mscale(float scale, float mscale) { + if (scale <= 1.0f) { + return 1.0f; + } + return 0.1f * mscale * std::log(scale) + 1.0f; +} + +} // namespace + +DeepseekV2MLAAttention::DeepseekV2MLAAttention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) { + layer_idx_ = layer_idx; + hidden_size_ = model_config->get("hidden_size"); + qk_nope_head_dim_ = model_config->get("qk_nope_head_dim"); + qk_rope_head_dim_ = model_config->get("qk_rope_head_dim"); + q_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_; + v_head_dim_ = model_config->get("v_head_dim"); + kv_lora_rank_ = model_config->get("kv_lora_rank"); + mla_head_dim_ = kv_lora_rank_ + qk_rope_head_dim_; + + if (model_config->get_or("q_lora_rank", 0) != 0) { + throw std::runtime_error("DeepseekV2MLAAttention: q_lora_rank is not supported yet"); + } + + const auto &dtype{model_config->get_dtype()}; + const size_t total_num_heads = model_config->get("num_attention_heads"); + const bool attention_bias = model_config->get_or("attention_bias", false); + const double rms_norm_eps = model_config->get("rms_norm_eps"); + + const auto &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + const int tp_rank = rank_info.tp_rank; + const int tp_size = rank_info.tp_size; + if ((total_num_heads < static_cast(tp_size)) || (total_num_heads % static_cast(tp_size) != 0)) { + throw std::runtime_error("DeepseekV2MLAAttention: num_attention_heads must be divisible by tp_size"); + } + num_attention_heads_ = total_num_heads / static_cast(tp_size); + attention_backend_ = infinilm::global_state::get_infinilm_config().attention_backend; + + auto quantization_method = model_config->get_quantization_method(); + INFINICORE_NN_MODULE_INIT(q_proj, hidden_size_, total_num_heads * q_head_dim_, quantization_method, false, dtype, device, tp_rank, tp_size); + INFINICORE_NN_MODULE_INIT(kv_a_proj_with_mqa, hidden_size_, kv_lora_rank_ + qk_rope_head_dim_, attention_bias, dtype, device); + INFINICORE_NN_MODULE_INIT(kv_a_layernorm, kv_lora_rank_, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(kv_b_proj, kv_lora_rank_, total_num_heads * (qk_nope_head_dim_ + v_head_dim_), quantization_method, false, dtype, device, tp_rank, tp_size); + INFINICORE_NN_MODULE_INIT(o_proj, total_num_heads * v_head_dim_, hidden_size_, quantization_method, attention_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + + const size_t max_position_embeddings = model_config->get("max_position_embeddings"); + const double rope_theta = model_config->get("rope_theta"); + rotary_emb_ = std::make_shared( + qk_rope_head_dim_, qk_rope_head_dim_, max_position_embeddings, rope_theta, + infinicore::nn::RoPE::Algo::GPT_J, dtype, device, nullptr); + + softmax_scale_ = 1.0f / std::sqrt(static_cast(q_head_dim_)); + auto &config_json = model_config->get_config_json(); + if (config_json.contains("rope_scaling") && config_json["rope_scaling"].is_object()) { + const auto &rope_scaling = config_json["rope_scaling"]; + const float mscale_all_dim = rope_scaling.value("mscale_all_dim", 0.0f); + if (mscale_all_dim != 0.0f) { + const float scaling_factor = rope_scaling.value("factor", 1.0f); + const float mscale = yarn_get_mscale(scaling_factor, mscale_all_dim); + softmax_scale_ *= mscale * mscale; + } + } + + latent_attn_ = std::make_shared( + num_attention_heads_, mla_head_dim_, softmax_scale_, 1, layer_idx_, + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + infinilm::layers::attention::init_kv_cache_quant_params( + [this](const std::string &n, infinicore::nn::Parameter p) { this->register_parameter(n, std::move(p)); }, + device, kv_cache_k_scale_, kv_cache_v_scale_); +} + +infinicore::Tensor DeepseekV2MLAAttention::position_ids_for_rope_(const infinicore::Tensor &position_ids) const { + auto pos_shape = position_ids->shape(); + if (pos_shape.size() == 2) { + return position_ids->narrow({{0, 0, 1}})->contiguous()->view({pos_shape[1]}); + } + if (pos_shape.size() == 1) { + return position_ids->contiguous(); + } + throw std::runtime_error("DeepseekV2MLAAttention: unexpected position_ids shape"); +} + +infinicore::Tensor DeepseekV2MLAAttention::kv_b_weight_3d_() const { + return kv_b_proj_->weight()->view({num_attention_heads_, qk_nope_head_dim_ + v_head_dim_, kv_lora_rank_}); +} + +infinicore::Tensor DeepseekV2MLAAttention::project_q_nope_to_latent_(const infinicore::Tensor &q_nope) const { + const size_t ntokens = q_nope->shape()[0]; + auto q_nope_by_head = q_nope->permute({1, 0, 2})->contiguous(); + auto w_uk_t = kv_b_weight_3d_()->narrow({{1, 0, qk_nope_head_dim_}})->contiguous(); + auto q_latent = infinicore::op::matmul(q_nope_by_head, w_uk_t); + return q_latent->permute({1, 0, 2})->contiguous()->view({ntokens, num_attention_heads_, kv_lora_rank_}); +} + +infinicore::Tensor DeepseekV2MLAAttention::project_latent_to_value_(const infinicore::Tensor &attn_output, + size_t batch_size, + size_t seq_len) const { + const size_t ntokens = batch_size * seq_len; + auto latent = attn_output->view({ntokens, num_attention_heads_, mla_head_dim_}) + ->narrow({{2, 0, kv_lora_rank_}}) + ->contiguous(); + auto latent_by_head = latent->permute({1, 0, 2})->contiguous(); + auto w_uv = kv_b_weight_3d_() + ->narrow({{1, qk_nope_head_dim_, v_head_dim_}}) + ->permute({0, 2, 1}) + ->contiguous(); + auto value = infinicore::op::matmul(latent_by_head, w_uv) + ->permute({1, 0, 2}) + ->contiguous() + ->view({batch_size, seq_len, num_attention_heads_ * v_head_dim_}); + return o_proj_->forward(value); +} + +infinicore::Tensor DeepseekV2MLAAttention::forward(const infinicore::Tensor &positions, + const infinicore::Tensor &hidden_states) const { + if (::infinilm::backends::AttentionBackend::STATIC_ATTN == attention_backend_) { + return forward_static_(positions, hidden_states); + } + return forward_paged_(positions, hidden_states); +} + +infinicore::Tensor DeepseekV2MLAAttention::forward_static_(const infinicore::Tensor &position_ids, + const infinicore::Tensor &hidden_states) const { + auto shape = hidden_states->shape(); + const size_t batch_size = shape[0]; + const size_t seq_len = shape[1]; + const size_t ntokens = batch_size * seq_len; + auto hidden_states_mutable = hidden_states; + + auto q = q_proj_->forward(hidden_states_mutable)->view({ntokens, num_attention_heads_, q_head_dim_}); + auto q_nope = q->narrow({{2, 0, qk_nope_head_dim_}})->contiguous(); + auto q_pe = q->narrow({{2, qk_nope_head_dim_, qk_rope_head_dim_}})->contiguous(); + + auto compressed = kv_a_proj_with_mqa_->forward(hidden_states_mutable)->view({ntokens, kv_lora_rank_ + qk_rope_head_dim_}); + auto compressed_kv = compressed->narrow({{1, 0, kv_lora_rank_}})->contiguous(); + auto k_pe = compressed->narrow({{1, kv_lora_rank_, qk_rope_head_dim_}})->contiguous(); + + auto kv_norm = kv_a_layernorm_->forward(compressed_kv); + auto pos_ids = position_ids_for_rope_(position_ids); + q_pe = rotary_emb_->forward(q_pe, pos_ids, true); + auto k_pe_rope = rotary_emb_->forward(k_pe->view({ntokens, 1, qk_rope_head_dim_}), pos_ids, true); + + auto q_latent = project_q_nope_to_latent_(q_nope); + auto query_states = infinicore::op::cat({q_latent, q_pe}, 2)->view({batch_size, seq_len, num_attention_heads_, mla_head_dim_}); + auto key_states = infinicore::op::cat({kv_norm->view({ntokens, 1, kv_lora_rank_}), k_pe_rope}, 2) + ->view({batch_size, seq_len, 1, mla_head_dim_}); + auto value_states = infinicore::op::pad(kv_norm->view({batch_size, seq_len, 1, kv_lora_rank_}), + {0, static_cast(qk_rope_head_dim_)}, "constant", 0.0); + + auto attn_output = latent_attn_->forward(query_states, key_states, value_states); + return project_latent_to_value_(attn_output, batch_size, seq_len); +} + +infinicore::Tensor DeepseekV2MLAAttention::forward_paged_(const infinicore::Tensor &position_ids, + const infinicore::Tensor &hidden_states) const { + auto shape = hidden_states->shape(); + const size_t batch_size = shape[0]; + const size_t seq_len = shape[1]; + ASSERT_EQ(batch_size, 1); + auto hidden_states_mutable = hidden_states; + + auto q = q_proj_->forward(hidden_states_mutable)->view({seq_len, num_attention_heads_, q_head_dim_}); + auto q_nope = q->narrow({{2, 0, qk_nope_head_dim_}})->contiguous(); + auto q_pe = q->narrow({{2, qk_nope_head_dim_, qk_rope_head_dim_}})->contiguous(); + + auto compressed = kv_a_proj_with_mqa_->forward(hidden_states_mutable)->view({seq_len, kv_lora_rank_ + qk_rope_head_dim_}); + auto compressed_kv = compressed->narrow({{1, 0, kv_lora_rank_}})->contiguous(); + auto k_pe = compressed->narrow({{1, kv_lora_rank_, qk_rope_head_dim_}})->contiguous(); + + auto kv_norm = kv_a_layernorm_->forward(compressed_kv); + auto pos_ids = position_ids_for_rope_(position_ids); + q_pe = rotary_emb_->forward(q_pe, pos_ids, true); + auto k_pe_rope = rotary_emb_->forward(k_pe->view({seq_len, 1, qk_rope_head_dim_}), pos_ids, true); + + auto q_latent = project_q_nope_to_latent_(q_nope); + auto query_states = infinicore::op::cat({q_latent, q_pe}, 2); + auto key_states = infinicore::op::cat({kv_norm->view({seq_len, 1, kv_lora_rank_}), k_pe_rope}, 2); + auto value_states = infinicore::op::pad(kv_norm->view({seq_len, 1, kv_lora_rank_}), + {0, static_cast(qk_rope_head_dim_)}, "constant", 0.0); + + auto attn_output = latent_attn_->forward(query_states, key_states, value_states); + return project_latent_to_value_(attn_output, batch_size, seq_len); +} + +} // namespace infinilm::models::deepseek_v2 diff --git a/csrc/models/deepseek_v2/deepseek_v2_mla_attention.hpp b/csrc/models/deepseek_v2/deepseek_v2_mla_attention.hpp new file mode 100644 index 000000000..26b42d97e --- /dev/null +++ b/csrc/models/deepseek_v2/deepseek_v2_mla_attention.hpp @@ -0,0 +1,61 @@ +#pragma once + +#include "../../backends/attention_backends.hpp" +#include "../../config/model_config.hpp" +#include "../../layers/attention/attention.hpp" +#include "../../layers/linear/linear.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/nn/rmsnorm.hpp" +#include "infinicore/nn/rope.hpp" +#include "infinicore/tensor.hpp" + +#include + +namespace infinilm::models::deepseek_v2 { + +class DeepseekV2MLAAttention : public infinicore::nn::Module { +public: + DeepseekV2MLAAttention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &positions, + const infinicore::Tensor &hidden_states) const; + +private: + infinicore::Tensor forward_static_(const infinicore::Tensor &positions, + const infinicore::Tensor &hidden_states) const; + infinicore::Tensor forward_paged_(const infinicore::Tensor &positions, + const infinicore::Tensor &hidden_states) const; + infinicore::Tensor position_ids_for_rope_(const infinicore::Tensor &position_ids) const; + infinicore::Tensor kv_b_weight_3d_() const; + infinicore::Tensor project_q_nope_to_latent_(const infinicore::Tensor &q_nope) const; + infinicore::Tensor project_latent_to_value_(const infinicore::Tensor &attn_output, + size_t batch_size, + size_t seq_len) const; + + size_t layer_idx_{0}; + size_t hidden_size_{0}; + size_t num_attention_heads_{0}; + size_t qk_nope_head_dim_{0}; + size_t qk_rope_head_dim_{0}; + size_t q_head_dim_{0}; + size_t v_head_dim_{0}; + size_t kv_lora_rank_{0}; + size_t mla_head_dim_{0}; + float softmax_scale_{1.0f}; + infinilm::backends::AttentionBackend attention_backend_; + + INFINICORE_NN_MODULE(infinilm::layers::linear::ColumnParallelLinear, q_proj); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, kv_a_proj_with_mqa); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, kv_a_layernorm); + INFINICORE_NN_MODULE(infinilm::layers::linear::ColumnParallelLinear, kv_b_proj); + INFINICORE_NN_MODULE(infinilm::layers::linear::RowParallelLinear, o_proj); + + std::shared_ptr rotary_emb_; + std::shared_ptr latent_attn_; + infinicore::nn::Parameter kv_cache_k_scale_; + infinicore::nn::Parameter kv_cache_v_scale_; +}; + +} // namespace infinilm::models::deepseek_v2 diff --git a/csrc/models/infinilm_model.cpp b/csrc/models/infinilm_model.cpp index 1eaff4f20..8429fffba 100644 --- a/csrc/models/infinilm_model.cpp +++ b/csrc/models/infinilm_model.cpp @@ -30,6 +30,12 @@ std::vector InfinilmModel::default_allocate_kv_cache_tensors } size_t head_dim = text_config->get("head_dim"); size_t num_key_value_heads = text_config->get("num_key_value_heads"); + const bool use_deepseek_mla = infinilm::global_state::get_infinilm_config().use_mla + && text_config->get_or("model_type", "") == "deepseek_v2"; + if (use_deepseek_mla) { + head_dim = text_config->get("kv_lora_rank") + text_config->get("qk_rope_head_dim"); + num_key_value_heads = 1; + } size_t max_position_embeddings = text_config->get("max_position_embeddings"); const auto &dtype = model_config_->get_kv_cache_dtype(); const size_t num_hidden_layers = text_config->get("num_hidden_layers"); diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 694387d39..983b0b88f 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -65,7 +65,8 @@ inline void bind_infer_engine(py::module &m) { std::shared_ptr cache_cfg, bool enable_graph_compiling, const std::string &attention_backend, - std::optional kv_cache_dtype) { + std::optional kv_cache_dtype, + bool use_mla) { return std::make_shared( config_str, dist, @@ -73,7 +74,8 @@ inline void bind_infer_engine(py::module &m) { cache_cfg ? cache_cfg.get() : nullptr, enable_graph_compiling, infinilm::backends::parse_attention_backend(attention_backend), - kv_cache_dtype); + kv_cache_dtype, + use_mla); }), py::arg("config_str") = "", py::arg("distributed_config") = distributed::DistConfig(), @@ -81,7 +83,8 @@ inline void bind_infer_engine(py::module &m) { py::arg("cache_config") = py::none(), py::arg("enable_graph_compiling") = false, py::arg("attention_backend") = "default", - py::arg("kv_cache_dtype") = py::none()) + py::arg("kv_cache_dtype") = py::none(), + py::arg("use_mla") = false) .def("load_param", &InferEngine::load_param, py::arg("name"), py::arg("param"), "Load a parameter tensor into all workers (each worker picks its shard)") diff --git a/examples/test_infer.py b/examples/test_infer.py index a90a5d3e8..dffe64d7f 100644 --- a/examples/test_infer.py +++ b/examples/test_infer.py @@ -16,6 +16,7 @@ def test( top_p=1.0, temperature=1.0, attn_backend="default", + use_mla=False, image_path=None, skip_load=False, ): @@ -38,6 +39,7 @@ def test( top_p=top_p, enable_graph=enable_graph, attn_backend=attn_backend, + use_mla=use_mla, skip_load=skip_load, ) @@ -101,6 +103,7 @@ def test( top_p=cfg.top_p, temperature=cfg.temperature, attn_backend=cfg.attn, + use_mla=cfg.use_mla, image_path=cfg.image, skip_load=cfg.skip_load, ) diff --git a/python/infinilm/base_config.py b/python/infinilm/base_config.py index c07beff6a..6b5515275 100644 --- a/python/infinilm/base_config.py +++ b/python/infinilm/base_config.py @@ -60,6 +60,7 @@ def __init__(self): self.attn = self.args.attn self.enable_graph = self.args.enable_graph self.enable_paged_attn = self.args.enable_paged_attn + self.use_mla = self.args.use_mla self.num_blocks = self.args.num_blocks self.block_size = self.args.block_size self.max_cache_len = self.args.max_cache_len @@ -122,6 +123,11 @@ def _add_common_args(self): choices=["default", "paged-attn", "flash-attn"], ) self.parser.add_argument("--enable-graph", action="store_true") + self.parser.add_argument( + "--use-mla", + action="store_true", + help="use DeepSeek V2 MLA attention when supported", + ) self.parser.add_argument( "--enable-paged-attn", action="store_true", diff --git a/python/infinilm/config/engine_config.py b/python/infinilm/config/engine_config.py index 5799b5b05..044cfdda4 100644 --- a/python/infinilm/config/engine_config.py +++ b/python/infinilm/config/engine_config.py @@ -23,6 +23,7 @@ class EngineConfig: top_k: Default top-k sampling parameter. enable_graph: Whether to enable graph compiling. attn_backend: Attention backend to use ('default', 'flash-attn'). + use_mla: Whether to use DeepSeek V2 MLA attention when supported. skip_load: Whether to skip loading model weights (for testing). """ @@ -41,6 +42,7 @@ class EngineConfig: top_k: int = 1 enable_graph: bool = False attn_backend: str = "default" + use_mla: bool = False skip_load: bool = False kv_transfer_config: Optional[KVTransferConfig] = None diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index e95a4f7dc..17ee6c12f 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -3,7 +3,7 @@ import infinicore -from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig +from infinilm.cache import PagedKVCacheConfig from infinilm.distributed import DistConfig from infinilm.lib import _infinilm @@ -67,6 +67,7 @@ def __init__( enable_graph_compiling=False, attention_backend="default", kv_cache_dtype=None, + use_mla=False, ): self.hf_config = read_hf_config(model_path) self.hf_generation_config = read_hf_generation_config(model_path) @@ -87,6 +88,7 @@ def __init__( if kv_cache_dtype is not None else None ), + use_mla, ) self.use_cache = False diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index 9177e3fad..f8171c0db 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -301,6 +301,7 @@ def __init__( top_k: int = 1, enable_graph: bool = False, attn_backend: str = "default", + use_mla: bool = False, skip_load: bool = False, ): """Initialize LLM. @@ -321,6 +322,7 @@ def __init__( top_k: Default top-k sampling parameter. enable_graph: Whether to enable graph compiling. attn_backend: Attention backend to use ('default', 'flash-attn'). + use_mla: Whether to use DeepSeek V2 MLA attention when supported. """ config = EngineConfig( model_path=model_path, @@ -338,6 +340,7 @@ def __init__( top_k=top_k, enable_graph=enable_graph, attn_backend=attn_backend, + use_mla=use_mla, skip_load=skip_load, ) self.engine = LLMEngine(config) @@ -493,6 +496,7 @@ def __init__( enable_graph: bool = False, attn_backend: str = "default", kv_transfer_config: Optional[KVTransferConfig] = None, + use_mla: bool = False, ): """Initialize AsyncLLMEngine. @@ -515,6 +519,7 @@ def __init__( kv_connector: KV connector type ('MooncakeConnector'). kv_role: Role in KV connector ('kv_producer' or 'kv_consumer'). kv_connector_extra_config: Extra config dict for KV connector. + use_mla: Whether to use DeepSeek V2 MLA attention when supported. """ config = EngineConfig( model_path=model_path, @@ -533,6 +538,7 @@ def __init__( enable_graph=enable_graph, attn_backend=attn_backend, kv_transfer_config=kv_transfer_config, + use_mla=use_mla, ) self.engine = LLMEngine(config) self.config = config diff --git a/python/infinilm/llm/model_runner/model_runner.py b/python/infinilm/llm/model_runner/model_runner.py index e551da0cb..5eca6d372 100644 --- a/python/infinilm/llm/model_runner/model_runner.py +++ b/python/infinilm/llm/model_runner/model_runner.py @@ -73,6 +73,7 @@ def __init__(self, config: EngineConfig): cache_config=cache_config, enable_graph_compiling=config.enable_graph, attention_backend=config.attn_backend, + use_mla=config.use_mla, ) # Load model weights diff --git a/python/infinilm/server/inference_server.py b/python/infinilm/server/inference_server.py index 11708a183..645b00656 100644 --- a/python/infinilm/server/inference_server.py +++ b/python/infinilm/server/inference_server.py @@ -7,7 +7,6 @@ import time import json import uuid -import argparse import uvicorn import logging import os @@ -111,6 +110,7 @@ def __init__( port: int = 8000, enable_graph: bool = False, attn_backend: str = "default", + use_mla: bool = False, ignore_eos: bool = False, kv_transfer_config: Optional[KVTransferConfig] = None, ): @@ -134,6 +134,7 @@ def __init__( port: Server port number. enable_graph: Whether to enable graph compiling. attn_backend: Attention backend to use ('default', 'flash-attn'). + use_mla: Whether to use DeepSeek V2 MLA attention when supported. ignore_eos: Whether to ignore EOS tokens during generation. kv_transfer_config: Optional configuration for the KV transfer mechanism. """ @@ -156,6 +157,7 @@ def __init__( self.port = port self.enable_graph = enable_graph self.attn_backend = attn_backend + self.use_mla = use_mla self.ignore_eos = ignore_eos self.kv_transfer_config = kv_transfer_config @@ -189,6 +191,7 @@ async def lifespan(app: FastAPI): top_k=self.top_k, enable_graph=self.enable_graph, attn_backend=self.attn_backend, + use_mla=self.use_mla, kv_transfer_config=self.kv_transfer_config, ) self.engine.start() @@ -591,6 +594,7 @@ def main(): port=cfg.port, enable_graph=cfg.enable_graph, attn_backend=cfg.attn, + use_mla=cfg.use_mla, ignore_eos=cfg.ignore_eos, kv_transfer_config=kv_transfer_config, )