refactor: inline InfiniCore into InfiniLM, switch xmake -> CMake#324
refactor: inline InfiniCore into InfiniLM, switch xmake -> CMake#324zhangyue207 wants to merge 6 commits intomainfrom
Conversation
Migrate the three InfiniLM-facing components of InfiniCore (multi-platform
runtime, distributed CCL, and operators + C++ API) into this repository as
self-contained CMake subprojects:
runtime/ -> libinfinirt.so (multi-platform device/stream/event)
ccl/ -> libinfiniccl.so (NCCL-backed allreduce, depends on runtime)
ops/ -> libinfinicore.so (infiniop kernels + infinicore C++ API)
+ _infinicore.*.so (Python binding, infinicore package)
Top-level CMakeLists.txt aggregates the three subprojects and builds the
existing csrc/ pybind module as _infinilm.*.so. All five shared objects are
co-located under python/infinilm/lib/ (and python/infinicore/lib/) with
RPATH=$ORIGIN, so neither LD_LIBRARY_PATH nor INFINI_ROOT is needed.
Build entry point:
pip install -e . # default
INFINILM_BUILD_FLASH_ATTN=1 \
pip install -e . --no-build-isolation # with vllm-fa (optional)
setup.py replaces the xmake driver with a CMake driver and supports knobs
via env vars: INFINILM_BUILD_FLASH_ATTN, INFINILM_FLASH_ATTN_DIR/REPO/REF,
INFINILM_FLASH_ATTN_ARCHS, INFINILM_BUILD_TYPE, INFINILM_BUILD_JOBS.
Changes summary:
- New: runtime/, ccl/, ops/, CMakeLists.txt, python/infinicore/
- Modified: setup.py (CMake driver), pyproject.toml (cmake/pybind11 build deps),
.gitignore (allow CMakeLists.txt past *.txt rule), examples/jiuge.py
(tokenizer compat for transformers >= 5.0 dropping batch_encode_plus)
- Tiny patches: metax_ht2mc.h include guarded behind ENABLE_METAX_API in
runtime/ and ccl/; mha_varlen_fwd call site + decl take vllm fork's
num_splits parameter
xmake.lua left in place but no longer drives the build.
src/ legacy ctypes models retained but not built (no longer needed by the
infinilm/infer_engine path used in examples/jiuge.py).
Validation on 8x A100-80GB:
- Single-GPU 9g_8b_thinking_llama: decode 62 tok/s default,
81 tok/s with --enable-graph --enable-paged-attn --attn flash-attn
- tp=4 FM9G_70B_SFT_MHA_qwen2: decode 20 tok/s default,
29 tok/s with full graph+paged+flash-attn stack
InfiniLM 重构后性能 Benchmark 报告1. 测试目标验证 2. 测试环境
3. 测试方法3.1 服务端
最大并发 64,KV cache 容量 = 1024 块 × 256 tok = 262 144 token。 3.2 客户端
3.3 扫描参数
4. 结果4.1 单卡 8B Llama 主表
4.2 扩展性分析Decode 吞吐 vs 并发:
4.3 关键观察(1) 功能正确性 ✓ (2) decode bound 转折点 = bs ≈ 16
(3) bs=128 受 server
(4) flash-attn + paged + graph 组合稳定 5. 与重构前的对比重构前 InfiniLM 依赖外部 InfiniCore(
6. 结论
7. TP=4 70B 分布式7.1 配置
每卡显存占用 ~70 GB / 80 GB(35 GB 模型权重 + KV cache + 激活 + flash-attn workspace)。 7.2 主表
7.3 观察
7.4 8B 单卡 vs 70B tp=4 对比(bs=1,每 token 计算量参考)
70B 在 tp=4 下每 token 实际只比 8B 单卡慢 2.6 倍,TP 加速比约 3.34×(理论 4×,效率 84%)— allreduce 通信 + tp 并行调度开销吃掉了 16% 的潜力。 8. 总结
|
ixblas selects a capture-unfriendly cublas algorithm when a GEMM's output dim exceeds ~65k (e.g. lm_head with vocab_size ~152k). The broken algo replays with corrupt output. Splitting along output dim into chunks <= 65k routes each sub-GEMM through the standard capture-friendly path. Each chunk must use an independent contiguous buffer because cublas selects algo by leading dim, not by shape; writing into a narrow view of the big output keeps ld large. Supporting changes for the capture window: - Dedicated cublas/cudnn capture handle, lazily created on the last warmup before BeginCapture so its stream binding is established before capture starts; pool-pop/push and per-call setStream are skipped during capture (they break ixblas semantics). - PinnableBlockAllocator skips frozen blocks during eager allocs so buffers handed to a captured graph aren't reused. - syncDevice() (not just syncStream) before/after the capture-mode warmup, to drain cublas/cudnn helper streams. Enable with INFINILM_LINEAR_CHUNK_OUT=32768. Default behavior unchanged on all platforms; eager forward path is byte-equivalent to the previous code on NVIDIA. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Iluvatar shares the nvidia kernel implementation for per-tensor int8 quant/dequant ops. Add the ENABLE_ILUVATAR_API dispatch entries so the ops are callable on iluvatar (was a missing-dispatch error). Also: tolerate missing kv_cache_*_scale weights in modeling_utils.check_parameters so models without int8 KV scales don't fail the parameter check. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Iluvatar MR-V100 适配验收报告补充本分支在 Iluvatar GPU 上的端到端验证。新增 commit:
1. 测试环境
2. 编译适配(已并入 a5f9634)CMake 三处加
3. 端到端推理性能(256 tok 强制生成,prompt 9–17 tok,BF16 batch=1)python3 examples/jiuge.py \
--device=iluvatar --backend=cpp --tp=1 \
--model=/workspace/model/9g_8b_thinking_llama \
--max-new-tokens=256 --ignore-eos \
--prompt="写一篇关于人工智能的长文章。"
paged-attn 相对 baseline decode +34%。 4. CUDA Graph 在 Iluvatar 上的根因 + 解法直接开 进一步只 capture op_list_[354]、前 354 全 eager 跑 → 仍乱码 → lm_head 自身的 capture 就坏,不是污染问题。 继续二分 chunk 大小:
ixblas 在单块 GEMM 关键工程细节:每块必须用独立 contiguous buffer,再 rearrange 拷回大 out。cublas 按 leading dim 选 algo,不是按 shape——写到大 buffer 的 narrow view 上 ld 仍是 152064,仍命中 broken algo。 修通后 5. 配套 capture-window 改动(NV 平台无破坏)
6. 现状决议
7. Test plan(已通过)
详细文档: |
Adds Hygon DCU support to the new CMake-driven `_infinilm` path. Reuses
NVIDIA's `.cu` source set under DTK's CUDA-compat shim (/opt/dtk/cuda/cuda-12)
— only ENABLE_HYGON_API is defined instead of ENABLE_NVIDIA_API, which
switches `infinirt::cuda` → `infinirt::hygon` and gates out FP8 +
cublasComputeType_t paths the DCU can't handle.
Build:
- INFINILM_ENABLE_HYGON=1 umbrella in setup.py / top-level CMakeLists flips
runtime/ccl/ops to their HYGON-enabled paths.
- runtime/ccl/ops add INFINI{RT,CCL,OPS}_ENABLE_HYGON options that compile
the same NVIDIA .cu sources with ENABLE_HYGON_API.
- ccl finds RCCL via find_library(rccl) under /opt/dtk/lib (NCCL-API compat).
- CUDA_SEPARABLE_COMPILATION forced OFF under Hygon — DTK's device-link drops
fatbin from the final shared lib otherwise.
- CMAKE_CUDA_ARCHITECTURES=75 — DTK's nvcc shim translates sm_75 to the full
DCU set internally.
Source-level dispatches:
- HYGON entries added to operator.cc dispatch tables in paged_attention/,
paged_attention_prefill/, paged_caching/, embedding/, mha_kvcache/,
multi_head_attention_varlen/, plus device whitelists in nn/embedding.cc
and nn/rmsnorm.cc.
- paged_attention_prefill kernel_v2.cuh's nvcuda::wmma block gated with
!defined(ENABLE_HYGON_API); default_prefill_kernel returns "warp" under
Hygon (DCU has no Tensor Cores).
Flash-attn integration:
- ops/src/infinicore/adaptor/flash_attn_hygon_dlsym.cc — dlsym wrapper
resolving mha_fwd_kvcache / vllm_mha_varlen_fwd / paged_attention from the
system flash_attn_2_cuda*.so (DTK fork wheel). Compiled with plain g++ +
__HIP_PLATFORM_AMD__ to pick up <c10/hip/...> headers without conflicting
with libinfinicore's DTK CUDA-compat path.
- python/infinilm/__init__.py RTLD_GLOBAL-preloads flash_attn_2_cuda*.so so
dlsym(RTLD_DEFAULT, ...) resolves.
- Decode path uses paged_attention (not mha_fwd_kvcache) to keep HIP graph
capture clean.
- HIPStreamGuard binds PyTorch's current stream to InfiniLM's stream so all
ATen calls participate in graph capture.
bs=1 GEMV kernel (custom):
- ops/src/infiniop/ops/gemm/nvidia/gemv_hygon.cuh — purpose-built BF16 GEMV
(64-thread wavefront, bf16x4 vector loads, FP32 accumulator, wave-shuffle
reduce). DTK hipBLAS picks MT16x16x32 for M=1 GEMM which runs at ~10 GFLOPS
/ 0.02% peak; this kernel is HBM-bandwidth-bound.
- gemm_nvidia.cu fast-path under ENABLE_HYGON_API hits when n=batch=1 +
op_a=T + standard PyTorch weight layout (M×K row-major, lda=K).
- Verified +44% bs=1 decode tok/s (39.7 → 58.7), reaches 83% A100 baseline.
- Override with INFINIOPS_DISABLE_GEMV_HYGON=1.
Benchmark scripts and report:
- benchmark/server_hygon.sh / run_client_hygon.sh / run_client_tp4_70b_hygon.sh
- benchmark/REPORT.md (NVIDIA + Hygon results, cross-backend comparison)
- benchmark/REPORT_old.md (archeology — TP=4 deadlock investigation)
TP>1 deadlock workaround (in server_hygon.sh):
- HSA_ENABLE_SDMA=0: prevents Segment::Freeze BlitKernel signal-wait deadlock
at first-iter HIP module load (TP>=2).
- HSA_FORCE_FINE_GRAIN_PCIE=1: RCCL warns this is required for stability.
- NCCL_P2P_DISABLE=1 (TP>=4): RCCL P2P over PCIe deadlocks the ring at
first allreduce; routes through shared-memory staging instead.
Verified end-to-end on /root/models/9g_8b_thinking_llama (TP=1/2/4) and
/root/models/FM9G_70B_SFT_MHA_qwen2 (TP=4 eager + graph). Outputs
token-for-token consistent across modes; 4560 server-client requests with
0 failures at TP=1.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- generation/utils.py: dedup the duplicate next_tokens.to_numpy() call per decode step (one D2H + numpy alloc instead of two). Token-identical; jiuge.py decode +8% on Hygon DCU. - infer_engine.py: cache the constant input_offsets ([0, 1] for bs=1 seq_len=1) outside the decode loop. Token-identical; small bench gain. Verified on hygon (9g_8b_thinking_llama, tp=1, graph + paged-attn + flash-attn): correctness diff vs baseline = 0; bench decode 42.9 → 44.1 tok/s. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Summary
将 InfiniCore 中 InfiniLM 实际依赖的三块作为子模块吸收进本仓库,移除外部
INFINI_ROOT预装依赖,构建系统由 xmake 切换到 CMake:runtime/→libinfinirt.so(多平台运行时)ccl/→libinfiniccl.so(NCCL allreduce,依赖 runtime)ops/→libinfinicore.so+_infinicore.*.so(infiniop 算子 + infinicore C++ API + Python binding)顶层
CMakeLists.txt聚合三个子项目并构建已有的csrc/为_infinilm.*.so。所有 5 个 .so 共置于python/{infinilm,infinicore}/lib/,RPATH=$ORIGIN,无需LD_LIBRARY_PATH/INFINI_ROOT。安装
环境变量:
INFINILM_BUILD_FLASH_ATTN、INFINILM_FLASH_ATTN_REPO/REF/DIR、INFINILM_FLASH_ATTN_ARCHS、INFINILM_BUILD_TYPE、INFINILM_BUILD_JOBS。主要改动
runtime/、ccl/、ops/、顶层CMakeLists.txt、python/infinicore/setup.py(CMake 驱动)、pyproject.toml(build deps 加 cmake/pybind11)、.gitignore(让*.txt通配不再屏蔽CMakeLists.txt)、examples/jiuge.py(兼容 transformers ≥ 5.0 删除的batch_encode_plus)ENABLE_METAX_API守卫;mha_varlen_fwd调用点 + 声明添加 vllm fork 的num_splits参数xmake.lua保留但不再驱动构建;src/legacy ctypes 模型保留但不再编译验收(8× A100-80GB)
<think>Okay, the user is asking...您好, 我是由启元实验室主导研发的九格通用基础大模型。Test plan
pip install -e .默认路径INFINILM_BUILD_FLASH_ATTN=1 pip install -e . --no-build-isolationflash-attn 路径python examples/jiuge.py --model <llama_path> --device nvidia --tp 1 --backend cpppython examples/jiuge.py --model <qwen2_path> --device nvidia --tp 4 --backend cpp--enable-graph --enable-paged-attn--attn flash-attn --enable-paged-attn