diff --git a/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py b/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py index 455599a30f60..724f1e0964cf 100644 --- a/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py +++ b/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py @@ -39,6 +39,19 @@ def _wn_conv_transpose1d(*args, **kwargs): return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) +def _normalize_vae_strides(c_mults: list[int], strides: list[int] | None = None) -> list[int]: + default_strides = [2, 4, 4, 8, 8] + num_blocks = len(c_mults) - 1 + if strides is None: + strides = default_strides + strides = list(strides) + if len(strides) < num_blocks: + strides.extend([strides[-1] if strides else 2] * (num_blocks - len(strides))) + else: + strides = strides[:num_blocks] + return strides + + class Snake1d(nn.Module): def __init__(self, channels: int, alpha_logscale: bool = True): super().__init__() @@ -200,11 +213,7 @@ def __init__( ): super().__init__() c_mults = [1] + (c_mults or [1, 2, 4, 8, 16]) - strides = list(strides or [2] * (len(c_mults) - 1)) - if len(strides) < len(c_mults) - 1: - strides.extend([strides[-1] if strides else 2] * (len(c_mults) - 1 - len(strides))) - else: - strides = strides[: len(c_mults) - 1] + strides = _normalize_vae_strides(c_mults, strides) channels_base = channels layers = [_wn_conv1d(in_channels, c_mults[0] * channels_base, kernel_size=7, padding=3)] for idx in range(len(c_mults) - 1): @@ -249,11 +258,7 @@ def __init__( ): super().__init__() c_mults = [1] + (c_mults or [1, 2, 4, 8, 16]) - strides = list(strides or [2] * (len(c_mults) - 1)) - if len(strides) < len(c_mults) - 1: - strides.extend([strides[-1] if strides else 2] * (len(c_mults) - 1 - len(strides))) - else: - strides = strides[: len(c_mults) - 1] + strides = _normalize_vae_strides(c_mults, strides) channels_base = channels self.shortcut = ( @@ -317,6 +322,18 @@ def __init__( scale: float = 0.71, ): super().__init__() + c_mults = c_mults or [1, 2, 4, 8, 16] + normalized_strides = _normalize_vae_strides([1] + c_mults, strides) + actual_downsampling_ratio = math.prod(normalized_strides) + if actual_downsampling_ratio != downsampling_ratio: + raise ValueError( + f"`downsampling_ratio` must match the product of normalized `strides`. Got " + f"`downsampling_ratio={downsampling_ratio}` but `strides={normalized_strides}` have product " + f"{actual_downsampling_ratio}." + ) + self.register_to_config( + c_mults=c_mults, strides=normalized_strides, downsampling_ratio=actual_downsampling_ratio + ) if act_fn is None: if use_snake is None: act_fn = "snake" @@ -326,7 +343,7 @@ def __init__( in_channels=in_channels, channels=channels, c_mults=c_mults, - strides=strides, + strides=normalized_strides, latent_dim=latent_dim, encoder_latent_dim=encoder_latent_dim, act_fn=act_fn, @@ -337,7 +354,7 @@ def __init__( in_channels=in_channels, channels=channels, c_mults=c_mults, - strides=strides, + strides=normalized_strides, latent_dim=latent_dim, act_fn=act_fn, in_shortcut=in_shortcut, diff --git a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py index 2a5b169ad5ee..c35a672a9483 100644 --- a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py +++ b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py @@ -25,7 +25,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph -from ..attention import AttentionModuleMixin +from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn from ..modeling_utils import ModelMixin from ..normalization import RMSNorm @@ -228,6 +228,12 @@ def __call__( class AudioDiTAttention(nn.Module, AttentionModuleMixin): + _default_processor_cls = AudioDiTSelfAttnProcessor + _available_processors = [ + AudioDiTSelfAttnProcessor, + ] + _supports_qkv_fusion = False + def __init__( self, q_dim: int, @@ -238,12 +244,13 @@ def __init__( bias: bool = True, qk_norm: bool = False, eps: float = 1e-6, - processor: AttentionModuleMixin | None = None, + processor: "AudioDiTSelfAttnProcessor | AudioDiTCrossAttnProcessor | None" = None, ): super().__init__() kv_dim = q_dim if kv_dim is None else kv_dim self.heads = heads self.inner_dim = dim_head * heads + self.use_bias = bias self.to_q = nn.Linear(q_dim, self.inner_dim, bias=bias) self.to_k = nn.Linear(kv_dim, self.inner_dim, bias=bias) self.to_v = nn.Linear(kv_dim, self.inner_dim, bias=bias) @@ -252,7 +259,9 @@ def __init__( self.q_norm = RMSNorm(self.inner_dim, eps=eps) self.k_norm = RMSNorm(self.inner_dim, eps=eps) self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, q_dim, bias=bias), nn.Dropout(dropout)]) - self.set_processor(processor or AudioDiTSelfAttnProcessor()) + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) def forward( self, @@ -331,6 +340,9 @@ def __call__( return hidden_states +AudioDiTAttention._available_processors = [AudioDiTSelfAttnProcessor, AudioDiTCrossAttnProcessor] + + class AudioDiTFeedForward(nn.Module): def __init__(self, dim: int, mult: float = 4.0, dropout: float = 0.0, bias: bool = True): super().__init__() @@ -452,7 +464,7 @@ def forward( return hidden_states -class LongCatAudioDiTTransformer(ModelMixin, ConfigMixin): +class LongCatAudioDiTTransformer(ModelMixin, ConfigMixin, AttentionMixin): _supports_gradient_checkpointing = False _repeated_blocks = ["AudioDiTBlock"] diff --git a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py index e6478535b373..783a789b7f26 100644 --- a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py +++ b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py @@ -16,7 +16,7 @@ # https://github.com/meituan-longcat/LongCat-AudioDiT import re -from typing import Callable +from typing import Any, Callable import torch import torch.nn.functional as F @@ -32,6 +32,8 @@ logger = logging.get_logger(__name__) +PipelineCallback = Callable[[Any, int, torch.Tensor, dict[str, torch.Tensor]], dict[str, torch.Tensor]] + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -148,8 +150,7 @@ def encode_prompt(self, prompt: str | list[str], device: torch.device) -> tuple[ ) input_ids = text_inputs.input_ids.to(device) attention_mask = text_inputs.attention_mask.to(device) - with torch.no_grad(): - output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) + output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) prompt_embeds = output.last_hidden_state if self.text_norm_feat: prompt_embeds = F.layer_norm(prompt_embeds, (prompt_embeds.shape[-1],), eps=1e-6) @@ -229,7 +230,7 @@ def __call__( generator: torch.Generator | list[torch.Generator] | None = None, output_type: str = "np", return_dict: bool = True, - callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end: PipelineCallback | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], ): r""" @@ -296,9 +297,13 @@ def __call__( negative_prompt_embeds_len, length=negative_prompt_embeds.shape[1] ) - latent_cond = torch.zeros(batch_size, duration, self.latent_dim, device=device, dtype=prompt_embeds.dtype) + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(dtype=transformer_dtype) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=transformer_dtype) + + latent_cond = torch.zeros(batch_size, duration, self.latent_dim, device=device, dtype=transformer_dtype) latents = self.prepare_latents( - batch_size, duration, device, prompt_embeds.dtype, generator=generator, latents=latents + batch_size, duration, device, transformer_dtype, generator=generator, latents=latents ) if num_inference_steps < 1: raise ValueError("num_inference_steps must be a positive integer.") @@ -311,9 +316,7 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - curr_t = ( - (t / self.scheduler.config.num_train_timesteps).expand(batch_size).to(dtype=prompt_embeds.dtype) - ) + curr_t = (t / self.scheduler.config.num_train_timesteps).expand(batch_size).to(dtype=transformer_dtype) pred = self.transformer( hidden_states=latents, encoder_hidden_states=prompt_embeds, diff --git a/tests/models/autoencoders/test_models_autoencoder_longcat_audio_dit.py b/tests/models/autoencoders/test_models_autoencoder_longcat_audio_dit.py new file mode 100644 index 000000000000..5beb451459f7 --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_longcat_audio_dit.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from diffusers import LongCatAudioDiTVae + + +def test_longcat_audio_vae_default_strides_match_downsampling_ratio(): + vae = LongCatAudioDiTVae(channels=1, latent_dim=2, encoder_latent_dim=4) + + assert vae.config.strides == [2, 4, 4, 8, 8] + assert vae.config.downsampling_ratio == 2048 + + +def test_longcat_audio_vae_raises_when_downsampling_ratio_mismatches_strides(): + with pytest.raises(ValueError, match="downsampling_ratio"): + LongCatAudioDiTVae( + channels=1, + latent_dim=2, + encoder_latent_dim=4, + strides=[2, 2, 2, 2, 2], + downsampling_ratio=2048, + ) diff --git a/tests/models/transformers/test_models_transformer_longcat_audio_dit.py b/tests/models/transformers/test_models_transformer_longcat_audio_dit.py index b418a3068449..e6e820640e7c 100644 --- a/tests/models/transformers/test_models_transformer_longcat_audio_dit.py +++ b/tests/models/transformers/test_models_transformer_longcat_audio_dit.py @@ -102,10 +102,17 @@ class TestLongCatAudioDiTTransformerAttention(LongCatAudioDiTTransformerTesterCo def test_longcat_audio_attention_uses_standard_self_attn_kwargs(): - from diffusers.models.transformers.transformer_longcat_audio_dit import AudioDiTAttention + from diffusers.models.transformers.transformer_longcat_audio_dit import ( + AudioDiTAttention, + AudioDiTSelfAttnProcessor, + ) attn = AudioDiTAttention(q_dim=4, kv_dim=None, heads=1, dim_head=4, dropout=0.0, bias=False) + assert attn._default_processor_cls is AudioDiTSelfAttnProcessor + assert AudioDiTSelfAttnProcessor in attn._available_processors + assert attn.use_bias is False + eye = torch.eye(4) with torch.no_grad(): attn.to_q.weight.copy_(eye) @@ -119,3 +126,30 @@ def test_longcat_audio_attention_uses_standard_self_attn_kwargs(): output = attn(hidden_states=hidden_states, attention_mask=attention_mask) assert torch.allclose(output[:, 1], torch.zeros_like(output[:, 1])) + + +def test_longcat_audio_attention_direct_fuse_projections_noops(): + from diffusers.models.transformers.transformer_longcat_audio_dit import AudioDiTAttention + + attn = AudioDiTAttention(q_dim=4, kv_dim=None, heads=1, dim_head=4) + + attn.fuse_projections() + + assert not attn.fused_projections + assert not hasattr(attn, "to_qkv") + + +def test_longcat_audio_transformer_exposes_attention_processors(): + model = LongCatAudioDiTTransformer( + dit_dim=64, + dit_depth=2, + dit_heads=4, + dit_text_dim=32, + latent_dim=8, + text_conv=False, + ) + + processors = model.attn_processors + + assert len(processors) == 4 + model.set_attn_processor(dict(processors)) diff --git a/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py index c4e1aeeda67c..40e44319967f 100644 --- a/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py +++ b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py @@ -67,7 +67,7 @@ def get_dummy_components(self): strides=[2], latent_dim=8, encoder_latent_dim=16, - downsampling_ratio=2, + downsampling_ratio=4, sample_rate=24000, ) @@ -158,6 +158,60 @@ def test_num_images_per_prompt(self): def test_encode_prompt_works_in_isolation(self): self.skipTest("LongCatAudioDiTPipeline.encode_prompt has a custom signature.") + def test_encode_prompt_returns_grad_bearing_embeds(self): + device = "cpu" + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe.to(device) + + with torch.enable_grad(): + prompt_embeds, _ = pipe.encode_prompt("soft ocean ambience", torch.device(device)) + loss = prompt_embeds.float().sum() + + self.assertTrue(prompt_embeds.requires_grad) + loss.backward() + self.assertTrue(any(param.grad is not None for param in pipe.text_encoder.parameters())) + + def test_transformer_inputs_use_transformer_dtype(self): + device = "cpu" + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe.to(device) + pipe.transformer.to(dtype=torch.bfloat16) + + observed_dtypes = [] + + def record_transformer_inputs(module, args, kwargs): + observed_dtypes.append( + { + "hidden_states": kwargs["hidden_states"].dtype, + "encoder_hidden_states": kwargs["encoder_hidden_states"].dtype, + "timestep": kwargs["timestep"].dtype, + "latent_cond": kwargs["latent_cond"].dtype, + } + ) + + hook = pipe.transformer.register_forward_pre_hook(record_transformer_inputs, with_kwargs=True) + inputs = self.get_dummy_inputs(device) + inputs.update( + { + "negative_prompt": "noise", + "guidance_scale": 4.0, + "output_type": "latent", + } + ) + + try: + output = pipe(**inputs).audios + finally: + hook.remove() + + self.assertEqual(output.dtype, torch.bfloat16) + self.assertGreaterEqual(len(observed_dtypes), 2) + for dtypes in observed_dtypes: + self.assertEqual(dtypes["hidden_states"], torch.bfloat16) + self.assertEqual(dtypes["encoder_hidden_states"], torch.bfloat16) + self.assertEqual(dtypes["timestep"], torch.bfloat16) + self.assertEqual(dtypes["latent_cond"], torch.bfloat16) + def test_uniform_flow_match_scheduler_grid_matches_manual_updates(self): num_inference_steps = 6 scheduler = FlowMatchEulerDiscreteScheduler(shift=1.0, invert_sigmas=True) @@ -203,9 +257,10 @@ def test_longcat_audio_pipeline_from_pretrained_real_local_weights(self): if not tokenizer_path.exists(): raise unittest.SkipTest(f"LongCat-AudioDiT tokenizer path not found: {tokenizer_path}") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True) pipe = LongCatAudioDiTPipeline.from_pretrained( model_path, - tokenizer=tokenizer_path, + tokenizer=tokenizer, torch_dtype=torch.float16, local_files_only=True, )