Skip to content

Fix Double Application of Softmax for Router Logits in MoE models#45346

Open
ionut-anghelina wants to merge 2 commits intohuggingface:mainfrom
ionut-anghelina:dev/ionut/FixDoubleSoftmax
Open

Fix Double Application of Softmax for Router Logits in MoE models#45346
ionut-anghelina wants to merge 2 commits intohuggingface:mainfrom
ionut-anghelina:dev/ionut/FixDoubleSoftmax

Conversation

@ionut-anghelina
Copy link
Copy Markdown

No description provided.

ionut-anghelina and others added 2 commits March 30, 2026 08:18
Several MoE routers applied softmax to raw logits inside forward() but
returned the result as `router_logits`. The load_balancing_loss_func then
applied softmax again, computing the aux loss on softmax(softmax(logits))
which flattens the distribution toward uniform, rendering the load-balancing
loss ineffective.

Fix: use a separate `router_probs` variable for the softmaxed values used
in top-k routing, keeping `router_logits` as raw logits so the loss
function's single softmax is correct.

Source modular files fixed:
- mixtral/modular_mixtral.py (MixtralTopKRouter)
- qwen2_moe/modular_qwen2_moe.py (Qwen2MoeTopKRouter)
- qwen3_vl_moe/modular_qwen3_vl_moe.py (Qwen3VLMoeTextTopKRouter)

Downstream models regenerated by make fix-repo:
mixtral, minimax, qwen2_moe, olmoe, flex_olmo, qwen3_moe, qwen3_next,
qwen3_omni_moe, qwen3_vl_moe, qwen3_5_moe

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add regression tests in mixtral and qwen2_moe to verify router_logits
  are raw logits (not softmax probabilities)
- Fix .to() dtype cast to use router_logits.dtype (model dtype) instead
  of router_probs.dtype (float32)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 9, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: flex_olmo, minimax, mixtral, olmoe, qwen2_moe, qwen3_5_moe, qwen3_moe, qwen3_next, qwen3_omni_moe, qwen3_vl_moe

@@ -89,6 +89,14 @@ def test_load_balancing_loss(self):
self.assertEqual(result.router_logits[0].shape, (91, config.num_local_experts))
torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(2, dtype=torch.float32), rtol=1e-2, atol=1e-2)

# Verify router_logits are raw logits, not softmax probabilities (regression test for double-softmax bug)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iirc, we have more appearances of that test in other models. It doesnt hurt to add them to all we have + maybe make it a generalized one in causal lm tester (because we now have ways to properly detect moes with the interface)

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 9, 2026

@Rocketknight1 I'm not sure about the current state here so just left a comment here since it seemed the most recent state of things. Lmk if not or where I should properly look at

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 9, 2026

Let's also add the fixes and closes statements for the issue and other PR please

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants