diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index 0b5b9fa3efba..c25bd2b98688 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import BaseOutput, logging +from ...utils import BaseOutput, apply_lora_scale, logging from ..attention import AttentionMixin from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -598,6 +598,7 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) + @apply_lora_scale("cross_attention_kwargs") def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 639a8ad7390a..2ad32ee0bdea 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -20,7 +20,11 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ...utils import ( + BaseOutput, + apply_lora_scale, + logging, +) from ..attention import AttentionMixin from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed @@ -150,6 +154,7 @@ def from_transformer( return controlnet + @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -197,20 +202,6 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) hidden_states = self.x_embedder(hidden_states) if self.input_hint_block is not None: @@ -323,10 +314,6 @@ def forward( None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples ) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (controlnet_block_samples, controlnet_single_block_samples) diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index fa374285eec1..0a421b8394b0 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -20,7 +20,12 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import ( + BaseOutput, + apply_lora_scale, + deprecate, + logging, +) from ..attention import AttentionMixin from ..cache_utils import CacheMixin from ..controlnets.controlnet import zero_module @@ -123,6 +128,7 @@ def from_transformer( return controlnet + @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -181,20 +187,6 @@ def forward( standard_warn=False, ) - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) hidden_states = self.img_in(hidden_states) # add @@ -256,10 +248,6 @@ def forward( controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples] controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return controlnet_block_samples diff --git a/src/diffusers/models/controlnets/controlnet_sana.py b/src/diffusers/models/controlnets/controlnet_sana.py index c71a8b326635..a739f12080b6 100644 --- a/src/diffusers/models/controlnets/controlnet_sana.py +++ b/src/diffusers/models/controlnets/controlnet_sana.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ...utils import BaseOutput, apply_lora_scale, logging from ..attention import AttentionMixin from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput @@ -117,6 +117,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -129,21 +130,6 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. @@ -218,10 +204,6 @@ def forward( block_res_sample = controlnet_block(block_res_sample) controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples] if not return_dict: diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index 08b86ff344eb..b1a8284fe2e1 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ..attention import AttentionMixin, JointTransformerBlock from ..attention_processor import Attention, FusedJointAttnProcessor2_0 from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed @@ -269,6 +269,7 @@ def from_transformer( return controlnet + @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -308,21 +309,6 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) - if self.pos_embed is not None and hidden_states.ndim != 4: raise ValueError("hidden_states must be 4D when pos_embed is used") @@ -382,10 +368,6 @@ def forward( # 6. scaling controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples] - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (controlnet_block_res_samples,) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index e3732662e408..48578dc5d3ce 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin from ..attention_processor import ( @@ -397,6 +397,7 @@ def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.FloatTensor, @@ -405,21 +406,6 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - height, width = hidden_states.shape[-2:] # Apply patch embedding, timestep embedding, and project the caption embeddings. @@ -486,10 +472,6 @@ def forward( shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size) ) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 14b38cd46c52..f48bbc4ef579 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, AttentionMixin, FeedForward from ..attention_processor import CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 @@ -363,6 +363,7 @@ def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -374,21 +375,6 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - batch_size, num_frames, channels, height, width = hidden_states.shape # 1. Time embedding @@ -454,10 +440,6 @@ def forward( ) output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/consisid_transformer_3d.py b/src/diffusers/models/transformers/consisid_transformer_3d.py index be20b0a3eacf..9458712ddacb 100644 --- a/src/diffusers/models/transformers/consisid_transformer_3d.py +++ b/src/diffusers/models/transformers/consisid_transformer_3d.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, AttentionMixin, FeedForward from ..attention_processor import CogVideoXAttnProcessor2_0 @@ -620,6 +620,7 @@ def _init_face_inputs(self): ] ) + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -632,21 +633,6 @@ def forward( id_vit_hidden: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - # fuse clip and insightface valid_face_emb = None if self.is_train_face: @@ -720,10 +706,6 @@ def forward( output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index d45badb1b121..2bee436daab5 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ..attention import AttentionMixin from ..attention_processor import ( Attention, @@ -414,6 +414,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -426,21 +427,6 @@ def forward( controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. @@ -527,10 +513,6 @@ def forward( hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4) output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/transformer_bria.py b/src/diffusers/models/transformers/transformer_bria.py index d54679306e64..05c8a6f27b19 100644 --- a/src/diffusers/models/transformers/transformer_bria.py +++ b/src/diffusers/models/transformers/transformer_bria.py @@ -8,7 +8,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -581,6 +581,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -621,20 +622,6 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) @@ -715,10 +702,6 @@ def forward( hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py index 09f79619320d..b07dc4099555 100644 --- a/src/diffusers/models/transformers/transformer_bria_fibo.py +++ b/src/diffusers/models/transformers/transformer_bria_fibo.py @@ -22,10 +22,8 @@ from ...models.modeling_utils import ModelMixin from ...models.transformers.transformer_bria import BriaAttnProcessor from ...utils import ( - USE_PEFT_BACKEND, + apply_lora_scale, logging, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionModuleMixin, FeedForward @@ -510,6 +508,7 @@ def __init__( ] self.caption_projection = nn.ModuleList(caption_projection) + @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -545,20 +544,7 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) @@ -645,10 +631,6 @@ def forward( hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 2ef3643dafbd..37b4e4e28499 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, deprecate, logging from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, FeedForward @@ -473,6 +473,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -511,20 +512,6 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) hidden_states = self.x_embedder(hidden_states) @@ -631,10 +618,6 @@ def forward( hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/transformer_chronoedit.py b/src/diffusers/models/transformers/transformer_chronoedit.py index 79828b6464f4..8742cf2951a6 100644 --- a/src/diffusers/models/transformers/transformer_chronoedit.py +++ b/src/diffusers/models/transformers/transformer_chronoedit.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, deprecate, logging from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward @@ -638,6 +638,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -647,21 +648,6 @@ def forward( return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size post_patch_num_frames = num_frames // p_t @@ -729,10 +715,6 @@ def forward( hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index 64e9a538a7c2..c77494c2b19f 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention @@ -703,6 +703,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -718,21 +719,6 @@ def forward( Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] ] = None, ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - batch_size, num_channels, height, width = hidden_states.shape # 1. RoPE @@ -779,10 +765,6 @@ def forward( hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p) output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 1a4464432425..f6bcaa6735a9 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward @@ -634,6 +634,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -675,20 +676,6 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) hidden_states = self.x_embedder(hidden_states) @@ -785,10 +772,6 @@ def forward( hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 9cadfcefc497..13be477e3485 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn @@ -774,6 +774,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -810,20 +811,6 @@ def forward( `tuple` where the first element is the sample tensor. """ # 0. Handle input arguments - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) num_txt_tokens = encoder_hidden_states.shape[1] @@ -908,10 +895,6 @@ def forward( hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index 4a5aee29abc4..0e9c947b5161 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -8,7 +8,7 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...models.modeling_outputs import Transformer2DModelOutput from ...models.modeling_utils import ModelMixin -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, deprecate, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention from ..embeddings import TimestepEmbedding, Timesteps @@ -773,6 +773,7 @@ def patchify(self, hidden_states): return hidden_states, hidden_states_masks, img_sizes, img_ids + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -808,21 +809,6 @@ def forward( "if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)" ) - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - # spatial forward batch_size = hidden_states.shape[0] hidden_states_type = hidden_states.dtype @@ -933,10 +919,6 @@ def forward( if hidden_states_masks is not None: hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len] - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 4f0775ac9fa0..84dcb1fe407a 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ..attention import AttentionMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention @@ -989,6 +989,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -1000,21 +1001,6 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - batch_size, num_channels, num_frames, height, width = hidden_states.shape p, p_t = self.config.patch_size, self.config.patch_size_t post_patch_num_frames = num_frames // p_t @@ -1104,10 +1090,6 @@ def forward( hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (hidden_states,) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video15.py b/src/diffusers/models/transformers/transformer_hunyuan_video15.py index 293ba996ea98..8595595326d3 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video15.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video15.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ..attention import AttentionMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention @@ -620,6 +620,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -633,21 +634,6 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size_t, self.config.patch_size, self.config.patch_size post_patch_num_frames = num_frames // p_t @@ -783,10 +769,6 @@ def forward( hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (hidden_states,) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py index 601ba0f0b472..500cec89f895 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, get_logger from ..cache_utils import CacheMixin from ..embeddings import get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput @@ -198,6 +198,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -217,21 +218,6 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - batch_size, num_channels, num_frames, height, width = hidden_states.shape p, p_t = self.config.patch_size, self.config.patch_size_t post_patch_num_frames = num_frames // p_t @@ -337,10 +323,6 @@ def forward( hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (hidden_states,) return Transformer2DModelOutput(sample=hidden_states) diff --git a/src/diffusers/models/transformers/transformer_hunyuanimage.py b/src/diffusers/models/transformers/transformer_hunyuanimage.py index d626e322ad6f..dc4b22c32309 100644 --- a/src/diffusers/models/transformers/transformer_hunyuanimage.py +++ b/src/diffusers/models/transformers/transformer_hunyuanimage.py @@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -742,6 +742,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -755,21 +756,6 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - if hidden_states.ndim == 4: batch_size, channels, height, width = hidden_states.shape sizes = (height, width) @@ -900,10 +886,6 @@ def forward( ] hidden_states = hidden_states.reshape(*final_dims) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (hidden_states,) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 685c73c07c75..4bb0eb9268d6 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, deprecate, is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward @@ -491,6 +491,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -505,21 +506,6 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> torch.Tensor: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords) # convert encoder_attention_mask to a bias the same way we do for attention_mask @@ -568,10 +554,6 @@ def forward( hidden_states = hidden_states * (1 + scale) + shift output = self.proj_out(hidden_states) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index b88f096e8033..62bb0dfb1ff3 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -22,14 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import ( - USE_PEFT_BACKEND, - BaseOutput, - is_torch_version, - logging, - scale_lora_layers, - unscale_lora_layers, -) +from ...utils import BaseOutput, apply_lora_scale, is_torch_version, logging from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -1101,6 +1094,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -1171,21 +1165,6 @@ def forward( `tuple` is returned where the first element is the denoised video latent patch sequence and the second element is the denoised audio latent patch sequence. """ - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - # Determine timestep for audio. audio_timestep = audio_timestep if audio_timestep is not None else timestep @@ -1341,10 +1320,6 @@ def forward( audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift audio_output = self.audio_proj_out(audio_hidden_states) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output, audio_output) return AudioVisualModelOutput(sample=output, audio_sample=audio_output) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 77121edb9fc9..ecc70b06ef06 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ..attention import LuminaFeedForward from ..attention_processor import Attention from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed @@ -455,6 +455,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -464,21 +465,6 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.Tensor, Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - # 1. Condition, positional & patch embedding batch_size, _, height, width = hidden_states.shape @@ -539,10 +525,6 @@ def forward( ) output = torch.stack(output, dim=0) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 63911fe7c10d..d0cf82dea55d 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import MochiAttention, MochiAttnProcessor2_0 @@ -404,6 +404,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -413,21 +414,6 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> torch.Tensor: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - batch_size, num_channels, num_frames, height, width = hidden_states.shape p = self.config.patch_size @@ -479,10 +465,6 @@ def forward( hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) output = hidden_states.reshape(batch_size, -1, num_frames, height, width) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index cf11d8e01fb4..b2d4bdcc8b25 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -24,7 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, deprecate, logging from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, FeedForward @@ -829,6 +829,7 @@ def __init__( self.gradient_checkpointing = False self.zero_cond_t = zero_cond_t + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -887,20 +888,6 @@ def forward( "The mask-based approach is more flexible and supports variable-length sequences.", standard_warn=False, ) - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) hidden_states = self.img_in(hidden_states) @@ -981,10 +968,6 @@ def forward( hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index a4f90342631a..f6a3aff3a4c6 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ..attention import AttentionMixin from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention @@ -570,6 +570,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -582,21 +583,6 @@ def forward( controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. @@ -695,10 +681,6 @@ def forward( hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 05391e047b7a..53d958784da9 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -18,7 +18,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, FeedForward, JointTransformerBlock from ..attention_processor import ( @@ -245,6 +245,7 @@ def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) + @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -284,20 +285,6 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) height, width = hidden_states.shape[-2:] @@ -352,10 +339,6 @@ def forward( shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) ) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 2b9fc5b8d9fb..0a2be5a311a3 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, deprecate, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -630,6 +630,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -641,21 +642,6 @@ def forward( return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size post_patch_num_frames = num_frames // p_t @@ -771,10 +757,6 @@ def forward( hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 132f615f2199..755a88dfda77 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, deprecate, logging from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward @@ -622,6 +622,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -631,21 +632,6 @@ def forward( return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size post_patch_num_frames = num_frames // p_t @@ -713,10 +699,6 @@ def forward( hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/transformer_wan_animate.py b/src/diffusers/models/transformers/transformer_wan_animate.py index ae1c3095a17f..b062cdfaccc6 100644 --- a/src/diffusers/models/transformers/transformer_wan_animate.py +++ b/src/diffusers/models/transformers/transformer_wan_animate.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin @@ -1141,6 +1141,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -1179,21 +1180,6 @@ def forward( Whether to return the output as a dict or tuple. """ - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - # Check that shapes match up if pose_hidden_states is not None and pose_hidden_states.shape[2] + 1 != hidden_states.shape[2]: raise ValueError( @@ -1294,10 +1280,6 @@ def forward( hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index 1be4f73e33e2..1c84b4628e10 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ..attention import AttentionMixin, FeedForward from ..cache_utils import CacheMixin from ..modeling_outputs import Transformer2DModelOutput @@ -261,6 +261,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -272,21 +273,6 @@ def forward( return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size post_patch_num_frames = num_frames // p_t @@ -379,10 +365,6 @@ def forward( hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index e669aa51a54e..037f720acf0e 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -20,7 +20,12 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import ( + BaseOutput, + apply_lora_scale, + deprecate, + logging, +) from ..activations import get_activation from ..attention import AttentionMixin from ..attention_processor import ( @@ -974,6 +979,7 @@ def process_encoder_hidden_states( encoder_hidden_states = (encoder_hidden_states, image_embeds) return encoder_hidden_states + @apply_lora_scale("cross_attention_kwargs") def forward( self, sample: torch.Tensor, @@ -1112,18 +1118,6 @@ def forward( cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down - # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated - # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. - if cross_attention_kwargs is not None: - cross_attention_kwargs = cross_attention_kwargs.copy() - lora_scale = cross_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets is_adapter = down_intrablock_additional_residuals is not None @@ -1239,10 +1233,6 @@ def forward( sample = self.conv_act(sample) sample = self.conv_out(sample) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (sample,) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 5a93541501d3..a96afa33937f 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin -from ...utils import BaseOutput, deprecate, logging +from ...utils import BaseOutput, apply_lora_scale, deprecate, logging from ...utils.torch_utils import apply_freeu from ..attention import AttentionMixin, BasicTransformerBlock from ..attention_processor import ( @@ -1875,6 +1875,7 @@ def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) + @apply_lora_scale("cross_attention_kwargs") def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/unets/uvit_2d.py b/src/diffusers/models/unets/uvit_2d.py index 4c99ef88ca19..836d41a7f946 100644 --- a/src/diffusers/models/unets/uvit_2d.py +++ b/src/diffusers/models/unets/uvit_2d.py @@ -21,6 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin +from ...utils import apply_lora_scale from ..attention import AttentionMixin, BasicTransformerBlock, SkipFFTransformerBlock from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -146,6 +147,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("cross_attention_kwargs") def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None): encoder_hidden_states = self.encoder_proj(encoder_hidden_states) encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index e726bbb46913..2c9f4c995001 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -130,6 +130,7 @@ from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( + apply_lora_scale, check_peft_version, delete_adapter_layers, get_adapter_name, diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 12066ee3f89b..d6f5bcde2d97 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -16,6 +16,7 @@ """ import collections +import functools import importlib from typing import Optional @@ -275,6 +276,59 @@ def get_module_weight(weight_for_adapter, module_name): module.set_scale(adapter_name, get_module_weight(weight, module_name)) +def apply_lora_scale(kwargs_name: str = "joint_attention_kwargs"): + """ + Decorator to automatically handle LoRA layer scaling/unscaling in forward methods. + + This decorator extracts the `lora_scale` from the specified kwargs parameter, applies scaling before the forward + pass, and ensures unscaling happens after, even if an exception occurs. + + Args: + kwargs_name (`str`, defaults to `"joint_attention_kwargs"`): + The name of the keyword argument that contains the LoRA scale. Common values include + "joint_attention_kwargs", "attention_kwargs", "cross_attention_kwargs", etc. + """ + + def decorator(forward_fn): + @functools.wraps(forward_fn) + def wrapper(self, *args, **kwargs): + from . import USE_PEFT_BACKEND + + lora_scale = 1.0 + attention_kwargs = kwargs.get(kwargs_name) + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + kwargs[kwargs_name] = attention_kwargs + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + if ( + not USE_PEFT_BACKEND + and attention_kwargs is not None + and attention_kwargs.get("scale", None) is not None + ): + logger.warning( + f"Passing `scale` via `{kwargs_name}` when not using the PEFT backend is ineffective." + ) + + # Apply LoRA scaling if using PEFT backend + if USE_PEFT_BACKEND: + scale_lora_layers(self, lora_scale) + + try: + # Execute the forward pass + result = forward_fn(self, *args, **kwargs) + return result + finally: + # Always unscale, even if forward pass raises an exception + if USE_PEFT_BACKEND: + unscale_lora_layers(self, lora_scale) + + return wrapper + + return decorator + + def check_peft_version(min_version: str) -> None: r""" Checks if the version of PEFT is compatible.