From 21b80d39b3e52efd47fb89ec5a0afbee37088910 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 19 Jun 2026 20:14:04 +0000 Subject: [PATCH] Reshard loss under explicit mesh on the NNX path loss_fn's NNX branch constrained xent/z_loss with raw nn.with_logical_constraint, which keeps the size-1 context axis, while the model shards activations via create_sharding (remove_size_one_mesh_axis drops context). Under an explicit mesh the mismatch is a hard assert. Use sharding.maybe_shard_with_logical, as the Linen branch already does: it builds the sharding via create_sharding (dropping size-1 context, so it matches the array) and reshards instead of asserting under explicit mode. --- src/maxtext/trainers/pre_train/train.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 047ddb97a8..e6f73928a0 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -221,8 +221,20 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size) xent, z_loss = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=config.z_loss_multiplier) - xent = nn.with_logical_constraint(xent, ("activation_embed_and_logits_batch", "activation_length")) - z_loss = nn.with_logical_constraint(z_loss, ("activation_embed_and_logits_batch", "activation_length")) + xent = sharding.maybe_shard_with_logical( + xent, + ("activation_embed_and_logits_batch", "activation_length"), + model.mesh, + config.shard_mode, + debug_sharding=config.debug_sharding, + ) + z_loss = sharding.maybe_shard_with_logical( + z_loss, + ("activation_embed_and_logits_batch", "activation_length"), + model.mesh, + config.shard_mode, + debug_sharding=config.debug_sharding, + ) # Mask out paddings at the end of each example. xent = xent * (data["targets_segmentation"] != 0)