Skip to content

Commit ea84281

Browse files
committed
test for rope type split added
1 parent 1fd5338 commit ea84281

1 file changed

Lines changed: 237 additions & 0 deletions

File tree

src/maxdiffusion/tests/ltx2_parity_test.py

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,5 +596,242 @@ def convert_weight(pt_key_base, jax_key):
596596
print(f"Audio Std: {jnp.std(max_audio_sample)}")
597597

598598

599+
def test_import_parity_comparison_split(self):
600+
"""
601+
Verifies that the LTX2VideoTransformer3DModel output matches the PyTorch implementation output
602+
exactly (to a high precision) given the same inputs and weights, with rope_type='split'.
603+
"""
604+
from flax import traverse_util
605+
606+
parity_file = "ltx2_parity_data_split.pt"
607+
if not os.path.exists(parity_file):
608+
print(f"Skipping parity test: {parity_file} not found. Run diffusers test first.")
609+
return
610+
611+
print(f"Loading {parity_file}...")
612+
parity_data = torch.load(parity_file)
613+
state_dict = parity_data["state_dict"]
614+
inputs = parity_data["inputs"]
615+
torch_outputs = parity_data["outputs"]
616+
config = parity_data["config"]
617+
618+
# 1. Instantiate Model
619+
# Ensure config matches what was exported
620+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
621+
model = LTX2VideoTransformer3DModel(
622+
rngs=nnx.Rngs(0),
623+
in_channels=config["in_channels"],
624+
out_channels=config["out_channels"],
625+
patch_size=config["patch_size"],
626+
patch_size_t=1,
627+
num_attention_heads=8,
628+
attention_head_dim=128,
629+
cross_attention_dim=1024, # Parity config
630+
caption_channels=config["caption_channels"],
631+
audio_in_channels=4,
632+
audio_out_channels=4,
633+
audio_patch_size=1,
634+
audio_patch_size_t=1,
635+
audio_num_attention_heads=8,
636+
audio_attention_head_dim=128,
637+
audio_cross_attention_dim=1024,
638+
num_layers=1,
639+
mesh=self.mesh,
640+
attention_kernel="dot_product",
641+
rope_type="split",
642+
)
643+
644+
# 2. Convert Weights (PyTorch -> Flax NNX)
645+
print("Converting weights (Split)...")
646+
647+
graph_def, state = nnx.split(model)
648+
flat_state = traverse_util.flatten_dict(state.to_pure_dict())
649+
new_flat_state = {}
650+
651+
# Helper to convert/transpose weights
652+
def convert_weight(pt_key_base, jax_key):
653+
# Try original key first
654+
pt_key = pt_key_base
655+
656+
# Map JAX 'kernel' to PT 'weight'
657+
if "kernel" in str(jax_key):
658+
pt_key = pt_key.replace("kernel", "weight")
659+
660+
# Fix scale logic (RMSNorm)
661+
# Only replace 'scale' if it's the parameter name (last part) to avoid breaking 'scale_shift'
662+
if jax_key[-1] == "scale" and "scale_shift" not in str(jax_key):
663+
pt_key = pt_key.replace("scale", "weight")
664+
665+
# Fix transformer_blocks prefix
666+
# JAX: ('transformer_blocks', 'attn1', ...)
667+
# PT: transformer_blocks.0.attn1...
668+
is_transformer_block = "transformer_blocks" in str(jax_key)
669+
if is_transformer_block:
670+
if "transformer_blocks" in pt_key and "transformer_blocks.0" not in pt_key:
671+
pt_key = pt_key.replace("transformer_blocks", "transformer_blocks.0")
672+
673+
# Fix `layers` keyword in JAX key usually implies `layers.0` if it was there?
674+
if "layers" in pt_key:
675+
pt_key = pt_key.replace("layers.", "")
676+
677+
# Fix to_out (Diffusers has to_out[0] as Linear)
678+
if "to_out" in pt_key and ("weight" in pt_key or "bias" in pt_key):
679+
pt_key = pt_key.replace("to_out.weight", "to_out.0.weight")
680+
pt_key = pt_key.replace("to_out.bias", "to_out.0.bias")
681+
682+
# Fix FeedForward (net_0 -> net.0.proj, net_2 -> net.2)
683+
if "net_0" in pt_key:
684+
pt_key = pt_key.replace("net_0", "net.0.proj")
685+
if "net_2" in pt_key:
686+
pt_key = pt_key.replace("net_2", "net.2")
687+
688+
if pt_key not in state_dict:
689+
# Try removing .0 if it was added erroneously
690+
candidates = [pt_key]
691+
if "transformer_blocks.0" in pt_key:
692+
candidates.append(pt_key.replace("transformer_blocks.0", "transformer_blocks"))
693+
694+
# Special Case: scale_shift_table
695+
# Only allow global scale_shift_table fallback if NOT inside transformer block
696+
if "scale_shift_table" in str(jax_key) and not is_transformer_block:
697+
candidates.append("scale_shift_table")
698+
699+
if "audio_scale_shift_table" in str(jax_key) and not is_transformer_block:
700+
candidates.append("audio_scale_shift_table")
701+
702+
for c in candidates:
703+
if c in state_dict:
704+
pt_key = c
705+
break
706+
else:
707+
# If unmapped bias, maybe it's just missing in PT (e.g. RMSNorm without bias)
708+
if "bias" in str(jax_key):
709+
# Initialize to zeros?
710+
print(f"Warning: Missing PT bias for {jax_key}. initializing to zeros.")
711+
# Use shape from current flat_state param
712+
return jnp.zeros(flat_state[jax_key].shape), pt_key
713+
714+
return None, pt_key
715+
716+
w = state_dict[pt_key].cpu().numpy()
717+
718+
# Debug Special Parameters
719+
if "scale_shift_table" in str(jax_key):
720+
print(f"Mapping scale_shift_table for {jax_key} from {pt_key} with shape {w.shape}")
721+
722+
# Handle vmap/scan dimension for transformer_blocks
723+
if is_transformer_block:
724+
# JAX expects (num_layers, ...) for these weights
725+
# PT has (...)
726+
# So expand dims
727+
w = w[None, ...]
728+
729+
# Handle Transforms
730+
is_kernel = "kernel" in str(jax_key)
731+
# Embedding projections are also 'kernel' in JAX (Linear)
732+
if is_kernel:
733+
if w.ndim == 3: # (1, out, in) -> (1, in, out)
734+
w = w.transpose(0, 2, 1)
735+
elif w.ndim == 2: # (out, in) -> (in, out)
736+
w = w.T
737+
738+
return jnp.array(w), pt_key
739+
740+
total_count = len(flat_state)
741+
mapped_count = 0
742+
743+
# Debug: Print available keys for audio_ff
744+
print("Debugging PT keys for mapping failure diagnosis (Split):")
745+
print("Available PT keys with 'ff':", [k for k in state_dict.keys() if "ff" in k])
746+
print("Available PT keys with 'norm':", [k for k in state_dict.keys() if "norm" in k])
747+
748+
for key in flat_state.keys():
749+
# Construct base PT key from JAX key tuple
750+
pt_key_base = ".".join([str(k) for k in key if str(k) != "layers"])
751+
752+
w, used_pt_key = convert_weight(pt_key_base, key)
753+
if w is not None:
754+
# Handle bias zero init which might return scalar 0 if shape was (1,) but it should be array
755+
# jnp.zeros(shape) returns array.
756+
new_flat_state[key] = w
757+
mapped_count += 1
758+
else:
759+
print(f"Warning: Could not map JAX key {key} (PT attempt: {used_pt_key})")
760+
if "audio_ff" in str(key):
761+
print("Available audio_ff keys:", [k for k in state_dict.keys() if "audio_ff" in k])
762+
if "norm_out" in str(key):
763+
print("Available norm_out keys:", [k for k in state_dict.keys() if "norm_out" in k])
764+
765+
print(f"Mapped {mapped_count}/{total_count} params.")
766+
767+
# Update model state
768+
new_state = traverse_util.unflatten_dict(new_flat_state)
769+
nnx.update(model, new_state)
770+
771+
# 3. Prepare Inputs
772+
jax_inputs = {
773+
"hidden_states": jnp.array(inputs["hidden_states"].cpu().numpy()),
774+
"audio_hidden_states": jnp.array(inputs["audio_hidden_states"].cpu().numpy()),
775+
"encoder_hidden_states": jnp.array(inputs["encoder_hidden_states"].cpu().numpy()),
776+
"audio_encoder_hidden_states": jnp.array(inputs["audio_encoder_hidden_states"].cpu().numpy()),
777+
"timestep": jnp.array(inputs["timestep"].cpu().numpy()),
778+
"encoder_attention_mask": jnp.array(inputs["encoder_attention_mask"].cpu().numpy()),
779+
"audio_encoder_attention_mask": jnp.array(inputs["audio_encoder_attention_mask"].cpu().numpy()),
780+
}
781+
782+
print("\n=== Input Verification (Split) ===")
783+
print(f"Hidden States Sum: {jnp.sum(jax_inputs['hidden_states'])}")
784+
print(f"Audio Hidden States Sum: {jnp.sum(jax_inputs['audio_hidden_states'])}")
785+
print(f"Encoder Hidden States Sum: {jnp.sum(jax_inputs['encoder_hidden_states'])}")
786+
print(f"Audio Encoder Hidden States Sum: {jnp.sum(jax_inputs['audio_encoder_hidden_states'])}")
787+
print(f"Timestep: {jax_inputs['timestep']}")
788+
print("==========================\n")
789+
790+
# 4. Run Forward
791+
print("Running MaxDiffusion forward pass (Split)...")
792+
output = model(
793+
hidden_states=jax_inputs["hidden_states"],
794+
audio_hidden_states=jax_inputs["audio_hidden_states"],
795+
encoder_hidden_states=jax_inputs["encoder_hidden_states"],
796+
audio_encoder_hidden_states=jax_inputs["audio_encoder_hidden_states"],
797+
timestep=jax_inputs["timestep"],
798+
encoder_attention_mask=jax_inputs["encoder_attention_mask"],
799+
audio_encoder_attention_mask=jax_inputs["audio_encoder_attention_mask"],
800+
num_frames=config["num_frames"] if "num_frames" in config else 4,
801+
height=config["height"] if "height" in config else 32,
802+
width=config["width"] if "width" in config else 32,
803+
audio_num_frames=128,
804+
fps=24.0,
805+
return_dict=True,
806+
)
807+
808+
max_sample = output["sample"]
809+
max_audio_sample = output["audio_sample"]
810+
811+
print("MAXDIFF Output Sample Stats (Split):")
812+
print(f"Sample Max: {jnp.max(max_sample)}")
813+
print(f"Sample Min: {jnp.min(max_sample)}")
814+
print(f"Sample Mean: {jnp.mean(max_sample)}")
815+
print(f"Sample Std: {jnp.std(max_sample)}")
816+
817+
print("MAXDIFF Output Audio Sample Stats (Split):")
818+
print(f"Audio Max: {jnp.max(max_audio_sample)}")
819+
print(f"Audio Min: {jnp.min(max_audio_sample)}")
820+
print(f"Audio Mean: {jnp.mean(max_audio_sample)}")
821+
print(f"Audio Std: {jnp.std(max_audio_sample)}")
822+
823+
# 5. Parity Check
824+
parity_sample = jnp.array(torch_outputs["sample"].cpu().numpy())
825+
parity_audio_sample = jnp.array(torch_outputs["audio_sample"].cpu().numpy())
826+
827+
print("Checking Parity (Split)...")
828+
print(f"Max Diff Sample: {jnp.max(jnp.abs(max_sample - parity_sample))}")
829+
print(f"Max Diff Audio: {jnp.max(jnp.abs(max_audio_sample - parity_audio_sample))}")
830+
831+
self.assertTrue(jnp.allclose(max_sample, parity_sample, atol=1e-3))
832+
self.assertTrue(jnp.allclose(max_audio_sample, parity_audio_sample, atol=1e-3))
833+
print("Parity check passed (Split)!")
834+
835+
599836
if __name__ == "__main__":
600837
unittest.main()

0 commit comments

Comments
 (0)