Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@
"FD_XPU_ENABLE_MIXED_EP_MODE": lambda: bool(int(os.getenv("FD_XPU_ENABLE_MIXED_EP_MODE", "0"))),
# Whether to use phi FP8 quantization,if 1,use paddle default.
"FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))),
# Enables the Paddle/phi combined TopK operator only when topk_method == noaux_tc,
# intended for training alignment. Defaults to 0 (disabled).
"FD_USE_PHI_MOE_TOPK": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_TOPK", "0"))),
# Whether to use phi MOE permute,if 1,use paddle op.
"FD_USE_PHI_MOE_PERMUTE": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "0"))),
# Control class SiluAndMul to use swiglu or fusid_bias_act operator in the forward_cuda function
Expand All @@ -218,6 +221,8 @@
# has been observed to cause NaN computation errors.
# Set to 1 to enable the lock; defaults to 0 (disabled).
"FD_USE_KVCACHE_LOCK": lambda: bool(int(os.getenv("FD_USE_KVCACHE_LOCK", "0"))),
# Whether to probe MoE routing probabilities and use Fleet's fused SwiGLU kernel.
"FD_MOE_PROB_IN_ADVANCE": lambda: bool(int(os.getenv("FD_MOE_PROB_IN_ADVANCE", "0"))),
# Suspend rollouting routing replay
"FD_SUSPEND_ROUTING_REPLAY": lambda: bool(int(os.getenv("FD_SUSPEND_ROUTING_REPLAY", "0"))),
}
Expand Down
189 changes: 149 additions & 40 deletions fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@

import fastdeploy
from fastdeploy.model_executor.layers.moe.ep import deep_ep
from fastdeploy.model_executor.layers.quantization.fp8_utils import deep_gemm
from fastdeploy.model_executor.layers.quantization.fp8_utils import (
deep_gemm,
paddlefleet_ops,
)
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func
from fastdeploy.platforms import current_platform
Expand Down Expand Up @@ -84,6 +87,7 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
layer_added_scale_attrs_1: paddle.Tensor, # getattr(layer, self.added_scale_attrs[1])
quant_config_weight_block_size_0: int, # self.quant_config.weight_block_size[0]
disable_ue8m0_cast: bool,
dst_weights: paddle.Tensor,
):

# up_gate_proj
Expand All @@ -104,23 +108,30 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
)

# swiglu
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out)

# down_proj
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
ffn_out, quant_config_weight_block_size_0
if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE:
ffn_in_x, ffn_in_x_scale_tensor = paddlefleet_ops.fuse_weighted_swiglu_fp8_quant(
ffn_out, dst_weights, using_pow2_scaling=True, use_ue8m0=not disable_ue8m0_cast
)

ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous()
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
ffn_in_x_scale_tensor = paddle.transpose(paddle.transpose(ffn_in_x_scale_tensor, [1, 0]).contiguous(), [1, 0])
else:
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
ffn_out,
using_pow2_scale=not disable_ue8m0_cast,
using_ue8m0_scale=not disable_ue8m0_cast,
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out)

# down_proj
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
ffn_out, quant_config_weight_block_size_0, not disable_ue8m0_cast
)

ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous()
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
else:
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
ffn_out,
using_pow2_scale=not disable_ue8m0_cast,
using_ue8m0_scale=not disable_ue8m0_cast,
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]

