From 0d85c114ff3018ef37a5581c42438329ffb53b72 Mon Sep 17 00:00:00 2001 From: AbgarSim Date: Thu, 18 Dec 2025 18:41:04 +0200 Subject: [PATCH] [BEAM-34076] Added TTL caching for BigQuery table definitions --- .../apache_beam/io/gcp/bigquery_tools.py | 25 ++++++- .../apache_beam/io/gcp/bigquery_tools_test.py | 74 +++++++++++++++++++ 2 files changed, 95 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_tools.py index ddab941f9278..3c3f954257a5 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py @@ -37,6 +37,7 @@ import sys import time import uuid +import threading from json.decoder import JSONDecodeError from typing import Optional from typing import Sequence @@ -66,6 +67,8 @@ from apache_beam.typehints.typehints import Any from apache_beam.utils import retry from apache_beam.utils.histogram import LinearBucket +from cachetools import TTLCache, cachedmethod, Cache +from cachetools.keys import hashkey # Protect against environments where bigquery library is not available. try: @@ -139,6 +142,12 @@ class ExportCompression(object): SNAPPY = 'SNAPPY' NONE = 'NONE' +class _NonNoneTTLCache(TTLCache): + """TTLCache that does not store None values.""" + def __setitem__(self, key, value, cache_setitem=Cache.__setitem__): + if value is not None: + super().__setitem__(key=key, value=value) + def default_encoder(obj): if isinstance(obj, decimal.Decimal): @@ -359,6 +368,9 @@ class BigQueryWrapper(object): HISTOGRAM_METRIC_LOGGER = MetricLogger() + _TABLE_CACHE = _NonNoneTTLCache(maxsize=1024, ttl=300) + _TABLE_CACHE_LOCK = threading.RLock() + def __init__(self, client=None, temp_dataset_id=None, temp_table_ref=None): self.client = client or BigQueryWrapper._bigquery_client(PipelineOptions()) self.gcp_bq_client = client or gcp_bigquery.Client( @@ -788,11 +800,17 @@ def _insert_all_rows( int(time.time() * 1000) - started_millis) return not errors, errors + @cachedmethod( + cache=lambda self: self._TABLE_CACHE, + lock=lambda self: self._TABLE_CACHE_LOCK, + key=lambda self, project_id, dataset_id, table_id: hashkey( + project_id, dataset_id, table_id), + ) @retry.with_exponential_backoff( num_retries=MAX_RETRIES, retry_filter=retry.retry_on_server_errors_timeout_or_quota_issues_filter) def get_table(self, project_id, dataset_id, table_id): - """Lookup a table's metadata object. + """Lookup a table's metadata object. (TTL cached at class level). Args: client: bigquery.BigqueryV2 instance @@ -806,9 +824,8 @@ def get_table(self, project_id, dataset_id, table_id): HttpError: if lookup failed. """ request = bigquery.BigqueryTablesGetRequest( - projectId=project_id, datasetId=dataset_id, tableId=table_id) - response = self.client.tables.Get(request) - return response + projectId=project_id, datasetId=dataset_id, tableId=table_id) + return self.client.tables.Get(request) def _create_table( self, diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py b/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py index 2594e6728e0e..67ed1867da1f 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py @@ -292,6 +292,80 @@ def test_temporary_dataset_is_unique(self, patched_time_sleep): wrapper.create_temporary_dataset('project-id', 'location') self.assertTrue(client.datasets.Get.called) + def test_get_table_invokes_tables_get_and_caches_result(self): + + from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper + + client = mock.Mock() + client.tables = mock.Mock() + + returned_table = mock.Mock(name="BigQueryTable") + client.tables.Get = mock.Mock(return_value=returned_table) + + wrapper = BigQueryWrapper(client=client) + + project_id = "my-project" + dataset_id = "my_dataset" + table_id = "my_table" + + table1 = wrapper.get_table(project_id, dataset_id, table_id) + + assert table1 is returned_table + assert client.tables.Get.call_count == 1 + + (request,), _ = client.tables.Get.call_args + assert isinstance(request, bigquery.BigqueryTablesGetRequest) + assert request.projectId == project_id + assert request.datasetId == dataset_id + assert request.tableId == table_id + + table2 = wrapper.get_table(project_id, dataset_id, table_id) + + assert table2 is returned_table + assert client.tables.Get.call_count == 1 # still 1 => cached + + def test_get_table_shared_cache_across_wrapper_instances(self): + from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper + + # ensure isolation -> clear the shared cache before the test + BigQueryWrapper._TABLE_CACHE.clear() + + client = mock.Mock() + client.tables = mock.Mock() + + returned_table = mock.Mock(name="BigQueryTable") + client.tables.Get = mock.Mock(return_value=returned_table) + + project_id = "my-project" + dataset_id = "my_dataset" + table_id = "my_table" + + w1 = BigQueryWrapper(client=client) + w2 = BigQueryWrapper(client=client) + w3 = BigQueryWrapper(client=client) + + # first call -> populate cache + t1 = w1.get_table(project_id, dataset_id, table_id) + assert t1 is returned_table + assert client.tables.Get.call_count == 1 + + # verify request shape (from first call) + (request,), _ = client.tables.Get.call_args + assert isinstance(request, bigquery.BigqueryTablesGetRequest) + assert request.projectId == project_id + assert request.datasetId == dataset_id + assert request.tableId == table_id + + # calls from DIFFERENT wrapper instances -> hit the SAME cache entry + t2 = w2.get_table(project_id, dataset_id, table_id) + t3 = w3.get_table(project_id, dataset_id, table_id) + + assert t2 is returned_table + assert t3 is returned_table + + # still 1 -> record cached across instances + assert client.tables.Get.call_count == 1 + def test_get_or_create_dataset_created(self): client = mock.Mock() client.datasets.Get.side_effect = HttpError(