diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 7d1bba33774..cc8fc0711ed 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -446,7 +446,7 @@ def apply_tp( gate_out = gate_out.cast("float32") if fc1_latent_proj is not None: x = fc1_latent_proj(x) - gate_out, topk_weights, topk_idx = get_moe_scores( + gate_out, _, __ = get_moe_scores( gate_out, layer.n_group, layer.topk_group, @@ -458,11 +458,6 @@ def apply_tp( use_fused_cast=use_fused, ) - if layer.routed_scaling_factor_learnable: - safe_topk_indices = paddle.clip(topk_idx, min=0) - gathered_scales = F.embedding(safe_topk_indices, layer.per_expert_scale.unsqueeze(1)).squeeze(-1) - topk_weights = topk_weights * gathered_scales - ( permute_input, token_nums_per_expert, @@ -484,6 +479,12 @@ def apply_tp( self.moe_quant_type, topk_only_mode=True, ) + + if layer.routed_scaling_factor_learnable: + safe_topk_indices = paddle.clip(topk_idx, min=0) + gathered_scales = F.embedding(safe_topk_indices, layer.per_expert_scale.unsqueeze(1)).squeeze(-1) + topk_weights = topk_weights * gathered_scales + else: gate_out = gate_out.cast("float32") if fc1_latent_proj is not None: diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 2991359ce0c..6ba2f13633d 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -108,6 +108,10 @@ def __init__(self, fd_config: FDConfig) -> None: else: self.max_chunk_tokens = self.fd_config.get_max_chunk_tokens(self.model_config.mm_max_tokens_per_item) + # NOTE (changwenbin):Supports neox_rotary_style. + rotary_percent = getattr(self.model_config, "rotary_percent", 1) + self.rotary_dim = int(rotary_percent * self.model_config.head_dim) + def init_share_inputs(self): max_num_seqs = self.scheduler_config.max_num_seqs @@ -229,7 +233,7 @@ def init_share_inputs(self): # Initialize rotary position embedding if not self.enable_mm: self.rope_emb = get_rope( - rotary_dim=self.model_config.head_dim, + rotary_dim=self.rotary_dim, position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)), base=self.model_config.rope_theta, model_config=self.model_config, @@ -707,7 +711,7 @@ def reset_share_inputs(self): else: # Reset non-multimodal rope_emb self.rope_emb = get_rope( - rotary_dim=self.model_config.head_dim, + rotary_dim=self.rotary_dim, position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)), base=self.model_config.rope_theta, model_config=self.model_config, @@ -813,7 +817,7 @@ def init_share_inputs(self): tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1)) self.rope_emb = get_rope( - rotary_dim=self.model_config.head_dim, + rotary_dim=self.rotary_dim, position_ids=tmp_position_ids, base=self.model_config.rope_theta, model_config=self.model_config, @@ -1008,7 +1012,7 @@ def reset_model_inputs(self) -> None: # Reset rope embedding by recreating with default position_ids tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1)) self.rope_emb = get_rope( - rotary_dim=self.model_config.head_dim, + rotary_dim=self.rotary_dim, position_ids=tmp_position_ids, base=self.model_config.rope_theta, model_config=self.model_config,