Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
395357a
flashinfer-cutedsl-nvfp4-groupgemm
mpgemm Mar 18, 2026
07e6fae
flashinfer-cutedsl-nvfp4-groupgemm
mpgemm Mar 18, 2026
0140f51
flashinfer-cutedsl-nvfp4-groupgemm
mpgemm Mar 18, 2026
6d65bc8
flashinfer-cutedsl-nvfp4-groupgemm
mpgemm Mar 18, 2026
400c6f9
fix prefill
mpgemm Mar 19, 2026
98e44bc
delete cutedslmoe
mpgemm Mar 19, 2026
0bc8a0c
prefill error
mpgemm Mar 19, 2026
bbcf766
support eb5
mpgemm Mar 19, 2026
3d8ecc3
fix prefill
mpgemm Mar 19, 2026
3cbba91
fix
mpgemm Mar 19, 2026
f2db042
first commit
mpgemm Mar 20, 2026
f126507
nvfp4
mpgemm Mar 20, 2026
e31cec8
fix
mpgemm Mar 20, 2026
0d2e5e1
fp4-moe
mpgemm Mar 21, 2026
8a894de
增加单测
mpgemm Mar 21, 2026
ad7f62a
fix formatting and unused variables
mpgemm Mar 22, 2026
087c7e1
修正合并错误
mpgemm Mar 22, 2026
2518e43
修正合并错误
mpgemm Mar 22, 2026
29a91c7
修正合并错误
mpgemm Mar 22, 2026
b3b8dff
修正合并错误
mpgemm Mar 22, 2026
2b36459
修复合并错误
mpgemm Mar 23, 2026
345a38c
Merge branch 'develop' into fp4-moe
mpgemm Mar 23, 2026
c4dcca3
修正合并错误
mpgemm Mar 23, 2026
62145bd
Merge remote-tracking branch 'upstream/develop' into fp4-moe
mpgemm Mar 23, 2026
8763d89
Merge remote-tracking branch 'upstream/develop' into fp4-moe
mpgemm Mar 24, 2026
b49582a
import flashinfer 放函数里面,实际推理需要放到头文件处
mpgemm Mar 24, 2026
1bde5b0
修正合并错误
mpgemm Mar 24, 2026
e3b3db6
support eb5
mpgemm Mar 24, 2026
381ac1f
fix
mpgemm Mar 24, 2026
af7b6ad
保留原有权重处理
mpgemm Mar 24, 2026
c4abb42
support decode
mpgemm Mar 24, 2026
f4f87d8
update prefill
Mar 25, 2026
7c43e58
support eb5
Mar 25, 2026
7aac2d0
merge develop
Mar 25, 2026
18a73f2
support eb5
Mar 25, 2026
27c1073
fix eb5-qijian
Mar 25, 2026
02711e0
delete print
Mar 25, 2026
31c40f6
解决冲突
mpgemm Mar 26, 2026
bd3555b
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
mpgemm Mar 26, 2026
e48a5f4
nvfp4 load blockscale
mpgemm Mar 26, 2026
1fd5d66
解决重复
mpgemm Mar 26, 2026
e066dc0
删除空行
mpgemm Mar 26, 2026
95dd02c
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
mpgemm Mar 26, 2026
b441aeb
导入放函数里面
mpgemm Mar 26, 2026
c9fc7bf
修复合并错误
mpgemm Mar 27, 2026
6aacc3b
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
mpgemm Mar 27, 2026
338ea07
删除logger
mpgemm Mar 27, 2026
7891d24
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
mpgemm Mar 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions custom_ops/gpu_ops/moe/depermute_prefill_combine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,19 @@ std::vector<paddle::Tensor> DepermutePrefillCombine(
case paddle::DataType::FLOAT8_E4M3FN: {
switch (topk) {
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 4)
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 6)
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
PD_THROW("Unsupported topk value, must be 4, 6 or 8");
}
}
case paddle::DataType::BFLOAT16: {
switch (topk) {
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 4)
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 6)
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
PD_THROW("Unsupported topk value, must be 4, 6 or 8");
}
}
default:
Expand Down
4 changes: 3 additions & 1 deletion custom_ops/gpu_ops/moe/prefill_permute_to_masked_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,12 @@ std::vector<paddle::Tensor> PrefillPermuteToMaskedGemm(
switch (topk) {
DISPATCH_TOPK(
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 4)
DISPATCH_TOPK(
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 6)
DISPATCH_TOPK(
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
PD_THROW("Unsupported topk value, must be 4 or 6 or 8");
}
}
case paddle::DataType::INT32: {
Expand Down
4 changes: 3 additions & 1 deletion fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ def _validate_split_kv_size(value: int) -> int:
"FD_ATTENTION_BACKEND": lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"),
# Set sampling class. "base", "base_non_truncated", "air" and "rejection" can be set currently.
"FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
# Set moe backend."cutlass","marlin", "triton", "flashinfer-cutlass" and "flashinfer-trtllm" can be set currently.
# Set moe backend."cutlass","marlin", "triton", "flashinfer-cutlass", "flashinfer-cutedsl" and "flashinfer-trtllm" can be set currently.
"FD_MOE_BACKEND": lambda: os.getenv("FD_MOE_BACKEND", "cutlass"),
# Set nvfp4 load interleaved weight scale.
"FD_NVFP4_LOAD_BLOCKSCALE_LEAVE": lambda: os.getenv("FD_NVFP4_LOAD_BLOCKSCALE_LEAVE", "0"),
# Set mxfp4 backend."flashinfer" can be set currently.
"FD_MOE_MXFP4_BACKEND": lambda: os.getenv("FD_MOE_MXFP4_BACKEND", "flashinfer"),
# Whether to use Machete for wint4 dense gemm.
Expand Down
7 changes: 4 additions & 3 deletions fastdeploy/model_executor/layers/moe/ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,15 @@ def create_buffer(self):
if self.deepep_buffer is not None:
self.clear_buffer()

num_qps_per_rank = max(24, self.num_experts // self.ep_size)
if self.splitwise_role == "mixed":
logger.info("Initializing mixed mode buffer (low latency).")
self.deepep_buffer = deep_ep.Buffer(
self.group,
self.num_nvl_bytes,
self.num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=24,
num_qps_per_rank=num_qps_per_rank,
)
self.deepep_buffer.set_num_sms(14) # TODO: tune in future
else:
Expand All @@ -183,7 +184,7 @@ def create_buffer(self):
self.num_nvl_bytes,
self.num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=24,
num_qps_per_rank=num_qps_per_rank,
)
else:
raise ValueError(f"Unknown generation phase: {self.moe_phase.phase}")
Expand All @@ -199,7 +200,7 @@ def _create_low_latency_buffer(self):
if self.ep_size // 8 > 1:
num_qps_per_rank_now = self.ep_size // 8
else:
num_qps_per_rank_now = 1
num_qps_per_rank_now = self.num_experts // self.ep_size
self.deepep_buffer = deep_ep.Buffer(
self.group,
self.num_nvl_bytes,
Expand Down
214 changes: 214 additions & 0 deletions fastdeploy/model_executor/layers/moe/flashinfer_cutedsl_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

from typing import Any, Optional

import paddle

paddle.compat.enable_torch_proxy(scope={"flashinfer"})


def _dtype_str(dtype) -> str:
"""Normalize dtype to string, handling both paddle and torch proxy dtypes."""
return str(dtype).split(".")[-1]


def _is_dtype(tensor, *dtype_names: str) -> bool:
"""Check tensor dtype by name, compatible with both paddle and torch proxy tensors."""
return _dtype_str(tensor.dtype) in dtype_names


def _perm(tensor, *dims):
"""Permute tensor dims, compatible with both paddle (transpose) and torch proxy (permute)."""
try:
return tensor.transpose(list(dims))
except TypeError:
return tensor.permute(*dims)


def get_cute_dtype(input) -> str:
s = _dtype_str(input.dtype)
if s == "bfloat16":
return "bfloat16"
elif s == "float16":
return "float16"
elif s == "float32":
return "float32"
else:
raise ValueError(f"Unsupported cute dtype {input.dtype}")


def flashinfer_cutedsl_moe_masked(
hidden_states: tuple,
input_global_scale: paddle.Tensor,
w1: paddle.Tensor,
w1_blockscale: paddle.Tensor,
w1_alpha: paddle.Tensor,
w2: paddle.Tensor,
a2_global_scale: paddle.Tensor,
w2_blockscale: paddle.Tensor,
w2_alpha: paddle.Tensor,
masked_m: paddle.Tensor,
down_sm_count: Optional[int] = None,
down_signals: Optional[paddle.Tensor] = None,
down_start_event: Optional[Any] = None,
):
"""
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL kernels.

Args:
hidden_states: Either of the following:
* (paddle.Tensor, None): [num_experts, m, k] bf16 — not pre-quantized
* (paddle.Tensor, paddle.Tensor): [m, k//2, num_experts] uint8,
[m, k//16, num_experts] float8_e4m3fn — pre-quantized FP4 from dispatch
input_global_scale: (l,) float32, value is 1/input_scale per expert
w1: [l, 2*n, k//2] uint8, FP4-packed gate+up projection weights
w1_blockscale: float8_e4m3fn blockscale for w1
w1_alpha: (l,) float32, = input_scale * w1_weight_scale_2
w2: [l, k, n//2] uint8, FP4-packed down projection weights
a2_global_scale: (l,) float32, 1/input_scale for GEMM2
w2_blockscale: float8_e4m3fn blockscale for w2
w2_alpha: (l,) float32, = input_scale * w2_weight_scale_2
masked_m: (l,) int32, valid token count per expert; max(masked_m) == m

Returns:
paddle.Tensor: [num_experts, m, k] bf16
"""
from flashinfer import (
scaled_fp4_grouped_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize,
)
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked

# === Dtype assertions ===
# Use string-based dtype check to be compatible with both paddle and torch proxy tensors
assert _is_dtype(w1, "uint8"), f"w1 must be uint8 (fp4 packed), got {w1.dtype}"
assert _is_dtype(w1_blockscale, "float8_e4m3fn"), f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
assert _is_dtype(w1_alpha, "float32"), f"w1_alpha must be float32, got {w1_alpha.dtype}"
assert _is_dtype(w2, "uint8"), f"w2 must be uint8 (fp4 packed), got {w2.dtype}"
assert _is_dtype(a2_global_scale, "float32"), f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
assert _is_dtype(w2_blockscale, "float8_e4m3fn"), f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
assert _is_dtype(w2_alpha, "float32"), f"w2_alpha must be float32, got {w2_alpha.dtype}"
assert len(hidden_states) == 2, f"hidden_states must be a tuple of length 2, got {len(hidden_states)}"

# intermediate_size derived from w2 last dimension
n = w2.shape[-1] * 2

if hidden_states[1] is not None:
# Pre-quantized path: tokens already FP4-packed by dispatch
# a_q: [m, k//2, num_experts] uint8
# a_q_sf:[m, k//16, num_experts] float8_e4m3fn
a_q = hidden_states[0].view(paddle.uint8)
a_q_sf = hidden_states[1].view(paddle.float8_e4m3fn)
m, k_by_2, num_experts = a_q.shape
k = k_by_2 * 2
else:
# Standard path: bf16 [num_experts, m, k], quantize to FP4 here
num_experts, m, k = hidden_states[0].shape

assert _is_dtype(
input_global_scale, "float32"
), f"input_global_scale must be float32, got {input_global_scale.dtype}"
assert list(input_global_scale.shape) == [
num_experts
], f"input_global_scale must be (l,), got {input_global_scale.shape}"

a_q, a_q_sf = scaled_fp4_grouped_quantize(
hidden_states[0],
masked_m,
input_global_scale,
)

assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n={2*n}, got {w1.shape[-2]}"
assert w1.shape[-1] * 2 == k, f"w1 last dim * 2 must equal k={k}, got {w1.shape[-1] * 2}"
assert (
w2.shape[-2] == k and w2.shape[-1] == n // 2
), f"w2 shape mismatch, got {list(w2.shape[-2:])}, expected [{k}, {n // 2}]"
assert list(w1_alpha.shape) == [num_experts], f"w1_alpha must be (l,), got {w1_alpha.shape}"
assert list(a2_global_scale.shape) == [num_experts], f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
assert list(w2_alpha.shape) == [num_experts], f"w2_alpha must be (l,), got {w2_alpha.shape}"

assert _is_dtype(a_q, "uint8")
assert _is_dtype(a_q_sf, "float8_e4m3fn")

ab_dtype = "float4_e2m1fn"
sf_dtype = "float8_e4m3fn"
c_dtype = "bfloat16"
sf_vec_size = 16

# === GEMM1: gate+up projection ===
# grouped_gemm_nt_masked requires output in [m, 2*n, l] layout
gateup_output = paddle.empty([num_experts, m, n * 2], dtype=paddle.bfloat16)
gateup_output = gateup_output.transpose([1, 2, 0]) # [m, 2*n, num_experts]

# w1: [E, 2*n, k//2] → _perm(., 1, 2, 0) → [2*n, k//2, E]
# w1_blockscale:[E, 2*n, k//G] → _perm(., 1, 2, 0) → [2*n, k//G, E]
# Both must share the same expert-last layout for grouped_gemm_nt_masked.
grouped_gemm_nt_masked(
(a_q, a_q_sf),
(_perm(w1, 1, 2, 0), _perm(w1_blockscale, 1, 2, 0)),
gateup_output,
masked_m,
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=w1_alpha.reshape([1, 1, num_experts]),
alpha_dtype=get_cute_dtype(w1_alpha),
) # fills gateup_output in logical [m, 2*n, l]

# === SiLU + mul + quantize intermediate activations to FP4 ===
# Input expected as [num_experts, m, 2*n]
diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize(
gateup_output.transpose([2, 0, 1]), # [num_experts, m, 2*n]
masked_m,
a2_global_scale,
)

if down_start_event is not None:
down_start_event.record()

# === GEMM2: down projection ===
# grouped_gemm_nt_masked requires output in [m, k, l] layout
out = paddle.empty([num_experts, m, k], dtype=paddle.bfloat16)
out = out.transpose([1, 2, 0]) # [m, k, num_experts]

# w2: [E, k, n//2] → _perm(., 1, 2, 0) → [k, n//2, E]
# w2_blockscale:[E, k, n//G] → _perm(., 1, 2, 0) → [k, n//G, E]
# Both must share the same expert-last layout for grouped_gemm_nt_masked.
grouped_gemm_nt_masked(
(diq, diq_sf),
(_perm(w2, 1, 2, 0), _perm(w2_blockscale, 1, 2, 0)),
out,
masked_m,
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=w2_alpha.reshape([1, 1, num_experts]),
alpha_dtype=get_cute_dtype(w2_alpha),
**(
dict(
sm_count=down_sm_count,
dst_signals=down_signals,
)
if down_sm_count is not None or down_signals is not None
else {}
),
) # fills out in logical [m, k, l]

# Return [num_experts, m, k]
return out.transpose([2, 0, 1])
Loading
Loading