Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
581 changes: 581 additions & 0 deletions src/maxtext/checkpoint_conversion/linen_nnx_converter.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import os
import sys

from flax import nnx
import jax
from jax import random
from jax.sharding import Mesh
Expand All @@ -48,11 +49,15 @@
from maxtext.common import checkpointing
from maxtext.common.common_types import MODEL_MODE_TRAIN
from maxtext.layers import quantizations
from maxtext.layers import train_state_nnx
from maxtext.models.models import transformer_as_linen
from maxtext.optimizers import optimizers
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
from maxtext.utils import maxtext_utils_nnx
from maxtext.utils import model_creation_utils
from maxtext.utils import train_utils
import numpy as np
from psutil import Process
import tensorstore as ts
Expand Down Expand Up @@ -87,13 +92,23 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
devices_array = maxtext_utils.create_device_mesh(cfg)
mesh = Mesh(devices_array, cfg.mesh_axes)

quant = quantizations.configure_quantization(cfg)
if cfg.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
rngs = maxtext_utils_nnx.create_nnx_rngs(cfg, rng_key=init_rng)
model = model_creation_utils.from_config(cfg, mesh=mesh, rngs=rngs)
_, tx = train_utils.create_training_optimizer(cfg, model)
_create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(cfg, mesh)

def init_state_fn():
nnx_model = _create_model_partial()
optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param)
return train_state_nnx.TrainStateNNX(nnx_model, optimizer)

else:
quant = quantizations.configure_quantization(cfg)
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)

checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
cfg.checkpoint_dir,
Expand All @@ -102,11 +117,6 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
cfg.checkpoint_period,
)

if cfg.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn)
max_logging.log("start")
max_utils.print_mem_stats("After params initialized")
Expand Down Expand Up @@ -191,10 +201,21 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
"['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None),
}

state_map = {
".step": ("step", None),
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
}
if cfg.pure_nnx:
# NNX state-tree paths after `nnx.split(TrainStateNNX)`:
# model params -> ['model']<rest>.value
# adam mu / nu -> ['optimizer']['opt_state']['mu' | 'nu']<rest>.value
# step -> ['optimizer']['step'].value
# opt count -> ['optimizer']['opt_state']['count'].value
state_map = {
".optimizer.step.value": ("step", None),
".optimizer.opt_state.count.value": ("opt_states_0.no_prefix_0.count", None),
}
else:
state_map = {
".step": ("step", None),
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
}

def get_layer_prefix(keystr_pax):
# different path format between decoder_layer variable
Expand All @@ -206,19 +227,27 @@ def get_layer_prefix(keystr_pax):
return prefix_pax_opt_state

for keystr_maxtext, (keystr_pax, transform_fn) in keystr_map.items():
# model variable
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
prefix_pax_opt_state = get_layer_prefix(keystr_pax)
# first momentum in optimizer state
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
transform_fn,
)
# second momentum in optimizer state
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
transform_fn,
)
if cfg.pure_nnx:
state_map[f".model{keystr_maxtext}.value"] = (f"mdl_vars{keystr_pax}", transform_fn)
state_map[f".optimizer.opt_state.mu{keystr_maxtext}.value"] = (
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
transform_fn,
)
state_map[f".optimizer.opt_state.nu{keystr_maxtext}.value"] = (
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
transform_fn,
)
else:
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
transform_fn,
)
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
transform_fn,
)

def verify_fn(key_path, _):
keystr = jax.tree_util.keystr(key_path)
Expand Down Expand Up @@ -270,10 +299,11 @@ def map_fn(key_path, value):
max_logging.log("converted state finished")
max_utils.print_mem_stats("converted state finished")

