@@ -596,5 +596,242 @@ def convert_weight(pt_key_base, jax_key):
596596 print (f"Audio Std: { jnp .std (max_audio_sample )} " )
597597
598598
599+ def test_import_parity_comparison_split (self ):
600+ """
601+ Verifies that the LTX2VideoTransformer3DModel output matches the PyTorch implementation output
602+ exactly (to a high precision) given the same inputs and weights, with rope_type='split'.
603+ """
604+ from flax import traverse_util
605+
606+ parity_file = "ltx2_parity_data_split.pt"
607+ if not os .path .exists (parity_file ):
608+ print (f"Skipping parity test: { parity_file } not found. Run diffusers test first." )
609+ return
610+
611+ print (f"Loading { parity_file } ..." )
612+ parity_data = torch .load (parity_file )
613+ state_dict = parity_data ["state_dict" ]
614+ inputs = parity_data ["inputs" ]
615+ torch_outputs = parity_data ["outputs" ]
616+ config = parity_data ["config" ]
617+
618+ # 1. Instantiate Model
619+ # Ensure config matches what was exported
620+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
621+ model = LTX2VideoTransformer3DModel (
622+ rngs = nnx .Rngs (0 ),
623+ in_channels = config ["in_channels" ],
624+ out_channels = config ["out_channels" ],
625+ patch_size = config ["patch_size" ],
626+ patch_size_t = 1 ,
627+ num_attention_heads = 8 ,
628+ attention_head_dim = 128 ,
629+ cross_attention_dim = 1024 , # Parity config
630+ caption_channels = config ["caption_channels" ],
631+ audio_in_channels = 4 ,
632+ audio_out_channels = 4 ,
633+ audio_patch_size = 1 ,
634+ audio_patch_size_t = 1 ,
635+ audio_num_attention_heads = 8 ,
636+ audio_attention_head_dim = 128 ,
637+ audio_cross_attention_dim = 1024 ,
638+ num_layers = 1 ,
639+ mesh = self .mesh ,
640+ attention_kernel = "dot_product" ,
641+ rope_type = "split" ,
642+ )
643+
644+ # 2. Convert Weights (PyTorch -> Flax NNX)
645+ print ("Converting weights (Split)..." )
646+
647+ graph_def , state = nnx .split (model )
648+ flat_state = traverse_util .flatten_dict (state .to_pure_dict ())
649+ new_flat_state = {}
650+
651+ # Helper to convert/transpose weights
652+ def convert_weight (pt_key_base , jax_key ):
653+ # Try original key first
654+ pt_key = pt_key_base
655+
656+ # Map JAX 'kernel' to PT 'weight'
657+ if "kernel" in str (jax_key ):
658+ pt_key = pt_key .replace ("kernel" , "weight" )
659+
660+ # Fix scale logic (RMSNorm)
661+ # Only replace 'scale' if it's the parameter name (last part) to avoid breaking 'scale_shift'
662+ if jax_key [- 1 ] == "scale" and "scale_shift" not in str (jax_key ):
663+ pt_key = pt_key .replace ("scale" , "weight" )
664+
665+ # Fix transformer_blocks prefix
666+ # JAX: ('transformer_blocks', 'attn1', ...)
667+ # PT: transformer_blocks.0.attn1...
668+ is_transformer_block = "transformer_blocks" in str (jax_key )
669+ if is_transformer_block :
670+ if "transformer_blocks" in pt_key and "transformer_blocks.0" not in pt_key :
671+ pt_key = pt_key .replace ("transformer_blocks" , "transformer_blocks.0" )
672+
673+ # Fix `layers` keyword in JAX key usually implies `layers.0` if it was there?
674+ if "layers" in pt_key :
675+ pt_key = pt_key .replace ("layers." , "" )
676+
677+ # Fix to_out (Diffusers has to_out[0] as Linear)
678+ if "to_out" in pt_key and ("weight" in pt_key or "bias" in pt_key ):
679+ pt_key = pt_key .replace ("to_out.weight" , "to_out.0.weight" )
680+ pt_key = pt_key .replace ("to_out.bias" , "to_out.0.bias" )
681+
682+ # Fix FeedForward (net_0 -> net.0.proj, net_2 -> net.2)
683+ if "net_0" in pt_key :
684+ pt_key = pt_key .replace ("net_0" , "net.0.proj" )
685+ if "net_2" in pt_key :
686+ pt_key = pt_key .replace ("net_2" , "net.2" )
687+
688+ if pt_key not in state_dict :
689+ # Try removing .0 if it was added erroneously
690+ candidates = [pt_key ]
691+ if "transformer_blocks.0" in pt_key :
692+ candidates .append (pt_key .replace ("transformer_blocks.0" , "transformer_blocks" ))
693+
694+ # Special Case: scale_shift_table
695+ # Only allow global scale_shift_table fallback if NOT inside transformer block
696+ if "scale_shift_table" in str (jax_key ) and not is_transformer_block :
697+ candidates .append ("scale_shift_table" )
698+
699+ if "audio_scale_shift_table" in str (jax_key ) and not is_transformer_block :
700+ candidates .append ("audio_scale_shift_table" )
701+
702+ for c in candidates :
703+ if c in state_dict :
704+ pt_key = c
705+ break
706+ else :
707+ # If unmapped bias, maybe it's just missing in PT (e.g. RMSNorm without bias)
708+ if "bias" in str (jax_key ):
709+ # Initialize to zeros?
710+ print (f"Warning: Missing PT bias for { jax_key } . initializing to zeros." )
711+ # Use shape from current flat_state param
712+ return jnp .zeros (flat_state [jax_key ].shape ), pt_key
713+
714+ return None , pt_key
715+
716+ w = state_dict [pt_key ].cpu ().numpy ()
717+
718+ # Debug Special Parameters
719+ if "scale_shift_table" in str (jax_key ):
720+ print (f"Mapping scale_shift_table for { jax_key } from { pt_key } with shape { w .shape } " )
721+
722+ # Handle vmap/scan dimension for transformer_blocks
723+ if is_transformer_block :
724+ # JAX expects (num_layers, ...) for these weights
725+ # PT has (...)
726+ # So expand dims
727+ w = w [None , ...]
728+
729+ # Handle Transforms
730+ is_kernel = "kernel" in str (jax_key )
731+ # Embedding projections are also 'kernel' in JAX (Linear)
732+ if is_kernel :
733+ if w .ndim == 3 : # (1, out, in) -> (1, in, out)
734+ w = w .transpose (0 , 2 , 1 )
735+ elif w .ndim == 2 : # (out, in) -> (in, out)
736+ w = w .T
737+
738+ return jnp .array (w ), pt_key
739+
740+ total_count = len (flat_state )
741+ mapped_count = 0
742+
743+ # Debug: Print available keys for audio_ff
744+ print ("Debugging PT keys for mapping failure diagnosis (Split):" )
745+ print ("Available PT keys with 'ff':" , [k for k in state_dict .keys () if "ff" in k ])
746+ print ("Available PT keys with 'norm':" , [k for k in state_dict .keys () if "norm" in k ])
747+
748+ for key in flat_state .keys ():
749+ # Construct base PT key from JAX key tuple
750+ pt_key_base = "." .join ([str (k ) for k in key if str (k ) != "layers" ])
751+
752+ w , used_pt_key = convert_weight (pt_key_base , key )
753+ if w is not None :
754+ # Handle bias zero init which might return scalar 0 if shape was (1,) but it should be array
755+ # jnp.zeros(shape) returns array.
756+ new_flat_state [key ] = w
757+ mapped_count += 1
758+ else :
759+ print (f"Warning: Could not map JAX key { key } (PT attempt: { used_pt_key } )" )
760+ if "audio_ff" in str (key ):
761+ print ("Available audio_ff keys:" , [k for k in state_dict .keys () if "audio_ff" in k ])
762+ if "norm_out" in str (key ):
763+ print ("Available norm_out keys:" , [k for k in state_dict .keys () if "norm_out" in k ])
764+
765+ print (f"Mapped { mapped_count } /{ total_count } params." )
766+
767+ # Update model state
768+ new_state = traverse_util .unflatten_dict (new_flat_state )
769+ nnx .update (model , new_state )
770+
771+ # 3. Prepare Inputs
772+ jax_inputs = {
773+ "hidden_states" : jnp .array (inputs ["hidden_states" ].cpu ().numpy ()),
774+ "audio_hidden_states" : jnp .array (inputs ["audio_hidden_states" ].cpu ().numpy ()),
775+ "encoder_hidden_states" : jnp .array (inputs ["encoder_hidden_states" ].cpu ().numpy ()),
776+ "audio_encoder_hidden_states" : jnp .array (inputs ["audio_encoder_hidden_states" ].cpu ().numpy ()),
777+ "timestep" : jnp .array (inputs ["timestep" ].cpu ().numpy ()),
778+ "encoder_attention_mask" : jnp .array (inputs ["encoder_attention_mask" ].cpu ().numpy ()),
779+ "audio_encoder_attention_mask" : jnp .array (inputs ["audio_encoder_attention_mask" ].cpu ().numpy ()),
780+ }
781+
782+ print ("\n === Input Verification (Split) ===" )
783+ print (f"Hidden States Sum: { jnp .sum (jax_inputs ['hidden_states' ])} " )
784+ print (f"Audio Hidden States Sum: { jnp .sum (jax_inputs ['audio_hidden_states' ])} " )
785+ print (f"Encoder Hidden States Sum: { jnp .sum (jax_inputs ['encoder_hidden_states' ])} " )
786+ print (f"Audio Encoder Hidden States Sum: { jnp .sum (jax_inputs ['audio_encoder_hidden_states' ])} " )
787+ print (f"Timestep: { jax_inputs ['timestep' ]} " )
788+ print ("==========================\n " )
789+
790+ # 4. Run Forward
791+ print ("Running MaxDiffusion forward pass (Split)..." )
792+ output = model (
793+ hidden_states = jax_inputs ["hidden_states" ],
794+ audio_hidden_states = jax_inputs ["audio_hidden_states" ],
795+ encoder_hidden_states = jax_inputs ["encoder_hidden_states" ],
796+ audio_encoder_hidden_states = jax_inputs ["audio_encoder_hidden_states" ],
797+ timestep = jax_inputs ["timestep" ],
798+ encoder_attention_mask = jax_inputs ["encoder_attention_mask" ],
799+ audio_encoder_attention_mask = jax_inputs ["audio_encoder_attention_mask" ],
800+ num_frames = config ["num_frames" ] if "num_frames" in config else 4 ,
801+ height = config ["height" ] if "height" in config else 32 ,
802+ width = config ["width" ] if "width" in config else 32 ,
803+ audio_num_frames = 128 ,
804+ fps = 24.0 ,
805+ return_dict = True ,
806+ )
807+
808+ max_sample = output ["sample" ]
809+ max_audio_sample = output ["audio_sample" ]
810+
811+ print ("MAXDIFF Output Sample Stats (Split):" )
812+ print (f"Sample Max: { jnp .max (max_sample )} " )
813+ print (f"Sample Min: { jnp .min (max_sample )} " )
814+ print (f"Sample Mean: { jnp .mean (max_sample )} " )
815+ print (f"Sample Std: { jnp .std (max_sample )} " )
816+
817+ print ("MAXDIFF Output Audio Sample Stats (Split):" )
818+ print (f"Audio Max: { jnp .max (max_audio_sample )} " )
819+ print (f"Audio Min: { jnp .min (max_audio_sample )} " )
820+ print (f"Audio Mean: { jnp .mean (max_audio_sample )} " )
821+ print (f"Audio Std: { jnp .std (max_audio_sample )} " )
822+
823+ # 5. Parity Check
824+ parity_sample = jnp .array (torch_outputs ["sample" ].cpu ().numpy ())
825+ parity_audio_sample = jnp .array (torch_outputs ["audio_sample" ].cpu ().numpy ())
826+
827+ print ("Checking Parity (Split)..." )
828+ print (f"Max Diff Sample: { jnp .max (jnp .abs (max_sample - parity_sample ))} " )
829+ print (f"Max Diff Audio: { jnp .max (jnp .abs (max_audio_sample - parity_audio_sample ))} " )
830+
831+ self .assertTrue (jnp .allclose (max_sample , parity_sample , atol = 1e-3 ))
832+ self .assertTrue (jnp .allclose (max_audio_sample , parity_audio_sample , atol = 1e-3 ))
833+ print ("Parity check passed (Split)!" )
834+
835+
599836if __name__ == "__main__" :
600837 unittest .main ()
0 commit comments