|
20 | 20 | from fastdeploy.config import CacheConfig, FDConfig, ModelConfig, SpeculativeConfig |
21 | 21 | from fastdeploy.model_executor.layers.rotary_embedding import get_rope |
22 | 22 | from fastdeploy.model_executor.logits_processor import build_logits_processors |
| 23 | +from fastdeploy.platforms import current_platform |
23 | 24 |
|
24 | 25 |
|
25 | 26 | class InputBatch: |
@@ -134,23 +135,29 @@ def init_share_inputs(self): |
134 | 135 | self.seq_lens_this_time_buffer = paddle.full([max_num_seqs, 1], 0, dtype="int32") |
135 | 136 | if self.enable_expert_parallel: |
136 | 137 | self.seq_lens_this_time = paddle.full([max_num_seqs, 1], 0, dtype="int32") |
137 | | - self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory() |
138 | 138 | self.seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32") |
139 | 139 | self.seq_lens_decoder = paddle.full([max_num_seqs, 1], 0, dtype="int32") |
140 | 140 | self.step_seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32") |
141 | 141 | self.step_seq_lens_decoder = paddle.full([max_num_seqs, 1], 0, dtype="int32") |
142 | 142 | self.prompt_lens = paddle.full([max_num_seqs, 1], 0, dtype="int64") |
143 | 143 | self.step_idx = paddle.full([max_num_seqs, 1], 0, dtype="int64") |
144 | | - self.not_need_stop = paddle.full([1], False, dtype="bool").pin_memory() |
| 144 | + if current_platform.is_maca(): |
| 145 | + self.not_need_stop = paddle.full([1], False, dtype="bool").cpu() |
| 146 | + self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64").cpu() |
| 147 | + self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").cpu() |
| 148 | + self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool").cpu() |
| 149 | + else: |
| 150 | + self.not_need_stop = paddle.full([1], False, dtype="bool").pin_memory() |
| 151 | + self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64").pin_memory() |
| 152 | + self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory() |
| 153 | + self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool").pin_memory() |
145 | 154 | self.not_need_stop_device = paddle.full([1], False, dtype="bool") |
146 | | - self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64").pin_memory() |
147 | 155 | self.stop_flags = paddle.full([max_num_seqs, 1], True, dtype="bool") |
148 | 156 |
|
149 | 157 | self.bad_tokens = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64") |
150 | 158 | self.bad_tokens_len = paddle.full([max_num_seqs], 1, dtype="int64") |
151 | 159 | self.next_tokens = paddle.full([max_num_seqs, 1], -1, dtype="int64") |
152 | 160 | self.is_block_step = paddle.full([max_num_seqs], False, dtype="bool") |
153 | | - self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool").pin_memory() |
154 | 161 | self.is_chunk_step = paddle.full([max_num_seqs], False, dtype="bool").cpu() |
155 | 162 | self.encoder_block_lens = paddle.full([max_num_seqs], 0, dtype="int32") |
156 | 163 | self.step_block_list = paddle.full([max_num_seqs], -1, dtype="int32") |
|
0 commit comments