[Feat]: add NPU fused operators (RMSNorm, RoPE, SwiGLU, SDPA)#194
[Feat]: add NPU fused operators (RMSNorm, RoPE, SwiGLU, SDPA)#194ys2025-AI wants to merge 5 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces comprehensive NPU hardware acceleration support for Ascend devices by implementing fused operators (RMSNorm, RoPE, SwiGLU, and SDPA) and monkey-patching logic for specific model families like Qwen. It also refactors the NPU patching mechanism to be applied automatically when an NPU device is detected. Review feedback focuses on improving error handling by logging tracebacks for broad exception catches and restoring type hints and assertions that were removed during the refactoring of the MoE grouped matrix multiplication functions.
| assert x.size(1) == weight_ekn.size(1), ( | ||
| f'input dim mismatch: x.shape={tuple(x.shape)}, weight_ekn.shape={tuple(weight_ekn.shape)}') | ||
|
|
||
| def forward(ctx, x, group_list, weight_ekn): |
There was a problem hiding this comment.
The type hints and assertions from the previous version of GmmFunction.forward have been removed. These are valuable for static analysis, code clarity, and preventing runtime errors. Please consider restoring the type hints and relevant assertions for tensor shapes and dimensions to maintain code quality.
| def forward(ctx, x, group_list, weight_ekn): | |
| def forward(ctx, x: torch.Tensor, group_list: torch.Tensor, weight_ekn: torch.Tensor): |
| assert weight_ekn.size(0) == offs.numel(), ( | ||
| f'weight_ekn.size(0)={weight_ekn.size(0)} != offs.numel()={offs.numel()}') | ||
|
|
||
| def _grouped_mm_npu(input, weight_ekn, offs): |
There was a problem hiding this comment.
Similar to GmmFunction, the type hints and assertions have been removed from this function. Restoring them would improve code quality and make it easier to understand the expected inputs and outputs.
| def _grouped_mm_npu(input, weight_ekn, offs): | |
| def _grouped_mm_npu(input: torch.Tensor, weight_ekn: torch.Tensor, offs: torch.Tensor) -> torch.Tensor: |
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
|
||
| # Priority 2: Fallback to global NPU availability | ||
| try: | ||
| if hasattr(torch, 'npu') and torch.npu.is_available(): |
There was a problem hiding this comment.
有个torch_util,建议复用代码
PR type
PR information
Extends Twinkle's NPU support from a basic MoE GMM patch to a full fused-operator suite (RMSNorm, RoPE, SwiGLU, SDPA) for Ascend hardware.
Experiment results
Atlas 900 A2 (8× NPU) | Qwen3-30B-A3B-Instruct-2507 | LoRA r=8, batch=16, 188 steps | Dataset GSM8K_ZH