Hotfix for Maxtext regression with JAX 0.9 changes#629
Conversation
Claude WalkthroughIntent. The JAX 0.9 compatibility refactor (#604) moved Key changes.
Walkthrough.
The net effect: validation now happens exactly once per guard context, at the earliest moment a real mesh is observed — eager if the mesh exists at guard-entry, otherwise lazy on the first Testing. No tests were added. The fix is behavior-restoring for an out-of-tree caller (Maxtext) whose entry order is not covered by the in-repo JAX tests. Notes for reviewers.
Generated by Claude. To request a code review, comment |
|
Looks like labels update replaced real CI run with skipped dispatch. Real run link is https://github.com/ROCm/TransformerEngine/actions/runs/27632772340 |
| # get_abstract_mesh() is empty there) → validation safely skipped. | ||
| if not _GLOBAL_MESH_RESOURCE_VALIDATED and is_mesh_available(): | ||
| _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) | ||
| _GLOBAL_MESH_RESOURCE_VALIDATED = True |
There was a problem hiding this comment.
Worth adding a small regression test in tests/jax/test_distributed_helper.py that mirrors the maxtext pattern — enter global_shard_guard(MeshResource(...)) outside any with mesh: context, then read global_mesh_resource() — to lock in the fix and prevent a future change to the validation cadence from silently re-introducing this assertion failure. Not blocking for a hotfix, but the existing test only covers the mesh-already-active path.
|
Claude review — single-file hotfix to One inline note left on the lazy-validation block suggesting an optional regression test for the maxtext entry-before-mesh-activation path. Copyright headers: OK. |
Description
JAX 0.9 compatibility changes resulted in regression in maxxtext that calls with global_shard_guard() before activatibg JAX mesh context, i.e. when mesh is not available yet
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: