[Enhance] use fp32 acc for liger fused ce loss#1747
Merged
HAOCHENYE merged 1 commit intoInternLM:mainfrom May 1, 2026
Merged
Conversation
HAOCHENYE
approved these changes
May 1, 2026
Harold-lkk
added a commit
that referenced
this pull request
May 2, 2026
* [Fix] Fix torch2.10 EP>1 init load spec error. (#1688) [Feature] Add wrapper for compute_local_shape_and_global_offset to handle meta tensor errors * 【CI】change dir on npu (#1693) change dir * [Chore] clean commented debug code * Propagate dataflow cleanup cancellation to rollout Ray tasks (#1699) * 【CI】update resume cases (#1687) * add new validation on resume cases * fix f-string error * update priority * add qwen3.5 case about 8nums vs 16nums * install tilelang * limit version * optimizer code * Pin flash-linear-attention to 0.4.2 (#1710) * upgrade deps to pytorch 2.9.1 and transformers 5.2.0 (#1596) * chore(build): update deps transformers -> 5.2.0 * chore(build): update cudnn to 9.15.1.9 of torch 2.9.1 in dockerfile * chore(ci): use wider tolerance in test_qwen3_5 sp case * refactor: add RopeParametersConfig due to transformers 5.2.0 bc * chore(build): conditional path for lmdeploy and sglang in Dockerfile * chore(build): update dockerfile for deepep, deep_gemm and ci proxy speed fix * fix(ci): ep>1 clip_grad_norm fails due to pt2.9 check * fix(ci): clean hf dynamic modules before test setup * chore(docker): update lmdeploy deps * refactor: Move compile config from FSDPConfig to model_cfg * fix(engine): use field rope_parameters when save_hf * [Feature] MTP RL with KL Loss (#1727) * mtp kl loss in rl * support chunk kl * resolve comments * Complete MTP RL support wiring * fix lint * 【CI】fix format (#1726) fix format * [Adapt] add DominoEP for InternS1 Pro VL (#1720) Co-authored-by: wentiange <tiangewen@qq.com> * fix: follow up for add dominoep for VL model (#1732) Co-authored-by: wentiange <tiangewen@qq.com> * [Bug]: Fix batch reward (#1738) * fix batch reward * fix lint * [Enhance] use fp32 acc for liger fused ce loss (#1747) --------- Co-authored-by: RangiLyu <lyuchqi@gmail.com> Co-authored-by: kkscilife <126147887+kkscilife@users.noreply.github.com> Co-authored-by: nil0x9 <nil.0x9@proton.me> Co-authored-by: duanyanhui <45005871+YanhuiDua@users.noreply.github.com> Co-authored-by: CyCle1024 <ccy_justin@163.com> Co-authored-by: tina-wen <61722970+tina-wen@users.noreply.github.com> Co-authored-by: wentiange <tiangewen@qq.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR is to replace the default liger CE loss accumulator dtype to fp32. There is a known issue in fused CE loss implementation that, in mixed-precision training, when grad accumulator is maintained in low precision, training becomes unstable in late training stage.
The official liger-kernel implementation introduced an extra arg

accum_dtypetoLigerFusedLinearCrossEntropyLossprobably in light of this issue. But they keep the default behavior to be "following input dtype":Therefore when we don't designate
accum_dtypeinLigerFusedLinearCrossEntropyLoss, this accumulator will follow the dtype of weight, which in most cases is of dtypebfloat16. This is likely gonna cause instability in training.See also:
accum_dtypeoption forFusedLinearCrossEntropylinkedin/Liger-Kernel#830LigerFusedLinearCrossEntropyLossCauses Training Loss to Diverge After Reaching ~8 linkedin/Liger-Kernel#512