Skip to content

[NNX] NNX migration prep (6/N): NNX-native DPO#3773

Draft
ecnal-cienet wants to merge 4 commits intomainfrom
feat/nnx-native-dpo
Draft

[NNX] NNX migration prep (6/N): NNX-native DPO#3773
ecnal-cienet wants to merge 4 commits intomainfrom
feat/nnx-native-dpo

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

Summary

Implements NNX-native DPO. The pure_nnx=True training path no longer raises NotImplementedError on use_dpo runs.

The Linen DPO overlay pattern (model.apply(params=..., reference_params=...)) does not translate to NNX modules, which carry their parameters internally. Instead, the policy and reference models are held as separate nnx.Module instances on TrainStateNNX, and a new dpo_loss_fn_nnx runs both forwards with stop_gradient on the reference logits.

This is the next step in the NNX migration after feat/nnx-correctness-fixes — it closes the remaining hard NotImplementedError on the NNX path and unblocks default-flipping pure_nnx=True for users who run DPO.

What's in this PR

TrainStateNNX (src/maxtext/layers/train_state_nnx.py)

  • Optional reference_model: nnx.Module field. apply_gradients continues to update only self.model, leaving self.reference_model bit-identical across steps.

dpo_utils.py (src/maxtext/trainers/post_train/dpo/dpo_utils.py)

  • New dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True) mirroring the Linen dpo_loss_fn signature so it slots into gradient_accumulation_loss_and_grad's dispatcher (dropout_rng / params slots are unused for NNX; reference_model is the single extra_dpo_args entry).
  • Reference forward wrapped in jax.lax.stop_gradient. With nnx.value_and_grad(..., argnums=0) over the policy, no gradient flows to the reference model's nnx.Param leaves.
  • Both dpo_loss_fn (Linen) and dpo_loss_fn_nnx (NNX) now include indexer_loss=0.0 / mtp_loss=0.0 in their aux dicts so the gradient-accumulation aux pytree shape matches the non-DPO loss_fn.
  • The Linen _split_dpo_state / _merge_dpo_state overlay pattern has no NNX counterpart by design — NNX holds the reference as a sibling field on TrainStateNNX, and apply_gradients already only touches self.model, so no split/merge is needed.

train.py (src/maxtext/trainers/pre_train/train.py)

  • NotImplementedError removed from train_step and eval_step NNX branches; both dispatch to dpo_loss_fn_nnx when use_dpo, with state.reference_model passed as extra_dpo_args[0].
  • diff_wrapper picks _loss_fn / extra_dpo_args from the per-path init block so GA and non-GA NNX paths route DPO identically.
  • _split_dpo_state checkpoint-save stripping is now Linen-only (NNX saves the whole TrainStateNNX including reference_model; the step-0 reload later overwrites reference_model).

train_utils.py (src/maxtext/utils/train_utils.py)

  • NNX init_state_fn materializes a frozen reference model alongside the policy when config.use_dpo. Both are constructed via _create_model_partial() with config.init_weights_seed, so they start identical (standard DPO practice) until the step-0 reload.
  • Step-0 checkpoint reload writes step0_state["model"] into state["reference_model"]. Linen path unchanged.

Tests

  • New tests/unit/dpo_nnx_test.py (7 tests):
    • TrainStateNNX(reference_model=...) init / hasattr semantics
    • apply_gradients leaves the reference model bit-identical
    • dpo_loss_fn_nnx aux key set
    • Identical policy/reference yields loss = log(2) and reward_accuracy = 0.0 (strict > on equal logratios)
    • dropout_rng / params slots are signature-compat only — passing arbitrary values does not change the result
    • nnx.value_and_grad(..., argnums=0) over the policy yields finite grads on policy params only
  • tests/unit/train_nnx_test.py: dropped two stale negative tests:
    • vocab_tiling_raises_not_implemented — vocab tiling on NNX was implemented in feat/nnx-correctness-fixes
    • train_step_dpo_raises_for_nnx — DPO on NNX is implemented here

Linen invariant

Linen DPO is behaviorally unchanged: the only Linen-side change is that dpo_loss_fn now includes indexer_loss=0.0 / mtp_loss=0.0 in aux. Existing callers either ignore those keys or read them with .get(..., 0.0), so this is a no-op for downstream code. The NNX non-DPO path is unchanged: every new code path is gated on config.use_dpo.

Stats

  • 4 source files + 2 test files modified, +412 / −41 lines
  • All 27 tests pass across train_nnx_test, dpo_nnx_test (new), train_state_nnx_test, train_utils_nnx_test, gradient_accumulation_nnx_test

Followups (not in this PR)

  • DPO + NNX with gradient_accumulation_steps > 1 is wired up via the same dispatcher (extra_dpo_args carries reference_model and gradient_accumulation_loss_and_grad's NNX branch already calls grad_func(local_model, config, data, None, None, *extra_dpo_args, is_train=True) against dpo_loss_fn_nnx's signature) but is not exercised by an explicit GA-DPO test. Tracked as a followup.
  • Step-0 abstract-state shape mismatch (DPO-shape abstract state used to restore a non-DPO step-0 checkpoint) is mirrored from the Linen path's existing behavior; revisit only if integration testing surfaces issues.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests
- Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils
- Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py
- Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
…raining fixes

Part 1 — sharding diagnostics and Linen<->NNX checkpoint utilities:
- modify print_shardings_params to support NNX (maxtext_utils.py)
- add --pure_nnx flag to run_sharding_dump.py
- add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py)
- add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py)

Part 2 — post-training bug fixes:
- models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the
  whole object as multimodal_input= kwarg; NNXDecoder only accepts individual fields)
- optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams
  (callable() check before invoking learning_rate_fn)
- train_distill.py: fix nested NNX transform issue (nnx.value_and_grad inside nnx.jit
  raises conflicting outer_index error); refactored to jax.value_and_grad + explicit
  nnx.split/merge pattern; teacher inference moved outside value_and_grad
Bug fixes (run as no-op while pure_nnx=False stays default):
- nnx_wrappers.py: add _refresh_variable_trace_state + is_linen_initializing;
  call from ToLinen after nnx.update to fix "Cannot extract graph node from
  different trace level" when grad tracers leak into Variable._trace_state.
- gpt_oss.py / olmo3.py: replace inline nn.Dropout(...) with self.dropout =
  linears.Dropout(...) in __init__ to fix CallCompactUnboundModuleError.
- normalizations.py: Qwen3NextRMSNorm signature: eps -> epsilon, accept
  shard_mode/kernel_axes/parameter_memory_host_offload for callsite parity.
- attentions.py / qwen3.py: callsites eps= -> epsilon=.
- moe.py: per_expert_scale block moved into the unfused-kernel else branch
  (was scaling wo even when fused_kernel was active).
- models.py: build MTP block as MultiTokenPredictionBlock(...) directly
  (drop the ToNNX(linen) + lazy_init wrap); pass multimodal_input whole
  to NNXDecoder instead of unpacking 5 fields.
- gradient_accumulation.py: ZeRO-1+GA all-reduce annotation deferred until
  after lax.scan (reduced/unreduced PartitionSpec is rejected inside scan
  carry); use nnx.merge(..., copy=True) to avoid Variable reuse.
- diloco.py: NNX-aware state handling — state.params -> state.model.filter
  (nnx.Param), step counter at state.optimizer.step, replace_nnx_model_params
  helper for jax.lax.cond pytree-structure parity.
- train_compile.py: new _collect_nnx_activation_shardings helper (forward
  pass populates _ACTIVATION_SHARDINGS_DUMP — get_abstract_state_nnx only
  traces __init__); NNX path now passes 2-arg shaped_train_args (no rng);
  diloco path patched to handle the 2-vs-3 length difference.
- muon_utils.py: get_model_mdn default pure_nnx=True; wrap NNX result as
  {"params": nnx.to_pure_dict(...)} for parity with Linen tree shape.
- nnx_decoders.py: FP8+NNX scan fix — Linen FP8 ops (fp8_nanoo, fp8_gpu)
  retain tracers in Linen scope across re-traces. Skip jax.checkpoint and
  use a Python for-loop instead of jax.lax.scan when quantization is FP8.
  Makes FP8 quantization usable on the NNX path.
- train.py (pre-train train_step): return nnx.state(new_state, nnx.Not
  (nnx.Intermediate)) so sowed forward-pass artifacts (e.g. max_logits for
  QK-Clip) don't break leaf-count parity with state_mesh_shardings.
- llama2.py: pass parameter_memory_host_offload to pre_self_attention_layer
  _norm RMSNorm (was missing on this norm only).
- base.yml: add 4 pipeline-related logical_axis_rules — layers_outside
  _pipeline, layers_per_stage, num_activations, circular_repeats. Additive,
  no-op without use_nnx_pipeline=True.

NNX feature enablements (clear all 17 "Pure NNX support has not been
implemented yet" NotImplementedError sites by routing Linen-coupled
utilities to the Linen path; their on-disk format is Linen):
- layerwise_quantization.py (2 sites): operates on Linen-format checkpoints
  via DeepSeek*ToLinen layers.
- lora_utils.py (1 site): downstream get_lora_abstract_state expects Linen
  tree shape; LoRA adapters on disk are Linen.
- standalone_checkpointer.py (2 sites): add_entropy_to_checkpoint accesses
  state.opt_state[0]._replace(mu=..., nu=...) — Linen-only.
- generate_param_only_checkpoint.py (3 sites): _possibly_unroll_params and
  _save_decode_checkpoint use state.params["params"]["decoder"] — Linen.
- convert_gpt3_ckpt_from_paxml.py (2 sites): keystr_map targets Linen tree
  paths (.params['params'], .opt_state.mu['params']).
- maxengine.py (3 sites): inference engine uses state.params and serves
  Linen-format inference checkpoints.
- grpo_trainer.py (4 sites): RL trainer is end-to-end Linen-shaped; route
  to Linen with a clear log warning since NNX-format checkpoints will fail
  at restore time.

Vocab tiling on NNX (real implementation, not just routing):
- models.py: add Transformer.logits_from_hidden_states on the NNX
  Transformer class — wraps NNXDecoder.apply_output_head with the
  token_embedder; mirrors TransformerLinenPure.logits_from_hidden_states.
- vocabulary_tiling.py: add vocab_tiling_nnx_loss — chunks the vocab axis
  via jax.lax.scan and calls model.logits_from_hidden_states(chunk) per
  chunk. The NNX model carries its parameters internally so no explicit
  FSDP gather is needed (unlike the Linen gathered_params pattern). MVP
  uses default autograd; custom_vjp memory-savings optimization is a
  follow-up if backward memory becomes a concern.
- train.py (NNX loss_fn): replace the NotImplementedError with the call
  to vocab_tiling_nnx_loss using hidden_states from intermediates.
- pyconfig_deprecated.py / configs/types.py: drop the num_vocab_tiling > 1
  and enable_nnx validation guards (no longer needed).

DPO + NNX retained as NotImplementedError but with a much more informative
message (points users at pure_nnx=False workaround). Full implementation
is deferred — needs a new TrainState shape carrying both policy and
reference NNX models plus an NNX dpo_loss_fn.

Stats: 26 source files modified, +406 / -171 lines. Linen invariant
verified: pure_nnx / enable_nnx / pure_nnx_decoder still default to False;
Linen-path UTs unaffected (3 pre-existing failures on the parent branch
remain unchanged — sharding_compare_test::deepseek2-16b,
optimizers_test::test_model_integration_kimi-k2-1t, diloco_test::two
_slices x2). All "Pure NNX support has not been implemented yet"
NotImplementedError sites cleared (was 17, now 0).
Implements NNX-native DPO so that the pure_nnx=True training path no
longer raises NotImplementedError on use_dpo runs. The Linen DPO
overlay pattern (model.apply(params=..., reference_params=...)) does
not translate to NNX modules, which carry their parameters internally.
Instead the policy and reference models are held as separate
nnx.Module instances on TrainStateNNX, and a new dpo_loss_fn_nnx
runs both forwards with stop_gradient on the reference logits.

TrainStateNNX:
- Add optional `reference_model: nnx.Module` field. apply_gradients
  continues to update only `self.model`, leaving `self.reference_model`
  bit-identical across steps.

dpo_utils.py:
- Add dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params,
  reference_model, is_train=True). Signature mirrors the Linen
  dpo_loss_fn so it slots into gradient_accumulation_loss_and_grad's
  dispatcher (dropout_rng / params slots are unused for NNX; carried
  for parity, and reference_model is passed as the single
  extra_dpo_args entry). With nnx.value_and_grad(..., argnums=0) over
  the policy, no gradient flows to the reference model's nnx.Param
  leaves; the explicit jax.lax.stop_gradient on ref_logits is a
  belt-and-braces guard.
- Both dpo_loss_fn (Linen) and dpo_loss_fn_nnx (NNX) now include
  indexer_loss=0.0 and mtp_loss=0.0 in aux so the
  gradient_accumulation aux pytree shape matches the non-DPO loss_fn.

train.py:
- Drop the NotImplementedError in train_step's NNX branch. When
  use_dpo, dispatch to dpo_loss_fn_nnx with state.reference_model as
  extra_dpo_args; otherwise use the regular loss_fn. eval_step gains
  the same dispatch.
- diff_wrapper picks _loss_fn / extra_dpo_args from the per-path init
  block, so both the GA and non-GA NNX paths route DPO identically.
- Checkpoint-save _split_dpo_state stripping is now Linen-only;
  TrainStateNNX saves whole (reference_model included) — the step-0
  reload later overwrites reference_model from the step-0 checkpoint.

train_utils.py:
- NNX init_state_fn materializes a frozen reference_model alongside
  the policy when config.use_dpo. Both are constructed by
  _create_model_partial() with config.init_weights_seed, so they
  start identical (standard DPO practice) until the step-0 reload.
- Step-0 checkpoint reload: copy step0_state["model"] into
  state["reference_model"]. Linen path unchanged.

Tests:
- New tests/unit/dpo_nnx_test.py (7 tests): TrainStateNNX
  reference_model init/hasattr semantics; apply_gradients leaves
  reference bit-identical; aux key set; identical policy/reference
  yields loss=log(2) and reward_accuracy=0.0 (strict > on equal
  logratios); dropout_rng/params slots are signature-compat only;
  nnx.value_and_grad(argnums=0) over the policy yields finite grads
  on policy params only.
- train_nnx_test.py: drop the two stale negative tests
  (vocab_tiling_raises_not_implemented,
  train_step_dpo_raises_for_nnx) — both features are now real.

Stats: 4 source files + 2 test files, +199/-22 source lines. Linen
DPO path behaviorally unchanged (only adds two harmless aux-dict
keys); NNX non-DPO path unchanged (all changes gated on
config.use_dpo).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant