@@ -415,6 +415,7 @@ def test_import_parity_comparison(self):
415415 num_layers = 1 ,
416416 mesh = self .mesh ,
417417 attention_kernel = "dot_product" ,
418+ rope_type = "split"
418419 )
419420
420421 # 2. Convert Weights (PyTorch -> Flax NNX)
@@ -597,250 +598,5 @@ def convert_weight(pt_key_base, jax_key):
597598 print (f"Audio Std: { jnp .std (max_audio_sample )} " )
598599
599600
600- def test_import_parity_comparison_split (self ):
601- """
602- Verifies that the LTX2VideoTransformer3DModel output matches the PyTorch implementation output
603- exactly (to a high precision) given the same inputs and weights, with rope_type='split'.
604- """
605- print ("\n === Parity Comparison Test (Split) ===" )
606- try :
607- import torch
608- except ImportError :
609- print ("Skipping parity test: torch not installed." )
610- return
611-
612- import os
613- from flax import traverse_util
614-
615- parity_file = "ltx2_parity_data_split.pt"
616- if not os .path .exists (parity_file ):
617- print (f"Skipping parity test: { parity_file } not found. Run diffusers test first." )
618- return
619-
620- print (f"Loading { parity_file } ..." )
621- parity_data = torch .load (parity_file )
622- state_dict = parity_data ["state_dict" ]
623- inputs = parity_data ["inputs" ]
624- torch_outputs = parity_data ["outputs" ]
625- config = parity_data ["config" ]
626-
627- # 1. Instantiate Model
628- # Ensure config matches what was exported
629- with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
630- model = LTX2VideoTransformer3DModel (
631- rngs = nnx .Rngs (0 ),
632- in_channels = config ["in_channels" ],
633- out_channels = config ["out_channels" ],
634- patch_size = config ["patch_size" ],
635- patch_size_t = 1 ,
636- num_attention_heads = 8 ,
637- attention_head_dim = 128 ,
638- cross_attention_dim = 1024 , # Parity config
639- caption_channels = config ["caption_channels" ],
640- audio_in_channels = 4 ,
641- audio_out_channels = 4 ,
642- audio_patch_size = 1 ,
643- audio_patch_size_t = 1 ,
644- audio_num_attention_heads = 8 ,
645- audio_attention_head_dim = 128 ,
646- audio_cross_attention_dim = 1024 ,
647- num_layers = 1 ,
648- mesh = self .mesh ,
649- attention_kernel = "dot_product" ,
650- rope_type = "split" ,
651- )
652-
653- # 2. Convert Weights (PyTorch -> Flax NNX)
654- print ("Converting weights (Split)..." )
655-
656- graph_def , state = nnx .split (model )
657- flat_state = traverse_util .flatten_dict (state .to_pure_dict ())
658- new_flat_state = {}
659-
660- # Helper to convert/transpose weights
661- def convert_weight (pt_key_base , jax_key ):
662- # Try original key first
663- pt_key = pt_key_base
664-
665- # Map JAX 'kernel' to PT 'weight'
666- if "kernel" in str (jax_key ):
667- pt_key = pt_key .replace ("kernel" , "weight" )
668-
669- # Fix scale logic (RMSNorm)
670- # Only replace 'scale' if it's the parameter name (last part) to avoid breaking 'scale_shift'
671- if jax_key [- 1 ] == "scale" and "scale_shift" not in str (jax_key ):
672- pt_key = pt_key .replace ("scale" , "weight" )
673-
674- # Fix transformer_blocks prefix
675- # JAX: ('transformer_blocks', 'attn1', ...)
676- # PT: transformer_blocks.0.attn1...
677- is_transformer_block = "transformer_blocks" in str (jax_key )
678- if is_transformer_block :
679- if "transformer_blocks" in pt_key and "transformer_blocks.0" not in pt_key :
680- pt_key = pt_key .replace ("transformer_blocks" , "transformer_blocks.0" )
681-
682- # Fix `layers` keyword in JAX key usually implies `layers.0` if it was there?
683- if "layers" in pt_key :
684- pt_key = pt_key .replace ("layers." , "" )
685-
686- # Fix to_out (Diffusers has to_out[0] as Linear)
687- if "to_out" in pt_key and ("weight" in pt_key or "bias" in pt_key ):
688- pt_key = pt_key .replace ("to_out.weight" , "to_out.0.weight" )
689- pt_key = pt_key .replace ("to_out.bias" , "to_out.0.bias" )
690-
691- # Fix FeedForward (net_0 -> net.0.proj, net_2 -> net.2)
692- if "net_0" in pt_key :
693- pt_key = pt_key .replace ("net_0" , "net.0.proj" )
694- if "net_2" in pt_key :
695- pt_key = pt_key .replace ("net_2" , "net.2" )
696-
697- if pt_key not in state_dict :
698- # Try removing .0 if it was added erroneously
699- candidates = [pt_key ]
700- if "transformer_blocks.0" in pt_key :
701- candidates .append (pt_key .replace ("transformer_blocks.0" , "transformer_blocks" ))
702-
703- # Special Case: scale_shift_table
704- # Only allow global scale_shift_table fallback if NOT inside transformer block
705- if "scale_shift_table" in str (jax_key ) and not is_transformer_block :
706- candidates .append ("scale_shift_table" )
707-
708- if "audio_scale_shift_table" in str (jax_key ) and not is_transformer_block :
709- candidates .append ("audio_scale_shift_table" )
710-
711- for c in candidates :
712- if c in state_dict :
713- pt_key = c
714- break
715- else :
716- # If unmapped bias, maybe it's just missing in PT (e.g. RMSNorm without bias)
717- if "bias" in str (jax_key ):
718- # Initialize to zeros?
719- print (f"Warning: Missing PT bias for { jax_key } . initializing to zeros." )
720- # Use shape from current flat_state param
721- return jnp .zeros (flat_state [jax_key ].shape ), pt_key
722-
723- return None , pt_key
724-
725- w = state_dict [pt_key ].cpu ().numpy ()
726-
727- # Debug Special Parameters
728- if "scale_shift_table" in str (jax_key ):
729- print (f"Mapping scale_shift_table for { jax_key } from { pt_key } with shape { w .shape } " )
730-
731- # Handle vmap/scan dimension for transformer_blocks
732- if is_transformer_block :
733- # JAX expects (num_layers, ...) for these weights
734- # PT has (...)
735- # So expand dims
736- w = w [None , ...]
737-
738- # Handle Transforms
739- is_kernel = "kernel" in str (jax_key )
740- # Embedding projections are also 'kernel' in JAX (Linear)
741- if is_kernel :
742- if w .ndim == 3 : # (1, out, in) -> (1, in, out)
743- w = w .transpose (0 , 2 , 1 )
744- elif w .ndim == 2 : # (out, in) -> (in, out)
745- w = w .T
746-
747- return jnp .array (w ), pt_key
748-
749- total_count = len (flat_state )
750- mapped_count = 0
751-
752- # Debug: Print available keys for audio_ff
753- print ("Debugging PT keys for mapping failure diagnosis (Split):" )
754- print ("Available PT keys with 'ff':" , [k for k in state_dict .keys () if "ff" in k ])
755- print ("Available PT keys with 'norm':" , [k for k in state_dict .keys () if "norm" in k ])
756-
757- for key in flat_state .keys ():
758- # Construct base PT key from JAX key tuple
759- pt_key_base = "." .join ([str (k ) for k in key if str (k ) != "layers" ])
760-
761- w , used_pt_key = convert_weight (pt_key_base , key )
762- if w is not None :
763- # Handle bias zero init which might return scalar 0 if shape was (1,) but it should be array
764- # jnp.zeros(shape) returns array.
765- new_flat_state [key ] = w
766- mapped_count += 1
767- else :
768- print (f"Warning: Could not map JAX key { key } (PT attempt: { used_pt_key } )" )
769- if "audio_ff" in str (key ):
770- print ("Available audio_ff keys:" , [k for k in state_dict .keys () if "audio_ff" in k ])
771- if "norm_out" in str (key ):
772- print ("Available norm_out keys:" , [k for k in state_dict .keys () if "norm_out" in k ])
773-
774- print (f"Mapped { mapped_count } /{ total_count } params." )
775-
776- # Update model state
777- new_state = traverse_util .unflatten_dict (new_flat_state )
778- nnx .update (model , new_state )
779-
780- # 3. Prepare Inputs
781- jax_inputs = {
782- "hidden_states" : jnp .array (inputs ["hidden_states" ].cpu ().numpy ()),
783- "audio_hidden_states" : jnp .array (inputs ["audio_hidden_states" ].cpu ().numpy ()),
784- "encoder_hidden_states" : jnp .array (inputs ["encoder_hidden_states" ].cpu ().numpy ()),
785- "audio_encoder_hidden_states" : jnp .array (inputs ["audio_encoder_hidden_states" ].cpu ().numpy ()),
786- "timestep" : jnp .array (inputs ["timestep" ].cpu ().numpy ()),
787- "encoder_attention_mask" : jnp .array (inputs ["encoder_attention_mask" ].cpu ().numpy ()),
788- "audio_encoder_attention_mask" : jnp .array (inputs ["audio_encoder_attention_mask" ].cpu ().numpy ()),
789- }
790-
791- print ("\n === Input Verification (Split) ===" )
792- print (f"Hidden States Sum: { jnp .sum (jax_inputs ['hidden_states' ])} " )
793- print (f"Audio Hidden States Sum: { jnp .sum (jax_inputs ['audio_hidden_states' ])} " )
794- print (f"Encoder Hidden States Sum: { jnp .sum (jax_inputs ['encoder_hidden_states' ])} " )
795- print (f"Audio Encoder Hidden States Sum: { jnp .sum (jax_inputs ['audio_encoder_hidden_states' ])} " )
796- print (f"Timestep: { jax_inputs ['timestep' ]} " )
797- print ("==========================\n " )
798-
799- # 4. Run Forward
800- print ("Running MaxDiffusion forward pass (Split)..." )
801- output = model (
802- hidden_states = jax_inputs ["hidden_states" ],
803- audio_hidden_states = jax_inputs ["audio_hidden_states" ],
804- encoder_hidden_states = jax_inputs ["encoder_hidden_states" ],
805- audio_encoder_hidden_states = jax_inputs ["audio_encoder_hidden_states" ],
806- timestep = jax_inputs ["timestep" ],
807- encoder_attention_mask = jax_inputs ["encoder_attention_mask" ],
808- audio_encoder_attention_mask = jax_inputs ["audio_encoder_attention_mask" ],
809- num_frames = config ["num_frames" ] if "num_frames" in config else 4 ,
810- height = config ["height" ] if "height" in config else 32 ,
811- width = config ["width" ] if "width" in config else 32 ,
812- audio_num_frames = 128 ,
813- fps = 24.0 ,
814- return_dict = True ,
815- )
816-
817- max_sample = output ["sample" ]
818- max_audio_sample = output ["audio_sample" ]
819-
820- print ("MAXDIFF Output Sample Stats (Split):" )
821- print (f"Sample Max: { jnp .max (max_sample )} " )
822- print (f"Sample Min: { jnp .min (max_sample )} " )
823- print (f"Sample Mean: { jnp .mean (max_sample )} " )
824- print (f"Sample Std: { jnp .std (max_sample )} " )
825-
826- print ("MAXDIFF Output Audio Sample Stats (Split):" )
827- print (f"Audio Max: { jnp .max (max_audio_sample )} " )
828- print (f"Audio Min: { jnp .min (max_audio_sample )} " )
829- print (f"Audio Mean: { jnp .mean (max_audio_sample )} " )
830- print (f"Audio Std: { jnp .std (max_audio_sample )} " )
831-
832- # 5. Parity Check
833- parity_sample = jnp .array (torch_outputs ["sample" ].cpu ().numpy ())
834- parity_audio_sample = jnp .array (torch_outputs ["audio_sample" ].cpu ().numpy ())
835-
836- print ("Checking Parity (Split)..." )
837- print (f"Max Diff Sample: { jnp .max (jnp .abs (max_sample - parity_sample ))} " )
838- print (f"Max Diff Audio: { jnp .max (jnp .abs (max_audio_sample - parity_audio_sample ))} " )
839-
840- self .assertTrue (jnp .allclose (max_sample , parity_sample , atol = 1e-3 ))
841- self .assertTrue (jnp .allclose (max_audio_sample , parity_audio_sample , atol = 1e-3 ))
842- print ("Parity check passed (Split)!" )
843-
844-
845601if __name__ == "__main__" :
846602 unittest .main ()
0 commit comments