From 75a431b4d24ac8efd9be13f220250c48d01c32f3 Mon Sep 17 00:00:00 2001 From: Abraham Lazaro Martinez Date: Thu, 8 May 2025 18:56:30 +0000 Subject: [PATCH 1/5] fix: first test --- .../plugins/vertex_ai/models/retriever.py | 7 +- .../tests/vector_search/test_vector_search.py | 328 ++++++++++++++++++ 2 files changed, 333 insertions(+), 2 deletions(-) create mode 100644 py/plugins/vertex-ai/tests/vector_search/test_vector_search.py diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py index 23e6644d95..ff68dcdbee 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py @@ -121,8 +121,11 @@ async def _get_closest_documents( index endpoint path in its metadata. """ metadata = request.query.metadata - if not metadata or 'index_endpoint_path' not in metadata or 'api_endpoint' not in metadata: - raise AttributeError('Request provides no data about index endpoint path') + + required_keys = ['index_endpoint_path', 'api_endpoint', 'deployed_index_id'] + + if not metadata or not all(key in metadata for key in required_keys): + raise AttributeError('Request provides no enough data about index') api_endpoint = metadata['api_endpoint'] index_endpoint_path = metadata['index_endpoint_path'] diff --git a/py/plugins/vertex-ai/tests/vector_search/test_vector_search.py b/py/plugins/vertex-ai/tests/vector_search/test_vector_search.py new file mode 100644 index 0000000000..2a2f888982 --- /dev/null +++ b/py/plugins/vertex-ai/tests/vector_search/test_vector_search.py @@ -0,0 +1,328 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from functools import partial +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from google.cloud.aiplatform_v1 import ( + FindNeighborsRequest, + FindNeighborsResponse, + IndexDatapoint, + MatchServiceAsyncClient, + NearestNeighbors, + Neighbor, +) + +from genkit.ai import Genkit +from genkit.blocks.document import Document, DocumentData, DocumentPart +from genkit.core.typing import Embedding +from genkit.plugins.vertex_ai.models.retriever import ( + BigQueryRetriever, + FirestoreRetriever, +) +from genkit.types import ( + ActionRunContext, + EmbedRequest, + EmbedResponse, + RetrieverRequest, + RetrieverResponse, + TextPart, +) + + +@pytest.fixture +def bq_retriever_instance(): + """Common initialization of bq retriever.""" + return BigQueryRetriever( + ai=MagicMock(), + name='test', + match_service_client_generator=MagicMock(), + embedder='embedder', + embedder_options=None, + bq_client=MagicMock(), + dataset_id='dataset_id', + table_id='table_id', + ) + + +def test_bigquery_retriever__init__(bq_retriever_instance): + """Init test.""" + bq_retriever = bq_retriever_instance + + assert bq_retriever is not None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "options, top_k", + [ + ( + {'limit': 10}, + 10, + ), + ( + {}, + 3, + ), + ( + None, + 3, + ) + ] +) +async def test_bigquery_retriever_retrieve( + bq_retriever_instance, + options, + top_k, +): + """Test retrieve method bq retriever.""" + # Mock query embedder + mock_embedding = MagicMock() + mock_embedding.embeddings = [ + Embedding( + embedding=[0.1, 0.2, 0.3], + ), + ] + + mock_genkit = MagicMock(spec=Genkit) + mock_genkit.embed.return_value = mock_embedding + + bq_retriever_instance.ai = mock_genkit + + # Mock _get_closest_documents + mock__get_closest_documents_result = [ + Document.from_text( + text='1', + metadata={ + 'distance': 0.0, + 'id': 1 + }, + ), + Document.from_text( + text='2', + metadata={ + 'distance': 0.0, + 'id': 2 + }, + ), + ] + + bq_retriever_instance._get_closest_documents = AsyncMock( + return_value=mock__get_closest_documents_result, + ) + + # Executes + await bq_retriever_instance.retrieve( + RetrieverRequest( + query=DocumentData( + content=[ + TextPart( + text='test-1' + ), + ], + ), + options=options, + ), + MagicMock(spec=ActionRunContext), + ) + + # Assert mocks + bq_retriever_instance.ai.embed.assert_called_once_with( + embedder='embedder', + documents=[Document( + content=[ + TextPart( + text='test-1' + ), + ], + ), + ], + options={}, + ) + + bq_retriever_instance._get_closest_documents.assert_awaited_once_with( + request=RetrieverRequest( + query=DocumentData( + content=[ + TextPart( + text='test-1' + ), + ], + ), + options=options, + ), + top_k=top_k, + query_embeddings=Embedding( + embedding=[0.1, 0.2, 0.3], + ), + ) + + +@pytest.mark.asyncio +async def test_bigquery__get_closest_documents(bq_retriever_instance): + """Test bigquery retriever _get_closest_documents.""" + # Mock find_neighbors method + mock_vector_search_client = MagicMock(spec=MatchServiceAsyncClient) + + # find_neighbors response + mock_nn = MagicMock() + mock_nn.neighbors = [] + + mock_nn_response = MagicMock(spec=FindNeighborsResponse) + mock_nn_response.nearest_neighbors = [ + mock_nn, + ] + + mock_vector_search_client.find_neighbors = AsyncMock( + return_value=mock_nn_response, + ) + + # find_neighbors call + bq_retriever_instance._match_service_client_generator.return_value = mock_vector_search_client + + # Mock _retrieve_neighbours_data_from_db method + mock__retrieve_neighbours_data_from_db_result = [ + Document.from_text( + text='1', + metadata={ + 'distance': 0.0, + 'id': 1 + } + ), + Document.from_text( + text='2', + metadata={ + 'distance': 0.0, + 'id': 2 + } + ), + ] + + bq_retriever_instance._retrieve_neighbours_data_from_db = AsyncMock( + return_value=mock__retrieve_neighbours_data_from_db_result, + ) + + await bq_retriever_instance._get_closest_documents( + request=RetrieverRequest( + query=DocumentData( + content=[ + TextPart( + text='test-1' + ) + ], + metadata={ + 'index_endpoint_path': 'index_endpoint_path', + 'api_endpoint': 'api_endpoint', + 'deployed_index_id': 'deployed_index_id', + } + ), + options={ + 'limit': 10, + } + ), + top_k=10, + query_embeddings=Embedding( + embedding=[0.1, 0.2, 0.3], + ) + ) + + # Assert calls + mock_vector_search_client.find_neighbors.assert_awaited_once_with( + request=FindNeighborsRequest( + index_endpoint="index_endpoint_path", + deployed_index_id="deployed_index_id", + queries=[ + FindNeighborsRequest.Query( + datapoint=IndexDatapoint(feature_vector=[0.1, 0.2, 0.3]), + neighbor_count=10, + ) + ], + ) + ) + + bq_retriever_instance._retrieve_neighbours_data_from_db.assert_awaited_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "metadata", + [ + { + 'index_endpoint_path': 'index_endpoint_path', + }, + { + 'api_endpoint': 'api_endpoint', + }, + { + 'deployed_index_id': 'deployed_index_id', + }, + { + 'index_endpoint_path': 'index_endpoint_path', + 'api_endpoint': 'api_endpoint', + }, + { + 'index_endpoint_path': 'index_endpoint_path', + 'deployed_index_id': 'deployed_index_id', + }, + { + 'api_endpoint': 'api_endpoint', + 'deployed_index_id': 'deployed_index_id', + } + ] +) +async def test_bigquery__get_closest_documents_fail( + bq_retriever_instance, + metadata, +): + """Test failures bigquery retriever _get_closest_documents.""" + with pytest.raises(AttributeError): + await bq_retriever_instance._get_closest_documents( + request=RetrieverRequest( + query=DocumentData( + content=[ + TextPart( + text='test-1' + ) + ], + metadata=metadata, + ), + options={ + 'limit': 10, + } + ), + top_k=10, + query_embeddings=Embedding( + embedding=[0.1, 0.2, 0.3], + ) + ) + + +def test_firestore_retriever__init__(): + """Init test.""" + fs_retriever = FirestoreRetriever( + ai=MagicMock(), + name='test', + match_service_client_generator=MagicMock(), + embedder='embedder', + embedder_options=None, + firestore_client=MagicMock(), + collection_name='collection_name', + ) + + assert fs_retriever is not None From a299af68324dee046eb4d5ee277b59a5b718e1d7 Mon Sep 17 00:00:00 2001 From: Abraham Lazaro Martinez Date: Thu, 8 May 2025 18:58:59 +0000 Subject: [PATCH 2/5] fix: better attribute request failure message --- .../src/genkit/plugins/vertex_ai/models/retriever.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py index ff68dcdbee..eb54a909b6 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py @@ -124,8 +124,12 @@ async def _get_closest_documents( required_keys = ['index_endpoint_path', 'api_endpoint', 'deployed_index_id'] - if not metadata or not all(key in metadata for key in required_keys): - raise AttributeError('Request provides no enough data about index') + if not metadata: + raise AttributeError('Request metadata provides no data about index') + + for rkey in required_keys: + if rkey not in metadata: + raise AttributeError(f'Request metadata provides no data for {rkey}') api_endpoint = metadata['api_endpoint'] index_endpoint_path = metadata['index_endpoint_path'] From 99dae220c492b8dde23ffcfa817cdfeac519f94a Mon Sep 17 00:00:00 2001 From: Abraham Lazaro Martinez Date: Thu, 8 May 2025 20:15:15 +0000 Subject: [PATCH 3/5] fix: linters and finish tests --- .../plugins/vertex_ai/models/retriever.py | 51 ++-- .../tests/vector_search/test_vector_search.py | 278 +++++++++++++----- .../src/sample.py | 4 +- .../src/setup_env.py | 82 +++--- .../src/sample.py | 4 +- 5 files changed, 283 insertions(+), 136 deletions(-) diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py index eb54a909b6..9a288ac08f 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py @@ -21,7 +21,7 @@ import structlog from google.cloud import bigquery, firestore -from google.cloud.aiplatform_v1 import FindNeighborsRequest, IndexDatapoint, Neighbor +from google.cloud.aiplatform_v1 import FindNeighborsRequest, FindNeighborsResponse, IndexDatapoint from pydantic import BaseModel, Field, ValidationError from genkit.ai import Genkit @@ -49,6 +49,7 @@ class DocRetriever(ABC): embedder: The name of the embedder to use for generating embeddings. embedder_options: Options to pass to the embedder. """ + def __init__( self, ai: Genkit, @@ -117,7 +118,7 @@ async def _get_closest_documents( A list of Document objects representing the closest documents. Raises: - AttributeError: If the request does not contain the necessary + AttributeError: If the request does not contain the necessary index endpoint path in its metadata. """ metadata = request.query.metadata @@ -135,9 +136,7 @@ async def _get_closest_documents( index_endpoint_path = metadata['index_endpoint_path'] deployed_index_id = metadata['deployed_index_id'] - client_options = { - "api_endpoint": api_endpoint - } + client_options = {'api_endpoint': api_endpoint} vector_search_client = self._match_service_client_generator( client_options=client_options, @@ -159,7 +158,9 @@ async def _get_closest_documents( return await self._retrieve_neighbours_data_from_db(neighbours=response.nearest_neighbors[0].neighbors) @abstractmethod - async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> list[Document]: + async def _retrieve_neighbours_data_from_db( + self, neighbours: list[FindNeighborsResponse.Neighbor] + ) -> list[Document]: """Retrieves document data from the database based on neighbor information. This method must be implemented by subclasses to define how document @@ -187,8 +188,14 @@ class BigQueryRetriever(DocRetriever): dataset_id: The ID of the BigQuery dataset. table_id: The ID of the BigQuery table. """ + def __init__( - self, bq_client: bigquery.Client, dataset_id: str, table_id: str, *args, **kwargs, + self, + bq_client: bigquery.Client, + dataset_id: str, + table_id: str, + *args, + **kwargs, ) -> None: """Initializes the BigQueryRetriever. @@ -204,11 +211,13 @@ def __init__( self.dataset_id = dataset_id self.table_id = table_id - async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> list[Document]: + async def _retrieve_neighbours_data_from_db( + self, neighbours: list[FindNeighborsResponse.Neighbor] + ) -> list[Document]: """Retrieves document data from the BigQuery table for the given neighbors. Constructs and executes a BigQuery query to fetch document data based on - the IDs obtained. Handles potential errors during query execution and + the IDs obtained. Handles potential errors during query execution and document parsing. Args: @@ -220,16 +229,10 @@ async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> Returns an empty list if no IDs are found in the neighbors or if the query fails. """ - ids = [ - n.datapoint.datapoint_id - for n in neighbours - if n.datapoint and n.datapoint.datapoint_id - ] + ids = [n.datapoint.datapoint_id for n in neighbours if n.datapoint and n.datapoint.datapoint_id] distance_by_id = { - n.datapoint.datapoint_id: n.distance - for n in neighbours - if n.datapoint and n.datapoint.datapoint_id + n.datapoint.datapoint_id: n.distance for n in neighbours if n.datapoint and n.datapoint.datapoint_id } if not ids: @@ -283,8 +286,13 @@ class FirestoreRetriever(DocRetriever): db: The Firestore client. collection_name: The name of the Firestore collection. """ + def __init__( - self, firestore_client: firestore.AsyncClient, collection_name: str, *args, **kwargs, + self, + firestore_client: firestore.AsyncClient, + collection_name: str, + *args, + **kwargs, ) -> None: """Initializes the FirestoreRetriever. @@ -298,7 +306,9 @@ def __init__( self.db = firestore_client self.collection_name = collection_name - async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> list[Document]: + async def _retrieve_neighbours_data_from_db( + self, neighbours: list[FindNeighborsResponse.Neighbor] + ) -> list[Document]: """Retrieves document data from the Firestore collection for the given neighbors. Fetches document data from Firestore based on the IDs of the nearest neighbors. @@ -321,7 +331,7 @@ async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> if doc_snapshot.exists: doc_data = doc_snapshot.to_dict() or {} - content = doc_data.get('content') + content = doc_data.get('content', '') content = json.dumps(content) if isinstance(content, dict) else str(content) metadata = doc_data.get('metadata', {}) @@ -349,4 +359,5 @@ class RetrieverOptionsSchema(BaseModel): Attributes: limit: Number of documents to retrieve. """ + limit: int | None = Field(title='Number of documents to retrieve', default=None) diff --git a/py/plugins/vertex-ai/tests/vector_search/test_vector_search.py b/py/plugins/vertex-ai/tests/vector_search/test_vector_search.py index 2a2f888982..0e6b87ac35 100644 --- a/py/plugins/vertex-ai/tests/vector_search/test_vector_search.py +++ b/py/plugins/vertex-ai/tests/vector_search/test_vector_search.py @@ -14,23 +14,21 @@ # # SPDX-License-Identifier: Apache-2.0 -import unittest -from functools import partial -from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch +import json +from unittest.mock import AsyncMock, MagicMock import pytest +from google.cloud import bigquery from google.cloud.aiplatform_v1 import ( FindNeighborsRequest, FindNeighborsResponse, IndexDatapoint, MatchServiceAsyncClient, - NearestNeighbors, - Neighbor, + types, ) from genkit.ai import Genkit -from genkit.blocks.document import Document, DocumentData, DocumentPart +from genkit.blocks.document import Document, DocumentData from genkit.core.typing import Embedding from genkit.plugins.vertex_ai.models.retriever import ( BigQueryRetriever, @@ -38,10 +36,7 @@ ) from genkit.types import ( ActionRunContext, - EmbedRequest, - EmbedResponse, RetrieverRequest, - RetrieverResponse, TextPart, ) @@ -70,7 +65,7 @@ def test_bigquery_retriever__init__(bq_retriever_instance): @pytest.mark.asyncio @pytest.mark.parametrize( - "options, top_k", + 'options, top_k', [ ( {'limit': 10}, @@ -83,8 +78,8 @@ def test_bigquery_retriever__init__(bq_retriever_instance): ( None, 3, - ) - ] + ), + ], ) async def test_bigquery_retriever_retrieve( bq_retriever_instance, @@ -109,17 +104,11 @@ async def test_bigquery_retriever_retrieve( mock__get_closest_documents_result = [ Document.from_text( text='1', - metadata={ - 'distance': 0.0, - 'id': 1 - }, + metadata={'distance': 0.0, 'id': 1}, ), Document.from_text( text='2', - metadata={ - 'distance': 0.0, - 'id': 2 - }, + metadata={'distance': 0.0, 'id': 2}, ), ] @@ -132,9 +121,7 @@ async def test_bigquery_retriever_retrieve( RetrieverRequest( query=DocumentData( content=[ - TextPart( - text='test-1' - ), + TextPart(text='test-1'), ], ), options=options, @@ -145,11 +132,10 @@ async def test_bigquery_retriever_retrieve( # Assert mocks bq_retriever_instance.ai.embed.assert_called_once_with( embedder='embedder', - documents=[Document( + documents=[ + Document( content=[ - TextPart( - text='test-1' - ), + TextPart(text='test-1'), ], ), ], @@ -160,9 +146,7 @@ async def test_bigquery_retriever_retrieve( request=RetrieverRequest( query=DocumentData( content=[ - TextPart( - text='test-1' - ), + TextPart(text='test-1'), ], ), options=options, @@ -198,20 +182,8 @@ async def test_bigquery__get_closest_documents(bq_retriever_instance): # Mock _retrieve_neighbours_data_from_db method mock__retrieve_neighbours_data_from_db_result = [ - Document.from_text( - text='1', - metadata={ - 'distance': 0.0, - 'id': 1 - } - ), - Document.from_text( - text='2', - metadata={ - 'distance': 0.0, - 'id': 2 - } - ), + Document.from_text(text='1', metadata={'distance': 0.0, 'id': 1}), + Document.from_text(text='2', metadata={'distance': 0.0, 'id': 2}), ] bq_retriever_instance._retrieve_neighbours_data_from_db = AsyncMock( @@ -221,32 +193,28 @@ async def test_bigquery__get_closest_documents(bq_retriever_instance): await bq_retriever_instance._get_closest_documents( request=RetrieverRequest( query=DocumentData( - content=[ - TextPart( - text='test-1' - ) - ], + content=[TextPart(text='test-1')], metadata={ 'index_endpoint_path': 'index_endpoint_path', 'api_endpoint': 'api_endpoint', 'deployed_index_id': 'deployed_index_id', - } + }, ), options={ 'limit': 10, - } + }, ), top_k=10, query_embeddings=Embedding( embedding=[0.1, 0.2, 0.3], - ) + ), ) # Assert calls mock_vector_search_client.find_neighbors.assert_awaited_once_with( request=FindNeighborsRequest( - index_endpoint="index_endpoint_path", - deployed_index_id="deployed_index_id", + index_endpoint='index_endpoint_path', + deployed_index_id='deployed_index_id', queries=[ FindNeighborsRequest.Query( datapoint=IndexDatapoint(feature_vector=[0.1, 0.2, 0.3]), @@ -261,7 +229,7 @@ async def test_bigquery__get_closest_documents(bq_retriever_instance): @pytest.mark.asyncio @pytest.mark.parametrize( - "metadata", + 'metadata', [ { 'index_endpoint_path': 'index_endpoint_path', @@ -283,8 +251,8 @@ async def test_bigquery__get_closest_documents(bq_retriever_instance): { 'api_endpoint': 'api_endpoint', 'deployed_index_id': 'deployed_index_id', - } - ] + }, + ], ) async def test_bigquery__get_closest_documents_fail( bq_retriever_instance, @@ -295,27 +263,114 @@ async def test_bigquery__get_closest_documents_fail( await bq_retriever_instance._get_closest_documents( request=RetrieverRequest( query=DocumentData( - content=[ - TextPart( - text='test-1' - ) - ], + content=[TextPart(text='test-1')], metadata=metadata, ), options={ 'limit': 10, - } + }, ), top_k=10, query_embeddings=Embedding( embedding=[0.1, 0.2, 0.3], - ) + ), ) -def test_firestore_retriever__init__(): - """Init test.""" - fs_retriever = FirestoreRetriever( +@pytest.mark.asyncio +async def test_bigquery__retrieve_neighbours_data_from_db( + bq_retriever_instance, +): + """Test bigquery retriver _retrieve_neighbours_data_from_db.""" + # Mock query job result from bigquery query + mock_bq_query_job = MagicMock() + mock_bq_query_job.result.return_value = [ + { + 'id': 'doc1', + 'content': {'body': 'text for document 1'}, + }, + {'id': 'doc2', 'content': json.dumps({'body': 'text for document 2'}), 'metadata': {'date': 'today'}}, + {}, # should error without skipping first two rows + ] + + bq_retriever_instance.bq_client.query.return_value = mock_bq_query_job + + # call the method + result = await bq_retriever_instance._retrieve_neighbours_data_from_db( + neighbours=[ + FindNeighborsResponse.Neighbor( + datapoint=types.index.IndexDatapoint(datapoint_id='doc1'), + distance=0.0, + sparse_distance=0.0, + ), + FindNeighborsResponse.Neighbor( + datapoint=types.index.IndexDatapoint(datapoint_id='doc2'), + distance=0.0, + sparse_distance=0.0, + ), + ] + ) + + # Assert results and calls + expected = [ + Document.from_text( + text=json.dumps( + { + 'body': 'text for document 1', + }, + ), + metadata={'id': 'doc1', 'distance': 0.0}, + ), + Document.from_text( + text=json.dumps( + { + 'body': 'text for document 2', + }, + ), + metadata={'id': 'doc2', 'distance': 0.0, 'date': 'today'}, + ), + ] + + assert result == expected + + bq_retriever_instance.bq_client.query.assert_called_once() + + mock_bq_query_job.result.assert_called_once() + + +@pytest.mark.asyncio +async def test_bigquery_retrieve_neighbours_data_from_db_fail( + bq_retriever_instance, +): + """Test bigquery retriver _retrieve_neighbours_data_from_db when fails.""" + # Mock exception from bigquery query + bq_retriever_instance.bq_client.query.raises = AttributeError + + # call the method + result = await bq_retriever_instance._retrieve_neighbours_data_from_db( + neighbours=[ + FindNeighborsResponse.Neighbor( + datapoint=types.index.IndexDatapoint(datapoint_id='doc1'), + distance=0.0, + sparse_distance=0.0, + ), + FindNeighborsResponse.Neighbor( + datapoint=types.index.IndexDatapoint(datapoint_id='doc2'), + distance=0.0, + sparse_distance=0.0, + ), + ] + ) + + assert len(result) == 0 + + bq_retriever_instance.bq_client.query.assert_called_once() + + +@pytest.fixture +def fs_retriever_instance(): + """Common initialization of bq retriever.""" + return FirestoreRetriever( ai=MagicMock(), name='test', match_service_client_generator=MagicMock(), @@ -325,4 +380,89 @@ def test_firestore_retriever__init__(): collection_name='collection_name', ) - assert fs_retriever is not None + +def test_firestore_retriever__init__(fs_retriever_instance): + """Init test.""" + assert fs_retriever_instance is not None + + +@pytest.mark.asyncio +async def test_firesstore__retrieve_neighbours_data_from_db( + fs_retriever_instance, +): + """Test _retrieve_neighbours_data_from_db for firestore retriever.""" + # Mock storage of firestore + storage = { + 'doc1': { + 'content': {'body': 'text for document 1'}, + }, + 'doc2': {'content': json.dumps({'body': 'text for document 2'}), 'metadata': {'date': 'today'}}, + 'doc3': {}, + } + + # Mock get from firestore + class MockCollection: + def document(self, document_id): + doc_ref = MagicMock() + doc_snapshot = MagicMock() + + doc_ref.get.return_value = doc_snapshot + if storage.get(document_id) is not None: + doc_snapshot.exists = True + doc_snapshot.to_dict.return_value = storage.get(document_id) + else: + doc_snapshot.exists = False + + return doc_ref + + fs_retriever_instance.db.collection.return_value = MockCollection() + + # call the method + result = await fs_retriever_instance._retrieve_neighbours_data_from_db( + neighbours=[ + FindNeighborsResponse.Neighbor( + datapoint=types.index.IndexDatapoint(datapoint_id='doc1'), + distance=0.0, + sparse_distance=0.0, + ), + FindNeighborsResponse.Neighbor( + datapoint=types.index.IndexDatapoint(datapoint_id='doc2'), + distance=0.0, + sparse_distance=0.0, + ), + FindNeighborsResponse.Neighbor( + datapoint=types.index.IndexDatapoint(datapoint_id='doc3'), + distance=0.0, + sparse_distance=0.0, + ), + ] + ) + + # Assert results and calls + expected = [ + Document.from_text( + text=json.dumps( + { + 'body': 'text for document 1', + }, + ), + metadata={'id': 'doc1', 'distance': 0.0}, + ), + Document.from_text( + text=json.dumps( + { + 'body': 'text for document 2', + }, + ), + metadata={'id': 'doc2', 'distance': 0.0, 'date': 'today'}, + ), + Document.from_text( + text='', + metadata={ + 'id': 'doc3', + 'distance': 0.0, + }, + ), + ] + + assert result == expected diff --git a/py/samples/vertex-ai-vector-search-bigquery/src/sample.py b/py/samples/vertex-ai-vector-search-bigquery/src/sample.py index 39c994a261..5f04d979c9 100644 --- a/py/samples/vertex-ai-vector-search-bigquery/src/sample.py +++ b/py/samples/vertex-ai-vector-search-bigquery/src/sample.py @@ -71,12 +71,14 @@ class QueryFlowInputSchema(BaseModel): """Input schema.""" + query: str k: int class QueryFlowOutputSchema(BaseModel): """Output schema.""" + result: list[dict] length: int time: int @@ -128,7 +130,7 @@ async def query_flow(_input: QueryFlowInputSchema) -> QueryFlowOutputSchema: async def main() -> None: """Main function.""" query_input = QueryFlowInputSchema( - query="Content for doc", + query='Content for doc', k=3, ) diff --git a/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py b/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py index 7fc6710c9b..b179fa2b16 100644 --- a/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py +++ b/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py @@ -57,7 +57,7 @@ }, embedder=EMBEDDING_MODEL, embedder_options={'task': 'RETRIEVAL_DOCUMENT'}, - ) + ), ] ) @@ -71,24 +71,28 @@ async def generate_embeddings(): """ toy_documents = [ { - "id": "doc1", - "content": {"title": "Document 1", "body": "This is the content of document 1."}, - "metadata": {"author": "Alice", "date": "2024-01-15"}, + 'id': 'doc1', + 'content': {'title': 'Document 1', 'body': 'This is the content of document 1.'}, + 'metadata': {'author': 'Alice', 'date': '2024-01-15'}, }, { - "id": "doc2", - "content": {"title": "Document 2", "body": "This is the content of document 2."}, - "metadata": {"author": "Bob", "date": "2024-02-20"}, + 'id': 'doc2', + 'content': {'title': 'Document 2', 'body': 'This is the content of document 2.'}, + 'metadata': {'author': 'Bob', 'date': '2024-02-20'}, }, { - "id": "doc3", - "content": {"title": "Document 3", "body": "Content for doc 3"}, - "metadata": {"author": "Charlie", "date": "2024-03-01"}, + 'id': 'doc3', + 'content': {'title': 'Document 3', 'body': 'Content for doc 3'}, + 'metadata': {'author': 'Charlie', 'date': '2024-03-01'}, }, ] create_bigquery_dataset_and_table( - PROJECT_ID, LOCATION, BIGQUERY_DATASET_NAME, BIGQUERY_TABLE_NAME, toy_documents, + PROJECT_ID, + LOCATION, + BIGQUERY_DATASET_NAME, + BIGQUERY_TABLE_NAME, + toy_documents, ) results_dict = get_data_from_bigquery( @@ -98,10 +102,7 @@ async def generate_embeddings(): table_id=BIGQUERY_TABLE_NAME, ) - genkit_documents = [ - types.Document(content=[types.TextPart(text=text)]) - for text in results_dict.values() - ] + genkit_documents = [types.Document(content=[types.TextPart(text=text)]) for text in results_dict.values()] embed_response = await ai.embed( embedder=vertexai_name(EMBEDDING_MODEL), @@ -112,7 +113,7 @@ async def generate_embeddings(): embeddings = [emb.embedding for emb in embed_response.embeddings] logger.debug(f'Generated {len(embeddings)} embeddings, dimension: {len(embeddings[0])}') - ids = list(results_dict.keys())[:len(embeddings)] + ids = list(results_dict.keys())[: len(embeddings)] data_embeddings = list(zip(ids, embeddings, strict=True)) upsert_data = [(str(id), embedding) for id, embedding in data_embeddings] @@ -144,41 +145,41 @@ def create_bigquery_dataset_and_table( try: dataset = client.create_dataset(dataset, exists_ok=True) - logger.debug(f"Dataset {client.project}.{dataset.dataset_id} created.") + logger.debug(f'Dataset {client.project}.{dataset.dataset_id} created.') except Exception as e: - logger.exception(f"Error creating dataset: {e}") + logger.exception(f'Error creating dataset: {e}') raise e schema = [ - bigquery.SchemaField("id", "STRING", mode="REQUIRED"), - bigquery.SchemaField("content", "JSON"), - bigquery.SchemaField("metadata", "JSON"), + bigquery.SchemaField('id', 'STRING', mode='REQUIRED'), + bigquery.SchemaField('content', 'JSON'), + bigquery.SchemaField('metadata', 'JSON'), ] table_ref = dataset_ref.table(table_id) table = bigquery.Table(table_ref, schema=schema) try: table = client.create_table(table, exists_ok=True) - logger.debug(f"Table {table.project}.{table.dataset_id}.{table.table_id} created.") + logger.debug(f'Table {table.project}.{table.dataset_id}.{table.table_id} created.') except Exception as e: - logger.exception(f"Error creating table: {e}") + logger.exception(f'Error creating table: {e}') raise e rows_to_insert = [ { - "id": doc["id"], - "content": json.dumps(doc["content"]), - "metadata": json.dumps(doc["metadata"]), + 'id': doc['id'], + 'content': json.dumps(doc['content']), + 'metadata': json.dumps(doc['metadata']), } for doc in documents ] errors = client.insert_rows_json(table, rows_to_insert) if errors: - logger.error(f"Errors inserting rows: {errors}") - raise Exception(f"Failed to insert rows: {errors}") + logger.error(f'Errors inserting rows: {errors}') + raise Exception(f'Failed to insert rows: {errors}') else: - logger.debug(f"Inserted {len(rows_to_insert)} rows into BigQuery.") + logger.debug(f'Inserted {len(rows_to_insert)} rows into BigQuery.') def get_data_from_bigquery( @@ -199,10 +200,8 @@ def get_data_from_bigquery( A dictionary where keys are document IDs and values are JSON strings representing the document content. """ - table_ref = bigquery.TableReference.from_string( - f"{project_id}.{dataset_id}.{table_id}" - ) - query = f"SELECT id, content FROM `{table_ref}`" + table_ref = bigquery.TableReference.from_string(f'{project_id}.{dataset_id}.{table_id}') + query = f'SELECT id, content FROM `{table_ref}`' query_job = bq_client.query(query) rows = query_job.result() @@ -230,26 +229,19 @@ def upsert_index( aiplatform.init(project=project_id, location=region) index_client = aiplatform_v1.IndexServiceClient( - client_options={"api_endpoint": f"{region}-aiplatform.googleapis.com"} + client_options={'api_endpoint': f'{region}-aiplatform.googleapis.com'} ) - index_path = index_client.index_path( - project=project_id, location=region, index=index_name - ) + index_path = index_client.index_path(project=project_id, location=region, index=index_name) - datapoints = [ - aiplatform_v1.IndexDatapoint(datapoint_id=id, feature_vector=embedding) - for id, embedding in data - ] + datapoints = [aiplatform_v1.IndexDatapoint(datapoint_id=id, feature_vector=embedding) for id, embedding in data] logger.debug(f'Attempting to insert {len(datapoints)} rows into Index {index_path}') - upsert_request = aiplatform_v1.UpsertDatapointsRequest( - index=index_path, datapoints=datapoints - ) + upsert_request = aiplatform_v1.UpsertDatapointsRequest(index=index_path, datapoints=datapoints) response = index_client.upsert_datapoints(request=upsert_request) - logger.info(f"Upserted {len(datapoints)} datapoints. Response: {response}") + logger.info(f'Upserted {len(datapoints)} datapoints. Response: {response}') async def main() -> None: diff --git a/py/samples/vertex-ai-vector-search-firestore/src/sample.py b/py/samples/vertex-ai-vector-search-firestore/src/sample.py index a8bd67f563..d5dd7fd05a 100644 --- a/py/samples/vertex-ai-vector-search-firestore/src/sample.py +++ b/py/samples/vertex-ai-vector-search-firestore/src/sample.py @@ -67,12 +67,14 @@ class QueryFlowInputSchema(BaseModel): """Input schema.""" + query: str k: int class QueryFlowOutputSchema(BaseModel): """Output schema.""" + result: list[dict] length: int time: int @@ -124,7 +126,7 @@ async def query_flow(_input: QueryFlowInputSchema) -> QueryFlowOutputSchema: async def main() -> None: """Main function.""" query_input = QueryFlowInputSchema( - query="Content for doc", + query='Content for doc', k=3, ) From eceb50b46560f25a7ef82ba8d2bbadb2a78dedaf Mon Sep 17 00:00:00 2001 From: Abraham Lazaro Martinez Date: Thu, 8 May 2025 20:28:28 +0000 Subject: [PATCH 4/5] fix: test for plugin --- ...st_vector_search.py => test_retrievers.py} | 0 .../test_vector_search_plugin.py | 32 +++++++++++++++++++ 2 files changed, 32 insertions(+) rename py/plugins/vertex-ai/tests/vector_search/{test_vector_search.py => test_retrievers.py} (100%) create mode 100644 py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py diff --git a/py/plugins/vertex-ai/tests/vector_search/test_vector_search.py b/py/plugins/vertex-ai/tests/vector_search/test_retrievers.py similarity index 100% rename from py/plugins/vertex-ai/tests/vector_search/test_vector_search.py rename to py/plugins/vertex-ai/tests/vector_search/test_retrievers.py diff --git a/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py b/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py new file mode 100644 index 0000000000..1eeb15d7b8 --- /dev/null +++ b/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py @@ -0,0 +1,32 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import MagicMock + +from genkit.ai import Genkit +from genkit.plugins.vertex_ai import VertexAIVectorSearch + + +def test_initialize_plugin(): + """Test plugin initialization.""" + plugin = VertexAIVectorSearch( + retriever=MagicMock(), + embedder='embedder', + ) + + result = plugin.initialize(ai=MagicMock(spec=Genkit)) + + assert result is not None From 05eb989387013df27fe3cb6c9818019320d0d1c4 Mon Sep 17 00:00:00 2001 From: Abraham Lazaro Martinez Date: Fri, 9 May 2025 15:14:37 +0000 Subject: [PATCH 5/5] fix comments --- .../plugins/vertex_ai/models/retriever.py | 32 +++++++-------- .../vertex_ai/vector_search/vector_search.py | 3 -- .../tests/vector_search/test_retrievers.py | 40 +++++++++++-------- .../test_vector_search_plugin.py | 2 + .../src/setup_env.py | 28 ++++++++----- 5 files changed, 56 insertions(+), 49 deletions(-) diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py index 9a288ac08f..2727fe9ea6 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py @@ -155,19 +155,17 @@ async def _get_closest_documents( response = await vector_search_client.find_neighbors(request=nn_request) - return await self._retrieve_neighbours_data_from_db(neighbours=response.nearest_neighbors[0].neighbors) + return await self._retrieve_neighbors_data_from_db(neighbors=response.nearest_neighbors[0].neighbors) @abstractmethod - async def _retrieve_neighbours_data_from_db( - self, neighbours: list[FindNeighborsResponse.Neighbor] - ) -> list[Document]: + async def _retrieve_neighbors_data_from_db(self, neighbors: list[FindNeighborsResponse.Neighbor]) -> list[Document]: """Retrieves document data from the database based on neighbor information. This method must be implemented by subclasses to define how document data is fetched from the database using the provided neighbor information. Args: - neighbours: A list of Neighbor objects representing the nearest neighbors + neighbors: A list of Neighbor objects representing the nearest neighbors found in the vector search index. Returns: @@ -211,9 +209,7 @@ def __init__( self.dataset_id = dataset_id self.table_id = table_id - async def _retrieve_neighbours_data_from_db( - self, neighbours: list[FindNeighborsResponse.Neighbor] - ) -> list[Document]: + async def _retrieve_neighbors_data_from_db(self, neighbors: list[FindNeighborsResponse.Neighbor]) -> list[Document]: """Retrieves document data from the BigQuery table for the given neighbors. Constructs and executes a BigQuery query to fetch document data based on @@ -221,7 +217,7 @@ async def _retrieve_neighbours_data_from_db( document parsing. Args: - neighbours: A list of Neighbor objects representing the nearest neighbors. + neighbors: A list of Neighbor objects representing the nearest neighbors. Each neighbor should contain a datapoint with a datapoint_id. Returns: @@ -229,10 +225,10 @@ async def _retrieve_neighbours_data_from_db( Returns an empty list if no IDs are found in the neighbors or if the query fails. """ - ids = [n.datapoint.datapoint_id for n in neighbours if n.datapoint and n.datapoint.datapoint_id] + ids = [n.datapoint.datapoint_id for n in neighbors if n.datapoint and n.datapoint.datapoint_id] distance_by_id = { - n.datapoint.datapoint_id: n.distance for n in neighbours if n.datapoint and n.datapoint.datapoint_id + n.datapoint.datapoint_id: n.distance for n in neighbors if n.datapoint and n.datapoint.datapoint_id } if not ids: @@ -270,7 +266,7 @@ async def _retrieve_neighbours_data_from_db( documents.append(Document.from_text(content, metadata)) except (ValidationError, json.JSONDecodeError, Exception) as error: doc_id = row.get('id', '') - await logger.awarning(f'Failed to parse document data for document with ID {doc_id}: {error}') + await logger.awarning('Failed to parse document data for document with ID %s: %s', doc_id, error) return documents @@ -306,16 +302,14 @@ def __init__( self.db = firestore_client self.collection_name = collection_name - async def _retrieve_neighbours_data_from_db( - self, neighbours: list[FindNeighborsResponse.Neighbor] - ) -> list[Document]: + async def _retrieve_neighbors_data_from_db(self, neighbors: list[FindNeighborsResponse.Neighbor]) -> list[Document]: """Retrieves document data from the Firestore collection for the given neighbors. Fetches document data from Firestore based on the IDs of the nearest neighbors. Handles potential errors during document retrieval and data parsing. Args: - neighbours: A list of Neighbor objects representing the nearest neighbors. + neighbors: A list of Neighbor objects representing the nearest neighbors. Each neighbor should contain a datapoint with a datapoint_id. Returns: @@ -324,7 +318,7 @@ async def _retrieve_neighbours_data_from_db( """ documents: list[Document] = [] - for neighbor in neighbours: + for neighbor in neighbors: doc_ref = self.db.collection(self.collection_name).document(document_id=neighbor.datapoint.datapoint_id) doc_snapshot = doc_ref.get() @@ -347,7 +341,9 @@ async def _retrieve_neighbours_data_from_db( ) except ValidationError as e: await logger.awarning( - f'Failed to parse document data for ID {neighbor.datapoint.datapoint_id}: {e}' + 'Failed to parse document data for ID %s: %s', + neighbor.datapoint.datapoint_id, + e, ) return documents diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py index 7126ca65b4..0525d8c89c 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py @@ -17,7 +17,6 @@ from functools import partial from typing import Any -import structlog from google.auth.credentials import Credentials from google.cloud import aiplatform_v1 @@ -28,8 +27,6 @@ RetrieverOptionsSchema, ) -logger = structlog.get_logger(__name__) - class VertexAIVectorSearch(Plugin): """A plugin for integrating VertexAI Vector Search. diff --git a/py/plugins/vertex-ai/tests/vector_search/test_retrievers.py b/py/plugins/vertex-ai/tests/vector_search/test_retrievers.py index 0e6b87ac35..915d166f4d 100644 --- a/py/plugins/vertex-ai/tests/vector_search/test_retrievers.py +++ b/py/plugins/vertex-ai/tests/vector_search/test_retrievers.py @@ -14,6 +14,12 @@ # # SPDX-License-Identifier: Apache-2.0 +"""Unittests for VertexAI Vector Search retrievers. + +Defines tests for all the methods of the DocRetriever +implementations like BigQueryRetriever and FirestoreRetriever. +""" + import json from unittest.mock import AsyncMock, MagicMock @@ -180,14 +186,14 @@ async def test_bigquery__get_closest_documents(bq_retriever_instance): # find_neighbors call bq_retriever_instance._match_service_client_generator.return_value = mock_vector_search_client - # Mock _retrieve_neighbours_data_from_db method - mock__retrieve_neighbours_data_from_db_result = [ + # Mock _retrieve_neighbors_data_from_db method + mock__retrieve_neighbors_data_from_db_result = [ Document.from_text(text='1', metadata={'distance': 0.0, 'id': 1}), Document.from_text(text='2', metadata={'distance': 0.0, 'id': 2}), ] - bq_retriever_instance._retrieve_neighbours_data_from_db = AsyncMock( - return_value=mock__retrieve_neighbours_data_from_db_result, + bq_retriever_instance._retrieve_neighbors_data_from_db = AsyncMock( + return_value=mock__retrieve_neighbors_data_from_db_result, ) await bq_retriever_instance._get_closest_documents( @@ -224,7 +230,7 @@ async def test_bigquery__get_closest_documents(bq_retriever_instance): ) ) - bq_retriever_instance._retrieve_neighbours_data_from_db.assert_awaited_once() + bq_retriever_instance._retrieve_neighbors_data_from_db.assert_awaited_once() @pytest.mark.asyncio @@ -278,10 +284,10 @@ async def test_bigquery__get_closest_documents_fail( @pytest.mark.asyncio -async def test_bigquery__retrieve_neighbours_data_from_db( +async def test_bigquery__retrieve_neighbors_data_from_db( bq_retriever_instance, ): - """Test bigquery retriver _retrieve_neighbours_data_from_db.""" + """Test bigquery retriver _retrieve_neighbors_data_from_db.""" # Mock query job result from bigquery query mock_bq_query_job = MagicMock() mock_bq_query_job.result.return_value = [ @@ -296,8 +302,8 @@ async def test_bigquery__retrieve_neighbours_data_from_db( bq_retriever_instance.bq_client.query.return_value = mock_bq_query_job # call the method - result = await bq_retriever_instance._retrieve_neighbours_data_from_db( - neighbours=[ + result = await bq_retriever_instance._retrieve_neighbors_data_from_db( + neighbors=[ FindNeighborsResponse.Neighbor( datapoint=types.index.IndexDatapoint(datapoint_id='doc1'), distance=0.0, @@ -339,16 +345,16 @@ async def test_bigquery__retrieve_neighbours_data_from_db( @pytest.mark.asyncio -async def test_bigquery_retrieve_neighbours_data_from_db_fail( +async def test_bigquery_retrieve_neighbors_data_from_db_fail( bq_retriever_instance, ): - """Test bigquery retriver _retrieve_neighbours_data_from_db when fails.""" + """Test bigquery retriver _retrieve_neighbors_data_from_db when fails.""" # Mock exception from bigquery query bq_retriever_instance.bq_client.query.raises = AttributeError # call the method - result = await bq_retriever_instance._retrieve_neighbours_data_from_db( - neighbours=[ + result = await bq_retriever_instance._retrieve_neighbors_data_from_db( + neighbors=[ FindNeighborsResponse.Neighbor( datapoint=types.index.IndexDatapoint(datapoint_id='doc1'), distance=0.0, @@ -387,10 +393,10 @@ def test_firestore_retriever__init__(fs_retriever_instance): @pytest.mark.asyncio -async def test_firesstore__retrieve_neighbours_data_from_db( +async def test_firesstore__retrieve_neighbors_data_from_db( fs_retriever_instance, ): - """Test _retrieve_neighbours_data_from_db for firestore retriever.""" + """Test _retrieve_neighbors_data_from_db for firestore retriever.""" # Mock storage of firestore storage = { 'doc1': { @@ -418,8 +424,8 @@ def document(self, document_id): fs_retriever_instance.db.collection.return_value = MockCollection() # call the method - result = await fs_retriever_instance._retrieve_neighbours_data_from_db( - neighbours=[ + result = await fs_retriever_instance._retrieve_neighbors_data_from_db( + neighbors=[ FindNeighborsResponse.Neighbor( datapoint=types.index.IndexDatapoint(datapoint_id='doc1'), distance=0.0, diff --git a/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py b/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py index 1eeb15d7b8..4536b87a6e 100644 --- a/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py +++ b/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py @@ -14,6 +14,8 @@ # # SPDX-License-Identifier: Apache-2.0 +"""Unittest for VertexAIVectorSearch plugin.""" + from unittest.mock import MagicMock from genkit.ai import Genkit diff --git a/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py b/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py index b179fa2b16..7a166f4719 100644 --- a/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py +++ b/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py @@ -14,6 +14,8 @@ # # SPDX-License-Identifier: Apache-2.0 +"""Example of using Genkit to fill VertexAI Index for Vector Search with BigQuery.""" + import json import os @@ -111,7 +113,6 @@ async def generate_embeddings(): ) embeddings = [emb.embedding for emb in embed_response.embeddings] - logger.debug(f'Generated {len(embeddings)} embeddings, dimension: {len(embeddings[0])}') ids = list(results_dict.keys())[: len(embeddings)] data_embeddings = list(zip(ids, embeddings, strict=True)) @@ -145,9 +146,9 @@ def create_bigquery_dataset_and_table( try: dataset = client.create_dataset(dataset, exists_ok=True) - logger.debug(f'Dataset {client.project}.{dataset.dataset_id} created.') + logger.debug('Dataset %s.%s created.', client.project, dataset.dataset_id) except Exception as e: - logger.exception(f'Error creating dataset: {e}') + logger.exception('Error creating dataset: %s', e) raise e schema = [ @@ -160,9 +161,14 @@ def create_bigquery_dataset_and_table( table = bigquery.Table(table_ref, schema=schema) try: table = client.create_table(table, exists_ok=True) - logger.debug(f'Table {table.project}.{table.dataset_id}.{table.table_id} created.') + logger.debug( + 'Table %s.%s.%s created.', + table.project, + table.dataset_id, + table.table_id, + ) except Exception as e: - logger.exception(f'Error creating table: {e}') + logger.exception('Error creating table: %s', e) raise e rows_to_insert = [ @@ -176,10 +182,10 @@ def create_bigquery_dataset_and_table( errors = client.insert_rows_json(table, rows_to_insert) if errors: - logger.error(f'Errors inserting rows: {errors}') + logger.error('Errors inserting rows: %s', errors) raise Exception(f'Failed to insert rows: {errors}') else: - logger.debug(f'Inserted {len(rows_to_insert)} rows into BigQuery.') + logger.debug('Inserted %s rows into BigQuery.', len(rows_to_insert)) def get_data_from_bigquery( @@ -206,7 +212,7 @@ def get_data_from_bigquery( rows = query_job.result() results = {row['id']: json.dumps(row['content']) for row in rows} - logger.debug(f'Found {len(results)} rows with different ids into BigQuery.') + logger.debug('Found %s rows with different ids into BigQuery.', len(results)) return results @@ -236,12 +242,12 @@ def upsert_index( datapoints = [aiplatform_v1.IndexDatapoint(datapoint_id=id, feature_vector=embedding) for id, embedding in data] - logger.debug(f'Attempting to insert {len(datapoints)} rows into Index {index_path}') + logger.debug('Attempting to insert %s rows into Index %s', len(datapoints), index_path) upsert_request = aiplatform_v1.UpsertDatapointsRequest(index=index_path, datapoints=datapoints) - response = index_client.upsert_datapoints(request=upsert_request) - logger.info(f'Upserted {len(datapoints)} datapoints. Response: {response}') + index_client.upsert_datapoints(request=upsert_request) + logger.info('Upserted %s datapoints.', len(datapoints)) async def main() -> None: