Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"

metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu)
metadata.max_len_tensor_cpu_decoder[1] = 0

forward_meta.attention_metadata = metadata

def forward_mixed(
Expand Down Expand Up @@ -237,6 +234,10 @@ def forward_mixed(
)

if forward_meta.max_len_tensor_cpu[1].item() > 0:

metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu)
metadata.max_len_tensor_cpu_decoder[1] = 0

(
metadata.cu_seqlens_k,
metadata.pre_cache_batch_ids,
Expand Down Expand Up @@ -305,7 +306,7 @@ def forward_mixed(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
self.zero_seq_enc_lens_for_decode if use_fa_do_prefill else forward_meta.seq_lens_encoder,
Comment thread
Deleter-D marked this conversation as resolved.
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"

metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu)
metadata.max_len_tensor_cpu_decoder[1] = 0

forward_meta.attention_metadata = metadata

def forward_mixed(
Expand Down Expand Up @@ -222,6 +219,10 @@ def forward_mixed(

# here we add five members,this is ugly, just for now.
if forward_meta.max_len_tensor_cpu[1].item() > 0:

metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu)
metadata.max_len_tensor_cpu_decoder[1] = 0

(
forward_meta.attn_cu_seqlens_k,
forward_meta.pre_cache_batch_ids,
Expand Down Expand Up @@ -293,7 +294,7 @@ def forward_mixed(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
self.zero_seq_enc_lens_for_decode if use_fa_do_prefill else forward_meta.seq_lens_encoder,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
Expand Down
21 changes: 4 additions & 17 deletions fastdeploy/model_executor/models/glm4_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
from fastdeploy.model_executor.graph_optimization.decorator import (
support_graph_optimization,
)
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.glm4_moe import Glm4MoeDecoderLayer
Expand Down Expand Up @@ -119,12 +117,8 @@ def __init__(
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.shared_head.norm",
)
self.head = ParallelLMHead(
fd_config,
embedding_dim=fd_config.model_config.hidden_size,
num_embeddings=fd_config.model_config.vocab_size,
prefix=f"{prefix}.shared_head.head",
)
if fd_config.speculative_config.sharing_model is not None:
self.head = fd_config.speculative_config.sharing_model.lm_head

def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
# NOTE(wangyanpeng04): Just for compute logits
Expand Down Expand Up @@ -216,15 +210,8 @@ def __init__(

assert self.num_mtp_layers == 1, f"Currently only supports single MTP layer, but got {self.num_mtp_layers}"

self.embed_tokens = VocabParallelEmbedding(
fd_config=fd_config,
num_embeddings=fd_config.model_config.vocab_size,
embedding_dim=fd_config.model_config.hidden_size,
params_dtype=paddle.get_default_dtype(),
prefix=(
f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{self.mtp_start_layer_idx}.embed_tokens"
),
)
if fd_config.speculative_config.sharing_model is not None:
self.embed_tokens = fd_config.speculative_config.sharing_model.model.embed_tokens

self.layers = nn.LayerDict(
{
Expand Down
6 changes: 4 additions & 2 deletions fastdeploy/rl/rollout_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
enable_chunked_prefill: bool = False,
speculative_method: str = None,
speculative_max_draft_token_num: int = 1,
speculative_model_name_or_path: str = "",
speculative_model_name_or_path: str = None,
speculative_model_quantization: str = "WINT8",
max_num_batched_tokens: int = 2048,
enable_prefix_caching: bool = False,
Expand Down Expand Up @@ -96,7 +96,9 @@ def __init__(
self.speculative_config = {}
self.speculative_config["method"] = speculative_method
self.speculative_config["max_draft_token_num"] = speculative_max_draft_token_num
self.speculative_config["model"] = speculative_model_name_or_path
self.speculative_config["model"] = (
speculative_model_name_or_path if speculative_model_name_or_path is not None else model_name_or_path
)
self.speculative_config["quantization"] = speculative_model_quantization
self.max_num_batched_tokens = max_num_batched_tokens
self.enable_prefix_caching = enable_prefix_caching
Expand Down
143 changes: 143 additions & 0 deletions fastdeploy/rl/rollout_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from typing import Dict

import paddle
import paddle.distributed as dist
from paddle import nn

from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.models.ernie4_5_moe import (
Expand All @@ -34,6 +36,10 @@
Glm4MoeForCausalLM,
Glm4MoePretrainedModel,
)
from fastdeploy.model_executor.models.glm4_mtp import (
Glm4MTPForCausalLM,
Glm4MTPPretrainedModel,
)
from fastdeploy.model_executor.models.model_base import ModelRegistry
from fastdeploy.model_executor.models.qwen2 import (
Qwen2ForCausalLM,
Expand Down Expand Up @@ -698,12 +704,52 @@ def __init__(self, fd_config: FDConfig):
fd_config (FDConfig): Configurations for the LLM model.
"""
super(Glm4MoeForCausalLMRL, self).__init__(fd_config)
self.speculative_decoding = fd_config.speculative_config.method is not None
self.speculative_method = fd_config.speculative_config.method

if self.speculative_decoding and self.speculative_method == "mtp":
fd_config.parallel_config.tp_group = None
fd_config.parallel_config.ep_group = None
self.mtp_fd_config = copy.deepcopy(fd_config)
fd_config.parallel_config.tp_group = dist.get_group(
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
)
fd_config.parallel_config.ep_group = dist.get_group(
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
)
self.fd_config.parallel_config.tp_group = dist.get_group(
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
)
self.fd_config.parallel_config.ep_group = dist.get_group(
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
)
self.update_mtp_config(self.mtp_fd_config)
self.mtp_layers = Glm4MTPForCausalLMRL(self.mtp_fd_config)

@classmethod
def name(self) -> str:
"""name"""
return "Glm4MoeForCausalLMRL"

def update_mtp_config(self, mtp_fd_config):
mtp_fd_config.model_config.architectures[0] = mtp_fd_config.model_config.architectures[0].replace("Moe", "MTP")
mtp_fd_config.speculative_config.sharing_model = None
mtp_fd_config.model_config.num_hidden_layers = 1
mtp_fd_config.model_config.model = mtp_fd_config.speculative_config.model
if mtp_fd_config.speculative_config.quantization != "":
mtp_fd_config.model_config.quantization = mtp_fd_config.speculative_config.quantization
mtp_fd_config.model_config.start_layer_index = mtp_fd_config.model_config.num_hidden_layers
mtp_fd_config.speculative_config.model_type = "mtp"

def state_dict(self):
"""state_dict"""
main_state_dict = super().state_dict()
state_dict = {k: v for k, v in main_state_dict.items() if not k.startswith("mtp_layers")}
if self.speculative_decoding and self.speculative_method == "mtp":
mtp_state_dict = self.mtp_layers.state_dict()
state_dict.update(mtp_state_dict)
return state_dict

def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
if self._mappings_built:
Expand Down Expand Up @@ -757,9 +803,106 @@ def _add_layer_mappings(layer_idx: int):
_add_layer_mappings(layer_idx)

self._complete_missing_mappings()

# extra for mtp
if self.speculative_decoding and self.speculative_method == "mtp":
mtp_infer_to_train_mapping = self.mtp_layers.get_name_mappings_to_training(trainer_degree)
self.infer_to_train_mapping.update(mtp_infer_to_train_mapping)

infer_to_train_mapping_copy = copy.deepcopy(self.infer_to_train_mapping)
for key in infer_to_train_mapping_copy.keys():
if "mlp.experts.gate_correction_bias" in key:
self.infer_to_train_mapping.pop(key)

return self.infer_to_train_mapping


class Glm4MTPForCausalLMRL(Glm4MTPForCausalLM, BaseRLModel):
"""
Glm4MTPForCausalLMRL
"""

_get_tensor_parallel_mappings = Glm4MTPPretrainedModel._get_tensor_parallel_mappings

def __init__(self, fd_config: FDConfig):
"""
Args:
fd_config (FDConfig): Configurations for the LLM model.
"""
super(Glm4MTPForCausalLMRL, self).__init__(fd_config)

@classmethod
def name(self) -> str:
"""name"""
return "Glm4MTPForCausalLMRL"

def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
if self._mappings_built:
return self.infer_to_train_mapping

self.infer_to_train_mapping = {}
self._mappings_built = True

# Prepare placeholders
place_holders = ["weight"]

base_name = "model.layers"

# Helper function to add layer mappings
def _add_layer_mappings(layer_idx: int):
# MTP specific mappings
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.shared_head.head.weight"] = (
f"{base_name}.{layer_idx}.shared_head.head.weight"
)
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.shared_head.norm.weight"] = (
f"{base_name}.{layer_idx}.shared_head.norm.weight"
)
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.eh_proj.weight"] = (
f"{base_name}.{layer_idx}.eh_proj.weight"
)
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.enorm.weight"] = (
f"{base_name}.{layer_idx}.enorm.weight"
)
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.hnorm.weight"] = (
f"{base_name}.{layer_idx}.hnorm.weight"
)

# MoE specific mappings
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.weight"] = (
f"{base_name}.{layer_idx}.mlp.gate.weight"
)

self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.e_score_correction_bias"] = (
f"{base_name}.{layer_idx}.mlp.gate.e_score_correction_bias"
)

# MoE experts mappings
for expert_idx in range(self.fd_config.model_config.n_routed_experts):
for ph in place_holders:
# up_gate_proj (up_gate_proj)
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.experts.up_gate_proj_weight"
if up_gate_proj_key not in self.infer_to_train_mapping:
self.infer_to_train_mapping[up_gate_proj_key] = []
self.infer_to_train_mapping[up_gate_proj_key].append(
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
)

# down_proj (down_proj)
down_proj_key = f"{base_name}.{layer_idx}.mlp.experts.down_proj_weight"
if down_proj_key not in self.infer_to_train_mapping:
self.infer_to_train_mapping[down_proj_key] = []
self.infer_to_train_mapping[down_proj_key].append(
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
)

# Process MoE layers
for layer_idx in range(
self.fd_config.model_config.start_layer_index,
self.fd_config.model_config.start_layer_index + self.fd_config.model_config.num_nextn_predict_layers,
):
_add_layer_mappings(layer_idx)

self._complete_missing_mappings()

return self.infer_to_train_mapping
17 changes: 11 additions & 6 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,19 +1547,24 @@ def load_model(self) -> None:
model_loader = get_model_loader(load_config=self.fd_config.load_config)
self.model = model_loader.load_model(fd_config=self.fd_config)

# 1.1 Load RL dynamic model
if self.fd_config.load_config.dynamic_load_weight:
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager

self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model, self.local_rank)

# 2. Load lora model

# 3. Load drafter model(for speculative decoding)

# 4. Init proposer for speculative method
self._init_speculative_proposer()

# Load RL dynamic model
if self.fd_config.load_config.dynamic_load_weight:
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager

if self.fd_config.speculative_config.method == "mtp":
self.dynamic_weight_manager = DynamicWeightManager(
self.fd_config, [self.model, self.proposer.model], self.local_rank
)
else:
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model, self.local_rank)

def get_model(self) -> nn.Layer:
"""Get current model"""
return self.model
Expand Down
Loading
Loading