Fix Double Application of Softmax for Router Logits in MoE models#45346
Open
ionut-anghelina wants to merge 2 commits intohuggingface:mainfrom
Open
Fix Double Application of Softmax for Router Logits in MoE models#45346ionut-anghelina wants to merge 2 commits intohuggingface:mainfrom
ionut-anghelina wants to merge 2 commits intohuggingface:mainfrom
Conversation
Several MoE routers applied softmax to raw logits inside forward() but returned the result as `router_logits`. The load_balancing_loss_func then applied softmax again, computing the aux loss on softmax(softmax(logits)) which flattens the distribution toward uniform, rendering the load-balancing loss ineffective. Fix: use a separate `router_probs` variable for the softmaxed values used in top-k routing, keeping `router_logits` as raw logits so the loss function's single softmax is correct. Source modular files fixed: - mixtral/modular_mixtral.py (MixtralTopKRouter) - qwen2_moe/modular_qwen2_moe.py (Qwen2MoeTopKRouter) - qwen3_vl_moe/modular_qwen3_vl_moe.py (Qwen3VLMoeTextTopKRouter) Downstream models regenerated by make fix-repo: mixtral, minimax, qwen2_moe, olmoe, flex_olmo, qwen3_moe, qwen3_next, qwen3_omni_moe, qwen3_vl_moe, qwen3_5_moe Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add regression tests in mixtral and qwen2_moe to verify router_logits are raw logits (not softmax probabilities) - Fix .to() dtype cast to use router_logits.dtype (model dtype) instead of router_probs.dtype (float32) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Contributor
|
[For maintainers] Suggested jobs to run (before merge) run-slow: flex_olmo, minimax, mixtral, olmoe, qwen2_moe, qwen3_5_moe, qwen3_moe, qwen3_next, qwen3_omni_moe, qwen3_vl_moe |
vasqu
reviewed
Apr 9, 2026
| @@ -89,6 +89,14 @@ def test_load_balancing_loss(self): | |||
| self.assertEqual(result.router_logits[0].shape, (91, config.num_local_experts)) | |||
| torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(2, dtype=torch.float32), rtol=1e-2, atol=1e-2) | |||
|
|
|||
| # Verify router_logits are raw logits, not softmax probabilities (regression test for double-softmax bug) | |||
Contributor
There was a problem hiding this comment.
Iirc, we have more appearances of that test in other models. It doesnt hurt to add them to all we have + maybe make it a generalized one in causal lm tester (because we now have ways to properly detect moes with the interface)
Contributor
|
@Rocketknight1 I'm not sure about the current state here so just left a comment here since it seemed the most recent state of things. Lmk if not or where I should properly look at |
Contributor
|
Let's also add the fixes and closes statements for the issue and other PR please |
5 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.