Skip to content

Fix logical sharding resolution in NNX#4205

Open
xibinliu wants to merge 1 commit into
mainfrom
xibin/nnx_sharding
Open

Fix logical sharding resolution in NNX#4205
xibinliu wants to merge 1 commit into
mainfrom
xibin/nnx_sharding

Conversation

@xibinliu

Copy link
Copy Markdown
Collaborator

Description

In pure NNX training runs, model variables retrieve physical PartitionSpecs via get_nnx_named_sharding_with_scan_axis in maxtext_utils.py. Previously, this helper used Flax core SPMD's from_sharding_rules to map logical names to physical axes. However, from_sharding_rules resolves rules by converting the rules list into a dictionary (last-write-wins). This caused fallback rules sharing the same logical name (e.g. 'embed') to overwrite preceding specific rules, dropping essential axes like fsdp_transpose and leading to unsharded parameter percentage assertion errors.

Additionally, resolving specifications independently for each dimension without tracking assigned axes could bind a single physical axis (like fsdp_transpose) to multiple positional dimensions of a tensor, causing DuplicateSpecError.

To fix this:

  1. Replaced from_sharding_rules with a Rules-first resolution loop that matches rules sequentially (first-match-wins), matching Flax Linen's mapping behavior.
  2. Implemented an assigned_axes tracker within the loop to ensure physical mesh axes are bound to at most one dimension per tensor.
  3. Added unit tests covering sequential matching (first-match-wins) and duplicate physical axis prevention during resolution.

Tests

Log with Gemma3-12B (2x v6e-256)

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 19, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 78.78788% with 7 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/maxtext_utils.py 78.78% 4 Missing and 3 partials ⚠️

📢 Thoughts on this report? Let us know!

@xibinliu xibinliu force-pushed the xibin/nnx_sharding branch 2 times, most recently from a9aedd5 to e044d17 Compare June 19, 2026 00:47
In pure NNX training runs, model variables retrieve physical PartitionSpecs
via `get_nnx_named_sharding_with_scan_axis` in `maxtext_utils.py`. Previously,
this helper used Flax core SPMD's `from_sharding_rules` to map logical names
to physical axes. However, `from_sharding_rules` resolves rules by converting the
rules list into a dictionary (last-write-wins). This caused fallback rules
sharing the same logical name (e.g. 'embed') to overwrite preceding specific
rules, dropping essential axes like `fsdp_transpose` and leading to unsharded
parameter percentage assertion errors.

Additionally, resolving specifications independently for each dimension without
tracking assigned axes could bind a single physical axis (like `fsdp_transpose`)
to multiple positional dimensions of a tensor, causing `DuplicateSpecError`.

To fix this:
1. Replaced `from_sharding_rules` with a Rules-first resolution loop that matches
   rules sequentially (first-match-wins), matching Flax Linen's mapping behavior.
2. Implemented an `assigned_axes` tracker within the loop to ensure physical
   mesh axes are bound to at most one dimension per tensor.
3. Added unit tests covering sequential matching (first-match-wins) and
   duplicate physical axis prevention during resolution.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant