diff --git a/src/dependencies/requirements/base_requirements/requirements.txt b/src/dependencies/requirements/base_requirements/requirements.txt index 5ba8ee5093..85fd71c938 100644 --- a/src/dependencies/requirements/base_requirements/requirements.txt +++ b/src/dependencies/requirements/base_requirements/requirements.txt @@ -45,4 +45,4 @@ tiktoken tokamax!=0.1.0 transformers>=5.8.0 uvloop -qwix>=0.1.6 +qwix>=0.1.8 diff --git a/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt b/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt index a436b1e86d..10475fec41 100644 --- a/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt +++ b/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt @@ -2,7 +2,7 @@ # See https://maxtext.readthedocs.io/en/latest/development/update_dependencies.html for details. absl-py>=2.4.0 -accelerate>=1.13.0 +accelerate>=1.14.0 aiofiles>=25.1.0 aiohappyeyeballs>=2.6.2 aiohttp>=3.14.1 @@ -10,9 +10,9 @@ aiohttp-cors>=0.1.0 aiosignal>=1.4.0 annotated-doc>=0.0.4 annotated-types>=0.7.0 -anthropic>=0.109.1 +anthropic>=0.111.0 antlr4-python3-runtime>=4.9.3 -anyio>=4.13.0 +anyio>=4.14.0 apache-tvm-ffi>=0.1.12 appnope>=0.1.4 ; sys_platform == 'darwin' aqtp>=0.9.0 @@ -24,7 +24,7 @@ astunparse>=1.6.3 attrs>=26.1.0 auditwheel>=6.6.0 black>=25.12.0 -blake3>=1.0.8 +blake3>=1.0.9 boto>=2.35.0 boto3>=1.34.0 botocore>=1.34.63 @@ -36,7 +36,7 @@ cffi>=2.0.0 ; implementation_name == 'pypy' or platform_python_implementation != cfgv>=3.5.0 charset-normalizer>=3.4.7 cheroot>=11.1.2 -chex>=0.1.91 +chex>=0.1.92 click>=8.4.1 cloud-accelerator-diagnostics>=0.1.1 cloudpickle>=3.1.2 @@ -47,7 +47,7 @@ colorful>=0.0.0 comm>=0.2.3 compressed-tensors>=0.17.0 contourpy>=1.3.3 -cryptography>=48.0.1 +cryptography>=49.0.0 cycler>=0.12.1 dataclasses>=0.5 dataclasses-json>=0.0.1 @@ -57,12 +57,12 @@ decorator>=5.3.1 depyf>=0.20.0 dill>=0.4.1 diskcache>=5.6.3 -distlib>=0.4.2 +distlib>=0.4.3 distro>=1.9.0 dm-tree>=0.1.10 dnspython>=2.0.0 docstring-parser>=0.18.0 -drjax>=0.1.4 +drjax>=0.2.0 editdistance>=0.8.1 einops>=0.8.2 einshape>=1.0 @@ -71,7 +71,7 @@ entrypoints>=0.4 etils>=1.14.0 execnet>=2.1.2 executing>=2.2.1 -fastapi>=0.136.3 +fastapi>=0.138.0 fastapi-cli>=0.0.8 fastapi-cloud-cli>=0.1.2 fastar>=0.9.0 @@ -88,28 +88,28 @@ gepa>=0.1.1 gguf>=0.19.0 google-api-core>=2.31.0 google-api-python-client>=2.197.0 -google-auth>=2.53.0 +google-auth>=2.55.0 google-auth-httplib2>=0.4.0 google-auth-oauthlib>=1.4.0 -google-cloud-aiplatform>=1.157.0 +google-cloud-aiplatform>=1.158.0 google-cloud-appengine-logging>=1.10.0 google-cloud-audit-log>=0.6.0 -google-cloud-bigquery>=3.41.0 +google-cloud-bigquery>=3.42.1 google-cloud-core>=2.6.0 google-cloud-logging>=3.16.0 google-cloud-mldiagnostics>=1.0.3 google-cloud-monitoring>=2.31.0 -google-cloud-resource-manager>=1.17.0 -google-cloud-storage>=3.11.0 +google-cloud-resource-manager>=1.18.0 +google-cloud-storage>=3.12.0 google-cloud-storage-control>=1.12.0 google-crc32c>=1.8.0 -google-genai>=2.8.0 +google-genai>=2.9.0 google-metrax>=0.2.3 google-pasta>=0.2.0 google-resumable-media>=2.10.0 google-tunix>=0.1.3 googleapis-common-protos>=1.75.0 -grain>=0.2.17 +grain>=0.2.18 grpc-google-iam-v1>=0.14.4 grpcio>=1.80.0 grpcio-status>=1.80.0 @@ -124,7 +124,7 @@ httplib2>=0.31.2 httptools>=0.8.0 httpx>=0.28.1 httpx-sse>=0.1.0 -huggingface-hub>=1.18.0 +huggingface-hub>=1.20.1 humanize>=4.15.0 hypothesis>=6.151.9 identify>=2.6.19 @@ -140,9 +140,9 @@ ipython-pygments-lexers>=1.1.1 ipywidgets>=8.1.8 isort>=8.0.1 jaraco-functools>=4.5.0 -jax>=0.10.1 +jax>=0.10.2 jaxlib>=0.10.1 -jaxtyping>=0.3.10 +jaxtyping>=0.3.11 jedi>=0.20.0 jinja2>=3.1.6 jiter>=0.15.0 @@ -154,7 +154,7 @@ jupyter-client>=8.9.1 jupyter-core>=5.9.1 jupyterlab-widgets>=3.0.16 kagglehub>=1.0.2 -kagglesdk>=0.1.28 +kagglesdk>=0.1.30 keras>=3.14.1 kiwisolver>=1.5.0 lark>=1.2.2 @@ -179,11 +179,11 @@ mdurl>=0.1.2 mistral-common>=1.11.3 ml-collections>=1.1.0 ml-dtypes>=0.5.4 -ml-goodput-measurement>=0.0.16 -model-hosting-container-standards>=0.1.15 +ml-goodput-measurement>=0.2.0 +model-hosting-container-standards>=0.1.16 more-itertools>=11.1.0 mpmath>=1.3.0 -msgpack>=1.1.2 +msgpack>=1.2.1 msgspec>=0.21.1 multidict>=6.7.1 multiprocess>=0.70.19 @@ -200,8 +200,8 @@ numba>=0.65.1 numpy>=2.1.3 numpy-typing-compat>=20251206.2.1 oauthlib>=3.3.1 -omegaconf>=2.3.0 -openai>=2.41.1 +omegaconf>=2.3.1 +openai>=2.43.0 openai-harmony>=0.0.8 opencensus>=0.0.1 opencv-python-headless>=4.13.0.90 @@ -229,7 +229,7 @@ parameterized>=0.9.0 parso>=0.8.7 partial-json-parser>=0.2.1.1.post7 pathspec>=1.1.1 -pathwaysutils>=0.1.8 +pathwaysutils>=0.1.9 peft>=0.19.1 perfetto>=0.56.0 pexpect>=4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32' @@ -239,7 +239,7 @@ pluggy>=1.6.0 portpicker>=1.6.0 pre-commit>=4.6.0 prometheus-client>=0.25.0 -prometheus-fastapi-instrumentator>=8.0.0 +prometheus-fastapi-instrumentator>=8.0.1 promise>=2.3 prompt-toolkit>=3.0.52 propcache>=0.5.2 @@ -266,16 +266,16 @@ pyelftools>=0.32 pyglove>=0.4.5 pygments>=2.20.0 pyink>=25.12.0 -pylint>=4.0.5 +pylint>=4.0.6 pynvml>=13.0.1 -pyopenssl>=26.2.0 +pyopenssl>=26.3.0 pyparsing>=3.3.2 pyproject-hooks>=1.2.0 pytest>=8.4.2 pytest-mock>=3.15.1 pytest-xdist>=3.8.0 python-dateutil>=2.9.0.post0 -python-discovery>=1.4.0 +python-discovery>=1.4.2 python-dotenv>=1.2.2 python-json-logger>=4.1.0 python-multipart>=0.0.18 @@ -283,7 +283,7 @@ pytokens>=0.4.1 pytype>=2024.10.11 pyyaml>=6.0.3 pyzmq>=27.1.0 -qwix>=0.1.6 +qwix>=0.1.8 ray>=2.55.1 referencing>=0.37.0 regex>=2026.5.9 @@ -317,7 +317,7 @@ sniffio>=1.3.1 sortedcontainers>=2.4.0 sse-starlette>=0.1.0 stack-data>=0.6.3 -starlette>=1.2.1 +starlette>=1.3.1 supervisor>=4.3.0 sympy>=1.14.0 tabulate>=0.10.0 @@ -343,9 +343,9 @@ torchax>=0.0.11 torchvision==0.26.0+cpu tornado>=6.5.7 tpu-info>=0.7.1 -tqdm>=4.68.2 +tqdm>=4.68.3 traitlets>=5.15.1 -transformers>=5.11.0 +transformers>=5.12.1 treescope>=0.1.10 triton>=3.6.0 ; sys_platform == 'linux' typeguard>=2.13.3 @@ -357,7 +357,7 @@ uritemplate>=4.2.0 urllib3>=2.7.0 uvicorn>=0.49.0 uvloop>=0.22.1 -virtualenv>=21.4.2 +virtualenv>=21.5.1 wadler-lindig>=0.1.7 watchfiles>=1.2.0 wcwidth>=0.8.1 @@ -366,8 +366,8 @@ werkzeug>=3.1.8 wheel>=0.47.0 widgetsnbextension>=4.0.15 win32-setctime>=1.2.0 ; sys_platform == 'win32' -wrapt>=2.2.1 -xgrammar>=0.2.1 +wrapt>=2.2.2 +xgrammar>=0.2.2 xprof>=2.22.3 xxhash>=3.7.0 yapf>=0.43.0 diff --git a/src/maxtext/configs/post_train/sft.yml b/src/maxtext/configs/post_train/sft.yml index 1188e5ea84..6692fc8358 100644 --- a/src/maxtext/configs/post_train/sft.yml +++ b/src/maxtext/configs/post_train/sft.yml @@ -27,6 +27,9 @@ lora: lora_rank: 0 lora_alpha: 0.0 lora_module_path: "" + # For QLoRA, set lora_weight_qtype (e.g., "nf4") and optionally lora_tile_size. + lora_weight_qtype: null + lora_tile_size: null # Optional path to LoRA weights to load before training. Ignored if the current run is resumed. lora_restore_path: "" diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 0d64347d60..807182bd11 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1301,9 +1301,19 @@ class LoRA(BaseModel): lora_module_path: str = Field( "", description=( - "Regex identifying target modules for LoRA, e.g." " '.*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj'." + "Regex identifying target NNX modules for LoRA. " + "Example for standard models: 'decoder/layers/.*(self_attention/(query|out)|mlp/(wi_0|wo))'. " + "Example for MoE: 'decoder/scanned_blocks/layers.*/.*(MoeBlock_0|shared_experts)/(wi_0|wo)'." ), ) + lora_weight_qtype: str | None = Field( + None, + description=("Optional quantization type for QLoRA (e.g., 'nf4'). If set, QLoRA is applied."), + ) + lora_tile_size: NonNegativeInt | None = Field( + None, + description=("Tile size for block-wise quantization. Typically 32 or 64."), + ) lora_restore_path: PathStr = Field( "", description=("Optional path to LoRA weights to load before training. Ignored if the current run is resumed."), diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 5d7262e5a4..6810f9a86c 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -66,8 +66,7 @@ simple_layer, ) from maxtext.multimodal import utils as mm_utils -from maxtext.utils import max_logging, max_utils, maxtext_utils, sharding -from maxtext.utils.maxtext_utils_nnx import nnx_ensure_scan_leading_axis +from maxtext.utils import max_logging, max_utils, maxtext_utils, maxtext_utils_nnx, sharding from maxtext.utils.sharding import create_sharding from maxtext.layers.pipeline import create_nnx_pipeline @@ -993,6 +992,8 @@ def _extract_matching_state(template, full): use_kv = kv_caches_stacked is not None def layer_fn(carry, scanned_vars): + # Ensure metadata rank matches the sliced values + scanned_vars = maxtext_utils_nnx.nnx_remove_scan_axis(scanned_vars, "layers") # Unpack the sliced variables for THIS layer if use_kv: @@ -1065,16 +1066,21 @@ def layer_fn(carry, scanned_vars): # inference with vLLM, parameters do not change and we don't need intermediates. return current_carry, layers, None else: - params = nnx_ensure_scan_leading_axis(params, length) - state = nnx_ensure_scan_leading_axis(state, length) + params = maxtext_utils_nnx.nnx_ensure_scan_leading_axis(params, length) + state = maxtext_utils_nnx.nnx_ensure_scan_leading_axis(state, length) final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state)) returned_kv_stacked = None - if scan_axis != 0: - new_params, new_rest = scanned_state.split(nnx.Param, ...) - new_params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), new_params) - scanned_state = nnx.merge_state(new_params, new_rest) + # Ensure metadata rank matches the stacked values + scanned_state = maxtext_utils_nnx.nnx_add_scan_axis(scanned_state, "layers", 0) + + if scan_axis != 0: + new_params, new_rest = scanned_state.split(nnx.Param, ...) + new_params = maxtext_utils_nnx.nnx_sync_moveaxis(new_params, 0, scan_axis) + scanned_state = nnx.merge_state(new_params, new_rest) + + returned_kv_stacked = None if dynamic_graph_init: # If graph changed, we need to merge with the new graphdef. diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index c433a0e5fb..6b4410f209 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -21,7 +21,7 @@ import re from typing import Any, Optional -from flax import nnx +from flax import nnx, linen as nn from flax.linen import partitioning as nn_partitioning from flax.training import train_state import jax @@ -35,7 +35,6 @@ from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils -from maxtext.utils import sharding from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR # NNX-only imports (train_state_nnx, model_creation_utils) are loaded lazily @@ -451,11 +450,18 @@ def _build_lora_provider(mt_config: pyconfig.HyperParameters) -> qwix.LoraProvid "rank": mt_config.lora.lora_rank, "alpha": mt_config.lora.lora_alpha, "dropout": 0.0, + "weight_qtype": mt_config.lora.lora_weight_qtype, + "tile_size": mt_config.lora.lora_tile_size, } + # Distinguish between standard LoRA and QLoRA in logs + lora_type = "QLoRA" if mt_config.lora.lora_weight_qtype else "LoRA" + max_logging.log( - f"LoRA configured: module_path={lora_module_path} " - f"rank={mt_config.lora.lora_rank} alpha={mt_config.lora.lora_alpha}" + f"{lora_type} configured: rank={mt_config.lora.lora_rank} alpha={mt_config.lora.lora_alpha} " + f"qtype={mt_config.lora.lora_weight_qtype} tile_size={mt_config.lora.lora_tile_size}" ) + + max_logging.log(f"Using lora_module_path: {lora_module_path}") return qwix.LoraProvider(**lora_kwargs) @@ -553,13 +559,22 @@ def apply_lora_to_model( # Use logical_to_mesh_sharding to correctly map logical axes like 'embed' # to physical mesh axes. - dst_shardings = sharding.logical_to_mesh_sharding( - nnx.get_partition_spec(state), mesh, rules=mt_config.logical_axis_rules - ) - - from tunix.rl import reshard # pylint: disable=import-outside-toplevel + dst_shardings = nn.logical_to_mesh_sharding(nnx.get_partition_spec(state), mesh, mt_config.logical_axis_rules) + + def _safe_reshard(var, sharding_spec): + if not isinstance(var, nnx.Variable) or not isinstance(sharding_spec, jax.sharding.Sharding): + return var + val = var.get_value() + if not isinstance(val, jax.Array): + return var + # make_array_from_callback natively constructs a globally sharded array + # from the local host arrays, bypassing backend-specific device_put issues + # on both Pathways and McJAX. + resharded_val = jax.make_array_from_callback(val.shape, sharding_spec, lambda idx: val[idx]) + return var.replace(value=resharded_val) + + state = jax.tree_util.tree_map(_safe_reshard, state, dst_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable)) - state = reshard.reshard_pytree(state, dst_shardings) lora_model = nnx.merge(graph_def, state) _verify_lora_parameters(lora_model, mt_config) diff --git a/src/maxtext/utils/maxtext_utils_nnx.py b/src/maxtext/utils/maxtext_utils_nnx.py index 5b645b85ca..d1827edfe8 100644 --- a/src/maxtext/utils/maxtext_utils_nnx.py +++ b/src/maxtext/utils/maxtext_utils_nnx.py @@ -18,7 +18,8 @@ from flax import nnx import jax -from jax.sharding import Mesh, NamedSharding +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from maxtext.utils import max_logging from maxtext.configs import pyconfig @@ -187,3 +188,90 @@ def _op(x): return x return jax.tree.map(_op, tree, is_leaf=lambda x: isinstance(x, nnx.Variable)) + + +# ------------------------------------------------------------------------------ +# Metadata Synchronization Helpers for NNX Variables +# ------------------------------------------------------------------------------ + + +def nnx_update_sharding_meta(variable, transform_fn): + """Generic helper to apply a list transformation to all sharding-related metadata.""" + if not (hasattr(variable, "get_metadata") and hasattr(variable, "replace")): + return variable + + meta = variable.get_metadata() + updates = {} + + for key in ["sharding", "out_sharding", "sharding_names"]: + if (val := meta.get(key)) and isinstance(val, (P, tuple, list)): + new_list = list(val) + transformed = transform_fn(new_list) + updates[key] = P(*transformed) if isinstance(val, P) else tuple(transformed) + + if updates: + return variable.replace(**updates) + return variable + + +def nnx_sync_moveaxis(tree, from_axis, to_axis): + """Moves an axis in both values and sharding metadata of nnx.Variables.""" + if from_axis == to_axis: + return tree + + def _op(x): + is_var = isinstance(x, nnx.Variable) + val = x.get_value() if is_var else x + if not hasattr(val, "shape"): + return x + + new_val = jnp.moveaxis(val, from_axis, to_axis) + if not is_var: + return new_val + + def move_fn(l): + if len(l) > max(from_axis, to_axis): + l.insert(to_axis, l.pop(from_axis)) + return l + + return nnx_update_sharding_meta(x.replace(value=new_val), move_fn) + + return jax.tree.map(_op, tree, is_leaf=lambda x: isinstance(x, nnx.Variable) or hasattr(x, "shape")) + + +def nnx_remove_scan_axis(tree, name="layers"): + """Removes the given scan axis from the PartitionSpec.""" + + def _op(x): + if not isinstance(x, nnx.Variable): + return x + + def remove_fn(l): + if name in l: + l.remove(name) + while len(l) > x.get_value().ndim: + l.pop(0) + return l + + return nnx_update_sharding_meta(x, remove_fn) + + return jax.tree.map(_op, tree, is_leaf=lambda x: isinstance(x, nnx.Variable)) + + +def nnx_add_scan_axis(tree, name="layers", pos=0): + """Adds the given scan axis to the PartitionSpec at the specified position.""" + + def _op(x): + if not isinstance(x, nnx.Variable): + return x + + def add_fn(l): + if name not in l: + l.insert(pos, name) + while len(l) < x.get_value().ndim: + l.insert(pos, None) + return l + + return nnx_update_sharding_meta(x, add_fn) + + return jax.tree.map(_op, tree, is_leaf=lambda x: isinstance(x, nnx.Variable)) diff --git a/tests/post_training/unit/lora_utils_test.py b/tests/post_training/unit/lora_utils_test.py index 4e616ff2af..4499bfb6b5 100644 --- a/tests/post_training/unit/lora_utils_test.py +++ b/tests/post_training/unit/lora_utils_test.py @@ -29,6 +29,8 @@ from maxtext.utils import lora_utils from maxtext.utils import model_creation_utils from maxtext.configs import pyconfig +from maxtext.utils import maxtext_utils +from jax.sharding import Mesh from tests.utils.test_helpers import get_test_config_path # --------------------------------------------------------------------------- @@ -104,10 +106,14 @@ def test_build_lora_provider(self): mock_config.lora.lora_module_path = "custom/path" mock_config.lora.lora_rank = 8 mock_config.lora.lora_alpha = 16.0 + mock_config.lora.lora_weight_qtype = "int8" + mock_config.lora.lora_tile_size = 32 with mock.patch("qwix.LoraProvider") as mock_provider: lora_utils._build_lora_provider(mock_config) - mock_provider.assert_called_once_with(module_path="custom/path", rank=8, alpha=16.0, dropout=0.0) + mock_provider.assert_called_once_with( + module_path="custom/path", rank=8, alpha=16.0, dropout=0.0, weight_qtype="int8", tile_size=32 + ) def test_prepare_dummy_inputs(self): """Test preparation of dummy inputs for LoRA verification.""" @@ -158,8 +164,8 @@ def test_apply_lora_to_model_adapters_loaded(self): # If we skip Qwix, it should stay False. self.assertFalse(lora_utils.is_lora_enabled(result)) - def _run_apply_lora_test(self, scan_layers: bool): - """Helper to run LoRA application test with/without scanned layers.""" + def _run_apply_lora_test(self, scan_layers: bool, weight_qtype=None, tile_size=None, mock_multihost: bool = False): + """Helper to run LoRA application test with/without scanned layers and optional QLoRA.""" # Passing nested dict as 'lora' kwarg to _make_config cfg = _make_config( lora={ @@ -167,18 +173,27 @@ def _run_apply_lora_test(self, scan_layers: bool): "lora_rank": 4, "lora_alpha": 8.0, "lora_module_path": ".*mlp/wi_.*", + "lora_weight_qtype": weight_qtype, + "lora_tile_size": tile_size, }, scan_layers=scan_layers, ) # Create a real small model using standard creation utils - model, _ = model_creation_utils.from_pretrained(cfg, mesh=None, model_mode=model_creation_utils.MODEL_MODE_TRAIN) + model, mesh = model_creation_utils.from_pretrained(cfg, mesh=None, model_mode=model_creation_utils.MODEL_MODE_TRAIN) # Verify model is NOT lora enabled initially self.assertFalse(lora_utils.is_lora_enabled(model)) - # Apply LoRA - lora_model = lora_utils.apply_lora_to_model(model, model.mesh, cfg) + if mock_multihost: + devices_array = maxtext_utils.create_device_mesh(cfg) + dummy_mesh = Mesh(devices_array, cfg.mesh_axes) + + # Just verify that apply_lora_to_model runs successfully with the dummy mesh + lora_model = lora_utils.apply_lora_to_model(model, dummy_mesh, cfg) + else: + # Apply LoRA + lora_model = lora_utils.apply_lora_to_model(model, mesh, cfg) # Verify we can find LoRAParam in the state _, state = nnx.split(lora_model) @@ -200,13 +215,25 @@ def _run_apply_lora_test(self, scan_layers: bool): self.assertGreater(len(jax.tree_util.tree_leaves(opt_state)), 0) def test_apply_lora_to_model_scan_layers_false(self): - """Test applying LoRA to model with scan_layers=False.""" + """Test applying standard LoRA to model with scan_layers=False.""" self._run_apply_lora_test(scan_layers=False) def test_apply_lora_to_model_scan_layers_true(self): - """Test applying LoRA to model with scan_layers=True.""" + """Test applying standard LoRA to model with scan_layers=True.""" self._run_apply_lora_test(scan_layers=True) + def test_apply_qlora_to_model_scan_layers_false(self): + """Test applying QLoRA to model with scan_layers=False.""" + self._run_apply_lora_test(scan_layers=False, weight_qtype="int8", tile_size=32) + + def test_apply_qlora_to_model_scan_layers_true(self): + """Test applying QLoRA to model with scan_layers=True.""" + self._run_apply_lora_test(scan_layers=True, weight_qtype="int8", tile_size=32) + + def test_apply_lora_multihost_mock(self): + """Test applying LoRA with a dummy mesh to trigger the multi-host reshard callback.""" + self._run_apply_lora_test(scan_layers=False, mock_multihost=True) + def test_restore_lora_from_path(self): """Test restoration of LoRA parameters from a path.""" cfg = _make_config( diff --git a/tests/utils/test_maxtext_utils_nnx.py b/tests/utils/test_maxtext_utils_nnx.py new file mode 100644 index 0000000000..f804c76d05 --- /dev/null +++ b/tests/utils/test_maxtext_utils_nnx.py @@ -0,0 +1,72 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Tests for NNX utilities. """ + +import unittest +from flax import nnx +import jax.numpy as jnp +from jax.sharding import PartitionSpec as P +from maxtext.utils import maxtext_utils_nnx # pylint: disable=no-name-in-module + + +class TestMaxtextUtilsNnx(unittest.TestCase): + """Tests for maxtext_utils_nnx.""" + + def test_nnx_sync_moveaxis(self): + val = jnp.zeros((10, 20, 30)) + sharding = P("layers", "embed", None) + + var = nnx.Param(val) + var.set_metadata(sharding=sharding, out_sharding=sharding) + + moved_var = maxtext_utils_nnx.nnx_sync_moveaxis(var, 0, 2) + + self.assertEqual(moved_var.get_value().shape, (20, 30, 10)) + self.assertEqual(moved_var.get_metadata().get("sharding"), P("embed", None, "layers")) + self.assertEqual(moved_var.get_metadata().get("out_sharding"), P("embed", None, "layers")) + + def test_nnx_remove_scan_axis(self): + val = jnp.zeros((10, 20)) + sharding = P("layers", "embed") + var = nnx.Param(val) + var.set_metadata(sharding=sharding) + + # Simulate slicing + sliced_val = val[0] + var_with_sliced_val = var.replace(value=sliced_val) + + fixed_var = maxtext_utils_nnx.nnx_remove_scan_axis(var_with_sliced_val, "layers") + + self.assertEqual(fixed_var.get_value().ndim, 1) + self.assertEqual(fixed_var.get_metadata().get("sharding"), P("embed")) + + def test_nnx_add_scan_axis(self): + val = jnp.zeros((20,)) + sharding = P("embed") + var = nnx.Param(val) + var.set_metadata(sharding=sharding) + + # Simulate stacking + stacked_val = jnp.stack([val] * 10, axis=0) + var_with_stacked_val = var.replace(value=stacked_val) + + fixed_var = maxtext_utils_nnx.nnx_add_scan_axis(var_with_stacked_val, "layers", 0) + + self.assertEqual(fixed_var.get_value().ndim, 2) + self.assertEqual(fixed_var.get_metadata().get("sharding"), P("layers", "embed")) + + +if __name__ == "__main__": + unittest.main()