diff --git a/.github/.OwlBot.lock.yaml b/.github/.OwlBot.lock.yaml index 02a4dedce..1b3cb6c52 100644 --- a/.github/.OwlBot.lock.yaml +++ b/.github/.OwlBot.lock.yaml @@ -13,5 +13,5 @@ # limitations under the License. docker: image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest - digest: sha256:240b5bcc2bafd450912d2da2be15e62bc6de2cf839823ae4bf94d4f392b451dc -# created: 2023-06-03T21:25:37.968717478Z + digest: sha256:ddf4551385d566771dc713090feb7b4c1164fb8a698fe52bbe7670b24236565b +# created: 2023-06-27T13:04:21.96690344Z diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index ac491adaf..201c73157 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -25,7 +25,7 @@ from google.cloud.bigtable.data._helpers import _attempt_timeout_generator # mutate_rows requests are limited to this number of mutations -from google.cloud.bigtable.data.mutations import MUTATE_ROWS_REQUEST_MUTATION_LIMIT +from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT if TYPE_CHECKING: from google.cloud.bigtable_v2.services.bigtable.async_client import ( @@ -65,10 +65,10 @@ def __init__( """ # check that mutations are within limits total_mutations = sum(len(entry.mutations) for entry in mutation_entries) - if total_mutations > MUTATE_ROWS_REQUEST_MUTATION_LIMIT: + if total_mutations > _MUTATE_ROWS_REQUEST_MUTATION_LIMIT: raise ValueError( "mutate_rows requests can contain at most " - f"{MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations across " + f"{_MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations across " f"all entries. Found {total_mutations}." ) # create partial function to pass to trigger rpc call diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 3a5831799..983c55a8d 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -73,7 +73,7 @@ from google.cloud.bigtable.data import ShardedQuery # used by read_rows_sharded to limit how many requests are attempted in parallel -CONCURRENCY_LIMIT = 10 +_CONCURRENCY_LIMIT = 10 # used to register instance data with the client for channel warming _WarmedInstanceKey = namedtuple( @@ -153,7 +153,7 @@ def __init__( self._channel_init_time = time.monotonic() self._channel_refresh_tasks: list[asyncio.Task[None]] = [] try: - self.start_background_channel_refresh() + self._start_background_channel_refresh() except RuntimeError: warnings.warn( f"{self.__class__.__name__} should be started in an " @@ -162,7 +162,7 @@ def __init__( stacklevel=2, ) - def start_background_channel_refresh(self) -> None: + def _start_background_channel_refresh(self) -> None: """ Starts a background task to ping and warm each channel in the pool Raises: @@ -308,7 +308,7 @@ async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: await self._ping_and_warm_instances(channel, instance_key) else: # refresh tasks aren't active. start them as background tasks - self.start_background_channel_refresh() + self._start_background_channel_refresh() async def _remove_instance_registration( self, instance_id: str, owner: TableAsync @@ -370,7 +370,7 @@ def get_table( ) async def __aenter__(self): - self.start_background_channel_refresh() + self._start_background_channel_refresh() return self async def __aexit__(self, exc_type, exc_val, exc_tb): @@ -453,7 +453,7 @@ def __init__( async def read_rows_stream( self, - query: ReadRowsQuery | dict[str, Any], + query: ReadRowsQuery, *, operation_timeout: float | None = None, per_request_timeout: float | None = None, @@ -521,7 +521,7 @@ async def read_rows_stream( async def read_rows( self, - query: ReadRowsQuery | dict[str, Any], + query: ReadRowsQuery, *, operation_timeout: float | None = None, per_request_timeout: float | None = None, @@ -608,10 +608,10 @@ async def read_rows_sharded( timeout_generator = _attempt_timeout_generator( operation_timeout, operation_timeout ) - # submit shards in batches if the number of shards goes over CONCURRENCY_LIMIT + # submit shards in batches if the number of shards goes over _CONCURRENCY_LIMIT batched_queries = [ - sharded_query[i : i + CONCURRENCY_LIMIT] - for i in range(0, len(sharded_query), CONCURRENCY_LIMIT) + sharded_query[i : i + _CONCURRENCY_LIMIT] + for i in range(0, len(sharded_query), _CONCURRENCY_LIMIT) ] # run batches and collect results results_list = [] @@ -942,7 +942,7 @@ async def bulk_mutate_rows( async def check_and_mutate_row( self, row_key: str | bytes, - predicate: RowFilter | dict[str, Any] | None, + predicate: RowFilter | None, *, true_case_mutations: Mutation | list[Mutation] | None = None, false_case_mutations: Mutation | list[Mutation] | None = None, @@ -994,12 +994,12 @@ async def check_and_mutate_row( ): false_case_mutations = [false_case_mutations] false_case_dict = [m._to_dict() for m in false_case_mutations or []] - if predicate is not None and not isinstance(predicate, dict): - predicate = predicate.to_dict() metadata = _make_metadata(self.table_name, self.app_profile_id) result = await self.client._gapic_client.check_and_mutate_row( request={ - "predicate_filter": predicate, + "predicate_filter": predicate._to_dict() + if predicate is not None + else None, "true_mutations": true_case_dict, "false_mutations": false_case_dict, "table_name": self.table_name, diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 25aafc2a1..b4c021ff7 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -26,7 +26,7 @@ from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync from google.cloud.bigtable.data._async._mutate_rows import ( - MUTATE_ROWS_REQUEST_MUTATION_LIMIT, + _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, ) from google.cloud.bigtable.data.mutations import Mutation @@ -143,7 +143,7 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] self._has_capacity(next_count, next_size) # make sure not to exceed per-request mutation count limits and (batch_mutation_count + next_count) - <= MUTATE_ROWS_REQUEST_MUTATION_LIMIT + <= _MUTATE_ROWS_REQUEST_MUTATION_LIMIT ): # room for new mutation; add to batch end_idx += 1 diff --git a/google/cloud/bigtable/data/mutations.py b/google/cloud/bigtable/data/mutations.py index de1b3b137..06db21879 100644 --- a/google/cloud/bigtable/data/mutations.py +++ b/google/cloud/bigtable/data/mutations.py @@ -19,14 +19,14 @@ from abc import ABC, abstractmethod from sys import getsizeof -from google.cloud.bigtable.data.read_modify_write_rules import MAX_INCREMENT_VALUE -# special value for SetCell mutation timestamps. If set, server will assign a timestamp -SERVER_SIDE_TIMESTAMP = -1 +from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE +# special value for SetCell mutation timestamps. If set, server will assign a timestamp +_SERVER_SIDE_TIMESTAMP = -1 # mutation entries above this should be rejected -MUTATE_ROWS_REQUEST_MUTATION_LIMIT = 100_000 +_MUTATE_ROWS_REQUEST_MUTATION_LIMIT = 100_000 class Mutation(ABC): @@ -112,7 +112,7 @@ def __init__( if isinstance(new_value, str): new_value = new_value.encode() elif isinstance(new_value, int): - if abs(new_value) > MAX_INCREMENT_VALUE: + if abs(new_value) > _MAX_INCREMENT_VALUE: raise ValueError( "int values must be between -2**63 and 2**63 (64-bit signed int)" ) @@ -123,9 +123,9 @@ def __init__( # use current timestamp, with milisecond precision timestamp_micros = time.time_ns() // 1000 timestamp_micros = timestamp_micros - (timestamp_micros % 1000) - if timestamp_micros < SERVER_SIDE_TIMESTAMP: + if timestamp_micros < _SERVER_SIDE_TIMESTAMP: raise ValueError( - "timestamp_micros must be positive (or -1 for server-side timestamp)" + f"timestamp_micros must be positive (or {_SERVER_SIDE_TIMESTAMP} for server-side timestamp)" ) self.family = family self.qualifier = qualifier @@ -145,7 +145,7 @@ def _to_dict(self) -> dict[str, Any]: def is_idempotent(self) -> bool: """Check if the mutation is idempotent""" - return self.timestamp_micros != SERVER_SIDE_TIMESTAMP + return self.timestamp_micros != _SERVER_SIDE_TIMESTAMP @dataclass @@ -208,9 +208,9 @@ def __init__(self, row_key: bytes | str, mutations: Mutation | list[Mutation]): mutations = [mutations] if len(mutations) == 0: raise ValueError("mutations must not be empty") - elif len(mutations) > MUTATE_ROWS_REQUEST_MUTATION_LIMIT: + elif len(mutations) > _MUTATE_ROWS_REQUEST_MUTATION_LIMIT: raise ValueError( - f"entries must have <= {MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations" + f"entries must have <= {_MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations" ) self.row_key = row_key self.mutations = tuple(mutations) diff --git a/google/cloud/bigtable/data/read_modify_write_rules.py b/google/cloud/bigtable/data/read_modify_write_rules.py index aa282b1a6..3a3eb3752 100644 --- a/google/cloud/bigtable/data/read_modify_write_rules.py +++ b/google/cloud/bigtable/data/read_modify_write_rules.py @@ -17,7 +17,7 @@ import abc # value must fit in 64-bit signed integer -MAX_INCREMENT_VALUE = (1 << 63) - 1 +_MAX_INCREMENT_VALUE = (1 << 63) - 1 class ReadModifyWriteRule(abc.ABC): @@ -37,7 +37,7 @@ class IncrementRule(ReadModifyWriteRule): def __init__(self, family: str, qualifier: bytes | str, increment_amount: int = 1): if not isinstance(increment_amount, int): raise TypeError("increment_amount must be an integer") - if abs(increment_amount) > MAX_INCREMENT_VALUE: + if abs(increment_amount) > _MAX_INCREMENT_VALUE: raise ValueError( "increment_amount must be between -2**63 and 2**63 (64-bit signed int)" ) diff --git a/google/cloud/bigtable/data/read_rows_query.py b/google/cloud/bigtable/data/read_rows_query.py index 7d7e1f99f..cf3cd316c 100644 --- a/google/cloud/bigtable/data/read_rows_query.py +++ b/google/cloud/bigtable/data/read_rows_query.py @@ -35,11 +35,16 @@ class _RangePoint: def __hash__(self) -> int: return hash((self.key, self.is_inclusive)) + def __eq__(self, other: Any) -> bool: + if not isinstance(other, _RangePoint): + return NotImplemented + return self.key == other.key and self.is_inclusive == other.is_inclusive + -@dataclass class RowRange: - start: _RangePoint | None - end: _RangePoint | None + """ + Represents a range of keys in a ReadRowsQuery + """ def __init__( self, @@ -48,11 +53,27 @@ def __init__( start_is_inclusive: bool | None = None, end_is_inclusive: bool | None = None, ): + """ + Args: + - start_key: The start key of the range. If empty, the range is unbounded on the left. + - end_key: The end key of the range. If empty, the range is unbounded on the right. + - start_is_inclusive: Whether the start key is inclusive. If None, the start key is + inclusive. + - end_is_inclusive: Whether the end key is inclusive. If None, the end key is not inclusive. + Raises: + - ValueError: if start_key is greater than end_key, or start_is_inclusive, + or end_is_inclusive is set when the corresponding key is None, + or start_key or end_key is not a string or bytes. + """ + # convert empty key inputs to None for consistency + start_key = None if not start_key else start_key + end_key = None if not end_key else end_key # check for invalid combinations of arguments if start_is_inclusive is None: start_is_inclusive = True elif start_key is None: raise ValueError("start_is_inclusive must be set with start_key") + if end_is_inclusive is None: end_is_inclusive = False elif end_key is None: @@ -66,29 +87,62 @@ def __init__( end_key = end_key.encode() elif end_key is not None and not isinstance(end_key, bytes): raise ValueError("end_key must be a string or bytes") + # ensure that start_key is less than or equal to end_key + if start_key is not None and end_key is not None and start_key > end_key: + raise ValueError("start_key must be less than or equal to end_key") - self.start = ( + self._start: _RangePoint | None = ( _RangePoint(start_key, start_is_inclusive) if start_key is not None else None ) - self.end = ( + self._end: _RangePoint | None = ( _RangePoint(end_key, end_is_inclusive) if end_key is not None else None ) + @property + def start_key(self) -> bytes | None: + """ + Returns the start key of the range. If None, the range is unbounded on the left. + """ + return self._start.key if self._start is not None else None + + @property + def end_key(self) -> bytes | None: + """ + Returns the end key of the range. If None, the range is unbounded on the right. + """ + return self._end.key if self._end is not None else None + + @property + def start_is_inclusive(self) -> bool: + """ + Returns whether the range is inclusive of the start key. + Returns True if the range is unbounded on the left. + """ + return self._start.is_inclusive if self._start is not None else True + + @property + def end_is_inclusive(self) -> bool: + """ + Returns whether the range is inclusive of the end key. + Returns True if the range is unbounded on the right. + """ + return self._end.is_inclusive if self._end is not None else True + def _to_dict(self) -> dict[str, bytes]: """Converts this object to a dictionary""" output = {} - if self.start is not None: - key = "start_key_closed" if self.start.is_inclusive else "start_key_open" - output[key] = self.start.key - if self.end is not None: - key = "end_key_closed" if self.end.is_inclusive else "end_key_open" - output[key] = self.end.key + if self._start is not None: + key = "start_key_closed" if self.start_is_inclusive else "start_key_open" + output[key] = self._start.key + if self._end is not None: + key = "end_key_closed" if self.end_is_inclusive else "end_key_open" + output[key] = self._end.key return output def __hash__(self) -> int: - return hash((self.start, self.end)) + return hash((self._start, self._end)) @classmethod def _from_dict(cls, data: dict[str, bytes]) -> RowRange: @@ -123,7 +177,35 @@ def __bool__(self) -> bool: Empty RowRanges (representing a full table scan) are falsy, because they can be substituted with None. Non-empty RowRanges are truthy. """ - return self.start is not None or self.end is not None + return self._start is not None or self._end is not None + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, RowRange): + return NotImplemented + return self._start == other._start and self._end == other._end + + def __str__(self) -> str: + """ + Represent range as a string, e.g. "[b'a', b'z)" + Unbounded start or end keys are represented as "-inf" or "+inf" + """ + left = "[" if self.start_is_inclusive else "(" + right = "]" if self.end_is_inclusive else ")" + start = repr(self.start_key) if self.start_key is not None else "-inf" + end = repr(self.end_key) if self.end_key is not None else "+inf" + return f"{left}{start}, {end}{right}" + + def __repr__(self) -> str: + args_list = [] + args_list.append(f"start_key={self.start_key!r}") + args_list.append(f"end_key={self.end_key!r}") + if self.start_is_inclusive is False: + # only show start_is_inclusive if it is different from the default + args_list.append(f"start_is_inclusive={self.start_is_inclusive}") + if self.end_is_inclusive is True and self._end is not None: + # only show end_is_inclusive if it is different from the default + args_list.append(f"end_is_inclusive={self.end_is_inclusive}") + return f"RowRange({', '.join(args_list)})" class ReadRowsQuery: @@ -136,7 +218,7 @@ def __init__( row_keys: list[str | bytes] | str | bytes | None = None, row_ranges: list[RowRange] | RowRange | None = None, limit: int | None = None, - row_filter: RowFilter | dict[str, Any] | None = None, + row_filter: RowFilter | None = None, ): """ Create a new ReadRowsQuery @@ -162,7 +244,7 @@ def __init__( for k in row_keys: self.add_key(k) self.limit: int | None = limit - self.filter: RowFilter | dict[str, Any] | None = row_filter + self.filter: RowFilter | None = row_filter @property def limit(self) -> int | None: @@ -187,11 +269,11 @@ def limit(self, new_limit: int | None): self._limit = new_limit @property - def filter(self) -> RowFilter | dict[str, Any] | None: + def filter(self) -> RowFilter | None: return self._filter @filter.setter - def filter(self, row_filter: RowFilter | dict[str, Any] | None): + def filter(self, row_filter: RowFilter | None): """ Set a RowFilter to apply to this query @@ -310,24 +392,24 @@ def _shard_range( - a list of tuples, containing a segment index and a new sub-range. """ # 1. find the index of the segment the start key belongs to - if orig_range.start is None: + if orig_range._start is None: # if range is open on the left, include first segment start_segment = 0 else: # use binary search to find the segment the start key belongs to # bisect method determines how we break ties when the start key matches a split point # if inclusive, bisect_left to the left segment, otherwise bisect_right - bisect = bisect_left if orig_range.start.is_inclusive else bisect_right - start_segment = bisect(split_points, orig_range.start.key) + bisect = bisect_left if orig_range._start.is_inclusive else bisect_right + start_segment = bisect(split_points, orig_range._start.key) # 2. find the index of the segment the end key belongs to - if orig_range.end is None: + if orig_range._end is None: # if range is open on the right, include final segment end_segment = len(split_points) else: # use binary search to find the segment the end key belongs to. end_segment = bisect_left( - split_points, orig_range.end.key, lo=start_segment + split_points, orig_range._end.key, lo=start_segment ) # note: end_segment will always bisect_left, because split points represent inclusive ends # whether the end_key is includes the split point or not, the result is the same segment @@ -343,7 +425,7 @@ def _shard_range( # first range spans from start_key to the split_point representing the last key in the segment last_key_in_first_segment = split_points[start_segment] start_range = RowRange._from_points( - start=orig_range.start, + start=orig_range._start, end=_RangePoint(last_key_in_first_segment, is_inclusive=True), ) results.append((start_segment, start_range)) @@ -353,7 +435,7 @@ def _shard_range( last_key_before_segment = split_points[previous_segment] end_range = RowRange._from_points( start=_RangePoint(last_key_before_segment, is_inclusive=False), - end=orig_range.end, + end=orig_range._end, ) results.append((end_segment, end_range)) # 3c. add new spanning range to all segments other than the first and last @@ -386,7 +468,9 @@ def _to_dict(self) -> dict[str, Any]: "rows": row_set, } dict_filter = ( - self.filter.to_dict() if isinstance(self.filter, RowFilter) else self.filter + self.filter._to_dict() + if isinstance(self.filter, RowFilter) + else self.filter ) if dict_filter: final_dict["filter"] = dict_filter @@ -412,9 +496,15 @@ def __eq__(self, other): ) if this_range_empty and other_range_empty: return self.filter == other.filter and self.limit == other.limit + # otherwise, sets should have same sizes + if len(self.row_keys) != len(other.row_keys): + return False + if len(self.row_ranges) != len(other.row_ranges): + return False + ranges_match = all([row in other.row_ranges for row in self.row_ranges]) return ( self.row_keys == other.row_keys - and self.row_ranges == other.row_ranges + and ranges_match and self.filter == other.filter and self.limit == other.limit ) diff --git a/google/cloud/bigtable/data/row.py b/google/cloud/bigtable/data/row.py index 5fdc1b365..f562e96d6 100644 --- a/google/cloud/bigtable/data/row.py +++ b/google/cloud/bigtable/data/row.py @@ -153,7 +153,7 @@ def __str__(self) -> str: } """ output = ["{"] - for family, qualifier in self.get_column_components(): + for family, qualifier in self._get_column_components(): cell_list = self[family, qualifier] line = [f" (family={family!r}, qualifier={qualifier!r}): "] if len(cell_list) == 0: @@ -168,16 +168,16 @@ def __str__(self) -> str: def __repr__(self): cell_str_buffer = ["{"] - for family, qualifier in self.get_column_components(): + for family, qualifier in self._get_column_components(): cell_list = self[family, qualifier] - repr_list = [cell.to_dict() for cell in cell_list] + repr_list = [cell._to_dict() for cell in cell_list] cell_str_buffer.append(f" ('{family}', {qualifier!r}): {repr_list},") cell_str_buffer.append("}") cell_str = "\n".join(cell_str_buffer) output = f"Row(key={self.row_key!r}, cells={cell_str})" return output - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """ Returns a dictionary representation of the cell in the Bigtable Row proto format @@ -188,7 +188,7 @@ def to_dict(self) -> dict[str, Any]: for family_name, qualifier_dict in self._index.items(): qualifier_list = [] for qualifier_name, cell_list in qualifier_dict.items(): - cell_dicts = [cell.to_dict() for cell in cell_list] + cell_dicts = [cell._to_dict() for cell in cell_list] qualifier_list.append( {"qualifier": qualifier_name, "cells": cell_dicts} ) @@ -268,7 +268,7 @@ def __len__(self): """ return len(self.cells) - def get_column_components(self) -> list[tuple[str, bytes]]: + def _get_column_components(self) -> list[tuple[str, bytes]]: """ Returns a list of (family, qualifier) pairs associated with the cells @@ -288,8 +288,8 @@ def __eq__(self, other): return False if len(self.cells) != len(other.cells): return False - components = self.get_column_components() - other_components = other.get_column_components() + components = self._get_column_components() + other_components = other._get_column_components() if len(components) != len(other_components): return False if components != other_components: @@ -375,7 +375,7 @@ def __int__(self) -> int: """ return int.from_bytes(self.value, byteorder="big", signed=True) - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """ Returns a dictionary representation of the cell in the Bigtable Cell proto format diff --git a/google/cloud/bigtable/data/row_filters.py b/google/cloud/bigtable/data/row_filters.py index b2fae6971..9f09133d5 100644 --- a/google/cloud/bigtable/data/row_filters.py +++ b/google/cloud/bigtable/data/row_filters.py @@ -47,10 +47,10 @@ def _to_pb(self) -> data_v2_pb2.RowFilter: Returns: The converted current object. """ - return data_v2_pb2.RowFilter(**self.to_dict()) + return data_v2_pb2.RowFilter(**self._to_dict()) @abstractmethod - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """Converts the row filter to a dict representation.""" pass @@ -91,7 +91,7 @@ class SinkFilter(_BoolFilter): of a :class:`ConditionalRowFilter`. """ - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """Converts the row filter to a dict representation.""" return {"sink": self.flag} @@ -105,7 +105,7 @@ class PassAllFilter(_BoolFilter): completeness. """ - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """Converts the row filter to a dict representation.""" return {"pass_all_filter": self.flag} @@ -118,7 +118,7 @@ class BlockAllFilter(_BoolFilter): temporarily disabling just part of a filter. """ - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """Converts the row filter to a dict representation.""" return {"block_all_filter": self.flag} @@ -175,7 +175,7 @@ class RowKeyRegexFilter(_RegexFilter): since the row key is already specified. """ - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """Converts the row filter to a dict representation.""" return {"row_key_regex_filter": self.regex} @@ -199,7 +199,7 @@ def __eq__(self, other): def __ne__(self, other): return not self == other - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """Converts the row filter to a dict representation.""" return {"row_sample_filter": self.sample} @@ -222,7 +222,7 @@ class FamilyNameRegexFilter(_RegexFilter): used as a literal. """ - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """Converts the row filter to a dict representation.""" return {"family_name_regex_filter": self.regex} @@ -248,7 +248,7 @@ class ColumnQualifierRegexFilter(_RegexFilter): match this regex (irrespective of column family). """ - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """Converts the row filter to a dict representation.""" return {"column_qualifier_regex_filter": self.regex} @@ -282,9 +282,9 @@ def _to_pb(self) -> data_v2_pb2.TimestampRange: Returns: The converted current object. """ - return data_v2_pb2.TimestampRange(**self.to_dict()) + return data_v2_pb2.TimestampRange(**self._to_dict()) - def to_dict(self) -> dict[str, int]: + def _to_dict(self) -> dict[str, int]: """Converts the timestamp range to a dict representation.""" timestamp_range_kwargs = {} if self.start is not None: @@ -330,9 +330,9 @@ def _to_pb(self) -> data_v2_pb2.RowFilter: """ return data_v2_pb2.RowFilter(timestamp_range_filter=self.range_._to_pb()) - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """Converts the row filter to a dict representation.""" - return {"timestamp_range_filter": self.range_.to_dict()} + return {"timestamp_range_filter": self.range_._to_dict()} def __repr__(self) -> str: return f"{self.__class__.__name__}(start={self.range_.start!r}, end={self.range_.end!r})" @@ -426,10 +426,10 @@ def _to_pb(self) -> data_v2_pb2.RowFilter: Returns: The converted current object. """ - column_range = data_v2_pb2.ColumnRange(**self.range_to_dict()) + column_range = data_v2_pb2.ColumnRange(**self._range_to_dict()) return data_v2_pb2.RowFilter(column_range_filter=column_range) - def range_to_dict(self) -> dict[str, str | bytes]: + def _range_to_dict(self) -> dict[str, str | bytes]: """Converts the column range range to a dict representation.""" column_range_kwargs: dict[str, str | bytes] = {} column_range_kwargs["family_name"] = self.family_id @@ -447,9 +447,9 @@ def range_to_dict(self) -> dict[str, str | bytes]: column_range_kwargs[key] = _to_bytes(self.end_qualifier) return column_range_kwargs - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """Converts the row filter to a dict representation.""" - return {"column_range_filter": self.range_to_dict()} + return {"column_range_filter": self._range_to_dict()} def __repr__(self) -> str: return f"{self.__class__.__name__}(family_id='{self.family_id}', start_qualifier={self.start_qualifier!r}, end_qualifier={self.end_qualifier!r}, inclusive_start={self.inclusive_start}, inclusive_end={self.inclusive_end})" @@ -476,7 +476,7 @@ class ValueRegexFilter(_RegexFilter): match this regex. String values will be encoded as ASCII. """ - def to_dict(self) -> dict[str, bytes]: + def _to_dict(self) -> dict[str, bytes]: """Converts the row filter to a dict representation.""" return {"value_regex_filter": self.regex} @@ -620,10 +620,10 @@ def _to_pb(self) -> data_v2_pb2.RowFilter: Returns: The converted current object. """ - value_range = data_v2_pb2.ValueRange(**self.range_to_dict()) + value_range = data_v2_pb2.ValueRange(**self._range_to_dict()) return data_v2_pb2.RowFilter(value_range_filter=value_range) - def range_to_dict(self) -> dict[str, bytes]: + def _range_to_dict(self) -> dict[str, bytes]: """Converts the value range range to a dict representation.""" value_range_kwargs = {} if self.start_value is not None: @@ -640,9 +640,9 @@ def range_to_dict(self) -> dict[str, bytes]: value_range_kwargs[key] = _to_bytes(self.end_value) return value_range_kwargs - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """Converts the row filter to a dict representation.""" - return {"value_range_filter": self.range_to_dict()} + return {"value_range_filter": self._range_to_dict()} def __repr__(self) -> str: return f"{self.__class__.__name__}(start_value={self.start_value!r}, end_value={self.end_value!r}, inclusive_start={self.inclusive_start}, inclusive_end={self.inclusive_end})" @@ -680,7 +680,7 @@ class CellsRowOffsetFilter(_CellCountFilter): :param num_cells: Skips the first N cells of the row. """ - def to_dict(self) -> dict[str, int]: + def _to_dict(self) -> dict[str, int]: """Converts the row filter to a dict representation.""" return {"cells_per_row_offset_filter": self.num_cells} @@ -692,7 +692,7 @@ class CellsRowLimitFilter(_CellCountFilter): :param num_cells: Matches only the first N cells of the row. """ - def to_dict(self) -> dict[str, int]: + def _to_dict(self) -> dict[str, int]: """Converts the row filter to a dict representation.""" return {"cells_per_row_limit_filter": self.num_cells} @@ -706,7 +706,7 @@ class CellsColumnLimitFilter(_CellCountFilter): timestamps of each cell. """ - def to_dict(self) -> dict[str, int]: + def _to_dict(self) -> dict[str, int]: """Converts the row filter to a dict representation.""" return {"cells_per_column_limit_filter": self.num_cells} @@ -720,7 +720,7 @@ class StripValueTransformerFilter(_BoolFilter): transformer than a generic query / filter. """ - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """Converts the row filter to a dict representation.""" return {"strip_value_transformer": self.flag} @@ -755,7 +755,7 @@ def __eq__(self, other): def __ne__(self, other): return not self == other - def to_dict(self) -> dict[str, str]: + def _to_dict(self) -> dict[str, str]: """Converts the row filter to a dict representation.""" return {"apply_label_transformer": self.label} @@ -841,9 +841,9 @@ def _to_pb(self) -> data_v2_pb2.RowFilter: ) return data_v2_pb2.RowFilter(chain=chain) - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """Converts the row filter to a dict representation.""" - return {"chain": {"filters": [f.to_dict() for f in self.filters]}} + return {"chain": {"filters": [f._to_dict() for f in self.filters]}} class RowFilterUnion(_FilterCombination): @@ -869,9 +869,9 @@ def _to_pb(self) -> data_v2_pb2.RowFilter: ) return data_v2_pb2.RowFilter(interleave=interleave) - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """Converts the row filter to a dict representation.""" - return {"interleave": {"filters": [f.to_dict() for f in self.filters]}} + return {"interleave": {"filters": [f._to_dict() for f in self.filters]}} class ConditionalRowFilter(RowFilter): @@ -939,18 +939,18 @@ def _to_pb(self) -> data_v2_pb2.RowFilter: condition = data_v2_pb2.RowFilter.Condition(**condition_kwargs) return data_v2_pb2.RowFilter(condition=condition) - def condition_to_dict(self) -> dict[str, Any]: + def _condition_to_dict(self) -> dict[str, Any]: """Converts the condition to a dict representation.""" - condition_kwargs = {"predicate_filter": self.predicate_filter.to_dict()} + condition_kwargs = {"predicate_filter": self.predicate_filter._to_dict()} if self.true_filter is not None: - condition_kwargs["true_filter"] = self.true_filter.to_dict() + condition_kwargs["true_filter"] = self.true_filter._to_dict() if self.false_filter is not None: - condition_kwargs["false_filter"] = self.false_filter.to_dict() + condition_kwargs["false_filter"] = self.false_filter._to_dict() return condition_kwargs - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """Converts the row filter to a dict representation.""" - return {"condition": self.condition_to_dict()} + return {"condition": self._condition_to_dict()} def __repr__(self) -> str: return f"{self.__class__.__name__}(predicate_filter={self.predicate_filter!r}, true_filter={self.true_filter!r}, false_filter={self.false_filter!r})" diff --git a/noxfile.py b/noxfile.py index 8499a610f..16447778e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -366,10 +366,9 @@ def docfx(session): session.install("-e", ".") session.install( - "sphinx==4.0.1", + "gcp-sphinx-docfx-yaml", "alabaster", "recommonmark", - "gcp-sphinx-docfx-yaml", ) shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py index 548433444..fe341e4a8 100644 --- a/tests/system/data/test_system.py +++ b/tests/system/data/test_system.py @@ -20,7 +20,7 @@ from google.api_core import retry from google.api_core.exceptions import ClientError -from google.cloud.bigtable.data.read_modify_write_rules import MAX_INCREMENT_VALUE +from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE TEST_FAMILY = "test-family" TEST_FAMILY_2 = "test-family-2" @@ -482,9 +482,9 @@ async def test_mutations_batcher_no_flush(client, table, temp_rows): (0, -100, -100), (0, 3000, 3000), (10, 4, 14), - (MAX_INCREMENT_VALUE, -MAX_INCREMENT_VALUE, 0), - (MAX_INCREMENT_VALUE, 2, -MAX_INCREMENT_VALUE), - (-MAX_INCREMENT_VALUE, -2, MAX_INCREMENT_VALUE), + (_MAX_INCREMENT_VALUE, -_MAX_INCREMENT_VALUE, 0), + (_MAX_INCREMENT_VALUE, 2, -_MAX_INCREMENT_VALUE), + (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE), ], ) @pytest.mark.asyncio diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index f77455d60..212d0522e 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -119,14 +119,14 @@ def test_ctor_too_many_entries(self): should raise an error if an operation is created with more than 100,000 entries """ from google.cloud.bigtable.data._async._mutate_rows import ( - MUTATE_ROWS_REQUEST_MUTATION_LIMIT, + _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, ) - assert MUTATE_ROWS_REQUEST_MUTATION_LIMIT == 100_000 + assert _MUTATE_ROWS_REQUEST_MUTATION_LIMIT == 100_000 client = mock.Mock() table = mock.Mock() - entries = [_make_mutation()] * MUTATE_ROWS_REQUEST_MUTATION_LIMIT + entries = [_make_mutation()] * _MUTATE_ROWS_REQUEST_MUTATION_LIMIT operation_timeout = 0.05 attempt_timeout = 0.01 # no errors if at limit diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 25006d725..5857ba98f 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -131,7 +131,7 @@ async def test_ctor_dict_options(self): assert called_options.api_endpoint == "foo.bar:1234" assert isinstance(called_options, ClientOptions) with mock.patch.object( - self._get_target_class(), "start_background_channel_refresh" + self._get_target_class(), "_start_background_channel_refresh" ) as start_background_refresh: client = self._make_one(client_options=client_options) start_background_refresh.assert_called_once() @@ -231,29 +231,29 @@ async def test_channel_pool_replace(self): await client.close() @pytest.mark.filterwarnings("ignore::RuntimeWarning") - def test_start_background_channel_refresh_sync(self): + def test__start_background_channel_refresh_sync(self): # should raise RuntimeError if called in a sync context client = self._make_one(project="project-id") with pytest.raises(RuntimeError): - client.start_background_channel_refresh() + client._start_background_channel_refresh() @pytest.mark.asyncio - async def test_start_background_channel_refresh_tasks_exist(self): + async def test__start_background_channel_refresh_tasks_exist(self): # if tasks exist, should do nothing client = self._make_one(project="project-id") with mock.patch.object(asyncio, "create_task") as create_task: - client.start_background_channel_refresh() + client._start_background_channel_refresh() create_task.assert_not_called() await client.close() @pytest.mark.asyncio @pytest.mark.parametrize("pool_size", [1, 3, 7]) - async def test_start_background_channel_refresh(self, pool_size): + async def test__start_background_channel_refresh(self, pool_size): # should create background tasks for each channel client = self._make_one(project="project-id", pool_size=pool_size) ping_and_warm = AsyncMock() client._ping_and_warm_instances = ping_and_warm - client.start_background_channel_refresh() + client._start_background_channel_refresh() assert len(client._channel_refresh_tasks) == pool_size for task in client._channel_refresh_tasks: assert isinstance(task, asyncio.Task) @@ -267,7 +267,7 @@ async def test_start_background_channel_refresh(self, pool_size): @pytest.mark.skipif( sys.version_info < (3, 8), reason="Task.name requires python3.8 or higher" ) - async def test_start_background_channel_refresh_tasks_names(self): + async def test__start_background_channel_refresh_tasks_names(self): # if tasks exist, should do nothing pool_size = 3 client = self._make_one(project="project-id", pool_size=pool_size) @@ -569,7 +569,7 @@ async def test__register_instance(self): client_mock._active_instances = active_instances client_mock._instance_owners = instance_owners client_mock._channel_refresh_tasks = [] - client_mock.start_background_channel_refresh.side_effect = ( + client_mock._start_background_channel_refresh.side_effect = ( lambda: client_mock._channel_refresh_tasks.append(mock.Mock) ) mock_channels = [mock.Mock() for i in range(5)] @@ -580,7 +580,7 @@ async def test__register_instance(self): client_mock, "instance-1", table_mock ) # first call should start background refresh - assert client_mock.start_background_channel_refresh.call_count == 1 + assert client_mock._start_background_channel_refresh.call_count == 1 # ensure active_instances and instance_owners were updated properly expected_key = ( "prefix/instance-1", @@ -593,12 +593,12 @@ async def test__register_instance(self): assert expected_key == tuple(list(instance_owners)[0]) # should be a new task set assert client_mock._channel_refresh_tasks - # # next call should not call start_background_channel_refresh again + # next call should not call _start_background_channel_refresh again table_mock2 = mock.Mock() await self._get_target_class()._register_instance( client_mock, "instance-2", table_mock2 ) - assert client_mock.start_background_channel_refresh.call_count == 1 + assert client_mock._start_background_channel_refresh.call_count == 1 # but it should call ping and warm with new instance key assert client_mock._ping_and_warm_instances.call_count == len(mock_channels) for channel in mock_channels: @@ -655,7 +655,7 @@ async def test__register_instance_state( client_mock._active_instances = active_instances client_mock._instance_owners = instance_owners client_mock._channel_refresh_tasks = [] - client_mock.start_background_channel_refresh.side_effect = ( + client_mock._start_background_channel_refresh.side_effect = ( lambda: client_mock._channel_refresh_tasks.append(mock.Mock) ) mock_channels = [mock.Mock() for i in range(5)] @@ -1181,7 +1181,7 @@ async def test_read_rows_query_matches_request(self, include_app_profile): read_rows = table.client._gapic_client.read_rows read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream([]) row_keys = [b"test_1", "test_2"] - row_ranges = RowRange("start", "end") + row_ranges = RowRange("1start", "2end") filter_ = {"test": "filter"} limit = 99 query = ReadRowsQuery( @@ -1788,12 +1788,12 @@ async def test_read_rows_sharded_batching(self): operation timeout should change between batches """ from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async.client import CONCURRENCY_LIMIT + from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT - assert CONCURRENCY_LIMIT == 10 # change this test if this changes + assert _CONCURRENCY_LIMIT == 10 # change this test if this changes n_queries = 90 - expected_num_batches = n_queries // CONCURRENCY_LIMIT + expected_num_batches = n_queries // _CONCURRENCY_LIMIT query_list = [ReadRowsQuery() for _ in range(n_queries)] table_mock = AsyncMock() @@ -1817,8 +1817,8 @@ async def test_read_rows_sharded_batching(self): for batch_idx in range(expected_num_batches): batch_kwargs = kwargs[ batch_idx - * CONCURRENCY_LIMIT : (batch_idx + 1) - * CONCURRENCY_LIMIT + * _CONCURRENCY_LIMIT : (batch_idx + 1) + * _CONCURRENCY_LIMIT ] for req_kwargs in batch_kwargs: # each batch should have the same operation_timeout, and it should decrease in each batch @@ -2688,12 +2688,12 @@ async def test_check_and_mutate_single_mutations(self): @pytest.mark.asyncio async def test_check_and_mutate_predicate_object(self): - """predicate object should be converted to dict""" + """predicate filter should be passed to gapic request""" from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse mock_predicate = mock.Mock() - fake_dict = {"fake": "dict"} - mock_predicate.to_dict.return_value = fake_dict + predicate_dict = {"predicate": "dict"} + mock_predicate._to_dict.return_value = predicate_dict async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( @@ -2708,8 +2708,8 @@ async def test_check_and_mutate_predicate_object(self): false_case_mutations=[mock.Mock()], ) kwargs = mock_gapic.call_args[1] - assert kwargs["request"]["predicate_filter"] == fake_dict - assert mock_predicate.to_dict.call_count == 1 + assert kwargs["request"]["predicate_filter"] == predicate_dict + assert mock_predicate._to_dict.call_count == 1 @pytest.mark.asyncio async def test_check_and_mutate_mutations_parsing(self): diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 1b14cc128..57af38010 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -240,7 +240,7 @@ async def test_add_to_flow_max_mutation_limits( Should submit request early, even if the flow control has room for more """ with mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MUTATE_ROWS_REQUEST_MUTATION_LIMIT", + "google.cloud.bigtable.data._async.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", max_limit, ): mutation_objs = [_make_mutation(count=m[0], size=m[1]) for m in mutations] diff --git a/tests/unit/data/test_mutations.py b/tests/unit/data/test_mutations.py index 8365dbd02..8680a8da9 100644 --- a/tests/unit/data/test_mutations.py +++ b/tests/unit/data/test_mutations.py @@ -507,12 +507,12 @@ def test_ctor(self): def test_ctor_over_limit(self): """Should raise error if mutations exceed MAX_MUTATIONS_PER_ENTRY""" from google.cloud.bigtable.data.mutations import ( - MUTATE_ROWS_REQUEST_MUTATION_LIMIT, + _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, ) - assert MUTATE_ROWS_REQUEST_MUTATION_LIMIT == 100_000 + assert _MUTATE_ROWS_REQUEST_MUTATION_LIMIT == 100_000 # no errors at limit - expected_mutations = [None for _ in range(MUTATE_ROWS_REQUEST_MUTATION_LIMIT)] + expected_mutations = [None for _ in range(_MUTATE_ROWS_REQUEST_MUTATION_LIMIT)] self._make_one(b"row_key", expected_mutations) # error if over limit with pytest.raises(ValueError) as e: diff --git a/tests/unit/data/test_read_rows_query.py b/tests/unit/data/test_read_rows_query.py index 88fde2d24..1e4e27d36 100644 --- a/tests/unit/data/test_read_rows_query.py +++ b/tests/unit/data/test_read_rows_query.py @@ -32,34 +32,52 @@ def _make_one(self, *args, **kwargs): def test_ctor_start_end(self): row_range = self._make_one("test_row", "test_row2") - assert row_range.start.key == "test_row".encode() - assert row_range.end.key == "test_row2".encode() - assert row_range.start.is_inclusive is True - assert row_range.end.is_inclusive is False + assert row_range._start.key == "test_row".encode() + assert row_range._end.key == "test_row2".encode() + assert row_range._start.is_inclusive is True + assert row_range._end.is_inclusive is False + assert row_range.start_key == "test_row".encode() + assert row_range.end_key == "test_row2".encode() + assert row_range.start_is_inclusive is True + assert row_range.end_is_inclusive is False def test_ctor_start_only(self): row_range = self._make_one("test_row3") - assert row_range.start.key == "test_row3".encode() - assert row_range.start.is_inclusive is True - assert row_range.end is None + assert row_range.start_key == "test_row3".encode() + assert row_range.start_is_inclusive is True + assert row_range.end_key is None + assert row_range.end_is_inclusive is True def test_ctor_end_only(self): row_range = self._make_one(end_key="test_row4") - assert row_range.end.key == "test_row4".encode() - assert row_range.end.is_inclusive is False - assert row_range.start is None + assert row_range.end_key == "test_row4".encode() + assert row_range.end_is_inclusive is False + assert row_range.start_key is None + assert row_range.start_is_inclusive is True + + def test_ctor_empty_strings(self): + """ + empty strings should be treated as None + """ + row_range = self._make_one("", "") + assert row_range._start is None + assert row_range._end is None + assert row_range.start_key is None + assert row_range.end_key is None + assert row_range.start_is_inclusive is True + assert row_range.end_is_inclusive is True def test_ctor_inclusive_flags(self): row_range = self._make_one("test_row5", "test_row6", False, True) - assert row_range.start.key == "test_row5".encode() - assert row_range.end.key == "test_row6".encode() - assert row_range.start.is_inclusive is False - assert row_range.end.is_inclusive is True + assert row_range.start_key == "test_row5".encode() + assert row_range.end_key == "test_row6".encode() + assert row_range.start_is_inclusive is False + assert row_range.end_is_inclusive is True def test_ctor_defaults(self): row_range = self._make_one() - assert row_range.start is None - assert row_range.end is None + assert row_range.start_key is None + assert row_range.end_key is None def test_ctor_flags_only(self): with pytest.raises(ValueError) as exc: @@ -83,6 +101,9 @@ def test_ctor_invalid_keys(self): with pytest.raises(ValueError) as exc: self._make_one("1", 2) assert str(exc.value) == "end_key must be a string or bytes" + with pytest.raises(ValueError) as exc: + self._make_one("2", "1") + assert str(exc.value) == "start_key must be less than or equal to end_key" def test__to_dict_defaults(self): row_range = self._make_one("test_row", "test_row2") @@ -143,8 +164,8 @@ def test__from_dict( row_range = RowRange._from_dict(input_dict) assert row_range._to_dict().keys() == input_dict.keys() - found_start = row_range.start - found_end = row_range.end + found_start = row_range._start + found_end = row_range._end if expected_start is None: assert found_start is None assert start_is_inclusive is None @@ -176,7 +197,7 @@ def test__from_points(self, dict_repr): row_range_from_dict = RowRange._from_dict(dict_repr) row_range_from_points = RowRange._from_points( - row_range_from_dict.start, row_range_from_dict.end + row_range_from_dict._start, row_range_from_dict._end ) assert row_range_from_points._to_dict() == row_range_from_dict._to_dict() @@ -238,6 +259,86 @@ def test___bool__(self, dict_repr, expected): row_range = RowRange._from_dict(dict_repr) assert bool(row_range) is expected + def test__eq__(self): + """ + test that row ranges can be compared for equality + """ + from google.cloud.bigtable.data.read_rows_query import RowRange + + range1 = RowRange("1", "2") + range1_dup = RowRange("1", "2") + range2 = RowRange("1", "3") + range_w_empty = RowRange(None, "2") + assert range1 == range1_dup + assert range1 != range2 + assert range1 != range_w_empty + range_1_w_inclusive_start = RowRange("1", "2", start_is_inclusive=True) + range_1_w_exclusive_start = RowRange("1", "2", start_is_inclusive=False) + range_1_w_inclusive_end = RowRange("1", "2", end_is_inclusive=True) + range_1_w_exclusive_end = RowRange("1", "2", end_is_inclusive=False) + assert range1 == range_1_w_inclusive_start + assert range1 == range_1_w_exclusive_end + assert range1 != range_1_w_exclusive_start + assert range1 != range_1_w_inclusive_end + + @pytest.mark.parametrize( + "dict_repr,expected", + [ + ( + {"start_key_closed": "test_row", "end_key_open": "test_row2"}, + "[b'test_row', b'test_row2')", + ), + ( + {"start_key_open": "test_row", "end_key_closed": "test_row2"}, + "(b'test_row', b'test_row2']", + ), + ({"start_key_open": b"a"}, "(b'a', +inf]"), + ({"end_key_closed": b"b"}, "[-inf, b'b']"), + ({"end_key_open": b"b"}, "[-inf, b'b')"), + ({}, "[-inf, +inf]"), + ], + ) + def test___str__(self, dict_repr, expected): + """ + test string representations of row ranges + """ + from google.cloud.bigtable.data.read_rows_query import RowRange + + row_range = RowRange._from_dict(dict_repr) + assert str(row_range) == expected + + @pytest.mark.parametrize( + "dict_repr,expected", + [ + ( + {"start_key_closed": "test_row", "end_key_open": "test_row2"}, + "RowRange(start_key=b'test_row', end_key=b'test_row2')", + ), + ( + {"start_key_open": "test_row", "end_key_closed": "test_row2"}, + "RowRange(start_key=b'test_row', end_key=b'test_row2', start_is_inclusive=False, end_is_inclusive=True)", + ), + ( + {"start_key_open": b"a"}, + "RowRange(start_key=b'a', end_key=None, start_is_inclusive=False)", + ), + ( + {"end_key_closed": b"b"}, + "RowRange(start_key=None, end_key=b'b', end_is_inclusive=True)", + ), + ({"end_key_open": b"b"}, "RowRange(start_key=None, end_key=b'b')"), + ({}, "RowRange(start_key=None, end_key=None)"), + ], + ) + def test___repr__(self, dict_repr, expected): + """ + test repr representations of row ranges + """ + from google.cloud.bigtable.data.read_rows_query import RowRange + + row_range = RowRange._from_dict(dict_repr) + assert repr(row_range) == expected + class TestReadRowsQuery: @staticmethod @@ -299,24 +400,6 @@ def test_set_filter(self): query.filter = 1 assert str(exc.value) == "row_filter must be a RowFilter or dict" - def test_set_filter_dict(self): - from google.cloud.bigtable.data.row_filters import RowSampleFilter - from google.cloud.bigtable_v2.types.bigtable import ReadRowsRequest - - filter1 = RowSampleFilter(0.5) - filter1_dict = filter1.to_dict() - query = self._make_one() - assert query.filter is None - query.filter = filter1_dict - assert query.filter == filter1_dict - output = query._to_dict() - assert output["filter"] == filter1_dict - proto_output = ReadRowsRequest(**output) - assert proto_output.filter == filter1._to_pb() - - query.filter = None - assert query.filter is None - def test_set_limit(self): query = self._make_one() assert query.limit is None @@ -698,13 +781,18 @@ def test_shard_limit_exception(self): ((), ("a",), False), (("a",), (), False), (("a",), ("a",), True), + (("a",), (["a", b"a"],), True), # duplicate keys + ((["a"],), (["a", "b"],), False), + ((["a", "b"],), (["a", "b"],), True), + ((["a", b"b"],), ([b"a", "b"],), True), (("a",), (b"a",), True), (("a",), ("b",), False), (("a",), ("a", ["b"]), False), - (("a", ["b"]), ("a", ["b"]), True), + (("a", "b"), ("a", ["b"]), True), (("a", ["b"]), ("a", ["b", "c"]), False), (("a", ["b", "c"]), ("a", [b"b", "c"]), True), (("a", ["b", "c"], 1), ("a", ["b", b"c"], 1), True), + (("a", ["b"], 1), ("a", ["b", b"b", "b"], 1), True), # duplicate ranges (("a", ["b"], 1), ("a", ["b"], 2), False), (("a", ["b"], 1, {"a": "b"}), ("a", ["b"], 1, {"a": "b"}), True), (("a", ["b"], 1, {"a": "b"}), ("a", ["b"], 1), False), diff --git a/tests/unit/data/test_row.py b/tests/unit/data/test_row.py index c9c797b61..df2fc72c0 100644 --- a/tests/unit/data/test_row.py +++ b/tests/unit/data/test_row.py @@ -176,7 +176,7 @@ def test_to_dict(self): cell2 = self._make_cell() cell2.value = b"other" row = self._make_one(TEST_ROW_KEY, [cell1, cell2]) - row_dict = row.to_dict() + row_dict = row._to_dict() expected_dict = { "key": TEST_ROW_KEY, "families": [ @@ -465,20 +465,20 @@ def test_get_column_components(self): ) row_response = self._make_one(TEST_ROW_KEY, [cell, cell2, cell3]) - self.assertEqual(len(row_response.get_column_components()), 2) + self.assertEqual(len(row_response._get_column_components()), 2) self.assertEqual( - row_response.get_column_components(), + row_response._get_column_components(), [(TEST_FAMILY_ID, TEST_QUALIFIER), (new_family_id, new_qualifier)], ) row_response = self._make_one(TEST_ROW_KEY, []) - self.assertEqual(len(row_response.get_column_components()), 0) - self.assertEqual(row_response.get_column_components(), []) + self.assertEqual(len(row_response._get_column_components()), 0) + self.assertEqual(row_response._get_column_components(), []) row_response = self._make_one(TEST_ROW_KEY, [cell]) - self.assertEqual(len(row_response.get_column_components()), 1) + self.assertEqual(len(row_response._get_column_components()), 1) self.assertEqual( - row_response.get_column_components(), [(TEST_FAMILY_ID, TEST_QUALIFIER)] + row_response._get_column_components(), [(TEST_FAMILY_ID, TEST_QUALIFIER)] ) def test_index_of(self): @@ -535,7 +535,7 @@ def test_to_dict(self): from google.cloud.bigtable_v2.types import Cell cell = self._make_one() - cell_dict = cell.to_dict() + cell_dict = cell._to_dict() expected_dict = { "value": TEST_VALUE, "timestamp_micros": TEST_TIMESTAMP, @@ -561,7 +561,7 @@ def test_to_dict_no_labels(self): TEST_TIMESTAMP, None, ) - cell_dict = cell_no_labels.to_dict() + cell_dict = cell_no_labels._to_dict() expected_dict = { "value": TEST_VALUE, "timestamp_micros": TEST_TIMESTAMP, diff --git a/tests/unit/data/test_row_filters.py b/tests/unit/data/test_row_filters.py index a3e275e70..e90b6f270 100644 --- a/tests/unit/data/test_row_filters.py +++ b/tests/unit/data/test_row_filters.py @@ -80,7 +80,7 @@ def test_sink_filter_to_dict(): flag = True row_filter = SinkFilter(flag) expected_dict = {"sink": flag} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -112,7 +112,7 @@ def test_pass_all_filter_to_dict(): flag = True row_filter = PassAllFilter(flag) expected_dict = {"pass_all_filter": flag} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -144,7 +144,7 @@ def test_block_all_filter_to_dict(): flag = True row_filter = BlockAllFilter(flag) expected_dict = {"block_all_filter": flag} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -214,7 +214,7 @@ def test_row_key_regex_filter_to_dict(): regex = b"row-key-regex" row_filter = RowKeyRegexFilter(regex) expected_dict = {"row_key_regex_filter": regex} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -302,7 +302,7 @@ def test_family_name_regex_filter_to_dict(): regex = "family-regex" row_filter = FamilyNameRegexFilter(regex) expected_dict = {"family_name_regex_filter": regex.encode()} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -335,7 +335,7 @@ def test_column_qualifier_regex_filter_to_dict(): regex = b"column-regex" row_filter = ColumnQualifierRegexFilter(regex) expected_dict = {"column_qualifier_regex_filter": regex} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -432,7 +432,7 @@ def test_timestamp_range_to_dict(): "start_timestamp_micros": 1546300800000000, "end_timestamp_micros": 1546387200000000, } - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.TimestampRange(**expected_dict) == expected_pb_value @@ -454,7 +454,7 @@ def test_timestamp_range_to_dict_start_only(): row_filter = TimestampRange(start=datetime.datetime(2019, 1, 1)) expected_dict = {"start_timestamp_micros": 1546300800000000} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.TimestampRange(**expected_dict) == expected_pb_value @@ -476,7 +476,7 @@ def test_timestamp_range_to_dict_end_only(): row_filter = TimestampRange(end=datetime.datetime(2019, 1, 2)) expected_dict = {"end_timestamp_micros": 1546387200000000} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.TimestampRange(**expected_dict) == expected_pb_value @@ -543,7 +543,7 @@ def test_timestamp_range_filter_to_dict(): "end_timestamp_micros": 1546387200000000, } } - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -554,7 +554,7 @@ def test_timestamp_range_filter_empty_to_dict(): row_filter = TimestampRangeFilter() expected_dict = {"timestamp_range_filter": {}} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -701,7 +701,7 @@ def test_column_range_filter_to_dict(): family_id = "column-family-id" row_filter = ColumnRangeFilter(family_id) expected_dict = {"column_range_filter": {"family_name": family_id}} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -782,7 +782,7 @@ def test_value_regex_filter_to_dict_w_bytes(): value = regex = b"value-regex" row_filter = ValueRegexFilter(value) expected_dict = {"value_regex_filter": regex} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -806,7 +806,7 @@ def test_value_regex_filter_to_dict_w_str(): regex = value.encode("ascii") row_filter = ValueRegexFilter(value) expected_dict = {"value_regex_filter": regex} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -839,7 +839,7 @@ def test_literal_value_filter_to_dict_w_bytes(): value = regex = b"value_regex" row_filter = LiteralValueFilter(value) expected_dict = {"value_regex_filter": regex} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -863,7 +863,7 @@ def test_literal_value_filter_to_dict_w_str(): regex = value.encode("ascii") row_filter = LiteralValueFilter(value) expected_dict = {"value_regex_filter": regex} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -896,7 +896,7 @@ def test_literal_value_filter_w_int(value, expected_byte_string): assert pb_val == expected_pb # test dict expected_dict = {"value_regex_filter": expected_byte_string} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict assert data_v2_pb2.RowFilter(**expected_dict) == pb_val @@ -1042,7 +1042,7 @@ def test_value_range_filter_to_dict(): row_filter = ValueRangeFilter() expected_dict = {"value_range_filter": {}} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -1149,7 +1149,7 @@ def test_cells_row_offset_filter_to_dict(): num_cells = 76 row_filter = CellsRowOffsetFilter(num_cells) expected_dict = {"cells_per_row_offset_filter": num_cells} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -1182,7 +1182,7 @@ def test_cells_row_limit_filter_to_dict(): num_cells = 189 row_filter = CellsRowLimitFilter(num_cells) expected_dict = {"cells_per_row_limit_filter": num_cells} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -1215,7 +1215,7 @@ def test_cells_column_limit_filter_to_dict(): num_cells = 10 row_filter = CellsColumnLimitFilter(num_cells) expected_dict = {"cells_per_column_limit_filter": num_cells} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -1248,7 +1248,7 @@ def test_strip_value_transformer_filter_to_dict(): flag = True row_filter = StripValueTransformerFilter(flag) expected_dict = {"strip_value_transformer": flag} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -1317,7 +1317,7 @@ def test_apply_label_filter_to_dict(): label = "label" row_filter = ApplyLabelFilter(label) expected_dict = {"apply_label_transformer": label} - assert row_filter.to_dict() == expected_dict + assert row_filter._to_dict() == expected_dict expected_pb_value = row_filter._to_pb() assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value @@ -1437,13 +1437,13 @@ def test_row_filter_chain_to_dict(): from google.cloud.bigtable_v2.types import data as data_v2_pb2 row_filter1 = StripValueTransformerFilter(True) - row_filter1_dict = row_filter1.to_dict() + row_filter1_dict = row_filter1._to_dict() row_filter2 = RowSampleFilter(0.25) - row_filter2_dict = row_filter2.to_dict() + row_filter2_dict = row_filter2._to_dict() row_filter3 = RowFilterChain(filters=[row_filter1, row_filter2]) - filter_dict = row_filter3.to_dict() + filter_dict = row_filter3._to_dict() expected_dict = {"chain": {"filters": [row_filter1_dict, row_filter2_dict]}} assert filter_dict == expected_dict @@ -1487,13 +1487,13 @@ def test_row_filter_chain_to_dict_nested(): row_filter2 = RowSampleFilter(0.25) row_filter3 = RowFilterChain(filters=[row_filter1, row_filter2]) - row_filter3_dict = row_filter3.to_dict() + row_filter3_dict = row_filter3._to_dict() row_filter4 = CellsRowLimitFilter(11) - row_filter4_dict = row_filter4.to_dict() + row_filter4_dict = row_filter4._to_dict() row_filter5 = RowFilterChain(filters=[row_filter3, row_filter4]) - filter_dict = row_filter5.to_dict() + filter_dict = row_filter5._to_dict() expected_dict = {"chain": {"filters": [row_filter3_dict, row_filter4_dict]}} assert filter_dict == expected_dict @@ -1559,13 +1559,13 @@ def test_row_filter_union_to_dict(): from google.cloud.bigtable_v2.types import data as data_v2_pb2 row_filter1 = StripValueTransformerFilter(True) - row_filter1_dict = row_filter1.to_dict() + row_filter1_dict = row_filter1._to_dict() row_filter2 = RowSampleFilter(0.25) - row_filter2_dict = row_filter2.to_dict() + row_filter2_dict = row_filter2._to_dict() row_filter3 = RowFilterUnion(filters=[row_filter1, row_filter2]) - filter_dict = row_filter3.to_dict() + filter_dict = row_filter3._to_dict() expected_dict = {"interleave": {"filters": [row_filter1_dict, row_filter2_dict]}} assert filter_dict == expected_dict @@ -1609,13 +1609,13 @@ def test_row_filter_union_to_dict_nested(): row_filter2 = RowSampleFilter(0.25) row_filter3 = RowFilterUnion(filters=[row_filter1, row_filter2]) - row_filter3_dict = row_filter3.to_dict() + row_filter3_dict = row_filter3._to_dict() row_filter4 = CellsRowLimitFilter(11) - row_filter4_dict = row_filter4.to_dict() + row_filter4_dict = row_filter4._to_dict() row_filter5 = RowFilterUnion(filters=[row_filter3, row_filter4]) - filter_dict = row_filter5.to_dict() + filter_dict = row_filter5._to_dict() expected_dict = {"interleave": {"filters": [row_filter3_dict, row_filter4_dict]}} assert filter_dict == expected_dict @@ -1750,18 +1750,18 @@ def test_conditional_row_filter_to_dict(): from google.cloud.bigtable_v2.types import data as data_v2_pb2 row_filter1 = StripValueTransformerFilter(True) - row_filter1_dict = row_filter1.to_dict() + row_filter1_dict = row_filter1._to_dict() row_filter2 = RowSampleFilter(0.25) - row_filter2_dict = row_filter2.to_dict() + row_filter2_dict = row_filter2._to_dict() row_filter3 = CellsRowOffsetFilter(11) - row_filter3_dict = row_filter3.to_dict() + row_filter3_dict = row_filter3._to_dict() row_filter4 = ConditionalRowFilter( row_filter1, true_filter=row_filter2, false_filter=row_filter3 ) - filter_dict = row_filter4.to_dict() + filter_dict = row_filter4._to_dict() expected_dict = { "condition": { @@ -1804,13 +1804,13 @@ def test_conditional_row_filter_to_dict_true_only(): from google.cloud.bigtable_v2.types import data as data_v2_pb2 row_filter1 = StripValueTransformerFilter(True) - row_filter1_dict = row_filter1.to_dict() + row_filter1_dict = row_filter1._to_dict() row_filter2 = RowSampleFilter(0.25) - row_filter2_dict = row_filter2.to_dict() + row_filter2_dict = row_filter2._to_dict() row_filter3 = ConditionalRowFilter(row_filter1, true_filter=row_filter2) - filter_dict = row_filter3.to_dict() + filter_dict = row_filter3._to_dict() expected_dict = { "condition": { @@ -1852,13 +1852,13 @@ def test_conditional_row_filter_to_dict_false_only(): from google.cloud.bigtable_v2.types import data as data_v2_pb2 row_filter1 = StripValueTransformerFilter(True) - row_filter1_dict = row_filter1.to_dict() + row_filter1_dict = row_filter1._to_dict() row_filter2 = RowSampleFilter(0.25) - row_filter2_dict = row_filter2.to_dict() + row_filter2_dict = row_filter2._to_dict() row_filter3 = ConditionalRowFilter(row_filter1, false_filter=row_filter2) - filter_dict = row_filter3.to_dict() + filter_dict = row_filter3._to_dict() expected_dict = { "condition": {