@@ -168,6 +168,28 @@ def test_transformer_3d_model_instantiation_and_forward(self):
168168 """
169169 print ("\n === Testing LTX2VideoTransformer3DModel Integration ===" )
170170
171+ # NNX sharding context
172+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
173+ model = LTX2VideoTransformer3DModel (
174+ rngs = self .rngs ,
175+ in_channels = self .in_channels ,
176+ out_channels = self .out_channels ,
177+ patch_size = self .patch_size ,
178+ patch_size_t = self .patch_size_t ,
179+ num_attention_heads = 2 ,
180+ attention_head_dim = 8 ,
181+ num_layers = 1 , # 1 layer for speed
182+ caption_channels = 32 , # small for test
183+ cross_attention_dim = 32 ,
184+ audio_in_channels = self .audio_in_channels ,
185+ audio_out_channels = self .audio_in_channels ,
186+ audio_num_attention_heads = 2 ,
187+ audio_attention_head_dim = 8 ,
188+ audio_cross_attention_dim = 32
189+ )
190+
191+ # Inputs
192+ # hidden_states: (B, F, H, W, C) or (B, L, C)?
171193 # Diffusers `forward` takes `hidden_states` usually as latents.
172194 # If it's 3D, it might expect (B, C, F, H, W) or (B, F, C, H, W)?
173195 # Checking `transformer_ltx2.py` `__call__` Line 680:
@@ -202,19 +224,20 @@ def test_transformer_3d_model_instantiation_and_forward(self):
202224 audio_encoder_hidden_states = jnp .zeros ((self .batch_size , 5 , 32 ))
203225
204226 # Forward
205- output = model (
206- hidden_states = hidden_states ,
207- audio_hidden_states = audio_hidden_states ,
208- encoder_hidden_states = encoder_hidden_states ,
209- audio_encoder_hidden_states = audio_encoder_hidden_states ,
210- timestep = timestep ,
211- num_frames = self .num_frames ,
212- height = self .height ,
213- width = self .width ,
214- audio_num_frames = 10 ,
215- fps = 24.0 ,
216- return_dict = True
217- )
227+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
228+ output = model (
229+ hidden_states = hidden_states ,
230+ audio_hidden_states = audio_hidden_states ,
231+ encoder_hidden_states = encoder_hidden_states ,
232+ audio_encoder_hidden_states = audio_encoder_hidden_states ,
233+ timestep = timestep ,
234+ num_frames = self .num_frames ,
235+ height = self .height ,
236+ width = self .width ,
237+ audio_num_frames = 10 ,
238+ fps = 24.0 ,
239+ return_dict = True
240+ )
218241
219242 sample = output ["sample" ]
220243 audio_sample = output ["audio_sample" ]
0 commit comments