22
33from concurrent .futures import ThreadPoolExecutor , Future
44import threading
5- from typing import Callable , List , Optional , Union , Generic , TypeVar
5+ from typing import Callable , List , Optional , Union , Generic , TypeVar , Tuple , Optional
66
77from databricks .sql .cloudfetch .downloader import (
88 ResultSetDownloadHandler ,
1111)
1212from databricks .sql .exc import Error
1313from databricks .sql .types import SSLOptions
14-
14+ from databricks . sql . telemetry . models . event import StatementType
1515from databricks .sql .thrift_api .TCLIService .ttypes import TSparkArrowResultLink
1616
1717logger = logging .getLogger (__name__ )
@@ -46,17 +46,22 @@ def __init__(
4646 lz4_compressed : bool ,
4747 ssl_options : SSLOptions ,
4848 expiry_callback : Optional [Callable [[TSparkArrowResultLink ], None ]] = None ,
49+ session_id_hex : Optional [str ],
50+ statement_id : str ,
51+ chunk_id : int ,
4952 ):
50- self ._pending_links : List [TSparkArrowResultLink ] = []
51- for link in links :
53+ self ._pending_links : List [Tuple [int , TSparkArrowResultLink ]] = []
54+ self .chunk_id = chunk_id
55+ for i , link in enumerate (links , start = chunk_id ):
5256 if link .rowCount <= 0 :
5357 continue
5458 logger .debug (
55- "ResultFileDownloadManager: adding file link, start offset {}, row count: {}" .format (
56- link .startRowOffset , link .rowCount
59+ "ResultFileDownloadManager: adding file link, chunk id {}, start offset {}, row count: {}" .format (
60+ i , link .startRowOffset , link .rowCount
5761 )
5862 )
59- self ._pending_links .append (link )
63+ self ._pending_links .append ((i , link ))
64+ self .chunk_id += len (links )
6065
6166 self ._max_download_threads : int = max_download_threads
6267
@@ -67,6 +72,8 @@ def __init__(
6772 self ._downloadable_result_settings = DownloadableResultSettings (lz4_compressed )
6873 self ._ssl_options = ssl_options
6974 self ._expiry_callback = expiry_callback
75+ self .session_id_hex = session_id_hex
76+ self .statement_id = statement_id
7077
7178 def get_next_downloaded_file (self , next_row_offset : int ) -> DownloadedFile :
7279 """
@@ -149,15 +156,20 @@ def _schedule_downloads(self):
149156 while (len (self ._download_tasks ) < self ._max_download_threads ) and (
150157 len (self ._pending_links ) > 0
151158 ):
152- link = self ._pending_links .pop (0 )
159+ chunk_id , link = self ._pending_links .pop (0 )
153160 logger .debug (
154- "- start: {}, row count: {}" .format (link .startRowOffset , link .rowCount )
161+ "- chunk: {}, start: {}, row count: {}" .format (
162+ chunk_id , link .startRowOffset , link .rowCount
163+ )
155164 )
156165 handler = ResultSetDownloadHandler (
157166 settings = self ._downloadable_result_settings ,
158167 link = link ,
159168 ssl_options = self ._ssl_options ,
160169 expiry_callback = self ._expiry_callback ,
170+ chunk_id = chunk_id ,
171+ session_id_hex = self .session_id_hex ,
172+ statement_id = self .statement_id ,
161173 )
162174 future = self ._thread_pool .submit (handler .run )
163175 task = TaskWithMetadata (future , link )
@@ -181,7 +193,8 @@ def add_links(self, links: List[TSparkArrowResultLink]):
181193 link .startRowOffset , link .rowCount
182194 )
183195 )
184- self ._pending_links .append (link )
196+ self ._pending_links .append ((self .chunk_id , link ))
197+ self .chunk_id += 1
185198
186199 self ._schedule_downloads ()
187200
@@ -190,4 +203,5 @@ def _shutdown_manager(self):
190203 self ._pending_links = []
191204 self ._download_tasks = []
192205 self ._thread_pool .shutdown (wait = False )
193- self ._download_condition .notify_all ()
206+ with self ._download_condition :
207+ self ._download_condition .notify_all ()
0 commit comments