ffn_out = paddle.empty(
(permute_input.shape[0], layer_added_weight_attrs_1.shape[1]),
Expand All @@ -136,6 +147,66 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
return ffn_out


def moe_topk_select(
gating_output: paddle.Tensor,
n_group: int,
topk_group: int,
top_k: int,
routed_scaling_factor: float,
e_score_correction_bias: paddle.Tensor,
renormalize: bool = False,
):
"""
Topk selection using paddle PHI topk API.

Args:
gating_output: gate output logits, shape [seq_len, n_experts]
n_group: number of expert groups
topk_group: number of top-k groups to select
top_k: number of top experts per token
routed_scaling_factor: scaling factor for routed experts
e_score_correction_bias: bias for expert selection
renormalize: whether to renormalize topk probabilities

Returns:
topk_weights: normalized topk probabilities, shape [seq_len, top_k]
topk_ids: topk expert indices, shape [seq_len, top_k]
"""
# compute gate probs via sigmoid
gate_probs = paddle.nn.functional.sigmoid(gating_output)
# probs_for_choice includes correction bias for topk selection
probs_for_choice = gate_probs + e_score_correction_bias if e_score_correction_bias is not None else gate_probs
# group-based topk selection
n_group = n_group if n_group > 0 else 1
topk_group = topk_group if topk_group > 0 else 1
if n_group > 1 and topk_group < n_group:
seq_length, n_experts = probs_for_choice.shape
group_scores = (
probs_for_choice.reshape([seq_length, n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1)
) # [seq_len, n_group]
group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [seq_len, topk_group]
group_mask = paddle.zeros_like(group_scores).put_along_axis(
group_idx, paddle.to_tensor(1.0, dtype=group_scores.dtype), axis=-1
)
score_mask = (
group_mask.unsqueeze(-1).expand([seq_length, n_group, n_experts // n_group]).reshape([seq_length, -1])
) # [seq_len, n_experts]
probs_for_choice = probs_for_choice.masked_fill(~score_mask.astype(paddle.bool), float("-inf"))

_, topk_ids = paddle.topk(probs_for_choice, top_k, axis=-1)
topk_weights = paddle.take_along_axis(gate_probs, topk_ids, axis=-1)

# normalize combine weights
if renormalize:
topk_weights = topk_weights / paddle.clip(topk_weights.sum(-1, keepdim=True), min=1e-12)

# apply routed scaling factor
if routed_scaling_factor:
topk_weights = topk_weights * routed_scaling_factor

return topk_weights, topk_ids


class DeepGemmFusedMoeMethod(MoEMethodBase):
"""
DeepGemmFusedMoeMethod is a class that implements the MoEMethodBase interface for DeepGemm backend.
Expand Down Expand Up @@ -257,7 +328,22 @@ def apply_ep_prefill(
hidden_size = x.shape[1]

# 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
if (
fastdeploy.envs.FD_USE_PHI_MOE_TOPK
and layer.redundant_table_manger is None
and layer.topk_method == "noaux_tc"
):
topk_weights, topk_idx = moe_topk_select(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
else:
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)

if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
Expand Down Expand Up @@ -371,22 +457,31 @@ def apply_ep_prefill(
)
del permute_input

# swiglu
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out, None)
if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE:
ffn_in_x, ffn_in_x_scale_tensor = paddlefleet_ops.fuse_weighted_swiglu_fp8_quant(
ffn_out, dst_weights, using_pow2_scaling=True, use_ue8m0=self.quant_config.deepgemm_scale_ue8m0
)

# down_proj
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
ffn_out, self.quant_config.weight_block_size[0]
ffn_in_x_scale_tensor = paddle.transpose(
paddle.transpose(ffn_in_x_scale_tensor, [1, 0]).contiguous(), [1, 0]
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous().transpose([1, 0])
else:
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
ffn_out,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
# swiglu
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out, None)

# down_proj
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
ffn_out, self.quant_config.weight_block_size[0], self.quant_config.deepgemm_scale_ue8m0
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous().transpose([1, 0])
else:
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
ffn_out,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]

del ffn_out
ffn_out = paddle.empty(
Expand All @@ -408,7 +503,7 @@ def apply_ep_prefill(
token_prob_unzipped=dst_weights,
total_zipped_tokens=recv_x.shape[0],
num_experts=layer.num_local_experts,
using_weighted_combine=True,
using_weighted_combine=not fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE,
)

else:
Expand Down Expand Up @@ -523,15 +618,28 @@ def apply_tp(
gate_out = gate(x.cast("float32"))

if layer.topk_method == "noaux_tc":
_, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)

if not fastdeploy.envs.FD_USE_PHI_MOE_TOPK:
_, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
else:
topk_weights, topk_ids = moe_topk_select(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)

else:
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
Expand Down Expand Up @@ -613,6 +721,7 @@ def apply_tp(
getattr(layer, self.added_scale_attrs[1]),
self.quant_config.weight_block_size[0],
disable_ue8m0_cast=not self.quant_config.deepgemm_scale_ue8m0,
dst_weights=dst_weights if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE else None,
)

# prmt back per rank
Expand All @@ -624,7 +733,7 @@ def apply_tp(
token_prob_unzipped=dst_weights,
total_zipped_tokens=recv_x.shape[0],
num_experts=layer.num_experts,
using_weighted_combine=True,
using_weighted_combine=not fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE,
)
else:
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
Expand Down
56 changes: 40 additions & 16 deletions fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
pass
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.model_executor.layers.quantization.fp8_utils import (
fused_stack_transpose_quant,
quant_weight_ue8m0,
transform_scale_ue8m0,
)
Expand Down Expand Up @@ -1622,22 +1623,45 @@ def _process_quantize(weight_idx):
)
weight[expert_id].copy_(weight_quant, False)
else:
weight = paddle.empty(shape=weight_shape, dtype=weight_dtype)
scale_list = []

for expert_id in range(layer.num_local_experts):
w_q, s_fp32 = quant_weight_ue8m0(
weight_dequant=getattr(layer, unquantized_weight_name)[expert_id]
.transpose([1, 0])
.contiguous(),
weight_block_size=self.quant_config.weight_block_size,
)
s_ue8m0 = transform_scale_ue8m0(
s_fp32, mn=w_q.shape[-2], weight_block_size=self.quant_config.weight_block_size
)
weight[expert_id].copy_(w_q, False)
scale_list.append(s_ue8m0)
scale = paddle.to_tensor(scale_list)
if fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
num_expert = layer.num_local_experts
expert_weight_list = [getattr(layer, unquantized_weight_name)[i] for i in range(num_expert)]
weight = paddle.empty(shape=weight_shape, dtype=weight_dtype)
scale_list = []
chunk_size = 64

for start_idx in range(0, num_expert, chunk_size):
end_idx = min(start_idx + chunk_size, num_expert)
local_chunk_size = end_idx - start_idx
chunk_experts = [w.contiguous() for w in expert_weight_list[start_idx:end_idx]]

w1_t_quant, w1_t_scale = fused_stack_transpose_quant(
chunk_experts, use_ue8m0=self.quant_config.deepgemm_scale_ue8m0
)
w1_t_quant = w1_t_quant.reshape([local_chunk_size, -1, w1_t_quant.shape[-1]])
w1_t_scale = w1_t_scale.reshape([local_chunk_size, -1, w1_t_scale.shape[-1]])

weight[start_idx:end_idx].copy_(w1_t_quant, False)
scale_list.append(w1_t_scale)

scale = paddle.concat(scale_list, axis=0)
else:
weight = paddle.empty(shape=weight_shape, dtype=weight_dtype)
scale_list = []

for expert_id in range(layer.num_local_experts):
w_q, s_fp32 = quant_weight_ue8m0(
weight_dequant=getattr(layer, unquantized_weight_name)[expert_id]
.transpose([1, 0])
.contiguous(),
weight_block_size=self.quant_config.weight_block_size,
)
s_ue8m0 = transform_scale_ue8m0(
s_fp32, mn=w_q.shape[-2], weight_block_size=self.quant_config.weight_block_size
)
weight[expert_id].copy_(w_q, False)
scale_list.append(s_ue8m0)
scale = paddle.to_tensor(scale_list)

free_tensor(getattr(layer, unquantized_weight_name))
free_tensor(getattr(layer, weight_name))
Expand Down
Loading
Loading