Skip to content

Commit 66e920c

Browse files
committed
fixes: coroutine and threading mix caused blocking bugs
Signed-off-by: Raphael Glon <oOraph@users.noreply.github.com>
1 parent 54d2596 commit 66e920c

File tree

4 files changed

+65
-35
lines changed

4 files changed

+65
-35
lines changed

src/huggingface_inference_toolkit/async_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from anyio import Semaphore
66
from typing_extensions import ParamSpec
77

8+
from huggingface_inference_toolkit.logging import logger
9+
810
# To not have too many threads running (which could happen on too many concurrent
911
# requests, we limit it with a semaphore.
1012
MAX_CONCURRENT_THREADS = 1
@@ -15,6 +17,8 @@
1517

1618
# moves blocking call to asyncio threadpool limited to 1 to not overload the system
1719
# REF: https://stackoverflow.com/a/70929141
18-
async def async_handler_call(handler: Callable[P, T], body: Dict[str, Any]) -> T:
20+
async def async_call(handler: Callable[P, T], *args, **kwargs) -> T:
21+
logger.info("Setting blocking call to async handler")
1922
async with MAX_THREADS_GUARD:
20-
return await anyio.to_thread.run_sync(functools.partial(handler, body))
23+
logger.info("Async call semaphore passed")
24+
return await anyio.to_thread.run_sync(handler, *args, **kwargs)

src/huggingface_inference_toolkit/handler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from pathlib import Path
3+
from time import perf_counter
34
from typing import Any, Dict, Literal, Optional, Union
45

56
from huggingface_inference_toolkit import logging
@@ -37,7 +38,13 @@ def __call__(self, data: Dict[str, Any]):
3738
:data: (obj): the raw request body data.
3839
:return: prediction output
3940
"""
41+
start = perf_counter()
42+
pred = self._timed_call(data)
43+
end = perf_counter()
44+
logger.info("Inference duration: %.2f ms", (end - start) * 1000)
45+
return pred
4046

47+
def _timed_call(self, data: Dict[str, Any]):
4148
logger.debug("Calling HF default handler")
4249
# import as late as possible to reduce the footprint
4350
from huggingface_inference_toolkit.sentence_transformers_utils import SENTENCE_TRANSFORMERS_TASKS

src/huggingface_inference_toolkit/idle.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import signal
66
import time
77

8+
from anyio import Semaphore
9+
810
LOG = logging.getLogger(__name__)
911

1012
LAST_START = None
@@ -13,13 +15,16 @@
1315
UNLOAD_IDLE = os.getenv("UNLOAD_IDLE", "").lower() in ("1", "true")
1416
IDLE_TIMEOUT = int(os.getenv("IDLE_TIMEOUT", 15))
1517

18+
MAX_REQUESTS = 1000
19+
REQUEST_COUNTER = Semaphore(MAX_REQUESTS)
20+
1621

1722
async def live_check_loop():
1823
global LAST_START, LAST_END
1924

2025
pid = os.getpid()
2126

22-
LOG.debug("Starting live check loop")
27+
LOG.info("Starting live check loop")
2328
sleep_time = max(int(IDLE_TIMEOUT // 5), 1)
2429

2530
while True:
@@ -31,32 +36,41 @@ async def live_check_loop():
3136

3237
LOG.debug("Checking pid %d activity", pid)
3338
if not last_start:
39+
LOG.debug("No request yet, no need to unload")
40+
continue
41+
42+
if REQUEST_COUNTER.value < MAX_REQUESTS:
43+
LOG.info("idle checker: %s requests likely being processed for pid %d, it won't be killed",
44+
MAX_REQUESTS - REQUEST_COUNTER.value, pid)
3445
continue
3546
if not last_end or last_start >= last_end:
36-
LOG.debug("Request likely being processed for pid %d", pid)
47+
LOG.warning("This case should not be possible, semaphore unconsistency ? "
48+
"Request likely being processed for pid %d", pid)
3749
continue
3850
now = time.time()
3951
last_request_age = now - last_end
4052
LOG.debug("Pid %d, last request age %s", pid, last_request_age)
4153
if last_request_age < IDLE_TIMEOUT:
4254
LOG.debug("Model recently active")
4355
else:
44-
LOG.debug("Inactive for too long. Leaving live check loop")
56+
LOG.info("Idle checker: worker inactive for too long. Leaving live check loop")
4557
break
46-
LOG.debug("Aborting this worker")
58+
LOG.info("Aborting this idle worker")
4759
os.kill(pid, signal.SIGTERM)
4860

4961

50-
@contextlib.contextmanager
51-
def request_witnesses():
52-
global LAST_START, LAST_END
53-
LOG.debug("Last request start was %s", LAST_START)
54-
LOG.debug("Last request end was %s", LAST_END)
55-
# Simple assignment, concurrency safe, no need for any lock
56-
LAST_START = time.time()
57-
LOG.debug("Current request start timestamp %s", LAST_START)
58-
try:
59-
yield
60-
finally:
61-
LAST_END = time.time()
62-
LOG.debug("Current request end timestamp %s", LAST_END)
62+
@contextlib.asynccontextmanager
63+
async def request_witnesses():
64+
async with REQUEST_COUNTER:
65+
LOG.info("Current request count, %s", REQUEST_COUNTER.value)
66+
global LAST_START, LAST_END
67+
LOG.info("Last request start was %s", LAST_START)
68+
LOG.info("Last request end was %s", LAST_END)
69+
# Simple assignment, concurrency safe, no need for any lock
70+
LAST_START = time.time()
71+
LOG.info("Current request start timestamp %s", LAST_START)
72+
try:
73+
yield
74+
finally:
75+
LAST_END = time.time()
76+
LOG.info("Current request end timestamp %s", LAST_END)

src/huggingface_inference_toolkit/webservice_starlette.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
import asyncio
22
import base64
33
import os
4-
import threading
54
from pathlib import Path
65
from time import perf_counter
76

87
import orjson
8+
from anyio import Semaphore
99
from starlette.applications import Starlette
1010
from starlette.responses import PlainTextResponse, Response
1111
from starlette.routing import Route
1212

1313
from huggingface_inference_toolkit import idle
14-
from huggingface_inference_toolkit.async_utils import MAX_CONCURRENT_THREADS, MAX_THREADS_GUARD, async_handler_call
14+
from huggingface_inference_toolkit.async_utils import MAX_CONCURRENT_THREADS, MAX_THREADS_GUARD, async_call
1515
from huggingface_inference_toolkit.const import (
1616
HF_FRAMEWORK,
1717
HF_HUB_TOKEN,
@@ -32,9 +32,9 @@
3232
from huggingface_inference_toolkit.vertex_ai_utils import _load_repository_from_gcs
3333

3434
INFERENCE_HANDLERS = {}
35-
INFERENCE_HANDLERS_LOCK = threading.Lock()
35+
INFERENCE_HANDLERS_SEMAPHORE = Semaphore(1)
3636
MODEL_DOWNLOADED = False
37-
MODEL_DL_LOCK = threading.Lock()
37+
MODEL_DL_SEMAPHORE = Semaphore(1)
3838

3939

4040
async def prepare_model_artifacts():
@@ -43,7 +43,7 @@ async def prepare_model_artifacts():
4343
if idle.UNLOAD_IDLE:
4444
asyncio.create_task(idle.live_check_loop(), name="live_check_loop")
4545
else:
46-
_eager_model_dl()
46+
await async_call(_eager_model_dl)
4747
logger.info(f"Initializing model from directory:{HF_MODEL_DIR}")
4848
# 2. determine correct inference handler
4949
inference_handler = get_inference_handler_either_custom_or_default_handler(
@@ -54,7 +54,7 @@ async def prepare_model_artifacts():
5454

5555

5656
def _eager_model_dl():
57-
logger.debug("Model download")
57+
logger.info("Model download")
5858
global MODEL_DOWNLOADED
5959
from huggingface_inference_toolkit.heavy_utils import load_repository_from_hf
6060
# 1. check if model artifacts available in HF_MODEL_DIR
@@ -83,7 +83,8 @@ def _eager_model_dl():
8383
HF_MODEL_DIR: {HF_MODEL_DIR} and HF_MODEL_ID:{HF_MODEL_ID}"""
8484
)
8585
else:
86-
logger.debug("Model already downloaded in %s", HF_MODEL_DIR)
86+
logger.info("Model already downloaded in %s", HF_MODEL_DIR)
87+
logger.info("Model successfully downloaded")
8788
MODEL_DOWNLOADED = True
8889

8990

@@ -104,14 +105,19 @@ async def metrics(request):
104105

105106

106107
async def predict(request):
107-
with idle.request_witnesses():
108+
total_start_time = perf_counter()
109+
110+
async with idle.request_witnesses():
108111
logger.debug("Received request, scope %s", request.scope)
109112

110113
global INFERENCE_HANDLERS
111114

112115
if not MODEL_DOWNLOADED:
113-
with MODEL_DL_LOCK:
114-
await asyncio.to_thread(_eager_model_dl)
116+
async with MODEL_DL_SEMAPHORE:
117+
if not MODEL_DOWNLOADED:
118+
logger.info("Model dl semaphore acquired")
119+
await async_call(_eager_model_dl)
120+
logger.info("Model dl semaphore released")
115121
try:
116122
task = request.path_params.get("task", HF_TASK)
117123
# extracts content from request
@@ -152,28 +158,27 @@ async def predict(request):
152158
task = "sentence-embeddings"
153159
inference_handler = INFERENCE_HANDLERS.get(task)
154160
if not inference_handler:
155-
with INFERENCE_HANDLERS_LOCK:
161+
async with INFERENCE_HANDLERS_SEMAPHORE:
156162
if task not in INFERENCE_HANDLERS:
157163
inference_handler = get_inference_handler_either_custom_or_default_handler(
158164
HF_MODEL_DIR, task=task)
159165
INFERENCE_HANDLERS[task] = inference_handler
160166
else:
161167
inference_handler = INFERENCE_HANDLERS[task]
162-
# tracks request time
163-
start_time = perf_counter()
164168

165169
if should_discard_left() and isinstance(inference_handler, HuggingFaceHandler):
166170
deserialized_body['handler_params'] = {
167171
'request': request
168172
}
169173

170-
logger.debug("Calling inference handler prediction routine")
174+
logger.info("Calling inference handler prediction routine")
171175
# run async not blocking call
172-
pred = await async_handler_call(inference_handler, deserialized_body)
176+
pred = await async_call(inference_handler, deserialized_body)
173177

174178
# log request time
179+
end_time = perf_counter()
175180
logger.info(
176-
f"POST {request.url.path} | Duration: {(perf_counter()-start_time) *1000:.2f} ms"
181+
f"POST {request.url.path} Total request duration: {(end_time-total_start_time) *1000:.2f} ms"
177182
)
178183

179184
if should_discard_left() and pred is None:

0 commit comments

Comments
 (0)