[NNX] NNX migration prep (10/N): vocab tiling custom_vjp with output-head carve-out#3849
Draft
ecnal-cienet wants to merge 10 commits intomainfrom
Draft
[NNX] NNX migration prep (10/N): vocab tiling custom_vjp with output-head carve-out#3849ecnal-cienet wants to merge 10 commits intomainfrom
ecnal-cienet wants to merge 10 commits intomainfrom
Conversation
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests - Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils - Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py - Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
Part 1 — sharding diagnostics: - maxtext_utils.py: extend print_shardings_params to support NNX (nnx.State input) - run_sharding_dump.py: add --pure_nnx flag Part 2 — post-training bugfixes (NNX-side): - models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the whole object as multimodal_input= kwarg; NNXDecoder only accepts the individual image/audio/mask fields) - optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams (callable() check before invoking learning_rate_fn) - train_distill.py / train_sft.py / train_rl.py: avoid nesting nnx.value_and_grad inside nnx.jit (Tunix's default trainer), which raises "graph structure of a node added to cached_partial was mutated" — refactor to jax.value_and_grad with explicit nnx.split / nnx.merge; train_rl.py also adds with_sharding_constraint + dtype-cast compat shims for jax 0.9 / tpu_inference Linen<->NNX checkpoint conversion utility and validation tool moved to a follow-up PR (PR4.5) to keep this change reviewable.
Bidirectional Linen <-> NNX checkpoint conversion. Same on-disk shape
both directions; round-trips preserve byte values.
Top-level key mapping:
- Linen params/params/<model> <-> NNX model/<model> (double-nesting,
{value:} wrappers).
- Linen opt_state <-> NNX optimizer/opt_state (params level on mu/nu).
- Linen step <-> NNX optimizer/step.
Layer structure:
- scan_layers=True (default): stack layers_N -> layers tensor.
- scan_layers=False: rename layers_N -> integer-keyed layers/{N}.
NNX->Linen direction auto-detects which layer layout the source uses.
--direction=auto picks Linen vs NNX from top-level keys.
Pure utility addition. No production-code dependencies; PR5+ do not
depend on this branch. Comparison utility split into PR4.6.
a8a53c3 to
e015455
Compare
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
Bug fixes (run as no-op while pure_nnx=False stays default):
- nnx_wrappers.py: add _refresh_variable_trace_state + is_linen_initializing;
call from ToLinen after nnx.update to fix "Cannot extract graph node from
different trace level" when grad tracers leak into Variable._trace_state.
- gpt_oss.py / olmo3.py: replace inline nn.Dropout(...) with self.dropout =
linears.Dropout(...) in __init__ to fix CallCompactUnboundModuleError.
- normalizations.py: Qwen3NextRMSNorm signature: eps -> epsilon, accept
shard_mode/kernel_axes/parameter_memory_host_offload for callsite parity.
- attentions.py / qwen3.py: callsites eps= -> epsilon=.
- moe.py: per_expert_scale block moved into the unfused-kernel else branch
(was scaling wo even when fused_kernel was active).
- models.py: build MTP block as MultiTokenPredictionBlock(...) directly
(drop the ToNNX(linen) + lazy_init wrap); pass multimodal_input whole
to NNXDecoder instead of unpacking 5 fields.
- gradient_accumulation.py: ZeRO-1+GA all-reduce annotation deferred until
after lax.scan (reduced/unreduced PartitionSpec is rejected inside scan
carry); use nnx.merge(..., copy=True) to avoid Variable reuse.
- diloco.py: NNX-aware state handling — state.params -> state.model.filter
(nnx.Param), step counter at state.optimizer.step, replace_nnx_model_params
helper for jax.lax.cond pytree-structure parity.
- train_compile.py: new _collect_nnx_activation_shardings helper (forward
pass populates _ACTIVATION_SHARDINGS_DUMP — get_abstract_state_nnx only
traces __init__); NNX path now passes 2-arg shaped_train_args (no rng);
diloco path patched to handle the 2-vs-3 length difference.
- muon_utils.py: get_model_mdn default pure_nnx=True; wrap NNX result as
{"params": nnx.to_pure_dict(...)} for parity with Linen tree shape.
- nnx_decoders.py: FP8+NNX scan fix — Linen FP8 ops (fp8_nanoo, fp8_gpu)
retain tracers in Linen scope across re-traces. Skip jax.checkpoint and
use a Python for-loop instead of jax.lax.scan when quantization is FP8.
Makes FP8 quantization usable on the NNX path.
- train.py (pre-train train_step): return nnx.state(new_state, nnx.Not
(nnx.Intermediate)) so sowed forward-pass artifacts (e.g. max_logits for
QK-Clip) don't break leaf-count parity with state_mesh_shardings.
- llama2.py: pass parameter_memory_host_offload to pre_self_attention_layer
_norm RMSNorm (was missing on this norm only).
- base.yml: add 4 pipeline-related logical_axis_rules — layers_outside
_pipeline, layers_per_stage, num_activations, circular_repeats. Additive,
no-op without use_nnx_pipeline=True.
NNX feature enablements (clear all 17 "Pure NNX support has not been
implemented yet" NotImplementedError sites by routing Linen-coupled
utilities to the Linen path; their on-disk format is Linen):
- layerwise_quantization.py (2 sites): operates on Linen-format checkpoints
via DeepSeek*ToLinen layers.
- lora_utils.py (1 site): downstream get_lora_abstract_state expects Linen
tree shape; LoRA adapters on disk are Linen.
- standalone_checkpointer.py (2 sites): add_entropy_to_checkpoint accesses
state.opt_state[0]._replace(mu=..., nu=...) — Linen-only.
- generate_param_only_checkpoint.py (3 sites): _possibly_unroll_params and
_save_decode_checkpoint use state.params["params"]["decoder"] — Linen.
- convert_gpt3_ckpt_from_paxml.py (2 sites): keystr_map targets Linen tree
paths (.params['params'], .opt_state.mu['params']).
- maxengine.py (3 sites): inference engine uses state.params and serves
Linen-format inference checkpoints.
- grpo_trainer.py (4 sites): RL trainer is end-to-end Linen-shaped; route
to Linen with a clear log warning since NNX-format checkpoints will fail
at restore time.
Vocab tiling on NNX (real implementation, not just routing):
- models.py: add Transformer.logits_from_hidden_states on the NNX
Transformer class — wraps NNXDecoder.apply_output_head with the
token_embedder; mirrors TransformerLinenPure.logits_from_hidden_states.
- vocabulary_tiling.py: add vocab_tiling_nnx_loss — chunks the vocab axis
via jax.lax.scan and calls model.logits_from_hidden_states(chunk) per
chunk. The NNX model carries its parameters internally so no explicit
FSDP gather is needed (unlike the Linen gathered_params pattern). MVP
uses default autograd; custom_vjp memory-savings optimization is a
follow-up if backward memory becomes a concern.
- train.py (NNX loss_fn): replace the NotImplementedError with the call
to vocab_tiling_nnx_loss using hidden_states from intermediates.
- pyconfig_deprecated.py / configs/types.py: drop the num_vocab_tiling > 1
and enable_nnx validation guards (no longer needed).
DPO + NNX retained as NotImplementedError but with a much more informative
message (points users at pure_nnx=False workaround). Full implementation
is deferred — needs a new TrainState shape carrying both policy and
reference NNX models plus an NNX dpo_loss_fn.
Stats: 26 source files modified, +406 / -171 lines. Linen invariant
verified: pure_nnx / enable_nnx / pure_nnx_decoder still default to False;
Linen-path UTs unaffected (3 pre-existing failures on the parent branch
remain unchanged — sharding_compare_test::deepseek2-16b,
optimizers_test::test_model_integration_kimi-k2-1t, diloco_test::two
_slices x2). All "Pure NNX support has not been implemented yet"
NotImplementedError sites cleared (was 17, now 0).
Implements NNX-native DPO so that the pure_nnx=True training path no longer raises NotImplementedError on use_dpo runs. The Linen DPO overlay pattern (model.apply(params=..., reference_params=...)) does not translate to NNX modules, which carry their parameters internally. Instead the policy and reference models are held as separate nnx.Module instances on TrainStateNNX, and a new dpo_loss_fn_nnx runs both forwards with stop_gradient on the reference logits. TrainStateNNX: - Add optional `reference_model: nnx.Module` field. apply_gradients continues to update only `self.model`, leaving `self.reference_model` bit-identical across steps. dpo_utils.py: - Add dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True). Signature mirrors the Linen dpo_loss_fn so it slots into gradient_accumulation_loss_and_grad's dispatcher (dropout_rng / params slots are unused for NNX; carried for parity, and reference_model is passed as the single extra_dpo_args entry). With nnx.value_and_grad(..., argnums=0) over the policy, no gradient flows to the reference model's nnx.Param leaves; the explicit jax.lax.stop_gradient on ref_logits is a belt-and-braces guard. - Both dpo_loss_fn (Linen) and dpo_loss_fn_nnx (NNX) now include indexer_loss=0.0 and mtp_loss=0.0 in aux so the gradient_accumulation aux pytree shape matches the non-DPO loss_fn. train.py: - Drop the NotImplementedError in train_step's NNX branch. When use_dpo, dispatch to dpo_loss_fn_nnx with state.reference_model as extra_dpo_args; otherwise use the regular loss_fn. eval_step gains the same dispatch. - diff_wrapper picks _loss_fn / extra_dpo_args from the per-path init block, so both the GA and non-GA NNX paths route DPO identically. - Checkpoint-save _split_dpo_state stripping is now Linen-only; TrainStateNNX saves whole (reference_model included) — the step-0 reload later overwrites reference_model from the step-0 checkpoint. train_utils.py: - NNX init_state_fn materializes a frozen reference_model alongside the policy when config.use_dpo. Both are constructed by _create_model_partial() with config.init_weights_seed, so they start identical (standard DPO practice) until the step-0 reload. - Step-0 checkpoint reload: copy step0_state["model"] into state["reference_model"]. Linen path unchanged. Tests: - New tests/unit/dpo_nnx_test.py (7 tests): TrainStateNNX reference_model init/hasattr semantics; apply_gradients leaves reference bit-identical; aux key set; identical policy/reference yields loss=log(2) and reward_accuracy=0.0 (strict > on equal logratios); dropout_rng/params slots are signature-compat only; nnx.value_and_grad(argnums=0) over the policy yields finite grads on policy params only. - train_nnx_test.py: drop the two stale negative tests (vocab_tiling_raises_not_implemented, train_step_dpo_raises_for_nnx) — both features are now real. Stats: 4 source files + 2 test files, +199/-22 source lines. Linen DPO path behaviorally unchanged (only adds two harmless aux-dict keys); NNX non-DPO path unchanged (all changes gated on config.use_dpo).
…e.py)
PR5 audited maxengine.py and routed the inference path to the Linen
implementation regardless of pure_nnx, with a comment block explaining
that "the flag affects training, not inference serving." That kept the
Linen serving path unchanged but meant pure_nnx=True users silently got
the Linen engine. This change replaces the route with a real NNX flow:
when config.pure_nnx=True, the engine builds an NNX Transformer, splits
out (params, cache, rest) with nnx.split, and at every JIT body merges
the model concretely with nnx.merge to run the forward pass. Linen is
preserved byte-for-byte; every NNX edit is gated `if config.pure_nnx:`
and pure_nnx=False is still the default.
maxengine.py (__init__):
- Build two abstract NNX Transformers on the NNX path: self.model with
model_mode=PREFILL (batch=1, single padded prompt) and self.model_ar
with model_mode=AUTOREGRESSIVE (batch=micro_batch_size_to_train_on,
decode_state shape). Both are needed because NNX cache vars inherit
CACHE_BATCH_PREFILL vs CACHE_BATCH from the construction model_mode,
and bulk_insert searches for the substring "cache_batch" in the
AR-mode logical-axes tuple. nnx.eval_shape is called directly inside
nn_partitioning.axis_rules rather than through create_nnx_abstract_model
to avoid the jax.set_mesh wrap that trips Flax 0.12.6 on logical-only
axes like "norm" (same reason get_abstract_state_nnx avoids set_mesh).
- Cache the graphdef from a 3-way nnx.split(Param, Cache, ...) so JIT
bodies can pass (params, cache, rest) separately to nnx.merge. The
rest slot (RNG vars etc.) is materialized concretely in load_params.
maxengine.py (cache adapter + _nnx_run_model):
- bulk_insert / _insert_jit / _maybe_*_prefill_result_cache walk the
cache via tree_map_with_path and switch on path[-1].key (the cache
variable name like "cached_prefill_key"). Linen mutable cache is a
plain nested dict. NNX Cache state would expose a ".value" accessor
at that position. Bridge via nnx.State.to_pure_dict() (after the
model run) and nnx.replace_by_pure_dict (before nnx.merge), so the
cache plumbing helpers see the same shape on both paths.
- Add _nnx_run_model: nnx.merge(graphdef, params, cache, rest, copy=True)
-> model(...) -> nnx.state(model, nnx.Cache).to_pure_dict(). copy=True
avoids reusing Variable objects across traces (TraceContextError),
mirroring train.py's diff_wrapper workaround.
- Add _nnx_cache_state_template / _nnx_init_cache_dict helpers
parametrised by mode so prefill (batch 1) and decode_state (batch N)
pull from the right abstract model.
maxengine.py (load_params):
- New _load_params_nnx: accepts user-provided NNX-shape params or loads
via from_pretrained. For user-provided params, materializes a concrete
model once via _create_model_fn() to capture a real rest state for
nnx.merge (wasteful but simple; the from_pretrained branch avoids
this). Refreshes self.graphdef from the concrete model so subsequent
merges line up exactly.
- Builds self.abstract_params, populates self.prefill_kv_cache_annotations
and self.kv_cache_annotations (using model_ar for the latter so
bulk_insert's substring lookup hits), wraps both into NamedSharding.
- pure_nnx + quantization, pure_nnx + LoRA, pure_nnx +
stack_prefill_result_cache=True, pure_nnx + prefill_multisampling,
and pure_nnx + prefill_concat raise NotImplementedError for now;
the Linen path is the workaround. AOT compilation
(aot_compile / _compile_generate_and_get_layouts) is not gated and
may work as-is; not exercised by tests yet.
maxengine.py (init_decode_state, _prefill_jit, _generate_jit):
- _init_decode_state_nnx zero-initializes a pure-dict cache from
model_ar (so the leading batch dim matches generate's input shape)
and builds kv_cache_annotations_named per leaf by reading
nnx.Cache.metadata. Tries "out_sharding", "sharding", and
"sharding_names" because Flax 0.12.6 renamed these.
- _prefill_jit / _generate_jit add an `if config.pure_nnx:` branch
that calls _nnx_run_model in place of self.model.apply with
mutable=["cache"]. existing_prefix.cache is threaded as a pure-dict
cache directly (no params|{"cache":...} dict-merge — params is an
nnx.State, not a dict).
maxtext_utils.py:
- New get_prefill_kv_cache_annotations_nnx / get_kv_cache_annotations_nnx
that mirror the Linen helpers' return shape (per-leaf PartitionSpec
tree). Both delegate to _nnx_cache_partition_specs which extracts
nnx.Cache state via nnx.split, calls
get_nnx_named_sharding_with_scan_axis inside
nn_partitioning.axis_rules so logical axes ("layers", "cache_batch",
"norm", ...) resolve to physical mesh axes, and converts the result
to a pure-dict tree.
tests/unit/maxengine_test.py:
- New tests: test_init_nnx, test_basic_prefill_nnx (with NaN/inf and
per-layer cache shape checks), test_basic_decode_nnx (4-step generate
with next_pos advancement check), test_quantize_raises_for_nnx,
test_lora_raises_for_nnx.
- New test_linen_nnx_parity_prefill: bridges Linen-init params into
the NNX engine via linen_nnx_converter (convert_linen_to_nnx ->
_strip_value_wrappers -> nnx.replace_by_pure_dict) and asserts the
NNX engine's prefill matches Linen on the same weights — logits
within bf16 tolerance (rtol=0.05, atol=0.1; the test config uses
bf16 compute) and exact greedy first-token argmax.
- Existing Linen tests untouched.
Test summary: 9 passed, 1 skipped (test_chunked_prefill is a
pre-existing CPU-only skip). bash lint.sh: codespell + pylint + pyink
all green.
Closes the QK-Clip TODO and migrates the remaining Linen-only checkpoint utilities to NNX. Linen paths preserved byte-for-byte (every NNX edit is gated on `config.pure_nnx` or runtime state-shape detection). QK-Clip: - qk_clip_utils.apply_qk_clip_nnx mutates state.model in place via nnx.split -> pure-dict tree_map -> nnx.replace_by_pure_dict -> nnx.update. Accepts both the production NNX intermediate shape (self_attention.attention_op.max_logits) and the synthetic-fixture shape from the existing Linen tests (self_attention.max_logits). - train.py train_step dispatches to apply_qk_clip_nnx for NNX, removing the prior TODO at the QK-Clip call site. Checkpoint utilities (NNX paths added): - standalone_checkpointer.checkpoint_loop builds an NNX init_state_fn under pure_nnx; add_entropy_to_checkpoint dispatches across Linen TrainState, NNX TrainStateNNX Module, and post-split nnx.State shapes. - generate_param_only_checkpoint: NNX init_state_fn under pure_nnx; _possibly_unroll_params_nnx slices scanned NNX layers via dict-style state mutation; _save_decode_checkpoint_nnx writes a bf16 pure-dict tree to orbax. Parallel LoRA decode flow operates on the single-nested LoRA delta tree from PR8's get_lora_abstract_state_nnx. - convert_gpt3_ckpt_from_paxml: parallel NNX state_map keystr translation (.params['params']<rest> -> .model<rest>.value, etc.). End-to-end paxml -> NNX conversion is wired but not yet validated on hardware. Tests: - qk_clip_test: 7 new NNX cases covering attention-type guard, MLA wq_b/wkv_b math, both intermediate shapes, no-clip-below-threshold, missing-stats resilience, Linen<->NNX numeric parity. - standalone_checkpointer_nnx_test (new): 3 cases for adam mu/nu overwrite on TrainStateNNX Module shape, no mutation of state.model params, post-split nnx.State shape from setup_training_state. - generate_param_only_checkpoint_nnx_test (new): 3 cases for scanned layer slicing (Llama-style; DeepSeek-style dense+moe split; LoRA delta unroll on the single-nested NNX shape). NNX + AQT in MaxEngine and the layerwise_quantization NNX path are split into the follow-up PR9.5.
Builds on PR9. Migrates the NNX + AQT integration so MaxEngine can both
load pre-quantized checkpoints directly and convert full-precision
checkpoints to int8 on load. Also bundles a pre-existing gpt3 prefill
bug surfaced by the AQT end-to-end validation.
NNX + AQT in MaxEngine:
- model_creation_utils threads quant_mode_str ("train" | "convert" |
"serve") through from_config / create_model /
get_nnx_create_model_fn / create_nnx_abstract_model /
from_pretrained. Default "train" preserves existing callers; "serve"
propagates to configure_quantization so AQT layers don't materialize
the full-precision kernel when the on-disk checkpoint already
carries qrhs scale factors.
- maxengine.__init__ selects the quant mode from
config.checkpoint_is_quantized; _load_params_nnx drops its
NotImplementedError. Two paths: pre-quantized
(checkpoint_is_quantized=True) loads via quant_mode_str="serve";
full-precision + quantization=int8 loads in TRAIN mode and AQT
layers quantize per-forward (same numerical result for absmax
calibration).
- layerwise_quantization._load_and_quantize_nnx: whole-model NNX
convert path. Loads full-precision in TRAIN mode, transfers kernels
into a CONVERT-mode model, runs forward to populate qrhs.frozen via
the ToNNX(AqtDotGeneral) bridge, strips kernels at quantized paths,
saves serve-mode-shaped state.
Sharding helpers and from_pretrained QTensor handling (5 chained fixes
that kept the serve-mode reload from working):
- maxtext_utils.get_nnx_named_sharding_with_scan_axis emits a
parallel-tree of replicated NamedSharding leaves when a Variable's
value is a composite pytree (AQT serve-mode QTensor with a qvalue
int8 leaf and a list of bf16 scale leaves).
- model_creation_utils.from_pretrained: drops a redundant
jax.set_mesh wrap in create_nnx_abstract_model that broke serve-mode
AQT under Flax 0.12.6. _build_value_target / _free_device_memory /
_unwrap_for_align use Variable.get_value() instead of v[...]
indexing for QTensor leaves (QTensor.__getitem__ trips on the
LogicallyPartitioned wrapper around qvalue). Widens the restore
filter beyond nnx.Param to cover the aqt-typed qrhs.frozen Variable
type. Skips QTensor leaves in the per-axis shape-alignment dispatch
(their saved shape already matches the model). _build_value_target
strips Partitioned wrappers around composite-leaf values so the
restore tree path matches the on-disk layout (LogicallyPartitioned
was adding an extra .value key under each QTensor leaf, which made
orbax silently fill the path with zero-init values).
gpt3 prefill / autoregressive fix (pre-existing, surfaced here):
- Gpt3MultiHeadAttention.__call__ invoked attention_op(...) without
ever calling update_kv_caches to build cached_values, so any
non-TRAIN forward (prefill or autoregressive) tripped the
`assert prefill_kv_cache` check. Mirror the standard Attention
plumbing in attentions.py: __init__ constructs a KVCache_0 module
when model_mode != MODEL_MODE_TRAIN, threads
max_prefill_predict_length into AttentionOp; __call__ calls
self.KVCache_0(...) and passes [prefill_kv_cache, ar_kv_cache] as
cached_values to attention_op. TRAIN-mode shape unchanged.
Tests:
- layerwise_quantization_nnx_test (new): 3 cases for
_strip_kernels_at_quantized_paths covering quantized removal,
non-quantized preservation (norms, embeddings), mixed-shape trees.
- aqt_serve_roundtrip_nnx_test (new): end-to-end regression test that
builds a small NNX model in CONVERT mode with int8, runs a forward
to populate qrhs.frozen via the ToNNX bridge, saves the
serve-mode-shape state to a tmp local orbax checkpoint, reloads via
from_pretrained(quant_mode_str="serve"), and asserts every saved
qrhs.frozen.qvalue array byte-matches what came back. Guards the
full chain of QTensor / Partitioned / filter fixes.
- maxengine_test: replaced test_quantize_raises_for_nnx with
test_quantize_passes_gate_for_nnx; added
test_load_pre_quantized_nnx_passes_quant_gate and
test_quantized_prefill_nnx_train_mode (real numerical verification
with quantization=int8 + random params + TRAIN mode).
End-to-end on TPU (gpt3-52k): convert-mode forward + qrhs.frozen
extraction + serve-mode-shape save + reload via
from_pretrained(quant_mode_str="serve") + maxengine.load_params +
quantized prefill forward all work; loaded qrhs.frozen.qvalue
byte-matches the on-disk state.
Replaces the PR9.5 NNX vocab-tiling MVP (chunked forward + default
autograd backward) with a jax.custom_vjp that mirrors the Linen path's
backward-memory savings, then carves out the output-head params so the
custom_vjp's residuals + grad accumulator scale with LM-head size, not
with the full model. Linen vocab_tiling_linen_loss is byte-for-byte
unchanged. Call sites in train.py / pyconfig_deprecated.py /
configs/types.py are unchanged.
Custom_vjp + output-head carve-out (vocabulary_tiling.py):
- Outside the custom_vjp: 3-way nnx.split with a callable path filter
(_is_output_head_param_path) matching {token_embedder,
shared_embedding, decoder_norm, logits_dense} — the only nnx.Param
paths apply_output_head touches. Returns (graphdef, head_params,
other_params, rest).
- Custom_vjp primals: (head_params, other_params, rest, hidden_states,
labels, segmentation). Only head_params and hidden_states are
differentiated; other_params + rest are threaded through as
non-differentiated primals so their tracers don't have to cross both
the custom_vjp and the inner lax.scan boundary (which previously
caused UnexpectedTracerError under logits_via_embedding=True).
- Forward (_chunked_cross_entropy_loss_fwd): reshapes to
(num_vocab_tiling, vocab_tile_size, ...) and runs lax.scan whose body
rebuilds the model per chunk via nnx.merge(graphdef, chunk_head,
chunk_other, chunk_rest, copy=True) and calls
logits_from_hidden_states. Initial scan accumulator is fp32 (was
hidden_states.dtype previously — caused a lax.scan carry dtype
mismatch with bf16 hidden_states since cross_entropy_with_logits
always returns fp32). Residuals are (chunk_head, chunk_other,
chunk_rest, reshaped_*, batch/seq/emb).
- Backward (_chunked_cross_entropy_loss_bwd): a second lax.scan whose
body builds loss_fn_for_vjp = lambda p, h: ..., calls
jax.vjp(loss_fn_for_vjp, chunk_head_params, hidden_chunk),
accumulates grad_head via tree.map(add), emits per-chunk grad_hidden.
Chain-rules grad_head *= loss_cotangent and dtype-casts to each
primal's dtype (custom_vjp requires this). chunk_other_params and
chunk_rest cotangents are explicit tree_map(jnp.zeros_like, ...) zero
pytrees, NOT None — None makes JAX synthesize zeros at AOT trace time
with axis-0 stacking (jax.scan convention) for nnx.scan-stacked
transformer-layer params, which carry axis-1 stacking (nnx
convention), and the cotangent-shape check fails as
"Expected cotangent type bfloat16[E,M] for primal type bfloat16[E,M],
but got bfloat16[L,E,M]". Materializing the zeros ties the cotangent
shape to the primal shape exactly.
- Correctness: logits_from_hidden_states provably depends only on
head_params; the gradient w.r.t. other_params through this loss is
exactly zero. When train.py also calls the full model forward (which
produces hidden_states), transformer-layer gradients flow back
through grad_hidden_states → outer backward, unaffected by the
carve-out.
Supporting fixes (touched for the carve-out to work end-to-end):
- nnx_decoders.py::apply_output_head logits_via_embedding=True branch
reads embedding_table = shared_embedding.embedding[...] instead of
the deprecated .value shim. The .value shim registers the access in
NNX mutation tracking, which JAX detects as a tracer leak when the
embedding is closure-captured / threaded across the custom_vjp +
lax.scan boundaries. The Linen branch is unchanged.
- models.py: deletes dead-code self.hidden_states = None and
if num_vocab_tiling > 1: self.hidden_states = hidden_state from the
NNX Transformer class. Two lines left over from an early PR5
implementation idea — neither path actually reads
model.hidden_states (Linen reads via mutable=["intermediates"]; NNX
reads via nnx.pop(model, nnx.Intermediate) from the decoder's sown
("decoder", "hidden_states") intermediate). Without this fix, AOT
compile under pure_nnx=True + num_vocab_tiling>1 raised
ValueError: Cannot assign data value of type 'LinearizeTracer' to
static attribute 'hidden_states' of Pytree type 'Transformer' —
would have silently broken any post-PR11 user with vocab tiling on.
Tests (tiling_test.py — new VocabTilingNNXTest class with 9 TPU tests):
- test_nnx_vocab_tiling_non_tied_embedding / _tied_embedding: loss +
grad parity vs. full-vocab xent reference for both LM-head modes.
- test_nnx_vocab_tiling_total_z_loss_value_parity: asserts the second
tuple element matches the reference (was untested before).
- test_nnx_vocab_tiling_padded_segmentation: half-padded mask;
exercises the segmentation != 0 mask branch and asserts padded loss
is strictly less than unpadded.
- test_nnx_vocab_tiling_grad_over_hidden_states: argnums=1
differentiation; exercises the custom_vjp's second-primal cotangent
path (grad_reshaped_hidden_states), shape + dtype + value parity.
- test_nnx_vocab_tiling_bf16_hidden_states: bf16 inputs with rtol/atol
loosened to 5e-2; asserts grad_h.dtype == bf16 (the bwd dtype-cast
preserves the primal's dtype). Caught the fp32-accumulator bug.
- test_nnx_vocab_tiling_z_loss_zero: z_loss_multiplier=0;
total_z_loss == 0.0 exactly and grad parity holds.
- test_nnx_vocab_tiling_num_vocab_tiling_variants: runs n ∈ {2, 4, 8}
and asserts identical loss + grads (catches off-by-one in
vocab_tile_size and scan/reshape interactions).
- test_nnx_vocab_tiling_other_params_get_zero_grad (carve-out
invariant): asserts every non-head leaf has gradient exactly zero
AND at least one head leaf has non-zero gradient (so the test can't
trivially pass by zeroing everything). Catches filter bugs (e.g.
forgetting that NNX names the embedder token_embedder while Linen
names it shared_embedding) and bwd zero-shape bugs.
AOT compile coverage (train_compile_test.py):
- Removed the now-stale pytest.skip("Vocab tiling not supported on
NNX.") in test_vocab_tiling_bf16.
- Added test_vocab_tiling_bf16_nnx (cpu_only): AOT-compiles the train
step under pure_nnx=true + enable_nnx=true + pure_nnx_decoder=true
with num_vocab_tiling=4 and weight_dtype=bfloat16. Surfaced both the
models.py dead-code regression and the cotangent-axis-ordering issue
the explicit-zeros bwd fixes.
Tests pass: 18 in tiling + AOT (7 Linen UTs + 9 NNX UTs + 2 AOT, one
Linen and one NNX); 52 in adjacent NNX surfaces (train_nnx, dpo_nnx,
grpo_nnx, lora_utils_nnx, maxengine, qk_clip, aqt_serve_roundtrip_nnx)
— regression check for the nnx_decoders.py change.
e015455 to
54fa362
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
NNX Migration Route Map
pure_nnxflag,init_state_fn,TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)get_abstract_state_nnx,get_named_sharding_nnx,set_named_sharding_nnx,get_partition_spec_nnx,get_mesh_from_config. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)4.5. ✅ Linen↔NNX checkpoint converter. (PR [NNX] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter #3843)
4.6. ✅ Linen↔NNX checkpoint comparator. (PR [NNX] NNX migration prep (4.6/N): Linen<->NNX checkpoint comparator #3846)
9.5. ✅ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix. (PR [NNX] NNX migration prep (9.5/N): NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix #3844)
custom_vjpfor NNX (with output-head carve-out, originally scoped as a follow-up PR10.5 — bundled in here).True; regenerate sharding goldens; flip back integration-testpure_nnx=Falseannotations.Description
Replaces the PR5 NNX vocab-tiling MVP (chunked forward + default autograd backward) with a
jax.custom_vjpthat mirrors the Linen path's backward-memory savings, then carves out the output-head params so thecustom_vjp's residuals + grad accumulator scale with LM-head size, not with the full model. PR10.5 was originally scoped as a separate follow-up but is bundled here — the carve-out and its supporting fixes (nnx_decoders.py.valuemodernization, explicit zero cotangents in the bwd) are necessary for the NNX path to AOT-compile correctly undernum_vocab_tiling > 1 + pure_nnx=True.The Linen
vocab_tiling_linen_lossis byte-for-byte unchanged. Call sites intrain.py,pyconfig_deprecated.py, andconfigs/types.pyare unchanged.Diff: +517 / −44 across 5 files (3 src, 2 test).
What it does
src/maxtext/utils/vocabulary_tiling.py—vocab_tiling_nnx_lossrewrite:nnx.split— outside thecustom_vjp, a callable path filter (_is_output_head_param_path) matching{token_embedder, shared_embedding, decoder_norm, logits_dense}separateshead_params(the onlynnx.Parampathsapply_output_headtouches) fromother_params(transformer layers, etc.) andrest(rngs).(head_params, other_params, rest, hidden_states, labels, segmentation). Onlyhead_paramsandhidden_statesare differentiated;other_params + restare threaded through as non-differentiated primals so their tracers don't have to cross both thecustom_vjpand the innerlax.scanboundary (which previously causedUnexpectedTracerErrorunderlogits_via_embedding=True).(num_vocab_tiling, vocab_tile_size, ...), runlax.scanwhose body rebuilds the model per chunk viannx.merge(graphdef, chunk_head, chunk_other, chunk_rest, copy=True)and callslogits_from_hidden_states. Initial scan accumulator is fp32 (washidden_states.dtypepreviously — caused alax.scancarry dtype mismatch with bf16 hidden_states sincecross_entropy_with_logitsalways returns fp32). Residuals are(chunk_head, chunk_other, chunk_rest, reshaped_*, batch/seq/emb).lax.scanwhose body buildsloss_fn_for_vjp = lambda p, h: ..., callsjax.vjp(loss_fn_for_vjp, chunk_head_params, hidden_chunk), accumulatesgrad_headviatree.map(add), emits per-chunkgrad_hidden. Chain-rulesgrad_head *= loss_cotangentand dtype-casts each grad back to its primal's dtype (custom_vjprequires this).chunk_other_paramsandchunk_restcotangents are explicittree_map(jnp.zeros_like, ...)zero pytrees, NOTNone—Nonemakes JAX synthesize zeros at AOT trace time with axis-0 stacking (jax.scanconvention) fornnx.scan-stacked transformer-layer params, which carry axis-1 stacking (NNX convention), and the cotangent-shape check fails (Expected cotangent type bfloat16[E,M] for primal type bfloat16[E,M], but got bfloat16[L,E,M]). Materializing the zeros ties the cotangent shape to the primal shape exactly.logits_from_hidden_statesprovably depends only onhead_params; the gradient w.r.t.other_paramsthrough this loss is exactly zero. Whentrain.pyalso calls the full model forward (which produceshidden_states), transformer-layer gradients flow back throughgrad_hidden_states→ outer backward, unaffected by the carve-out.src/maxtext/layers/nnx_decoders.py—apply_output_headlogits_via_embedding=Truebranch readsembedding_table = shared_embedding.embedding[...](modern NNXVariable[Array]API) instead of the deprecated.valueshim. The.valueshim registers the access in NNX mutation tracking, which JAX detects as a tracer leak when the embedding is closure-captured / threaded across thecustom_vjp + lax.scanboundaries. The Linen branch (shared_embedding.variables["params"]["embedding"]) is unchanged.src/maxtext/models/models.py— deletes dead-codeself.hidden_states = Noneandif self.config.num_vocab_tiling > 1: self.hidden_states = hidden_statefrom the NNXTransformerclass. Two lines left over from an early PR5 implementation idea — neither path actually readsmodel.hidden_states(Linen reads viamutable=["intermediates"]; NNX reads viannx.pop(model, nnx.Intermediate)from the decoder's sown("decoder", "hidden_states")intermediate). Without this fix, AOT compile underpure_nnx=True + num_vocab_tiling>1raisedValueError: Cannot assign data value of type 'LinearizeTracer' to static attribute 'hidden_states' of Pytree type 'Transformer'— would have silently broken any post-PR11 user with vocab tiling on.Tests
tests/unit/tiling_test.py— newVocabTilingNNXTestclass with 9 TPU tests:test_nnx_vocab_tiling_non_tied_embedding/_tied_embedding— loss + grad parity vs. full-vocab xent reference for both LM-head modes.test_nnx_vocab_tiling_total_z_loss_value_parity— asserts the second tuple element matches the reference (was untested before).test_nnx_vocab_tiling_padded_segmentation— half-padded mask; exercises thesegmentation != 0mask branch and asserts padded loss is strictly less than unpadded.test_nnx_vocab_tiling_grad_over_hidden_states—argnums=1differentiation; exercises thecustom_vjp's second-primal cotangent path (grad_reshaped_hidden_states); shape + dtype + value parity.test_nnx_vocab_tiling_bf16_hidden_states— bf16 inputs with rtol/atol loosened to 5e-2; assertsgrad_h.dtype == bf16(the bwd dtype-cast preserves the primal's dtype). Caught the fp32-accumulator bug.test_nnx_vocab_tiling_z_loss_zero—z_loss_multiplier=0;total_z_loss == 0.0exactly and grad parity holds.test_nnx_vocab_tiling_num_vocab_tiling_variants— runsn ∈ {2, 4, 8}and asserts identical loss + grads (catches off-by-one invocab_tile_sizeand scan/reshape interactions).test_nnx_vocab_tiling_other_params_get_zero_grad— carve-out invariant: asserts every non-head leaf has gradient exactly zero AND at least one head leaf has non-zero gradient (so the test can't trivially pass by zeroing everything). Catches filter bugs (e.g. forgetting that NNX names the embeddertoken_embedderwhile Linen names itshared_embedding) and bwd zero-shape bugs.tests/unit/train_compile_test.py:pytest.skip("Vocab tiling not supported on NNX.")intest_vocab_tiling_bf16.test_vocab_tiling_bf16_nnx(@pytest.mark.cpu_only) which AOT-compiles the train step underpure_nnx=true + enable_nnx=true + pure_nnx_decoder=truewithnum_vocab_tiling=4andweight_dtype=bfloat16. Surfaced both themodels.pydead-code regression and the cotangent-axis-ordering issue the explicit-zeros bwd fixes.Existing tests untouched.
Stats
vocab_tiling_nnx_lossrewritten (replaces PR5's MVP).nnx_decoders.py::apply_output_head1-line modernization affects every NNX run withlogits_via_embedding=True.models.py2-line dead-code deletion has no observable behavior change (no readers).vocab_tiling_linen_lossbyte-for-byte unchanged.nnx_decoders.pyLinen branch unchanged.TransformerLinenPureunchanged.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.