@@ -51,12 +51,38 @@ def setUp(self):
5151 self .config = config
5252 devices_array = create_device_mesh (config )
5353 self .mesh = Mesh (devices_array , config .mesh_axes )
54+
55+ # Common dimensions from ltx2_parity_test.py
56+ self .batch_size = 1
57+ self .num_frames = 4
58+ self .height = 32
59+ self .width = 32
60+ self .patch_size = 1
61+ self .patch_size_t = 1
62+
63+ self .in_channels = 8
64+ self .out_channels = 8
65+ self .audio_in_channels = 4
66+
67+ # Derived
68+ self .seq_len = (self .num_frames // self .patch_size_t ) * (self .height // self .patch_size ) * (self .width // self .patch_size )
69+
70+ # Transformer config (matching parity test)
71+ self .dim = 1024
72+ self .num_heads = 8
73+ self .head_dim = 128
74+ self .cross_dim = 1024 # context dim
75+
76+ self .audio_dim = 1024
77+ self .audio_num_heads = 8
78+ self .audio_head_dim = 128
79+ self .audio_cross_dim = 1024
5480
5581 def test_ltx2_rope (self ):
5682 """Tests LTX2RotaryPosEmbed output shapes and basic functionality."""
57- dim = 64
58- patch_size = 1
59- patch_size_t = 1
83+ dim = self . dim
84+ patch_size = self . patch_size
85+ patch_size_t = self . patch_size_t
6086 base_num_frames = 8
6187 base_height = 32
6288 base_width = 32
@@ -91,15 +117,14 @@ def test_ltx2_rope(self):
91117 cos , sin = rope (ids )
92118
93119 # Check output shape
94- # dim=64, so output should be (1, 10, 64)
95- self .assertEqual (cos .shape , (1 , 10 , 64 ))
96- self .assertEqual (sin .shape , (1 , 10 , 64 ))
120+ self .assertEqual (cos .shape , (1 , 10 , dim ))
121+ self .assertEqual (sin .shape , (1 , 10 , dim ))
97122
98123 def test_ltx2_ada_layer_norm_single (self ):
99124 """Tests LTX2AdaLayerNormSingle initialization and execution."""
100125 key = jax .random .key (0 )
101126 rngs = nnx .Rngs (key )
102- embedding_dim = 128
127+ embedding_dim = self . dim
103128
104129 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
105130 layer = LTX2AdaLayerNormSingle (
@@ -110,7 +135,7 @@ def test_ltx2_ada_layer_norm_single(self):
110135 )
111136
112137 timestep = jnp .array ([1.0 ])
113- batch_size = 1
138+ batch_size = self . batch_size
114139
115140 # Forward
116141 output , embedded_timestep = layer (timestep )
@@ -125,33 +150,33 @@ def test_ltx2_transformer_block(self):
125150 key = jax .random .key (0 )
126151 rngs = nnx .Rngs (key )
127152
128- dim = 64
129- audio_dim = 32
130- cross_attention_dim = 128
131- audio_cross_attention_dim = 128 # usually same as context
153+ dim = self . dim
154+ audio_dim = self . audio_dim
155+ cross_attention_dim = self . cross_dim
156+ audio_cross_attention_dim = self . audio_cross_dim
132157
133158 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
134159 block = LTX2VideoTransformerBlock (
135160 rngs = rngs ,
136161 dim = dim ,
137- num_attention_heads = 4 ,
138- attention_head_dim = 16 ,
162+ num_attention_heads = self . num_heads ,
163+ attention_head_dim = self . head_dim ,
139164 cross_attention_dim = cross_attention_dim ,
140165 audio_dim = audio_dim ,
141- audio_num_attention_heads = 4 ,
142- audio_attention_head_dim = 8 ,
166+ audio_num_attention_heads = self . audio_num_heads ,
167+ audio_attention_head_dim = self . audio_head_dim ,
143168 audio_cross_attention_dim = audio_cross_attention_dim ,
144169 mesh = self .mesh
145170 )
146171
147- batch_size = 1
148- seq_len = 8
149- audio_seq_len = 4
172+ batch_size = self . batch_size
173+ seq_len = self . seq_len
174+ audio_seq_len = 128 # Matching parity test
150175
151176 hidden_states = jnp .zeros ((batch_size , seq_len , dim ))
152177 audio_hidden_states = jnp .zeros ((batch_size , audio_seq_len , audio_dim ))
153- encoder_hidden_states = jnp .zeros ((batch_size , 10 , cross_attention_dim ))
154- audio_encoder_hidden_states = jnp .zeros ((batch_size , 10 , audio_cross_attention_dim ))
178+ encoder_hidden_states = jnp .zeros ((batch_size , 128 , cross_attention_dim ))
179+ audio_encoder_hidden_states = jnp .zeros ((batch_size , 128 , audio_cross_attention_dim ))
155180
156181 # Mock modulation parameters
157182 # sizes based on `transformer_ltx2.py` logic
@@ -185,54 +210,54 @@ def test_ltx2_transformer_model(self):
185210 key = jax .random .key (0 )
186211 rngs = nnx .Rngs (key )
187212
188- in_channels = 128
189- out_channels = 128
190- audio_in_channels = 64
213+ in_channels = self . in_channels
214+ out_channels = self . out_channels
215+ audio_in_channels = self . audio_in_channels
191216
192217 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
193218 model = LTX2VideoTransformer3DModel (
194219 rngs = rngs ,
195220 in_channels = in_channels ,
196221 out_channels = out_channels ,
197- patch_size = 1 ,
198- patch_size_t = 1 ,
199- num_attention_heads = 4 ,
200- attention_head_dim = 16 ,
201- cross_attention_dim = 64 ,
202- caption_channels = 32 ,
222+ patch_size = self . patch_size ,
223+ patch_size_t = self . patch_size_t ,
224+ num_attention_heads = self . num_heads ,
225+ attention_head_dim = self . head_dim ,
226+ cross_attention_dim = self . cross_dim ,
227+ caption_channels = 32 , # kept small for now, or match parity if needed
203228 audio_in_channels = audio_in_channels ,
204229 audio_out_channels = audio_in_channels ,
205- audio_num_attention_heads = 4 ,
206- audio_attention_head_dim = 16 ,
207- audio_cross_attention_dim = 64 ,
230+ audio_num_attention_heads = self . audio_num_heads ,
231+ audio_attention_head_dim = self . audio_head_dim ,
232+ audio_cross_attention_dim = self . audio_cross_dim ,
208233 num_layers = 1 ,
209234 mesh = self .mesh ,
210235 attention_kernel = "dot_product" # Force dot_product for test stability on CPU/small config
211236 )
212237
213- batch_size = 1
214- seq_len = 8 # Flattened spatial-temporal tokens
215- audio_seq_len = 4
238+ batch_size = self . batch_size
239+ seq_len = self . seq_len
240+ audio_seq_len = 128
216241
217242 hidden_states = jnp .zeros ((batch_size , seq_len , in_channels ))
218243 audio_hidden_states = jnp .zeros ((batch_size , audio_seq_len , audio_in_channels ))
219244
220245 timestep = jnp .array ([1.0 ])
221- encoder_hidden_states = jnp .zeros ((batch_size , 10 , 32 )) # (B, L, D) match caption_channels
222- audio_encoder_hidden_states = jnp .zeros ((batch_size , 10 , 32 ))
246+ encoder_hidden_states = jnp .zeros ((batch_size , 128 , 32 )) # (B, L, D) match caption_channels
247+ audio_encoder_hidden_states = jnp .zeros ((batch_size , 128 , 32 ))
223248
224- encoder_attention_mask = jnp .ones ((batch_size , 10 ))
225- audio_encoder_attention_mask = jnp .ones ((batch_size , 10 ))
249+ encoder_attention_mask = jnp .ones ((batch_size , 128 ))
250+ audio_encoder_attention_mask = jnp .ones ((batch_size , 128 ))
226251
227252 output = model (
228253 hidden_states = hidden_states ,
229254 audio_hidden_states = audio_hidden_states ,
230255 encoder_hidden_states = encoder_hidden_states ,
231256 audio_encoder_hidden_states = audio_encoder_hidden_states ,
232257 timestep = timestep ,
233- num_frames = 2 ,
234- height = 2 ,
235- width = 2 ,
258+ num_frames = self . num_frames ,
259+ height = self . height ,
260+ width = self . width ,
236261 audio_num_frames = audio_seq_len ,
237262 encoder_attention_mask = encoder_attention_mask ,
238263 audio_encoder_attention_mask = audio_encoder_attention_mask ,
0 commit comments