Skip to content

Hotfix for Maxtext regression with JAX 0.9 changes#629

Open
ipanfilo wants to merge 2 commits into
devfrom
ipanfilo/jax09_maxtext_hotfix
Open

Hotfix for Maxtext regression with JAX 0.9 changes#629
ipanfilo wants to merge 2 commits into
devfrom
ipanfilo/jax09_maxtext_hotfix

Conversation

@ipanfilo

@ipanfilo ipanfilo commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator

Description

https://github.com/ROCm/frameworks-internal/issues/16494
JAX 0.9 compatibility changes resulted in regression in maxxtext that calls with global_shard_guard() before activatibg JAX mesh context, i.e. when mesh is not available yet

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add mesh validation guard in global_shard_quard()
  • Restore mesh validation in global_mesh_resource()
  • Add flag to do mesh validation only of the first global_mesh_resource() invocation()
  • Add test

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions

Copy link
Copy Markdown

Claude Walkthrough

Intent. The JAX 0.9 compatibility refactor (#604) moved MeshResource validation out of global_mesh_resource() into global_shard_guard() so it would not run from inside jit(...).lower() where get_abstract_mesh() is empty. That broke Maxtext, which enters global_shard_guard() before activating its JAX mesh — so eager validation hit get_mesh_axis_size() on an empty mesh and asserted. This PR keeps the JAX 0.9 fix intact while restoring a Maxtext-compatible path.

Key changes.

  • New module-level flag _GLOBAL_MESH_RESOURCE_VALIDATED tracks whether the current resource has been validated (transformer_engine/jax/sharding.py:378).
  • global_shard_guard() now does eager validation only when is_mesh_available() is true, and saves/restores the validated flag for nesting (transformer_engine/jax/sharding.py:384).
  • global_mesh_resource() regains a lazy validation step — runs at most once per guard context, gated by both the flag and is_mesh_available() (transformer_engine/jax/sharding.py:418).

Walkthrough.

transformer_engine/jax/sharding.py — single-file change.

  • The previous behavior validated unconditionally on global_shard_guard entry whenever resource is not None. The hotfix wraps that call in if resource is not None and is_mesh_available():. When Maxtext enters the guard with no active JAX mesh, validation is deferred instead of asserting.
  • _GLOBAL_MESH_RESOURCE_VALIDATED is reset to False at every guard entry (a new resource always needs re-validation) and saved/restored in the try/finally so nested guards do not leak state from an inner context to an outer one.
  • global_mesh_resource() gets a "lazy validation" block: on each call, if the resource has not yet been validated and a mesh is now available, run _validate_mesh_resource_configuration and set the flag. Once validated, subsequent calls short-circuit on a single boolean check — cheaper than the pre-Ipanfilo/jax0.9 support #604 code that ran get_mesh_axis_size() on every access. Inside jit(...).lower(), is_mesh_available() is false (per the JAX 0.9 behavior the original fix targeted), so the lazy path stays safely skipped there.

The net effect: validation now happens exactly once per guard context, at the earliest moment a real mesh is observed — eager if the mesh exists at guard-entry, otherwise lazy on the first global_mesh_resource() call that sees one.

Testing. No tests were added. The fix is behavior-restoring for an out-of-tree caller (Maxtext) whose entry order is not covered by the in-repo JAX tests.

Notes for reviewers.

  • Backwards compatibility: callers that always enter the guard with an active mesh see identical eager-validation behavior; the only observable change is the addition of a fast-path lazy check in global_mesh_resource().
  • The old_validated save/restore in finally is load-bearing for nested global_shard_guard use — without it, exiting an inner guard would leave the outer guard's resource marked "validated" and skip its lazy check.
  • Validation correctness still depends on is_mesh_available() accurately reflecting whether get_mesh_axis_size() will succeed; both rely on _get_mesh() / get_abstract_mesh() semantics under JAX 0.9.

Generated by Claude. To request a code review, comment /claude review.

@ipanfilo

Copy link
Copy Markdown
Collaborator Author

Looks like labels update replaced real CI run with skipped dispatch. Real run link is https://github.com/ROCm/TransformerEngine/actions/runs/27632772340

Comment thread transformer_engine/jax/sharding.py
@github-actions

Copy link
Copy Markdown

Claude review — single-file hotfix to transformer_engine/jax/sharding.py. The eager-on-global_shard_guard-entry plus lazy-on-first-global_mesh_resource() validation correctly handles the maxtext pattern (guard entered before mesh activation) without re-introducing the JAX 0.9 custom_partitioning empty-mesh issue from #604. Save/restore of _GLOBAL_MESH_RESOURCE_VALIDATED in the finally block is correct, including for nested guards. Comments are clear about the why.

One inline note left on the lazy-validation block suggesting an optional regression test for the maxtext entry-before-mesh-activation path.

Copyright headers: OK.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 2 CI test level 2

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants