diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d675f1de04a7..44fe8367636d 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -22,7 +22,7 @@ import types from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Union, get_args, get_origin, get_type_hints +from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin, get_type_hints import httpx import numpy as np @@ -68,6 +68,7 @@ is_transformers_version, logging, numpy_to_pil, + requires_backends, ) from ..utils.distributed_utils import is_torch_dist_rank_zero from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card @@ -2248,6 +2249,61 @@ def _is_pipeline_device_mapped(self): return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1 + def enable_neuron_compile( + self, + model_names: Optional[List[str]] = None, + cache_dir: Optional[str] = None, + fullgraph: bool = True, + ) -> None: + """ + Compiles the pipeline's nn.Module components with ``torch.compile(backend="neuron")``, + enabling whole-graph NEFF compilation for AWS Trainium/Inferentia. + + The first forward call per component triggers neuronx-cc compilation (slow). + Use ``neuron_warmup()`` to trigger this explicitly before timed inference. + + Args: + model_names (`List[str]`, *optional*): + Component names to compile. Defaults to all nn.Module components. + cache_dir (`str`, *optional*): + Path to persist compiled NEFFs across runs via ``TORCH_NEURONX_NEFF_CACHE_DIR``. + Skips recompilation on subsequent runs. + fullgraph (`bool`, defaults to `True`): + Disallow graph breaks (required for full-graph fusion). + """ + requires_backends(self, "torch_neuronx") + import torch_neuronx # noqa: F401 — registers neuron backend + + if cache_dir is not None: + os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = cache_dir + + if model_names is None: + model_names = [ + name for name, comp in self.components.items() if isinstance(comp, torch.nn.Module) + ] + + for name in model_names: + component = getattr(self, name, None) + if isinstance(component, torch.nn.Module) and not is_compiled_module(component): + logger.info(f"Compiling {name} with backend='neuron'") + setattr(self, name, torch.compile(component, backend="neuron", fullgraph=fullgraph)) + + def neuron_warmup(self, *args, **kwargs) -> None: + """ + Runs a single dummy forward pass through the pipeline to trigger neuronx-cc + compilation for all components (static-shape NEFF compilation). + + This is equivalent to calling ``__call__`` with the same shapes but discards + the output. After warmup, subsequent calls reuse the compiled NEFFs and run fast. + + Pass the same arguments you would use for real inference (height, width, + num_inference_steps, batch_size, etc.) so that the compiled shapes match. + """ + logger.info("Running Neuron warmup forward pass to trigger NEFF compilation...") + with torch.no_grad(): + self(*args, **kwargs) + logger.info("Neuron warmup complete.") + class StableDiffusionMixin: r""" diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 23d7ac7c6c2d..8a86cf4f4151 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -110,6 +110,7 @@ is_timm_available, is_torch_available, is_torch_mlu_available, + is_torch_neuronx_available, is_torch_npu_available, is_torch_version, is_torch_xla_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 551fa358a28d..e23fccc1a374 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -193,6 +193,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> tuple[b _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla") _torch_npu_available, _torch_npu_version = _is_package_available("torch_npu") _torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu") +_torch_neuronx_available, _torch_neuronx_version = _is_package_available("torch_neuronx") _transformers_available, _transformers_version = _is_package_available("transformers") _hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub") _kernels_available, _kernels_version = _is_package_available("kernels") @@ -249,6 +250,10 @@ def is_torch_mlu_available(): return _torch_mlu_available +def is_torch_neuronx_available(): + return _torch_neuronx_available + + def is_flax_available(): return _flax_available diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 7f4cb3e12766..88b53e2b5b16 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -21,19 +21,26 @@ import os from . import logging -from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version +from .import_utils import ( + is_torch_available, + is_torch_mlu_available, + is_torch_neuronx_available, + is_torch_npu_available, + is_torch_version, +) if is_torch_available(): import torch from torch.fft import fftn, fftshift, ifftn, ifftshift - BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True} + BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "neuron": False, "default": True} BACKEND_EMPTY_CACHE = { "cuda": torch.cuda.empty_cache, "xpu": torch.xpu.empty_cache, "cpu": None, "mps": torch.mps.empty_cache, + "neuron": None, "default": None, } BACKEND_DEVICE_COUNT = { @@ -41,6 +48,7 @@ "xpu": torch.xpu.device_count, "cpu": lambda: 0, "mps": lambda: 0, + "neuron": lambda: getattr(getattr(torch, "neuron", None), "device_count", lambda: 0)(), "default": 0, } BACKEND_MANUAL_SEED = { @@ -48,6 +56,7 @@ "xpu": torch.xpu.manual_seed, "cpu": torch.manual_seed, "mps": torch.mps.manual_seed, + "neuron": torch.manual_seed, "default": torch.manual_seed, } BACKEND_RESET_PEAK_MEMORY_STATS = { @@ -55,6 +64,7 @@ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, + "neuron": None, "default": None, } BACKEND_RESET_MAX_MEMORY_ALLOCATED = { @@ -62,6 +72,7 @@ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, + "neuron": None, "default": None, } BACKEND_MAX_MEMORY_ALLOCATED = { @@ -69,6 +80,7 @@ "xpu": getattr(torch.xpu, "max_memory_allocated", None), "cpu": 0, "mps": 0, + "neuron": 0, "default": 0, } BACKEND_SYNCHRONIZE = { @@ -76,6 +88,7 @@ "xpu": getattr(torch.xpu, "synchronize", None), "cpu": None, "mps": None, + "neuron": None, "default": None, } logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -164,11 +177,15 @@ def randn_tensor( layout = layout or torch.strided device = device or torch.device("cpu") + # Neuron (XLA) does not support creating random tensors directly on device; always use CPU + if device.type == "neuron": + rand_device = torch.device("cpu") + if generator is not None: gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type if gen_device_type != device.type and gen_device_type == "cpu": rand_device = "cpu" - if device != "mps": + if device.type not in ("mps", "neuron"): logger.info( f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" @@ -289,6 +306,8 @@ def get_device(): return "mps" elif is_torch_mlu_available(): return "mlu" + elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available(): + return "neuron" else: return "cpu"