Skip to content

[Feat]: add NPU fused operators (RMSNorm, RoPE, SwiGLU, SDPA)#194

Open
ys2025-AI wants to merge 5 commits into
modelscope:mainfrom
ys2025-AI:main
Open

[Feat]: add NPU fused operators (RMSNorm, RoPE, SwiGLU, SDPA)#194
ys2025-AI wants to merge 5 commits into
modelscope:mainfrom
ys2025-AI:main

Conversation

@ys2025-AI
Copy link
Copy Markdown
Collaborator

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

Extends Twinkle's NPU support from a basic MoE GMM patch to a full fused-operator suite (RMSNorm, RoPE, SwiGLU, SDPA) for Ascend hardware.

Experiment results

Atlas 900 A2 (8× NPU) | Qwen3-30B-A3B-Instruct-2507 | LoRA r=8, batch=16, 188 steps | Dataset GSM8K_ZH

Metric Baseline This PR Delta
Total 544 s 503 s +7.5%
Training (step 10–180) 465 s 404 s +13.1%
Loss / GradNorm << 0.01

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces comprehensive NPU hardware acceleration support for Ascend devices by implementing fused operators (RMSNorm, RoPE, SwiGLU, and SDPA) and monkey-patching logic for specific model families like Qwen. It also refactors the NPU patching mechanism to be applied automatically when an NPU device is detected. Review feedback focuses on improving error handling by logging tracebacks for broad exception catches and restoring type hints and assertions that were removed during the refactoring of the MoE grouped matrix multiplication functions.

Comment thread src/twinkle/kernel/__init__.py Outdated
Comment thread src/twinkle/kernel/__init__.py Outdated
assert x.size(1) == weight_ekn.size(1), (
f'input dim mismatch: x.shape={tuple(x.shape)}, weight_ekn.shape={tuple(weight_ekn.shape)}')

def forward(ctx, x, group_list, weight_ekn):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hints and assertions from the previous version of GmmFunction.forward have been removed. These are valuable for static analysis, code clarity, and preventing runtime errors. Please consider restoring the type hints and relevant assertions for tensor shapes and dimensions to maintain code quality.

Suggested change
def forward(ctx, x, group_list, weight_ekn):
def forward(ctx, x: torch.Tensor, group_list: torch.Tensor, weight_ekn: torch.Tensor):

assert weight_ekn.size(0) == offs.numel(), (
f'weight_ekn.size(0)={weight_ekn.size(0)} != offs.numel()={offs.numel()}')

def _grouped_mm_npu(input, weight_ekn, offs):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to GmmFunction, the type hints and assertions have been removed from this function. Restoring them would improve code quality and make it easier to understand the expected inputs and outputs.

Suggested change
def _grouped_mm_npu(input, weight_ekn, offs):
def _grouped_mm_npu(input: torch.Tensor, weight_ekn: torch.Tensor, offs: torch.Tensor) -> torch.Tensor:

ys2025-AI and others added 3 commits May 18, 2026 20:25
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

# Priority 2: Fallback to global NPU availability
try:
if hasattr(torch, 'npu') and torch.npu.is_available():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有个torch_util,建议复用代码

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants