Skip to content

Triton mla#7804

Open
Linboyan-trc wants to merge 7 commits into
PaddlePaddle:developfrom
Linboyan-trc:triton_mla
Open

Triton mla#7804
Linboyan-trc wants to merge 7 commits into
PaddlePaddle:developfrom
Linboyan-trc:triton_mla

Conversation

@Linboyan-trc
Copy link
Copy Markdown

Motivation

将 SGLang 中基于 Triton 实现的 MLA (Multi-head Latent Attention) decode attention 迁移至 FastDeploy,作为新的 attention backend (TRITON_MLA_ATTN)。

Modifications

新增文件

文件 说明
fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py Triton MLA Backend 核心实现,包含 extend/decode/mixed forward 及 CUDA Graph 兼容的 metadata 初始化
fastdeploy/model_executor/layers/attention/triton_ops/__init__.py triton_ops 包导出
fastdeploy/model_executor/layers/attention/triton_ops/decode_attention.py Split-KV 两阶段 decode attention triton kernel(从 SGLang 迁移,适配 FastDeploy paged KV cache 寻址)
fastdeploy/model_executor/layers/attention/triton_ops/mla_cache_kernel.py Triton KV cache write kernel,将 `[compressed_kv

修改文件

文件 说明
fastdeploy/platforms/base.py 新增 _Backend.TRITON_MLA_ATTN 枚举值
fastdeploy/platforms/cuda.py 添加 backend 路由至 TritonMLAAttentionBackend
fastdeploy/config.py CacheConfig.use_mla_cache 识别 TRITON_MLA_ATTN
fastdeploy/worker/gpu_model_runner.py MLA cache 路径识别 + _apply_position_ids_if_needed 支持新 backend
fastdeploy/model_executor/layers/attention/__init__.py 注册并导出 TritonMLAAttentionBackend
fastdeploy/model_executor/models/deepseek_v3.py attn_out 空值保护,避免空 batch 时 None 传入 o_proj
custom_ops/gpu_ops/helper.h C++ 层 checkAttentionBackend() 识别新 backend

关键设计

  1. Paged KV Cache 适配:SGLang 使用 flat cache,FastDeploy 使用 paged cache。Decode kernel 中通过 kv_loc // block_sizekv_loc % block_size 解码 block 寻址。
  2. CUDA Graph 兼容:预分配固定大小 buffer(_kv_indptr_buf_kv_indices_buf_num_kv_splits_buf),使用 Triton cumsum kernel 替代 thrust(避免 cudaMalloc),padding 填充保持 kernel grid dim 恒定。
  3. Prefill 复用 Flash Attention:SM90+ 使用 flash_attention_v3_varlen,SM80 使用 flash_attn_unpadded,不重复造轮子。

Usage or Command

export CUDA_VISIBLE_DEVICES=0,1,2,3
export FD_ATTENTION_BACKEND="TRITON_MLA_ATTN"
export FLAGS_flash_attn_version=3
export FD_SAMPLING_CLASS=rejection

python -m fastdeploy.entrypoints.openai.api_server \
    --model /path/to/GLM-4.7-Flash \
    --port 8380 \
    --tensor-parallel-size 4 \
    --max-model-len 32768 \
    --max-num-seqs 32

Copilot AI review requested due to automatic review settings May 13, 2026 10:15
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 将 SGLang 中基于 Triton 的 MLA(Multi-head Latent Attention)decode attention 迁移到 FastDeploy,作为新的 attention backend:TRITON_MLA_ATTN,并打通对应的后端选择、KV cache 写入与 decode kernel 逻辑。

Changes:

  • 新增 TritonMLAAttentionBackend:extend 复用 FlashAttention,decode 走 Triton split-KV 两阶段 kernel,并加入 CUDA Graph 相关的元数据/缓冲区预分配。
  • 新增 Triton kernel:paged KV cache 写入(mla_cache_kernel.py)与 decode attention(decode_attention.py),并在 triton_ops/__init__.py 导出。
  • 平台/配置/运行时适配:新增 _Backend.TRITON_MLA_ATTN,CUDA 平台路由、use_mla_cache 识别、GPU runner 的 MLA cache 判断与 DeepSeek-V3 空 batch 保护。

需要关注(非代码行评论):

  • PR 标题目前为 "Triton mla",不符合仓库约定的 [CLASS]Title 格式;建议例如:[Feature] Add Triton MLA attention backend(或按实际分类调整)。
  • 该 PR 引入新 backend 与新的环境变量取值(FD_ATTENTION_BACKEND=TRITON_MLA_ATTN),建议同步检查/更新相关使用文档(如环境变量说明文档)以避免用户漏配或误配。

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
fastdeploy/worker/gpu_model_runner.py 识别 TRITON_MLA_ATTN 作为 MLA cache 路径,并让 position_ids/slot_mapping 计算覆盖新 backend
fastdeploy/platforms/cuda.py CUDA 平台增加 TRITON_MLA_ATTN 路由与日志/错误提示更新
fastdeploy/platforms/base.py 新增 _Backend.TRITON_MLA_ATTN 枚举值
fastdeploy/config.py CacheConfig.use_mla_cache 识别 TRITON_MLA_ATTN
fastdeploy/model_executor/layers/attention/init.py 注册并导出 TritonMLAAttentionBackend
fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py Triton MLA backend 核心实现(extend/ decode/ mixed + metadata/buffer 预分配)
fastdeploy/model_executor/layers/attention/triton_ops/init.py 导出 triton_ops 下的核心函数
fastdeploy/model_executor/layers/attention/triton_ops/decode_attention.py 新增 split-KV decode attention Triton kernel(paged KV 寻址)
fastdeploy/model_executor/layers/attention/triton_ops/mla_cache_kernel.py 新增 Triton KV cache 写入 kernel(写入 paged latent cache)
fastdeploy/model_executor/models/deepseek_v3.py 空 batch 时 attn_out 保护,避免 None 传入 o_proj
custom_ops/gpu_ops/helper.h C++ 侧 checkAttentionBackend() 识别 TRITON_MLA_ATTN

Comment on lines +279 to +281
tl.store(
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
acc / e_sum,
Comment on lines 80 to 83
raise ValueError(
"Invalid attention backend you specified.\n"
"Now only support [NATIVE_ATTN, MLA_ATTN, APPEND_ATTN] in cuda place."
"Now only support [NATIVE_ATTN, MLA_ATTN, APPEND_ATTN, TRITON_MLA_ATTN] in cuda place."
)
Comment on lines +76 to +78
elif selected_backend == _Backend.TRITON_MLA_ATTN:
logger.info("Using TRITON MLA ATTN backend.")
return "fastdeploy.model_executor.layers.attention.TritonMLAAttentionBackend"
PaddlePaddle-bot

This comment was marked as outdated.

@PaddlePaddle-bot
Copy link
Copy Markdown

PaddlePaddle-bot commented May 13, 2026

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-16 02:45:40

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

当前无 required 失败任务,2 个可选任务全部通过。⚠️ 有 7 个 Workflow 处于 action_required 状态,等待人工审批后才会执行(含主要 CI 流程)。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
2(0) 2 2 0 0 0 0

⚠️ 注意:以下 7 个 Workflow 处于 action_required 状态(等待审批后才会执行):ApprovalCI_HPUCheck PR TemplateILUVATAR-CICI_XPUCodestyle-CheckPR Build and Test。这些 Workflow 需人工审批触发,包含单元测试和代码规范检查等关键流程,建议审批后等待 CI 完整运行结果

注意:action_required workflows 不计入上表的任务统计。


2 任务状态汇总

2.1 Required 任务 : 0/0 通过

当前未检测到分支保护规则配置的必选任务。

2.2 可选任务 — 2/2 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
Trigger Jenkins for PR 58s - -
Remove skip-ci labels on new commits 4s - -

3 失败详情(仅 required)

无 required 失败任务。

attn_out = paddle.zeros(
[hidden_states.shape[0], self.num_attention_heads_tp * self.v_head_dim],
dtype=hidden_states.dtype,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是None直接报错就好啦,感觉没啥必要新增这段逻辑?

@@ -0,0 +1,368 @@
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要添加算子单测

@@ -0,0 +1,499 @@
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同样需要添加单测

@@ -0,0 +1,147 @@
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要添加算子单测

@CLAassistant
Copy link
Copy Markdown

CLAassistant commented May 15, 2026

CLA assistant check
All committers have signed the CLA.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.

self.causal: bool = getattr(fd_config.model_config, "causal", True)

self.num_heads: int = num_heads
self.head_dim: int = fd_config.model_config.head_dim
self.max_kv_splits: int = 32

self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.useless_tensor = paddle.randn([1]).cast("int32")
Comment on lines +182 to +192
seq_lens_decoder = forward_meta.seq_lens_decoder
seq_lens_this_time = forward_meta.seq_lens_this_time
decode_mask = seq_lens_decoder > 0
decode_bs = int(decode_mask.sum().item())
metadata.decode_bs = decode_bs

if decode_bs > 0:
decode_seq_lens = (seq_lens_decoder + seq_lens_this_time)[decode_mask]
decode_block_tables = forward_meta.block_tables[decode_mask]
total_kv_len = int(paddle.sum(decode_seq_lens).item())

Comment on lines +318 to +335
total_tokens = q.shape[0]
Lv = self.kv_lora_rank

# Decode tokens are at positions cu_seqlens_q[i] for sequences with seq_lens_decoder > 0
cu_seqlens = forward_meta.cu_seqlens_q
seq_lens_decoder = forward_meta.seq_lens_decoder
decode_mask = seq_lens_decoder > 0
max_num_seqs = seq_lens_decoder.shape[0]
seq_indices = paddle.arange(max_num_seqs, dtype="int32")
decode_seq_indices = seq_indices[decode_mask]
decode_token_positions = cu_seqlens[decode_seq_indices]

q_decode = q[decode_token_positions]
decode_out = self._run_decode_kernel(q_decode, latent_cache, metadata)

output = paddle.zeros([total_tokens, self.num_heads * Lv], dtype=q.dtype)
output[decode_token_positions] = decode_out
return output
Comment on lines +344 to +351
bs = q.shape[0]
Lv = self.kv_lora_rank
latent_dim = self.kv_lora_rank + self.qk_rope_head_dim
q_reshaped = q.reshape([bs, self.num_heads, latent_dim])

attn_logits = paddle.empty([bs, self.num_heads, self.max_kv_splits, Lv], dtype="float32")
attn_lse = paddle.empty([bs, self.num_heads, self.max_kv_splits], dtype="float32")
o = paddle.empty([bs, self.num_heads, Lv], dtype=q.dtype)
PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.

Comments suppressed due to low confidence (4)

fastdeploy/model_executor/models/deepseek_v3.py:1065

  • 早退返回的是单个 hidden_states,但调用方(DeepSeekV3Model.forward 中 hidden_states, residual = self.layers[i](...))按二元组解包。当 need_do_prefillneed_do_decode 都为 False 时这里会直接 ValueError。应当返回 (hidden_states, residual)
        if not need_do_prefill and not need_do_decode:
            return hidden_states

fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py:335

  • forward_mixedk is None 分支下,对 extend+decode 混合 batch 只会运行 decode kernel,extend token 位置上的 output 始终被填 0(333-334 行),相当于丢弃了 extend token 的注意力输出。如果该分支确实需要支持混合 batch,则需要同时执行 extend;如果约定该 backend 在 mixed 调用时永不出现 extend,至少应当 assert decode_bs == batch_size,否则将产生静默错误。
        # Mixed batch (no CUDAGraph): q has all tokens (extend + decode).
        # Extract decode tokens (1 per decode sequence), run kernel, scatter back.
        decode_bs = metadata.decode_bs
        if decode_bs == 0:
            Lv = self.kv_lora_rank
            return paddle.zeros([q.shape[0], self.num_heads * Lv], dtype=q.dtype)

        total_tokens = q.shape[0]
        Lv = self.kv_lora_rank

        # Decode tokens are at positions cu_seqlens_q[i] for sequences with seq_lens_decoder > 0
        cu_seqlens = forward_meta.cu_seqlens_q
        seq_lens_decoder = forward_meta.seq_lens_decoder
        decode_mask = seq_lens_decoder > 0
        max_num_seqs = seq_lens_decoder.shape[0]
        seq_indices = paddle.arange(max_num_seqs, dtype="int32")
        decode_seq_indices = seq_indices[decode_mask]
        decode_token_positions = cu_seqlens[decode_seq_indices]

        q_decode = q[decode_token_positions]
        decode_out = self._run_decode_kernel(q_decode, latent_cache, metadata)

        output = paddle.zeros([total_tokens, self.num_heads * Lv], dtype=q.dtype)
        output[decode_token_positions] = decode_out
        return output

fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py:351

  • _run_decode_kernel 每次调用都通过 paddle.empty 新分配 attn_logitsattn_lseo。CUDA Graph 要求 capture 与 replay 使用相同的内存地址,每次新分配会破坏图捕获的稳定性(与本文件 129-134 行刻意预分配 _kv_indptr_buf 等 buffer 的初衷相矛盾)。建议把这些中间 buffer 也按 max_num_seqs 预分配,并在使用时按实际 bs 切片。
        attn_logits = paddle.empty([bs, self.num_heads, self.max_kv_splits, Lv], dtype="float32")
        attn_lse = paddle.empty([bs, self.num_heads, self.max_kv_splits], dtype="float32")
        o = paddle.empty([bs, self.num_heads, Lv], dtype=q.dtype)

fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py:262

  • flash_attention_v3_varlenflash_attn_unpadded 的关键字参数名不同(前者通常使用 softmax_scale,后者使用 scale),且返回结构不一致(v3 返回 (out, softmax_lse) 风格,而 flash_attn_unpadded 返回结构也是 tuple,但元素含义不同)。统一通过 [0] 取第一个元素在两条路径上需要确认是否都对应输出 tensor;另外建议增加对 flash_attention_v3_varlen is None 但 SM>=90 时的兜底处理(当前 try/except 仅在 import 阶段,运行期若 import 成功但调用失败将抛错)。请确认两套 API 调用确实兼容。
        fmha_out = self.flash_attn_func(
            q,
            k,
            v,
            forward_meta.cu_seqlens_q,
            forward_meta.cu_seqlens_k,
            metadata.max_enc_len_this_time,
            metadata.max_enc_len_this_time,
            causal=self.causal,
            **self.flash_attn_kwargs,
        )[0]

Comment on lines +180 to +191
# Pre-compute decode kv_indptr/kv_indices into stable pre-allocated buffers.
# CUDAGraph requires tensors at the same memory address between capture and replay.
seq_lens_decoder = forward_meta.seq_lens_decoder
seq_lens_this_time = forward_meta.seq_lens_this_time
decode_mask = seq_lens_decoder > 0
decode_bs = int(decode_mask.sum().item())
metadata.decode_bs = decode_bs

if decode_bs > 0:
decode_seq_lens = (seq_lens_decoder + seq_lens_this_time)[decode_mask]
decode_block_tables = forward_meta.block_tables[decode_mask]
total_kv_len = int(paddle.sum(decode_seq_lens).item())
Comment on lines +37 to +42
def tanh(x):
return 2 * tl.sigmoid(2 * x) - 1


@enable_compat_on_triton_kernel
@triton.jit
Comment on lines +1061 to +1065
need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0
need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0

if not need_do_prefill and not need_do_decode:
return hidden_states
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.

Comments suppressed due to low confidence (1)

fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py:335

  • forward_mixed 的 CUDA Graph 分支条件 q.shape[0] == metadata.decode_bs 依赖于在 init_attention_metadata 中通过 .item() 同步出来的 decode_bs。这两点共同导致 CUDA Graph 实际无法用于该 backend:(a) capture 期内有 device→host 同步,(b) 不同 step 的 decode_bs 值通过 Python int 路径影响分支,会破坏 graph 的一致性。建议显式按 capture batch size 走单一路径(例如总是 _run_decode_kernel),并通过预分配的 padded buffer 在 kernel 内部用 mask 处理无效 batch。
        # CUDAGraph path: q contains exactly the captured batch size of decode tokens.
        # Must always take this path during CUDAGraph capture/replay to keep the
        # execution trace identical (same kernel launches, same tensor shapes).
        if forward_meta.step_use_cudagraph or q.shape[0] == metadata.decode_bs:
            return self._run_decode_kernel(q, latent_cache, metadata)

        # Mixed batch (no CUDAGraph): q has all tokens (extend + decode).
        # Extract decode tokens (1 per decode sequence), run kernel, scatter back.
        decode_bs = metadata.decode_bs
        if decode_bs == 0:
            Lv = self.kv_lora_rank
            return paddle.zeros([q.shape[0], self.num_heads * Lv], dtype=q.dtype)

        total_tokens = q.shape[0]
        Lv = self.kv_lora_rank

        # Decode tokens are at positions cu_seqlens_q[i] for sequences with seq_lens_decoder > 0
        cu_seqlens = forward_meta.cu_seqlens_q
        seq_lens_decoder = forward_meta.seq_lens_decoder
        decode_mask = seq_lens_decoder > 0
        max_num_seqs = seq_lens_decoder.shape[0]
        seq_indices = paddle.arange(max_num_seqs, dtype="int32")
        decode_seq_indices = seq_indices[decode_mask]
        decode_token_positions = cu_seqlens[decode_seq_indices]

        q_decode = q[decode_token_positions]
        decode_out = self._run_decode_kernel(q_decode, latent_cache, metadata)

        output = paddle.zeros([total_tokens, self.num_heads * Lv], dtype=q.dtype)
        output[decode_token_positions] = decode_out
        return output

Comment on lines +349 to +351
attn_logits = paddle.empty([bs, self.num_heads, self.max_kv_splits, Lv], dtype="float32")
attn_lse = paddle.empty([bs, self.num_heads, self.max_kv_splits], dtype="float32")
o = paddle.empty([bs, self.num_heads, Lv], dtype=q.dtype)
Comment on lines +185 to +198
decode_bs = int(decode_mask.sum().item())
metadata.decode_bs = decode_bs

if decode_bs > 0:
decode_seq_lens = (seq_lens_decoder + seq_lens_this_time)[decode_mask]
decode_block_tables = forward_meta.block_tables[decode_mask]
total_kv_len = int(paddle.sum(decode_seq_lens).item())

build_kv_indices_from_block_tables(
decode_block_tables, decode_seq_lens, self.block_size, decode_bs,
total_kv_len=total_kv_len,
kv_indptr_buf=self._kv_indptr_buf,
kv_indices_buf=self._kv_indices_buf,
)
self.max_kv_splits: int = 32

self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.useless_tensor = paddle.randn([1]).cast("int32")
Comment on lines +252 to +262
fmha_out = self.flash_attn_func(
q,
k,
v,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time,
metadata.max_enc_len_this_time,
causal=self.causal,
**self.flash_attn_kwargs,
)[0]
Comment on lines +353 to +366
decode_attention_fwd(
q_reshaped,
latent_cache,
latent_cache[:, :, :, :self.kv_lora_rank],
o,
metadata.kv_indptr,
metadata.kv_indices,
attn_logits,
attn_lse,
metadata.num_kv_splits,
self.max_kv_splits,
self.attn_softmax_scale,
self.block_size,
)
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.

Comments suppressed due to low confidence (2)

fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py:203

  • 注释里写的是 “kv_indptr[decode_bs] = total_kv_len; positions beyond must equal the same so that (kv_indptr[i+1] - kv_indptr[i]) = 0 for padded batches”,但代码 self._kv_indptr_buf[decode_bs + 1:] = total_kv_len 只覆盖了 decode_bs+1 及以后的位置,没有显式写 decode_bs 这个位置——这隐式依赖了 build_kv_indices_from_block_tables 已经把 kv_indptr[decode_bs] = total_kv_len 写好。建议要么注释明确这层依赖,要么把赋值改成 self._kv_indptr_buf[decode_bs:self.max_num_seqs + 1] = total_kv_len 以便独立成立。
            # Fill padded entries in kv_indptr so out-of-range batches see 0 KV length.
            # kv_indptr[decode_bs] = total_kv_len; positions beyond must equal the same
            # so that (kv_indptr[i+1] - kv_indptr[i]) = 0 for padded batches.
            if decode_bs < self.max_num_seqs:
                self._kv_indptr_buf[decode_bs + 1:] = total_kv_len

fastdeploy/model_executor/models/deepseek_v3.py:1065

  • PR 描述与代码改动有不一致之处:描述中 deepseek_v3.py 的修改是 “attn_out 空值保护,避免空 batch 时 None 传入 o_proj”,但实际 diff 改在了 DeepSeekV3DecoderLayer.forward 入口(基于 max_len_tensor_cpu[1]/[2] 做整层早退)。请同步更新 PR 描述或调整代码到 o_proj 处加 None 保护,以保证可审计性。
        need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0
        need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0

        if not need_do_prefill and not need_do_decode:
            return hidden_states

need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0

if not need_do_prefill and not need_do_decode:
return hidden_states
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

关注下捏

Comment on lines +184 to +214
decode_mask = seq_lens_decoder > 0
decode_bs = int(decode_mask.sum().item())
metadata.decode_bs = decode_bs

if decode_bs > 0:
decode_seq_lens = (seq_lens_decoder + seq_lens_this_time)[decode_mask]
decode_block_tables = forward_meta.block_tables[decode_mask]
total_kv_len = int(paddle.sum(decode_seq_lens).item())

build_kv_indices_from_block_tables(
decode_block_tables, decode_seq_lens, self.block_size, decode_bs,
total_kv_len=total_kv_len,
kv_indptr_buf=self._kv_indptr_buf,
kv_indices_buf=self._kv_indices_buf,
)
# Fill padded entries in kv_indptr so out-of-range batches see 0 KV length.
# kv_indptr[decode_bs] = total_kv_len; positions beyond must equal the same
# so that (kv_indptr[i+1] - kv_indptr[i]) = 0 for padded batches.
if decode_bs < self.max_num_seqs:
self._kv_indptr_buf[decode_bs + 1:] = total_kv_len

# Compute num_kv_splits into the pre-allocated buffer
compute_num_kv_splits(decode_seq_lens, decode_bs, self.max_kv_splits,
out_buf=self._num_kv_splits_buf)
# Padded entries must be >= 1 to avoid division by zero in kernel
if decode_bs < self.max_num_seqs:
self._num_kv_splits_buf[decode_bs:] = 1
else:
# No decode sequences: fill buffers with safe defaults
self._kv_indptr_buf[:] = 0
self._num_kv_splits_buf[:] = 1
Comment on lines +349 to +351
attn_logits = paddle.empty([bs, self.num_heads, self.max_kv_splits, Lv], dtype="float32")
attn_lse = paddle.empty([bs, self.num_heads, self.max_kv_splits], dtype="float32")
o = paddle.empty([bs, self.num_heads, Lv], dtype=q.dtype)
Comment on lines +305 to +316
# CUDAGraph path: q contains exactly the captured batch size of decode tokens.
# Must always take this path during CUDAGraph capture/replay to keep the
# execution trace identical (same kernel launches, same tensor shapes).
if forward_meta.step_use_cudagraph or q.shape[0] == metadata.decode_bs:
return self._run_decode_kernel(q, latent_cache, metadata)

# Mixed batch (no CUDAGraph): q has all tokens (extend + decode).
# Extract decode tokens (1 per decode sequence), run kernel, scatter back.
decode_bs = metadata.decode_bs
if decode_bs == 0:
Lv = self.kv_lora_rank
return paddle.zeros([q.shape[0], self.num_heads * Lv], dtype=q.dtype)

@dataclass
class TritonMLAAttentionMetadata(AttentionMetadata):
_dtype: str = "bfloat16"
PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings May 15, 2026 12:09
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

PaddlePaddle-bot

This comment was marked as outdated.

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented May 15, 2026

Codecov Report

❌ Patch coverage is 33.60000% with 249 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@12c6ae0). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...r/layers/attention/triton_mla_attention_backend.py 24.26% 127 Missing and 1 partial ⚠️
...or/layers/attention/triton_ops/decode_attention.py 35.86% 93 Missing ⚠️
...or/layers/attention/triton_ops/mla_cache_kernel.py 48.88% 23 Missing ⚠️
fastdeploy/platforms/cuda.py 0.00% 2 Missing and 1 partial ⚠️
fastdeploy/model_executor/models/deepseek_v3.py 50.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7804   +/-   ##
==========================================
  Coverage           ?   63.63%           
==========================================
  Files              ?      466           
  Lines              ?    64648           
  Branches           ?     9883           
==========================================
  Hits               ?    41137           
  Misses             ?    20714           
  Partials           ?     2797           
Flag Coverage Δ
GPU 72.73% <33.60%> (?)
XPU 7.08% <0.26%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.


self.prefix = prefix

prop = paddle.device.cuda.get_device_properties()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥修改原本的逻辑呢?

@PaddlePaddle-bot
Copy link
Copy Markdown

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-15 23:51:50

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

1 个 Required 任务失败(Approval 未通过审批),需优先处理后才能合并。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
26(0) 26 21 4 0 1 0

2 任务状态汇总

2.1 Required任务 : 2/3 通过

必选任务阻塞合并,失败需优先处理。

状态 任务 耗时 根因 修复建议 日志 重跑
Approval 9s PR问题:新增logger.info/debug调用,需RD审批 联系 xyxinyang 或 zyyzghb 在 PR 页面 Approve Job -
其余 2 个必选任务通过 - - - - -

2.2 可选任务 — 19/23 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
Run iluvatar Tests / run_iluvatar_cases 10m48s Job -
Agent (Copilot code review) 2m13s Job -
Cleanup artifacts (Copilot code review) 5s Job -
⏸️ CI_HPU - 等待中 -
其余 19 个可选任务通过 - - -

3 失败详情(仅 required)

Approval — 代码规范(置信度: 高)

Approval

  • 状态: ❌ 失败
  • 错误类型: 代码规范
  • 置信度: 高
  • 根因摘要: PR 新增 logger.info/debug 日志调用,需 RD 审批
  • 分析器: 通用分析(fallback)

根因详情:
PR 在 Triton MLA 相关代码中新增了多处日志调用(logger.debuglogger.info),触发了 scripts/check_approval.sh 审批检查。该脚本要求修改日志行为(.info/.debug/.error/log_request)时,必须获得 FastDeploy RD(xyxinyang/zhouchong 或 zyyzghb/zhangyongyue)的 GitHub Review Approve 后方可通过。

关键日志:

Detected log modification in diff:
+    logger.debug(f"flash_attention_v3_varlen not available: {e}")
+                logger.info("TritonMLAAttentionBackend: Using Flash Attention V3.")
+                logger.info("TritonMLAAttentionBackend: Using Flash Attention V2.")
+            logger.info("Using TRITON MLA ATTN backend.")
0. You must have one FastDeploy RD (xyxinyang(zhouchong), zyyzghb(zhangyongyue)) approval for modifying logging behavior.
There are 1 approved errors.
##[error]Process completed with exit code 6.

修复建议:

  1. 请联系 FastDeploy RD(xyxinyang 或 zyyzghb)在 GitHub PR 页面进行 Approve Review。
  2. 审批完成后 CI 将自动重新触发。

修复建议摘要: 联系 xyxinyang 或 zyyzghb 在 PR 页面 Approve Review

关联变更: PR 中新增了 Triton MLA 相关的 logger.debug/logger.info 调用

链接: 查看日志

PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-16 03:06:30

📋 Review 摘要

PR 概述:将 SGLang 中基于 Triton 实现的 MLA decode attention 迁移至 FastDeploy,新增 TRITON_MLA_ATTN 后端
变更范围layers/attention/triton_ops/models/deepseek_v3.pyconfig.pyplatforms/worker/
影响面 Tag[OP] [Models] [FDConfig]

问题

级别 文件 概述
🟡 建议 triton_mla_attention_backend.py:267 forward_extendmax_seqlen_k 使用了 max_enc_len_this_time,在有 prefix KV cache 的场景可能应为 max_kv_len_this_time
❓ 疑问 deepseek_v3.py:1108 max_len_tensor_cpu[1][2] 使用魔法数字索引,缺少注释说明布局
❓ 疑问 triton_mla_attention_backend.py:319 forward_mixedk=None 混合 batch 路径中 extend token 输出为零,是否为预期行为?
📝 PR 规范 标题缺 Tag;描述缺 ## Accuracy Tests## Checklist

📝 PR 规范检查

标题 "Triton mla" 缺少官方 Tag,PR 描述缺少 ## Accuracy Tests## Checklist 两个必填段落。

标题建议(可直接复制):

  • [Feature][OP] Add Triton MLA Attention Backend (TRITON_MLA_ATTN)

PR 描述建议(可直接复制):

## Motivation

将 SGLang 中基于 Triton 实现的 MLA (Multi-head Latent Attention) decode attention 迁移至 FastDeploy,作为新的 attention backend (`TRITON_MLA_ATTN`),为 DeepSeek-V3 等 MLA 模型提供纯 Python/Triton 的推理路径,无需依赖 custom CUDA op。

## Modifications

### 新增文件

| 文件 | 说明 |
|------|------|
| `fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py` | Triton MLA Backend 核心实现,包含 extend/decode/mixed forward 及 CUDA Graph 兼容的 metadata 初始化 |
| `fastdeploy/model_executor/layers/attention/triton_ops/decode_attention.py` | Split-KV 两阶段 decode attention triton kernel(从 SGLang 迁移,适配 FastDeploy paged KV cache 寻址) |
| `fastdeploy/model_executor/layers/attention/triton_ops/mla_cache_kernel.py` | Triton KV cache write kernel,将 `[compressed_kv || k_pe]` 写入分页 cache |

### 修改文件

| 文件 | 说明 |
|------|------|
| `fastdeploy/platforms/base.py` | 新增 `_Backend.TRITON_MLA_ATTN` 枚举值 |
| `fastdeploy/platforms/cuda.py` | 添加 backend 路由至 `TritonMLAAttentionBackend` |
| `fastdeploy/config.py` | `CacheConfig.use_mla_cache` 识别 `TRITON_MLA_ATTN` |
| `fastdeploy/worker/gpu_model_runner.py` | MLA cache 路径识别 + `_apply_position_ids_if_needed` 支持新 backend |
| `fastdeploy/model_executor/layers/attention/__init__.py` | 注册并导出 `TritonMLAAttentionBackend` |
| `fastdeploy/model_executor/models/deepseek_v3.py` | `attn_out` 空值保护,避免空 batch 时 None 传入 `o_proj` |
| `custom_ops/gpu_ops/helper.h` | C++ 层 `checkAttentionBackend()` 识别新 backend |

## Usage or Command

```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
export FD_ATTENTION_BACKEND="TRITON_MLA_ATTN"
export FLAGS_flash_attn_version=3
export FD_SAMPLING_CLASS=rejection

