Fix Double Application of Softmax for Router Logits in MoE models#45346
Fix Double Application of Softmax for Router Logits in MoE models#45346ionut-anghelina wants to merge 3 commits intohuggingface:mainfrom
Conversation
Several MoE routers applied softmax to raw logits inside forward() but returned the result as `router_logits`. The load_balancing_loss_func then applied softmax again, computing the aux loss on softmax(softmax(logits)) which flattens the distribution toward uniform, rendering the load-balancing loss ineffective. Fix: use a separate `router_probs` variable for the softmaxed values used in top-k routing, keeping `router_logits` as raw logits so the loss function's single softmax is correct. Source modular files fixed: - mixtral/modular_mixtral.py (MixtralTopKRouter) - qwen2_moe/modular_qwen2_moe.py (Qwen2MoeTopKRouter) - qwen3_vl_moe/modular_qwen3_vl_moe.py (Qwen3VLMoeTextTopKRouter) Downstream models regenerated by make fix-repo: mixtral, minimax, qwen2_moe, olmoe, flex_olmo, qwen3_moe, qwen3_next, qwen3_omni_moe, qwen3_vl_moe, qwen3_5_moe Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add regression tests in mixtral and qwen2_moe to verify router_logits are raw logits (not softmax probabilities) - Fix .to() dtype cast to use router_logits.dtype (model dtype) instead of router_probs.dtype (float32) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| @@ -89,6 +89,14 @@ def test_load_balancing_loss(self): | |||
| self.assertEqual(result.router_logits[0].shape, (91, config.num_local_experts)) | |||
| torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(2, dtype=torch.float32), rtol=1e-2, atol=1e-2) | |||
|
|
|||
| # Verify router_logits are raw logits, not softmax probabilities (regression test for double-softmax bug) | |||
There was a problem hiding this comment.
Iirc, we have more appearances of that test in other models. It doesnt hurt to add them to all we have + maybe make it a generalized one in causal lm tester (because we now have ways to properly detect moes with the interface)
|
@Rocketknight1 I'm not sure about the current state here so just left a comment here since it seemed the most recent state of things. Lmk if not or where I should properly look at |
|
Let's also add the fixes and closes statements for the issue and other PR please |
|
Ok #45131 was merged instead, we can still use this PR for tests tho @Rocketknight1 @ionut-anghelina |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: mixtral, qwen2_moe |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
No description provided.