Skip to content

Commit c1446ba

Browse files
committed
removed unused arg qk_norm
1 parent 04a31f0 commit c1446ba

2 files changed

Lines changed: 1 addition & 5 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def __init__(
7272
audio_attention_head_dim: int,
7373
audio_cross_attention_dim: int,
7474
activation_fn: str = "gelu",
75-
qk_norm: str = "rms_norm_across_heads",
7675
attention_bias: bool = True,
7776
attention_out_bias: bool = True,
7877
norm_elementwise_affine: bool = False,
@@ -435,7 +434,6 @@ def __init__(
435434
audio_hop_length: int = 160,
436435
num_layers: int = 48, # Shared arguments
437436
activation_fn: str = "gelu",
438-
qk_norm: str = "rms_norm_across_heads",
439437
norm_elementwise_affine: bool = False,
440438
norm_eps: float = 1e-6,
441439
caption_channels: int = 3840,
@@ -481,7 +479,6 @@ def __init__(
481479
self.audio_hop_length = audio_hop_length
482480
self.num_layers = num_layers
483481
self.activation_fn = activation_fn
484-
self.qk_norm = qk_norm
485482
self.norm_elementwise_affine = norm_elementwise_affine
486483
self.norm_eps = norm_eps
487484
self.caption_channels = caption_channels
@@ -644,7 +641,6 @@ def init_block(rngs):
644641
audio_attention_head_dim=self.audio_attention_head_dim,
645642
audio_cross_attention_dim=audio_inner_dim,
646643
activation_fn=self.activation_fn,
647-
qk_norm=self.qk_norm,
648644
attention_bias=self.attention_bias,
649645
attention_out_bias=self.attention_out_bias,
650646
norm_elementwise_affine=self.norm_elementwise_affine,
@@ -676,7 +672,6 @@ def init_block(rngs):
676672
audio_attention_head_dim=self.audio_attention_head_dim,
677673
audio_cross_attention_dim=audio_inner_dim,
678674
activation_fn=self.activation_fn,
679-
qk_norm=self.qk_norm,
680675
attention_bias=self.attention_bias,
681676
attention_out_bias=self.attention_out_bias,
682677
norm_elementwise_affine=self.norm_elementwise_affine,

src/maxdiffusion/tests/ltx2_parity_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def test_transformer_block_shapes(self):
8686
audio_cross_attention_dim=cross_dim,
8787
activation_fn="gelu",
8888
qk_norm="rms_norm_across_heads",
89+
qk_norm="rms_norm_across_heads",
8990
mesh=self.mesh,
9091
)
9192

0 commit comments

Comments
 (0)