[WIP] Support Experts TP#1837
Open
jayhenry wants to merge 31 commits into
Open
Conversation
…alidation - Add torch_all2all_tpep dispatcher and wire it in dispatcher __init__ - Add megatron_tp_ep.md and validate_xtuner_tpep_md script with shell runner - Apply ruff formatting to validation script; fix mypy (ctx Any, combine_preprocess signature) Made-with: Cursor
…on tolerance for tests
…ngth all-to-all operations and Clarify the impact on computation overlap for domino ep.
Move TP communication into dispatch/combine, share sync and async collective cores, and use real TP reduce-scatter semantics. Update tests, docs, pseudocode, and validation snapshots for the new flow.
Co-authored-by: Cursor <cursoragent@cursor.com>
The Triton 3.4 pipeliner refuses to predicate `ttng.tensormap_create`, so the
original per-tile descriptors with dynamic `group_end` / `(group + 1) * N`
shapes crashed (`PassManager::run failed`) whenever the autotuner picked a
config that triggered outer-loop pipelining — observed at expert_tp=4 where
per-rank N=384 steers the autotuner there.
Restructure the kernel so the outer tile loop can pipeline normally:
* A and B descriptors get loop-invariant static shapes (`[M, K]` and
`[B_ROWS, K]` / `[B_ROWS, N]`) and are hoisted out of the loop. OOB reads
past `group_end` pull from the *next* group's tokens / weights, but the
contaminated output rows / columns are filtered by the masked C store
below, so correctness is preserved.
* C TMA store is replaced with a masked `tl.store`. C was the only remaining
per-tile `tensormap_create`; without removing it the pipeliner still
rejects the loop. Losing TMA store on a single BLOCK_M x BLOCK_N tile is
cheap relative to losing outer-loop pipelining.
* Extend autotune with BLOCK_N=128 / BLOCK_K=128 configs. The existing
{64,256} pair leaves N=384 and K=192 (typical at expert_tp=4) with no
cleanly-tiling option.
Microbenchmark (M_per_group=2048, ep=2/tp=4 shapes):
* w1w3 (N=384, K=2048): 401 -> 504 TFLOPS, +26%
* w2 (N=2048, K=192): 194 -> 261 TFLOPS, +35%
Real-training wall-clock gain is smaller (~2-3% tgs) because grouped_gemm is
~25% of the compute-stream critical path and load imbalance washes out the
uniform-M benchmark gain.
…lumbing Two correctness bugs and a tactical cleanup, bundled because they all live in ``DeepEPDispatcher`` and any one of them in isolation leaves the file in a broken state. 1. Forward NaN under ``intra_layer_micro_batch>1`` with virtual expert TP. ``dispatch_preprocess`` previously called ``buffer_capture()`` *before* ``_expand_topk_ids_for_tp`` / ``topk_weights.repeat_interleave``. The captured event therefore did not cover the expand kernels, and DeepEP's ``stream_wait(comm_stream, previous_event)`` only synchronized to a point before those kernels' writes. At ``intra=1`` the expand always finished before DeepEP could start, so the race was invisible; at ``intra=2`` mb1's dispatch could enter the comm stream before mb1's expand had retired, reading stale ``topk_ids`` / ``topk_weights`` and producing NaN. Move the expand into ``dispatch_preprocess`` (so it runs on Loop A's compute stream and overlaps the next microbatch's attention/gate) and capture the event after it. ``dispatch`` becomes a thin DeepEP launcher that consumes ``pre_dispatched["topk_weights"]``. 2. Backward grad_norm NaN — symmetric race. The existing ``hidden_states.grad_fn.register_prehook`` only covered the ``combined_grad_x`` path. ``combined_grad_recv_topk_weights`` flowed back through ``repeat_interleave_backward`` on the compute stream while DeepEP's dispatch backward wrote that gradient on the comm stream — no wait inserted, output read stale memory. Register a matching prehook on the expanded ``topk_weights.grad_fn`` referencing the same shared ``backward_previous_event``. 3. Drop the unused ``ExpertTP`` plumbing. ``ExpertTP``-tracking fields (``num_recv_tokens_per_expert_group``, ``tp_rank_row_counts``, ``hidden_backward_finished_event``, ``topk_weights_backward_previous_event``, ``topk_weights_backward_finished_event``, ``tp_backward_finished_event``) were carried in every result TypedDict and threaded through the autograd Function signatures but never consumed after virtual TP encoding subsumed their role. Removing them simplifies the data flow that (1) and (2) operate on. Base abstract ``dispatch_preprocess`` gains a ``topk_weights`` parameter; the non-DeepEP implementations (Naive, AGRS, TorchAll2All) accept and ignore it. Both callers in ``moe_decoder_layer`` pass ``router_results["topk_weights"]``. Verified at ep=2/tp=4 + intra_layer_micro_batch=2: step-1 loss matches the intra=1 baseline bit-exact (2.4467), grad_norm decays normally (46→11 over 4 steps).
…mn-parallel weights
Introduce a ``DTensor`` placement that fits XTuner's per-expert column-parallel
MoE weights, plus the checkpoint integration needed to keep DCP snapshots
working alongside it.
``InterleavedShard(dim, num_local_stripes)`` (xtuner/v1/utils/interleaved_shard.py)
splits a tensor dim into ``num_local_stripes`` stripes and column-parallel
splits *each* stripe across the mesh dim, so a fused
``[local_experts * num_fused_projections * out, in]`` weight ends up with every
TP rank holding the same half of every (expert, projection) — exactly what
column-parallel ``fused_w1w3`` requires. The placement deliberately reports a
shard layout that PyTorch's ``redistribute`` / ``full_tensor`` cannot reverse
(``shard_order=None`` on torch>=2.10), so the module also provides:
* ``has_interleaved_placement(dt)`` — guarded against the missing
``DTensorSpec.shard_order`` attribute on torch<2.10 via ``getattr``.
* ``compute_runs(...)`` — the deterministic (global_row_start, local_row_start,
length) plan callers use to copy HF safetensor slices into local tensors
without going through DTensor.
* ``reconstruct_full_tensor(dt)`` — an all-gather-based rebuild that bypasses
``redistribute``, including the post-``fully_shard`` 3D layout where FSDP
prepends ``_StridedShard`` on top of the (Shard, InterleavedShard) pair.
DCP integration (xtuner/v1/engine/train_engine.py) drops InterleavedShard
parameters and their optimizer state from snapshots before save/load. DCP's
planner cannot represent ``_StridedShard(split_factor != 1)`` — without
filtering it raises during plan generation. The dropped parameters are already
covered by HF safetensors written separately, so resume reloads them via
``from_hf`` after DCP restores the rest of the state.
Test coverage (tests/utils/test_interleaved_shard.py):
* world=4: plain ``(Shard, InterleavedShard)`` on a 2D (ep, tp) mesh —
layout matches a hand-computed per-expert column-parallel split and
``reconstruct_full_tensor`` round-trips bit-exact.
* world=8: the post-``fully_shard`` 3D layout — forward result matches the
unsharded reference and ``reconstruct_full_tensor`` still round-trips.
Both pass on torch 2.9 (py312-pt29 env).
… with HF save/load
End-to-end wiring of the InterleavedShard placement (from the previous commit)
through XTuner's MoE stack so that ``GroupedLinear`` column-parallel weights
are real DTensors at every step of the train / save / load loop.
Mesh plumbing
=============
* MoE model carries an ``(ep, tp)`` 2D sub-mesh and threads it through
``MoEDecoderLayer`` -> ``MoEBlock`` -> ``GroupedLinear`` constructors.
* ``module/dispatcher/__init__.py`` selects ``ep_tp_group`` (flattened
``ep × tp``) when ``tp_size > 1`` so DeepEP sees the full virtual-expert
group; otherwise it falls back to ``ep_group`` for the legacy path.
GroupedLinear weight layout
===========================
When ``ep_tp_mesh`` is supplied and ``tp_size > 1``:
* column-parallel: ``DTensor.from_local(local, mesh, (Shard(0),
InterleavedShard(0, num_local_stripes=local_experts*num_fused_projections)))``
— ``num_fused_projections=2`` for ``fused_w1w3`` so InterleavedShard cuts
inside each (expert, projection) stripe, otherwise the two TP ranks would
swap halves of gate_proj vs up_proj and silently corrupt
``silu(gate) * up``.
* row-parallel: ``(Shard(0), Shard(1))`` — TP cuts in_features, a different
tensor dim from EP, no shard_order conflict.
* ``from_local`` not ``distribute_tensor``: the latter goes through
``redistribute`` which crashes on the ``(Shard, InterleavedShard)`` pair
(``shard_order is None``).
* ``ep_tp_mesh`` absent or ``tp_size == 1`` keeps the legacy plain-tensor
path — no behavior change for non-TP MoE configs.
HF I/O (load_spec, base, init_weight)
=====================================
* ``LoadSpec.build_save_plan`` detects InterleavedShard DTensors and marks
them ``needs_full_reconstruct=True`` with empty ``shards``; the save planner
then calls ``reconstruct_full_tensor`` on rank 0 instead of trying to walk
per-shard offsets that the placement does not define.
* ``BaseModel.from_hf`` uses ``compute_runs(...)`` to copy HF safetensor
slices into local InterleavedShard tensors row-by-row, avoiding
``distribute_tensor`` / ``redistribute``.
* ``init_weight`` bypasses ``full_tensor`` / ``distribute_tensor`` for
InterleavedShard parameters — initializes the local shard in place with the
appropriate per-rank seed so init is consistent across the global tensor.
CI configs + design docs
========================
* ``ci/config/qwen3_moe_30BA3_tp.py`` — ep=2 / tp=4 / intra_layer_micro_batch=2
end-to-end smoke config.
* ``ci/config/qwen3_dense_8B_tp.py`` — dense TP path.
* ``ci/config/qwen3_moe_30BA3_ep8_il2.py`` — non-TP regression baseline
(ep=8, no TP) to catch regressions in the legacy path.
* ``docs/design/dense_tp.md`` and ``docs/design/load_spec_refactor_progress.md``
— design notes for the TP layout and the load-spec refactor that supports
``needs_full_reconstruct``.
Validation
==========
* InterleavedShard unit tests (added in the previous commit) cover
``reconstruct_full_tensor`` round-trip bit-exact, which is what HF save
relies on.
* Training validated end-to-end at ep=2/tp=4 + intra=2: loss=2.4467 at step 1
(bit-exact match to the intra=1 baseline), grad_norm decays normally
(46→11 over 4 steps), so HF load correctly populates the InterleavedShard
weights and forward/backward are numerically intact.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.