if checkpointing.save_checkpoint(checkpoint_manager, converted_state.step, converted_state):
max_logging.log(f"saved a checkpoint at step {converted_state.step}")
step_value = int(converted_state.optimizer.step.value) if cfg.pure_nnx else converted_state.step
if checkpointing.save_checkpoint(checkpoint_manager, step_value, converted_state):
max_logging.log(f"saved a checkpoint at step {step_value}")
# Upon preemption, exit when and only when all ongoing saves are complete.
if checkpoint_manager.reached_preemption(converted_state.step):
if checkpoint_manager.reached_preemption(step_value):
checkpoint_manager.wait_until_finished()
sys.exit()

Expand Down
32 changes: 27 additions & 5 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl import flags
import datetime
from etils import epath
from flax import nnx
from flax.training import train_state
import jax
from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
Expand Down Expand Up @@ -571,7 +572,7 @@ def load_state_if_possible(
load_parameters_from_path: str,
load_full_state_from_path: str,
checkpoint_storage_concurrent_gb: int,
abstract_unboxed_pre_state: train_state.TrainState,
abstract_unboxed_pre_state: train_state.TrainState | nnx.State,
enable_single_replica_ckpt_restoring: bool | None = False,
dataset_type: str | None = "tfds",
step: int = -1, # -1 means latest
Expand Down Expand Up @@ -639,9 +640,14 @@ def map_to_pspec(data):
)
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)

restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state)
# Convert nnx.State to pure dict to match how checkpoints are saved for NNX
restore_target = abstract_unboxed_pre_state
if isinstance(abstract_unboxed_pre_state, nnx.State):
restore_target = abstract_unboxed_pre_state.to_pure_dict()

restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target)
checkpoint_args = ocp.args.PyTreeRestore(
item=abstract_unboxed_pre_state,
item=restore_target,
restore_args=restore_args,
partial_restore=True,
)
Expand Down Expand Up @@ -679,9 +685,14 @@ def map_to_pspec(data):
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None)

if load_parameters_from_path != "":
if isinstance(abstract_unboxed_pre_state, nnx.State):
_, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...)
else:
params = abstract_unboxed_pre_state.params

restored_params = load_params_from_path(
load_parameters_from_path,
abstract_unboxed_pre_state.params,
params,
checkpoint_storage_concurrent_gb,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
Expand Down Expand Up @@ -773,7 +784,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
# Determine the effective step for saving a checkpoint.
# If 'step' is not provided, this call is for a potential final checkpoint
# and use the last completed step from the state.
actual_step = (int(state.step) - 1) if step is None else int(step)
if step is not None:
actual_step = int(step)
else:
if config.pure_nnx:
actual_step = int(state.optimizer.step) - 1
else:
# Linen TrainState has .step attribute
actual_step = int(state.step) - 1

if config.pure_nnx:
# Convert nnx.State to dict.
state = state.to_pure_dict()

# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
# This occurs if this function was called:
Expand Down
7 changes: 7 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,13 @@ logical_axis_rules: [
['tokens_per_page', []],
['paged_kv_head_dim_size', []],
# ==========================================
# Pipeline Parallelism
# ==========================================
['layers_outside_pipeline', []],
['layers_per_stage', []],
['num_activations', []],
['circular_repeats', []],
# ==========================================
# Deprecated / Scheduled for Removal
# ==========================================
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
Expand Down
3 changes: 1 addition & 2 deletions src/maxtext/configs/pyconfig_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,9 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) -


def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool):
del enable_nnx # NNX vocab tiling supported via vocab_tiling_nnx_loss in vocabulary_tiling.py
if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0:
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration
raise ValueError("We currently don't support vocab tiling on NNX module.")


def validate_rampup_batch_size(batch_size_start, batch_size_end, batch_size_increment, global_rampup_samples):
Expand Down
3 changes: 1 addition & 2 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2833,8 +2833,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0
):
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
if self.num_vocab_tiling > 1 and self.enable_nnx:
raise ValueError("We currently don't support vocab tiling on NNX module.")
# Vocab tiling on NNX is now supported via vocab_tiling_nnx_loss in vocabulary_tiling.py.
if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring":
if "gpu" not in self.hardware:
raise ValueError(
Expand Down
Loading
Loading