Skip to content

Commit 6e66e84

Browse files
committed
changed input dims in unit tests
1 parent 4aa2e35 commit 6e66e84

1 file changed

Lines changed: 68 additions & 43 deletions

File tree

src/maxdiffusion/tests/ltx_2_transformer_test.py

Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,38 @@ def setUp(self):
5151
self.config = config
5252
devices_array = create_device_mesh(config)
5353
self.mesh = Mesh(devices_array, config.mesh_axes)
54+
55+
# Common dimensions from ltx2_parity_test.py
56+
self.batch_size = 1
57+
self.num_frames = 4
58+
self.height = 32
59+
self.width = 32
60+
self.patch_size = 1
61+
self.patch_size_t = 1
62+
63+
self.in_channels = 8
64+
self.out_channels = 8
65+
self.audio_in_channels = 4
66+
67+
# Derived
68+
self.seq_len = (self.num_frames // self.patch_size_t) * (self.height // self.patch_size) * (self.width // self.patch_size)
69+
70+
# Transformer config (matching parity test)
71+
self.dim = 1024
72+
self.num_heads = 8
73+
self.head_dim = 128
74+
self.cross_dim = 1024 # context dim
75+
76+
self.audio_dim = 1024
77+
self.audio_num_heads = 8
78+
self.audio_head_dim = 128
79+
self.audio_cross_dim = 1024
5480

5581
def test_ltx2_rope(self):
5682
"""Tests LTX2RotaryPosEmbed output shapes and basic functionality."""
57-
dim = 64
58-
patch_size = 1
59-
patch_size_t = 1
83+
dim = self.dim
84+
patch_size = self.patch_size
85+
patch_size_t = self.patch_size_t
6086
base_num_frames = 8
6187
base_height = 32
6288
base_width = 32
@@ -91,15 +117,14 @@ def test_ltx2_rope(self):
91117
cos, sin = rope(ids)
92118

93119
# Check output shape
94-
# dim=64, so output should be (1, 10, 64)
95-
self.assertEqual(cos.shape, (1, 10, 64))
96-
self.assertEqual(sin.shape, (1, 10, 64))
120+
self.assertEqual(cos.shape, (1, 10, dim))
121+
self.assertEqual(sin.shape, (1, 10, dim))
97122

98123
def test_ltx2_ada_layer_norm_single(self):
99124
"""Tests LTX2AdaLayerNormSingle initialization and execution."""
100125
key = jax.random.key(0)
101126
rngs = nnx.Rngs(key)
102-
embedding_dim = 128
127+
embedding_dim = self.dim
103128

104129
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
105130
layer = LTX2AdaLayerNormSingle(
@@ -110,7 +135,7 @@ def test_ltx2_ada_layer_norm_single(self):
110135
)
111136

112137
timestep = jnp.array([1.0])
113-
batch_size = 1
138+
batch_size = self.batch_size
114139

115140
# Forward
116141
output, embedded_timestep = layer(timestep)
@@ -125,33 +150,33 @@ def test_ltx2_transformer_block(self):
125150
key = jax.random.key(0)
126151
rngs = nnx.Rngs(key)
127152

128-
dim = 64
129-
audio_dim = 32
130-
cross_attention_dim = 128
131-
audio_cross_attention_dim = 128 # usually same as context
153+
dim = self.dim
154+
audio_dim = self.audio_dim
155+
cross_attention_dim = self.cross_dim
156+
audio_cross_attention_dim = self.audio_cross_dim
132157

133158
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
134159
block = LTX2VideoTransformerBlock(
135160
rngs=rngs,
136161
dim=dim,
137-
num_attention_heads=4,
138-
attention_head_dim=16,
162+
num_attention_heads=self.num_heads,
163+
attention_head_dim=self.head_dim,
139164
cross_attention_dim=cross_attention_dim,
140165
audio_dim=audio_dim,
141-
audio_num_attention_heads=4,
142-
audio_attention_head_dim=8,
166+
audio_num_attention_heads=self.audio_num_heads,
167+
audio_attention_head_dim=self.audio_head_dim,
143168
audio_cross_attention_dim=audio_cross_attention_dim,
144169
mesh=self.mesh
145170
)
146171

147-
batch_size = 1
148-
seq_len = 8
149-
audio_seq_len = 4
172+
batch_size = self.batch_size
173+
seq_len = self.seq_len
174+
audio_seq_len = 128 # Matching parity test
150175

151176
hidden_states = jnp.zeros((batch_size, seq_len, dim))
152177
audio_hidden_states = jnp.zeros((batch_size, audio_seq_len, audio_dim))
153-
encoder_hidden_states = jnp.zeros((batch_size, 10, cross_attention_dim))
154-
audio_encoder_hidden_states = jnp.zeros((batch_size, 10, audio_cross_attention_dim))
178+
encoder_hidden_states = jnp.zeros((batch_size, 128, cross_attention_dim))
179+
audio_encoder_hidden_states = jnp.zeros((batch_size, 128, audio_cross_attention_dim))
155180

156181
# Mock modulation parameters
157182
# sizes based on `transformer_ltx2.py` logic
@@ -185,54 +210,54 @@ def test_ltx2_transformer_model(self):
185210
key = jax.random.key(0)
186211
rngs = nnx.Rngs(key)
187212

188-
in_channels = 128
189-
out_channels = 128
190-
audio_in_channels = 64
213+
in_channels = self.in_channels
214+
out_channels = self.out_channels
215+
audio_in_channels = self.audio_in_channels
191216

192217
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
193218
model = LTX2VideoTransformer3DModel(
194219
rngs=rngs,
195220
in_channels=in_channels,
196221
out_channels=out_channels,
197-
patch_size=1,
198-
patch_size_t=1,
199-
num_attention_heads=4,
200-
attention_head_dim=16,
201-
cross_attention_dim=64,
202-
caption_channels=32,
222+
patch_size=self.patch_size,
223+
patch_size_t=self.patch_size_t,
224+
num_attention_heads=self.num_heads,
225+
attention_head_dim=self.head_dim,
226+
cross_attention_dim=self.cross_dim,
227+
caption_channels=32, # kept small for now, or match parity if needed
203228
audio_in_channels=audio_in_channels,
204229
audio_out_channels=audio_in_channels,
205-
audio_num_attention_heads=4,
206-
audio_attention_head_dim=16,
207-
audio_cross_attention_dim=64,
230+
audio_num_attention_heads=self.audio_num_heads,
231+
audio_attention_head_dim=self.audio_head_dim,
232+
audio_cross_attention_dim=self.audio_cross_dim,
208233
num_layers=1,
209234
mesh=self.mesh,
210235
attention_kernel="dot_product" # Force dot_product for test stability on CPU/small config
211236
)
212237

213-
batch_size = 1
214-
seq_len = 8 # Flattened spatial-temporal tokens
215-
audio_seq_len = 4
238+
batch_size = self.batch_size
239+
seq_len = self.seq_len
240+
audio_seq_len = 128
216241

217242
hidden_states = jnp.zeros((batch_size, seq_len, in_channels))
218243
audio_hidden_states = jnp.zeros((batch_size, audio_seq_len, audio_in_channels))
219244

220245
timestep = jnp.array([1.0])
221-
encoder_hidden_states = jnp.zeros((batch_size, 10, 32)) # (B, L, D) match caption_channels
222-
audio_encoder_hidden_states = jnp.zeros((batch_size, 10, 32))
246+
encoder_hidden_states = jnp.zeros((batch_size, 128, 32)) # (B, L, D) match caption_channels
247+
audio_encoder_hidden_states = jnp.zeros((batch_size, 128, 32))
223248

224-
encoder_attention_mask = jnp.ones((batch_size, 10))
225-
audio_encoder_attention_mask = jnp.ones((batch_size, 10))
249+
encoder_attention_mask = jnp.ones((batch_size, 128))
250+
audio_encoder_attention_mask = jnp.ones((batch_size, 128))
226251

227252
output = model(
228253
hidden_states=hidden_states,
229254
audio_hidden_states=audio_hidden_states,
230255
encoder_hidden_states=encoder_hidden_states,
231256
audio_encoder_hidden_states=audio_encoder_hidden_states,
232257
timestep=timestep,
233-
num_frames=2,
234-
height=2,
235-
width=2,
258+
num_frames=self.num_frames,
259+
height=self.height,
260+
width=self.width,
236261
audio_num_frames=audio_seq_len,
237262
encoder_attention_mask=encoder_attention_mask,
238263
audio_encoder_attention_mask=audio_encoder_attention_mask,

0 commit comments

Comments
 (0)