[NNX] NNX migration prep (6/N): NNX-native DPO#3773
Draft
ecnal-cienet wants to merge 4 commits intomainfrom
Draft
[NNX] NNX migration prep (6/N): NNX-native DPO#3773ecnal-cienet wants to merge 4 commits intomainfrom
ecnal-cienet wants to merge 4 commits intomainfrom
Conversation
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests - Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils - Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py - Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
…raining fixes Part 1 — sharding diagnostics and Linen<->NNX checkpoint utilities: - modify print_shardings_params to support NNX (maxtext_utils.py) - add --pure_nnx flag to run_sharding_dump.py - add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py) - add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py) Part 2 — post-training bug fixes: - models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the whole object as multimodal_input= kwarg; NNXDecoder only accepts individual fields) - optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams (callable() check before invoking learning_rate_fn) - train_distill.py: fix nested NNX transform issue (nnx.value_and_grad inside nnx.jit raises conflicting outer_index error); refactored to jax.value_and_grad + explicit nnx.split/merge pattern; teacher inference moved outside value_and_grad
Bug fixes (run as no-op while pure_nnx=False stays default):
- nnx_wrappers.py: add _refresh_variable_trace_state + is_linen_initializing;
call from ToLinen after nnx.update to fix "Cannot extract graph node from
different trace level" when grad tracers leak into Variable._trace_state.
- gpt_oss.py / olmo3.py: replace inline nn.Dropout(...) with self.dropout =
linears.Dropout(...) in __init__ to fix CallCompactUnboundModuleError.
- normalizations.py: Qwen3NextRMSNorm signature: eps -> epsilon, accept
shard_mode/kernel_axes/parameter_memory_host_offload for callsite parity.
- attentions.py / qwen3.py: callsites eps= -> epsilon=.
- moe.py: per_expert_scale block moved into the unfused-kernel else branch
(was scaling wo even when fused_kernel was active).
- models.py: build MTP block as MultiTokenPredictionBlock(...) directly
(drop the ToNNX(linen) + lazy_init wrap); pass multimodal_input whole
to NNXDecoder instead of unpacking 5 fields.
- gradient_accumulation.py: ZeRO-1+GA all-reduce annotation deferred until
after lax.scan (reduced/unreduced PartitionSpec is rejected inside scan
carry); use nnx.merge(..., copy=True) to avoid Variable reuse.
- diloco.py: NNX-aware state handling — state.params -> state.model.filter
(nnx.Param), step counter at state.optimizer.step, replace_nnx_model_params
helper for jax.lax.cond pytree-structure parity.
- train_compile.py: new _collect_nnx_activation_shardings helper (forward
pass populates _ACTIVATION_SHARDINGS_DUMP — get_abstract_state_nnx only
traces __init__); NNX path now passes 2-arg shaped_train_args (no rng);
diloco path patched to handle the 2-vs-3 length difference.
- muon_utils.py: get_model_mdn default pure_nnx=True; wrap NNX result as
{"params": nnx.to_pure_dict(...)} for parity with Linen tree shape.
- nnx_decoders.py: FP8+NNX scan fix — Linen FP8 ops (fp8_nanoo, fp8_gpu)
retain tracers in Linen scope across re-traces. Skip jax.checkpoint and
use a Python for-loop instead of jax.lax.scan when quantization is FP8.
Makes FP8 quantization usable on the NNX path.
- train.py (pre-train train_step): return nnx.state(new_state, nnx.Not
(nnx.Intermediate)) so sowed forward-pass artifacts (e.g. max_logits for
QK-Clip) don't break leaf-count parity with state_mesh_shardings.
- llama2.py: pass parameter_memory_host_offload to pre_self_attention_layer
_norm RMSNorm (was missing on this norm only).
- base.yml: add 4 pipeline-related logical_axis_rules — layers_outside
_pipeline, layers_per_stage, num_activations, circular_repeats. Additive,
no-op without use_nnx_pipeline=True.
NNX feature enablements (clear all 17 "Pure NNX support has not been
implemented yet" NotImplementedError sites by routing Linen-coupled
utilities to the Linen path; their on-disk format is Linen):
- layerwise_quantization.py (2 sites): operates on Linen-format checkpoints
via DeepSeek*ToLinen layers.
- lora_utils.py (1 site): downstream get_lora_abstract_state expects Linen
tree shape; LoRA adapters on disk are Linen.
- standalone_checkpointer.py (2 sites): add_entropy_to_checkpoint accesses
state.opt_state[0]._replace(mu=..., nu=...) — Linen-only.
- generate_param_only_checkpoint.py (3 sites): _possibly_unroll_params and
_save_decode_checkpoint use state.params["params"]["decoder"] — Linen.
- convert_gpt3_ckpt_from_paxml.py (2 sites): keystr_map targets Linen tree
paths (.params['params'], .opt_state.mu['params']).
- maxengine.py (3 sites): inference engine uses state.params and serves
Linen-format inference checkpoints.
- grpo_trainer.py (4 sites): RL trainer is end-to-end Linen-shaped; route
to Linen with a clear log warning since NNX-format checkpoints will fail
at restore time.
Vocab tiling on NNX (real implementation, not just routing):
- models.py: add Transformer.logits_from_hidden_states on the NNX
Transformer class — wraps NNXDecoder.apply_output_head with the
token_embedder; mirrors TransformerLinenPure.logits_from_hidden_states.
- vocabulary_tiling.py: add vocab_tiling_nnx_loss — chunks the vocab axis
via jax.lax.scan and calls model.logits_from_hidden_states(chunk) per
chunk. The NNX model carries its parameters internally so no explicit
FSDP gather is needed (unlike the Linen gathered_params pattern). MVP
uses default autograd; custom_vjp memory-savings optimization is a
follow-up if backward memory becomes a concern.
- train.py (NNX loss_fn): replace the NotImplementedError with the call
to vocab_tiling_nnx_loss using hidden_states from intermediates.
- pyconfig_deprecated.py / configs/types.py: drop the num_vocab_tiling > 1
and enable_nnx validation guards (no longer needed).
DPO + NNX retained as NotImplementedError but with a much more informative
message (points users at pure_nnx=False workaround). Full implementation
is deferred — needs a new TrainState shape carrying both policy and
reference NNX models plus an NNX dpo_loss_fn.
Stats: 26 source files modified, +406 / -171 lines. Linen invariant
verified: pure_nnx / enable_nnx / pure_nnx_decoder still default to False;
Linen-path UTs unaffected (3 pre-existing failures on the parent branch
remain unchanged — sharding_compare_test::deepseek2-16b,
optimizers_test::test_model_integration_kimi-k2-1t, diloco_test::two
_slices x2). All "Pure NNX support has not been implemented yet"
NotImplementedError sites cleared (was 17, now 0).
Implements NNX-native DPO so that the pure_nnx=True training path no longer raises NotImplementedError on use_dpo runs. The Linen DPO overlay pattern (model.apply(params=..., reference_params=...)) does not translate to NNX modules, which carry their parameters internally. Instead the policy and reference models are held as separate nnx.Module instances on TrainStateNNX, and a new dpo_loss_fn_nnx runs both forwards with stop_gradient on the reference logits. TrainStateNNX: - Add optional `reference_model: nnx.Module` field. apply_gradients continues to update only `self.model`, leaving `self.reference_model` bit-identical across steps. dpo_utils.py: - Add dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True). Signature mirrors the Linen dpo_loss_fn so it slots into gradient_accumulation_loss_and_grad's dispatcher (dropout_rng / params slots are unused for NNX; carried for parity, and reference_model is passed as the single extra_dpo_args entry). With nnx.value_and_grad(..., argnums=0) over the policy, no gradient flows to the reference model's nnx.Param leaves; the explicit jax.lax.stop_gradient on ref_logits is a belt-and-braces guard. - Both dpo_loss_fn (Linen) and dpo_loss_fn_nnx (NNX) now include indexer_loss=0.0 and mtp_loss=0.0 in aux so the gradient_accumulation aux pytree shape matches the non-DPO loss_fn. train.py: - Drop the NotImplementedError in train_step's NNX branch. When use_dpo, dispatch to dpo_loss_fn_nnx with state.reference_model as extra_dpo_args; otherwise use the regular loss_fn. eval_step gains the same dispatch. - diff_wrapper picks _loss_fn / extra_dpo_args from the per-path init block, so both the GA and non-GA NNX paths route DPO identically. - Checkpoint-save _split_dpo_state stripping is now Linen-only; TrainStateNNX saves whole (reference_model included) — the step-0 reload later overwrites reference_model from the step-0 checkpoint. train_utils.py: - NNX init_state_fn materializes a frozen reference_model alongside the policy when config.use_dpo. Both are constructed by _create_model_partial() with config.init_weights_seed, so they start identical (standard DPO practice) until the step-0 reload. - Step-0 checkpoint reload: copy step0_state["model"] into state["reference_model"]. Linen path unchanged. Tests: - New tests/unit/dpo_nnx_test.py (7 tests): TrainStateNNX reference_model init/hasattr semantics; apply_gradients leaves reference bit-identical; aux key set; identical policy/reference yields loss=log(2) and reward_accuracy=0.0 (strict > on equal logratios); dropout_rng/params slots are signature-compat only; nnx.value_and_grad(argnums=0) over the policy yields finite grads on policy params only. - train_nnx_test.py: drop the two stale negative tests (vocab_tiling_raises_not_implemented, train_step_dpo_raises_for_nnx) — both features are now real. Stats: 4 source files + 2 test files, +199/-22 source lines. Linen DPO path behaviorally unchanged (only adds two harmless aux-dict keys); NNX non-DPO path unchanged (all changes gated on config.use_dpo).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Implements NNX-native DPO. The
pure_nnx=Truetraining path no longer raisesNotImplementedErroronuse_dporuns.The Linen DPO overlay pattern (
model.apply(params=..., reference_params=...)) does not translate to NNX modules, which carry their parameters internally. Instead, the policy and reference models are held as separatennx.Moduleinstances onTrainStateNNX, and a newdpo_loss_fn_nnxruns both forwards withstop_gradienton the reference logits.This is the next step in the NNX migration after
feat/nnx-correctness-fixes— it closes the remaining hardNotImplementedErroron the NNX path and unblocks default-flippingpure_nnx=Truefor users who run DPO.What's in this PR
TrainStateNNX(src/maxtext/layers/train_state_nnx.py)reference_model: nnx.Modulefield.apply_gradientscontinues to update onlyself.model, leavingself.reference_modelbit-identical across steps.dpo_utils.py(src/maxtext/trainers/post_train/dpo/dpo_utils.py)dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True)mirroring the Linendpo_loss_fnsignature so it slots intogradient_accumulation_loss_and_grad's dispatcher (dropout_rng/paramsslots are unused for NNX;reference_modelis the singleextra_dpo_argsentry).jax.lax.stop_gradient. Withnnx.value_and_grad(..., argnums=0)over the policy, no gradient flows to the reference model'snnx.Paramleaves.dpo_loss_fn(Linen) anddpo_loss_fn_nnx(NNX) now includeindexer_loss=0.0/mtp_loss=0.0in their aux dicts so the gradient-accumulation aux pytree shape matches the non-DPOloss_fn._split_dpo_state/_merge_dpo_stateoverlay pattern has no NNX counterpart by design — NNX holds the reference as a sibling field onTrainStateNNX, andapply_gradientsalready only touchesself.model, so no split/merge is needed.train.py(src/maxtext/trainers/pre_train/train.py)NotImplementedErrorremoved fromtrain_stepandeval_stepNNX branches; both dispatch todpo_loss_fn_nnxwhenuse_dpo, withstate.reference_modelpassed asextra_dpo_args[0].diff_wrapperpicks_loss_fn/extra_dpo_argsfrom the per-path init block so GA and non-GA NNX paths route DPO identically._split_dpo_statecheckpoint-save stripping is now Linen-only (NNX saves the wholeTrainStateNNXincludingreference_model; the step-0 reload later overwritesreference_model).train_utils.py(src/maxtext/utils/train_utils.py)init_state_fnmaterializes a frozen reference model alongside the policy whenconfig.use_dpo. Both are constructed via_create_model_partial()withconfig.init_weights_seed, so they start identical (standard DPO practice) until the step-0 reload.step0_state["model"]intostate["reference_model"]. Linen path unchanged.Tests
TrainStateNNX(reference_model=...)init /hasattrsemanticsapply_gradientsleaves the reference model bit-identicaldpo_loss_fn_nnxaux key setloss = log(2)andreward_accuracy = 0.0(strict>on equal logratios)dropout_rng/paramsslots are signature-compat only — passing arbitrary values does not change the resultnnx.value_and_grad(..., argnums=0)over the policy yields finite grads on policy params onlyvocab_tiling_raises_not_implemented— vocab tiling on NNX was implemented infeat/nnx-correctness-fixestrain_step_dpo_raises_for_nnx— DPO on NNX is implemented hereLinen invariant
Linen DPO is behaviorally unchanged: the only Linen-side change is that
dpo_loss_fnnow includesindexer_loss=0.0/mtp_loss=0.0inaux. Existing callers either ignore those keys or read them with.get(..., 0.0), so this is a no-op for downstream code. The NNX non-DPO path is unchanged: every new code path is gated onconfig.use_dpo.Stats
train_nnx_test,dpo_nnx_test(new),train_state_nnx_test,train_utils_nnx_test,gradient_accumulation_nnx_testFollowups (not in this PR)
gradient_accumulation_steps > 1is wired up via the same dispatcher (extra_dpo_argscarriesreference_modelandgradient_accumulation_loss_and_grad's NNX branch already callsgrad_func(local_model, config, data, None, None, *extra_dpo_args, is_train=True)againstdpo_loss_fn_nnx's signature) but is not exercised by an explicit GA-DPO test. Tracked as a followup.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.