Skip to content

Commit df8a5fc

Browse files
committed
testing with rope_type = split
1 parent c4dcd64 commit df8a5fc

File tree

1 file changed

+1
-245
lines changed

1 file changed

+1
-245
lines changed

src/maxdiffusion/tests/ltx2_parity_test.py

Lines changed: 1 addition & 245 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ def test_import_parity_comparison(self):
415415
num_layers=1,
416416
mesh=self.mesh,
417417
attention_kernel="dot_product",
418+
rope_type="split"
418419
)
419420

420421
# 2. Convert Weights (PyTorch -> Flax NNX)
@@ -597,250 +598,5 @@ def convert_weight(pt_key_base, jax_key):
597598
print(f"Audio Std: {jnp.std(max_audio_sample)}")
598599

599600

600-
def test_import_parity_comparison_split(self):
601-
"""
602-
Verifies that the LTX2VideoTransformer3DModel output matches the PyTorch implementation output
603-
exactly (to a high precision) given the same inputs and weights, with rope_type='split'.
604-
"""
605-
print("\n=== Parity Comparison Test (Split) ===")
606-
try:
607-
import torch
608-
except ImportError:
609-
print("Skipping parity test: torch not installed.")
610-
return
611-
612-
import os
613-
from flax import traverse_util
614-
615-
parity_file = "ltx2_parity_data_split.pt"
616-
if not os.path.exists(parity_file):
617-
print(f"Skipping parity test: {parity_file} not found. Run diffusers test first.")
618-
return
619-
620-
print(f"Loading {parity_file}...")
621-
parity_data = torch.load(parity_file)
622-
state_dict = parity_data["state_dict"]
623-
inputs = parity_data["inputs"]
624-
torch_outputs = parity_data["outputs"]
625-
config = parity_data["config"]
626-
627-
# 1. Instantiate Model
628-
# Ensure config matches what was exported
629-
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
630-
model = LTX2VideoTransformer3DModel(
631-
rngs=nnx.Rngs(0),
632-
in_channels=config["in_channels"],
633-
out_channels=config["out_channels"],
634-
patch_size=config["patch_size"],
635-
patch_size_t=1,
636-
num_attention_heads=8,
637-
attention_head_dim=128,
638-
cross_attention_dim=1024, # Parity config
639-
caption_channels=config["caption_channels"],
640-
audio_in_channels=4,
641-
audio_out_channels=4,
642-
audio_patch_size=1,
643-
audio_patch_size_t=1,
644-
audio_num_attention_heads=8,
645-
audio_attention_head_dim=128,
646-
audio_cross_attention_dim=1024,
647-
num_layers=1,
648-
mesh=self.mesh,
649-
attention_kernel="dot_product",
650-
rope_type="split",
651-
)
652-
653-
# 2. Convert Weights (PyTorch -> Flax NNX)
654-
print("Converting weights (Split)...")
655-
656-
graph_def, state = nnx.split(model)
657-
flat_state = traverse_util.flatten_dict(state.to_pure_dict())
658-
new_flat_state = {}
659-
660-
# Helper to convert/transpose weights
661-
def convert_weight(pt_key_base, jax_key):
662-
# Try original key first
663-
pt_key = pt_key_base
664-
665-
# Map JAX 'kernel' to PT 'weight'
666-
if "kernel" in str(jax_key):
667-
pt_key = pt_key.replace("kernel", "weight")
668-
669-
# Fix scale logic (RMSNorm)
670-
# Only replace 'scale' if it's the parameter name (last part) to avoid breaking 'scale_shift'
671-
if jax_key[-1] == "scale" and "scale_shift" not in str(jax_key):
672-
pt_key = pt_key.replace("scale", "weight")
673-
674-
# Fix transformer_blocks prefix
675-
# JAX: ('transformer_blocks', 'attn1', ...)
676-
# PT: transformer_blocks.0.attn1...
677-
is_transformer_block = "transformer_blocks" in str(jax_key)
678-
if is_transformer_block:
679-
if "transformer_blocks" in pt_key and "transformer_blocks.0" not in pt_key:
680-
pt_key = pt_key.replace("transformer_blocks", "transformer_blocks.0")
681-
682-
# Fix `layers` keyword in JAX key usually implies `layers.0` if it was there?
683-
if "layers" in pt_key:
684-
pt_key = pt_key.replace("layers.", "")
685-
686-
# Fix to_out (Diffusers has to_out[0] as Linear)
687-
if "to_out" in pt_key and ("weight" in pt_key or "bias" in pt_key):
688-
pt_key = pt_key.replace("to_out.weight", "to_out.0.weight")
689-
pt_key = pt_key.replace("to_out.bias", "to_out.0.bias")
690-
691-
# Fix FeedForward (net_0 -> net.0.proj, net_2 -> net.2)
692-
if "net_0" in pt_key:
693-
pt_key = pt_key.replace("net_0", "net.0.proj")
694-
if "net_2" in pt_key:
695-
pt_key = pt_key.replace("net_2", "net.2")
696-
697-
if pt_key not in state_dict:
698-
# Try removing .0 if it was added erroneously
699-
candidates = [pt_key]
700-
if "transformer_blocks.0" in pt_key:
701-
candidates.append(pt_key.replace("transformer_blocks.0", "transformer_blocks"))
702-
703-
# Special Case: scale_shift_table
704-
# Only allow global scale_shift_table fallback if NOT inside transformer block
705-
if "scale_shift_table" in str(jax_key) and not is_transformer_block:
706-
candidates.append("scale_shift_table")
707-
708-
if "audio_scale_shift_table" in str(jax_key) and not is_transformer_block:
709-
candidates.append("audio_scale_shift_table")
710-
711-
for c in candidates:
712-
if c in state_dict:
713-
pt_key = c
714-
break
715-
else:
716-
# If unmapped bias, maybe it's just missing in PT (e.g. RMSNorm without bias)
717-
if "bias" in str(jax_key):
718-
# Initialize to zeros?
719-
print(f"Warning: Missing PT bias for {jax_key}. initializing to zeros.")
720-
# Use shape from current flat_state param
721-
return jnp.zeros(flat_state[jax_key].shape), pt_key
722-
723-
return None, pt_key
724-
725-
w = state_dict[pt_key].cpu().numpy()
726-
727-
# Debug Special Parameters
728-
if "scale_shift_table" in str(jax_key):
729-
print(f"Mapping scale_shift_table for {jax_key} from {pt_key} with shape {w.shape}")
730-
731-
# Handle vmap/scan dimension for transformer_blocks
732-
if is_transformer_block:
733-
# JAX expects (num_layers, ...) for these weights
734-
# PT has (...)
735-
# So expand dims
736-
w = w[None, ...]
737-
738-
# Handle Transforms
739-
is_kernel = "kernel" in str(jax_key)
740-
# Embedding projections are also 'kernel' in JAX (Linear)
741-
if is_kernel:
742-
if w.ndim == 3: # (1, out, in) -> (1, in, out)
743-
w = w.transpose(0, 2, 1)
744-
elif w.ndim == 2: # (out, in) -> (in, out)
745-
w = w.T
746-
747-
return jnp.array(w), pt_key
748-
749-
total_count = len(flat_state)
750-
mapped_count = 0
751-
752-
# Debug: Print available keys for audio_ff
753-
print("Debugging PT keys for mapping failure diagnosis (Split):")
754-
print("Available PT keys with 'ff':", [k for k in state_dict.keys() if "ff" in k])
755-
print("Available PT keys with 'norm':", [k for k in state_dict.keys() if "norm" in k])
756-
757-
for key in flat_state.keys():
758-
# Construct base PT key from JAX key tuple
759-
pt_key_base = ".".join([str(k) for k in key if str(k) != "layers"])
760-
761-
w, used_pt_key = convert_weight(pt_key_base, key)
762-
if w is not None:
763-
# Handle bias zero init which might return scalar 0 if shape was (1,) but it should be array
764-
# jnp.zeros(shape) returns array.
765-
new_flat_state[key] = w
766-
mapped_count += 1
767-
else:
768-
print(f"Warning: Could not map JAX key {key} (PT attempt: {used_pt_key})")
769-
if "audio_ff" in str(key):
770-
print("Available audio_ff keys:", [k for k in state_dict.keys() if "audio_ff" in k])
771-
if "norm_out" in str(key):
772-
print("Available norm_out keys:", [k for k in state_dict.keys() if "norm_out" in k])
773-
774-
print(f"Mapped {mapped_count}/{total_count} params.")
775-
776-
# Update model state
777-
new_state = traverse_util.unflatten_dict(new_flat_state)
778-
nnx.update(model, new_state)
779-
780-
# 3. Prepare Inputs
781-
jax_inputs = {
782-
"hidden_states": jnp.array(inputs["hidden_states"].cpu().numpy()),
783-
"audio_hidden_states": jnp.array(inputs["audio_hidden_states"].cpu().numpy()),
784-
"encoder_hidden_states": jnp.array(inputs["encoder_hidden_states"].cpu().numpy()),
785-
"audio_encoder_hidden_states": jnp.array(inputs["audio_encoder_hidden_states"].cpu().numpy()),
786-
"timestep": jnp.array(inputs["timestep"].cpu().numpy()),
787-
"encoder_attention_mask": jnp.array(inputs["encoder_attention_mask"].cpu().numpy()),
788-
"audio_encoder_attention_mask": jnp.array(inputs["audio_encoder_attention_mask"].cpu().numpy()),
789-
}
790-
791-
print("\n=== Input Verification (Split) ===")
792-
print(f"Hidden States Sum: {jnp.sum(jax_inputs['hidden_states'])}")
793-
print(f"Audio Hidden States Sum: {jnp.sum(jax_inputs['audio_hidden_states'])}")
794-
print(f"Encoder Hidden States Sum: {jnp.sum(jax_inputs['encoder_hidden_states'])}")
795-
print(f"Audio Encoder Hidden States Sum: {jnp.sum(jax_inputs['audio_encoder_hidden_states'])}")
796-
print(f"Timestep: {jax_inputs['timestep']}")
797-
print("==========================\n")
798-
799-
# 4. Run Forward
800-
print("Running MaxDiffusion forward pass (Split)...")
801-
output = model(
802-
hidden_states=jax_inputs["hidden_states"],
803-
audio_hidden_states=jax_inputs["audio_hidden_states"],
804-
encoder_hidden_states=jax_inputs["encoder_hidden_states"],
805-
audio_encoder_hidden_states=jax_inputs["audio_encoder_hidden_states"],
806-
timestep=jax_inputs["timestep"],
807-
encoder_attention_mask=jax_inputs["encoder_attention_mask"],
808-
audio_encoder_attention_mask=jax_inputs["audio_encoder_attention_mask"],
809-
num_frames=config["num_frames"] if "num_frames" in config else 4,
810-
height=config["height"] if "height" in config else 32,
811-
width=config["width"] if "width" in config else 32,
812-
audio_num_frames=128,
813-
fps=24.0,
814-
return_dict=True,
815-
)
816-
817-
max_sample = output["sample"]
818-
max_audio_sample = output["audio_sample"]
819-
820-
print("MAXDIFF Output Sample Stats (Split):")
821-
print(f"Sample Max: {jnp.max(max_sample)}")
822-
print(f"Sample Min: {jnp.min(max_sample)}")
823-
print(f"Sample Mean: {jnp.mean(max_sample)}")
824-
print(f"Sample Std: {jnp.std(max_sample)}")
825-
826-
print("MAXDIFF Output Audio Sample Stats (Split):")
827-
print(f"Audio Max: {jnp.max(max_audio_sample)}")
828-
print(f"Audio Min: {jnp.min(max_audio_sample)}")
829-
print(f"Audio Mean: {jnp.mean(max_audio_sample)}")
830-
print(f"Audio Std: {jnp.std(max_audio_sample)}")
831-
832-
# 5. Parity Check
833-
parity_sample = jnp.array(torch_outputs["sample"].cpu().numpy())
834-
parity_audio_sample = jnp.array(torch_outputs["audio_sample"].cpu().numpy())
835-
836-
print("Checking Parity (Split)...")
837-
print(f"Max Diff Sample: {jnp.max(jnp.abs(max_sample - parity_sample))}")
838-
print(f"Max Diff Audio: {jnp.max(jnp.abs(max_audio_sample - parity_audio_sample))}")
839-
840-
self.assertTrue(jnp.allclose(max_sample, parity_sample, atol=1e-3))
841-
self.assertTrue(jnp.allclose(max_audio_sample, parity_audio_sample, atol=1e-3))
842-
print("Parity check passed (Split)!")
843-
844-
845601
if __name__ == "__main__":
846602
unittest.main()

0 commit comments

Comments
 (0)