diff --git a/custom_ops/gpu_ops/moe/depermute_prefill_combine.cu b/custom_ops/gpu_ops/moe/depermute_prefill_combine.cu index 1d7ceffee04..22c19f39e58 100644 --- a/custom_ops/gpu_ops/moe/depermute_prefill_combine.cu +++ b/custom_ops/gpu_ops/moe/depermute_prefill_combine.cu @@ -172,17 +172,19 @@ std::vector DepermutePrefillCombine( case paddle::DataType::FLOAT8_E4M3FN: { switch (topk) { DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 4) + DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 6) DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 8) default: - PD_THROW("Unsupported topk value, must be 4 or 8"); + PD_THROW("Unsupported topk value, must be 4, 6 or 8"); } } case paddle::DataType::BFLOAT16: { switch (topk) { DISPATCH_TOPK(paddle::DataType::BFLOAT16, 4) + DISPATCH_TOPK(paddle::DataType::BFLOAT16, 6) DISPATCH_TOPK(paddle::DataType::BFLOAT16, 8) default: - PD_THROW("Unsupported topk value, must be 4 or 8"); + PD_THROW("Unsupported topk value, must be 4, 6 or 8"); } } default: diff --git a/custom_ops/gpu_ops/moe/prefill_permute_to_masked_gemm.cu b/custom_ops/gpu_ops/moe/prefill_permute_to_masked_gemm.cu index ca12a12687b..a5fb2494605 100644 --- a/custom_ops/gpu_ops/moe/prefill_permute_to_masked_gemm.cu +++ b/custom_ops/gpu_ops/moe/prefill_permute_to_masked_gemm.cu @@ -217,10 +217,12 @@ std::vector PrefillPermuteToMaskedGemm( switch (topk) { DISPATCH_TOPK( paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 4) + DISPATCH_TOPK( + paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 6) DISPATCH_TOPK( paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 8) default: - PD_THROW("Unsupported topk value, must be 4 or 8"); + PD_THROW("Unsupported topk value, must be 4 or 6 or 8"); } } case paddle::DataType::INT32: { diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 72cd6dc7c48..3e5d25a6e3f 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -62,7 +62,7 @@ def _validate_split_kv_size(value: int) -> int: "FD_ATTENTION_BACKEND": lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"), # Set sampling class. "base", "base_non_truncated", "air" and "rejection" can be set currently. "FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"), - # Set moe backend."cutlass","marlin", "triton", "flashinfer-cutlass" and "flashinfer-trtllm" can be set currently. + # Set moe backend."cutlass","marlin", "triton", "flashinfer-cutlass", "flashinfer-cutedsl" and "flashinfer-trtllm" can be set currently. "FD_MOE_BACKEND": lambda: os.getenv("FD_MOE_BACKEND", "cutlass"), # Set mxfp4 backend."flashinfer" can be set currently. "FD_MOE_MXFP4_BACKEND": lambda: os.getenv("FD_MOE_MXFP4_BACKEND", "flashinfer"), diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 66bb8bfeef5..243567a422f 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -163,6 +163,7 @@ def create_buffer(self): if self.deepep_buffer is not None: self.clear_buffer() + num_qps_per_rank = max(24, self.num_experts // self.ep_size) if self.splitwise_role == "mixed": logger.info("Initializing mixed mode buffer (low latency).") self.deepep_buffer = deep_ep.Buffer( @@ -170,7 +171,7 @@ def create_buffer(self): self.num_nvl_bytes, self.num_rdma_bytes, low_latency_mode=True, - num_qps_per_rank=24, + num_qps_per_rank=num_qps_per_rank, ) self.deepep_buffer.set_num_sms(14) # TODO: tune in future else: @@ -183,7 +184,7 @@ def create_buffer(self): self.num_nvl_bytes, self.num_rdma_bytes, low_latency_mode=True, - num_qps_per_rank=24, + num_qps_per_rank=num_qps_per_rank, ) else: raise ValueError(f"Unknown generation phase: {self.moe_phase.phase}") @@ -199,7 +200,7 @@ def _create_low_latency_buffer(self): if self.ep_size // 8 > 1: num_qps_per_rank_now = self.ep_size // 8 else: - num_qps_per_rank_now = 1 + num_qps_per_rank_now = self.num_experts // self.ep_size self.deepep_buffer = deep_ep.Buffer( self.group, self.num_nvl_bytes, diff --git a/fastdeploy/model_executor/layers/moe/flashinfer_cutedsl_moe.py b/fastdeploy/model_executor/layers/moe/flashinfer_cutedsl_moe.py new file mode 100644 index 00000000000..6386a919a51 --- /dev/null +++ b/fastdeploy/model_executor/layers/moe/flashinfer_cutedsl_moe.py @@ -0,0 +1,212 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Any, Optional + +import paddle +from flashinfer import ( + scaled_fp4_grouped_quantize, + silu_and_mul_scaled_nvfp4_experts_quantize, +) +from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked + + +def _dtype_str(dtype) -> str: + """Normalize dtype to string, handling both paddle and torch proxy dtypes.""" + return str(dtype).split(".")[-1] + + +def _is_dtype(tensor, *dtype_names: str) -> bool: + """Check tensor dtype by name, compatible with both paddle and torch proxy tensors.""" + return _dtype_str(tensor.dtype) in dtype_names + + +def _perm(tensor, *dims): + """Permute tensor dims, compatible with both paddle (transpose) and torch proxy (permute).""" + try: + return tensor.transpose(list(dims)) + except TypeError: + return tensor.permute(*dims) + + +def get_cute_dtype(input) -> str: + s = _dtype_str(input.dtype) + if s == "bfloat16": + return "bfloat16" + elif s == "float16": + return "float16" + elif s == "float32": + return "float32" + else: + raise ValueError(f"Unsupported cute dtype {input.dtype}") + + +def flashinfer_cutedsl_moe_masked( + hidden_states: tuple, + input_global_scale: paddle.Tensor, + w1: paddle.Tensor, + w1_blockscale: paddle.Tensor, + w1_alpha: paddle.Tensor, + w2: paddle.Tensor, + a2_global_scale: paddle.Tensor, + w2_blockscale: paddle.Tensor, + w2_alpha: paddle.Tensor, + masked_m: paddle.Tensor, + down_sm_count: Optional[int] = None, + down_signals: Optional[paddle.Tensor] = None, + down_start_event: Optional[Any] = None, +): + """ + Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL kernels. + + Args: + hidden_states: Either of the following: + * (paddle.Tensor, None): [num_experts, m, k] bf16 — not pre-quantized + * (paddle.Tensor, paddle.Tensor): [m, k//2, num_experts] uint8, + [m, k//16, num_experts] float8_e4m3fn — pre-quantized FP4 from dispatch + input_global_scale: (l,) float32, value is 1/input_scale per expert + w1: [l, 2*n, k//2] uint8, FP4-packed gate+up projection weights + w1_blockscale: float8_e4m3fn blockscale for w1 + w1_alpha: (l,) float32, = input_scale * w1_weight_scale_2 + w2: [l, k, n//2] uint8, FP4-packed down projection weights + a2_global_scale: (l,) float32, 1/input_scale for GEMM2 + w2_blockscale: float8_e4m3fn blockscale for w2 + w2_alpha: (l,) float32, = input_scale * w2_weight_scale_2 + masked_m: (l,) int32, valid token count per expert; max(masked_m) == m + + Returns: + paddle.Tensor: [num_experts, m, k] bf16 + """ + + # === Dtype assertions === + # Use string-based dtype check to be compatible with both paddle and torch proxy tensors + assert _is_dtype(w1, "uint8"), f"w1 must be uint8 (fp4 packed), got {w1.dtype}" + assert _is_dtype(w1_blockscale, "float8_e4m3fn"), f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}" + assert _is_dtype(w1_alpha, "float32"), f"w1_alpha must be float32, got {w1_alpha.dtype}" + assert _is_dtype(w2, "uint8"), f"w2 must be uint8 (fp4 packed), got {w2.dtype}" + assert _is_dtype(a2_global_scale, "float32"), f"a2_global_scale must be float32, got {a2_global_scale.dtype}" + assert _is_dtype(w2_blockscale, "float8_e4m3fn"), f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}" + assert _is_dtype(w2_alpha, "float32"), f"w2_alpha must be float32, got {w2_alpha.dtype}" + assert len(hidden_states) == 2, f"hidden_states must be a tuple of length 2, got {len(hidden_states)}" + + # intermediate_size derived from w2 last dimension + n = w2.shape[-1] * 2 + + if hidden_states[1] is not None: + # Pre-quantized path: tokens already FP4-packed by dispatch + # a_q: [m, k//2, num_experts] uint8 + # a_q_sf:[m, k//16, num_experts] float8_e4m3fn + a_q = hidden_states[0].view(paddle.uint8) + a_q_sf = hidden_states[1].view(paddle.float8_e4m3fn) + m, k_by_2, num_experts = a_q.shape + k = k_by_2 * 2 + else: + # Standard path: bf16 [num_experts, m, k], quantize to FP4 here + num_experts, m, k = hidden_states[0].shape + + assert _is_dtype( + input_global_scale, "float32" + ), f"input_global_scale must be float32, got {input_global_scale.dtype}" + assert list(input_global_scale.shape) == [ + num_experts + ], f"input_global_scale must be (l,), got {input_global_scale.shape}" + + a_q, a_q_sf = scaled_fp4_grouped_quantize( + hidden_states[0], + masked_m, + input_global_scale, + ) + + assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n={2*n}, got {w1.shape[-2]}" + assert w1.shape[-1] * 2 == k, f"w1 last dim * 2 must equal k={k}, got {w1.shape[-1] * 2}" + assert ( + w2.shape[-2] == k and w2.shape[-1] == n // 2 + ), f"w2 shape mismatch, got {list(w2.shape[-2:])}, expected [{k}, {n // 2}]" + assert list(w1_alpha.shape) == [num_experts], f"w1_alpha must be (l,), got {w1_alpha.shape}" + assert list(a2_global_scale.shape) == [num_experts], f"a2_global_scale must be (l,), got {a2_global_scale.shape}" + assert list(w2_alpha.shape) == [num_experts], f"w2_alpha must be (l,), got {w2_alpha.shape}" + + assert _is_dtype(a_q, "uint8") + assert _is_dtype(a_q_sf, "float8_e4m3fn") + + ab_dtype = "float4_e2m1fn" + sf_dtype = "float8_e4m3fn" + c_dtype = "bfloat16" + sf_vec_size = 16 + + # === GEMM1: gate+up projection === + # grouped_gemm_nt_masked requires output in [m, 2*n, l] layout + gateup_output = paddle.empty([num_experts, m, n * 2], dtype=paddle.bfloat16) + gateup_output = gateup_output.transpose([1, 2, 0]) # [m, 2*n, num_experts] + + # w1: [E, 2*n, k//2] → _perm(., 1, 2, 0) → [2*n, k//2, E] + # w1_blockscale:[E, 2*n, k//G] → _perm(., 1, 2, 0) → [2*n, k//G, E] + # Both must share the same expert-last layout for grouped_gemm_nt_masked. + grouped_gemm_nt_masked( + (a_q, a_q_sf), + (_perm(w1, 1, 2, 0), _perm(w1_blockscale, 1, 2, 0)), + gateup_output, + masked_m, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + sf_vec_size=sf_vec_size, + alpha=w1_alpha.reshape([1, 1, num_experts]), + alpha_dtype=get_cute_dtype(w1_alpha), + ) # fills gateup_output in logical [m, 2*n, l] + + # === SiLU + mul + quantize intermediate activations to FP4 === + # Input expected as [num_experts, m, 2*n] + diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize( + gateup_output.transpose([2, 0, 1]), # [num_experts, m, 2*n] + masked_m, + a2_global_scale, + ) + + if down_start_event is not None: + down_start_event.record() + + # === GEMM2: down projection === + # grouped_gemm_nt_masked requires output in [m, k, l] layout + out = paddle.empty([num_experts, m, k], dtype=paddle.bfloat16) + out = out.transpose([1, 2, 0]) # [m, k, num_experts] + + # w2: [E, k, n//2] → _perm(., 1, 2, 0) → [k, n//2, E] + # w2_blockscale:[E, k, n//G] → _perm(., 1, 2, 0) → [k, n//G, E] + # Both must share the same expert-last layout for grouped_gemm_nt_masked. + grouped_gemm_nt_masked( + (diq, diq_sf), + (_perm(w2, 1, 2, 0), _perm(w2_blockscale, 1, 2, 0)), + out, + masked_m, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + sf_vec_size=sf_vec_size, + alpha=w2_alpha.reshape([1, 1, num_experts]), + alpha_dtype=get_cute_dtype(w2_alpha), + **( + dict( + sm_count=down_sm_count, + dst_signals=down_signals, + ) + if down_sm_count is not None or down_signals is not None + else {} + ), + ) # fills out in logical [m, k, l] + + # Return [num_experts, m, k] + return out.transpose([2, 0, 1]) diff --git a/fastdeploy/model_executor/layers/quantization/nvfp4.py b/fastdeploy/model_executor/layers/quantization/nvfp4.py index d15d3d9dbac..531bd1a7ba8 100644 --- a/fastdeploy/model_executor/layers/quantization/nvfp4.py +++ b/fastdeploy/model_executor/layers/quantization/nvfp4.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import os from typing import Callable, Optional import paddle @@ -35,6 +36,88 @@ paddle.compat.enable_torch_proxy(scope={"flashinfer"}) +try: + # FlashInfer cutedsl blockscaled gemm kernels + + from fastdeploy.model_executor.layers.moe.ep import deep_ep + from fastdeploy.model_executor.ops.gpu import ( + depermute_prefill_combine, + prefill_permute_to_masked_gemm, + ) + + logger.info( + "FlashInfer cutedsl is slow to import because it triggers JIT compilation of " + "CUDA kernels via TVM/CODEGEN, and cuBLASLt initializes lookup tables and " + "compiles GEMM kernels during first load. This may take several minutes. " + "The wait is expected and only happens once per process." + ) + from fastdeploy.model_executor.layers.moe.flashinfer_cutedsl_moe import ( + flashinfer_cutedsl_moe_masked, + ) + +except ImportError: + raise ImportError("flashinfer_cutedsl_moe_masked not found, flashinfer kernel may not be enabled.") + +global_values = {} + + +def call_prefill_permute_to_masked_gemm( + x: paddle.Tensor, + scale: paddle.Tensor, + topk_ids: paddle.Tensor, + num_local_experts: int, + max_token_num: int, +): + """ + Permute input tokens and scales from token-major to expert-major layout + for MoE masked GEMM operations. + + Args: + x: Input hidden states [num_tokens, hidden]. + scale: Input scales [num_tokens, hidden_scale]. + topk_ids: Expert routing indices [num_tokens, topk] (int64 or int32). + num_local_experts: Number of local experts on this device. + max_token_num: Maximum tokens per expert buffer. + + Returns: + tuple: (permute_x, permute_scale, permuted_indice_map, token_nums_per_expert) + """ + if topk_ids.dtype != paddle.int64: + topk_ids = topk_ids.cast(paddle.int64) + + # NVFP4 dispatch returns plain BF16 (no fp8 scale); pass empty tensor so the + # C++ op can detect the no-scale path via tensor.numel() == 0. + if scale is None: + scale = paddle.empty([0], dtype=paddle.float32) + + results = prefill_permute_to_masked_gemm(x, scale, topk_ids, num_local_experts, max_token_num) + + return results[0], results[1], results[2], results[3] + + +def call_depermute_prefill_combine( + x: paddle.Tensor, + indice_map: paddle.Tensor, + topk_weights: paddle.Tensor, + num_worst_tokens: int, +): + """ + Depermute and combine expert outputs back to token-major layout. + + Args: + x: Expert outputs [num_local_experts, max_tokens_per_expert, hidden]. + indice_map: Flat index tensor [num_worst_tokens, topk] (int32). + topk_weights: Combination weights [num_worst_tokens, topk] (float32). + num_worst_tokens: Number of output tokens to produce. + + Returns: + depermuted_x: Combined output [num_worst_tokens, hidden]. + """ + results = depermute_prefill_combine(x, indice_map, topk_weights, num_worst_tokens) + + return results + + def next_power_of_2(n: int): return 1 << (n - 1).bit_length() if n > 0 else 1 @@ -191,12 +274,12 @@ def create_weights( layer, **extra_weight_attrs, ): - # 因为模型存储是列存储的,所以这里需要not一下! + # Model storage is column-major, so we need to invert the output_dim flag extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"] K = layer.weight_shape[0] N = layer.weight_shape[1] - # 因为模型的存储时候权重是[N,K//2] - # 所以这里创建的权重是为了契合模型存储的权重! + # Model stored weights are in [N, K//2] format + # Create weight shape to match model storage format weight_shape = [N, K // 2] layer.weight_dtype = "uint8" @@ -209,12 +292,12 @@ def create_weights( self._create_weight_scales(layer, weight_scale_shape, weight_scale_2_shape, extra_weight_attrs) def _create_main_weight(self, layer, weight_shape, extra_weight_attrs): - """创建主权重参数 + """Create main weight parameter - 参数: - layer: 当前层对象 - weight_shape: 权重形状 - extra_weight_attrs: 额外权重属性 + Args: + layer: Current layer object + weight_shape: Weight shape + extra_weight_attrs: Extra weight attributes """ layer.weight = layer.create_parameter( shape=weight_shape, @@ -228,11 +311,11 @@ def _create_main_weight(self, layer, weight_shape, extra_weight_attrs): ) def _create_input_scale(self, layer, input_scale_shape): - """创建输入缩放参数 + """Create input scale parameter - 参数: - layer: 当前层对象 - input_scale_shape: 输入缩放形状 + Args: + layer: Current layer object + input_scale_shape: Input scale shape """ layer.input_scale = layer.create_parameter( shape=input_scale_shape, @@ -242,13 +325,13 @@ def _create_input_scale(self, layer, input_scale_shape): ) def _create_weight_scales(self, layer, weight_scale_shape, weight_scale_2_shape, extra_weight_attrs): - """创建权重缩放参数 + """Create weight scale parameters - 参数: - layer: 当前层对象 - weight_scale_shape: 权重缩放形状 - weight_scale_2_shape: 权重缩放2形状 - extra_weight_attrs: 额外权重属性 + Args: + layer: Current layer object + weight_scale_shape: Weight scale shape + weight_scale_2_shape: Secondary weight scale shape + extra_weight_attrs: Extra weight attributes """ layer.weight_scale = layer.create_parameter( shape=weight_scale_shape, @@ -339,9 +422,7 @@ def apply( else: raise ValueError(f"Unsupported backend: {self.backend}.") - # shape 恢复到[K//2,N] w = layer.weight.T - # shape 恢复到[K//group_size, N] w_scale_interleaved = layer.weight_scale_interleaved.T if backend == "cutlass": @@ -377,7 +458,6 @@ def __init__(self, quant_config: ModelOptNvFp4Config): "up_gate_proj_weight_scale", "down_proj_weight_scale", ] - self.quant_config = quant_config self.backend = "none" if envs.FD_MOE_BACKEND is None: @@ -508,10 +588,11 @@ def process_weights_after_loading(self, layer): # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13 - [a, b] = layer.up_gate_proj_weight.split(2, axis=1) - layer.up_gate_proj_weight.set_value(paddle.concat([b, a], axis=1)) - [a, b] = layer.up_gate_proj_weight_scale.split(2, axis=1) - layer.up_gate_proj_weight_scale.set_value(paddle.concat([b, a], axis=1)) + if self.backend == "flashinfer-cutlass": + [a, b] = layer.up_gate_proj_weight.split(2, axis=1) + layer.up_gate_proj_weight.set_value(paddle.concat([b, a], axis=1)) + [a, b] = layer.up_gate_proj_weight_scale.split(2, axis=1) + layer.up_gate_proj_weight_scale.set_value(paddle.concat([b, a], axis=1)) up_gate_proj_weight_scale_2 = layer.up_gate_proj_weight_scale_2[:, 0] free_tensor(layer.up_gate_proj_weight_scale_2) @@ -541,12 +622,14 @@ def process_weights_after_loading(self, layer): ), f"{name} Weight Blockscale must be represented as FP8-E4M3" up_gate_proj_blockscale_swizzled = _process_scale_interleaved(layer.up_gate_proj_weight_scale) + # up_gate_proj_blockscale_swizzled = layer.up_gate_proj_weight_scale free_tensor(layer.up_gate_proj_weight_scale) layer.up_gate_proj_weight_scale = None create_parameter_and_copy( layer, name="up_gate_proj_blockscale_swizzled", weight=up_gate_proj_blockscale_swizzled ) down_proj_blockscale_swizzled = _process_scale_interleaved(layer.down_proj_weight_scale) + # down_proj_blockscale_swizzled = layer.down_proj_weight_scale free_tensor(layer.down_proj_weight_scale) layer.down_proj_weight_scale = None create_parameter_and_copy(layer, name="down_proj_blockscale_swizzled", weight=down_proj_blockscale_swizzled) @@ -557,8 +640,108 @@ def apply_ep_prefill( x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None, + shared_experts: nn.Layer = None, ) -> paddle.Tensor: - pass + + logger.info("Running prefill") + # 1. top experts and weights + gate_out = gate(x.cast("float32")) + topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out) + hidden_size = x.shape[1] + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + + event = deep_ep.Buffer.capture() + + # 2. ep dispatch + ( + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_num_tokens_per_expert_list, + handle, + event, + ) = self.ep_prefill_runner.dispatch( + x, + topk_idx, + topk_weights, + expert_alignment=128, + previous_event=event, + ) + + if self.ep_prefill_runner.ep_engine.async_finish: + event.current_stream_wait() + + # nvfp4 dispatch returns a plain BF16 tensor (no fp8 scale), unlike deepgemm which returns (value, scale) tuple + recv_x_value = recv_x + recv_x_scale = None + + # 3. compute ffn + token_all_num = sum(recv_num_tokens_per_expert_list) + + if self.ep_prefill_runner.num_worst_tokens > 0: + token_split_factor = 2 if int(os.getenv("USE_TBO", "0")) == 1 else 1 + max_tokens_per_rank = ( + layer.fd_config.scheduler_config.max_num_batched_tokens + // layer.fd_config.parallel_config.tensor_parallel_size + // token_split_factor + ) + + # logger.debug(f"max_tokens_per_rank {max_tokens_per_rank}") + + permute_input, permute_scale, permuted_indice_map, token_nums_per_expert = ( + call_prefill_permute_to_masked_gemm( + x=recv_x_value, + scale=recv_x_scale, + topk_ids=recv_topk_idx, + num_local_experts=layer.num_local_experts, + max_token_num=layer.ep_size * max_tokens_per_rank, + ) + ) + max_token_num = layer.ep_size * max_tokens_per_rank + permute_input = permute_input.reshape([layer.num_local_experts, max_token_num, recv_x_value.shape[-1]]) + + # ffn_out: [num_local_experts, m, hidden_size] + # NVFP4 dispatch returns BF16 (no pre-quantized scale), so permute_scale is empty. + # Use per-expert 1/input_scale (up_gate_proj_input_scale_quant) as input_global_scale, + # consistent with apply_ep_decode which also uses this value directly. + ffn_out = flashinfer_cutedsl_moe_masked( + hidden_states=(permute_input, None), + input_global_scale=layer.up_gate_proj_input_scale_quant.expand([layer.num_local_experts]), + w1=layer.up_gate_proj_weight, + w1_blockscale=layer.up_gate_proj_blockscale_swizzled, + w1_alpha=layer.g1_alphas, + w2=layer.down_proj_weight, + a2_global_scale=layer.down_proj_input_scale_quant.expand([layer.num_local_experts]), + w2_blockscale=layer.down_proj_blockscale_swizzled, + w2_alpha=layer.g2_alphas, + masked_m=token_nums_per_expert.squeeze(-1).cast(paddle.int32), + ) + + tmp_ffn_out = call_depermute_prefill_combine( + x=ffn_out, + indice_map=permuted_indice_map, + topk_weights=recv_topk_weights, + num_worst_tokens=recv_x_value.shape[0], + ) + + elif token_all_num > 0: + raise NotImplementedError( + "NVFP4 EP prefill contiguous path (num_worst_tokens <= 0, token_all_num > 0) is not yet implemented." + ) + else: + tmp_ffn_out = paddle.empty([0, hidden_size], dtype=paddle.bfloat16) + + # 4. EP combine + event = deep_ep.Buffer.capture() + + tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights, event) + + if self.ep_prefill_runner.ep_engine.async_finish: + event.current_stream_wait() + + return tmp_ffn_out def apply_ep_decode( self, @@ -566,8 +749,40 @@ def apply_ep_decode( x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None, + shared_experts: nn.Layer = None, ) -> paddle.Tensor: - pass + + gate_out = gate(x.cast("float32")) + topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + + recv_x, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch( + x, + topk_idx, + topk_weights, + use_fp8=False, + ) + + # Compute FFN via CuteDSL masked grouped GEMM + num_experts = layer.num_local_experts + ffn_out = flashinfer_cutedsl_moe_masked( + hidden_states=(recv_x, None), + input_global_scale=layer.up_gate_proj_input_scale_quant.expand([num_experts]), + w1=layer.up_gate_proj_weight, + w1_blockscale=layer.up_gate_proj_blockscale_swizzled, + w1_alpha=layer.g1_alphas, + w2=layer.down_proj_weight, + a2_global_scale=layer.down_proj_input_scale_quant.expand([num_experts]), + w2_blockscale=layer.down_proj_blockscale_swizzled, + w2_alpha=layer.g2_alphas, + masked_m=token_nums_per_expert, + ) + + out = self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights, handle) + + return out def apply_tp( self, @@ -575,37 +790,25 @@ def apply_tp( x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None, - ) -> paddle.Tensor: - pass - - def apply( - self, - layer, - x, - gate, - topk_ids_hookfunc: Callable = None, shared_experts: nn.Layer = None, - ): - """ - flashinfer nvfp4 fusedmoe for Model Optimizer - """ - gate_out = gate(x.cast("float32")) - topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( - gate_out, - layer.gate_correction_bias, - layer.top_k, - True, # apply_norm_weight, - False, - ) + ) -> paddle.Tensor: + if self.backend == "flashinfer-cutlass": + gate_out = gate(x.cast("float32")) + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + layer.top_k, + True, # apply_norm_weight, + False, + ) - if topk_ids_hookfunc is not None: - topk_ids_hookfunc(topk_ids) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids) - output_dtype = x.dtype - x_sf = None - output = paddle.empty_like(x) + output_dtype = x.dtype + x_sf = None + output = paddle.empty_like(x) - if self.backend == "flashinfer-cutlass": # flashinfer cutlass from flashinfer.fused_moe import ( cutlass_fused_moe as flashinfer_cutlass_fused_moe, @@ -636,6 +839,5 @@ def apply( ) return output - - # flashinfer-trtllm - return output + else: + raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/utils.py b/fastdeploy/model_executor/layers/utils.py index aa5fd39de0f..f3444173e18 100644 --- a/fastdeploy/model_executor/layers/utils.py +++ b/fastdeploy/model_executor/layers/utils.py @@ -561,15 +561,31 @@ def get_sm_version(): def modules_to_convert(prefix: str, fd_config: FDConfig): import fnmatch + exclude_patterns = [] + if ( hasattr(fd_config.model_config, "quantization_config") and fd_config.model_config.quantization_config is not None ): if "modules_to_not_convert" in fd_config.model_config.quantization_config: - patterns = fd_config.model_config.quantization_config["modules_to_not_convert"] - for p in patterns: - if fnmatch.fnmatch(prefix, p) or fnmatch.fnmatch(prefix, p + ".*"): - return False - return True - else: - return True + exclude_patterns.extend(fd_config.model_config.quantization_config["modules_to_not_convert"]) + # 2. Check quantization_config["ignore"] (used by some models like NVFP4) + if "ignore" in fd_config.model_config.quantization_config: + exclude_patterns.extend(fd_config.model_config.quantization_config["ignore"]) + + # Get the model's actual prefix_name (e.g., "ernie" or "model") + prefix_name = "model" # default + if hasattr(fd_config, "model_config") and hasattr(fd_config.model_config, "pretrained_config"): + prefix_name = getattr(fd_config.model_config.pretrained_config, "prefix_name", "model") + + # Check if prefix matches any exclude pattern + for p in exclude_patterns: + # Direct match + if fnmatch.fnmatch(prefix, p) or fnmatch.fnmatch(prefix, p + ".*"): + return False + # Handle case where pattern uses "model" but actual prefix is "ernie" (or vice versa) + if p.startswith("model."): + adapted_pattern = prefix_name + "." + p[6:] + if fnmatch.fnmatch(prefix, adapted_pattern) or fnmatch.fnmatch(prefix, adapted_pattern + ".*"): + return False + return True diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 4cc4306de5f..c37d4f080f3 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -491,6 +491,7 @@ def forward( ids_remove_padding: paddle.Tensor, forward_meta: ForwardMeta, ): + hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta) if current_platform.is_iluvatar() and forward_meta.attn_backend.mixed: diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index b03fa480d3e..7c2de5990c8 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -538,6 +538,7 @@ def rename_offline_ckpt_suffix_to_fd_suffix( def fn(loaded_weight_name, is_moe): if fd_config.quant_config is None or fd_config.quant_config.is_checkpoint_bf16: return loaded_weight_name + fd_suffix_map = {} # Can be extended to other offline quantization suffixes if needed. if (is_moe and moe_quant_type == "block_wise_fp8") or (not is_moe and dense_quant_type == "block_wise_fp8"): fd_suffix_map = fp8_suffix_map diff --git a/tests/layers/test_fusedmoe.py b/tests/layers/test_nvfp4_fusedmoe.py similarity index 53% rename from tests/layers/test_fusedmoe.py rename to tests/layers/test_nvfp4_fusedmoe.py index d97363fe758..44aaa303484 100644 --- a/tests/layers/test_fusedmoe.py +++ b/tests/layers/test_nvfp4_fusedmoe.py @@ -36,9 +36,7 @@ ) from fastdeploy.model_executor.layers.linear import ReplicatedLinear from fastdeploy.model_executor.layers.moe.moe import FusedMoE -from fastdeploy.model_executor.layers.quantization.block_wise_fp8 import ( - BlockWiseFP8Config, -) +from fastdeploy.model_executor.layers.quantization.nvfp4 import ModelOptNvFp4Config from fastdeploy.scheduler import SchedulerConfig from fastdeploy.worker.worker_process import init_distributed_environment @@ -446,9 +444,9 @@ class FuseMoEWrapper(paddle.nn.Layer): def __init__( self, model_config: ModelConfig, - tp_size: int = 1, + tp_size: int = 8, tp_rank: int = 0, - ep_size: int = 1, + ep_size: int = 8, ep_rank: int = 0, prefix: str = "layer0", nnodes: int = 1, @@ -471,7 +469,13 @@ def __init__( "data_parallel_size": self.ep_size, } ), - quant_config=BlockWiseFP8Config(weight_block_size=[128, 128]), + # quant_config=BlockWiseFP8Config(weight_block_size=[128, 128]), + quant_config=ModelOptNvFp4Config( + is_checkpoint_nvfp4_serialized=True, + kv_cache_quant_algo=None, + exclude_modules=[], + group_size=16, + ), # quant_config=WINT8Config({}), # quant_config=WINT4Config({}), scheduler_config=SchedulerConfig({}), @@ -487,8 +491,8 @@ def __init__( self.fd_config.parallel_config.expert_parallel_size = self.ep_size if self.ep_size > 1: self.fd_config.parallel_config.ep_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() - self.fd_config.scheduler_config.splitwise_role = "decode" - self.fd_config.model_config.moe_phase.phase = "decode" + self.fd_config.scheduler_config.splitwise_role = "prefill" + self.fd_config.model_config.moe_phase.phase = "prefill" weight_key_map = { "gate_weight_key": f"{self.prefix}.gate.weight", @@ -512,8 +516,8 @@ def __init__( moe_intermediate_size=self.fd_config.model_config.moe_intermediate_size, num_experts=self.fd_config.model_config.moe_num_experts, top_k=self.fd_config.model_config.moe_k, - # avoiding invoke clean_low_latency_buffer in mixed ep. - layer_idx=666, + # Keep start-layer index so mixed-EP low-latency buffer cleanup can run. + layer_idx=0, weight_key_map=weight_key_map, topk_method="noaux_tc", topk_group=4, @@ -523,43 +527,44 @@ def __init__( ) moe_layer = self.fused_moe - up_gate_proj_weight_shape = [ - moe_layer.num_local_experts, - moe_layer.hidden_size, - moe_layer.moe_intermediate_size * 2, - ] - down_proj_weight_shape = [ - moe_layer.num_local_experts, - moe_layer.moe_intermediate_size, - moe_layer.hidden_size, - ] - - up_gate_proj_weight = paddle.randn(up_gate_proj_weight_shape, paddle.bfloat16) - down_proj_weight = paddle.randn(down_proj_weight_shape, paddle.bfloat16) - - local_expert_ids = list( - range(moe_layer.expert_id_offset, moe_layer.expert_id_offset + moe_layer.num_local_experts) + up_gate_proj_weight = getattr(moe_layer, "up_gate_proj_weight") + down_proj_weight = getattr(moe_layer, "down_proj_weight") + up_gate_proj_weight.set_value( + paddle.randint(0, 255, up_gate_proj_weight.shape, dtype=paddle.int32).cast(paddle.uint8) + ) + down_proj_weight.set_value( + paddle.randint(0, 255, down_proj_weight.shape, dtype=paddle.int32).cast(paddle.uint8) ) - state_dict = {} - up_gate_proj_expert_weight_key = moe_layer.weight_key_map.get("up_gate_proj_expert_weight_key") - down_proj_expert_weight_key = moe_layer.weight_key_map.get("down_proj_expert_weight_key") - for expert_idx in local_expert_ids: - down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx) - up_gate_proj_expert_weight_key_name = up_gate_proj_expert_weight_key.format(expert_idx) - state_dict[up_gate_proj_expert_weight_key_name] = up_gate_proj_weight[ - expert_idx - moe_layer.expert_id_offset - ] - state_dict[down_proj_expert_weight_key_name] = down_proj_weight[expert_idx - moe_layer.expert_id_offset] - moe_layer.load_state_dict(state_dict) + up_gate_proj_weight_scale = getattr(moe_layer, "up_gate_proj_weight_scale") + down_proj_weight_scale = getattr(moe_layer, "down_proj_weight_scale") + up_gate_proj_weight_scale.set_value( + paddle.ones(up_gate_proj_weight_scale.shape, dtype=paddle.float32).cast(up_gate_proj_weight_scale.dtype) + ) + down_proj_weight_scale.set_value( + paddle.ones(down_proj_weight_scale.shape, dtype=paddle.float32).cast(down_proj_weight_scale.dtype) + ) + moe_layer.up_gate_proj_weight_scale_2.set_value( + paddle.ones(moe_layer.up_gate_proj_weight_scale_2.shape, dtype=paddle.float32) + ) + moe_layer.down_proj_weight_scale_2.set_value( + paddle.ones(moe_layer.down_proj_weight_scale_2.shape, dtype=paddle.float32) + ) + moe_layer.up_gate_proj_input_scale.set_value( + paddle.ones(moe_layer.up_gate_proj_input_scale.shape, dtype=paddle.float32) + ) + moe_layer.down_proj_input_scale.set_value( + paddle.ones(moe_layer.down_proj_input_scale.shape, dtype=paddle.float32) + ) + moe_layer.quant_method.process_weights_after_loading(moe_layer) class TestFusedMoE(unittest.TestCase): def setUp(self) -> None: - self.architectures = ["Ernie4_5_MoeForCausalLM"] - self.hidden_size = 4096 - self.moe_intermediate_size = 2048 - self.moe_num_experts = 64 + self.architectures = ["Ernie5_MoeForCausalLM"] + self.hidden_size = 7168 + self.moe_intermediate_size = 3584 + self.moe_num_experts = 160 self.moe_k = 8 self.num_layers = 2 self.num_attention_heads = -1 @@ -570,7 +575,7 @@ def build_model_config(self) -> ModelConfig: return ModelConfig( { "model": model_name_or_path, - "max_model_len": 2048, + "max_model_len": 4096, } ) @@ -596,6 +601,7 @@ def test_fused_moe(self): init_distributed_environment() os.environ["FD_USE_DEEP_GEMM"] = "0" + os.environ["FD_MOE_BACKEND"] = "flashinfer-cutedsl" if int(os.getenv("USE_FUSEDMOE_TP", "0")) == 1: ep_size = 1 ep_rank = 0 @@ -612,6 +618,25 @@ def test_fused_moe(self): # 这行代码必须保留,否则影响均匀性! paddle.seed(ep_rank + 100) + # Compute the token list first so we can size the DeepEP low-latency buffer + # (num_max_dispatch_tokens_per_rank) correctly before creating FuseMoEWrapper. + token_list_env = os.getenv("NVFP4_TEST_TOKEN_LIST", "") + if token_list_env: + test_token_nums = [int(v.strip()) for v in token_list_env.split(",") if v.strip()] + else: + # Keep CI as a correctness/perf-smoke test by default. + test_token_nums = [60, 64, 1024] + + test_mode = os.getenv("NVFP4_TEST_MODE", "decode").lower() + # Default to decode for any unrecognised value (mirrors env-var default). + is_decoder = test_mode != "prefill" + + # For decode mode the DeepEP low-latency buffer must be pre-sized to hold + # at least max(test_token_nums) tokens. The framework default is 128, which + # is too small for the 1024-token case. + if is_decoder and ep_size > 1: + self.model_config.num_max_dispatch_tokens_per_rank = max(test_token_nums) + num_layers = self.num_layers real_weight_layers = num_layers // 2 fused_moe = [None] * real_weight_layers @@ -620,43 +645,52 @@ def test_fused_moe(self): moe_cuda_graphs = [None] * 100 cache_hidden_states = [None] * 100 - is_decoder = fused_moe[0].fd_config.model_config.moe_phase.phase == "decode" - test_token_nums = [4096 * i for i in [1, 2, 4, 8]] + + # For decode mode: set moe_phase to "decode" so apply_ep_decode is used, + # which is CUDA-graph-compatible (uses ep_decoder_runner / low-latency dispatch). + # For prefill mode: keep "prefill" so apply_ep_prefill uses ep_prefill_runner. if is_decoder: - test_token_nums = [10, 20, 40, 60, 80, 100, 128, 160, 192, 256] - for idx, num_tokens in enumerate(test_token_nums): + for layer_wrapper in fused_moe: + if layer_wrapper is not None: + layer_wrapper.fd_config.model_config.moe_phase.phase = "decode" + + # Avoid per-iteration weight mutation in hot path. + for layer in fused_moe: + layer.gating.weight.set_value(paddle.rand(layer.gating.weight.shape, dtype=paddle.float32)) + enable_cuda_graph = False # grouped_gemm_nt_masked (CuteDSL) is not CUDA-graph-capturable + + for idx, num_tokens in enumerate(test_token_nums): cache_hidden_states[idx] = paddle.rand((num_tokens, self.model_config.hidden_size), dtype=paddle.bfloat16) def fake_model_run(): for j in range(num_layers): if int(os.getenv("DISABLE_CI_FUSEDMOE_EP", "0")) == 1: - out = cache_hidden_states + cache_hidden_states + out = cache_hidden_states[idx] + cache_hidden_states[idx] else: gating = fused_moe[j % real_weight_layers].gating - gating.weight.set_value(paddle.rand(gating.weight.shape, dtype=paddle.float32)) out = fused_moe[j % real_weight_layers].fused_moe( cache_hidden_states[idx], gating, forward_meta=MockForwardMeta() ) return out - if is_decoder: + if enable_cuda_graph: moe_cuda_graphs[idx] = graphs.CUDAGraph() moe_cuda_graphs[idx].capture_begin() fake_model_run() - if is_decoder: + if enable_cuda_graph: moe_cuda_graphs[idx].capture_end() - num_tests = 20 + num_tests = max(2, int(os.getenv("NVFP4_TEST_ITERS", "6"))) start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)] end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)] for i in range(num_tests): start_events[i].record() - if is_decoder: + if enable_cuda_graph: moe_cuda_graphs[idx].replay() else: fake_model_run() @@ -684,6 +718,172 @@ def fake_model_run(): shutil.rmtree(self.model_name_or_path) + def test_decode_correctness(self): + """ + Verify apply_ep_decode correctness on 8 GPUs. + + Strategy + -------- + 1. All ranks share the SAME input x and gate weights (same seed everywhere). + 2. Each rank runs apply_ep_decode(x) → ep_output. + 3. Each rank independently runs _run_cutedsl_grouped_masked on ALL tokens + for each of its local experts (no dispatch needed). + 4. All-gather those per-expert outputs → all_ref[E_total, N, H]. + 5. Re-run moe_select to recover the routing (topk_idx, topk_weights). + 6. Manually compute: ref[i] = sum_k(w_k * all_ref[expert_k, i, :]). + 7. Compare ep_output vs ref. Only the dispatch/combine protocol is tested. + """ + init_distributed_environment() + + os.environ["FD_USE_DEEP_GEMM"] = "0" + os.environ["FD_MOE_BACKEND"] = "flashinfer-cutedsl" + + ep_size = paddle.distributed.get_world_size() + ep_rank = paddle.distributed.get_rank() + + if ep_size <= 1: + print("test_decode_correctness requires ep_size > 1, skipping.") + return + + num_tokens = 128 # small batch – fast but non-trivial + num_local_experts = self.moe_num_experts // ep_size # 8 + H = self.hidden_size # 7168 + + # DeepEP low-latency buffer must hold at least num_tokens tokens per rank. + self.model_config.num_max_dispatch_tokens_per_rank = max(num_tokens, 4096) + + # ── Build EP model ──────────────────────────────────────────────────── + ep_moe = FuseMoEWrapper(self.model_config, 1, 0, ep_size, ep_rank, nnodes=1) + ep_moe.fd_config.model_config.moe_phase.phase = "decode" + moe_layer = ep_moe.fused_moe + + # ── Shared input: same seed → same x on every rank ─────────────────── + paddle.seed(444) + x = paddle.rand([num_tokens, H], dtype=paddle.bfloat16) + + # ── Same gate weights everywhere ───────────────────────────────────── + paddle.seed(4444) + ep_moe.gating.weight.set_value(paddle.rand(ep_moe.gating.weight.shape, dtype=paddle.float32)) + + # ── Deterministic per-expert quantized weights (global expert id seed) ─ + # FuseMoEWrapper.__init__ already called process_weights_after_loading with + # all-one scales, so g1_alphas / g2_alphas / blockscale_swizzled are correct. + # We only need to swap the raw packed-uint8 weight tensors. + up_gate_w = np.zeros(moe_layer.up_gate_proj_weight.shape, dtype=np.uint8) + down_w = np.zeros(moe_layer.down_proj_weight.shape, dtype=np.uint8) + + for li in range(num_local_experts): + ge = ep_rank * num_local_experts + li + rng = np.random.default_rng(ge + 1_000_000) + up_gate_w[li] = rng.integers(0, 256, up_gate_w[li].shape, dtype=np.uint8) + down_w[li] = rng.integers(0, 256, down_w[li].shape, dtype=np.uint8) + + moe_layer.up_gate_proj_weight.set_value(paddle.to_tensor(up_gate_w)) + moe_layer.down_proj_weight.set_value(paddle.to_tensor(down_w)) + # NOTE: do NOT re-call process_weights_after_loading — scales already processed. + + # ── Step 1: EP decode forward ───────────────────────────────────────── + ep_output = ep_moe.fused_moe(x, ep_moe.gating, forward_meta=MockForwardMeta()) + # ep_output: [num_tokens, H] + + # ── Diagnostic: all ranks must produce the same output (same x, same routing) ─ + # If this fails, dispatch/combine has a routing bug. + ep_list = [paddle.zeros_like(ep_output) for _ in range(ep_size)] + # print(ep_list) + paddle.distributed.all_gather(ep_list, ep_output) + # print(ep_list) + if ep_rank == 0: + for r in range(1, ep_size): + d = (ep_list[0].cast(paddle.float32) - ep_list[r].cast(paddle.float32)).abs() + print(f"[ep_consistency] rank0 vs rank{r}: max_diff={float(d.max().numpy()):.5f}") + + # ── Step 2: Recover routing (deterministic for same x & gate weights) ─ + gate_out = ep_moe.gating(x.cast("float32")) + topk_idx, topk_weights = moe_layer.quant_method.ep_decoder_runner.moe_select(moe_layer, gate_out) + # topk_idx: [num_tokens, top_k] – global expert indices 0..63 + # topk_weights: [num_tokens, top_k] – softmax weights + + if ep_rank == 0: + print( + f"[test_decode_correctness] topk_idx range [{int(topk_idx.min().numpy())}, " + f"{int(topk_idx.max().numpy())}], " + f"topk_idx[0]={topk_idx.numpy()[0].tolist()}" + ) + + # ── Step 3: Direct GEMM reference for each local expert on all tokens ─ + # Feed every token to every local expert independently. + hidden_3d = x.unsqueeze(0).tile([num_local_experts, 1, 1]) + # hidden_3d: [E_local, num_tokens, H] + masked_m_ref = paddle.full([num_local_experts], num_tokens, dtype=paddle.int32) + + local_ref = moe_layer.quant_method._run_cutedsl_grouped_masked(moe_layer, hidden_3d, masked_m_ref) + # local_ref: [E_local, num_tokens, H] + + # ── Step 4: All-gather reference outputs across all ranks ───────────── + local_ref_list = [paddle.zeros_like(local_ref) for _ in range(ep_size)] + paddle.distributed.all_gather(local_ref_list, local_ref) + all_ref = paddle.concat(local_ref_list, axis=0) + # all_ref: [E_total=64, num_tokens, H] + # all_ref[global_e, i, :] == ffn_{global_e}(x[i]) + + # ── Step 5: Manual weighted sum (reference output) ─────────────────── + top_k = self.moe_k + ref_output = paddle.zeros([num_tokens, H], dtype=paddle.float32) + topk_idx_cpu = topk_idx.numpy() + topk_weights_f32 = topk_weights.cast(paddle.float32) + all_ref_f32 = all_ref.cast(paddle.float32) + + for i in range(num_tokens): + for k in range(top_k): + ge = int(topk_idx_cpu[i, k]) + w = float(topk_weights_f32[i, k].numpy()) + ref_output[i] += w * all_ref_f32[ge, i] + + if ep_rank == 0: + # Sanity: pick the first assigned expert for token 0 + ge0 = int(topk_idx_cpu[0, 0]) + print( + f"[test_decode_correctness] token0 k=0 expert ge={ge0}, " + f"all_ref[ge0,0,:3]={all_ref_f32[ge0, 0, :3].numpy().tolist()}" + ) + + # ── Step 6: Compare ─────────────────────────────────────────────────── + ep_f32 = ep_output.cast(paddle.float32) + diff = (ep_f32 - ref_output).abs() + max_diff = float(diff.max().numpy()) + mean_diff = float(diff.mean().numpy()) + + # Tolerance: the EP combine accumulates in BF16 while our reference uses + # float32. One BF16 rounding step introduces at most eps_bf16 ≈ 2^-7 + # error relative to the current partial sum. With top_k steps the + # worst-case absolute error is: + # output_scale × top_k × 2^-7 + # We give 2× headroom and floor at 100.0 to handle near-zero outputs. + output_scale = float(ref_output.abs().max().numpy()) + tol = max(100.0, output_scale * top_k * 2 ** (-6)) + + if ep_rank == 0: + rel_err = max_diff / (output_scale + 1e-9) + print(f"[Rank 0] ep_output[0,:4] = {ep_f32[0, :4].numpy().tolist()}") + print(f"[Rank 0] ref_output[0,:4] = {ref_output[0, :4].numpy().tolist()}") + print(f"[Rank 0] relative error: {rel_err*100:.4f}% (tol={tol:.0f})") + + print( + f"[Rank {ep_rank}] decode correctness: " + f"max_diff={max_diff:.1f} mean_diff={mean_diff:.5f} tol={tol:.1f}" + ) + + # A large relative error (>2%) almost certainly indicates a routing bug, + # not floating-point rounding. + assert max_diff < tol, ( + f"[Rank {ep_rank}] max_diff={max_diff:.1f} exceeds BF16 tolerance ({tol:.1f}). " + f"output_scale={output_scale:.1f}. " + "Likely a dispatch/combine routing error." + ) + print(f"[Rank {ep_rank}] PASS") + + shutil.rmtree(self.model_name_or_path) + if __name__ == "__main__": unittest.main()