Skip to content

Commit ba47dd4

Browse files
committed
Upsampler pipeline
1 parent a0f171d commit ba47dd4

File tree

7 files changed

+836
-25
lines changed

7 files changed

+836
-25
lines changed

src/maxdiffusion/checkpointing/ltx2_checkpointer.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,15 @@
1919
import numpy as np
2020
from typing import Optional, Tuple
2121
from maxdiffusion.pipelines.ltx2.ltx2_pipeline import LTX2Pipeline
22-
from maxdiffusion import max_logging
22+
from maxdiffusion.models.ltx2.transformer_ltx2 import LTX2VideoTransformer3DModel
23+
from maxdiffusion.models.ltx2.autoencoder_kl_ltx2 import LTX2VideoAutoencoderKL
24+
from maxdiffusion.models.ltx2.autoencoder_kl_ltx2_audio import FlaxAutoencoderKLLTX2Audio
25+
from maxdiffusion.models.ltx2.text_encoders.text_encoders_ltx2 import LTX2AudioVideoGemmaTextEncoder
26+
from maxdiffusion.models.ltx2.vocoder_ltx2 import LTX2Vocoder
27+
from maxdiffusion.schedulers.scheduling_flow_match_flax import FlaxFlowMatchScheduler
28+
from maxdiffusion.models.ltx2.ltx2_utils import load_upsampler_weights
29+
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
30+
from maxdiffusion import max_logging, max_utils
2331
from maxdiffusion.checkpointing.checkpointing_utils import create_orbax_checkpoint_manager
2432
import orbax.checkpoint as ocp
2533
from etils import epath
@@ -95,6 +103,23 @@ def load_checkpoint(
95103

96104
return pipeline, opt_state, step
97105

106+
def load_upsampler(self, upsampler_model_path: str, eval_shapes: dict = None) -> dict:
107+
"""
108+
Uses the central utils file to load the upsampler weights.
109+
"""
110+
# Assuming standard Hugging Face format (e.g., path/to/latent_upsampler/model.safetensors)
111+
max_logging.log("Loading Latent Upsampler from checkpoint...")
112+
113+
flax_params = load_upsampler_weights(
114+
pretrained_model_name_or_path=upsampler_model_path,
115+
eval_shapes=eval_shapes,
116+
device=jax.devices()[0].platform,
117+
subfolder="latent_upsampler"
118+
)
119+
120+
return flax_params
121+
122+
98123
def save_checkpoint(self, train_step, pipeline: LTX2Pipeline, train_states: dict):
99124
"""Saves the training state and model configurations."""
100125

src/maxdiffusion/compare.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
import jax
3+
import jax.numpy as jnp
4+
import numpy as np
5+
6+
# 1. ALIAS THE IMPORTS to prevent name collisions!
7+
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel as PT_Upsampler
8+
from maxdiffusion.models.ltx2.latent_upsampler_ltx2 import LTX2LatentUpsamplerModel as JAX_Upsampler
9+
from maxdiffusion.models.ltx2.ltx2_utils import load_upsampler_weights
10+
11+
def test_side_by_side():
12+
# --- Setup PyTorch ---
13+
print("Initializing PyTorch Model...")
14+
# Load the real pretrained weights
15+
pt_model = PT_Upsampler.from_pretrained("Lightricks/LTX-2", subfolder="latent_upsampler")
16+
pt_model.eval()
17+
18+
# --- Setup JAX ---
19+
print("Initializing JAX Model...")
20+
jax_model = JAX_Upsampler()
21+
22+
print("Loading JAX Weights from HuggingFace...")
23+
# Use your actual conversion script to load the exact same weights
24+
flax_params = load_upsampler_weights(
25+
pretrained_model_name_or_path="Lightricks/LTX-2",
26+
eval_shapes=None,
27+
device="cpu", # Load into CPU for comparison
28+
subfolder="latent_upsampler"
29+
)
30+
31+
# for key, value in jax.tree_util.tree_flatten(flax_params)[0]:
32+
# if hasattr(value, 'dtype'):
33+
# print(f"{key}: {value.dtype}, shape: {value.shape}")
34+
35+
# --- Generate Identical Dummy Data ---
36+
# Shape: Batch=1, Channels=128, Frames=8, Height=32, Width=32
37+
print("Generating identical random inputs...")
38+
torch.manual_seed(42)
39+
pt_input = torch.randn(1, 128, 8, 32, 32, dtype=torch.float32)
40+
41+
# Convert PyTorch NCDHW -> JAX NDHWC
42+
# (0, 2, 3, 4, 1) maps (B, C, F, H, W) -> (B, F, H, W, C)
43+
jax_input_np = pt_input.permute(0, 2, 3, 4, 1).numpy()
44+
jax_input = jnp.array(jax_input_np)
45+
46+
# --- Run Forward Passes ---
47+
print("Running PyTorch pass...")
48+
with torch.no_grad():
49+
pt_output = pt_model(pt_input)
50+
51+
print("Running JAX pass...")
52+
jax_output = jax_model.apply({'params': flax_params}, jax_input)
53+
54+
# --- Compare Results ---
55+
# Convert JAX output back to PyTorch shape: NDHWC -> NCDHW
56+
# (0, 4, 1, 2, 3) maps (B, F, H, W, C) -> (B, C, F, H, W)
57+
jax_output_converted = torch.tensor(np.array(jax_output)).permute(0, 4, 1, 2, 3)
58+
59+
# Calculate Mean Squared Error (MSE) and Max Absolute Difference
60+
mse = torch.nn.functional.mse_loss(pt_output, jax_output_converted)
61+
max_diff = (pt_output - jax_output_converted).abs().max()
62+
63+
print("\n" + "="*30)
64+
print(" COMPARISON RESULTS ")
65+
print("="*30)
66+
print(f"Mean Squared Error: {mse.item():.8f}")
67+
print(f"Max Absolute Error: {max_diff.item():.8f}")
68+
69+
if max_diff.item() < 1e-3:
70+
print("\n✅ SUCCESS: The models are mathematically identical!")
71+
else:
72+
print("\n❌ FAILED: The models diverge. There is a bug in the math/weights.")
73+
74+
if __name__ == "__main__":
75+
test_side_by_side()

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ names_which_can_be_saved: []
99
names_which_can_be_offloaded: []
1010
remat_policy: "NONE"
1111

12-
jax_cache_dir: ''
12+
jax_cache_dir: '/mnt/disks/mehdy-disk1/maxdiffusion_hf_cache'
1313
weights_dtype: 'bfloat16'
1414
activations_dtype: 'bfloat16'
1515

@@ -92,3 +92,12 @@ jit_initializers: True
9292
enable_single_replica_ckpt_restoring: False
9393
seed: 0
9494
audio_format: "s16"
95+
96+
# LTX-2 Latent Upsampler
97+
run_latent_upsampler: False
98+
upsampler_model_path: "Lightricks/LTX-2"
99+
upsampler_spatial_patch_size: 1
100+
upsampler_temporal_patch_size: 1
101+
upsampler_adain_factor: 0.0
102+
upsampler_tone_map_compression_ratio: 0.0
103+
upsampler_rational_spatial_scale: 2.0

0 commit comments

Comments
 (0)