Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/tutorials/posttraining/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ export DATASET_NAME=<DATASET_NAME> # e.g., openai/gsm8k
export TRAIN_SPLIT=<TRAIN_SPLIT> # e.g., train
export HF_DATA_DIR=<DATASET_PATH> # e.g., main
export TRAIN_DATA_COLUMNS=<DATA_COLUMNS> # e.g., ['question','answer']
export CHAT_TEMPLATE_PATH=<TEMPLATE_PATH> # e.g., maxtext/examples/chat_templates/math_qa.json
export CHAT_TEMPLATE_PATH=<TEMPLATE_PATH> # e.g., src/maxtext/examples/chat_templates/math_qa.json (use gemma4_math_qa.json for Gemma 4 models)

# -- LoRA Conversion configuration (Optional) --
export HF_LORA_ADAPTER_PATH=<HF_LORA_ADAPTER_PATH> # e.g., 'username/adapter-name'
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/post_train/lora_module_path.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mistral: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
deepseek2: "decoder/(dense_layers|moe_stack)/self_attention/(query|out|wkv_a|wkv_b)|decoder/(dense_layers|moe_stack)/(mlp|shared_experts)/(wi_0|wi_1|wo)"
gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)"
gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))"
gemma4: "decoder/(scanned_blocks|layers_remainder)/layers.*/.*(self_attention/(query|key|value|out)|mlp/.*(MoeBlock_0|wi_0|wi_1|wo|shared_experts/(wi_0|wi_1|wo)))"
olmo3: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
gpt3: "decoder/layers/(self_attention/(qkv_proj|out)|mlp/(wi|wo))"

Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/configs/post_train/sft.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ lora:
lora_rank: 0
lora_alpha: 0.0
lora_module_path: ""
# For QLoRA, set lora_weight_qtype (e.g., "nf4") and optionally lora_tile_size.
lora_weight_qtype: null
lora_tile_size: null
# Optional path to LoRA weights to load before training. Ignored if the current run is resumed.
lora_restore_path: ""

Expand Down
12 changes: 11 additions & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,9 +1238,19 @@ class LoRA(BaseModel):
lora_module_path: str = Field(
"",
description=(
"Regex identifying target modules for LoRA, e.g." " '.*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj'."
"Regex identifying target NNX modules for LoRA. "
"Example for standard models: 'decoder/layers/.*(self_attention/(query|out)|mlp/(wi_0|wo))'. "
"Example for MoE: 'decoder/scanned_blocks/layers.*/.*(MoeBlock_0|shared_experts)/(wi_0|wo)'."
),
)
lora_weight_qtype: str | None = Field(
None,
description=("Optional quantization type for QLoRA (e.g., 'nf4'). If set, QLoRA is applied."),
)
lora_tile_size: NonNegativeInt | None = Field(
None,
description=("Tile size for block-wise quantization. Typically 32 or 64."),
)
lora_restore_path: PathStr = Field(
"",
description=("Optional path to LoRA weights to load before training. Ignored if the current run is resumed."),
Expand Down
6 changes: 6 additions & 0 deletions src/maxtext/examples/chat_templates/gemma4_math_qa.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"PROMPT_TEMPLATE": "You are given a mathematical problem. You must solve the problem and provide your reasoning. Place your entire thought process and steps between <reasoning> and </reasoning>. After your reasoning, provide only the final numerical answer, extracted from your reasoning steps, between <answer> and </answer>. The user's problem is:\n{question}",
"COMPLETION_TEMPLATE": "<reasoning>\n{reasoning}\n</reasoning>\n<answer>{answer}</answer>",
"REASONING_ANSWER_SEPARATOR": "####",
"chat_template": "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'system' %}{{ '<|turn>system\n' + message['content'] + '<turn|>\n' }}{% elif message['role'] == 'user' %}{{ '<|turn>user\n' + message['content'] + '<turn|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|turn>model\n' + message['content'] + '<turn|>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|turn>model\n' }}{% endif %}"
}
23 changes: 17 additions & 6 deletions src/maxtext/layers/nnx_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@
)
from maxtext.multimodal import utils as mm_utils
from maxtext.utils import max_logging, max_utils, maxtext_utils, sharding
from maxtext.utils.maxtext_utils_nnx import nnx_ensure_scan_leading_axis
from maxtext.utils.maxtext_utils_nnx import (
nnx_add_scan_axis,
nnx_ensure_scan_leading_axis,
nnx_remove_scan_axis,
nnx_sync_moveaxis,
)
from maxtext.utils.sharding import create_sharding

# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -595,6 +600,8 @@ def _extract_matching_state(template, full):
use_kv = kv_caches_stacked is not None

def layer_fn(carry, scanned_vars):
# Ensure metadata rank matches the sliced values
scanned_vars = nnx_remove_scan_axis(scanned_vars, "layers")

# Unpack the sliced variables for THIS layer
if use_kv:
Expand Down Expand Up @@ -668,12 +675,16 @@ def layer_fn(carry, scanned_vars):
state = nnx_ensure_scan_leading_axis(state, length)

final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
returned_kv_stacked = None

if scan_axis != 0:
new_params, new_rest = scanned_state.split(nnx.Param, ...)
new_params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), new_params)
scanned_state = nnx.merge_state(new_params, new_rest)
# Ensure metadata rank matches the stacked values
scanned_state = nnx_add_scan_axis(scanned_state, "layers", 0)

if scan_axis != 0:
new_params, new_rest = scanned_state.split(nnx.Param, ...)
new_params = nnx_sync_moveaxis(new_params, 0, scan_axis)
scanned_state = nnx.merge_state(new_params, new_rest)

returned_kv_stacked = None

if dynamic_graph_init:
# If graph changed, we need to merge with the new graphdef.
Expand Down
6 changes: 3 additions & 3 deletions src/maxtext/utils/lora_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import re
from typing import Any, Optional

from flax import nnx
from flax import nnx, linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state
import jax
Expand Down Expand Up @@ -513,8 +513,8 @@ def apply_lora_to_model(

# Use logical_to_mesh_sharding to correctly map logical axes like 'embed'
# to physical mesh axes.
dst_shardings = sharding.logical_to_mesh_sharding(
nnx.get_partition_spec(state), mesh, rules=mt_config.logical_axis_rules
dst_shardings = nn.logical_to_mesh_sharding(
nnx.get_partition_spec(state), mesh, mt_config.logical_axis_rules
)

from tunix.rl import reshard # pylint: disable=import-outside-toplevel
Expand Down
86 changes: 85 additions & 1 deletion src/maxtext/utils/maxtext_utils_nnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from flax import nnx
import jax
from jax.sharding import Mesh, NamedSharding
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

from maxtext.utils import max_logging
from maxtext.configs import pyconfig
Expand Down Expand Up @@ -187,3 +187,87 @@ def _op(x):
return x

return jax.tree.map(_op, tree, is_leaf=lambda x: isinstance(x, nnx.Variable))


# ------------------------------------------------------------------------------
# Metadata Synchronization Helpers for NNX Variables
# ------------------------------------------------------------------------------

def nnx_update_sharding_meta(variable, transform_fn):
"""Generic helper to apply a list transformation to all sharding-related metadata."""
if not (hasattr(variable, "get_metadata") and hasattr(variable, "replace")):
return variable

meta = variable.get_metadata()
updates = {}

for key in ["sharding", "out_sharding", "sharding_names"]:
if (val := meta.get(key)) and isinstance(val, (P, tuple, list)):
new_list = list(val)
transformed = transform_fn(new_list)
updates[key] = P(*transformed) if isinstance(val, P) else tuple(transformed)

if updates:
return variable.replace(**updates)
return variable

def nnx_sync_moveaxis(tree, from_axis, to_axis):
"""Moves an axis in both values and sharding metadata of nnx.Variables."""
if from_axis == to_axis:
return tree

import jax.numpy as jnp
def _op(x):
is_var = isinstance(x, nnx.Variable)
val = x.get_value() if is_var else x
if not hasattr(val, "shape"):
return x

new_val = jnp.moveaxis(val, from_axis, to_axis)
if not is_var:
return new_val

def move_fn(l):
if len(l) > max(from_axis, to_axis):
l.insert(to_axis, l.pop(from_axis))
return l

return nnx_update_sharding_meta(x.replace(value=new_val), move_fn)

return jax.tree.map(_op, tree, is_leaf=lambda x: isinstance(x, nnx.Variable) or hasattr(x, "shape"))

def nnx_remove_scan_axis(tree, name="layers"):
"""Removes the given scan axis from the PartitionSpec."""

def _op(x):
if not isinstance(x, nnx.Variable):
return x

def remove_fn(l):
if name in l:
l.remove(name)
while len(l) > x.get_value().ndim:
l.pop(0)
return l

return nnx_update_sharding_meta(x, remove_fn)

return jax.tree.map(_op, tree, is_leaf=lambda x: isinstance(x, nnx.Variable))

def nnx_add_scan_axis(tree, name="layers", pos=0):
"""Adds the given scan axis to the PartitionSpec at the specified position."""

def _op(x):
if not isinstance(x, nnx.Variable):
return x

def add_fn(l):
if name not in l:
l.insert(pos, name)
while len(l) < x.get_value().ndim:
l.insert(pos, None)
return l

return nnx_update_sharding_meta(x, add_fn)

return jax.tree.map(_op, tree, is_leaf=lambda x: isinstance(x, nnx.Variable))
70 changes: 70 additions & 0 deletions tests/utils/test_maxtext_utils_nnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" Tests for NNX utilities. """

import unittest
from flax import nnx
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from maxtext.utils import maxtext_utils_nnx

class TestMaxtextUtilsNnx(unittest.TestCase):

def test_nnx_sync_moveaxis(self):
val = jnp.zeros((10, 20, 30))
sharding = P("layers", "embed", None)

var = nnx.Param(val)
var.set_metadata(sharding=sharding, out_sharding=sharding)

moved_var = maxtext_utils_nnx.nnx_sync_moveaxis(var, 0, 2)

self.assertEqual(moved_var.get_value().shape, (20, 30, 10))
self.assertEqual(moved_var.get_metadata().get("sharding"), P("embed", None, "layers"))
self.assertEqual(moved_var.get_metadata().get("out_sharding"), P("embed", None, "layers"))

def test_nnx_remove_scan_axis(self):
val = jnp.zeros((10, 20))
sharding = P("layers", "embed")
var = nnx.Param(val)
var.set_metadata(sharding=sharding)

# Simulate slicing
sliced_val = val[0]
var_with_sliced_val = var.replace(value=sliced_val)

fixed_var = maxtext_utils_nnx.nnx_remove_scan_axis(var_with_sliced_val, "layers")

self.assertEqual(fixed_var.get_value().ndim, 1)
self.assertEqual(fixed_var.get_metadata().get("sharding"), P("embed"))

def test_nnx_add_scan_axis(self):
val = jnp.zeros((20,))
sharding = P("embed")
var = nnx.Param(val)
var.set_metadata(sharding=sharding)

# Simulate stacking
stacked_val = jnp.stack([val] * 10, axis=0)
var_with_stacked_val = var.replace(value=stacked_val)

fixed_var = maxtext_utils_nnx.nnx_add_scan_axis(var_with_stacked_val, "layers", 0)

self.assertEqual(fixed_var.get_value().ndim, 2)
self.assertEqual(fixed_var.get_metadata().get("sharding"), P("layers", "embed"))

if __name__ == '__main__':
unittest.main()
Loading