Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ tiktoken
tokamax!=0.1.0
transformers>=5.8.0
uvloop
qwix>=0.1.6
qwix>=0.1.8
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
# See https://maxtext.readthedocs.io/en/latest/development/update_dependencies.html for details.

absl-py>=2.4.0
accelerate>=1.13.0
accelerate>=1.14.0
aiofiles>=25.1.0
aiohappyeyeballs>=2.6.2
aiohttp>=3.14.1
aiohttp-cors>=0.1.0
aiosignal>=1.4.0
annotated-doc>=0.0.4
annotated-types>=0.7.0
anthropic>=0.109.1
anthropic>=0.111.0
antlr4-python3-runtime>=4.9.3
anyio>=4.13.0
anyio>=4.14.0
apache-tvm-ffi>=0.1.12
appnope>=0.1.4 ; sys_platform == 'darwin'
aqtp>=0.9.0
Expand All @@ -24,7 +24,7 @@ astunparse>=1.6.3
attrs>=26.1.0
auditwheel>=6.6.0
black>=25.12.0
blake3>=1.0.8
blake3>=1.0.9
boto>=2.35.0
boto3>=1.34.0
botocore>=1.34.63
Expand All @@ -36,7 +36,7 @@ cffi>=2.0.0 ; implementation_name == 'pypy' or platform_python_implementation !=
cfgv>=3.5.0
charset-normalizer>=3.4.7
cheroot>=11.1.2
chex>=0.1.91
chex>=0.1.92
click>=8.4.1
cloud-accelerator-diagnostics>=0.1.1
cloudpickle>=3.1.2
Expand All @@ -47,7 +47,7 @@ colorful>=0.0.0
comm>=0.2.3
compressed-tensors>=0.17.0
contourpy>=1.3.3
cryptography>=48.0.1
cryptography>=49.0.0
cycler>=0.12.1
dataclasses>=0.5
dataclasses-json>=0.0.1
Expand All @@ -57,12 +57,12 @@ decorator>=5.3.1
depyf>=0.20.0
dill>=0.4.1
diskcache>=5.6.3
distlib>=0.4.2
distlib>=0.4.3
distro>=1.9.0
dm-tree>=0.1.10
dnspython>=2.0.0
docstring-parser>=0.18.0
drjax>=0.1.4
drjax>=0.2.0
editdistance>=0.8.1
einops>=0.8.2
einshape>=1.0
Expand All @@ -71,7 +71,7 @@ entrypoints>=0.4
etils>=1.14.0
execnet>=2.1.2
executing>=2.2.1
fastapi>=0.136.3
fastapi>=0.138.0
fastapi-cli>=0.0.8
fastapi-cloud-cli>=0.1.2
fastar>=0.9.0
Expand All @@ -88,28 +88,28 @@ gepa>=0.1.1
gguf>=0.19.0
google-api-core>=2.31.0
google-api-python-client>=2.197.0
google-auth>=2.53.0
google-auth>=2.55.0
google-auth-httplib2>=0.4.0
google-auth-oauthlib>=1.4.0
google-cloud-aiplatform>=1.157.0
google-cloud-aiplatform>=1.158.0
google-cloud-appengine-logging>=1.10.0
google-cloud-audit-log>=0.6.0
google-cloud-bigquery>=3.41.0
google-cloud-bigquery>=3.42.1
google-cloud-core>=2.6.0
google-cloud-logging>=3.16.0
google-cloud-mldiagnostics>=1.0.3
google-cloud-monitoring>=2.31.0
google-cloud-resource-manager>=1.17.0
google-cloud-storage>=3.11.0
google-cloud-resource-manager>=1.18.0
google-cloud-storage>=3.12.0
google-cloud-storage-control>=1.12.0
google-crc32c>=1.8.0
google-genai>=2.8.0
google-genai>=2.9.0
google-metrax>=0.2.3
google-pasta>=0.2.0
google-resumable-media>=2.10.0
google-tunix>=0.1.3
googleapis-common-protos>=1.75.0
grain>=0.2.17
grain>=0.2.18
grpc-google-iam-v1>=0.14.4
grpcio>=1.80.0
grpcio-status>=1.80.0
Expand All @@ -124,7 +124,7 @@ httplib2>=0.31.2
httptools>=0.8.0
httpx>=0.28.1
httpx-sse>=0.1.0
huggingface-hub>=1.18.0
huggingface-hub>=1.20.1
humanize>=4.15.0
hypothesis>=6.151.9
identify>=2.6.19
Expand All @@ -140,9 +140,9 @@ ipython-pygments-lexers>=1.1.1
ipywidgets>=8.1.8
isort>=8.0.1
jaraco-functools>=4.5.0
jax>=0.10.1
jax>=0.10.2
jaxlib>=0.10.1
jaxtyping>=0.3.10
jaxtyping>=0.3.11
jedi>=0.20.0
jinja2>=3.1.6
jiter>=0.15.0
Expand All @@ -154,7 +154,7 @@ jupyter-client>=8.9.1
jupyter-core>=5.9.1
jupyterlab-widgets>=3.0.16
kagglehub>=1.0.2
kagglesdk>=0.1.28
kagglesdk>=0.1.30
keras>=3.14.1
kiwisolver>=1.5.0
lark>=1.2.2
Expand All @@ -179,11 +179,11 @@ mdurl>=0.1.2
mistral-common>=1.11.3
ml-collections>=1.1.0
ml-dtypes>=0.5.4
ml-goodput-measurement>=0.0.16
model-hosting-container-standards>=0.1.15
ml-goodput-measurement>=0.2.0
model-hosting-container-standards>=0.1.16
more-itertools>=11.1.0
mpmath>=1.3.0
msgpack>=1.1.2
msgpack>=1.2.1
msgspec>=0.21.1
multidict>=6.7.1
multiprocess>=0.70.19
Expand All @@ -200,8 +200,8 @@ numba>=0.65.1
numpy>=2.1.3
numpy-typing-compat>=20251206.2.1
oauthlib>=3.3.1
omegaconf>=2.3.0
openai>=2.41.1
omegaconf>=2.3.1
openai>=2.43.0
openai-harmony>=0.0.8
opencensus>=0.0.1
opencv-python-headless>=4.13.0.90
Expand Down Expand Up @@ -229,7 +229,7 @@ parameterized>=0.9.0
parso>=0.8.7
partial-json-parser>=0.2.1.1.post7
pathspec>=1.1.1
pathwaysutils>=0.1.8
pathwaysutils>=0.1.9
peft>=0.19.1
perfetto>=0.56.0
pexpect>=4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
Expand All @@ -239,7 +239,7 @@ pluggy>=1.6.0
portpicker>=1.6.0
pre-commit>=4.6.0
prometheus-client>=0.25.0
prometheus-fastapi-instrumentator>=8.0.0
prometheus-fastapi-instrumentator>=8.0.1
promise>=2.3
prompt-toolkit>=3.0.52
propcache>=0.5.2
Expand All @@ -266,24 +266,24 @@ pyelftools>=0.32
pyglove>=0.4.5
pygments>=2.20.0
pyink>=25.12.0
pylint>=4.0.5
pylint>=4.0.6
pynvml>=13.0.1
pyopenssl>=26.2.0
pyopenssl>=26.3.0
pyparsing>=3.3.2
pyproject-hooks>=1.2.0
pytest>=8.4.2
pytest-mock>=3.15.1
pytest-xdist>=3.8.0
python-dateutil>=2.9.0.post0
python-discovery>=1.4.0
python-discovery>=1.4.2
python-dotenv>=1.2.2
python-json-logger>=4.1.0
python-multipart>=0.0.18
pytokens>=0.4.1
pytype>=2024.10.11
pyyaml>=6.0.3
pyzmq>=27.1.0
qwix>=0.1.6
qwix>=0.1.8
ray>=2.55.1
referencing>=0.37.0
regex>=2026.5.9
Expand Down Expand Up @@ -317,7 +317,7 @@ sniffio>=1.3.1
sortedcontainers>=2.4.0
sse-starlette>=0.1.0
stack-data>=0.6.3
starlette>=1.2.1
starlette>=1.3.1
supervisor>=4.3.0
sympy>=1.14.0
tabulate>=0.10.0
Expand All @@ -343,9 +343,9 @@ torchax>=0.0.11
torchvision==0.26.0+cpu
tornado>=6.5.7
tpu-info>=0.7.1
tqdm>=4.68.2
tqdm>=4.68.3
traitlets>=5.15.1
transformers>=5.11.0
transformers>=5.12.1
treescope>=0.1.10
triton>=3.6.0 ; sys_platform == 'linux'
typeguard>=2.13.3
Expand All @@ -357,7 +357,7 @@ uritemplate>=4.2.0
urllib3>=2.7.0
uvicorn>=0.49.0
uvloop>=0.22.1
virtualenv>=21.4.2
virtualenv>=21.5.1
wadler-lindig>=0.1.7
watchfiles>=1.2.0
wcwidth>=0.8.1
Expand All @@ -366,8 +366,8 @@ werkzeug>=3.1.8
wheel>=0.47.0
widgetsnbextension>=4.0.15
win32-setctime>=1.2.0 ; sys_platform == 'win32'
wrapt>=2.2.1
xgrammar>=0.2.1
wrapt>=2.2.2
xgrammar>=0.2.2
xprof>=2.22.3
xxhash>=3.7.0
yapf>=0.43.0
Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/configs/post_train/sft.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ lora:
lora_rank: 0
lora_alpha: 0.0
lora_module_path: ""
# For QLoRA, set lora_weight_qtype (e.g., "nf4") and optionally lora_tile_size.
lora_weight_qtype: null
lora_tile_size: null
# Optional path to LoRA weights to load before training. Ignored if the current run is resumed.
lora_restore_path: ""

Expand Down
12 changes: 11 additions & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,9 +1301,19 @@ class LoRA(BaseModel):
lora_module_path: str = Field(
"",
description=(
"Regex identifying target modules for LoRA, e.g." " '.*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj'."
"Regex identifying target NNX modules for LoRA. "
"Example for standard models: 'decoder/layers/.*(self_attention/(query|out)|mlp/(wi_0|wo))'. "
"Example for MoE: 'decoder/scanned_blocks/layers.*/.*(MoeBlock_0|shared_experts)/(wi_0|wo)'."
),
)
lora_weight_qtype: str | None = Field(
None,
description=("Optional quantization type for QLoRA (e.g., 'nf4'). If set, QLoRA is applied."),
)
lora_tile_size: NonNegativeInt | None = Field(
None,
description=("Tile size for block-wise quantization. Typically 32 or 64."),
)
lora_restore_path: PathStr = Field(
"",
description=("Optional path to LoRA weights to load before training. Ignored if the current run is resumed."),
Expand Down
22 changes: 14 additions & 8 deletions src/maxtext/layers/nnx_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@
simple_layer,
)
from maxtext.multimodal import utils as mm_utils
from maxtext.utils import max_logging, max_utils, maxtext_utils, sharding
from maxtext.utils.maxtext_utils_nnx import nnx_ensure_scan_leading_axis
from maxtext.utils import max_logging, max_utils, maxtext_utils, maxtext_utils_nnx, sharding
from maxtext.utils.sharding import create_sharding
from maxtext.layers.pipeline import create_nnx_pipeline

Expand Down Expand Up @@ -993,6 +992,8 @@ def _extract_matching_state(template, full):
use_kv = kv_caches_stacked is not None

def layer_fn(carry, scanned_vars):
# Ensure metadata rank matches the sliced values
scanned_vars = maxtext_utils_nnx.nnx_remove_scan_axis(scanned_vars, "layers")

# Unpack the sliced variables for THIS layer
if use_kv:
Expand Down Expand Up @@ -1065,16 +1066,21 @@ def layer_fn(carry, scanned_vars):
# inference with vLLM, parameters do not change and we don't need intermediates.
return current_carry, layers, None
else:
params = nnx_ensure_scan_leading_axis(params, length)
state = nnx_ensure_scan_leading_axis(state, length)
params = maxtext_utils_nnx.nnx_ensure_scan_leading_axis(params, length)
state = maxtext_utils_nnx.nnx_ensure_scan_leading_axis(state, length)

final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
returned_kv_stacked = None

if scan_axis != 0:
new_params, new_rest = scanned_state.split(nnx.Param, ...)
new_params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), new_params)
scanned_state = nnx.merge_state(new_params, new_rest)
# Ensure metadata rank matches the stacked values
scanned_state = maxtext_utils_nnx.nnx_add_scan_axis(scanned_state, "layers", 0)

if scan_axis != 0:
new_params, new_rest = scanned_state.split(nnx.Param, ...)
new_params = maxtext_utils_nnx.nnx_sync_moveaxis(new_params, 0, scan_axis)
scanned_state = nnx.merge_state(new_params, new_rest)

returned_kv_stacked = None

if dynamic_graph_init:
# If graph changed, we need to merge with the new graphdef.
Expand Down
Loading
Loading