Make MoE dispatch/MLP expert-axis batch sharding configurable (fix Mixtral EP throughput)#4179
Open
gulsumgudukbay wants to merge 5 commits into
Open
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
i-chaochen
reviewed
Jun 17, 2026
i-chaochen
left a comment
There was a problem hiding this comment.
can we add test into moe_test.py for activation_batch_no_exp and moe_dispatch_batch_axis ? a pure logical check to make sure batch doesn't steal expert
Collaborator
Author
you beat me! I just pushed it seconds after this comment |
795a115 to
0fa5152
Compare
Contributor
|
I tested across many models on MI355 single node. Mixtral regression is resolved without any impact to deepseek-v2-lite. Seeing slight regression to llama2-7b, not suspecting that this change is the cause. |
The dispatch/MLP MoE activations are already expert-sharded via activation_exp. Since AI-Hypercomputer#4007, their batch dim also maps to activation_batch_moe, which includes 'expert'. Under single-node expert parallelism (ici_expert_parallelism=-1) this double-maps two tensor dims onto the 'expert' mesh axis, so GSPMD falls back from expert-parallel AllToAll to FSDP-style AllGather+ReduceScatter, regressing throughput for few-large-expert models (e.g. Mixtral-8x7b: ~7.4k -> ~10.9k tok/s/device at bs=11 on 8x MI355X). Add a config flag moe_dispatch_no_expert_sharding (default false) that selects a new activation_batch_no_exp rule ([data, fsdp, fsdp_transpose], no 'expert') for the training dispatch/MLP batch axis. Enable it for mixtral-8x7b. Default-false keeps every other model and all TPU/non-EP paths byte-identical; the flag only changes sharding when the 'expert' mesh axis size > 1.
…e-expert geometry as 8x7b, so it benefits from the same expert-parallel MoE dispatch/MLP sharding.
…oe_dispatch_no_expert_sharding the expert dim is sharded by 'expert' and the batch dim is not, guarding the expert-parallel dispatch/MLP sharding.
0fa5152 to
b6b511c
Compare
yeandy
reviewed
Jun 18, 2026
… of a new logical rule Replace the activation_batch_no_exp logical rule with a remove_expert_from_partition_spec util (mirrors remove_fsdp_sharding), applied at the training dispatch/MLP sites when moe_dispatch_no_expert_sharding is set. Avoids a logical name that every custom_mesh_and_rule set would have to redefine. Same result (verified on 8xMI355X): Mixtral stays expert-parallel (a2a=5, ~11k tok/s/device), DeepSeek unchanged (a2a=0, ~17.8k).
ddc4b30 to
edb3286
Compare
NuojCheng
approved these changes
Jun 18, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
PR #4007 added
'expert'toactivation_batch_moeto fix a DeepSeek MoE throughput regression. That change is applied to the post-dispatch core MoE activations (dispatch_axis/mlp_axis) as well, where the tensor's expert dimEis alreadysharded via
activation_exp.For models with few large experts (e.g. Mixtral), mapping the batch dim onto
'expert'too double-maps two tensor dims onto one mesh axis, so GSPMD abandons the expert-local layout and falls back to FSDP-style AllGather+ReduceScatter instead of expert-parallel AllToAll, creatiing a large throughput regression under single-node expert parallelism (ici_expert_parallelism=-1).This PR adds a config flag
moe_dispatch_no_expert_sharding(defaultfalse) that selects, for the training dispatch/MLP batch axis only, a newactivation_batch_no_exprule (['data','fsdp','fsdp_transpose'], i.e. without'expert'). Mixtral-8x7b uses it as true.This is the per-model knob anticipated in the #4007 review discussion (sharding core MoE activations by the expert physical axes rather than the batch dimension), without changing the default for any other model.
Behavior
falseis byte-identical to currentmainfor every model.mixtral-8x7bopts in; no other config inherits it.mixtral-8x7b, the flag only changes sharding when theexpertmesh axis size > 1 (expert parallelism active). Whenexpertis size 1 (TPU/FSDP-primary, or non-expert parallelism GPU) the two axis rules are identical, so those paths are unaffected.Scope (why only the training dispatch/MLP axes)
mask_axesis intentionally left onactivation_batch_moe: switching it too regresses Mixtral bs=11 (~10,900 -> ~8,400 tok/s/device, train-step all-to-all 5 -> 3).dispatch_axis/mlp_axisare left unchanged: the inference path is already expert-parallel withactivation_batch_moe(no FSDP fallback) at both prefill (bs=1: all-to-all=290, all-gather=0) and decode (bs=8: all-to-all=193, all-gather=0), so the regression does not occur there. The change is scoped to the training path.Tests
Measured on 1x MI355X node (8 GPUs), JAX 0.9.1,
ici_expert_parallelism=-1,capacity_factor>0:Mixtral recovers ~47% throughput via restored expert-parallel AllToAll; DeepSeek (flag off by default) is unchanged.
Added a unit regression test
test_moe_dispatch_keeps_expert_on_expert_dim(tests/unit/moe_test.py), parametrized overmixtral-8x7b,mixtral-8x22b, anddeepseek3-671b: whenmoe_dispatch_no_expert_shardingis set it asserts the expert dim is sharded by theexpertmesh axis and the batch dim is not (guarding the expert-parallel dispatch/MLP sharding). Passing locally.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.