-
Notifications
You must be signed in to change notification settings - Fork 686
Reject requests on stale session or sleeping engine #4496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e0ceca3
20883ea
3fe7f01
d72fcf9
e29c31c
ea9aa7a
cbcdfa8
cf4a597
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,7 +46,7 @@ class GenOut: | |
| history_token_len: int | ||
| input_token_len: int | ||
| generate_token_len: int | ||
| finish_reason: Literal['stop', 'length', 'error'] | None = None | ||
| finish_reason: Literal['stop', 'length', 'error', 'abort'] | None = None | ||
| token_ids: list[int] | None = None | ||
| logprobs: list[dict[int, float]] | None = None | ||
| logits: Any = None | ||
|
|
@@ -201,6 +201,23 @@ def _build_stat_loggers(self): | |
| # set stats loggers of metrics processor | ||
| metrics_processor.stat_loggers = self.stat_loggers | ||
|
|
||
| def _if_session_stale(self, session: Session, | ||
| input_token_len: int) -> GenOut | None: | ||
| """If ``session.epoch`` was stamped by api_server and | ||
| ``stop_all_session`` ran since then (the engine epoch changed), drop | ||
| the session.""" | ||
| epoch = session.epoch | ||
| if epoch is None or epoch == self.epoch: | ||
| return None | ||
| logger.info(f'[generate] drop stale session {session.session_id} ' | ||
| f'(session.epoch={epoch}, async_engine.epoch={self.epoch})') | ||
| return GenOut(response='', | ||
| history_token_len=session.step, | ||
| input_token_len=input_token_len, | ||
| generate_token_len=0, | ||
| finish_reason='abort', | ||
| token_ids=[]) | ||
|
|
||
| async def get_schedule_metrics(self): | ||
| result = self.engine.get_schedule_metrics() | ||
| if asyncio.iscoroutine(result): | ||
|
|
@@ -215,19 +232,24 @@ async def do_log_stats(self): | |
|
|
||
| async def stop_all_session(self): | ||
| """Stop all running sessions.""" | ||
| logger.info('stop all sessions') | ||
| logger.info(f'stop all sessions, epoch {self.epoch} -> {self.epoch + 1}') | ||
| self.epoch += 1 | ||
| await self.session_mgr.async_abort_all() | ||
|
|
||
| def sleep(self, level: int = 1): | ||
| def prepare_sleep(self): | ||
| """Reject new inference requests before backend sleep starts.""" | ||
| self.sleeping_tags = {'weights', 'kv_cache'} | ||
| self.is_sleeping = True | ||
|
|
||
| async def sleep(self, level: int = 1): | ||
| """Sleep the model. | ||
|
|
||
| Args: | ||
| level (int): The sleep level. Level 1 sleep will offload the model | ||
| weights and discard the kv cache. Level 2 sleep will | ||
| discard both the model weights and the kv cache. | ||
| """ | ||
| self.engine.sleep(level) | ||
| await self.engine.sleep(level) | ||
|
lvhan028 marked this conversation as resolved.
|
||
| self.sleeping_tags = {'weights', 'kv_cache'} | ||
| self.is_sleeping = True | ||
|
|
||
|
|
@@ -342,7 +364,8 @@ async def generate( | |
| do_preprocess (bool): whether pre-process the messages. Default to | ||
| True, which means chat_template will be applied. | ||
| """ | ||
| epoch = self.epoch | ||
| metrics_processor.increase_total_requests() | ||
|
|
||
| if (messages is not None) ^ (input_ids is None): | ||
| raise ValueError('You must specify exactly one of messages or input_ids') | ||
| if isinstance(session_id, Session): | ||
|
Comment on lines
+367
to
371
|
||
|
|
@@ -389,6 +412,7 @@ async def generate( | |
|
|
||
| if gen_config.max_new_tokens == 0: | ||
| logger.info(f'run out of tokens. session={session_id}.') | ||
| metrics_processor.increase_failed_requests('error') | ||
| yield GenOut(response='', | ||
| history_token_len=session.step, | ||
| input_token_len=len(input_ids), | ||
|
|
@@ -403,6 +427,7 @@ async def generate( | |
| or gen_config.output_logits == 'all'): | ||
| errmsg = ('lmdeploy does not support outputting all token\'s logits or last_hidden_state ' | ||
| 'when prefix caching is ON') | ||
| metrics_processor.increase_failed_requests('error') | ||
| yield GenOut(response=errmsg, | ||
| history_token_len=session.step, | ||
| input_token_len=len(input_ids), | ||
|
|
@@ -424,10 +449,18 @@ def is_error(status): | |
| if not gen_config.ignore_eos: | ||
| stop_ids = gen_config.stop_token_ids or [] | ||
|
|
||
| metrics_processor.increase_total_requests() | ||
|
|
||
| stale = self._if_session_stale(session, len(prompt_input['input_ids'])) | ||
| if stale is not None: | ||
| metrics_processor.increase_failed_requests('abort') | ||
| yield stale | ||
| if sequence_end: | ||
| self.session_mgr.remove(session) | ||
| return | ||
| async with session.request_handle() as handle: | ||
| if epoch != self.epoch: | ||
| logger.info(f'[generate] session {session_id} got aborted before starting inference') | ||
| if session.epoch is not None and session.epoch != self.epoch: | ||
| logger.info(f'[generate] session {session_id} got aborted before starting inference, ' | ||
| f'session.epoch={session.epoch}, async_engine.epoch={self.epoch}') | ||
| metrics_processor.increase_failed_requests('abort') | ||
| yield GenOut(response='', | ||
| history_token_len=0, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
| from __future__ import annotations | ||
|
|
||
| # yapf: disable | ||
| import asyncio | ||
| import copy | ||
|
|
@@ -10,7 +12,7 @@ | |
| from contextlib import asynccontextmanager | ||
| from functools import partial | ||
| from http import HTTPStatus | ||
| from typing import Literal | ||
| from typing import TYPE_CHECKING, Literal | ||
|
|
||
| import uvicorn | ||
| from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status | ||
|
|
@@ -76,10 +78,13 @@ | |
| ) | ||
| from lmdeploy.serve.openai.reasoning_parser.reasoning_parser import ReasoningParser, ReasoningParserManager | ||
| from lmdeploy.serve.openai.tool_parser.tool_parser import ToolParser, ToolParserManager | ||
| from lmdeploy.serve.utils.server_utils import validate_json_request | ||
| from lmdeploy.serve.utils.server_utils import AuthenticationMiddleware, EngineSleepingMiddleware, validate_json_request | ||
| from lmdeploy.tokenizer import DetokenizeState, Tokenizer | ||
| from lmdeploy.utils import get_logger | ||
|
|
||
| if TYPE_CHECKING: | ||
| from lmdeploy.serve.managers import Session | ||
|
|
||
| # yapf: enable | ||
|
|
||
| logger = get_logger('lmdeploy') | ||
|
|
@@ -100,12 +105,15 @@ class VariableInterface: | |
| enable_abort_handling: bool = False | ||
|
|
||
| @staticmethod | ||
| def get_session(session_id: int) -> int: | ||
| def get_session(session_id: int) -> Session: | ||
| session_mgr = VariableInterface.get_session_manager() | ||
| if session_id == -1: | ||
| return session_mgr.get() | ||
| session = session_mgr.get() | ||
| else: | ||
| return session_mgr.get(session_id) | ||
| session = session_mgr.get(session_id) | ||
| # Stamp epoch for ``stop_all_session`` / ``abort_all`` coordination in ``AsyncEngine.generate``. | ||
| session.epoch = VariableInterface.async_engine.epoch | ||
| return session | ||
|
|
||
| @staticmethod | ||
| def get_session_manager(): | ||
|
|
@@ -769,7 +777,6 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None | |
| error_check_ret = check_request(request) | ||
| if error_check_ret is not None: | ||
| return error_check_ret | ||
|
|
||
| json_request = await raw_request.json() | ||
| migration_request = json_request.pop('migration_request', None) | ||
| with_cache = json_request.pop('with_cache', False) | ||
|
|
@@ -963,6 +970,7 @@ async def generate(request: GenerateReqInput, raw_request: Request = None): | |
| error_check_ret = check_request(request) | ||
| if error_check_ret is not None: | ||
| return error_check_ret | ||
|
|
||
| session = VariableInterface.get_session(request.session_id) | ||
|
|
||
| prompt = request.prompt | ||
|
|
@@ -1175,7 +1183,16 @@ def update_params(request: UpdateParamsRequest, raw_request: Request = None): | |
| @router.post('/sleep', dependencies=[Depends(validate_json_request)]) | ||
| async def sleep(raw_request: Request = None): | ||
| level = raw_request.query_params.get('level', '1') | ||
| VariableInterface.async_engine.sleep(int(level)) | ||
| try: | ||
| level = int(level) | ||
| except (TypeError, ValueError): | ||
| return create_error_response(HTTPStatus.BAD_REQUEST, 'The "level" query parameter must be an integer.') | ||
| if level not in (1, 2): | ||
| return create_error_response(HTTPStatus.BAD_REQUEST, 'The "level" query parameter must be 1 or 2.') | ||
| async_engine = VariableInterface.async_engine | ||
| async_engine.prepare_sleep() | ||
| await async_engine.stop_all_session() | ||
| await async_engine.sleep(level) | ||
|
lvhan028 marked this conversation as resolved.
|
||
| return Response(status_code=200) | ||
|
|
||
|
|
||
|
|
@@ -1526,10 +1543,13 @@ def serve(model_path: str, | |
| ) | ||
|
|
||
| if api_keys is not None and (tokens := [key for key in api_keys if key]): | ||
| from lmdeploy.serve.utils.server_utils import AuthenticationMiddleware | ||
|
|
||
| app.add_middleware(AuthenticationMiddleware, tokens=tokens) | ||
|
|
||
| def is_engine_sleeping() -> bool: | ||
| eng = VariableInterface.async_engine | ||
| return eng is not None and eng.is_sleeping | ||
| app.add_middleware(EngineSleepingMiddleware, is_sleeping=is_engine_sleeping) | ||
|
|
||
| # set the maximum number of concurrent requests | ||
| if max_concurrent_requests is not None: | ||
| app.add_middleware(ConcurrencyLimitMiddleware, max_concurrent_requests=max_concurrent_requests) | ||
|
Comment on lines
+1548
to
1555
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.