Skip to content

[Feat] Intranode Dispatch&Combine Kernel #522

Draft
yanboshao wants to merge 23 commits into
mainfrom
yanbo/dispatch_combine
Draft

[Feat] Intranode Dispatch&Combine Kernel #522
yanboshao wants to merge 23 commits into
mainfrom
yanbo/dispatch_combine

Conversation

@yanboshao
Copy link
Copy Markdown
Contributor

@yanboshao yanboshao commented May 14, 2026

Motivation

Added dispatch & combine communication operators for MOE in intranode EP parallelism, replacing the alltoall operator.

Technical Details

Dispatch

  • Routes and packs input tokens into expert-aligned buffers for expert-parallel execution.
  • Produces routing metadata that is consumed by combine.
  • Supported dtypes
    • bf16, f32, fp8_ocp, fp8_fnuz, fp4
  • StdMoe
    • Flag: enable_std_moe=True
    • Effect: switches to Standard-MoE-compatible routing/packing path.
    • Memory layout note: when enable_std_moe=True, the output token layout is expert-major.

Combine

  • Reads expert outputs and routing metadata, then merges results back to token outputs.
  • Completes the inverse mapping of dispatch.
  • Supported dtypes
    • Runtime combine dtype supports bf16, f32, fp8_ocp, fp8_fnuz, fp4
    • Dispatch and combine dtypes can be decoupled (mixed-dtype path).
  • StdMoe
    • Flag: enable_std_moe=True
    • Effect: uses the Standard-MoE combine path and consumes StdMoe-style metadata.
    • Memory layout note: with enable_std_moe=True, intermediate/output organization follows expert-major semantics.
  • ZeroCopy
    • Flag: zero_copy=True
    • Effect: combine reads directly from registered communication/staging buffers, minimizing extra data movement.
  • fp8_direct_cast
    • Flag: quant_type=fp8_direct_cast
    • Effect: keeps external I/O in bf16 while using FP8 transport/staging in combine.
    • Constraints:
      • external dtype must be bf16
      • incompatible with StdMoe
      • incompatible with ZeroCopy

Test Plan

Device: MI355 *8
CudaGraph: true
Dispatch: fp 4/fp8
Combine: bf16, ZeroCopy:true/false, fp8_direct_cast:false
StdMoe: off
hidden dim=7168, expert_per_token=8, expert_per_rank=32,routing=random

Test Result

FlyDSL vs mori: fp4 -> bf16, Zero-Copy = False

max_tokens (batch size) FlyDSL best config (block_num/warp_per_block) FlyDSL dispatch time (us) FlyDSL combine time (us) FlyDSL total time (us) Mori best config (block_num/warp_per_block) Mori dispatch time (us) Mori combine time (us) Mori total time (us) speedup (mori_best_total / fly_best_total)
4 128/4 8.9 14.4 23.3 128/16 18.4 15.0 33.4 1.43x
8 128/8 9.0 14.2 23.2 128/8 18.9 15.2 34.1 1.47x
16 64/4 9.8 14.1 23.9 128/8 19.8 16.3 36.1 1.51x
32 128/16 12.2 15.1 27.3 64/16 21.3 18.0 39.3 1.44x
64 64/8 16.9 22.1 39.0 128/8 25.0 21.4 46.4 1.19x
128 256/16 22.6 35.2 57.8 128/16 37.9 29.9 67.8 1.17x
256 256/16 35.6 64.2 99.8 256/4 65.1 52.2 117.3 1.18x
512 128/8 59.8 101.8 161.6 128/16 120.4 100.6 221.0 1.37x
1024 64/4 109.8 179.8 289.6 128/8 235.3 201.7 437.0 1.51x
2048 128/4 206.8 356.5 563.3 64/4 458.5 402.3 860.8 1.53x
4096 64/8 399.9 736.0 1135.9 128/8 896.2 798.4 1694.6 1.49x

FlyDSL vs mori : fp4 -> bf16, Zero-Copy = True

max_tokens (batch size) FlyDSL best config (block_num/warp_per_block) FlyDSL dispatch time (us) FlyDSL combine time (us) FlyDSL total time (us) Mori best config (block_num/warp_per_block) Mori dispatch time (us) Mori combine time (us) Mori total time (us) speedup (mori_best_total / fly_best_total)
4 128/4 9.3 17.3 26.6 64/4 18.6 11.1 29.7 1.12x
8 128/4 9.7 18.0 27.7 256/8 18.9 11.4 30.3 1.09x
16 128/8 9.9 19.1 29.0 128/16 19.6 12.5 32.1 1.11x
32 64/8 11.8 20.0 31.8 256/8 21.2 20.5 41.7 1.31x
64 64/8 16.8 26.0 42.8 64/16 25.4 18.7 44.1 1.03x
128 64/4 22.9 33.4 56.3 128/4 38.3 24.7 63.0 1.12x
256 256/16 35.7 48.8 84.5 256/8 65.2 40.9 106.1 1.26x
512 256/16 60.4 79.7 140.1 64/4 120.4 73.5 193.9 1.38x
1024 128/8 109.6 155.4 265.0 128/8 234.3 141.3 375.6 1.42x
2048 64/16 206.6 301.0 507.6 128/8 457.8 274.7 732.5 1.44x
4096 256/8 399.0 582.9 981.9 64/16 896.2 526.4 1422.6 1.45x

fp8_ocp -> bf16, zero-copy = False

max_tokens (batch size) FlyDSL best config (block_num/warp_per_block) FlyDSL dispatch time (us) FlyDSL combine time (us) FlyDSL total time (us) Mori best config (block_num/warp_per_block) Mori dispatch time (us) Mori combine time (us) Mori total time (us) speedup (mori_best_total / fly_best_total)
4 128/4 11.5 16.4 27.9 256/8 15.0 19.7 34.7 1.24x
8 256/8 10.7 15.8 26.5 64/16 15.2 20.8 36.0 1.36x
16 256/8 11.6 16.3 27.9 256/16 15.5 22.1 37.6 1.35x
32 64/8 14.5 21.1 35.6 64/4 17.5 23.4 40.9 1.15x
64 128/8 21.7 34.4 56.1 128/4 22.5 28.8 51.3 0.91x
128 64/4 31.3 58.8 90.1 256/8 35.2 50.8 86.0 0.95x
256 128/16 50.3 102.0 152.3 128/8 59.5 97.4 156.9 1.03x
512 64/8 86.6 184.4 271.0 256/8 104.7 192.3 297.0 1.10x
1024 256/16 161.5 384.5 546.0 256/8 194.5 379.8 574.3 1.05x
2048 256/16 307.0 755.1 1062.1 256/16 374.3 758.4 1132.7 1.07x
4096 128/16 599.4 1527.6 2127.0 256/8 731.5 1529.2 2260.7 1.06x

FlyDSL vs mori : fp8_ocp -> bf16, Zero-Copy = False

max_tokens (batch size) FlyDSL best config (block_num/warp_per_block) FlyDSL dispatch time (us) FlyDSL combine time (us) FlyDSL total time (us) Mori best config (block_num/warp_per_block) Mori dispatch time (us) Mori combine time (us) Mori total time (us) speedup (mori_best_total / fly_best_total)
4 64/4 11.2 17.1 28.3 256/8 15.0 11.2 26.2 0.93x
8 128/16 11.6 17.7 29.3 128/4 15.1 12.5 27.6 0.94x
16 64/8 12.1 20.3 32.4 64/16 15.6 20.3 35.9 1.11x
32 64/16 14.0 25.7 39.7 256/8 17.5 18.0 35.5 0.89x
64 64/16 21.5 33.4 54.9 256/4 22.4 23.3 45.7 0.83x
128 256/16 31.0 47.8 78.8 128/16 35.1 38.9 74.0 0.94x
256 64/16 49.3 79.8 129.1 256/4 57.8 72.0 129.8 1.01x
512 128/4 86.8 153.7 240.5 64/8 99.7 137.3 237.0 0.99x
1024 64/4 159.6 286.4 446.0 256/8 189.1 264.7 453.8 1.02x
2048 64/4 305.1 564.4 869.5 128/4 368.1 520.9 889.0 1.02x
4096 256/8 595.8 1108.5 1704.3 128/16 718.7 1031.6 1750.3 1.03x

Submission Checklist

@yanboshao yanboshao marked this pull request as draft May 14, 2026 07:19
@yanboshao yanboshao changed the title feat(dispatch_combine): intranode dispatch/combine kernel [Feat]: intranode dispatch/combine kernel May 14, 2026
@yanboshao yanboshao changed the title [Feat]: intranode dispatch/combine kernel [Feat] Intranode Dispatch&Combine Kernel May 14, 2026
@yanboshao yanboshao force-pushed the yanbo/dispatch_combine branch from 7f53c40 to 1a8596c Compare May 14, 2026 07:53
xudoyuan
xudoyuan previously approved these changes May 14, 2026
Comment thread python/flydsl/compiler/ast_rewriter.py Outdated
@yanboshao yanboshao force-pushed the yanbo/dispatch_combine branch 2 times, most recently from 495da61 to e3cd19d Compare May 24, 2026 04:59
yanboshao added 14 commits May 25, 2026 22:32
Trim the FlyDSL Python helper surface introduced by the dispatch/combine
kernel down to what is strictly necessary, by leaning on existing main-
branch idioms and pushing small kernel-only wrappers into the kernel
file itself.

FlyDSL helper modules
- python/flydsl/expr/arith.py: revert to origin/main. Drop the unused
  divui/remui/select_by_index extensions, and remove zext_i64 in favor
  of a kernel-local _to_i64 helper that wraps arith.extui(_lv_unwrap(...)).
- python/flydsl/expr/vector.py: revert to origin/main. Drop the
  bitcast_i32_to_v2bf16/bitcast_v2bf16_to_i32 helpers; the kernel now
  uses the standard vector.from_elements + vector.bitcast + vector.extract
  idiom (mirrors kernels/hgemm_splitk.py:578-585).
- python/flydsl/expr/rocdl/__init__.py: replace the bespoke ballot_i64 /
  readlane wrappers with generic ballot(res, pred, **kw) and
  readlane(res, src, lane, **kw) functions, aligned with the existing
  readfirstlane(res, src, **kw) style: capture the ODS-generated symbols
  as _ods_ballot / _ods_readlane up top, and use _to_ir coercion in the
  wrappers. Lets call sites pick the lane-mask width (i32 on wave32,
  i64 on wave64) explicitly.

Kernel
- kernels/dispatch_combine_intranode_kernel.py:
  - Add three file-local helpers: _to_i64, _i32_to_vec_bitcast,
    _vec_to_i32_bitcast (with docstrings pointing at the main-branch
    idioms they mirror).
  - Replace 31 arith.zext_i64(x) call sites with _to_i64(x); collapse
    two arith.zext_i64(arith.constant(rank)) sites into
    arith.constant(rank, type=T.i64()).
  - Update the 4 llvm_bitcast call sites to use the new
    _i32_to_vec_bitcast / _vec_to_i32_bitcast helpers.
  - Update ballot_i64(...) / readlane(...) call sites to the new generic
    APIs: ballot(T.i64(), pred), readlane(T.i32(), src, lane).

Net effect vs origin/main: arith.py and vector.py are now untouched;
rocdl/__init__.py keeps a +22 line delta (generic ballot/readlane
wrappers). All complexity that used to live in FlyDSL core has moved
into the kernel file where it belongs.

Verified
- torchrun -np 8 tests/kernels/test_profiler_dispatch_combine.py
  --mode verify              -> ALL PASS (diff=0 on dispatch + combine)
- torchrun -np 8 tests/kernels/test_profiler_dispatch_combine.py
  --mode verify --enable-std-moe -> ALL PASS (max_diff=0.015625 within
  StdMoE weighted tolerance)
Make the PR pass the `Check Python Code Style` CI step (.github/workflows/
pre-checks.yaml), which runs ``black --check --diff`` and ``ruff check``
on the set of Python files changed by the PR.

Auto-fixes (ruff --fix): I001 (5 unsorted-imports), F401 (3 unused-imports),
F811 (1 redefined-while-unused), W293 (1 blank-line-with-whitespace).

Manual fixes:
- F841 (7 unused-variable): drop dead assignments to ``tok_stride`` /
  ``inp_n_i32`` in dispatch_combine_intranode_kernel.py, and four
  ``hdim`` + one ``esz`` in dispatch_combine_intranode_op.py.
- E702 (23 multiple-statements-on-one-line): split ``a; b; c`` boilerplate
  in tests/kernels/test_profiler_dispatch_combine.py (mostly
  ``dist.all_reduce`` aggregation patterns).
- E402 (2 module-import-not-at-top): add ``# noqa: E402`` to the two
  imports that intentionally follow ``sys.path.insert(0, _p)`` in the
  test script.

Formatting: run ``black`` (line-length=120, per pyproject.toml) on the
four PR-modified Python files. ast_rewriter.py was already compliant.

CI parity locally: ``black --check`` + ``ruff check`` both clean on all
PR files.

Verified end-to-end (8x GPU, gfx942, bf16) after the style sweep:
- torchrun -np 8 tests/kernels/test_profiler_dispatch_combine.py
  --mode verify              -> ALL PASS (diff=0 on dispatch + combine).
- torchrun -np 8 tests/kernels/test_profiler_dispatch_combine.py
  --mode verify --enable-std-moe -> ALL PASS (max_diff=0.015625 within
  StdMoE weighted tolerance).
The dispatch/combine intranode test depends on mori shmem, which is
only installed on the 8-GPU multi-gpu CI runners.  Previously pytest
collection on single-GPU / Navi-2-GPU runners would crash because
``import mori`` raises ModuleNotFoundError at module load time.

* tests/kernels/test_profiler_dispatch_combine.py: when imported under
  pytest collection (detected via ``"pytest" in sys.modules``), call
  ``pytest.importorskip("mori")`` so single/dual-GPU jobs cleanly skip
  this file.  Direct ``torchrun``/``python`` invocations are unaffected
  and still surface a normal ImportError when mori is genuinely missing.

* .github/workflows/flydsl.yaml: add two explicit multi-GPU steps to the
  multi-gpu job that run the dispatch/combine verify torchrun script
  (default config + --enable-std-moe).  These only execute when the PR
  carries the ``multi-gpu`` label, providing real 8-GPU coverage for the
  new kernel.

* kernels/dispatch_combine_intranode_op.py: drop unused local
  ``_disp_wpb`` alias, use ``config.warp_num_per_block`` directly.
Align arith module ordering/ruff pragmas with mainline formatting so the Python style pre-check passes reliably in PR CI.
- StdMoE Phase 4: in-kernel ticket-based race-free grid barrier
  (i64 monotonic counter, no host-side reset; CUDAGraph-replay safe).
- SmemAllocator arch from cached device probe; LDS capacity check at JIT.
- New ROCDL fp4/fp8 (un)pack helpers; ast_rewriter recurses into
  synthesized wrapper bodies so nested control-flow still gets lowered.
- Verify bypasses mori's broken P2P-read combine ref via self-check;
  fp4_dispatch gated to gfx950 with liveness check.
- CI sweep collapses two verify steps into a single --ci-sweep.
…ep note

Gate combine_no_stage1 behind _ENABLE_COMBINE_NO_STAGE1 so the standalone wrapper raises NotImplementedError; move the docstring above the body and document that Stage 3 consumes the caller-supplied cur_tok rather than total_recv. Update the CI sweep comment to drop the stale "P2P-read known-fail" note now that the path is verified by the FlyDSL self-check.
* recv-cap (mori MaxNumTokensToRecvPerRank parity): max_total_recv_tokens
  is the total budget; per-rank slots = min(ceil(cap/ws), M). cap==0
  keeps the legacy worst-case (ws*M).
* Fix OOB: combine Stage 1 P2P-scatters tokens at (rank*M + dest_lid),
  so shmem_comb_inp_* must stay ws*M regardless of cap; only dispatch-
  side buffers shrink.
* Decouple max_token_bytes from per-call dtype; override only when
  max_token_type_size > 0 (mori EpDispatchCombineConfig parity).
* JIT: single launcher; schema v3 -> v4 invalidates stale on-disk
  cache (previously caused hipErrorInvalidHandle).
* verify_self: when effective_max_recv < ws*M, byte-exact vs mori is
  undefined; downgrade to NaN/Inf liveness + max-diff DIAG. Byte-exact
  remains the hard gate otherwise.
* CI: add bf16_recv_cap_half (cap = ws*M/2). 12-case sweep PASS.
  Perf vs token_size baseline: dispatch -25%, combine -32%.
- kernels/dispatch_combine_intranode_kernel.py: move flydsl.expr.rocdl /
  flydsl.expr.typing / flydsl.utils.smem_allocator imports back to the
  top of the file (they were dropped below the _DISPATCH_COMBINE_JIT_
  SCHEMA_VERSION / _S3_WIDE_PATH_THRESHOLD_I32 / _SLC_CACHE constants,
  which split the import block and tripped E402 + I001).
- kernels/rmsnorm_kernel.py: drop dead arch = get_hip_arch() at L631
  (F841; the other three call sites at L106/398/1005 still use arch).
- tests/kernels/test_profiler_dispatch_combine.py: drop unused tol
  block (F841) — verify path uses mori-byte oracle (strict equal)
  now, the tolerance dict has no reader.
- black auto-format on the three touched files (120-col line merges
  and a couple of redundant blank lines).
… op-wrapper comments

combine() API surface
---------------------
Drop block_num / warp_per_block / use_external_inp_buf from
EpDispatchCombineOp.combine() and combine_no_stage1(): launch
geometry and the P2P-read vs zero-copy mode are frozen at op
construction (taken from cfg.{block_num,warp_num_per_block,
use_external_inp_buf}), matching mori EpDispatchCombineHandle::
launchSettings.  Rebuild the op to switch geometry; callers must
not pass these dynamically.

The previous shim that accepted these as -1 sentinels and raised
on mismatch is removed -- no in-tree caller passes them
(verified by grep over FlyDSL/), and tests/_ckw never staged
them.

Wrapper comments / docstrings
-----------------------------
Trim ~270 lines of redundant docstring / narration from
kernels/dispatch_combine_intranode_op.py without losing any
non-obvious invariant:

- Module docstring: dedup the three "Mori-parity surface" bullets
  against the dataclass field comments below; keep the buffer-
  sizing design note (combine_inp stays worst-case while dispatch
  buffers shrink with the cap) since that is the one deliberate
  drift from mori.
- Dataclass property docstrings (max_recv, effective_max_recv*,
  max_token_bytes): collapse to 3-4 line contract statements.
- __init__ / _alloc_buffers: remove narration headers, keep root-
  cause notes (xdev_flag must start at 1, mr vs mr_worst rationale,
  zero-init purpose for tok_id_to_src / xdev_bar_mem /
  comb_inp_{tok,wts}, StdMoE i64 ticket counter).
- _check_config: drop "Type sanity" / "Scales consistency" /
  "Launch geometry" narration headers; keep the k <= 64 ballot
  constraint, fp8_direct_cast + std_moe mutex, max_total_recv_
  tokens clamp policy, total-experts i32 overflow guard.
- dispatch(): replace the historical "Earlier this list was
  missing ..." bug story with just the JIT-slot 17..24 layout
  table.
- combine() / combine_no_stage1(): keep the contract requirements
  (cur_tok semantics, enable_weights rationale, fused-path
  shmem_comb_inp pre-population) but drop repeated fp8_direct_cast
  prose and the legacy barrier()/reset() tombstone block.

Remove the now-superseded comment about combine_no_stage1 reading
fp8 vs bf16: combine writes back cfg.data_type symmetrically and
the kernel handles the bf16->fp8 cast inline (UseFp8DirectCast).

CI sweep description
--------------------
.github/workflows/flydsl.yaml: update the inline notes for the
multi-GPU dispatch/combine sweep step to reflect 14 cases (was
12) -- adds the cap=ws minimum-cap path and the forced single-PE
hot-spot routing case introduced in P2-ci.

No semantic change in the kernel; black + ruff pre-checks pass
on all 160 changed files.
…ng + f32 bitcast JIT

Three bugs surfaced by the L1/L2 main-matrix sweep that were blocking
100% pass.

1. Stage 3 wide-path silently emits no ``out_tok`` writes when
   ``warp_num_per_block == 16`` because the existing const_expr branch
   was gated on ``... and warp_num_per_block < 16`` with no else, so the
   wide-path block compiled to a no-op for wpb=16 and ``out_tok`` stayed
   zero (max_diff ~= max(|0 - k*inp|) ~= 44.75 on randn input). Add a
   step=64 fallback main loop covering ``[0, eff_end_128)`` for the
   ``wpb >= 16`` / ``n_i32 % 256 != 0`` case so bf16 + fp8_direct_cast
   both write correctly at wpb=16. The fp8_direct_cast wpb=16 case had
   been silently pseudo-passing because mori has no bf16+fp8cast oracle,
   so verify_self fell back to a NaN/Inf gate (DIAG showed
   |fly - k*inp|.max=43.5 -- same root cause).

2. ``_verify_dispatch_self_consistency`` early-returned on
   ``total_recv == 0`` for the local rank, but the ensuing
   ``_allgather_rows`` calls are collective ops -- surviving ranks then
   blocked on ``dist.all_gather`` until the 30-min NCCL watchdog timeout.
   L2 bs=1+k=1 hit this routinely. Now every rank reaches the collective
   in lock-step; only the byte-compare body is skipped when there is
   nothing to slice.

3. Stage 3 f32 path used scalar ``arith.bitcast(T.f32(), i32_val)``
   directly, but the auto-generated MLIR ``BitcastOp`` requires
   operand[0] to be a raw ``mlir.ir.Value`` and ``i32_val`` arrives
   wrapped in an ``fx.Int32`` Numeric shell. bf16 / fp4 / fp8 paths route
   through ``vector.*`` builders that auto-unwrap, so they never tripped
   it; f32 is the only scalar-bitcast path and silently failed JIT
   compile (``Operand 0 of operation "arith.bitcast" must be a Value``).
   Extract the raw mlir Value via ``.ir_value()`` before invoking the
   dialect builder, then re-wrap with ``fx.Float32`` / ``fx.Int32`` so
   downstream ``acc + ...`` and ``arith.select(...)`` keep working.

Also includes: ``soft_hot_pe`` routing for the CI sweep (used in place
of ``forced_hot_spot`` which livelocks ROCm IPC on the single PE-0
``shmem_tok_off`` slot), ``_apply_ci_case`` wiring for the external
JSON case-file harness, and ``known_block_size`` kernel metadata so
wpb=8/16 launches don't trip the default 256-thread guard.

All three fixes verified end-to-end: 5 wpb=16 cases (bf16 + fp8dc) +
bs=1/k=1 hang repro + 2 f32 dtype cases all PASS byte-equal vs mori,
plus a 5-case wpb=8 regression confirms the existing quad/dual-unroll
paths are unaffected.
Black auto-format split the two-line print("..." "...") introduced
in the previous commit 52c09f6 (Bug 2 fix) onto its own
print(\n ... \n) form to satisfy line-wrap rules. No behavior change.
Restores Check Python Code Style CI parity.
…stead of editing flydsl JIT core

The combine launcher's ``_key_data_type`` was the raw ``torch.dtype``
object whereas ``_collect_closure_scalar_vals`` in
``flydsl.compiler.jit_function`` only materializes
``(int, float, bool, str, type(None), tuple)`` into the cache key, so
distinct dtype variants of the combine kernel were silently sharing
one on-disk artifact and the second variant would trip
``hipErrorInvalidHandle`` at module load.

The dispatch launcher (line 1726) already addressed this with
``_key_data_type = str(data_type)``; mirror that on the combine
side (line 1883) so the kernel's cache key carries dtype identity
on its own, no flydsl/compiler changes required.

Reverts the +49-line ``_is_stable_hashable_atom`` /
``_stable_atom_repr`` extension to ``python/flydsl/compiler/jit_function.py``;
the fix now lives entirely in the kernel file.
@yanboshao yanboshao force-pushed the yanbo/dispatch_combine branch from 08b23e3 to 68f158e Compare May 26, 2026 04:15
yanboshao added 6 commits May 26, 2026 10:17
* Prune three CI cases that no longer represent kernel correctness:
  - bf16_zerocopy_fp8_direct_cast (mutually-exclusive options after
    the mori-parity zero-copy refactor; only ever ran as xfail)
  - bf16_recv_cap_min (1-slot-per-peer stress that triggers ROCm
    IPC fabric multi-second stalls; not a FlyDSL correctness bug)
  - bf16_forced_hot_spot (all-traffic-to-PE-0 atomic_add hot-spot
    livelock on ROCm; mori exhibits the same hang)
  The forced_hot_spot routing mode is kept behind --routing CLI for
  manual IPC-atomic regression hunting; cross-reference comments in
  build_mori_ref / soft_hot_pe / argparse help updated accordingly.
* Relax fp8 verify_self byte-equality to a value-equivalent fallback
  so +0/-0 fp8 sign-bit encoding differences don't flag as failures
  (the dispatched floats compare exactly equal).
* Slim intranode kernel/op surface based on CI-debug findings.

AST-verified: CI_CASES count 17 -> 14, no dangling references to the
three removed cases.
The mori-parity refactor renamed the P2P-read path to ``zero_copy``.
Update the CI workflow step comment to match so the description
stays in sync with the actual kernel/op API surface.
…ofilers + black squeeze

The mori-parity mixed-dtype CI cases (mixed_fp8_ocp_dispatch_bf16_combine_*,
mixed_fp4_dispatch_bf16_combine_zero_copy) dispatch with one dtype and
combine with another, so the mori kernel-name lookups in the cudagraph
profilers need separate suffixes for the dispatch and combine sides.
Plumb a ``combine_dtype_key`` (default = ``dtype_key``) through
``_allreduce_cudagraph_stats_from_key_averages``, ``_cudagraph_stats_from_trace``
and ``profile_cudagraph_op`` so each picks its own ``msuf_d`` / ``msuf_c``
and resolves the right ``EpDispatchIntraNodeKernel_*`` / ``EpCombineIntraNodeKernel_*``
labels for the mori comparison path.  Same-dtype callers pass nothing
and fall back to the original single-suffix behaviour.

Also a one-line black squeeze on the ``forced_hot_spot`` assert that
slipped through the earlier style pass.
yanboshao added 2 commits May 27, 2026 23:01
…ulti-gpu CI rename

Three fp8_ocp dispatch -> bf16 combine cases at M = 4 / 32 / 8192
covering the small / moderate / large batch regimes.  zero_copy is
picked at random per spawn via a new ``_random_fields`` case attribute
(resolved once in the parent before mp.spawn so all 8 ranks see the
same value), so both buffer modes get exercised at every M without
doubling the case count.

Also rename the multi-gpu CI job from "AllReduce Tests" to
"Communication Operator Tests" to match what it actually runs (shmem
regression + dispatch/combine sweep + allreduce), and drop a stale
comment block that still referenced 14 cases / the removed
recv_cap_min + forced_hot_spot paths.
…eep local)

The mixed-dtype perf-sweep driver, its best-per-shape post-processor
and the regenerated sweep report/jsonl are local working artifacts
that don't belong in the PR.  Untrack them via ``git rm --cached`` so
they stay on disk for offline analysis without flowing into git.

* dispatch_combine_mixed_dtype_sweep.md
* dispatch_combine_mixed_dtype_sweep.jsonl
* scripts/perf_sweep/append_best_per_shape.py
* scripts/perf_sweep/run_mixed_dtype_sweep.py
@yanboshao yanboshao force-pushed the yanbo/dispatch_combine branch from dbc7b5c to 886560d Compare May 28, 2026 05:11
Co-authored-by: Cursor <cursoragent@cursor.com>
@yanboshao yanboshao force-pushed the yanbo/dispatch_combine branch from a83fe1f to 3e6d2e3 Compare May 29, 2026 10:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants