diff --git a/docs/usage/environment_variables.md b/docs/usage/environment_variables.md index 043ae0b378e..4b8dcda3636 100644 --- a/docs/usage/environment_variables.md +++ b/docs/usage/environment_variables.md @@ -230,5 +230,9 @@ environment_variables: dict[str, Callable[[], Any]] = { # Worker process health check timeout when waiting for responses in seconds (default: 30) "FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")), + + # Data type for MoE gate weight, can be set to float32, bfloat16, default is FLOAT32 + "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", "FLOAT32"), + } ``` diff --git a/docs/zh/usage/environment_variables.md b/docs/zh/usage/environment_variables.md index 82b30040924..d7de2e370e2 100644 --- a/docs/zh/usage/environment_variables.md +++ b/docs/zh/usage/environment_variables.md @@ -230,5 +230,8 @@ environment_variables: dict[str, Callable[[], Any]] = { # Worker 进程响应等待时的健康检查超时时间(秒),默认 30 秒 "FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")), + + # moe层gate的权重类型,可以设置float32,bfloat16,默认为float32 + "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", "FLOAT32"), } ``` diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index f91a9d77b83..81f8a51e167 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -196,6 +196,8 @@ ), # Timeout for worker process health check in seconds "FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")), + # Data type for MoE gate weight, can set "BFLOAT16" or "FLOAT32" . + "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", "FLOAT32"), } diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 37b77821052..8d28d1496bc 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -1674,7 +1674,12 @@ def apply( Triton compute Fused MoE. """ - gate_out = gate(x.cast("float32")) + if gate.weight.dtype != paddle.float32: + gate_out = gate(x) + if gate_out.dtype != paddle.float32: + gate_out = gate_out.cast("float32") + else: + gate_out = gate(x.cast("float32")) top_k = layer.top_k num_local_experts = layer.num_local_experts moe_intermediate_size = layer.moe_intermediate_size diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 82c1b531502..dd3eb98d10a 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -28,6 +28,7 @@ from paddleformers.transformers.configuration_utils import PretrainedConfig from paddleformers.utils.log import logger +from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.graph_optimization.decorator import ( @@ -182,7 +183,7 @@ def __init__( output_size=fd_config.model_config.moe_num_experts, with_bias=False, skip_quant=True, - weight_dtype="float32", + weight_dtype=envs.FD_MOE_GATE_WEIGHT_DTYPE.lower(), ) self.experts = FusedMoE(