diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 047ddb97a8..e6f73928a0 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -221,8 +221,20 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size) xent, z_loss = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=config.z_loss_multiplier) - xent = nn.with_logical_constraint(xent, ("activation_embed_and_logits_batch", "activation_length")) - z_loss = nn.with_logical_constraint(z_loss, ("activation_embed_and_logits_batch", "activation_length")) + xent = sharding.maybe_shard_with_logical( + xent, + ("activation_embed_and_logits_batch", "activation_length"), + model.mesh, + config.shard_mode, + debug_sharding=config.debug_sharding, + ) + z_loss = sharding.maybe_shard_with_logical( + z_loss, + ("activation_embed_and_logits_batch", "activation_length"), + model.mesh, + config.shard_mode, + debug_sharding=config.debug_sharding, + ) # Mask out paddings at the end of each example. xent = xent * (data["targets_segmentation"] != 0)