Skip to content

[ROCm]: fix: JAX/TE sharding compatibility and tmem reduction foundations (PR1)#4191

Open
cj401-amd wants to merge 5 commits into
AI-Hypercomputer:mainfrom
cj401-amd:cj/tmem-fixes-clean-1-jax-sharding-compat
Open

[ROCm]: fix: JAX/TE sharding compatibility and tmem reduction foundations (PR1)#4191
cj401-amd wants to merge 5 commits into
AI-Hypercomputer:mainfrom
cj401-amd:cj/tmem-fixes-clean-1-jax-sharding-compat

Conversation

@cj401-amd

Copy link
Copy Markdown
Collaborator

Summary

  • Sharding: filter logical_axis_rules to only include axes present in the mesh,
    preventing crashes when fsdp_transpose or other axes are absent
  • Sharding: add skip_trivial_specs parameter to maybe_shard_with_logical to
    skip no-op resharding constraints (all-None PartitionSpecs), reducing XLA overhead
  • RMSNorm: use reshard() for explicit shard mode; replace jnp.einsum scale
    application with direct multiply to avoid unnecessary XLA ops
  • Attention (TE): for synthetic data, use mask_type="causal" directly instead of
    materializing the full [seq, seq] attention mask — avoids ~5 GiB temp memory from
    XLA loop_broadcast_fusion hoisting the mask into the pipeline scan carry
  • Attention (TE): remove deprecated scale_factor and context_parallel_strategy
    params from DotProductAttention for newer TransformerEngine compatibility
  • Train step: skip identity grad_dtype cast when grad_dtype == float32;
    set flax_always_shard_variable=False
  • train_compile.py: handle serialize() API change (tuple vs bytes return type);
    pyink formatting
  • Config: add pipeline_save_decoder_layer_input flag (used by PR 2)

Test plan

  • python3 -m pytest tests/unit/train_compile_test.py -v -k "test_save_compiled_v5e or test_save_compiled_v4"
  • Smoke-test training with pipeline parallelism config

@cj401-amd cj401-amd requested a review from NuojCheng June 17, 2026 22:51
@cj401-amd cj401-amd changed the title [ROCm]: fix: JAX/TE sharding compatibility and tmem reduction foundations [ROCm]: fix: JAX/TE sharding compatibility and tmem reduction foundations (PR1) Jun 17, 2026
@cj401-amd cj401-amd force-pushed the cj/tmem-fixes-clean-1-jax-sharding-compat branch from 4f116aa to 2f5fa82 Compare June 18, 2026 22:39
@codecov

codecov Bot commented Jun 18, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 44.82759% with 16 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/attention_op.py 0.00% 7 Missing ⚠️
src/maxtext/layers/normalizations.py 54.54% 4 Missing and 1 partial ⚠️
src/maxtext/trainers/pre_train/train.py 42.85% 3 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant