-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy paththread_pool.py
More file actions
346 lines (280 loc) · 10.2 KB
/
thread_pool.py
File metadata and controls
346 lines (280 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
#
# Copyright 2025 Splunk Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""A simple thread pool implementation."""
import multiprocessing
import queue
import threading
import traceback
from time import time
import logging
class ThreadPool:
"""A simple thread pool implementation."""
_high_watermark = 0.2
_resize_window = 10
def __init__(self, min_size=1, max_size=128, task_queue_size=1024, daemon=True):
assert task_queue_size
if not min_size or min_size <= 0:
min_size = multiprocessing.cpu_count()
if not max_size or max_size <= 0:
max_size = multiprocessing.cpu_count() * 8
self._min_size = min_size
self._max_size = max_size
self._daemon = daemon
self._work_queue = queue.Queue(task_queue_size)
self._thrs = []
for _ in range(min_size):
thr = threading.Thread(target=self._run)
self._thrs.append(thr)
self._admin_queue = queue.Queue()
self._admin_thr = threading.Thread(target=self._do_admin)
self._last_resize_time = time()
self._last_size = min_size
self._lock = threading.Lock()
self._occupied_threads = 0
self._count_lock = threading.Lock()
self._started = False
def start(self):
"""Start threads in the pool."""
with self._lock:
if self._started:
return
self._started = True
for thr in self._thrs:
thr.daemon = self._daemon
thr.start()
self._admin_thr.start()
logging.info("ThreadPool started.")
def tear_down(self):
"""Tear down thread pool."""
with self._lock:
if not self._started:
return
self._started = False
for thr in self._thrs:
self._work_queue.put(None, block=True)
self._admin_queue.put(None)
if not self._daemon:
logging.info("Wait for threads to stop.")
for thr in self._thrs:
thr.join()
self._admin_thr.join()
logging.info("ThreadPool stopped.")
def enqueue_funcs(self, funcs, block=True):
"""Run jobs in a fire and forget way, no result will be handled over to
clients.
:param funcs: tuple/list-like or generator like object, func
shall be callable
"""
if not self._started:
logging.info("ThreadPool has already stopped.")
return
for func in funcs:
self._work_queue.put(func, block)
def apply_async(self, func, args=(), kwargs=None, callback=None):
"""
:param func: callable
:param args: free params
:param kwargs: named params
:callback: when func is done and without exception, call the callback
:return AsyncResult, clients can poll or wait the result through it
"""
if not self._started:
logging.info("ThreadPool has already stopped.")
return None
res = AsyncResult(func, args, kwargs, callback)
self._work_queue.put(res)
return res
def apply(self, func, args=(), kwargs=None):
"""
:param func: callable
:param args: free params
:param kwargs: named params
:return whatever the func returns
"""
if not self._started:
logging.info("ThreadPool has already stopped.")
return None
res = self.apply_async(func, args, kwargs)
return res.get()
def size(self):
return self._last_size
def resize(self, new_size):
"""Resize the pool size, spawn or destroy threads if necessary."""
if new_size <= 0:
return
if self._lock.locked() or not self._started:
logging.info(
"Try to resize thread pool during the tear " "down process, do nothing"
)
return
with self._lock:
self._remove_exited_threads_with_lock()
size = self._last_size
self._last_size = new_size
if new_size > size:
for _ in range(new_size - size):
thr = threading.Thread(target=self._run)
thr.daemon = self._daemon
thr.start()
self._thrs.append(thr)
elif new_size < size:
for _ in range(size - new_size):
self._work_queue.put(None)
logging.info("Finished ThreadPool resizing. New size=%d", new_size)
def _remove_exited_threads_with_lock(self):
"""Join the exited threads last time when resize was called."""
joined_thrs = set()
for thr in self._thrs:
if not thr.is_alive():
try:
if not thr.daemon:
thr.join(timeout=0.5)
joined_thrs.add(thr.ident)
except RuntimeError:
pass
if joined_thrs:
live_thrs = []
for thr in self._thrs:
if thr.ident not in joined_thrs:
live_thrs.append(thr)
self._thrs = live_thrs
def _do_resize_according_to_loads(self):
if (
self._last_resize_time
and time() - self._last_resize_time < self._resize_window
):
return
thr_size = self._last_size
free_thrs = thr_size - self._occupied_threads
work_size = self._work_queue.qsize()
logging.debug(
"current_thr_size=%s, free_thrs=%s, work_size=%s",
thr_size,
free_thrs,
work_size,
)
if work_size and work_size > free_thrs:
if thr_size < self._max_size:
thr_size = min(thr_size * 2, self._max_size)
self.resize(thr_size)
elif free_thrs > 0:
free = free_thrs * 1.0
if free / thr_size >= self._high_watermark and free_thrs >= 2:
# 20 % thrs are idle, tear down half of the idle ones
thr_size = thr_size - int(free_thrs // 2)
if thr_size > self._min_size:
self.resize(thr_size)
self._last_resize_time = time()
def _do_admin(self):
admin_q = self._admin_queue
resize_win = self._resize_window
while 1:
try:
wakup = admin_q.get(timeout=resize_win + 1)
except queue.Empty:
self._do_resize_according_to_loads()
continue
if wakup is None:
break
else:
self._do_resize_according_to_loads()
logging.info(
"ThreadPool admin thread=%s stopped.", threading.current_thread().getName()
)
def _run(self):
"""Threads callback func, run forever to handle jobs from the job
queue."""
work_queue = self._work_queue
count_lock = self._count_lock
while 1:
logging.debug("Going to get job")
func = work_queue.get()
if func is None:
break
if not self._started:
break
logging.debug("Going to exec job")
with count_lock:
self._occupied_threads += 1
try:
func()
except Exception:
logging.error(traceback.format_exc())
with count_lock:
self._occupied_threads -= 1
logging.debug("Done with exec job")
logging.info("Thread work_queue_size=%d", work_queue.qsize())
logging.debug("Worker thread %s stopped.", threading.current_thread().getName())
class AsyncResult:
def __init__(self, func, args, kwargs, callback):
self._func = func
self._args = args
self._kwargs = kwargs
self._callback = callback
self._q = queue.Queue()
def __call__(self):
try:
if self._args and self._kwargs:
res = self._func(*self._args, **self._kwargs)
elif self._args:
res = self._func(*self._args)
elif self._kwargs:
res = self._func(**self._kwargs)
else:
res = self._func()
except Exception as e:
self._q.put(e)
return
else:
self._q.put(res)
if self._callback is not None:
self._callback()
def get(self, timeout=None):
"""Return the result when it arrives.
If timeout is not None and the result does not arrive within
timeout seconds then multiprocessing.TimeoutError is raised. If
the remote call raised an exception then that exception will be
reraised by get().
"""
try:
res = self._q.get(timeout=timeout)
except queue.Empty:
raise multiprocessing.TimeoutError("Timed out")
if isinstance(res, Exception):
raise res
return res
def wait(self, timeout=None):
"""Wait until the result is available or until timeout seconds pass."""
try:
res = self._q.get(timeout=timeout)
except queue.Empty:
pass
else:
self._q.put(res)
def ready(self):
"""Return whether the call has completed."""
return len(self._q)
def successful(self):
"""Return whether the call completed without raising an exception.
Will raise AssertionError if the result is not ready.
"""
if not self.ready():
raise AssertionError("Function is not ready")
res = self._q.get()
self._q.put(res)
if isinstance(res, Exception):
return False
return True