diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index c106ca1cd9..668250fa5a 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -197,13 +197,12 @@ async def healthcheck(request: Request): if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true": return JSONResponse({"message": "Error"}, status_code=503) - from lightllm.utils.health_check import health_check, health_obj + from lightllm.utils.health_check import health_check - health_task = asyncio.create_task(health_check(g_objs.args, g_objs.httpserver_manager, None)) - if not health_obj.is_health(): - await health_task + is_healthy = health_check(g_objs.httpserver_manager.shm_req_manager) return JSONResponse( - {"message": "Ok" if health_obj.is_health() else "Error"}, status_code=200 if health_obj.is_health() else 503 + {"message": "Ok" if is_healthy else "Error"}, + status_code=200 if is_healthy else 503, ) diff --git a/lightllm/server/core/objs/shm_req_manager.py b/lightllm/server/core/objs/shm_req_manager.py index aa3641afc2..fd9106d59c 100644 --- a/lightllm/server/core/objs/shm_req_manager.py +++ b/lightllm/server/core/objs/shm_req_manager.py @@ -114,6 +114,10 @@ def release_req_index(self, req_index_in_mem): async def async_release_req_index(self, req_index_in_mem): return self.release_req_index(req_index_in_mem) + def is_idle(self) -> bool: + """True when no request slot is currently allocated in shared memory.""" + return int(np.sum(self.alloc_state_shm.arr)) == 0 + # get_req_obj_by_index 和 put_back_req_obj 是 分配好后,进行对象获取和 # 管理的接口,主要是要进行引用计数的管理。 def get_req_obj_by_index(self, req_index_in_mem): diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 6d9f4d54a2..0ac1ea2069 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -52,8 +52,9 @@ def __init__( self.multinode_req_manager = None self.nnodes = args.nnodes - self._shm_lock_pool = AtomicShmArrayLock(f"{get_unique_server_name()}_lightllm_resource_lock", 1) + self._shm_lock_pool = AtomicShmArrayLock(f"{get_unique_server_name()}_lightllm_resource_lock", 2) self._resource_lock = AsyncLock(self._shm_lock_pool.get_lock_context(0)) + self._run_reqs_count_lock = AsyncLock(self._shm_lock_pool.get_lock_context(1)) self.node_rank = args.node_rank self.disable_abort = args.nnodes > 1 and args.dp == 1 # mulitnode dp=1 mode, disable abort self.is_multinode_tp = args.dp == 1 and args.nnodes > 1 @@ -118,11 +119,13 @@ def __init__( # 有的模型的vocab size 读取tokenizer和config.json中不一致 self.vocab_size = max(get_vocab_size(args.model_dir), self.tokenizer.vocab_size) - # The timemark of the latest inference(prefill/decode) which is used to check the health status of the system. - # If the timemark is not updated for a pre-set time, a prob request will be sent to the backend. + # Timemark of the latest successful inference, used by passive /health checks. self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") self.latest_success_infer_time_mark.set_value(int(time.time())) + self.run_reqs_count_mark = SharedInt(f"{get_unique_server_name()}_run_reqs_count_mark") + self.run_reqs_count_mark.set_value(0) + # 用于记录真实的--max_total_token_num 参数,当这个参数在启动参数中没有设置的时候,其是在推理进程中被分析出来的, # 这个时候如果 --max_req_total_len > --max_total_token_num 时,如果httpserver放过一些非法的输入进入后续的模块可能 # 会触发整个系统崩溃,所以httpserver需要知道真实的 max_total_token_num的数据,用于提前拦截非法请求等参数。 @@ -283,12 +286,9 @@ async def generate_wrapper(results_generator): asyncio.create_task(generate_wrapper(results_generator)) return - def alloc_req_id(self, sampling_params, is_health_req: bool = False): + def alloc_req_id(self, sampling_params): # 请求的 id 可以由外部传入,也可以由内部生成,但是由外部传入的时候,要自己保证全局唯一性 # 否则会造成异常问题。目前限制 NORMAL 模式都使用内部id替换, P 和 D 模式按需设置 - # health 请求 request_id 为负数,直接返回 - if is_health_req: - return sampling_params.group_request_id if self.pd_mode.is_normal(): if not self.is_multinode_tp: group_request_id = self.id_gen.generate_id() @@ -312,7 +312,6 @@ async def generate( sampling_params: SamplingParams, multimodal_params: MultimodalParams, request: Request, - is_health_req: bool = False, # 该参数只会在 nixl pd mode 中使用,用于上报一些信息给 pd_master nixl_pd_upload_websocket: ClientConnection = None, # 用于等待 pd_master 下发的交换信息 @@ -321,7 +320,7 @@ async def generate( start_time = time.time() request_headers = request.headers if request is not None else {} - group_request_id = self.alloc_req_id(sampling_params, is_health_req) + group_request_id = self.alloc_req_id(sampling_params) audio_count = len(multimodal_params.audios) if multimodal_params is not None else 0 image_count = len(multimodal_params.images) if multimodal_params is not None else 0 self._log_stage_timing( @@ -332,6 +331,9 @@ async def generate( image_count=image_count, ) + async with self._run_reqs_count_lock: + self.run_reqs_count_mark.set_value(self.run_reqs_count_mark.get_value() + 1) + try: original_multimodal_params = None if self.is_multinode_tp_master: @@ -358,17 +360,17 @@ async def generate( prompt_tokens = len(prompt_ids) prompt_ids = await self._check_and_repair_length(prompt_ids, sampling_params) # 监控 - if group_request_id > 0: - self.metric_client.counter_inc("lightllm_request_count") - self.metric_client.histogram_observe("lightllm_request_input_length", prompt_tokens) - self.metric_client.histogram_observe("lightllm_request_max_new_tokens", sampling_params.max_new_tokens) + self.metric_client.counter_inc("lightllm_request_count") + self.metric_client.histogram_observe("lightllm_request_input_length", prompt_tokens) + self.metric_client.histogram_observe("lightllm_request_max_new_tokens", sampling_params.max_new_tokens) + self._log_stage_timing( group_request_id, start_time, "check_and_repair_length_done", ) - if nixl_pd_upload_websocket is not None and not is_health_req and self.pd_mode.is_NP(): + if nixl_pd_upload_websocket is not None and self.pd_mode.is_NP(): # 在 nixl pd 模式下的 p 节点, 为了更好的兼容多模态的推理流程,np 节点需要先上报其 encode 好的 prompt ids 信息,然后 # 再等待 pd_master 传输下来的对应的进行 decode 节点的decode信息,然后再执行后续的流程 logger.info( @@ -479,6 +481,9 @@ async def generate( await self._release_multimodal_resources(multimodal_params) await self.abort(group_request_id) raise e + finally: + async with self._run_reqs_count_lock: + self.run_reqs_count_mark.set_value(self.run_reqs_count_mark.get_value() - 1) return def _count_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple[int, int]: @@ -754,9 +759,6 @@ async def _wait_to_token_package( f"disk_prompt_cache_ratio:{disk_prompt_cache_ratio} " f"mtp_avg_token_per_step:{mtp_avg_token_per_step} " ) - if group_request_id < 0: - # health 探测请求,不记录日志和监控 - return self.metric_client.histogram_observe("lightllm_cache_length", prompt_cache_len) self.metric_client.histogram_observe("lightllm_cache_ratio", prompt_cache_ratio) self.metric_client.histogram_observe( diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index a646d4f4cc..e341da2a85 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -6,6 +6,9 @@ import httpx import base64 import weakref +import os +import signal +import sys from typing import Dict, Optional, Union, List from websockets import ClientConnection from lightllm.server.pd_io_struct import NodeRole, ObjType @@ -31,7 +34,12 @@ async def timer_log(manager: HttpServerManager): async def pd_handle_loop(manager: HttpServerManager): - assert manager.args.host not in ["127.0.0.1", "localhost"], "pd mode must specify host ip" + if manager.args.host in ["127.0.0.1", "localhost"]: + logger.error("pd mode must specify host ip, not use 127.0.0.1 or localhost") + # kill father process to trigger graceful exit, avoid orphan process + os.kill(os.getppid(), signal.SIGINT) + sys.exit(-1) + if manager.args.host in ["0.0.0.0"]: manager.host_ip = get_hostname_ip() else: @@ -213,11 +221,8 @@ async def _pd_process_generate( nixl_pd_upload_websocket=nixl_pd_upload_websocket, nixl_pd_event=nixl_pd_event, ): - # p d 模式下,将 token 数据放入到转发队列中, 请求id 小于0的请求是health探测请求,不用转发。 - is_health_check_req = sub_req_id < 0 - if not is_health_check_req: - metadata["node_mode"] = manager.args.run_mode - await forwarding_queue.put((sub_req_id, request_output, metadata, finish_status)) + metadata["node_mode"] = manager.args.run_mode + await forwarding_queue.put((sub_req_id, request_output, metadata, finish_status)) except NixlPrefillNodeStopGenToken as e: logger.info(f"nixl prefill node stop gen token for group_request_id {e.group_request_id}") except BaseException as e: diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index e47717747e..6fde73865a 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -451,9 +451,7 @@ def _read_reqs_buffer_and_init_reqs(self): else: assert False, f"error type {type(obj)}" if init_reqs: - req_ids = self._init_reqs(reqs=init_reqs) - if self.args.enable_cpu_cache and req_ids: - self._load_cpu_cache_to_reqs(req_ids=req_ids) + self._init_reqs(reqs=init_reqs) return def _read_nixl_trans_io_buffer_and_update_req_status(self): @@ -506,6 +504,10 @@ def _init_reqs(self, reqs: List[Tuple]): g_infer_context.add_reqs(reqs) g_infer_state_lock.release() req_ids = [e[0] for e in reqs] + + if self.args.enable_cpu_cache: + self._load_cpu_cache_to_reqs(req_ids=req_ids) + return req_ids def _load_cpu_cache_to_reqs(self, req_ids): diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py index b367a66a75..d13987e23f 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py @@ -56,6 +56,9 @@ def _init_reqs(self, reqs: List[Tuple]): g_infer_state_lock.release() req_ids = [e[0] for e in reqs] + + # pd nccl 的 decode 节点模式下不支持 cpu cache + assert not self.args.enable_cpu_cache return req_ids def _post_init_reqs(self, uninit_reqs: List[InferReq]): diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index c83e8cd4a5..f1896b201a 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -82,6 +82,10 @@ def _init_reqs(self, reqs: List[Tuple]): g_infer_state_lock.release() req_ids = [e[0] for e in current_dp_reqs] + + if self.args.enable_cpu_cache: + self._load_cpu_cache_to_reqs(req_ids=req_ids) + return req_ids def infer_loop(self): diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py index f1309ca9cc..481a3197d7 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py @@ -33,11 +33,14 @@ def _init_reqs(self, reqs: List[Tuple]): g_infer_state_lock.acquire() - uninit_reqs = g_infer_context.add_reqs(reqs, init_prefix_cache=False) + uninit_reqs = g_infer_context.add_reqs(reqs, init_prefix_cache=True) # 匹配radix cache,并更新一些资源的管理。 self._post_init_reqs(uninit_reqs=uninit_reqs) - g_infer_state_lock.release() + + # pd nixl 的 decode 节点模式下当前不支持 cpu cache, 未来可能会支持。 + assert not self.args.enable_cpu_cache + req_ids = [e[0] for e in reqs] return req_ids @@ -50,26 +53,9 @@ def _post_init_reqs(self, uninit_reqs: List[InferReq]): for req_obj in uninit_reqs: req_obj: InferReq = req_obj # for easy typing - request_id = req_obj.req_id - if request_id > 0: - req_obj._match_radix_cache() - # 构建 chuncked trans task - self._decode_node_gen_trans_tasks(req_obj=req_obj) - else: - # 对于不合法的请求, 主要是health请求,直接模拟将其finished掉 - req_obj.cur_output_len += 1 - req_obj.set_next_gen_token_id(0, 0.0, 1) - req_obj.finish_status.set_status(FinishStatus.FINISHED_STOP) - - if self.is_master_in_dp: - req_obj.shm_req.shm_cur_kv_len = req_obj.cur_kv_len - req_obj.shm_req.shm_cur_output_len = req_obj.cur_output_len - req_obj.shm_req.finish_token_index = req_obj.get_cur_total_len() - 1 - req_obj.shm_req.finish_status.set_status(FinishStatus.FINISHED_STOP) - req_obj.shm_req.candetoken_out_len = req_obj.cur_output_len - - req_id = req_obj.shm_req.request_id - logger.error(f"req_id: {req_id} forced to finished") + # 构建 chuncked trans task + self._decode_node_gen_trans_tasks(req_obj=req_obj) + return def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_kv_move_manager.py index 877c5c12db..f1b2244024 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_kv_move_manager.py @@ -1,5 +1,6 @@ import inspect import pickle +import setproctitle import torch.multiprocessing as mp import time from typing import List, Dict, Optional, Tuple, Union, Callable @@ -10,6 +11,7 @@ from ..trans_process_obj import KVTransProcess from ..base_kv_move_manager import BaseKVMoveManager from lightllm.utils.error_utils import log_exception +from lightllm.utils.envs_utils import get_unique_server_name logger = init_logger(__name__) @@ -29,6 +31,7 @@ def _init_env(args, info_queue: mp.Queue, event: mp.Event): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::nixl_decode_kv_move_manager") from .up_status import start_up_kv_status_process diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py index b04cbb900a..7913865406 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py @@ -2,6 +2,7 @@ import time import inspect import threading +import setproctitle import torch.multiprocessing as mp import collections import queue @@ -22,6 +23,7 @@ from lightllm.server.core.objs import StartArgs from ..nixl_kv_transporter import NixlKVTransporter from lightllm.utils.error_utils import log_exception +from lightllm.utils.envs_utils import get_unique_server_name logger = init_logger(__name__) @@ -48,6 +50,7 @@ def _init_env( up_status_in_queue: Optional[mp.SimpleQueue], ): torch.backends.cudnn.enabled = False + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::nixl_decode_trans:Device{device_id}") try: torch.cuda.set_device(device_id) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py index f79fb4ea2c..bf70694672 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/up_status.py @@ -5,6 +5,7 @@ import websockets import inspect import pickle +import setproctitle from typing import Dict, Union from dataclasses import asdict @@ -13,6 +14,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.pd_io_struct import PD_Master_Obj import torch.multiprocessing as mp +from lightllm.utils.envs_utils import get_unique_server_name logger = init_logger(__name__) @@ -108,6 +110,7 @@ async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj): def _init_env(args, task_in_queue: mp.SimpleQueue): graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::nixl_up_kv_status") up_kv_manager = UpStatusManager(args, task_in_queue) logger.info(f"up kv manager {str(up_kv_manager)} start ok") while True: diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py index 6f5a6e17d8..20c487d56e 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py @@ -55,9 +55,6 @@ def _prefill_chuncked_handle_func( """ 在每一步chuncked prefill 后,尝试生成chuncked 传输任务,发个 kv_move_manager 进行处理。 """ - # 系统内部的 health 请求不创建 kv 传输任务。 - if req_obj.req_id < 0: - return assert req_obj.cur_kv_len <= req_obj.shm_req.input_len input_len = req_obj.shm_req.input_len diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py index ac8026e58e..fb95091158 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py @@ -1,4 +1,5 @@ import inspect +import setproctitle import torch.multiprocessing as mp import time from typing import List, Dict, Union, Callable @@ -9,6 +10,7 @@ from ..trans_process_obj import KVTransProcess from ..base_kv_move_manager import BaseKVMoveManager from lightllm.utils.error_utils import log_exception +from lightllm.utils.envs_utils import get_unique_server_name logger = init_logger(__name__) @@ -28,6 +30,7 @@ def _init_env(args, info_queue: mp.Queue, event: mp.Event): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::nixl_prefill_kv_move_manager") from .prefill_trans_process import start_prefill_trans_process diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py index 063ce5c6a9..7975a253f1 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py @@ -2,6 +2,7 @@ import time import inspect import threading +import setproctitle import torch.multiprocessing as mp import collections import queue @@ -15,6 +16,7 @@ from lightllm.server.core.objs import StartArgs from ..nixl_kv_transporter import NixlKVTransporter from lightllm.utils.error_utils import log_exception +from lightllm.utils.envs_utils import get_unique_server_name logger = init_logger(__name__) @@ -41,6 +43,7 @@ def _init_env( task_out_queue: mp.Queue, ): torch.backends.cudnn.enabled = False + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::nixl_prefill_trans:Device{device_id}") try: torch.cuda.set_device(device_id) diff --git a/lightllm/server/visualserver/model_infer/__init__.py b/lightllm/server/visualserver/model_infer/__init__.py index ae3c4204db..3e74793634 100644 --- a/lightllm/server/visualserver/model_infer/__init__.py +++ b/lightllm/server/visualserver/model_infer/__init__.py @@ -4,12 +4,13 @@ import uuid import os import multiprocessing +import setproctitle from lightllm.utils.retry_utils import retry from rpyc.utils.factory import unix_connect from rpyc.utils.classic import obtain from rpyc.utils.server import ThreadedServer from lightllm.utils.graceful_utils import graceful_registry -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name from .model_rpc_client import VisualModelRpcClient from .model_rpc import VisualModelRpcServer from ..objs import rpyc_config @@ -18,6 +19,7 @@ def _init_env(socket_path: str, success_event): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_model_infer") import lightllm.utils.rpyc_fix_utils as _ diff --git a/lightllm/server/visualserver/proxy_manager.py b/lightllm/server/visualserver/proxy_manager.py index 2cf02d19e6..0c977b2aa9 100644 --- a/lightllm/server/visualserver/proxy_manager.py +++ b/lightllm/server/visualserver/proxy_manager.py @@ -211,7 +211,7 @@ def start_visual_process(args, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) - setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_proxy_server") start_parent_check_thread() try: visualserver = ProxyVisualManager(args=args) diff --git a/lightllm/server/visualserver/visual_only_manager.py b/lightllm/server/visualserver/visual_only_manager.py index 15705a1140..b06713d87c 100644 --- a/lightllm/server/visualserver/visual_only_manager.py +++ b/lightllm/server/visualserver/visual_only_manager.py @@ -177,7 +177,7 @@ def start_visual_process(args: StartArgs, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) - setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_only_server") start_parent_check_thread() try: diff --git a/lightllm/utils/health_check.py b/lightllm/utils/health_check.py index f6c52bdb38..d2a776b862 100644 --- a/lightllm/utils/health_check.py +++ b/lightllm/utils/health_check.py @@ -1,105 +1,52 @@ import os import time -import asyncio -import numpy as np from dataclasses import dataclass -from lightllm.server.core.objs import SamplingParams -from lightllm.server.multimodal_params import MultimodalParams -from lightllm.server.httpserver.manager import HttpServerManager +from typing import TYPE_CHECKING + from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt -from fastapi import Request from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args +from lightllm.utils.envs_utils import get_unique_server_name + +if TYPE_CHECKING: + from lightllm.server.core.objs.shm_req_manager import ShmReqManager logger = init_logger(__name__) @dataclass class HealthObj: - _is_health: bool = False - _is_health_checking: bool = False - _failure_count: int = 0 - _failure_threshold: int = int(os.getenv("HEALTH_FAILURE_THRESHOLD", 3)) - timeout: int = int(os.getenv("HEALTH_TIMEOUT", 100)) - dynamic_timeout: int = int(os.getenv("HEALTH_TIMEOUT", 100)) - latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") - - def begin_check(self): - self._is_health_checking = True - - def end_check(self): - self._is_health_checking = False - - def set_unhealth(self): - self._failure_count += 1 - self.dynamic_timeout += self.timeout - if self._failure_count > self._failure_threshold: - self._is_health = False - - def set_health(self): - self._is_health = True - self._failure_count = 0 - self.dynamic_timeout = self.timeout - - def is_health(self): - return self._is_health + grace_timeout: int = int(os.getenv("HEALTH_TIMEOUT", "200")) - def is_checking(self): - return self._is_health_checking + def __post_init__(self): + uid = get_unique_server_name() + self.latest_success_infer_time_mark = SharedInt(f"{uid}_latest_success_infer_time_mark") + self.run_reqs_count_mark = SharedInt(f"{uid}_run_reqs_count_mark") - def has_latest_inference(self): - last_timemark = self.latest_success_infer_time_mark.get_value() - time_diff = time.time() - last_timemark - return time_diff < self.timeout + def check(self, shm_req_manager: "ShmReqManager") -> bool: + """On-the-fly health check: recent success is ok; otherwise require no in-flight shm requests.""" + try: + now = time.time() + last_success_time = self.latest_success_infer_time_mark.get_value() + + # 如果最近一次成功推理的时间距离现在小于 grace_timeout,则认为系统健康 + if now - last_success_time <= self.grace_timeout: + return True + elif self.run_reqs_count_mark.get_value() == 0 and shm_req_manager.is_idle(): + # 如果最近一次成功推理的时间距离现在大于 grace_timeout,并且没有在推理的请求,则认为系统健康 + return True + else: + logger.warning( + "Health check failed: no success for %ss and in-flight shm requests remain", + int(now - last_success_time), + ) + return False + except Exception as e: + logger.exception(str(e)) + return False health_obj = HealthObj() -async def health_check(args, httpserver_manager: HttpServerManager, request: Request): - if health_obj.is_checking(): - return health_obj.is_health() - - if health_obj.is_health() and health_obj.has_latest_inference(): - return health_obj.is_health() - - health_obj.begin_check() - try: - request_dict = {"inputs": "你好!", "parameters": {"do_sample": True, "temperature": 0.8, "max_new_tokens": 2}} - if args.run_mode in ["prefill", "nixl_prefill"]: - request_dict["parameters"]["max_new_tokens"] = 1 - prompt = request_dict.pop("inputs") - sample_params_dict = request_dict["parameters"] - sampling_params = SamplingParams() - sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict) - sampling_params.verify() - - if get_env_start_args().run_mode == "pd_master": - # Since the id assigned by pd master needs to be passed to prefill and decode nodes for inference, - # a normal request id is required instead of a negative id. - sampling_params.group_request_id = httpserver_manager.id_gen.generate_id() - else: - sampling_params.group_request_id = -httpserver_manager.id_gen.generate_id() # health monitor 的 id 是负的 - multimodal_params_dict = request_dict.get("multimodal_params", {}) - multimodal_params = MultimodalParams(**multimodal_params_dict) - results_generator = httpserver_manager.generate( - prompt, sampling_params, multimodal_params, request, is_health_req=True - ) - - async def check_timeout(results_generator): - async for _, _, _, _ in results_generator: - pass - - try: - await asyncio.wait_for(check_timeout(results_generator), timeout=health_obj.dynamic_timeout) - health_obj.set_health() - except asyncio.TimeoutError: - health_obj.set_unhealth() - logger.warning(f"Health check timeout! The failure count is: {str(health_obj._failure_count)}") - return health_obj.is_health() - except Exception as e: - logger.exception(str(e)) - health_obj.set_unhealth() - return health_obj.is_health() - finally: - health_obj.end_check() +def health_check(shm_req_manager: "ShmReqManager") -> bool: + return health_obj.check(shm_req_manager)