Skip to content

[Enhance] use fp32 acc for liger fused ce loss#1747

Merged
HAOCHENYE merged 1 commit intoInternLM:mainfrom
nil0x9:linty/fix-liger-acc-dtype
May 1, 2026
Merged

[Enhance] use fp32 acc for liger fused ce loss#1747
HAOCHENYE merged 1 commit intoInternLM:mainfrom
nil0x9:linty/fix-liger-acc-dtype

Conversation

@nil0x9
Copy link
Copy Markdown
Collaborator

@nil0x9 nil0x9 commented Apr 30, 2026

This PR is to replace the default liger CE loss accumulator dtype to fp32. There is a known issue in fused CE loss implementation that, in mixed-precision training, when grad accumulator is maintained in low precision, training becomes unstable in late training stage.

The official liger-kernel implementation introduced an extra arg accum_dtype to LigerFusedLinearCrossEntropyLoss probably in light of this issue. But they keep the default behavior to be "following input dtype":
image

Therefore when we don't designate accum_dtype in LigerFusedLinearCrossEntropyLoss, this accumulator will follow the dtype of weight, which in most cases is of dtype bfloat16. This is likely gonna cause instability in training.

See also:

@HAOCHENYE HAOCHENYE merged commit 31dca88 into InternLM:main May 1, 2026
7 checks passed
Harold-lkk added a commit that referenced this pull request May 2, 2026
* [Fix] Fix torch2.10 EP>1 init load spec error. (#1688)

[Feature] Add wrapper for compute_local_shape_and_global_offset to handle meta tensor errors

* 【CI】change dir on npu (#1693)

change dir

* [Chore] clean commented debug code

* Propagate dataflow cleanup cancellation to rollout Ray tasks (#1699)

* 【CI】update resume cases (#1687)

* add new validation on resume cases

* fix f-string error

* update priority

* add qwen3.5 case about 8nums vs 16nums

* install tilelang

* limit version

* optimizer code

* Pin flash-linear-attention to 0.4.2 (#1710)

* upgrade deps to pytorch 2.9.1 and transformers 5.2.0 (#1596)

* chore(build): update deps transformers -> 5.2.0

* chore(build): update cudnn to 9.15.1.9 of torch 2.9.1 in dockerfile

* chore(ci): use wider tolerance in test_qwen3_5 sp case

* refactor: add RopeParametersConfig due to transformers 5.2.0 bc

* chore(build): conditional path for lmdeploy and sglang in Dockerfile

* chore(build): update dockerfile for deepep, deep_gemm and ci proxy speed fix

* fix(ci): ep>1 clip_grad_norm fails due to pt2.9 check

* fix(ci): clean hf dynamic modules before test setup

* chore(docker): update lmdeploy deps

* refactor: Move compile config from FSDPConfig to model_cfg

* fix(engine): use field rope_parameters when save_hf

* [Feature] MTP RL with KL Loss (#1727)

* mtp kl loss in rl

* support chunk kl

* resolve comments

* Complete MTP RL support wiring

* fix lint

* 【CI】fix format (#1726)

fix format

* [Adapt] add DominoEP for InternS1 Pro VL (#1720)

Co-authored-by: wentiange <tiangewen@qq.com>

* fix: follow up for add dominoep for VL model (#1732)

Co-authored-by: wentiange <tiangewen@qq.com>

* [Bug]: Fix batch reward (#1738)

* fix batch reward

* fix lint

* [Enhance] use fp32 acc for liger fused ce loss (#1747)

---------

Co-authored-by: RangiLyu <lyuchqi@gmail.com>
Co-authored-by: kkscilife <126147887+kkscilife@users.noreply.github.com>
Co-authored-by: nil0x9 <nil.0x9@proton.me>
Co-authored-by: duanyanhui <45005871+YanhuiDua@users.noreply.github.com>
Co-authored-by: CyCle1024 <ccy_justin@163.com>
Co-authored-by: tina-wen <61722970+tina-wen@users.noreply.github.com>
Co-authored-by: wentiange <tiangewen@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants