Skip to content

[Speculative Decoding] Refine ngram kernel signature and adapt ngram proposer#7774

Open
NKNaN wants to merge 3 commits into
PaddlePaddle:developfrom
NKNaN:spec-ngram
Open

[Speculative Decoding] Refine ngram kernel signature and adapt ngram proposer#7774
NKNaN wants to merge 3 commits into
PaddlePaddle:developfrom
NKNaN:spec-ngram

Conversation

@NKNaN
Copy link
Copy Markdown
Contributor

@NKNaN NKNaN commented May 11, 2026

Motivation

精简 ngram match kernel 接口,将原本分离的 input_ids/input_ids_len 参数合并到 token_ids_all(由 prompt_lens 划定 prompt 与 generated tokens 边界),并修复 ngram 指针偏移 bug(step_idx 语义由 0-based 末尾位置索引统一为 token 计数语义)。完成 A800 单卡端到端结果验证,确认投机解码 ngram 方法的端到端正确性。

Modifications

  1. custom_ops/gpu_ops/speculate_decoding/ngram_match.cu / cpp_extensions.cc:删除 input_idsinput_ids_leninput_ids_stride 参数;GPU kernel 与 CPU fallback 均改为直接从 token_ids_all[:, :prompt_len] 读取 prompt(搜索域)、从 token_ids_all[:, prompt_len:] 读取 pre_ids(ngram 来源);修复 ngram 指针偏移 bug:将 cur_step_idx + 1 - ngram_size 改为 cur_step_idx - ngram_size
  2. fastdeploy/spec_decode/ngram.py:删除 input_ids_len 相关张量及 update() 方法,_run_impl 调用签名与新 kernel 接口对齐。
  3. fastdeploy/config.py:将 SpecMethod.NGRAM 加入 CUDAGraph capture 的 expected_decode_len 计算逻辑。
  4. fastdeploy/worker/gpu_model_runner.pycapture_model() 中为 NGRAM 方法补充 warmup 路径,与 MTP/SUFFIX 保持一致。
  5. 测试:更新 tests/operators/test_ngram_match.pytests/spec_decode/test_benchmark_ngram_kernel.pytests/spec_decode/test_ngram_gpu_kernel.py;新增 tests/spec_decode/test_ngram_proposer.pytests/e2e/test_ernie_03b_ngram.py

Usage or Command

N/A

Accuracy Tests

端到端验证(A800 单卡):通过打印 token_ids_all 读写日志确认 prompt 写入与读取一致;验证 draft_tokens 均在 token_ids_all[:prompt_len + step_idx] 范围内命中;step_idx 跨步增量(如 25→31)确认一次 decode 成功接受多个 speculative tokens。

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented May 11, 2026

Thanks for your contribution!

@paddle-bot paddle-bot Bot added the contributor External developers label May 11, 2026
PaddlePaddle-bot

This comment was marked as outdated.

@PaddlePaddle-bot
Copy link
Copy Markdown

PaddlePaddle-bot commented May 11, 2026

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-14 14:28:02

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

⚠️ CI 进行中:1 个 Required 任务失败,6 个 Required 任务仍在运行中,请优先处理失败任务。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
38(0) 38 26 2 10 0 0

2 任务状态汇总

2.1 Required任务 : 3/10 通过

必选任务阻塞合并,失败需优先处理。

状态 任务 耗时 根因 修复建议 日志 重跑
Approval 11s PR问题:修改spec_decode目录,需指定RD审批 @freeliuzc@Deleter-D Approve此PR Job -
xpu_4cards_case_test / run_xpu_4cards_cases - 运行中 - Job -
xpu_8cards_case_test / run_xpu_8cards_cases - 运行中 - Job -
Extracted partial CE model tasks / run_ce_cases - 运行中 - Job -
Run Base Tests / base_tests - 运行中 - Job -
Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage - 运行中 - Job -
Run Four Cards Tests / run_4_cards_tests - 运行中 - Job -
其余 3 个必选任务通过 - - - - -

2.2 可选任务 — 23/28 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
Check PR Template 16s Job -
Run iluvatar Tests / run_iluvatar_cases - Job -
CI_HPU - Job -
xpu_unit_test / run_xpu_unit_test - Job -
Trigger Jenkins for PR - Job -
其余 23 个可选任务通过 - - -

