diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py index b023974a33dd..b79ee280ca34 100644 --- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py @@ -324,17 +324,18 @@ def generate_language_model( `inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): The sequence of generated hidden-states. """ - cache_position_kwargs = {} - if is_transformers_version("<", "4.52.1"): - cache_position_kwargs["input_ids"] = inputs_embeds - else: - cache_position_kwargs["seq_length"] = inputs_embeds.shape[0] - cache_position_kwargs["device"] = ( - self.language_model.device if getattr(self, "language_model", None) is not None else self.device - ) - cache_position_kwargs["model_kwargs"] = model_kwargs max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens - model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs) + if hasattr(self.language_model, "_get_initial_cache_position"): + cache_position_kwargs = {} + if is_transformers_version("<", "4.52.1"): + cache_position_kwargs["input_ids"] = inputs_embeds + else: + cache_position_kwargs["seq_length"] = inputs_embeds.shape[0] + cache_position_kwargs["device"] = ( + self.language_model.device if getattr(self, "language_model", None) is not None else self.device + ) + cache_position_kwargs["model_kwargs"] = model_kwargs + model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs) for _ in range(max_new_tokens): # prepare model inputs