Hotfix for Maxtext regression with JAX 0.9 changes#629
Conversation
Claude WalkthroughIntent. The JAX 0.9 compatibility refactor (#604) moved Key changes.
Walkthrough.
The net effect: validation now happens exactly once per guard context, at the earliest moment a real mesh is observed — eager if the mesh exists at guard-entry, otherwise lazy on the first Testing. No tests were added. The fix is behavior-restoring for an out-of-tree caller (Maxtext) whose entry order is not covered by the in-repo JAX tests. Notes for reviewers.
Generated by Claude. To request a code review, comment |
|
Looks like labels update replaced real CI run with skipped dispatch. Real run link is https://github.com/ROCm/TransformerEngine/actions/runs/27632772340 |
|
Claude review — single-file hotfix to One inline note left on the lazy-validation block suggesting an optional regression test for the maxtext entry-before-mesh-activation path. Copyright headers: OK. |
Description
https://github.com/ROCm/frameworks-internal/issues/16494
JAX 0.9 compatibility changes resulted in regression in maxxtext that calls with global_shard_guard() before activatibg JAX mesh context, i.e. when mesh is not available yet
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: