1+ import torch
2+ import jax
3+ import jax .numpy as jnp
4+ import numpy as np
5+
6+ # 1. ALIAS THE IMPORTS to prevent name collisions!
7+ from diffusers .pipelines .ltx2 .latent_upsampler import LTX2LatentUpsamplerModel as PT_Upsampler
8+ from maxdiffusion .models .ltx2 .latent_upsampler_ltx2 import LTX2LatentUpsamplerModel as JAX_Upsampler
9+ from maxdiffusion .models .ltx2 .ltx2_utils import load_upsampler_weights
10+
11+ def test_side_by_side ():
12+ # --- Setup PyTorch ---
13+ print ("Initializing PyTorch Model..." )
14+ # Load the real pretrained weights
15+ pt_model = PT_Upsampler .from_pretrained ("Lightricks/LTX-2" , subfolder = "latent_upsampler" )
16+ pt_model .eval ()
17+
18+ # --- Setup JAX ---
19+ print ("Initializing JAX Model..." )
20+ jax_model = JAX_Upsampler ()
21+
22+ print ("Loading JAX Weights from HuggingFace..." )
23+ # Use your actual conversion script to load the exact same weights
24+ flax_params = load_upsampler_weights (
25+ pretrained_model_name_or_path = "Lightricks/LTX-2" ,
26+ eval_shapes = None ,
27+ device = "cpu" , # Load into CPU for comparison
28+ subfolder = "latent_upsampler"
29+ )
30+
31+ # for key, value in jax.tree_util.tree_flatten(flax_params)[0]:
32+ # if hasattr(value, 'dtype'):
33+ # print(f"{key}: {value.dtype}, shape: {value.shape}")
34+
35+ # --- Generate Identical Dummy Data ---
36+ # Shape: Batch=1, Channels=128, Frames=8, Height=32, Width=32
37+ print ("Generating identical random inputs..." )
38+ torch .manual_seed (42 )
39+ pt_input = torch .randn (1 , 128 , 8 , 32 , 32 , dtype = torch .float32 )
40+
41+ # Convert PyTorch NCDHW -> JAX NDHWC
42+ # (0, 2, 3, 4, 1) maps (B, C, F, H, W) -> (B, F, H, W, C)
43+ jax_input_np = pt_input .permute (0 , 2 , 3 , 4 , 1 ).numpy ()
44+ jax_input = jnp .array (jax_input_np )
45+
46+ # --- Run Forward Passes ---
47+ print ("Running PyTorch pass..." )
48+ with torch .no_grad ():
49+ pt_output = pt_model (pt_input )
50+
51+ print ("Running JAX pass..." )
52+ jax_output = jax_model .apply ({'params' : flax_params }, jax_input )
53+
54+ # --- Compare Results ---
55+ # Convert JAX output back to PyTorch shape: NDHWC -> NCDHW
56+ # (0, 4, 1, 2, 3) maps (B, F, H, W, C) -> (B, C, F, H, W)
57+ jax_output_converted = torch .tensor (np .array (jax_output )).permute (0 , 4 , 1 , 2 , 3 )
58+
59+ # Calculate Mean Squared Error (MSE) and Max Absolute Difference
60+ mse = torch .nn .functional .mse_loss (pt_output , jax_output_converted )
61+ max_diff = (pt_output - jax_output_converted ).abs ().max ()
62+
63+ print ("\n " + "=" * 30 )
64+ print (" COMPARISON RESULTS " )
65+ print ("=" * 30 )
66+ print (f"Mean Squared Error: { mse .item ():.8f} " )
67+ print (f"Max Absolute Error: { max_diff .item ():.8f} " )
68+
69+ if max_diff .item () < 1e-3 :
70+ print ("\n ✅ SUCCESS: The models are mathematically identical!" )
71+ else :
72+ print ("\n ❌ FAILED: The models diverge. There is a bug in the math/weights." )
73+
74+ if __name__ == "__main__" :
75+ test_side_by_side ()
0 commit comments