Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,9 @@ def update_params(self, request: Any):
"""Update params."""
self.executor.update_params(request)

def sleep(self, level: int = 1):
async def sleep(self, level: int = 1):
"""Sleep."""
self.executor.sleep(level)
await self.executor.sleep(level)

def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/pytorch/engine/executor/mp_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,14 @@ def warmup(self):
"""Build cache engine."""
self.collective_rpc('warmup')

async def sleep(self, level: int = 1):
"""Sleep."""
await self.collective_rpc_async('sleep', args=(level, ), return_mask=0)

def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
self.collective_rpc('wakeup', args=(tags, ), return_mask=0)

async def _prefetch_outputs(self):
while True:
out = (await self.collective_rpc_async('get_outputs', receiver_mask=1, return_mask=1))[0]
Expand Down
16 changes: 14 additions & 2 deletions lmdeploy/pytorch/engine/executor/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,18 @@ def collective_rpc(self,
kwargs = dict()
return ray.get([getattr(worker, method).remote(*args, **kwargs) for worker in self.workers], timeout=timeout)

async def collective_rpc_async(self,
method: str,
args: tuple[Any] = None,
kwargs: dict[str, Any] = None):
"""Collective async rpc."""
if args is None:
args = list()
if kwargs is None:
kwargs = dict()
tasks = [getattr(worker, method).remote(*args, **kwargs) for worker in self.workers]
return await asyncio.gather(*tasks)

def build_model(self):
"""Build model."""
self.collective_rpc('build_model')
Expand Down Expand Up @@ -353,9 +365,9 @@ def warmup(self):
"""Build cache engine."""
self.collective_rpc('warmup')

def sleep(self, level: int = 1):
async def sleep(self, level: int = 1):
"""Sleep."""
self.collective_rpc('sleep', (level, ))
await self.collective_rpc_async('sleep', (level, ))

def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/pytorch/engine/executor/uni_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ async def get_output_async(self, dp_rank: int = 0):
assert dp_rank == 0
return await self.model_agent.get_output_async()

async def sleep(self, level: int = 1):
"""Sleep."""
await self.model_agent.sleep(level)

def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
self.model_agent.wakeup(tags)

def get_input_processor(self):
"""Get input processor."""
return self.model_agent.get_input_processor()
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/engine/mp_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def end_session(self, session_id: int):
"""End session."""
return self._collective_rpc('end_session', session_id)

def sleep(self, level: int):
async def sleep(self, level: int):
"""sleep."""
return self._collective_rpc('sleep', level)
return await self._collective_rpc_async('sleep', level)

def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/engine/mp_engine/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):
"""
return self.engine.p2p_drop_connect(drop_conn_request)

def sleep(self, level: int = 1):
async def sleep(self, level: int = 1):
"""sleep."""
return self.engine.sleep(level)
return await self.engine.sleep(level)

def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
Expand Down
49 changes: 41 additions & 8 deletions lmdeploy/serve/core/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class GenOut:
history_token_len: int
input_token_len: int
generate_token_len: int
finish_reason: Literal['stop', 'length', 'error'] | None = None
finish_reason: Literal['stop', 'length', 'error', 'abort'] | None = None
token_ids: list[int] | None = None
logprobs: list[dict[int, float]] | None = None
logits: Any = None
Expand Down Expand Up @@ -201,6 +201,23 @@ def _build_stat_loggers(self):
# set stats loggers of metrics processor
metrics_processor.stat_loggers = self.stat_loggers

def _if_session_stale(self, session: Session,
input_token_len: int) -> GenOut | None:
"""If ``session.epoch`` was stamped by api_server and
``stop_all_session`` ran since then (the engine epoch changed), drop
the session."""
epoch = session.epoch
if epoch is None or epoch == self.epoch:
return None
logger.info(f'[generate] drop stale session {session.session_id} '
f'(session.epoch={epoch}, async_engine.epoch={self.epoch})')
return GenOut(response='',
history_token_len=session.step,
input_token_len=input_token_len,
generate_token_len=0,
finish_reason='abort',
Comment thread
lvhan028 marked this conversation as resolved.
token_ids=[])

async def get_schedule_metrics(self):
result = self.engine.get_schedule_metrics()
if asyncio.iscoroutine(result):
Expand All @@ -215,19 +232,24 @@ async def do_log_stats(self):

async def stop_all_session(self):
"""Stop all running sessions."""
logger.info('stop all sessions')
logger.info(f'stop all sessions, epoch {self.epoch} -> {self.epoch + 1}')
self.epoch += 1
await self.session_mgr.async_abort_all()

def sleep(self, level: int = 1):
def prepare_sleep(self):
"""Reject new inference requests before backend sleep starts."""
self.sleeping_tags = {'weights', 'kv_cache'}
self.is_sleeping = True

async def sleep(self, level: int = 1):
"""Sleep the model.

Args:
level (int): The sleep level. Level 1 sleep will offload the model
weights and discard the kv cache. Level 2 sleep will
discard both the model weights and the kv cache.
"""
self.engine.sleep(level)
await self.engine.sleep(level)
Comment thread
lvhan028 marked this conversation as resolved.
self.sleeping_tags = {'weights', 'kv_cache'}
self.is_sleeping = True

Expand Down Expand Up @@ -342,7 +364,8 @@ async def generate(
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
"""
epoch = self.epoch
metrics_processor.increase_total_requests()

if (messages is not None) ^ (input_ids is None):
raise ValueError('You must specify exactly one of messages or input_ids')
if isinstance(session_id, Session):
Comment on lines +367 to 371
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metrics_processor.increase_total_requests() is now called before the input validation that can raise ValueError (e.g. the messages/input_ids XOR check). If a caller triggers these errors, total requests will be incremented without a corresponding failed-request metric, skewing metrics. Consider moving increase_total_requests() after validation (or ensuring validation errors are counted as failures).

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -389,6 +412,7 @@ async def generate(

if gen_config.max_new_tokens == 0:
logger.info(f'run out of tokens. session={session_id}.')
metrics_processor.increase_failed_requests('error')
yield GenOut(response='',
history_token_len=session.step,
input_token_len=len(input_ids),
Expand All @@ -403,6 +427,7 @@ async def generate(
or gen_config.output_logits == 'all'):
errmsg = ('lmdeploy does not support outputting all token\'s logits or last_hidden_state '
'when prefix caching is ON')
metrics_processor.increase_failed_requests('error')
yield GenOut(response=errmsg,
history_token_len=session.step,
input_token_len=len(input_ids),
Expand All @@ -424,10 +449,18 @@ def is_error(status):
if not gen_config.ignore_eos:
stop_ids = gen_config.stop_token_ids or []

metrics_processor.increase_total_requests()

stale = self._if_session_stale(session, len(prompt_input['input_ids']))
if stale is not None:
metrics_processor.increase_failed_requests('abort')
yield stale
if sequence_end:
self.session_mgr.remove(session)
return
async with session.request_handle() as handle:
if epoch != self.epoch:
logger.info(f'[generate] session {session_id} got aborted before starting inference')
if session.epoch is not None and session.epoch != self.epoch:
logger.info(f'[generate] session {session_id} got aborted before starting inference, '
f'session.epoch={session.epoch}, async_engine.epoch={self.epoch}')
metrics_processor.increase_failed_requests('abort')
yield GenOut(response='',
history_token_len=0,
Expand Down
9 changes: 7 additions & 2 deletions lmdeploy/serve/managers/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __init__(self, session_id: int, session_mgr: SessionManager, **kwargs):
self.history: list[tuple[Any, str]] = []
self.gen_config: GenerationConfig | None = None
self.step: int = 0
# Set by api_server to AsyncEngine.epoch when a request binds a session;
# generate() drops work if stop_all_session() bumped epoch after bind.
self.epoch: int | None = None
# event to wait for the session to be active
self._active: asyncio.Event | None = None
self._handle = None # inference instance
Expand Down Expand Up @@ -64,6 +67,7 @@ def reset(self):
self.history = []
self.gen_config = None
self.step = 0
self.epoch = None
self._active = None
self._handle = None
self._session_mgr = None
Expand Down Expand Up @@ -101,7 +105,7 @@ async def request_handle(self):

async def async_abort(self):
"""Abort the session."""
logger.info(f'[session] Aborting session {self.session_id}')
logger.debug(f'[session] Aborting session {self.session_id}, epoch={self.epoch}')
if self._handle is not None:
await self._handle.async_cancel(self.session_id)

Expand Down Expand Up @@ -205,13 +209,14 @@ def get(self, session_id: int | None = None, **kwargs) -> Session:
session.update(**kwargs)
return session
else:
logger.info(f'[SessionManager] session {session_id} not found. Creating...')
logger.debug(f'[SessionManager] session {session_id} not found. Creating...')
session = Session(session_id, self, **kwargs)
self.sessions[session_id] = session
return session

async def async_abort_all(self):
"""Abort all sessions."""
logger.info(f'[SessionManager] aborting all {len(self.sessions)} sessions')
tasks = []
for session in list(self.sessions.values()):
tasks.append(session.async_abort())
Expand Down
38 changes: 29 additions & 9 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from __future__ import annotations

# yapf: disable
import asyncio
import copy
Expand All @@ -10,7 +12,7 @@
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import Literal
from typing import TYPE_CHECKING, Literal

import uvicorn
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status
Expand Down Expand Up @@ -76,10 +78,13 @@
)
from lmdeploy.serve.openai.reasoning_parser.reasoning_parser import ReasoningParser, ReasoningParserManager
from lmdeploy.serve.openai.tool_parser.tool_parser import ToolParser, ToolParserManager
from lmdeploy.serve.utils.server_utils import validate_json_request
from lmdeploy.serve.utils.server_utils import AuthenticationMiddleware, EngineSleepingMiddleware, validate_json_request
from lmdeploy.tokenizer import DetokenizeState, Tokenizer
from lmdeploy.utils import get_logger

if TYPE_CHECKING:
from lmdeploy.serve.managers import Session

# yapf: enable

logger = get_logger('lmdeploy')
Expand All @@ -100,12 +105,15 @@ class VariableInterface:
enable_abort_handling: bool = False

@staticmethod
def get_session(session_id: int) -> int:
def get_session(session_id: int) -> Session:
session_mgr = VariableInterface.get_session_manager()
if session_id == -1:
return session_mgr.get()
session = session_mgr.get()
else:
return session_mgr.get(session_id)
session = session_mgr.get(session_id)
# Stamp epoch for ``stop_all_session`` / ``abort_all`` coordination in ``AsyncEngine.generate``.
session.epoch = VariableInterface.async_engine.epoch
return session

@staticmethod
def get_session_manager():
Expand Down Expand Up @@ -769,7 +777,6 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None
error_check_ret = check_request(request)
if error_check_ret is not None:
return error_check_ret

json_request = await raw_request.json()
migration_request = json_request.pop('migration_request', None)
with_cache = json_request.pop('with_cache', False)
Expand Down Expand Up @@ -963,6 +970,7 @@ async def generate(request: GenerateReqInput, raw_request: Request = None):
error_check_ret = check_request(request)
if error_check_ret is not None:
return error_check_ret

session = VariableInterface.get_session(request.session_id)

prompt = request.prompt
Expand Down Expand Up @@ -1175,7 +1183,16 @@ def update_params(request: UpdateParamsRequest, raw_request: Request = None):
@router.post('/sleep', dependencies=[Depends(validate_json_request)])
async def sleep(raw_request: Request = None):
level = raw_request.query_params.get('level', '1')
VariableInterface.async_engine.sleep(int(level))
try:
level = int(level)
except (TypeError, ValueError):
return create_error_response(HTTPStatus.BAD_REQUEST, 'The "level" query parameter must be an integer.')
if level not in (1, 2):
return create_error_response(HTTPStatus.BAD_REQUEST, 'The "level" query parameter must be 1 or 2.')
async_engine = VariableInterface.async_engine
async_engine.prepare_sleep()
await async_engine.stop_all_session()
await async_engine.sleep(level)
Comment thread
lvhan028 marked this conversation as resolved.
return Response(status_code=200)


Expand Down Expand Up @@ -1526,10 +1543,13 @@ def serve(model_path: str,
)

if api_keys is not None and (tokens := [key for key in api_keys if key]):
from lmdeploy.serve.utils.server_utils import AuthenticationMiddleware

app.add_middleware(AuthenticationMiddleware, tokens=tokens)

def is_engine_sleeping() -> bool:
eng = VariableInterface.async_engine
return eng is not None and eng.is_sleeping
app.add_middleware(EngineSleepingMiddleware, is_sleeping=is_engine_sleeping)

# set the maximum number of concurrent requests
if max_concurrent_requests is not None:
app.add_middleware(ConcurrencyLimitMiddleware, max_concurrent_requests=max_concurrent_requests)
Comment on lines +1548 to 1555
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Middleware ordering: EngineSleepingMiddleware is added before ConcurrencyLimitMiddleware, and Starlette middleware stacking makes the last added middleware outermost. That means sleeping inference requests still acquire a concurrency semaphore slot before being rejected with 503, which can unnecessarily block other endpoints (including /wakeup) under load. Consider adding EngineSleepingMiddleware after the concurrency limiter (or implementing the sleep gate inside the concurrency middleware) so rejections happen before acquiring the semaphore.

Copilot uses AI. Check for mistakes.
Expand Down
Loading
Loading