Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions lightllm/server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,12 @@ async def healthcheck(request: Request):

if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true":
return JSONResponse({"message": "Error"}, status_code=503)
from lightllm.utils.health_check import health_check, health_obj
from lightllm.utils.health_check import health_check

health_task = asyncio.create_task(health_check(g_objs.args, g_objs.httpserver_manager, None))
if not health_obj.is_health():
await health_task
is_healthy = health_check(g_objs.httpserver_manager.shm_req_manager)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If the server is still starting up or fails to initialize, g_objs.httpserver_manager may be None. Calling g_objs.httpserver_manager.shm_req_manager directly will raise an AttributeError, resulting in a 500 Internal Server Error instead of a clean 503 Service Unavailable response.

Adding a defensive check to ensure g_objs.httpserver_manager and its shm_req_manager are fully initialized before invoking the health check prevents crashes during startup probes.

    if g_objs.httpserver_manager is None or getattr(g_objs.httpserver_manager, "shm_req_manager", None) is None:
        return JSONResponse({"message": "Error"}, status_code=503)

    is_healthy = health_check(g_objs.httpserver_manager.shm_req_manager)

return JSONResponse(
{"message": "Ok" if health_obj.is_health() else "Error"}, status_code=200 if health_obj.is_health() else 503
{"message": "Ok" if is_healthy else "Error"},
status_code=200 if is_healthy else 503,
)


Expand Down
4 changes: 4 additions & 0 deletions lightllm/server/core/objs/shm_req_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def release_req_index(self, req_index_in_mem):
async def async_release_req_index(self, req_index_in_mem):
return self.release_req_index(req_index_in_mem)

def is_idle(self) -> bool:
"""True when no request slot is currently allocated in shared memory."""
return int(np.sum(self.alloc_state_shm.arr)) == 0

