Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions sdks/python/apache_beam/io/gcp/bigquery_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import sys
import time
import uuid
import threading
from json.decoder import JSONDecodeError
from typing import Optional
from typing import Sequence
Expand Down Expand Up @@ -66,6 +67,8 @@
from apache_beam.typehints.typehints import Any
from apache_beam.utils import retry
from apache_beam.utils.histogram import LinearBucket
from cachetools import TTLCache, cachedmethod, Cache
from cachetools.keys import hashkey

# Protect against environments where bigquery library is not available.
try:
Expand Down Expand Up @@ -139,6 +142,12 @@ class ExportCompression(object):
SNAPPY = 'SNAPPY'
NONE = 'NONE'

class _NonNoneTTLCache(TTLCache):
"""TTLCache that does not store None values."""
def __setitem__(self, key, value, cache_setitem=Cache.__setitem__):
if value is not None:
super().__setitem__(key=key, value=value)


def default_encoder(obj):
if isinstance(obj, decimal.Decimal):
Expand Down Expand Up @@ -359,6 +368,9 @@ class BigQueryWrapper(object):

HISTOGRAM_METRIC_LOGGER = MetricLogger()

_TABLE_CACHE = _NonNoneTTLCache(maxsize=1024, ttl=300)
_TABLE_CACHE_LOCK = threading.RLock()

def __init__(self, client=None, temp_dataset_id=None, temp_table_ref=None):
self.client = client or BigQueryWrapper._bigquery_client(PipelineOptions())
self.gcp_bq_client = client or gcp_bigquery.Client(
Expand Down Expand Up @@ -788,11 +800,17 @@ def _insert_all_rows(
int(time.time() * 1000) - started_millis)
return not errors, errors

@cachedmethod(
cache=lambda self: self._TABLE_CACHE,
lock=lambda self: self._TABLE_CACHE_LOCK,
key=lambda self, project_id, dataset_id, table_id: hashkey(
project_id, dataset_id, table_id),
)
@retry.with_exponential_backoff(
num_retries=MAX_RETRIES,
retry_filter=retry.retry_on_server_errors_timeout_or_quota_issues_filter)
def get_table(self, project_id, dataset_id, table_id):
"""Lookup a table's metadata object.
"""Lookup a table's metadata object. (TTL cached at class level).

Args:
client: bigquery.BigqueryV2 instance
Expand All @@ -806,9 +824,8 @@ def get_table(self, project_id, dataset_id, table_id):
HttpError: if lookup failed.
"""
request = bigquery.BigqueryTablesGetRequest(
projectId=project_id, datasetId=dataset_id, tableId=table_id)
response = self.client.tables.Get(request)
return response
projectId=project_id, datasetId=dataset_id, tableId=table_id)
return self.client.tables.Get(request)

def _create_table(
self,
Expand Down
74 changes: 74 additions & 0 deletions sdks/python/apache_beam/io/gcp/bigquery_tools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,80 @@ def test_temporary_dataset_is_unique(self, patched_time_sleep):
wrapper.create_temporary_dataset('project-id', 'location')
self.assertTrue(client.datasets.Get.called)

def test_get_table_invokes_tables_get_and_caches_result(self):

from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper

client = mock.Mock()
client.tables = mock.Mock()

returned_table = mock.Mock(name="BigQueryTable")
client.tables.Get = mock.Mock(return_value=returned_table)

wrapper = BigQueryWrapper(client=client)

project_id = "my-project"
dataset_id = "my_dataset"
table_id = "my_table"

table1 = wrapper.get_table(project_id, dataset_id, table_id)

assert table1 is returned_table
assert client.tables.Get.call_count == 1

(request,), _ = client.tables.Get.call_args
assert isinstance(request, bigquery.BigqueryTablesGetRequest)
assert request.projectId == project_id
assert request.datasetId == dataset_id
assert request.tableId == table_id

table2 = wrapper.get_table(project_id, dataset_id, table_id)

assert table2 is returned_table
assert client.tables.Get.call_count == 1 # still 1 => cached

def test_get_table_shared_cache_across_wrapper_instances(self):
from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper

# ensure isolation -> clear the shared cache before the test
BigQueryWrapper._TABLE_CACHE.clear()

client = mock.Mock()
client.tables = mock.Mock()

returned_table = mock.Mock(name="BigQueryTable")
client.tables.Get = mock.Mock(return_value=returned_table)

project_id = "my-project"
dataset_id = "my_dataset"
table_id = "my_table"

w1 = BigQueryWrapper(client=client)
w2 = BigQueryWrapper(client=client)
w3 = BigQueryWrapper(client=client)

# first call -> populate cache
t1 = w1.get_table(project_id, dataset_id, table_id)
assert t1 is returned_table
assert client.tables.Get.call_count == 1

# verify request shape (from first call)
(request,), _ = client.tables.Get.call_args
assert isinstance(request, bigquery.BigqueryTablesGetRequest)
assert request.projectId == project_id
assert request.datasetId == dataset_id
assert request.tableId == table_id

# calls from DIFFERENT wrapper instances -> hit the SAME cache entry
t2 = w2.get_table(project_id, dataset_id, table_id)
t3 = w3.get_table(project_id, dataset_id, table_id)

assert t2 is returned_table
assert t3 is returned_table

# still 1 -> record cached across instances
assert client.tables.Get.call_count == 1

def test_get_or_create_dataset_created(self):
client = mock.Mock()
client.datasets.Get.side_effect = HttpError(
Expand Down
Loading