Skip to content

Commit bb7e8f0

Browse files
committed
fix
1 parent 8722641 commit bb7e8f0

1 file changed

Lines changed: 36 additions & 13 deletions

File tree

src/maxdiffusion/tests/ltx_2_transformer_test.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,28 @@ def test_transformer_3d_model_instantiation_and_forward(self):
168168
"""
169169
print("\n=== Testing LTX2VideoTransformer3DModel Integration ===")
170170

171+
# NNX sharding context
172+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
173+
model = LTX2VideoTransformer3DModel(
174+
rngs=self.rngs,
175+
in_channels=self.in_channels,
176+
out_channels=self.out_channels,
177+
patch_size=self.patch_size,
178+
patch_size_t=self.patch_size_t,
179+
num_attention_heads=2,
180+
attention_head_dim=8,
181+
num_layers=1, # 1 layer for speed
182+
caption_channels=32, # small for test
183+
cross_attention_dim=32,
184+
audio_in_channels=self.audio_in_channels,
185+
audio_out_channels= self.audio_in_channels,
186+
audio_num_attention_heads=2,
187+
audio_attention_head_dim=8,
188+
audio_cross_attention_dim=32
189+
)
190+
191+
# Inputs
192+
# hidden_states: (B, F, H, W, C) or (B, L, C)?
171193
# Diffusers `forward` takes `hidden_states` usually as latents.
172194
# If it's 3D, it might expect (B, C, F, H, W) or (B, F, C, H, W)?
173195
# Checking `transformer_ltx2.py` `__call__` Line 680:
@@ -202,19 +224,20 @@ def test_transformer_3d_model_instantiation_and_forward(self):
202224
audio_encoder_hidden_states = jnp.zeros((self.batch_size, 5, 32))
203225

204226
# Forward
205-
output = model(
206-
hidden_states=hidden_states,
207-
audio_hidden_states=audio_hidden_states,
208-
encoder_hidden_states=encoder_hidden_states,
209-
audio_encoder_hidden_states=audio_encoder_hidden_states,
210-
timestep=timestep,
211-
num_frames=self.num_frames,
212-
height=self.height,
213-
width=self.width,
214-
audio_num_frames=10,
215-
fps=24.0,
216-
return_dict=True
217-
)
227+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
228+
output = model(
229+
hidden_states=hidden_states,
230+
audio_hidden_states=audio_hidden_states,
231+
encoder_hidden_states=encoder_hidden_states,
232+
audio_encoder_hidden_states=audio_encoder_hidden_states,
233+
timestep=timestep,
234+
num_frames=self.num_frames,
235+
height=self.height,
236+
width=self.width,
237+
audio_num_frames=10,
238+
fps=24.0,
239+
return_dict=True
240+
)
218241

219242
sample = output["sample"]
220243
audio_sample = output["audio_sample"]

0 commit comments

Comments
 (0)