Triton mla#7804
Conversation
There was a problem hiding this comment.
Pull request overview
该 PR 将 SGLang 中基于 Triton 的 MLA(Multi-head Latent Attention)decode attention 迁移到 FastDeploy,作为新的 attention backend:TRITON_MLA_ATTN,并打通对应的后端选择、KV cache 写入与 decode kernel 逻辑。
Changes:
- 新增
TritonMLAAttentionBackend:extend 复用 FlashAttention,decode 走 Triton split-KV 两阶段 kernel,并加入 CUDA Graph 相关的元数据/缓冲区预分配。 - 新增 Triton kernel:paged KV cache 写入(
mla_cache_kernel.py)与 decode attention(decode_attention.py),并在triton_ops/__init__.py导出。 - 平台/配置/运行时适配:新增
_Backend.TRITON_MLA_ATTN,CUDA 平台路由、use_mla_cache识别、GPU runner 的 MLA cache 判断与 DeepSeek-V3 空 batch 保护。
需要关注(非代码行评论):
- PR 标题目前为
"Triton mla",不符合仓库约定的[CLASS]Title格式;建议例如:[Feature] Add Triton MLA attention backend(或按实际分类调整)。 - 该 PR 引入新 backend 与新的环境变量取值(
FD_ATTENTION_BACKEND=TRITON_MLA_ATTN),建议同步检查/更新相关使用文档(如环境变量说明文档)以避免用户漏配或误配。
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| fastdeploy/worker/gpu_model_runner.py | 识别 TRITON_MLA_ATTN 作为 MLA cache 路径,并让 position_ids/slot_mapping 计算覆盖新 backend |
| fastdeploy/platforms/cuda.py | CUDA 平台增加 TRITON_MLA_ATTN 路由与日志/错误提示更新 |
| fastdeploy/platforms/base.py | 新增 _Backend.TRITON_MLA_ATTN 枚举值 |
| fastdeploy/config.py | CacheConfig.use_mla_cache 识别 TRITON_MLA_ATTN |
| fastdeploy/model_executor/layers/attention/init.py | 注册并导出 TritonMLAAttentionBackend |
| fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py | Triton MLA backend 核心实现(extend/ decode/ mixed + metadata/buffer 预分配) |
| fastdeploy/model_executor/layers/attention/triton_ops/init.py | 导出 triton_ops 下的核心函数 |
| fastdeploy/model_executor/layers/attention/triton_ops/decode_attention.py | 新增 split-KV decode attention Triton kernel(paged KV 寻址) |
| fastdeploy/model_executor/layers/attention/triton_ops/mla_cache_kernel.py | 新增 Triton KV cache 写入 kernel(写入 paged latent cache) |
| fastdeploy/model_executor/models/deepseek_v3.py | 空 batch 时 attn_out 保护,避免 None 传入 o_proj |
| custom_ops/gpu_ops/helper.h | C++ 侧 checkAttentionBackend() 识别 TRITON_MLA_ATTN |
| tl.store( | ||
| O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, | ||
| acc / e_sum, |
| raise ValueError( | ||
| "Invalid attention backend you specified.\n" | ||
| "Now only support [NATIVE_ATTN, MLA_ATTN, APPEND_ATTN] in cuda place." | ||
| "Now only support [NATIVE_ATTN, MLA_ATTN, APPEND_ATTN, TRITON_MLA_ATTN] in cuda place." | ||
| ) |
| elif selected_backend == _Backend.TRITON_MLA_ATTN: | ||
| logger.info("Using TRITON MLA ATTN backend.") | ||
| return "fastdeploy.model_executor.layers.attention.TritonMLAAttentionBackend" |
CI报告基于以下代码生成(30分钟更新一次): 1 任务总览当前无 required 失败任务,2 个可选任务全部通过。
2 任务状态汇总2.1 Required 任务 : 0/0 通过
2.2 可选任务 — 2/2 通过
3 失败详情(仅 required)无 required 失败任务。 |
| attn_out = paddle.zeros( | ||
| [hidden_states.shape[0], self.num_attention_heads_tp * self.v_head_dim], | ||
| dtype=hidden_states.dtype, | ||
| ) |
There was a problem hiding this comment.
这个是None直接报错就好啦,感觉没啥必要新增这段逻辑?
| @@ -0,0 +1,368 @@ | |||
| """ | |||
| @@ -0,0 +1,499 @@ | |||
| """ | |||
| @@ -0,0 +1,147 @@ | |||
| """ | |||
| self.causal: bool = getattr(fd_config.model_config, "causal", True) | ||
|
|
||
| self.num_heads: int = num_heads | ||
| self.head_dim: int = fd_config.model_config.head_dim |
| self.max_kv_splits: int = 32 | ||
|
|
||
| self.rank, self.device_id = init_rank_and_device_id(fd_config) | ||
| self.useless_tensor = paddle.randn([1]).cast("int32") |
| seq_lens_decoder = forward_meta.seq_lens_decoder | ||
| seq_lens_this_time = forward_meta.seq_lens_this_time | ||
| decode_mask = seq_lens_decoder > 0 | ||
| decode_bs = int(decode_mask.sum().item()) | ||
| metadata.decode_bs = decode_bs | ||
|
|
||
| if decode_bs > 0: | ||
| decode_seq_lens = (seq_lens_decoder + seq_lens_this_time)[decode_mask] | ||
| decode_block_tables = forward_meta.block_tables[decode_mask] | ||
| total_kv_len = int(paddle.sum(decode_seq_lens).item()) | ||
|
|
| total_tokens = q.shape[0] | ||
| Lv = self.kv_lora_rank | ||
|
|
||
| # Decode tokens are at positions cu_seqlens_q[i] for sequences with seq_lens_decoder > 0 | ||
| cu_seqlens = forward_meta.cu_seqlens_q | ||
| seq_lens_decoder = forward_meta.seq_lens_decoder | ||
| decode_mask = seq_lens_decoder > 0 | ||
| max_num_seqs = seq_lens_decoder.shape[0] | ||
| seq_indices = paddle.arange(max_num_seqs, dtype="int32") | ||
| decode_seq_indices = seq_indices[decode_mask] | ||
| decode_token_positions = cu_seqlens[decode_seq_indices] | ||
|
|
||
| q_decode = q[decode_token_positions] | ||
| decode_out = self._run_decode_kernel(q_decode, latent_cache, metadata) | ||
|
|
||
| output = paddle.zeros([total_tokens, self.num_heads * Lv], dtype=q.dtype) | ||
| output[decode_token_positions] = decode_out | ||
| return output |
| bs = q.shape[0] | ||
| Lv = self.kv_lora_rank | ||
| latent_dim = self.kv_lora_rank + self.qk_rope_head_dim | ||
| q_reshaped = q.reshape([bs, self.num_heads, latent_dim]) | ||
|
|
||
| attn_logits = paddle.empty([bs, self.num_heads, self.max_kv_splits, Lv], dtype="float32") | ||
| attn_lse = paddle.empty([bs, self.num_heads, self.max_kv_splits], dtype="float32") | ||
| o = paddle.empty([bs, self.num_heads, Lv], dtype=q.dtype) |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.
Comments suppressed due to low confidence (4)
fastdeploy/model_executor/models/deepseek_v3.py:1065
- 早退返回的是单个
hidden_states,但调用方(DeepSeekV3Model.forward 中hidden_states, residual = self.layers[i](...))按二元组解包。当need_do_prefill和need_do_decode都为 False 时这里会直接 ValueError。应当返回(hidden_states, residual)。
if not need_do_prefill and not need_do_decode:
return hidden_states
fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py:335
forward_mixed在k is None分支下,对 extend+decode 混合 batch 只会运行 decode kernel,extend token 位置上的 output 始终被填 0(333-334 行),相当于丢弃了 extend token 的注意力输出。如果该分支确实需要支持混合 batch,则需要同时执行 extend;如果约定该 backend 在 mixed 调用时永不出现 extend,至少应当 assertdecode_bs == batch_size,否则将产生静默错误。
# Mixed batch (no CUDAGraph): q has all tokens (extend + decode).
# Extract decode tokens (1 per decode sequence), run kernel, scatter back.
decode_bs = metadata.decode_bs
if decode_bs == 0:
Lv = self.kv_lora_rank
return paddle.zeros([q.shape[0], self.num_heads * Lv], dtype=q.dtype)
total_tokens = q.shape[0]
Lv = self.kv_lora_rank
# Decode tokens are at positions cu_seqlens_q[i] for sequences with seq_lens_decoder > 0
cu_seqlens = forward_meta.cu_seqlens_q
seq_lens_decoder = forward_meta.seq_lens_decoder
decode_mask = seq_lens_decoder > 0
max_num_seqs = seq_lens_decoder.shape[0]
seq_indices = paddle.arange(max_num_seqs, dtype="int32")
decode_seq_indices = seq_indices[decode_mask]
decode_token_positions = cu_seqlens[decode_seq_indices]
q_decode = q[decode_token_positions]
decode_out = self._run_decode_kernel(q_decode, latent_cache, metadata)
output = paddle.zeros([total_tokens, self.num_heads * Lv], dtype=q.dtype)
output[decode_token_positions] = decode_out
return output
fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py:351
_run_decode_kernel每次调用都通过paddle.empty新分配attn_logits、attn_lse、o。CUDA Graph 要求 capture 与 replay 使用相同的内存地址,每次新分配会破坏图捕获的稳定性(与本文件 129-134 行刻意预分配_kv_indptr_buf等 buffer 的初衷相矛盾)。建议把这些中间 buffer 也按max_num_seqs预分配,并在使用时按实际 bs 切片。
attn_logits = paddle.empty([bs, self.num_heads, self.max_kv_splits, Lv], dtype="float32")
attn_lse = paddle.empty([bs, self.num_heads, self.max_kv_splits], dtype="float32")
o = paddle.empty([bs, self.num_heads, Lv], dtype=q.dtype)
fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py:262
flash_attention_v3_varlen与flash_attn_unpadded的关键字参数名不同(前者通常使用softmax_scale,后者使用scale),且返回结构不一致(v3 返回(out, softmax_lse)风格,而flash_attn_unpadded返回结构也是 tuple,但元素含义不同)。统一通过[0]取第一个元素在两条路径上需要确认是否都对应输出 tensor;另外建议增加对flash_attention_v3_varlen is None但 SM>=90 时的兜底处理(当前 try/except 仅在 import 阶段,运行期若 import 成功但调用失败将抛错)。请确认两套 API 调用确实兼容。
fmha_out = self.flash_attn_func(
q,
k,
v,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time,
metadata.max_enc_len_this_time,
causal=self.causal,
**self.flash_attn_kwargs,
)[0]
| # Pre-compute decode kv_indptr/kv_indices into stable pre-allocated buffers. | ||
| # CUDAGraph requires tensors at the same memory address between capture and replay. | ||
| seq_lens_decoder = forward_meta.seq_lens_decoder | ||
| seq_lens_this_time = forward_meta.seq_lens_this_time | ||
| decode_mask = seq_lens_decoder > 0 | ||
| decode_bs = int(decode_mask.sum().item()) | ||
| metadata.decode_bs = decode_bs | ||
|
|
||
| if decode_bs > 0: | ||
| decode_seq_lens = (seq_lens_decoder + seq_lens_this_time)[decode_mask] | ||
| decode_block_tables = forward_meta.block_tables[decode_mask] | ||
| total_kv_len = int(paddle.sum(decode_seq_lens).item()) |
| def tanh(x): | ||
| return 2 * tl.sigmoid(2 * x) - 1 | ||
|
|
||
|
|
||
| @enable_compat_on_triton_kernel | ||
| @triton.jit |
| need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0 | ||
| need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0 | ||
|
|
||
| if not need_do_prefill and not need_do_decode: | ||
| return hidden_states |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.
Comments suppressed due to low confidence (1)
fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py:335
forward_mixed的 CUDA Graph 分支条件q.shape[0] == metadata.decode_bs依赖于在init_attention_metadata中通过.item()同步出来的decode_bs。这两点共同导致 CUDA Graph 实际无法用于该 backend:(a) capture 期内有 device→host 同步,(b) 不同 step 的decode_bs值通过 Python int 路径影响分支,会破坏 graph 的一致性。建议显式按 capture batch size 走单一路径(例如总是_run_decode_kernel),并通过预分配的 padded buffer 在 kernel 内部用 mask 处理无效 batch。
# CUDAGraph path: q contains exactly the captured batch size of decode tokens.
# Must always take this path during CUDAGraph capture/replay to keep the
# execution trace identical (same kernel launches, same tensor shapes).
if forward_meta.step_use_cudagraph or q.shape[0] == metadata.decode_bs:
return self._run_decode_kernel(q, latent_cache, metadata)
# Mixed batch (no CUDAGraph): q has all tokens (extend + decode).
# Extract decode tokens (1 per decode sequence), run kernel, scatter back.
decode_bs = metadata.decode_bs
if decode_bs == 0:
Lv = self.kv_lora_rank
return paddle.zeros([q.shape[0], self.num_heads * Lv], dtype=q.dtype)
total_tokens = q.shape[0]
Lv = self.kv_lora_rank
# Decode tokens are at positions cu_seqlens_q[i] for sequences with seq_lens_decoder > 0
cu_seqlens = forward_meta.cu_seqlens_q
seq_lens_decoder = forward_meta.seq_lens_decoder
decode_mask = seq_lens_decoder > 0
max_num_seqs = seq_lens_decoder.shape[0]
seq_indices = paddle.arange(max_num_seqs, dtype="int32")
decode_seq_indices = seq_indices[decode_mask]
decode_token_positions = cu_seqlens[decode_seq_indices]
q_decode = q[decode_token_positions]
decode_out = self._run_decode_kernel(q_decode, latent_cache, metadata)
output = paddle.zeros([total_tokens, self.num_heads * Lv], dtype=q.dtype)
output[decode_token_positions] = decode_out
return output
| attn_logits = paddle.empty([bs, self.num_heads, self.max_kv_splits, Lv], dtype="float32") | ||
| attn_lse = paddle.empty([bs, self.num_heads, self.max_kv_splits], dtype="float32") | ||
| o = paddle.empty([bs, self.num_heads, Lv], dtype=q.dtype) |
| decode_bs = int(decode_mask.sum().item()) | ||
| metadata.decode_bs = decode_bs | ||
|
|
||
| if decode_bs > 0: | ||
| decode_seq_lens = (seq_lens_decoder + seq_lens_this_time)[decode_mask] | ||
| decode_block_tables = forward_meta.block_tables[decode_mask] | ||
| total_kv_len = int(paddle.sum(decode_seq_lens).item()) | ||
|
|
||
| build_kv_indices_from_block_tables( | ||
| decode_block_tables, decode_seq_lens, self.block_size, decode_bs, | ||
| total_kv_len=total_kv_len, | ||
| kv_indptr_buf=self._kv_indptr_buf, | ||
| kv_indices_buf=self._kv_indices_buf, | ||
| ) |
| self.max_kv_splits: int = 32 | ||
|
|
||
| self.rank, self.device_id = init_rank_and_device_id(fd_config) | ||
| self.useless_tensor = paddle.randn([1]).cast("int32") |
| fmha_out = self.flash_attn_func( | ||
| q, | ||
| k, | ||
| v, | ||
| forward_meta.cu_seqlens_q, | ||
| forward_meta.cu_seqlens_k, | ||
| metadata.max_enc_len_this_time, | ||
| metadata.max_enc_len_this_time, | ||
| causal=self.causal, | ||
| **self.flash_attn_kwargs, | ||
| )[0] |
| decode_attention_fwd( | ||
| q_reshaped, | ||
| latent_cache, | ||
| latent_cache[:, :, :, :self.kv_lora_rank], | ||
| o, | ||
| metadata.kv_indptr, | ||
| metadata.kv_indices, | ||
| attn_logits, | ||
| attn_lse, | ||
| metadata.num_kv_splits, | ||
| self.max_kv_splits, | ||
| self.attn_softmax_scale, | ||
| self.block_size, | ||
| ) |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.
Comments suppressed due to low confidence (2)
fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py:203
- 注释里写的是 “kv_indptr[decode_bs] = total_kv_len; positions beyond must equal the same so that (kv_indptr[i+1] - kv_indptr[i]) = 0 for padded batches”,但代码
self._kv_indptr_buf[decode_bs + 1:] = total_kv_len只覆盖了decode_bs+1及以后的位置,没有显式写decode_bs这个位置——这隐式依赖了build_kv_indices_from_block_tables已经把kv_indptr[decode_bs] = total_kv_len写好。建议要么注释明确这层依赖,要么把赋值改成self._kv_indptr_buf[decode_bs:self.max_num_seqs + 1] = total_kv_len以便独立成立。
# Fill padded entries in kv_indptr so out-of-range batches see 0 KV length.
# kv_indptr[decode_bs] = total_kv_len; positions beyond must equal the same
# so that (kv_indptr[i+1] - kv_indptr[i]) = 0 for padded batches.
if decode_bs < self.max_num_seqs:
self._kv_indptr_buf[decode_bs + 1:] = total_kv_len
fastdeploy/model_executor/models/deepseek_v3.py:1065
- PR 描述与代码改动有不一致之处:描述中
deepseek_v3.py的修改是 “attn_out空值保护,避免空 batch 时 None 传入o_proj”,但实际 diff 改在了DeepSeekV3DecoderLayer.forward入口(基于max_len_tensor_cpu[1]/[2]做整层早退)。请同步更新 PR 描述或调整代码到o_proj处加 None 保护,以保证可审计性。
need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0
need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0
if not need_do_prefill and not need_do_decode:
return hidden_states
| need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0 | ||
|
|
||
| if not need_do_prefill and not need_do_decode: | ||
| return hidden_states |
| decode_mask = seq_lens_decoder > 0 | ||
| decode_bs = int(decode_mask.sum().item()) | ||
| metadata.decode_bs = decode_bs | ||
|
|
||
| if decode_bs > 0: | ||
| decode_seq_lens = (seq_lens_decoder + seq_lens_this_time)[decode_mask] | ||
| decode_block_tables = forward_meta.block_tables[decode_mask] | ||
| total_kv_len = int(paddle.sum(decode_seq_lens).item()) | ||
|
|
||
| build_kv_indices_from_block_tables( | ||
| decode_block_tables, decode_seq_lens, self.block_size, decode_bs, | ||
| total_kv_len=total_kv_len, | ||
| kv_indptr_buf=self._kv_indptr_buf, | ||
| kv_indices_buf=self._kv_indices_buf, | ||
| ) | ||
| # Fill padded entries in kv_indptr so out-of-range batches see 0 KV length. | ||
| # kv_indptr[decode_bs] = total_kv_len; positions beyond must equal the same | ||
| # so that (kv_indptr[i+1] - kv_indptr[i]) = 0 for padded batches. | ||
| if decode_bs < self.max_num_seqs: | ||
| self._kv_indptr_buf[decode_bs + 1:] = total_kv_len | ||
|
|
||
| # Compute num_kv_splits into the pre-allocated buffer | ||
| compute_num_kv_splits(decode_seq_lens, decode_bs, self.max_kv_splits, | ||
| out_buf=self._num_kv_splits_buf) | ||
| # Padded entries must be >= 1 to avoid division by zero in kernel | ||
| if decode_bs < self.max_num_seqs: | ||
| self._num_kv_splits_buf[decode_bs:] = 1 | ||
| else: | ||
| # No decode sequences: fill buffers with safe defaults | ||
| self._kv_indptr_buf[:] = 0 | ||
| self._num_kv_splits_buf[:] = 1 |
| attn_logits = paddle.empty([bs, self.num_heads, self.max_kv_splits, Lv], dtype="float32") | ||
| attn_lse = paddle.empty([bs, self.num_heads, self.max_kv_splits], dtype="float32") | ||
| o = paddle.empty([bs, self.num_heads, Lv], dtype=q.dtype) |
| # CUDAGraph path: q contains exactly the captured batch size of decode tokens. | ||
| # Must always take this path during CUDAGraph capture/replay to keep the | ||
| # execution trace identical (same kernel launches, same tensor shapes). | ||
| if forward_meta.step_use_cudagraph or q.shape[0] == metadata.decode_bs: | ||
| return self._run_decode_kernel(q, latent_cache, metadata) | ||
|
|
||
| # Mixed batch (no CUDAGraph): q has all tokens (extend + decode). | ||
| # Extract decode tokens (1 per decode sequence), run kernel, scatter back. | ||
| decode_bs = metadata.decode_bs | ||
| if decode_bs == 0: | ||
| Lv = self.kv_lora_rank | ||
| return paddle.zeros([q.shape[0], self.num_heads * Lv], dtype=q.dtype) |
|
|
||
| @dataclass | ||
| class TritonMLAAttentionMetadata(AttentionMetadata): | ||
| _dtype: str = "bfloat16" |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## develop #7804 +/- ##
==========================================
Coverage ? 63.63%
==========================================
Files ? 466
Lines ? 64648
Branches ? 9883
==========================================
Hits ? 41137
Misses ? 20714
Partials ? 2797
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
|
||
| self.prefix = prefix | ||
|
|
||
| prop = paddle.device.cuda.get_device_properties() |
CI报告基于以下代码生成(30分钟更新一次): 1 任务总览有 1 个 Required 任务失败(Approval 未通过审批),需优先处理后才能合并。
2 任务状态汇总2.1 Required任务 : 2/3 通过
2.2 可选任务 — 19/23 通过
3 失败详情(仅 required)Approval — 代码规范(置信度: 高)Approval
根因详情: 关键日志: 修复建议:
修复建议摘要: 联系 xyxinyang 或 zyyzghb 在 PR 页面 Approve Review 关联变更: PR 中新增了 Triton MLA 相关的 链接: 查看日志 |
PaddlePaddle-bot
left a comment
There was a problem hiding this comment.
🤖 Paddle-CI-Agent | pr_review |
2026-05-16 03:06:30
📋 Review 摘要
PR 概述:将 SGLang 中基于 Triton 实现的 MLA decode attention 迁移至 FastDeploy,新增 TRITON_MLA_ATTN 后端
变更范围:layers/attention/triton_ops/、models/deepseek_v3.py、config.py、platforms/、worker/
影响面 Tag:[OP] [Models] [FDConfig]
问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 🟡 建议 | triton_mla_attention_backend.py:267 |
forward_extend 中 max_seqlen_k 使用了 max_enc_len_this_time,在有 prefix KV cache 的场景可能应为 max_kv_len_this_time |
| ❓ 疑问 | deepseek_v3.py:1108 |
max_len_tensor_cpu[1] 和 [2] 使用魔法数字索引,缺少注释说明布局 |
| ❓ 疑问 | triton_mla_attention_backend.py:319 |
forward_mixed 的 k=None 混合 batch 路径中 extend token 输出为零,是否为预期行为? |
| 📝 PR 规范 | — | 标题缺 Tag;描述缺 ## Accuracy Tests 和 ## Checklist |
📝 PR 规范检查
标题 "Triton mla" 缺少官方 Tag,PR 描述缺少 ## Accuracy Tests 和 ## Checklist 两个必填段落。
标题建议(可直接复制):
[Feature][OP] Add Triton MLA Attention Backend (TRITON_MLA_ATTN)
PR 描述建议(可直接复制):
## Motivation
将 SGLang 中基于 Triton 实现的 MLA (Multi-head Latent Attention) decode attention 迁移至 FastDeploy,作为新的 attention backend (`TRITON_MLA_ATTN`),为 DeepSeek-V3 等 MLA 模型提供纯 Python/Triton 的推理路径,无需依赖 custom CUDA op。
## Modifications
### 新增文件
| 文件 | 说明 |
|------|------|
| `fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py` | Triton MLA Backend 核心实现,包含 extend/decode/mixed forward 及 CUDA Graph 兼容的 metadata 初始化 |
| `fastdeploy/model_executor/layers/attention/triton_ops/decode_attention.py` | Split-KV 两阶段 decode attention triton kernel(从 SGLang 迁移,适配 FastDeploy paged KV cache 寻址) |
| `fastdeploy/model_executor/layers/attention/triton_ops/mla_cache_kernel.py` | Triton KV cache write kernel,将 `[compressed_kv || k_pe]` 写入分页 cache |
### 修改文件
| 文件 | 说明 |
|------|------|
| `fastdeploy/platforms/base.py` | 新增 `_Backend.TRITON_MLA_ATTN` 枚举值 |
| `fastdeploy/platforms/cuda.py` | 添加 backend 路由至 `TritonMLAAttentionBackend` |
| `fastdeploy/config.py` | `CacheConfig.use_mla_cache` 识别 `TRITON_MLA_ATTN` |
| `fastdeploy/worker/gpu_model_runner.py` | MLA cache 路径识别 + `_apply_position_ids_if_needed` 支持新 backend |
| `fastdeploy/model_executor/layers/attention/__init__.py` | 注册并导出 `TritonMLAAttentionBackend` |
| `fastdeploy/model_executor/models/deepseek_v3.py` | `attn_out` 空值保护,避免空 batch 时 None 传入 `o_proj` |
| `custom_ops/gpu_ops/helper.h` | C++ 层 `checkAttentionBackend()` 识别新 backend |
## Usage or Command
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
export FD_ATTENTION_BACKEND="TRITON_MLA_ATTN"
export FLAGS_flash_attn_version=3
export FD_SAMPLING_CLASS=rejection
python -m fastdeploy.entrypoints.openai.api_server \
--model /path/to/GLM-4.7-Flash \
--port 8380 \
--tensor-parallel-size 4 \
--max-model-len 32768 \
--max-num-seqs 32
```
## Accuracy Tests
N/A
## Checklist
- [ ] Add at least a tag in the PR title.
- Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
- You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.总体评价
整体设计合理,split-KV 两阶段 triton kernel、paged cache 寻址适配以及 CUDAGraph 预分配 buffer 思路清晰。建议作者澄清 forward_extend 的 max_seqlen_k 取值逻辑,以及 forward_mixed 在 k=None 时对 extend token 的处理是否符合预期,确保精度正确性。
| forward_meta.cu_seqlens_q, | ||
| forward_meta.cu_seqlens_k, | ||
| metadata.max_enc_len_this_time, | ||
| metadata.max_enc_len_this_time, |
There was a problem hiding this comment.
🟡 建议 forward_extend 中 max_seqlen_k 传入了 max_enc_len_this_time(等于当前 step 最大 query 长度),而非 max_kv_len_this_time。
当存在 prefix KV cache 或 chunked prefill 时,cu_seqlens_k 中的序列长度可能大于 max_enc_len_this_time。对于 flash_attn_unpadded,max_seqlen_k 是影响 kernel 正确性的参数,若低于真实最大 KV 长度会导致 KV 被截断。
建议核实:extend 阶段 cu_seqlens_k 是否与 cu_seqlens_q 严格相等(纯 prefill 无 prefix cache 复用);若成立请加注释说明;否则应改为:
metadata.max_kv_len_this_time, # max_seqlen_k| residual: paddle.Tensor, | ||
| ): | ||
| """ """ | ||
| need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0 |
There was a problem hiding this comment.
❓ 疑问 max_len_tensor_cpu[1] 和 [2] 使用了魔法数字索引,缺少说明。这两个下标的含义(index 1 = max encode len,index 2 = max decode len)与 init_attention_metadata 中的隐式约定耦合,若布局变化会静默出错。
建议加注释说明:
# max_len_tensor_cpu layout: [_, max_enc_len, max_dec_len, ...]
need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0
need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0| if forward_meta.step_use_cudagraph or q.shape[0] == metadata.decode_bs: | ||
| return self._run_decode_kernel(q, latent_cache, metadata) | ||
|
|
||
| # Mixed batch (no CUDAGraph): q has all tokens (extend + decode). |
There was a problem hiding this comment.
❓ 疑问 此分支注释说明处理 "extend + decode" 混合 batch,但代码只计算了 decode token 的 attention,extend token 的输出被置零后直接返回:
output = paddle.zeros([total_tokens, self.num_heads * Lv], dtype=q.dtype)
output[decode_token_positions] = decode_out
return output如果调用方保证 k=None 时不存在 extend 序列(extend 路径由 forward_extend 处理,即 k is not None),则逻辑正确。请加注释说明该假设,例如:
# Invariant: when k is None, no extend sequences are present.
# Mixed extend+decode is always routed through forward_extend (k is not None).否则 extend token 输出为零会导致 o_proj 输入错误,产生静默精度问题。
Motivation
将 SGLang 中基于 Triton 实现的 MLA (Multi-head Latent Attention) decode attention 迁移至 FastDeploy,作为新的 attention backend (
TRITON_MLA_ATTN)。Modifications
新增文件
fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.pyfastdeploy/model_executor/layers/attention/triton_ops/__init__.pyfastdeploy/model_executor/layers/attention/triton_ops/decode_attention.pyfastdeploy/model_executor/layers/attention/triton_ops/mla_cache_kernel.py修改文件
fastdeploy/platforms/base.py_Backend.TRITON_MLA_ATTN枚举值fastdeploy/platforms/cuda.pyTritonMLAAttentionBackendfastdeploy/config.pyCacheConfig.use_mla_cache识别TRITON_MLA_ATTNfastdeploy/worker/gpu_model_runner.py_apply_position_ids_if_needed支持新 backendfastdeploy/model_executor/layers/attention/__init__.pyTritonMLAAttentionBackendfastdeploy/model_executor/models/deepseek_v3.pyattn_out空值保护,避免空 batch 时 None 传入o_projcustom_ops/gpu_ops/helper.hcheckAttentionBackend()识别新 backend关键设计
kv_loc // block_size和kv_loc % block_size解码 block 寻址。_kv_indptr_buf、_kv_indices_buf、_num_kv_splits_buf),使用 Triton cumsum kernel 替代 thrust(避免 cudaMalloc),padding 填充保持 kernel grid dim 恒定。flash_attention_v3_varlen,SM80 使用flash_attn_unpadded,不重复造轮子。Usage or Command