diff --git a/src/maxtext/checkpoint_conversion/linen_nnx_converter.py b/src/maxtext/checkpoint_conversion/linen_nnx_converter.py new file mode 100644 index 0000000000..015d3b5a56 --- /dev/null +++ b/src/maxtext/checkpoint_conversion/linen_nnx_converter.py @@ -0,0 +1,581 @@ +# Copyright 2023-2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bidirectional conversion between Linen and NNX checkpoint formats. + +Top-level key mapping: + Linen → NNX: + params/params/ → model/ (remove double-nesting, rename, add {value:} wrappers) + opt_state → optimizer/opt_state (remove 'params' level from mu/nu) + step → optimizer/step (move inside optimizer) + + NNX → Linen: + model/ → params/params/ (strip {value:} wrappers, add double-nesting) + optimizer/opt_state → opt_state (add 'params' level to mu/nu) + optimizer/step → step (move to top level) + +Layer structure (--scan_layers): + linen_to_nnx: + scan_layers=True (default): stack layers_N arrays → 'layers' tensor with layer dim at axis 1 + scan_layers=False: rename layers_N → integer-keyed 'layers/{N}' + + nnx_to_linen (auto-detected): + Stacked 'layers' tensor → unstack along axis 1 → layers_N per-layer arrays + Integer-keyed layers/{N} → rename to layers_N + +Usage: + python linen_nnx_converter.py \\ + --source_path="gs://bucket/checkpoint/0/items" \\ + --target_path="gs://bucket/converted/" \\ + --direction=auto +""" + +import argparse +import os +import re +import time +from typing import Any + +# MUST set before importing JAX to force CPU-only mode +os.environ["JAX_PLATFORMS"] = "cpu" + +import jax +import numpy as np +from etils import epath +import orbax.checkpoint as ocp + + +def log(message: str) -> None: + print(f"[linen_nnx_converter] {message}") + + +# ── Format detection ─────────────────────────────────────────────────────────── + + +def detect_format(state: dict) -> str: + """Detects checkpoint format ('linen' or 'nnx') from top-level keys.""" + # NNX: uses 'model' as the top-level params key + if "model" in state: + return "nnx" + + if "params" not in state: + raise ValueError(f"Cannot detect checkpoint format: no 'model' or 'params' key. " f"Found: {list(state.keys())}") + + params = state["params"] + + # Linen: double-nested params/params/decoder + if isinstance(params, dict) and "params" in params: + inner = params["params"] + if isinstance(inner, dict) and ("decoder" in inner or "encoder" in inner): + return "linen" + + # Old NNX format: params/decoder (single-nested with value wrappers) + if isinstance(params, dict) and ("decoder" in params or "encoder" in params): + if _has_value_wrappers(params): + return "nnx" + + if "optimizer" in state: + return "nnx" + if "opt_state" in state: + return "linen" + + raise ValueError( + f"Could not detect checkpoint format. Keys: {list(state.keys())}, " + f"params keys: {list(params.keys()) if isinstance(params, dict) else type(params)}" + ) + + +# ── Value wrapper helpers ────────────────────────────────────────────────────── + + +def _has_value_wrappers(tree: Any) -> bool: + """Returns True if tree contains {value: array} wrappers (NNX style).""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return True + for v in tree.values(): + if _has_value_wrappers(v): + return True + return False + + +def _strip_value_wrappers(tree: Any) -> Any: + """Recursively strips {value: array} wrappers from a tree.""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return inner + return {k: _strip_value_wrappers(v) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_strip_value_wrappers(item) for item in tree) + else: + return tree + + +def _add_value_wrappers(tree: Any) -> Any: + """Recursively wraps leaf arrays in {value: array} (NNX nnx.Param format).""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return tree # Already wrapped + return {k: _add_value_wrappers(v) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_add_value_wrappers(item) for item in tree) + elif hasattr(tree, "shape") or isinstance(tree, np.ndarray): + return {"value": tree} + else: + return tree + + +# ── Layer structure helpers ──────────────────────────────────────────────────── + + +def _stack_layers(decoder: dict) -> tuple[dict, bool]: + """Stacks per-layer parameters (layers_N) into a single 'layers' dict at axis 0. + + Returns (result_dict, was_stacked). + """ + layer_pattern = re.compile(r"^layers_(\d+)$") + layer_indices = {} + other_keys = {} + + for key, value in decoder.items(): + match = layer_pattern.match(key) + if match: + layer_indices[int(match.group(1))] = value + else: + other_keys[key] = value + + if not layer_indices: + return decoder, False + + sorted_indices = sorted(layer_indices.keys()) + num_layers = len(sorted_indices) + log(f" Found {num_layers} individual layers, stacking into 'layers'") + + def stack_arrays(layers_data: list) -> Any: + first = layers_data[0] + if hasattr(first, "shape") or isinstance(first, np.ndarray): + return np.stack([np.asarray(layers_data[i]) for i in range(len(layers_data))], axis=0) + elif isinstance(first, dict): + result = {} + for key in first.keys(): + child_data = [layers_data[i].get(key) for i in range(len(layers_data))] + if all(c is not None for c in child_data): + result[key] = stack_arrays(child_data) + return result + else: + return first + + layers_data = [layer_indices[i] for i in sorted_indices] + stacked = stack_arrays(layers_data) + + result = dict(other_keys) + result["layers"] = stacked + return result, True + + +def _rename_layers_to_integer_keys(decoder: dict) -> dict: + """Converts layers_N keys to integer-keyed dict under 'layers' (no stacking). + + Converts {layers_0: {...}, layers_1: {...}} → {layers: {'0': {...}, '1': {...}}}. + Used for scan_layers=False linen→nnx conversion (Pattern C). + """ + layer_pattern = re.compile(r"^layers_(\d+)$") + layer_indices = {} + other_keys = {} + + for key, value in decoder.items(): + match = layer_pattern.match(key) + if match: + layer_indices[int(match.group(1))] = value + else: + other_keys[key] = value + + if not layer_indices: + return decoder + + sorted_indices = sorted(layer_indices.keys()) + log(f" Found {len(sorted_indices)} individual layers, renaming to integer-keyed 'layers/N'") + result = dict(other_keys) + result["layers"] = {str(i): layer_indices[i] for i in sorted_indices} + return result + + +def _transpose_layers_axes(tree: Any, src_axis: int, dst_axis: int) -> Any: + """Transposes the layers dimension in arrays within a tree (src_axis ↔ dst_axis).""" + if src_axis == dst_axis: + return tree + if isinstance(tree, dict): + return {k: _transpose_layers_axes(v, src_axis, dst_axis) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_transpose_layers_axes(item, src_axis, dst_axis) for item in tree) + elif hasattr(tree, "shape") and len(tree.shape) >= 2: + axes = list(range(len(tree.shape))) + axes[src_axis], axes[dst_axis] = axes[dst_axis], axes[src_axis] + result = np.transpose(np.asarray(tree), axes=axes) + log(f" Transposed: {tree.shape} → {result.shape}") + return result + else: + return tree + + +def _detect_num_layers(tree: Any, scan_axis: int) -> int | None: + """Detects num_layers from the first array with ndim > scan_axis.""" + if hasattr(tree, "shape") or isinstance(tree, np.ndarray): + shape = getattr(tree, "shape", None) or np.asarray(tree).shape + if len(shape) > scan_axis: + return shape[scan_axis] + return None + if isinstance(tree, dict): + for v in tree.values(): + result = _detect_num_layers(v, scan_axis) + if result is not None: + return result + return None + + +def _unstack_single_layer(tree: Any, idx: int, scan_axis: int) -> Any: + """Extracts a single layer by indexing at scan_axis.""" + if hasattr(tree, "shape") or isinstance(tree, np.ndarray): + arr = np.asarray(tree) + if arr.ndim > scan_axis: + return np.take(arr, idx, axis=scan_axis) + return arr + if isinstance(tree, dict): + return {k: _unstack_single_layer(v, idx, scan_axis) for k, v in tree.items()} + if isinstance(tree, (list, tuple)): + return type(tree)(_unstack_single_layer(v, idx, scan_axis) for v in tree) + return tree + + +def _convert_layers_to_linen_format(decoder: dict) -> dict: + """Converts NNX 'layers' back to Linen's layers_N format (auto-detects NNX style). + + Handles: + - Stacked tensor (Pattern B): layers/ + → layers_0, layers_1, ... (unstack along axis 1) + - Integer-keyed (Pattern C): layers/0, layers/1, ... + → layers_0, layers_1, ... (rename) + """ + if "layers" not in decoder: + return decoder + + layers_val = decoder["layers"] + other_keys = {k: v for k, v in decoder.items() if k != "layers"} + + if not isinstance(layers_val, dict): + # Already a non-dict (shouldn't happen normally), keep as-is + return decoder + + # Pattern C: integer-keyed per-layer dict → rename + if all(k.isdigit() for k in layers_val.keys()): + result = dict(other_keys) + for idx_str, layer_data in sorted(layers_val.items(), key=lambda x: int(x[0])): + result[f"layers_{idx_str}"] = layer_data + log(f" Renamed integer-keyed layers/N → layers_N ({len(layers_val)} layers)") + return result + + # Pattern B: stacked tensor (layer dim at axis 1) → unstack + num_layers = _detect_num_layers(layers_val, scan_axis=1) + if num_layers is None: + log(" WARNING: Could not detect num_layers for unstacking, keeping 'layers' as-is") + result = dict(other_keys) + result["layers"] = layers_val + return result + + result = dict(other_keys) + for i in range(num_layers): + result[f"layers_{i}"] = _unstack_single_layer(layers_val, idx=i, scan_axis=1) + log(f" Unstacked scanned 'layers' → layers_N ({num_layers} layers at axis 1)") + return result + + +# ── Optimizer state helpers ──────────────────────────────────────────────────── + + +def _convert_opt_state_linen_to_nnx(opt_state: Any) -> Any: + """Removes 'params' nesting from mu/nu in linen opt_state. + + NNX optimizer state has plain arrays (no {value:} wrappers). + Linen opt_state mirrors the params structure (params/decoder/...), + so we remove the 'params' level to get decoder/... directly. + """ + if isinstance(opt_state, dict): + result = {} + for k, v in opt_state.items(): + if k == "params": + # Remove this level by merging its contents up + converted = _convert_opt_state_linen_to_nnx(v) + if isinstance(converted, dict): + result.update(converted) + else: + result[k] = converted + else: + result[k] = _convert_opt_state_linen_to_nnx(v) + return result + elif isinstance(opt_state, (list, tuple)): + return type(opt_state)(_convert_opt_state_linen_to_nnx(item) for item in opt_state) + else: + return opt_state # Plain array or scalar — no value wrapper for opt_state + + +def _convert_opt_state_nnx_to_linen(opt_state: Any, depth: int = 0) -> Any: + """Adds 'params' nesting to mu/nu, removes any stray {value:} wrappers. + + NNX optimizer mu/nu contains decoder/... directly. + Linen expects mu/params/decoder/... (one 'params' level mirroring the params structure). + """ + if isinstance(opt_state, dict): + # Strip any {value:} wrappers in opt_state (shouldn't be there but handle gracefully) + if set(opt_state.keys()) == {"value"}: + inner = opt_state["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return inner + + result = {} + for k, v in opt_state.items(): + converted = _convert_opt_state_nnx_to_linen(v, depth + 1) + # Add one 'params' level after mu/nu (mirrors linen's params structure) + if k in ("mu", "nu") and isinstance(converted, dict): + result[k] = {"params": converted} + else: + result[k] = converted + return result + elif isinstance(opt_state, (list, tuple)): + return type(opt_state)(_convert_opt_state_nnx_to_linen(item, depth + 1) for item in opt_state) + else: + return opt_state + + +# ── Main conversion functions ────────────────────────────────────────────────── + + +def convert_linen_to_nnx(state: dict, scan_layers: bool = True) -> dict: + """Converts Linen checkpoint to NNX format. + + Args: + state: Linen checkpoint dict with keys ['params', 'opt_state', 'step']. + scan_layers: If True (default), stack per-layer arrays and insert layer + dim at axis 1 (for NNX with scan_layers=True). + If False, rename layers_N → integer-keyed layers/N + (for NNX with scan_layers=False). + """ + result = {} + + if "params" in state: + linen_params = state["params"] + # Remove double 'params' nesting: params/params/decoder → decoder + if isinstance(linen_params, dict) and "params" in linen_params: + nnx_params = linen_params["params"] + log(" params: Removed double 'params' nesting (params/params → model)") + else: + nnx_params = linen_params + log(" params: No double nesting found") + + stripped = _strip_value_wrappers(nnx_params) + + for component in ("decoder", "encoder"): + if component in stripped and isinstance(stripped[component], dict): + if scan_layers: + stripped[component], was_stacked = _stack_layers(stripped[component]) + if was_stacked and "layers" in stripped[component]: + log(f" {component}/layers: Transposing stacked (layers, ...) → (..., layers, ...) at axis 1") + stripped[component]["layers"] = _transpose_layers_axes(stripped[component]["layers"], src_axis=0, dst_axis=1) + else: + stripped[component] = _rename_layers_to_integer_keys(stripped[component]) + + result["model"] = _add_value_wrappers(stripped) + log(" model: Saved with {value:} wrappers under 'model' key") + + # optimizer: move step inside, keep opt_state + optimizer_dict = {} + if "step" in state: + optimizer_dict["step"] = state["step"] + log(f" optimizer/step: Moved from top-level (step={state['step']})") + if "opt_state" in state: + optimizer_dict["opt_state"] = _convert_opt_state_linen_to_nnx(state["opt_state"]) + log(" optimizer/opt_state: Removed 'params' nesting from mu/nu") + if optimizer_dict: + result["optimizer"] = optimizer_dict + + return result + + +def convert_nnx_to_linen(state: dict) -> dict: + """Converts NNX checkpoint to Linen format. + + Reads from 'model'/'optimizer' keys (or falls back to old 'params'/'opt_state' format). + Layer structure is auto-detected (stacked vs integer-keyed). + """ + result = {} + + model_key = "model" if "model" in state else "params" + if model_key in state: + nnx_params = state[model_key] + stripped = _strip_value_wrappers(nnx_params) + log(f" {model_key}: Removed {{value:}} wrappers") + + for component in ("decoder", "encoder"): + if component in stripped and isinstance(stripped[component], dict): + stripped[component] = _convert_layers_to_linen_format(stripped[component]) + + # Add double 'params' nesting: decoder → params/params/decoder + result["params"] = {"params": stripped} + log(" params: Added double 'params' nesting (model → params/params)") + + # optimizer: extract step and opt_state back to top level + if "optimizer" in state: + optimizer = state["optimizer"] + if "step" in optimizer: + result["step"] = optimizer["step"] + log(" step: Extracted from optimizer/step to top level") + if "opt_state" in optimizer: + result["opt_state"] = _convert_opt_state_nnx_to_linen(optimizer["opt_state"]) + log(" opt_state: Added 'params' nesting to mu/nu") + elif "opt_state" in state: + # Backward compat: old format with opt_state at top level + result["opt_state"] = _convert_opt_state_nnx_to_linen(state["opt_state"]) + log(" opt_state: Converted from top-level opt_state (old format)") + + if "step" in state and "step" not in result: + result["step"] = state["step"] + + return result + + +# ── Checkpoint I/O ───────────────────────────────────────────────────────────── + + +def load_checkpoint(checkpoint_path: str) -> dict: + """Loads checkpoint from local or GCS path.""" + log(f"Loading checkpoint from: {checkpoint_path}") + + checkpoint_dir = epath.Path(checkpoint_path) + ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) + metadata = ckptr.metadata(checkpoint_dir) + + devices = np.array(jax.devices()).reshape((-1,)) + single_device_mesh = jax.sharding.Mesh(devices, ("x",)) + unsharded = jax.sharding.NamedSharding(single_device_mesh, jax.sharding.PartitionSpec()) + + restore_args = jax.tree_util.tree_map( + lambda x: ocp.ArrayRestoreArgs(sharding=unsharded) if hasattr(x, "shape") else None, + metadata.item_metadata.tree, + is_leaf=lambda x: hasattr(x, "shape"), + ) + + state = ckptr.restore(checkpoint_dir, restore_args=restore_args) + log(f" Loaded keys: {list(state.keys())}") + return state + + +def save_checkpoint(state: dict, output_path: str) -> None: + """Saves checkpoint to local or GCS path.""" + log(f"Saving checkpoint to: {output_path}") + + output_dir = epath.Path(output_path) + output_dir.mkdir(exist_ok=True, parents=True) + + ckptr = ocp.PyTreeCheckpointer() + ckptr.save(output_dir, state, force=True) + log(" Checkpoint saved successfully") + + +# ── CLI ──────────────────────────────────────────────────────────────────────── + + +def main(): + parser = argparse.ArgumentParser( + description="Convert between Linen and NNX checkpoint formats.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--source_path", + type=str, + required=True, + help="Path to source checkpoint items directory (e.g. gs://bucket/ckpt/0/items).", + ) + parser.add_argument( + "--target_path", + type=str, + required=True, + help="Path to save converted checkpoint.", + ) + parser.add_argument( + "--direction", + type=str, + choices=["auto", "linen_to_nnx", "nnx_to_linen"], + default="auto", + help="Conversion direction. 'auto' detects from source format.", + ) + parser.add_argument( + "--scan_layers", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "For linen_to_nnx only: if True (default), stack per-layer arrays into a " + "scanned 'layers' tensor with layer dim at axis 1 (for NNX with scan_layers=True). " + "If False, rename layers_N to integer-keyed layers/N without stacking " + "(for NNX with scan_layers=False)." + ), + ) + + args = parser.parse_args() + + print("=" * 80) + print("Linen <-> NNX Checkpoint Converter") + print("=" * 80) + + start_time = time.time() + + state = load_checkpoint(args.source_path) + + if args.direction == "auto": + source_format = detect_format(state) + target_format = "nnx" if source_format == "linen" else "linen" + log(f"Auto-detected: {source_format} → {target_format}") + else: + source_format = args.direction.split("_to_")[0] + target_format = args.direction.split("_to_")[1] + log(f"Using specified direction: {source_format} → {target_format}") + + log(f"Converting: {source_format} → {target_format}") + if source_format == "linen": + log(f"scan_layers={args.scan_layers}") + + if source_format == "linen" and target_format == "nnx": + converted_state = convert_linen_to_nnx(state, scan_layers=args.scan_layers) + elif source_format == "nnx" and target_format == "linen": + converted_state = convert_nnx_to_linen(state) + else: + raise ValueError(f"Invalid conversion: {source_format} → {target_format}") + + save_checkpoint(converted_state, args.target_path) + + elapsed = time.time() - start_time + print("\n" + "=" * 80) + print(f"Conversion complete in {elapsed:.2f} seconds") + print(f" Source: {args.source_path}") + print(f" Target: {args.target_path}") + print(f" Direction: {source_format} → {target_format}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py index 9b5f0cfb21..62f0e0da77 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py @@ -40,6 +40,7 @@ import os import sys +from flax import nnx import jax from jax import random from jax.sharding import Mesh @@ -48,11 +49,15 @@ from maxtext.common import checkpointing from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.layers import quantizations +from maxtext.layers import train_state_nnx from maxtext.models.models import transformer_as_linen from maxtext.optimizers import optimizers from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +from maxtext.utils import maxtext_utils_nnx +from maxtext.utils import model_creation_utils +from maxtext.utils import train_utils import numpy as np from psutil import Process import tensorstore as ts @@ -87,13 +92,23 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name devices_array = maxtext_utils.create_device_mesh(cfg) mesh = Mesh(devices_array, cfg.mesh_axes) - quant = quantizations.configure_quantization(cfg) if cfg.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") + rngs = maxtext_utils_nnx.create_nnx_rngs(cfg, rng_key=init_rng) + model = model_creation_utils.from_config(cfg, mesh=mesh, rngs=rngs) + _, tx = train_utils.create_training_optimizer(cfg, model) + _create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(cfg, mesh) + + def init_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(nnx_model, optimizer) + else: + quant = quantizations.configure_quantization(cfg) model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg) - tx = optimizers.get_optimizer(cfg, learning_rate_schedule) + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg) + tx = optimizers.get_optimizer(cfg, learning_rate_schedule) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( cfg.checkpoint_dir, @@ -102,11 +117,6 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name cfg.checkpoint_period, ) - if cfg.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn) max_logging.log("start") max_utils.print_mem_stats("After params initialized") @@ -191,10 +201,21 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name "['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None), } - state_map = { - ".step": ("step", None), - ".opt_state.count": ("opt_states_0.no_prefix_0.count", None), - } + if cfg.pure_nnx: + # NNX state-tree paths after `nnx.split(TrainStateNNX)`: + # model params -> ['model'].value + # adam mu / nu -> ['optimizer']['opt_state']['mu' | 'nu'].value + # step -> ['optimizer']['step'].value + # opt count -> ['optimizer']['opt_state']['count'].value + state_map = { + ".optimizer.step.value": ("step", None), + ".optimizer.opt_state.count.value": ("opt_states_0.no_prefix_0.count", None), + } + else: + state_map = { + ".step": ("step", None), + ".opt_state.count": ("opt_states_0.no_prefix_0.count", None), + } def get_layer_prefix(keystr_pax): # different path format between decoder_layer variable @@ -206,19 +227,27 @@ def get_layer_prefix(keystr_pax): return prefix_pax_opt_state for keystr_maxtext, (keystr_pax, transform_fn) in keystr_map.items(): - # model variable - state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn) prefix_pax_opt_state = get_layer_prefix(keystr_pax) - # first momentum in optimizer state - state_map[f".opt_state.mu['params']{keystr_maxtext}"] = ( - f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}", - transform_fn, - ) - # second momentum in optimizer state - state_map[f".opt_state.nu['params']{keystr_maxtext}"] = ( - f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}", - transform_fn, - ) + if cfg.pure_nnx: + state_map[f".model{keystr_maxtext}.value"] = (f"mdl_vars{keystr_pax}", transform_fn) + state_map[f".optimizer.opt_state.mu{keystr_maxtext}.value"] = ( + f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}", + transform_fn, + ) + state_map[f".optimizer.opt_state.nu{keystr_maxtext}.value"] = ( + f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}", + transform_fn, + ) + else: + state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn) + state_map[f".opt_state.mu['params']{keystr_maxtext}"] = ( + f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}", + transform_fn, + ) + state_map[f".opt_state.nu['params']{keystr_maxtext}"] = ( + f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}", + transform_fn, + ) def verify_fn(key_path, _): keystr = jax.tree_util.keystr(key_path) @@ -270,10 +299,11 @@ def map_fn(key_path, value): max_logging.log("converted state finished") max_utils.print_mem_stats("converted state finished") - if checkpointing.save_checkpoint(checkpoint_manager, converted_state.step, converted_state): - max_logging.log(f"saved a checkpoint at step {converted_state.step}") + step_value = int(converted_state.optimizer.step.value) if cfg.pure_nnx else converted_state.step + if checkpointing.save_checkpoint(checkpoint_manager, step_value, converted_state): + max_logging.log(f"saved a checkpoint at step {step_value}") # Upon preemption, exit when and only when all ongoing saves are complete. - if checkpoint_manager.reached_preemption(converted_state.step): + if checkpoint_manager.reached_preemption(step_value): checkpoint_manager.wait_until_finished() sys.exit() diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index dc01262e6c..ad7618868a 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -20,6 +20,7 @@ from absl import flags import datetime from etils import epath +from flax import nnx from flax.training import train_state import jax from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE @@ -571,7 +572,7 @@ def load_state_if_possible( load_parameters_from_path: str, load_full_state_from_path: str, checkpoint_storage_concurrent_gb: int, - abstract_unboxed_pre_state: train_state.TrainState, + abstract_unboxed_pre_state: train_state.TrainState | nnx.State, enable_single_replica_ckpt_restoring: bool | None = False, dataset_type: str | None = "tfds", step: int = -1, # -1 means latest @@ -639,9 +640,14 @@ def map_to_pspec(data): ) ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) - restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state) + # Convert nnx.State to pure dict to match how checkpoints are saved for NNX + restore_target = abstract_unboxed_pre_state + if isinstance(abstract_unboxed_pre_state, nnx.State): + restore_target = abstract_unboxed_pre_state.to_pure_dict() + + restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target) checkpoint_args = ocp.args.PyTreeRestore( - item=abstract_unboxed_pre_state, + item=restore_target, restore_args=restore_args, partial_restore=True, ) @@ -679,9 +685,14 @@ def map_to_pspec(data): return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None) if load_parameters_from_path != "": + if isinstance(abstract_unboxed_pre_state, nnx.State): + _, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...) + else: + params = abstract_unboxed_pre_state.params + restored_params = load_params_from_path( load_parameters_from_path, - abstract_unboxed_pre_state.params, + params, checkpoint_storage_concurrent_gb, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3, @@ -773,7 +784,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step # Determine the effective step for saving a checkpoint. # If 'step' is not provided, this call is for a potential final checkpoint # and use the last completed step from the state. - actual_step = (int(state.step) - 1) if step is None else int(step) + if step is not None: + actual_step = int(step) + else: + if config.pure_nnx: + actual_step = int(state.optimizer.step) - 1 + else: + # Linen TrainState has .step attribute + actual_step = int(state.step) - 1 + + if config.pure_nnx: + # Convert nnx.State to dict. + state = state.to_pure_dict() # Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic. # This occurs if this function was called: diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index cc2f674fd4..fa51f00d62 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -560,6 +560,13 @@ logical_axis_rules: [ ['tokens_per_page', []], ['paged_kv_head_dim_size', []], # ========================================== + # Pipeline Parallelism + # ========================================== + ['layers_outside_pipeline', []], + ['layers_per_stage', []], + ['num_activations', []], + ['circular_repeats', []], + # ========================================== # Deprecated / Scheduled for Removal # ========================================== ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], diff --git a/src/maxtext/configs/pyconfig_deprecated.py b/src/maxtext/configs/pyconfig_deprecated.py index 406ba92523..c14d87cd4b 100644 --- a/src/maxtext/configs/pyconfig_deprecated.py +++ b/src/maxtext/configs/pyconfig_deprecated.py @@ -195,10 +195,9 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) - def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool): + del enable_nnx # NNX vocab tiling supported via vocab_tiling_nnx_loss in vocabulary_tiling.py if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0: raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") - if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration - raise ValueError("We currently don't support vocab tiling on NNX module.") def validate_rampup_batch_size(batch_size_start, batch_size_end, batch_size_increment, global_rampup_samples): diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index c35274cd24..12c4136f2d 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2833,8 +2833,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0 ): raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") - if self.num_vocab_tiling > 1 and self.enable_nnx: - raise ValueError("We currently don't support vocab tiling on NNX module.") + # Vocab tiling on NNX is now supported via vocab_tiling_nnx_loss in vocabulary_tiling.py. if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring": if "gpu" not in self.hardware: raise ValueError( diff --git a/src/maxtext/experimental/rl/grpo_trainer.py b/src/maxtext/experimental/rl/grpo_trainer.py index 28eef21cb0..b788ccd13e 100644 --- a/src/maxtext/experimental/rl/grpo_trainer.py +++ b/src/maxtext/experimental/rl/grpo_trainer.py @@ -54,9 +54,12 @@ from jax import random from flax.linen import partitioning as nn_partitioning +from flax import nnx from flax import struct from flax.nnx import TrainState +from maxtext.layers import train_state_nnx + from cloud_tpu_diagnostics import diagnostic from cloud_tpu_diagnostics.configuration import debug_configuration from cloud_tpu_diagnostics.configuration import diagnostic_configuration @@ -85,11 +88,12 @@ from maxtext.experimental.rl import grpo_utils from maxtext.common.metric_logger import MetricLogger from maxtext.common.vertex_tensorboard import VertexTensorboardManager -from maxtext.inference import offline_engine from maxtext.utils import exceptions from maxtext.utils import gcs_utils from maxtext.utils import max_logging from maxtext.utils import max_utils +from maxtext.utils import maxtext_utils_nnx +from maxtext.utils import model_creation_utils from maxtext.utils import maxtext_utils from maxtext.utils import sharding from maxtext.utils import train_utils @@ -335,34 +339,190 @@ def grpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_ return loss, aux +def grpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True): + """GRPO loss for the NNX path. + + Signature matches the Linen `grpo_loss_fn` so callers can dispatch on the + same shape. `dropout_rng` and `params` are unused (NNX models carry these + themselves); `reference_model` is the frozen reference `nnx.Module`. The + reference forward is wrapped in `stop_gradient` so grads only flow into the + policy. Returns `(loss, LossAux)`. + """ + del dropout_rng, params # NNX models carry these themselves + + prompt_with_completions = data[f"{config.train_data_columns}_completions"] + prompt_completions_position = data[f"{config.train_data_columns}_completions_position"] + prompt_completions_segmentation = data[f"{config.train_data_columns}_completions_segmentation"] + completions_segmentation = data["ar_completions_segmentation"] + + token_logps_policy, intermediate_outputs = grpo_utils.compute_log_probs_nnx( + policy_model, + prompt_with_completions, + prompt_completions_position, + prompt_completions_segmentation, + completions_segmentation, + config, + is_train=is_train, + ) + + completion_target_segmentation = data["ar_completions_segmentation"][..., 1:] + valid_seq_mask = completion_target_segmentation != 0 + + rewards = grpo_utils.dummy_reward_len(valid_seq_mask) + rewards = jnp.array(rewards) + + G = config.num_generations + rewards_grouped = rewards.reshape(-1, G) + group_mean = jnp.mean(rewards_grouped, axis=1) + group_std = jnp.std(rewards_grouped, axis=1) + repeated_group_mean = jnp.repeat(group_mean, G) + repeated_group_std = jnp.repeat(group_std, G) + advantages = (rewards - repeated_group_mean) / (repeated_group_std + EPS) + advantages_exp = advantages[:, None] + + if data["completions_logprobs"] is None: # off-policy + old_per_token_logps = jax.lax.stop_gradient(token_logps_policy) + else: # on-policy + old_per_token_logps = data["completions_logprobs"] + + policy_diff = token_logps_policy - old_per_token_logps + coef_1 = jnp.exp(policy_diff) + coef_2 = jnp.clip(coef_1, 1 - config.grpo_epsilon, 1 + config.grpo_epsilon) + loss_tokens = -jnp.minimum(coef_1 * advantages_exp, coef_2 * advantages_exp) + + if config.grpo_beta != 0.0: + token_logps_ref, _ = grpo_utils.compute_log_probs_nnx( + reference_model, + prompt_with_completions, + prompt_completions_position, + prompt_completions_segmentation, + completions_segmentation, + config, + is_train=False, + ) + token_logps_ref = jax.lax.stop_gradient(token_logps_ref) + token_diff_logps_ref_policy = token_logps_ref - token_logps_policy + per_token_kl = jnp.exp(token_diff_logps_ref_policy) - token_diff_logps_ref_policy - 1 + per_token_kl = per_token_kl * valid_seq_mask + loss_tokens += config.grpo_beta * per_token_kl + + loss_per_example = jnp.sum(loss_tokens * valid_seq_mask, axis=1) / jnp.clip(jnp.sum(valid_seq_mask, axis=1), min=1) + loss = jnp.mean(loss_per_example) + total_weights = jnp.sum(valid_seq_mask) + + moe_lb_loss = 0.0 + if config.num_experts > 1: + moe_lb_losses = maxtext_utils.collect_intermediates_by_suffix(intermediate_outputs, "moe_lb_loss") + if moe_lb_losses: + moe_lb_loss = jnp.mean(jnp.concatenate(moe_lb_losses)) + loss += moe_lb_loss + + if config.grpo_beta != 0.0: + avg_kl = jnp.mean((per_token_kl * valid_seq_mask) / jnp.clip(jnp.sum(valid_seq_mask, axis=1, keepdims=True), min=1)) + else: + avg_kl = None + avg_completion_length = jnp.mean(jnp.sum(data["ar_completions_segmentation"] != 0, axis=1)) + aux = LossAux( + total_loss=loss, + avg_reward=jnp.mean(rewards), + avg_reward_std=jnp.mean(repeated_group_std), + avg_advantage=jnp.mean(advantages), + avg_kl=avg_kl, + completion_length=avg_completion_length, + moe_lb_loss=moe_lb_loss, + total_weights=total_weights, + ) + return loss, aux + + # ----------------------------------------------------------------------------- # Trainer and top level training functions # ----------------------------------------------------------------------------- -def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng): - """Performs a single training step of the GRPO algorithm. +def _train_step_nnx(model_graphdef, config, state_mesh_shardings, state, data): + """GRPO train_step body for the NNX path. - This function computes the GRPO loss, calculates gradients, and updates the - model's parameters. It handles gradient accumulation and clipping as configured. - The reference model's parameters are held constant during the update. + Reconstructs `TrainStateNNX` from `(model_graphdef, state)`, splits out + the policy params for value_and_grad, applies gradients, and returns the + new state with `nnx.Intermediate` filtered out (transient sown values + must not persist across steps). + """ + del state_mesh_shardings # host-offload paths not yet wired up here - Args: - model: The transformer model to be trained. - config: The training configuration object. - state_mesh_shardings: Pytree of sharding specifications for the training state. - params_shardings: Pytree of sharding specifications for the model parameters. - This argument is not used and is kept to match the signature of other trainers. - state: The current training state, including parameters and optimizer state. - data: A batch of training data, including prompts and generated completions. - dropout_rng: JAX PRNG key for dropout. + if config.gradient_accumulation_steps > 1: + raise NotImplementedError( + "GRPO + pure_nnx + gradient_accumulation_steps>1 not supported yet. " + "Set gradient_accumulation_steps=1 or pure_nnx=False." + ) - Returns: - A tuple containing: - - new_state: The updated training state after applying gradients. - - metrics: A dictionary of metrics for logging, including loss, reward, - and gradient norms. + state = nnx.merge(model_graphdef, state) # reconstruct TrainStateNNX + policy_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...) + + def diff_wrapper(param, rest, config, data): + local_model = nnx.merge(policy_graphdef, param, rest, copy=True) + loss, aux = grpo_loss_fn_nnx(local_model, config, data, None, None, state.reference_model, is_train=True) + _, _, new_rest = nnx.split(local_model, nnx.Param, ...) + return loss, (aux, new_rest) + + grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True) + (loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data) + nnx.update(state.model, new_rest) + + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold) + else: + grads = raw_grads + state.apply_gradients(grads) + new_state = state + + scalar_metrics = { + "learning/loss": loss, + "learning/avg_reward": aux.avg_reward, + "learning/avg_reward_std": aux.avg_reward_std, + "learning/avg_advantage": aux.avg_advantage, + "learning/avg_kl": aux.avg_kl, + "learning/completion_length": aux.completion_length, + "learning/moe_lb_loss": aux.moe_lb_loss, + "learning/total_weights": aux.total_weights, + "learning/grad_norm": max_utils.l2norm_pytree(grads), + "learning/raw_grad_norm": max_utils.l2norm_pytree(raw_grads), + } + _, new_policy_params, _ = nnx.split(new_state.model, nnx.Param, ...) + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_policy_params) + metrics = {"scalar": scalar_metrics, "scalars": {}} + + return nnx.state(new_state, nnx.Not(nnx.Intermediate)), metrics + + +def _eval_step_nnx(model_graphdef, config, state, data): + """GRPO eval_step body for the NNX path. No state update.""" + state = nnx.merge(model_graphdef, state) + loss, aux = grpo_loss_fn_nnx(state.model, config, data, None, None, state.reference_model, is_train=False) + metrics = { + "scalar": { + "evaluation/loss": loss, + "evaluation/total_loss": aux.total_loss, + "evaluation/total_weights": aux.total_weights, + "evaluation/moe_lb_loss": aux.moe_lb_loss, + }, + } + return metrics + + +def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng): + """Single GRPO training step. + + Computes the GRPO loss, gradients, and applies them to the policy. The + reference model is held constant. Linen and NNX paths split below; on + NNX, `model` is a GraphDef and `state` is a flat `nnx.State` of a + `TrainStateNNX` (with `model`, `optimizer`, and `reference_model`). + + Returns `(new_state, metrics)`. """ + if config.pure_nnx: + return _train_step_nnx(model, config, state_mesh_shardings, state, data) + state, reference_params = _split_grpo_state(state) state_mesh_shardings, reference_params_sharding = _split_grpo_state(state_mesh_shardings) extra_grpo_args = [reference_params] @@ -473,6 +633,8 @@ def eval_step(model, config, state, data, dropout_rng): Returns: A dictionary of evaluation metrics. """ + if config.pure_nnx: + return _eval_step_nnx(model, config, state, data) reference_params, extra_grpo_args, _loss_fn = [], [], grpo_loss_fn state, reference_params = _split_grpo_state(state) @@ -542,27 +704,48 @@ def setup_train_loop( - eval_data_iterator: The iterator for the evaluation dataset (or None). - state: The initialized training state. """ + if config.pure_nnx != config_inference.pure_nnx: + raise ValueError( + f"config.pure_nnx ({config.pure_nnx}) and config_inference.pure_nnx " f"({config_inference.pure_nnx}) must agree." + ) with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): max_logging.log("Training mesh used for the workload") num_inference_devices = config.inference_devices_per_replica * config.inference_replicas training_devices = jax.devices()[num_inference_devices:] + init_rng = jax.random.PRNGKey(config.init_weights_seed) + if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") + training_mesh = maxtext_utils.get_mesh_from_config(config, devices=training_devices) + training_rngs = maxtext_utils_nnx.create_nnx_rngs(config, rng_key=init_rng) + model = mt.from_config(config, devices=training_devices, mesh=training_mesh, rngs=training_rngs) else: model = mt.from_config(config, devices=training_devices) mesh = model.mesh + max_logging.log("Inference mesh used for the workload") inference_devices = jax.devices()[:num_inference_devices] if config_inference.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") + inference_mesh_obj = maxtext_utils.get_mesh_from_config(config_inference, devices=inference_devices) + inference_rngs = maxtext_utils_nnx.create_nnx_rngs(config_inference, rng_key=init_rng) + inference_model = mt.from_config( + config_inference, devices=inference_devices, mesh=inference_mesh_obj, rngs=inference_rngs + ) else: inference_model = mt.from_config(config_inference, devices=inference_devices) inference_mesh = inference_model.mesh - init_rng = jax.random.PRNGKey(config.init_weights_seed) + learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model) + if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + _create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(config, mesh, devices=training_devices) + + def init_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + # Reference uses the same init seed so it starts identical to the policy. + reference_model = _create_model_partial() + return train_state_nnx.TrainStateNNX(nnx_model, optimizer, reference_model=reference_model) + else: init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) @@ -573,10 +756,15 @@ def setup_train_loop( data_iterator, config, mesh, checkpoint_manager, init_state_fn ) - # create inference_state_mesh_shardings from inference_mesh if config_inference.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + _create_inference_partial, _ = model_creation_utils.create_nnx_abstract_model( + config_inference, inference_mesh, devices=inference_devices + ) + + def init_inference_state_fn(): + inference_nnx_model = _create_inference_partial() + return train_state_nnx.TrainStateNNX(inference_nnx_model, None) + else: init_inference_state_fn = functools.partial( maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng @@ -586,7 +774,11 @@ def setup_train_loop( )[2] if not config.using_pipeline_parallelism: # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage - sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance) + if config.pure_nnx: + _, params_for_check, _ = nnx.split(state.model, nnx.Param, ...) + sharding.assert_params_sufficiently_sharded(params_for_check, mesh, config.sharding_tolerance) + else: + sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance) return ( init_rng, @@ -699,10 +891,15 @@ def train_loop(config, config_inference, recorder, state=None): token=config.hf_access_token, ) - if "reference_params" not in state.params: - reference_params = jax.tree.map(jnp.copy, state.params["params"]) - state = _merge_grpo_state(state, reference_params) - state_mesh_shardings = _merge_grpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + if config.pure_nnx: + # `reference_model` is set up by init_state_fn as a sibling field — nothing to merge. + if not hasattr(state, "reference_model"): + raise RuntimeError("NNX GRPO state is missing reference_model; check setup_train_loop.") + else: + if "reference_params" not in state.params: + reference_params = jax.tree.map(jnp.copy, state.params["params"]) + state = _merge_grpo_state(state, reference_params) + state_mesh_shardings = _merge_grpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator @@ -710,6 +907,9 @@ def train_loop(config, config_inference, recorder, state=None): data_sharding = sharding.get_input_data_sharding(config, mesh) + # Lazy import: pulls in maxengine and jetstream stubs. + from maxtext.inference import offline_engine # pylint: disable=import-outside-toplevel + inference_engine = offline_engine.OfflineEngine( config=config_inference, mesh=inference_mesh, @@ -724,7 +924,11 @@ def train_loop(config, config_inference, recorder, state=None): metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) # Write train config params, num model params, and XLA flags to tensorboard - metric_logger.write_setup_info_to_tensorboard(state.params["params"]) + if config.pure_nnx: + _, _params_for_metrics, _ = nnx.split(state.model, nnx.Param, ...) + metric_logger.write_setup_info_to_tensorboard(_params_for_metrics) + else: + metric_logger.write_setup_info_to_tensorboard(state.params["params"]) def generation_worker_fn( worker_inference_engine, @@ -848,21 +1052,32 @@ def generation_worker_fn( state, metrics = p_train_step(state, example_batch, train_rng) with jax.profiler.StepTraceAnnotation("transfer data", step_num=step): if step != 0 and step % config.inference_rollouts == 0: - grpo_utils.pathways_reshard( - config_inference, - inference_engine, - {"params": state.params["params"]}, - {"params": state_mesh_shardings.params["params"]}, - mesh, - {"params": inference_state_mesh_shardings.params["params"]}, - ) + if config.pure_nnx: + grpo_utils.pathways_reshard_nnx( + config_inference, + inference_engine, + state.model, + state_mesh_shardings.model, + inference_state_mesh_shardings.model, + ) + else: + grpo_utils.pathways_reshard( + config_inference, + inference_engine, + {"params": state.params["params"]}, + {"params": state_mesh_shardings.params["params"]}, + mesh, + {"params": inference_state_mesh_shardings.params["params"]}, + ) with data_buffer_lock: data_buffer.clear() step_time_delta = datetime.datetime.now() - last_step_completion last_step_completion = datetime.datetime.now() - state_to_save = _split_grpo_state(state)[0] + # Linen embeds reference in `state.params` and strips it for save; NNX + # holds it as a sibling field on TrainStateNNX so the whole state goes. + state_to_save = state if config.pure_nnx else _split_grpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) if config.dump_hlo and step == start_step: @@ -900,7 +1115,7 @@ def generation_worker_fn( metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: - state_to_save = _split_grpo_state(state)[0] + state_to_save = state if config.pure_nnx else _split_grpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) elif checkpoint_manager is not None: # in case the last checkpoint_period checkpoint is still in progress diff --git a/src/maxtext/experimental/rl/grpo_utils.py b/src/maxtext/experimental/rl/grpo_utils.py index 352a2b3b8d..34d437867e 100644 --- a/src/maxtext/experimental/rl/grpo_utils.py +++ b/src/maxtext/experimental/rl/grpo_utils.py @@ -21,8 +21,9 @@ import jaxtyping from typing import Any, Callable +from flax import nnx + from maxtext.common.common_types import DecoderBlockType -from maxtext.inference.offline_engine import InputData from maxtext.utils import max_logging from maxtext.utils import max_utils @@ -112,6 +113,48 @@ def compute_log_probs( return token_log_probs, intermediate_outputs +def compute_log_probs_nnx( + model, + inputs, + inputs_position, + inputs_segmentation, + completion_segmentation, + config, + is_train=False, +): + """`compute_log_probs` for the NNX path. + + `model` is an `nnx.Module` (carries its own params + RNG state), so there's + no `params` arg. Intermediates are pulled off the model after the forward + via `nnx.state(model, nnx.Intermediate).to_pure_dict()`. + """ + logits = model( + decoder_input_tokens=inputs, + decoder_positions=inputs_position, + decoder_segment_ids=inputs_segmentation, + enable_dropout=(config.enable_dropout if is_train else False), + ) + intermediate_outputs = nnx.state(model, nnx.Intermediate).to_pure_dict() + logits = logits / config.decode_sampling_temperature + + targets = inputs[:, 1:] + shifted_completion_segmentation = jax.lax.dynamic_slice( + completion_segmentation, (0, 1), (completion_segmentation.shape[0], completion_segmentation.shape[1] - 1) + ) + shifted_completion_segmentation = jnp.pad( + shifted_completion_segmentation, ((0, 0), (0, 1)), mode="constant", constant_values=0 + ) + mask = shifted_completion_segmentation[..., None] + mask = jnp.broadcast_to(mask, logits.shape) + masked_logits = jnp.where(mask, logits, -jnp.inf) + log_probs = jax.nn.log_softmax(masked_logits, axis=-1) + log_probs = jnp.where(mask, log_probs, -0.0) + log_probs = log_probs[:, :-1, :] + token_log_probs = jnp.take_along_axis(log_probs, targets[..., None], axis=-1)[..., 0] + token_log_probs = token_log_probs * shifted_completion_segmentation[:, :-1] + return token_log_probs, intermediate_outputs + + def generate_offline_completions(config, tokenizer_model, inference_engine, data): """Generates completions for a batch of prompts using an offline engine. @@ -125,6 +168,10 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data The input `data` dictionary updated with the generated completions, segmentations, positions, and log-probabilities. """ + # Lazy import: pulls in maxengine and jetstream stubs, which we only want to + # touch when this function is actually called (i.e. during a real GRPO run). + from maxtext.inference.offline_engine import InputData # pylint: disable=import-outside-toplevel + data[config.train_data_columns] = np.asarray( jnp.repeat(data[config.train_data_columns], config.num_generations, axis=0) ) @@ -175,6 +222,30 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data return data +def pathways_reshard_nnx( + config, inference_engine, policy_state_model, source_shardings_model, destination_shardings_model +): + """`pathways_reshard` for the NNX path. + + Reshard the policy params onto the inference mesh and push them into the + inference engine. Requires `scan_layers=True` (no NNX-aware unscan helper yet). + """ + if not config.scan_layers: + raise NotImplementedError( + "GRPO + pure_nnx + scan_layers=False not supported yet. " "Use scan_layers=True or pure_nnx=False." + ) + _, policy_params, _ = nnx.split(policy_state_model, nnx.Param, ...) + _, source_param_shardings, _ = nnx.split(source_shardings_model, nnx.Param, ...) + _, dest_param_shardings, _ = nnx.split(destination_shardings_model, nnx.Param, ...) + del source_param_shardings # already encoded on policy_params + with ( + jax.transfer_guard_device_to_host("disallow_explicit"), + jax.transfer_guard_host_to_device("disallow_explicit"), + ): + resharded_params = reshard_pytree(policy_params, dest_param_shardings) + inference_engine.update_params(resharded_params) + + def pathways_reshard(config, inference_engine, params, source_shardings, source_mesh, destination_shardings): """Reshards model parameters from training to inference sharding. diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index 5bb0a87b5a..9c42929f5e 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -32,6 +32,7 @@ from jax.experimental.layout import DeviceLocalLayout as DLL # type: ignore from flax import linen as nn +from flax import nnx from flax import struct from flax.linen import partitioning as nn_partitioning import flax @@ -44,8 +45,10 @@ from maxtext.inference.page_manager import PageManager, PageState from maxtext.multimodal import processor as mm_processor from maxtext.utils import lora_utils +from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +from maxtext.utils import model_creation_utils from maxtext.common.gcloud_stub import jetstream, is_decoupled from maxtext.common.common_types import MODEL_MODE_PREFILL, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE @@ -111,12 +114,42 @@ def __init__(self, config: Any, devices: Any | None = None): devices_array = maxtext_utils.create_device_mesh(config=config, devices=devices) self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) - # Model and Optimizer definition + # Model and Optimizer definition. quant = quantizations.configure_quantization(config) if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") + # `serve` only when the on-disk checkpoint already carries `qrhs.frozen` + # (no full-precision kernel). For `checkpoint_is_quantized=False` with + # quant enabled we stay in `train` mode and let AQT quantize per-forward + # against the full-precision kernel — same numerical result as `serve` + # for absmax calibration, just slower. + nnx_quant_mode_str = "serve" if (quant is not None and config.checkpoint_is_quantized) else "train" + # We need both PREFILL and AR abstract models because the cache vars inherit + # CACHE_BATCH_PREFILL vs CACHE_BATCH from the construction model_mode, and + # bulk_insert searches for the substring "cache_batch" in the AR-mode names. + # Calling nnx.eval_shape directly (instead of create_nnx_abstract_model) avoids + # the jax.set_mesh wrap that trips Flax 0.12.6 on logical-only axes like "norm". + _create_model = model_creation_utils.get_nnx_create_model_fn( + config, mesh=self._mesh, model_mode=MODEL_MODE_PREFILL, quant_mode_str=nnx_quant_mode_str + ) + _create_model_ar = model_creation_utils.get_nnx_create_model_fn( + config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE, quant_mode_str=nnx_quant_mode_str + ) + self._nnx_quant_mode_str = nnx_quant_mode_str + with nn_partitioning.axis_rules(config.logical_axis_rules): + abstract_model = nnx.eval_shape(_create_model) + abstract_model_ar = nnx.eval_shape(_create_model_ar) + self.model = abstract_model + self.model_ar = abstract_model_ar + # 3-way split so JIT bodies can pass (params, cache, rest) separately to + # nnx.merge. `rest` (RNG state etc.) is materialized in load_params. + graphdef, _, _, _ = nnx.split(abstract_model, nnx.Param, nnx.Cache, ...) + self.graphdef = graphdef + self._create_model_fn = _create_model + self._nnx_rest_state = None else: self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + self.graphdef = None + self._create_model_fn = None self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None)) self.abstract_params = None @@ -142,6 +175,65 @@ def print_stats(self, label: str): max_utils.print_mem_stats(label) max_utils.print_cpu_ram_stats(label) + # NNX cache adapter: bulk_insert / _insert_jit / _maybe_stack_* switch on + # path[-1].key (e.g. "cached_prefill_key"). NNX state would expose ".value" at + # that position, so we convert NNX state <-> plain dict at the JIT boundary + # via to_pure_dict / replace_by_pure_dict. The cache helpers stay unchanged. + + def _nnx_cache_state_template(self, mode: str = MODEL_MODE_PREFILL) -> Any: + """Empty nnx.State template for the model's nnx.Cache vars (PREFILL=batch 1, AR=batch N).""" + src = self.model if mode == MODEL_MODE_PREFILL else self.model_ar + _, cache_state, _ = nnx.split(src, nnx.Cache, ...) + return cache_state + + def _nnx_init_cache_dict(self, mode: str = MODEL_MODE_PREFILL) -> dict: + """Zero-filled pure-dict cache matching the abstract NNX model.""" + src = self.model if mode == MODEL_MODE_PREFILL else self.model_ar + _, cache_state, _ = nnx.split(src, nnx.Cache, ...) + cache_dict = cache_state.to_pure_dict() + return jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), cache_dict) + + def _nnx_run_model( + self, + params, + cache_dict, + decoder_input_tokens, + decoder_positions, + *, + decoder_segment_ids=None, + enable_dropout=False, + model_mode, + previous_chunk=None, + true_length=None, + slot=None, + page_state=None, + encoder_images=None, + encoder_image_masks=None, + encoder_audios=None, + ): + """NNX equivalent of `model.apply(..., mutable=["cache"])`. Returns (logits, new_cache_dict).""" + cache_state = self._nnx_cache_state_template(mode=model_mode) + nnx.replace_by_pure_dict(cache_state, cache_dict) + # copy=True avoids reusing Variable objects across traces (TraceContextError), + # mirroring the workaround in train.py's diff_wrapper. + model = nnx.merge(self.graphdef, params, cache_state, self._nnx_rest_state, copy=True) + logits = model( + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + encoder_images=encoder_images, + encoder_image_masks=encoder_image_masks, + encoder_audios=encoder_audios, + enable_dropout=enable_dropout, + model_mode=model_mode, + previous_chunk=previous_chunk, + true_length=true_length, + slot=slot, + page_state=page_state, + ) + new_cache = nnx.state(model, nnx.Cache).to_pure_dict() + return logits, new_cache + def generate_aot( self, params: Params, decode_state: DecodeState, rng: PRNGKeyType | None = None ): # returns (new_decode_state, result_tokens) @@ -225,6 +317,9 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar if rng is None: rng = jax.random.PRNGKey(0) + if self.config.pure_nnx: + return self._load_params_nnx(params=params, rng=rng) + if self.model.quant and self.config.checkpoint_is_quantized: print("Loading from the quantized checkpoint...") self.model.quant.quant_mode = quantizations.get_quant_mode("serve") @@ -232,11 +327,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar rng1, rng2, rng3 = jax.random.split(rng, 3) if params: print("Resharding given params") - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) _, self.state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state( self.config, self._mesh, init_state_fn, False ) @@ -245,11 +336,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar state = maxtext_utils.init_decode_state(None, params) state = max_utils.unbox_logicallypartioned(state) else: - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1) state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.config, self._mesh, None, init_state_fn) # pylint: disable=isinstance-second-argument-not-valid-type self.abstract_params = jax.tree_util.tree_map( @@ -292,10 +379,115 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar return params - def load_single_adapter(self, adapter_path): + def _load_params_nnx(self, params, rng): + """NNX equivalent of load_params: returns an nnx.Param state and populates KV cache shardings. + + Quantization handling: + * `checkpoint_is_quantized=True`: model built in `serve` mode (no full + kernel), `from_pretrained` reads `qrhs.frozen` from disk. + * `checkpoint_is_quantized=False` + `quantization=...`: model built in + `train` mode, full-precision kernel loaded; AQT layers quantize per + forward. Same output as serve mode (absmax calibration), slower. """ - Load Single adapter from adapter_path. - Expect adapter_config.json and LoRA adapter weights at this path within subdirectory `/0/items`. + + if params: + print("Resharding given NNX params") + _, params_abs, _ = nnx.split(self.model, nnx.Param, ...) + target_shardings = jax.tree.map( + lambda x: x.sharding if hasattr(x, "sharding") else None, + params_abs, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + params_state = jax.device_put(params, target_shardings) + # Build a concrete model once to capture a real `rest` (RNG vars) for nnx.merge. + # Wasteful but simple — the from_pretrained branch below avoids this. + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + concrete_model = self._create_model_fn() + graphdef, _, _, rest_state = nnx.split(concrete_model, nnx.Param, nnx.Cache, ...) + self.graphdef = graphdef + self._nnx_rest_state = rest_state + del concrete_model + else: + max_logging.log("Loading NNX params via from_pretrained") + with self._mesh: + nnx_model = model_creation_utils.from_pretrained( + self.config, + mesh=self._mesh, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + quant_mode_str=self._nnx_quant_mode_str, + ) + # 4-way split keeps the loaded AQT `qrhs.frozen` leaves (and any other + # non-Param/non-Cache vars) in `loaded_rest_state` so they survive into + # `_nnx_rest_state`. Param-only filtering would silently drop them and + # the model would run with random qrhs values. + _, params_state, _, loaded_rest_state = nnx.split(nnx_model, nnx.Param, nnx.Cache, ...) + # `_prefill_jit` re-merges with `self.graphdef`, which must be the PREFILL + # graphdef built in `__init__` (matching `_create_model_fn`). Don't + # overwrite with the AR-mode graphdef from `from_pretrained` — the + # PREFILL/AR attention ops have different cache variable shapes, and a + # mismatch trips the `assert prefill_kv_cache` check inside attention_op. + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + concrete_model = self._create_model_fn() + graphdef, _, _, rest_state = nnx.split(concrete_model, nnx.Param, nnx.Cache, ...) + # Overlay loaded non-Param/non-Cache leaves (e.g. AQT qrhs.frozen) onto + # the PREFILL-mode rest_state. The PREFILL concrete_model already has + # placeholder qrhs vars at the right paths; we just swap in the loaded + # values. Anything only in `loaded_rest_state` (e.g. AR-only RNG slots) + # is ignored. We keep PREFILL rest_state as the base so RNG variables + # match the PREFILL graphdef's expectations. + loaded_rest_dict = loaded_rest_state.to_pure_dict() + rest_dict = rest_state.to_pure_dict() + def _overlay(dst, src): + if isinstance(dst, dict): + for k, v in dst.items(): + if k in src: + dst[k] = _overlay(v, src[k]) + return dst + return src if not isinstance(src, dict) else dst + rest_dict = _overlay(rest_dict, loaded_rest_dict) + nnx.replace_by_pure_dict(rest_state, rest_dict) + self.graphdef = graphdef + self._nnx_rest_state = rest_state + del nnx_model, concrete_model + + self.abstract_params = jax.tree.map( + lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) + if isinstance(x, jax.Array) + else None, + params_state, + ) + + self.prefill_kv_cache_annotations = maxtext_utils.get_prefill_kv_cache_annotations_nnx( + self.model, self.config, self._mesh + ) + self.prefill_kv_cache_shardings = jax.tree.map( + lambda x: jax.sharding.NamedSharding(self._mesh, x), + self.prefill_kv_cache_annotations, + ) + if self.config.stack_prefill_result_cache: + # With scan_layers=True the NNX cache leaves are already stacked on axis 0, + # so the engine's manual-stack helper (which assumes an unstacked Linen tree) + # doesn't apply. Wiring this up cleanly is a Phase-2 follow-up. + raise NotImplementedError("pure_nnx + stack_prefill_result_cache=True not yet supported.") + # AR-mode abstract model so axis names use CACHE_BATCH (not CACHE_BATCH_PREFILL); + # bulk_insert / _insert_jit search for "cache_batch" in the per-leaf logical axes. + self.kv_cache_annotations = maxtext_utils.get_kv_cache_annotations_nnx(self.model_ar, self.config, self._mesh) + self.kv_cache_shardings = jax.tree.map( + lambda x: jax.sharding.NamedSharding(self._mesh, x), + self.kv_cache_annotations, + ) + # state_mesh_annotations is unused on the NNX path; callers reading it + # (e.g. set_engine_vars_from_base_engine) need to be NNX-aware first. + self.state_mesh_annotations = None + + self.print_stats("After load_params (NNX)") + return params_state + + def load_single_adapter(self, adapter_path): + """Load a single LoRA adapter from `adapter_path`. + + Expects `adapter_config.json` plus adapter weights at `/0/items`. + The returned `params` shape matches `self.abstract_params` (NNX or Linen). """ adapter_config_path = os.path.join(adapter_path, "adapter_config.json") adapter_weights_path = os.path.join(adapter_path, "0", "items") @@ -319,19 +511,36 @@ def apply_adapter(self, base_params, adapter_config, adapter_params): lora_rank = int(adapter_config["r"]) lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank - lora_utils.apply_lora_on_base_params(base_params, adapter_params, lora_scale_factor) + if self.config.pure_nnx: + lora_utils.apply_lora_on_base_params_nnx(base_params, adapter_params, lora_scale_factor) + else: + lora_utils.apply_lora_on_base_params(base_params, adapter_params, lora_scale_factor) def unapply_adapter(self, base_params, adapter_config, adapter_params): """Unapply the adapter params from the merged params to get back the base params.""" lora_rank = int(adapter_config["r"]) lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank - lora_utils.unapply_lora_from_base_params(base_params, adapter_params, lora_scale_factor) + if self.config.pure_nnx: + lora_utils.unapply_lora_from_base_params_nnx(base_params, adapter_params, lora_scale_factor) + else: + lora_utils.unapply_lora_from_base_params(base_params, adapter_params, lora_scale_factor) def quantize_params(self, state, rng: PRNGKeyType | None = None): """Forward pass to quantize decode params.""" if rng is None: rng = jax.random.PRNGKey(0) + if self.config.pure_nnx: + # NNX takes a different code path: convert-on-load lives in `_load_params_nnx` + # via `_convert_and_quantize_nnx`, which runs the dummy forward against a + # CONVERT-mode model and transfers `qrhs.frozen` into the SERVE model. + # The standalone `quantize_params(state, rng)` API expects a Linen-shape + # `state.params` dict and isn't reachable on the NNX pathway in maxengine + # (load_params already dispatched to _load_params_nnx). + raise NotImplementedError( + "Use load_params() on NNX — the convert step runs inside _load_params_nnx via " + "_convert_and_quantize_nnx. quantize_params(state, rng) is the Linen API." + ) self.model.quant.quant_mode = quantizations.get_quant_mode("convert") @@ -486,7 +695,10 @@ def _prefill_jit( if existing_prefix is not None: if not self.use_chunked_prefill: raise ValueError("Using chunked prefill is needed for existing_prefix.") - input_params = params | {"cache": existing_prefix.cache} + # NNX threads existing_prefix.cache via the nnx_cache local below; only + # the Linen path merges cache into input_params (params is a dict there). + if not self.config.pure_nnx: + input_params = params | {"cache": existing_prefix.cache} start_position = existing_prefix.common_prefix_tokens.shape[0] # TODO(yuyanpeng): rename previous_chunk previous_chunk = jnp.expand_dims(existing_prefix.common_prefix_tokens, 0) @@ -518,24 +730,48 @@ def _prefill_jit( sequence_indicator = jnp.expand_dims(one_d_output, 0) rng, new_rng = jax.random.split(rng) - with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - flat_logits, new_vars = self.model.apply( - input_params, - input_tokens, - positions, - encoder_images=images, - encoder_image_masks=image_masks, - encoder_audios=audio_values, - decoder_segment_ids=sequence_indicator, - enable_dropout=False, - model_mode=MODEL_MODE_PREFILL, - rngs={"params": new_rng}, - mutable=["cache"], - previous_chunk=previous_chunk, - true_length=true_length, - slot=slot, - page_state=page_state, + if self.config.pure_nnx: + # Prefill always operates on batch=1 (one padded prompt at a time). + nnx_cache = ( + existing_prefix.cache if existing_prefix is not None else self._nnx_init_cache_dict(mode=MODEL_MODE_PREFILL) ) + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + flat_logits, new_cache_dict = self._nnx_run_model( + params=input_params, + cache_dict=nnx_cache, + decoder_input_tokens=input_tokens, + decoder_positions=positions, + decoder_segment_ids=sequence_indicator, + encoder_images=images, + encoder_image_masks=image_masks, + encoder_audios=audio_values, + enable_dropout=False, + model_mode=MODEL_MODE_PREFILL, + previous_chunk=previous_chunk, + true_length=true_length, + slot=slot, + page_state=page_state, + ) + new_vars = {"cache": new_cache_dict} + else: + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + flat_logits, new_vars = self.model.apply( + input_params, + input_tokens, + positions, + encoder_images=images, + encoder_image_masks=image_masks, + encoder_audios=audio_values, + decoder_segment_ids=sequence_indicator, + enable_dropout=False, + model_mode=MODEL_MODE_PREFILL, + rngs={"params": new_rng}, + mutable=["cache"], + previous_chunk=previous_chunk, + true_length=true_length, + slot=slot, + page_state=page_state, + ) if return_prompt_logp: prompt_logp = inference_utils.prompt_logprobs_from_prefill(flat_logits, input_tokens, true_length) else: @@ -744,6 +980,9 @@ def _prefill_multisampling_jit( prefilling stage. The number of tokens is specified by num_samples. """ + if self.config.pure_nnx: + raise NotImplementedError("pure_nnx + prefill_multisampling not yet supported. Use pure_nnx=False.") + input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] positions = jnp.expand_dims(jnp.arange(0, input_tokens.shape[1]), 0) @@ -869,6 +1108,9 @@ def prefill_concat( if existing_prefix: raise ValueError("We don't know what to do with existing_prefix") + if self.config.pure_nnx: + raise NotImplementedError("pure_nnx + prefill_concat not yet supported. Use pure_nnx=False.") + if rng is None: rng = jax.random.PRNGKey(0) input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] @@ -1038,17 +1280,30 @@ def _generate_jit( previous_token = decode_state["tokens"] rng, new_rng = jax.random.split(rng) # run one step generation - with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - out_logits, new_vars = self.model.apply( - params | {"cache": decode_state["cache"]}, - previous_token, - decode_state["next_pos"], - enable_dropout=False, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - rngs={"params": new_rng}, - mutable=["cache"], - page_state=page_state, - ) + if self.config.pure_nnx: + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + out_logits, new_cache_dict = self._nnx_run_model( + params=params, + cache_dict=decode_state["cache"], + decoder_input_tokens=previous_token, + decoder_positions=decode_state["next_pos"], + enable_dropout=False, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + page_state=page_state, + ) + new_vars = {"cache": new_cache_dict} + else: + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + out_logits, new_vars = self.model.apply( + params | {"cache": decode_state["cache"]}, + previous_token, + decode_state["next_pos"], + enable_dropout=False, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + rngs={"params": new_rng}, + mutable=["cache"], + page_state=page_state, + ) out_logits = jax.lax.with_sharding_constraint(out_logits, self.replicated_sharding) new_cache = jax.lax.with_sharding_constraint(new_vars["cache"], self.kv_cache_shardings) # sampling tokens @@ -1606,6 +1861,9 @@ def init_decode_state( if self.config.attention == "paged" and self.page_manager is not None: page_state = self.page_manager.get_initial_page_state() # pytype: disable=attribute-error + if self.config.pure_nnx: + return self._init_decode_state_nnx(rng=rng, page_state=page_state) + # pylint: disable=unused-argument def init(abstract_params, page_state): x = jnp.ones( @@ -1699,6 +1957,51 @@ def is_lp(k): zeroed = max_utils.unbox_logicallypartioned(init_state) return zeroed + def _init_decode_state_nnx(self, rng, page_state) -> DecodeState: + """NNX equivalent of init_decode_state. Returns a decode_state dict with a pure-dict cache.""" + del rng, page_state # cache shape comes from the abstract model + batch = int(self.config.per_device_batch_size * self.mesh.size) + vocab = self.config.vocab_size + + # AR-mode cache so the batch dim matches generate's input shape. + cache_dict_abs = self._nnx_init_cache_dict(mode=MODEL_MODE_AUTOREGRESSIVE) + + @functools.partial(jax.jit, out_shardings=(self.kv_cache_shardings,)) + def _init_cache(): + return (jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), cache_dict_abs),) + + (cache,) = _init_cache() + + # Per-leaf logical axes for bulk_insert's "cache_batch" lookup. Use model_ar + # so segment_id leaves carry CACHE_BATCH (under PREFILL they'd carry + # CACHE_BATCH_PREFILL, which doesn't contain the "cache_batch" substring). + _, cache_state, _ = nnx.split(self.model_ar, nnx.Cache, ...) + + def _logical_axes_for(var): + # Flax 0.12.6 renamed "sharding" to "out_sharding"; older code may still + # use "sharding_names". Try all three. + meta = var.get_metadata() if hasattr(var, "get_metadata") else {} + out = meta.get("out_sharding") or meta.get("sharding") or meta.get("sharding_names") + if out is None: + return () + return (out,) if isinstance(out, str) else tuple(out) + + annotations_state = jax.tree.map( + _logical_axes_for, + cache_state, + is_leaf=lambda v: isinstance(v, nnx.Variable), + ) + self.kv_cache_annotations_named = annotations_state.to_pure_dict() + + return { + "logits": jnp.zeros((batch, 1, vocab), dtype=jnp.float32), + "cache": cache, + "next_pos": jnp.zeros((batch, 1), dtype=jnp.int32), + "generated_tokens": jnp.zeros((batch, 1), dtype=jnp.int32), + "tokens": jnp.zeros((batch, 1), dtype=jnp.int32), + "token_logp": jnp.zeros((batch, 1), dtype=jnp.float32), + } + @property def max_concurrent_decodes(self) -> int: """Free slots.""" diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 509e1ef7d3..66215fe011 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -525,14 +525,14 @@ def __init__( elif self.is_qwen3_hybrid: self.query_norm = Qwen3NextRMSNorm( num_features=self.config.head_dim, - eps=self.config.normalization_layer_epsilon, + epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, rngs=self.rngs, ) self.key_norm = Qwen3NextRMSNorm( num_features=self.config.head_dim, - eps=self.config.normalization_layer_epsilon, + epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, rngs=self.rngs, diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index e23c3eba9f..48d1f78108 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -2242,8 +2242,8 @@ def __call__( w0_kernel = jnp.asarray(self.wi_0[...], self.dtype) w1_kernel = jnp.asarray(self.wi_1[...], self.dtype) - if self.per_expert_scale is not None: - wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None] + if self.per_expert_scale is not None: + wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None] if self.wi_0_sparsity_module is not None: _, w0_kernel = self.wi_0_sparsity_module(jnp.zeros_like(w0_kernel), w0_kernel) diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 4c6a3fb9c5..5d18b91ba6 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -35,6 +35,7 @@ MODEL_MODE_TRAIN, Config, DecoderBlockType, + MultimodalInput, ShardMode, ) from maxtext.inference import page_manager @@ -543,8 +544,16 @@ def pure_layer_fn(state_in, y_in): out = merged_layer(y_in, **kwargs) return out, nnx.state(merged_layer) - checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) - out, new_state = checkpointed_fn(state, y) + # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen + # mutable scope. jax.checkpoint re-traces the scan body during backward (remat), + # but the Linen scope retains JAX tracers from the first trace, causing + # UnexpectedTracerError. Skip checkpoint for these quantization types. + uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu") + if uses_linen_fp8_mutable_state: + out, new_state = pure_layer_fn(state, y) + else: + checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) + out, new_state = checkpointed_fn(state, y) nnx.update(layer, new_state) return out @@ -623,13 +632,12 @@ def layer_fn(carry, scanned_vars): return new_carry, (new_current_state, updated_kv) return new_carry, new_current_state - layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) - if use_kv: # If kv_caches is provided (e.g., from vLLM), we CANNOT use jax.lax.scan # because scanning requires stacking the kv_caches list, which creates a copy # and breaks the in-place memory updates required by vLLM's PagedAttention. # Therefore, we must unroll the loop statically when kv_caches is provided. + layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) # kv_caches_stacked is actually the original kv_caches list in this new flow kv_caches_list = kv_caches_stacked @@ -651,7 +659,24 @@ def layer_fn(carry, scanned_vars): # inference with vLLM, parameters do not change and we don't need intermediates. return current_carry, layers, None else: - final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state)) + # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen + # mutable scope. jax.lax.scan traces the body function and Linen's setup() creates + # intermediate tracer values (amax_history float32[1024]) that escape the scan scope, + # causing UnexpectedTracerError. Use a Python for loop instead for these types. + uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu") + if uses_linen_fp8_mutable_state: + carry = x_in + per_layer_states = [] + for i in range(length): + current_params = jax.tree.map(lambda x, i=i: x[i], params) + current_state = jax.tree.map(lambda x, i=i: x[i], state) + carry, new_state_i = layer_fn(carry, (current_params, current_state)) + per_layer_states.append(new_state_i) + final_carry = carry + scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states) + else: + layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) + final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state)) returned_kv_stacked = None if scan_axis != 0: @@ -937,7 +962,10 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode): if cfg.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. if isinstance(shared_embedding, nnx.Module): - embedding_table = shared_embedding.embedding.value + # Modern NNX API; the deprecated `.value` shim registers the access in NNX's + # mutation tracking, which JAX detects as a tracer leak when the embedding is + # closure-captured across a custom_vjp boundary (e.g. vocab_tiling_nnx_loss). + embedding_table = shared_embedding.embedding[...] else: embedding_table = shared_embedding.variables["params"]["embedding"] if isinstance(embedding_table, nn.spmd.LogicallyPartitioned): @@ -1061,10 +1089,10 @@ def __call__( previous_chunk=None, slot: None | int = None, page_state: None | page_manager.PageState = None, - multimodal_input: None | Any = None, kv_caches: list[jax.Array] | None = None, attention_metadata=None, deepstack_visual_embeds: None | list[jnp.ndarray] = None, + multimodal_input: None | MultimodalInput = None, ): cfg = self.config assert decoder_input_tokens.ndim == 2 # [batch, len] diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index d41d924456..24ebecd492 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -26,6 +26,7 @@ from flax.core import FrozenDict from flax.core import meta from flax.nnx import graph +from flax.nnx import tracers as nnx_tracers from flax.nnx import variablelib from flax.nnx.bridge import module as bdg_module from flax.nnx.module import Module @@ -167,6 +168,39 @@ def current_linen_module() -> linen.Module | None: return None +def is_linen_initializing() -> bool: + """Check if the current execution context is inside a Linen init() call. + + Returns True when called from within a ``to_linen_class`` wrapper's + ``init()`` path. Uses :func:`current_linen_module` to access the Linen + module stack (private API already used by this module). + + This is used by NNX pipeline modules to short-circuit the full scan + during Linen init, where only the output shape/dtype is needed. + """ + module = current_linen_module() + if module is not None and hasattr(module, "is_initializing") and callable(module.is_initializing): + return module.is_initializing() + return False + + +def _refresh_variable_trace_state(module: Module) -> None: + """Refresh _trace_state for Variables that have stale trace state. + + When nnx.update() is called with tracer values from a JAX transformation + (e.g. jax.grad's LinearizeTracer), it uses _unsafe_bypass_check=True which + updates the raw value but not _trace_state. This leaves Variables with a + stale _trace_state from the outer (Python) context, causing nnx.split() to + fail with "Cannot extract graph node from different trace level" errors. + + This function resets _trace_state on any Variables whose _can_update is False + so that downstream NNX operations (e.g. nnx.split in NNXPipeline) succeed. + """ + for _, v in nnx.graph.iter_graph(module): + if isinstance(v, variablelib.Variable) and not v._can_update: # pylint: disable=protected-access + object.__setattr__(v, "_trace_state", nnx_tracers.TraceState()) + + class ToNNX(Module): """A wrapper to turn any Linen module into an NNX module. @@ -464,6 +498,7 @@ def maybe_unbox(x): warnings.warn(f"Found unknown module paths in incoming state:{paths_str}") nnx.update(module, new_state) + _refresh_variable_trace_state(module) _fix_for_qwix_quantization(module) method_fn = _get_module_method(module, nnx_method) diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index bf91262bf1..35611b2166 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -114,7 +114,17 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> return y_flat.reshape(input_shape) -def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs): +def Qwen3NextRMSNorm( + num_features: int, + epsilon: float = 1e-6, + dtype: DType = None, + weight_dtype: DType = None, + shard_mode=None, + kernel_axes=None, + parameter_memory_host_offload=None, + *, + rngs: nnx.Rngs, +): """ Used for input and post attention layernorms in Qwen3NextDecoderLayer. @@ -127,7 +137,7 @@ def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: return nnx.data( RMSNorm( num_features=num_features, - epsilon=eps, + epsilon=epsilon, dtype=dtype, weight_dtype=weight_dtype, scale_init=linen_initializers.zeros, diff --git a/src/maxtext/layers/train_state_nnx.py b/src/maxtext/layers/train_state_nnx.py index 9ef0e6dffd..3f9ee1ce29 100644 --- a/src/maxtext/layers/train_state_nnx.py +++ b/src/maxtext/layers/train_state_nnx.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" The NNX Unified TrainState. """ +"""The NNX Unified TrainState.""" from typing import Any @@ -25,20 +25,34 @@ class TrainStateNNX(nnx.Module): This replaces Linen's TrainState for checkpointing. Linen TrainState pytree: - {“params”: {...}, “opt_state”: {}...} + {"params": {...}, "opt_state": {}...} TrainStateNNX state pytree: - {“model”: {...}, “optimizer”: {“opt_state”: {...}} + {"model": {...}, "optimizer": {"opt_state": {...}}} + + For DPO (Direct Preference Optimization), an optional `reference_model` + carries a frozen copy of the same architecture used to compute reference + log-probabilities. Only `model` is updated by `apply_gradients`; the + reference is held alongside so it is sharded, jit-traced, and checkpointed + with the rest of the train state. """ - def __init__(self, model: nnx.Module, optimizer: nnx.Optimizer | None): + def __init__( + self, + model: nnx.Module, + optimizer: nnx.Optimizer | None, + reference_model: nnx.Module | None = None, + ): self.model = model self.optimizer = optimizer + if reference_model is not None: + self.reference_model = reference_model def apply_gradients(self, grads: Any): """ Mimics the Linen apply_gradients function. Updates the optimizer state, applies updates to parameters, - and increments the step counter. + and increments the step counter. Only updates `self.model`; + `self.reference_model` (if present) is left untouched. """ if self.optimizer is None: raise RuntimeError( diff --git a/src/maxtext/models/gpt3.py b/src/maxtext/models/gpt3.py index 2736b8aafb..8a34e8395e 100644 --- a/src/maxtext/models/gpt3.py +++ b/src/maxtext/models/gpt3.py @@ -28,6 +28,7 @@ from flax import nnx from maxtext.common.common_types import Config, DType, AxisNames, BATCH, LENGTH, EMBED, HEAD, D_KV, Array, MODEL_MODE_TRAIN +from maxtext.inference import kvcache from maxtext.layers import initializers, nnx_wrappers from maxtext.layers.linears import DenseGeneral, MlpBlock, canonicalize_tuple, normalize_axes from maxtext.layers import quantizations @@ -235,6 +236,7 @@ def __init__( self.key_axis_names = key_axis_names self.value_axis_names = value_axis_names self.out_axis_names = out_axis_names + self.model_mode = model_mode self.rngs = rngs if self.fused_qkv: self.qkv_proj = self.create_projection_layer( @@ -252,6 +254,7 @@ def __init__( mesh=self.mesh, attention_kernel=self.attention_kernel, max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, float32_qk_product=self.float32_qk_product, float32_logits=self.float32_logits, quant=self.quant, @@ -260,6 +263,30 @@ def __init__( num_kv_heads=self.num_heads, dtype=self.dtype, ) + # KV cache only matters in non-TRAIN modes. Mirrors Attention.__init__ in + # attentions.py so prefill / autoregressive get a real KVCache_0 module + # whose update_kv_caches() builds the cached_values tuple that + # AttentionOp.__call__ requires. + batch_size, _ = max_utils.get_batch_seq_len_for_mode(config, model_mode) + self.KVCache_0 = ( + kvcache.KVCache( + max_prefill_length=self.max_prefill_predict_length, + max_target_length=self.max_target_length, + batch=batch_size, + key_seq_len=1, + value_seq_len=1, + key_heads=self.num_heads, + value_heads=self.num_heads, + key_head_size=self.head_dim, + value_head_size=self.head_dim, + dtype=self.dtype, + kv_quant=self.kv_quant, + model_mode=model_mode, + rngs=self.rngs, + ) + if model_mode != MODEL_MODE_TRAIN + else None + ) def create_projection_layer( self, @@ -328,7 +355,18 @@ def __call__( value = nn.with_logical_constraint(value, self.value_axis_names) value = checkpoint_name(value, "value_proj") - out = self.attention_op(query, key, value, decoder_segment_ids, None, model_mode) + cached_values = [None, None] + if model_mode != MODEL_MODE_TRAIN and self.KVCache_0 is not None: + prefill_kv_cache, ar_kv_cache = self.KVCache_0( + key=key, + value=value, + decoder_segment_ids=decoder_segment_ids, + model_mode=model_mode, + use_ragged_attention=False, + previous_chunk=None, + ) + cached_values = [prefill_kv_cache, ar_kv_cache] + out = self.attention_op(query, key, value, decoder_segment_ids, None, model_mode, cached_values) out = nn.with_logical_constraint(out, self.out_axis_names) diff --git a/src/maxtext/models/gpt_oss.py b/src/maxtext/models/gpt_oss.py index 9401d01d9f..5f4a2f3fb6 100644 --- a/src/maxtext/models/gpt_oss.py +++ b/src/maxtext/models/gpt_oss.py @@ -29,6 +29,7 @@ from maxtext.common.common_types import AttentionType, Config from maxtext.layers import attentions from maxtext.layers import initializers +from maxtext.layers import linears from maxtext.layers import moe from maxtext.layers import nnx_wrappers from maxtext.layers import quantizations @@ -132,6 +133,8 @@ def __init__( rngs=rngs, ) + self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) + def __call__( self, inputs, @@ -189,7 +192,7 @@ def __call__( mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, diff --git a/src/maxtext/models/llama2.py b/src/maxtext/models/llama2.py index 6a215c5dbe..244eed03bb 100644 --- a/src/maxtext/models/llama2.py +++ b/src/maxtext/models/llama2.py @@ -71,6 +71,7 @@ def __init__( shard_mode=config.shard_mode, kernel_axes=("norm",), epsilon=config.normalization_layer_epsilon, + parameter_memory_host_offload=config.parameter_memory_host_offload, rngs=rngs, ) diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index 1b0d4b4cd3..3a884af799 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -347,7 +347,6 @@ def __init__( else: decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs) - self.hidden_states = None batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode) dummy_decoder_input_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32) @@ -398,6 +397,19 @@ def no_op(self, *args, **kwargs): """A no-op method to allow the model to be used in a lazy context.""" return + def logits_from_hidden_states(self, hidden_states, deterministic, model_mode): + """Compute logits from hidden states (wraps NNXDecoder.apply_output_head). + + Mirrors the Linen TransformerLinenPure.logits_from_hidden_states method; + used by vocabulary tiling to recompute logits from chunked hidden states. + """ + return self.decoder.apply_output_head( + shared_embedding=self.token_embedder, + y=hidden_states, + deterministic=deterministic, + model_mode=model_mode, + ) + def init_cache(self, cache_size: int, batch_size: int, dtype=jnp.float32): """Initializes the KV cache for the Transformer. @@ -532,10 +544,6 @@ def __call__( mutable=mutable_collections, ) # pytype: disable=wrong-keyword-args - # Materialize hidden state when vocab tiling is enabled - if self.config.num_vocab_tiling > 1: - self.hidden_states = hidden_state - # If we are initializing the model AND MTP is enabled, we must create # dummy target tensors. This allows Flax to trace the MTPBlock and create # all its necessary parameters, without requiring the main training pipeline diff --git a/src/maxtext/models/olmo3.py b/src/maxtext/models/olmo3.py index 09c5b4e079..b743e8d4b7 100644 --- a/src/maxtext/models/olmo3.py +++ b/src/maxtext/models/olmo3.py @@ -30,6 +30,7 @@ from maxtext.common.common_types import AttentionType, Config from maxtext.layers import attentions from maxtext.layers import initializers +from maxtext.layers import linears from maxtext.layers import nnx_wrappers from maxtext.layers import quantizations from maxtext.layers.attentions import Attention @@ -142,6 +143,7 @@ def __init__( model_mode=model_mode, rngs=rngs, ) + self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) def __call__( self, @@ -202,7 +204,7 @@ def __call__( mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index bd65f04438..87cb4cc7ef 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -966,7 +966,7 @@ def __init__( # First LayerNorm, applied before the attention block. self.input_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, @@ -991,7 +991,7 @@ def __init__( # Second LayerNorm, applied before the MoE block. self.post_attention_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, diff --git a/src/maxtext/models/qwen3_5.py b/src/maxtext/models/qwen3_5.py index b25ecf09e8..143bf63a07 100644 --- a/src/maxtext/models/qwen3_5.py +++ b/src/maxtext/models/qwen3_5.py @@ -139,7 +139,7 @@ def __init__( # First LayerNorm, applied before the attention block. self.input_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, @@ -164,7 +164,7 @@ def __init__( # Second LayerNorm, applied before the MoE block. self.post_attention_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, diff --git a/src/maxtext/optimizers/optimizers.py b/src/maxtext/optimizers/optimizers.py index 2ae7e5f8e5..9992d7674f 100644 --- a/src/maxtext/optimizers/optimizers.py +++ b/src/maxtext/optimizers/optimizers.py @@ -336,7 +336,9 @@ def _update_momentum(update, mu, nu): else: updates = jax.tree_util.tree_map(lambda x, v: x + weight_decay * v, updates, params) - step_size = -1.0 * learning_rate_fn(count) + # learning_rate_fn may be a callable schedule or a scalar (e.g. when wrapped + # by optax.inject_hyperparams, it is passed as a pre-evaluated scalar). + step_size = -1.0 * (learning_rate_fn(count) if callable(learning_rate_fn) else learning_rate_fn) # Finally, fold in step size. updates = jax.tree_util.tree_map(lambda x: step_size * x, updates) diff --git a/src/maxtext/trainers/diloco/diloco.py b/src/maxtext/trainers/diloco/diloco.py index a9ef64631a..39d84a89dc 100644 --- a/src/maxtext/trainers/diloco/diloco.py +++ b/src/maxtext/trainers/diloco/diloco.py @@ -26,6 +26,7 @@ from typing import Any, Callable import drjax +from flax import nnx from flax import struct from flax.training import train_state import jax @@ -153,7 +154,15 @@ def add_diloco_dim(x): momentum=config.diloco_outer_momentum, nesterov=True, ) - outer_opt_state = jax.eval_shape(outer_optimizer.init, abstract_state.params) + # For NNX, model params (Param variables only) live under abstract_state.model; + # for Linen under abstract_state.params. + if config.pure_nnx: + model_params = abstract_state.model.filter(nnx.Param) + model_params_sharding = state_mesh_shardings.model.filter(nnx.Param) + else: + model_params = abstract_state.params + model_params_sharding = state_mesh_shardings.params + outer_opt_state = jax.eval_shape(outer_optimizer.init, model_params) # Create abstract step abstract_step = jax.ShapeDtypeStruct((), jnp.int32) @@ -161,7 +170,7 @@ def add_diloco_dim(x): # Build abstract DiLoCo state diloco_state = DiLoCoTrainState( inner_state=inner_state, - params=abstract_state.params, + params=model_params, outer_opt_state=outer_opt_state, step=abstract_step, ) @@ -171,12 +180,12 @@ def add_diloco_dim(x): # Sharding for outer_opt_state. For SGD with momentum, it is (TraceState(trace=...), EmptyState()) # We shard the momentum trace the same way as the parameters. outer_opt_state_sharding = ( - optax.TraceState(trace=state_mesh_shardings.params), + optax.TraceState(trace=model_params_sharding), optax.EmptyState(), ) diloco_state_shardings = DiLoCoTrainState( inner_state=inner_state_shardings, - params=state_mesh_shardings.params, + params=model_params_sharding, outer_opt_state=outer_opt_state_sharding, step=None, ) @@ -205,11 +214,15 @@ def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]: # mesh automatically when jax.set_mesh is used. inner_state = drjax.broadcast(state, mesh=mesh) # Outer state retains a single copy of the model parameters and optimizer state. - outer_params = state.params + # For NNX, model params (Param variables only) live under state.model; + # for Linen under state.params. + outer_params = state.model.filter(nnx.Param) if config.pure_nnx else state.params outer_opt_state = outer_optimizer.init(outer_params) outer_opt_state_sharding = jax.tree_util.tree_map(lambda x: x.sharding, outer_opt_state) + # For NNX, the step counter lives at state.optimizer.step; for Linen at state.step. + step = state.optimizer.step if config.pure_nnx else state.step return ( - DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=state.step), + DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=step), outer_opt_state_sharding, ) @@ -244,7 +257,11 @@ def synchronize(state): # Calculate the delta between the current replica's state and the global # state (since last synchronization). broadcast_outer_params = drjax.broadcast(state.params, mesh=mesh) - model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params) + # For NNX, model Param vars live under inner_state.model; for Linen under inner_state.params. + inner_model_params = ( + nnx.filter_state(state.inner_state.model, nnx.Param) if config.pure_nnx else state.inner_state.params + ) + model_delta = jax.tree.map(lambda x, y: y - x, inner_model_params, broadcast_outer_params) # Treat the average delta as the outer optimizer's gradient and apply to # the global (outer) model params. averaged_pseudo_grad = drjax.reduce_mean(model_delta) @@ -253,7 +270,27 @@ def synchronize(state): # Replace inner model params with the new global model params. # NOTE: inner optimizer state is retained despite the change in parameters, # see section 6.1 in https://arxiv.org/pdf/2311.08105. - new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state, mesh=mesh) + if config.pure_nnx: + # For NNX: merge new Param vars back with the non-Param model vars (e.g. RNG state). + def replace_nnx_model_params(s, new_params): + non_param_model = nnx.filter_state(s.model, nnx.Not(nnx.Param)) + new_model = nnx.merge_state(non_param_model, new_params) + # Build result via __setitem__ so nested States are stored as plain dicts + # internally, matching the pytree structure produced by nnx.state(). + # (Passing State objects via the constructor dict literal stores them + # as-is, causing jax.lax.cond to see mismatched pytree structures.) + result = type(s)({}) + result["model"] = new_model + result["optimizer"] = s["optimizer"] + return result + + new_inner_state = drjax.map_fn( + lambda s: replace_nnx_model_params(s, new_outer_params), + state.inner_state, + mesh=mesh, + ) + else: + new_inner_state = drjax.map_fn(lambda s: s.replace(params=new_outer_params), state.inner_state, mesh=mesh) return state.replace( params=new_outer_params, outer_opt_state=new_opt_state, @@ -271,14 +308,16 @@ def diloco_train_step(state, batch, prng): broadcast_rng = drjax.broadcast(prng, mesh=mesh) inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng), mesh=mesh) avg_metrics = typed_reduce_mean(metrics) + # For NNX, the step counter lives at inner_state.optimizer.step; for Linen at inner_state.step. + new_step = inner_state.optimizer.step[0] if config.pure_nnx else inner_state.step[0] state = state.replace( inner_state=inner_state, - step=inner_state.step[0], + step=new_step, ) # Either synchronize the model, or no-op, depending on whether the current # step falls on the synchronization period. state = jax.lax.cond( - inner_state.step[0] % config.diloco_sync_period == 0, + new_step % config.diloco_sync_period == 0, synchronize, lambda x: x, # no-op state, diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 1a66a532fb..40b866d415 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -273,30 +273,45 @@ def wrt_filter(path, x): # Inherits _shard_optimizer from PeftTrainer. def _train_step(self, model, optimizer, inputs): - """Overrides the main JIT block to natively handle ModelBundle module.""" + """Overrides the main JIT block to natively handle ModelBundle module. + Uses jax.value_and_grad with explicit split/merge to avoid nesting + nnx.value_and_grad inside nnx.jit, which causes Flax NNX to assign + conflicting outer_index values and raises: + ValueError: The graph structure of a node added to cached_partial was + mutated inside the transformation. + """ batch = self.gen_model_input_fn(inputs) + student = model.student_model + teacher = model.teacher_model current_step = model.training_step[...] - def loss_wrapper(student, teacher, batch): - if "teacher_output" in batch: - teacher_output = batch["teacher_output"] - else: - teacher_output = self.strategy.teacher_forward_fn( - model=teacher, - input_tokens=batch["input_tokens"], - positions=batch["positions"], - attention_mask=batch.get("attention_mask"), - decoder_segment_ids=batch.get("decoder_segment_ids"), - decoder_target_tokens=batch.get("targets", None), - decoder_target_mask=batch.get("targets_segmentation", None), - cache=None, - ) + # Run teacher inference outside of value_and_grad. + # The teacher is frozen (stop_gradient), so its output is a constant + # from the perspective of the student gradient computation. + if "teacher_output" in batch: + teacher_output = batch["teacher_output"] + else: + teacher_output = self.strategy.teacher_forward_fn( + model=teacher, + input_tokens=batch["input_tokens"], + positions=batch["positions"], + attention_mask=batch.get("attention_mask"), + decoder_segment_ids=batch.get("decoder_segment_ids"), + decoder_target_tokens=batch.get("targets", None), + decoder_target_mask=batch.get("targets_segmentation", None), + cache=None, + ) + teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output) - teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output) + # Split student into differentiable params and non-differentiable rest. + # Capture graphdef outside of jax.value_and_grad for stable graph tracking. + student_graphdef, diff_params, rest = nnx.split(student, self.wrt_filter, ...) + def loss_wrapper_pure(diff_params, rest): + local_student = nnx.merge(student_graphdef, diff_params, rest, copy=True) student_output = self.strategy.student_forward_fn( - model=student, + model=local_student, input_tokens=batch["input_tokens"], positions=batch["positions"], attention_mask=batch.get("attention_mask"), @@ -305,29 +320,26 @@ def loss_wrapper(student, teacher, batch): decoder_target_mask=batch.get("targets_segmentation", None), cache=None, ) - # we should apply a mask for labels to disable segment-separator tokens labels = self.strategy.create_labels(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None)) - return self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step) - - # Because student is the 0th argument, argnums=0 guarantees - # we only compute gradients for the student. - grad_fn = nnx.value_and_grad( - loss_wrapper, - argnums=nnx.DiffState(0, self.wrt_filter), - has_aux=True, - ) + loss, aux = self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step) + # Capture updated non-param state (e.g. RNG counters) from local_student. + _, _, new_rest = nnx.split(local_student, self.wrt_filter, ...) + return loss, (aux, new_rest) - out, grads = grad_fn(model.student_model, model.teacher_model, batch) + grad_fn = jax.value_and_grad(loss_wrapper_pure, argnums=0, has_aux=True) + (loss, (aux, new_rest)), grads = grad_fn(diff_params, rest) - model.training_step.set_value(current_step + 1) + # Propagate updated non-param state back to student. + nnx.update(student, new_rest) - tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True) + optimizer.update(student, grads) - optimizer.update(model.student_model, grads) + model.training_step.set_value(current_step + 1) + tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True) if tunix_expects_grad_norm: - return out[0], out[1], optax.global_norm(grads) - return out[0], out[1] + return loss, aux, optax.global_norm(grads) + return loss, aux def _eval_step(self, model, inputs): """Evaluation only needs the student.""" diff --git a/src/maxtext/trainers/post_train/dpo/dpo_utils.py b/src/maxtext/trainers/post_train/dpo/dpo_utils.py index eeda1c1a7f..fd5faa5c9c 100644 --- a/src/maxtext/trainers/post_train/dpo/dpo_utils.py +++ b/src/maxtext/trainers/post_train/dpo/dpo_utils.py @@ -19,6 +19,8 @@ import jax import jax.numpy as jnp +from flax import nnx + from maxtext.utils import maxtext_utils @@ -148,6 +150,8 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t "total_weights": total_weights, "moe_lb_loss": moe_lb_loss, "reward_accuracy": reward_accuracy, + "indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility + "mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility } return loss, aux @@ -155,3 +159,138 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t def _merge_dpo_state(state, reference_params): """Merge reference parameters back into DPO state.""" return state.replace(params=dict(state.params, reference_params=reference_params)) + + +# NNX DPO has no split/merge counterpart: the Linen path overlays +# `reference_params` inside `state.params`, so it must be peeled off and +# reattached around `apply_gradients`. The NNX path holds the reference as a +# sibling field `TrainStateNNX.reference_model`; `apply_gradients` already +# only touches `self.model`, so no split/merge is needed. + + +def dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True): + """NNX DPO loss_fn for both train and eval. + + Signature mirrors the Linen `dpo_loss_fn` so it slots into the same + dispatcher in `gradient_accumulation_loss_and_grad`: + `(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True)` + + Differences from the Linen `dpo_loss_fn`: + * `policy_model` is an `nnx.Module` (carries its own params + RNG state). + * `dropout_rng` and `params` are unused for NNX (kept positional for + signature parity; NNX models manage these internally). + * The 6th arg (the `extra_dpo_args[0]`) is a frozen reference + `nnx.Module`, not a `reference_params` pytree. + * Reference forward is wrapped in `jax.lax.stop_gradient`; combined with + `nnx.value_and_grad(..., argnums=0)` over the policy, no gradient flows + to the reference's `nnx.Param` leaves. + + Args: + policy_model: Policy `nnx.Module` (the model being trained). + config: Config of parameters. + data: Batch of preference data with `chosen` / `rejected` fields. + dropout_rng: Unused for NNX (kept for signature parity with Linen). + params: Unused for NNX (kept for signature parity with Linen). + reference_model: Frozen reference `nnx.Module` for DPO logratio computation. + is_train: True for train_step and False for eval_step. + + Returns: + loss: DPO preference loss + MoE load balance loss (if applicable). + aux: dict with intermediate_outputs, xent_sum (always 0.0), dpo_loss, + total_weights, moe_lb_loss, reward_accuracy. + """ + del dropout_rng, params # unused for NNX + # decimate proportion of data when per_device_batch_size<1 + if is_train: + for k, v in data.items(): + data[k] = v[: config.micro_batch_size_to_train_on, :] + + # for DPO we don't support packed sequences (they shouldn't be present in the first place) + data["chosen_segmentation"] = (data["chosen_segmentation"] == 1).astype(jnp.int32) + data["rejected_segmentation"] = (data["rejected_segmentation"] == 1).astype(jnp.int32) + data["chosen_position"] = data["chosen_position"] * (data["chosen_segmentation"] == 1) + data["rejected_position"] = data["rejected_position"] * (data["rejected_segmentation"] == 1) + + # concatenated policy/reference forward pass + inputs = jnp.concatenate([data["chosen"], data["rejected"]], 0) + inputs_position = jnp.concatenate([data["chosen_position"], data["rejected_position"]], 0) + inputs_segmentation = jnp.concatenate([data["chosen_segmentation"], data["rejected_segmentation"]], 0) + + logits = policy_model( + decoder_input_tokens=inputs, + decoder_positions=inputs_position, + decoder_segment_ids=inputs_segmentation, + enable_dropout=config.enable_dropout if is_train else False, + ) + intermediate_outputs = nnx.state(policy_model, nnx.Intermediate).to_pure_dict() + + ref_logits = reference_model( + decoder_input_tokens=inputs, + decoder_positions=inputs_position, + decoder_segment_ids=inputs_segmentation, + enable_dropout=False, + ) + ref_logits = jax.lax.stop_gradient(ref_logits) + + # extract token ids, segmentation and logits for chosen and rejected sequences + chosen_ids = data["chosen"][..., 1:] + rejected_ids = data["rejected"][..., 1:] + chosen_segmentation = data["chosen_segmentation"][..., 1:] + rejected_segmentation = data["rejected_segmentation"][..., 1:] + n_logits = logits.shape[-3] // 2 # [B, S, E] - [batch, sequence, embedding/vocab] + chosen_logits, rejected_logits = logits[:n_logits, :, :], logits[n_logits:, :, :] + chosen_ref_logits, rejected_ref_logits = ref_logits[:n_logits, :, :], ref_logits[n_logits:, :, :] + + # common subsequence and padding mask + common_prefix_mask = jnp.cumsum(chosen_ids != rejected_ids, axis=-1) == 0 # [B, S] + valid_seq_mask = (chosen_segmentation != 0) & (rejected_segmentation != 0) & ~common_prefix_mask # [B, S] + + # compute logratios from the sequence-reduced observed token log-probability + chosen_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(chosen_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1 + )[..., 0] + chosen_logps = jnp.sum(chosen_logps_seq * valid_seq_mask, axis=-1) # [B] + chosen_ref_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(chosen_ref_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1 + )[..., 0] + chosen_ref_logps = jnp.sum(chosen_ref_logps_seq * valid_seq_mask, axis=-1) # [B] + chosen_logratios = chosen_logps - chosen_ref_logps # [B] + + rejected_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(rejected_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1 + )[..., 0] + rejected_logps = jnp.sum(rejected_logps_seq * valid_seq_mask, axis=-1) # [B] + rejected_ref_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(rejected_ref_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1 + )[..., 0] + rejected_ref_logps = jnp.sum(rejected_ref_logps_seq * valid_seq_mask, axis=-1) # [B] + rejected_logratios = rejected_logps - rejected_ref_logps # [B] + + # DPO loss from chosen and rejected logratios + LABEL_SMOOTHING, BETA = config.dpo_label_smoothing, config.dpo_beta + logratios_delta = BETA * (chosen_logratios - rejected_logratios) # [B] + losses = ( # [B] + -jax.nn.log_sigmoid(BETA * logratios_delta) * (1 - LABEL_SMOOTHING) + - jax.nn.log_sigmoid(-BETA * logratios_delta) * LABEL_SMOOTHING + ) + total_loss, total_weights = jnp.mean(losses), losses.shape[0] + loss = total_loss + + moe_lb_loss = 0.0 + if config.num_experts > 1: + moe_lb_losses = maxtext_utils.collect_intermediates_by_suffix(intermediate_outputs, "moe_lb_loss") + if moe_lb_losses: + moe_lb_loss = jnp.mean(jnp.concatenate(moe_lb_losses)) + loss += moe_lb_loss + reward_accuracy = jnp.mean(chosen_logratios > rejected_logratios) + aux = { + "intermediate_outputs": intermediate_outputs, + "xent_sum": 0.0, # DPO has no per-token cross-entropy sum; set to 0 for train_step compatibility + "dpo_loss": total_loss, # pure preference loss before MoE lb, analogous to lm_loss in pre-training + "total_weights": total_weights, + "moe_lb_loss": moe_lb_loss, + "reward_accuracy": reward_accuracy, + "indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility + "mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility + } + return loss, aux diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 0af37dc10f..5e15697127 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -55,6 +55,42 @@ import os import pathwaysutils +# JAX 0.9+ changed with_sharding_constraint to assert (not reshard) when all +# mesh axes are Explicit. tpu_inference still expects resharding semantics. +# Patch: try the original (works for Auto axes); on AssertionError (Explicit +# mesh) fall back to jax.sharding.reshard. +_orig_wsc = jax.lax.with_sharding_constraint + + +def _compat_wsc(x, shardings): + try: + return _orig_wsc(x, shardings) + except AssertionError: + return jax.sharding.reshard(x, shardings) + + +jax.lax.with_sharding_constraint = _compat_wsc + +# tpu_inference JaxEinsum defaults param_dtype=float32, so tpu_inference model weights +# initialize as float32. During weight sync, tunix._apply_dtype_cast then upcasts the +# incoming bfloat16 MaxText weights → float32 to match the target. This leaves v_proj +# as float32 while k_proj output appears bfloat16 (due to k_norm dtype promotion), +# causing a dtype mismatch in the ragged paged attention kernel. +# Fix: skip bfloat16→float32 upcasts during weight sync so synced weights stay bfloat16. +import jax.numpy as _jnp +import tunix.generate.utils as _tunix_utils + +_orig_apply_dtype_cast = _tunix_utils._apply_dtype_cast # pylint: disable=protected-access + + +def _no_bf16_to_f32_cast(val, tgt_dtype, src_key): + if hasattr(val, "dtype") and val.dtype == _jnp.bfloat16 and tgt_dtype == _jnp.float32: + return val # keep bfloat16; tpu_inference model dtype is bfloat16 despite float32 init + return _orig_apply_dtype_cast(val, tgt_dtype, src_key) + + +_tunix_utils._apply_dtype_cast = _no_bf16_to_f32_cast # pylint: disable=protected-access + from absl import app from absl import logging as absl_logging from etils import epath @@ -418,6 +454,8 @@ def create_rl_components( "hf_overrides": trainer_config.vllm_hf_overrides, "enable_expert_parallel": sampler_config.enable_expert_parallel, "enable_prefix_caching": True, # Enable prefix caching to speed up generation for long prompts + # Ensures vLLM model initializes with correct dtype (not float32 default) + "dtype": trainer_config.weight_dtype, }, rollout_vllm_sampling_kwargs={ "stop": trainer_config.stop_strings, @@ -563,7 +601,10 @@ def rl_train(argv: Sequence[str], kwargs: dict): max_train_steps = get_max_train_steps(trainer_config) # Create model tokenizer - model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path) + model_tokenizer = AutoTokenizer.from_pretrained( + trainer_config.tokenizer_path, + token=trainer_config.hf_access_token or None, + ) train_dataset, test_dataset = prepare_datasets(trainer_config, model_tokenizer) diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index c7c726cec9..a6c80d27dc 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -35,7 +35,7 @@ eval_interval=-1 steps=10 profiler=xplane weight_dtype=bfloat16 """ -from typing import Sequence +from typing import Any, Sequence from absl import app import os @@ -43,6 +43,7 @@ import optax import pathwaysutils +from flax import nnx from flax.linen import partitioning as nn_partitioning from orbax import checkpoint as ocp @@ -68,6 +69,70 @@ from maxtext.utils import model_creation_utils +class MaxTextPeftTrainer(peft_trainer.PeftTrainer): + """MaxText-specific PeftTrainer that avoids nested NNX transformations. + + Tunix's default PeftTrainer._train_step creates nnx.value_and_grad inside + nnx.jit. This nesting causes Flax NNX to assign conflicting outer_index + values to graph nodes, resulting in: + ValueError: The graph structure of a node added to cached_partial was + mutated inside the transformation. + + This subclass overrides create_train_step_fn to use jax.value_and_grad + with an explicit split/merge pattern (matching MaxText's pre-training NNX + train_step), which avoids the nested NNX transformation issue entirely. + """ + + def create_train_step_fn(self): + """Creates a train step using jax.value_and_grad with explicit NNX split/merge.""" + loss_fn_ref = self.loss_fn + has_aux = self._has_aux + gen_fn = self.gen_model_input_fn + is_lora_enabled = self._lora_enabled + wrt = nnx.LoRAParam if is_lora_enabled else nnx.Param + tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True) + + # Capture the graphdef once outside of JIT so that split/merge inside + # jax.value_and_grad can use a stable (non-traced) structural descriptor. + graphdef, _, _ = nnx.split(self.model, wrt, ...) + + def train_step(model: nnx.Module, optimizer: nnx.Optimizer, inputs: Any): + inputs = gen_fn(inputs) + + # Split model into differentiable params and non-differentiable rest. + # Using jax.value_and_grad (not nnx.value_and_grad) avoids nesting NNX + # transforms inside nnx.jit, which would corrupt outer_index tracking. + _, diff_params, rest = nnx.split(model, wrt, ...) + + def loss_wrapper(diff_params, rest, **inputs_kw): + local_model = nnx.merge(graphdef, diff_params, rest, copy=True) + out = loss_fn_ref(local_model, **inputs_kw) + # Capture updated non-param state (e.g. RNG counters) from local_model. + _, _, new_rest = nnx.split(local_model, wrt, ...) + if has_aux: + loss, aux = out + return loss, (aux, new_rest) + else: + return out, (None, new_rest) + + grad_fn = jax.value_and_grad(loss_wrapper, argnums=0, has_aux=True) + (out_val, (aux, new_rest)), grads = grad_fn(diff_params, rest, **inputs) + + # Propagate updated non-param state (RNG counters, etc.) back to model. + nnx.update(model, new_rest) + + # Apply optimizer update. grads has the same nnx.State(wrt) structure + # as diff_params, which is compatible with optimizer.update. + optimizer.update(model, grads) + + aux_out = aux if has_aux else None + if tunix_expects_grad_norm: + return out_val, aux_out, optax.global_norm(grads) + return out_val, aux_out + + return train_step + + def get_tunix_config(mt_config): """Gets the Tunix training configurations from the MaxText config. @@ -109,6 +174,7 @@ def get_tunix_config(mt_config): checkpointing_options=checkpointing_options, metrics_logging_options=metrics_logging_options, profiler_options=profiler_options, + data_sharding_axis=tuple(mt_config.data_sharding), ) @@ -162,7 +228,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None): data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder) # Provide rules context so 'norm' is translated to mesh axes during maybe_restore with nn_partitioning.axis_rules(mt_config.logical_axis_rules): - trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config) + trainer = MaxTextPeftTrainer(model, optimizer, tunix_config) trainer.with_training_hooks(training_hooks) trainer.with_data_hooks(data_hooks) trainer = use_maxtext_loss_function(trainer, mt_config) diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 1011563a7b..a30170f6f5 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -35,8 +35,9 @@ import jax import jax.numpy as jnp +from jax.sharding import NamedSharding -from flax import linen as nn +from flax import linen as nn, nnx from flax.linen import partitioning as nn_partitioning from maxtext.configs import pyconfig @@ -60,7 +61,7 @@ from maxtext.common.gcloud_stub import vertex_tensorboard_modules from maxtext.common import metric_logger from maxtext.common.metric_logger import record_activation_metrics -from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn +from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn, dpo_loss_fn_nnx from maxtext.utils import exceptions from maxtext.utils import gcs_utils from maxtext.utils import max_logging @@ -68,9 +69,10 @@ from maxtext.utils import maxtext_utils from maxtext.utils import qk_clip_utils from maxtext.utils import sharding +from maxtext.utils import maxtext_utils_nnx from maxtext.utils import train_utils from maxtext.utils.gradient_accumulation import gradient_accumulation_loss_and_grad -from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss +from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss, vocab_tiling_nnx_loss _diag_modules = _cloud_diag() diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = _diag_modules @@ -92,11 +94,11 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr """loss_fn for both train and eval. Args: - model: A nn.Module + model: A nn.Module (Linen) or nnx.Module (NNX). config: Config of parameters data: Batch of data to apply to the model - dropout_rng: A key to use to generate rng for dropout - params: Model params + dropout_rng: A key to use to generate rng for dropout (Linen); unused for NNX. + params: Model params (Linen); unused for NNX (params are part of the model). is_train: True for train_step and False for eval_step Returns: @@ -183,7 +185,7 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr xent_sum = jnp.sum(xent) total_z_loss = jnp.sum(z_loss) else: - # Flax NNX model + # Flax NNX model: forward pass, then pop Intermediates sown during it. logits = model( decoder_input_tokens=data["inputs"], decoder_positions=data["inputs_position"], @@ -194,9 +196,14 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr decoder_target_tokens=data["targets"], decoder_target_mask=data["targets_segmentation"], ) - intermediate_outputs = {} + intermediates = nnx.pop(model, nnx.Intermediate) + intermediate_outputs = intermediates.to_pure_dict() - if (config.use_indexer and not config.indexer_sparse_training) and is_train: + if config.num_vocab_tiling > 1: + hidden_state_key = ("decoder", "hidden_states") + hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0] + xent_sum, total_z_loss = vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train) + elif (config.use_indexer and not config.indexer_sparse_training) and is_train: # In Dense Warm-up stage, we skip main model loss calculation for efficiency. # The main model parameters are frozen and only the indexer is trained via KL divergence. xent_sum = 0.0 @@ -286,74 +293,121 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr return loss, aux -def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng): - """ +def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng=None): + """Training step for both Linen and NNX models. Args: - model: A nn.Module - state: A pytree of the current state of the model - data: Batch of data to apply to the model - dropout_rng: A key to use to generate rng for dropout + model: A nn.Module (Linen) or nnx.GraphDef of the TrainStateNNX (NNX). + config: Hyperparameters. + state_mesh_shardings: PyTree of PartitionSpecs for the train state. + params_shardings: PyTree of PartitionSpecs for model parameters, used for gradient accumulation. + state: Linen TrainState or NNX pure State. + data: Training data batch. + dropout_rng: A key to use to generate rng for dropout (Linen); unused for NNX. Returns: - new_state: Same format as state. + new_state: Updated Linen TrainState or NNX pure State. metrics: Dictionary of model metrics such as loss, training rate, etc. - rng2: A new rng key that can be used in future calls. - """ - reference_params, reference_params_sharding, extra_dpo_args, _loss_fn = ( - [], - [], - [], - loss_fn, - ) - if config.use_dpo: - state, reference_params = _split_dpo_state(state) - state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings) - extra_dpo_args = [reference_params] - _loss_fn = dpo_loss_fn + # --- Per-path initialization --- + if isinstance(model, nn.Module): + reference_params, reference_params_sharding, extra_dpo_args, _loss_fn = [], [], [], loss_fn + if config.use_dpo: + state, reference_params = _split_dpo_state(state) + state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings) + extra_dpo_args = [reference_params] + _loss_fn = dpo_loss_fn + params = state.params + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = _loss_fn, model, params, dropout_rng, extra_dpo_args + else: + state = nnx.merge(model, state) # reconstruct TrainStateNNX + if config.use_dpo: + # NNX DPO: reference_model is a sibling field on TrainStateNNX (set up by + # init_initial_state when config.use_dpo=True). dpo_loss_fn_nnx mirrors + # the Linen dpo_loss_fn signature, so it slots into the same dispatcher + # with reference_model passed as the single extra_dpo_args entry. + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = (dpo_loss_fn_nnx, state.model, None, None, [state.reference_model]) + else: + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] - params = state.params + # --- Gradient computation --- if config.gradient_accumulation_steps > 1: loss, aux, raw_grads = gradient_accumulation_loss_and_grad( - _loss_fn, + ga_fn, config, - model, - params, + ga_model, + ga_params, params_shardings, data, - dropout_rng, - extra_dpo_args, + ga_rng, + ga_dpo, ) else: - if config.optimizer_memory_host_offload: - if config.use_dpo: + if isinstance(model, nn.Module): + if config.optimizer_memory_host_offload and config.use_dpo: reference_params = jax.device_put( reference_params, max_utils.with_memory_kind(reference_params_sharding, "device"), ) extra_dpo_args = [reference_params] - if config.shard_optimizer_over_data: - params = jax.tree.map( - functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), - params, - params_shardings, + if config.shard_optimizer_over_data: + params = jax.tree.map( + functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), + params, + params_shardings, + ) + sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m + pure_params = params["params"] if sparsity_enabled else params + batch_stats = params.get("batch_stats", {}) + + grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) + (loss, aux), raw_grads = grad_func( + model, + config, + data, + dropout_rng, + pure_params, + *extra_dpo_args, + sparsity_state=batch_stats, + is_train=True, ) - sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m - pure_params = params["params"] if sparsity_enabled else params - batch_stats = params.get("batch_stats", {}) + else: + model_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...) + if config.parameter_memory_host_offload: + # Params are kept on host (pinned_host) in in_shardings. Move only Param + # variables to device before the forward/backward pass so that all dot_general + # operands share the same memory space (XLA on GPU requires this). + # Using params_shardings (Param-only) avoids Shardy rank mismatches that + # occur when applying PartitionSpec() (rank-0 in SDY) to rank-1 RNG key tensors. + device_param_shardings = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_device, + params_shardings, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + curr_params = jax.device_put(curr_params, device_param_shardings) + nnx.update(state.model, curr_params) # ensure state.model has device params for optimizer update + if config.shard_optimizer_over_data: + curr_params = jax.tree.map( + functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), + curr_params, + params_shardings, + ) + nnx.update(state.model, curr_params) - grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) - (loss, aux), raw_grads = grad_func( - model, - config, - data, - dropout_rng, - pure_params, - *extra_dpo_args, - sparsity_state=batch_stats, - is_train=True, - ) + # `ga_fn` and `ga_dpo` were set up earlier (loss_fn vs dpo_loss_fn_nnx; + # ga_dpo carries the frozen reference_model when use_dpo, else empty). + _nnx_loss_fn = ga_fn + _nnx_extra_dpo_args = ga_dpo + + def diff_wrapper(param, rest, config, data): + local_model = nnx.merge(model_graphdef, param, rest, copy=True) + loss, aux = _nnx_loss_fn(local_model, config, data, None, None, *_nnx_extra_dpo_args, is_train=True) + _, _, new_rest = nnx.split(local_model, nnx.Param, ...) + return loss, (aux, new_rest) + + grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True) + (loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data) + nnx.update(state.model, new_rest) raw_grads = jax.tree_util.tree_map( lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, @@ -364,6 +418,8 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat raw_grads, max_utils.with_memory_kind(params_shardings, "device"), ) + + # Extract aux fields into locals intermediate_outputs = aux["intermediate_outputs"] xent_sum = aux["xent_sum"] total_weights = aux["total_weights"] @@ -373,67 +429,90 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat moe_bias_updates = aux.get("moe_bias_updates") mtp_loss = aux.get("mtp_loss", 0.0) - if config.gradient_clipping_threshold > 0: - grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold) - else: - grads = raw_grads - - if config.optimizer_memory_host_offload: - state = state.replace( - opt_state=jax.device_put( - state.opt_state, - jax.tree_util.tree_map( - lambda x: x.with_memory_kind(kind="device"), - state_mesh_shardings.opt_state, - ), - ) - ) - # Move all parameters to device before optimizer update - if config.parameter_memory_host_offload: - max_logging.log("\nMoving all parameters to device before optimizer update") - - def move(path, value): - max_logging.log(f"train.py: Moving f{path} to device") - return value.with_memory_kind(kind="device") - - state = state.replace( - params=jax.device_put( - state.params, - jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params), - ) - ) - # Re-wrap grads to match state.params structure if it's a dict of collections - sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m - if sparsity_enabled: - full_grads = {"params": grads} - if sparsity_enabled and "batch_stats" in state.params: - batch_stats_grads = jax.tree_util.tree_map(jnp.zeros_like, state.params.get("batch_stats", {})) - full_grads["batch_stats"] = batch_stats_grads - full_grads = max_utils.unbox_logicallypartioned(full_grads) - else: - full_grads = grads - - if getattr(config, "skip_step_on_spikes", False): - grad_norm = max_utils.l2norm_pytree(grads) - # TrainState.apply_gradients doesn't pass **kwargs to tx.update, so we unpack it manually. - updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params, loss=loss, grad_norm=grad_norm) - new_params = optax.apply_updates(state.params, updates) - - new_state = state.replace( - step=state.step + 1, - params=new_params, - opt_state=new_opt_state, - ) + if isinstance(model, nn.Module): + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold) + else: + grads = raw_grads + if config.optimizer_memory_host_offload: + state = state.replace( + opt_state=jax.device_put( + state.opt_state, + jax.tree_util.tree_map( + lambda x: x.with_memory_kind(kind="device"), + state_mesh_shardings.opt_state, + ), + ) + ) + # Move all parameters to device before optimizer update + if config.parameter_memory_host_offload: + max_logging.log("\nMoving all parameters to device before optimizer update") + + def move(path, value): + max_logging.log(f"train.py: Moving f{path} to device") + return value.with_memory_kind(kind="device") + + state = state.replace( + params=jax.device_put( + state.params, + jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params), + ) + ) + # Re-wrap grads to match state.params structure if it's a dict of collections + # (when weight_sparsity is enabled, params has both 'params' and 'batch_stats' keys). + sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m + if sparsity_enabled: + full_grads = {"params": grads} + if "batch_stats" in state.params: + batch_stats_grads = jax.tree_util.tree_map(jnp.zeros_like, state.params.get("batch_stats", {})) + full_grads["batch_stats"] = batch_stats_grads + full_grads = max_utils.unbox_logicallypartioned(full_grads) + else: + full_grads = grads + + if getattr(config, "skip_step_on_spikes", False): + grad_norm = max_utils.l2norm_pytree(grads) + # TrainState.apply_gradients doesn't pass **kwargs to tx.update, so we unpack it manually. + updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params, loss=loss, grad_norm=grad_norm) + new_params = optax.apply_updates(state.params, updates) + + new_state = state.replace( + step=state.step + 1, + params=new_params, + opt_state=new_opt_state, + ) + else: + new_state = state.apply_gradients(grads=full_grads) + + # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family + if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: + target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") + # Updates the shape to be aligned with state. + moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() + new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) else: - new_state = state.apply_gradients(grads=full_grads) + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold) + else: + grads = raw_grads + if config.optimizer_memory_host_offload: + # state.optimizer is an NNX Optimizer module; state_mesh_shardings.optimizer + # is an NNX State. Use nnx.state() to get a compatible State for device_put. + device_opt_shardings = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_device, + state_mesh_shardings.optimizer, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + opt_state = nnx.state(state.optimizer) + new_opt_state = jax.device_put(opt_state, device_opt_shardings) + nnx.update(state.optimizer, new_opt_state) + state.apply_gradients(grads) + new_state = state - # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") - # Flax 'sow' returns a tuple, so we take the first element [0]. - # Updates the shape to be aligned with state. - moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() - new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) + # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family + if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: + target_bias = new_state.model.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias + target_bias.value = target_bias.value + jnp.array(moe_bias_updates[0]).transpose() lm_loss = xent_sum / (total_weights + EPS) scalar_metrics = { @@ -447,10 +526,11 @@ def move(path, value): "learning/total_weights": total_weights, } if config.use_qk_clip: - # Apply QK-Clip - new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) + if isinstance(model, nn.Module): + new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) + else: + new_state = qk_clip_utils.apply_qk_clip_nnx(new_state, intermediate_outputs, config) - # Report max_logits metric global_max_logit = qk_clip_utils.calculate_max_logit_metric(intermediate_outputs) if global_max_logit is not None: scalar_metrics["learning/max_logits"] = global_max_logit @@ -458,7 +538,11 @@ def move(path, value): if not config.optimizer_memory_host_offload: scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads) scalar_metrics["learning/raw_grad_norm"] = max_utils.l2norm_pytree(raw_grads) - scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_state.params) + if isinstance(model, nn.Module): + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_state.params) + else: + model_params = nnx.state(new_state.model, nnx.Param) + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(model_params) if config.use_dpo: scalar_metrics["learning/dpo_loss"] = aux["dpo_loss"] scalar_metrics["learning/dpo_reward_accuracy"] = aux["reward_accuracy"] @@ -466,31 +550,41 @@ def move(path, value): "scalar": scalar_metrics, "scalars": {}, } - if config.record_internal_nn_metrics: record_activation_metrics(metrics, intermediate_outputs, config) - if config.use_dpo: - new_state = _merge_dpo_state(new_state, reference_params) - - return new_state, metrics + if isinstance(model, nn.Module): + if config.use_dpo: + new_state = _merge_dpo_state(new_state, reference_params) + return new_state, metrics + # Exclude Intermediate variables (e.g., sowed max_logits for QK-Clip) from the + # returned state. Intermediates are transient forward-pass artifacts and must not + # persist across steps: they're absent from the abstract state used to build + # state_mesh_shardings, so including them would cause a leaf-count mismatch in JAX. + return nnx.state(new_state, nnx.Not(nnx.Intermediate)), metrics -def eval_step(model, config, state, data, dropout_rng): +def eval_step(model, config, state, data, dropout_rng=None): """eval_step no backprop and new state compared with train_step.""" + if isinstance(model, nn.Module): + reference_params, extra_dpo_args, _loss_fn = [], [], loss_fn + if config.use_dpo: + state, reference_params = _split_dpo_state(state) + extra_dpo_args = [reference_params] + _loss_fn = dpo_loss_fn - reference_params, extra_dpo_args, _loss_fn = [], [], loss_fn - if config.use_dpo: - state, reference_params = _split_dpo_state(state) - extra_dpo_args = [reference_params] - _loss_fn = dpo_loss_fn - - sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m - pure_params = state.params["params"] if sparsity_enabled else state.params - batch_stats = state.params.get("batch_stats", {}) + sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m + pure_params = state.params["params"] if sparsity_enabled else state.params + batch_stats = state.params.get("batch_stats", {}) - eval_loss_fn = functools.partial(_loss_fn, model, config, data, dropout_rng, is_train=False) - loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, sparsity_state=batch_stats) + eval_loss_fn = functools.partial(_loss_fn, model, config, data, dropout_rng, is_train=False) + loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, sparsity_state=batch_stats) + else: + state = nnx.merge(model, state) # reconstruct TrainStateNNX + if config.use_dpo: + loss, aux = dpo_loss_fn_nnx(state.model, config, data, None, None, state.reference_model, is_train=False) + else: + loss, aux = loss_fn(state.model, config, data, None, None, is_train=False) mtp_acceptance_rate = 0.0 if config.mtp_eval_target_module > 0: @@ -518,7 +612,7 @@ def eval_step(model, config, state, data, dropout_rng): "evaluation/mtp_acceptance_rate_percent": mtp_acceptance_rate, }, } - if config.use_dpo: + if isinstance(model, nn.Module) and config.use_dpo: metrics["scalar"]["evaluation/dpo_reward_accuracy"] = aux["reward_accuracy"] return metrics @@ -540,32 +634,46 @@ def train_loop(config, recorder, state=None): state, ) = train_utils.setup_train_loop(config, recorder) - if config.use_dpo: - if "reference_params" not in state.params: - reference_params = jax.tree.map(jnp.copy, state.params["params"]) - state = _merge_dpo_state(state, reference_params) - state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + if isinstance(model, nn.Module): + if config.use_dpo: + if "reference_params" not in state.params: + reference_params = jax.tree.map(jnp.copy, state.params["params"]) + state = _merge_dpo_state(state, reference_params) + state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + jit_model = model + else: + if config.use_dpo: + raise NotImplementedError("DPO is not supported for NNX models.") + jit_model, state = nnx.split(state) params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) + p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( + config, + jit_model, + mesh, + state, + state_mesh_shardings, + train_step, + eval_step, + eval_data_iterator, + params_shardings, + ) + with jax.set_mesh(mesh), mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( - config, - model, - mesh, - state, - state_mesh_shardings, - train_step, - eval_step, - eval_data_iterator, - params_shardings, - ) shaped_batch = maxtext_utils.get_shaped_batch(config) - if config.shard_optimizer_over_data: + if config.shard_optimizer_over_data and isinstance(model, nn.Module): state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) - maxtext_utils.maybe_dump_jaxpr(config, p_train_step, (state, shaped_batch, init_rng)) + elif config.shard_optimizer_over_data: + # NNX: reshard state so params match the data-sharded in_shardings (Zero-1 layout) + state = jax.device_put(state, state_mesh_shardings) + if isinstance(model, nn.Module): + lower_args = (state, shaped_batch, init_rng) + else: + lower_args = (state, shaped_batch) + maxtext_utils.maybe_dump_jaxpr(config, p_train_step, lower_args) if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded - compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() + compiled = p_train_step.lower(*lower_args).compile() compiled_stats = compiled.memory_analysis() max_utils.print_compiled_memory_stats(compiled_stats) @@ -574,7 +682,11 @@ def train_loop(config, recorder, state=None): metric_logger_instance = metric_logger.MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) # Write train config params, num model params, and XLA flags to tensorboard - metric_logger_instance.write_setup_info_to_tensorboard(state.params) + if isinstance(model, nn.Module): + setup_params = state.params + else: + _, setup_params, _ = nnx.split(state.model, nnx.Param, ...) + metric_logger_instance.write_setup_info_to_tensorboard(setup_params) _job_completed_gracefully = False try: @@ -584,62 +696,65 @@ def train_loop(config, recorder, state=None): with jax.profiler.StepTraceAnnotation("train", step_num=step): example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager) - # pylint: disable=not-callable - nextrng = jax.jit(jax.random.fold_in)(init_rng, step) + if isinstance(model, nn.Module): + # pylint: disable=not-callable + step_rng_args = (jax.jit(jax.random.fold_in)(init_rng, step),) + else: + step_rng_args = () with maybe_record_goodput(recorder, GoodputEvent.STEP, step): with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): - if config.shard_optimizer_over_data: + if config.shard_optimizer_over_data and isinstance(model, nn.Module): state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) - state, metrics = p_train_step(state, example_batch, nextrng) - - step_time_delta = datetime.datetime.now() - last_step_completion - last_step_completion = datetime.datetime.now() - - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] - checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) - - if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): - jax.block_until_ready(state) # Ensure compilation has finished. - gcs_utils.upload_dump( - config.dump_hlo_local_dir, - config.dump_hlo_gcs_dir, - module_name=config.dump_hlo_module_name, - delete_local_after=config.dump_hlo_delete_local_after, - all_host_upload=config.dump_hlo_upload_all, - ) - - if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: - assert eval_data_iterator - # Explicitly reset the eval iterator and counters before starting the eval loop - eval_data_iterator.reset() - metric_logger_instance.reset_eval_metrics() - - eval_step_count = 0 - # pylint: disable=not-callable - for eval_batch in eval_data_iterator: - # Shard input eval data - eval_batch = jax.device_put(eval_batch, sharding.get_input_data_sharding(config, mesh)) - if config.eval_steps > 0 and eval_step_count >= config.eval_steps: - break - with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): - eval_metrics = p_eval_step(state, eval_batch, nextrng) - metric_logger_instance.record_eval_metrics(step, metrics=eval_metrics) - max_logging.log(f"Completed eval step {eval_step_count}") - eval_step_count += 1 - metric_logger_instance.record_eval_metrics(step, eval_step_count=eval_step_count) - if metric_logger_instance.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: - prof.deactivate() - raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") - - prof.maybe_deactivate_profiler(step, state) - - if step == start_step: - max_utils.print_mem_stats("After params initialized") - - metric_logger_instance.buffer_and_write_train_metrics(metrics, step, step_time_delta) + state, metrics = p_train_step(state, example_batch, *step_rng_args) + + step_time_delta = datetime.datetime.now() - last_step_completion + last_step_completion = datetime.datetime.now() + + state_to_save = state if not (config.use_dpo and not config.pure_nnx) else _split_dpo_state(state)[0] + checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) + + if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): + jax.block_until_ready(state) # Ensure compilation has finished. + gcs_utils.upload_dump( + config.dump_hlo_local_dir, + config.dump_hlo_gcs_dir, + module_name=config.dump_hlo_module_name, + delete_local_after=config.dump_hlo_delete_local_after, + all_host_upload=config.dump_hlo_upload_all, + ) + + if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: + assert eval_data_iterator + # Explicitly reset the eval iterator and counters before starting the eval loop + eval_data_iterator.reset() + metric_logger_instance.reset_eval_metrics() + + eval_step_count = 0 + # pylint: disable=not-callable + for eval_batch in eval_data_iterator: + # Shard input eval data + eval_batch = jax.device_put(eval_batch, sharding.get_input_data_sharding(config, mesh)) + if config.eval_steps > 0 and eval_step_count >= config.eval_steps: + break + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + eval_metrics = p_eval_step(state, eval_batch, *step_rng_args) + metric_logger_instance.record_eval_metrics(step, metrics=eval_metrics) + max_logging.log(f"Completed eval step {eval_step_count}") + eval_step_count += 1 + metric_logger_instance.record_eval_metrics(step, eval_step_count=eval_step_count) + if metric_logger_instance.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: + prof.deactivate() + raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") + + prof.maybe_deactivate_profiler(step, state) + + if step == start_step: + max_utils.print_mem_stats("After params initialized") + + metric_logger_instance.buffer_and_write_train_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + state_to_save = state if not (config.use_dpo and not config.pure_nnx) else _split_dpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) if checkpoint_manager is not None: # in case the last checkpoint_period checkpoint is still in progress diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index a2981f67ed..c593d3c540 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -30,6 +30,7 @@ from flax import nnx from flax.linen import partitioning as nn_partitioning import jax +import jax.numpy as jnp from jax.experimental.serialize_executable import serialize from jax.experimental.topologies import get_topology_desc from jax.sharding import AxisType, Mesh @@ -92,6 +93,27 @@ def get_topology_mesh(config): return topology_mesh +def _collect_nnx_activation_shardings(create_model_fn, config, mesh): + """Run an NNX forward pass in abstract mode to populate _ACTIVATION_SHARDINGS_DUMP. + + get_abstract_state_nnx uses nnx.eval_shape which only traces model initialization, + not __call__. Activation shardings are only collected during a forward pass. + """ + input_shape = (config.micro_batch_size_to_train_on, config.max_target_length) + + def _nnx_forward(): + model_instance = create_model_fn() + return model_instance( + decoder_input_tokens=jnp.ones(input_shape, dtype=jnp.int32), + decoder_positions=jnp.ones(input_shape, dtype=jnp.int32), + decoder_segment_ids=jnp.ones(input_shape, dtype=jnp.int32), + enable_dropout=False, + ) + + with nn_partitioning.axis_rules(config.logical_axis_rules): + jax.eval_shape(_nnx_forward) + + def get_shaped_inputs(topology_mesh, config): """Get shaped abstractions of inputs to train_step: state, batch and rng""" # Construct the model and optimizer to get shaped versions of the state @@ -129,7 +151,8 @@ def create_train_state_fn(): # For NNX, get_functional_train_with_signature expects the graphdef (static structure), # not the raw model — mirroring how the training loop does nnx.split(train_state). with nn_partitioning.axis_rules(config.logical_axis_rules): - graphdef, _ = nnx.get_abstract_model(init_state_fn, topology_mesh) + abs_train_state = nnx.eval_shape(init_state_fn) + graphdef, _ = nnx.split(abs_train_state) model = graphdef else: # unsharded logical annotations @@ -139,10 +162,17 @@ def create_train_state_fn(): shaped_batch = maxtext_utils.get_shaped_batch(config) if config.pure_nnx: - shaped_train_args = (abstract_state, shaped_batch, None) # NNX doesn't use dropout_rng + shaped_train_args = (abstract_state, shaped_batch) # NNX doesn't use dropout_rng else: shaped_train_args = (abstract_state, shaped_batch, shaped_rng) shaped_train_kwargs = {} + + # Collect activation shardings for NNX by running an abstract forward pass. + # This must happen after get_abstract_state (which uses nnx.eval_shape and only + # traces __init__, not __call__). + if config.debug_sharding and config.pure_nnx: + _collect_nnx_activation_shardings(_create_model_partial, config, topology_mesh) + return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model @@ -280,7 +310,9 @@ def main(argv: Sequence[str]) -> None: diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state( config, abstract_state, state_mesh_shardings, topology_mesh ) - shaped_train_args = (diloco_state, shaped_train_args[1], shaped_train_args[2]) + # For NNX, shaped_train_args has 2 elements (state, batch) — no rng; pass None for prng. + shaped_rng_arg = shaped_train_args[2] if len(shaped_train_args) > 2 else None + shaped_train_args = (diloco_state, shaped_train_args[1], shaped_rng_arg) # Wrap train_step with diloco train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, None) diff --git a/src/maxtext/utils/generate_param_only_checkpoint.py b/src/maxtext/utils/generate_param_only_checkpoint.py index 2fd14b87a2..2292a073d0 100644 --- a/src/maxtext/utils/generate_param_only_checkpoint.py +++ b/src/maxtext/utils/generate_param_only_checkpoint.py @@ -28,13 +28,16 @@ from absl import app from etils import epath +from flax import nnx import jax +import jax.numpy as jnp from jax import random from jax.sharding import Mesh from maxtext.configs import pyconfig from maxtext.common import checkpointing from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_TRAIN from maxtext.layers import quantizations +from maxtext.layers import train_state_nnx from maxtext.models import models from maxtext.optimizers import optimizers from maxtext.utils import gcs_utils @@ -42,12 +45,18 @@ from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +from maxtext.utils import maxtext_utils_nnx +from maxtext.utils import model_creation_utils +from maxtext.utils import train_utils def _possibly_unroll_params(config, training_state, training_state_annotations, mesh): """Unroll scanned input layers when force_unroll is set.""" if not config.scan_layers or not config.force_unroll: return + if config.pure_nnx: + _possibly_unroll_params_nnx(config, training_state, training_state_annotations, mesh) + return def unroll_layer_group(num_layers, layer_name="layers"): """Helper function to unroll layers (e.g. dense or MoE) into individual layers.""" @@ -88,25 +97,85 @@ def slice_ith(input_layers): unroll_layer_group(config.num_decoder_layers, layer_name="layers") +def _possibly_unroll_params_nnx(config, state, state_mesh_shardings, mesh): + """NNX equivalent of _possibly_unroll_params. + + `state` is a flat `nnx.State` (post-split TrainStateNNX) with `state.model` + as a sub-State whose tree mirrors the model module hierarchy. Slices + `state.model.decoder[layer_name]` into per-index `layer_name_0..N` siblings + and removes the original collection. Mirrors the same operation on + `state_mesh_shardings` so downstream sharding stays correct. + """ + decoder_state = state.model.decoder + decoder_shardings = state_mesh_shardings.model.decoder + + def unroll_layer_group(num_layers, layer_name="layers"): + layers = decoder_state.get(layer_name, None) + layers_shardings = decoder_shardings.get(layer_name, None) + if layers is None or layers_shardings is None: + raise ValueError(f"Missing {layer_name} in NNX state.model.decoder or state_mesh_shardings.") + + def drop_scan_axis(named_sharding): + ps = named_sharding.spec + return jax.sharding.PartitionSpec(*(ps[0 : config.param_scan_axis] + ps[config.param_scan_axis + 1 :])) + + new_layer_pspec = jax.tree_util.tree_map( + drop_scan_axis, layers_shardings, is_leaf=lambda x: isinstance(x, jax.sharding.NamedSharding) + ) + new_layer_sharding = jax.tree_util.tree_map(lambda ps: jax.sharding.NamedSharding(mesh, ps), new_layer_pspec) + + for i in range(num_layers): + + def slice_ith(input_layers): + return jax.tree_util.tree_map(lambda x: jnp.take(x, i, axis=config.param_scan_axis), input_layers) + + # pylint: disable=not-callable + new_layer = jax.jit(slice_ith, out_shardings=new_layer_sharding)(layers) + + decoder_state[f"{layer_name}_{i}"] = new_layer + decoder_shardings[f"{layer_name}_{i}"] = new_layer_sharding + + decoder_state.pop(layer_name) + decoder_shardings.pop(layer_name) + jax.tree_util.tree_map(lambda x: x.delete() if hasattr(x, "delete") else None, layers) + + if config.decoder_block == DecoderBlockType.DEEPSEEK: + unroll_layer_group(config.first_num_dense_layers, layer_name="dense_layers") + unroll_layer_group(config.num_decoder_layers - config.first_num_dense_layers, layer_name="moe_layers") + else: + unroll_layer_group(config.num_decoder_layers, layer_name="layers") + + def _read_train_checkpoint(config, checkpoint_manager, mesh): """Read training checkpoint at path defined by load_full_state_path.""" - # Model and Optimizer definition - quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) - learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) - tx = optimizers.get_optimizer(config, learning_rate_schedule) if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + rngs = maxtext_utils_nnx.create_nnx_rngs(config, rng_key=rng) + model = model_creation_utils.from_config(config, mesh=mesh, rngs=rngs) + _, tx = train_utils.create_training_optimizer(config, model) + _create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(config, mesh) + + def init_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(nnx_model, optimizer) + else: + quant = quantizations.configure_quantization(config) + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) + tx = optimizers.get_optimizer(config, learning_rate_schedule) init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) - state, state_mesh_notations, _, _ = maxtext_utils.setup_training_state( + + state, state_mesh_notations, state_mesh_shardings, _ = maxtext_utils.setup_training_state( None, config, mesh, checkpoint_manager, init_state_fn ) + if config.pure_nnx: + # On NNX, state is a flat nnx.State; params live under state.model and the + # legacy notations are unused (callers receive shardings directly). + num_params = max_utils.calculate_num_params_from_pytree(state.model) + max_logging.log(f"In input checkpoint Number of model params={num_params/1e9:.3f} billion") + return state, state_mesh_shardings num_params = max_utils.calculate_num_params_from_pytree(state.params) max_logging.log(f"In input checkpoint Number of model params={num_params/1e9:.3f} billion") return state, state_mesh_notations @@ -114,12 +183,11 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh): def _generate_lora_decode_checkpoints(config, mesh): """Read lora checkpoints checkpoint at path defined by load_full_state_path.""" - # Model and Optimizer definition - quant = quantizations.configure_quantization(config) if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) + _generate_lora_decode_checkpoints_nnx(config, mesh) + return + quant = quantizations.configure_quantization(config) + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) @@ -154,6 +222,9 @@ def _generate_lora_decode_checkpoints(config, mesh): def _save_decode_checkpoint(config, state, checkpoint_manager): """Generate checkpoint for decode from the training_state.""" + if config.pure_nnx: + _save_decode_checkpoint_nnx(config, state, checkpoint_manager) + return decode_state = maxtext_utils.init_decode_state( None, jax.tree_util.tree_map(lambda x: x.astype(jax.numpy.bfloat16), state.params) ) @@ -163,6 +234,121 @@ def _save_decode_checkpoint(config, state, checkpoint_manager): checkpoint_manager.wait_until_finished() +def _save_decode_checkpoint_nnx(config, state, checkpoint_manager): + """Save a bf16 NNX-format param-only decode checkpoint. + + The on-disk shape mirrors what a vanilla NNX-trained checkpoint produces: a + plain dict tree of arrays (one per nnx.Param), with no Linen-style "params" + wrapper. This is the shape `from_pretrained` reads via its NNX-detection + branch (see model_creation_utils._adjust_target_for_moe_fusion / "is_nnx_checkpoint"). + """ + pure_model = state.model.to_pure_dict() if hasattr(state.model, "to_pure_dict") else dict(state.model) + bf16_model = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pure_model) + if checkpoint_manager is not None: + if checkpointing.save_checkpoint(checkpoint_manager, 0, bf16_model): + max_logging.log(f"saved an NNX decode checkpoint at {config.checkpoint_dir}") + checkpoint_manager.wait_until_finished() + + +def _possibly_unroll_lora_params_nnx(config, lora_state, lora_state_annotations, mesh): + """Unroll scanned LoRA delta layers when force_unroll is set on the NNX path. + + `lora_state` is a Linen-style `TrainState` (returned by `get_lora_abstract_state_nnx`) + whose `.params` is single-nested (`{"decoder": {...}}`, no outer `params` wrap) + and whose leaves at target attention paths are `lora_a.kernel`/`lora_b.kernel`. + """ + if not config.scan_layers or not config.force_unroll: + return + + decoder_params = lora_state.params["decoder"] + decoder_annotations = lora_state_annotations.params["decoder"] + + def unroll_layer_group(num_layers, layer_name="layers"): + layers = decoder_params.get(layer_name) + layers_annotations = decoder_annotations.get(layer_name) + if layers is None or layers_annotations is None: + return # No LoRA on this layer group; nothing to unroll. + + def new_pspec(x): + return jax.sharding.PartitionSpec(*(x[0 : config.param_scan_axis] + x[config.param_scan_axis + 1 :])) + + new_layer_annotation = jax.tree_util.tree_map(new_pspec, layers_annotations) + new_layer_sharding = jax.tree_util.tree_map(lambda x: jax.sharding.NamedSharding(mesh, x), new_layer_annotation) + + for i in range(num_layers): + + def slice_ith(input_layers): + return jax.tree_util.tree_map(lambda x: jnp.take(x, i, axis=config.param_scan_axis), input_layers) + + # pylint: disable=not-callable + new_layer = jax.jit(slice_ith, out_shardings=new_layer_sharding)(layers) + decoder_params[f"{layer_name}_{i}"] = new_layer + decoder_annotations[f"{layer_name}_{i}"] = new_layer_annotation + + del decoder_params[layer_name] + del decoder_annotations[layer_name] + jax.tree_util.tree_map(lambda x: x.delete() if hasattr(x, "delete") else None, layers) + + if config.decoder_block == DecoderBlockType.DEEPSEEK: + unroll_layer_group(config.first_num_dense_layers, layer_name="dense_layers") + unroll_layer_group(config.num_decoder_layers - config.first_num_dense_layers, layer_name="moe_layers") + else: + unroll_layer_group(config.num_decoder_layers, layer_name="layers") + + +def _save_lora_decode_checkpoint_nnx(config, lora_state, checkpoint_manager): + """Save a bf16 LoRA-only decode checkpoint (NNX path). + + `lora_state.params` is single-nested (NNX-derived shape). The on-disk + format mirrors the Linen LoRA decode shape so existing serving consumers + can keep reading it: a `TrainState` wrapper with `params` set to the + bf16-cast LoRA delta tree. The base model is loaded separately at serve + time via `apply_lora_on_base_params_nnx` (PR8). + """ + decode_state = maxtext_utils.init_decode_state( + None, jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), lora_state.params) + ) + if checkpoint_manager is not None: + if checkpointing.save_checkpoint(checkpoint_manager, 0, decode_state): + max_logging.log(f"saved a LoRA decode checkpoint at {config.checkpoint_dir}") + checkpoint_manager.wait_until_finished() + + +def _generate_lora_decode_checkpoints_nnx(config, mesh): + """NNX-shaped sibling of `_generate_lora_decode_checkpoints`. + + Builds the NNX abstract base model so `setup_initial_lora_state` (PR8) + produces an NNX-derived `lora_state`, then runs an NNX-shape unroll/save. + """ + rng = random.PRNGKey(0) + rngs = maxtext_utils_nnx.create_nnx_rngs(config, rng_key=rng) + model = model_creation_utils.from_config(config, mesh=mesh, rngs=rngs) + _, tx = train_utils.create_training_optimizer(config, model) + + lora_adapters = gcs_utils.gcs_list_directories(config.lora_input_adapters_path) + for lora_id in lora_adapters: + lora_checkpoint_dir = os.path.join(config.checkpoint_dir, "loras", lora_id, "") + lora_adapter_path = os.path.join(config.lora_input_adapters_path, lora_id, "") + + checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( + lora_checkpoint_dir, + config.enable_checkpointing, + config.async_checkpointing, + config.checkpoint_period, + ) + + lora_config, lora_state, lora_state_annotations = lora_utils.setup_initial_lora_state( + model, None, tx, config, rng, mesh, checkpoint_manager, lora_adapter_path + ) + + _possibly_unroll_lora_params_nnx(config, lora_state, lora_state_annotations, mesh) + + gcs_utils.write_dict_to_gcs_json(lora_config, os.path.join(lora_checkpoint_dir, "adapter_config.json")) + + _save_lora_decode_checkpoint_nnx(config, lora_state, checkpoint_manager) + max_logging.log(f"Successfully saved LoRA checkpoint at: {os.path.join(lora_checkpoint_dir, '0', 'items')}") + + def generate_decode_checkpoint(config): """ Generate an decode checkpoint from a given training checkpoint. diff --git a/src/maxtext/utils/gradient_accumulation.py b/src/maxtext/utils/gradient_accumulation.py index 9bad1cfb35..cf84577dbd 100644 --- a/src/maxtext/utils/gradient_accumulation.py +++ b/src/maxtext/utils/gradient_accumulation.py @@ -17,6 +17,7 @@ import jax import jax.numpy as jnp from jax.sharding import NamedSharding +from flax import nnx from maxtext.common.common_types import ShardMode from maxtext.utils.sharding import maybe_shard_with_name @@ -49,7 +50,8 @@ def gradient_accumulation_loss_and_grad( config: Model and training configuration object. Must contain `gradient_accumulation_steps` and `shard_optimizer_over_data`. model: The model module. - params: The model parameters (PyTree). + params: The model parameters (PyTree). This is only used for Linen. For NNX, + we can get the params from the model. params_shardings: The sharding constraints for the parameters (PyTree). data: A PyTree of batched data. The leading dimension is assumed to be the total batch size (microbatch_size * num_accumulations). @@ -67,12 +69,24 @@ def _maybe_shard_with_name(inputs, sharding_names): """Wrapper of maybe_shard_with_name with fixed shard_mode""" return maybe_shard_with_name(inputs, sharding_names, config.shard_mode, debug_sharding=config.debug_sharding) - # For more efficient DP/ZeRO-1 + GA - if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1: - ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings) - grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) + is_nnx = isinstance(model, nnx.Module) + + # For more efficient DP/ZeRO-1 + GA. + # config.ici_data_parallelism may be -1 (auto-fill: resolved at mesh creation time, but + # the config field remains -1). Treat any value != 1 as "data parallelism is active". + if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism != 1: + # jax.lax.scan traces its body with an AbstractMesh where all axis types are Auto, + # which rejects reduced/unreduced PartitionSpec in scan carry tensors (raises ValueError). + # Use plain params_shardings for ga_params and init_grad in the carry. + # The all-reduce for data parallelism is applied to raw_grads after the scan instead. + ga_params_shardings = params_shardings + grad_shardings = params_shardings else: ga_params_shardings = grad_shardings = params_shardings + + if is_nnx: + graphdef, params, rest = nnx.split(model, nnx.Param, ...) + # When using Zero-1 optimizer sharding, cast params to lower precision and apply sharding constraints # so that all-gather is done once in the lower precision before the gradient accumulation loop if config.shard_optimizer_over_data: @@ -87,11 +101,27 @@ def convert_to_bf16(param): ga_params = params ga_params = jax.tree.map(_maybe_shard_with_name, ga_params, ga_params_shardings) - grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) + if is_nnx: + grad_func = nnx.value_and_grad(_loss_fn, argnums=0, has_aux=True) + else: + grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) def accumulate_gradient(acc_grad_and_loss, data): ga_params = acc_grad_and_loss["ga_params"] - (_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, ga_params, *extra_dpo_args, is_train=True) + if is_nnx: + # Reconstruct the model using the fixed parameters (ga_params) + # and the advancing non-parameter state (RNGs) from the carry. + local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"], copy=True) + (_, aux), cur_batch_gradient = grad_func(local_model, config, data, None, None, *extra_dpo_args, is_train=True) + _, _, next_rest_state = nnx.split(local_model, nnx.Param, ...) + acc_grad_and_loss["rest_state"] = next_rest_state + else: + rng = ( + jax.random.fold_in(dropout_rng, acc_grad_and_loss["total_weights"].astype(jnp.int32)) + if dropout_rng is not None + else None + ) + (_, aux), cur_batch_gradient = grad_func(model, config, data, rng, ga_params, *extra_dpo_args, is_train=True) acc_grad_and_loss["loss"] += aux["xent_sum"] + aux.get("dpo_loss", 0.0) acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"] acc_grad_and_loss["indexer_loss"] += aux["indexer_loss"] @@ -119,6 +149,8 @@ def reshape_to_microbatch_accumulations(batch_arr): "mtp_loss": 0.0, "ga_params": ga_params, } + if is_nnx: + init_grad_and_loss["rest_state"] = rest grad_and_loss, aux = jax.lax.scan( accumulate_gradient, init_grad_and_loss, data, length=config.gradient_accumulation_steps @@ -130,10 +162,18 @@ def reshape_to_microbatch_accumulations(batch_arr): + grad_and_loss["mtp_loss"] / config.gradient_accumulation_steps ) raw_grads = grad_and_loss["grad"] + if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism != 1: + # Apply unreduced annotation after the scan to trigger all-reduce across data-parallel + # devices (reduced/unreduced cannot be used inside jax.lax.scan carry tensors). + unreduced_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) + raw_grads = jax.tree.map(_maybe_shard_with_name, raw_grads, unreduced_shardings) raw_grads = jax.tree.map(_maybe_shard_with_name, raw_grads, params_shardings) raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss["total_weights"], raw_grads) aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) # pytype: disable=module-attr + if is_nnx: + nnx.update(model, grad_and_loss["rest_state"]) + return loss, aux, raw_grads diff --git a/src/maxtext/utils/layerwise_quantization.py b/src/maxtext/utils/layerwise_quantization.py index 29fa928656..d9cb997b01 100644 --- a/src/maxtext/utils/layerwise_quantization.py +++ b/src/maxtext/utils/layerwise_quantization.py @@ -47,6 +47,8 @@ from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +from maxtext.utils import maxtext_utils_nnx +from maxtext.utils import model_creation_utils import orbax.checkpoint as ocp from tqdm import tqdm from maxtext.configs import pyconfig @@ -164,28 +166,29 @@ def __init__(self, config: Any, rng: PRNGKeyType): self.config = config self.rng = rng - # TODO(ranlihao): Remove this assertion once the Layerwise quantization is supported for other decoder blocks. - assert ( - config.decoder_block == common_types.DecoderBlockType.DEEPSEEK - ), f"Layerwise quantization is only supported for {common_types.DecoderBlockType.DEEPSEEK}\ - , but got {config.decoder_block}." + # The Linen path runs layer-by-layer (memory-efficient for big DeepSeek + # models) and is DeepSeek-specific because it relies on the per-layer + # `DeepSeek*ToLinen` wrappers. The NNX path runs whole-model convert + # forward and is model-agnostic — see `_load_and_quantize_nnx`. + if not config.pure_nnx: + assert config.decoder_block == common_types.DecoderBlockType.DEEPSEEK, ( + f"Linen layerwise quantization only supports {common_types.DecoderBlockType.DEEPSEEK}, " + f"got {config.decoder_block}." + ) # Mesh definition devices_array = maxtext_utils.create_device_mesh(config=config) self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) - # Model and quantization config self.quant = quantizations.configure_quantization(config) - if self.config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen( - config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN - ) - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng) + if config.pure_nnx: + # NNX takes a separate code path that builds the model via from_pretrained; + # no Linen abstract-state bookkeeping is needed here. + self.unboxed_abstract_state = None + return + model = models.transformer_as_linen( + config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN + ) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng) self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(self.config, self._mesh, init_state_fn, False) @@ -193,6 +196,9 @@ def load_and_quantize(self) -> None: """ Load parameters layer by layer and quantize them. """ + if self.config.pure_nnx: + self._load_and_quantize_nnx() + return quantized_params = {} quantized_params["params"] = {"decoder": {}} quantized_params["aqt"] = {"decoder": {}} @@ -278,6 +284,131 @@ def model_apply(_p, _rng, layer): maxtext_utils.save_quantized_checkpoint_if_configured(self.config, quantized_params) + def _load_and_quantize_nnx(self) -> None: + """Whole-model NNX convert: load full-precision via TRAIN-mode `from_pretrained`, + transfer kernels into a fresh CONVERT-mode model, run a forward (the + `ToNNX(AqtDotGeneral)` bridge auto-captures `qrhs.frozen`), strip kernels at + quantized paths, and save the serve-mode-shaped state. + + Two-step load: input checkpoints are typically full-precision (no AQT state + on disk), so we can't `from_pretrained(quant_mode_str="convert")` directly — + orbax would fail to find the missing `qrhs.frozen` leaves. Instead we load + in TRAIN mode (which has only kernels), then copy them into a randomly + initialized CONVERT model that already has the AQT variables provisioned. + """ + config = self.config + # MODEL_MODE_TRAIN avoids the PREFILL/AUTOREGRESSIVE cache plumbing — AQT + # layers populate `qrhs.frozen` regardless of model_mode, so train mode is + # simpler and faster. + max_logging.log("Loading full-precision NNX checkpoint in TRAIN mode...") + with self._mesh: + train_model = model_creation_utils.from_pretrained( + config, + mesh=self._mesh, + model_mode=common_types.MODEL_MODE_TRAIN, + quant_mode_str="train", + ) + + max_logging.log("Building CONVERT-mode model (random init) and copying kernels in...") + rngs = maxtext_utils_nnx.create_nnx_rngs(config, rng_key=self.rng) + with nn_partitioning.axis_rules(config.logical_axis_rules): + convert_model = model_creation_utils.from_config( + config, + mesh=self._mesh, + rngs=rngs, + model_mode=common_types.MODEL_MODE_TRAIN, + quant_mode_str="convert", + ) + self._copy_kernel_leaves_(convert_model, train_model) + del train_model + + # Forward populates AqtDotGeneral_0.qrhs.frozen on every quantized layer. + L = config.max_target_length + decoder_input_tokens = jnp.zeros((1, L), dtype=jnp.int32) + decoder_positions = jnp.arange(L, dtype=jnp.int32)[None, :] + decoder_segment_ids = jnp.ones((1, L), dtype=jnp.int32) + max_logging.log("Running CONVERT-mode forward to populate AQT scale factors...") + with nn_partitioning.axis_rules(config.logical_axis_rules): + _ = convert_model( + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + enable_dropout=False, + model_mode=common_types.MODEL_MODE_TRAIN, + ) + + # Convert-mode state has both `kernel` (full precision) and `AqtDotGeneral_0.qrhs.frozen` + # at every quantized DenseGeneral; the serve-mode reader expects only the latter. + convert_state = nnx.state(convert_model).to_pure_dict() + serve_state = self._strip_kernels_at_quantized_paths(convert_state) + + if config.save_quantized_params_path: + max_logging.log(f"Saving NNX-format quantized checkpoint to {config.save_quantized_params_path}") + # Wrap each leaf in `{"value": }` so the on-disk shape matches what + # `from_pretrained`'s NNX-detection branch reads back (it later does + # `tree.map(lambda v: v["value"], ...)` on each leaf). Save directly via + # orbax — `save_params_to_path` would add an outer `{"params": ...}` wrap + # that the NNX path doesn't expect. + def _wrap_value(node): + if isinstance(node, dict): + return {k: _wrap_value(v) for k, v in node.items()} + return {"value": node} + + wrapped = _wrap_value(serve_state) + orbax_checkpointer = ocp.PyTreeCheckpointer( + use_ocdbt=config.checkpoint_storage_use_ocdbt, + use_zarr3=config.checkpoint_storage_use_zarr3, + ) + orbax_checkpointer.save(config.save_quantized_params_path, wrapped, force=True) + max_logging.log(f"Saved NNX-format quantized checkpoint at: {config.save_quantized_params_path}") + else: + max_logging.log("Skipping save: save_quantized_params_path is null.") + + @staticmethod + def _copy_kernel_leaves_(dst_model, src_model): + """Copy the full-precision parameter leaves (kernel/embedding/scale/bias) + from src into dst, leaving dst's AQT and RNG variables untouched. + """ + src_dict = nnx.state(src_model).to_pure_dict() + dst_state = nnx.state(dst_model) + dst_dict = dst_state.to_pure_dict() + + def walk(d_node, s_node): + if not (isinstance(d_node, dict) and isinstance(s_node, dict)): + return + for key, d_child in d_node.items(): + if key not in s_node: + continue + s_child = s_node[key] + if key in ("kernel", "embedding", "scale", "bias") and not isinstance(d_child, dict): + d_node[key] = s_child + elif isinstance(d_child, dict): + walk(d_child, s_child) + + walk(dst_dict, src_dict) + nnx.replace_by_pure_dict(dst_state, dst_dict) + nnx.update(dst_model, dst_state) + + @staticmethod + def _strip_kernels_at_quantized_paths(state_dict): + """Drop `kernel` keys at any node that has a sibling `AqtDotGeneral_0`. + + In convert mode each quantized DenseGeneral keeps both the full-precision + `kernel` (an nnx.Param) and the AQT-quantized `AqtDotGeneral_0.qrhs.frozen` + side-by-side. Serve mode (the on-disk shape `from_pretrained` reads back) + only carries the latter; the kernel is recreated as a dummy zero in + `linears.DenseGeneral.__call__`. + """ + if not isinstance(state_dict, dict): + return state_dict + has_aqt = "AqtDotGeneral_0" in state_dict + out = {} + for k, v in state_dict.items(): + if k == "kernel" and has_aqt: + continue + out[k] = LayerwiseQuantization._strip_kernels_at_quantized_paths(v) if isinstance(v, dict) else v + return out + def _load_layer(self, layer_name): """Loads a specific layer's parameters from the checkpoint.""" diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 24099ef22a..a49f4c1f25 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Common LoRA utils needed to support LoRA adapters.""" +"""Common LoRA utils needed to support LoRA adapters.""" +from collections.abc import Mapping from functools import partial import json @@ -29,6 +30,10 @@ from maxtext.utils import maxtext_utils from maxtext.utils import max_logging +# NNX-only imports (`flax.nnx`, `train_state_nnx`, `model_creation_utils`) are +# loaded lazily inside the NNX dispatch branches so the Linen-only flow's +# import chain stays identical to pre-PR8. + def apply_lora_on_base_params(base_params, lora_params, lora_scale_factor=1.0): """ @@ -109,8 +114,10 @@ def unapply_lora_recursively(base_params, lora_params, module_name): def load_adapter(config, base_abstract_state_params, adapter_config_path, adapter_weights_path): - """ - Load the LoRA weights into a PyTree and return it. + """Load LoRA weights into a PyTree and return it. + + On the NNX path, `base_abstract_state_params` and the returned `lora_params` + are `nnx.State`-shaped (no outer `{"params": ...}` wrap). """ # Load LoRA weights lora_params = None @@ -128,7 +135,10 @@ def load_adapter(config, base_abstract_state_params, adapter_config_path, adapte if not gcs_utils.gcs_path_exists(f"{adapter_weights_path}/commit_success.txt"): raise FileNotFoundError(f"Failed to read lora_weights from {adapter_weights_path}.") - lora_state, _ = get_lora_abstract_state(base_abstract_state_params, lora_config) + if config.pure_nnx: + lora_state, _ = get_lora_abstract_state_nnx(base_abstract_state_params, lora_config) + else: + lora_state, _ = get_lora_abstract_state(base_abstract_state_params, lora_config) with nn_partitioning.axis_rules(config.logical_axis_rules): lora_params = checkpointing.load_params_from_path( @@ -143,22 +153,12 @@ def load_adapter(config, base_abstract_state_params, adapter_config_path, adapte def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager, lora_adapter_path): - """We initialize the model and optimizer state, and optionally load from a - checkpoint as necessary. + """Initialize the LoRA train state and optionally load weights from disk. - Args: - model: the flax model to initialize - tx: the optax.GradientTransformation - config: config object - rng: jax.prng key - mesh: jax.devices() mesh - checkpoint_manager: an Orbax checkpointing.CheckpointManager object - lora_adapter_path: Path of the LoRA adapter which is expected to have - `adapter_config.json` and adapter weights - - Returns: - state: the initialized train state - state_mesh_annotations: the mesh annotations for the train state + Returns `(lora_config, lora_state, lora_state_annotations)`. On the NNX path + `model` is unused (the NNX abstract state is built via + `model_creation_utils.create_nnx_abstract_model`) and `lora_state.params` + is `nnx.State`-shaped; on Linen it is the original `{"params": ...}` tree. """ lora_state = None @@ -168,8 +168,19 @@ def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkp if lora_adapter_path: max_logging.log(f"Setting initial state of LoRA with lora_adapter_path = {lora_adapter_path}") if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + # pylint: disable=import-outside-toplevel + from flax import nnx + from maxtext.layers import train_state_nnx + from maxtext.utils import model_creation_utils + + _create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(config, mesh) + + def create_train_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(nnx_model, optimizer) + + init_state_fn = create_train_state_fn else: init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, True) @@ -178,7 +189,11 @@ def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkp lora_config = gcs_utils.read_json_from_gcs(lora_config_path) - lora_state, lora_state_annotations = get_lora_abstract_state(unboxed_abstract_state.params, lora_config) + if config.pure_nnx: + base_abstract_params = _nnx_param_subtree(unboxed_abstract_state) + lora_state, lora_state_annotations = get_lora_abstract_state_nnx(base_abstract_params, lora_config) + else: + lora_state, lora_state_annotations = get_lora_abstract_state(unboxed_abstract_state.params, lora_config) lora_weights_path = f"{lora_adapter_path}/0/items" @@ -349,3 +364,165 @@ def get_lora_annotations(lora_abstract_params): ) return unboxed_abstract_lora_state, lora_state_mesh_annotations + + +# NNX-shaped LoRA helpers. The Linen walkers above key on `isinstance(x, dict)` +# and bare leaves; NNX trees use `nnx.State` (Mapping but not dict) and +# Variable-wrapped leaves, so we need separate mirrors. The math (W += B @ A * s) +# is identical. + + +def _is_nnx_branch(x): + return isinstance(x, Mapping) + + +def _nnx_param_subtree(unboxed_abstract_state): + """Drop the outer TrainStateNNX wrapping and return the model substate.""" + return unboxed_abstract_state["model"] if "model" in unboxed_abstract_state else unboxed_abstract_state + + +def apply_lora_on_base_params_nnx(base_params, lora_params, lora_scale_factor=1.0): + """NNX variant of `apply_lora_on_base_params`. Mutates `base_params` in place.""" + + def lora_update_or_base(base_weight, lora_a, lora_b): + if lora_a is not None and lora_b is not None: + return base_weight + jnp.einsum("br,rnd->bnd", lora_b, lora_a) * lora_scale_factor + return base_weight + + def recurse(base_node, lora_node, path): + for name, lora_child in lora_node.items(): + if _is_nnx_branch(lora_child): + recurse(base_node[name], lora_child, f"{path}.{name}") + elif lora_child is not None: + if name not in ("lora_a.kernel", "lora_b.kernel"): + raise ValueError(f"Unexpected non-lora key ({path}.{name}) in lora_params") + lora_b = lora_node["lora_a.kernel"] + lora_a = lora_node["lora_b.kernel"] + base_node["kernel"] = lora_update_or_base(base_node["kernel"], lora_a, lora_b) + return + + recurse(base_params, lora_params, "") + + +def unapply_lora_from_base_params_nnx(base_params, lora_params, lora_scale_factor=1.0): + """NNX-shaped variant of `unapply_lora_from_base_params`. Mutates `base_params`.""" + + def lora_update_or_base(base_weight, lora_a, lora_b): + if lora_a is not None and lora_b is not None: + return base_weight - jnp.einsum("br,rnd->bnd", lora_b, lora_a) * lora_scale_factor + return base_weight + + def recurse(base_node, lora_node, path): + for name, lora_child in lora_node.items(): + if _is_nnx_branch(lora_child): + recurse(base_node[name], lora_child, f"{path}.{name}") + elif lora_child is not None: + if name not in ("lora_a.kernel", "lora_b.kernel"): + raise ValueError(f"Unexpected non-lora key ({path}.{name}) in lora_params") + lora_b = lora_node["lora_a.kernel"] + lora_a = lora_node["lora_b.kernel"] + base_node["kernel"] = lora_update_or_base(base_node["kernel"], lora_a, lora_b) + return + + recurse(base_params, lora_params, "") + + +def get_lora_abstract_state_nnx(base_abstract_params, lora_config): + """`get_lora_abstract_state` for the NNX path. + + Walks the abstract `state.model` substate and emits a parallel tree with + `lora_a.kernel` / `lora_b.kernel` leaves at target attention paths and + `None` elsewhere. + """ + other_lora_format_to_jax_format = { + "q_proj": "self_attention.query", + "k_proj": "self_attention.key", + "v_proj": "self_attention.value", + "o_proj": "self_attention.out", + } + + lora_target_modules = [other_lora_format_to_jax_format.get(s, s) for s in lora_config["target_modules"]] + lora_rank = int(lora_config["r"]) + + def get_lora_param_shape(base_array_shape, lora_module): + if len(base_array_shape) > 4: + raise ValueError(f"Unsupported base array shape {base_array_shape} (>4D)") + if lora_module in ("self_attention.query", "self_attention.key", "self_attention.value"): + lora_a_shape = base_array_shape[:-2] + (lora_rank,) + lora_b_shape = (lora_rank,) + base_array_shape[1:] + elif lora_module == "self_attention.out": + lora_a_shape = base_array_shape[:-1] + (lora_rank,) + if len(base_array_shape) == 4: + lora_b_shape = (lora_rank, base_array_shape[1], base_array_shape[-1]) + else: + lora_b_shape = (lora_rank, base_array_shape[-1]) + else: + raise ValueError(f"Unsupported lora_module={lora_module}") + return lora_a_shape, lora_b_shape + + def get_lora_param_sharding(base_param_sharding, lora_module): + if base_param_sharding is None: + return None, None + base_pspec = base_param_sharding.spec + if len(base_pspec) > 4: + raise ValueError("PartitionSpec size > 4 not supported") + if lora_module in ("self_attention.query", "self_attention.key", "self_attention.value"): + lora_a_pspec = jax.sharding.PartitionSpec(*(base_pspec[:-2] + ((),))) + lora_b_pspec = jax.sharding.PartitionSpec(*(((),) + base_pspec[1:])) + elif lora_module == "self_attention.out": + lora_a_pspec = jax.sharding.PartitionSpec(*(base_pspec[:-1] + ((),))) + if len(base_pspec) == 4: + lora_b_pspec = jax.sharding.PartitionSpec((), base_pspec[1], base_pspec[-1]) + else: + lora_b_pspec = jax.sharding.PartitionSpec((), base_pspec[-1]) + else: + raise ValueError(f"Unsupported lora_module={lora_module}") + mesh = base_param_sharding.mesh + mem_kind = base_param_sharding.memory_kind + return ( + jax.sharding.NamedSharding(mesh=mesh, spec=lora_a_pspec, memory_kind=mem_kind), + jax.sharding.NamedSharding(mesh=mesh, spec=lora_b_pspec, memory_kind=mem_kind), + ) + + def module_is_target(module_path): + for tgt in lora_target_modules: + if tgt in module_path: + return tgt + return None + + def add_lora(out_node, base_node, path): + for name, child in base_node.items(): + if _is_nnx_branch(child): + out_node[name] = {} + add_lora(out_node[name], child, f"{path}.{name}") + else: + if name not in ("kernel", "scale", "embedding"): + raise ValueError(f"Unexpected key={name} in base abstract params at {path}") + if not isinstance(child, jax.ShapeDtypeStruct): + raise ValueError(f"Unexpected leaf type {type(child).__name__} at {path}.{name}") + target_module = module_is_target(path) + if target_module is not None: + a_shape, b_shape = get_lora_param_shape(child.shape, target_module) + a_sharding, b_sharding = get_lora_param_sharding(child.sharding, target_module) + out_node["lora_a.kernel"] = jax.ShapeDtypeStruct(shape=a_shape, dtype=child.dtype, sharding=a_sharding) + out_node["lora_b.kernel"] = jax.ShapeDtypeStruct(shape=b_shape, dtype=child.dtype, sharding=b_sharding) + else: + out_node[name] = None + + lora_abstract_params = {} + add_lora(lora_abstract_params, base_abstract_params, "") + + unboxed_abstract_lora_state = train_state.TrainState( + step=0, apply_fn=None, params=lora_abstract_params, tx=None, opt_state={} # type: ignore + ) + lora_state_mesh_annotations = train_state.TrainState( + step=0, + apply_fn=None, + params=jax.tree_util.tree_map( + lambda x: x.sharding.spec if x.sharding is not None else None, + lora_abstract_params, + ), + tx=None, # type: ignore + opt_state={}, + ) + return unboxed_abstract_lora_state, lora_state_mesh_annotations diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index ae0496b002..f24cbfefab 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -20,21 +20,20 @@ import os from typing import Sequence -from flax import linen as nn +from flax import nnx, linen as nn +from flax.core.spmd import composite_rules, from_sharding_rules, get_logical_axis_rules from flax.linen import partitioning as nn_partitioning -from flax.training import train_state +from flax.training.train_state import TrainState import numpy as np -from jax.experimental import mesh_utils -from jax.experimental.serialize_executable import deserialize_and_load -from jax.sharding import AxisType, Mesh - import jax import jax.numpy as jnp +from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils +from jax.experimental.serialize_executable import deserialize_and_load import optax - import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager @@ -55,6 +54,7 @@ from maxtext.utils import max_utils from maxtext.utils import sharding from maxtext.utils import elastic_utils +from maxtext.utils import maxtext_utils_nnx OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" @@ -102,7 +102,10 @@ def get_functional_train_with_signature( """Get the shardings (both state and data) for `train_step`.""" functional_train = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings) functional_train.__name__ = "train_step" - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + if config.pure_nnx: + in_shardings = (state_mesh_shardings, data_sharding) # State, batch + else: + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng out_shardings = (state_mesh_shardings, None) # State, metrics static_argnums = () # We partial out the static argnums of model and config donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory. @@ -113,7 +116,10 @@ def get_functional_eval_with_signature(eval_step, data_sharding, state_mesh_shar """Get the shardings (both state and data) for `eval_step`.""" functional_eval = functools.partial(eval_step, model, config) functional_eval.__name__ = "eval_step" - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + if config.pure_nnx: + in_shardings = (state_mesh_shardings, data_sharding) # State, batch (NNX: no rng) + else: + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng out_shardings = None # metrics static_argnums = () # We partial out the static argnums of model, config donate_argnums = () # state will be kept instead of being donated in eval_step @@ -1236,15 +1242,15 @@ def _apply_update(path, param): return state.replace(params=new_params) -def init_decode_state(apply_fn, params) -> train_state.TrainState: +def init_decode_state(apply_fn, params) -> TrainState: """Init train state with null opt state for decode.""" - state = train_state.TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore + state = TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore return state def init_training_state(apply_fn, params, tx): """Init train state with null opt state for decode.""" - state = train_state.TrainState.create(apply_fn=apply_fn, params=params, tx=tx) + state = TrainState.create(apply_fn=apply_fn, params=params, tx=tx) return state @@ -1372,7 +1378,7 @@ def setup_initial_state( is_training: True to initialize training state, False for decode state Returns: - state: the initialized train state + train_state: the initialized train state. For NNX, this is a TrainStateNNX instance state_mesh_annotations: the mesh annotations for the train state """ @@ -1411,29 +1417,48 @@ def setup_initial_state( else: # The update of data_iterator state happens in place, no need to assign explicitly state = restored["items"] + + # For NNX, convert the pure dict to nnx.State using the abstract state as template + if config.pure_nnx: + nnx.replace_by_pure_dict(unboxed_abstract_state, state) + state = unboxed_abstract_state else: init_state_partial = init_state_fn init_state_partial.__name__ = "initialize_state" - # pylint: disable=not-callable - state = jax.jit( - init_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings, - )() - sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m - if sparsity_enabled and raw_params: # If we loaded a partial state, we need to merge it. - - def _merge_params(p_raw, p_init): - if isinstance(p_raw, jax.ShapeDtypeStruct): - return p_init - return p_raw - - merged_params = jax.tree_util.tree_map(_merge_params, raw_params, state.params) - state = state.replace(params=merged_params) - elif raw_params: - state = state.replace(params=raw_params) - - state = max_utils.unbox_logicallypartioned(state) + if config.pure_nnx: + state = jax.jit( + lambda: nnx.state(init_state_partial()), # Get state only, mapping to out_sharding structure + in_shardings=None, + out_shardings=state_mesh_shardings, + )() + else: + # pylint: disable=not-callable + state = jax.jit( + init_state_partial, + in_shardings=None, + out_shardings=state_mesh_shardings, + )() + if raw_params: # If we loaded a partial state, we need to merge it. + if config.pure_nnx: + # raw_params should have the same sharding info as in the model + nnx.update(state.model, raw_params) + else: + sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m + if sparsity_enabled: + # Sparsity-init keeps freshly initialized params for any leaf still + # represented as an abstract ShapeDtypeStruct in raw_params (i.e. not + # actually restored), and uses the restored value otherwise. + def _merge_params(p_raw, p_init): + if isinstance(p_raw, jax.ShapeDtypeStruct): + return p_init + return p_raw + + merged_params = jax.tree_util.tree_map(_merge_params, raw_params, state.params) + state = state.replace(params=merged_params) + else: + state = state.replace(params=raw_params) + if not config.pure_nnx: + state = max_utils.unbox_logicallypartioned(state) return state, state_mesh_annotations, state_mesh_shardings, data_iterator @@ -1448,6 +1473,9 @@ def get_logical_annotations(config, mesh, init_state_fn): def get_abstract_state(config, mesh, init_state_fn, is_training=True): """Get a shaped abstraction of the state (including optimizer)""" + if config.pure_nnx: + return get_abstract_state_nnx(config, mesh, init_state_fn, is_training) + init_state_partial = init_state_fn with nn_partitioning.axis_rules(config.logical_axis_rules): @@ -1491,6 +1519,157 @@ def move(path, x): ) +def get_nnx_named_sharding_with_scan_axis(abs_var_state: nnx.State, mesh) -> nnx.State: + """Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis. + + Unlike flax.nnx.spmd.get_var_pspec (used inside nnx.get_abstract_model), this function also + inserts the partition_name axis at the correct scan_axis position for parameters created by + _create_scanned_layers. Without this, scanned parameters get a 2D partition spec applied to a + 3D tensor, placing sharding on the stacked-layers dimension instead of the embedding dimension. + + Args: + abs_var_state: NNX abstract variable state from nnx.split(nnx.eval_shape(...)). + mesh: JAX physical mesh. + + Returns: + Same tree structure as abs_var_state but each Variable's value replaced with NamedSharding. + """ + + def _make_named_sharding(v): + val = v.get_value() + if not hasattr(val, "shape"): + # `val` is either truly leafless (e.g. optax MaskedNode) or a composite + # pytree of tensors (e.g. AQT QTensor on serve-mode quantized variables — + # a `qvalue` int8 array + a list of `scale` bf16 arrays). For the latter + # we must emit a parallel tree of NamedSharding leaves so the downstream + # `jax.tree.map(lambda a, s: ShapeDtypeStruct(..., sharding=s), abs, names)` + # finds a real Sharding at every position. Replicated sharding is a safe + # default — AQT serve-mode QTensors are normally small (per-channel scale + # factors and packed int8 weights) and don't need axis-aware sharding. + if jax.tree_util.tree_leaves(val): + replicated = NamedSharding(mesh, PartitionSpec()) + return v.replace(jax.tree.map(lambda _: replicated, val)) + return v + metadata = v.get_metadata() + out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding") + if not out_sharding: + pspec = PartitionSpec() + else: + # Insert the scan axis for parameters created by _create_scanned_layers. + # _add_scan_metadata stores the axis name in nnx.PARTITION_NAME and the + # axis index in "param_scan_axis". flax.nnx.spmd.get_var_pspec ignores these. + if nnx.PARTITION_NAME in metadata: + partition_name = metadata[nnx.PARTITION_NAME] + # Always use param_scan_axis from metadata. OptVariable (optimizer state) inherits + # param_scan_axis=1 from the model Param via to_opt_state(), so we must not hardcode + # scan_axis=0 for non-Param types. stacked_rest non-Param variables have + # param_scan_axis=0 set explicitly by _add_scan_metadata, so this is always correct. + scan_axis = metadata.get("param_scan_axis", 0) + out_sharding = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding) + # Guard against double-insertion: Flax 0.12.6 _remap_sharding_metadata renames + # 'sharding' -> 'out_sharding', so _add_scan_metadata may have already inserted + # the scan axis. Only insert if not already present. + if partition_name not in out_sharding: + out_sharding.insert(scan_axis, partition_name) + out_sharding = tuple(out_sharding) + # Convert logical axis names to physical mesh axes using current context rules. + context_rules = get_logical_axis_rules() + local_rules = metadata.get("sharding_rules", ()) + if context_rules or local_rules: + rules = composite_rules(context_rules, local_rules) + pspec = PartitionSpec(*from_sharding_rules(out_sharding, rules)) + else: + pspec = PartitionSpec(*out_sharding) + return v.replace(NamedSharding(mesh, pspec)) + + return jax.tree.map(_make_named_sharding, abs_var_state, is_leaf=lambda x: isinstance(x, nnx.Variable)) + + +def get_abstract_state_nnx(config, mesh, nnx_init_trainstate_fn, is_training=True): + """Calculates the abstract sharded state and memory placement for an NNX TrainState. + + This function performs an abstract trace of the NNX model and optimizer using + `nnx.get_abstract_model`. It resolves logical sharding annotations into physical + JAX shardings and applies memory placement optimizations such as optimizer + sharding and host memory offloading (pinning to CPU RAM). + + Args: + config: Configuration object containing sharding and offloading hyperparameters + (e.g., shard_optimizer_over_data, optimizer_memory_host_offload). + mesh: JAX physical mesh used to resolve logical axis names to physical devices. + nnx_init_trainstate_fn: A zero-argument factory function that produces a + TrainStateNNX instance during the abstract trace. + is_training: Boolean indicating if the state is for training. If True, + optimizer state is processed and memory offloading strategies are applied. + + Returns: + A tuple containing (abstract_sharded_state, None, state_mesh_shardings): + abstract_sharded_state: An nnx.State containing ShapeDtypeStructs with + fully resolved physical sharding and memory_kind metadata. + state_mesh_annotations: An nnx.State tree consisting of the raw PartitionSpec + objects corresponding to each parameter/variable. + state_mesh_shardings: An nnx.State tree consisting of the raw JAX + Sharding objects corresponding to each parameter/variable. + """ + assert nnx_init_trainstate_fn is not None, "get_abstract_state_nnx: init function must be given." + + with nn_partitioning.axis_rules(config.logical_axis_rules): + # Use nnx.eval_shape + nnx.split instead of nnx.get_abstract_model, so we can apply + # get_nnx_named_sharding_with_scan_axis which correctly inserts the stacked-layers + # axis into the partition spec. nnx.get_abstract_model uses get_var_pspec internally + # which ignores nnx.PARTITION_NAME / param_scan_axis metadata set by _create_scanned_layers, + # causing the 2D partition spec to be misapplied to the 3D stacked parameter tensor. + # Do NOT wrap nnx.eval_shape in jax.set_mesh: Flax 0.12.6's _to_variable calls + # var.shape for every variable when a global mesh is active, but masked optimizer + # state variables (e.g. from trainable_parameters_mask) have value=MaskedNode() + # which has no .shape and would raise AttributeError. We handle sharding + # ourselves via get_nnx_named_sharding_with_scan_axis, so auto-assignment is not + # needed here. + abs_model = nnx.eval_shape(nnx_init_trainstate_fn) + _, abs_var_state = nnx.split(abs_model) + named_sharding_state = get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh) + abstract_state = jax.tree.map( + lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), + abs_var_state, + named_sharding_state, + ) + + state_mesh_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + + if is_training and config.shard_optimizer_over_data: + # Add data to sharding for optimizer state + optimizer_sharding = jax.tree_util.tree_map_with_path( + functools.partial(sharding.add_data_to_sharding, mesh), + abstract_state.optimizer, + state_mesh_shardings.optimizer, + ) + state_mesh_shardings.optimizer = optimizer_sharding + if is_training and config.optimizer_memory_host_offload: + optimizer_sharding = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_host, + state_mesh_shardings.optimizer, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + state_mesh_shardings.optimizer = optimizer_sharding + if is_training and config.parameter_memory_host_offload: + assert config.param_scan_axis == 0, "You must set the scan axis 0 to enable parameter offloading." + _, state_params, _ = nnx.split(state_mesh_shardings, nnx.Param, ...) + state_params = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_host, + state_params, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + nnx.update(state_mesh_shardings, state_params) + + abstract_sharded_state = maxtext_utils_nnx.set_named_sharding_nnx(abstract_state, state_mesh_shardings) + state_mesh_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings) + return ( + abstract_sharded_state, + state_mesh_annotations, + state_mesh_shardings, + ) + + def get_prefill_kv_cache_annotations(model, config, rng, mesh, page_state: None | PageState = None): """Get a shaped abstraction of the state (including optimizer)""" @@ -1556,6 +1735,30 @@ def init_kv_cache(model, config): return state_mesh_annotations +def _nnx_cache_partition_specs(abstract_model, config, mesh): + """Per-leaf PartitionSpec tree for the abstract model's nnx.Cache vars. + + Returned as a pure dict so the engine can wrap it in NamedSharding the same + way it does for the Linen helpers below. + """ + _, cache_state, _ = nnx.split(abstract_model, nnx.Cache, ...) + # get_nnx_named_sharding_with_scan_axis reads logical axis rules from the + # active flax partitioning context, so wrap. + with nn_partitioning.axis_rules(config.logical_axis_rules): + named_state = get_nnx_named_sharding_with_scan_axis(cache_state, mesh) + return jax.tree.map(lambda s: s.spec, named_state.to_pure_dict()) + + +def get_prefill_kv_cache_annotations_nnx(abstract_model, config, mesh): + """NNX equivalent of get_prefill_kv_cache_annotations.""" + return _nnx_cache_partition_specs(abstract_model, config, mesh) + + +def get_kv_cache_annotations_nnx(abstract_model, config, mesh): + """NNX equivalent of get_kv_cache_annotations.""" + return _nnx_cache_partition_specs(abstract_model, config, mesh) + + def save_quantized_checkpoint_if_configured(config, params): """Save quantized checkpoint if configured""" assert config.quantization, "quantization must be configured" @@ -1736,26 +1939,41 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No """ Print state shardings comparing Logical Definition vs Physical Result. """ - if not hasattr(params, "params"): - params = {"params": params} - if not hasattr(params_sharding, "params"): - params_sharding = {"params": params_sharding} - if logical_annotations and not hasattr(logical_annotations, "params"): - logical_annotations = {"params": logical_annotations} + if not isinstance(params, nnx.State): + if not hasattr(params, "params"): + params = {"params": params} + if not hasattr(params_sharding, "params"): + params_sharding = {"params": params_sharding} + if logical_annotations and not hasattr(logical_annotations, "params"): + logical_annotations = {"params": logical_annotations} leaves_params, _ = jax.tree_util.tree_flatten_with_path(params) leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding) - leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations) - for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical): - path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) - shape = jax.typeof(leaf_val) - pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) - pspec_str = str(tuple(pspec)) - logical_str = str(leaf_logical_val) - - message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}" - max_logging.info(message) + if logical_annotations is not None: + leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations) + for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip( + leaves_params, leaves_sharding, leaves_logical + ): + path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) + shape = jax.typeof(leaf_val) + pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) + pspec_str = str(tuple(pspec)) + logical_str = str(leaf_logical_val) + + message = ( + f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}" + ) + max_logging.info(message) + else: + for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding): + path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) + shape = jax.typeof(leaf_val) + pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) + pspec_str = str(tuple(pspec)) + + message = f" {path_str}\n" f" Shape: {shape}\n" f" Physical: {pspec_str}" + max_logging.info(message) print(flush=True) diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index ab85894832..34fe18ae77 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -1,3 +1,17 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Copyright 2023–2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,16 +32,17 @@ import dataclasses import collections from collections.abc import Sequence +from typing import Callable, overload from functools import partial import os import subprocess import sys -from typing import overload from etils import epath from flax import nnx import flax.linen as nn import jax import jax.numpy as jnp +import numpy as np from jax.sharding import Mesh from maxtext.configs import pyconfig from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN @@ -452,6 +467,7 @@ def from_config( *, model_mode: str = MODEL_MODE_TRAIN, rngs: None = None, + quant_mode_str: str = "train", ) -> nn.Module: ... @@ -464,6 +480,7 @@ def from_config( *, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rngs, + quant_mode_str: str = "train", ) -> models.Transformer: ... @@ -475,25 +492,18 @@ def from_config( *, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rngs | None = None, + quant_mode_str: str = "train", ) -> nn.Module | models.Transformer: """Load a pretrained MaxText model from checkpoint. - This function loads a model from a checkpoint. - - Args: - config: Config object. - devices: Sequence of devices to use for the model. If None, use all - available devices. - - Returns: - Transformer: The loaded model instance (only the model) - - Example: - model = from_config(config) + `quant_mode_str` is one of "train", "convert", "serve" — controls the AQT + quantization mode at model construction time. NNX layers freeze their + param shape on `quant_mode_str` (e.g. SERVE skips the full-precision kernel), + so callers loading a pre-quantized checkpoint must pass `"serve"`. """ if mesh is None: mesh = maxtext_utils.get_mesh_from_config(config, devices) - model = create_model(config, mesh, model_mode=model_mode, rngs=rngs) + model = create_model(config, mesh, model_mode=model_mode, rngs=rngs, quant_mode_str=quant_mode_str) # Return only the model return model @@ -507,43 +517,118 @@ def get_transformer_model(config, mesh, quant, model_mode: str = MODEL_MODE_TRAI return models.transformer_as_linen(config, mesh, quant=quant, model_mode=model_mode) -def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rngs | None = None): +def create_model( + config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rngs | None = None, *, quant_mode_str: str = "train" +): """Instantiates and returns the model object, sharded across the mesh.""" # Model definition - quant = quantizations.configure_quantization(config) + quant = quantizations.configure_quantization(config, quant_mode_str=quant_mode_str) model = get_transformer_model(config, mesh, quant, model_mode=model_mode, rngs=rngs) model = quantizations.maybe_quantize_model(model, config) return model -def create_nnx_abstract_model(config, mesh, model_mode=MODEL_MODE_TRAIN, rng_key=None): - """Returns (_create_model_partial, abstract_model) for AOT compilation. +def get_nnx_create_model_fn( + config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None, *, quant_mode_str: str = "train" +) -> Callable: - This does not shard parameters or load checkpoints. It only builds the - abstract shape/dtype structure needed by get_abstract_state and optimizer - construction (e.g. Muon). + def _create_model(): + rngs = maxtext_utils_nnx.create_nnx_rngs(config, model_mode=model_mode, rng_key=rng_key) + return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode, quant_mode_str=quant_mode_str) - Args: - config: the configuration - mesh: the device mesh - model_mode: train or inference - rng_key: optional RNG key + return _create_model + + +def create_nnx_abstract_model( + config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None, *, quant_mode_str: str = "train" +) -> tuple[Callable, nnx.Module]: + """Creates an abstract NNX model. + + `quant_mode_str` is forwarded to model construction so AQT layers freeze + the right param shape (e.g. SERVE skips the full-precision kernel). Returns: - (_create_model_partial, abstract_model) where _create_model_partial() creates - a concrete model instance and abstract_model is the eval_shape result. + A tuple containing (create_model_fn, abstract_model): + create_model_fn: A zero-argument callable that produces a new model instance. + abstract_model: The stateful NNX model instance in an abstract state. """ - def _create_model(rng_key=None): - rngs = maxtext_utils_nnx.create_nnx_rngs(config, model_mode=model_mode, rng_key=rng_key) - return from_config(config, mesh=mesh, rngs=rngs, model_mode=model_mode) + with nn.logical_axis_rules(config.logical_axis_rules): + _create_model = get_nnx_create_model_fn(config, mesh, devices, model_mode, rng_key, quant_mode_str=quant_mode_str) + if mesh is None: + _tmp = nnx.eval_shape(_create_model) + mesh = _tmp.mesh + # Use nnx.eval_shape + our scan-axis-aware sharding helper instead of + # nnx.get_abstract_model, which uses get_var_pspec internally and ignores + # param_scan_axis / nnx.PARTITION_NAME metadata set by _create_scanned_layers, + # causing the stacked layers axis to be missing from the PartitionSpec. + # Wrapping in `jax.set_mesh(mesh)` trips Flax 0.12.6's `_to_variable` for + # serve-mode AQT variables (NamedSharding with `spec=None` rejected under + # AbstractMesh). Sharding is resolved afterwards via the helper, so the + # wrap is unnecessary here. + abs_model = nnx.eval_shape(_create_model) + graphdef, abs_var_state = nnx.split(abs_model) + named_sharding_state = maxtext_utils.get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh) + abstract_state = jax.tree.map( + lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), + abs_var_state, + named_sharding_state, + ) + return _create_model, nnx.merge(graphdef, abstract_state) - _create_model_partial = partial(_create_model, rng_key=rng_key) + +def create_nnx_sharded_model_hybrid(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): + """Creates a sharded model for hybrid NNX modules containing Linen sub-modules. + + DEPRECATED: This function is a transitional utility for the Linen-to-NNX + migration. It should be removed once all model components are ported to + pure NNX modules. + + This function specifically handles the complexity of "mixed" state initialization, + where logical sharding annotations must be resolved for both NNX native + Parameters and legacy Linen variables wrapped via the NNX-Linen bridge. + It ensures that both systems correctly respect the provided mesh and + logical axis rules during the abstraction/sharding planning phase. + """ + _create_model_partial = get_nnx_create_model_fn(config, mesh, devices, model_mode, rng_key) with nn.logical_axis_rules(config.logical_axis_rules): abstract_model = nnx.eval_shape(_create_model_partial) + graphdef, abstract_state = nnx.split(abstract_model) + specs = nnx.get_partition_spec(abstract_state) - return _create_model_partial, abstract_model + if mesh is None: + mesh = abstract_model.mesh + + # JIT a function that creates the model state with proper sharding from the start. + # By providing out_shardings, we instruct JAX to produce sharded output directly, + # avoiding a large intermediate allocation on a single device. + with nn.logical_axis_rules(config.logical_axis_rules): + out_shardings = nn.logical_to_mesh_sharding(specs, mesh) + + @partial(jax.jit, out_shardings=out_shardings) + def create_sharded_state(): + # This will be JIT-compiled. JAX knows the output sharding and can + # initialize the parameters directly on the target devices in a sharded way. + model = _create_model_partial() + return nnx.state(model) + + with mesh: + # Create the model with sharded parameters. + with nn.logical_axis_rules(config.logical_axis_rules): + sharded_state = create_sharded_state() + model = nnx.merge(graphdef, sharded_state) + + # print weights sharding info under debug sharding mode + if config.debug_sharding: + max_utils.print_non_trivial_mesh_axis(model.mesh) + maxtext_utils.print_shardings_params( + params=sharded_state, + params_sharding=out_shardings, + mesh=model.mesh, + logical_annotations=specs, + ) + return model def setup_configs_and_devices(argv: list[str] | None = None, kwargs: dict | None = None, **extra_kwargs): @@ -670,9 +755,21 @@ def create_models_and_meshes(trainer_config, sampler_config, trainer_devices, sa def from_pretrained( - config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None, wrap_with_tunix_adapter=False + config, + mesh=None, + devices=None, + model_mode=MODEL_MODE_TRAIN, + rng_key=None, + wrap_with_tunix_adapter=False, + *, + quant_mode_str: str = "train", ): - """Creates a NNX model with sharded parameters, possibly loading from a checkpoint.""" + """Creates a NNX model with sharded parameters, possibly loading from a checkpoint. + + `quant_mode_str` is forwarded to model construction. Pass `"serve"` when + loading a pre-quantized checkpoint so AQT layers materialize the on-disk + scale factors instead of full-precision kernels. + """ original_mesh = mesh if config.convert_checkpoint_if_possible and not config.load_parameters_path: if not (epath.Path(config.base_output_directory) / "0" / "items").exists(): @@ -728,60 +825,34 @@ def from_pretrained( ) config = pyconfig.HyperParameters(new_config) - def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None): - rngs = maxtext_utils_nnx.create_nnx_rngs(config, model_mode=model_mode, rng_key=rng_key) - return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode) - - _create_model_partial = partial(_create_model, mesh=mesh, model_mode=model_mode, rng_key=rng_key) - + if config.pure_nnx: + _create_model, abstract_model = create_nnx_abstract_model( + config, mesh, devices, model_mode, rng_key, quant_mode_str=quant_mode_str + ) + model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model, mesh=mesh) + # TODO: print debug_sharding info + else: + model = create_nnx_sharded_model_hybrid(config, mesh, devices, model_mode, rng_key) + + # Compute logical-axis specs for downstream checkpoint alignment. + # The model-creation helpers above resolve specs internally for sharding, but + # the checkpoint-loading branch below needs the logical PartitionSpec tree + # (axis names like "kv_heads", "mlp_moe") for repeat/zero-pad dispatch in + # _align_checkpoint_to_model_shapes. nnx.eval_shape is cheap (abstract trace). + _create_model_for_specs = get_nnx_create_model_fn( + config, mesh, devices, model_mode, rng_key, quant_mode_str=quant_mode_str + ) with nn.logical_axis_rules(config.logical_axis_rules): - abstract_model = nnx.eval_shape(_create_model_partial) - graphdef, abstract_state = nnx.split(abstract_model) - specs = nnx.get_partition_spec(abstract_state) + _abs_model_for_specs = nnx.eval_shape(_create_model_for_specs) + _, _abs_state_for_specs = nnx.split(_abs_model_for_specs) + specs = nnx.get_partition_spec(_abs_state_for_specs) - if mesh is None: - mesh = abstract_model.mesh + sharded_state = nnx.state(model) - # Note for pure_nnx: - # Currently, the NNX model returned has a linen decoder wrapped to NNX. So it is not a pure NNX model and - # we still need to use nn.logical_axis_rules(config.logical_axis_rules) to get the out sharding from the linen - # LogicallyPartitioned structure. - # In the future if the pure NNX model is used, with pure NNX's eager sharding, there will be no LogicallyPartitioned - # structure in the abstract state and we can get the sharded state with the following code: - # graphdef, state = nnx.get_abstract_model(_create_model_partial, mesh) - # abstract_model = nnx.merge(graphdef, state) - # model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model_partial, mesh=mesh) - # sharded_state = nnx.state(model) - - # JIT a function that creates the model state with proper sharding from the start. - # By providing out_shardings, we instruct JAX to produce sharded output directly, - # avoiding a large intermediate allocation on a single device. - with nn.logical_axis_rules(config.logical_axis_rules): - out_shardings = nn.logical_to_mesh_sharding(specs, mesh) - - @partial(jax.jit, out_shardings=out_shardings) - def create_sharded_state(): - # This will be JIT-compiled. JAX knows the output sharding and can - # initialize the parameters directly on the target devices in a sharded way. - model = _create_model_partial() - return nnx.state(model) + if mesh is None: + mesh = model.mesh with mesh: - # Create the model with sharded parameters. - with nn.logical_axis_rules(config.logical_axis_rules): - sharded_state = create_sharded_state() - model = nnx.merge(graphdef, sharded_state) - - # print weights sharding info under debug sharding mode - if config.debug_sharding: - max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_shardings_params( - params=sharded_state, - params_sharding=out_shardings, - mesh=model.mesh, - logical_annotations=specs, - ) - if config.load_parameters_path: try: ckptr = ocp.Checkpointer( @@ -874,10 +945,51 @@ def _adjust_target_for_moe_fusion(target, meta_tree, is_nnx): } else: # NNX checkpoint: {'decoder': {'value': ...}}, or NNX-RL with extra 'base' nesting. - # Restore only nnx.Param — RNG variable shapes may differ between checkpoint and model. + # Restore only nnx.Param — RNG variable shapes may differ between checkpoint and model, + # and pure-dict checkpoints written by `layerwise_quantization._load_and_quantize_nnx` + # don't carry RNG/dropout state at all (they only persist nnx.Param leaves, including + # AQT serve-mode `qrhs.frozen` which is a Param subclass). + def _build_value_target(v): + # `v[...]` (a.k.a. `v.get_value(index=...)`) descends into the inner + # value with `value[Ellipsis]`. AQT serve-mode `qrhs.frozen` variables + # wrap a QTensor whose `__getitem__` calls `qvalue[idx]` on a + # `LogicallyPartitioned` wrapper — that fails. For QTensor (and any + # composite pytree value), use the unwrapped value directly so the + # restore target preserves the QTensor's qvalue/scale sub-structure. + inner = v.get_value() if hasattr(v, "get_value") else v[...] + if hasattr(inner, "shape"): + return {"value": v[...]} + # AQT QTensor: qvalue/scale leaves come back wrapped in flax + # `Partitioned` (a logical-axis sharding box). The on-disk save in + # `_load_and_quantize_nnx` flushes the QTensor as plain arrays — + # paths look like `qrhs.frozen.value.qvalue` / `...scale.0`. If we + # leave Partitioned in place, jax.tree adds an extra `.value` key + # under each leaf (`qrhs.frozen.value.qvalue.value`) and orbax + # silently fills with zeros because that path doesn't exist on + # disk. Strip Partitioned wrappers so the target tree matches. + from flax.core.meta import Partitioned + inner = jax.tree.map( + lambda x: x.value if isinstance(x, Partitioned) else x, + inner, + is_leaf=lambda x: isinstance(x, Partitioned), + ) + return {"value": inner} + + # Keep persisted weight-like leaves: `nnx.Param` plus AQT serve-mode + # `qrhs.frozen` (a separate `aqt` Variable type, NOT a Param subclass). + # Excluded: `nnx.RngState` (regenerated per load, shapes can drift) and + # `nnx.Cache` (PREFILL/AR scratch, not persisted). Pure-dict checkpoints + # written by `layerwise_quantization._load_and_quantize_nnx` carry both + # Param kernels and `aqt`-typed `qrhs.frozen` quantized payloads. + if hasattr(sharded_state, "filter"): + param_state = sharded_state.filter( + lambda path, var: not isinstance(var, (nnx.RngState, nnx.Cache)) + ) + else: + param_state = sharded_state target_for_restore = jax.tree.map( - lambda v: {"value": v[...]}, - sharded_state, + _build_value_target, + param_state, is_leaf=lambda n: isinstance(n, nnx.Variable), ) has_base_key = "base" in metadata.item_metadata.tree @@ -891,10 +1003,14 @@ def _adjust_target_for_moe_fusion(target, meta_tree, is_nnx): # Free memory used by initial sharded_state before restore, to make room for the incoming checkpoint arrays. def _free_device_memory(node): + val = node if isinstance(node, nnx.Variable) and not isinstance(node, nnx.RngState): - val = node[...] - else: - val = node + inner = node.get_value() if hasattr(node, "get_value") else node[...] + # Same QTensor caveat as `_build_value_target`: AQT serve-mode `qrhs.frozen` + # wraps a QTensor whose `__getitem__` fails on `LogicallyPartitioned`. + # We only need to free a single jax.Array leaf — for composite values + # there's nothing to free at this level, so skip. + val = inner if hasattr(inner, "shape") else None if isinstance(val, jax.Array) and not val.is_deleted(): val.delete() @@ -921,8 +1037,14 @@ def _free_device_memory(node): checkpoint = restored["params"]["params"] if checkpoint: + # Same QTensor caveat as `_build_value_target` / `_free_device_memory`: + # `v[...]` fails on Variables wrapping QTensors. Use `get_value()` to + # access the inner value directly without index-style descent. + def _unwrap_for_align(v): + return v.get_value() if hasattr(v, "get_value") else v[...] + model_arrays = jax.tree.map( - lambda v: v[...], + _unwrap_for_align, sharded_state, is_leaf=lambda n: isinstance(n, nnx.Variable), ) @@ -971,6 +1093,12 @@ def _walk_align(ckpt, model_arr, axes): ) for k, v in ckpt.items() } + # AQT serve-mode `qrhs.frozen` wraps a QTensor (composite pytree of + # qvalue+scale arrays), not a single jax.Array. Shape alignment + # only makes sense for full-precision kernels — quantized payloads + # are saved in the exact shape the model expects, so pass through. + if not isinstance(ckpt, (jax.Array, jax.ShapeDtypeStruct, np.ndarray)): + return ckpt return _align_checkpoint_to_model_shapes(ckpt, model_arr, axes) checkpoint = _walk_align(checkpoint, model_arrays, logical_axes_tree) diff --git a/src/maxtext/utils/muon_utils.py b/src/maxtext/utils/muon_utils.py index 3ba60d7371..049a084979 100644 --- a/src/maxtext/utils/muon_utils.py +++ b/src/maxtext/utils/muon_utils.py @@ -24,25 +24,23 @@ python3 -m maxtext.utils.muon_utils qwen3-4b True """ - import os import sys from typing import Optional, Tuple import flax.linen as nn +from flax import nnx import jax from maxtext.configs import pyconfig from maxtext.utils.globals import MAXTEXT_PKG_DIR from maxtext.layers import quantizations from maxtext.models import models -from maxtext.utils import maxtext_utils +from maxtext.utils import maxtext_utils, model_creation_utils from optax.contrib._muon import MuonDimensionNumbers as mdn -Transformer = models.transformer_as_linen - - def _is_path_contain_any(tuples, path): + """Checks if any element in 'tuples' is present in 'path'.""" return any(x in path for x in tuples) @@ -107,10 +105,26 @@ def get_transform_tree(tree, path=()): def get_muon_weight_dimension_numbers(model, config, verbose=False): """Extract muon dimension number from model structure.""" - # quickly get param structure without materialization - abstract_param = maxtext_utils.get_abstract_param(model, config) - # get muon dimension number from param - muon_weight_dimension_numbers = get_transform_tree(abstract_param) + + if isinstance(model, nnx.Module): + _, abstract_param, _ = nnx.split(model, nnx.Param, ...) + + def apply_transform_nnx(path: Tuple[jax.tree_util.KeyEntry, ...], leaf): + # Convert jax.tree_util.KeyEntry path to Tuple[str, ...] + path_strings = tuple(p.key for p in path if isinstance(p, jax.tree_util.DictKey)) + return transform_logic(path_strings) + + # Use jax.tree_util.tree_map_with_path for NNX's potentially complex PyTree structure. + # This is different with linen where abstract_param is a dict-based tree with nn.LogicallyPartitioned leaves. + # The result is an nnx.State with the same structure, where each Param's value holds the mdn result. + muon_weight_dimension_numbers = jax.tree_util.tree_map_with_path(apply_transform_nnx, abstract_param) + + else: # Linen + # quickly get param structure without materialization + abstract_param = maxtext_utils.get_abstract_param(model, config) + # get muon dimension number from param + muon_weight_dimension_numbers = get_transform_tree(abstract_param) + if verbose: _print_structure_debug(abstract_param, muon_weight_dimension_numbers) return muon_weight_dimension_numbers @@ -118,19 +132,30 @@ def get_muon_weight_dimension_numbers(model, config, verbose=False): def _print_structure_debug(abstract_param, muon_weight_dimension_numbers): """Prints the model structure and the resulting Muon config.""" - # Access the shape from the inner ShapeDtypeStruct and names from the wrapper - # Return a new tree with the same structure containing only shapes/names + + def get_leaf_info(leaf): + # For linen: + # Access the shape from the inner ShapeDtypeStruct and names from the wrapper + # Return a new tree with the same structure containing only shapes/names + if isinstance(leaf, nn.LogicallyPartitioned): + return {"shape": leaf.value.shape, "names": leaf.names} + # For nnx: + # Only return the shape because it doesn't have a wrapper. + elif isinstance(leaf, jax.ShapeDtypeStruct): + return {"shape": leaf.shape} + return {"shape": "N/A"} + info_tree = jax.tree_util.tree_map( - lambda leaf: {"shape": leaf.value.shape, "names": leaf.names}, + get_leaf_info, abstract_param, - is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned), + is_leaf=lambda x: isinstance(x, (nn.LogicallyPartitioned, jax.ShapeDtypeStruct)), ) print(f"\n=== Model Structure ===\n{info_tree}") print(f"\n=== Muon Dimension Numbers ===\n{muon_weight_dimension_numbers}") print("\nIs this reasonable?") -def get_model_mdn(model_name, scan_layers=True, verbose=False): +def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=True): """Initializes a model and retrieves its Muon dimension numbers. This function sets up the configuration for a given model, initializes the @@ -154,15 +179,21 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False): f"model_name={model_name}", f"scan_layers={scan_layers}", "attention=dot_product", + f"pure_nnx={pure_nnx}", ] config = pyconfig.initialize(argv) # Setup model devices_array = maxtext_utils.create_device_mesh(config) mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) quant = quantizations.configure_quantization(config) - model = Transformer(config, mesh=mesh, quant=quant) + if pure_nnx: + _, model = model_creation_utils.create_nnx_abstract_model(config, mesh) + else: + model = models.transformer_as_linen(config, mesh=mesh, quant=quant) # Get dimension number muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config, verbose=verbose) + if pure_nnx: + muon_weight_dimension_numbers = {"params": nnx.to_pure_dict(muon_weight_dimension_numbers)} return muon_weight_dimension_numbers @@ -172,4 +203,4 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False): sys.exit(1) model_name_arg = sys.argv[1] scan_layers_arg = sys.argv[2].lower() == "true" - get_model_mdn(model_name_arg, scan_layers_arg, verbose=True) + get_model_mdn(model_name_arg, scan_layers_arg, verbose=True, pure_nnx=False) diff --git a/src/maxtext/utils/qk_clip_utils.py b/src/maxtext/utils/qk_clip_utils.py index 64848b8ffb..d3a7b926e4 100644 --- a/src/maxtext/utils/qk_clip_utils.py +++ b/src/maxtext/utils/qk_clip_utils.py @@ -16,6 +16,7 @@ import jax import jax.numpy as jnp +from flax import nnx def _get_key_name(k): @@ -30,132 +31,150 @@ def _get_key_name(k): def calculate_max_logit_metric(intermediate_outputs): """Extracts and computes the global maximum logit from intermediate outputs. - Args: - intermediate_outputs: A pytree containing model intermediates, potentially - including 'max_logits' sowed by Attention layers. + Recognizes two shapes: Linen sow stores `(array,)` so the leaf path ends in + `max_logits, 0`; NNX `nnx.Intermediate(array)` stores the array directly so + the leaf path ends in `max_logits`. - Returns: - The global maximum logit scalar, or None if no logits were found. + Returns the global max scalar, or None if no logits were found. """ all_max_logits = [] def extract_logits(path, val): - # 'sow' stores values in a tuple/list. tree_map descends into it. - # The path to the leaf array will look like: (..., 'max_logits', 0) - # So we check if the parent key (path[-2]) is 'max_logits'. - if len(path) >= 2: - parent_key = _get_key_name(path[-2]) - if parent_key == "max_logits": - all_max_logits.append(val) + if not path: + return + last_key = _get_key_name(path[-1]) + parent_key = _get_key_name(path[-2]) if len(path) >= 2 else None + if last_key == "max_logits" or parent_key == "max_logits": + all_max_logits.append(val) jax.tree_util.tree_map_with_path(extract_logits, intermediate_outputs) if not all_max_logits: return None - # Compute max per layer first to handle potential shape mismatches return jnp.max(jnp.stack([jnp.max(x) for x in all_max_logits])) -def apply_qk_clip(state, intermediate_outputs, config): - """Applies QK-Clip to MLA weights based on max_logits. - - Iterates over parameters. If a parameter belongs to an MLA attention layer, - it finds the corresponding max_logits statistics from intermediate_outputs, - calculates the clipping factor, and applies it to W_q and W_k components. - - Args: - state: The current training state containing model parameters. - intermediate_outputs: A dictionary of intermediate outputs from the model - forward pass. It is expected to contain 'max_logits' entries sowed by - Attention layers if QK-Clip is enabled. - config: The model configuration object, containing QK-Clip hyperparameters - (e.g. qk_clip_threshold, qk_nope_head_dim) and attention_type. - - Returns: - A new training state with updated (clipped) parameters. - - Raises: - ValueError: If the configured attention_type is not 'mla'. - """ +def _check_attention_type(config): if getattr(config, "attention_type", None) != "mla": raise ValueError( f"QK-Clip is only supported for MLA attention (attention_type='mla'). " f"Current configuration: {getattr(config, 'attention_type', 'None')}" ) - tau = float(config.qk_clip_threshold) - def clip_mla_weights(path, param): - """Applies QK-Clip to a single parameter if it's an MLA projection weight. +def _max_logits_at(curr): + """Read max_logits from a node in the intermediates tree. + + Returns the [batch, num_heads] array, or None if not present. Handles both + the Linen sow shape (`{"max_logits": (array,)}`) and the NNX shape + (`{"max_logits": array}` or `{"attention_op": {"max_logits": array}}`). + """ + if not isinstance(curr, dict): + return None + ml = curr.get("max_logits") + if ml is None and "attention_op" in curr and isinstance(curr["attention_op"], dict): + ml = curr["attention_op"].get("max_logits") + if ml is None: + return None + if isinstance(ml, (tuple, list)): + return ml[0] if ml else None + return ml - Args: - path: A tuple of JAX Key objects representing the hierarchy path to the parameter in the state PyTree. - param: The actual JAX array (weight tensor) at the given path. - Returns: - The scaled parameter if it is an MLA projection ('wq_b' or 'wkv_b'), otherwise the original parameter. - """ - # Skip irrelevant weights (embeddings, norms, etc.). - # We only care about specific MLA projection matrices ('wq_b', 'wkv_b'). +def _scale_from_max_logits(max_logits_batch, tau): + s_max = jnp.max(max_logits_batch, axis=0) + return jnp.minimum(1.0, tau / (s_max + 1e-6)) + + +def _clip_mla_weight(layer_name, param, scale, qk_nope): + """Apply the per-head scale to a wq_b or wkv_b kernel.""" + scale_b = scale[None, :, None] # broadcasts over [rank, heads, dim] + head = param[..., :qk_nope] + tail = param[..., qk_nope:] + head_new = head * jnp.sqrt(scale_b) + if layer_name == "wq_b": + tail_new = tail * scale_b + else: # wkv_b: tail is the V slice, untouched + tail_new = tail + return jnp.concatenate([head_new, tail_new], axis=-1) + + +def apply_qk_clip(state, intermediate_outputs, config): + """Applies QK-Clip to MLA weights based on max_logits (Linen path). + + Returns a new TrainState with `wq_b`/`wkv_b` kernels rescaled per-head. + """ + _check_attention_type(config) + tau = float(config.qk_clip_threshold) + + def clip_mla_weights(path, param): if len(path) < 2: return param - layer_name = _get_key_name(path[-2]) if layer_name not in ("wq_b", "wkv_b"): return param - # Search for max_logits in intermediate_outputs curr = intermediate_outputs.get("intermediates", intermediate_outputs) for node in path[:-2]: key = _get_key_name(node) if isinstance(curr, dict) and key in curr: curr = curr[key] else: - return param # Path not found in intermediates, skip + return param - if not isinstance(curr, dict) or "max_logits" not in curr: + max_logits_batch = _max_logits_at(curr) + if max_logits_batch is None: return param - # max_logits was sowed as a tuple (array,) - # shape: [batch, num_heads] - max_logits_sowed = curr["max_logits"] - if not max_logits_sowed: - return param + scale = _scale_from_max_logits(max_logits_batch, tau) + return _clip_mla_weight(layer_name, param, scale, config.qk_nope_head_dim) - max_logits_batch = max_logits_sowed[0] - - # Calculate S_max (per head) - # We want the global maximum across the batch dimension. - # Result shape: [num_heads] - s_max = jnp.max(max_logits_batch, axis=0) - - # Calculate scaling factor gamma - # gamma = tau / s_max. Clip if s_max > tau. - scale = jnp.minimum(1.0, tau / (s_max + 1e-6)) - - # Apply qk clipping based on weight type - if layer_name == "wq_b": - # MLA Up-projection for Query [rank, heads, q_head_dim] - qk_nope = config.qk_nope_head_dim - w_qc = param[..., :qk_nope] - w_qr = param[..., qk_nope:] - scale_b = scale[None, :, None] # Broadcast: [1, heads, 1] - w_qc_new = w_qc * jnp.sqrt(scale_b) - w_qr_new = w_qr * scale_b - return jnp.concatenate([w_qc_new, w_qr_new], axis=-1) - - elif layer_name == "wkv_b": - # MLA Up-projection for Key/Value [rank, heads, kv_head_dim] - qk_nope = config.qk_nope_head_dim - w_kc = param[..., :qk_nope] - w_v = param[..., qk_nope:] - scale_b = scale[None, :, None] - w_kc_new = w_kc * jnp.sqrt(scale_b) - return jnp.concatenate([w_kc_new, w_v], axis=-1) - - return param - - # Apply transformation new_params = jax.tree_util.tree_map_with_path(clip_mla_weights, state.params) return state.replace(params=new_params) + + +def apply_qk_clip_nnx(state, intermediate_outputs, config): + """Applies QK-Clip to MLA weights on an NNX TrainStateNNX. + + `state.model` is mutated in place (NNX modules are mutable). Returns `state` + so call sites can use the same `new_state = apply_qk_clip(...)` pattern as + the Linen path. + + The intermediates tree mirrors the NNX module hierarchy, so `max_logits` + sowed by `AttentionOp` lives at `...self_attention.attention_op.max_logits`. + We accept either that shape or `...self_attention.max_logits` (matching the + Linen-side fixtures and small-test setups). + """ + _check_attention_type(config) + tau = float(config.qk_clip_threshold) + + _, params_state, _ = nnx.split(state.model, nnx.Param, ...) + params_dict = params_state.to_pure_dict() + + def clip_mla_weights(path, param): + if len(path) < 2: + return param + layer_name = _get_key_name(path[-2]) + if layer_name not in ("wq_b", "wkv_b"): + return param + + curr = intermediate_outputs + for node in path[:-2]: + key = _get_key_name(node) + if isinstance(curr, dict) and key in curr: + curr = curr[key] + else: + return param + + max_logits_batch = _max_logits_at(curr) + if max_logits_batch is None: + return param + + scale = _scale_from_max_logits(max_logits_batch, tau) + return _clip_mla_weight(layer_name, param, scale, config.qk_nope_head_dim) + + new_params_dict = jax.tree_util.tree_map_with_path(clip_mla_weights, params_dict) + nnx.replace_by_pure_dict(params_state, new_params_dict) + nnx.update(state.model, params_state) + return state diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index d4bb64f016..4a500e2fe1 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -15,7 +15,7 @@ # pylint: disable=line-too-long, disable=bare-except, consider-using-generator """ Utils that are only interesting to MaxText and sharding related. """ -from flax import linen as nn +from flax import linen as nn, nnx from collections.abc import Iterable @@ -25,6 +25,7 @@ import optax +from maxtext.configs import pyconfig from maxtext.common.common_types import ShardMode from maxtext.utils import max_logging from maxtext.utils import max_utils @@ -483,6 +484,8 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): - updated_state_mesh_shardings: State mesh shardings with updated params field (unchanged if shard_optimizer_over_data is False) """ + if config.pure_nnx: + return maybe_update_params_sharding_with_opt_nnx(config, state_mesh_shardings) prev_params_shardings = state_mesh_shardings.params if config.shard_optimizer_over_data: if isinstance(state_mesh_shardings.opt_state, optax.ScaleByAdamState): @@ -501,6 +504,122 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): return prev_params_shardings, state_mesh_shardings +def maybe_update_params_sharding_with_opt_nnx( + config: pyconfig.HyperParameters, state_mesh_shardings: nnx.State +) -> tuple[nnx.State, nnx.State]: + """ + NNX version of parameter sharding update. Updates parameter sharding configuration + when optimizer state sharding is enabled. + + When shard_optimizer_over_data is enabled (Zero-1 style sharding), this function + extracts the optimizer state shardings from the Adam optimizer's first moment (mu) + and merges them with the parameter shardings. This ensures parameter sharding is + consistent with how the optimizer state is distributed across the compute mesh. + + Args: + config: Configuration with shard_optimizer_over_data flag. + state_mesh_shardings: The sharding state for a TrainStateNNX container. + + Returns: + A tuple of (prev_params_shardings, updated_state_mesh_shardings): + - prev_params_shardings: Original parameter shardings before the update + - updated_state_mesh_shardings: State mesh shardings with updated params field + (unchanged if shard_optimizer_over_data is False)""" + # In TrainStateNNX, parameters are under 'model' + model_shardings = state_mesh_shardings.model + + def _extract_param_only(state): + """Recursively extract nnx.Param variables from an nnx.State into a nested plain dict. + + Constructs nnx.State({'key': nested_dict, ...}) which produces the same pytree + structure as nnx.split(model, nnx.Param, ...)[1], enabling jax.tree.map + to work correctly between ga_params (Param-only) and params_shardings. + """ + result = {} + for k, v in state.items(): + if isinstance(v, nnx.Param): + result[k] = v + elif isinstance(v, nnx.Variable): + pass # skip non-Param variables (RngKey, RngCount, OptVariable, etc.) + elif hasattr(v, "items"): + sub = _extract_param_only(v) + if sub: + result[k] = sub + return result + + # prev_params_shardings must match the pytree structure of ga_params from + # nnx.split(model, nnx.Param, ...) — Param variables only, no rngs. + prev_params_shardings = nnx.State(_extract_param_only(model_shardings)) + + if not config.shard_optimizer_over_data: + return prev_params_shardings, state_mesh_shardings + + sharded_fp32_params = None + # Check if the optimizer has any state at all (stateless optimizers like SGD omit this key) + if "opt_state" in state_mesh_shardings.optimizer: + # Access the optimizer branch to find the optax state + # state_mesh_shardings.optimizer contains the sharding for the nnx.Optimizer + opt_state = state_mesh_shardings.optimizer.opt_state + + def find_adam_mu(obj): + # 1. Direct hit on ScaleByAdamState (Linen path or unflattened NNX) + if isinstance(obj, optax.ScaleByAdamState): + return obj.mu + + # 2. Check for flattened ScaleByAdamState (nnx.State/dict) + # These nodes contain 'mu', 'nu', and 'count' as keys. + if hasattr(obj, "__getitem__") and "mu" in obj and "nu" in obj: + return obj["mu"] + + # 3. Recursive search through containers (nnx.State, dict, list, tuple) + values = None + if hasattr(obj, "values"): # Handles nnx.State and dict + values = obj.values() + elif isinstance(obj, (list, tuple)): + values = obj + + if values: + for v in values: + res = find_adam_mu(v) + if res is not None: + return res + return None + + sharded_fp32_params = find_adam_mu(opt_state) + if sharded_fp32_params is None: + actual_type = type(state_mesh_shardings.optimizer.get("opt_state", "None")) + raise NotImplementedError(f"Could not find Adam optimizer state in: {actual_type}") + + # Update model parameter sharding to match the mu (first moment) sharding. + # This ensures parameter sharding is consistent with the Zero-1 distributed layout. + # Build a path → new_PS lookup from sharded_fp32_params (mu), then update model_shardings + # at those paths while preserving rngs and any other non-Param variables. + mu_leaves_with_paths = list( + jax.tree_util.tree_leaves_with_path(sharded_fp32_params, is_leaf=lambda x: isinstance(x, nnx.Variable)) + ) + mu_lookup = {path: mu_var.get_value() for path, mu_var in mu_leaves_with_paths} + + def _update_model_var(path, var): + if path in mu_lookup: + return var.replace(mu_lookup[path]) + return var + + new_model_shardings = jax.tree_util.tree_map_with_path( + _update_model_var, model_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable) + ) + # Use jax.tree_util.tree_map (identity) to create a new nnx.State via JAX's unflatten + # mechanism (not the nnx.State constructor). This is critical because: + # 1. nnx.State({...}) constructor recursively converts nested plain dicts to nnx.State, + # causing a pytree type mismatch with the actual state from nnx.split (which stores + # nested module states as plain dicts). JAX's unflatten preserves the original types. + # 2. copy.deepcopy fails because NamedSharding contains non-picklable jaxlib.Device objects. + # Direct __setattr__ assignment stores new_model_shardings as-is (no type conversion). + updated_state = jax.tree_util.tree_map(lambda x: x, state_mesh_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable)) + updated_state.model = new_model_shardings + + return prev_params_shardings, updated_state + + def logical_axis_rules_pp_act_as_dp(logical_rules): """Add stage as a physical axes before data for each rule, so stage acts just like data instead of PP. This is used when we want to pipeline only a subset of layers, and leave the rest like DP. diff --git a/src/maxtext/utils/standalone_checkpointer.py b/src/maxtext/utils/standalone_checkpointer.py index ba6b148b04..658c74e788 100644 --- a/src/maxtext/utils/standalone_checkpointer.py +++ b/src/maxtext/utils/standalone_checkpointer.py @@ -24,15 +24,19 @@ from typing import Sequence from absl import app +from flax import nnx from flax.linen import partitioning as nn_partitioning import jax from jax import numpy as jnp from maxtext.configs import pyconfig from maxtext.common import checkpointing +from maxtext.layers import train_state_nnx from maxtext.models import models from maxtext.trainers.pre_train.train import get_first_step from maxtext.utils import max_logging from maxtext.utils import maxtext_utils +from maxtext.utils import maxtext_utils_nnx +from maxtext.utils import model_creation_utils from maxtext.utils import train_utils from maxtext.utils.model_creation_utils import from_config import numpy as np @@ -41,28 +45,29 @@ def checkpoint_loop(config, state=None): - """Main Checkpointing loop. + """Save/restore exerciser. - Saves checkpoints. - - Args: - config: - state: - ckpt_path: - - Returns: + Builds an abstract train state, restores or initializes it, perturbs the + optimizer moments via `add_entropy_to_checkpoint`, then writes checkpoints + on the configured cadence. Works on both Linen and NNX state shapes. """ - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = from_config(config) - mesh = model.mesh init_rng = jax.random.PRNGKey(config.init_weights_seed) - _, tx = train_utils.create_training_optimizer(config, model) if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + mesh = maxtext_utils.get_mesh_from_config(config) + rngs = maxtext_utils_nnx.create_nnx_rngs(config, rng_key=init_rng) + model = from_config(config, mesh=mesh, rngs=rngs) + _, tx = train_utils.create_training_optimizer(config, model) + _create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(config, mesh) + + def init_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(nnx_model, optimizer) + else: + model = from_config(config) + mesh = model.mesh + _, tx = train_utils.create_training_optimizer(config, model) init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) @@ -113,22 +118,38 @@ def checkpoint_loop(config, state=None): def add_entropy_to_checkpoint(state): - """Introduce randomness in checkpoints. - - This is useful to simulate real checkpoints, without training. - - Args: - state: Initial state - - Returns: - state: Returns state with entropy added to the optimizer state. + """Replace adam mu/nu with cos/sin of params. + + Stand-in for real training when exercising checkpoint save/restore. Handles + three shapes: + * Linen `TrainState`: `state.params` + `state.opt_state` (tuple). + * NNX `TrainStateNNX` (Module): `state.model` is an `nnx.Module`; the + optimizer's `opt_state` is the optax tuple of NamedTuples. + * NNX `nnx.State` (post-split, what `setup_training_state` returns under + `pure_nnx`): `state.model` and `state.optimizer.opt_state` are sub-States; + `opt_state[0].mu`/`nu` are themselves States that can be reassigned. """ + if hasattr(state, "model"): + if isinstance(state, nnx.Module): + params = nnx.state(state.model, nnx.Param) + else: + params = state.model.filter(nnx.Param) if hasattr(state.model, "filter") else state.model + new_mu = jax.tree_util.tree_map(lambda k: jnp.cos(1000 * k), params) + new_nu = jax.tree_util.tree_map(lambda k: jnp.sin(1000 * k), params) + + if isinstance(state, nnx.Module): + opt = state.optimizer + opt.opt_state = (opt.opt_state[0]._replace(mu=new_mu, nu=new_nu),) + tuple(opt.opt_state[1:]) + else: + state.optimizer.opt_state[0].mu = new_mu + state.optimizer.opt_state[0].nu = new_nu + return state + opt_0 = state.opt_state[0] opt_0 = opt_0._replace(mu=jax.tree_util.tree_map(lambda k: jnp.cos(1000 * k), state.params)) opt_0 = opt_0._replace(nu=jax.tree_util.tree_map(lambda k: jnp.sin(1000 * k), state.params)) new_opt = [opt_0] + list(state.opt_state[1:]) - state = state.replace(opt_state=new_opt) - return state + return state.replace(opt_state=new_opt) def main(argv: Sequence[str]) -> None: diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index 906a597728..80229b05be 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -15,12 +15,14 @@ # pylint: disable=bare-except, consider-using-generator """Utils that are only interesting for training in MaxText.""" +import functools import os from functools import partial import jax -import functools +from flax import nnx from flax.linen import partitioning as nn_partitioning +from maxtext.layers import train_state_nnx from maxtext.common import checkpointing from maxtext.common.data_loader import create_dataloader from maxtext.common.goodput import GoodputEvent, maybe_record_goodput @@ -205,7 +207,7 @@ def setup_train_loop(config, recorder, devices=None): data_iterator: data_loader: rampup_manager: the class managing rampup batch sizes - state: the initialized train state + train_state: the initialized train state. For NNX, this is a TrainStateNNX instance """ # pylint: disable=import-outside-toplevel from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator @@ -213,16 +215,28 @@ def setup_train_loop(config, recorder, devices=None): with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): is_training = True init_rng = jax.random.PRNGKey(config.init_weights_seed) + mesh = maxtext_utils.get_mesh_from_config(config, devices) if config.pure_nnx: # Create abstract NNX model. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + _create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, mesh, devices) else: model = model_creation_utils.from_config(config, devices) - mesh = model.mesh learning_rate_schedule, tx = create_training_optimizer(config, model) + if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + # For NNX, the train state is wrapped in the TrainStateNNX module. + # When DPO is enabled, also materialize a frozen reference model alongside + # the policy. Both are constructed by `_create_model_partial()` (which uses + # `config.init_weights_seed`), so the reference starts identical to the + # policy — standard DPO practice. The reference is later overwritten by + # the step-0 checkpoint in `setup_post_setup_state` below. + def create_train_state_fn(): + model = _create_model_partial() + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + reference_model = _create_model_partial() if config.use_dpo else None + return train_state_nnx.TrainStateNNX(model, optimizer, reference_model=reference_model) + + init_state_fn = create_train_state_fn else: init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, is_training, init_rng) checkpoint_manager = create_checkpoint_manager(config, mesh, init_state_fn) @@ -266,6 +280,15 @@ def setup_train_loop(config, recorder, devices=None): state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( data_iterator, config, mesh, checkpoint_manager, init_state_fn ) + if config.pure_nnx: + with nn_partitioning.axis_rules(config.logical_axis_rules): + # train_state is instance of TrainStateNNX + state_graphdef, _ = nnx.get_abstract_model(init_state_fn, mesh) + _, state_params, _ = nnx.split(state.model, nnx.Param, ...) + _, state_mesh_shardings_params, _ = nnx.split(state_mesh_shardings.model, nnx.Param, ...) + else: + state_params = state.params + state_mesh_shardings_params = state_mesh_shardings.params if config.enable_diloco: with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): @@ -283,15 +306,20 @@ def setup_train_loop(config, recorder, devices=None): # TODO(aireenmei, hengtaoguo): support sharding in vit for multimodal if not config.using_pipeline_parallelism and not config.use_multimodal: # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage - sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance) + sharding.assert_params_sufficiently_sharded(state_params, mesh, config.sharding_tolerance) # print weights sharding info under debug sharding mode if config.debug_sharding: - logical_annotations = maxtext_utils.get_logical_annotations(config, mesh, init_state_fn) + if config.pure_nnx: + # TODO: Study how to get logical annotations of NNX module. Because of eager sharding, we + # probably already lost the logical partition info at this moment. + logical_annotations_params = None + else: + logical_annotations = maxtext_utils.get_logical_annotations(config, mesh, init_state_fn) + logical_annotations_params = logical_annotations.params + max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_shardings_params( - state.params, state_mesh_shardings.params, model.mesh, logical_annotations.params - ) + maxtext_utils.print_shardings_params(state_params, state_mesh_shardings_params, mesh, logical_annotations_params) if config.use_dpo: abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training) @@ -318,12 +346,26 @@ def setup_train_loop(config, recorder, devices=None): except FileNotFoundError: step0_restored = None if step0_restored is not None: - reference_params = step0_restored["items"].params["params"] - state = _merge_dpo_state(state, reference_params) + if config.pure_nnx: + # step0_restored["items"] is the flat nnx.State of the step-0 TrainStateNNX + # (typically from a non-DPO pre-training run, so its top-level fields are + # `model` and `optimizer` — no `reference_model`). Copy its `model` substate + # into our current state's `reference_model` slot. + step0_state = step0_restored["items"] + step0_model_substate = step0_state["model"] if "model" in step0_state else step0_state + state["reference_model"] = step0_model_substate + else: + reference_params = step0_restored["items"].params["params"] + state = _merge_dpo_state(state, reference_params) else: max_logging.log( "Could not restore reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" ) + if config.pure_nnx: + train_state = nnx.merge(state_graphdef, state) + model = train_state.model + else: + train_state = state return ( init_rng, @@ -336,7 +378,7 @@ def setup_train_loop(config, recorder, devices=None): data_loader, rampup_manager, eval_data_iterator, - state, + train_state, ) diff --git a/src/maxtext/utils/vocabulary_tiling.py b/src/maxtext/utils/vocabulary_tiling.py index e7b155416c..685bd32599 100644 --- a/src/maxtext/utils/vocabulary_tiling.py +++ b/src/maxtext/utils/vocabulary_tiling.py @@ -17,6 +17,7 @@ import functools from flax import linen as nn +from flax import nnx import jax import jax.numpy as jnp @@ -29,6 +30,25 @@ from maxtext.utils import max_utils +# Submodule-name keys whose `nnx.Param` leaves are touched by +# `Transformer.logits_from_hidden_states` (= `decoder.apply_output_head`): +# * `token_embedder` / `shared_embedding` — token embedder; used for tied logits. +# * `decoder_norm` — final layer norm. +# * `logits_dense` — LM-head dense; used for non-tied logits. +# Path filter for the 3-way `nnx.split` in `vocab_tiling_nnx_loss`'s output-head +# carve-out: matching leaves go into `head_params` (the custom_vjp's differentiated +# primal); everything else ends up in `other_params` and is threaded through as a +# non-differentiated primal so the bwd can rebuild the model without crossing +# trace boundaries. +_OUTPUT_HEAD_PATH_KEYS = ("token_embedder", "shared_embedding", "decoder_norm", "logits_dense") + + +def _is_output_head_param_path(path, _value): + """nnx.split callable filter: True iff `path` lies under an output-head submodule.""" + keys = [str(getattr(k, "key", k)) for k in path] + return any(k in keys for k in _OUTPUT_HEAD_PATH_KEYS) + + def vocab_tiling_linen_loss( hidden_states, data, @@ -247,3 +267,223 @@ def _bwd_scan_body(grad_params_acc, chunk_data): ) return total_loss, total_z_loss + + +def vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train): + """Calculates cross-entropy loss using vocab tiling for NNX models. + + NNX equivalent of `vocab_tiling_linen_loss`. The model is partitioned via + `nnx.split` into output-head params (`token_embedder`/`shared_embedding`, + `decoder_norm`, `logits_dense`), other params (transformer layers, etc.), and + non-Param state (rngs). Only the output-head params are the differentiated + primal of the custom_vjp; other params + rest are threaded through as + non-differentiated primals (bwd returns explicit zero pytrees of the same + shape/dtype as each primal). Forward and backward scans both rebuild the model + per chunk via `nnx.merge(..., copy=True)` and call `logits_from_hidden_states`. + + Backward memory is bounded by one chunk's logits (same as the Linen path). + The output-head carve-out additionally shrinks the custom_vjp's residual + + grad-accumulator scope from O(model params) to O(head params). + + Args: + model: The NNX model instance (must implement `logits_from_hidden_states`). + hidden_states: The final hidden states from the decoder. + data: A dictionary containing the input data, including 'targets' and 'targets_segmentation'. + config: The model and training configuration. + is_train: A boolean indicating if the model is in training mode. + + Returns: + A tuple (total_loss, total_z_loss). + """ + labels = data["targets"] + segmentation = data["targets_segmentation"] + deterministic = not config.enable_dropout if is_train else True + model_mode = "train" + + hidden_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch", "activation_length", "activation_embed"), + ) + label_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch", "activation_length"), + ) + reshaped_hidden_spec = create_sharding( + model.mesh, + ("num_tile", "activation_embed_and_logits_batch_sequence", "activation_embed"), + ) + reshaped_data_spec = create_sharding( + model.mesh, + ("num_tile", "activation_embed_and_logits_batch_sequence"), + ) + chunked_hidden_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence", "activation_embed"), + ) + chunked_data_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence",), + ) + chunked_logits_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence", "activation_vocab"), + ) + + _maybe_shard_with_name = functools.partial( + maybe_shard_with_name, + shard_mode=config.shard_mode, + debug_sharding=config.debug_sharding, + extra_stack_level=1, + ) + + def _reshape(inputs, out_shape, out_sharding): + reshape_out_sharding = out_sharding if config.shard_mode == ShardMode.EXPLICIT else None + inputs = jax.lax.reshape(inputs, out_shape, out_sharding=reshape_out_sharding) + return _maybe_shard_with_name(inputs, out_sharding) + + hidden_states = _maybe_shard_with_name(hidden_states, hidden_spec) + labels = _maybe_shard_with_name(labels, label_spec) + segmentation = _maybe_shard_with_name(segmentation, label_spec) + + # 3-way split outside the custom_vjp: + # * head_params — only leaves `logits_from_hidden_states` touches + # (token_embedder/shared_embedding, decoder_norm, logits_dense). Differentiated. + # * other_params — every other `nnx.Param` (transformer layers, etc.). + # Threaded through the custom_vjp as a primal; bwd returns explicit zeros. + # * rest — non-Param state (rngs). Threaded through as a primal too. + # Threading non-head primals (instead of closure-capture) is required to avoid + # `UnexpectedTracerError` when the embedded variables are accessed through the + # custom_vjp + lax.scan boundaries (manifests on `logits_via_embedding=True`). + graphdef, head_params, other_params, rest = nnx.split(model, _is_output_head_param_path, nnx.Param, ...) + + def _logits_for_chunk(chunk_head_params, chunk_other_params, chunk_rest, hidden_chunk): + local_model = nnx.merge(graphdef, chunk_head_params, chunk_other_params, chunk_rest, copy=True) + chunk_logits = local_model.logits_from_hidden_states(hidden_chunk, deterministic, model_mode) + return _maybe_shard_with_name(chunk_logits, chunked_logits_spec) + + @jax.custom_vjp + def chunked_cross_entropy_loss(chunk_head_params, chunk_other_params, chunk_rest, hidden_states, labels, segmentation): + (total_loss, total_z_loss), _ = _chunked_cross_entropy_loss_fwd( + chunk_head_params, chunk_other_params, chunk_rest, hidden_states, labels, segmentation + ) + return total_loss, total_z_loss + + def _chunked_cross_entropy_loss_fwd( + chunk_head_params, chunk_other_params, chunk_rest, hidden_states, labels, segmentation + ): + batch_size, seq_len, emb_dim = hidden_states.shape + vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling + + reshaped_hidden_states = _reshape( + hidden_states, (config.num_vocab_tiling, vocab_tile_size, emb_dim), reshaped_hidden_spec + ) + reshaped_labels = _reshape(labels, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) + reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) + + def _fwd_scan_body(accumulators, chunk_data): + loss_accumulator, z_loss_accumulator = accumulators + hidden_chunk, label_chunk, segmentation_chunk = chunk_data + hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec) + label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec) + segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec) + + chunk_logits = _logits_for_chunk(chunk_head_params, chunk_other_params, chunk_rest, hidden_chunk) + one_hot_label_chunk = jax.nn.one_hot(label_chunk, config.vocab_size) + chunk_xent, chunk_z_loss = max_utils.cross_entropy_with_logits( + chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier + ) + + masked_xent = jnp.sum(chunk_xent * (segmentation_chunk != 0)) + masked_z_loss = jnp.sum(chunk_z_loss * (segmentation_chunk != 0)) + + return (loss_accumulator + masked_xent, z_loss_accumulator + masked_z_loss), None + + # Always accumulate in fp32 — `cross_entropy_with_logits` returns fp32 regardless of + # logits dtype, and a bf16 carry would mismatch the body output type under lax.scan. + initial_acc = (jnp.zeros((), dtype=jnp.float32), jnp.zeros((), dtype=jnp.float32)) + (total_loss, total_z_loss), _ = jax.lax.scan( + _fwd_scan_body, initial_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation) + ) + residuals = ( + chunk_head_params, + chunk_other_params, + chunk_rest, + reshaped_hidden_states, + reshaped_labels, + reshaped_segmentation, + batch_size, + seq_len, + emb_dim, + ) + return (total_loss, total_z_loss), residuals + + def _chunked_cross_entropy_loss_bwd(residuals, cotangents): + # z_loss is folded into the xent loss inside cross_entropy_with_logits. + loss_cotangent, _ = cotangents + + ( + chunk_head_params, + chunk_other_params, + chunk_rest, + reshaped_hidden_states, + reshaped_labels, + reshaped_segmentation, + batch_size, + seq_len, + emb_dim, + ) = residuals + + def _single_chunk_loss_fn(input_head_params, input_hidden_chunk, input_label_chunk, input_segmentation_chunk): + chunk_logits = _logits_for_chunk(input_head_params, chunk_other_params, chunk_rest, input_hidden_chunk) + one_hot_label_chunk = jax.nn.one_hot(input_label_chunk, config.vocab_size) + xent, _ = max_utils.cross_entropy_with_logits(chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier) + return jnp.sum(xent * (input_segmentation_chunk != 0)) + + def _bwd_scan_body(grad_head_acc, chunk_data): + hidden_chunk, label_chunk, segmentation_chunk = chunk_data + hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec) + label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec) + segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec) + + # pylint: disable=unnecessary-lambda-assignment + loss_fn_for_vjp = lambda p, h: _single_chunk_loss_fn(p, h, label_chunk, segmentation_chunk) + _, vjp_fn = jax.vjp(loss_fn_for_vjp, chunk_head_params, hidden_chunk) + (grad_head_update, grad_hidden_chunk) = vjp_fn(1.0) + grad_hidden_chunk = _maybe_shard_with_name(grad_hidden_chunk, chunked_hidden_spec) + + grad_head_acc = jax.tree_util.tree_map(lambda acc, update: acc + update, grad_head_acc, grad_head_update) + return grad_head_acc, grad_hidden_chunk + + initial_grad_head = jax.tree_util.tree_map(jnp.zeros_like, chunk_head_params) + + grad_head, grad_reshaped_hidden_states = jax.lax.scan( + _bwd_scan_body, initial_grad_head, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation) + ) + grad_reshaped_hidden_states = _maybe_shard_with_name(grad_reshaped_hidden_states, reshaped_hidden_spec) + grad_head = jax.tree_util.tree_map(lambda g: g * loss_cotangent, grad_head) + grad_head = jax.tree_util.tree_map(lambda x, y: y.astype(x.dtype), chunk_head_params, grad_head) + grad_reshaped_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec) + + # Explicit zero cotangents for `chunk_other_params` and `chunk_rest`. Returning `None` + # makes JAX synthesize zeros at AOT time with the wrong axis convention for nnx-scanned + # transformer layer params (axis-0 instead of nnx's axis-1 stacking), causing + # `Expected cotangent type bfloat16[E,M] for primal type bfloat16[E,M], but got + # bfloat16[L,E,M]` at trace check. Materializing the zeros here ties the cotangent + # shape to the primal shape exactly. + grad_other = jax.tree_util.tree_map(jnp.zeros_like, chunk_other_params) + grad_rest = jax.tree_util.tree_map(jnp.zeros_like, chunk_rest) + return ( + grad_head, + grad_other, + grad_rest, + grad_reshaped_hidden_states.astype(reshaped_hidden_states.dtype), + None, + None, + ) + + chunked_cross_entropy_loss.defvjp(_chunked_cross_entropy_loss_fwd, _chunked_cross_entropy_loss_bwd) + + total_loss, total_z_loss = chunked_cross_entropy_loss( + head_params, other_params, rest, hidden_states, labels, segmentation + ) + return total_loss, total_z_loss diff --git a/tests/integration/setup_train_loop_nnx_test.py b/tests/integration/setup_train_loop_nnx_test.py new file mode 100644 index 0000000000..05a7fcffec --- /dev/null +++ b/tests/integration/setup_train_loop_nnx_test.py @@ -0,0 +1,131 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for setup_train_loop with pure_nnx=True. + +setup_train_loop wires together create_nnx_abstract_model, the training optimizer, +the checkpoint manager, the data iterator, and finally nnx.split / nnx.merge to +return a fully-formed TrainStateNNX. This test exercises that wiring end-to-end +on a tiny synthetic config — the goal is to cover the integration glue that the +unit tests in tests/unit/train_utils_nnx_test.py cannot reach. +""" + +import os +import sys +import unittest + +import pytest + +import jax +from flax import nnx + +from maxtext.configs import pyconfig +from maxtext.layers import train_state_nnx +from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT +from maxtext.utils.train_utils import setup_train_loop +from tests.utils.test_helpers import get_test_config_path + + +def _tiny_nnx_pyconfig(**overrides): + """Build a tiny pyconfig suitable for a single-host setup_train_loop run.""" + init_kwargs = { + "run_name": "setup_train_loop_nnx_test", + "enable_checkpointing": False, + "dataset_type": "synthetic", + "model_name": "default", + "pure_nnx": True, + "per_device_batch_size": 1.0, + "base_emb_dim": 8, + "base_num_query_heads": 4, + "base_num_kv_heads": 4, + "base_mlp_dim": 32, + "base_num_decoder_layers": 2, + "head_dim": 128, + "max_target_length": 128, + "vocab_size": 256, + "steps": 1, + "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"), + "enable_goodput_recording": False, + "enable_checkpoint_cloud_logger": False, + "monitor_goodput": False, + } + init_kwargs.update(overrides) + return pyconfig.initialize([sys.argv[0], get_test_config_path()], **init_kwargs) + + +@pytest.mark.integration_test +@pytest.mark.tpu_only +class SetupTrainLoopNNXIntegrationTest(unittest.TestCase): + """End-to-end check that setup_train_loop returns a usable TrainStateNNX.""" + + def test_pure_nnx_setup_returns_train_state_nnx(self): + config = _tiny_nnx_pyconfig() + + ( + init_rng, + checkpoint_manager, + state_mesh_shardings, + model, + mesh, + learning_rate_schedule, + data_iterator, + data_loader, + rampup_manager, + eval_data_iterator, + train_state, + ) = setup_train_loop(config, recorder=None) + + # The NNX path returns a fully-merged TrainStateNNX (lines 352-354 in train_utils.py). + self.assertIsInstance(train_state, train_state_nnx.TrainStateNNX) + # Optimizer.step starts at 0 for a fresh init. + self.assertEqual(int(train_state.optimizer.step.get_value()), 0) + # The returned model is train_state.model, an NNX module. + self.assertIsInstance(model, nnx.Module) + self.assertIs(model, train_state.model) + + # Sanity for sibling outputs: + self.assertIsNotNone(init_rng) + self.assertIsNotNone(mesh) + self.assertTrue(callable(learning_rate_schedule)) + # data_loader is mandatory; data_iterator may be wrapped/unwrapped. + self.assertIsNotNone(data_loader) + self.assertIsNotNone(data_iterator) + + # state_mesh_shardings (NNX) is an nnx.State and contains a 'model' branch. + self.assertIsInstance(state_mesh_shardings, nnx.State) + self.assertIn("model", state_mesh_shardings) + + # Cleanup: the rest are not asserted on but referenced so linters don't + # flag them as unused — they're part of the public return contract. + del checkpoint_manager, rampup_manager, eval_data_iterator + + def test_pure_nnx_setup_param_only_split_matches_model(self): + """nnx.split(state.model, nnx.Param, ...) must yield a non-empty Param tree + whose structure matches state_mesh_shardings.model after the same split.""" + config = _tiny_nnx_pyconfig() + *_, state_mesh_shardings, model, _, _, _, _, _, _, train_state = setup_train_loop(config, recorder=None) + + _, params, _ = nnx.split(train_state.model, nnx.Param, ...) + _, params_shardings, _ = nnx.split(state_mesh_shardings.model, nnx.Param, ...) + + # Same key-set after nnx.split — this is what setup_train_loop relies on at + # train_utils.py:281-282 to pair state_params with state_mesh_shardings_params. + self.assertEqual(jax.tree_util.tree_structure(params), jax.tree_util.tree_structure(params_shardings)) + self.assertGreater(len(jax.tree.leaves(params)), 0) + + del model + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/post_training/unit/distillation_scheduling_test.py b/tests/post_training/unit/distillation_scheduling_test.py index 21e22839b4..24b9b6d721 100644 --- a/tests/post_training/unit/distillation_scheduling_test.py +++ b/tests/post_training/unit/distillation_scheduling_test.py @@ -412,9 +412,15 @@ def __call__(self, x): self.assertEqual(int(bundle.training_step[...]), 2) @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") - def test_train_step_increments_and_passes_step(self, mock_value_and_grad, mock_global_norm): + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") + def test_train_step_increments_and_passes_step( + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_global_norm + ): """_train_step passes pre-increment step to compute_loss and increments after.""" + del mock_merge, mock_update # pylint: disable=no-value-for-parameter trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer) trainer.strategy = mock.Mock() @@ -442,37 +448,54 @@ def test_train_step_increments_and_passes_step(self, mock_value_and_grad, mock_g # Simulate resume from step 5 model_bundle.training_step.set_value(jnp.array(5, dtype=jnp.int32)) - mock_grad_fn = mock.Mock(return_value=((mock.Mock(), {}), mock.Mock())) + # nnx.split returns (graphdef, diff_params, rest); loss_wrapper_pure takes (diff_params, rest). + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + # grad_fn returns ((loss, (aux, new_rest)), grads) + mock_grad_fn = mock.Mock(return_value=((mock.Mock(), ({}, mock.Mock())), mock.Mock())) mock_value_and_grad.return_value = mock_grad_fn mock_global_norm.return_value = mock.Mock() + trainer.strategy.compute_loss.return_value = (mock.Mock(), {}) trainer._train_step(model_bundle, optimizer, mock.Mock()) # Step should have incremented to 6 self.assertEqual(int(model_bundle.training_step[...]), 6) - # Trigger loss_wrapper to verify step=5 was passed to compute_loss + # Trigger loss_wrapper_pure to verify step=5 was passed to compute_loss. + # Signature is (diff_params, rest). loss_wrapper = mock_value_and_grad.call_args[0][0] - loss_wrapper(student_model, teacher_model, mock_batch) + loss_wrapper(mock_diff_params, mock_rest) call_kwargs = trainer.strategy.compute_loss.call_args self.assertIn("step", call_kwargs.kwargs) self.assertEqual(int(call_kwargs.kwargs["step"]), 5) @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") - def test_consecutive_train_steps_increment(self, mock_value_and_grad, mock_global_norm): + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") + def test_consecutive_train_steps_increment( + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_global_norm + ): """training_step increments 0→1→2→3 across consecutive _train_step calls.""" + del mock_merge, mock_update # pylint: disable=no-value-for-parameter trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer) trainer.strategy = mock.Mock() trainer.wrt_filter = lambda path, x: True # type: ignore + # Use a real DistillationForwardOutput so jax.tree.map(stop_gradient, ...) works. + fake_teacher_output = distillation_utils.DistillationForwardOutput( + logits=jnp.zeros((1, 2, 4)), out_projection_activations=None + ) mock_batch = { "input_tokens": mock.Mock(), "positions": mock.Mock(), "targets": mock.Mock(), - "teacher_output": mock.Mock(), + "teacher_output": fake_teacher_output, } trainer.gen_model_input_fn = mock.Mock(return_value=mock_batch) @@ -480,7 +503,10 @@ def test_consecutive_train_steps_increment(self, mock_value_and_grad, mock_globa model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) optimizer = mock.Mock() - mock_grad_fn = mock.Mock(return_value=((mock.Mock(), {}), mock.Mock())) + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + mock_grad_fn = mock.Mock(return_value=((mock.Mock(), ({}, mock.Mock())), mock.Mock())) mock_value_and_grad.return_value = mock_grad_fn mock_global_norm.return_value = mock.Mock() diff --git a/tests/post_training/unit/train_distill_test.py b/tests/post_training/unit/train_distill_test.py index 80b7cbfce7..ca57e13f7a 100644 --- a/tests/post_training/unit/train_distill_test.py +++ b/tests/post_training/unit/train_distill_test.py @@ -162,9 +162,12 @@ def test_prepare_inputs_logic(self): @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") def test_train_step_skips_teacher_forward_when_output_present( - self, mock_value_and_grad, mock_tree_map, mock_global_norm + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_tree_map, mock_global_norm ): """Verifies teacher forward is skipped when model_output is already in the batch.""" # 1. Initialize Trainer @@ -189,21 +192,28 @@ def test_train_step_skips_teacher_forward_when_output_present( model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) optimizer, inputs = mock.Mock(), mock.Mock() - # 4. Configure mocked nnx.value_and_grad + # 4. Configure nnx.split/merge/update mocks + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + # 5. Configure mocked jax.value_and_grad + # _train_step uses: (loss, (aux, new_rest)), grads = grad_fn(diff_params, rest) mock_loss, mock_aux, mock_grads = mock.Mock(), {}, mock.Mock() - mock_grad_fn = mock.Mock(return_value=((mock_loss, mock_aux), mock_grads)) + mock_grad_fn = mock.Mock(return_value=((mock_loss, (mock_aux, mock.Mock())), mock_grads)) mock_value_and_grad.return_value = mock_grad_fn mock_global_norm.return_value = mock.Mock() + trainer.strategy.compute_loss.return_value = (mock.Mock(), {}) - # 5. Execute outer function & trigger inner loss_wrapper + # 6. Execute outer function & trigger inner loss_wrapper_pure trainer._train_step(model_bundle, optimizer, inputs) loss_wrapper = mock_value_and_grad.call_args[0][0] - loss_wrapper(student_model, teacher_model, mock_batch) + # loss_wrapper_pure signature is (diff_params, rest), not (student, teacher, batch) + loss_wrapper(mock_diff_params, mock_rest) - # 6. Assertions + # 7. Assertions trainer.strategy.teacher_forward_fn.assert_not_called() trainer.strategy.student_forward_fn.assert_called_once_with( - model=student_model, + model=mock.ANY, # local_student from nnx.merge, not the original student_model input_tokens=mock_batch["input_tokens"], positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], @@ -215,9 +225,12 @@ def test_train_step_skips_teacher_forward_when_output_present( @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") def test_train_step_calls_teacher_forward_when_output_missing( - self, mock_value_and_grad, mock_tree_map, mock_global_norm + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_tree_map, mock_global_norm ): """Verifies teacher forward is called when model_output is missing from the batch.""" # 1. Initialize Trainer @@ -242,19 +255,27 @@ def test_train_step_calls_teacher_forward_when_output_missing( model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) optimizer, inputs = mock.Mock(), mock.Mock() - # 4. Configure mocked nnx.value_and_grad + # 4. Configure nnx.split/merge/update mocks + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + # 5. Configure mocked jax.value_and_grad + # _train_step uses: (loss, (aux, new_rest)), grads = grad_fn(diff_params, rest) mock_loss, mock_aux, mock_grads = mock.Mock(), {}, mock.Mock() - mock_grad_fn = mock.Mock(return_value=((mock_loss, mock_aux), mock_grads)) + mock_grad_fn = mock.Mock(return_value=((mock_loss, (mock_aux, mock.Mock())), mock_grads)) mock_value_and_grad.return_value = mock_grad_fn mock_gn = mock.Mock() mock_global_norm.return_value = mock_gn + trainer.strategy.compute_loss.return_value = (mock.Mock(), {}) - # 5. Execute outer function & trigger inner loss_wrapper + # 6. Execute outer function & trigger inner loss_wrapper_pure train_step_out = trainer._train_step(model_bundle, optimizer, inputs) loss_wrapper = mock_value_and_grad.call_args[0][0] - loss_wrapper(student_model, teacher_model, mock_batch) + # loss_wrapper_pure signature is (diff_params, rest), not (student, teacher, batch) + loss_wrapper(mock_diff_params, mock_rest) - # 6. Assertions + # 7. Assertions + # Teacher forward is called OUTSIDE value_and_grad in _train_step trainer.strategy.teacher_forward_fn.assert_called_once_with( model=teacher_model, input_tokens=mock_batch["input_tokens"], @@ -266,8 +287,9 @@ def test_train_step_calls_teacher_forward_when_output_missing( decoder_target_mask=None, ) + # Student forward is called INSIDE loss_wrapper_pure via nnx.merge'd local_student trainer.strategy.student_forward_fn.assert_called_once_with( - model=student_model, + model=mock.ANY, # local_student from nnx.merge, not the original student_model input_tokens=mock_batch["input_tokens"], positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], @@ -291,8 +313,13 @@ def test_train_step_calls_teacher_forward_when_output_missing( @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") - def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_tree_map, mock_global_norm): + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") + def test_train_step_passes_targets_segmentation( + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_tree_map, mock_global_norm + ): """Verifies strategy callbacks receive decoder_target_tokens and decoder_target_mask.""" # 1. Initialize Trainer # pylint: disable=no-value-for-parameter @@ -317,22 +344,30 @@ def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_ model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) optimizer, inputs = mock.Mock(), mock.Mock() - # 4. Configure mocked nnx.value_and_grad - mock_grad_fn = mock.Mock(return_value=((mock.Mock(), {}), mock.Mock())) + # 4. Configure nnx.split/merge/update mocks + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + # 5. Configure mocked jax.value_and_grad + # _train_step uses: (loss, (aux, new_rest)), grads = grad_fn(diff_params, rest) + mock_grad_fn = mock.Mock(return_value=((mock.Mock(), ({}, mock.Mock())), mock.Mock())) mock_value_and_grad.return_value = mock_grad_fn mock_global_norm.return_value = mock.Mock() + trainer.strategy.compute_loss.return_value = (mock.Mock(), {}) - # 5. Execute outer function & trigger inner loss_wrapper + # 6. Execute outer function & trigger inner loss_wrapper_pure trainer._train_step(model_bundle, optimizer, inputs) loss_wrapper = mock_value_and_grad.call_args[0][0] - loss_wrapper(student_model, teacher_model, mock_batch) + # loss_wrapper_pure signature is (diff_params, rest), not (student, teacher, batch) + loss_wrapper(mock_diff_params, mock_rest) - # 6. Assertions + # 7. Assertions trainer.strategy.create_labels.assert_called_once_with( mock_batch["targets"], targets_segmentation=mock_targets_segmentation ) + # Student forward is called INSIDE loss_wrapper_pure via nnx.merge'd local_student trainer.strategy.student_forward_fn.assert_called_once_with( - model=student_model, + model=mock.ANY, # local_student from nnx.merge, not the original student_model input_tokens=mock_batch["input_tokens"], positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], @@ -341,6 +376,7 @@ def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_ decoder_target_mask=mock_targets_segmentation, cache=None, ) + # Teacher forward is called OUTSIDE value_and_grad in _train_step trainer.strategy.teacher_forward_fn.assert_called_once_with( model=teacher_model, input_tokens=mock_batch["input_tokens"], diff --git a/tests/unit/aqt_serve_roundtrip_nnx_test.py b/tests/unit/aqt_serve_roundtrip_nnx_test.py new file mode 100644 index 0000000000..83b79b95aa --- /dev/null +++ b/tests/unit/aqt_serve_roundtrip_nnx_test.py @@ -0,0 +1,148 @@ +# Copyright 2023-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Round-trip test for the NNX serve-mode AQT checkpoint path. + +Builds a small NNX model in CONVERT mode with int8 quantization, runs a forward +to populate `qrhs.frozen`, saves the serve-mode-shape state to a local orbax +checkpoint, then reloads via `from_pretrained(quant_mode_str="serve")` and +checks that the loaded QTensor leaves match what was saved. + +This guards the chain of issues fixed in PR9 (sharding helper for QTensor, +v[...] vs get_value() for composite values, Param-only filter dropping +aqt-typed leaves, Partitioned-unwrap for matching on-disk paths). +""" + +import os +import sys +import tempfile +import unittest + +import jax +import jax.numpy as jnp +import orbax.checkpoint as ocp +from flax import nnx +from flax.core.meta import Partitioned +from flax.linen import partitioning as nn_partitioning + +from maxtext.configs import pyconfig +from maxtext.utils import maxtext_utils, model_creation_utils, maxtext_utils_nnx +from maxtext.utils.globals import MAXTEXT_PKG_DIR +from maxtext.utils.layerwise_quantization import LayerwiseQuantization + + +def _wrap_value(node): + """Add `{"value": ...}` per-leaf wrap matching `_load_and_quantize_nnx` save format.""" + if isinstance(node, dict): + return {k: _wrap_value(v) for k, v in node.items()} + return {"value": node} + + +def _unbox(x): + return x.value if isinstance(x, Partitioned) else x + + +def _walk_qrhs(state): + """Yield (path_str, variable) pairs for every qrhs.frozen entry in an nnx.State.""" + for path, var in state.flat_state(): + keys = [str(getattr(k, "key", k)) for k in path] + if "qrhs" in keys and "frozen" in keys: + yield ".".join(keys), var + + +class ServeModeRoundTripTest(unittest.TestCase): + """End-to-end save+reload of a serve-mode NNX AQT checkpoint.""" + + def _init_cfg(self, ckpt_path, *, checkpoint_is_quantized): + # Use base.yml + gpt3-52k. The decoupled test config strips + # logical_axis_rules (e.g. "norm"), which the AQT serve-mode model + # construction needs. + base_yml = os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml") + args = [ + sys.argv[0], base_yml, + "model_name=gpt3-52k", + "pure_nnx=true", "enable_nnx=true", "pure_nnx_decoder=true", + "max_target_length=64", "max_prefill_predict_length=16", + "per_device_batch_size=1", "scan_layers=true", + "quantization=int8", + "checkpoint_storage_use_ocdbt=false", "checkpoint_storage_use_zarr3=false", + "skip_jax_distributed_system=true", + ] + if checkpoint_is_quantized: + args += [ + f"load_parameters_path={ckpt_path}", + "checkpoint_is_quantized=true", + "enable_checkpointing=true", # required by config validator when load_parameters_path is set + ] + else: + args += ["enable_checkpointing=false"] + return pyconfig.initialize(args) + + def test_save_then_reload_preserves_qrhs_frozen(self): + """Save a serve-mode-shape NNX checkpoint, then reload it and compare qvalue arrays.""" + with tempfile.TemporaryDirectory() as tmpdir: + ckpt_path = os.path.join(tmpdir, "quantized_ckpt") + + # Step 1: build CONVERT-mode model + run forward to populate qrhs.frozen. + cfg_save = self._init_cfg(ckpt_path, checkpoint_is_quantized=False) + mesh = maxtext_utils.get_mesh_from_config(cfg_save) + rngs = maxtext_utils_nnx.create_nnx_rngs(cfg_save) + with nn_partitioning.axis_rules(cfg_save.logical_axis_rules): + convert_model = model_creation_utils.from_config( + cfg_save, mesh=mesh, rngs=rngs, model_mode="train", quant_mode_str="convert", + ) + L = cfg_save.max_prefill_predict_length + tokens = jnp.zeros((1, L), dtype=jnp.int32) + pos = jnp.arange(L, dtype=jnp.int32)[None, :] + seg = jnp.ones((1, L), dtype=jnp.int32) + with nn_partitioning.axis_rules(cfg_save.logical_axis_rules): + _ = convert_model(tokens, pos, decoder_segment_ids=seg, enable_dropout=False, model_mode="train") + + # Step 2: capture the qrhs.frozen leaves we expect to round-trip, then save. + convert_state = nnx.state(convert_model).to_pure_dict() + serve_state = LayerwiseQuantization._strip_kernels_at_quantized_paths(convert_state) # pylint: disable=protected-access + saved_qrhs = {} + for path, var in _walk_qrhs(nnx.state(convert_model)): + qt = var.value if hasattr(var, "value") else var + saved_qrhs[path] = _unbox(qt.qvalue) + + orbax_checkpointer = ocp.PyTreeCheckpointer(use_ocdbt=False, use_zarr3=False) + orbax_checkpointer.save(ckpt_path, _wrap_value(serve_state), force=True) + self.assertGreater(len(saved_qrhs), 0, "Test config must produce at least one qrhs.frozen leaf") + + # Step 3: reload via from_pretrained in serve mode. + cfg_load = self._init_cfg(ckpt_path, checkpoint_is_quantized=True) + with nn_partitioning.axis_rules(cfg_load.logical_axis_rules): + loaded_model = model_creation_utils.from_pretrained( + cfg_load, mesh=mesh, model_mode="autoregressive", quant_mode_str="serve", + ) + + # Step 4: assert every saved qrhs.frozen leaf matches what was persisted. + loaded_state = nnx.state(loaded_model) + loaded_qrhs = dict(_walk_qrhs(loaded_state)) + self.assertEqual(set(saved_qrhs.keys()), set(loaded_qrhs.keys())) + for path, saved_qv in saved_qrhs.items(): + var = loaded_qrhs[path] + qt = var.value if hasattr(var, "value") else var + loaded_qv = _unbox(qt.qvalue) + self.assertEqual(loaded_qv.shape, saved_qv.shape, f"shape mismatch at {path}") + self.assertEqual(loaded_qv.dtype, saved_qv.dtype, f"dtype mismatch at {path}") + self.assertTrue( + jnp.array_equal(loaded_qv.astype(jnp.int32), saved_qv.astype(jnp.int32)), + f"qvalue not preserved at {path}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/checkpointing_nnx_load_test.py b/tests/unit/checkpointing_nnx_load_test.py new file mode 100644 index 0000000000..622f19323a --- /dev/null +++ b/tests/unit/checkpointing_nnx_load_test.py @@ -0,0 +1,106 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX branches of load_state_if_possible.""" + +import unittest +from unittest import mock + +import jax +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.common import checkpointing +from maxtext.layers import train_state_nnx + + +class _Model(nnx.Module): + """Tiny single-linear NNX model for restore tests.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + +def _abstract_nnx_state(): + """Build an nnx.State from a TrainStateNNX — same shape that pre_train passes in.""" + model = _Model(rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + return nnx.state(train_state_nnx.TrainStateNNX(model, optimizer)) + + +class TestLoadStateIfPossibleNNX(unittest.TestCase): + """Cover the NNX branches in load_state_if_possible.""" + + def test_load_parameters_from_path_splits_nnx_state_for_param_view(self): + """When abstract_unboxed_pre_state is an nnx.State, the function must call + nnx.split(model, nnx.Param, ...) to get the params and forward them to load_params_from_path.""" + abstract = _abstract_nnx_state() + sentinel_restored = {"linear": {"kernel": jnp.ones((2, 1)), "bias": jnp.zeros((1,))}} + + with mock.patch.object(checkpointing, "load_params_from_path", return_value=sentinel_restored) as m: + full, params = checkpointing.load_state_if_possible( + checkpoint_manager=None, + data_iterator=None, + load_parameters_from_path="gs://does-not-exist/params", + load_full_state_from_path="", + checkpoint_storage_concurrent_gb=8, + abstract_unboxed_pre_state=abstract, + ) + + self.assertIsNone(full) + self.assertIs(params, sentinel_restored) + m.assert_called_once() + forwarded_params = m.call_args[0][1] # second positional arg = abstract_unboxed_params + # The forwarded params come from nnx.split(..., nnx.Param, ...) — same key shape as the model. + leaves = jax.tree.leaves(forwarded_params) + self.assertEqual(len(leaves), 2) # linear.kernel + linear.bias + + def test_load_parameters_from_path_uses_state_params_for_linen(self): + """For Linen TrainState, the function must use state.params (not nnx.split).""" + fake_state = mock.Mock(spec=["params"]) + fake_state.params = {"layer": {"kernel": jnp.ones((2, 2))}} + sentinel = object() + + with mock.patch.object(checkpointing, "load_params_from_path", return_value=sentinel) as m: + full, params = checkpointing.load_state_if_possible( + checkpoint_manager=None, + data_iterator=None, + load_parameters_from_path="gs://does-not-exist/params", + load_full_state_from_path="", + checkpoint_storage_concurrent_gb=8, + abstract_unboxed_pre_state=fake_state, + ) + + self.assertIsNone(full) + self.assertIs(params, sentinel) + forwarded_params = m.call_args[0][1] + self.assertIs(forwarded_params, fake_state.params) + + def test_no_paths_returns_none_none(self): + """Sanity: with no checkpoint manager and no load paths, the function returns (None, None).""" + full, params = checkpointing.load_state_if_possible( + checkpoint_manager=None, + data_iterator=None, + load_parameters_from_path="", + load_full_state_from_path="", + checkpoint_storage_concurrent_gb=8, + abstract_unboxed_pre_state=_abstract_nnx_state(), + ) + self.assertIsNone(full) + self.assertIsNone(params) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dpo_nnx_test.py b/tests/unit/dpo_nnx_test.py new file mode 100644 index 0000000000..461c3cb2aa --- /dev/null +++ b/tests/unit/dpo_nnx_test.py @@ -0,0 +1,215 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NNX DPO unit tests. + +Covers the NNX-native DPO surface: + * `TrainStateNNX(model, optimizer, reference_model=...)` — reference model + sits alongside policy and is not touched by `apply_gradients`. + * `dpo_loss_fn_nnx(policy, config, data, None, None, reference, is_train)` — + aux structure, identical-model invariant (loss = log(2), reward_accuracy = 0.5). +""" + +import math +import types +import unittest + +import jax +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.layers import train_state_nnx +from maxtext.trainers.post_train.dpo import dpo_utils + + +class _MockTransformer(nnx.Module): + """Tiny NNX transformer-shaped module for DPO tests. + + Accepts the same keyword args that `dpo_loss_fn_nnx` passes: + `decoder_input_tokens`, `decoder_positions`, `decoder_segment_ids`, + `enable_dropout`. Other args are tolerated via **kwargs. + """ + + def __init__(self, vocab_size: int, embed_dim: int, rngs: nnx.Rngs): + self.embed = nnx.Embed(vocab_size, embed_dim, rngs=rngs) + self.proj = nnx.Linear(embed_dim, vocab_size, rngs=rngs) + + def __call__( + self, + decoder_input_tokens, + decoder_positions=None, + decoder_segment_ids=None, + enable_dropout=False, + **kwargs, + ): + del decoder_positions, decoder_segment_ids, enable_dropout, kwargs + return self.proj(self.embed(decoder_input_tokens)) + + +def _make_dpo_config(**overrides): + """Build the minimal config surface that `dpo_loss_fn_nnx` reads.""" + base = { + "dpo_label_smoothing": 0.0, + "dpo_beta": 0.1, + "enable_dropout": False, + "num_experts": 1, + "micro_batch_size_to_train_on": 2, + } + base.update(overrides) + return types.SimpleNamespace(**base) + + +def _make_dpo_batch(batch_size=2, seq_len=5): + """Build a tiny DPO-shaped batch. + + `chosen` and `rejected` share the first 2 tokens (common prefix is masked + out in the loss), differ at positions 2 and 3, and are padded at position 4. + """ + chosen = jnp.array([[1, 2, 3, 4, 0]] * batch_size, dtype=jnp.int32) + rejected = jnp.array([[1, 2, 5, 6, 0]] * batch_size, dtype=jnp.int32) + positions = jnp.tile(jnp.arange(seq_len, dtype=jnp.int32), (batch_size, 1)) + segmentation = jnp.array([[1, 1, 1, 1, 0]] * batch_size, dtype=jnp.int32) + return { + "chosen": chosen, + "rejected": rejected, + "chosen_position": positions, + "rejected_position": positions, + "chosen_segmentation": segmentation, + "rejected_segmentation": segmentation, + } + + +class TestTrainStateNNXWithReferenceModel(unittest.TestCase): + """`TrainStateNNX(reference_model=...)` semantics.""" + + def setUp(self): + self.policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + self.reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(1)) + self.tx = optax.adam(1e-3) + + def test_init_with_reference(self): + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer, reference_model=self.reference) + self.assertIs(state.model, self.policy) + self.assertIs(state.reference_model, self.reference) + self.assertEqual(state.optimizer.step.value, 0) + + def test_init_without_reference_omits_attribute(self): + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer) + self.assertFalse(hasattr(state, "reference_model")) + + def test_apply_gradients_does_not_touch_reference(self): + """Gradient update on policy must leave reference model bit-identical.""" + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer, reference_model=self.reference) + + ref_kernel_before = jnp.asarray(state.reference_model.proj.kernel.value).copy() + + def policy_loss(m): + return jnp.mean(m(jnp.array([[1, 2]])) ** 2) + + grads = nnx.grad(policy_loss)(state.model) + state.apply_gradients(grads) + + ref_kernel_after = jnp.asarray(state.reference_model.proj.kernel.value) + self.assertTrue(jnp.array_equal(ref_kernel_before, ref_kernel_after)) + + +class TestDPOLossFnNNX(unittest.TestCase): + """`dpo_loss_fn_nnx` numerical and structural sanity checks.""" + + def setUp(self): + self.policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + # Reference initialized with the same seed to make policy and reference + # bit-identical at construction time. + self.reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + self.config = _make_dpo_config() + self.data = _make_dpo_batch() + + def test_aux_has_expected_keys(self): + _, aux = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + expected_keys = { + "intermediate_outputs", + "xent_sum", + "dpo_loss", + "total_weights", + "moe_lb_loss", + "reward_accuracy", + "indexer_loss", + "mtp_loss", + } + self.assertEqual(set(aux.keys()), expected_keys) + self.assertEqual(aux["xent_sum"], 0.0) + self.assertEqual(aux["moe_lb_loss"], 0.0) # num_experts=1 + self.assertEqual(aux["total_weights"], self.data["chosen"].shape[0]) + + def test_identical_policy_and_reference_yields_log2_loss(self): + """When policy == reference, all logratios are 0; with label_smoothing=0 + the per-example loss is `-log(sigmoid(0)) = log(2)`. `reward_accuracy` + uses strict `chosen > rejected`, so equal logratios score 0.0 (no example + is strictly preferred). + """ + loss, aux = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + self.assertAlmostEqual(float(loss), math.log(2.0), places=4) + self.assertAlmostEqual(float(aux["dpo_loss"]), math.log(2.0), places=4) + self.assertAlmostEqual(float(aux["reward_accuracy"]), 0.0, places=4) + + def test_dropout_rng_and_params_args_are_unused(self): + """The 4th and 5th positional args are signature-compat slots for the + Linen dispatcher; passing arbitrary values must not affect the result. + """ + loss_a, _ = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + loss_b, _ = dpo_utils.dpo_loss_fn_nnx( + self.policy, + self.config, + dict(self.data), + jax.random.PRNGKey(123), # dropout_rng — unused + {"params": "garbage"}, # params — unused + self.reference, + is_train=True, + ) + self.assertAlmostEqual(float(loss_a), float(loss_b), places=6) + + def test_value_and_grad_argnums0_only_diffs_policy(self): + """`nnx.value_and_grad(..., argnums=0)` over the policy should produce + finite grads on policy params and not require reference grads. + """ + + def _loss(policy_module): + loss, _ = dpo_utils.dpo_loss_fn_nnx( + policy_module, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + return loss + + grad_fn = nnx.value_and_grad(_loss, argnums=0) + loss, grads = grad_fn(self.policy) + self.assertTrue(jnp.isfinite(loss)) + # Grads is an nnx.State of the policy's nnx.Param leaves; check at least one + # leaf is finite and non-trivially shaped. + leaves = jax.tree_util.tree_leaves(grads) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + self.assertTrue(jnp.all(jnp.isfinite(leaf))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/generate_param_only_checkpoint_nnx_test.py b/tests/unit/generate_param_only_checkpoint_nnx_test.py new file mode 100644 index 0000000000..49c43a6dae --- /dev/null +++ b/tests/unit/generate_param_only_checkpoint_nnx_test.py @@ -0,0 +1,205 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX path of generate_param_only_checkpoint. + +Covers `_possibly_unroll_params_nnx` (slicing scanned NNX layers) and the +shape parity of `_save_decode_checkpoint_nnx`'s bf16 cast. +""" + +from types import SimpleNamespace +import unittest + +import jax +import jax.numpy as jnp +import numpy as np +import optax +from flax import nnx +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from flax.training import train_state as linen_train_state + +from maxtext.common.common_types import DecoderBlockType +from maxtext.layers import train_state_nnx +from maxtext.utils.generate_param_only_checkpoint import ( + _possibly_unroll_lora_params_nnx, + _possibly_unroll_params_nnx, +) + + +class _ScanLayerLeaf(nnx.Module): + """One scanned-layer kernel with leading shape `[num_layers, *]`.""" + + def __init__(self, num_layers: int, in_dim: int, out_dim: int): + self.kernel = nnx.Param( + jnp.arange(num_layers * in_dim * out_dim, dtype=jnp.float32).reshape(num_layers, in_dim, out_dim) + ) + + +class _Decoder(nnx.Module): + + def __init__(self, num_layers: int): + self.layers = _ScanLayerLeaf(num_layers, 3, 5) + + +class _Model(nnx.Module): + + def __init__(self, num_layers: int): + self.decoder = _Decoder(num_layers) + + +def _make_split_state(num_layers: int): + model = _Model(num_layers) + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + ts = train_state_nnx.TrainStateNNX(model, optimizer) + _, state = nnx.split(ts) + return state + + +def _make_shardings_state(state, mesh): + """Build a sibling shardings tree where each Variable is replaced by NamedSharding(replicated).""" + + def to_named(v): + return NamedSharding(mesh, PartitionSpec()) + + return jax.tree_util.tree_map(to_named, state, is_leaf=lambda x: isinstance(x, nnx.Variable)) + + +class PossiblyUnrollParamsNNXTest(unittest.TestCase): + + def setUp(self): + devices = np.array(jax.devices()).reshape(-1) + self.mesh = Mesh(devices, ("data",)) + + def test_unrolls_scanned_layers(self): + num_layers = 3 + state = _make_split_state(num_layers) + shardings = _make_shardings_state(state, self.mesh) + + original_kernel = np.asarray(state.model.decoder.layers.kernel[...]) + + config = SimpleNamespace( + scan_layers=True, + force_unroll=True, + pure_nnx=True, + param_scan_axis=0, + decoder_block=DecoderBlockType.LLAMA2, + num_decoder_layers=num_layers, + ) + + _possibly_unroll_params_nnx(config, state, shardings, self.mesh) + + self.assertNotIn("layers", state.model.decoder) + self.assertNotIn("layers", shardings.model.decoder) + for i in range(num_layers): + self.assertIn(f"layers_{i}", state.model.decoder) + self.assertIn(f"layers_{i}", shardings.model.decoder) + sliced = state.model.decoder[f"layers_{i}"]["kernel"][...] + expected = jnp.take(original_kernel, i, axis=0) + self.assertTrue(jnp.array_equal(sliced, expected)) + + def test_deepseek_split(self): + """DeepSeek decoder has separate dense/moe layer collections.""" + + # Build a DeepSeek-flavored synthetic model with two scanned groups. + class _DeepSeekDecoder(nnx.Module): + + def __init__(self): + self.dense_layers = _ScanLayerLeaf(2, 3, 5) + self.moe_layers = _ScanLayerLeaf(3, 3, 5) + + class _DSModel(nnx.Module): + + def __init__(self): + self.decoder = _DeepSeekDecoder() + + model = _DSModel() + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + ts = train_state_nnx.TrainStateNNX(model, optimizer) + _, state = nnx.split(ts) + shardings = _make_shardings_state(state, self.mesh) + + config = SimpleNamespace( + scan_layers=True, + force_unroll=True, + pure_nnx=True, + param_scan_axis=0, + decoder_block=DecoderBlockType.DEEPSEEK, + num_decoder_layers=5, + first_num_dense_layers=2, + ) + + _possibly_unroll_params_nnx(config, state, shardings, self.mesh) + + self.assertNotIn("dense_layers", state.model.decoder) + self.assertNotIn("moe_layers", state.model.decoder) + for i in range(2): + self.assertIn(f"dense_layers_{i}", state.model.decoder) + for i in range(3): + self.assertIn(f"moe_layers_{i}", state.model.decoder) + + +class PossiblyUnrollLoraParamsNNXTest(unittest.TestCase): + """The LoRA delta tree is single-nested (`{"decoder": {...}}`) and held in a + Linen `TrainState` even on the NNX path — the unroll has to walk that shape.""" + + def setUp(self): + devices = np.array(jax.devices()).reshape(-1) + self.mesh = Mesh(devices, ("data",)) + + def _make_lora_state(self, num_layers: int, lora_rank: int = 4): + """Build a synthetic LoRA delta TrainState mirroring `get_lora_abstract_state_nnx`'s output shape.""" + lora_a = jnp.arange(num_layers * 8 * lora_rank, dtype=jnp.float32).reshape(num_layers, 8, lora_rank) + lora_b = jnp.arange(num_layers * lora_rank * 4 * 2, dtype=jnp.float32).reshape(num_layers, lora_rank, 4, 2) + params = { + "decoder": { + "layers": { + "self_attention": { + "query": {"lora_a.kernel": lora_a, "lora_b.kernel": lora_b}, + } + } + } + } + annotations_params = jax.tree_util.tree_map(lambda _: PartitionSpec(), params) + state = linen_train_state.TrainState(step=0, apply_fn=None, params=params, tx=None, opt_state={}) + annotations = linen_train_state.TrainState(step=0, apply_fn=None, params=annotations_params, tx=None, opt_state={}) + return state, annotations + + def test_unrolls_scanned_lora_layers(self): + num_layers = 3 + state, annotations = self._make_lora_state(num_layers) + original_a = np.asarray(state.params["decoder"]["layers"]["self_attention"]["query"]["lora_a.kernel"]) + + config = SimpleNamespace( + scan_layers=True, + force_unroll=True, + pure_nnx=True, + param_scan_axis=0, + decoder_block=DecoderBlockType.LLAMA2, + num_decoder_layers=num_layers, + ) + + _possibly_unroll_lora_params_nnx(config, state, annotations, self.mesh) + + self.assertNotIn("layers", state.params["decoder"]) + self.assertNotIn("layers", annotations.params["decoder"]) + for i in range(num_layers): + self.assertIn(f"layers_{i}", state.params["decoder"]) + sliced_a = state.params["decoder"][f"layers_{i}"]["self_attention"]["query"]["lora_a.kernel"] + expected = jnp.take(original_a, i, axis=0) + self.assertTrue(jnp.array_equal(sliced_a, expected)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/gradient_accumulation_nnx_test.py b/tests/unit/gradient_accumulation_nnx_test.py new file mode 100644 index 0000000000..6353f02397 --- /dev/null +++ b/tests/unit/gradient_accumulation_nnx_test.py @@ -0,0 +1,159 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX branch of gradient_accumulation_loss_and_grad.""" + +import unittest +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +import numpy as np +from flax import nnx +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from maxtext.common.common_types import ShardMode +from maxtext.utils import gradient_accumulation + + +@dataclass +class _Cfg: + gradient_accumulation_steps: int = 2 + shard_optimizer_over_data: bool = False + shard_mode: int = ShardMode.AUTO + ici_data_parallelism: int = 1 + debug_sharding: bool = False + + +class _TinyNNX(nnx.Module): + """Single linear layer NNX model.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + def __call__(self, x): + return self.linear(x) + + +def _fake_loss_fn(model, config, data, dropout_rng, params, is_train=True): + """A loss_fn shaped like the production loss_fn but for a tiny linear model. + + Returns (loss, aux) where aux follows the schema gradient_accumulation_loss_and_grad + reads from: xent_sum / total_weights / moe_lb_loss / indexer_loss / mtp_loss. + """ + del config, dropout_rng, params, is_train + pred = model(data["inputs"]) + per_sample_loss = jnp.mean((pred - data["targets"]) ** 2, axis=-1) + xent_sum = jnp.sum(per_sample_loss) + total_weights = jnp.array(per_sample_loss.shape[0], dtype=jnp.float32) + aux = { + "xent_sum": xent_sum, + "total_weights": total_weights, + "moe_lb_loss": jnp.array(0.0), + "indexer_loss": jnp.array(0.0), + "mtp_loss": jnp.array(0.0), + } + return xent_sum / total_weights, aux + + +class TestGradientAccumulationNNX(unittest.TestCase): + """Cover the NNX path of gradient_accumulation_loss_and_grad.""" + + def setUp(self): + self.model = _TinyNNX(rngs=nnx.Rngs(0)) + self.cfg = _Cfg(gradient_accumulation_steps=2) + # 4 examples → 2 microbatches of 2 each + self.data = { + "inputs": jnp.arange(8.0).reshape(4, 2), + "targets": jnp.zeros((4, 1)), + } + + def _params_shardings(self): + """Build a per-leaf NamedSharding tree shaped like nnx.split(model, nnx.Param, ...)[1]. + + Uses a trivial single-device mesh so jax.lax.with_sharding_constraint accepts the + sharding without contradicting the actual device topology. + """ + _, params, _ = nnx.split(self.model, nnx.Param, ...) + mesh = Mesh( + np.array(jax.local_devices()[:1]).reshape( + 1, + ), + ("x",), + ) + ns = NamedSharding(mesh, PartitionSpec()) + return jax.tree.map(lambda _: ns, params) + + def test_nnx_path_runs_and_returns_grad_for_every_param(self): + """The NNX branch must call nnx.value_and_grad and return one gradient per Param.""" + loss, aux, raw_grads = gradient_accumulation.gradient_accumulation_loss_and_grad( + _fake_loss_fn, + self.cfg, + self.model, + params=None, # NNX branch ignores params + params_shardings=self._params_shardings(), + data=self.data, + dropout_rng=None, + extra_dpo_args=[], + ) + self.assertTrue(jnp.isfinite(loss)) + self.assertIn("xent_sum", aux) + self.assertIn("total_weights", aux) + grad_leaves = jax.tree.leaves(raw_grads) + self.assertEqual(len(grad_leaves), 2) # linear.kernel + linear.bias + for g in grad_leaves: + self.assertTrue(jnp.all(jnp.isfinite(g))) + + def test_nnx_path_updates_model_rest_state_after_scan(self): + """After accumulation, nnx.update is called on the model with the rest_state from the scan. + + For a TinyNNX (no rngs/dropout), the rest tree is empty but the call path must still + succeed end-to-end without raising — covering the `if is_nnx: nnx.update(...)` branch. + """ + pre_kernel = self.model.linear.kernel.value.copy() + gradient_accumulation.gradient_accumulation_loss_and_grad( + _fake_loss_fn, + self.cfg, + self.model, + params=None, + params_shardings=self._params_shardings(), + data=self.data, + dropout_rng=None, + extra_dpo_args=[], + ) + # The kernel itself is a Param — gradient_accumulation_loss_and_grad does not apply + # gradients to params, so the value should be untouched. + self.assertTrue(jnp.allclose(self.model.linear.kernel.value, pre_kernel)) + + def test_nnx_with_shard_optimizer_over_data_casts_to_bf16(self): + """Zero-1 path must convert fp32 params to bf16 before the scan loop.""" + self.cfg.shard_optimizer_over_data = True + # Should not raise; just verify the function runs and returns sensible outputs. + loss, _, raw_grads = gradient_accumulation.gradient_accumulation_loss_and_grad( + _fake_loss_fn, + self.cfg, + self.model, + params=None, + params_shardings=self._params_shardings(), + data=self.data, + dropout_rng=None, + extra_dpo_args=[], + ) + self.assertTrue(jnp.isfinite(loss)) + for g in jax.tree.leaves(raw_grads): + self.assertTrue(jnp.all(jnp.isfinite(g))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/grpo_nnx_test.py b/tests/unit/grpo_nnx_test.py new file mode 100644 index 0000000000..6f72c43723 --- /dev/null +++ b/tests/unit/grpo_nnx_test.py @@ -0,0 +1,231 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for `grpo_loss_fn_nnx`, `compute_log_probs_nnx`, plus a small +Linen-path regression block (the repo's existing Linen GRPO integration test +is TPU-only).""" + +import types +import unittest + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax import nnx + +from maxtext.experimental.rl import grpo_trainer +from maxtext.experimental.rl import grpo_utils + + +class _MockTransformer(nnx.Module): + """Tiny NNX module that responds to the kwargs `compute_log_probs_nnx` uses.""" + + def __init__(self, vocab_size: int, embed_dim: int, rngs: nnx.Rngs): + self.embed = nnx.Embed(vocab_size, embed_dim, rngs=rngs) + self.proj = nnx.Linear(embed_dim, vocab_size, rngs=rngs) + + def __call__( + self, + decoder_input_tokens, + decoder_positions=None, + decoder_segment_ids=None, + enable_dropout=False, + **kwargs, + ): + del decoder_positions, decoder_segment_ids, enable_dropout, kwargs + return self.proj(self.embed(decoder_input_tokens)) + + +def _make_grpo_config(**overrides): + """Minimal config namespace covering every field `grpo_loss_fn_nnx` reads.""" + base = { + "train_data_columns": "prompt", + "num_generations": 2, + "grpo_epsilon": 0.2, + "grpo_beta": 0.1, + "num_experts": 1, + "decode_sampling_temperature": 1.0, + "enable_dropout": False, + "use_dpo": False, + } + base.update(overrides) + return types.SimpleNamespace(**base) + + +def _make_grpo_batch(B=2, G=2, S=6): + """Minimal GRPO batch: `B` prompts, `G` generations each (total `B*G`), seq length `S`.""" + total = B * G + prompts = jnp.tile(jnp.arange(S, dtype=jnp.int32), (total, 1)) + return { + "prompt_completions": prompts, + "prompt_completions_position": prompts, + "prompt_completions_segmentation": jnp.ones((total, S), dtype=jnp.int32), + "ar_completions_segmentation": jnp.array([[0, 0, 1, 1, 1, 0]] * total, dtype=jnp.int32), + "completions_logprobs": None, # off-policy + } + + +class TestGrpoLossFnNnx(unittest.TestCase): + """Behavior of `grpo_loss_fn_nnx` on a synthetic policy / reference pair.""" + + def setUp(self): + self.policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + self.reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) # identical seed + self.config = _make_grpo_config() + self.data = _make_grpo_batch() + + def test_aux_structure_matches_linen(self): + """`grpo_loss_fn_nnx` returns the same `LossAux` dataclass shape as `grpo_loss_fn`.""" + loss, aux = grpo_trainer.grpo_loss_fn_nnx( + self.policy, self.config, self.data, None, None, self.reference, is_train=True + ) + self.assertIsInstance(aux, grpo_trainer.LossAux) + for field in ( + "total_loss", + "avg_reward", + "avg_reward_std", + "avg_advantage", + "completion_length", + "moe_lb_loss", + "total_weights", + ): + self.assertTrue(hasattr(aux, field), f"aux missing field {field}") + self.assertTrue(jnp.isfinite(loss)) + + def test_unused_dropout_rng_and_params_args_are_ignored(self): + """`dropout_rng` and `params` are positional placeholders only — values shouldn't matter.""" + a = grpo_trainer.grpo_loss_fn_nnx(self.policy, self.config, self.data, None, None, self.reference, is_train=True) + b = grpo_trainer.grpo_loss_fn_nnx( + self.policy, self.config, self.data, jax.random.key(99), {"junk": 1}, self.reference, is_train=True + ) + np.testing.assert_allclose(np.asarray(a[0]), np.asarray(b[0]), rtol=1e-6) + + def test_identical_policy_and_reference_zero_kl(self): + """Identical policy and reference → per-token KL is zero, so `aux.avg_kl ≈ 0`.""" + cfg = _make_grpo_config(grpo_beta=0.5) + _, aux = grpo_trainer.grpo_loss_fn_nnx(self.policy, cfg, self.data, None, None, self.reference, is_train=True) + self.assertIsNotNone(aux.avg_kl) + np.testing.assert_allclose(np.asarray(aux.avg_kl), 0.0, atol=1e-5) + + def test_grpo_beta_zero_avg_kl_is_none(self): + cfg = _make_grpo_config(grpo_beta=0.0) + _, aux = grpo_trainer.grpo_loss_fn_nnx(self.policy, cfg, self.data, None, None, self.reference, is_train=True) + self.assertIsNone(aux.avg_kl) + + def test_value_and_grad_flows_only_to_policy(self): + """`nnx.value_and_grad` over the policy yields finite grads; reference is left alone.""" + + def loss_only(policy_model): + loss, _ = grpo_trainer.grpo_loss_fn_nnx( + policy_model, self.config, self.data, None, None, self.reference, is_train=True + ) + return loss + + # nnx.value_and_grad returns (value, grad_state) where grad_state holds nnx.Param leaves. + _, grads = nnx.value_and_grad(loss_only, argnums=0)(self.policy) + leaves = jax.tree_util.tree_leaves(grads) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + self.assertTrue(np.all(np.isfinite(np.asarray(leaf))), "policy grad has non-finite entries") + + +class TestComputeLogProbsNnx(unittest.TestCase): + """Shape contract of `compute_log_probs_nnx`.""" + + def test_returns_correct_shape(self): + config = _make_grpo_config() + data = _make_grpo_batch() + model = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + log_probs, _ = grpo_utils.compute_log_probs_nnx( + model, + data["prompt_completions"], + data["prompt_completions_position"], + data["prompt_completions_segmentation"], + data["ar_completions_segmentation"], + config, + is_train=False, + ) + # Inputs are [B, S] → log_probs are [B, S-1]. + self.assertEqual(log_probs.shape, (data["prompt_completions"].shape[0], data["prompt_completions"].shape[1] - 1)) + + +# --------------------------------------------------------------------------- +# Linen-path regression smoke tests +# --------------------------------------------------------------------------- + + +class _MockLinenTransformer(nn.Module): + """Tiny Linen module that responds to the same `model.apply(...)` shape Linen `compute_log_probs` uses.""" + + vocab_size: int + embed_dim: int + + @nn.compact + def __call__(self, inputs, positions, decoder_segment_ids=None, enable_dropout=False): + del positions, decoder_segment_ids, enable_dropout + embed = nn.Embed(num_embeddings=self.vocab_size, features=self.embed_dim, name="embed")(inputs) + return nn.Dense(features=self.vocab_size, name="proj")(embed) + + +class TestLinenGrpoRegression(unittest.TestCase): + """Smoke test that the Linen `grpo_loss_fn` and `compute_log_probs` still run + end-to-end with `pure_nnx=False`-style inputs.""" + + def setUp(self): + self.config = _make_grpo_config() + self.config.pure_nnx = False # explicit Linen mode + self.config.gradient_accumulation_steps = 1 + self.data = _make_grpo_batch() + self.model = _MockLinenTransformer(vocab_size=8, embed_dim=4) + rng = jax.random.key(0) + inputs = self.data["prompt_completions"] + self.params = self.model.init(rng, inputs, inputs, decoder_segment_ids=jnp.ones_like(inputs), enable_dropout=False) + self.reference_params = jax.tree_util.tree_map(jnp.copy, self.params) + + def test_linen_grpo_loss_fn_still_runs(self): + """Linen `grpo_loss_fn` returns a finite loss + a `LossAux`.""" + loss, aux = grpo_trainer.grpo_loss_fn( + self.model, + self.config, + self.data, + jax.random.key(1), + self.params, + self.reference_params["params"], # Linen reference_params is the inner subtree + is_train=True, + ) + self.assertTrue(jnp.isfinite(loss)) + self.assertTrue(hasattr(aux, "total_loss")) + self.assertTrue(hasattr(aux, "moe_lb_loss")) + self.assertTrue(hasattr(aux, "total_weights")) + + def test_linen_compute_log_probs_still_runs(self): + """Linen `compute_log_probs` produces shape `[B, S-1]`.""" + log_probs, _ = grpo_utils.compute_log_probs( + self.model, + self.params, + self.data["prompt_completions"], + self.data["prompt_completions_position"], + self.data["prompt_completions_segmentation"], + self.data["ar_completions_segmentation"], + self.config, + is_train=False, + rngs={"dropout": jax.random.key(2), "params": jax.random.key(3)}, + ) + S = self.data["prompt_completions"].shape[1] + self.assertEqual(log_probs.shape, (self.data["prompt_completions"].shape[0], S - 1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/layerwise_quantization_nnx_test.py b/tests/unit/layerwise_quantization_nnx_test.py new file mode 100644 index 0000000000..bbd43f7964 --- /dev/null +++ b/tests/unit/layerwise_quantization_nnx_test.py @@ -0,0 +1,77 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX path of layerwise_quantization. + +Covers `_strip_kernels_at_quantized_paths` — the convert→serve shape converter +that drops the redundant full-precision kernel from quantized DenseGeneral +nodes while leaving non-quantized kernels (norms, embeddings) intact. +""" + +import unittest + +from maxtext.utils.layerwise_quantization import LayerwiseQuantization + + +class StripKernelsTest(unittest.TestCase): + + def test_drops_kernel_at_quantized_dense(self): + """A node with both `kernel` and `AqtDotGeneral_0` loses the kernel.""" + state = { + "decoder": { + "layers": { + "mlp": { + "wi": { + "kernel": "FULL_PRECISION_W", + "AqtDotGeneral_0": {"qrhs": {"frozen": "AQT_STATE"}}, + } + } + } + } + } + out = LayerwiseQuantization._strip_kernels_at_quantized_paths(state) # pylint: disable=protected-access + wi = out["decoder"]["layers"]["mlp"]["wi"] + self.assertNotIn("kernel", wi) + self.assertIn("AqtDotGeneral_0", wi) + self.assertEqual(wi["AqtDotGeneral_0"]["qrhs"]["frozen"], "AQT_STATE") + + def test_preserves_non_quantized_kernel(self): + """A non-quantized kernel (e.g. embedding, norm) survives.""" + state = { + "decoder": { + "decoder_norm": {"scale": "NORM_SCALE"}, + "logits_dense": {"kernel": "LOGITS_KERNEL"}, # no AqtDotGeneral_0 sibling + }, + "token_embedder": {"embedding": "EMB"}, + } + out = LayerwiseQuantization._strip_kernels_at_quantized_paths(state) # pylint: disable=protected-access + self.assertEqual(out["decoder"]["logits_dense"]["kernel"], "LOGITS_KERNEL") + self.assertEqual(out["decoder"]["decoder_norm"]["scale"], "NORM_SCALE") + self.assertEqual(out["token_embedder"]["embedding"], "EMB") + + def test_mixed_tree(self): + """Quantized + non-quantized at the same depth: only the quantized one strips.""" + state = { + "self_attention": { + "qkv_proj": {"kernel": "QKV", "AqtDotGeneral_0": "AQT"}, + "out": {"kernel": "OUT_FULL"}, # non-quantized output projection + } + } + out = LayerwiseQuantization._strip_kernels_at_quantized_paths(state) # pylint: disable=protected-access + self.assertNotIn("kernel", out["self_attention"]["qkv_proj"]) + self.assertEqual(out["self_attention"]["out"]["kernel"], "OUT_FULL") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/linen_nnx_converter_test.py b/tests/unit/linen_nnx_converter_test.py new file mode 100644 index 0000000000..808990f8cf --- /dev/null +++ b/tests/unit/linen_nnx_converter_test.py @@ -0,0 +1,869 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for linen_nnx_converter utilities.""" + +import unittest +import numpy as np +from unittest.mock import MagicMock, patch + +from maxtext.checkpoint_conversion.linen_nnx_converter import ( + detect_format, + _has_value_wrappers, + _strip_value_wrappers, + _add_value_wrappers, + _transpose_layers_axes, + _stack_layers, + convert_linen_to_nnx, + convert_nnx_to_linen, + _convert_opt_state_linen_to_nnx, + _convert_opt_state_nnx_to_linen, + load_checkpoint, + save_checkpoint, + main, +) + + +def _make_array(*shape): + """Helper to create a numpy array with given shape.""" + return np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + +class TestDetectFormat(unittest.TestCase): + """Tests for the detect_format function.""" + + def test_raises_when_no_params_key(self): + with self.assertRaises(ValueError): + detect_format({"step": 0}) + + def test_detects_nnx_format_via_model_key(self): + # NNX: top-level "model" key + state = {"model": {"decoder": {"layers": {}}}, "optimizer": {}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_format_double_nested(self): + state = {"params": {"params": {"decoder": {"layers": {}}}}} + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_format_single_nested_with_value_wrappers(self): + # Old NNX format: params/decoder with {value:} wrappers + arr = _make_array(2, 2) + state = {"params": {"decoder": {"kernel": {"value": arr}}}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_via_encoder(self): + state = {"params": {"params": {"encoder": {"layers": {}}}}} + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_via_encoder_with_value_wrappers(self): + arr = _make_array(2, 2) + state = {"params": {"encoder": {"kernel": {"value": arr}}}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_nnx_via_optimizer_key(self): + arr = _make_array(2, 2) + state = {"params": {"something": arr}, "optimizer": {"step": 0}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_via_opt_state(self): + arr = _make_array(2, 2) + state = { + "params": {"something": arr}, + "opt_state": {"params": {"mu": {"decoder": {"kernel": arr}}}}, + } + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_via_optimizer_over_opt_state(self): + # "optimizer" key takes precedence for NNX detection + arr = _make_array(2, 2) + state = { + "params": {"something": arr}, + "optimizer": {"step": 0, "opt_state": {}}, + } + self.assertEqual(detect_format(state), "nnx") + + def test_raises_on_undetectable_format(self): + state = {"params": {"some_unknown_key": 42}} + with self.assertRaises(ValueError): + detect_format(state) + + +class TestHasValueWrappers(unittest.TestCase): + """Tests for the _has_value_wrappers helper.""" + + def test_returns_true_for_value_wrapper(self): + arr = _make_array(2, 2) + self.assertTrue(_has_value_wrappers({"value": arr})) + + def test_returns_true_for_nested_value_wrapper(self): + arr = _make_array(2, 2) + self.assertTrue(_has_value_wrappers({"mu": {"value": arr}})) + + def test_returns_false_for_plain_array(self): + # A plain array is not a {"value": ...} wrapper dict + self.assertFalse(_has_value_wrappers(_make_array(2, 2))) + + def test_returns_false_for_multi_key_dict(self): + arr = _make_array(2, 2) + self.assertFalse(_has_value_wrappers({"value": arr, "extra": arr})) + + def test_returns_false_for_non_array_value(self): + self.assertFalse(_has_value_wrappers({"value": "string"})) + + +class TestStripValueWrappers(unittest.TestCase): + """Tests for the _strip_value_wrappers helper.""" + + def test_strips_single_wrapper(self): + arr = _make_array(3, 4) + result = _strip_value_wrappers({"value": arr}) + np.testing.assert_array_equal(result, arr) + + def test_strips_nested_wrappers(self): + arr = _make_array(2, 2) + wrapped = {"decoder": {"layers": {"kernel": {"value": arr}}}} + stripped = _strip_value_wrappers(wrapped) + np.testing.assert_array_equal(stripped["decoder"]["layers"]["kernel"], arr) + + def test_passes_through_plain_array(self): + arr = _make_array(2, 3) + result = _strip_value_wrappers(arr) + np.testing.assert_array_equal(result, arr) + + def test_handles_list_and_tuple(self): + arr = _make_array(2) + result_list = _strip_value_wrappers([{"value": arr}]) + result_tuple = _strip_value_wrappers(({"value": arr},)) + np.testing.assert_array_equal(result_list[0], arr) + np.testing.assert_array_equal(result_tuple[0], arr) + + def test_passes_through_non_array_value(self): + # A dict with key "value" but scalar content should not be unwrapped + d = {"value": 42} + result = _strip_value_wrappers(d) + self.assertEqual(result, d) + + +class TestAddValueWrappers(unittest.TestCase): + """Tests for the _add_value_wrappers helper.""" + + def test_wraps_array(self): + arr = _make_array(3, 4) + result = _add_value_wrappers(arr) + self.assertIsInstance(result, dict) + self.assertIn("value", result) + np.testing.assert_array_equal(result["value"], arr) + + def test_wraps_nested_arrays(self): + arr = _make_array(2, 2) + nested = {"decoder": {"layers": {"kernel": arr}}} + wrapped = _add_value_wrappers(nested) + self.assertEqual(set(wrapped["decoder"]["layers"]["kernel"].keys()), {"value"}) + np.testing.assert_array_equal(wrapped["decoder"]["layers"]["kernel"]["value"], arr) + + def test_idempotent_on_already_wrapped(self): + arr = _make_array(2) + already_wrapped = {"value": arr} + result = _add_value_wrappers(already_wrapped) + # Should not double-wrap + self.assertEqual(set(result.keys()), {"value"}) + np.testing.assert_array_equal(result["value"], arr) + + def test_handles_list_and_tuple(self): + arr = _make_array(2) + result_list = _add_value_wrappers([arr]) + result_tuple = _add_value_wrappers((arr,)) + self.assertEqual(set(result_list[0].keys()), {"value"}) + self.assertEqual(set(result_tuple[0].keys()), {"value"}) + + def test_passes_through_non_array_scalars(self): + result = _add_value_wrappers(42) + self.assertEqual(result, 42) + result_str = _add_value_wrappers("text") + self.assertEqual(result_str, "text") + + +class TestTransposeLayersAxes(unittest.TestCase): + """Tests for the _transpose_layers_axes helper.""" + + def test_noop_when_same_axis(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes(arr, src_axis=0, dst_axis=0) + np.testing.assert_array_equal(result, arr) + + def test_transposes_axis_0_to_1(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes(arr, src_axis=0, dst_axis=1) + self.assertEqual(result.shape, (2, 4, 3)) + + def test_transposes_axis_1_to_0(self): + arr = _make_array(2, 4, 3) + result = _transpose_layers_axes(arr, src_axis=1, dst_axis=0) + self.assertEqual(result.shape, (4, 2, 3)) + + def test_transposes_nested_dict(self): + arr = _make_array(4, 2, 3) + tree = {"decoder": {"layers": {"kernel": arr}}} + result = _transpose_layers_axes(tree, src_axis=0, dst_axis=1) + self.assertEqual(result["decoder"]["layers"]["kernel"].shape, (2, 4, 3)) + + def test_passes_through_1d_array(self): + arr = _make_array(5) + result = _transpose_layers_axes(arr, src_axis=0, dst_axis=1) + # 1D array has no axis 1, should be returned unchanged + np.testing.assert_array_equal(result, arr) + + def test_handles_list(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes([arr], src_axis=0, dst_axis=1) + self.assertIsInstance(result, list) + self.assertEqual(result[0].shape, (2, 4, 3)) + + def test_handles_tuple(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes((arr,), src_axis=0, dst_axis=1) + self.assertIsInstance(result, tuple) + self.assertEqual(result[0].shape, (2, 4, 3)) + + +class TestStackLayers(unittest.TestCase): + """Tests for the _stack_layers helper.""" + + def test_stacks_individual_layers(self): + arr0 = _make_array(3, 4) + arr1 = _make_array(3, 4) + decoder = { + "layers_0": {"mlp": {"kernel": arr0}}, + "layers_1": {"mlp": {"kernel": arr1}}, + } + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("layers", result) + stacked = result["layers"]["mlp"]["kernel"] + self.assertEqual(stacked.shape, (2, 3, 4)) + np.testing.assert_array_equal(stacked[0], arr0) + np.testing.assert_array_equal(stacked[1], arr1) + + def test_noop_when_no_layer_pattern(self): + arr = _make_array(3, 4) + decoder = {"layers": {"mlp": {"kernel": arr}}} + result, was_stacked = _stack_layers(decoder) + self.assertFalse(was_stacked) + self.assertIs(result, decoder) + + def test_preserves_non_layer_keys(self): + norm_weight = _make_array(4) + arr0 = _make_array(3, 4) + decoder = { + "layers_0": {"mlp": {"kernel": arr0}}, + "final_norm": {"scale": norm_weight}, + } + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("final_norm", result) + np.testing.assert_array_equal(result["final_norm"]["scale"], norm_weight) + + def test_stacks_three_layers(self): + arrays = [_make_array(2, 2) for _ in range(3)] + decoder = {f"layers_{i}": {"w": arrays[i]} for i in range(3)} + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + stacked = result["layers"]["w"] + self.assertEqual(stacked.shape, (3, 2, 2)) + + def test_non_array_non_dict_leaf(self): + # Scalar leaf — stack_arrays returns first element + decoder = {"layers_0": {"count": 1}, "layers_1": {"count": 2}} + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("layers", result) + + def test_with_missing_key_in_some_layers(self): + arr = _make_array(3, 4) + decoder = { + "layers_0": {"mlp": {"kernel": arr, "bias": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, # no "bias" + } + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("kernel", result["layers"]["mlp"]) + + +class TestConvertLinenToNNX(unittest.TestCase): + """Tests for the convert_linen_to_nnx function.""" + + def _make_linen_state(self, add_opt_state=False): + """Creates a minimal Linen checkpoint structure.""" + arr = _make_array(2, 4, 3) + state = { + "step": 10, + "params": { + "params": { + "decoder": { + "layers": {"mlp": {"wi": {"kernel": arr}}}, + "decoder_norm": {"scale": _make_array(4)}, + } + } + }, + } + if add_opt_state: + state["opt_state"] = {"params": {"mu": {"decoder": {"layers": {"kernel": arr}}}}} + return state + + def test_converts_step_under_optimizer(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + self.assertEqual(result["optimizer"]["step"], 10) + + def test_step_not_at_top_level(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + self.assertNotIn("step", result) + + def test_params_stored_under_model_key(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + self.assertIn("model", result) + self.assertNotIn("params", result) + + def test_removes_double_nesting(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + # model should have 'decoder' directly, not 'params.decoder' + self.assertIn("decoder", result["model"]) + self.assertNotIn("params", result["model"]) + + def test_adds_value_wrappers(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + # Arrays should be wrapped in {"value": array} + kernel = result["model"]["decoder"]["layers"]["mlp"]["wi"]["kernel"] + self.assertIsInstance(kernel, dict) + self.assertIn("value", kernel) + + def test_converts_opt_state_under_optimizer(self): + state = self._make_linen_state(add_opt_state=True) + result = convert_linen_to_nnx(state) + self.assertIn("opt_state", result["optimizer"]) + # Linen opt_state had nested 'params' level; it should be removed + self.assertNotIn("params", result["optimizer"]["opt_state"]) + + def test_no_step_produces_no_optimizer_step(self): + arr = _make_array(2, 4, 3) + state = {"params": {"params": {"decoder": {"layers": {"kernel": arr}}}}} + result = convert_linen_to_nnx(state) + self.assertNotIn("step", result) + self.assertIn("model", result) + + def test_no_double_nesting_still_converts(self): + # Linen state without double-nesting (unusual but handled) + arr = _make_array(2, 4) + state = {"params": {"decoder": {"layers": {"kernel": arr}}}} + result = convert_linen_to_nnx(state) + self.assertIn("decoder", result["model"]) + + def test_no_params_key_only_step(self): + state = {"step": 3} + result = convert_linen_to_nnx(state) + self.assertEqual(result["optimizer"]["step"], 3) + self.assertNotIn("model", result) + + def test_with_per_layer_params_stacked_and_transposed(self): + # Linen checkpoint with layers_0, layers_1 → stacked + transposed to axis 1 + arr = _make_array(3, 4) + state = { + "params": { + "params": { + "decoder": { + "layers_0": {"mlp": {"kernel": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, + } + } + } + } + result = convert_linen_to_nnx(state) + stacked = result["model"]["decoder"]["layers"]["mlp"]["kernel"]["value"] + # Original (3, 4) stacked → (2, 3, 4), transposed to (3, 2, 4) + self.assertEqual(stacked.shape, (3, 2, 4)) + + +class TestConvertNNXToLinen(unittest.TestCase): + """Tests for the convert_nnx_to_linen function.""" + + def _make_nnx_state(self, add_opt_state=False): + """Creates an NNX checkpoint with 'model' and 'optimizer' keys. + + Uses 'attention' (not 'layers') as the sub-key so _convert_layers_to_linen_format + does not try to unstack the data. + """ + arr = _make_array(2, 4, 3) + state = { + "model": { + "decoder": { + "attention": {"wi": {"kernel": {"value": arr}}}, + "decoder_norm": {"scale": {"value": _make_array(4)}}, + } + }, + "optimizer": {"step": 5}, + } + if add_opt_state: + state["optimizer"]["opt_state"] = { + "mu": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + "nu": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + } + return state + + def test_converts_step(self): + state = self._make_nnx_state() + result = convert_nnx_to_linen(state) + self.assertEqual(result["step"], 5) + + def test_adds_double_nesting(self): + state = self._make_nnx_state() + result = convert_nnx_to_linen(state) + self.assertIn("params", result["params"]) + self.assertIn("decoder", result["params"]["params"]) + + def test_strips_value_wrappers(self): + state = self._make_nnx_state() + result = convert_nnx_to_linen(state) + kernel = result["params"]["params"]["decoder"]["attention"]["wi"]["kernel"] + self.assertIsInstance(kernel, np.ndarray) + + def test_converts_opt_state(self): + state = self._make_nnx_state(add_opt_state=True) + result = convert_nnx_to_linen(state) + self.assertIn("opt_state", result) + # mu/nu should get a 'params' level added + self.assertIn("params", result["opt_state"]["mu"]) + self.assertIn("params", result["opt_state"]["nu"]) + + def test_backward_compat_params_key(self): + # Old NNX format: "params" instead of "model", top-level "step" + arr = _make_array(2, 4, 3) + state = { + "step": 5, + "params": { + "decoder": { + "layers": {"mlp": {"wi": {"kernel": {"value": arr}}}}, + "decoder_norm": {"scale": {"value": _make_array(4)}}, + } + }, + } + result = convert_nnx_to_linen(state) + self.assertEqual(result["step"], 5) + self.assertIn("decoder", result["params"]["params"]) + + def test_no_step(self): + arr = _make_array(2, 4) + state = {"model": {"decoder": {"layers": {"kernel": {"value": arr}}}}} + result = convert_nnx_to_linen(state) + self.assertNotIn("step", result) + self.assertIn("params", result) + + +class TestRoundTrip(unittest.TestCase): + """Verifies that linen->nnx->linen round-trip preserves data.""" + + def test_linen_to_nnx_to_linen(self): + # Use "attention" (not "layers") so _convert_layers_to_linen_format + # does not try to unstack the dict as a stacked-layers tensor. + arr = _make_array(2, 4, 3) + linen_state = { + "step": 42, + "params": { + "params": { + "decoder": { + "attention": {"mlp": {"wi": {"kernel": arr}}}, + "norm": {"scale": _make_array(4)}, + } + } + }, + } + nnx_state = convert_linen_to_nnx(linen_state) + recovered_state = convert_nnx_to_linen(nnx_state) + + self.assertEqual(recovered_state["step"], 42) + recovered_kernel = recovered_state["params"]["params"]["decoder"]["attention"]["mlp"]["wi"]["kernel"] + np.testing.assert_array_equal(recovered_kernel, arr) + + def test_nnx_to_linen_to_nnx(self): + arr = _make_array(2, 4, 3) + nnx_state = { + "model": { + "decoder": { + "layers": {"mlp": {"wi": {"kernel": {"value": arr}}}}, + } + }, + "optimizer": {"step": 7}, + } + linen_state = convert_nnx_to_linen(nnx_state) + recovered_state = convert_linen_to_nnx(linen_state) + + self.assertEqual(recovered_state["optimizer"]["step"], 7) + recovered_kernel = recovered_state["model"]["decoder"]["layers"]["mlp"]["wi"]["kernel"] + self.assertIn("value", recovered_kernel) + np.testing.assert_array_equal(recovered_kernel["value"], arr) + + +class TestConvertOptState(unittest.TestCase): + """Tests for the _convert_opt_state_linen_to_nnx and _convert_opt_state_nnx_to_linen helpers.""" + + def test_linen_to_nnx_removes_params_level(self): + arr = _make_array(3, 4) + opt_state = {"mu": {"params": {"decoder": {"kernel": arr}}}} + result = _convert_opt_state_linen_to_nnx(opt_state) + # 'params' key removed; decoder promoted + self.assertNotIn("params", result["mu"]) + self.assertIn("decoder", result["mu"]) + # Arrays are plain (no value wrappers in NNX opt_state) + np.testing.assert_array_equal(result["mu"]["decoder"]["kernel"], arr) + + def test_linen_to_nnx_handles_list_input(self): + arr = _make_array(2, 2) + opt_state = [{"decoder": {"kernel": arr}}, {"decoder": {"kernel": arr}}] + result = _convert_opt_state_linen_to_nnx(opt_state) + self.assertIsInstance(result, list) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_linen_to_nnx_handles_tuple_input(self): + arr = _make_array(2, 2) + opt_state = ({"decoder": {"kernel": arr}},) + result = _convert_opt_state_linen_to_nnx(opt_state) + self.assertIsInstance(result, tuple) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_linen_to_nnx_handles_non_array_non_dict(self): + # Scalars should be passed through unchanged + result = _convert_opt_state_linen_to_nnx(42) + self.assertEqual(result, 42) + + def test_linen_to_nnx_params_key_with_non_dict_value(self): + # When k == "params" but converted value is not a dict, store it as-is + opt_state = {"params": 99} + result = _convert_opt_state_linen_to_nnx(opt_state) + self.assertIn("params", result) + self.assertEqual(result["params"], 99) + + def test_nnx_to_linen_adds_params_level_and_strips(self): + arr = _make_array(3, 4) + opt_state = { + "mu": {"decoder": {"kernel": {"value": arr}}}, + "nu": {"decoder": {"kernel": {"value": arr}}}, + } + result = _convert_opt_state_nnx_to_linen(opt_state) + # mu/nu should have 'params' nested inside + self.assertIn("params", result["mu"]) + self.assertIn("params", result["nu"]) + # Arrays unwrapped + kernel = result["mu"]["params"]["decoder"]["kernel"] + np.testing.assert_array_equal(kernel, arr) + + def test_nnx_to_linen_handles_list_input(self): + arr = _make_array(2, 2) + opt_state = [{"decoder": {"kernel": {"value": arr}}}] + result = _convert_opt_state_nnx_to_linen(opt_state) + self.assertIsInstance(result, list) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_nnx_to_linen_handles_tuple_input(self): + arr = _make_array(2, 2) + opt_state = ({"decoder": {"kernel": {"value": arr}}},) + result = _convert_opt_state_nnx_to_linen(opt_state) + self.assertIsInstance(result, tuple) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_nnx_to_linen_passes_through_scalars(self): + result = _convert_opt_state_nnx_to_linen("scalar_string") + self.assertEqual(result, "scalar_string") + + def test_nnx_to_linen_value_wrapper_with_non_array_inner(self): + # {"value": scalar} should NOT be unwrapped (only arrays get unwrapped) + d = {"value": 42} + result = _convert_opt_state_nnx_to_linen(d) + self.assertIn("value", result) + self.assertEqual(result["value"], 42) + + +class TestConvertLinenToNNXEncoder(unittest.TestCase): + """Tests encoder path in convert_linen_to_nnx.""" + + def test_converts_encoder_params(self): + arr = _make_array(2, 4, 3) + state = { + "params": { + "params": { + "encoder": { + "layers": {"mlp": {"wi": {"kernel": arr}}}, + } + } + } + } + result = convert_linen_to_nnx(state) + self.assertIn("encoder", result["model"]) + kernel = result["model"]["encoder"]["layers"]["mlp"]["wi"]["kernel"] + self.assertIsInstance(kernel, dict) + self.assertIn("value", kernel) + + def test_converts_encoder_with_per_layer_stacking(self): + arr = _make_array(3, 4) + state = { + "params": { + "params": { + "encoder": { + "layers_0": {"mlp": {"kernel": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, + } + } + } + } + result = convert_linen_to_nnx(state) + stacked = result["model"]["encoder"]["layers"]["mlp"]["kernel"]["value"] + # Stacked at axis 0 → (2, 3, 4), then transposed to (3, 2, 4) + self.assertEqual(stacked.shape, (3, 2, 4)) + + +class TestAdditionalEdgeCases(unittest.TestCase): + """Covers remaining edge cases.""" + + def test_detect_format_params_has_params_but_no_decoder_encoder(self): + # params["params"] exists but inner has no decoder/encoder -> falls through + # no optimizer/opt_state -> should raise + state = {"params": {"params": {"some_other_key": {}}}} + with self.assertRaises(ValueError): + detect_format(state) + + def test_detect_format_opt_state_returns_linen(self): + # Any state with "opt_state" (but no "model"/"optimizer") detects as linen + arr = _make_array(2) + state = { + "params": {"something": arr}, + "opt_state": {"mu": {"decoder": {"kernel": arr}}}, + } + self.assertEqual(detect_format(state), "linen") + + def test_add_value_wrappers_value_key_with_non_array(self): + # {"value": "text"} is not a wrapper (inner is not an array), recurse normally + d = {"value": "not_an_array"} + result = _add_value_wrappers(d) + self.assertEqual(result, {"value": "not_an_array"}) + + def test_convert_nnx_to_linen_no_step(self): + arr = _make_array(2, 4) + state = {"model": {"decoder": {"layers": {"kernel": {"value": arr}}}}} + result = convert_nnx_to_linen(state) + self.assertNotIn("step", result) + self.assertIn("params", result) + + def test_convert_nnx_to_linen_already_has_params_nesting(self): + arr = _make_array(2, 4) + state = {"params": {"params": {"decoder": {"layers": {"kernel": {"value": arr}}}}}} + result = convert_nnx_to_linen(state) + self.assertIn("params", result) + + def test_convert_nnx_to_linen_no_params_key(self): + state = {"optimizer": {"step": 8}} + result = convert_nnx_to_linen(state) + self.assertEqual(result["step"], 8) + self.assertNotIn("params", result) + + +class TestLoadCheckpoint(unittest.TestCase): + """Tests for load_checkpoint with mocked orbax/epath.""" + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_load_checkpoint_calls_checkpointer_and_returns_state(self, mock_epath, mock_ocp): + arr = _make_array(2, 2) + expected_state = {"params": arr, "step": 0} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + + mock_metadata = MagicMock() + mock_metadata.item_metadata.tree = {"params": arr} + + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = mock_metadata + mock_ckptr.restore.return_value = expected_state + mock_ocp.Checkpointer.return_value = mock_ckptr + mock_ocp.ArrayRestoreArgs.return_value = MagicMock() + + result = load_checkpoint("/tmp/test_ckpt") + + mock_epath.Path.assert_called_once_with("/tmp/test_ckpt") + mock_ocp.Checkpointer.assert_called_once() + mock_ckptr.metadata.assert_called_once_with(mock_path) + mock_ckptr.restore.assert_called_once() + self.assertEqual(result, expected_state) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_load_checkpoint_with_empty_tree_metadata(self, mock_epath, mock_ocp): + expected_state = {"step": 5} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + + mock_metadata = MagicMock() + mock_metadata.item_metadata.tree = {} + + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = mock_metadata + mock_ckptr.restore.return_value = expected_state + mock_ocp.Checkpointer.return_value = mock_ckptr + + result = load_checkpoint("/tmp/empty_ckpt") + + self.assertEqual(result["step"], 5) + + +class TestSaveCheckpoint(unittest.TestCase): + """Tests for save_checkpoint with mocked orbax/epath.""" + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_save_checkpoint_creates_dir_and_saves(self, mock_epath, mock_ocp): + state = {"params": _make_array(2, 2), "step": 1} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + + mock_ckptr = MagicMock() + mock_ocp.PyTreeCheckpointer.return_value = mock_ckptr + + save_checkpoint(state, "/tmp/output") + + mock_epath.Path.assert_called_once_with("/tmp/output") + mock_path.mkdir.assert_called_once_with(exist_ok=True, parents=True) + mock_ocp.PyTreeCheckpointer.assert_called_once() + mock_ckptr.save.assert_called_once_with(mock_path, state, force=True) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_save_checkpoint_passes_state_unchanged(self, mock_epath, mock_ocp): + state = {"step": 99, "params": {"decoder": {}}} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + mock_ckptr = MagicMock() + mock_ocp.PyTreeCheckpointer.return_value = mock_ckptr + + save_checkpoint(state, "/tmp/out2") + + call_args = mock_ckptr.save.call_args + self.assertIs(call_args[0][1], state) + + +class TestMain(unittest.TestCase): + """Tests for the main() CLI entry point.""" + + def _run_main(self, argv): + with patch("sys.argv", ["prog"] + argv): + main() + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_explicit_linen_to_nnx(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "step": 1, + "params": {"params": {"decoder": {"layers": {"kernel": arr}}}}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=linen_to_nnx"]) + mock_load.assert_called_once_with("/src") + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # NNX format: decoder at top level of model + self.assertIn("decoder", saved_state["model"]) + self.assertEqual(mock_save.call_args[0][1], "/dst") + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_explicit_nnx_to_linen(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "model": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + "optimizer": {"step": 2}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=nnx_to_linen"]) + mock_load.assert_called_once_with("/src") + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # Linen format: double nesting + self.assertIn("params", saved_state["params"]) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_auto_detects_linen_converts_to_nnx(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "step": 3, + "params": {"params": {"decoder": {"layers": {"kernel": arr}}}}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=auto"]) + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # Auto-detected linen → NNX format: model key + self.assertIn("decoder", saved_state["model"]) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_auto_detects_nnx_converts_to_linen(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "model": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + "optimizer": {"step": 4}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=auto"]) + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # Auto-detected nnx → Linen format + self.assertIn("params", saved_state["params"]) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_default_direction_is_auto(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "params": {"params": {"decoder": {"layers": {"kernel": arr}}}}, + } + # No --direction arg -> defaults to "auto" + self._run_main(["--source_path=/src", "--target_path=/dst"]) + mock_save.assert_called_once() + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_scan_layers_false(self, mock_load, mock_save): + arr = _make_array(3, 4) + mock_load.return_value = { + "params": { + "params": { + "decoder": { + "layers_0": {"mlp": {"kernel": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, + } + } + } + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=linen_to_nnx", "--no-scan_layers"]) + saved_state = mock_save.call_args[0][0] + # With scan_layers=False: integer-keyed layers/N + layers = saved_state["model"]["decoder"]["layers"] + self.assertIsInstance(layers, dict) + self.assertTrue(all(k.isdigit() for k in layers.keys())) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/lora_utils_nnx_test.py b/tests/unit/lora_utils_nnx_test.py new file mode 100644 index 0000000000..e0e8cbb529 --- /dev/null +++ b/tests/unit/lora_utils_nnx_test.py @@ -0,0 +1,293 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX-shaped LoRA helpers in `lora_utils`, plus a small +Linen regression block.""" + +import unittest + +import jax +import jax.numpy as jnp +import numpy as np + +from maxtext.utils.lora_utils import ( + apply_lora_on_base_params, + apply_lora_on_base_params_nnx, + get_lora_abstract_state_nnx, + unapply_lora_from_base_params, + unapply_lora_from_base_params_nnx, +) + + +# --------------------------------------------------------------------------- +# Fake abstract state builders (mirror the NNX vs. Linen tree shapes) +# --------------------------------------------------------------------------- + + +def _make_nnx_attention_abstract(emb=8, num_heads=2, head_dim=4, dtype=jnp.float32): + """Tiny NNX-shaped abstract state for one attention block.""" + + def _sds(shape): + return jax.ShapeDtypeStruct(shape=shape, dtype=dtype, sharding=None) + + return { + "decoder": { + "layers": { + "self_attention": { + "query": {"kernel": _sds((emb, num_heads, head_dim))}, + "key": {"kernel": _sds((emb, num_heads, head_dim))}, + "value": {"kernel": _sds((emb, num_heads, head_dim))}, + "out": {"kernel": _sds((emb, num_heads, head_dim))}, + }, + "mlp": {"wi": {"kernel": _sds((emb, 4 * emb))}}, + }, + "shared_embedding": {"embedding": _sds((100, emb))}, + }, + } + + +def _make_linen_attention_abstract(emb=8, num_heads=2, head_dim=4, dtype=jnp.float32): + """Linen-shaped equivalent (with the `{"params": ...}` outer wrap).""" + return {"params": _make_nnx_attention_abstract(emb, num_heads, head_dim, dtype)} + + +def _lora_config(rank=4, alpha=8.0, target_modules=("q_proj", "v_proj")): + return { + "r": rank, + "lora_alpha": alpha, + "target_modules": list(target_modules), + } + + +# --------------------------------------------------------------------------- +# get_lora_abstract_state_nnx +# --------------------------------------------------------------------------- + + +class TestGetLoraAbstractStateNnx(unittest.TestCase): + """`get_lora_abstract_state_nnx` shape, sharding, and error-path coverage.""" + + def test_lora_shapes_for_query_and_value(self): + abs_params = _make_nnx_attention_abstract(emb=8, num_heads=2, head_dim=4) + state, _ = get_lora_abstract_state_nnx(abs_params, _lora_config(rank=4)) + attn = state.params["decoder"]["layers"]["self_attention"] + + a = attn["query"]["lora_a.kernel"] + b = attn["query"]["lora_b.kernel"] + self.assertEqual(a.shape, (8, 4)) + self.assertEqual(b.shape, (4, 2, 4)) + self.assertEqual(a.dtype, jnp.float32) + self.assertEqual(b.dtype, jnp.float32) + + a = attn["value"]["lora_a.kernel"] + b = attn["value"]["lora_b.kernel"] + self.assertEqual(a.shape, (8, 4)) + self.assertEqual(b.shape, (4, 2, 4)) + + def test_non_target_modules_emit_none_leaves(self): + abs_params = _make_nnx_attention_abstract() + state, _ = get_lora_abstract_state_nnx(abs_params, _lora_config(target_modules=("q_proj",))) + attn = state.params["decoder"]["layers"]["self_attention"] + self.assertIn("lora_a.kernel", attn["query"]) + self.assertIsNone(attn["key"]["kernel"]) + self.assertIsNone(attn["value"]["kernel"]) + self.assertIsNone(attn["out"]["kernel"]) + self.assertIsNone(state.params["decoder"]["layers"]["mlp"]["wi"]["kernel"]) + self.assertIsNone(state.params["decoder"]["shared_embedding"]["embedding"]) + + def test_o_proj_has_distinct_shape(self): + abs_params = _make_nnx_attention_abstract(emb=8, num_heads=2, head_dim=4) + state, _ = get_lora_abstract_state_nnx(abs_params, _lora_config(rank=3, target_modules=("o_proj",))) + out = state.params["decoder"]["layers"]["self_attention"]["out"] + a = out["lora_a.kernel"] + b = out["lora_b.kernel"] + # 3D base (emb, num_heads, head_dim) → lora_a.shape = (..., r), lora_b = (r, last) + self.assertEqual(a.shape, (8, 2, 3)) + self.assertEqual(b.shape, (3, 4)) + + def test_unsupported_leaf_type_raises(self): + bad = {"decoder": {"layers": {"self_attention": {"query": {"kernel": jnp.zeros((4, 2, 2))}}}}} + with self.assertRaises(ValueError): + get_lora_abstract_state_nnx(bad, _lora_config()) + + def test_unexpected_leaf_name_raises(self): + bad = {"decoder": {"layers": {"self_attention": {"query": {"weight": jax.ShapeDtypeStruct((4, 2), jnp.float32)}}}}} + with self.assertRaises(ValueError): + get_lora_abstract_state_nnx(bad, _lora_config()) + + # Linen-vs-NNX numerical parity is covered by TestApplyLoraNnx.test_numerical_parity_with_linen_apply. + + +# --------------------------------------------------------------------------- +# apply / unapply on NNX-shape pure dicts +# --------------------------------------------------------------------------- + + +def _concrete_base(rng=None, emb=4, num_heads=2, head_dim=3): + """Concrete arrays mirroring the abstract structure used above (NNX-shape).""" + if rng is None: + rng = jax.random.key(0) + k1, k2, k3, k4, k5, k6 = jax.random.split(rng, 6) + shape_attn = (emb, num_heads, head_dim) + return { + "decoder": { + "layers": { + "self_attention": { + "query": {"kernel": jax.random.normal(k1, shape_attn)}, + "key": {"kernel": jax.random.normal(k2, shape_attn)}, + "value": {"kernel": jax.random.normal(k3, shape_attn)}, + "out": {"kernel": jax.random.normal(k4, shape_attn)}, + }, + "mlp": {"wi": {"kernel": jax.random.normal(k5, (emb, 4 * emb))}}, + }, + "shared_embedding": {"embedding": jax.random.normal(k6, (100, emb))}, + }, + } + + +def _build_lora_params(base, lora_config_dict, rng): + """Build a concrete LoRA tree (random arrays) matching `base`.""" + abs_tree = jax.tree_util.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=None), base) + lora_state, _ = get_lora_abstract_state_nnx(abs_tree, lora_config_dict) + + def _to_concrete(leaf, rng_key): + if leaf is None: + return None + return jax.random.normal(rng_key, leaf.shape, leaf.dtype) + + leaves, tree = jax.tree_util.tree_flatten(lora_state.params, is_leaf=lambda x: x is None) + rngs = jax.random.split(rng, max(1, len(leaves))) + out_leaves = [_to_concrete(l, r) for l, r in zip(leaves, rngs)] + return jax.tree_util.tree_unflatten(tree, out_leaves) + + +class TestApplyLoraNnx(unittest.TestCase): + """`apply_lora_on_base_params_nnx` round-trip and Linen-vs-NNX parity.""" + + def test_apply_then_unapply_is_identity(self): + rng = jax.random.key(42) + base_orig = _concrete_base(rng) + base = jax.tree_util.tree_map(jnp.copy, base_orig) + lora = _build_lora_params(base, _lora_config(rank=2, target_modules=("q_proj", "v_proj")), jax.random.key(7)) + apply_lora_on_base_params_nnx(base, lora, lora_scale_factor=0.5) + # query/value kernels were modified + self.assertFalse( + jnp.allclose( + base["decoder"]["layers"]["self_attention"]["query"]["kernel"], + base_orig["decoder"]["layers"]["self_attention"]["query"]["kernel"], + ) + ) + # key/out are untouched + np.testing.assert_array_equal( + np.asarray(base["decoder"]["layers"]["self_attention"]["key"]["kernel"]), + np.asarray(base_orig["decoder"]["layers"]["self_attention"]["key"]["kernel"]), + ) + np.testing.assert_array_equal( + np.asarray(base["decoder"]["layers"]["self_attention"]["out"]["kernel"]), + np.asarray(base_orig["decoder"]["layers"]["self_attention"]["out"]["kernel"]), + ) + unapply_lora_from_base_params_nnx(base, lora, lora_scale_factor=0.5) + np.testing.assert_allclose( + np.asarray(base["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + np.asarray(base_orig["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + rtol=1e-5, + atol=1e-6, + ) + np.testing.assert_allclose( + np.asarray(base["decoder"]["layers"]["self_attention"]["value"]["kernel"]), + np.asarray(base_orig["decoder"]["layers"]["self_attention"]["value"]["kernel"]), + rtol=1e-5, + atol=1e-6, + ) + + def test_numerical_parity_with_linen_apply(self): + """Same base+lora numbers → same kernel after apply, on either tree shape.""" + rng = jax.random.key(123) + base_nnx = _concrete_base(rng) + base_linen = {"params": jax.tree_util.tree_map(jnp.copy, base_nnx)} + lora = _build_lora_params(base_nnx, _lora_config(rank=2, target_modules=("q_proj",)), jax.random.key(5)) + apply_lora_on_base_params_nnx(base_nnx, lora, lora_scale_factor=0.7) + apply_lora_on_base_params(base_linen, {"params": lora}, lora_scale_factor=0.7) + np.testing.assert_allclose( + np.asarray(base_nnx["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + np.asarray(base_linen["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + rtol=1e-6, + ) + + def test_apply_with_unexpected_lora_key_raises(self): + base = _concrete_base() + bad = {"decoder": {"layers": {"self_attention": {"query": {"unexpected": jnp.zeros((4, 2))}}}}} + with self.assertRaises(ValueError): + apply_lora_on_base_params_nnx(base, bad) + + +class TestLinenLoraRegression(unittest.TestCase): + """Smoke tests for the Linen apply / unapply helpers (no other unit test exercises them).""" + + def _linen_pair(self, rng=None): + """Build a Linen-shape (with `{"params": ...}` outer wrapper) base + lora pair.""" + if rng is None: + rng = jax.random.key(99) + base_inner = _concrete_base(rng) + base = {"params": jax.tree_util.tree_map(jnp.copy, base_inner)} + lora_inner = _build_lora_params( + base_inner, + _lora_config(rank=2, target_modules=("q_proj", "v_proj")), + jax.random.key(7), + ) + lora = {"params": lora_inner} + return base, lora + + def test_linen_apply_then_unapply_is_identity(self): + base, lora = self._linen_pair() + base_orig = jax.tree_util.tree_map(jnp.copy, base) + apply_lora_on_base_params(base, lora, lora_scale_factor=0.5) + unapply_lora_from_base_params(base, lora, lora_scale_factor=0.5) + np.testing.assert_allclose( + np.asarray(base["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + np.asarray(base_orig["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + rtol=1e-5, + atol=1e-6, + ) + np.testing.assert_allclose( + np.asarray(base["params"]["decoder"]["layers"]["self_attention"]["value"]["kernel"]), + np.asarray(base_orig["params"]["decoder"]["layers"]["self_attention"]["value"]["kernel"]), + rtol=1e-5, + atol=1e-6, + ) + + def test_linen_apply_only_modifies_target_modules(self): + base, lora = self._linen_pair() + base_orig = jax.tree_util.tree_map(jnp.copy, base) + apply_lora_on_base_params(base, lora, lora_scale_factor=1.0) + # query and value are targets — must change. + self.assertFalse( + jnp.allclose( + base["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"], + base_orig["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"], + ) + ) + # key and out are non-target — must be untouched. + np.testing.assert_array_equal( + np.asarray(base["params"]["decoder"]["layers"]["self_attention"]["key"]["kernel"]), + np.asarray(base_orig["params"]["decoder"]["layers"]["self_attention"]["key"]["kernel"]), + ) + np.testing.assert_array_equal( + np.asarray(base["params"]["decoder"]["layers"]["self_attention"]["out"]["kernel"]), + np.asarray(base_orig["params"]["decoder"]["layers"]["self_attention"]["out"]["kernel"]), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/maxengine_test.py b/tests/unit/maxengine_test.py index 944d34bfef..c9c880024e 100644 --- a/tests/unit/maxengine_test.py +++ b/tests/unit/maxengine_test.py @@ -23,6 +23,8 @@ from jax.sharding import Mesh import numpy as np import pytest +from flax import nnx +from flax.linen import partitioning as nn_partitioning from maxtext.configs import pyconfig from maxtext.common.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_PREFILL from maxtext.layers import quantizations @@ -30,7 +32,10 @@ pytest.importorskip("jetstream", reason="jetstream not installed") from maxtext.inference.maxengine import maxengine from maxtext.models import models +from maxtext.checkpoint_conversion import linen_nnx_converter +from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +from maxtext.utils import model_creation_utils from tests.utils.test_helpers import get_test_config_path pytestmark = [pytest.mark.external_serving] @@ -162,6 +167,214 @@ def test_basic_decode(self): self.assertEqual(result_token.data.ndim, 2) self.assertEqual(result_token.data.shape[1], 3) + def _init_nnx_pyconfig(self, **kwargs): + """init_pyconfig with NNX flags on.""" + return self.init_pyconfig(pure_nnx=True, enable_nnx=True, pure_nnx_decoder=True, **kwargs) + + def _build_nnx_params(self, cfg, mesh): + """Materialize an NNX Transformer and return its nnx.Param state.""" + _create_model = model_creation_utils.get_nnx_create_model_fn(cfg, mesh=mesh, model_mode=MODEL_MODE_PREFILL) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + model = _create_model() + _, params_state, _ = nnx.split(model, nnx.Param, ...) + return params_state + + def test_init_nnx(self): + """NNX engine init exposes graphdef + abstract Transformer.""" + cfg = self._init_nnx_pyconfig() + engine = maxengine.MaxEngine(cfg, jax.devices()) + self.assertIsNotNone(engine.graphdef) + self.assertIsNotNone(engine.model) + self.assertEqual(type(engine.model).__name__, "Transformer") + + def test_basic_prefill_nnx(self): + """NNX prefill returns a Linen-shape result dict with finite values.""" + cfg = self._init_nnx_pyconfig() + devices_array = maxtext_utils.create_device_mesh(cfg) + mesh = Mesh(devices_array, cfg.mesh_axes) + params_state = self._build_nnx_params(cfg, mesh) + + input_tokens = jnp.array([1, 306, 5360, 304, 0, 0, 0, 0]) + true_length = 4 + engine = maxengine.MaxEngine(cfg, jax.devices()) + params = engine.load_params(params=params_state) + prefill_result, first_token = engine.prefill(params=params, padded_tokens=input_tokens, true_length=true_length) + + self.assertEqual(prefill_result["generated_tokens"], jnp.array([0])) + self.assertEqual(prefill_result["tokens"].size, 1) + self.assertTrue(jnp.array_equal(first_token.data.size, 3)) + self.assertEqual(first_token.log_prob.shape, (1, 1)) + self.assertIn("cache", prefill_result) + self.assertIsInstance(prefill_result["cache"], dict) + # Catch silent NaN/inf from a bad nnx.merge or cache round-trip. + self.assertTrue(jnp.all(jnp.isfinite(prefill_result["logits"]))) + cache_leaves, _ = jax.tree.flatten(prefill_result["cache"]) + for leaf in cache_leaves: + self.assertTrue(jnp.all(jnp.isfinite(leaf)), msg=f"non-finite cache leaf, shape={leaf.shape}") + # scan_layers=True (default in test config) ⇒ leading axis is num_decoder_layers. + for leaf in cache_leaves: + self.assertEqual(leaf.shape[0], cfg.num_decoder_layers, msg=f"layer-axis mismatch, got shape={leaf.shape}") + + def test_basic_decode_nnx(self): + """NNX prefill → insert → 4 generate steps. Verifies next_pos advances and logits stay finite.""" + cfg = self._init_nnx_pyconfig() + devices_array = maxtext_utils.create_device_mesh(cfg) + mesh = Mesh(devices_array, cfg.mesh_axes) + params_state = self._build_nnx_params(cfg, mesh) + + input_tokens = jnp.array([1, 306, 5360, 304]) + engine = maxengine.MaxEngine(cfg, jax.devices()) + params = engine.load_params(params=params_state) + decode_state = engine.init_decode_state() + prefill_result, _ = engine.prefill(params=params, padded_tokens=input_tokens, true_length=4) + decode_state = engine.insert(prefill_result, decode_state, slot=0) + + # 4 steps is enough to catch off-by-one cache pointer bugs. + initial_next_pos = int(decode_state["next_pos"][0, 0]) + for step in range(4): + decode_state, result_token = engine.generate(params=params, decode_state=decode_state) + self.assertEqual(result_token.log_prob.ndim, 2) + self.assertEqual(result_token.log_prob.shape[1], 1) + self.assertEqual(result_token.data.ndim, 2) + self.assertEqual(result_token.data.shape[1], 3) + self.assertTrue(jnp.all(jnp.isfinite(decode_state["logits"]))) + self.assertEqual( + int(decode_state["next_pos"][0, 0]), + initial_next_pos + step + 1, + msg=f"next_pos didn't advance at step {step}", + ) + + def test_quantize_passes_gate_for_nnx(self): + """pure_nnx + quantization (convert-on-load) reaches the actual machinery in train mode.""" + # checkpoint_is_quantized defaults to False — full-precision on disk, AQT + # quantizes per-forward against the loaded kernel (train mode). + cfg = self._init_nnx_pyconfig(quantization="int8") + engine = maxengine.MaxEngine(cfg, jax.devices()) + self.assertEqual(engine._nnx_quant_mode_str, "train") # pylint: disable=protected-access + try: + engine.load_params(rng=self.rng) + except NotImplementedError as e: + self.fail(f"convert-on-load path should not raise NotImplementedError; got: {e}") + except Exception: # pylint: disable=broad-except + pass # any other failure (e.g. checkpoint not found) is fine for this test + + def test_load_pre_quantized_nnx_passes_quant_gate(self): + """pure_nnx + quantization + checkpoint_is_quantized=True clears the load gate.""" + cfg = self._init_nnx_pyconfig(quantization="int8", checkpoint_is_quantized=True) + engine = maxengine.MaxEngine(cfg, jax.devices()) + self.assertEqual(engine._nnx_quant_mode_str, "serve") # pylint: disable=protected-access + try: + engine.load_params(rng=self.rng) + except NotImplementedError as e: + self.fail(f"checkpoint_is_quantized=True path should not raise NotImplementedError; got: {e}") + except Exception: # pylint: disable=broad-except + pass # any other failure (e.g. checkpoint not found) is fine for this test + + def test_quantized_prefill_nnx_train_mode(self): + """End-to-end: NNX prefill with quantization=int8 + checkpoint_is_quantized=False. + + TRAIN-mode AQT layers quantize per-forward against the loaded full-precision + kernel; output must be finite and shape-valid. This is the real numerical + verification that the convert-on-load path produces a usable model. + """ + cfg = self._init_nnx_pyconfig(quantization="int8") + self.assertFalse(cfg.checkpoint_is_quantized) + devices_array = maxtext_utils.create_device_mesh(cfg) + mesh = Mesh(devices_array, cfg.mesh_axes) + params_state = self._build_nnx_params(cfg, mesh) + + engine = maxengine.MaxEngine(cfg, jax.devices()) + self.assertEqual(engine._nnx_quant_mode_str, "train") # pylint: disable=protected-access + params = engine.load_params(params=params_state) + input_tokens = jnp.array([1, 306, 5360, 304, 0, 0, 0, 0]) + prefill_result, _ = engine.prefill(params=params, padded_tokens=input_tokens, true_length=4) + self.assertTrue(jnp.all(jnp.isfinite(prefill_result["logits"]))) + + def test_lora_load_single_adapter_reaches_loader_on_nnx(self): + """pure_nnx + LoRA: load_single_adapter dispatches to the NNX loader. + + With a nonexistent path the loader raises FileNotFoundError (not + NotImplementedError, which would mean the dispatch never reached the loader). + """ + cfg = self._init_nnx_pyconfig() + engine = maxengine.MaxEngine(cfg, jax.devices()) + with self.assertRaises(FileNotFoundError): + engine.load_single_adapter("/nonexistent/adapter/path") + + def _linen_params_to_nnx_state(self, linen_params, abstract_nnx_model): + """Convert Linen params → NNX nnx.Param state via linen_nnx_converter so both engines share weights.""" + nnx_dict_wrapped = linen_nnx_converter.convert_linen_to_nnx({"params": linen_params}, scan_layers=True)["model"] + # pylint: disable=protected-access + nnx_pure = linen_nnx_converter._strip_value_wrappers(nnx_dict_wrapped) + _, params_state, _ = nnx.split(abstract_nnx_model, nnx.Param, ...) + nnx.replace_by_pure_dict(params_state, nnx_pure) + return params_state + + def test_linen_nnx_parity_prefill(self): + """Same weights → same prefill output across Linen and NNX engines. + + A failure here means the NNX forward pass diverges from Linen on identical + weights (cache plumbing, nnx.merge wiring, or Transformer.__call__). + """ + cfg_linen = self.init_pyconfig() + devices_array = maxtext_utils.create_device_mesh(cfg_linen) + mesh = Mesh(devices_array, cfg_linen.mesh_axes) + + # Linen: init params, run prefill. + quant = quantizations.configure_quantization(cfg_linen) + linen_model = models.transformer_as_linen(config=cfg_linen, mesh=mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + ids, decoder_segment_ids, decoder_positions = self.get_data() + linen_vars = linen_model.init( + {"params": self.rng, "aqt": self.rng, "dropout": self.rng}, + ids, + decoder_positions, + decoder_segment_ids, + enable_dropout=False, + ) + # Linen.init wraps leaves in LogicallyPartitioned (which has a `.value` + # attribute); unbox so the converter's {value:} wrapper detector doesn't + # mistake them for already-wrapped NNX leaves. + linen_vars = max_utils.unbox_logicallypartioned(linen_vars) + + input_tokens = jnp.array([1, 306, 5360, 304, 0, 0, 0, 0]) + true_length = 4 + linen_engine = maxengine.MaxEngine(cfg_linen, jax.devices()) + linen_params = linen_engine.load_params(params=linen_vars) + linen_prefill, linen_first_token = linen_engine.prefill( + params=linen_params, padded_tokens=input_tokens, true_length=true_length + ) + + # NNX: bridge Linen weights, run prefill on the same prompt. + cfg_nnx = self._init_nnx_pyconfig() + nnx_engine = maxengine.MaxEngine(cfg_nnx, jax.devices()) + nnx_params_state = self._linen_params_to_nnx_state(linen_vars["params"], nnx_engine.model) + nnx_params = nnx_engine.load_params(params=nnx_params_state) + nnx_prefill, nnx_first_token = nnx_engine.prefill( + params=nnx_params, padded_tokens=input_tokens, true_length=true_length + ) + + # Tolerance is loose because the test config uses bf16 compute, where + # accumulation order between Linen-scan and NNX-scan drifts by ~0.05. + # Greedy match below is the behavioral check that actually matters. + linen_logits = np.asarray(linen_prefill["logits"]) + nnx_logits = np.asarray(nnx_prefill["logits"]) + self.assertEqual(linen_logits.shape, nnx_logits.shape) + np.testing.assert_allclose( + linen_logits, + nnx_logits, + rtol=0.05, + atol=0.1, + err_msg="Linen vs NNX prefill logits diverge beyond bf16 tolerance.", + ) + self.assertEqual( + int(linen_first_token.data[0, 0]), + int(nnx_first_token.data[0, 0]), + msg="Linen and NNX disagreed on greedy first token with identical weights.", + ) + linen_cache_leaves, _ = jax.tree.flatten(linen_prefill["cache"]) + nnx_cache_leaves, _ = jax.tree.flatten(nnx_prefill["cache"]) + self.assertEqual(len(linen_cache_leaves), len(nnx_cache_leaves)) + @pytest.mark.skip(reason="Can only pass on CPU.") def test_chunked_prefill(self): """Test identical result between chunked prefill with single and multiple chunked. diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index e24e662543..ca5dbe7a87 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -15,11 +15,13 @@ """Tests for the common MaxText utilities""" import functools -from typing import Any, Sequence from collections.abc import Callable +from typing import Any, Sequence import unittest from unittest.mock import MagicMock, Mock, patch from dataclasses import dataclass, field +import numpy as np +import optax from flax import linen as nn from flax import nnx @@ -29,6 +31,7 @@ from jax import random, vmap import jax.numpy as jnp from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils from maxtext.configs import pyconfig from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_TRAIN, ShardMode from maxtext.inference import inference_utils @@ -39,8 +42,7 @@ from maxtext.utils import sharding from maxtext.utils.sharding import assert_params_sufficiently_sharded, get_formatted_sharding_annotations from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides -import numpy as np -import optax +from maxtext.utils import maxtext_utils_nnx Transformer = models.transformer_as_linen @@ -179,11 +181,7 @@ def setUp(self): "decoder": {"gate": {"bias": jnp.array([0.5, 0.5])}}, } self.state = train_state.TrainState( - step=0, - apply_fn=self.model.apply, - params=self.initial_params, - tx=None, - opt_state={}, + step=0, apply_fn=self.model.apply, params=self.initial_params, tx=None, opt_state={} ) def test_update_mode_add(self): @@ -196,10 +194,10 @@ def test_update_mode_add(self): self.assertTrue(jnp.allclose(actual, expected)) # Other values are untouched - original_layer_0 = self.state.params["layers"]["layer_0"]["bias"] + original_layer_0 = self.state.params["layers"]["layer_0"]["bias"] # pylint: disable=unsubscriptable-object new_layer_0 = new_state.params["layers"]["layer_0"]["bias"] self.assertTrue(jnp.array_equal(original_layer_0, new_layer_0)) - original_layer_1 = self.state.params["layers"]["layer_1"]["bias"] + original_layer_1 = self.state.params["layers"]["layer_1"]["bias"] # pylint: disable=unsubscriptable-object new_layer_1 = new_state.params["layers"]["layer_1"]["bias"] self.assertTrue(jnp.array_equal(original_layer_1, new_layer_1)) @@ -264,7 +262,7 @@ def test_init_training_state(self): @nnx.register_variable_name("special_variables") -class SpecialVariables(nnx.Variable): +class SpecialVariables(nnx.Variable): # pylint: disable=abstract-method pass @@ -281,7 +279,7 @@ def __call__(self, x, y, encoder_images=None, nnx_method=None, model_mode=None): return x -class TrainState(train_state.TrainState): +class TrainState(train_state.TrainState): # pylint: disable=abstract-method other_variables: nnx.State @@ -993,49 +991,63 @@ def train_step(_model, _config, _state_shardings, _params_shardings, state, _bat return train_step + def _make_mock_config(self, pure_nnx=False): + cfg = MagicMock() + cfg.pure_nnx = pure_nnx + return cfg + def test_returns_five_tuple(self): step = self._make_mock_step() result = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config() ) self.assertEqual(len(result), 5) def test_functional_train_has_correct_name(self): step = self._make_mock_step() fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config() ) self.assertEqual(fn.__name__, "train_step") - def test_in_shardings_structure(self): + def test_linen_in_shardings_includes_rng(self): + """pure_nnx=False: in_shardings should be (state, batch, rng).""" step = self._make_mock_step() _, in_shardings, _, _, _ = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config(pure_nnx=False) ) - # (state, batch, rng) self.assertEqual(len(in_shardings), 3) self.assertIsNone(in_shardings[2]) # rng sharding is None + def test_nnx_in_shardings_excludes_rng(self): + """pure_nnx=True: in_shardings should be (state, batch) — no rng slot.""" + step = self._make_mock_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_train_with_signature( + step, "data_sharding", "state_shardings", "model", self._make_mock_config(pure_nnx=True) + ) + self.assertEqual(len(in_shardings), 2) + def test_donate_argnums_is_zero(self): step = self._make_mock_step() _, _, _, _, donate_argnums = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config() ) self.assertEqual(donate_argnums, 0) def test_functional_train_is_partial(self): """functional_train should partially apply model and config.""" received = {} + cfg = self._make_mock_config() def train_step(model, config, _state_shardings, _params_shardings, state, _batch, _rng=None): received["model"] = model received["config"] = config return state, {} - fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature(train_step, "ds", "ss", "my_model", "my_config") + fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature(train_step, "ds", "ss", "my_model", cfg) fn("state", "batch") self.assertEqual(received["model"], "my_model") - self.assertEqual(received["config"], "my_config") + self.assertEqual(received["config"], cfg) class TestGetFunctionalEvalWithSignature(unittest.TestCase): @@ -1047,26 +1059,51 @@ def eval_step(_model, _config, _state, _batch, _rng=None): return eval_step + def _make_mock_config(self, pure_nnx=False): + cfg = MagicMock() + cfg.pure_nnx = pure_nnx + return cfg + def test_returns_five_tuple(self): step = self._make_mock_eval_step() - result = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + result = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", self._make_mock_config()) self.assertEqual(len(result), 5) def test_functional_eval_has_correct_name(self): step = self._make_mock_eval_step() - fn, _, _, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + fn, _, _, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", self._make_mock_config()) self.assertEqual(fn.__name__, "eval_step") def test_out_shardings_is_none(self): step = self._make_mock_eval_step() - _, _, out_shardings, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + _, _, out_shardings, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "ds", "ss", "model", self._make_mock_config() + ) self.assertIsNone(out_shardings) def test_donate_argnums_is_empty(self): step = self._make_mock_eval_step() - _, _, _, _, donate_argnums = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + _, _, _, _, donate_argnums = maxtext_utils.get_functional_eval_with_signature( + step, "ds", "ss", "model", self._make_mock_config() + ) self.assertEqual(donate_argnums, ()) + def test_nnx_in_shardings_excludes_rng(self): + """pure_nnx=True: in_shardings should be (state, batch) — no rng slot.""" + step = self._make_mock_eval_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "batch_sharding", "state_sharding", "model", self._make_mock_config(pure_nnx=True) + ) + self.assertEqual(len(in_shardings), 2) + + def test_linen_in_shardings_includes_rng(self): + """pure_nnx=False: in_shardings should be (state, batch, rng).""" + step = self._make_mock_eval_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "batch_sharding", "state_sharding", "model", self._make_mock_config(pure_nnx=False) + ) + self.assertEqual(len(in_shardings), 3) + class TestGetShapedBatch(unittest.TestCase): """Tests for get_shaped_batch.""" @@ -1420,5 +1457,183 @@ def test_runs_without_logical_annotations(self): maxtext_utils.print_shardings_params(params, param_sharding, mesh=self.mesh, logical_annotations=None) +class TestNNXAbstractState(unittest.TestCase): + """Test the get_abstract_state_nnx func.""" + + @dataclass + class MockConfig: + init_weights_seed: int = 42 + shard_optimizer_over_data: bool = False + optimizer_memory_host_offload: bool = False + parameter_memory_host_offload: bool = False + param_scan_axis: int = 0 + logical_axis_rules: list = field(default_factory=lambda: [["data", ["data"]]]) + + class MockTrainState(nnx.Module): + """Simulates a TrainState with params and optimizer state.""" + + def __init__(self, rngs: nnx.Rngs): + # Model parameters + device_num = len(jax.local_devices()) + self.params = nnx.Linear( + 2, 4, kernel_init=nnx.with_partitioning(nnx.initializers.ones, sharding=("model",)), rngs=rngs + ) + # Simulated optimizer state + self.optimizer = nnx.Variable(jnp.zeros((device_num,)), sharding=("model",)) + + def setUp(self): + # Create a real 1D mesh on local devices + devices = jax.local_devices() + self.mesh = Mesh(mesh_utils.create_device_mesh((len(devices), 1)), axis_names=("model", "data")) + self.config = self.MockConfig() + + def nnx_init_trainstate_wrapper(self): + """Wrapper to initialize the mock NNX model.""" + rngs = maxtext_utils_nnx.create_nnx_rngs(self.config) + return self.MockTrainState(rngs) + + def test_basic_abstraction(self): + """Verifies the basic return structure and partition spec extraction.""" + abstract_state, annotations, shardings = maxtext_utils.get_abstract_state_nnx( + self.config, self.mesh, self.nnx_init_trainstate_wrapper + ) + + # Check return types + self.assertIsInstance(abstract_state, nnx.State) + self.assertIsInstance(annotations, nnx.State) + self.assertIsInstance(shardings, nnx.State) + + # Verify PartitionSpec was extracted correctly from the mock model's annotations + # Path: params -> kernel -> spec + self.assertEqual( + annotations.params.kernel.get_value(), + PartitionSpec( + "model", + ), + ) + + def test_shard_optimizer_over_data(self): + """Verifies that 'data' is added to optimizer sharding using the real utility.""" + self.config.shard_optimizer_over_data = True + + _, annotations, _ = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Original Pspec for optimizer was PartitionSpec(None). + # add_data_to_sharding should find that dim 0 is compatible with mesh 'data' + # and update it to PartitionSpec(('data',)). + opt_spec = annotations.optimizer.get_value() + + # Verify 'data' is now in the spec + self.assertEqual(opt_spec, PartitionSpec(("data", "model"))) + + def test_optimizer_host_offload(self): + """Verifies that optimizer memory is moved to host when configured.""" + self.config.optimizer_memory_host_offload = True + + _, _, shardings = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Optimizer state should be pinned to host + opt_sharding = shardings.optimizer.get_value() + self.assertEqual(opt_sharding.memory_kind, "pinned_host") + + # Params should still be on default memory (usually device) + param_sharding = shardings.params.kernel.get_value() + self.assertNotEqual(param_sharding.memory_kind, "pinned_host") + + def test_parameter_host_offload(self): + """Verifies that parameter memory is moved to host when configured.""" + self.config.parameter_memory_host_offload = True + self.config.param_scan_axis = 0 + + _, _, shardings = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Parameters should be pinned to host + param_sharding = shardings.params.kernel.get_value() + self.assertEqual(param_sharding.memory_kind, "pinned_host") + + def test_invalid_init_fn(self): + """Ensures function raises error if no init function is provided.""" + with self.assertRaises(AssertionError): + maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, None) + + +class TestGetNnxNamedShardingWithScanAxis(unittest.TestCase): + """Unit tests for get_nnx_named_sharding_with_scan_axis covering every branch. + + The helper resolves a NamedSharding for each NNX Variable and — unlike + flax.nnx.spmd.get_var_pspec — also inserts the `nnx.PARTITION_NAME` axis at + `param_scan_axis` when scanned-layers metadata is present. + """ + + def setUp(self): + # Mesh needs to contain every axis name the tests reference in partition specs. + self.mesh = Mesh(np.array(jax.local_devices()[:1]).reshape(1, 1), ("fsdp", "layers")) + + def _build_state(self, **variables): + """Wrap a dict of {key: nnx.Variable} in an nnx.State for tree traversal.""" + return nnx.State(variables) + + def _run(self, state): + return maxtext_utils.get_nnx_named_sharding_with_scan_axis(state, self.mesh) + + def test_scan_axis_inserted_at_param_scan_axis(self): + """When PARTITION_NAME is present, the partition name is inserted at `param_scan_axis`.""" + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((3, 4, 8)), + out_sharding=(None, "fsdp"), + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 1}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertIsInstance(result_sharding, NamedSharding) + # 'layers' must be inserted at position 1 (param_scan_axis=1). + self.assertEqual(result_sharding.spec, PartitionSpec(None, "layers", "fsdp")) + + def test_scan_axis_not_inserted_when_already_present(self): + """Guard against double-insertion when partition_name is already in out_sharding.""" + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((2, 2, 2)), + out_sharding=("layers", None, "fsdp"), + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 0}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # 'layers' must appear exactly once — the same PartitionSpec we started with. + self.assertEqual(result_sharding.spec, PartitionSpec("layers", None, "fsdp")) + + def test_masked_node_preserved_as_is(self): + """Values without a .shape attribute (e.g., optax.MaskedNode) are returned unchanged.""" + masked = nnx.Variable(optax.MaskedNode()) + state = self._build_state(masked=masked) + out = self._run(state) + # The leaf must be the original Variable, not a NamedSharding wrapper. + self.assertIs(out["masked"], masked) + + def test_empty_out_sharding_yields_empty_pspec(self): + """A Variable without any sharding metadata should resolve to PartitionSpec().""" + with jax.set_mesh(self.mesh): + # No out_sharding/sharding_names/sharding metadata → falsy → PartitionSpec() + v = nnx.Param(jnp.zeros((4,))) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertIsInstance(result_sharding, NamedSharding) + self.assertEqual(result_sharding.spec, PartitionSpec()) + + def test_string_out_sharding_is_wrapped_into_tuple(self): + """A single-string out_sharding value should still produce a valid PartitionSpec.""" + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((4,)), + out_sharding="fsdp", + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 0}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # The single string 'fsdp' is turned into a list, and 'layers' is prepended. + self.assertEqual(result_sharding.spec, PartitionSpec("layers", "fsdp")) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/muon_utils_test.py b/tests/unit/muon_utils_test.py new file mode 100644 index 0000000000..9570257eee --- /dev/null +++ b/tests/unit/muon_utils_test.py @@ -0,0 +1,224 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for muon_utils.py.""" + +# pylint: disable=protected-access + +import io +import contextlib +import unittest +from unittest import mock + +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax import nnx +from optax.contrib._muon import MuonDimensionNumbers as mdn + +from maxtext.utils import muon_utils + + +class TestIsPathContainAny(unittest.TestCase): + """Tests for _is_path_contain_any helper.""" + + def test_returns_true_when_any_element_in_path(self): + self.assertTrue(muon_utils._is_path_contain_any(("bias", "scale"), ("decoder", "bias"))) + + def test_returns_false_when_no_element_in_path(self): + self.assertFalse(muon_utils._is_path_contain_any(("bias", "scale"), ("decoder", "kernel"))) + + def test_empty_tuples_returns_false(self): + self.assertFalse(muon_utils._is_path_contain_any((), ("decoder", "kernel"))) + + +class TestTransformLogic(unittest.TestCase): + """Tests for transform_logic: covers every branch of the mapping.""" + + # --- 1. Exclusions --- + def test_scale_is_excluded(self): + self.assertIsNone(muon_utils.transform_logic(("decoder", "norm", "scale"))) + + def test_bias_is_excluded(self): + self.assertIsNone(muon_utils.transform_logic(("decoder", "dense", "bias"))) + + def test_embedding_is_excluded(self): + self.assertIsNone(muon_utils.transform_logic(("token_embedder", "embedding"))) + + def test_logits_dense_is_excluded(self): + self.assertIsNone(muon_utils.transform_logic(("decoder", "logits_dense", "kernel"))) + + # --- 2.1 MoE --- + def test_moe_wi_0_uses_last_two_axes(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "MoeBlock_0", "wi_0")), mdn((-2,), (-1,))) + + def test_moe_wi_1_uses_last_two_axes(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "MoeBlock_0", "wi_1")), mdn((-2,), (-1,))) + + def test_moe_wo_uses_last_two_axes(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "MoeBlock_0", "wo")), mdn((-2,), (-1,))) + + def test_moe_gate_falls_through_to_standard(self): + # 'gate' is inside MoeBlock_0 but not one of (wi_0, wi_1, wo) → standard. + self.assertEqual(muon_utils.transform_logic(("decoder", "MoeBlock_0", "gate", "kernel")), mdn((0,), (-1,))) + + # --- 2.2 Self-attention --- + def test_self_attention_out_projection(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "out")), mdn((0, -2), (-1,))) + + def test_self_attention_query_projection(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "query")), mdn((0,), (-2, -1))) + + def test_self_attention_key_projection(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "key")), mdn((0,), (-2, -1))) + + def test_self_attention_value_projection(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "value")), mdn((0,), (-2, -1))) + + def test_self_attention_wq_b_and_wkv_b(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "wq_b")), mdn((0,), (-2, -1))) + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "wkv_b")), mdn((0,), (-2, -1))) + + def test_self_attention_mla_wq_a_is_excluded_from_special(self): + # wq_a / wkv_a are MLA down-projections; they fall through the self_attention branch + # without matching anything, so the function returns the default standard mdn((0,), (-1,)). + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "wq_a")), mdn((0,), (-1,))) + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "wkv_a")), mdn((0,), (-1,))) + + # --- 3. Standard --- + def test_standard_weight(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "mlp", "kernel")), mdn((0,), (-1,))) + + +class TestGetTransformTree(unittest.TestCase): + """Tests for get_transform_tree: recursive dict walk that applies transform_logic.""" + + def test_nested_dict_is_walked(self): + tree = {"decoder": {"self_attention": {"out": 0}, "mlp": {"kernel": 0}}} + result = muon_utils.get_transform_tree(tree) + self.assertEqual(result["decoder"]["self_attention"]["out"], mdn((0, -2), (-1,))) + self.assertEqual(result["decoder"]["mlp"]["kernel"], mdn((0,), (-1,))) + + def test_excluded_leaves_become_none(self): + tree = {"decoder": {"norm": {"scale": 0}}} + self.assertIsNone(muon_utils.get_transform_tree(tree)["decoder"]["norm"]["scale"]) + + def test_non_dict_leaf_at_root_returns_transform(self): + # If the tree itself is a leaf, path=() and transform_logic returns the standard mdn. + self.assertEqual(muon_utils.get_transform_tree(0), mdn((0,), (-1,))) + + +class _MoeLikeNNXModel(nnx.Module): + """Small NNX model whose param paths exercise the NNX branch of get_muon_weight_dimension_numbers.""" + + def __init__(self, rngs): + # Names are chosen so transform_logic matches each of the three meaningful branches: + # - w_standard: default mdn + # - self_attention_out: attention-out mdn + # - scale: excluded (None) + self.w_standard = nnx.Param(jnp.ones((4, 8))) + self.self_attention_out = nnx.Param(jnp.ones((4, 8))) + self.scale = nnx.Param(jnp.ones((8,))) + + +class TestGetMuonWeightDimensionNumbersNNX(unittest.TestCase): + """Covers the NNX branch of get_muon_weight_dimension_numbers (isinstance(model, nnx.Module)).""" + + def setUp(self): + self.model = _MoeLikeNNXModel(rngs=nnx.Rngs(0)) + + def test_nnx_model_dispatches_to_tree_map_with_path(self): + """NNX branch should produce an nnx.State tree with transform_logic applied per leaf.""" + result = muon_utils.get_muon_weight_dimension_numbers(self.model, config=None) + + # Result is an nnx.State whose top-level keys mirror the model attributes. + self.assertIn("w_standard", result) + self.assertIn("self_attention_out", result) + self.assertIn("scale", result) + + # NNX Variables are walked by jax.tree_util.tree_map_with_path, so the returned + # tree replaces each Variable's value with transform_logic(path_strings). + # 'scale' matches the exclusion branch → value is None. + self.assertIsNone(result["scale"].get_value()) + # 'w_standard' does not trigger any special rule → standard mdn. + self.assertEqual(result["w_standard"].get_value(), mdn((0,), (-1,))) + + def test_nnx_verbose_path_executes_print_debug(self): + """verbose=True should also execute _print_structure_debug without raising.""" + buf = io.StringIO() + with contextlib.redirect_stdout(buf): + muon_utils.get_muon_weight_dimension_numbers(self.model, config=None, verbose=True) + self.assertIn("Model Structure", buf.getvalue()) + self.assertIn("Muon Dimension Numbers", buf.getvalue()) + + +class TestGetMuonWeightDimensionNumbersLinen(unittest.TestCase): + """Covers the Linen branch of get_muon_weight_dimension_numbers.""" + + def test_linen_branch_uses_get_abstract_param(self): + """Linen models dispatch to maxtext_utils.get_abstract_param + get_transform_tree.""" + # Build a Linen nn.Module so isinstance(model, nnx.Module) is False. + + class LinenStub(nn.Module): + + @nn.compact + def __call__(self, x): + return x + + model = LinenStub() + + # Mock the heavy get_abstract_param call with a pre-shaped dict that exercises + # both a standard weight path and an excluded path. + fake_abstract_param = { + "params": { + "self_attention": {"out": object()}, + "norm": {"scale": object()}, + }, + } + + with mock.patch.object(muon_utils.maxtext_utils, "get_abstract_param", return_value=fake_abstract_param): + result = muon_utils.get_muon_weight_dimension_numbers(model, config=mock.MagicMock()) + + self.assertEqual(result["params"]["self_attention"]["out"], mdn((0, -2), (-1,))) + self.assertIsNone(result["params"]["norm"]["scale"]) + + +class TestPrintStructureDebug(unittest.TestCase): + """Covers both branches of get_leaf_info inside _print_structure_debug.""" + + def test_handles_logically_partitioned_leaf(self): + """Linen leaves are nn.LogicallyPartitioned; the helper should return {shape, names}.""" + leaf = nn.LogicallyPartitioned(value=jax.ShapeDtypeStruct((4, 8), jnp.float32), names=("embed", "mlp")) + tree = {"params": {"kernel": leaf}} + + buf = io.StringIO() + with contextlib.redirect_stdout(buf): + muon_utils._print_structure_debug(tree, muon_weight_dimension_numbers={"params": {"kernel": mdn((0,), (-1,))}}) + out = buf.getvalue() + self.assertIn("(4, 8)", out) + self.assertIn("embed", out) + + def test_handles_shape_dtype_struct_leaf(self): + """NNX abstract leaves are ShapeDtypeStruct directly; the helper should return {shape}.""" + tree = {"kernel": jax.ShapeDtypeStruct((16, 32), jnp.float32)} + + buf = io.StringIO() + with contextlib.redirect_stdout(buf): + muon_utils._print_structure_debug(tree, muon_weight_dimension_numbers={"kernel": mdn((0,), (-1,))}) + out = buf.getvalue() + self.assertIn("(16, 32)", out) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/nnx_decoders_test.py b/tests/unit/nnx_decoders_test.py index acff8afe23..2525a181f1 100644 --- a/tests/unit/nnx_decoders_test.py +++ b/tests/unit/nnx_decoders_test.py @@ -31,7 +31,13 @@ from flax import nnx from jax.sharding import Mesh -from maxtext.common.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, DecoderBlockType +from maxtext.common.common_types import ( + DECODING_ACTIVE_SEQUENCE_INDICATOR, + MODEL_MODE_PREFILL, + MODEL_MODE_TRAIN, + DecoderBlockType, + MultimodalInput, +) from maxtext.configs import pyconfig from maxtext.layers import linears from maxtext.layers.attentions import Attention @@ -573,6 +579,71 @@ def test_logits_are_finite(self): ) self.assertTrue(jnp.all(jnp.isfinite(logits))) + def test_multimodal_input_forwarded_to_apply_embedding(self): + """`multimodal_input` must reach `_apply_embedding` as the original struct. + + `NNXDecoder.__call__` takes a `MultimodalInput` struct and hands it to + `_apply_embedding`, which is the layer that actually unpacks the fields + and merges the embeddings. This test stubs `_apply_embedding` to capture + the forwarded struct without running the real embedding path (the test + config has `use_multimodal=False`). + """ + ids, segment_ids, positions = self._make_token_inputs() + + # Distinct sentinels so each field can be traced independently. + sentinel_img_emb = jnp.full((1, 1), 11.0) + sentinel_img_mask = jnp.full((1, 1), 22.0) + sentinel_aud_emb = jnp.full((1, 1), 33.0) + sentinel_aud_mask = jnp.full((1, 1), 44.0) + sentinel_bidir = jnp.full((1, 1), 55.0) + + mm_input = MultimodalInput( + image_embeddings=sentinel_img_emb, + image_masks=sentinel_img_mask, + audio_embeddings=sentinel_aud_emb, + audio_masks=sentinel_aud_mask, + bidirectional_mask=sentinel_bidir, + ) + + captured = {} + + def fake_apply_embedding( + _shared_embedding, + _ids, + _positions, + _deterministic, + _model_mode, + multimodal_input=None, + ): + captured["multimodal_input"] = multimodal_input + batch = self.cfg.global_batch_size_to_train_on + seq_len = self.cfg.max_target_length + emb_dim = self.cfg.emb_dim + return jnp.zeros((batch, seq_len, emb_dim), dtype=self.cfg.dtype) + + self.decoder._apply_embedding = fake_apply_embedding # pylint: disable=protected-access + try: + self.decoder( + self.shared_embedding, + ids, + positions, + decoder_segment_ids=segment_ids, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + multimodal_input=mm_input, + ) + finally: + # NNX modules bind attributes statefully; remove the override to avoid leaking. + del self.decoder._apply_embedding # pylint: disable=protected-access + + forwarded = captured["multimodal_input"] + self.assertIsNotNone(forwarded) + self.assertTrue(jnp.array_equal(forwarded.image_embeddings, sentinel_img_emb)) + self.assertTrue(jnp.array_equal(forwarded.image_masks, sentinel_img_mask)) + self.assertTrue(jnp.array_equal(forwarded.audio_embeddings, sentinel_aud_emb)) + self.assertTrue(jnp.array_equal(forwarded.audio_masks, sentinel_aud_mask)) + self.assertTrue(jnp.array_equal(forwarded.bidirectional_mask, sentinel_bidir)) + def test_different_random_seeds_produce_different_logits(self): """Two randomly-initialised decoders should not produce identical logits.""" cfg = self.cfg diff --git a/tests/unit/optimizers_test.py b/tests/unit/optimizers_test.py index 44623f24f3..5194719ce2 100644 --- a/tests/unit/optimizers_test.py +++ b/tests/unit/optimizers_test.py @@ -15,19 +15,19 @@ """ Unit tests for all optimizers. """ import re import unittest -from unittest.mock import patch +from unittest.mock import patch, MagicMock import jax import optax import jax.numpy as jnp import pytest from absl.testing import parameterized +from flax import nnx from optax.contrib import MuonDimensionNumbers as mdn from maxtext.configs import pyconfig from maxtext.optimizers import optimizers -from maxtext.utils import maxtext_utils -from maxtext.utils.muon_utils import get_model_mdn +from maxtext.utils import maxtext_utils, muon_utils from tests.utils.test_helpers import get_test_config_path from typing import NamedTuple @@ -49,6 +49,7 @@ DEEPSEEK2_DIMENSION_NUMBER = { "params": { "decoder": { + "decoder_norm": {"scale": None}, "dense_layers": { "mlp": { "wi_0": {"kernel": mdn((0,), (-1,))}, @@ -57,6 +58,7 @@ }, **_DEEPSEEK2_ATTENTION, }, + "logits_dense": {"kernel": None}, "moe_layers": { "DeepSeekMoeBlock_0": { "MoeBlock_0": { @@ -73,8 +75,6 @@ }, **_DEEPSEEK2_ATTENTION, }, - "decoder_norm": {"scale": None}, - "logits_dense": {"kernel": None}, }, "token_embedder": {"embedding": None}, } @@ -99,6 +99,7 @@ DEEPSEEK3_DIMENSION_NUMBER = { "params": { "decoder": { + "decoder_norm": {"scale": None}, "dense_layers": { "mlp": { "wi_0": {"kernel": mdn((0,), (-1,))}, @@ -107,6 +108,7 @@ }, **_DEEPSEEK3_ATTENTION, }, + "logits_dense": {"kernel": None}, "moe_layers": { "DeepSeekMoeBlock_0": { "MoeBlock_0": { @@ -123,8 +125,6 @@ }, **_DEEPSEEK3_ATTENTION, }, - "decoder_norm": {"scale": None}, - "logits_dense": {"kernel": None}, }, "token_embedder": {"embedding": None}, } @@ -243,7 +243,7 @@ def test_model_integration(self, model_name, expected_output): Initializes the specified MaxText model and asserts that the generated Muon dimension numbers match the hardcoded reference. """ - actual_output = get_model_mdn(model_name, scan_layers=True) + actual_output = muon_utils.get_model_mdn(model_name, scan_layers=True, pure_nnx=False) self.assertEqual(actual_output, expected_output) @@ -483,5 +483,105 @@ def test_no_skip_without_kwargs(self): self.assertEqual(opt_state["count"], 0) +class TestMuonLogic(unittest.TestCase): + """Tests the granular path transformation functions.""" + + def test_is_path_contain_any(self): + # pylint: disable=protected-access + self.assertTrue(muon_utils._is_path_contain_any(("a", "b"), ("x", "a", "z"))) + self.assertFalse(muon_utils._is_path_contain_any(("a", "b"), ("x", "y", "z"))) + + def test_transform_logic_exclusions(self): + self.assertIsNone(muon_utils.transform_logic(("layer_0", "bias"))) + self.assertIsNone(muon_utils.transform_logic(("layer_0", "scale"))) + self.assertIsNone(muon_utils.transform_logic(("embedding", "kernel"))) + + def test_transform_logic_moe(self): + path = ("layers_0", "MoeBlock_0", "wi_0") + result = muon_utils.transform_logic(path) + self.assertEqual(result.reduction_axis, (-2,)) + self.assertEqual(result.output_axis, (-1,)) + + def test_transform_logic_attention(self): + path_out = ("layers_0", "self_attention", "out", "kernel") + self.assertEqual(muon_utils.transform_logic(path_out), mdn((0, -2), (-1,))) + + path_q = ("layers_0", "self_attention", "query", "kernel") + self.assertEqual(muon_utils.transform_logic(path_q), mdn((0,), (-2, -1))) + + def test_get_transform_tree(self): + fake_tree = {"params": {"layer_0": {"kernel": "leaf", "bias": "leaf"}, "MoeBlock_0": {"wi_0": "leaf"}}} + result = muon_utils.get_transform_tree(fake_tree) + self.assertEqual(result["params"]["layer_0"]["kernel"], mdn((0,), (-1,))) + self.assertIsNone(result["params"]["layer_0"]["bias"]) + + def test_get_muon_weight_dimension_numbers_nnx(self): + """Verifies dimension extraction for stateful NNX modules.""" + + class MockNNXModel(nnx.Module): + """Mock NNX Module.""" + + def __init__(self, rngs: nnx.Rngs): + # 1. Standard layer + self.layer1 = nnx.Linear(2, 4, rngs=rngs) + + # 2. MoE specific naming to trigger transform logic. + # The logic expects "MoeBlock_0" AND "wi_0"/"wi_1"/"wo" in the path. + # We nest the linear layer to create the path: ('MoeBlock_0', 'wi_0', 'kernel') + self.MoeBlock_0 = nnx.Module() + self.MoeBlock_0.wi_0 = nnx.Linear(4, 2, rngs=rngs) + + # 3. Exclusion case (scaler/scale) + self.scale = nnx.Param(jnp.ones((1,))) + + # Use eval_shape to create an abstract version of the model. + model = nnx.eval_shape(lambda: MockNNXModel(rngs=nnx.Rngs(0))) + config = MagicMock() + + # Extract dimension numbers using the NNX path in muon_utils + result = muon_utils.get_muon_weight_dimension_numbers(model, config) + + # Verify standard weight path: ('layer1', 'kernel') -> default (0,) + self.assertEqual(result.layer1.kernel.value, mdn((0,), (-1,))) + + # Verify MoE weight path: ('MoeBlock_0', 'wi_0', 'kernel') -> (-2,) + self.assertEqual(result.MoeBlock_0.wi_0.kernel.value, mdn((-2,), (-1,))) + + # Verify exclusion (scalar/scale) + self.assertIsNone(result.scale.value) + + def test_verbose_output_nnx(self): + """Covers lines 128 and 135-154: _print_structure_debug via verbose=True with NNX model.""" + + class SimpleNNXModel(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 4, rngs=rngs) + + model = nnx.eval_shape(lambda: SimpleNNXModel(rngs=nnx.Rngs(0))) + config = MagicMock() + muon_utils.get_muon_weight_dimension_numbers(model, config, verbose=True) + + def test_nnx_deepseek_attention_logic(self): + """Simulates a DeepSeek-like attention structure in NNX.""" + + class DeepSeekAttention(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.self_attention = nnx.Module() + self.self_attention.query = nnx.Linear(8, 8, rngs=rngs) + self.self_attention.out = nnx.Linear(8, 8, rngs=rngs) + + # Use eval_shape to create an abstract version of the model. + model = nnx.eval_shape(lambda: DeepSeekAttention(nnx.Rngs(0))) + config = MagicMock() + result = muon_utils.get_muon_weight_dimension_numbers(model, config) + + # Check attention query: [0] -> [-2, -1] + self.assertEqual(result.self_attention.query.kernel.value, mdn((0,), (-2, -1))) + # Check attention out: [0, -2] -> [-1] + self.assertEqual(result.self_attention.out.kernel.value, mdn((0, -2), (-1,))) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/qk_clip_test.py b/tests/unit/qk_clip_test.py index f21722c8a5..c52c8dbb5e 100644 --- a/tests/unit/qk_clip_test.py +++ b/tests/unit/qk_clip_test.py @@ -27,7 +27,7 @@ from maxtext.common.gcloud_stub import is_decoupled from maxtext.layers import attention_mla from maxtext.utils import maxtext_utils -from maxtext.utils.qk_clip_utils import apply_qk_clip, calculate_max_logit_metric +from maxtext.utils.qk_clip_utils import apply_qk_clip, apply_qk_clip_nnx, calculate_max_logit_metric from maxtext.configs import pyconfig from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides @@ -503,5 +503,179 @@ def replace_fn(params=None, **kwargs): ) +class _MockAttentionOp(nnx.Module): + """Holds the sowed `max_logits` intermediate at the same tree depth as production.""" + + def __init__(self, max_logits=None): + if max_logits is not None: + self.max_logits = nnx.Intermediate(max_logits) + + +class _MockMLAAttention(nnx.Module): + """`wq_b.kernel` + `wkv_b.kernel` as `nnx.Param`, plus an `attention_op` child.""" + + def __init__(self, wq_b_kernel, wkv_b_kernel, max_logits=None): + self.wq_b = nnx.Module() + self.wq_b.kernel = nnx.Param(wq_b_kernel) + self.wkv_b = nnx.Module() + self.wkv_b.kernel = nnx.Param(wkv_b_kernel) + self.attention_op = _MockAttentionOp(max_logits) + + +class _MockLayer(nnx.Module): + + def __init__(self, attn): + self.self_attention = attn + + +class _MockDecoder(nnx.Module): + + def __init__(self, layer): + self.layers_0 = layer + + +class _MockTransformer(nnx.Module): + + def __init__(self, decoder): + self.decoder = decoder + + +class _MockState: + """Stand-in for `TrainStateNNX`: only `apply_qk_clip_nnx` accesses `.model`.""" + + def __init__(self, model): + self.model = model + + +def _build_mock_nnx_state(wq_b, wkv_b, max_logits=None): + attn = _MockMLAAttention(wq_b, wkv_b, max_logits) + return _MockState(_MockTransformer(_MockDecoder(_MockLayer(attn)))) + + +def _read_kernels(state): + attn = state.model.decoder.layers_0.self_attention + return attn.wq_b.kernel.value, attn.wkv_b.kernel.value + + +class QKClipNNXTest(unittest.TestCase): + """Mirrors `QKClipTest` against the NNX path.""" + + def _make_config(self, threshold, nope_dim, attention_type="mla"): + Config = namedtuple("Config", ["qk_clip_threshold", "qk_nope_head_dim", "attention_type"]) + return Config(qk_clip_threshold=threshold, qk_nope_head_dim=nope_dim, attention_type=attention_type) + + def test_raises_error_for_non_mla(self): + state = _build_mock_nnx_state(jnp.zeros((1, 1, 2)), jnp.zeros((1, 1, 2))) + config = self._make_config(threshold=10.0, nope_dim=4, attention_type="dot_product") + with self.assertRaisesRegex(ValueError, "QK-Clip is only supported for MLA attention"): + apply_qk_clip_nnx(state, {}, config) + + def test_apply_qk_clip_logic(self): + rng = jax.random.PRNGKey(0) + rng_q, rng_kv = jax.random.split(rng) + wq_b = jax.random.normal(rng_q, (2, 2, 6)) + wkv_b = jax.random.normal(rng_kv, (2, 2, 6)) + state = _build_mock_nnx_state(wq_b, wkv_b) + config = self._make_config(threshold=10.0, nope_dim=4) + + # Head 0 logit 20.0 (>tau, scale=0.5); head 1 logit 5.0 ( scalar loss.""" + + def loss_fn(p, h): + local_model = nnx.merge(graphdef, p, rest, copy=True) + logits = local_model.logits_from_hidden_states(h, True, "train") + one_hot = jax.nn.one_hot(labels, cfg.vocab_size) + xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot, z_loss=cfg.z_loss_multiplier) + return jnp.sum(xent * (segmentation != 0)) + + return loss_fn + + def _tiled_loss_fn(self, cfg, graphdef, rest, hidden_states, labels, segmentation): + """vocab_tiling_nnx_loss closure (params, hidden_states) -> scalar loss.""" + # hidden_states unused at the closure boundary (it comes via h), but kept in the + # signature so the two closures are callable interchangeably. + del hidden_states + data = {"targets": labels, "targets_segmentation": segmentation} + + def loss_fn(p, h): + local_model = nnx.merge(graphdef, p, rest, copy=True) + total_loss, _ = vocab_tiling_nnx_loss(local_model, h, data, cfg, is_train=True) + return total_loss + + return loss_fn + + def _split_and_axes(self, cfg, model): + """Common boilerplate: split the model and bind the logical axis rules.""" + graphdef, params, rest = nnx.split(model, nnx.Param, ...) + return graphdef, params, rest + + def _assert_pytrees_close(self, ref, tiled, msg, *, rtol=None, atol=None): + rtol = self.rtol if rtol is None else rtol + atol = self.atol if atol is None else atol + leaves_close = jax.tree_util.tree_map(lambda x, y: jnp.allclose(x, y, rtol=rtol, atol=atol), ref, tiled) + if not all(jax.tree_util.tree_leaves(leaves_close)): + raise AssertionError(msg) + + def _run_parity(self, *, logits_via_embedding): + """Compare full-vocab xent loss/grads against the tiled custom_vjp path.""" + cfg, model = self._build_cfg_and_model(num_vocab_tiling=4, logits_via_embedding=logits_via_embedding) + hidden_states, labels, segmentation = self._make_inputs(cfg) + graphdef, params, rest = self._split_and_axes(cfg, model) + + ref_loss_fn = self._reference_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + tile_loss_fn = self._tiled_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + ref_loss, ref_grads = jax.value_and_grad(ref_loss_fn)(params, hidden_states) + tile_loss, tile_grads = jax.value_and_grad(tile_loss_fn)(params, hidden_states) + + assert jnp.allclose( + ref_loss, tile_loss, rtol=self.rtol, atol=self.atol + ), f"Losses differ: ref={ref_loss} tiled={tile_loss}" + self._assert_pytrees_close(ref_grads, tile_grads, "Param gradients differ between full-vocab and tiled paths.") + + # ---------- Original parity tests (params gradient under both embedding modes) ---------- + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_non_tied_embedding(self): + """custom_vjp parity for non-tied embedding (separate logits_dense).""" + self._run_parity(logits_via_embedding=False) + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_tied_embedding(self): + """custom_vjp parity when logits share the input embedding table.""" + self._run_parity(logits_via_embedding=True) + + # ---------- Coverage expansion ---------- + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_total_z_loss_value_parity(self): + """The second tuple element (total_z_loss) must match the full-vocab reference.""" + cfg, model = self._build_cfg_and_model(num_vocab_tiling=4) + hidden_states, labels, segmentation = self._make_inputs(cfg) + graphdef, params, rest = self._split_and_axes(cfg, model) + data = {"targets": labels, "targets_segmentation": segmentation} + + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + local_model = nnx.merge(graphdef, params, rest, copy=True) + logits = local_model.logits_from_hidden_states(hidden_states, True, "train") + one_hot = jax.nn.one_hot(labels, cfg.vocab_size) + xent_ref, z_ref = max_utils.cross_entropy_with_logits(logits, one_hot, z_loss=cfg.z_loss_multiplier) + ref_total_loss = jnp.sum(xent_ref * (segmentation != 0)) + ref_total_z_loss = jnp.sum(z_ref * (segmentation != 0)) + + local_model_tile = nnx.merge(graphdef, params, rest, copy=True) + tile_total_loss, tile_total_z_loss = vocab_tiling_nnx_loss( + local_model_tile, hidden_states, data, cfg, is_train=True + ) + + assert jnp.allclose(ref_total_loss, tile_total_loss, rtol=self.rtol, atol=self.atol) + assert jnp.allclose( + ref_total_z_loss, tile_total_z_loss, rtol=self.rtol, atol=self.atol + ), f"total_z_loss differs: ref={ref_total_z_loss} tiled={tile_total_z_loss}" + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_padded_segmentation(self): + """Half-padded segmentation: mask actually changes the loss, and parity holds.""" + cfg, model = self._build_cfg_and_model(num_vocab_tiling=4) + + # Compare unpadded vs padded loss to confirm the mask is wired through. + hs, labels, full_seg = self._make_inputs(cfg, pad_half=False) + _, _, pad_seg = self._make_inputs(cfg, pad_half=True) + graphdef, params, rest = self._split_and_axes(cfg, model) + + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + local_model_full = nnx.merge(graphdef, params, rest, copy=True) + full_loss, _ = vocab_tiling_nnx_loss( + local_model_full, hs, {"targets": labels, "targets_segmentation": full_seg}, cfg, is_train=True + ) + local_model_pad = nnx.merge(graphdef, params, rest, copy=True) + pad_loss, _ = vocab_tiling_nnx_loss( + local_model_pad, hs, {"targets": labels, "targets_segmentation": pad_seg}, cfg, is_train=True + ) + assert float(pad_loss) < float( + full_loss + ), f"Padded loss should be strictly smaller (fewer tokens contribute). full={full_loss} pad={pad_loss}" + + # Now check parity against the full-vocab reference using the padded mask. + ref_loss_fn = self._reference_loss_fn(cfg, graphdef, rest, hs, labels, pad_seg) + tile_loss_fn = self._tiled_loss_fn(cfg, graphdef, rest, hs, labels, pad_seg) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + ref_loss, ref_grads = jax.value_and_grad(ref_loss_fn)(params, hs) + tile_loss, tile_grads = jax.value_and_grad(tile_loss_fn)(params, hs) + assert jnp.allclose(ref_loss, tile_loss, rtol=self.rtol, atol=self.atol) + self._assert_pytrees_close(ref_grads, tile_grads, "Padded-segmentation gradients differ.") + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_grad_over_hidden_states(self): + """Differentiate w.r.t. hidden_states (argnums=1): the second-primal cotangent path of custom_vjp.""" + cfg, model = self._build_cfg_and_model(num_vocab_tiling=4) + hidden_states, labels, segmentation = self._make_inputs(cfg) + graphdef, params, rest = self._split_and_axes(cfg, model) + + ref_loss_fn = self._reference_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + tile_loss_fn = self._tiled_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + ref_grad_h = jax.grad(ref_loss_fn, argnums=1)(params, hidden_states) + tile_grad_h = jax.grad(tile_loss_fn, argnums=1)(params, hidden_states) + + assert ref_grad_h.shape == hidden_states.shape + assert tile_grad_h.shape == hidden_states.shape + assert ref_grad_h.dtype == hidden_states.dtype + assert tile_grad_h.dtype == hidden_states.dtype + assert jnp.allclose(ref_grad_h, tile_grad_h, rtol=self.rtol, atol=self.atol), "grad_hidden_states diverged" + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_bf16_hidden_states(self): + """bf16 hidden_states: the bwd dtype-cast (`y.astype(x.dtype)`) preserves parity at lower precision.""" + cfg, model = self._build_cfg_and_model(num_vocab_tiling=4) + hidden_states, labels, segmentation = self._make_inputs(cfg, dtype=jnp.bfloat16) + graphdef, params, rest = self._split_and_axes(cfg, model) + + ref_loss_fn = self._reference_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + tile_loss_fn = self._tiled_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + ref_loss, ref_grad_h = jax.value_and_grad(ref_loss_fn, argnums=1)(params, hidden_states) + tile_loss, tile_grad_h = jax.value_and_grad(tile_loss_fn, argnums=1)(params, hidden_states) + + # bf16 has ~3 decimal digits — loosen tolerance. + assert jnp.allclose(ref_loss, tile_loss, rtol=5e-2, atol=5e-2) + assert tile_grad_h.dtype == jnp.bfloat16, f"grad cast to primal dtype expected bf16, got {tile_grad_h.dtype}" + assert jnp.allclose(ref_grad_h, tile_grad_h, rtol=5e-2, atol=5e-2) + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_z_loss_zero(self): + """z_loss=0: total_z_loss is exactly zero; loss/grad parity still holds.""" + cfg, model = self._build_cfg_and_model(num_vocab_tiling=4, z_loss_multiplier=0.0) + hidden_states, labels, segmentation = self._make_inputs(cfg) + graphdef, params, rest = self._split_and_axes(cfg, model) + data = {"targets": labels, "targets_segmentation": segmentation} + + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + local_model = nnx.merge(graphdef, params, rest, copy=True) + total_loss, total_z_loss = vocab_tiling_nnx_loss(local_model, hidden_states, data, cfg, is_train=True) + assert float(total_z_loss) == 0.0, f"z_loss=0 but tile path returned {total_z_loss}" + assert float(total_loss) > 0.0 # cross-entropy on random logits should be positive + + ref_loss_fn = self._reference_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + tile_loss_fn = self._tiled_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + ref_loss, ref_grads = jax.value_and_grad(ref_loss_fn)(params, hidden_states) + tile_loss, tile_grads = jax.value_and_grad(tile_loss_fn)(params, hidden_states) + assert jnp.allclose(ref_loss, tile_loss, rtol=self.rtol, atol=self.atol) + self._assert_pytrees_close(ref_grads, tile_grads, "z_loss=0 gradients differ.") + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_other_params_get_zero_grad(self): + """Output-head carve-out invariant: every non-head nnx.Param leaf gets exactly zero grad. + + The PR10.5 carve-out splits the model into head_params (used by + `logits_from_hidden_states`) vs. other_params (transformer layers, etc.), + threading other_params through the custom_vjp as a non-differentiated primal + whose bwd cotangent is `tree_map(jnp.zeros_like, ...)`. This test asserts the + contract: the gradient at every non-head path is exactly 0, and at least one + head path has a non-zero gradient (so it isn't trivially passing because some + bug zeroed everything). + """ + cfg, model = self._build_cfg_and_model(num_vocab_tiling=4) + hidden_states, labels, segmentation = self._make_inputs(cfg) + graphdef, params, rest = self._split_and_axes(cfg, model) + + tile_loss_fn = self._tiled_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + _, tile_grads = jax.value_and_grad(tile_loss_fn)(params, hidden_states) + + head_keywords = ("token_embedder", "shared_embedding", "decoder_norm", "logits_dense") + head_nonzero_seen = False + for path, leaf in jax.tree_util.tree_leaves_with_path(tile_grads): + path_str = jax.tree_util.keystr(path) + is_head = any(kw in path_str for kw in head_keywords) + if is_head: + if jnp.any(leaf != 0): + head_nonzero_seen = True + else: + assert jnp.all(leaf == 0), f"non-head leaf {path_str} has non-zero grad — carve-out is wrong" + assert head_nonzero_seen, "expected at least one head leaf with non-zero grad; got all zeros" + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_num_vocab_tiling_variants(self): + """Different num_vocab_tiling values (2, 4, 8) all produce identical loss + grads.""" + losses = [] + grads_list = [] + for n in (2, 4, 8): + cfg, model = self._build_cfg_and_model(num_vocab_tiling=n) + hidden_states, labels, segmentation = self._make_inputs(cfg) + graphdef, params, rest = self._split_and_axes(cfg, model) + tile_loss_fn = self._tiled_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + loss, grads = jax.value_and_grad(tile_loss_fn)(params, hidden_states) + losses.append(loss) + grads_list.append(grads) + + base_loss = losses[0] + base_grads = grads_list[0] + for n, loss, grads in zip((2, 4, 8), losses, grads_list): + assert jnp.allclose( + loss, base_loss, rtol=self.rtol, atol=self.atol + ), f"num_vocab_tiling={n}: loss diverges from n=2 baseline ({loss} vs {base_loss})" + self._assert_pytrees_close(base_grads, grads, f"num_vocab_tiling={n}: grads diverge from n=2 baseline.") diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index 6be542d1ff..3dbb7307ee 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -1013,10 +1013,6 @@ def test_muon(self, muon_consistent_rms): @pytest.mark.cpu_only def test_vocab_tiling_bf16(self): """test vocab_tiling when weight_dtype=bfloat16""" - cfg = pyconfig.initialize([None, get_test_config_path()]) - if getattr(cfg, "enable_nnx", False): - pytest.skip("Vocab tiling not supported on NNX.") - compiled_trainstep_file = "/tmp/test_vocab_tiling_bf16.pickle" train_compile_main( ( @@ -1053,3 +1049,31 @@ def test_qwen3_5(self): "use_tokamax_splash=True", ) ) + + @pytest.mark.cpu_only + def test_vocab_tiling_bf16_nnx(self): + """AOT compile vocab tiling on the NNX path (vocab_tiling_nnx_loss + custom_vjp). + + Forward-compatibility for PR11: once NNX defaults flip, the existing + `test_vocab_tiling_bf16` will exercise this same path via defaults, but here we + set the flags explicitly so the NNX AOT path is covered today regardless of + default values. + """ + compiled_trainstep_file = "/tmp/test_vocab_tiling_bf16_nnx.pickle" + train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-8", + "compile_topology_num_slices=1", + "base_num_decoder_layers=2", + "per_device_batch_size=2", + "max_target_length=1024", + "num_vocab_tiling=4", + "weight_dtype=bfloat16", + "pure_nnx=true", + "enable_nnx=true", + "pure_nnx_decoder=true", + ) + ) diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py new file mode 100644 index 0000000000..4340d4e22a --- /dev/null +++ b/tests/unit/train_nnx_test.py @@ -0,0 +1,222 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX paths of loss_fn / train_step / eval_step in pre_train.train. + +These tests exercise the NNX branches without standing up a real Transformer or +data pipeline. We use a tiny NNX module that mimics the call signature the +production loss_fn uses (decoder_input_tokens, decoder_positions, ...). +""" + +import unittest +from dataclasses import dataclass + +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.layers import train_state_nnx +from maxtext.trainers.pre_train import train as pre_train + + +@dataclass +class _Cfg: + """Subset of HyperParameters used by loss_fn / train_step / eval_step.""" + + micro_batch_size_to_train_on: int = 2 + micro_batch_size_to_eval_on: int = 2 + vocab_size: int = 8 + z_loss_multiplier: float = 0.0 + enable_dropout: bool = False + use_multimodal: bool = False + use_indexer: bool = False + indexer_sparse_training: bool = False + indexer_loss_scaling_factor: float = 0.0 + num_vocab_tiling: int = 1 + num_experts: int = 1 + routed_bias: bool = False + routed_bias_update_rate: float = 0.0 + mtp_num_layers: int = 0 + mtp_eval_target_module: int = 0 + use_dpo: bool = False + use_qk_clip: bool = False + use_tunix_gradient_accumulation: bool = False + gradient_accumulation_steps: int = 1 + shard_optimizer_over_data: bool = False + optimizer_memory_host_offload: bool = False + parameter_memory_host_offload: bool = False + gradient_clipping_threshold: float = 0.0 + grad_dtype: jnp.dtype = jnp.float32 + record_internal_nn_metrics: bool = False + skip_step_on_spikes: bool = False + shard_mode: int = 0 # ShardMode.AUTO + weight_sparsity_n: int = 0 + weight_sparsity_m: int = 0 + + +class _TinyDecoder(nnx.Module): + """Mimics NNXDecoder.__call__ enough for loss_fn to run end-to-end. + + Returns logits of shape [batch, seq_len, vocab_size]. Ignores all multimodal + / dropout / target arguments — they exist only to match the keyword signature. + """ + + def __init__(self, vocab_size: int, hidden: int, rngs: nnx.Rngs): + self.embed = nnx.Embed(vocab_size, hidden, rngs=rngs) + self.proj = nnx.Linear(hidden, vocab_size, rngs=rngs) + + def __call__( + self, + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=None, + encoder_images=None, + encoder_image_masks=None, + enable_dropout=False, + decoder_target_tokens=None, + decoder_target_mask=None, + ): + del decoder_positions, decoder_segment_ids, encoder_images, encoder_image_masks + del enable_dropout, decoder_target_tokens, decoder_target_mask + h = self.embed(decoder_input_tokens) + return self.proj(h) + + +def _make_data(batch=2, seq=4, vocab=8): + return { + "inputs": jnp.zeros((batch, seq), dtype=jnp.int32), + "inputs_position": jnp.broadcast_to(jnp.arange(seq), (batch, seq)), + "inputs_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + "targets": jnp.zeros((batch, seq), dtype=jnp.int32), + "targets_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + } + + +def _build_state(): + cfg = _Cfg() + model = _TinyDecoder(cfg.vocab_size, hidden=4, rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, optax.sgd(0.01), wrt=nnx.Param) + ts = train_state_nnx.TrainStateNNX(model, optimizer) + return cfg, ts + + +class TestLossFnNNX(unittest.TestCase): + """Cover the NNX branch of loss_fn (lines 178-213).""" + + def test_returns_loss_and_full_aux_dict(self): + cfg, ts = _build_state() + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + loss, aux = pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=True) + self.assertTrue(jnp.isfinite(loss)) + # Aux schema relied on by train_step / eval_step / GA. + for key in ( + "intermediate_outputs", + "xent_sum", + "z_loss", + "total_weights", + "moe_lb_loss", + "indexer_loss", + "moe_bias_updates", + "mtp_loss", + ): + self.assertIn(key, aux) + # NNX intermediates are captured into a pure-dict snapshot, then logits attached. + self.assertIsInstance(aux["intermediate_outputs"], dict) + self.assertIn("logits", aux["intermediate_outputs"]) + + def test_eval_mode_truncates_to_eval_micro_batch(self): + cfg, ts = _build_state() + cfg.micro_batch_size_to_eval_on = 1 + data = _make_data(batch=2, vocab=cfg.vocab_size) + loss, aux = pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=False) + self.assertTrue(jnp.isfinite(loss)) + # eval truncated batch to 1 → total_weights = seq_len * 1 + self.assertEqual(int(aux["total_weights"]), data["targets_segmentation"].shape[1]) + + def test_indexer_dense_warmup_skips_xent(self): + cfg, ts = _build_state() + cfg.use_indexer = True + cfg.indexer_sparse_training = False + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + loss, aux = pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=True) + # When dense warm-up is active the loss_fn skips the main loss entirely. + self.assertEqual(float(aux["xent_sum"]), 0.0) + self.assertEqual(float(loss), 0.0) + + +class TestTrainStepNNX(unittest.TestCase): + """Cover the NNX branch of train_step (the diff_wrapper / nnx.update path).""" + + def test_train_step_returns_state_and_metrics(self): + cfg, ts = _build_state() + state_graphdef, state_pure = nnx.split(ts) + + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + new_state, metrics = pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + # NNX path returns nnx.State (via nnx.state(new_state)) and a metrics dict. + self.assertIsInstance(new_state, nnx.State) + self.assertIn("scalar", metrics) + self.assertIn("learning/loss", metrics["scalar"]) + self.assertIn("learning/grad_norm", metrics["scalar"]) + self.assertIn("learning/param_norm", metrics["scalar"]) + self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) + + def test_train_step_increments_optimizer_step(self): + cfg, ts = _build_state() + state_graphdef, state_pure = nnx.split(ts) + pre_step = int(state_pure.optimizer.step.get_value()) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + new_state, _ = pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + self.assertEqual(int(new_state.optimizer.step.get_value()), pre_step + 1) + + def test_train_step_with_gradient_clipping(self): + """The clipping branch (gradient_clipping_threshold > 0) must run without raising.""" + cfg, ts = _build_state() + cfg.gradient_clipping_threshold = 1.0 + state_graphdef, state_pure = nnx.split(ts) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + new_state, metrics = pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + self.assertIsInstance(new_state, nnx.State) + self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) + + +class TestEvalStepNNX(unittest.TestCase): + """Cover the NNX branch of eval_step (lines 568-570).""" + + def test_eval_step_returns_metrics(self): + cfg, ts = _build_state() + state_graphdef, state_pure = nnx.split(ts) + data = _make_data(batch=cfg.micro_batch_size_to_eval_on, vocab=cfg.vocab_size) + metrics = pre_train.eval_step(state_graphdef, cfg, state_pure, data) + self.assertIn("scalar", metrics) + for key in ( + "evaluation/loss", + "evaluation/total_loss", + "evaluation/total_weights", + "evaluation/moe_lb_loss", + ): + self.assertIn(key, metrics["scalar"]) + # NNX path must NOT include DPO eval metric. + self.assertNotIn("evaluation/dpo_reward_accuracy", metrics["scalar"]) + self.assertTrue(jnp.isfinite(metrics["scalar"]["evaluation/loss"])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_state_nnx_checkpoint_test.py b/tests/unit/train_state_nnx_checkpoint_test.py new file mode 100644 index 0000000000..0f7dc22d68 --- /dev/null +++ b/tests/unit/train_state_nnx_checkpoint_test.py @@ -0,0 +1,412 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TrainStateNNX checkpoint tests.""" + +import pathlib +import tempfile +import shutil +from types import SimpleNamespace +from unittest import mock + +import unittest +import jax +import jax.numpy as jnp +from flax import nnx, serialization +from flax import linen as nn +from flax.training import train_state +import optax +import orbax.checkpoint as ocp + +from maxtext.common import checkpointing +from maxtext.layers import train_state_nnx + + +class MockModel(nnx.Module): + """A simple model for checkpoint testing.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + def __call__(self, x): + return self.linear(x) + + +class LinenMockModel(nn.Module): + """The Linen equivalent of the MockModel.""" + + @nn.compact + def __call__(self, x): + # We name the layer 'linear' to match the attribute name in the NNX MockModel + return nn.Dense(features=1, name="linear")(x) + + +def _replicate_for_orbax(pytree): + """Give every array a replicated NamedSharding so Orbax can save in multi-host CI. + + Orbax refuses arrays with the default SingleDeviceSharding when + jax.process_count() > 1. Putting each leaf on a NamedSharding over the local + mesh works in both single- and multi-host environments without changing + values. + """ + mesh = jax.sharding.Mesh(jax.devices(), ("x",)) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + return jax.tree.map(lambda x: jax.device_put(x, sharding) if isinstance(x, jax.Array) else x, pytree) + + +class TestTrainStateNNXCheckpoint(unittest.TestCase): + """Class to test NNX checkpoint.""" + + def setUp(self): + self.rngs = nnx.Rngs(0) + self.model = MockModel(rngs=self.rngs) + + # Setup a chained optimizer: Gradient Clipping -> Adam + # Note: optax.adam is also a chain (scale_by_adam + scale_by_learning_rate). + # This creates a nested state structure: (EmptyState, (ScaleByAdamState, EmptyState)) + self.tx = optax.chain( + optax.clip_by_global_norm(max_norm=1.0), + optax.adam(1e-3), + ) + + def test_checkpoint_structure(self): + """Ensures the state object contains both model and optimizer keys.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # We use .to_pure_dict() to simulate the format stored in a checkpoint. + # This converts nnx.Variable/State objects into raw arrays and dictionaries. + full_state = nnx.state(state).to_pure_dict() + + # 1. Verify Top-level Keys + self.assertIn("model", full_state) + self.assertIn("optimizer", full_state) + + # 2. Verify Optimizer Internal Structure + opt_inner_state = full_state["optimizer"]["opt_state"] + + # Because we used optax.chain(clip, adam), index 0 is clip, index 1 is adam. + # Since adam is also a chain, index 1 is itself a dictionary/tuple representation. + # Adam's momentum (mu/nu) is in the first element of its own sub-chain. + adam_component = opt_inner_state[1][0] + + self.assertIn("mu", adam_component, "Adam 'mu' buffer not found in pure dict state.") + self.assertIn("nu", adam_component, "Adam 'nu' buffer not found in pure dict state.") + + # In a pure dict, these are nested dictionaries containing arrays, not NNX objects. + self.assertIsInstance(adam_component["mu"], dict) + self.assertIsInstance(adam_component["nu"], dict) + + # To verify a specific leaf, we navigate the dictionary hierarchy: + self.assertIsInstance(adam_component["mu"]["linear"]["kernel"], jax.Array) + + def test_checkpoint_and_restore(self): + """Verifies that the full state can be captured and restored into a new instance.""" + # 1. Initialize original state and optimizer + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state_original = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # 2. Perform a training step to modify weights and optimizer buffers + def loss_fn(m): + return jnp.mean(m(jnp.ones((1, 2))) ** 2) + + grads = nnx.grad(loss_fn)(state_original.model) + state_original.apply_gradients(grads) + + # Capture state after one step + original_kernel_val = state_original.model.linear.kernel.value + original_step_val = state_original.optimizer.step.value + self.assertEqual(original_step_val, 1) + + # 3. Capture the "Checkpoint" as a pure dictionary + checkpoint_state = nnx.state(state_original).to_pure_dict() + + # 4. Initialize a fresh, different instance + new_rngs = nnx.Rngs(1) + new_model = MockModel(rngs=new_rngs) + new_optimizer = nnx.Optimizer(new_model, self.tx, wrt=nnx.Param) + state_restored = train_state_nnx.TrainStateNNX(new_model, new_optimizer) + + # Check differences before restoration + self.assertEqual(state_restored.optimizer.step.value, 0) + self.assertFalse(jnp.allclose(state_restored.model.linear.kernel.value, original_kernel_val)) + + # 5. Restore the state into the new instance. + # nnx.update supports updating from a pure dictionary. + nnx.update(state_restored, checkpoint_state) + + # 6. Verify restoration + # Check step counter + self.assertEqual(state_restored.optimizer.step.value, original_step_val) + # Check model weights + self.assertTrue(jnp.allclose(state_restored.model.linear.kernel.value, original_kernel_val)) + + # Check that it can still be trained after restoration + new_grads = nnx.grad(loss_fn)(state_restored.model) + state_restored.apply_gradients(new_grads) + self.assertEqual(state_restored.optimizer.step.value, 2) + + def test_restore_from_linen_state(self): + """Verifies a multi-stage migration: Linen CKPT -> Migrate -> NNX CKPT -> Restore.""" + # 1. Setup Linen TrainState (Simulating original training) + linen_model = LinenMockModel() + dummy_input = jnp.ones((1, 2)) + variables = linen_model.init(jax.random.key(42), dummy_input) + + state_linen = train_state.TrainState.create(apply_fn=linen_model.apply, params=variables["params"], tx=self.tx) + + # Perform a step to populate optimizer buffers + grads = jax.tree.map(jnp.ones_like, state_linen.params) + state_linen = state_linen.apply_gradients(grads=grads) + + temp_dir = pathlib.Path(tempfile.mkdtemp()) + try: + # --- PHASE 1: Save Legacy Linen Checkpoint --- + linen_ckpt_dir = temp_dir / "linen_ckpt" + mngr_linen = ocp.CheckpointManager( + linen_ckpt_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + mngr_linen.save(0, args=ocp.args.StandardSave(_replicate_for_orbax(state_linen))) + mngr_linen.wait_until_finished() + + # --- PHASE 2: Read Linen CKPT and Convert to NNX Structure --- + # Load it back without knowing the blueprint (reading as a pure PyTree) + restored_linen_obj = mngr_linen.restore(0) + + # Convert the restored object to a pure dictionary structure. + restored_linen_dict = serialization.to_state_dict(restored_linen_obj) + + # Helper to recursively convert string keys back to integers + # and filter out None values. + def recursive_clean(obj): + if isinstance(obj, dict): + return {int(k) if k.isdigit() else k: recursive_clean(v) for k, v in obj.items() if v is not None} + return obj + + # Converted dict - simple PyTree mapping, no NNX Module initialization needed here. + # This simulates a situation where the conversion logic is blueprint-agnostic. + linen_as_nnx_dict = { + "model": restored_linen_dict["params"], + "optimizer": { + "step": jnp.array(restored_linen_dict["step"]), + "opt_state": recursive_clean(restored_linen_dict["opt_state"]), + }, + } + + # --- PHASE 3: Save as Native NNX Checkpoint --- + nnx_ckpt_dir = temp_dir / "nnx_ckpt" + mngr_nnx = ocp.CheckpointManager( + nnx_ckpt_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + # We save the raw dictionary directly to disk. + mngr_nnx.save(0, args=ocp.args.StandardSave(_replicate_for_orbax(linen_as_nnx_dict))) + mngr_nnx.wait_until_finished() + + # --- PHASE 4: Restore from NNX Checkpoint to target Model --- + nnx_model = MockModel(rngs=nnx.Rngs(0)) + nnx_optimizer = nnx.Optimizer(nnx_model, self.tx, wrt=nnx.Param) + state_nnx = train_state_nnx.TrainStateNNX(nnx_model, nnx_optimizer) + + # We now restore using the nnx.State as a blueprint. This ensures Orbax + # correctly maps the arrays on disk to the model's structural expectation. + blueprint = nnx.state(state_nnx).to_pure_dict() + restored_nnx_pytree = mngr_nnx.restore(0, args=ocp.args.StandardRestore(item=blueprint)) + nnx.update(state_nnx, restored_nnx_pytree) + + # --- PHASE 5: Verification --- + # 1. Verify Step + self.assertEqual(state_nnx.optimizer.step.value, 1) + + # 2. Verify Weights + self.assertTrue(jnp.allclose(state_nnx.model.linear.kernel.value, state_linen.params["linear"]["kernel"])) + + # 3. Verify Chained Optimizer State (Clip at index 0, Adam at index 1) + self.assertEqual(type(state_nnx.optimizer.opt_state[0]), type(state_linen.opt_state[0])) + + # state_linen.opt_state[1] is the Adam chain state. + # state_linen.opt_state[1][0] is the ScaleByAdamState containing 'mu'. + self.assertTrue( + jnp.allclose( + state_nnx.optimizer.opt_state[1][0].mu["linear"]["kernel"], + state_linen.opt_state[1][0].mu["linear"]["kernel"], + ) + ) + + finally: + # Cleanup temporary directory + shutil.rmtree(temp_dir) + + def test_restore_from_checkpoint_model_params(self): + """Verifies that model parameters can be restored from model params only.""" + # 1. Setup mocked parameters manually (no Linen model needed for setup) + # This structure matches the path model.linear.kernel/bias in the NNX MockModel. + mock_params = {"linear": {"kernel": jnp.ones((2, 1)) * 9.0, "bias": jnp.zeros((1,))}} + + # Simplified checkpoint dictionary using hardcoded mocked params as requested + checkpoint_dict = { + "model": mock_params, + } + + temp_dir = pathlib.Path(tempfile.mkdtemp()) + try: + # --- PHASE 1: Save the partial checkpoint --- + mngr = ocp.CheckpointManager( + temp_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + mngr.save(0, args=ocp.args.StandardSave(_replicate_for_orbax(checkpoint_dict))) + mngr.wait_until_finished() + + # --- PHASE 2: Restore into a full TrainStateNNX --- + nnx_model = MockModel(rngs=nnx.Rngs(0)) + nnx_optimizer = nnx.Optimizer(nnx_model, self.tx, wrt=nnx.Param) + state_nnx = train_state_nnx.TrainStateNNX(nnx_model, nnx_optimizer) + + # We use nnx.state to get a full blueprint as a reference. + full_nnx_pure_dict = nnx.state(state_nnx).to_pure_dict() + blueprint = {"model": full_nnx_pure_dict["model"]} + + # If we don't know if the checkpoint on disk has 'optimizer' or not, we simulate + # schema-agnostic restoration by calling restore without a blueprint. + # This avoids Orbax structural mismatch errors while allowing us to see the data. + restored_pytree = mngr.restore(0, args=ocp.args.StandardRestore(item=blueprint)) + + # Use nnx.update to apply the restored data to the stateful NNX object. + # nnx.update is naturally partial: it will update 'model' from the restored dict + # and leave 'optimizer' untouched at its initialized value. + nnx.update(state_nnx, restored_pytree) + + # --- PHASE 3: Verification --- + # Check that weights were restored to the specific mock values + self.assertTrue(jnp.allclose(state_nnx.model.linear.kernel.value, mock_params["linear"]["kernel"])) + # Step remains at its initialized value (0) because it was not in the checkpoint + self.assertEqual(state_nnx.optimizer.step.value, 0) + + # Verify that the optimizer state still exists in the object (initialized) + # even though it was not provided in the checkpoint. + # Adam's state is at index 1 of the chain, and it's a nested structure (tuple). + # We verify that index 0 (ScaleByAdamState) contains the 'mu' State container. + self.assertIsInstance(state_nnx.optimizer.opt_state[1][0].mu, nnx.State) + + finally: + # Cleanup temporary directory + shutil.rmtree(temp_dir) + + +class TestMaybeSaveCheckpointStepAlignment(unittest.TestCase): + """Verify maybe_save_checkpoint's fallback step matches the last completed step. + + When the training loop's final save calls maybe_save_checkpoint without an + explicit `step`, it derives `actual_step` from the state: + - NNX: int(state.optimizer.step) - 1 + - Linen: int(state.step) - 1 + Both TrainStateNNX.apply_gradients (via nnx.Optimizer.update) and Linen + TrainState.apply_gradients increment the counter by 1 per call, so after N + gradient applications the counter is N and the "last completed step" is N-1. + """ + + N_STEPS = 5 + + def setUp(self): + self.tx = optax.adam(1e-3) + + def _build_nnx_state(self, num_steps): + """Build an nnx.State flattened from TrainStateNNX after num_steps gradient applications.""" + model = MockModel(rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(model, optimizer) + + def loss_fn(m): + return jnp.mean(m(jnp.ones((1, 2))) ** 2) + + for _ in range(num_steps): + grads = nnx.grad(loss_fn)(state.model) + state.apply_gradients(grads) + # maybe_save_checkpoint is called with a flat nnx.State in the NNX path + # (train_step returns nnx.state(new_state)). + return nnx.state(state) + + def _build_linen_state(self, num_steps): + """Build a Linen TrainState after num_steps gradient applications.""" + model = LinenMockModel() + variables = model.init(jax.random.key(0), jnp.ones((1, 2))) + state = train_state.TrainState.create(apply_fn=model.apply, params=variables["params"], tx=self.tx) + grads = jax.tree.map(jnp.ones_like, state.params) + for _ in range(num_steps): + state = state.apply_gradients(grads=grads) + return state + + def _invoke_maybe_save(self, state, pure_nnx): + """Call maybe_save_checkpoint with save_checkpoint patched, return {step, state} captured.""" + # checkpoint_period=1 keeps force_ckpt_save False regardless of actual_step. + config = SimpleNamespace(pure_nnx=pure_nnx, checkpoint_period=1, async_checkpointing=False) + mgr = mock.MagicMock() + mgr.reached_preemption.return_value = False + + captured = {} + + def fake_save_checkpoint(_mgr, step, state_arg, *_args, **_kwargs): + captured["step"] = step + captured["state"] = state_arg + return False # no save happened => print_save_message is skipped + + with mock.patch.object(checkpointing, "save_checkpoint", side_effect=fake_save_checkpoint): + checkpointing.maybe_save_checkpoint(mgr, state, config, data_iterator=None, step=None) + return captured + + def test_nnx_final_save_step_is_n_minus_1(self): + state = self._build_nnx_state(self.N_STEPS) + self.assertEqual(int(state.optimizer.step.value), self.N_STEPS) + captured = self._invoke_maybe_save(state, pure_nnx=True) + self.assertEqual(captured["step"], self.N_STEPS - 1) + + def test_linen_final_save_step_is_n_minus_1(self): + state = self._build_linen_state(self.N_STEPS) + self.assertEqual(int(state.step), self.N_STEPS) + captured = self._invoke_maybe_save(state, pure_nnx=False) + self.assertEqual(captured["step"], self.N_STEPS - 1) + + def test_nnx_and_linen_agree_on_actual_step(self): + """TrainStateNNX and Linen TrainState must yield the same fallback actual_step.""" + nnx_state = self._build_nnx_state(self.N_STEPS) + linen_state = self._build_linen_state(self.N_STEPS) + self.assertEqual( + self._invoke_maybe_save(nnx_state, pure_nnx=True)["step"], + self._invoke_maybe_save(linen_state, pure_nnx=False)["step"], + ) + + def test_nnx_state_is_converted_to_pure_dict_before_save(self): + """For pure_nnx=True, maybe_save_checkpoint must pass a plain dict to save_checkpoint, not an nnx.State.""" + state = self._build_nnx_state(self.N_STEPS) + self.assertIsInstance(state, nnx.State) # precondition: NNX train_step returns an nnx.State + + captured = self._invoke_maybe_save(state, pure_nnx=True) + + # save_checkpoint should have received a plain Python dict (the result of + # nnx.State.to_pure_dict()), not the original nnx.State. + self.assertIsInstance(captured["state"], dict) + self.assertNotIsInstance(captured["state"], nnx.State) + # Sanity: the converted dict still mirrors the TrainStateNNX structure. + self.assertIn("model", captured["state"]) + self.assertIn("optimizer", captured["state"]) + + def test_linen_state_is_passed_through_unchanged(self): + """For pure_nnx=False, maybe_save_checkpoint must pass the original TrainState object through.""" + state = self._build_linen_state(self.N_STEPS) + captured = self._invoke_maybe_save(state, pure_nnx=False) + # Linen path must not invoke to_pure_dict(); state is forwarded as-is. + self.assertIs(captured["state"], state) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_state_nnx_test.py b/tests/unit/train_state_nnx_test.py new file mode 100644 index 0000000000..03db77ff63 --- /dev/null +++ b/tests/unit/train_state_nnx_test.py @@ -0,0 +1,90 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TrainStateNNX tests.""" + +import unittest +import jax.numpy as jnp +from flax import nnx +import optax + +from maxtext.layers import train_state_nnx + + +class MockModel(nnx.Module): + """Mocked NNX model""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + def __call__(self, x): + return self.linear(x) + + +class TestTrainStateNNX(unittest.TestCase): + """TrainStateNNX tests.""" + + def setUp(self): + self.rngs = nnx.Rngs(0) + self.model = MockModel(rngs=self.rngs) + self.tx = optax.adam(1e-3) + + def test_init_with_optimizer(self): + """Test init with iptimizer.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + self.assertEqual(state.model, self.model) + self.assertEqual(state.optimizer, optimizer) + # Access step directly from optimizer + self.assertEqual(state.optimizer.step.value, 0) + + def test_init_without_optimizer(self): + """Test init without optimizer.""" + state = train_state_nnx.TrainStateNNX(self.model, None) + + self.assertEqual(state.model, self.model) + self.assertIsNone(state.optimizer) + + def test_apply_gradients_success(self): + """Test apply gradients can be called successfully.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # Create dummy gradients matching the model state structure + def loss_fn(m): + return jnp.mean(m(jnp.ones((1, 2))) ** 2) + + grads = nnx.grad(loss_fn)(state.model) + + # Apply gradients + state.apply_gradients(grads) + + # Verify step incremented (managed by nnx.Optimizer) + self.assertEqual(state.optimizer.step.value, 1) + + def test_apply_gradients_raises_runtime_error(self): + """Test apply gradients without a optimizer.""" + # Initialize without optimizer (inference mode) + state = train_state_nnx.TrainStateNNX(self.model, None) + + dummy_grads = {} + with self.assertRaises(RuntimeError) as cm: + state.apply_gradients(dummy_grads) + + self.assertIn("inference only", str(cm.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_utils_nnx_test.py b/tests/unit/train_utils_nnx_test.py new file mode 100644 index 0000000000..2ff7276fd9 --- /dev/null +++ b/tests/unit/train_utils_nnx_test.py @@ -0,0 +1,149 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX-specific helpers / patterns in train_utils.setup_train_loop. + +setup_train_loop itself is integration territory (it touches data iterators, +checkpoint managers, and a real mesh), so we cover the NNX-only pieces that +have unit-testable contracts: + + 1. The create_train_state_fn closure pattern: builds nnx.Optimizer + TrainStateNNX + from a zero-arg model factory and a transform. + 2. nnx.split(state.model, nnx.Param, ...) returns Param-only state used to + compute state_params / state_mesh_shardings_params. + 3. nnx.merge(state_graphdef, state) reconstitutes a TrainStateNNX from the + pure-state form returned by setup_training_state. +""" + +import unittest +from functools import partial + +import jax +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.layers import train_state_nnx + + +class _Model(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + +class TestCreateTrainStateFnClosure(unittest.TestCase): + """Exercise the closure pattern in setup_train_loop: + + def create_train_state_fn(): + model = _create_model_partial() + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(model, optimizer) + """ + + def test_returns_train_state_nnx_with_optimizer(self): + tx = optax.sgd(0.01) + + def _create_model(): + return _Model(rngs=nnx.Rngs(0)) + + def create_train_state_fn(): + model = _create_model() + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(model, optimizer) + + state = create_train_state_fn() + self.assertIsInstance(state, train_state_nnx.TrainStateNNX) + self.assertIsInstance(state.optimizer, nnx.Optimizer) + self.assertEqual(int(state.optimizer.step.get_value()), 0) + + def test_two_invocations_produce_independent_states(self): + """The lambda must call the factory each time (otherwise checkpoint init/restore would alias).""" + tx = optax.sgd(0.01) + counter = {"n": 0} + + def _create_model(): + counter["n"] += 1 + return _Model(rngs=nnx.Rngs(counter["n"])) + + def create_train_state_fn(): + model = _create_model() + return train_state_nnx.TrainStateNNX(model, nnx.Optimizer(model, tx, wrt=nnx.Param)) + + s1 = create_train_state_fn() + s2 = create_train_state_fn() + self.assertEqual(counter["n"], 2) + self.assertIsNot(s1.model, s2.model) + + +class TestSetupTrainLoopNNXTreeOps(unittest.TestCase): + """Cover the nnx.split(state.model, nnx.Param, ...) and nnx.merge round-trip + patterns that setup_train_loop uses to derive Param-only views and rebuild + the full TrainStateNNX before returning.""" + + def setUp(self): + self.tx = optax.sgd(0.01) + self.model = _Model(rngs=nnx.Rngs(0)) + self.state = train_state_nnx.TrainStateNNX(self.model, nnx.Optimizer(self.model, self.tx, wrt=nnx.Param)) + + def test_nnx_split_yields_param_only_state(self): + """state_params used for assert_params_sufficiently_sharded must contain only nnx.Param leaves.""" + _, state_params, _ = nnx.split(self.state.model, nnx.Param, ...) + leaves = jax.tree.leaves(state_params, is_leaf=lambda x: isinstance(x, nnx.Variable)) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + self.assertIsInstance(leaf, nnx.Param) + + def test_nnx_merge_reconstructs_train_state_nnx(self): + """setup_train_loop ends with nnx.merge(state_graphdef, state) — verify that round-trips.""" + state_graphdef, state_pure = nnx.split(self.state) + train_state = nnx.merge(state_graphdef, state_pure) + self.assertIsInstance(train_state, train_state_nnx.TrainStateNNX) + # Same numeric values. + self.assertTrue(jnp.allclose(train_state.model.linear.kernel.value, self.state.model.linear.kernel.value)) + + +class TestInitStateFnIsCallable(unittest.TestCase): + """For the Linen path setup_train_loop builds init_state_fn = partial(...). + + The NNX path uses a closure instead — confirm both forms have the + zero-argument call contract create_checkpoint_manager / setup_training_state expect. + """ + + def test_nnx_init_state_fn_callable_with_no_args(self): + tx = optax.sgd(0.01) + + def _create_model(): + return _Model(rngs=nnx.Rngs(0)) + + def init_state_fn(): + model = _create_model() + return train_state_nnx.TrainStateNNX(model, nnx.Optimizer(model, tx, wrt=nnx.Param)) + + state = init_state_fn() # must not raise / require args + self.assertIsInstance(state, train_state_nnx.TrainStateNNX) + + def test_linen_init_state_fn_is_partial_callable_with_no_args(self): + """Sanity: the Linen-side `partial(init_initial_state, model, tx, config, is_training, init_rng)` form.""" + + def init_initial_state(model, tx, config, is_training, init_rng): + del model, tx, config, is_training, init_rng + return "linen-state" + + init_state_fn = partial(init_initial_state, "model", "tx", "config", True, "rng") + self.assertEqual(init_state_fn(), "linen-state") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/run_sharding_dump.py b/tests/utils/run_sharding_dump.py index 7d3156fe00..62c71a9b5b 100644 --- a/tests/utils/run_sharding_dump.py +++ b/tests/utils/run_sharding_dump.py @@ -59,9 +59,12 @@ flags.DEFINE_string("topology", None, "Specific topology to dump.") flags.DEFINE_string("num_slice", None, "Specific number of slices to dump.") flags.DEFINE_string("custom_mesh_and_rule", None, "Specific custom_mesh_and_rule to dump.") +flags.DEFINE_bool("pure_nnx", False, "Use pure NNX model.") -def run_single_dump(model_name: str, topology: str, num_slice: str, custom_mesh_and_rule: str, overrides: tuple) -> None: +def run_single_dump( + model_name: str, topology: str, num_slice: str, custom_mesh_and_rule: str, overrides: tuple, pure_nnx: bool = False +) -> None: """Generate sharding json file for one specific model, topology, slice and rule.""" args = [ "python3", @@ -79,6 +82,8 @@ def run_single_dump(model_name: str, topology: str, num_slice: str, custom_mesh_ args.append(f"custom_mesh_and_rule={custom_mesh_and_rule}") if overrides: args.extend(overrides) + if pure_nnx: + args.append("pure_nnx=true") subprocess.run(args, check=True) @@ -117,7 +122,7 @@ def main(argv: Sequence[str]) -> None: print(" -> Sharding files already exist. Regenerating to overwrite.") try: - run_single_dump(model_name, topology, str(num_slice), custom_mesh_and_rule, overrides) + run_single_dump(model_name, topology, str(num_slice), custom_mesh_and_rule, overrides, pure_nnx=FLAGS.pure_nnx) except subprocess.CalledProcessError: print(f"!!! FAILED: {model_name} {topology} {num_slice} {custom_mesh_and_rule} overrides={overrides}")