From 708422610b8c9ca4ff2b772adfee34a4c7c363cd Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Wed, 28 Jan 2026 18:59:44 +0800 Subject: [PATCH 1/6] optim grid --- fastdeploy/envs.py | 2 ++ .../model_executor/layers/moe/fused_moe_triton_backend.py | 7 ++++++- fastdeploy/model_executor/models/ernie4_5_moe.py | 3 ++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index f91a9d77b83..3bcf8739026 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -196,6 +196,8 @@ ), # Timeout for worker process health check in seconds "FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")), + # Data type for MoE gate weight, can set "BFLOAT16" or "FLOAT32". + "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", "FLOAT32"), } diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 37b77821052..8d28d1496bc 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -1674,7 +1674,12 @@ def apply( Triton compute Fused MoE. """ - gate_out = gate(x.cast("float32")) + if gate.weight.dtype != paddle.float32: + gate_out = gate(x) + if gate_out.dtype != paddle.float32: + gate_out = gate_out.cast("float32") + else: + gate_out = gate(x.cast("float32")) top_k = layer.top_k num_local_experts = layer.num_local_experts moe_intermediate_size = layer.moe_intermediate_size diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 82c1b531502..dd3eb98d10a 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -28,6 +28,7 @@ from paddleformers.transformers.configuration_utils import PretrainedConfig from paddleformers.utils.log import logger +from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.graph_optimization.decorator import ( @@ -182,7 +183,7 @@ def __init__( output_size=fd_config.model_config.moe_num_experts, with_bias=False, skip_quant=True, - weight_dtype="float32", + weight_dtype=envs.FD_MOE_GATE_WEIGHT_DTYPE.lower(), ) self.experts = FusedMoE( From fb14ca0267fc038a53ddc114345429c9eea2f6b9 Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Wed, 28 Jan 2026 20:39:42 +0800 Subject: [PATCH 2/6] add document --- docs/usage/environment_variables.md | 3 +++ docs/zh/usage/environment_variables.md | 3 +++ 2 files changed, 6 insertions(+) diff --git a/docs/usage/environment_variables.md b/docs/usage/environment_variables.md index 043ae0b378e..584342eede8 100644 --- a/docs/usage/environment_variables.md +++ b/docs/usage/environment_variables.md @@ -230,5 +230,8 @@ environment_variables: dict[str, Callable[[], Any]] = { # Worker process health check timeout when waiting for responses in seconds (default: 30) "FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")), + + # Data type for MoE gate weight, can set "BFLOAT16" or "FLOAT32". + "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", "FLOAT32"), } ``` diff --git a/docs/zh/usage/environment_variables.md b/docs/zh/usage/environment_variables.md index 82b30040924..1d8e2fd02e9 100644 --- a/docs/zh/usage/environment_variables.md +++ b/docs/zh/usage/environment_variables.md @@ -230,5 +230,8 @@ environment_variables: dict[str, Callable[[], Any]] = { # Worker 进程响应等待时的健康检查超时时间(秒),默认 30 秒 "FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")), + + # moe层gate的权重类型,默认为float32 + "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", "FLOAT32"), } ``` From 6e247659a4ef46eca1d7847185f097c0a0b9b6b3 Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Thu, 29 Jan 2026 17:35:55 +0800 Subject: [PATCH 3/6] support --- docs/usage/environment_variables.md | 5 +-- docs/zh/usage/environment_variables.md | 4 +-- fastdeploy/envs.py | 4 +-- .../model_executor/models/ernie4_5_moe.py | 20 ++++++++++++ fastdeploy/model_executor/utils.py | 31 +++++++++++++++++++ 5 files changed, 58 insertions(+), 6 deletions(-) diff --git a/docs/usage/environment_variables.md b/docs/usage/environment_variables.md index 584342eede8..02c8b9e3d6b 100644 --- a/docs/usage/environment_variables.md +++ b/docs/usage/environment_variables.md @@ -231,7 +231,8 @@ environment_variables: dict[str, Callable[[], Any]] = { # Worker process health check timeout when waiting for responses in seconds (default: 30) "FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")), - # Data type for MoE gate weight, can set "BFLOAT16" or "FLOAT32". - "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", "FLOAT32"), + # Data type for MoE gate weight, can be set to float32, bfloat16, default is empty + "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", ""), + } ``` diff --git a/docs/zh/usage/environment_variables.md b/docs/zh/usage/environment_variables.md index 1d8e2fd02e9..4e85cfdd36d 100644 --- a/docs/zh/usage/environment_variables.md +++ b/docs/zh/usage/environment_variables.md @@ -231,7 +231,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # Worker 进程响应等待时的健康检查超时时间(秒),默认 30 秒 "FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")), - # moe层gate的权重类型,默认为float32 - "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", "FLOAT32"), + # moe层gate的权重类型,可以设置float32,bfloat16,默认为空 + "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", ""), } ``` diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 3bcf8739026..0ebb4184b87 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -196,8 +196,8 @@ ), # Timeout for worker process health check in seconds "FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")), - # Data type for MoE gate weight, can set "BFLOAT16" or "FLOAT32". - "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", "FLOAT32"), + # Data type for MoE gate weight, can set "BFLOAT16" or "FLOAT32" . + "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", ""), } diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index dd3eb98d10a..dd2faa551ab 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -54,6 +54,7 @@ from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid from fastdeploy.model_executor.models.utils import WeightMeta +from fastdeploy.model_executor.utils import adapt_moe_gate_parameter_dtype from fastdeploy.platforms import current_platform from fastdeploy.worker.experts_manager import RedundantExpertManger @@ -628,6 +629,25 @@ def load_weights(self, weights_iterator) -> None: continue param = params_dict[model_param_name] + if not envs.FD_MOE_GATE_WEIGHT_DTYPE: + param = adapt_moe_gate_parameter_dtype(self, model_param_name, param, loaded_weight, params_dict) + # Get weight loader from parameter and set weight + weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) + sig = inspect.signature(weight_loader) + if "expert_id" in sig.parameters: + weight_loader(param, loaded_weight, expert_id=expert_id, shard_id=shard_id) + else: + weight_loader(param, loaded_weight, shard_id) + + model_sublayer_name = re.sub( + r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name + ) + process_weights_after_loading_fn(model_sublayer_name, param) + if getattr(self, "tie_word_embeddings", False): + self.lm_head.linear.weight.set_value( + self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype) + ) + # Get weight loader from parameter and set weight weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) sig = inspect.signature(weight_loader) diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 59941d94595..5357d0e377a 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -524,3 +524,34 @@ def fn(loaded_weight_name, is_moe): return loaded_weight_name return fn + + +def adapt_moe_gate_parameter_dtype(model, model_param_name, param, loaded_weight, params_dict): + if not model_param_name.endswith(".gate.weight"): + return param + + if param.dtype != loaded_weight.dtype: + new_param = paddle.create_parameter( + shape=param.shape, + dtype=loaded_weight.dtype, + is_bias=getattr(param, "is_bias", False), + default_initializer=paddle.nn.initializer.Constant(0), + ) + + for k, v in param.__dict__.items(): + if not k.startswith("_"): + setattr(new_param, k, v) + + parts = model_param_name.split(".") + obj = model + for part in parts[:-1]: + if part.isdigit() and isinstance(obj, (paddle.nn.LayerList, list)): + obj = obj[int(part)] + else: + obj = getattr(obj, part) + setattr(obj, parts[-1], new_param) + + params_dict[model_param_name] = new_param + return new_param + + return param From 96d4abaed342f1b1ae1bf4be79b2d25278f8d130 Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Thu, 29 Jan 2026 18:10:36 +0800 Subject: [PATCH 4/6] fix --- fastdeploy/envs.py | 2 +- .../model_executor/models/ernie4_5_moe.py | 3 -- fastdeploy/model_executor/models/glm4_moe.py | 2 +- fastdeploy/model_executor/utils.py | 31 ------------------- 4 files changed, 2 insertions(+), 36 deletions(-) diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 0ebb4184b87..81f8a51e167 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -197,7 +197,7 @@ # Timeout for worker process health check in seconds "FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")), # Data type for MoE gate weight, can set "BFLOAT16" or "FLOAT32" . - "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", ""), + "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", "FLOAT32"), } diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index dd2faa551ab..c2c77edd029 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -54,7 +54,6 @@ from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid from fastdeploy.model_executor.models.utils import WeightMeta -from fastdeploy.model_executor.utils import adapt_moe_gate_parameter_dtype from fastdeploy.platforms import current_platform from fastdeploy.worker.experts_manager import RedundantExpertManger @@ -629,8 +628,6 @@ def load_weights(self, weights_iterator) -> None: continue param = params_dict[model_param_name] - if not envs.FD_MOE_GATE_WEIGHT_DTYPE: - param = adapt_moe_gate_parameter_dtype(self, model_param_name, param, loaded_weight, params_dict) # Get weight loader from parameter and set weight weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) sig = inspect.signature(weight_loader) diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index c1c86d399ea..fd57c9c813b 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -128,7 +128,7 @@ def __init__( output_size=fd_config.model_config.n_routed_experts, with_bias=False, skip_quant=True, - weight_dtype="float32", + weight_dtype="bfloat16", ) self.gate.e_score_correction_bias = self.create_parameter( shape=[1, fd_config.model_config.n_routed_experts], diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 5357d0e377a..59941d94595 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -524,34 +524,3 @@ def fn(loaded_weight_name, is_moe): return loaded_weight_name return fn - - -def adapt_moe_gate_parameter_dtype(model, model_param_name, param, loaded_weight, params_dict): - if not model_param_name.endswith(".gate.weight"): - return param - - if param.dtype != loaded_weight.dtype: - new_param = paddle.create_parameter( - shape=param.shape, - dtype=loaded_weight.dtype, - is_bias=getattr(param, "is_bias", False), - default_initializer=paddle.nn.initializer.Constant(0), - ) - - for k, v in param.__dict__.items(): - if not k.startswith("_"): - setattr(new_param, k, v) - - parts = model_param_name.split(".") - obj = model - for part in parts[:-1]: - if part.isdigit() and isinstance(obj, (paddle.nn.LayerList, list)): - obj = obj[int(part)] - else: - obj = getattr(obj, part) - setattr(obj, parts[-1], new_param) - - params_dict[model_param_name] = new_param - return new_param - - return param From db7052ffbd7fb47014a3b21beb24d3b9cd75e94a Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Thu, 29 Jan 2026 18:13:22 +0800 Subject: [PATCH 5/6] fix --- docs/usage/environment_variables.md | 4 ++-- docs/zh/usage/environment_variables.md | 4 ++-- .../model_executor/models/ernie4_5_moe.py | 17 ----------------- 3 files changed, 4 insertions(+), 21 deletions(-) diff --git a/docs/usage/environment_variables.md b/docs/usage/environment_variables.md index 02c8b9e3d6b..4b8dcda3636 100644 --- a/docs/usage/environment_variables.md +++ b/docs/usage/environment_variables.md @@ -231,8 +231,8 @@ environment_variables: dict[str, Callable[[], Any]] = { # Worker process health check timeout when waiting for responses in seconds (default: 30) "FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")), - # Data type for MoE gate weight, can be set to float32, bfloat16, default is empty - "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", ""), + # Data type for MoE gate weight, can be set to float32, bfloat16, default is FLOAT32 + "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", "FLOAT32"), } ``` diff --git a/docs/zh/usage/environment_variables.md b/docs/zh/usage/environment_variables.md index 4e85cfdd36d..d7de2e370e2 100644 --- a/docs/zh/usage/environment_variables.md +++ b/docs/zh/usage/environment_variables.md @@ -231,7 +231,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # Worker 进程响应等待时的健康检查超时时间(秒),默认 30 秒 "FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")), - # moe层gate的权重类型,可以设置float32,bfloat16,默认为空 - "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", ""), + # moe层gate的权重类型,可以设置float32,bfloat16,默认为float32 + "FD_MOE_GATE_WEIGHT_DTYPE": lambda: os.getenv("FD_MOE_GATE_WEIGHT_DTYPE", "FLOAT32"), } ``` diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index c2c77edd029..dd3eb98d10a 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -645,23 +645,6 @@ def load_weights(self, weights_iterator) -> None: self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype) ) - # Get weight loader from parameter and set weight - weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) - sig = inspect.signature(weight_loader) - if "expert_id" in sig.parameters: - weight_loader(param, loaded_weight, expert_id=expert_id, shard_id=shard_id) - else: - weight_loader(param, loaded_weight, shard_id) - - model_sublayer_name = re.sub( - r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name - ) - process_weights_after_loading_fn(model_sublayer_name, param) - if getattr(self, "tie_word_embeddings", False): - self.lm_head.linear.weight.set_value( - self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype) - ) - def compute_logits(self, hidden_states: paddle.Tensor): logits = self.lm_head(hidden_states) logits = logits.astype(paddle.float32) From 8216e46481f49f610e01dd00ae702914aac2720c Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Thu, 29 Jan 2026 20:12:29 +0800 Subject: [PATCH 6/6] fix --- fastdeploy/model_executor/models/glm4_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index fd57c9c813b..c1c86d399ea 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -128,7 +128,7 @@ def __init__( output_size=fd_config.model_config.n_routed_experts, with_bias=False, skip_quant=True, - weight_dtype="bfloat16", + weight_dtype="float32", ) self.gate.e_score_correction_bias = self.create_parameter( shape=[1, fd_config.model_config.n_routed_experts],