Skip to content

Commit e109fb9

Browse files
authored
[Metax][Fix] fix issues based #6259 (#6338)
1 parent 90db0bd commit e109fb9

2 files changed

Lines changed: 162 additions & 324 deletions

File tree

fastdeploy/worker/input_batch.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from fastdeploy.config import CacheConfig, FDConfig, ModelConfig, SpeculativeConfig
2121
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
2222
from fastdeploy.model_executor.logits_processor import build_logits_processors
23+
from fastdeploy.platforms import current_platform
2324

2425

2526
class InputBatch:
@@ -134,23 +135,29 @@ def init_share_inputs(self):
134135
self.seq_lens_this_time_buffer = paddle.full([max_num_seqs, 1], 0, dtype="int32")
135136
if self.enable_expert_parallel:
136137
self.seq_lens_this_time = paddle.full([max_num_seqs, 1], 0, dtype="int32")
137-
self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory()
138138
self.seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
139139
self.seq_lens_decoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
140140
self.step_seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
141141
self.step_seq_lens_decoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
142142
self.prompt_lens = paddle.full([max_num_seqs, 1], 0, dtype="int64")
143143
self.step_idx = paddle.full([max_num_seqs, 1], 0, dtype="int64")
144-
self.not_need_stop = paddle.full([1], False, dtype="bool").pin_memory()
144+
if current_platform.is_maca():
145+
self.not_need_stop = paddle.full([1], False, dtype="bool").cpu()
146+
self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64").cpu()
147+
self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").cpu()
148+
self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool").cpu()
149+
else:
150+
self.not_need_stop = paddle.full([1], False, dtype="bool").pin_memory()
151+
self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64").pin_memory()
152+
self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory()
153+
self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool").pin_memory()
145154
self.not_need_stop_device = paddle.full([1], False, dtype="bool")
146-
self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64").pin_memory()
147155
self.stop_flags = paddle.full([max_num_seqs, 1], True, dtype="bool")
148156

149157
self.bad_tokens = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
150158
self.bad_tokens_len = paddle.full([max_num_seqs], 1, dtype="int64")
151159
self.next_tokens = paddle.full([max_num_seqs, 1], -1, dtype="int64")
152160
self.is_block_step = paddle.full([max_num_seqs], False, dtype="bool")
153-
self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool").pin_memory()
154161
self.is_chunk_step = paddle.full([max_num_seqs], False, dtype="bool").cpu()
155162
self.encoder_block_lens = paddle.full([max_num_seqs], 0, dtype="int32")
156163
self.step_block_list = paddle.full([max_num_seqs], -1, dtype="int32")

0 commit comments

Comments
 (0)