Fix explicit-mesh sharding assert in deepseek batch-split scan#4208
Draft
ecnal-cienet wants to merge 1 commit into
Draft
Fix explicit-mesh sharding assert in deepseek batch-split scan#4208ecnal-cienet wants to merge 1 commit into
ecnal-cienet wants to merge 1 commit into
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
loss_fn's NNX branch constrained xent/z_loss with raw nn.with_logical_constraint, which keeps the size-1 context axis, while the model shards activations via create_sharding (remove_size_one_mesh_axis drops context). Under an explicit mesh the mismatch is a hard assert. Use sharding.maybe_shard_with_logical, as the Linen branch already does: it builds the sharding via create_sharding (dropping size-1 context, so it matches the array) and reshards instead of asserting under explicit mode.
1beca40 to
21b80d3
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Training
deepseek3-671b-batchsplit(which setsshard_mode=explicitanduse_batch_split_schedule=true) crashes at the firsttrain_stepcompile:Root cause
Under JAX's explicit mesh axes,
with_logical_constraint(used inloss_fn) is a hard assert that the array's sharding matches exactly — it no longer silently reshards as it did underautoaxes.scan_batch_split_layersreshards activations internally toP(('data','fsdp','expert'), None, None), deliberately dropping thecontextaxis (size 1 here) for the manual split/merge collectives, and returns the hidden states still sharded that way.loss_fnthen constrainsxent/z_lossviaactivation_embed_and_logits_batch, which maps to('data','fsdp','expert','context'). The two specs differ by thecontextaxis, so the assert fires.Fix
Capture the incoming sharding on entry and
jax.reshardthe output back to it before returning, so the batch-split path is transparent to downstream constraints. This mirrors what the per-layerDeepSeekDecoderLayer.__call__path already does (input_sharding = jax.typeof(inputs).sharding→jax.reshard(outputs, input_sharding));scan_batch_split_layerswas simply missing the restore step. Sincecontexthas size 1 in this config, this is the same physical layout — only the spec is made to line up for the explicit mesh.Tests
Before Fix (Linen, passed):
After Fix (Linen, passed):
Before Fix (NNX, failed): https://cloudlogging.app.goo.gl/XF3eVFANEPx43XBG7
After Fix (NNX, passed):
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.