diff --git a/google/cloud/bigquery/_pandas_helpers.py b/google/cloud/bigquery/_pandas_helpers.py index 5460f7ca7..7bd9f99b6 100644 --- a/google/cloud/bigquery/_pandas_helpers.py +++ b/google/cloud/bigquery/_pandas_helpers.py @@ -33,6 +33,7 @@ from google.cloud.bigquery import _pyarrow_helpers from google.cloud.bigquery import _versions_helpers +from google.cloud.bigquery import retry as bq_retry from google.cloud.bigquery import schema @@ -740,7 +741,7 @@ def _row_iterator_page_to_arrow(page, column_names, arrow_types): return pyarrow.RecordBatch.from_arrays(arrays, names=column_names) -def download_arrow_row_iterator(pages, bq_schema): +def download_arrow_row_iterator(pages, bq_schema, timeout=None): """Use HTTP JSON RowIterator to construct an iterable of RecordBatches. Args: @@ -751,6 +752,10 @@ def download_arrow_row_iterator(pages, bq_schema): Mapping[str, Any] \ ]]): A decription of the fields in result pages. + timeout (Optional[float]): + The number of seconds to wait for the underlying download to complete. + If ``None``, wait indefinitely. + Yields: :class:`pyarrow.RecordBatch` The next page of records as a ``pyarrow`` record batch. @@ -759,8 +764,16 @@ def download_arrow_row_iterator(pages, bq_schema): column_names = bq_to_arrow_schema(bq_schema) or [field.name for field in bq_schema] arrow_types = [bq_to_arrow_data_type(field) for field in bq_schema] - for page in pages: - yield _row_iterator_page_to_arrow(page, column_names, arrow_types) + if timeout is None: + for page in pages: + yield _row_iterator_page_to_arrow(page, column_names, arrow_types) + else: + start_time = time.monotonic() + for page in pages: + if time.monotonic() - start_time > timeout: + raise concurrent.futures.TimeoutError() + + yield _row_iterator_page_to_arrow(page, column_names, arrow_types) def _row_iterator_page_to_dataframe(page, column_names, dtypes): @@ -778,7 +791,7 @@ def _row_iterator_page_to_dataframe(page, column_names, dtypes): return pandas.DataFrame(columns, columns=column_names) -def download_dataframe_row_iterator(pages, bq_schema, dtypes): +def download_dataframe_row_iterator(pages, bq_schema, dtypes, timeout=None): """Use HTTP JSON RowIterator to construct a DataFrame. Args: @@ -792,14 +805,27 @@ def download_dataframe_row_iterator(pages, bq_schema, dtypes): dtypes(Mapping[str, numpy.dtype]): The types of columns in result data to hint construction of the resulting DataFrame. Not all column types have to be specified. + timeout (Optional[float]): + The number of seconds to wait for the underlying download to complete. + If ``None``, wait indefinitely. + Yields: :class:`pandas.DataFrame` The next page of records as a ``pandas.DataFrame`` record batch. """ bq_schema = schema._to_schema_fields(bq_schema) column_names = [field.name for field in bq_schema] - for page in pages: - yield _row_iterator_page_to_dataframe(page, column_names, dtypes) + + if timeout is None: + for page in pages: + yield _row_iterator_page_to_dataframe(page, column_names, dtypes) + else: + start_time = time.monotonic() + for page in pages: + if time.monotonic() - start_time > timeout: + raise concurrent.futures.TimeoutError() + + yield _row_iterator_page_to_dataframe(page, column_names, dtypes) def _bqstorage_page_to_arrow(page): @@ -928,6 +954,7 @@ def _download_table_bqstorage( if "@" in table.table_id: raise ValueError("Reading from a specific snapshot is not currently supported.") + start_time = time.monotonic() requested_streams = determine_requested_streams(preserve_order, max_stream_count) requested_session = bigquery_storage.types.stream.ReadSession( @@ -944,10 +971,16 @@ def _download_table_bqstorage( ArrowSerializationOptions.CompressionCodec(1) ) + retry_policy = ( + bq_retry.DEFAULT_RETRY.with_deadline(timeout) if timeout is not None else None + ) + session = bqstorage_client.create_read_session( parent="projects/{}".format(project_id), read_session=requested_session, max_stream_count=requested_streams, + retry=retry_policy, + timeout=timeout, ) _LOGGER.debug( @@ -983,8 +1016,6 @@ def _download_table_bqstorage( # Manually manage the pool to control shutdown behavior on timeout. pool = concurrent.futures.ThreadPoolExecutor(max_workers=max(1, total_streams)) wait_on_shutdown = True - start_time = time.time() - try: # Manually submit jobs and wait for download to complete rather # than using pool.map because pool.map continues running in the @@ -1006,7 +1037,7 @@ def _download_table_bqstorage( while not_done: # Check for timeout if timeout is not None: - elapsed = time.time() - start_time + elapsed = time.monotonic() - start_time if elapsed > timeout: wait_on_shutdown = False raise concurrent.futures.TimeoutError( diff --git a/google/cloud/bigquery/dbapi/cursor.py b/google/cloud/bigquery/dbapi/cursor.py index 014a6825e..bffd7678f 100644 --- a/google/cloud/bigquery/dbapi/cursor.py +++ b/google/cloud/bigquery/dbapi/cursor.py @@ -323,6 +323,8 @@ def _bqstorage_fetch(self, bqstorage_client): read_session=requested_session, # a single stream only, as DB API is not well-suited for multithreading max_stream_count=1, + retry=None, + timeout=None, ) if not read_session.streams: diff --git a/google/cloud/bigquery/retry.py b/google/cloud/bigquery/retry.py index 19012efd6..6fd458df5 100644 --- a/google/cloud/bigquery/retry.py +++ b/google/cloud/bigquery/retry.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from google.api_core import exceptions from google.api_core import retry import google.api_core.future.polling from google.auth import exceptions as auth_exceptions # type: ignore import requests.exceptions +_LOGGER = logging.getLogger(__name__) _RETRYABLE_REASONS = frozenset( ["rateLimitExceeded", "backendError", "internalError", "badGateway"] @@ -61,14 +64,17 @@ def _should_retry(exc): """Predicate for determining when to retry. - We retry if and only if the 'reason' is 'backendError' - or 'rateLimitExceeded'. + We retry if and only if the 'reason' is in _RETRYABLE_REASONS or is + in _UNSTRUCTURED_RETRYABLE_TYPES. """ - if not hasattr(exc, "errors") or len(exc.errors) == 0: - # Check for unstructured error returns, e.g. from GFE + try: + reason = exc.errors[0]["reason"] + except (AttributeError, IndexError, TypeError, KeyError): + # Fallback for when errors attribute is missing, empty, or not a dict + # or doesn't contain "reason" (e.g. gRPC exceptions). + _LOGGER.debug("Inspecting unstructured error for retry: %r", exc) return isinstance(exc, _UNSTRUCTURED_RETRYABLE_TYPES) - reason = exc.errors[0]["reason"] return reason in _RETRYABLE_REASONS diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index 195461006..88b673a8b 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -2152,7 +2152,10 @@ def to_arrow_iterable( timeout=timeout, ) tabledata_list_download = functools.partial( - _pandas_helpers.download_arrow_row_iterator, iter(self.pages), self.schema + _pandas_helpers.download_arrow_row_iterator, + iter(self.pages), + self.schema, + timeout=timeout, ) return self._to_page_iterable( bqstorage_download, @@ -2366,6 +2369,7 @@ def to_dataframe_iterable( iter(self.pages), self.schema, dtypes, + timeout=timeout, ) return self._to_page_iterable( bqstorage_download, diff --git a/tests/unit/job/test_query_pandas.py b/tests/unit/job/test_query_pandas.py index 4390309f1..e0e0438f5 100644 --- a/tests/unit/job/test_query_pandas.py +++ b/tests/unit/job/test_query_pandas.py @@ -179,6 +179,8 @@ def test_to_dataframe_bqstorage_preserve_order(query, table_read_options_kwarg): parent="projects/test-project", read_session=expected_session, max_stream_count=1, # Use a single stream to preserve row order. + retry=None, + timeout=None, ) @@ -593,6 +595,8 @@ def test_to_dataframe_bqstorage(table_read_options_kwarg): parent="projects/bqstorage-billing-project", read_session=expected_session, max_stream_count=0, # Use default number of streams for best performance. + retry=None, + timeout=None, ) bqstorage_client.read_rows.assert_called_once_with(stream_id) @@ -644,6 +648,8 @@ def test_to_dataframe_bqstorage_no_pyarrow_compression(): parent="projects/bqstorage-billing-project", read_session=expected_session, max_stream_count=0, + retry=None, + timeout=None, ) diff --git a/tests/unit/test__pandas_helpers.py b/tests/unit/test__pandas_helpers.py index a1cbb726b..6ec62c0b6 100644 --- a/tests/unit/test__pandas_helpers.py +++ b/tests/unit/test__pandas_helpers.py @@ -2252,3 +2252,134 @@ def fast_download_stream( results = list(result_gen) assert results == ["result_page"] + + +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") +@pytest.mark.parametrize( + "sleep_time, timeout, should_timeout", + [ + (0.1, 0.05, True), # Timeout case + (0, 10.0, False), # Success case + ], +) +def test_download_arrow_row_iterator_with_timeout( + module_under_test, sleep_time, timeout, should_timeout +): + bq_schema = [schema.SchemaField("name", "STRING")] + + # Mock page with to_arrow method + mock_page = mock.Mock() + mock_page.to_arrow.return_value = pyarrow.RecordBatch.from_arrays( + [pyarrow.array(["foo"])], + names=["name"], + ) + mock_page.__iter__ = lambda self: iter(["row1"]) + mock_page._columns = [["foo"]] + + def pages_gen(): + # First page yields quickly + yield mock_page + if sleep_time > 0: + time.sleep(sleep_time) + yield mock_page + + iterator = module_under_test.download_arrow_row_iterator( + pages_gen(), bq_schema, timeout=timeout + ) + + # First item should always succeed + next(iterator) + + if should_timeout: + with pytest.raises(concurrent.futures.TimeoutError): + next(iterator) + else: + # Should succeed and complete + results = list(iterator) + assert len(results) == 1 # 1 remaining item + + +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") +@pytest.mark.parametrize( + "sleep_time, timeout, should_timeout", + [ + (0.1, 0.05, True), # Timeout case + (0, 10.0, False), # Success case + ], +) +def test_download_dataframe_row_iterator_with_timeout( + module_under_test, sleep_time, timeout, should_timeout +): + bq_schema = [schema.SchemaField("name", "STRING")] + dtypes = {} + + # Mock page + mock_page = mock.Mock() + # Mock iterator for _row_iterator_page_to_dataframe checking next(iter(page)) + mock_page.__iter__ = lambda self: iter(["row1"]) + mock_page._columns = [["foo"]] + + def pages_gen(): + yield mock_page + if sleep_time > 0: + time.sleep(sleep_time) + yield mock_page + + iterator = module_under_test.download_dataframe_row_iterator( + pages_gen(), bq_schema, dtypes, timeout=timeout + ) + + next(iterator) + + if should_timeout: + with pytest.raises(concurrent.futures.TimeoutError): + next(iterator) + else: + results = list(iterator) + assert len(results) == 1 + + +@pytest.mark.skipif( + bigquery_storage is None, reason="Requires `google-cloud-bigquery-storage`" +) +def test_download_arrow_bqstorage_passes_timeout_to_create_read_session( + module_under_test, +): + # Mock dependencies + project_id = "test-project" + table = mock.Mock() + table.table_id = "test_table" + table.to_bqstorage.return_value = "projects/test/datasets/test/tables/test" + + bqstorage_client = mock.create_autospec( + bigquery_storage.BigQueryReadClient, instance=True + ) + # Mock create_read_session to return a session with no streams so the function returns early + # (Checking start of loop logic vs empty streams return) + session = mock.Mock() + # If streams is empty, _download_table_bqstorage returns early, which is fine for this test + session.streams = [] + bqstorage_client.create_read_session.return_value = session + + # Call the function + timeout = 123.456 + # download_arrow_bqstorage yields frames, so we need to iterate to trigger execution + list( + module_under_test.download_arrow_bqstorage( + project_id, table, bqstorage_client, timeout=timeout + ) + ) + + # Verify timeout and retry were passed + bqstorage_client.create_read_session.assert_called_once() + _, kwargs = bqstorage_client.create_read_session.call_args + assert "timeout" in kwargs + assert kwargs["timeout"] == timeout + + assert "retry" in kwargs + retry_policy = kwargs["retry"] + assert retry_policy is not None + # Check if deadline is set correctly in the retry policy + assert retry_policy._deadline == timeout diff --git a/tests/unit/test_client_retry.py b/tests/unit/test_client_retry.py index 6e49cc464..f0e7ac88f 100644 --- a/tests/unit/test_client_retry.py +++ b/tests/unit/test_client_retry.py @@ -23,6 +23,11 @@ PROJECT = "test-project" +# A deadline > 1.0s is required because the default retry (google.api_core.retry.Retry) +# has an initial delay of 1.0s. If the deadline is <= 1.0s, the first retry attempt +# (scheduled for now + 1.0s) will be rejected immediately as exceeding the deadline. +_RETRY_DEADLINE = 10.0 + def _make_credentials(): import google.auth.credentials @@ -83,7 +88,7 @@ def test_call_api_applying_custom_retry_on_timeout(global_time_lock): "api_request", side_effect=[TimeoutError, "result"], ) - retry = DEFAULT_RETRY.with_deadline(1).with_predicate( + retry = DEFAULT_RETRY.with_deadline(_RETRY_DEADLINE).with_predicate( lambda exc: isinstance(exc, TimeoutError) ) diff --git a/tests/unit/test_dbapi_cursor.py b/tests/unit/test_dbapi_cursor.py index 6fca4cec0..c5cad8c91 100644 --- a/tests/unit/test_dbapi_cursor.py +++ b/tests/unit/test_dbapi_cursor.py @@ -480,7 +480,11 @@ def fake_ensure_bqstorage_client(bqstorage_client=None, **kwargs): data_format=bigquery_storage.DataFormat.ARROW, ) mock_bqstorage_client.create_read_session.assert_called_once_with( - parent="projects/P", read_session=expected_session, max_stream_count=1 + parent="projects/P", + read_session=expected_session, + max_stream_count=1, + retry=None, + timeout=None, ) # Check the data returned. diff --git a/tests/unit/test_table.py b/tests/unit/test_table.py index 97a1b4916..a8397247d 100644 --- a/tests/unit/test_table.py +++ b/tests/unit/test_table.py @@ -4125,6 +4125,10 @@ def test_to_dataframe_tqdm_error(self): # Warn that a progress bar was requested, but creating the tqdm # progress bar failed. for warning in warned: # pragma: NO COVER + # Pyparsing warnings appear to be coming from a transitive + # dependency and are unrelated to the code under test. + if "Pyparsing" in warning.category.__name__: + continue self.assertIn( warning.category, [UserWarning, DeprecationWarning, tqdm.TqdmExperimentalWarning], @@ -6853,6 +6857,8 @@ def test_to_arrow_iterable_w_bqstorage_max_stream_count(preserve_order): parent=mock.ANY, read_session=mock.ANY, max_stream_count=max_stream_count if not preserve_order else 1, + retry=None, + timeout=None, ) @@ -6888,4 +6894,6 @@ def test_to_dataframe_iterable_w_bqstorage_max_stream_count(preserve_order): parent=mock.ANY, read_session=mock.ANY, max_stream_count=max_stream_count if not preserve_order else 1, + retry=None, + timeout=None, )