From 5bed92c27df7c187a602f72d02293362aa298159 Mon Sep 17 00:00:00 2001 From: Michal Zukowski Date: Thu, 4 Jun 2020 12:21:04 +0200 Subject: [PATCH] Limit number of download workers --- b2sdk/api.py | 14 ++- b2sdk/transfer/inbound/download_manager.py | 24 ++++- b2sdk/transfer/inbound/downloader/parallel.py | 93 +++++++++++++------ test/v0/test_bucket.py | 2 + test/v1/test_bucket.py | 2 + 5 files changed, 104 insertions(+), 31 deletions(-) diff --git a/b2sdk/api.py b/b2sdk/api.py index 8ffde1df8..b1c6afeb7 100644 --- a/b2sdk/api.py +++ b/b2sdk/api.py @@ -43,17 +43,21 @@ def url_for_api(info, api_name): class Services(object): """ Gathers objects that provide high level logic over raw api usage. """ - def __init__(self, session, max_upload_workers=10, max_copy_workers=10): + def __init__( + self, session, max_upload_workers=10, max_copy_workers=10, max_download_workers=None + ): """ Initialize Services object using given session. :param b2sdk.v1.Session session: :param int max_upload_workers: a number of upload threads :param int max_copy_workers: a number of copy threads + :param int max_download_workers: a maximum number of download threads. + If ``None`` then :class:`~b2sdk.v1.DownloadManager` ``4 * DEFAULT_MAX_STREAMS`` is used. """ self.session = session self.large_file = LargeFileServices(self) - self.download_manager = DownloadManager(self) + self.download_manager = DownloadManager(self, max_download_workers=max_download_workers) self.upload_manager = UploadManager(self, max_upload_workers=max_upload_workers) self.copy_manager = CopyManager(self, max_copy_workers=max_copy_workers) self.emerger = Emerger(self) @@ -89,7 +93,8 @@ def __init__( cache=None, raw_api=None, max_upload_workers=10, - max_copy_workers=10 + max_copy_workers=10, + max_download_workers=None, ): """ Initialize the API using the given account info. @@ -116,12 +121,15 @@ def __init__( :param int max_upload_workers: a number of upload threads, default is 10 :param int max_copy_workers: a number of copy threads, default is 10 + :param int max_download_workers: a maximum number of download threads. + If ``None`` then :class:`~b2sdk.v1.DownloadManager` ``4 * DEFAULT_MAX_STREAMS`` is used. """ self.session = B2Session(account_info=account_info, cache=cache, raw_api=raw_api) self.services = Services( self.session, max_upload_workers=max_upload_workers, max_copy_workers=max_copy_workers, + max_download_workers=max_download_workers, ) @property diff --git a/b2sdk/transfer/inbound/download_manager.py b/b2sdk/transfer/inbound/download_manager.py index 22cc00538..69ce99d18 100644 --- a/b2sdk/transfer/inbound/download_manager.py +++ b/b2sdk/transfer/inbound/download_manager.py @@ -10,6 +10,9 @@ import logging import six +import threading + +from contextlib import contextmanager from b2sdk.download_dest import DownloadDestProgressWrapper from b2sdk.progress import DoNothingProgressListener @@ -30,6 +33,17 @@ logger = logging.getLogger(__name__) +class ProtectedSemaphore(object): + def __init__(self, semaphore): + self._lock = threading.RLock() + self._semaphore = semaphore + + @contextmanager + def get_semaphore(self): + with self._lock: + yield self._semaphore + + @six.add_metaclass(B2TraceMetaAbstract) class DownloadManager(object): """ @@ -46,18 +60,26 @@ class DownloadManager(object): MIN_CHUNK_SIZE = 8192 # ~1MB file will show ~1% progress increment MAX_CHUNK_SIZE = 1024**2 - def __init__(self, services): + def __init__(self, services, max_download_workers=None): """ Initialize the DownloadManager using the given services object. :param b2sdk.v1.Services services: + :param int max_download_workers: a maximum number of download threads. + If ``None`` then ``4 * DEFAULT_MAX_STREAMS`` is used. """ self.services = services + self.max_download_workers = max_download_workers or 4 * self.DEFAULT_MAX_STREAMS + self.max_workers_semaphore = ProtectedSemaphore( + threading.BoundedSemaphore(self.max_download_workers) + ) + self.strategies = [ ParallelDownloader( max_streams=self.DEFAULT_MAX_STREAMS, min_part_size=self.DEFAULT_MIN_PART_SIZE, + protected_semaphore=self.max_workers_semaphore, min_chunk_size=self.MIN_CHUNK_SIZE, max_chunk_size=self.MAX_CHUNK_SIZE, ), diff --git a/b2sdk/transfer/inbound/downloader/parallel.py b/b2sdk/transfer/inbound/downloader/parallel.py index 9adfffb31..f62a77f4e 100644 --- a/b2sdk/transfer/inbound/downloader/parallel.py +++ b/b2sdk/transfer/inbound/downloader/parallel.py @@ -43,13 +43,14 @@ class ParallelDownloader(AbstractDownloader): # FINISH_HASHING_BUFFER_SIZE = 1024**2 - def __init__(self, max_streams, min_part_size, *args, **kwargs): + def __init__(self, max_streams, min_part_size, protected_semaphore, *args, **kwargs): """ :param max_streams: maximum number of simultaneous streams :param min_part_size: minimum amount of data a single stream will retrieve, in bytes """ self.max_streams = max_streams self.min_part_size = min_part_size + self.protected_semaphore = protected_semaphore super(ParallelDownloader, self).__init__(*args, **kwargs) def is_suitable(self, metadata, progress_listener): @@ -132,31 +133,59 @@ def _finish_hashing(self, first_part, file, hasher, content_length): def _get_parts( self, response, session, writer, hasher, first_part, parts_to_download, chunk_size ): - stream = FirstPartDownloaderThread( - response, - hasher, - session, - writer, - first_part, - chunk_size, - ) - stream.start() - streams = [stream] - - for part in parts_to_download: - stream = NonHashingDownloaderThread( - response.request.url, - session, - writer, - part, - chunk_size, - ) - stream.start() - streams.append(stream) + with self.protected_semaphore.get_semaphore() as semaphore: + semaphore.acquire() + try: + stream = FirstPartDownloaderThread( + response, + hasher, + session, + writer, + first_part, + chunk_size, + semaphore, + ) + stream.start() + except Exception: + semaphore.release() + raise + + streams = [stream] + + for part in parts_to_download: + semaphore.acquire() + try: + stream = NonHashingDownloaderThread( + response.request.url, + session, + writer, + part, + chunk_size, + semaphore, + ) + stream.start() + except Exception: + semaphore.release() + raise + streams.append(stream) for stream in streams: stream.join() +class ClosableQueue(queue.Queue): + def __init__(self, *args, **kwargs): + super(ClosableQueue, self).__init__(*args, **kwargs) + self._closed = False + + def put(self, *args, **kwargs): + if self._closed: + raise RuntimeError('queue closed') + return super(ClosableQueue, self).put(*args, **kwargs) + + def close(self): + self._closed = True + + class WriterThread(threading.Thread): """ A thread responsible for keeping a queue of data chunks to write to a file-like object and for actually writing them down. @@ -183,7 +212,7 @@ class WriterThread(threading.Thread): def __init__(self, file, max_queue_depth): self.file = file - self.queue = queue.Queue(max_queue_depth) + self.queue = ClosableQueue(max_queue_depth) self.total = 0 super(WriterThread, self).__init__() @@ -204,25 +233,35 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.queue.put((True, None, None)) + # any thread trying to put somthing on queue would fail with RuntimeError + self.queue.close() self.join() class AbstractDownloaderThread(threading.Thread): - def __init__(self, session, writer, part_to_download, chunk_size): + def __init__(self, session, writer, part_to_download, chunk_size, semaphore): """ :param session: raw_api wrapper :param writer: where to write data :param part_to_download: PartToDownload object :param chunk_size: internal buffer size to use for writing and hashing + :param semaphore: already acquired semaphore that downloader thread has to release on finish """ self.session = session self.writer = writer self.part_to_download = part_to_download self.chunk_size = chunk_size + self.semaphore = semaphore super(AbstractDownloaderThread, self).__init__() - @abstractmethod def run(self): + try: + self.run_download() + finally: + self.semaphore.release() + + @abstractmethod + def run_download(self): pass @@ -236,7 +275,7 @@ def __init__(self, response, hasher, *args, **kwargs): self.hasher = hasher super(FirstPartDownloaderThread, self).__init__(*args, **kwargs) - def run(self): + def run_download(self): writer_queue_put = self.writer.queue.put hasher_update = self.hasher.update first_offset = self.part_to_download.local_range.start @@ -291,7 +330,7 @@ def __init__(self, url, *args, **kwargs): self.url = url super(NonHashingDownloaderThread, self).__init__(*args, **kwargs) - def run(self): + def run_download(self): writer_queue_put = self.writer.queue.put start_range = self.part_to_download.local_range.start actual_part_size = self.part_to_download.local_range.size() diff --git a/test/v0/test_bucket.py b/test/v0/test_bucket.py index 0dc923c04..4584fae7e 100644 --- a/test/v0/test_bucket.py +++ b/test/v0/test_bucket.py @@ -743,6 +743,7 @@ def setUp(self): force_chunk_size=2, max_streams=999, min_part_size=2, + protected_semaphore=self.bucket.api.services.download_manager.max_workers_semaphore, ) ] @@ -792,5 +793,6 @@ def setUp(self): force_chunk_size=3, max_streams=2, min_part_size=2, + protected_semaphore=self.bucket.api.services.download_manager.max_workers_semaphore, ) ] diff --git a/test/v1/test_bucket.py b/test/v1/test_bucket.py index 8f71157a4..e8e1b840f 100644 --- a/test/v1/test_bucket.py +++ b/test/v1/test_bucket.py @@ -843,6 +843,7 @@ def setUp(self): force_chunk_size=2, max_streams=999, min_part_size=2, + protected_semaphore=self.bucket.api.services.download_manager.max_workers_semaphore, ) ] @@ -892,5 +893,6 @@ def setUp(self): force_chunk_size=3, max_streams=2, min_part_size=2, + protected_semaphore=self.bucket.api.services.download_manager.max_workers_semaphore, ) ]