# get_req_obj_by_index 和 put_back_req_obj 是 分配好后,进行对象获取和
# 管理的接口,主要是要进行引用计数的管理。
def get_req_obj_by_index(self, req_index_in_mem):
Expand Down
36 changes: 19 additions & 17 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ def __init__(

self.multinode_req_manager = None
self.nnodes = args.nnodes
self._shm_lock_pool = AtomicShmArrayLock(f"{get_unique_server_name()}_lightllm_resource_lock", 1)
self._shm_lock_pool = AtomicShmArrayLock(f"{get_unique_server_name()}_lightllm_resource_lock", 2)
self._resource_lock = AsyncLock(self._shm_lock_pool.get_lock_context(0))
self._run_reqs_count_lock = AsyncLock(self._shm_lock_pool.get_lock_context(1))
self.node_rank = args.node_rank
self.disable_abort = args.nnodes > 1 and args.dp == 1 # mulitnode dp=1 mode, disable abort
self.is_multinode_tp = args.dp == 1 and args.nnodes > 1
Expand Down Expand Up @@ -118,11 +119,13 @@ def __init__(
# 有的模型的vocab size 读取tokenizer和config.json中不一致
self.vocab_size = max(get_vocab_size(args.model_dir), self.tokenizer.vocab_size)

# The timemark of the latest inference(prefill/decode) which is used to check the health status of the system.
# If the timemark is not updated for a pre-set time, a prob request will be sent to the backend.
# Timemark of the latest successful inference, used by passive /health checks.
self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark")
self.latest_success_infer_time_mark.set_value(int(time.time()))

self.run_reqs_count_mark = SharedInt(f"{get_unique_server_name()}_run_reqs_count_mark")
self.run_reqs_count_mark.set_value(0)

# 用于记录真实的--max_total_token_num 参数,当这个参数在启动参数中没有设置的时候,其是在推理进程中被分析出来的,
# 这个时候如果 --max_req_total_len > --max_total_token_num 时,如果httpserver放过一些非法的输入进入后续的模块可能
# 会触发整个系统崩溃,所以httpserver需要知道真实的 max_total_token_num的数据,用于提前拦截非法请求等参数。
Expand Down Expand Up @@ -283,12 +286,9 @@ async def generate_wrapper(results_generator):
asyncio.create_task(generate_wrapper(results_generator))
return

def alloc_req_id(self, sampling_params, is_health_req: bool = False):
def alloc_req_id(self, sampling_params):
# 请求的 id 可以由外部传入,也可以由内部生成,但是由外部传入的时候,要自己保证全局唯一性
# 否则会造成异常问题。目前限制 NORMAL 模式都使用内部id替换, P 和 D 模式按需设置
# health 请求 request_id 为负数,直接返回
if is_health_req:
return sampling_params.group_request_id
if self.pd_mode.is_normal():
if not self.is_multinode_tp:
group_request_id = self.id_gen.generate_id()
Expand All @@ -312,7 +312,6 @@ async def generate(
sampling_params: SamplingParams,
multimodal_params: MultimodalParams,
request: Request,
is_health_req: bool = False,
# 该参数只会在 nixl pd mode 中使用,用于上报一些信息给 pd_master
nixl_pd_upload_websocket: ClientConnection = None,
# 用于等待 pd_master 下发的交换信息
Expand All @@ -321,7 +320,7 @@ async def generate(

start_time = time.time()
request_headers = request.headers if request is not None else {}
group_request_id = self.alloc_req_id(sampling_params, is_health_req)
group_request_id = self.alloc_req_id(sampling_params)
audio_count = len(multimodal_params.audios) if multimodal_params is not None else 0
image_count = len(multimodal_params.images) if multimodal_params is not None else 0
self._log_stage_timing(
Expand All @@ -332,6 +331,9 @@ async def generate(
image_count=image_count,
)

async with self._run_reqs_count_lock:
self.run_reqs_count_mark.set_value(self.run_reqs_count_mark.get_value() + 1)

try:
original_multimodal_params = None
if self.is_multinode_tp_master:
Expand All @@ -358,17 +360,17 @@ async def generate(
prompt_tokens = len(prompt_ids)
prompt_ids = await self._check_and_repair_length(prompt_ids, sampling_params)
# 监控
if group_request_id > 0:
self.metric_client.counter_inc("lightllm_request_count")
self.metric_client.histogram_observe("lightllm_request_input_length", prompt_tokens)
self.metric_client.histogram_observe("lightllm_request_max_new_tokens", sampling_params.max_new_tokens)
self.metric_client.counter_inc("lightllm_request_count")
self.metric_client.histogram_observe("lightllm_request_input_length", prompt_tokens)
self.metric_client.histogram_observe("lightllm_request_max_new_tokens", sampling_params.max_new_tokens)

self._log_stage_timing(
group_request_id,
start_time,
"check_and_repair_length_done",
)

if nixl_pd_upload_websocket is not None and not is_health_req and self.pd_mode.is_NP():
if nixl_pd_upload_websocket is not None and self.pd_mode.is_NP():
# 在 nixl pd 模式下的 p 节点, 为了更好的兼容多模态的推理流程,np 节点需要先上报其 encode 好的 prompt ids 信息,然后
# 再等待 pd_master 传输下来的对应的进行 decode 节点的decode信息,然后再执行后续的流程
logger.info(
Expand Down Expand Up @@ -479,6 +481,9 @@ async def generate(
await self._release_multimodal_resources(multimodal_params)
await self.abort(group_request_id)
raise e
finally:
async with self._run_reqs_count_lock:
self.run_reqs_count_mark.set_value(self.run_reqs_count_mark.get_value() - 1)
return

def _count_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple[int, int]:
Expand Down Expand Up @@ -754,9 +759,6 @@ async def _wait_to_token_package(
f"disk_prompt_cache_ratio:{disk_prompt_cache_ratio} "
f"mtp_avg_token_per_step:{mtp_avg_token_per_step} "
)
if group_request_id < 0:
# health 探测请求,不记录日志和监控
return
self.metric_client.histogram_observe("lightllm_cache_length", prompt_cache_len)
self.metric_client.histogram_observe("lightllm_cache_ratio", prompt_cache_ratio)
self.metric_client.histogram_observe(
Expand Down
17 changes: 11 additions & 6 deletions lightllm/server/httpserver/pd_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import httpx
import base64
import weakref
import os
import signal
import sys
from typing import Dict, Optional, Union, List
from websockets import ClientConnection
from lightllm.server.pd_io_struct import NodeRole, ObjType
Expand All @@ -31,7 +34,12 @@ async def timer_log(manager: HttpServerManager):


async def pd_handle_loop(manager: HttpServerManager):
assert manager.args.host not in ["127.0.0.1", "localhost"], "pd mode must specify host ip"
if manager.args.host in ["127.0.0.1", "localhost"]:
logger.error("pd mode must specify host ip, not use 127.0.0.1 or localhost")
# kill father process to trigger graceful exit, avoid orphan process
os.kill(os.getppid(), signal.SIGINT)
sys.exit(-1)

if manager.args.host in ["0.0.0.0"]:
manager.host_ip = get_hostname_ip()
else:
Expand Down Expand Up @@ -213,11 +221,8 @@ async def _pd_process_generate(
nixl_pd_upload_websocket=nixl_pd_upload_websocket,
nixl_pd_event=nixl_pd_event,
):
# p d 模式下,将 token 数据放入到转发队列中, 请求id 小于0的请求是health探测请求,不用转发。
is_health_check_req = sub_req_id < 0
if not is_health_check_req:
metadata["node_mode"] = manager.args.run_mode
await forwarding_queue.put((sub_req_id, request_output, metadata, finish_status))
metadata["node_mode"] = manager.args.run_mode
await forwarding_queue.put((sub_req_id, request_output, metadata, finish_status))
except NixlPrefillNodeStopGenToken as e:
logger.info(f"nixl prefill node stop gen token for group_request_id {e.group_request_id}")
except BaseException as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,7 @@ def _read_reqs_buffer_and_init_reqs(self):
else:
assert False, f"error type {type(obj)}"
if init_reqs:
req_ids = self._init_reqs(reqs=init_reqs)
if self.args.enable_cpu_cache and req_ids:
self._load_cpu_cache_to_reqs(req_ids=req_ids)
self._init_reqs(reqs=init_reqs)
return

def _read_nixl_trans_io_buffer_and_update_req_status(self):
Expand Down Expand Up @@ -506,6 +504,10 @@ def _init_reqs(self, reqs: List[Tuple]):
g_infer_context.add_reqs(reqs)
g_infer_state_lock.release()
req_ids = [e[0] for e in reqs]

if self.args.enable_cpu_cache:
self._load_cpu_cache_to_reqs(req_ids=req_ids)

return req_ids

def _load_cpu_cache_to_reqs(self, req_ids):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def _init_reqs(self, reqs: List[Tuple]):

g_infer_state_lock.release()
req_ids = [e[0] for e in reqs]

# pd nccl 的 decode 节点模式下不支持 cpu cache
assert not self.args.enable_cpu_cache
return req_ids

def _post_init_reqs(self, uninit_reqs: List[InferReq]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def _init_reqs(self, reqs: List[Tuple]):
g_infer_state_lock.release()

req_ids = [e[0] for e in current_dp_reqs]

if self.args.enable_cpu_cache:
self._load_cpu_cache_to_reqs(req_ids=req_ids)

return req_ids

def infer_loop(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@ def _init_reqs(self, reqs: List[Tuple]):

g_infer_state_lock.acquire()

uninit_reqs = g_infer_context.add_reqs(reqs, init_prefix_cache=False)
uninit_reqs = g_infer_context.add_reqs(reqs, init_prefix_cache=True)
# 匹配radix cache,并更新一些资源的管理。
self._post_init_reqs(uninit_reqs=uninit_reqs)

g_infer_state_lock.release()

# pd nixl 的 decode 节点模式下当前不支持 cpu cache, 未来可能会支持。
assert not self.args.enable_cpu_cache

req_ids = [e[0] for e in reqs]
return req_ids

Expand All @@ -50,26 +53,9 @@ def _post_init_reqs(self, uninit_reqs: List[InferReq]):

for req_obj in uninit_reqs:
req_obj: InferReq = req_obj # for easy typing
request_id = req_obj.req_id
if request_id > 0:
req_obj._match_radix_cache()
# 构建 chuncked trans task
self._decode_node_gen_trans_tasks(req_obj=req_obj)
else:
# 对于不合法的请求, 主要是health请求,直接模拟将其finished掉
req_obj.cur_output_len += 1
req_obj.set_next_gen_token_id(0, 0.0, 1)
req_obj.finish_status.set_status(FinishStatus.FINISHED_STOP)

if self.is_master_in_dp:
req_obj.shm_req.shm_cur_kv_len = req_obj.cur_kv_len
req_obj.shm_req.shm_cur_output_len = req_obj.cur_output_len
req_obj.shm_req.finish_token_index = req_obj.get_cur_total_len() - 1
req_obj.shm_req.finish_status.set_status(FinishStatus.FINISHED_STOP)
req_obj.shm_req.candetoken_out_len = req_obj.cur_output_len

req_id = req_obj.shm_req.request_id
logger.error(f"req_id: {req_id} forced to finished")
# 构建 chuncked trans task
self._decode_node_gen_trans_tasks(req_obj=req_obj)

return

def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import pickle
import setproctitle
import torch.multiprocessing as mp
import time
from typing import List, Dict, Optional, Tuple, Union, Callable
Expand All @@ -10,6 +11,7 @@
from ..trans_process_obj import KVTransProcess
from ..base_kv_move_manager import BaseKVMoveManager
from lightllm.utils.error_utils import log_exception
from lightllm.utils.envs_utils import get_unique_server_name

logger = init_logger(__name__)

Expand All @@ -29,6 +31,7 @@ def _init_env(args, info_queue: mp.Queue, event: mp.Event):

# 注册graceful 退出的处理
graceful_registry(inspect.currentframe().f_code.co_name)
setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::nixl_decode_kv_move_manager")

from .up_status import start_up_kv_status_process

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import inspect
import threading
import setproctitle
import torch.multiprocessing as mp
import collections
import queue
Expand All @@ -22,6 +23,7 @@
from lightllm.server.core.objs import StartArgs
from ..nixl_kv_transporter import NixlKVTransporter
from lightllm.utils.error_utils import log_exception
from lightllm.utils.envs_utils import get_unique_server_name

logger = init_logger(__name__)

Expand All @@ -48,6 +50,7 @@ def _init_env(
up_status_in_queue: Optional[mp.SimpleQueue],
):
torch.backends.cudnn.enabled = False
setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::nixl_decode_trans:Device{device_id}")

try:
torch.cuda.set_device(device_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import websockets
import inspect
import pickle
import setproctitle

from typing import Dict, Union
from dataclasses import asdict
Expand All @@ -13,6 +14,7 @@
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.server.pd_io_struct import PD_Master_Obj
import torch.multiprocessing as mp
from lightllm.utils.envs_utils import get_unique_server_name

logger = init_logger(__name__)

Expand Down Expand Up @@ -108,6 +110,7 @@ async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj):

def _init_env(args, task_in_queue: mp.SimpleQueue):
graceful_registry(inspect.currentframe().f_code.co_name)
setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::nixl_up_kv_status")
up_kv_manager = UpStatusManager(args, task_in_queue)
logger.info(f"up kv manager {str(up_kv_manager)} start ok")
while True:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ def _prefill_chuncked_handle_func(
"""
在每一步chuncked prefill 后,尝试生成chuncked 传输任务,发个 kv_move_manager 进行处理。
"""
# 系统内部的 health 请求不创建 kv 传输任务。
if req_obj.req_id < 0:
return

assert req_obj.cur_kv_len <= req_obj.shm_req.input_len
input_len = req_obj.shm_req.input_len
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import setproctitle
import torch.multiprocessing as mp
import time
from typing import List, Dict, Union, Callable
Expand All @@ -9,6 +10,7 @@
from ..trans_process_obj import KVTransProcess
from ..base_kv_move_manager import BaseKVMoveManager
from lightllm.utils.error_utils import log_exception
from lightllm.utils.envs_utils import get_unique_server_name

logger = init_logger(__name__)

Expand All @@ -28,6 +30,7 @@ def _init_env(args, info_queue: mp.Queue, event: mp.Event):

# 注册graceful 退出的处理
graceful_registry(inspect.currentframe().f_code.co_name)
setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::nixl_prefill_kv_move_manager")

from .prefill_trans_process import start_prefill_trans_process

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import inspect
import threading
import setproctitle
import torch.multiprocessing as mp
import collections
import queue
Expand All @@ -15,6 +16,7 @@
from lightllm.server.core.objs import StartArgs
from ..nixl_kv_transporter import NixlKVTransporter
from lightllm.utils.error_utils import log_exception
from lightllm.utils.envs_utils import get_unique_server_name


logger = init_logger(__name__)
Expand All @@ -41,6 +43,7 @@ def _init_env(
task_out_queue: mp.Queue,
):
torch.backends.cudnn.enabled = False
setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::nixl_prefill_trans:Device{device_id}")

try:
torch.cuda.set_device(device_id)
Expand Down
Loading
Loading