Skip to content

Make MoE dispatch/MLP expert-axis batch sharding configurable (fix Mixtral EP throughput)#4179

Open
gulsumgudukbay wants to merge 5 commits into
AI-Hypercomputer:mainfrom
ROCm:fix-moe-expert-parallel-sharding
Open

Make MoE dispatch/MLP expert-axis batch sharding configurable (fix Mixtral EP throughput)#4179
gulsumgudukbay wants to merge 5 commits into
AI-Hypercomputer:mainfrom
ROCm:fix-moe-expert-parallel-sharding

Conversation

@gulsumgudukbay

@gulsumgudukbay gulsumgudukbay commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator

Description

PR #4007 added 'expert' to activation_batch_moe to fix a DeepSeek MoE throughput regression. That change is applied to the post-dispatch core MoE activations (dispatch_axis/mlp_axis) as well, where the tensor's expert dim E is already
sharded via activation_exp.
For models with few large experts (e.g. Mixtral), mapping the batch dim onto 'expert' too double-maps two tensor dims onto one mesh axis, so GSPMD abandons the expert-local layout and falls back to FSDP-style AllGather+ReduceScatter instead of expert-parallel AllToAll, creatiing a large throughput regression under single-node expert parallelism (ici_expert_parallelism=-1).
This PR adds a config flag moe_dispatch_no_expert_sharding (default false) that selects, for the training dispatch/MLP batch axis only, a new activation_batch_no_exp rule (['data','fsdp','fsdp_transpose'], i.e. without 'expert'). Mixtral-8x7b uses it as true.
This is the per-model knob anticipated in the #4007 review discussion (sharding core MoE activations by the expert physical axes rather than the batch dimension), without changing the default for any other model.

Behavior

  • Default false is byte-identical to current main for every model.
  • Only mixtral-8x7b opts in; no other config inherits it.
  • For mixtral-8x7b, the flag only changes sharding when the expert mesh axis size > 1 (expert parallelism active). When expert is size 1 (TPU/FSDP-primary, or non-expert parallelism GPU) the two axis rules are identical, so those paths are unaffected.

Scope (why only the training dispatch/MLP axes)

  • mask_axes is intentionally left on activation_batch_moe: switching it too regresses Mixtral bs=11 (~10,900 -> ~8,400 tok/s/device, train-step all-to-all 5 -> 3).
  • The inference-mode dispatch_axis/mlp_axis are left unchanged: the inference path is already expert-parallel with activation_batch_moe (no FSDP fallback) at both prefill (bs=1: all-to-all=290, all-gather=0) and decode (bs=8: all-to-all=193, all-gather=0), so the regression does not occur there. The change is scoped to the training path.

Tests

Measured on 1x MI355X node (8 GPUs), JAX 0.9.1, ici_expert_parallelism=-1, capacity_factor>0:

Model bs tok/s/device (before) tok/s/device (this PR) train-step all-to-all
Mixtral-8x7b 11 ~7,400 ~10,900 0 -> 5
Mixtral-8x7b 6 ~7,500 ~11,400 0 -> 5
DeepSeek-v2-lite-16b 8 ~17,800 ~17,800 (unchanged) 0 (unchanged)

Mixtral recovers ~47% throughput via restored expert-parallel AllToAll; DeepSeek (flag off by default) is unchanged.

Added a unit regression test test_moe_dispatch_keeps_expert_on_expert_dim (tests/unit/moe_test.py), parametrized over mixtral-8x7b, mixtral-8x22b, and deepseek3-671b: when moe_dispatch_no_expert_sharding is set it asserts the expert dim is sharded by the expert mesh axis and the batch dim is not (guarding the expert-parallel dispatch/MLP sharding). Passing locally.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 16, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 84.00000% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/sharding.py 69.23% 2 Missing and 2 partials ⚠️

📢 Thoughts on this report? Let us know!

Comment thread src/maxtext/configs/models/mixtral-8x7b.yml
Comment thread src/maxtext/layers/moe.py Outdated

@i-chaochen i-chaochen left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add test into moe_test.py for activation_batch_no_exp and moe_dispatch_batch_axis ? a pure logical check to make sure batch doesn't steal expert

@gulsumgudukbay

Copy link
Copy Markdown
Collaborator Author

can we add test into moe_test.py for activation_batch_no_exp and moe_dispatch_batch_axis ? a pure logical check to make sure batch doesn't steal expert

you beat me! I just pushed it seconds after this comment

@gulsumgudukbay gulsumgudukbay force-pushed the fix-moe-expert-parallel-sharding branch from 795a115 to 0fa5152 Compare June 17, 2026 20:35
@yeandy

yeandy commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

I tested across many models on MI355 single node. Mixtral regression is resolved without any impact to deepseek-v2-lite.

Seeing slight regression to llama2-7b, not suspecting that this change is the cause.

The dispatch/MLP MoE activations are already expert-sharded via
activation_exp. Since AI-Hypercomputer#4007, their batch dim also maps to
activation_batch_moe, which includes 'expert'. Under single-node expert
parallelism (ici_expert_parallelism=-1) this double-maps two tensor dims
onto the 'expert' mesh axis, so GSPMD falls back from expert-parallel
AllToAll to FSDP-style AllGather+ReduceScatter, regressing throughput for
few-large-expert models (e.g. Mixtral-8x7b: ~7.4k -> ~10.9k tok/s/device
at bs=11 on 8x MI355X).

Add a config flag moe_dispatch_no_expert_sharding (default false) that
selects a new activation_batch_no_exp rule ([data, fsdp, fsdp_transpose],
no 'expert') for the training dispatch/MLP batch axis. Enable it for
mixtral-8x7b. Default-false keeps every other model and all TPU/non-EP
paths byte-identical; the flag only changes sharding when the 'expert'
mesh axis size > 1.
…e-expert geometry as 8x7b, so it benefits from the same expert-parallel MoE dispatch/MLP sharding.
…oe_dispatch_no_expert_sharding the expert dim is sharded by 'expert' and the batch dim is not, guarding the expert-parallel dispatch/MLP sharding.
@gulsumgudukbay gulsumgudukbay force-pushed the fix-moe-expert-parallel-sharding branch from 0fa5152 to b6b511c Compare June 18, 2026 14:22
Comment thread tests/unit/moe_test.py Outdated
… of a new logical rule

Replace the activation_batch_no_exp logical rule with a remove_expert_from_partition_spec util (mirrors remove_fsdp_sharding), applied at the training dispatch/MLP sites when moe_dispatch_no_expert_sharding is set. Avoids a logical name that every custom_mesh_and_rule set would have to redefine. Same result (verified on 8xMI355X): Mixtral stays expert-parallel (a2a=5, ~11k tok/s/device), DeepSeek unchanged (a2a=0, ~17.8k).
@gulsumgudukbay gulsumgudukbay force-pushed the fix-moe-expert-parallel-sharding branch from ddc4b30 to edb3286 Compare June 18, 2026 19:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants