From c0a0e42b4ab52f147d6b8e828cd0cc79770868b8 Mon Sep 17 00:00:00 2001 From: Raphael Glon Date: Fri, 11 Apr 2025 11:04:07 +0200 Subject: [PATCH 1/4] feat(hf-inference): fork for hf-inference optim (overcommit) and widget compat * Env var settings: customize default num inference steps default content type env var default accept env var Diffusers, txt2img (and img2img when supported), make sure guidance scale defaults to 0 when num steps <=4 * Content-type / accept / serialization fixes: content type case ignore fix: content-type and accept parsing, more flexibility than an exact string match since there can be some additional params application/octet-stream support in content type deserialization, no reason not to accept it fix: avoid returning none as a serializer, return an error instead fix: de/serializer is not optional, do not support content type which we do not know what to do with fix: explicit error message when no content-type is provided * HF inference specificities Multi task support + /pipeline/ support for api-inference backward compat api inference compat responses fix(api inference): compat for text-classification token-classification fix: token classification api-inference-compat fix: image segmentation on hf inference zero shot classif: api inference compat substitute /pipeline/sentence-embeddings to /pipeline/feature-extraction for sentence transformers fix(api-inference): feature-extraction, flatten array, discard the batch size dim feat(hf-inference): disable custom handler * Build: add timm hf_xet dependencies (for object detection, xethub support) Dockerfile refacto: split requirements and source code layers, to optimize build time and enhance layer reuse * Memory footprint + kick and respawn (primary memory gc) feat(memory): reduce memory footprint on idle service backported and adapted from https://github.com/huggingface/api-inference-community/blob/main/docker_images/diffusers/app/idle.py 1. adding gunicorn instead of uvicorn to allow for wsgi/asgi workers to easily be suppressed when idle whithout stopping the entire service -> easy way to release memory whithout digging into the depth of the imported modules 2. memory consuming libs lazy load (transformers, diffusers, sentence_transformers) 3. pipeline lazy load as well The first 'cold start' request tends to be a bit slower than others but the footprint is reduced to the minimum when idle --- dockerfiles/pytorch/Dockerfile | 8 +- requirements.txt | 22 ++ scripts/entrypoint.sh | 2 +- setup.py | 49 ++--- .../diffusers_utils.py | 11 + .../env_utils.py | 11 + src/huggingface_inference_toolkit/handler.py | 99 ++++++++- .../heavy_utils.py | 187 +++++++++++++++++ src/huggingface_inference_toolkit/idle.py | 58 ++++++ .../sentence_transformers_utils.py | 12 +- .../serialization/base.py | 44 ++-- src/huggingface_inference_toolkit/utils.py | 188 +----------------- .../webservice_starlette.py | 108 +++++++--- test-requirements.txt | 12 ++ tests/integ/conftest.py | 4 +- tests/integ/helpers.py | 12 +- tests/unit/test_diffusers.py | 8 +- tests/unit/test_handler.py | 20 +- tests/unit/test_optimum_utils.py | 8 +- tests/unit/test_sentence_transformers.py | 20 +- tests/unit/test_utils.py | 24 +-- 21 files changed, 584 insertions(+), 323 deletions(-) create mode 100644 requirements.txt create mode 100644 src/huggingface_inference_toolkit/heavy_utils.py create mode 100644 src/huggingface_inference_toolkit/idle.py create mode 100644 test-requirements.txt diff --git a/dockerfiles/pytorch/Dockerfile b/dockerfiles/pytorch/Dockerfile index 12cb541e..d63ab519 100644 --- a/dockerfiles/pytorch/Dockerfile +++ b/dockerfiles/pytorch/Dockerfile @@ -32,8 +32,7 @@ RUN apt-get update && \ && apt-get clean autoremove --yes \ && rm -rf /var/lib/{apt,dpkg,cache,log} -# Copying only necessary files as filtered by .dockerignore -COPY . . +RUN mkdir -p /var/lib/dpkg && touch /var/lib/dpkg/status # Set Python 3.11 as the default python version RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 && \ @@ -47,6 +46,11 @@ RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ # Upgrade pip RUN pip install --no-cache-dir --upgrade pip +COPY requirements.txt . +RUN pip install -r requirements.txt && rm -rf /root/.cache + +# Copying only necessary files as filtered by .dockerignore +COPY . . # Install wheel and setuptools RUN pip install --no-cache-dir --upgrade pip ".[torch,st,diffusers]" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..8acfe504 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,22 @@ +kenlm@ git+https://github.com/kpu/kenlm@ba83eafdce6553addd885ed3da461bb0d60f8df7 +transformers[audio,sentencepiece,sklearn,vision]==4.51.3 +huggingface_hub[hf_transfer,hf_xet]==0.31.1 +Pillow +librosa +pyctcdecode>=0.3.0 +phonemizer +ffmpeg +starlette +uvicorn +gunicorn +pandas +orjson +einops +timm +sentence_transformers==4.0.2 +diffusers==0.33.1 +accelerate==1.6.0 +torch==2.5.1 +torchvision +torchaudio +peft==0.15.1 diff --git a/scripts/entrypoint.sh b/scripts/entrypoint.sh index 68969353..20aedf9f 100755 --- a/scripts/entrypoint.sh +++ b/scripts/entrypoint.sh @@ -59,4 +59,4 @@ if [[ ! -z "${HF_MODEL_DIR}" ]]; then fi # Start the server -exec uvicorn webservice_starlette:app --host 0.0.0.0 --port ${PORT} +exec gunicorn webservice_starlette:app -k uvicorn.workers.UvicornWorker --workers ${WORKERS:-1} --bind 0.0.0.0:${PORT} diff --git a/setup.py b/setup.py index 7ad7d6a0..3676b95e 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,19 @@ from __future__ import absolute_import - +import os from setuptools import find_packages, setup +lib_folder = os.path.dirname(os.path.realpath(__file__)) +requirements_path = f"{lib_folder}/requirements.txt" +install_requires = [] # Here we'll add: ["gunicorn", "docutils>=0.3", "lxml==0.5a7"] +if os.path.isfile(requirements_path): + with open(requirements_path) as f: + install_requires = f.read().splitlines() + +test_requirements_path = f"{lib_folder}/test-requirements.txt" +if os.path.isfile(test_requirements_path): + with open(test_requirements_path) as f: + test_requirements = f.read().splitlines() + # We don't declare our dependency on transformers here because we build with # different packages for different variants @@ -12,47 +24,14 @@ # ffmpeg: ffmpeg is required for audio processing. On Ubuntu it can be installed as follows: apt install ffmpeg # libavcodec-extra : libavcodec-extra includes additional codecs for ffmpeg -install_requires = [ - # Due to an error affecting kenlm and cmake (see https://github.com/kpu/kenlm/pull/464) - # Also see the transformers patch for it https://github.com/huggingface/transformers/pull/37091 - "kenlm@git+https://github.com/kpu/kenlm@ba83eafdce6553addd885ed3da461bb0d60f8df7", - "transformers[sklearn,sentencepiece,audio,vision]==4.51.3", - "huggingface_hub[hf_transfer]==0.30.2", - # vision - "Pillow", - "librosa", - # speech + torchaudio - "pyctcdecode>=0.3.0", - "phonemizer", - "ffmpeg", - # web api - "starlette", - "uvicorn", - "pandas", - "orjson", - "einops", -] - extras = {} - extras["st"] = ["sentence_transformers==4.0.2"] extras["diffusers"] = ["diffusers==0.33.1", "accelerate==1.6.0"] # Includes `peft` as PEFT requires `torch` so having `peft` as a core dependency # means that `torch` will be installed even if the `torch` extra is not specified. extras["torch"] = ["torch==2.5.1", "torchvision", "torchaudio", "peft==0.15.1"] -extras["test"] = [ - "pytest==7.2.1", - "pytest-xdist", - "parameterized", - "psutil", - "datasets", - "pytest-sugar", - "mock==2.0.0", - "docker", - "requests", - "tenacity", -] extras["quality"] = ["isort", "ruff"] +extras["test"] = test_requirements extras["inf2"] = ["optimum-neuron"] extras["google"] = ["google-cloud-storage", "crcmod==1.7"] diff --git a/src/huggingface_inference_toolkit/diffusers_utils.py b/src/huggingface_inference_toolkit/diffusers_utils.py index 61886659..f10cadc0 100644 --- a/src/huggingface_inference_toolkit/diffusers_utils.py +++ b/src/huggingface_inference_toolkit/diffusers_utils.py @@ -1,4 +1,5 @@ import importlib.util +import os from typing import Union from transformers.utils.import_utils import is_torch_bf16_gpu_available @@ -63,6 +64,16 @@ def __call__( kwargs.pop("num_images_per_prompt") logger.warning("Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1.") + if "num_inference_steps" not in kwargs: + default_num_steps = os.environ.get("DEFAULT_NUM_INFERENCE_STEPS") + if default_num_steps: + kwargs["num_inference_steps"] = int(default_num_steps) + + if "guidance_scale" not in kwargs: + guidance_scale = os.environ.get("DEFAULT_GUIDANCE_SCALE") + if guidance_scale is not None: + kwargs["guidance_scale"] = float(guidance_scale) + if "target_size" in kwargs: kwargs["height"] = kwargs["target_size"].pop("height") kwargs["width"] = kwargs["target_size"].pop("width") diff --git a/src/huggingface_inference_toolkit/env_utils.py b/src/huggingface_inference_toolkit/env_utils.py index e582ec98..fd2b9da7 100644 --- a/src/huggingface_inference_toolkit/env_utils.py +++ b/src/huggingface_inference_toolkit/env_utils.py @@ -1,3 +1,6 @@ +import os + + def strtobool(val: str) -> bool: """Convert a string representation of truth to True or False booleans. True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values @@ -20,3 +23,11 @@ def strtobool(val: str) -> bool: raise ValueError( f"Invalid truth value, it should be a string but {val} was provided instead." ) + + +def api_inference_compat(): + return strtobool(os.getenv("API_INFERENCE_COMPAT", "false")) + + +def ignore_custom_handler(): + return strtobool(os.getenv("IGNORE_CUSTOM_HANDLER", "false")) diff --git a/src/huggingface_inference_toolkit/handler.py b/src/huggingface_inference_toolkit/handler.py index 43a979bd..3ff104d5 100644 --- a/src/huggingface_inference_toolkit/handler.py +++ b/src/huggingface_inference_toolkit/handler.py @@ -2,12 +2,10 @@ from pathlib import Path from typing import Any, Dict, Literal, Optional, Union +from huggingface_inference_toolkit import logging from huggingface_inference_toolkit.const import HF_TRUST_REMOTE_CODE -from huggingface_inference_toolkit.sentence_transformers_utils import SENTENCE_TRANSFORMERS_TASKS -from huggingface_inference_toolkit.utils import ( - check_and_register_custom_pipeline_from_directory, - get_pipeline, -) +from huggingface_inference_toolkit.env_utils import api_inference_compat, ignore_custom_handler +from huggingface_inference_toolkit.utils import check_and_register_custom_pipeline_from_directory class HuggingFaceHandler: @@ -19,6 +17,7 @@ class HuggingFaceHandler: def __init__( self, model_dir: Union[str, Path], task: Union[str, None] = None, framework: Literal["pt"] = "pt" ) -> None: + from huggingface_inference_toolkit.heavy_utils import get_pipeline self.pipeline = get_pipeline( model_dir=model_dir, # type: ignore task=task, # type: ignore @@ -26,13 +25,17 @@ def __init__( trust_remote_code=HF_TRUST_REMOTE_CODE, ) - def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + def __call__(self, data: Dict[str, Any]): """ Handles an inference request with input data and makes a prediction. Args: :data: (obj): the raw request body data. :return: prediction output """ + + # import as late as possible to reduce the footprint + from huggingface_inference_toolkit.sentence_transformers_utils import SENTENCE_TRANSFORMERS_TASKS + inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) @@ -101,9 +104,82 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: "or `candidateLabels`." ) - return ( - self.pipeline(**inputs, **parameters) if isinstance(inputs, dict) else self.pipeline(inputs, **parameters) # type: ignore - ) + if api_inference_compat(): + if self.pipeline.task == "text-classification" and isinstance(inputs, str): + inputs = [inputs] + parameters.setdefault("top_k", os.environ.get("DEFAULT_TOP_K", 5)) + if self.pipeline.task == "token-classification": + parameters.setdefault("aggregation_strategy", os.environ.get("DEFAULT_AGGREGATION_STRATEGY", "simple")) + + resp = self.pipeline(**inputs, **parameters) if isinstance(inputs, dict) else \ + self.pipeline(inputs, **parameters) + + if api_inference_compat(): + if self.pipeline.task == "text-classification": + # We don't want to return {} but [{}] in any case + if isinstance(resp, list) and len(resp) > 0: + if not isinstance(resp[0], list): + return [resp] + return resp + if self.pipeline.task == "feature-extraction": + # If the library used is Transformers then the feature-extraction is returning the headless encoder + # outputs as embeddings. The shape is a 3D or 4D array + # [n_inputs, batch_size = 1, n_sentence_tokens, num_hidden_dim]. + # Let's just discard the batch size dim that always seems to be 1 and return a 2D/3D array + # https://github.com/huggingface/transformers/blob/5c47d08b0d6835b8d8fc1c06d9a1bc71f6e78ace/src/transformers/pipelines/feature_extraction.py#L27 + # for api inference (reason: mainly display) + new_resp = [] + if isinstance(inputs, list): + if isinstance(resp, list) and len(resp) == len(inputs): + for it in resp: + # Batch size dim is the first it level, discard it + if isinstance(it, list) and len(it) == 1: + new_resp.append(it[0]) + else: + logging.logger.warning("One of the output batch size differs from 1: %d", len(it)) + return resp + return new_resp + else: + logging.logger.warning("Inputs and resp len differ (or resp is not a list, type %s)", + type(resp)) + return resp + elif isinstance(inputs, str): + if isinstance(resp, list) and len(resp) == 1: + return resp[0] + else: + logging.logger.warning("The output batch size differs from 1: %d", len(resp)) + return resp + else: + logging.logger.warning("Output unexpected type %s", type(resp)) + return resp + if self.pipeline.task == "image-segmentation": + if isinstance(resp, list): + new_resp = [] + for el in resp: + if isinstance(el, dict) and el.get("score") is None: + el["score"] = 1 + new_resp.append(el) + resp = new_resp + if self.pipeline.task == "zero-shot-classification": + try: + if isinstance(resp, dict): + if 'labels' in resp and 'scores' in resp: + labels = resp['labels'] + scores = resp['scores'] + if len(labels) == len(scores): + new_resp = [] + for label, score in zip(labels, scores): + new_resp.append({"label": label, "score": score}) + resp = new_resp + else: + raise Exception("labels and scores do not have the same len, {} != {}".format( + len(labels), len(scores))) + else: + raise Exception("Missing labels or scores key in response dict {}".format(resp)) + except Exception as e: + logging.logger.warning("Unable to remap response for api inference compat") + logging.logger.exception(e) + return resp class VertexAIHandler(HuggingFaceHandler): @@ -149,7 +225,10 @@ def get_inference_handler_either_custom_or_default_handler(model_dir: Path, task Returns: InferenceHandler: The appropriate inference handler based on the given model directory and task. """ - custom_pipeline = check_and_register_custom_pipeline_from_directory(model_dir) + if ignore_custom_handler(): + custom_pipeline = None + else: + custom_pipeline = check_and_register_custom_pipeline_from_directory(model_dir) if custom_pipeline is not None: return custom_pipeline diff --git a/src/huggingface_inference_toolkit/heavy_utils.py b/src/huggingface_inference_toolkit/heavy_utils.py new file mode 100644 index 00000000..c144082e --- /dev/null +++ b/src/huggingface_inference_toolkit/heavy_utils.py @@ -0,0 +1,187 @@ +# Heavy because they consume a lot of memory and we want to import them as late as possible to reduce the footprint +# Transformers / Sentence transformers utils. This module should be imported as late as possible +# to reduce the memory footprint of a worker: we don't bother handling the uncaching/gc collecting because +# we want to combine it with idle unload: the gunicorn worker will just suppress itself when unused freeing the memory +# as wished +from pathlib import Path +from typing import Optional, Union + +from huggingface_hub import HfApi, login, snapshot_download +from transformers import WhisperForConditionalGeneration, pipeline +from transformers.file_utils import is_tf_available, is_torch_available +from transformers.pipelines import Pipeline + +from huggingface_inference_toolkit.diffusers_utils import ( + get_diffusers_pipeline, + is_diffusers_available, +) +from huggingface_inference_toolkit.logging import logger +from huggingface_inference_toolkit.optimum_utils import ( + get_optimum_neuron_pipeline, + is_optimum_neuron_available, +) +from huggingface_inference_toolkit.sentence_transformers_utils import ( + get_sentence_transformers_pipeline, + is_sentence_transformers_available, +) +from huggingface_inference_toolkit.utils import create_artifact_filter + + +def load_repository_from_hf( + repository_id: Optional[str] = None, + target_dir: Optional[Union[str, Path]] = None, + framework: Optional[str] = None, + revision: Optional[str] = None, + hf_hub_token: Optional[str] = None, +): + """ + Load a model from huggingface hub. + """ + + if hf_hub_token is not None: + login(token=hf_hub_token) + + if framework is None: + framework = _get_framework() + + if isinstance(target_dir, str): + target_dir = Path(target_dir) + + # create workdir + if not target_dir.exists(): + target_dir.mkdir(parents=True) + + # check if safetensors weights are available + if framework == "pytorch": + files = HfApi().model_info(repository_id).siblings + if any(f.rfilename.endswith("safetensors") for f in files): + framework = "safetensors" + + # create regex to only include the framework specific weights + ignore_regex = create_artifact_filter(framework) + logger.info(f"Ignore regex pattern for files, which are not downloaded: {', '.join(ignore_regex)}") + + # Download the repository to the workdir and filter out non-framework + # specific weights + snapshot_download( + repo_id=repository_id, + revision=revision, + local_dir=str(target_dir), + local_dir_use_symlinks=False, + ignore_patterns=ignore_regex, + ) + + return target_dir + + +def get_device(): + """ + The get device function will return the device for the DL Framework. + """ + gpu = _is_gpu_available() + + if gpu: + return 0 + else: + return -1 + + +if is_tf_available(): + import tensorflow as tf + + +if is_torch_available(): + import torch + + +def _is_gpu_available(): + """ + checks if a gpu is available. + """ + if is_tf_available(): + return True if len(tf.config.list_physical_devices("GPU")) > 0 else False + elif is_torch_available(): + return torch.cuda.is_available() + else: + raise RuntimeError( + "At least one of TensorFlow 2.0 or PyTorch should be installed. " + "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " + "To install PyTorch, read the instructions at https://pytorch.org/." + ) + + +def _get_framework(): + """ + extracts which DL framework is used for inference, if both are installed use pytorch + """ + + if is_torch_available(): + return "pytorch" + elif is_tf_available(): + return "tensorflow" + else: + raise RuntimeError( + "At least one of TensorFlow 2.0 or PyTorch should be installed. " + "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " + "To install PyTorch, read the instructions at https://pytorch.org/." + ) + + +def get_pipeline( + task: Union[str, None], + model_dir: Path, + **kwargs, +) -> Pipeline: + """ + create pipeline class for a specific task based on local saved model + """ + if task is None: + raise EnvironmentError( + "The task for this model is not set: Please set one: https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined" + ) + + if task == "conversational": + task = "text-generation" + + if is_optimum_neuron_available(): + logger.info("Using device Neuron") + return get_optimum_neuron_pipeline(task=task, model_dir=model_dir) + + device = get_device() + logger.info(f"Using device {'GPU' if device == 0 else 'CPU'}") + + # define tokenizer or feature extractor as kwargs to load it the pipeline + # correctly + if task in { + "automatic-speech-recognition", + "image-segmentation", + "image-classification", + "audio-classification", + "object-detection", + "zero-shot-image-classification", + }: + kwargs["feature_extractor"] = model_dir + elif task not in {"image-text-to-text", "image-to-text", "text-to-image"}: + kwargs["tokenizer"] = model_dir + + if is_sentence_transformers_available() and task in [ + "sentence-similarity", + "sentence-embeddings", + "sentence-ranking", + "text-ranking", + ]: + if task == "text-ranking": + task = "sentence-ranking" + hf_pipeline = get_sentence_transformers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs) + elif is_diffusers_available() and task == "text-to-image": + hf_pipeline = get_diffusers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs) + else: + hf_pipeline = pipeline(task=task, model=model_dir, device=device, **kwargs) + + if task == "automatic-speech-recognition" and isinstance(hf_pipeline.model, WhisperForConditionalGeneration): + # set chunk length to 30s for whisper to enable long audio files + hf_pipeline._preprocess_params["chunk_length_s"] = 30 + hf_pipeline.model.config.forced_decoder_ids = hf_pipeline.tokenizer.get_decoder_prompt_ids( + language="english", task="transcribe" + ) + return hf_pipeline # type: ignore \ No newline at end of file diff --git a/src/huggingface_inference_toolkit/idle.py b/src/huggingface_inference_toolkit/idle.py new file mode 100644 index 00000000..a861d527 --- /dev/null +++ b/src/huggingface_inference_toolkit/idle.py @@ -0,0 +1,58 @@ +import asyncio +import contextlib +import logging +import os +import signal +import time + +LOG = logging.getLogger(__name__) + +LAST_START = None +LAST_END = None + +UNLOAD_IDLE = os.getenv("UNLOAD_IDLE", "").lower() in ("1", "true") +IDLE_TIMEOUT = int(os.getenv("IDLE_TIMEOUT", 15)) + + +async def live_check_loop(): + global LAST_START, LAST_END + + pid = os.getpid() + + LOG.debug("Starting live check loop") + sleep_time = max(int(IDLE_TIMEOUT // 5), 1) + + while True: + await asyncio.sleep(sleep_time) + LOG.debug("Checking whether we should unload anything from gpu") + + last_start = LAST_START + last_end = LAST_END + + LOG.debug("Checking pid %d activity", pid) + if not last_start: + continue + if not last_end or last_start >= last_end: + LOG.debug("Request likely being processed for pid %d", pid) + continue + now = time.time() + last_request_age = now - last_end + LOG.debug("Pid %d, last request age %s", pid, last_request_age) + if last_request_age < IDLE_TIMEOUT: + LOG.debug("Model recently active") + else: + LOG.debug("Inactive for too long. Leaving live check loop") + break + LOG.debug("Aborting this worker") + os.kill(pid, signal.SIGTERM) + + +@contextlib.contextmanager +def request_witnesses(): + global LAST_START, LAST_END + # Simple assignment, concurrency safe, no need for any lock + LAST_START = time.time() + try: + yield + finally: + LAST_END = time.time() diff --git a/src/huggingface_inference_toolkit/sentence_transformers_utils.py b/src/huggingface_inference_toolkit/sentence_transformers_utils.py index 0d648420..5857a940 100644 --- a/src/huggingface_inference_toolkit/sentence_transformers_utils.py +++ b/src/huggingface_inference_toolkit/sentence_transformers_utils.py @@ -1,6 +1,8 @@ import importlib.util from typing import Any, Dict, List, Tuple, Union +from huggingface_inference_toolkit.env_utils import api_inference_compat + try: from typing import Literal except ImportError: @@ -26,7 +28,10 @@ def __call__(self, source_sentence: str, sentences: List[str]) -> Dict[str, floa embeddings1 = self.model.encode(source_sentence, convert_to_tensor=True) embeddings2 = self.model.encode(sentences, convert_to_tensor=True) similarities = util.pytorch_cos_sim(embeddings1, embeddings2).tolist()[0] - return {"similarities": similarities} + if api_inference_compat(): + return similarities + else: + return {"similarities": similarities} class SentenceEmbeddingPipeline: @@ -36,7 +41,10 @@ def __init__(self, model_dir: str, device: Union[str, None] = None, **kwargs: An def __call__(self, sentences: Union[str, List[str]]) -> Dict[str, List[float]]: embeddings = self.model.encode(sentences).tolist() - return {"embeddings": embeddings} + if api_inference_compat(): + return embeddings + else: + return {"embeddings": embeddings} class SentenceRankingPipeline: diff --git a/src/huggingface_inference_toolkit/serialization/base.py b/src/huggingface_inference_toolkit/serialization/base.py index 8be40173..c6e58075 100644 --- a/src/huggingface_inference_toolkit/serialization/base.py +++ b/src/huggingface_inference_toolkit/serialization/base.py @@ -4,9 +4,6 @@ content_type_mapping = { "application/json": Jsoner, - "application/json; charset=UTF-8": Jsoner, - "text/csv": None, - "text/plain": None, # image types "image/png": Imager, "image/jpeg": Imager, @@ -39,25 +36,44 @@ class ContentType: @staticmethod - def get_deserializer(content_type): - if content_type in content_type_mapping: - return content_type_mapping[content_type] - else: + def get_deserializer(content_type: str, task: str): + if not content_type: + message = f"No content type provided and no default one configured." + raise Exception(message) + if content_type.lower().startswith("application/octet-stream"): + if "audio" in task or "speech" in task: + return Audioer + elif "image" in task: + return Imager message = f""" - Content type "{content_type}" not supported. + Content type "{content_type}" not supported for task {task}. Supported content types are: {", ".join(list(content_type_mapping.keys()))} """ raise Exception(message) - @staticmethod - def get_serializer(accept): - if accept in content_type_mapping: - return content_type_mapping[accept] + # Extract media type from content type + extracted = content_type.split(";")[0] + if extracted in content_type_mapping: + return content_type_mapping[extracted] else: message = f""" - Accept type "{accept}" not supported. - Supported accept types are: + Content type "{content_type}" not supported. + Supported content types are: {", ".join(list(content_type_mapping.keys()))} """ raise Exception(message) + + @staticmethod + def get_serializer(accept: str): + extracts = accept.split(",") + for extract in extracts: + extracted = extract.split(";")[0] + if extracted in content_type_mapping: + return content_type_mapping[extracted] + message = f""" + Accept type "{accept}" not supported. + Supported accept types are: + {", ".join(list(content_type_mapping.keys()))} + """ + raise Exception(message) diff --git a/src/huggingface_inference_toolkit/utils.py b/src/huggingface_inference_toolkit/utils.py index 2b4cd394..1e96b5e1 100644 --- a/src/huggingface_inference_toolkit/utils.py +++ b/src/huggingface_inference_toolkit/utils.py @@ -1,33 +1,9 @@ import importlib.util import sys from pathlib import Path -from typing import Optional, Union - -from huggingface_hub import HfApi, login, snapshot_download -from transformers import WhisperForConditionalGeneration, pipeline -from transformers.file_utils import is_tf_available, is_torch_available -from transformers.pipelines import Pipeline from huggingface_inference_toolkit.const import HF_DEFAULT_PIPELINE_NAME, HF_MODULE_NAME -from huggingface_inference_toolkit.diffusers_utils import ( - get_diffusers_pipeline, - is_diffusers_available, -) from huggingface_inference_toolkit.logging import logger -from huggingface_inference_toolkit.optimum_utils import ( - get_optimum_neuron_pipeline, - is_optimum_neuron_available, -) -from huggingface_inference_toolkit.sentence_transformers_utils import ( - get_sentence_transformers_pipeline, - is_sentence_transformers_available, -) - -if is_tf_available(): - import tensorflow as tf - -if is_torch_available(): - import torch _optimum_available = importlib.util.find_spec("optimum") is not None @@ -57,7 +33,7 @@ def is_optimum_available(): def create_artifact_filter(framework): """ - Returns a list of regex pattern based on the DL Framework. which will be to used to ignore files when downloading + Returns a list of regex pattern based on the DL Framework. which will be used to ignore files when downloading """ ignore_regex_list = list(set(framework2weight.values())) @@ -69,86 +45,6 @@ def create_artifact_filter(framework): return [] -def _is_gpu_available(): - """ - checks if a gpu is available. - """ - if is_tf_available(): - return True if len(tf.config.list_physical_devices("GPU")) > 0 else False - elif is_torch_available(): - return torch.cuda.is_available() - else: - raise RuntimeError( - "At least one of TensorFlow 2.0 or PyTorch should be installed. " - "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " - "To install PyTorch, read the instructions at https://pytorch.org/." - ) - - -def _get_framework(): - """ - extracts which DL framework is used for inference, if both are installed use pytorch - """ - - if is_torch_available(): - return "pytorch" - elif is_tf_available(): - return "tensorflow" - else: - raise RuntimeError( - "At least one of TensorFlow 2.0 or PyTorch should be installed. " - "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " - "To install PyTorch, read the instructions at https://pytorch.org/." - ) - - -def _load_repository_from_hf( - repository_id: Optional[str] = None, - target_dir: Optional[Union[str, Path]] = None, - framework: Optional[str] = None, - revision: Optional[str] = None, - hf_hub_token: Optional[str] = None, -): - """ - Load a model from huggingface hub. - """ - - if hf_hub_token is not None: - login(token=hf_hub_token) - - if framework is None: - framework = _get_framework() - - if isinstance(target_dir, str): - target_dir = Path(target_dir) - - # create workdir - if not target_dir.exists(): - target_dir.mkdir(parents=True) - - # check if safetensors weights are available - if framework == "pytorch": - files = HfApi().model_info(repository_id).siblings - if any(f.rfilename.endswith("safetensors") for f in files): - framework = "safetensors" - - # create regex to only include the framework specific weights - ignore_regex = create_artifact_filter(framework) - logger.info(f"Ignore regex pattern for files, which are not downloaded: {', '.join(ignore_regex)}") - - # Download the repository to the workdir and filter out non-framework - # specific weights - snapshot_download( - repo_id=repository_id, - revision=revision, - local_dir=str(target_dir), - local_dir_use_symlinks=False, - ignore_patterns=ignore_regex, - ) - - return target_dir - - def check_and_register_custom_pipeline_from_directory(model_dir): """ Checks if a custom pipeline is available and registers it if so. @@ -156,11 +52,12 @@ def check_and_register_custom_pipeline_from_directory(model_dir): # path to custom handler custom_module = Path(model_dir).joinpath(HF_DEFAULT_PIPELINE_NAME) legacy_module = Path(model_dir).joinpath("pipeline.py") + custom_pipeline = None if custom_module.is_file(): logger.info(f"Found custom pipeline at {custom_module}") spec = importlib.util.spec_from_file_location(HF_MODULE_NAME, custom_module) if spec: - # add the whole directory to path for submodlues + # add the whole directory to path for submodules sys.path.insert(0, model_dir) # import custom handler handler = importlib.util.module_from_spec(spec) @@ -168,7 +65,8 @@ def check_and_register_custom_pipeline_from_directory(model_dir): spec.loader.exec_module(handler) # init custom handler with model_dir custom_pipeline = handler.EndpointHandler(model_dir) - + else: + logger.info(f"No spec from file location found for module %s, file %s", HF_MODULE_NAME, custom_module) elif legacy_module.is_file(): logger.warning( """You are using a legacy custom pipeline. @@ -177,7 +75,7 @@ def check_and_register_custom_pipeline_from_directory(model_dir): ) spec = importlib.util.spec_from_file_location("pipeline.PreTrainedPipeline", legacy_module) if spec: - # add the whole directory to path for submodlues + # add the whole directory to path for submodules sys.path.insert(0, model_dir) # import custom handler pipeline = importlib.util.module_from_spec(spec) @@ -187,80 +85,8 @@ def check_and_register_custom_pipeline_from_directory(model_dir): custom_pipeline = pipeline.PreTrainedPipeline(model_dir) else: logger.info(f"No custom pipeline found at {custom_module}") - custom_pipeline = None - return custom_pipeline - - -def get_device(): - """ - The get device function will return the device for the DL Framework. - """ - gpu = _is_gpu_available() - - if gpu: - return 0 - else: - return -1 - -def get_pipeline( - task: Union[str, None], - model_dir: Path, - **kwargs, -) -> Pipeline: - """ - create pipeline class for a specific task based on local saved model - """ - if task is None: - raise EnvironmentError( - "The task for this model is not set: Please set one: https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined" - ) - - if task == "conversational": - task = "text-generation" - - if is_optimum_neuron_available(): - logger.info("Using device Neuron") - return get_optimum_neuron_pipeline(task=task, model_dir=model_dir) - - device = get_device() - logger.info(f"Using device {'GPU' if device == 0 else 'CPU'}") - - # define tokenizer or feature extractor as kwargs to load it the pipeline - # correctly - if task in { - "automatic-speech-recognition", - "image-segmentation", - "image-classification", - "audio-classification", - "object-detection", - "zero-shot-image-classification", - }: - kwargs["feature_extractor"] = model_dir - elif task not in {"image-text-to-text", "image-to-text", "text-to-image"}: - kwargs["tokenizer"] = model_dir - - if is_sentence_transformers_available() and task in [ - "sentence-similarity", - "sentence-embeddings", - "sentence-ranking", - "text-ranking", - ]: - if task == "text-ranking": - task = "sentence-ranking" - hf_pipeline = get_sentence_transformers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs) - elif is_diffusers_available() and task == "text-to-image": - hf_pipeline = get_diffusers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs) - else: - hf_pipeline = pipeline(task=task, model=model_dir, device=device, **kwargs) - - if task == "automatic-speech-recognition" and isinstance(hf_pipeline.model, WhisperForConditionalGeneration): - # set chunk length to 30s for whisper to enable long audio files - hf_pipeline._preprocess_params["chunk_length_s"] = 30 - hf_pipeline.model.config.forced_decoder_ids = hf_pipeline.tokenizer.get_decoder_prompt_ids( - language="english", task="transcribe" - ) - return hf_pipeline # type: ignore + return custom_pipeline def convert_params_to_int_or_bool(params): diff --git a/src/huggingface_inference_toolkit/webservice_starlette.py b/src/huggingface_inference_toolkit/webservice_starlette.py index a0c3f5fd..5dd4f800 100644 --- a/src/huggingface_inference_toolkit/webservice_starlette.py +++ b/src/huggingface_inference_toolkit/webservice_starlette.py @@ -1,5 +1,7 @@ +import asyncio import base64 import os +import threading from pathlib import Path from time import perf_counter @@ -8,6 +10,7 @@ from starlette.responses import PlainTextResponse, Response from starlette.routing import Route +from huggingface_inference_toolkit import idle from huggingface_inference_toolkit.async_utils import MAX_CONCURRENT_THREADS, MAX_THREADS_GUARD, async_handler_call from huggingface_inference_toolkit.const import ( HF_FRAMEWORK, @@ -17,26 +20,46 @@ HF_REVISION, HF_TASK, ) +from huggingface_inference_toolkit.env_utils import api_inference_compat from huggingface_inference_toolkit.handler import ( get_inference_handler_either_custom_or_default_handler, ) from huggingface_inference_toolkit.logging import logger from huggingface_inference_toolkit.serialization.base import ContentType from huggingface_inference_toolkit.serialization.json_utils import Jsoner -from huggingface_inference_toolkit.utils import ( - _load_repository_from_hf, - convert_params_to_int_or_bool, -) +from huggingface_inference_toolkit.utils import convert_params_to_int_or_bool from huggingface_inference_toolkit.vertex_ai_utils import _load_repository_from_gcs +INFERENCE_HANDLERS = {} +INFERENCE_HANDLERS_LOCK = threading.Lock() +MODEL_DOWNLOADED = False +MODEL_DL_LOCK = threading.Lock() + async def prepare_model_artifacts(): - global inference_handler + global INFERENCE_HANDLERS + + if idle.UNLOAD_IDLE: + asyncio.create_task(idle.live_check_loop(), name="live_check_loop") + else: + _eager_model_dl() + logger.info(f"Initializing model from directory:{HF_MODEL_DIR}") + # 2. determine correct inference handler + inference_handler = get_inference_handler_either_custom_or_default_handler( + HF_MODEL_DIR, task=HF_TASK + ) + INFERENCE_HANDLERS[HF_TASK] = inference_handler + logger.info("Model initialized successfully") + + +def _eager_model_dl(): + global MODEL_DOWNLOADED + from huggingface_inference_toolkit.heavy_utils import load_repository_from_hf # 1. check if model artifacts available in HF_MODEL_DIR if len(list(Path(HF_MODEL_DIR).glob("**/*"))) <= 0: # 2. if not available, try to load from HF_MODEL_ID if HF_MODEL_ID is not None: - _load_repository_from_hf( + load_repository_from_hf( repository_id=HF_MODEL_ID, target_dir=HF_MODEL_DIR, framework=HF_FRAMEWORK, @@ -53,17 +76,11 @@ async def prepare_model_artifacts(): else: raise ValueError( f"""Can't initialize model. - Please set env HF_MODEL_DIR or provider a HF_MODEL_ID. - Provided values are: - HF_MODEL_DIR: {HF_MODEL_DIR} and HF_MODEL_ID:{HF_MODEL_ID}""" + Please set env HF_MODEL_DIR or provider a HF_MODEL_ID. + Provided values are: + HF_MODEL_DIR: {HF_MODEL_DIR} and HF_MODEL_ID:{HF_MODEL_ID}""" ) - - logger.info(f"Initializing model from directory:{HF_MODEL_DIR}") - # 2. determine correct inference handler - inference_handler = get_inference_handler_either_custom_or_default_handler( - HF_MODEL_DIR, task=HF_TASK - ) - logger.info("Model initialized successfully") + MODEL_DOWNLOADED = True async def health(request): @@ -83,11 +100,16 @@ async def metrics(request): async def predict(request): + global INFERENCE_HANDLERS + if not MODEL_DOWNLOADED: + with MODEL_DL_LOCK: + _eager_model_dl() try: + task = request.path_params.get("task", HF_TASK) # extracts content from request - content_type = request.headers.get("content-Type", None) + content_type = request.headers.get("content-Type", os.environ.get("DEFAULT_CONTENT_TYPE", "")).lower() # try to deserialize payload - deserialized_body = ContentType.get_deserializer(content_type).deserialize( + deserialized_body = ContentType.get_deserializer(content_type, task).deserialize( await request.body() ) # checks if input schema is correct @@ -112,26 +134,47 @@ async def predict(request): dict(request.query_params) ) + # We lazily load pipelines for alt tasks + + if task == "feature-extraction" and HF_TASK in [ + "sentence-similarity", + "sentence-embeddings", + "sentence-ranking", + ]: + task = "sentence-embeddings" + inference_handler = INFERENCE_HANDLERS.get(task) + if not inference_handler: + with INFERENCE_HANDLERS_LOCK: + if task not in INFERENCE_HANDLERS: + inference_handler = get_inference_handler_either_custom_or_default_handler( + HF_MODEL_DIR, task=task) + INFERENCE_HANDLERS[task] = inference_handler + else: + inference_handler = INFERENCE_HANDLERS[task] # tracks request time start_time = perf_counter() - # run async not blocking call - pred = await async_handler_call(inference_handler, deserialized_body) + + with idle.request_witnesses(): + # run async not blocking call + pred = await async_handler_call(inference_handler, deserialized_body) + # log request time logger.info( f"POST {request.url.path} | Duration: {(perf_counter()-start_time) *1000:.2f} ms" ) # response extracts content from request - accept = request.headers.get("accept", None) + accept = request.headers.get("accept") if accept is None or accept == "*/*": - accept = "application/json" + accept = os.environ.get("DEFAULT_ACCEPT", "application/json") + logger.info("Request accepts %s", accept) # deserialized and resonds with json serialized_response_body = ContentType.get_serializer(accept).serialize( pred, accept ) return Response(serialized_response_body, media_type=accept) except Exception as e: - logger.error(e) + logger.exception(e) return Response( Jsoner.serialize({"error": str(e)}), status_code=400, @@ -159,14 +202,19 @@ async def predict(request): on_startup=[prepare_model_artifacts], ) else: + routes = [ + Route("/", health, methods=["GET"]), + Route("/health", health, methods=["GET"]), + Route("/", predict, methods=["POST"]), + Route("/predict", predict, methods=["POST"]), + Route("/metrics", metrics, methods=["GET"]), + ] + if api_inference_compat(): + routes.append( + Route("/pipeline/{task:path}", predict, methods=["POST"]) + ) app = Starlette( debug=False, - routes=[ - Route("/", health, methods=["GET"]), - Route("/health", health, methods=["GET"]), - Route("/", predict, methods=["POST"]), - Route("/predict", predict, methods=["POST"]), - Route("/metrics", metrics, methods=["GET"]), - ], + routes=routes, on_startup=[prepare_model_artifacts], ) diff --git a/test-requirements.txt b/test-requirements.txt new file mode 100644 index 00000000..082fbf6d --- /dev/null +++ b/test-requirements.txt @@ -0,0 +1,12 @@ +isort +ruff +datasets +docker +mock==2.0.0 +parameterized +psutil +pytest-sugar +pytest-xdist +pytest==7.2.1 +requests +tenacity diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index d69a9f97..2b3f79e8 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -9,7 +9,7 @@ import tenacity from transformers.testing_utils import _run_slow_tests -from huggingface_inference_toolkit.utils import _load_repository_from_hf +from huggingface_inference_toolkit.heavy_utils import load_repository_from_hf from tests.integ.config import task2model HF_HUB_CACHE = os.environ.get("HF_HUB_CACHE", "/home/ubuntu/.cache/huggingface/hub") @@ -124,7 +124,7 @@ def local_container(device, task, repository_id, framework): object_id = model.replace("/", "--") model_dir = f"{HF_HUB_CACHE}/{object_id}" - _storage_dir = _load_repository_from_hf( + _storage_dir = load_repository_from_hf( repository_id=model, target_dir=model_dir ) diff --git a/tests/integ/helpers.py b/tests/integ/helpers.py index 5591e2a3..7de85a40 100644 --- a/tests/integ/helpers.py +++ b/tests/integ/helpers.py @@ -10,7 +10,7 @@ from docker import DockerClient from transformers.testing_utils import _run_slow_tests, require_tf, require_torch -from huggingface_inference_toolkit.utils import _load_repository_from_hf +from huggingface_inference_toolkit.heavy_utils import load_repository_from_hf from tests.integ.config import task2input, task2model, task2output, task2validation IS_GPU = _run_slow_tests @@ -207,7 +207,7 @@ def test_pt_container_local_model(task: str) -> None: make_sure_other_containers_are_stopped(client, container_name) with tempfile.TemporaryDirectory() as tmpdirname: # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py - _load_repository_from_hf(model, tmpdirname, framework="pytorch") + load_repository_from_hf(model, tmpdirname, framework="pytorch") container = client.containers.run( container_image, name=container_name, @@ -238,7 +238,7 @@ def test_pt_container_custom_handler(repository_id) -> None: make_sure_other_containers_are_stopped(client, container_name) with tempfile.TemporaryDirectory() as tmpdirname: # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py - _storage_dir = _load_repository_from_hf(repository_id, tmpdirname) + _storage_dir = load_repository_from_hf(repository_id, tmpdirname) container = client.containers.run( container_image, name=container_name, @@ -275,7 +275,7 @@ def test_pt_container_legacy_custom_pipeline(repository_id: str) -> None: make_sure_other_containers_are_stopped(client, container_name) with tempfile.TemporaryDirectory() as tmpdirname: # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py - _storage_dir = _load_repository_from_hf(repository_id, tmpdirname) + _storage_dir = load_repository_from_hf(repository_id, tmpdirname) container = client.containers.run( container_image, name=container_name, @@ -393,7 +393,7 @@ def test_tf_container_local_model(task) -> None: make_sure_other_containers_are_stopped(client, container_name) with tempfile.TemporaryDirectory() as tmpdirname: # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py - _storage_dir = _load_repository_from_hf(model, tmpdirname, framework=framework) + _storage_dir = load_repository_from_hf(model, tmpdirname, framework=framework) container = client.containers.run( container_image, name=container_name, @@ -421,7 +421,7 @@ def test_tf_container_local_model(task) -> None: # make_sure_other_containers_are_stopped(client, container_name) # with tempfile.TemporaryDirectory() as tmpdirname: # # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py -# storage_dir = _load_repository_from_hf("philschmid/custom-pipeline-text-classification", tmpdirname) +# storage_dir = load_repository_from_hf("philschmid/custom-pipeline-text-classification", tmpdirname) # container = client.containers.run( # container_image, # name=container_name, diff --git a/tests/unit/test_diffusers.py b/tests/unit/test_diffusers.py index b7f4a56d..90abe08c 100644 --- a/tests/unit/test_diffusers.py +++ b/tests/unit/test_diffusers.py @@ -5,14 +5,14 @@ from transformers.testing_utils import require_torch, slow from huggingface_inference_toolkit.diffusers_utils import IEAutoPipelineForText2Image -from huggingface_inference_toolkit.utils import _load_repository_from_hf, get_pipeline +from huggingface_inference_toolkit.heavy_utils import get_pipeline, load_repository_from_hf logging.basicConfig(level="DEBUG") @require_torch def test_get_diffusers_pipeline(): with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf( + storage_dir = load_repository_from_hf( "echarlaix/tiny-random-stable-diffusion-xl", tmpdirname, framework="pytorch" @@ -25,7 +25,7 @@ def test_get_diffusers_pipeline(): @require_torch def test_pipe_on_gpu(): with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf( + storage_dir = load_repository_from_hf( "echarlaix/tiny-random-stable-diffusion-xl", tmpdirname, framework="pytorch" @@ -41,7 +41,7 @@ def test_pipe_on_gpu(): @require_torch def test_text_to_image_task(): with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf( + storage_dir = load_repository_from_hf( "echarlaix/tiny-random-stable-diffusion-xl", tmpdirname, framework="pytorch" diff --git a/tests/unit/test_handler.py b/tests/unit/test_handler.py index 2935d6e7..47c6c25c 100644 --- a/tests/unit/test_handler.py +++ b/tests/unit/test_handler.py @@ -8,9 +8,9 @@ HuggingFaceHandler, get_inference_handler_either_custom_or_default_handler, ) -from huggingface_inference_toolkit.utils import ( +from huggingface_inference_toolkit.heavy_utils import ( _is_gpu_available, - _load_repository_from_hf, + load_repository_from_hf, ) TASK = "text-classification" @@ -29,7 +29,7 @@ def test_pt_get_device() -> None: with tempfile.TemporaryDirectory() as tmpdirname: # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py - storage_dir = _load_repository_from_hf(MODEL, tmpdirname, framework="pytorch") + storage_dir = load_repository_from_hf(MODEL, tmpdirname, framework="pytorch") h = HuggingFaceHandler(model_dir=str(storage_dir), task=TASK) if torch.cuda.is_available(): assert h.pipeline.model.device == torch.device(type="cuda", index=0) @@ -41,7 +41,7 @@ def test_pt_get_device() -> None: def test_pt_predict_call(input_data: Dict[str, str]) -> None: with tempfile.TemporaryDirectory() as tmpdirname: # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py - storage_dir = _load_repository_from_hf(MODEL, tmpdirname, framework="pytorch") + storage_dir = load_repository_from_hf(MODEL, tmpdirname, framework="pytorch") h = HuggingFaceHandler(model_dir=str(storage_dir), task=TASK) prediction = h(input_data) @@ -52,7 +52,7 @@ def test_pt_predict_call(input_data: Dict[str, str]) -> None: @require_torch def test_pt_custom_pipeline(input_data: Dict[str, str]) -> None: with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf( + storage_dir = load_repository_from_hf( "philschmid/custom-pipeline-text-classification", tmpdirname, framework="pytorch", @@ -64,7 +64,7 @@ def test_pt_custom_pipeline(input_data: Dict[str, str]) -> None: @require_torch def test_pt_sentence_transformers_pipeline(input_data: Dict[str, str]) -> None: with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf( + storage_dir = load_repository_from_hf( "sentence-transformers/all-MiniLM-L6-v2", tmpdirname, framework="pytorch" ) h = get_inference_handler_either_custom_or_default_handler(str(storage_dir), task="sentence-embeddings") @@ -76,7 +76,7 @@ def test_pt_sentence_transformers_pipeline(input_data: Dict[str, str]) -> None: def test_tf_get_device(): with tempfile.TemporaryDirectory() as tmpdirname: # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py - storage_dir = _load_repository_from_hf(MODEL, tmpdirname, framework="tensorflow") + storage_dir = load_repository_from_hf(MODEL, tmpdirname, framework="tensorflow") h = HuggingFaceHandler(model_dir=str(storage_dir), task=TASK) if _is_gpu_available(): assert h.pipeline.device == 0 @@ -88,7 +88,7 @@ def test_tf_get_device(): def test_tf_predict_call(input_data: Dict[str, str]) -> None: with tempfile.TemporaryDirectory() as tmpdirname: # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py - storage_dir = _load_repository_from_hf(MODEL, tmpdirname, framework="tensorflow") + storage_dir = load_repository_from_hf(MODEL, tmpdirname, framework="tensorflow") handler = HuggingFaceHandler(model_dir=str(storage_dir), task=TASK, framework="tf") prediction = handler(input_data) @@ -99,7 +99,7 @@ def test_tf_predict_call(input_data: Dict[str, str]) -> None: @require_tf def test_tf_custom_pipeline(input_data: Dict[str, str]) -> None: with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf( + storage_dir = load_repository_from_hf( "philschmid/custom-pipeline-text-classification", tmpdirname, framework="tensorflow", @@ -112,7 +112,7 @@ def test_tf_custom_pipeline(input_data: Dict[str, str]) -> None: def test_tf_sentence_transformers_pipeline(): # TODO should fail! because TF is not supported yet with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf( + storage_dir = load_repository_from_hf( "sentence-transformers/all-MiniLM-L6-v2", tmpdirname, framework="tensorflow" ) with pytest.raises(Exception) as _exc_info: diff --git a/tests/unit/test_optimum_utils.py b/tests/unit/test_optimum_utils.py index 075fdf81..5faa9892 100644 --- a/tests/unit/test_optimum_utils.py +++ b/tests/unit/test_optimum_utils.py @@ -4,12 +4,12 @@ import pytest from transformers.testing_utils import require_torch +from huggingface_inference_toolkit.heavy_utils import load_repository_from_hf from huggingface_inference_toolkit.optimum_utils import ( get_input_shapes, get_optimum_neuron_pipeline, is_optimum_neuron_available, ) -from huggingface_inference_toolkit.utils import _load_repository_from_hf require_inferentia = pytest.mark.skipif( not is_optimum_neuron_available(), @@ -34,7 +34,7 @@ def test_not_supported_task(): @require_inferentia def test_get_input_shapes_from_file(): with tempfile.TemporaryDirectory() as tmpdirname: - storage_folder = _load_repository_from_hf( + storage_folder = load_repository_from_hf( repository_id=REMOTE_CONVERTED_MODEL, target_dir=tmpdirname, ) @@ -49,7 +49,7 @@ def test_get_input_shapes_from_env(): os.environ["HF_OPTIMUM_BATCH_SIZE"] = "4" os.environ["HF_OPTIMUM_SEQUENCE_LENGTH"] = "32" with tempfile.TemporaryDirectory() as tmpdirname: - storage_folder = _load_repository_from_hf( + storage_folder = load_repository_from_hf( repository_id=REMOTE_NOT_CONVERTED_MODEL, target_dir=tmpdirname, ) @@ -77,7 +77,7 @@ def test_get_optimum_neuron_pipeline_from_converted_model(): def test_get_optimum_neuron_pipeline_from_non_converted_model(): os.environ["HF_OPTIMUM_SEQUENCE_LENGTH"] = "32" with tempfile.TemporaryDirectory() as tmpdirname: - storage_folder = _load_repository_from_hf( + storage_folder = load_repository_from_hf( repository_id=REMOTE_NOT_CONVERTED_MODEL, target_dir=tmpdirname, ) diff --git a/tests/unit/test_sentence_transformers.py b/tests/unit/test_sentence_transformers.py index e48533bc..efe86b14 100644 --- a/tests/unit/test_sentence_transformers.py +++ b/tests/unit/test_sentence_transformers.py @@ -3,20 +3,20 @@ import pytest from transformers.testing_utils import require_torch +from huggingface_inference_toolkit.heavy_utils import ( + get_pipeline, + load_repository_from_hf, +) from huggingface_inference_toolkit.sentence_transformers_utils import ( SentenceEmbeddingPipeline, get_sentence_transformers_pipeline, ) -from huggingface_inference_toolkit.utils import ( - _load_repository_from_hf, - get_pipeline, -) @require_torch def test_get_sentence_transformers_pipeline(): with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf("sentence-transformers/all-MiniLM-L6-v2", tmpdirname) + storage_dir = load_repository_from_hf("sentence-transformers/all-MiniLM-L6-v2", tmpdirname) pipe = get_pipeline("sentence-embeddings", storage_dir.as_posix()) assert isinstance(pipe, SentenceEmbeddingPipeline) @@ -24,7 +24,7 @@ def test_get_sentence_transformers_pipeline(): @require_torch def test_sentence_embedding_task(): with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf("sentence-transformers/all-MiniLM-L6-v2", tmpdirname) + storage_dir = load_repository_from_hf("sentence-transformers/all-MiniLM-L6-v2", tmpdirname) pipe = get_sentence_transformers_pipeline("sentence-embeddings", storage_dir.as_posix()) res = pipe(sentences="Lets create an embedding") assert isinstance(res["embeddings"], list) @@ -36,7 +36,7 @@ def test_sentence_embedding_task(): @require_torch def test_sentence_similarity(): with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf("sentence-transformers/all-MiniLM-L6-v2", tmpdirname) + storage_dir = load_repository_from_hf("sentence-transformers/all-MiniLM-L6-v2", tmpdirname) pipe = get_sentence_transformers_pipeline("sentence-similarity", storage_dir.as_posix()) res = pipe(source_sentence="Lets create an embedding", sentences=["Lets create an embedding"]) assert isinstance(res["similarities"], list) @@ -45,7 +45,7 @@ def test_sentence_similarity(): @require_torch def test_sentence_ranking(): with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf("cross-encoder/ms-marco-MiniLM-L-6-v2", tmpdirname) + storage_dir = load_repository_from_hf("cross-encoder/ms-marco-MiniLM-L-6-v2", tmpdirname) pipe = get_sentence_transformers_pipeline("sentence-ranking", storage_dir.as_posix()) res = pipe( sentences=[ @@ -61,7 +61,7 @@ def test_sentence_ranking(): @require_torch def test_sentence_ranking_tei(): with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf("cross-encoder/ms-marco-MiniLM-L-6-v2", tmpdirname, framework="pytorch") + storage_dir = load_repository_from_hf("cross-encoder/ms-marco-MiniLM-L-6-v2", tmpdirname, framework="pytorch") pipe = get_sentence_transformers_pipeline("sentence-ranking", storage_dir.as_posix()) res = pipe( query="Lets create an embedding", @@ -82,7 +82,7 @@ def test_sentence_ranking_tei(): @require_torch def test_sentence_ranking_validation_errors(): with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf("cross-encoder/ms-marco-MiniLM-L-6-v2", tmpdirname, framework="pytorch") + storage_dir = load_repository_from_hf("cross-encoder/ms-marco-MiniLM-L-6-v2", tmpdirname, framework="pytorch") pipe = get_sentence_transformers_pipeline("sentence-ranking", storage_dir.as_posix()) with pytest.raises( diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index c0f1fef8..3b99040a 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -7,13 +7,13 @@ from transformers.testing_utils import require_tf, require_torch, slow from huggingface_inference_toolkit.handler import get_inference_handler_either_custom_or_default_handler -from huggingface_inference_toolkit.utils import ( +from huggingface_inference_toolkit.heavy_utils import ( _get_framework, _is_gpu_available, - _load_repository_from_hf, - check_and_register_custom_pipeline_from_directory, get_pipeline, + load_repository_from_hf, ) +from huggingface_inference_toolkit.utils import check_and_register_custom_pipeline_from_directory TASK_MODEL = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english" @@ -22,7 +22,7 @@ def test_load_revision_repository_from_hf(): MODEL = "lysandre/tiny-bert-random" REVISION = "eb4c77816edd604d0318f8e748a1c606a2888493" with tempfile.TemporaryDirectory() as tmpdirname: - storage_folder = _load_repository_from_hf(MODEL, tmpdirname, revision=REVISION) + storage_folder = load_repository_from_hf(MODEL, tmpdirname, revision=REVISION) # folder contains all config files and pytorch_model.bin folder_contents = os.listdir(storage_folder) # revision doesn't have tokenizer @@ -36,7 +36,7 @@ def test_load_tensorflow_repository_from_hf(): tf_tmp = Path(tmpdirname).joinpath("tf") tf_tmp.mkdir(parents=True, exist_ok=True) - storage_folder = _load_repository_from_hf(MODEL, tf_tmp, framework="tensorflow") + storage_folder = load_repository_from_hf(MODEL, tf_tmp, framework="tensorflow") # folder contains all config files and pytorch_model.bin folder_contents = os.listdir(storage_folder) assert "pytorch_model.bin" not in folder_contents @@ -52,7 +52,7 @@ def test_load_onnx_repository_from_hf(): ox_tmp = Path(tmpdirname).joinpath("onnx") ox_tmp.mkdir(parents=True, exist_ok=True) - storage_folder = _load_repository_from_hf(MODEL, ox_tmp, framework="onnx") + storage_folder = load_repository_from_hf(MODEL, ox_tmp, framework="onnx") # folder contains all config files and pytorch_model.bin folder_contents = os.listdir(storage_folder) assert "pytorch_model.bin" not in folder_contents @@ -73,7 +73,7 @@ def test_load_pytorch_repository_from_hf(): pt_tmp = Path(tmpdirname).joinpath("pt") pt_tmp.mkdir(parents=True, exist_ok=True) - storage_folder = _load_repository_from_hf(MODEL, pt_tmp, framework="pytorch") + storage_folder = load_repository_from_hf(MODEL, pt_tmp, framework="pytorch") # folder contains all config files and pytorch_model.bin folder_contents = os.listdir(storage_folder) assert "pytorch_model.bin" in folder_contents @@ -109,7 +109,7 @@ def test_get_pipeline(): MODEL = "hf-internal-testing/tiny-random-BertForSequenceClassification" TASK = "text-classification" with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf(MODEL, tmpdirname, framework="pytorch") + storage_dir = load_repository_from_hf(MODEL, tmpdirname, framework="pytorch") pipe = get_pipeline( task = TASK, model_dir = storage_dir.as_posix(), @@ -121,7 +121,7 @@ def test_get_pipeline(): @require_torch def test_whisper_long_audio(cache_test_dir): with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf( + storage_dir = load_repository_from_hf( repository_id = "openai/whisper-tiny", target_dir = tmpdirname, ) @@ -139,7 +139,7 @@ def test_whisper_long_audio(cache_test_dir): @require_torch def test_wrapped_pipeline(): with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf( + storage_dir = load_repository_from_hf( repository_id = "microsoft/DialoGPT-small", target_dir = tmpdirname, framework="pytorch" @@ -175,7 +175,7 @@ def test_local_custom_pipeline(cache_test_dir): def test_remote_custom_pipeline(): with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf( + storage_dir = load_repository_from_hf( "philschmid/custom-pipeline-text-classification", tmpdirname, framework="pytorch" @@ -188,7 +188,7 @@ def test_remote_custom_pipeline(): def test_get_inference_handler_either_custom_or_default_pipeline(): with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf( + storage_dir = load_repository_from_hf( "philschmid/custom-pipeline-text-classification", tmpdirname, framework="pytorch" From 52511c0c5aa00507c2c7ee5f6fd500c685a2932a Mon Sep 17 00:00:00 2001 From: Raphael Glon Date: Fri, 19 Sep 2025 09:53:22 +0200 Subject: [PATCH 2/4] feat(relieve): discard request if the caller is not waiting for the answer anymore* When behind a proxy this requires the proxy to close the connection to be effective though Signed-off-by: Raphael Glon --- requirements.txt | 1 + src/huggingface_inference_toolkit/handler.py | 21 +++++- .../heavy_utils.py | 2 +- .../serialization/base.py | 2 +- src/huggingface_inference_toolkit/utils.py | 67 ++++++++++++++++++- .../webservice_starlette.py | 12 +++- 6 files changed, 98 insertions(+), 7 deletions(-) diff --git a/requirements.txt b/requirements.txt index 8acfe504..ebd1ed1d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ torch==2.5.1 torchvision torchaudio peft==0.15.1 +psutil>=6.0.0 diff --git a/src/huggingface_inference_toolkit/handler.py b/src/huggingface_inference_toolkit/handler.py index 3ff104d5..2ed8a757 100644 --- a/src/huggingface_inference_toolkit/handler.py +++ b/src/huggingface_inference_toolkit/handler.py @@ -5,7 +5,12 @@ from huggingface_inference_toolkit import logging from huggingface_inference_toolkit.const import HF_TRUST_REMOTE_CODE from huggingface_inference_toolkit.env_utils import api_inference_compat, ignore_custom_handler -from huggingface_inference_toolkit.utils import check_and_register_custom_pipeline_from_directory +from huggingface_inference_toolkit.logging import logger +from huggingface_inference_toolkit.utils import ( + already_left, + check_and_register_custom_pipeline_from_directory, + should_discard_left, +) class HuggingFaceHandler: @@ -39,7 +44,17 @@ def __call__(self, data: Dict[str, Any]): inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) - # diffusers and sentence transformers pipelines do not have the `task` arg + if "handler_params" in data: + handler_params = data.pop("handler_params") + if should_discard_left(): + request = handler_params.get("request") + if not request: + logger.warn("Cannot know if request caller already left, missing request handler param") + elif already_left(request): + logger.info("Discarding request as the caller already left") + return None + + # diffusers and sentence transformers pipelines do not have the `task` arg if not hasattr(self.pipeline, "task"): # sentence transformers parameters not supported yet if any(isinstance(self.pipeline, v) for v in SENTENCE_TRANSFORMERS_TASKS.values()): @@ -168,7 +183,7 @@ def __call__(self, data: Dict[str, Any]): scores = resp['scores'] if len(labels) == len(scores): new_resp = [] - for label, score in zip(labels, scores): + for label, score in zip(labels, scores, strict=True): new_resp.append({"label": label, "score": score}) resp = new_resp else: diff --git a/src/huggingface_inference_toolkit/heavy_utils.py b/src/huggingface_inference_toolkit/heavy_utils.py index c144082e..f12190f5 100644 --- a/src/huggingface_inference_toolkit/heavy_utils.py +++ b/src/huggingface_inference_toolkit/heavy_utils.py @@ -184,4 +184,4 @@ def get_pipeline( hf_pipeline.model.config.forced_decoder_ids = hf_pipeline.tokenizer.get_decoder_prompt_ids( language="english", task="transcribe" ) - return hf_pipeline # type: ignore \ No newline at end of file + return hf_pipeline # type: ignore diff --git a/src/huggingface_inference_toolkit/serialization/base.py b/src/huggingface_inference_toolkit/serialization/base.py index c6e58075..e949de4a 100644 --- a/src/huggingface_inference_toolkit/serialization/base.py +++ b/src/huggingface_inference_toolkit/serialization/base.py @@ -38,7 +38,7 @@ class ContentType: @staticmethod def get_deserializer(content_type: str, task: str): if not content_type: - message = f"No content type provided and no default one configured." + message = "No content type provided and no default one configured." raise Exception(message) if content_type.lower().startswith("application/octet-stream"): if "audio" in task or "speech" in task: diff --git a/src/huggingface_inference_toolkit/utils.py b/src/huggingface_inference_toolkit/utils.py index 1e96b5e1..93798ec9 100644 --- a/src/huggingface_inference_toolkit/utils.py +++ b/src/huggingface_inference_toolkit/utils.py @@ -1,7 +1,12 @@ import importlib.util +import ipaddress +import os import sys from pathlib import Path +import psutil +from starlette.requests import Request + from huggingface_inference_toolkit.const import HF_DEFAULT_PIPELINE_NAME, HF_MODULE_NAME from huggingface_inference_toolkit.logging import logger @@ -66,7 +71,7 @@ def check_and_register_custom_pipeline_from_directory(model_dir): # init custom handler with model_dir custom_pipeline = handler.EndpointHandler(model_dir) else: - logger.info(f"No spec from file location found for module %s, file %s", HF_MODULE_NAME, custom_module) + logger.info("No spec from file location found for module %s, file %s", HF_MODULE_NAME, custom_module) elif legacy_module.is_file(): logger.warning( """You are using a legacy custom pipeline. @@ -99,3 +104,63 @@ def convert_params_to_int_or_bool(params): if v == "true": params[k] = True return params + + +def should_discard_left() -> bool: + return os.getenv('DISCARD_LEFT', '0').lower() in ['true', 'yes', '1'] + + +def already_left(request: Request) -> bool: + """ + Check if the caller has already left without waiting for the answer to come. This can help during burst to relieve + the pressure on the worker by cancelling jobs whose results don't matter as they won't be fetched anyway + :param request: + :return: bool + """ + # NOTE: Starlette method request.is_disconnected is totally broken, consumes the payload, does not return + # the correct status. So we use the good old way to identify if the caller is still there. + # In any case, if we are not sure, we return False + logger.info("Checking if request caller already left") + try: + client = request.client + host = client.host + if not host: + return False + + port = int(client.port) + host = ipaddress.ip_address(host) + + if port <= 0 or port > 65535: + logger.warning("Unexpected source port format for caller %s", port) + return False + counter = 0 + for connection in psutil.net_connections(kind="tcp"): + counter += 1 + if connection.status != "ESTABLISHED": + continue + if not connection.raddr: + continue + if int(connection.raddr.port) != port: + continue + if ( + not connection.raddr.ip + or ipaddress.ip_address(connection.raddr.ip) != host + ): + continue + logger.info( + "Found caller connection still established, caller is most likely still there, %s", + connection, + ) + return False + except Exception as e: + logger.warning( + "Unexpected error while checking if caller already left, assuming still there" + ) + logger.exception(e) + return False + + logger.info( + "%d connections checked. No connection found matching to the caller, probably left", + counter, + ) + return True diff --git a/src/huggingface_inference_toolkit/webservice_starlette.py b/src/huggingface_inference_toolkit/webservice_starlette.py index 5dd4f800..820a6e7e 100644 --- a/src/huggingface_inference_toolkit/webservice_starlette.py +++ b/src/huggingface_inference_toolkit/webservice_starlette.py @@ -22,12 +22,13 @@ ) from huggingface_inference_toolkit.env_utils import api_inference_compat from huggingface_inference_toolkit.handler import ( + HuggingFaceHandler, get_inference_handler_either_custom_or_default_handler, ) from huggingface_inference_toolkit.logging import logger from huggingface_inference_toolkit.serialization.base import ContentType from huggingface_inference_toolkit.serialization.json_utils import Jsoner -from huggingface_inference_toolkit.utils import convert_params_to_int_or_bool +from huggingface_inference_toolkit.utils import convert_params_to_int_or_bool, should_discard_left from huggingface_inference_toolkit.vertex_ai_utils import _load_repository_from_gcs INFERENCE_HANDLERS = {} @@ -101,6 +102,7 @@ async def metrics(request): async def predict(request): global INFERENCE_HANDLERS + if not MODEL_DOWNLOADED: with MODEL_DL_LOCK: _eager_model_dl() @@ -154,6 +156,10 @@ async def predict(request): # tracks request time start_time = perf_counter() + if should_discard_left() and isinstance(inference_handler, HuggingFaceHandler): + deserialized_body['handler_params'] = { + 'request': request + } with idle.request_witnesses(): # run async not blocking call pred = await async_handler_call(inference_handler, deserialized_body) @@ -163,6 +169,10 @@ async def predict(request): f"POST {request.url.path} | Duration: {(perf_counter()-start_time) *1000:.2f} ms" ) + if should_discard_left() and pred is None: + logger.info("No content returned as caller already left") + return Response(status_code=204) + # response extracts content from request accept = request.headers.get("accept") if accept is None or accept == "*/*": From 54d2596560ac237b1292972259d97689ef27aecc Mon Sep 17 00:00:00 2001 From: Raphael Glon Date: Wed, 12 Nov 2025 15:27:15 +0100 Subject: [PATCH 3/4] feat: log level + fixes: async bug, idle bug * environment log level var * some long blocking sync calls should be wrapped in a thread (model download) * idle check should be done for the entire predict call otherwise in non idle mode the worker could be kicked in the middle of a request Signed-off-by: Raphael Glon --- scripts/entrypoint.sh | 2 +- src/huggingface_inference_toolkit/handler.py | 2 + src/huggingface_inference_toolkit/idle.py | 6 +- src/huggingface_inference_toolkit/logging.py | 3 +- .../webservice_starlette.py | 175 +++++++++--------- 5 files changed, 101 insertions(+), 87 deletions(-) diff --git a/scripts/entrypoint.sh b/scripts/entrypoint.sh index 20aedf9f..fb0ebad5 100755 --- a/scripts/entrypoint.sh +++ b/scripts/entrypoint.sh @@ -59,4 +59,4 @@ if [[ ! -z "${HF_MODEL_DIR}" ]]; then fi # Start the server -exec gunicorn webservice_starlette:app -k uvicorn.workers.UvicornWorker --workers ${WORKERS:-1} --bind 0.0.0.0:${PORT} +exec gunicorn webservice_starlette:app -k uvicorn.workers.UvicornWorker --workers ${WORKERS:-1} --bind 0.0.0.0:${PORT} --timeout 30 diff --git a/src/huggingface_inference_toolkit/handler.py b/src/huggingface_inference_toolkit/handler.py index 2ed8a757..5cb3b50b 100644 --- a/src/huggingface_inference_toolkit/handler.py +++ b/src/huggingface_inference_toolkit/handler.py @@ -38,6 +38,7 @@ def __call__(self, data: Dict[str, Any]): :return: prediction output """ + logger.debug("Calling HF default handler") # import as late as possible to reduce the footprint from huggingface_inference_toolkit.sentence_transformers_utils import SENTENCE_TRANSFORMERS_TASKS @@ -126,6 +127,7 @@ def __call__(self, data: Dict[str, Any]): if self.pipeline.task == "token-classification": parameters.setdefault("aggregation_strategy", os.environ.get("DEFAULT_AGGREGATION_STRATEGY", "simple")) + logger.debug("Performing inference") resp = self.pipeline(**inputs, **parameters) if isinstance(inputs, dict) else \ self.pipeline(inputs, **parameters) diff --git a/src/huggingface_inference_toolkit/idle.py b/src/huggingface_inference_toolkit/idle.py index a861d527..c094c420 100644 --- a/src/huggingface_inference_toolkit/idle.py +++ b/src/huggingface_inference_toolkit/idle.py @@ -24,7 +24,7 @@ async def live_check_loop(): while True: await asyncio.sleep(sleep_time) - LOG.debug("Checking whether we should unload anything from gpu") + LOG.debug("Checking whether we should unload anything from memory") last_start = LAST_START last_end = LAST_END @@ -50,9 +50,13 @@ async def live_check_loop(): @contextlib.contextmanager def request_witnesses(): global LAST_START, LAST_END + LOG.debug("Last request start was %s", LAST_START) + LOG.debug("Last request end was %s", LAST_END) # Simple assignment, concurrency safe, no need for any lock LAST_START = time.time() + LOG.debug("Current request start timestamp %s", LAST_START) try: yield finally: LAST_END = time.time() + LOG.debug("Current request end timestamp %s", LAST_END) diff --git a/src/huggingface_inference_toolkit/logging.py b/src/huggingface_inference_toolkit/logging.py index 513d94fe..5bf42bc0 100644 --- a/src/huggingface_inference_toolkit/logging.py +++ b/src/huggingface_inference_toolkit/logging.py @@ -1,4 +1,5 @@ import logging +import os import sys @@ -9,7 +10,7 @@ def setup_logging(): # Configure the root logger logging.basicConfig( - level=logging.INFO, + level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO")), format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", stream=sys.stdout, diff --git a/src/huggingface_inference_toolkit/webservice_starlette.py b/src/huggingface_inference_toolkit/webservice_starlette.py index 820a6e7e..5e635831 100644 --- a/src/huggingface_inference_toolkit/webservice_starlette.py +++ b/src/huggingface_inference_toolkit/webservice_starlette.py @@ -54,6 +54,7 @@ async def prepare_model_artifacts(): def _eager_model_dl(): + logger.debug("Model download") global MODEL_DOWNLOADED from huggingface_inference_toolkit.heavy_utils import load_repository_from_hf # 1. check if model artifacts available in HF_MODEL_DIR @@ -81,6 +82,8 @@ def _eager_model_dl(): Provided values are: HF_MODEL_DIR: {HF_MODEL_DIR} and HF_MODEL_ID:{HF_MODEL_ID}""" ) + else: + logger.debug("Model already downloaded in %s", HF_MODEL_DIR) MODEL_DOWNLOADED = True @@ -101,95 +104,99 @@ async def metrics(request): async def predict(request): - global INFERENCE_HANDLERS - - if not MODEL_DOWNLOADED: - with MODEL_DL_LOCK: - _eager_model_dl() - try: - task = request.path_params.get("task", HF_TASK) - # extracts content from request - content_type = request.headers.get("content-Type", os.environ.get("DEFAULT_CONTENT_TYPE", "")).lower() - # try to deserialize payload - deserialized_body = ContentType.get_deserializer(content_type, task).deserialize( - await request.body() - ) - # checks if input schema is correct - if "inputs" not in deserialized_body and "instances" not in deserialized_body: - raise ValueError( - f"Body needs to provide a inputs key, received: {orjson.dumps(deserialized_body)}" - ) - - # Decode base64 audio inputs before running inference - if "parameters" in deserialized_body and HF_TASK in { - "automatic-speech-recognition", - "audio-classification", - }: - # Be more strict on base64 decoding, the provided string should valid base64 encoded data - deserialized_body["inputs"] = base64.b64decode( - deserialized_body["inputs"], validate=True - ) - - # check for query parameter and add them to the body - if request.query_params and "parameters" not in deserialized_body: - deserialized_body["parameters"] = convert_params_to_int_or_bool( - dict(request.query_params) + with idle.request_witnesses(): + logger.debug("Received request, scope %s", request.scope) + + global INFERENCE_HANDLERS + + if not MODEL_DOWNLOADED: + with MODEL_DL_LOCK: + await asyncio.to_thread(_eager_model_dl) + try: + task = request.path_params.get("task", HF_TASK) + # extracts content from request + content_type = request.headers.get("content-Type", os.environ.get("DEFAULT_CONTENT_TYPE", "")).lower() + # try to deserialize payload + deserialized_body = ContentType.get_deserializer(content_type, task).deserialize( + await request.body() ) - - # We lazily load pipelines for alt tasks - - if task == "feature-extraction" and HF_TASK in [ - "sentence-similarity", - "sentence-embeddings", - "sentence-ranking", - ]: - task = "sentence-embeddings" - inference_handler = INFERENCE_HANDLERS.get(task) - if not inference_handler: - with INFERENCE_HANDLERS_LOCK: - if task not in INFERENCE_HANDLERS: - inference_handler = get_inference_handler_either_custom_or_default_handler( - HF_MODEL_DIR, task=task) - INFERENCE_HANDLERS[task] = inference_handler - else: - inference_handler = INFERENCE_HANDLERS[task] - # tracks request time - start_time = perf_counter() - - if should_discard_left() and isinstance(inference_handler, HuggingFaceHandler): - deserialized_body['handler_params'] = { - 'request': request - } - with idle.request_witnesses(): + # checks if input schema is correct + if "inputs" not in deserialized_body and "instances" not in deserialized_body: + raise ValueError( + f"Body needs to provide a inputs key, received: {orjson.dumps(deserialized_body)}" + ) + + # Decode base64 audio inputs before running inference + if "parameters" in deserialized_body and HF_TASK in { + "automatic-speech-recognition", + "audio-classification", + }: + # Be more strict on base64 decoding, the provided string should valid base64 encoded data + deserialized_body["inputs"] = base64.b64decode( + deserialized_body["inputs"], validate=True + ) + + # check for query parameter and add them to the body + if request.query_params and "parameters" not in deserialized_body: + deserialized_body["parameters"] = convert_params_to_int_or_bool( + dict(request.query_params) + ) + + # We lazily load pipelines for alt tasks + + if task == "feature-extraction" and HF_TASK in [ + "sentence-similarity", + "sentence-embeddings", + "sentence-ranking", + ]: + task = "sentence-embeddings" + inference_handler = INFERENCE_HANDLERS.get(task) + if not inference_handler: + with INFERENCE_HANDLERS_LOCK: + if task not in INFERENCE_HANDLERS: + inference_handler = get_inference_handler_either_custom_or_default_handler( + HF_MODEL_DIR, task=task) + INFERENCE_HANDLERS[task] = inference_handler + else: + inference_handler = INFERENCE_HANDLERS[task] + # tracks request time + start_time = perf_counter() + + if should_discard_left() and isinstance(inference_handler, HuggingFaceHandler): + deserialized_body['handler_params'] = { + 'request': request + } + + logger.debug("Calling inference handler prediction routine") # run async not blocking call pred = await async_handler_call(inference_handler, deserialized_body) - # log request time - logger.info( - f"POST {request.url.path} | Duration: {(perf_counter()-start_time) *1000:.2f} ms" - ) + # log request time + logger.info( + f"POST {request.url.path} | Duration: {(perf_counter()-start_time) *1000:.2f} ms" + ) - if should_discard_left() and pred is None: - logger.info("No content returned as caller already left") - return Response(status_code=204) - - # response extracts content from request - accept = request.headers.get("accept") - if accept is None or accept == "*/*": - accept = os.environ.get("DEFAULT_ACCEPT", "application/json") - logger.info("Request accepts %s", accept) - # deserialized and resonds with json - serialized_response_body = ContentType.get_serializer(accept).serialize( - pred, accept - ) - return Response(serialized_response_body, media_type=accept) - except Exception as e: - logger.exception(e) - return Response( - Jsoner.serialize({"error": str(e)}), - status_code=400, - media_type="application/json", - ) + if should_discard_left() and pred is None: + logger.info("No content returned as caller already left") + return Response(status_code=204) + + # response extracts content from request + accept = request.headers.get("accept") + if accept is None or accept == "*/*": + accept = os.environ.get("DEFAULT_ACCEPT", "application/json") + logger.info("Request accepts %s", accept) + # deserialized and resonds with json + serialized_response_body = ContentType.get_serializer(accept).serialize( + pred, accept + ) + return Response(serialized_response_body, media_type=accept) + except Exception as e: + logger.exception(e) + return Response( + Jsoner.serialize({"error": str(e)}), + status_code=400, + media_type="application/json", + ) # Create app based on which cloud environment is used From 2c1c2afe154b6c9a980154c8ababcffbda0884ba Mon Sep 17 00:00:00 2001 From: Raphael Glon Date: Mon, 17 Nov 2025 14:41:00 +0100 Subject: [PATCH 4/4] fixes: coroutine and threading mix caused blocking bugs Signed-off-by: Raphael Glon --- .../async_utils.py | 8 +++- src/huggingface_inference_toolkit/handler.py | 7 +++ src/huggingface_inference_toolkit/idle.py | 48 ++++++++++++------- .../webservice_starlette.py | 37 +++++++------- 4 files changed, 65 insertions(+), 35 deletions(-) diff --git a/src/huggingface_inference_toolkit/async_utils.py b/src/huggingface_inference_toolkit/async_utils.py index 5b6af3fd..e82aefd9 100644 --- a/src/huggingface_inference_toolkit/async_utils.py +++ b/src/huggingface_inference_toolkit/async_utils.py @@ -5,6 +5,8 @@ from anyio import Semaphore from typing_extensions import ParamSpec +from huggingface_inference_toolkit.logging import logger + # To not have too many threads running (which could happen on too many concurrent # requests, we limit it with a semaphore. MAX_CONCURRENT_THREADS = 1 @@ -15,6 +17,8 @@ # moves blocking call to asyncio threadpool limited to 1 to not overload the system # REF: https://stackoverflow.com/a/70929141 -async def async_handler_call(handler: Callable[P, T], body: Dict[str, Any]) -> T: +async def async_call(handler: Callable[P, T], *args, **kwargs) -> T: + logger.info("Setting blocking call to async handler") async with MAX_THREADS_GUARD: - return await anyio.to_thread.run_sync(functools.partial(handler, body)) + logger.info("Async call semaphore passed") + return await anyio.to_thread.run_sync(handler, *args, **kwargs) diff --git a/src/huggingface_inference_toolkit/handler.py b/src/huggingface_inference_toolkit/handler.py index 5cb3b50b..d4950d68 100644 --- a/src/huggingface_inference_toolkit/handler.py +++ b/src/huggingface_inference_toolkit/handler.py @@ -1,5 +1,6 @@ import os from pathlib import Path +from time import perf_counter from typing import Any, Dict, Literal, Optional, Union from huggingface_inference_toolkit import logging @@ -37,7 +38,13 @@ def __call__(self, data: Dict[str, Any]): :data: (obj): the raw request body data. :return: prediction output """ + start = perf_counter() + pred = self._timed_call(data) + end = perf_counter() + logger.info("Inference duration: %.2f ms", (end - start) * 1000) + return pred + def _timed_call(self, data: Dict[str, Any]): logger.debug("Calling HF default handler") # import as late as possible to reduce the footprint from huggingface_inference_toolkit.sentence_transformers_utils import SENTENCE_TRANSFORMERS_TASKS diff --git a/src/huggingface_inference_toolkit/idle.py b/src/huggingface_inference_toolkit/idle.py index c094c420..462dbdfe 100644 --- a/src/huggingface_inference_toolkit/idle.py +++ b/src/huggingface_inference_toolkit/idle.py @@ -5,6 +5,8 @@ import signal import time +from anyio import Semaphore + LOG = logging.getLogger(__name__) LAST_START = None @@ -13,13 +15,16 @@ UNLOAD_IDLE = os.getenv("UNLOAD_IDLE", "").lower() in ("1", "true") IDLE_TIMEOUT = int(os.getenv("IDLE_TIMEOUT", 15)) +MAX_REQUESTS = 1000 +REQUEST_COUNTER = Semaphore(MAX_REQUESTS) + async def live_check_loop(): global LAST_START, LAST_END pid = os.getpid() - LOG.debug("Starting live check loop") + LOG.info("Starting live check loop") sleep_time = max(int(IDLE_TIMEOUT // 5), 1) while True: @@ -31,9 +36,16 @@ async def live_check_loop(): LOG.debug("Checking pid %d activity", pid) if not last_start: + LOG.debug("No request yet, no need to unload") + continue + + if REQUEST_COUNTER.value < MAX_REQUESTS: + LOG.info("idle checker: %s requests likely being processed for pid %d, it won't be killed", + MAX_REQUESTS - REQUEST_COUNTER.value, pid) continue if not last_end or last_start >= last_end: - LOG.debug("Request likely being processed for pid %d", pid) + LOG.warning("This case should not be possible, semaphore unconsistency ? " + "Request likely being processed for pid %d", pid) continue now = time.time() last_request_age = now - last_end @@ -41,22 +53,24 @@ async def live_check_loop(): if last_request_age < IDLE_TIMEOUT: LOG.debug("Model recently active") else: - LOG.debug("Inactive for too long. Leaving live check loop") + LOG.info("Idle checker: worker inactive for too long. Leaving live check loop") break - LOG.debug("Aborting this worker") + LOG.info("Aborting this idle worker") os.kill(pid, signal.SIGTERM) -@contextlib.contextmanager -def request_witnesses(): - global LAST_START, LAST_END - LOG.debug("Last request start was %s", LAST_START) - LOG.debug("Last request end was %s", LAST_END) - # Simple assignment, concurrency safe, no need for any lock - LAST_START = time.time() - LOG.debug("Current request start timestamp %s", LAST_START) - try: - yield - finally: - LAST_END = time.time() - LOG.debug("Current request end timestamp %s", LAST_END) +@contextlib.asynccontextmanager +async def request_witnesses(): + async with REQUEST_COUNTER: + LOG.info("Current request count: %s", MAX_REQUESTS - REQUEST_COUNTER.value) + global LAST_START, LAST_END + LOG.debug("Last request start was %s", LAST_START) + LOG.debug("Last request end was %s", LAST_END) + # Simple assignment, concurrency safe, no need for any lock + LAST_START = time.time() + LOG.debug("Current request start timestamp %s", LAST_START) + try: + yield + finally: + LAST_END = time.time() + LOG.debug("Current request end timestamp %s", LAST_END) diff --git a/src/huggingface_inference_toolkit/webservice_starlette.py b/src/huggingface_inference_toolkit/webservice_starlette.py index 5e635831..4c824542 100644 --- a/src/huggingface_inference_toolkit/webservice_starlette.py +++ b/src/huggingface_inference_toolkit/webservice_starlette.py @@ -1,17 +1,17 @@ import asyncio import base64 import os -import threading from pathlib import Path from time import perf_counter import orjson +from anyio import Semaphore from starlette.applications import Starlette from starlette.responses import PlainTextResponse, Response from starlette.routing import Route from huggingface_inference_toolkit import idle -from huggingface_inference_toolkit.async_utils import MAX_CONCURRENT_THREADS, MAX_THREADS_GUARD, async_handler_call +from huggingface_inference_toolkit.async_utils import MAX_CONCURRENT_THREADS, MAX_THREADS_GUARD, async_call from huggingface_inference_toolkit.const import ( HF_FRAMEWORK, HF_HUB_TOKEN, @@ -32,9 +32,9 @@ from huggingface_inference_toolkit.vertex_ai_utils import _load_repository_from_gcs INFERENCE_HANDLERS = {} -INFERENCE_HANDLERS_LOCK = threading.Lock() +INFERENCE_HANDLERS_SEMAPHORE = Semaphore(1) MODEL_DOWNLOADED = False -MODEL_DL_LOCK = threading.Lock() +MODEL_DL_SEMAPHORE = Semaphore(1) async def prepare_model_artifacts(): @@ -43,7 +43,7 @@ async def prepare_model_artifacts(): if idle.UNLOAD_IDLE: asyncio.create_task(idle.live_check_loop(), name="live_check_loop") else: - _eager_model_dl() + await async_call(_eager_model_dl) logger.info(f"Initializing model from directory:{HF_MODEL_DIR}") # 2. determine correct inference handler inference_handler = get_inference_handler_either_custom_or_default_handler( @@ -54,7 +54,7 @@ async def prepare_model_artifacts(): def _eager_model_dl(): - logger.debug("Model download") + logger.info("Model download") global MODEL_DOWNLOADED from huggingface_inference_toolkit.heavy_utils import load_repository_from_hf # 1. check if model artifacts available in HF_MODEL_DIR @@ -83,7 +83,8 @@ def _eager_model_dl(): HF_MODEL_DIR: {HF_MODEL_DIR} and HF_MODEL_ID:{HF_MODEL_ID}""" ) else: - logger.debug("Model already downloaded in %s", HF_MODEL_DIR) + logger.info("Model already downloaded in %s", HF_MODEL_DIR) + logger.info("Model successfully downloaded") MODEL_DOWNLOADED = True @@ -104,14 +105,19 @@ async def metrics(request): async def predict(request): - with idle.request_witnesses(): + total_start_time = perf_counter() + + async with idle.request_witnesses(): logger.debug("Received request, scope %s", request.scope) global INFERENCE_HANDLERS if not MODEL_DOWNLOADED: - with MODEL_DL_LOCK: - await asyncio.to_thread(_eager_model_dl) + async with MODEL_DL_SEMAPHORE: + if not MODEL_DOWNLOADED: + logger.info("Model dl semaphore acquired") + await async_call(_eager_model_dl) + logger.info("Model dl semaphore released") try: task = request.path_params.get("task", HF_TASK) # extracts content from request @@ -152,28 +158,27 @@ async def predict(request): task = "sentence-embeddings" inference_handler = INFERENCE_HANDLERS.get(task) if not inference_handler: - with INFERENCE_HANDLERS_LOCK: + async with INFERENCE_HANDLERS_SEMAPHORE: if task not in INFERENCE_HANDLERS: inference_handler = get_inference_handler_either_custom_or_default_handler( HF_MODEL_DIR, task=task) INFERENCE_HANDLERS[task] = inference_handler else: inference_handler = INFERENCE_HANDLERS[task] - # tracks request time - start_time = perf_counter() if should_discard_left() and isinstance(inference_handler, HuggingFaceHandler): deserialized_body['handler_params'] = { 'request': request } - logger.debug("Calling inference handler prediction routine") + logger.info("Calling inference handler prediction routine") # run async not blocking call - pred = await async_handler_call(inference_handler, deserialized_body) + pred = await async_call(inference_handler, deserialized_body) # log request time + end_time = perf_counter() logger.info( - f"POST {request.url.path} | Duration: {(perf_counter()-start_time) *1000:.2f} ms" + f"POST {request.url.path} Total request duration: {(end_time-total_start_time) *1000:.2f} ms" ) if should_discard_left() and pred is None: