[Feat] Intranode Dispatch&Combine Kernel #522
Draft
yanboshao wants to merge 23 commits into
Draft
Conversation
7f53c40 to
1a8596c
Compare
xudoyuan
previously approved these changes
May 14, 2026
495da61 to
e3cd19d
Compare
Trim the FlyDSL Python helper surface introduced by the dispatch/combine
kernel down to what is strictly necessary, by leaning on existing main-
branch idioms and pushing small kernel-only wrappers into the kernel
file itself.
FlyDSL helper modules
- python/flydsl/expr/arith.py: revert to origin/main. Drop the unused
divui/remui/select_by_index extensions, and remove zext_i64 in favor
of a kernel-local _to_i64 helper that wraps arith.extui(_lv_unwrap(...)).
- python/flydsl/expr/vector.py: revert to origin/main. Drop the
bitcast_i32_to_v2bf16/bitcast_v2bf16_to_i32 helpers; the kernel now
uses the standard vector.from_elements + vector.bitcast + vector.extract
idiom (mirrors kernels/hgemm_splitk.py:578-585).
- python/flydsl/expr/rocdl/__init__.py: replace the bespoke ballot_i64 /
readlane wrappers with generic ballot(res, pred, **kw) and
readlane(res, src, lane, **kw) functions, aligned with the existing
readfirstlane(res, src, **kw) style: capture the ODS-generated symbols
as _ods_ballot / _ods_readlane up top, and use _to_ir coercion in the
wrappers. Lets call sites pick the lane-mask width (i32 on wave32,
i64 on wave64) explicitly.
Kernel
- kernels/dispatch_combine_intranode_kernel.py:
- Add three file-local helpers: _to_i64, _i32_to_vec_bitcast,
_vec_to_i32_bitcast (with docstrings pointing at the main-branch
idioms they mirror).
- Replace 31 arith.zext_i64(x) call sites with _to_i64(x); collapse
two arith.zext_i64(arith.constant(rank)) sites into
arith.constant(rank, type=T.i64()).
- Update the 4 llvm_bitcast call sites to use the new
_i32_to_vec_bitcast / _vec_to_i32_bitcast helpers.
- Update ballot_i64(...) / readlane(...) call sites to the new generic
APIs: ballot(T.i64(), pred), readlane(T.i32(), src, lane).
Net effect vs origin/main: arith.py and vector.py are now untouched;
rocdl/__init__.py keeps a +22 line delta (generic ballot/readlane
wrappers). All complexity that used to live in FlyDSL core has moved
into the kernel file where it belongs.
Verified
- torchrun -np 8 tests/kernels/test_profiler_dispatch_combine.py
--mode verify -> ALL PASS (diff=0 on dispatch + combine)
- torchrun -np 8 tests/kernels/test_profiler_dispatch_combine.py
--mode verify --enable-std-moe -> ALL PASS (max_diff=0.015625 within
StdMoE weighted tolerance)
Make the PR pass the `Check Python Code Style` CI step (.github/workflows/ pre-checks.yaml), which runs ``black --check --diff`` and ``ruff check`` on the set of Python files changed by the PR. Auto-fixes (ruff --fix): I001 (5 unsorted-imports), F401 (3 unused-imports), F811 (1 redefined-while-unused), W293 (1 blank-line-with-whitespace). Manual fixes: - F841 (7 unused-variable): drop dead assignments to ``tok_stride`` / ``inp_n_i32`` in dispatch_combine_intranode_kernel.py, and four ``hdim`` + one ``esz`` in dispatch_combine_intranode_op.py. - E702 (23 multiple-statements-on-one-line): split ``a; b; c`` boilerplate in tests/kernels/test_profiler_dispatch_combine.py (mostly ``dist.all_reduce`` aggregation patterns). - E402 (2 module-import-not-at-top): add ``# noqa: E402`` to the two imports that intentionally follow ``sys.path.insert(0, _p)`` in the test script. Formatting: run ``black`` (line-length=120, per pyproject.toml) on the four PR-modified Python files. ast_rewriter.py was already compliant. CI parity locally: ``black --check`` + ``ruff check`` both clean on all PR files. Verified end-to-end (8x GPU, gfx942, bf16) after the style sweep: - torchrun -np 8 tests/kernels/test_profiler_dispatch_combine.py --mode verify -> ALL PASS (diff=0 on dispatch + combine). - torchrun -np 8 tests/kernels/test_profiler_dispatch_combine.py --mode verify --enable-std-moe -> ALL PASS (max_diff=0.015625 within StdMoE weighted tolerance).
The dispatch/combine intranode test depends on mori shmem, which is
only installed on the 8-GPU multi-gpu CI runners. Previously pytest
collection on single-GPU / Navi-2-GPU runners would crash because
``import mori`` raises ModuleNotFoundError at module load time.
* tests/kernels/test_profiler_dispatch_combine.py: when imported under
pytest collection (detected via ``"pytest" in sys.modules``), call
``pytest.importorskip("mori")`` so single/dual-GPU jobs cleanly skip
this file. Direct ``torchrun``/``python`` invocations are unaffected
and still surface a normal ImportError when mori is genuinely missing.
* .github/workflows/flydsl.yaml: add two explicit multi-GPU steps to the
multi-gpu job that run the dispatch/combine verify torchrun script
(default config + --enable-std-moe). These only execute when the PR
carries the ``multi-gpu`` label, providing real 8-GPU coverage for the
new kernel.
* kernels/dispatch_combine_intranode_op.py: drop unused local
``_disp_wpb`` alias, use ``config.warp_num_per_block`` directly.
Align arith module ordering/ruff pragmas with mainline formatting so the Python style pre-check passes reliably in PR CI.
- StdMoE Phase 4: in-kernel ticket-based race-free grid barrier (i64 monotonic counter, no host-side reset; CUDAGraph-replay safe). - SmemAllocator arch from cached device probe; LDS capacity check at JIT. - New ROCDL fp4/fp8 (un)pack helpers; ast_rewriter recurses into synthesized wrapper bodies so nested control-flow still gets lowered. - Verify bypasses mori's broken P2P-read combine ref via self-check; fp4_dispatch gated to gfx950 with liveness check. - CI sweep collapses two verify steps into a single --ci-sweep.
…ep note Gate combine_no_stage1 behind _ENABLE_COMBINE_NO_STAGE1 so the standalone wrapper raises NotImplementedError; move the docstring above the body and document that Stage 3 consumes the caller-supplied cur_tok rather than total_recv. Update the CI sweep comment to drop the stale "P2P-read known-fail" note now that the path is verified by the FlyDSL self-check.
* recv-cap (mori MaxNumTokensToRecvPerRank parity): max_total_recv_tokens is the total budget; per-rank slots = min(ceil(cap/ws), M). cap==0 keeps the legacy worst-case (ws*M). * Fix OOB: combine Stage 1 P2P-scatters tokens at (rank*M + dest_lid), so shmem_comb_inp_* must stay ws*M regardless of cap; only dispatch- side buffers shrink. * Decouple max_token_bytes from per-call dtype; override only when max_token_type_size > 0 (mori EpDispatchCombineConfig parity). * JIT: single launcher; schema v3 -> v4 invalidates stale on-disk cache (previously caused hipErrorInvalidHandle). * verify_self: when effective_max_recv < ws*M, byte-exact vs mori is undefined; downgrade to NaN/Inf liveness + max-diff DIAG. Byte-exact remains the hard gate otherwise. * CI: add bf16_recv_cap_half (cap = ws*M/2). 12-case sweep PASS. Perf vs token_size baseline: dispatch -25%, combine -32%.
- kernels/dispatch_combine_intranode_kernel.py: move flydsl.expr.rocdl / flydsl.expr.typing / flydsl.utils.smem_allocator imports back to the top of the file (they were dropped below the _DISPATCH_COMBINE_JIT_ SCHEMA_VERSION / _S3_WIDE_PATH_THRESHOLD_I32 / _SLC_CACHE constants, which split the import block and tripped E402 + I001). - kernels/rmsnorm_kernel.py: drop dead arch = get_hip_arch() at L631 (F841; the other three call sites at L106/398/1005 still use arch). - tests/kernels/test_profiler_dispatch_combine.py: drop unused tol block (F841) — verify path uses mori-byte oracle (strict equal) now, the tolerance dict has no reader. - black auto-format on the three touched files (120-col line merges and a couple of redundant blank lines).
… op-wrapper comments
combine() API surface
---------------------
Drop block_num / warp_per_block / use_external_inp_buf from
EpDispatchCombineOp.combine() and combine_no_stage1(): launch
geometry and the P2P-read vs zero-copy mode are frozen at op
construction (taken from cfg.{block_num,warp_num_per_block,
use_external_inp_buf}), matching mori EpDispatchCombineHandle::
launchSettings. Rebuild the op to switch geometry; callers must
not pass these dynamically.
The previous shim that accepted these as -1 sentinels and raised
on mismatch is removed -- no in-tree caller passes them
(verified by grep over FlyDSL/), and tests/_ckw never staged
them.
Wrapper comments / docstrings
-----------------------------
Trim ~270 lines of redundant docstring / narration from
kernels/dispatch_combine_intranode_op.py without losing any
non-obvious invariant:
- Module docstring: dedup the three "Mori-parity surface" bullets
against the dataclass field comments below; keep the buffer-
sizing design note (combine_inp stays worst-case while dispatch
buffers shrink with the cap) since that is the one deliberate
drift from mori.
- Dataclass property docstrings (max_recv, effective_max_recv*,
max_token_bytes): collapse to 3-4 line contract statements.
- __init__ / _alloc_buffers: remove narration headers, keep root-
cause notes (xdev_flag must start at 1, mr vs mr_worst rationale,
zero-init purpose for tok_id_to_src / xdev_bar_mem /
comb_inp_{tok,wts}, StdMoE i64 ticket counter).
- _check_config: drop "Type sanity" / "Scales consistency" /
"Launch geometry" narration headers; keep the k <= 64 ballot
constraint, fp8_direct_cast + std_moe mutex, max_total_recv_
tokens clamp policy, total-experts i32 overflow guard.
- dispatch(): replace the historical "Earlier this list was
missing ..." bug story with just the JIT-slot 17..24 layout
table.
- combine() / combine_no_stage1(): keep the contract requirements
(cur_tok semantics, enable_weights rationale, fused-path
shmem_comb_inp pre-population) but drop repeated fp8_direct_cast
prose and the legacy barrier()/reset() tombstone block.
Remove the now-superseded comment about combine_no_stage1 reading
fp8 vs bf16: combine writes back cfg.data_type symmetrically and
the kernel handles the bf16->fp8 cast inline (UseFp8DirectCast).
CI sweep description
--------------------
.github/workflows/flydsl.yaml: update the inline notes for the
multi-GPU dispatch/combine sweep step to reflect 14 cases (was
12) -- adds the cap=ws minimum-cap path and the forced single-PE
hot-spot routing case introduced in P2-ci.
No semantic change in the kernel; black + ruff pre-checks pass
on all 160 changed files.
…ng + f32 bitcast JIT Three bugs surfaced by the L1/L2 main-matrix sweep that were blocking 100% pass. 1. Stage 3 wide-path silently emits no ``out_tok`` writes when ``warp_num_per_block == 16`` because the existing const_expr branch was gated on ``... and warp_num_per_block < 16`` with no else, so the wide-path block compiled to a no-op for wpb=16 and ``out_tok`` stayed zero (max_diff ~= max(|0 - k*inp|) ~= 44.75 on randn input). Add a step=64 fallback main loop covering ``[0, eff_end_128)`` for the ``wpb >= 16`` / ``n_i32 % 256 != 0`` case so bf16 + fp8_direct_cast both write correctly at wpb=16. The fp8_direct_cast wpb=16 case had been silently pseudo-passing because mori has no bf16+fp8cast oracle, so verify_self fell back to a NaN/Inf gate (DIAG showed |fly - k*inp|.max=43.5 -- same root cause). 2. ``_verify_dispatch_self_consistency`` early-returned on ``total_recv == 0`` for the local rank, but the ensuing ``_allgather_rows`` calls are collective ops -- surviving ranks then blocked on ``dist.all_gather`` until the 30-min NCCL watchdog timeout. L2 bs=1+k=1 hit this routinely. Now every rank reaches the collective in lock-step; only the byte-compare body is skipped when there is nothing to slice. 3. Stage 3 f32 path used scalar ``arith.bitcast(T.f32(), i32_val)`` directly, but the auto-generated MLIR ``BitcastOp`` requires operand[0] to be a raw ``mlir.ir.Value`` and ``i32_val`` arrives wrapped in an ``fx.Int32`` Numeric shell. bf16 / fp4 / fp8 paths route through ``vector.*`` builders that auto-unwrap, so they never tripped it; f32 is the only scalar-bitcast path and silently failed JIT compile (``Operand 0 of operation "arith.bitcast" must be a Value``). Extract the raw mlir Value via ``.ir_value()`` before invoking the dialect builder, then re-wrap with ``fx.Float32`` / ``fx.Int32`` so downstream ``acc + ...`` and ``arith.select(...)`` keep working. Also includes: ``soft_hot_pe`` routing for the CI sweep (used in place of ``forced_hot_spot`` which livelocks ROCm IPC on the single PE-0 ``shmem_tok_off`` slot), ``_apply_ci_case`` wiring for the external JSON case-file harness, and ``known_block_size`` kernel metadata so wpb=8/16 launches don't trip the default 256-thread guard. All three fixes verified end-to-end: 5 wpb=16 cases (bf16 + fp8dc) + bs=1/k=1 hang repro + 2 f32 dtype cases all PASS byte-equal vs mori, plus a 5-case wpb=8 regression confirms the existing quad/dual-unroll paths are unaffected.
Black auto-format split the two-line print("..." "...") introduced
in the previous commit 52c09f6 (Bug 2 fix) onto its own
print(\n ... \n) form to satisfy line-wrap rules. No behavior change.
Restores Check Python Code Style CI parity.
…stead of editing flydsl JIT core The combine launcher's ``_key_data_type`` was the raw ``torch.dtype`` object whereas ``_collect_closure_scalar_vals`` in ``flydsl.compiler.jit_function`` only materializes ``(int, float, bool, str, type(None), tuple)`` into the cache key, so distinct dtype variants of the combine kernel were silently sharing one on-disk artifact and the second variant would trip ``hipErrorInvalidHandle`` at module load. The dispatch launcher (line 1726) already addressed this with ``_key_data_type = str(data_type)``; mirror that on the combine side (line 1883) so the kernel's cache key carries dtype identity on its own, no flydsl/compiler changes required. Reverts the +49-line ``_is_stable_hashable_atom`` / ``_stable_atom_repr`` extension to ``python/flydsl/compiler/jit_function.py``; the fix now lives entirely in the kernel file.
08b23e3 to
68f158e
Compare
* Prune three CI cases that no longer represent kernel correctness:
- bf16_zerocopy_fp8_direct_cast (mutually-exclusive options after
the mori-parity zero-copy refactor; only ever ran as xfail)
- bf16_recv_cap_min (1-slot-per-peer stress that triggers ROCm
IPC fabric multi-second stalls; not a FlyDSL correctness bug)
- bf16_forced_hot_spot (all-traffic-to-PE-0 atomic_add hot-spot
livelock on ROCm; mori exhibits the same hang)
The forced_hot_spot routing mode is kept behind --routing CLI for
manual IPC-atomic regression hunting; cross-reference comments in
build_mori_ref / soft_hot_pe / argparse help updated accordingly.
* Relax fp8 verify_self byte-equality to a value-equivalent fallback
so +0/-0 fp8 sign-bit encoding differences don't flag as failures
(the dispatched floats compare exactly equal).
* Slim intranode kernel/op surface based on CI-debug findings.
AST-verified: CI_CASES count 17 -> 14, no dangling references to the
three removed cases.
The mori-parity refactor renamed the P2P-read path to ``zero_copy``. Update the CI workflow step comment to match so the description stays in sync with the actual kernel/op API surface.
…ofilers + black squeeze The mori-parity mixed-dtype CI cases (mixed_fp8_ocp_dispatch_bf16_combine_*, mixed_fp4_dispatch_bf16_combine_zero_copy) dispatch with one dtype and combine with another, so the mori kernel-name lookups in the cudagraph profilers need separate suffixes for the dispatch and combine sides. Plumb a ``combine_dtype_key`` (default = ``dtype_key``) through ``_allreduce_cudagraph_stats_from_key_averages``, ``_cudagraph_stats_from_trace`` and ``profile_cudagraph_op`` so each picks its own ``msuf_d`` / ``msuf_c`` and resolves the right ``EpDispatchIntraNodeKernel_*`` / ``EpCombineIntraNodeKernel_*`` labels for the mori comparison path. Same-dtype callers pass nothing and fall back to the original single-suffix behaviour. Also a one-line black squeeze on the ``forced_hot_spot`` assert that slipped through the earlier style pass.
…ulti-gpu CI rename Three fp8_ocp dispatch -> bf16 combine cases at M = 4 / 32 / 8192 covering the small / moderate / large batch regimes. zero_copy is picked at random per spawn via a new ``_random_fields`` case attribute (resolved once in the parent before mp.spawn so all 8 ranks see the same value), so both buffer modes get exercised at every M without doubling the case count. Also rename the multi-gpu CI job from "AllReduce Tests" to "Communication Operator Tests" to match what it actually runs (shmem regression + dispatch/combine sweep + allreduce), and drop a stale comment block that still referenced 14 cases / the removed recv_cap_min + forced_hot_spot paths.
…eep local) The mixed-dtype perf-sweep driver, its best-per-shape post-processor and the regenerated sweep report/jsonl are local working artifacts that don't belong in the PR. Untrack them via ``git rm --cached`` so they stay on disk for offline analysis without flowing into git. * dispatch_combine_mixed_dtype_sweep.md * dispatch_combine_mixed_dtype_sweep.jsonl * scripts/perf_sweep/append_best_per_shape.py * scripts/perf_sweep/run_mixed_dtype_sweep.py
dbc7b5c to
886560d
Compare
Co-authored-by: Cursor <cursoragent@cursor.com>
a83fe1f to
3e6d2e3
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Added dispatch & combine communication operators for MOE in intranode EP parallelism, replacing the alltoall operator.
Technical Details
Dispatch
combine.bf16,f32,fp8_ocp,fp8_fnuz,fp4enable_std_moe=Trueenable_std_moe=True, the output token layout is expert-major.Combine
dispatch.bf16,f32,fp8_ocp,fp8_fnuz,fp4enable_std_moe=Trueenable_std_moe=True, intermediate/output organization follows expert-major semantics.zero_copy=Truequant_type=fp8_direct_castbf16while using FP8 transport/staging in combine.bf16StdMoeZeroCopyTest Plan
Device: MI355 *8
CudaGraph: true
Dispatch: fp 4/fp8
Combine: bf16, ZeroCopy:true/false, fp8_direct_cast:false
StdMoe: off
hidden dim=7168, expert_per_token=8, expert_per_rank=32,routing=random
Test Result
FlyDSL vs mori: fp4 -> bf16, Zero-Copy = False
FlyDSL vs mori : fp4 -> bf16, Zero-Copy = True
fp8_ocp -> bf16, zero-copy = False
FlyDSL vs mori : fp8_ocp -> bf16, Zero-Copy = False
Submission Checklist