Skip to content

Source sync: Add Flax configuration updates and improve WAN trainer profiling#326

Open
michelle-yooh wants to merge 1 commit intomainfrom
yooh/copybara-sync
Open

Source sync: Add Flax configuration updates and improve WAN trainer profiling#326
michelle-yooh wants to merge 1 commit intomainfrom
yooh/copybara-sync

Conversation

@michelle-yooh
Copy link
Collaborator

This PR contains source synchronization from internal MaxDiffusion updates with the following improvements:

Changes:

Flax Configuration - Add automatic Flax config update to disable variable sharding by default in the main init.py, ensuring consistent behavior across the codebase.

Configuration Loading Fix - Improve model configuration loading in configuration_utils.py to prioritize subfolder-specific configs when available. This ensures models with subfolder structures load their configs correctly before falling back to the root directory.

WAN Trainer Profiling Enhancements - Add jax.named_scope() annotations in wan_trainer.py to improve performance profiling:

  • Wrap the forward pass in a named scope for better profiling visibility
  • Wrap loss computation in a named scope to separately track loss calculation overhead
  • Enables clearer performance metrics and debugging in XLA compilation

These changes improve model loading robustness, provide better profiling granularity for performance analysis, and ensure consistent Flax behavior across the framework.

PiperOrigin-RevId: 866156279

PiperOrigin-RevId: 866156279
@michelle-yooh michelle-yooh requested a review from entrpn as a code owner February 6, 2026 01:42
@github-actions
Copy link

github-actions bot commented Feb 6, 2026

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants