Skip to content

[WIP] Support Experts TP#1837

Open
jayhenry wants to merge 31 commits into
InternLM:mainfrom
jayhenry:ep_tp
Open

[WIP] Support Experts TP#1837
jayhenry wants to merge 31 commits into
InternLM:mainfrom
jayhenry:ep_tp

Conversation

@jayhenry

Copy link
Copy Markdown
Collaborator

No description provided.

jayhenry and others added 30 commits April 28, 2026 06:56
…alidation

- Add torch_all2all_tpep dispatcher and wire it in dispatcher __init__
- Add megatron_tp_ep.md and validate_xtuner_tpep_md script with shell runner
- Apply ruff formatting to validation script; fix mypy (ctx Any, combine_preprocess signature)

Made-with: Cursor
…ngth all-to-all operations and Clarify the impact on computation overlap for domino ep.
Move TP communication into dispatch/combine, share sync and async collective cores, and use real TP reduce-scatter semantics. Update tests, docs, pseudocode, and validation snapshots for the new flow.
Co-authored-by: Cursor <cursoragent@cursor.com>
The Triton 3.4 pipeliner refuses to predicate `ttng.tensormap_create`, so the
original per-tile descriptors with dynamic `group_end` / `(group + 1) * N`
shapes crashed (`PassManager::run failed`) whenever the autotuner picked a
config that triggered outer-loop pipelining — observed at expert_tp=4 where
per-rank N=384 steers the autotuner there.

Restructure the kernel so the outer tile loop can pipeline normally:

* A and B descriptors get loop-invariant static shapes (`[M, K]` and
  `[B_ROWS, K]` / `[B_ROWS, N]`) and are hoisted out of the loop. OOB reads
  past `group_end` pull from the *next* group's tokens / weights, but the
  contaminated output rows / columns are filtered by the masked C store
  below, so correctness is preserved.
* C TMA store is replaced with a masked `tl.store`. C was the only remaining
  per-tile `tensormap_create`; without removing it the pipeliner still
  rejects the loop. Losing TMA store on a single BLOCK_M x BLOCK_N tile is
  cheap relative to losing outer-loop pipelining.
* Extend autotune with BLOCK_N=128 / BLOCK_K=128 configs. The existing
  {64,256} pair leaves N=384 and K=192 (typical at expert_tp=4) with no
  cleanly-tiling option.

Microbenchmark (M_per_group=2048, ep=2/tp=4 shapes):
* w1w3 (N=384, K=2048): 401 -> 504 TFLOPS, +26%
* w2   (N=2048, K=192): 194 -> 261 TFLOPS, +35%

Real-training wall-clock gain is smaller (~2-3% tgs) because grouped_gemm is
~25% of the compute-stream critical path and load imbalance washes out the
uniform-M benchmark gain.
…lumbing

Two correctness bugs and a tactical cleanup, bundled because they all live in
``DeepEPDispatcher`` and any one of them in isolation leaves the file in a
broken state.

1. Forward NaN under ``intra_layer_micro_batch>1`` with virtual expert TP.

   ``dispatch_preprocess`` previously called ``buffer_capture()`` *before*
   ``_expand_topk_ids_for_tp`` / ``topk_weights.repeat_interleave``. The
   captured event therefore did not cover the expand kernels, and DeepEP's
   ``stream_wait(comm_stream, previous_event)`` only synchronized to a point
   before those kernels' writes. At ``intra=1`` the expand always finished
   before DeepEP could start, so the race was invisible; at ``intra=2`` mb1's
   dispatch could enter the comm stream before mb1's expand had retired,
   reading stale ``topk_ids`` / ``topk_weights`` and producing NaN.

   Move the expand into ``dispatch_preprocess`` (so it runs on Loop A's
   compute stream and overlaps the next microbatch's attention/gate) and
   capture the event after it. ``dispatch`` becomes a thin DeepEP launcher
   that consumes ``pre_dispatched["topk_weights"]``.

2. Backward grad_norm NaN — symmetric race.

   The existing ``hidden_states.grad_fn.register_prehook`` only covered
   the ``combined_grad_x`` path. ``combined_grad_recv_topk_weights`` flowed
   back through ``repeat_interleave_backward`` on the compute stream while
   DeepEP's dispatch backward wrote that gradient on the comm stream — no
   wait inserted, output read stale memory. Register a matching prehook on
   the expanded ``topk_weights.grad_fn`` referencing the same shared
   ``backward_previous_event``.

3. Drop the unused ``ExpertTP`` plumbing.

   ``ExpertTP``-tracking fields (``num_recv_tokens_per_expert_group``,
   ``tp_rank_row_counts``, ``hidden_backward_finished_event``,
   ``topk_weights_backward_previous_event``,
   ``topk_weights_backward_finished_event``,
   ``tp_backward_finished_event``) were carried in every result TypedDict and
   threaded through the autograd Function signatures but never consumed
   after virtual TP encoding subsumed their role. Removing them simplifies
   the data flow that (1) and (2) operate on.

Base abstract ``dispatch_preprocess`` gains a ``topk_weights`` parameter; the
non-DeepEP implementations (Naive, AGRS, TorchAll2All) accept and ignore it.
Both callers in ``moe_decoder_layer`` pass ``router_results["topk_weights"]``.

Verified at ep=2/tp=4 + intra_layer_micro_batch=2: step-1 loss matches the
intra=1 baseline bit-exact (2.4467), grad_norm decays normally (46→11 over
4 steps).
…mn-parallel weights

Introduce a ``DTensor`` placement that fits XTuner's per-expert column-parallel
MoE weights, plus the checkpoint integration needed to keep DCP snapshots
working alongside it.

``InterleavedShard(dim, num_local_stripes)`` (xtuner/v1/utils/interleaved_shard.py)
splits a tensor dim into ``num_local_stripes`` stripes and column-parallel
splits *each* stripe across the mesh dim, so a fused
``[local_experts * num_fused_projections * out, in]`` weight ends up with every
TP rank holding the same half of every (expert, projection) — exactly what
column-parallel ``fused_w1w3`` requires. The placement deliberately reports a
shard layout that PyTorch's ``redistribute`` / ``full_tensor`` cannot reverse
(``shard_order=None`` on torch>=2.10), so the module also provides:

  * ``has_interleaved_placement(dt)`` — guarded against the missing
    ``DTensorSpec.shard_order`` attribute on torch<2.10 via ``getattr``.
  * ``compute_runs(...)`` — the deterministic (global_row_start, local_row_start,
    length) plan callers use to copy HF safetensor slices into local tensors
    without going through DTensor.
  * ``reconstruct_full_tensor(dt)`` — an all-gather-based rebuild that bypasses
    ``redistribute``, including the post-``fully_shard`` 3D layout where FSDP
    prepends ``_StridedShard`` on top of the (Shard, InterleavedShard) pair.

DCP integration (xtuner/v1/engine/train_engine.py) drops InterleavedShard
parameters and their optimizer state from snapshots before save/load. DCP's
planner cannot represent ``_StridedShard(split_factor != 1)`` — without
filtering it raises during plan generation. The dropped parameters are already
covered by HF safetensors written separately, so resume reloads them via
``from_hf`` after DCP restores the rest of the state.

Test coverage (tests/utils/test_interleaved_shard.py):

  * world=4: plain ``(Shard, InterleavedShard)`` on a 2D (ep, tp) mesh —
    layout matches a hand-computed per-expert column-parallel split and
    ``reconstruct_full_tensor`` round-trips bit-exact.
  * world=8: the post-``fully_shard`` 3D layout — forward result matches the
    unsharded reference and ``reconstruct_full_tensor`` still round-trips.

Both pass on torch 2.9 (py312-pt29 env).
… with HF save/load

End-to-end wiring of the InterleavedShard placement (from the previous commit)
through XTuner's MoE stack so that ``GroupedLinear`` column-parallel weights
are real DTensors at every step of the train / save / load loop.

Mesh plumbing
=============
* MoE model carries an ``(ep, tp)`` 2D sub-mesh and threads it through
  ``MoEDecoderLayer`` -> ``MoEBlock`` -> ``GroupedLinear`` constructors.
* ``module/dispatcher/__init__.py`` selects ``ep_tp_group`` (flattened
  ``ep × tp``) when ``tp_size > 1`` so DeepEP sees the full virtual-expert
  group; otherwise it falls back to ``ep_group`` for the legacy path.

GroupedLinear weight layout
===========================
When ``ep_tp_mesh`` is supplied and ``tp_size > 1``:
  * column-parallel: ``DTensor.from_local(local, mesh, (Shard(0),
    InterleavedShard(0, num_local_stripes=local_experts*num_fused_projections)))``
    — ``num_fused_projections=2`` for ``fused_w1w3`` so InterleavedShard cuts
    inside each (expert, projection) stripe, otherwise the two TP ranks would
    swap halves of gate_proj vs up_proj and silently corrupt
    ``silu(gate) * up``.
  * row-parallel: ``(Shard(0), Shard(1))`` — TP cuts in_features, a different
    tensor dim from EP, no shard_order conflict.
  * ``from_local`` not ``distribute_tensor``: the latter goes through
    ``redistribute`` which crashes on the ``(Shard, InterleavedShard)`` pair
    (``shard_order is None``).
  * ``ep_tp_mesh`` absent or ``tp_size == 1`` keeps the legacy plain-tensor
    path — no behavior change for non-TP MoE configs.

HF I/O (load_spec, base, init_weight)
=====================================
* ``LoadSpec.build_save_plan`` detects InterleavedShard DTensors and marks
  them ``needs_full_reconstruct=True`` with empty ``shards``; the save planner
  then calls ``reconstruct_full_tensor`` on rank 0 instead of trying to walk
  per-shard offsets that the placement does not define.
* ``BaseModel.from_hf`` uses ``compute_runs(...)`` to copy HF safetensor
  slices into local InterleavedShard tensors row-by-row, avoiding
  ``distribute_tensor`` / ``redistribute``.
* ``init_weight`` bypasses ``full_tensor`` / ``distribute_tensor`` for
  InterleavedShard parameters — initializes the local shard in place with the
  appropriate per-rank seed so init is consistent across the global tensor.

CI configs + design docs
========================
* ``ci/config/qwen3_moe_30BA3_tp.py`` — ep=2 / tp=4 / intra_layer_micro_batch=2
  end-to-end smoke config.
* ``ci/config/qwen3_dense_8B_tp.py`` — dense TP path.
* ``ci/config/qwen3_moe_30BA3_ep8_il2.py`` — non-TP regression baseline
  (ep=8, no TP) to catch regressions in the legacy path.
* ``docs/design/dense_tp.md`` and ``docs/design/load_spec_refactor_progress.md``
  — design notes for the TP layout and the load-spec refactor that supports
  ``needs_full_reconstruct``.

Validation
==========
* InterleavedShard unit tests (added in the previous commit) cover
  ``reconstruct_full_tensor`` round-trip bit-exact, which is what HF save
  relies on.
* Training validated end-to-end at ep=2/tp=4 + intra=2: loss=2.4467 at step 1
  (bit-exact match to the intra=1 baseline), grad_norm decays normally
  (46→11 over 4 steps), so HF load correctly populates the InterleavedShard
  weights and forward/backward are numerically intact.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants