Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 49 additions & 10 deletions src/maxtext/utils/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Sequence

from flax import nnx, linen as nn
from flax.core.spmd import composite_rules, from_sharding_rules, get_logical_axis_rules
from flax.core.spmd import get_logical_axis_rules
from flax.linen import partitioning as nn_partitioning
from flax.training.train_state import TrainState

Expand Down Expand Up @@ -1612,6 +1612,43 @@ def move(path, x):
)


def _resolve_logical_sharding(out_sharding, context_rules, local_rules) -> list:
"""Resolves logical sharding annotations into physical sharding specs.

This matches rules sequentially (first-match-wins) and ensures that physical
mesh axes are bound to at most one dimension per tensor, preventing JAX
DuplicateSpecError.
"""
local_rules_list = list(local_rules) if local_rules is not None else []
context_rules_list = list(context_rules) if context_rules is not None else []
merged_rules = local_rules_list + context_rules_list
raw_sharding = list(out_sharding)
assigned_positions = set()
assigned_axes = set()

for rule_logical, rule_physical in merged_rules:
if rule_logical not in out_sharding:
continue
pos = out_sharding.index(rule_logical)
if pos in assigned_positions:
continue

if rule_physical is None:
raw_sharding[pos] = None
assigned_positions.add(pos)
continue

physical_axes = [rule_physical] if isinstance(rule_physical, str) else list(rule_physical)
if any(axis in assigned_axes for axis in physical_axes):
continue

raw_sharding[pos] = rule_physical
assigned_positions.add(pos)
assigned_axes.update(physical_axes)

return raw_sharding


def get_nnx_named_sharding_with_scan_axis(abs_var_state: nnx.State, mesh) -> nnx.State:
"""Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis.

Expand Down Expand Up @@ -1669,20 +1706,22 @@ def _make_named_sharding(v):
context_rules = get_logical_axis_rules()
local_rules = metadata.get("sharding_rules", ())
if context_rules or local_rules:
rules = composite_rules(context_rules, local_rules)
raw_sharding = from_sharding_rules(out_sharding, rules)
raw_sharding = _resolve_logical_sharding(out_sharding, context_rules, local_rules)
mesh_axis_names = mesh.axis_names if mesh is not None else ()

# from_sharding_rules leaves a logical name with no matching rule unchanged, so a
# name missing from logical_axis_rules (e.g. concat_embed on the MTP kernel)
# reaches NamedSharding and is rejected as an unknown mesh axis. Map any such
# leftover name to None (replicated), matching Linen, whose logical_to_mesh_axes
# replicates unmatched names.
# Map unmatched logical names to None (replicated), matching Linen's behavior.
# Also clean up tuples to only keep physical axes present in the active mesh.
def _sanitize(x):
if isinstance(x, list):
x = tuple(x)
if x is None or (isinstance(x, str) and x in mesh_axis_names) or isinstance(x, tuple):
return x
if x is None:
return None
if isinstance(x, str):
return x if x in mesh_axis_names else None
if isinstance(x, tuple):
# Only keep axes that actually exist in the physical mesh.
sanitized_tuple = tuple(i for i in x if i in mesh_axis_names)
return sanitized_tuple if sanitized_tuple else None
return None

sanitized_sharding = [_sanitize(x) for x in raw_sharding]
Expand Down
55 changes: 55 additions & 0 deletions tests/unit/maxtext_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import optax

from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax import nnx
from flax.core.scope import FrozenVariableDict
from flax.training import train_state
Expand Down Expand Up @@ -1690,6 +1691,60 @@ def test_string_out_sharding_is_wrapped_into_tuple(self):
# The single string 'fsdp' is turned into a list, and 'layers' is prepended.
self.assertEqual(result_sharding.spec, PartitionSpec("layers", "fsdp"))

def test_sequential_matching_first_match_wins(self):
"""Multiple rules for the same logical axis are matched sequentially, first-match-wins."""
# We define rules for 'embed' mapping to 'fsdp' (specific) then 'layers' (fallback)
rules = (
("embed", "fsdp"),
("embed", "layers"),
)
with nn_partitioning.axis_rules(rules):
with jax.set_mesh(self.mesh):
v = nnx.Param(
jnp.zeros((3,)),
out_sharding=("embed",),
)
out = self._run(self._build_state(w=v))
result_sharding = out["w"].get_value()
# 'embed' must match the first rule ('fsdp'), not the second ('layers').
self.assertEqual(result_sharding.spec, PartitionSpec("fsdp"))

def test_deduplicates_assigned_physical_axes(self):
"""Physical axes already bound to a dimension cannot be bound to another dimension."""
# Define rules where 'embed' maps to ('fsdp', 'layers') and 'mlp' maps to 'fsdp'.
# Because 'embed' is defined first, it binds 'fsdp'.
# When matching 'mlp', 'fsdp' is already bound, so it is skipped (unassigned/None).
rules = (
("embed", ("fsdp", "layers")),
("mlp", "fsdp"),
)
with nn_partitioning.axis_rules(rules):
with jax.set_mesh(self.mesh):
v = nnx.Param(
jnp.zeros((3, 4)),
out_sharding=("embed", "mlp"),
)
out = self._run(self._build_state(w=v))
result_sharding = out["w"].get_value()
# 'embed' maps to ('fsdp', 'layers').
# 'mlp' maps to None (replicated) because 'fsdp' is already bound.
self.assertEqual(result_sharding.spec, PartitionSpec(("fsdp", "layers"), None))

def test_resolves_when_context_rules_is_none(self):
"""When context_rules is None but local_rules are defined, resolution should succeed."""
# Ensure get_logical_axis_rules() returns None (which is the default outside axis_rules)
# We define local rules on the variable metadata.
with jax.set_mesh(self.mesh):
v = nnx.Param(
jnp.zeros((3,)),
out_sharding=("embed",),
sharding_rules=(("embed", "fsdp"),),
)
out = self._run(self._build_state(w=v))
result_sharding = out["w"].get_value()
# 'embed' must match the local rules even when context_rules is None.
self.assertEqual(result_sharding.spec, PartitionSpec("fsdp"))


if __name__ == "__main__":
unittest.main()
Loading