Skip to content

Commit c7492cc

Browse files
committed
LAST CHECKPOINT
1 parent ea3c337 commit c7492cc

File tree

4 files changed

+132
-56
lines changed

4 files changed

+132
-56
lines changed

src/databricks/sql/client.py

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
transform_paramstyle,
3232
ColumnTable,
3333
ColumnQueue,
34+
concat_chunked_tables,
35+
merge_columnar,
3436
)
3537
from databricks.sql.parameters.native import (
3638
DbsqlParameterBase,
@@ -1454,36 +1456,25 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
14541456
results = self.results.next_n_rows(size)
14551457
n_remaining_rows = size - results.num_rows
14561458
self._next_row_index += results.num_rows
1459+
partial_result_chunks = [results]
14571460

1461+
TOTAL_SIZE = results.num_rows
14581462
while (
14591463
n_remaining_rows > 0
14601464
and not self.has_been_closed_server_side
14611465
and self.has_more_rows
14621466
):
1467+
print(f"TOTAL DATA ROWS {TOTAL_SIZE}")
14631468
self._fill_results_buffer()
14641469
partial_results = self.results.next_n_rows(n_remaining_rows)
1465-
results = pyarrow.concat_tables([results, partial_results])
1470+
partial_result_chunks.append(partial_results)
14661471
n_remaining_rows -= partial_results.num_rows
14671472
self._next_row_index += partial_results.num_rows
1473+
TOTAL_SIZE += partial_results.num_rows
14681474

1469-
return results
1470-
1471-
def merge_columnar(self, result1, result2):
1472-
"""
1473-
Function to merge / combining the columnar results into a single result
1474-
:param result1:
1475-
:param result2:
1476-
:return:
1477-
"""
1478-
1479-
if result1.column_names != result2.column_names:
1480-
raise ValueError("The columns in the results don't match")
1481-
1482-
merged_result = [
1483-
result1.column_table[i] + result2.column_table[i]
1484-
for i in range(result1.num_columns)
1485-
]
1486-
return ColumnTable(merged_result, result1.column_names)
1475+
return concat_chunked_tables(partial_result_chunks)
1476+
1477+
14871478

14881479
def fetchmany_columnar(self, size: int):
14891480
"""
@@ -1504,7 +1495,7 @@ def fetchmany_columnar(self, size: int):
15041495
):
15051496
self._fill_results_buffer()
15061497
partial_results = self.results.next_n_rows(n_remaining_rows)
1507-
results = self.merge_columnar(results, partial_results)
1498+
results = merge_columnar(results, partial_results)
15081499
n_remaining_rows -= partial_results.num_rows
15091500
self._next_row_index += partial_results.num_rows
15101501

@@ -1514,20 +1505,20 @@ def fetchall_arrow(self) -> "pyarrow.Table":
15141505
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
15151506
results = self.results.remaining_rows()
15161507
self._next_row_index += results.num_rows
1517-
1518-
print("Server side has more rows", self.has_more_rows)
15191508

1509+
partial_result_chunks = [results]
1510+
print("Server side has more rows", self.has_more_rows)
1511+
TOTAL_SIZE = results.num_rows
1512+
15201513
while not self.has_been_closed_server_side and self.has_more_rows:
1521-
print(f"RESULT SIZE TOTAL {results.num_rows}")
1514+
print(f"TOTAL DATA ROWS {TOTAL_SIZE}")
15221515
self._fill_results_buffer()
15231516
partial_results = self.results.remaining_rows()
1524-
if isinstance(results, ColumnTable) and isinstance(
1525-
partial_results, ColumnTable
1526-
):
1527-
results = self.merge_columnar(results, partial_results)
1528-
else:
1529-
results = pyarrow.concat_tables([results, partial_results])
1517+
partial_result_chunks.append(partial_results)
15301518
self._next_row_index += partial_results.num_rows
1519+
TOTAL_SIZE += partial_results.num_rows
1520+
1521+
results = concat_chunked_tables(partial_result_chunks)
15311522

15321523
# If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table
15331524
# Valid only for metadata commands result set
@@ -1547,7 +1538,7 @@ def fetchall_columnar(self):
15471538
while not self.has_been_closed_server_side and self.has_more_rows:
15481539
self._fill_results_buffer()
15491540
partial_results = self.results.remaining_rows()
1550-
results = self.merge_columnar(results, partial_results)
1541+
results = merge_columnar(results, partial_results)
15511542
self._next_row_index += partial_results.num_rows
15521543

15531544
return results

src/databricks/sql/cloudfetch/downloader.py

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
1010
from databricks.sql.exc import Error
1111
from databricks.sql.types import SSLOptions
12+
from databricks.sql.common.http import DatabricksHttpClient, HttpMethod
1213

1314
logger = logging.getLogger(__name__)
1415

@@ -70,6 +71,7 @@ def __init__(
7071
self.settings = settings
7172
self.link = link
7273
self._ssl_options = ssl_options
74+
self._http_client = DatabricksHttpClient.get_instance()
7375

7476
def run(self) -> DownloadedFile:
7577
"""
@@ -89,27 +91,20 @@ def run(self) -> DownloadedFile:
8991
ResultSetDownloadHandler._validate_link(
9092
self.link, self.settings.link_expiry_buffer_secs
9193
)
92-
93-
session = requests.Session()
94-
session.mount("http://", HTTPAdapter(max_retries=retryPolicy))
95-
session.mount("https://", HTTPAdapter(max_retries=retryPolicy))
96-
97-
try:
94+
95+
with self._http_client.execute(
96+
method=HttpMethod.GET,
97+
url=self.link.fileLink,
98+
timeout=self.settings.download_timeout,
99+
verify=self._ssl_options.tls_verify,
100+
headers=self.link.httpHeaders
101+
) as response:
98102
print_text = [
99103

100104
]
101-
start_time = time.time()
102-
# Get the file via HTTP request
103-
response = session.get(
104-
self.link.fileLink,
105-
timeout=self.settings.download_timeout,
106-
verify=self._ssl_options.tls_verify,
107-
headers=self.link.httpHeaders
108-
# TODO: Pass cert from `self._ssl_options`
109-
)
105+
110106
response.raise_for_status()
111-
end_time = time.time()
112-
print_text.append(f"Downloaded file in {end_time - start_time} seconds")
107+
113108
# Save (and decompress if needed) the downloaded file
114109
compressed_data = response.content
115110
decompressed_data = (
@@ -144,9 +139,63 @@ def run(self) -> DownloadedFile:
144139
self.link.startRowOffset,
145140
self.link.rowCount,
146141
)
147-
finally:
148-
if session:
149-
session.close()
142+
# session = requests.Session()
143+
# session.mount("http://", HTTPAdapter(max_retries=retryPolicy))
144+
# session.mount("https://", HTTPAdapter(max_retries=retryPolicy))
145+
146+
# try:
147+
# print_text = [
148+
149+
# ]
150+
# start_time = time.time()
151+
# # Get the file via HTTP request
152+
# response = session.get(
153+
# self.link.fileLink,
154+
# timeout=self.settings.download_timeout,
155+
# verify=self._ssl_options.tls_verify,
156+
# headers=self.link.httpHeaders
157+
# # TODO: Pass cert from `self._ssl_options`
158+
# )
159+
# response.raise_for_status()
160+
# end_time = time.time()
161+
# print_text.append(f"Downloaded file in {end_time - start_time} seconds")
162+
# # Save (and decompress if needed) the downloaded file
163+
# compressed_data = response.content
164+
# decompressed_data = (
165+
# ResultSetDownloadHandler._decompress_data(compressed_data)
166+
# if self.settings.is_lz4_compressed
167+
# else compressed_data
168+
# )
169+
170+
# # The size of the downloaded file should match the size specified from TSparkArrowResultLink
171+
# if len(decompressed_data) != self.link.bytesNum:
172+
# logger.debug(
173+
# "ResultSetDownloadHandler: downloaded file size {} does not match the expected value {}".format(
174+
# len(decompressed_data), self.link.bytesNum
175+
# )
176+
# )
177+
178+
# logger.debug(
179+
# "ResultSetDownloadHandler: successfully downloaded file, offset {}, row count {}".format(
180+
# self.link.startRowOffset, self.link.rowCount
181+
# )
182+
# )
183+
184+
# print_text.append(
185+
# f"Downloaded file startRowOffset - {self.link.startRowOffset} - rowCount - {self.link.rowCount}"
186+
# )
187+
188+
# for text in print_text:
189+
# print(text)
190+
191+
# return DownloadedFile(
192+
# decompressed_data,
193+
# self.link.startRowOffset,
194+
# self.link.rowCount,
195+
# )
196+
# finally:
197+
# if session:
198+
# session.close()
150199

151200
@staticmethod
152201
def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int):

src/databricks/sql/common/http.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from contextlib import contextmanager
88
from typing import Generator
99
import logging
10-
10+
import time
1111
logger = logging.getLogger(__name__)
1212

1313

@@ -70,7 +70,10 @@ def execute(
7070
logger.info("Executing HTTP request: %s with url: %s", method.value, url)
7171
response = None
7272
try:
73+
start_time = time.time()
7374
response = self.session.request(method.value, url, **kwargs)
75+
end_time = time.time()
76+
print(f"Downloaded file in {end_time - start_time} seconds")
7477
yield response
7578
except Exception as e:
7679
logger.error("Error executing HTTP request in DatabricksHttpClient: %s", e)

src/databricks/sql/utils.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ def __eq__(self, other):
137137
)
138138

139139

140+
class ArrowStreamTable:
141+
def __init__(self, arrow_stream, num_rows):
142+
self.arrow_stream = arrow_stream
143+
self.num_rows = num_rows
144+
140145
class ColumnQueue(ResultSetQueue):
141146
def __init__(self, column_table: ColumnTable):
142147
self.column_table = column_table
@@ -263,11 +268,12 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
263268
return self._create_empty_table()
264269
logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows))
265270
results = self.table.slice(0, 0)
271+
partial_result_chunks = [results]
266272
while num_rows > 0 and self.table:
267273
# Get remaining of num_rows or the rest of the current table, whichever is smaller
268274
length = min(num_rows, self.table.num_rows - self.table_row_index)
269275
table_slice = self.table.slice(self.table_row_index, length)
270-
results = pyarrow.concat_tables([results, table_slice])
276+
partial_result_chunks.append(table_slice)
271277
self.table_row_index += table_slice.num_rows
272278

273279
# Replace current table with the next table if we are at the end of the current table
@@ -277,7 +283,7 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
277283
num_rows -= table_slice.num_rows
278284

279285
logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows))
280-
return results
286+
return concat_chunked_tables(partial_result_chunks)
281287

282288
def remaining_rows(self) -> "pyarrow.Table":
283289
"""
@@ -290,19 +296,19 @@ def remaining_rows(self) -> "pyarrow.Table":
290296
# Return empty pyarrow table to cause retry of fetch
291297
return self._create_empty_table()
292298
results = self.table.slice(0, 0)
293-
299+
partial_result_chunks = [results]
294300
print("remaining_rows call")
295301
print(f"self.table.num_rows - {self.table.num_rows}")
296302
while self.table:
297303
table_slice = self.table.slice(
298304
self.table_row_index, self.table.num_rows - self.table_row_index
299305
)
300-
results = pyarrow.concat_tables([results, table_slice])
306+
partial_result_chunks.append(table_slice)
301307
self.table_row_index += table_slice.num_rows
302308
self.table = self._create_next_table()
303309
self.table_row_index = 0
304310
print(f"results.num_rows - {results.num_rows}")
305-
return results
311+
return concat_chunked_tables(partial_result_chunks)
306312

307313
def _create_next_table(self) -> Union["pyarrow.Table", None]:
308314
logger.debug(
@@ -771,3 +777,30 @@ def _create_python_tuple(t_col_value_wrapper):
771777
result[i] = None
772778

773779
return tuple(result)
780+
781+
782+
def concat_chunked_tables(tables: List[Union["pyarrow.Table", ColumnTable]]) -> Union["pyarrow.Table", ColumnTable]:
783+
if isinstance(tables[0], ColumnTable):
784+
base_table = tables[0]
785+
for table in tables[1:]:
786+
base_table = merge_columnar(base_table, table)
787+
return base_table
788+
else:
789+
return pyarrow.concat_tables(tables)
790+
791+
def merge_columnar(result1: ColumnTable, result2: ColumnTable) -> ColumnTable:
792+
"""
793+
Function to merge / combining the columnar results into a single result
794+
:param result1:
795+
:param result2:
796+
:return:
797+
"""
798+
799+
if result1.column_names != result2.column_names:
800+
raise ValueError("The columns in the results don't match")
801+
802+
merged_result = [
803+
result1.column_table[i] + result2.column_table[i]
804+
for i in range(result1.num_columns)
805+
]
806+
return ColumnTable(merged_result, result1.column_names)

0 commit comments

Comments
 (0)