diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index b70abf8ee1f..9e46a9af92b 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1097,6 +1097,7 @@ def _fetch_request(): if self.engine_worker_queue.exist_tasks(): time.sleep(0.001) continue + if self.cfg.scheduler_config.splitwise_role != "mixed": if not is_fetching: is_fetching = True @@ -1118,7 +1119,10 @@ def _fetch_request(): if hasattr(self.resource_manager, "scheduler_unhandled_request_num"): self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num() # 2. Schedule requests - batch_request, error_tasks = self.resource_manager.schedule() + if self.cfg.scheduler_config.splitwise_role == "prefill": + batch_request, error_tasks = self.resource_manager.prefill_schedule() + else: + batch_request, error_tasks = self.resource_manager.schedule() # 3. Send to engine if len(batch_request) > 0: diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 8bc831844f5..692e210623a 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -1248,6 +1248,32 @@ def _allocate_decode_and_extend(): return batch_request, error_reqs + def prefill_schedule(self): + with self.lock: + # P instance has no decode — full budget for prefill + batch_request = BatchRequest() + token_budget = self.config.scheduler_config.max_num_batched_tokens + + assert len(self.waiting) == 0, "Prefill scheduler should not have waiting requests" + + # Prepare prefill tasks for all running requests + for request in list(self.running): + if self._is_decoding(request): + continue + + num_new_tokens = self._get_num_new_tokens(request, token_budget) + if num_new_tokens == 0: + continue + + # Add requests into scheduled batch as blocks were preallocated + llm_logger.debug(f"schedule prefill tasks {request.request_id} with {num_new_tokens} tokens") + batch_request.add_request(self._prepare_prefill_task(request, num_new_tokens)) + token_budget -= num_new_tokens + request.num_computed_tokens += num_new_tokens + + self.update_metrics() + return batch_request, [] + def waiting_async_process(self, request: Request) -> None: """ Check if async preprocessing is complete for a request.