Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import types
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Union, get_args, get_origin, get_type_hints
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin, get_type_hints

import httpx
import numpy as np
Expand Down Expand Up @@ -68,6 +68,7 @@
is_transformers_version,
logging,
numpy_to_pil,
requires_backends,
)
from ..utils.distributed_utils import is_torch_dist_rank_zero
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
Expand Down Expand Up @@ -2248,6 +2249,61 @@ def _is_pipeline_device_mapped(self):

return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1

def enable_neuron_compile(
self,
model_names: Optional[List[str]] = None,
cache_dir: Optional[str] = None,
fullgraph: bool = True,
) -> None:
"""
Compiles the pipeline's nn.Module components with ``torch.compile(backend="neuron")``,
enabling whole-graph NEFF compilation for AWS Trainium/Inferentia.

The first forward call per component triggers neuronx-cc compilation (slow).
Use ``neuron_warmup()`` to trigger this explicitly before timed inference.

Args:
model_names (`List[str]`, *optional*):
Component names to compile. Defaults to all nn.Module components.
cache_dir (`str`, *optional*):
Path to persist compiled NEFFs across runs via ``TORCH_NEURONX_NEFF_CACHE_DIR``.
Skips recompilation on subsequent runs.
fullgraph (`bool`, defaults to `True`):
Disallow graph breaks (required for full-graph fusion).
"""
requires_backends(self, "torch_neuronx")
import torch_neuronx # noqa: F401 — registers neuron backend

if cache_dir is not None:
os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = cache_dir

if model_names is None:
model_names = [
name for name, comp in self.components.items() if isinstance(comp, torch.nn.Module)
]

for name in model_names:
component = getattr(self, name, None)
if isinstance(component, torch.nn.Module) and not is_compiled_module(component):
logger.info(f"Compiling {name} with backend='neuron'")
setattr(self, name, torch.compile(component, backend="neuron", fullgraph=fullgraph))

def neuron_warmup(self, *args, **kwargs) -> None:
"""
Runs a single dummy forward pass through the pipeline to trigger neuronx-cc
compilation for all components (static-shape NEFF compilation).

This is equivalent to calling ``__call__`` with the same shapes but discards
the output. After warmup, subsequent calls reuse the compiled NEFFs and run fast.

Pass the same arguments you would use for real inference (height, width,
num_inference_steps, batch_size, etc.) so that the compiled shapes match.
"""
logger.info("Running Neuron warmup forward pass to trigger NEFF compilation...")
with torch.no_grad():
self(*args, **kwargs)
logger.info("Neuron warmup complete.")


class StableDiffusionMixin:
r"""
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
is_timm_available,
is_torch_available,
is_torch_mlu_available,
is_torch_neuronx_available,
is_torch_npu_available,
is_torch_version,
is_torch_xla_available,
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> tuple[b
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
_torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu")
_torch_neuronx_available, _torch_neuronx_version = _is_package_available("torch_neuronx")
_transformers_available, _transformers_version = _is_package_available("transformers")
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
_kernels_available, _kernels_version = _is_package_available("kernels")
Expand Down Expand Up @@ -249,6 +250,10 @@ def is_torch_mlu_available():
return _torch_mlu_available


def is_torch_neuronx_available():
return _torch_neuronx_available


def is_flax_available():
return _flax_available

Expand Down
25 changes: 22 additions & 3 deletions src/diffusers/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,61 +21,74 @@
import os

from . import logging
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
from .import_utils import (
is_torch_available,
is_torch_mlu_available,
is_torch_neuronx_available,
is_torch_npu_available,
is_torch_version,
)


if is_torch_available():
import torch
from torch.fft import fftn, fftshift, ifftn, ifftshift

BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "neuron": False, "default": True}
BACKEND_EMPTY_CACHE = {
"cuda": torch.cuda.empty_cache,
"xpu": torch.xpu.empty_cache,
"cpu": None,
"mps": torch.mps.empty_cache,
"neuron": None,
"default": None,
}
BACKEND_DEVICE_COUNT = {
"cuda": torch.cuda.device_count,
"xpu": torch.xpu.device_count,
"cpu": lambda: 0,
"mps": lambda: 0,
"neuron": lambda: getattr(getattr(torch, "neuron", None), "device_count", lambda: 0)(),
"default": 0,
}
BACKEND_MANUAL_SEED = {
"cuda": torch.cuda.manual_seed,
"xpu": torch.xpu.manual_seed,
"cpu": torch.manual_seed,
"mps": torch.mps.manual_seed,
"neuron": torch.manual_seed,
"default": torch.manual_seed,
}
BACKEND_RESET_PEAK_MEMORY_STATS = {
"cuda": torch.cuda.reset_peak_memory_stats,
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
"cpu": None,
"mps": None,
"neuron": None,
"default": None,
}
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.reset_max_memory_allocated,
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
"cpu": None,
"mps": None,
"neuron": None,
"default": None,
}
BACKEND_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.max_memory_allocated,
"xpu": getattr(torch.xpu, "max_memory_allocated", None),
"cpu": 0,
"mps": 0,
"neuron": 0,
"default": 0,
}
BACKEND_SYNCHRONIZE = {
"cuda": torch.cuda.synchronize,
"xpu": getattr(torch.xpu, "synchronize", None),
"cpu": None,
"mps": None,
"neuron": None,
"default": None,
}
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -164,11 +177,15 @@ def randn_tensor(
layout = layout or torch.strided
device = device or torch.device("cpu")

# Neuron (XLA) does not support creating random tensors directly on device; always use CPU
if device.type == "neuron":
rand_device = torch.device("cpu")

if generator is not None:
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
if gen_device_type != device.type and gen_device_type == "cpu":
rand_device = "cpu"
if device != "mps":
if device.type not in ("mps", "neuron"):
logger.info(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
Expand Down Expand Up @@ -289,6 +306,8 @@ def get_device():
return "mps"
elif is_torch_mlu_available():
return "mlu"
elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available():
return "neuron"
else:
return "cpu"

Expand Down
Loading