11import asyncio
22import base64
33import os
4- import threading
54from pathlib import Path
65from time import perf_counter
76
87import orjson
8+ from anyio import Semaphore
99from starlette .applications import Starlette
1010from starlette .responses import PlainTextResponse , Response
1111from starlette .routing import Route
1212
1313from huggingface_inference_toolkit import idle
14- from huggingface_inference_toolkit .async_utils import MAX_CONCURRENT_THREADS , MAX_THREADS_GUARD , async_handler_call
14+ from huggingface_inference_toolkit .async_utils import MAX_CONCURRENT_THREADS , MAX_THREADS_GUARD , async_call
1515from huggingface_inference_toolkit .const import (
1616 HF_FRAMEWORK ,
1717 HF_HUB_TOKEN ,
3232from huggingface_inference_toolkit .vertex_ai_utils import _load_repository_from_gcs
3333
3434INFERENCE_HANDLERS = {}
35- INFERENCE_HANDLERS_LOCK = threading . Lock ( )
35+ INFERENCE_HANDLERS_SEMAPHORE = Semaphore ( 1 )
3636MODEL_DOWNLOADED = False
37- MODEL_DL_LOCK = threading . Lock ( )
37+ MODEL_DL_SEMAPHORE = Semaphore ( 1 )
3838
3939
4040async def prepare_model_artifacts ():
@@ -43,7 +43,7 @@ async def prepare_model_artifacts():
4343 if idle .UNLOAD_IDLE :
4444 asyncio .create_task (idle .live_check_loop (), name = "live_check_loop" )
4545 else :
46- _eager_model_dl ( )
46+ await async_call ( _eager_model_dl )
4747 logger .info (f"Initializing model from directory:{ HF_MODEL_DIR } " )
4848 # 2. determine correct inference handler
4949 inference_handler = get_inference_handler_either_custom_or_default_handler (
@@ -54,7 +54,7 @@ async def prepare_model_artifacts():
5454
5555
5656def _eager_model_dl ():
57- logger .debug ("Model download" )
57+ logger .info ("Model download" )
5858 global MODEL_DOWNLOADED
5959 from huggingface_inference_toolkit .heavy_utils import load_repository_from_hf
6060 # 1. check if model artifacts available in HF_MODEL_DIR
@@ -83,7 +83,8 @@ def _eager_model_dl():
8383 HF_MODEL_DIR: { HF_MODEL_DIR } and HF_MODEL_ID:{ HF_MODEL_ID } """
8484 )
8585 else :
86- logger .debug ("Model already downloaded in %s" , HF_MODEL_DIR )
86+ logger .info ("Model already downloaded in %s" , HF_MODEL_DIR )
87+ logger .info ("Model successfully downloaded" )
8788 MODEL_DOWNLOADED = True
8889
8990
@@ -104,14 +105,19 @@ async def metrics(request):
104105
105106
106107async def predict (request ):
107- with idle .request_witnesses ():
108+ total_start_time = perf_counter ()
109+
110+ async with idle .request_witnesses ():
108111 logger .debug ("Received request, scope %s" , request .scope )
109112
110113 global INFERENCE_HANDLERS
111114
112115 if not MODEL_DOWNLOADED :
113- with MODEL_DL_LOCK :
114- await asyncio .to_thread (_eager_model_dl )
116+ async with MODEL_DL_SEMAPHORE :
117+ if not MODEL_DOWNLOADED :
118+ logger .info ("Model dl semaphore acquired" )
119+ await async_call (_eager_model_dl )
120+ logger .info ("Model dl semaphore released" )
115121 try :
116122 task = request .path_params .get ("task" , HF_TASK )
117123 # extracts content from request
@@ -152,28 +158,27 @@ async def predict(request):
152158 task = "sentence-embeddings"
153159 inference_handler = INFERENCE_HANDLERS .get (task )
154160 if not inference_handler :
155- with INFERENCE_HANDLERS_LOCK :
161+ async with INFERENCE_HANDLERS_SEMAPHORE :
156162 if task not in INFERENCE_HANDLERS :
157163 inference_handler = get_inference_handler_either_custom_or_default_handler (
158164 HF_MODEL_DIR , task = task )
159165 INFERENCE_HANDLERS [task ] = inference_handler
160166 else :
161167 inference_handler = INFERENCE_HANDLERS [task ]
162- # tracks request time
163- start_time = perf_counter ()
164168
165169 if should_discard_left () and isinstance (inference_handler , HuggingFaceHandler ):
166170 deserialized_body ['handler_params' ] = {
167171 'request' : request
168172 }
169173
170- logger .debug ("Calling inference handler prediction routine" )
174+ logger .info ("Calling inference handler prediction routine" )
171175 # run async not blocking call
172- pred = await async_handler_call (inference_handler , deserialized_body )
176+ pred = await async_call (inference_handler , deserialized_body )
173177
174178 # log request time
179+ end_time = perf_counter ()
175180 logger .info (
176- f"POST { request .url .path } | Duration : { (perf_counter () - start_time ) * 1000 :.2f} ms"
181+ f"POST { request .url .path } Total request duration : { (end_time - total_start_time ) * 1000 :.2f} ms"
177182 )
178183
179184 if should_discard_left () and pred is None :
0 commit comments