python -m fastdeploy.entrypoints.openai.api_server \
    --model /path/to/GLM-4.7-Flash \
    --port 8380 \
    --tensor-parallel-size 4 \
    --max-model-len 32768 \
    --max-num-seqs 32
```

## Accuracy Tests

N/A

## Checklist

- [ ] Add at least a tag in the PR title.
  - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
  - You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.

总体评价

整体设计合理,split-KV 两阶段 triton kernel、paged cache 寻址适配以及 CUDAGraph 预分配 buffer 思路清晰。建议作者澄清 forward_extendmax_seqlen_k 取值逻辑,以及 forward_mixedk=None 时对 extend token 的处理是否符合预期,确保精度正确性。

forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time,
metadata.max_enc_len_this_time,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 forward_extendmax_seqlen_k 传入了 max_enc_len_this_time(等于当前 step 最大 query 长度),而非 max_kv_len_this_time

当存在 prefix KV cache 或 chunked prefill 时,cu_seqlens_k 中的序列长度可能大于 max_enc_len_this_time。对于 flash_attn_unpaddedmax_seqlen_k 是影响 kernel 正确性的参数,若低于真实最大 KV 长度会导致 KV 被截断。

建议核实:extend 阶段 cu_seqlens_k 是否与 cu_seqlens_q 严格相等(纯 prefill 无 prefix cache 复用);若成立请加注释说明;否则应改为:

metadata.max_kv_len_this_time,  # max_seqlen_k

residual: paddle.Tensor,
):
""" """
need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 max_len_tensor_cpu[1][2] 使用了魔法数字索引,缺少说明。这两个下标的含义(index 1 = max encode len,index 2 = max decode len)与 init_attention_metadata 中的隐式约定耦合,若布局变化会静默出错。

建议加注释说明:

# max_len_tensor_cpu layout: [_, max_enc_len, max_dec_len, ...]
need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0
need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0

if forward_meta.step_use_cudagraph or q.shape[0] == metadata.decode_bs:
return self._run_decode_kernel(q, latent_cache, metadata)

# Mixed batch (no CUDAGraph): q has all tokens (extend + decode).
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 此分支注释说明处理 "extend + decode" 混合 batch,但代码只计算了 decode token 的 attention,extend token 的输出被置零后直接返回:

output = paddle.zeros([total_tokens, self.num_heads * Lv], dtype=q.dtype)
output[decode_token_positions] = decode_out
return output

如果调用方保证 k=None 时不存在 extend 序列(extend 路径由 forward_extend 处理,即 k is not None),则逻辑正确。请加注释说明该假设,例如:

# Invariant: when k is None, no extend sequences are present.
# Mixed extend+decode is always routed through forward_extend (k is not None).

否则 extend token 输出为零会导致 o_proj 输入错误,产生静默精度问题。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants