diff --git a/packages/google-cloud-bigtable/google/cloud/bigtable/data/_async/client.py b/packages/google-cloud-bigtable/google/cloud/bigtable/data/_async/client.py index b2c13521240f..a47eabeb9994 100644 --- a/packages/google-cloud-bigtable/google/cloud/bigtable/data/_async/client.py +++ b/packages/google-cloud-bigtable/google/cloud/bigtable/data/_async/client.py @@ -83,7 +83,7 @@ from google.cloud.bigtable.data.execute_query.values import ExecuteQueryValueType from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule -from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery, RowRange from google.cloud.bigtable.data.row import Row from google.cloud.bigtable.data.row_filters import ( CellsRowLimitFilter, @@ -1389,6 +1389,7 @@ async def row_exists( async def sample_row_keys( self, *, + row_range: RowRange | None = None, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, retryable_errors: Sequence[type[Exception]] @@ -1406,6 +1407,8 @@ async def sample_row_keys( row_keys, along with offset positions in the table Args: + row_range: the range of rows to sample. If not provided, samples the + entire table. operation_timeout: the time budget for the entire operation, in seconds. Failed requests will be retried within the budget.i Defaults to the Table's default_operation_timeout @@ -1443,7 +1446,9 @@ async def sample_row_keys( async def execute_rpc(): results = await self.client._gapic_client.sample_row_keys( request=SampleRowKeysRequest( - app_profile_id=self.app_profile_id, **self._request_path + app_profile_id=self.app_profile_id, + row_range=row_range._to_pb() if row_range is not None else None, + **self._request_path, ), timeout=next(attempt_timeout_gen), retry=None, diff --git a/packages/google-cloud-bigtable/google/cloud/bigtable/data/_sync_autogen/client.py b/packages/google-cloud-bigtable/google/cloud/bigtable/data/_sync_autogen/client.py index 9dc118de0289..f89373718cc9 100644 --- a/packages/google-cloud-bigtable/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/packages/google-cloud-bigtable/google/cloud/bigtable/data/_sync_autogen/client.py @@ -84,7 +84,7 @@ from google.cloud.bigtable.data.execute_query.values import ExecuteQueryValueType from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule -from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery, RowRange from google.cloud.bigtable.data.row import Row from google.cloud.bigtable.data.row_filters import ( CellsRowLimitFilter, @@ -1139,6 +1139,7 @@ def row_exists( def sample_row_keys( self, *, + row_range: RowRange | None = None, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, retryable_errors: Sequence[type[Exception]] @@ -1155,6 +1156,8 @@ def sample_row_keys( row_keys, along with offset positions in the table Args: + row_range: the range of rows to sample. If not provided, samples the + entire table. operation_timeout: the time budget for the entire operation, in seconds. Failed requests will be retried within the budget.i Defaults to the Table's default_operation_timeout @@ -1187,7 +1190,9 @@ def sample_row_keys( def execute_rpc(): results = self.client._gapic_client.sample_row_keys( request=SampleRowKeysRequest( - app_profile_id=self.app_profile_id, **self._request_path + app_profile_id=self.app_profile_id, + row_range=row_range._to_pb() if row_range is not None else None, + **self._request_path, ), timeout=next(attempt_timeout_gen), retry=None, diff --git a/packages/google-cloud-bigtable/test_proxy/handlers/client_handler_data_async.py b/packages/google-cloud-bigtable/test_proxy/handlers/client_handler_data_async.py index 246b7fcd70cc..38084e991514 100644 --- a/packages/google-cloud-bigtable/test_proxy/handlers/client_handler_data_async.py +++ b/packages/google-cloud-bigtable/test_proxy/handlers/client_handler_data_async.py @@ -250,7 +250,11 @@ async def SampleRowKeys(self, request, **kwargs): kwargs["operation_timeout"] = ( kwargs.get("operation_timeout", self.per_operation_timeout) or 20 ) - result = CrossSync.rm_aio(await table.sample_row_keys(**kwargs)) + row_range = None + if "row_range" in request: + from google.cloud.bigtable.data.read_rows_query import RowRange + row_range = RowRange._from_dict(request["row_range"]) + result = CrossSync.rm_aio(await table.sample_row_keys(row_range=row_range, **kwargs)) return result @error_safe diff --git a/packages/google-cloud-bigtable/test_proxy/handlers/client_handler_data_sync_autogen.py b/packages/google-cloud-bigtable/test_proxy/handlers/client_handler_data_sync_autogen.py index b2864db94b21..869014be0598 100644 --- a/packages/google-cloud-bigtable/test_proxy/handlers/client_handler_data_sync_autogen.py +++ b/packages/google-cloud-bigtable/test_proxy/handlers/client_handler_data_sync_autogen.py @@ -187,7 +187,12 @@ async def SampleRowKeys(self, request, **kwargs): kwargs["operation_timeout"] = ( kwargs.get("operation_timeout", self.per_operation_timeout) or 20 ) - result = table.sample_row_keys(**kwargs) + row_range = None + if "row_range" in request: + from google.cloud.bigtable.data.read_rows_query import RowRange + + row_range = RowRange._from_dict(request["row_range"]) + result = table.sample_row_keys(row_range=row_range, **kwargs) return result @error_safe diff --git a/packages/google-cloud-bigtable/tests/unit/data/_async/test_client.py b/packages/google-cloud-bigtable/tests/unit/data/_async/test_client.py index 6c6719615c40..2dfe50444263 100644 --- a/packages/google-cloud-bigtable/tests/unit/data/_async/test_client.py +++ b/packages/google-cloud-bigtable/tests/unit/data/_async/test_client.py @@ -2392,6 +2392,32 @@ async def test_sample_row_keys(self): assert result[1] == samples[1] assert result[2] == samples[2] + @CrossSync.pytest + async def test_sample_row_keys_w_row_range(self): + """ + Test that method returns the expected key samples when row_range is provided + """ + samples = [ + (b"a_key1", 100), + (b"b", 200), + ] + from google.cloud.bigtable.data import RowRange + + row_range = RowRange(start_key=b"a", end_key=b"b") + async with self._make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", CrossSync.Mock() + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream(samples) + result = await table.sample_row_keys(row_range=row_range) + assert len(result) == 2 + assert result[0] == samples[0] + assert result[1] == samples[1] + sample_row_keys.assert_called_once() + called_request = sample_row_keys.call_args[1]["request"] + assert called_request.row_range == row_range._to_pb() + @CrossSync.pytest async def test_sample_row_keys_bad_timeout(self): """ diff --git a/packages/google-cloud-bigtable/tests/unit/data/_sync_autogen/test_client.py b/packages/google-cloud-bigtable/tests/unit/data/_sync_autogen/test_client.py index 79ad903b6191..6d061d16147b 100644 --- a/packages/google-cloud-bigtable/tests/unit/data/_sync_autogen/test_client.py +++ b/packages/google-cloud-bigtable/tests/unit/data/_sync_autogen/test_client.py @@ -1998,6 +1998,28 @@ def test_sample_row_keys(self): assert result[1] == samples[1] assert result[2] == samples[2] + def test_sample_row_keys_w_row_range(self): + """Test that method returns the expected key samples when row_range is provided""" + samples = [(b"a_key1", 100), (b"b", 200)] + from google.cloud.bigtable.data import RowRange + + row_range = RowRange(start_key=b"a", end_key=b"b") + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, + "sample_row_keys", + CrossSync._Sync_Impl.Mock(), + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream(samples) + result = table.sample_row_keys(row_range=row_range) + assert len(result) == 2 + assert result[0] == samples[0] + assert result[1] == samples[1] + sample_row_keys.assert_called_once() + called_request = sample_row_keys.call_args[1]["request"] + assert called_request.row_range == row_range._to_pb() + def test_sample_row_keys_bad_timeout(self): """should raise error if timeout is negative""" with self._make_client() as client: