Skip to content

NNX migration prep (2/N): NNX utils and sharding utilities#3470

Merged
copybara-service[bot] merged 1 commit intomainfrom
feat/migrate-nnx-utils
Apr 17, 2026
Merged

NNX migration prep (2/N): NNX utils and sharding utilities#3470
copybara-service[bot] merged 1 commit intomainfrom
feat/migrate-nnx-utils

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

@ecnal-cienet ecnal-cienet commented Mar 20, 2026

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR #3427)
  2. [This PR] NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR #3500)
  4. ✅ NNX sharding diagnostics and bidirectional Linen↔NNX checkpoint conversion utilities. (PR #3525)
  5. ❌ NNX post-training fixes: MultimodalInput unpacking, scalar LR guard, nested NNX transform workaround.
  6. ❌ Enable NNX by default; fix unit and integration test failures.
  7. ❌ Remove Linen-specific code paths and NNX compatibility flags.

Description

Note: This is the second in a series of NNX migration PRs. Pure NNX training is not yet implemented — all NNX code paths currently raise NotImplementedError. This PR only introduces the structural scaffolding needed for subsequent patches to plug in NNX logic without modifying shared infrastructure.

  • NNX sharding utilities (maxtext_utils_nnx.py) — Functions to manipulate NNX model shardings using abstract model state: get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, and memory movement helpers (move_memory_to_host / move_memory_to_device).
  • get_abstract_state NNX path — Added get_abstract_state_nnx to maxtext_utils.py, which uses nnx.get_abstract_model to return a flat nnx.State (rather than a full TrainStateNNX), and updated get_abstract_state to dispatch to it when pure_nnx=True.
  • maxtext_utils.get_mesh_from_config() — Extracted mesh creation into a standalone function with unit tests.
  • Unit tests — Added tests/unit/maxtext_utils_nnx_test.py and extended tests/unit/maxtext_utils_test.py to cover the new mesh and sharding utilities.

Note on Flax deprecation warnings:
Flax v0.12 emits DeprecationWarning for .value access and VariableState. These are intentionally left unaddressed because post-training currently requires Flax v0.11 compatibility.

Tests

pytest tests/unit/maxtext_utils_nnx_test.py tests/unit/maxtext_utils_test.py -v 
pytest tests/unit/maxtext_utils_nnx_test.py tests/unit/maxtext_utils_nnx_test.py -v 

Pre-train Test Result

View Result

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 20, 2026

Codecov Report

❌ Patch coverage is 97.14286% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/maxtext_utils.py 94.11% 0 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@ecnal-cienet ecnal-cienet changed the title Feat/migrate nnx utils NNX migration prep (1/N): Migrate MaxText Utils Mar 20, 2026
@ecnal-cienet ecnal-cienet changed the title NNX migration prep (1/N): Migrate MaxText Utils NNX migration prep (2/N): Migrate MaxText Utils Mar 20, 2026
@ecnal-cienet ecnal-cienet force-pushed the feat/migrate-nnx-utils branch 2 times, most recently from 4fc37b6 to 722386f Compare March 21, 2026 00:57
@ecnal-cienet ecnal-cienet changed the title NNX migration prep (2/N): Migrate MaxText Utils NNX migration prep (2/N): NNX utils and sharding utilities Mar 21, 2026
@ecnal-cienet ecnal-cienet force-pushed the feat/migrate-nnx-utils branch 4 times, most recently from d8dd362 to b013c20 Compare March 26, 2026 17:07
@ecnal-cienet ecnal-cienet force-pushed the feat/migrate-nnx-utils branch 3 times, most recently from 954fded to 9b2900b Compare March 31, 2026 13:58
@ecnal-cienet ecnal-cienet force-pushed the feat/migrate-nnx-utils branch 3 times, most recently from 7c4588a to 546580f Compare April 6, 2026 14:51
@ecnal-cienet ecnal-cienet marked this pull request as ready for review April 6, 2026 14:52
Comment thread src/maxtext/utils/model_creation_utils.py Outdated
Comment thread src/maxtext/utils/maxtext_utils_nnx.py Outdated
- Add utils to manipulate the NNX shardings with abstract state of a
  model
  - also add unit tests for the utils
- Extract mesh creation function to maxtext_utils.get_mesh_from_config()
  - also add unit tests for this func

Note:
flax v0.12 has DeprecationWarning in multiple places:
  - DeprecationWarning: '.value' access is now deprecated. Use
    variable.get_value() or variable[...] (for [Array]).
  - DeprecationWarning: 'VariableState' was removed, this is just
    an alias to 'Variable'. Plase use 'Variable' directly instead.
But since the code needs to work with post-training, which currently
requires flax v0.11, we didn't change code for these warnings.
@ecnal-cienet ecnal-cienet force-pushed the feat/migrate-nnx-utils branch from dc4d588 to 6a0f895 Compare April 16, 2026 21:42
@copybara-service copybara-service Bot merged commit a049f9a into main Apr 17, 2026
49 checks passed
@copybara-service copybara-service Bot deleted the feat/migrate-nnx-utils branch April 17, 2026 20:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants