Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/usage/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,5 +230,9 @@ environment_variables: dict[str, Callable[[], Any]] = {

# Worker process health check timeout when waiting for responses in seconds (default: 30)
"FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")),

# Data type for MoE gate weight, can be set to float32, bfloat16, default is FLOAT32
"FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", "FLOAT32"),

}
```
3 changes: 3 additions & 0 deletions docs/zh/usage/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,5 +230,8 @@ environment_variables: dict[str, Callable[[], Any]] = {

# Worker 进程响应等待时的健康检查超时时间(秒),默认 30 秒
"FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")),

# moe层gate的权重类型,可以设置float32,bfloat16,默认为float32
"FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", "FLOAT32"),
}
```
2 changes: 2 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@
),
# Timeout for worker process health check in seconds
"FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")),
# Data type for MoE gate weight, can set "BFLOAT16" or "FLOAT32" .
"FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", "FLOAT32"),
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1674,7 +1674,12 @@ def apply(
Triton compute Fused MoE.
"""

gate_out = gate(x.cast("float32"))
if gate.weight.dtype != paddle.float32:
gate_out = gate(x)
if gate_out.dtype != paddle.float32:
gate_out = gate_out.cast("float32")
else:
gate_out = gate(x.cast("float32"))
top_k = layer.top_k
num_local_experts = layer.num_local_experts
moe_intermediate_size = layer.moe_intermediate_size
Expand Down
3 changes: 2 additions & 1 deletion fastdeploy/model_executor/models/ernie4_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from paddleformers.transformers.configuration_utils import PretrainedConfig
from paddleformers.utils.log import logger

from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import (
Expand Down Expand Up @@ -182,7 +183,7 @@ def __init__(
output_size=fd_config.model_config.moe_num_experts,
with_bias=False,
skip_quant=True,
weight_dtype="float32",
weight_dtype=envs.FD_MOE_GATE_WEIGHT_DTYPE.lower(),
)

self.experts = FusedMoE(
Expand Down
Loading