Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def apply_tp(
gate_out = gate_out.cast("float32")
if fc1_latent_proj is not None:
x = fc1_latent_proj(x)
gate_out, topk_weights, topk_idx = get_moe_scores(
gate_out, _, __ = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
Expand All @@ -458,11 +458,6 @@ def apply_tp(
use_fused_cast=use_fused,
)

if layer.routed_scaling_factor_learnable:
safe_topk_indices = paddle.clip(topk_idx, min=0)
gathered_scales = F.embedding(safe_topk_indices, layer.per_expert_scale.unsqueeze(1)).squeeze(-1)
topk_weights = topk_weights * gathered_scales

(
permute_input,
token_nums_per_expert,
Expand All @@ -484,6 +479,12 @@ def apply_tp(
self.moe_quant_type,
topk_only_mode=True,
)

if layer.routed_scaling_factor_learnable:
safe_topk_indices = paddle.clip(topk_idx, min=0)
gathered_scales = F.embedding(safe_topk_indices, layer.per_expert_scale.unsqueeze(1)).squeeze(-1)
topk_weights = topk_weights * gathered_scales

else:
gate_out = gate_out.cast("float32")
if fc1_latent_proj is not None:
Expand Down
12 changes: 8 additions & 4 deletions fastdeploy/worker/input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def __init__(self, fd_config: FDConfig) -> None:
else:
self.max_chunk_tokens = self.fd_config.get_max_chunk_tokens(self.model_config.mm_max_tokens_per_item)

# NOTE (changwenbin):Supports neox_rotary_style.
rotary_percent = getattr(self.model_config, "rotary_percent", 1)
self.rotary_dim = int(rotary_percent * self.model_config.head_dim)

def init_share_inputs(self):
max_num_seqs = self.scheduler_config.max_num_seqs

Expand Down Expand Up @@ -229,7 +233,7 @@ def init_share_inputs(self):
# Initialize rotary position embedding
if not self.enable_mm:
self.rope_emb = get_rope(
rotary_dim=self.model_config.head_dim,
rotary_dim=self.rotary_dim,
position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)),
base=self.model_config.rope_theta,
model_config=self.model_config,
Expand Down Expand Up @@ -707,7 +711,7 @@ def reset_share_inputs(self):
else:
# Reset non-multimodal rope_emb
self.rope_emb = get_rope(
rotary_dim=self.model_config.head_dim,
rotary_dim=self.rotary_dim,
position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)),
base=self.model_config.rope_theta,
model_config=self.model_config,
Expand Down Expand Up @@ -813,7 +817,7 @@ def init_share_inputs(self):
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug ProposerInputBatch 继承自 InputBatch 但其 __init__ 未调用 super().__init__(),也未独立设置 self.rotary_dim。当投机解码场景调用 ProposerInputBatch.init_share_inputs() 时,此处访问 self.rotary_dim 将抛出 AttributeError

建议修复方式:在 ProposerInputBatch.__init__ 中补充 rotary_dim 的计算,与父类保持一致:

rotary_percent = getattr(self.model_config, "rotary_percent", 1)
self.rotary_dim = int(rotary_percent * self.model_config.head_dim)


self.rope_emb = get_rope(
rotary_dim=self.model_config.head_dim,
rotary_dim=self.rotary_dim,
position_ids=tmp_position_ids,
base=self.model_config.rope_theta,
model_config=self.model_config,
Expand Down Expand Up @@ -1008,7 +1012,7 @@ def reset_model_inputs(self) -> None:
# Reset rope embedding by recreating with default position_ids
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
self.rope_emb = get_rope(
rotary_dim=self.model_config.head_dim,
rotary_dim=self.rotary_dim,
position_ids=tmp_position_ids,
base=self.model_config.rope_theta,
model_config=self.model_config,
Expand Down
Loading