3 失败详情(仅 required)

Approval — 流程审批(置信度: 高)

Approval

  • 状态: ❌ 失败
  • 错误类型: 流程审批
  • 置信度: 高
  • 根因摘要: PR修改spec_decode目录,需freeliuzc或Deleter-D审批
  • 分析器: 通用分析(fallback)

根因详情:
PR 修改了 fastdeploy/spec_decode 和/或 custom_ops/gpu_ops/speculate_decoding 路径下的文件。根据仓库审批规则,这些路径的变更必须获得指定 FastDeploy RD(@freeliuzc(liuzichang01) 或 @Deleter-D(wangyanpeng04))的 Review Approval 方可通过检查。目前 PR 尚未获得上述人员的 Approve,脚本以退出码 6 退出。

关键日志:

0. You must have one FastDeploy RD (freeliuzc(liuzichang01), Deleter-D(wangyanpeng04)) approval for modifing [fastdeploy/spec_decode,custom_ops/gpu_ops/speculate_decoding].
There are 1 approved errors.
##[error]Process completed with exit code 6.

修复建议:

  1. @freeliuzc (liuzichang01) 或 @Deleter-D (wangyanpeng04) 对此 PR 进行 Review 并点击 Approve。

修复建议摘要: 请@freeliuzc@Deleter-D Approve此PR

关联变更: PR 标题显示修改了 speculative decoding 相关代码(ngram kernel 签名及 ngram proposer 适配),涉及 fastdeploy/spec_decodecustom_ops/gpu_ops/speculate_decoding 目录,触发了该审批门禁。

链接: 查看日志

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented May 11, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@d70f33d). Learn more about missing BASE report.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7774   +/-   ##
==========================================
  Coverage           ?   72.35%           
==========================================
  Files              ?      398           
  Lines              ?    55882           
  Branches           ?     8726           
==========================================
  Hits               ?    40435           
  Misses             ?    12673           
  Partials           ?     2774           
Flag Coverage Δ
GPU 72.35% <100.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 聚焦于投机解码的 ngram 路径:通过精简 ngram_match 自定义算子接口(移除独立的 input_ids/input_ids_len,统一从 token_ids_all + prompt_lens 取 prompt / pre_ids),并同步适配 NgramProposer、CUDA Graph warmup 逻辑及相关单测/benchmark,完成端到端可用性验证。

Changes:

  • 调整 ngram_match GPU 自定义算子签名与内部寻址逻辑:prompt 从 token_ids_all 前半段读取,ngram 从 pre_ids[step_idx-ngram_size:step_idx] 读取。
  • 适配 fastdeploy/spec_decode/ngram.py 的 proposer 调用方式,并将 SpecMethod.NGRAM 纳入 CUDA Graph capture 的 warmup 分支。
  • 新增/更新 ngram proposer 与 kernel 的测试与 benchmark 数据构造,统一 step_idx 语义(ngram_match 场景下为“已生成 token 数”)。

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated no comments.

Show a summary per file
File Description
tests/spec_decode/test_ngram_proposer.py 新增 NgramProposer 的 CUDA 单测覆盖(无 proposal / 有 match / max_dec_len 裁剪)。
tests/spec_decode/test_ngram_gpu_kernel.py 更新 CPU 参考实现与测试数据构造以匹配新的 ngram slice 与 step_idx 语义,并适配新算子签名。
tests/spec_decode/test_benchmark_ngram_kernel.py benchmark 数据构造与调用方式适配新算子签名(基于 token_ids_all/prompt_lens)。
tests/operators/test_ngram_match.py operators 层单测适配新签名与 token_ids_all layout。
fastdeploy/worker/gpu_model_runner.py CUDA Graph capture warmup 分支将 NGRAM 纳入(与 MTP/SUFFIX 同类形状覆盖)。
fastdeploy/spec_decode/ngram.py NgramProposer 改为直接调用新签名的 ngram_match(不再依赖 input_ids_cpu/input_ids_len)。
fastdeploy/config.py cudagraph size 初始化逻辑中将 SpecMethod.NGRAM 纳入 speculative 形状推导。
custom_ops/gpu_ops/speculate_decoding/ngram_match.cu ngram_match kernel/CPU 路径与静态 OP 注册签名更新:移除 input_ids/input_ids_len,改用 token_ids_all+prompt_lens。
custom_ops/gpu_ops/cpp_extensions.cc NgramMatch 扩展声明更新以匹配新接口。

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-14 14:24:03

📋 Review 摘要

PR 概述:精简 ngram match kernel 接口(移除 input_ids/input_ids_len,统一由 token_ids_all 承载),同步修复 ngram 指针偏移 bug(step_idx + 1 - ngram_sizestep_idx - ngram_size),并补充端到端及单元测试。
变更范围custom_ops/gpu_ops/speculate_decoding/fastdeploy/spec_decode/ngram.pyfastdeploy/config.pyfastdeploy/worker/gpu_model_runner.py、测试文件
影响面 Tag[Speculative Decoding] [OP]

📝 PR 规范检查

PR 标题含官方 Tag [Speculative Decoding],描述五个必填 section 均完整,规范合规。此 PR 同时包含明确的 bug fix(指针偏移),若后续需要精确检索可考虑补充 [BugFix] tag,但当前单 Tag 格式符合规范,无需强制修改。

问题

级别 文件 概述
🟡 建议 tests/spec_decode/test_ngram_gpu_kernel.py:283 _make_mixed_test_datastep_idx = gen_len - 1 与新语义(count semantics)不一致,可能导致测试覆盖偏移
🟡 建议 tests/e2e/test_ernie_03b_ngram.py:168 硬编码精确浮点基准值(如 0.31000000000000005)具有强硬件绑定性,跨机 CI 极易失稳

总体评价

kernel 签名简化和 bug 修复逻辑清晰,E2E 验证充分,PR 质量较高。有两处测试健壮性问题建议跟进修复,不影响主流程合入。

@@ -281,7 +281,7 @@ def _make_mixed_test_data(batch_size=4, input_len=64, pre_ids_len=256, max_draft
gen_len = 20
src = rng.randint(0, max(1, input_len - gen_len))
pre_ids[b, :gen_len] = input_ids[b, src : src + gen_len]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 _make_mixed_test_datastep_idx 语义与新 kernel 不一致。

当前 _make_ngram_test_data 已更新为 step_idx = gen_len(count 语义),但 _make_mixed_test_data 仍保留 step_idx = gen_len - 1,注释改为 "matches hybrid kernel semantics"。

然而 GPU kernel 和 CPU fallback 均已更新cur_step_idx - ngram_size 语义,两者均以 count 语义解释 step_idx。如果此数据用于调用新 kernel:

  • 实际读取的 ngram 起点为 pre_ids[gen_len - 1 - ngram_size]
  • 而植入的 match 模式基于 pre_ids[:gen_len],起点应为 gen_len - ngram_size
  • 造成 off-by-1,测试覆盖的是错误位置的 ngram

建议:将 step_idx[b] = gen_len - 1 改为 step_idx[b] = gen_len,并更新数据植入逻辑使其与 count 语义对齐;或明确说明此函数仅用于特定 CPU legacy path 测试(但该路径似乎也已迁移到新语义)。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hybrid_mtp_ngram尚未进行端到端验证,暂不改动

result = response["choices"][0]["message"]["content"]
assert result != "", "Generation result is empty"
usage = response["usage"]
assert usage["completion_tokens"] <= payload["max_tokens"]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 硬编码精确浮点基准值会导致跨机 CI 失稳。

当前断言使用精确相等比较:

assert speculate_metrics == baseline

其中 baseline 包含如 0.310000000000000050.18840579710144928 等精确值。这些数值依赖于:

  • 特定 GPU 型号(A800)的算术行为
  • 特定模型版本/权重
  • 运行时的调度顺序

建议改为范围断言,验证功能性正确性而非精确数值:

# 验证 accept_ratio 在合理范围而非精确值
assert 0.1 <= speculate_metrics["accept_ratio"] <= 0.9
assert speculate_metrics["accepted_tokens"] > 0
assert speculate_metrics["average_accept_length"] >= 1.0
# 验证 per_head 长度结构正确
assert len(speculate_metrics["accepted_tokens_per_head"]) == 6
assert len(speculate_metrics["accept_ratio_per_head"]) == 5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants