diff --git a/google/cloud/bigtable/data/exceptions.py b/google/cloud/bigtable/data/exceptions.py index a5e6d31ba..4597a8d1d 100644 --- a/google/cloud/bigtable/data/exceptions.py +++ b/google/cloud/bigtable/data/exceptions.py @@ -142,7 +142,7 @@ def __repr__(self): # TODO: When working on mutations batcher, rework exception handling to guarantee that -# MutationsExceptionGroup only stores FailedMutationEntryErrors. +# MutationsExceptionGroup only stores FailedMutationEntryErrors. class MutationsExceptionGroup(_BigtableExceptionGroup): """ Represents one or more exceptions that occur during a bulk mutation operation diff --git a/google/cloud/bigtable/helpers.py b/google/cloud/bigtable/helpers.py index 78af43089..faa305dc8 100644 --- a/google/cloud/bigtable/helpers.py +++ b/google/cloud/bigtable/helpers.py @@ -29,3 +29,30 @@ def batched(iterable: Iterable[T], n) -> Generator[Tuple[T, ...], None, None]: while batch: yield batch batch = tuple(islice(it, n)) + + +class _MappableAttributesMixin: + """ + Mixin for classes that need some of their attribute names remapped. + + This is for taking some of the classes from the data client row filters + and row range classes that are 1:1 with their legacy client counterparts but with + some of their attributes renamed. To use in a class, override the base class with this mixin + class and define a map _attribute_map from legacy client attributes to data client + attributes. + + Attributes are remapped and redefined in __init__ as well as getattr/setattr. + """ + + def __init__(self, *args, **kwargs): + new_kwargs = {self._attribute_map.get(k, k): v for (k, v) in kwargs.items()} + super(_MappableAttributesMixin, self).__init__(*args, **new_kwargs) + + def __getattr__(self, name): + if name not in self._attribute_map: + raise AttributeError + return getattr(self, self._attribute_map[name]) + + def __setattr__(self, name, value): + attribute = self._attribute_map.get(name, name) + super(_MappableAttributesMixin, self).__setattr__(attribute, value) diff --git a/google/cloud/bigtable/row_filters.py b/google/cloud/bigtable/row_filters.py index a7581e339..d304fdf5e 100644 --- a/google/cloud/bigtable/row_filters.py +++ b/google/cloud/bigtable/row_filters.py @@ -44,37 +44,11 @@ RowFilterUnion, ConditionalRowFilter as BaseConditionalRowFilter, ) +from google.cloud.bigtable.helpers import _MappableAttributesMixin _PACK_I64 = struct.Struct(">q").pack -class _MappableAttributesMixin: - """ - Mixin for classes that need some of their attribute names remapped. - - This is for taking some of the classes from the data client row filters - that are 1:1 with their legacy client counterparts but with some of their - attributes renamed. To use in a class, override the base class with this mixin - class and define a map _attribute_map from legacy client attributes to data client - attributes. - - Attributes are remapped and redefined in __init__ as well as getattr/setattr. - """ - - def __init__(self, *args, **kwargs): - new_kwargs = {self._attribute_map.get(k, k): v for (k, v) in kwargs.items()} - super(_MappableAttributesMixin, self).__init__(*args, **new_kwargs) - - def __getattr__(self, name): - if name not in self._attribute_map: - raise AttributeError - return getattr(self, self._attribute_map[name]) - - def __setattr__(self, name, value): - attribute = self._attribute_map.get(name, name) - super(_MappableAttributesMixin, self).__setattr__(attribute, value) - - # The classes defined below are to provide constructors and members # that have an interface that does not match the one used by the data # client, for backwards compatibility purposes. diff --git a/google/cloud/bigtable/row_set.py b/google/cloud/bigtable/row_set.py index 2bc436d54..ef6c711ab 100644 --- a/google/cloud/bigtable/row_set.py +++ b/google/cloud/bigtable/row_set.py @@ -15,7 +15,12 @@ """User-friendly container for Google Cloud Bigtable RowSet """ -from google.cloud._helpers import _to_bytes # type: ignore +from google.cloud._helpers import _to_bytes +from google.cloud.bigtable.data.read_rows_query import ( + RowRange as BaseRowRange, + ReadRowsQuery, +) +from google.cloud.bigtable.helpers import _MappableAttributesMixin class RowSet(object): @@ -26,30 +31,25 @@ class RowSet(object): """ def __init__(self): - self.row_keys = [] - self.row_ranges = [] + self._read_rows_query = ReadRowsQuery() def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented - if len(other.row_keys) != len(self.row_keys): - return False - - if len(other.row_ranges) != len(self.row_ranges): - return False - - if not set(other.row_keys) == set(self.row_keys): - return False - - if not set(other.row_ranges) == set(self.row_ranges): - return False - - return True + return self._read_rows_query == other._read_rows_query def __ne__(self, other): return not self == other + @property + def row_keys(self): + return self._read_rows_query.row_keys + + @property + def row_ranges(self): + return self._read_rows_query.row_ranges + def add_row_key(self, row_key): """Add row key to row_keys list. @@ -63,7 +63,7 @@ def add_row_key(self, row_key): :type row_key: bytes :param row_key: The key of a row to read """ - self.row_keys.append(row_key) + self._read_rows_query.add_key(row_key) def add_row_range(self, row_range): """Add row_range to row_ranges list. @@ -78,7 +78,7 @@ def add_row_range(self, row_range): :type row_range: class:`RowRange` :param row_range: The row range object having start and end key """ - self.row_ranges.append(row_range) + self._read_rows_query.add_range(row_range) def add_row_range_from_keys( self, start_key=None, end_key=None, start_inclusive=True, end_inclusive=False @@ -110,7 +110,7 @@ def add_row_range_from_keys( considered inclusive. The default is False (exclusive). """ row_range = RowRange(start_key, end_key, start_inclusive, end_inclusive) - self.row_ranges.append(row_range) + self._read_rows_query.add_range(row_range) def add_row_range_with_prefix(self, row_key_prefix): """Add row range to row_ranges list that start with the row_key_prefix from the row keys @@ -136,22 +136,21 @@ def _update_message_request(self, message): :type message: class:`data_messages_v2_pb2.ReadRowsRequest` :param message: The ``ReadRowsRequest`` protobuf """ - for each in self.row_keys: + for each in self._read_rows_query.row_keys: message.rows.row_keys._pb.append(_to_bytes(each)) - for each in self.row_ranges: - r_kwrags = each.get_range_kwargs() - message.rows.row_ranges.append(r_kwrags) + for each in self._read_rows_query.row_ranges: + message.rows.row_ranges.append(each._to_pb()) -class RowRange(object): +class RowRange(_MappableAttributesMixin, BaseRowRange): """Convenience wrapper of google.bigtable.v2.RowRange - :type start_key: bytes + :type start_key: str | bytes :param start_key: (Optional) Start key of the row range. If left empty, will be interpreted as the empty string. - :type end_key: bytes + :type end_key: str | bytes :param end_key: (Optional) End key of the row range. If left empty, will be interpreted as the empty string and range will be unbounded on the high end. @@ -165,49 +164,26 @@ class RowRange(object): considered inclusive. The default is False (exclusive). """ - def __init__( - self, start_key=None, end_key=None, start_inclusive=True, end_inclusive=False - ): - self.start_key = start_key - self.start_inclusive = start_inclusive - self.end_key = end_key - self.end_inclusive = end_inclusive + _attribute_map = { + "start_inclusive": "start_is_inclusive", + "end_inclusive": "end_is_inclusive", + } def _key(self): - """A tuple key that uniquely describes this field. - - Used to compute this instance's hashcode and evaluate equality. - - Returns: - Tuple[str]: The contents of this :class:`.RowRange`. - """ - return (self.start_key, self.start_inclusive, self.end_key, self.end_inclusive) + return ( + self.start_key, + self.end_key, + self.start_is_inclusive, + self.end_is_inclusive, + ) def __hash__(self): return hash(self._key()) - def __eq__(self, other): - if not isinstance(other, self.__class__): - return NotImplemented - return self._key() == other._key() - - def __ne__(self, other): - return not self == other - def get_range_kwargs(self): """Convert row range object to dict which can be passed to google.bigtable.v2.RowRange add method. """ - range_kwargs = {} - if self.start_key is not None: - start_key_key = "start_key_open" - if self.start_inclusive: - start_key_key = "start_key_closed" - range_kwargs[start_key_key] = _to_bytes(self.start_key) - - if self.end_key is not None: - end_key_key = "end_key_open" - if self.end_inclusive: - end_key_key = "end_key_closed" - range_kwargs[end_key_key] = _to_bytes(self.end_key) - return range_kwargs + return { + descriptor.name: value for descriptor, value in self._pb._pb.ListFields() + } diff --git a/tests/unit/v2_client/test_row_set.py b/tests/unit/v2_client/test_row_set.py index 1a33be720..1964922bd 100644 --- a/tests/unit/v2_client/test_row_set.py +++ b/tests/unit/v2_client/test_row_set.py @@ -49,39 +49,6 @@ def test_row_set__eq__type_differ(): assert not (row_set1 == row_set2) -def test_row_set__eq__len_row_keys_differ(): - from google.cloud.bigtable.row_set import RowSet - - row_key1 = b"row_key1" - row_key2 = b"row_key1" - - row_set1 = RowSet() - row_set2 = RowSet() - - row_set1.add_row_key(row_key1) - row_set1.add_row_key(row_key2) - row_set2.add_row_key(row_key2) - - assert not (row_set1 == row_set2) - - -def test_row_set__eq__len_row_ranges_differ(): - from google.cloud.bigtable.row_set import RowRange - from google.cloud.bigtable.row_set import RowSet - - row_range1 = RowRange(b"row_key4", b"row_key9") - row_range2 = RowRange(b"row_key4", b"row_key9") - - row_set1 = RowSet() - row_set2 = RowSet() - - row_set1.add_row_range(row_range1) - row_set1.add_row_range(row_range2) - row_set2.add_row_range(row_range2) - - assert not (row_set1 == row_set2) - - def test_row_set__eq__row_keys_differ(): from google.cloud.bigtable.row_set import RowSet @@ -229,8 +196,8 @@ def test_row_range_constructor(): start_key = "row_key1" end_key = "row_key9" row_range = RowRange(start_key, end_key) - assert start_key == row_range.start_key - assert end_key == row_range.end_key + assert start_key == row_range.start_key.decode() + assert end_key == row_range.end_key.decode() assert row_range.start_inclusive assert not row_range.end_inclusive diff --git a/tests/unit/v2_client/test_table.py b/tests/unit/v2_client/test_table.py index 902367202..1ddeba28d 100644 --- a/tests/unit/v2_client/test_table.py +++ b/tests/unit/v2_client/test_table.py @@ -1580,7 +1580,7 @@ def test__create_row_request_row_range_start_key(): from google.cloud.bigtable_v2.types import RowRange table_name = "table_name" - start_key = b"start_key" + start_key = b"begin_key" result = _create_row_request(table_name, start_key=start_key) expected_result = _ReadRowsRequestPB(table_name=table_name) row_range = RowRange(start_key_closed=start_key) @@ -1606,7 +1606,7 @@ def test__create_row_request_row_range_both_keys(): from google.cloud.bigtable_v2.types import RowRange table_name = "table_name" - start_key = b"start_key" + start_key = b"begin_key" end_key = b"end_key" result = _create_row_request(table_name, start_key=start_key, end_key=end_key) row_range = RowRange(start_key_closed=start_key, end_key_open=end_key) @@ -1620,7 +1620,7 @@ def test__create_row_request_row_range_both_keys_inclusive(): from google.cloud.bigtable_v2.types import RowRange table_name = "table_name" - start_key = b"start_key" + start_key = b"begin_key" end_key = b"end_key" result = _create_row_request( table_name, start_key=start_key, end_key=end_key, end_inclusive=True