Skip to content

refactor: inline InfiniCore into InfiniLM, switch xmake -> CMake#324

Draft
zhangyue207 wants to merge 6 commits intomainfrom
refactor/inline-infinicore
Draft

refactor: inline InfiniCore into InfiniLM, switch xmake -> CMake#324
zhangyue207 wants to merge 6 commits intomainfrom
refactor/inline-infinicore

Conversation

@zhangyue207
Copy link
Copy Markdown
Collaborator

Summary

将 InfiniCore 中 InfiniLM 实际依赖的三块作为子模块吸收进本仓库,移除外部 INFINI_ROOT 预装依赖,构建系统由 xmake 切换到 CMake:

  • runtime/libinfinirt.so(多平台运行时)
  • ccl/libinfiniccl.so(NCCL allreduce,依赖 runtime)
  • ops/libinfinicore.so + _infinicore.*.so(infiniop 算子 + infinicore C++ API + Python binding)

顶层 CMakeLists.txt 聚合三个子项目并构建已有的 csrc/_infinilm.*.so。所有 5 个 .so 共置于 python/{infinilm,infinicore}/lib/RPATH=$ORIGIN,无需 LD_LIBRARY_PATH/INFINI_ROOT

安装

# 默认(无 flash-attn)
pip install -e .

# 启用 flash-attn(自动 clone vllm-project/flash-attention)
INFINILM_BUILD_FLASH_ATTN=1 \
    pip install -e . --no-build-isolation

环境变量:INFINILM_BUILD_FLASH_ATTNINFINILM_FLASH_ATTN_REPO/REF/DIRINFINILM_FLASH_ATTN_ARCHSINFINILM_BUILD_TYPEINFINILM_BUILD_JOBS

主要改动

  • 新增runtime/ccl/ops/、顶层 CMakeLists.txtpython/infinicore/
  • 修改setup.py(CMake 驱动)、pyproject.toml(build deps 加 cmake/pybind11)、.gitignore(让 *.txt 通配不再屏蔽 CMakeLists.txt)、examples/jiuge.py(兼容 transformers ≥ 5.0 删除的 batch_encode_plus
  • 微补丁:metax include 加 ENABLE_METAX_API 守卫;mha_varlen_fwd 调用点 + 声明添加 vllm fork 的 num_splits 参数
  • xmake.lua 保留但不再驱动构建;src/ legacy ctypes 模型保留但不再编译

验收(8× A100-80GB)

配置 TTFT Decode 输出
单卡 llama 8B 默认 841 ms 62 tok/s <think>Okay, the user is asking...
单卡 llama 8B + graph + paged + flash-attn 20 ms 81 tok/s
tp=4 qwen2 70B 默认 2100 ms 20 tok/s 您好, 我是由启元实验室主导研发的九格通用基础大模型。
tp=4 qwen2 70B + 全栈优化 47 ms 29 tok/s

Test plan

  • pip install -e . 默认路径
  • INFINILM_BUILD_FLASH_ATTN=1 pip install -e . --no-build-isolation flash-attn 路径
  • python examples/jiuge.py --model <llama_path> --device nvidia --tp 1 --backend cpp
  • python examples/jiuge.py --model <qwen2_path> --device nvidia --tp 4 --backend cpp
  • CUDA Graph:--enable-graph --enable-paged-attn
  • flash-attn:--attn flash-attn --enable-paged-attn

Migrate the three InfiniLM-facing components of InfiniCore (multi-platform
runtime, distributed CCL, and operators + C++ API) into this repository as
self-contained CMake subprojects:

  runtime/  -> libinfinirt.so   (multi-platform device/stream/event)
  ccl/      -> libinfiniccl.so  (NCCL-backed allreduce, depends on runtime)
  ops/      -> libinfinicore.so (infiniop kernels + infinicore C++ API)
              + _infinicore.*.so (Python binding, infinicore package)

Top-level CMakeLists.txt aggregates the three subprojects and builds the
existing csrc/ pybind module as _infinilm.*.so. All five shared objects are
co-located under python/infinilm/lib/ (and python/infinicore/lib/) with
RPATH=$ORIGIN, so neither LD_LIBRARY_PATH nor INFINI_ROOT is needed.

Build entry point:

  pip install -e .                                 # default
  INFINILM_BUILD_FLASH_ATTN=1 \
      pip install -e . --no-build-isolation        # with vllm-fa (optional)

setup.py replaces the xmake driver with a CMake driver and supports knobs
via env vars: INFINILM_BUILD_FLASH_ATTN, INFINILM_FLASH_ATTN_DIR/REPO/REF,
INFINILM_FLASH_ATTN_ARCHS, INFINILM_BUILD_TYPE, INFINILM_BUILD_JOBS.

Changes summary:
- New: runtime/, ccl/, ops/, CMakeLists.txt, python/infinicore/
- Modified: setup.py (CMake driver), pyproject.toml (cmake/pybind11 build deps),
  .gitignore (allow CMakeLists.txt past *.txt rule), examples/jiuge.py
  (tokenizer compat for transformers >= 5.0 dropping batch_encode_plus)
- Tiny patches: metax_ht2mc.h include guarded behind ENABLE_METAX_API in
  runtime/ and ccl/; mha_varlen_fwd call site + decl take vllm fork's
  num_splits parameter

xmake.lua left in place but no longer drives the build.
src/ legacy ctypes models retained but not built (no longer needed by the
infinilm/infer_engine path used in examples/jiuge.py).

Validation on 8x A100-80GB:
- Single-GPU 9g_8b_thinking_llama: decode 62 tok/s default,
  81 tok/s with --enable-graph --enable-paged-attn --attn flash-attn
- tp=4 FM9G_70B_SFT_MHA_qwen2: decode 20 tok/s default,
  29 tok/s with full graph+paged+flash-attn stack
@zhangyue207 zhangyue207 requested a review from a team April 25, 2026 16:15
@zhangyue207 zhangyue207 marked this pull request as draft April 25, 2026 16:16
@zhangyue207
Copy link
Copy Markdown
Collaborator Author

zhangyue207 commented Apr 25, 2026

InfiniLM 重构后性能 Benchmark 报告

1. 测试目标

验证 refactor/inline-infinicore 分支重构后的 InfiniLM 在生产推理场景下的功能正确性与吞吐性能,作为重构验收的最终评估。

2. 测试环境

硬件 1× NVIDIA A100-SXM4-80GB (sm_80)
CUDA 13.0 (Driver, nvcc)
cuDNN 9.16.0-cuda-13
NCCL 2.28.9
OS Linux 5.15 (Ubuntu 22.04)
Python 3.10.19
PyTorch 2.11.0+cu130
InfiniLM refactor/inline-infinicore @ 8a2e268
flash-attention vllm-project/flash-attention (FA2 路径,sm_80)

3. 测试方法

3.1 服务端

benchmark/server.sh 启动 InfiniLM HTTP server(OpenAI-compatible),关键配置:

--device nvidia
--model /data-aisoft/mechdancer/models/9g_8b_thinking_llama
--tp 1
--cache-type paged --num-blocks 1024 --block-size 256
--max-new-tokens 4096 --max-batch-size 64
--enable-graph
--attn flash-attn

最大并发 64,KV cache 容量 = 1024 块 × 256 tok = 262 144 token。

3.2 客户端

benchmark/bench_client.py:基于 httpx.AsyncClient 的 SSE benchmark 客户端,与 vllm bench serve 输出指标等价:

  • 协议:POST /v1/chat/completions stream=true,按 SSE chunk 时序计时
  • 输入:固定 seed=42,按 random token id 解码生成 prompt(保证不同 batch 间 prompt 一致可比)
  • 终止:ignore_eos=true 强制模型生成满 max_tokens,避免 EOS 提前结束影响吞吐
  • 指标:TTFT、ITL、TPOT、E2EL(mean / p50 / p99 / min / max)+ request/output token throughput

3.3 扫描参数

维度
并发数 (max_concurrency) 1, 4, 16, 64, 128
input_len 256 token
output_len 256 token
每并发 prompts 20,下限 200

4. 结果

4.1 单卡 8B Llama 主表

bs 请求数 成功率 TTFT mean (ms) TTFT p99 (ms) TPOT mean (ms) E2EL mean (s) req/s decode tok/s
1 200 200/200 94 159 13.85 3.6 0.28 70.6
4 200 200/200 137 331 14.31 3.8 1.06 270.2
16 320 320/320 798 1503 20.05 5.9 2.71 691.5
64 1280 1280/1280 1317 3687 47.90 13.5 4.68 1192.9
128 2560 2560/2560 7819 14011 59.44 22.9 5.56 1416.9

4.2 扩展性分析

Decode 吞吐 vs 并发

       1     4    16    64   128
tok/s 71   270   692  1193  1417
ratio  -  3.83x 2.56x 1.72x 1.19x
并发跃迁 理论加速 实测加速 效率
1 → 4 3.83× 96%
4 → 16 2.56× 64%
16 → 64 1.72× 43%
64 → 128 1.19× 60%(受限流)

4.3 关键观察

(1) 功能正确性 ✓
4560 次请求 0 失败,输出连贯(端到端无回归)。

(2) decode bound 转折点 = bs ≈ 16

  • bs ≤ 4:TPOT ≈ 14 ms(与单卡完全相同)→ 此区间是 latency-bound,加 batch 几乎免费
  • bs = 16:TPOT 升到 20 ms,加 batch 开始有 cost
  • bs ≥ 64:TPOT 翻倍到 48–60 ms,但 throughput 仍持续增长 → 进入 throughput-bound

(3) bs=128 受 server max-batch-size=64 限流

  • TTFT p99 = 14 s(vs bs=64 的 3.7 s),多出来的 64 个请求被排队等待
  • decode 吞吐仅多 19%,验证了 max-batch-size 是真实瓶颈
  • 推荐生产配置:max-batch-size--max-concurrency 对齐

(4) flash-attn + paged + graph 组合稳定
全程无 graph capture 错误、KV cache OOM 或 NCCL 报错。

5. 与重构前的对比

重构前 InfiniLM 依赖外部 InfiniCore($INFINI_ROOT/lib/),同硬件同模型同参数下不可直接比较(构建系统差异、SO 装载路径差异)。仅就同一码树下:

  • 编译 1×:从 xmake build && xmake install 改为 pip install -e .,无需预装 InfiniCore
  • 部署后 SO 体积:原 libinfiniop.so + libinfinirt.so + libinfiniccl.so + libinfinicore_cpp_api.so + libinfinicore_infer.so + 2× python 模块 = 7 个 SO;现合并为 libinfinirt + libinfiniccl + libinfinicore + libflash-attn-nvidia + 2× python 模块 = 6 个(其中 flash-attn 可选,180 MB sm_80 single arch)
  • 验证用例(单卡 + tp=4)输出 token 与重构前完全一致(语义等价)

6. 结论

  1. 重构后 InfiniLM 在 CMake + flash-attn + paged + graph 全开下,单卡 A100-80GB 推理 8B Llama 达到 1417 tok/s decode(bs=128,但实际饱和点在 bs=64 ~1193 tok/s)
  2. 稳定性合格:4560 个请求零失败
  3. bs=64 是当前 server 配置下的最优 throughput-latency 平衡点:1193 tok/s decode,TTFT p99 < 4 s
  4. 建议生产配置--max-batch-size 与负载并发对齐;--num-blocks 1024 对 in=out=256 场景有充足余量

7. TP=4 70B 分布式

7.1 配置

模型 FM9G_70B_SFT_MHA_qwen2 (Qwen2ForCausalLM, 80 层, hidden=8192, attention heads=64, kv heads=64)
硬件 4× A100-SXM4-80GB
TP 度 4
dtype bf16
--max-batch-size 16
--num-blocks 128(block_size=256,KV 容量 32 768 token)
--max-new-tokens 1024
其他 同上:--enable-graph --cache-type paged --attn flash-attn
Output len 128 token(70B 约 8 倍 8B 解码耗时,缩短以控制总测试时间)

每卡显存占用 ~70 GB / 80 GB(35 GB 模型权重 + KV cache + 激活 + flash-attn workspace)。

7.2 主表

bs 请求数 成功率 TTFT mean (ms) TTFT p99 (ms) TPOT mean (ms) E2EL mean (s) req/s decode tok/s
1 100 100/100 195 591 36.28 2.2 0.45 24.7
4 100 100/100 621 1271 114.28 5.5 0.72 38.8
16 160 160/160 1250 2578 228.57 12.9 1.22 66.2

7.3 观察

  • bs=1 decode 25 tok/s:与 8B 单卡 bs=1(70.6 tok/s)对比,70B 是 8.75× 大但只慢 2.86×,说明 TP=4 把单 token 计算量做了 ~3× 加速(理论上限是 4×)
  • TPOT 在 bs=16 升到 229 ms:4 卡 + 70B 大模型,decode 严重 memory-bound,加 batch 摊薄 KV cache 读取的收益有限
  • TTFT 比 8B 低很多(bs=1 195 ms vs 8B 单卡 94 ms):tp=4 把 prefill 时的 GEMM 切了 4 份并行,attention/QKV 都受益
  • 稳定性:360 次请求 0 失败,AllReduce 通信走 NCCL 全程无错

7.4 8B 单卡 vs 70B tp=4 对比(bs=1,每 token 计算量参考)

模型 参数量 硬件 TPOT decode tok/s 计算量比 实测耗时比
8B Llama 8 B 1× A100 13.85 ms 70.6
70B Qwen2 70 B 4× A100 36.28 ms 24.7 8.75× 2.62×

70B 在 tp=4 下每 token 实际只比 8B 单卡慢 2.6 倍,TP 加速比约 3.34×(理论 4×,效率 84%)— allreduce 通信 + tp 并行调度开销吃掉了 16% 的潜力。

8. 总结

验证项 结论
单卡功能 (8B llama) ✓ 0 失败/4560 请求
分布式功能 (tp=4 70B) ✓ 0 失败/360 请求
单卡饱和 throughput 1193 tok/s @ bs=64(接近 server max-batch-size 限)
tp=4 70B throughput 66 tok/s @ bs=16(受 KV cache 大小限制无法继续加 batch)
TP=4 加速比 3.34× / 4× = 84% 效率
路径稳定性 flash-attn + graph + paged 三栈全程稳

zhangyue207 and others added 3 commits April 26, 2026 20:44
ixblas selects a capture-unfriendly cublas algorithm when a GEMM's
output dim exceeds ~65k (e.g. lm_head with vocab_size ~152k). The
broken algo replays with corrupt output. Splitting along output dim
into chunks <= 65k routes each sub-GEMM through the standard
capture-friendly path. Each chunk must use an independent contiguous
buffer because cublas selects algo by leading dim, not by shape;
writing into a narrow view of the big output keeps ld large.

Supporting changes for the capture window:
- Dedicated cublas/cudnn capture handle, lazily created on the last
  warmup before BeginCapture so its stream binding is established
  before capture starts; pool-pop/push and per-call setStream are
  skipped during capture (they break ixblas semantics).
- PinnableBlockAllocator skips frozen blocks during eager allocs so
  buffers handed to a captured graph aren't reused.
- syncDevice() (not just syncStream) before/after the capture-mode
  warmup, to drain cublas/cudnn helper streams.

Enable with INFINILM_LINEAR_CHUNK_OUT=32768. Default behavior
unchanged on all platforms; eager forward path is byte-equivalent
to the previous code on NVIDIA.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Iluvatar shares the nvidia kernel implementation for per-tensor int8
quant/dequant ops. Add the ENABLE_ILUVATAR_API dispatch entries so
the ops are callable on iluvatar (was a missing-dispatch error).

Also: tolerate missing kv_cache_*_scale weights in
modeling_utils.check_parameters so models without int8 KV scales
don't fail the parameter check.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@zhangyue207
Copy link
Copy Markdown
Collaborator Author

Iluvatar MR-V100 适配验收报告

补充本分支在 Iluvatar GPU 上的端到端验证。新增 commit:

  • 32599ec feat(iluvatar): enable CUDA Graph via lm_head GEMM chunking
  • 338df19 feat(iluvatar): register per_tensor_(de)quant_int8 dispatch

1. 测试环境

硬件 1× Iluvatar MR-V100,32 GiB,CC 7.1 (ivcore11)
工具链 Iluvatar Corex 4.1.3(CUDA 10.2 兼容,clang 16 + cuDNN 7 + NCCL 2.14.3)
Docker 镜像 corex:4.1.3(pyTorch 2.1.1 / Python 3.10.12)
模型 9g_8b_thinking_llama(LlamaForCausalLM,hidden 4096,32 layer,32 q-head / 2 KV-head GQA,BF16,max_pos=65536)
编译 pip install -e .,CMake 路径

2. 编译适配(已并入 a5f9634

CMake 三处加 INFINI_NV_BACKEND cache 变量;ENABLE_NVIDIA_API 改为 ENABLE_X_APIX=ILUVATAR 时定义 ENABLE_ILUVATAR_API。CUDA 编译路径检测到 Clang(Iluvatar 的"nvcc"是 clang-CUDA 包装)时不传 nvcc 专有 flag。ops/CMakeLists.txt GLOB 当 INFINI_NV_BACKEND 非 NVIDIA 时额外收集 ops/<op>/<backend>/*.cu

nvidia_kernel_common.cuhcuda_fp8.h include 加 #ifndef ENABLE_ILUVATAR_API 守卫,cuda_fp8_e4m3 别名退到 unsigned char stub(Corex 4.1.3 无 FP8 头)。三个训练专用 kernel(relu6fmin/fmaxgaussian_nll_loss / hinge_embedding_loss 跳过 atomicAdd(double*, double))做 fallback,推理路径不会走到。

3. 端到端推理性能(256 tok 强制生成,prompt 9–17 tok,BF16 batch=1)

python3 examples/jiuge.py \
  --device=iluvatar --backend=cpp --tp=1 \
  --model=/workspace/model/9g_8b_thinking_llama \
  --max-new-tokens=256 --ignore-eos \
  --prompt="写一篇关于人工智能的长文章。"
配置 Total Prefill TTFT Decode ITL Decode tok/s 输出
baseline(默认 attention) 12.95 s 187.68 ms 50.05 ms 19.98
--enable-paged-attn 9.63 s 128.26 ms 37.24 ms 26.85
--enable-paged-attn --enable-graph 乱码(见下)
--enable-paged-attn --enable-graph + INFINILM_LINEAR_CHUNK_OUT=32768 9.56 s 64.96 ms 37.23 ms 26.86

paged-attn 相对 baseline decode +34%

4. CUDA Graph 在 Iluvatar 上的根因 + 解法

直接开 --enable-graph --enable-paged-attn 输出乱码。bisect(按 op_list 二分捕获前 N 个 op)9 步定位到 op_list_[354],即 forward 的最后一个 op:lm_head GEMM。形状 [1, 152064] = [1, 4096] × [4096, 152064],是整个 forward 中唯一 output_dim = vocab_size 的大 GEMM

进一步只 capture op_list_[354]、前 354 全 eager 跑 → 仍乱码 → lm_head 自身的 capture 就坏,不是污染问题

继续二分 chunk 大小:

chunk_size 块数 单块 output_dim 输出
131072 / 98304 / 76032 2 > 65k ❌ 乱码
65536 3 65536
32768 5 32768
16384 10 16384

ixblas 在单块 GEMM output_dim > ~65k 时切换到 capture-不友好的 algo(疑似 split-K / stream-K,host-side state 在 capture/replay 间不一致)。chunk ≤ 65536 都能完整 capture。

关键工程细节:每块必须用独立 contiguous buffer,再 rearrange 拷回大 out。cublas 按 leading dim 选 algo,不是按 shape——写到大 buffer 的 narrow view 上 ld 仍是 152064,仍命中 broken algo。

修通后 --enable-graph 能完整 capture,整个 forward 1 次 cudaGraphLaunch、0 次 sync。

5. 配套 capture-window 改动(NV 平台无破坏)

改动 文件 NV 影响
INFINILM_LINEAR_CHUNK_OUT chunked path linear.cc env 默认不设 → chunk_threshold=0 → 走原 path 字节级等价
Dedicated cublas/cudnn capture handle nvidia_handle.cuh, nvidia_common.cu g_capture_mode=false 时走原 pool path 字节级等价
PinnableBlockAllocator frozen-block 隔离 pinnable_block_allocator.cc 非 capture 时 frozen 永远 false,predicate 短路;与 vllm cudaMemPool_t 隔离同语义
syncDevice() 替代 syncStream() 围 capture graph.cc one-shot instantiate 时 +~ms,drains cublas/cudnn helper streams

Graph::run() 已回退到原始两行实现,decode 热路径对 NV 零影响

6. 现状决议

  • Iluvatar 单卡 batch=1 上 graph +0.3% 实测提升(26.86 vs 26.85),属于 ITL 37 ms 下 launch overhead 占比的现实上限附近。默认仍保持 --enable-graph=False;显式开启需配合 INFINILM_LINEAR_CHUNK_OUT=32768
  • paged-attn 是 Iluvatar 的主收益路径,+34% decode,无副作用,生产建议默认开
  • 翻盘场景(未验证):batch ≥ 8 / 多卡 NCCL / ixblas 修好大 output_dim algo——graph 收益预期可涨到 5–15%。

7. Test plan(已通过)

  • pip install -e . Iluvatar 路径(-DINFINI_NV_BACKEND=ILUVATAR -DCMAKE_CUDA_COMPILER=/usr/local/corex-4.1.3/bin/clang++
  • BF16 端到端:baseline 19.98 / paged-attn 26.85 tok/s decode @ 256 tok
  • --enable-paged-attn --enable-graph INFINILM_LINEAR_CHUNK_OUT=32768:26.86 tok/s decode,输出连贯
  • OpenAI-compatible inference_server.py:health / models / sync / stream / 4-way concurrent batching 95 tok/s
  • Graph::run() 回退原版后 NV 平台 decode 路径字节级等价(review)

详细文档:iluvatar_perf.mdiluvatar_deploy.md(仓外)。

zhangyue and others added 2 commits April 27, 2026 03:50
Adds Hygon DCU support to the new CMake-driven `_infinilm` path. Reuses
NVIDIA's `.cu` source set under DTK's CUDA-compat shim (/opt/dtk/cuda/cuda-12)
— only ENABLE_HYGON_API is defined instead of ENABLE_NVIDIA_API, which
switches `infinirt::cuda` → `infinirt::hygon` and gates out FP8 +
cublasComputeType_t paths the DCU can't handle.

Build:
- INFINILM_ENABLE_HYGON=1 umbrella in setup.py / top-level CMakeLists flips
  runtime/ccl/ops to their HYGON-enabled paths.
- runtime/ccl/ops add INFINI{RT,CCL,OPS}_ENABLE_HYGON options that compile
  the same NVIDIA .cu sources with ENABLE_HYGON_API.
- ccl finds RCCL via find_library(rccl) under /opt/dtk/lib (NCCL-API compat).
- CUDA_SEPARABLE_COMPILATION forced OFF under Hygon — DTK's device-link drops
  fatbin from the final shared lib otherwise.
- CMAKE_CUDA_ARCHITECTURES=75 — DTK's nvcc shim translates sm_75 to the full
  DCU set internally.

Source-level dispatches:
- HYGON entries added to operator.cc dispatch tables in paged_attention/,
  paged_attention_prefill/, paged_caching/, embedding/, mha_kvcache/,
  multi_head_attention_varlen/, plus device whitelists in nn/embedding.cc
  and nn/rmsnorm.cc.
- paged_attention_prefill kernel_v2.cuh's nvcuda::wmma block gated with
  !defined(ENABLE_HYGON_API); default_prefill_kernel returns "warp" under
  Hygon (DCU has no Tensor Cores).

Flash-attn integration:
- ops/src/infinicore/adaptor/flash_attn_hygon_dlsym.cc — dlsym wrapper
  resolving mha_fwd_kvcache / vllm_mha_varlen_fwd / paged_attention from the
  system flash_attn_2_cuda*.so (DTK fork wheel). Compiled with plain g++ +
  __HIP_PLATFORM_AMD__ to pick up <c10/hip/...> headers without conflicting
  with libinfinicore's DTK CUDA-compat path.
- python/infinilm/__init__.py RTLD_GLOBAL-preloads flash_attn_2_cuda*.so so
  dlsym(RTLD_DEFAULT, ...) resolves.
- Decode path uses paged_attention (not mha_fwd_kvcache) to keep HIP graph
  capture clean.
- HIPStreamGuard binds PyTorch's current stream to InfiniLM's stream so all
  ATen calls participate in graph capture.

bs=1 GEMV kernel (custom):
- ops/src/infiniop/ops/gemm/nvidia/gemv_hygon.cuh — purpose-built BF16 GEMV
  (64-thread wavefront, bf16x4 vector loads, FP32 accumulator, wave-shuffle
  reduce). DTK hipBLAS picks MT16x16x32 for M=1 GEMM which runs at ~10 GFLOPS
  / 0.02% peak; this kernel is HBM-bandwidth-bound.
- gemm_nvidia.cu fast-path under ENABLE_HYGON_API hits when n=batch=1 +
  op_a=T + standard PyTorch weight layout (M×K row-major, lda=K).
- Verified +44% bs=1 decode tok/s (39.7 → 58.7), reaches 83% A100 baseline.
- Override with INFINIOPS_DISABLE_GEMV_HYGON=1.

Benchmark scripts and report:
- benchmark/server_hygon.sh / run_client_hygon.sh / run_client_tp4_70b_hygon.sh
- benchmark/REPORT.md (NVIDIA + Hygon results, cross-backend comparison)
- benchmark/REPORT_old.md (archeology — TP=4 deadlock investigation)

TP>1 deadlock workaround (in server_hygon.sh):
- HSA_ENABLE_SDMA=0: prevents Segment::Freeze BlitKernel signal-wait deadlock
  at first-iter HIP module load (TP>=2).
- HSA_FORCE_FINE_GRAIN_PCIE=1: RCCL warns this is required for stability.
- NCCL_P2P_DISABLE=1 (TP>=4): RCCL P2P over PCIe deadlocks the ring at
  first allreduce; routes through shared-memory staging instead.

Verified end-to-end on /root/models/9g_8b_thinking_llama (TP=1/2/4) and
/root/models/FM9G_70B_SFT_MHA_qwen2 (TP=4 eager + graph). Outputs
token-for-token consistent across modes; 4560 server-client requests with
0 failures at TP=1.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- generation/utils.py: dedup the duplicate next_tokens.to_numpy() call per
  decode step (one D2H + numpy alloc instead of two). Token-identical;
  jiuge.py decode +8% on Hygon DCU.
- infer_engine.py: cache the constant input_offsets ([0, 1] for bs=1
  seq_len=1) outside the decode loop. Token-identical; small bench gain.

Verified on hygon (9g_8b_thinking_llama, tp=1, graph + paged-attn +
flash-attn): correctness diff vs baseline = 0; bench decode 42.9 → 44.1
tok/s.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant