@@ -72,7 +72,6 @@ def __init__(
7272 audio_attention_head_dim : int ,
7373 audio_cross_attention_dim : int ,
7474 activation_fn : str = "gelu" ,
75- qk_norm : str = "rms_norm_across_heads" ,
7675 attention_bias : bool = True ,
7776 attention_out_bias : bool = True ,
7877 norm_elementwise_affine : bool = False ,
@@ -435,7 +434,6 @@ def __init__(
435434 audio_hop_length : int = 160 ,
436435 num_layers : int = 48 , # Shared arguments
437436 activation_fn : str = "gelu" ,
438- qk_norm : str = "rms_norm_across_heads" ,
439437 norm_elementwise_affine : bool = False ,
440438 norm_eps : float = 1e-6 ,
441439 caption_channels : int = 3840 ,
@@ -481,7 +479,6 @@ def __init__(
481479 self .audio_hop_length = audio_hop_length
482480 self .num_layers = num_layers
483481 self .activation_fn = activation_fn
484- self .qk_norm = qk_norm
485482 self .norm_elementwise_affine = norm_elementwise_affine
486483 self .norm_eps = norm_eps
487484 self .caption_channels = caption_channels
@@ -644,7 +641,6 @@ def init_block(rngs):
644641 audio_attention_head_dim = self .audio_attention_head_dim ,
645642 audio_cross_attention_dim = audio_inner_dim ,
646643 activation_fn = self .activation_fn ,
647- qk_norm = self .qk_norm ,
648644 attention_bias = self .attention_bias ,
649645 attention_out_bias = self .attention_out_bias ,
650646 norm_elementwise_affine = self .norm_elementwise_affine ,
@@ -676,7 +672,6 @@ def init_block(rngs):
676672 audio_attention_head_dim = self .audio_attention_head_dim ,
677673 audio_cross_attention_dim = audio_inner_dim ,
678674 activation_fn = self .activation_fn ,
679- qk_norm = self .qk_norm ,
680675 attention_bias = self .attention_bias ,
681676 attention_out_bias = self .attention_out_bias ,
682677 norm_elementwise_affine = self .norm_elementwise_affine ,
0 commit comments