Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
797b67e
feat: add NPU FastGELUMLP (task 1 of 4)
Apr 22, 2026
1a0615b
feat: add NPU RMSNorm wrapper (task 2 of 4)
Apr 22, 2026
b8dba48
feat: add MINDIE attention backend for NPU (task 3 of 4)
Apr 22, 2026
a98e665
feat: add NPU RoPE with use_real=True and use_real=False paths (task …
Apr 23, 2026
2c0a8fe
feat: add NPU AdaLayerNorm wrapper (task 5 of 4)
Apr 24, 2026
5a27364
fix: cache RMSNorm fallback instance and remove debug code
Apr 30, 2026
4e0e310
fix: unsqueeze gates to [B,1,D] for broadcast with 3D tensors
May 6, 2026
a3744c6
fix: align img modulation with txt under zero_cond_t to fix 2*B vs B …
May 6, 2026
a0a2e24
fix: only auto-switch to MINDIE on NPU when user doesn't specify attn…
May 6, 2026
a49c78c
test: add AdaLayerNorm edge case and NPU path mock tests
May 15, 2026
3cc53d1
test: add unit tests for NPU adaptation features
May 15, 2026
d0024c2
fix: prevent _fallback from appearing in RMSNorm state_dict
May 15, 2026
f8ca528
test: verify RMSNorm state_dict has no _fallback keys and supports st…
May 15, 2026
1b31ad3
test: add precision comparison tests for 4 NPU fused operators
May 19, 2026
6e85007
fix: move RMSNorm to NPU device and remove asymmetric kv_len test cases
May 19, 2026
2b9d2f9
feat: add env var switches to disable individual NPU fused ops
May 19, 2026
0fd27d1
fix: restore v1 per-token CFG for img modulation when zero_cond_t=True
May 19, 2026
0406753
Revert "feat: add env var switches to disable individual NPU fused ops"
May 19, 2026
e80b76d
DEBUG: force F.gelu fallback to measure npu_fast_gelu SSIM impact
May 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions diffsynth_engine/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ def _parse_tuple(value: str) -> Tuple[int, int] | int:
raise ValueError(f"Cannot parse tuple: {value}, format should be '256,256' or '256'")


def _parse_attention_type(attn_type_str: str) -> AttentionType:
def _parse_attention_type(attn_type_str: str | None) -> AttentionType | None:
"""Convert string to AttentionType enum"""
if attn_type_str is None:
return None
return AttentionType[attn_type_str.upper()]


Expand Down Expand Up @@ -106,9 +108,9 @@ def parse_cli_args() -> Dict[str, Any]:
attn_group.add_argument(
"--attn-type",
type=str,
default="sdpa",
default=None,
choices=attn_type_choices,
help="Attention type (default: sdpa)",
help="Attention type (default: auto, SDPA on GPU, MINDIE on NPU)",
)
attn_group.add_argument(
"--sparge-topk",
Expand Down
2 changes: 1 addition & 1 deletion diffsynth_engine/configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class PipelineConfig:
vae_tile_stride: int | Tuple[int, int] = (192, 192)

# attention
attn_type: AttentionType = AttentionType.SDPA
attn_type: AttentionType | None = None # None = auto-detect
attn_params: Optional[AttentionParams] = None

# parallelism
Expand Down
1 change: 1 addition & 0 deletions diffsynth_engine/layers/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class AttentionType(enum.Enum):
SAGE2 = enum.auto()
SAGE3 = enum.auto()
SPARGE = enum.auto()
MINDIE = enum.auto()

def __str__(self) -> str:
return self.name.lower()
Expand Down
79 changes: 79 additions & 0 deletions diffsynth_engine/layers/attention/backends/mindie_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
from diffsynth_engine.layers.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionType,
)
from diffsynth_engine.utils.import_utils import is_npu_available


class MindieAttentionBackend(AttentionBackend):
@staticmethod
def check_availability() -> None:
if not is_npu_available():
raise RuntimeError("NPU is not available, cannot use MINDIE attention backend")

@staticmethod
def get_type() -> AttentionType:
return AttentionType.MINDIE

@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
return MindieAttentionImpl

@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return AttentionMetadata

@staticmethod
def get_builder_cls() -> type:
return None

@staticmethod
def get_supported_head_sizes() -> list[int]:
return []


class MindieAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
softmax_scale: float | None = None,
causal: bool = False,
num_kv_heads: int | None = None,
**extra_impl_args,
) -> None:
if num_kv_heads is None:
num_kv_heads = num_heads
self.num_kv_groups = num_heads // num_kv_heads
self.causal = causal
self.softmax_scale = softmax_scale
self.num_heads = num_heads
self.head_size = head_size

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
attn_metadata=None,
) -> torch.Tensor:
from mindiesd.layers.flash_attn.attention_forward import attention_forward

scale = self.softmax_scale
if scale is None:
scale = self.head_size ** -0.5

out = attention_forward(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
scale=scale,
fused=True,
head_first=False,
)
return out
9 changes: 6 additions & 3 deletions diffsynth_engine/layers/attention/selector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import cache

from diffsynth_engine.layers.attention.backends.abstract import AttentionBackend, AttentionType
from diffsynth_engine.utils.import_utils import LazyImport
from diffsynth_engine.utils.import_utils import LazyImport, is_npu_available

AiterBackend = LazyImport("diffsynth_engine.layers.attention.backends.aiter", "AiterBackend")
AiterFP8Backend = LazyImport("diffsynth_engine.layers.attention.backends.aiter", "AiterFP8Backend")
Expand All @@ -15,6 +15,7 @@
SageAttention3Backend = LazyImport("diffsynth_engine.layers.attention.backends.sage_attn_3", "SageAttention3Backend")
SDPABackend = LazyImport("diffsynth_engine.layers.attention.backends.sdpa", "SDPABackend")
SpargeAttentionBackend = LazyImport("diffsynth_engine.layers.attention.backends.sparge_attn", "SpargeAttentionBackend")
MindieAttentionBackend = LazyImport("diffsynth_engine.layers.attention.backends.mindie_attn", "MindieAttentionBackend")

_attention_backends = {
AttentionType.AITER: AiterBackend,
Expand All @@ -27,14 +28,16 @@
AttentionType.SAGE3: SageAttention3Backend,
AttentionType.SDPA: SDPABackend,
AttentionType.SPARGE: SpargeAttentionBackend,
AttentionType.MINDIE: MindieAttentionBackend,
}


@cache
def get_attn_backend(head_size: int, attn_type: AttentionType | None = None) -> type["AttentionBackend"]:
# use SDPA as default
if attn_type is None:
attn_type = AttentionType.SDPA
# Auto-detect: NPU → MINDIE, otherwise → SDPA
attn_type = AttentionType.MINDIE if is_npu_available() else AttentionType.SDPA

selected_backend = _attention_backends[attn_type]
selected_backend.check_availability()
if not selected_backend.supports_head_size(head_size):
Expand Down
73 changes: 73 additions & 0 deletions diffsynth_engine/layers/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch.nn as nn
import torch.nn.functional as F
from diffsynth_engine.utils.import_utils import is_npu_available

try:
import torch_npu
except ImportError:
torch_npu = None


class _GELUProj(nn.Module):
"""Wrapper to match diffusers FeedForward GELU structure with internal proj.

This wrapper holds the first Linear layer as .proj to match checkpoint keys.
"""

def __init__(self, dim, inner_dim):
super().__init__()
self.proj = nn.Linear(dim, inner_dim)

def forward(self, x):
return F.gelu(x, approximate="tanh")


class FastGELUMLP(nn.Module):
"""MLP with npu_fast_gelu on NPU, fallback to F.gelu on other devices.

Functionally equivalent to diffusers.models.attention.FeedForward(
dim=dim, dim_out=dim, activation_fn="gelu-approximate"
)
"""

def __init__(self, dim, dim_out=None, mult=4):
"""Initialize MLP.

Args:
dim: Input and output dimension
dim_out: Output dimension, defaults to dim
mult: inner_dim = dim * mult, defaults to 4
"""
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out or dim

# Match diffusers FeedForward structure: net[0]=GELU(proj), net[2]=output
# net[1] is Dropout which is skipped in inference
self.net = nn.ModuleList([
_GELUProj(dim, inner_dim),
nn.Dropout(0.0),
nn.Linear(inner_dim, dim_out),
])

def forward(self, hidden_states):
"""Forward pass.

Args:
hidden_states: Input tensor, shape [B, S, dim]

Returns:
Output tensor, shape [B, S, dim_out]
"""
# net[0] = _GELUProj with internal proj (dim → inner_dim)
hidden_states = self.net[0].proj(hidden_states)

# TODO: temporarily force F.gelu to measure SSIM impact of npu_fast_gelu
# if is_npu_available() and torch_npu is not None:
# hidden_states = torch_npu.npu_fast_gelu(hidden_states)
# else:
hidden_states = F.gelu(hidden_states, approximate="tanh")

# net[2] = output Linear (inner_dim → dim_out)
hidden_states = self.net[2](hidden_states)
return hidden_states
82 changes: 82 additions & 0 deletions diffsynth_engine/layers/norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import torch.nn as nn
from diffusers.models.normalization import RMSNorm as DiffusersRMSNorm
from diffsynth_engine.utils.import_utils import is_npu_available

try:
import torch_npu
except ImportError:
torch_npu = None

try:
from mindiesd.layers import layernorm_scale_shift
except ImportError:
layernorm_scale_shift = None


class RMSNorm(nn.Module):
"""NPU-optimized RMSNorm wrapper with fallback to diffusers implementation."""

def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
# Cache the fallback instance so forward() reuses the same weight
# tensor. register_parameter is reference assignment (no copy), so
# self.weight and self._fallback.weight share the same storage.
# When a checkpoint writes to "weight", both paths see the update.
fallback = DiffusersRMSNorm(hidden_size, eps)
self.register_parameter("weight", fallback.weight)
# Use object.__setattr__ to avoid registering _fallback as an
# nn.Module submodule, which would add spurious keys to state_dict()
# and break strict checkpoint loading.
object.__setattr__(self, "_fallback", fallback)

def forward(self, hidden_states):
if is_npu_available() and torch_npu is not None:
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
else:
return self._fallback(hidden_states)


class AdaLayerNorm(nn.Module):
"""NPU-optimized AdaLayerNorm with fallback to original implementation.

Performs: output = layernorm(x) * (1 + scale) + shift

Args:
layernorm: The underlying nn.LayerNorm module (elementwise_affine=False)
"""

def __init__(self, layernorm: nn.LayerNorm):
super().__init__()
self.layernorm = layernorm

def forward(self, hidden_states: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states: Input tensor, shape [B, S, H]
scale: Scale parameter, shape [B, H] or [B, 1, H]
shift: Shift parameter, shape [B, H] or [B, 1, H]

Returns:
layernorm(x) * (1 + scale) + shift
"""
if is_npu_available() and layernorm_scale_shift is not None:
# NPU path: use MindIE-SD fused operator
return layernorm_scale_shift(
layernorm=self.layernorm,
x=hidden_states,
scale=scale,
shift=shift,
fused=True
)
else:
# Fallback: original Python implementation
normed = self.layernorm(hidden_states)
# Handle [B, 1, H] -> [B, H] dimension
if scale.dim() == 2:
scale = scale.unsqueeze(1)
if shift.dim() == 2:
shift = shift.unsqueeze(1)
return normed * (1 + scale) + shift
Loading