diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py index 23e6644d95..2727fe9ea6 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py @@ -21,7 +21,7 @@ import structlog from google.cloud import bigquery, firestore -from google.cloud.aiplatform_v1 import FindNeighborsRequest, IndexDatapoint, Neighbor +from google.cloud.aiplatform_v1 import FindNeighborsRequest, FindNeighborsResponse, IndexDatapoint from pydantic import BaseModel, Field, ValidationError from genkit.ai import Genkit @@ -49,6 +49,7 @@ class DocRetriever(ABC): embedder: The name of the embedder to use for generating embeddings. embedder_options: Options to pass to the embedder. """ + def __init__( self, ai: Genkit, @@ -117,20 +118,25 @@ async def _get_closest_documents( A list of Document objects representing the closest documents. Raises: - AttributeError: If the request does not contain the necessary + AttributeError: If the request does not contain the necessary index endpoint path in its metadata. """ metadata = request.query.metadata - if not metadata or 'index_endpoint_path' not in metadata or 'api_endpoint' not in metadata: - raise AttributeError('Request provides no data about index endpoint path') + + required_keys = ['index_endpoint_path', 'api_endpoint', 'deployed_index_id'] + + if not metadata: + raise AttributeError('Request metadata provides no data about index') + + for rkey in required_keys: + if rkey not in metadata: + raise AttributeError(f'Request metadata provides no data for {rkey}') api_endpoint = metadata['api_endpoint'] index_endpoint_path = metadata['index_endpoint_path'] deployed_index_id = metadata['deployed_index_id'] - client_options = { - "api_endpoint": api_endpoint - } + client_options = {'api_endpoint': api_endpoint} vector_search_client = self._match_service_client_generator( client_options=client_options, @@ -149,17 +155,17 @@ async def _get_closest_documents( response = await vector_search_client.find_neighbors(request=nn_request) - return await self._retrieve_neighbours_data_from_db(neighbours=response.nearest_neighbors[0].neighbors) + return await self._retrieve_neighbors_data_from_db(neighbors=response.nearest_neighbors[0].neighbors) @abstractmethod - async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> list[Document]: + async def _retrieve_neighbors_data_from_db(self, neighbors: list[FindNeighborsResponse.Neighbor]) -> list[Document]: """Retrieves document data from the database based on neighbor information. This method must be implemented by subclasses to define how document data is fetched from the database using the provided neighbor information. Args: - neighbours: A list of Neighbor objects representing the nearest neighbors + neighbors: A list of Neighbor objects representing the nearest neighbors found in the vector search index. Returns: @@ -180,8 +186,14 @@ class BigQueryRetriever(DocRetriever): dataset_id: The ID of the BigQuery dataset. table_id: The ID of the BigQuery table. """ + def __init__( - self, bq_client: bigquery.Client, dataset_id: str, table_id: str, *args, **kwargs, + self, + bq_client: bigquery.Client, + dataset_id: str, + table_id: str, + *args, + **kwargs, ) -> None: """Initializes the BigQueryRetriever. @@ -197,15 +209,15 @@ def __init__( self.dataset_id = dataset_id self.table_id = table_id - async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> list[Document]: + async def _retrieve_neighbors_data_from_db(self, neighbors: list[FindNeighborsResponse.Neighbor]) -> list[Document]: """Retrieves document data from the BigQuery table for the given neighbors. Constructs and executes a BigQuery query to fetch document data based on - the IDs obtained. Handles potential errors during query execution and + the IDs obtained. Handles potential errors during query execution and document parsing. Args: - neighbours: A list of Neighbor objects representing the nearest neighbors. + neighbors: A list of Neighbor objects representing the nearest neighbors. Each neighbor should contain a datapoint with a datapoint_id. Returns: @@ -213,16 +225,10 @@ async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> Returns an empty list if no IDs are found in the neighbors or if the query fails. """ - ids = [ - n.datapoint.datapoint_id - for n in neighbours - if n.datapoint and n.datapoint.datapoint_id - ] + ids = [n.datapoint.datapoint_id for n in neighbors if n.datapoint and n.datapoint.datapoint_id] distance_by_id = { - n.datapoint.datapoint_id: n.distance - for n in neighbours - if n.datapoint and n.datapoint.datapoint_id + n.datapoint.datapoint_id: n.distance for n in neighbors if n.datapoint and n.datapoint.datapoint_id } if not ids: @@ -260,7 +266,7 @@ async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> documents.append(Document.from_text(content, metadata)) except (ValidationError, json.JSONDecodeError, Exception) as error: doc_id = row.get('id', '') - await logger.awarning(f'Failed to parse document data for document with ID {doc_id}: {error}') + await logger.awarning('Failed to parse document data for document with ID %s: %s', doc_id, error) return documents @@ -276,8 +282,13 @@ class FirestoreRetriever(DocRetriever): db: The Firestore client. collection_name: The name of the Firestore collection. """ + def __init__( - self, firestore_client: firestore.AsyncClient, collection_name: str, *args, **kwargs, + self, + firestore_client: firestore.AsyncClient, + collection_name: str, + *args, + **kwargs, ) -> None: """Initializes the FirestoreRetriever. @@ -291,14 +302,14 @@ def __init__( self.db = firestore_client self.collection_name = collection_name - async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> list[Document]: + async def _retrieve_neighbors_data_from_db(self, neighbors: list[FindNeighborsResponse.Neighbor]) -> list[Document]: """Retrieves document data from the Firestore collection for the given neighbors. Fetches document data from Firestore based on the IDs of the nearest neighbors. Handles potential errors during document retrieval and data parsing. Args: - neighbours: A list of Neighbor objects representing the nearest neighbors. + neighbors: A list of Neighbor objects representing the nearest neighbors. Each neighbor should contain a datapoint with a datapoint_id. Returns: @@ -307,14 +318,14 @@ async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> """ documents: list[Document] = [] - for neighbor in neighbours: + for neighbor in neighbors: doc_ref = self.db.collection(self.collection_name).document(document_id=neighbor.datapoint.datapoint_id) doc_snapshot = doc_ref.get() if doc_snapshot.exists: doc_data = doc_snapshot.to_dict() or {} - content = doc_data.get('content') + content = doc_data.get('content', '') content = json.dumps(content) if isinstance(content, dict) else str(content) metadata = doc_data.get('metadata', {}) @@ -330,7 +341,9 @@ async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> ) except ValidationError as e: await logger.awarning( - f'Failed to parse document data for ID {neighbor.datapoint.datapoint_id}: {e}' + 'Failed to parse document data for ID %s: %s', + neighbor.datapoint.datapoint_id, + e, ) return documents @@ -342,4 +355,5 @@ class RetrieverOptionsSchema(BaseModel): Attributes: limit: Number of documents to retrieve. """ + limit: int | None = Field(title='Number of documents to retrieve', default=None) diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py index 7126ca65b4..0525d8c89c 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py @@ -17,7 +17,6 @@ from functools import partial from typing import Any -import structlog from google.auth.credentials import Credentials from google.cloud import aiplatform_v1 @@ -28,8 +27,6 @@ RetrieverOptionsSchema, ) -logger = structlog.get_logger(__name__) - class VertexAIVectorSearch(Plugin): """A plugin for integrating VertexAI Vector Search. diff --git a/py/plugins/vertex-ai/tests/vector_search/test_retrievers.py b/py/plugins/vertex-ai/tests/vector_search/test_retrievers.py new file mode 100644 index 0000000000..915d166f4d --- /dev/null +++ b/py/plugins/vertex-ai/tests/vector_search/test_retrievers.py @@ -0,0 +1,474 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unittests for VertexAI Vector Search retrievers. + +Defines tests for all the methods of the DocRetriever +implementations like BigQueryRetriever and FirestoreRetriever. +""" + +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest +from google.cloud import bigquery +from google.cloud.aiplatform_v1 import ( + FindNeighborsRequest, + FindNeighborsResponse, + IndexDatapoint, + MatchServiceAsyncClient, + types, +) + +from genkit.ai import Genkit +from genkit.blocks.document import Document, DocumentData +from genkit.core.typing import Embedding +from genkit.plugins.vertex_ai.models.retriever import ( + BigQueryRetriever, + FirestoreRetriever, +) +from genkit.types import ( + ActionRunContext, + RetrieverRequest, + TextPart, +) + + +@pytest.fixture +def bq_retriever_instance(): + """Common initialization of bq retriever.""" + return BigQueryRetriever( + ai=MagicMock(), + name='test', + match_service_client_generator=MagicMock(), + embedder='embedder', + embedder_options=None, + bq_client=MagicMock(), + dataset_id='dataset_id', + table_id='table_id', + ) + + +def test_bigquery_retriever__init__(bq_retriever_instance): + """Init test.""" + bq_retriever = bq_retriever_instance + + assert bq_retriever is not None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'options, top_k', + [ + ( + {'limit': 10}, + 10, + ), + ( + {}, + 3, + ), + ( + None, + 3, + ), + ], +) +async def test_bigquery_retriever_retrieve( + bq_retriever_instance, + options, + top_k, +): + """Test retrieve method bq retriever.""" + # Mock query embedder + mock_embedding = MagicMock() + mock_embedding.embeddings = [ + Embedding( + embedding=[0.1, 0.2, 0.3], + ), + ] + + mock_genkit = MagicMock(spec=Genkit) + mock_genkit.embed.return_value = mock_embedding + + bq_retriever_instance.ai = mock_genkit + + # Mock _get_closest_documents + mock__get_closest_documents_result = [ + Document.from_text( + text='1', + metadata={'distance': 0.0, 'id': 1}, + ), + Document.from_text( + text='2', + metadata={'distance': 0.0, 'id': 2}, + ), + ] + + bq_retriever_instance._get_closest_documents = AsyncMock( + return_value=mock__get_closest_documents_result, + ) + + # Executes + await bq_retriever_instance.retrieve( + RetrieverRequest( + query=DocumentData( + content=[ + TextPart(text='test-1'), + ], + ), + options=options, + ), + MagicMock(spec=ActionRunContext), + ) + + # Assert mocks + bq_retriever_instance.ai.embed.assert_called_once_with( + embedder='embedder', + documents=[ + Document( + content=[ + TextPart(text='test-1'), + ], + ), + ], + options={}, + ) + + bq_retriever_instance._get_closest_documents.assert_awaited_once_with( + request=RetrieverRequest( + query=DocumentData( + content=[ + TextPart(text='test-1'), + ], + ), + options=options, + ), + top_k=top_k, + query_embeddings=Embedding( + embedding=[0.1, 0.2, 0.3], + ), + ) + + +@pytest.mark.asyncio +async def test_bigquery__get_closest_documents(bq_retriever_instance): + """Test bigquery retriever _get_closest_documents.""" + # Mock find_neighbors method + mock_vector_search_client = MagicMock(spec=MatchServiceAsyncClient) + + # find_neighbors response + mock_nn = MagicMock() + mock_nn.neighbors = [] + + mock_nn_response = MagicMock(spec=FindNeighborsResponse) + mock_nn_response.nearest_neighbors = [ + mock_nn, + ] + + mock_vector_search_client.find_neighbors = AsyncMock( + return_value=mock_nn_response, + ) + + # find_neighbors call + bq_retriever_instance._match_service_client_generator.return_value = mock_vector_search_client + + # Mock _retrieve_neighbors_data_from_db method + mock__retrieve_neighbors_data_from_db_result = [ + Document.from_text(text='1', metadata={'distance': 0.0, 'id': 1}), + Document.from_text(text='2', metadata={'distance': 0.0, 'id': 2}), + ] + + bq_retriever_instance._retrieve_neighbors_data_from_db = AsyncMock( + return_value=mock__retrieve_neighbors_data_from_db_result, + ) + + await bq_retriever_instance._get_closest_documents( + request=RetrieverRequest( + query=DocumentData( + content=[TextPart(text='test-1')], + metadata={ + 'index_endpoint_path': 'index_endpoint_path', + 'api_endpoint': 'api_endpoint', + 'deployed_index_id': 'deployed_index_id', + }, + ), + options={ + 'limit': 10, + }, + ), + top_k=10, + query_embeddings=Embedding( + embedding=[0.1, 0.2, 0.3], + ), + ) + + # Assert calls + mock_vector_search_client.find_neighbors.assert_awaited_once_with( + request=FindNeighborsRequest( + index_endpoint='index_endpoint_path', + deployed_index_id='deployed_index_id', + queries=[ + FindNeighborsRequest.Query( + datapoint=IndexDatapoint(feature_vector=[0.1, 0.2, 0.3]), + neighbor_count=10, + ) + ], + ) + ) + + bq_retriever_instance._retrieve_neighbors_data_from_db.assert_awaited_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'metadata', + [ + { + 'index_endpoint_path': 'index_endpoint_path', + }, + { + 'api_endpoint': 'api_endpoint', + }, + { + 'deployed_index_id': 'deployed_index_id', + }, + { + 'index_endpoint_path': 'index_endpoint_path', + 'api_endpoint': 'api_endpoint', + }, + { + 'index_endpoint_path': 'index_endpoint_path', + 'deployed_index_id': 'deployed_index_id', + }, + { + 'api_endpoint': 'api_endpoint', + 'deployed_index_id': 'deployed_index_id', + }, + ], +) +async def test_bigquery__get_closest_documents_fail( + bq_retriever_instance, + metadata, +): + """Test failures bigquery retriever _get_closest_documents.""" + with pytest.raises(AttributeError): + await bq_retriever_instance._get_closest_documents( + request=RetrieverRequest( + query=DocumentData( + content=[TextPart(text='test-1')], + metadata=metadata, + ), + options={ + 'limit': 10, + }, + ), + top_k=10, + query_embeddings=Embedding( + embedding=[0.1, 0.2, 0.3], + ), + ) + + +@pytest.mark.asyncio +async def test_bigquery__retrieve_neighbors_data_from_db( + bq_retriever_instance, +): + """Test bigquery retriver _retrieve_neighbors_data_from_db.""" + # Mock query job result from bigquery query + mock_bq_query_job = MagicMock() + mock_bq_query_job.result.return_value = [ + { + 'id': 'doc1', + 'content': {'body': 'text for document 1'}, + }, + {'id': 'doc2', 'content': json.dumps({'body': 'text for document 2'}), 'metadata': {'date': 'today'}}, + {}, # should error without skipping first two rows + ] + + bq_retriever_instance.bq_client.query.return_value = mock_bq_query_job + + # call the method + result = await bq_retriever_instance._retrieve_neighbors_data_from_db( + neighbors=[ + FindNeighborsResponse.Neighbor( + datapoint=types.index.IndexDatapoint(datapoint_id='doc1'), + distance=0.0, + sparse_distance=0.0, + ), + FindNeighborsResponse.Neighbor( + datapoint=types.index.IndexDatapoint(datapoint_id='doc2'), + distance=0.0, + sparse_distance=0.0, + ), + ] + ) + + # Assert results and calls + expected = [ + Document.from_text( + text=json.dumps( + { + 'body': 'text for document 1', + }, + ), + metadata={'id': 'doc1', 'distance': 0.0}, + ), + Document.from_text( + text=json.dumps( + { + 'body': 'text for document 2', + }, + ), + metadata={'id': 'doc2', 'distance': 0.0, 'date': 'today'}, + ), + ] + + assert result == expected + + bq_retriever_instance.bq_client.query.assert_called_once() + + mock_bq_query_job.result.assert_called_once() + + +@pytest.mark.asyncio +async def test_bigquery_retrieve_neighbors_data_from_db_fail( + bq_retriever_instance, +): + """Test bigquery retriver _retrieve_neighbors_data_from_db when fails.""" + # Mock exception from bigquery query + bq_retriever_instance.bq_client.query.raises = AttributeError + + # call the method + result = await bq_retriever_instance._retrieve_neighbors_data_from_db( + neighbors=[ + FindNeighborsResponse.Neighbor( + datapoint=types.index.IndexDatapoint(datapoint_id='doc1'), + distance=0.0, + sparse_distance=0.0, + ), + FindNeighborsResponse.Neighbor( + datapoint=types.index.IndexDatapoint(datapoint_id='doc2'), + distance=0.0, + sparse_distance=0.0, + ), + ] + ) + + assert len(result) == 0 + + bq_retriever_instance.bq_client.query.assert_called_once() + + +@pytest.fixture +def fs_retriever_instance(): + """Common initialization of bq retriever.""" + return FirestoreRetriever( + ai=MagicMock(), + name='test', + match_service_client_generator=MagicMock(), + embedder='embedder', + embedder_options=None, + firestore_client=MagicMock(), + collection_name='collection_name', + ) + + +def test_firestore_retriever__init__(fs_retriever_instance): + """Init test.""" + assert fs_retriever_instance is not None + + +@pytest.mark.asyncio +async def test_firesstore__retrieve_neighbors_data_from_db( + fs_retriever_instance, +): + """Test _retrieve_neighbors_data_from_db for firestore retriever.""" + # Mock storage of firestore + storage = { + 'doc1': { + 'content': {'body': 'text for document 1'}, + }, + 'doc2': {'content': json.dumps({'body': 'text for document 2'}), 'metadata': {'date': 'today'}}, + 'doc3': {}, + } + + # Mock get from firestore + class MockCollection: + def document(self, document_id): + doc_ref = MagicMock() + doc_snapshot = MagicMock() + + doc_ref.get.return_value = doc_snapshot + if storage.get(document_id) is not None: + doc_snapshot.exists = True + doc_snapshot.to_dict.return_value = storage.get(document_id) + else: + doc_snapshot.exists = False + + return doc_ref + + fs_retriever_instance.db.collection.return_value = MockCollection() + + # call the method + result = await fs_retriever_instance._retrieve_neighbors_data_from_db( + neighbors=[ + FindNeighborsResponse.Neighbor( + datapoint=types.index.IndexDatapoint(datapoint_id='doc1'), + distance=0.0, + sparse_distance=0.0, + ), + FindNeighborsResponse.Neighbor( + datapoint=types.index.IndexDatapoint(datapoint_id='doc2'), + distance=0.0, + sparse_distance=0.0, + ), + FindNeighborsResponse.Neighbor( + datapoint=types.index.IndexDatapoint(datapoint_id='doc3'), + distance=0.0, + sparse_distance=0.0, + ), + ] + ) + + # Assert results and calls + expected = [ + Document.from_text( + text=json.dumps( + { + 'body': 'text for document 1', + }, + ), + metadata={'id': 'doc1', 'distance': 0.0}, + ), + Document.from_text( + text=json.dumps( + { + 'body': 'text for document 2', + }, + ), + metadata={'id': 'doc2', 'distance': 0.0, 'date': 'today'}, + ), + Document.from_text( + text='', + metadata={ + 'id': 'doc3', + 'distance': 0.0, + }, + ), + ] + + assert result == expected diff --git a/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py b/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py new file mode 100644 index 0000000000..4536b87a6e --- /dev/null +++ b/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py @@ -0,0 +1,34 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unittest for VertexAIVectorSearch plugin.""" + +from unittest.mock import MagicMock + +from genkit.ai import Genkit +from genkit.plugins.vertex_ai import VertexAIVectorSearch + + +def test_initialize_plugin(): + """Test plugin initialization.""" + plugin = VertexAIVectorSearch( + retriever=MagicMock(), + embedder='embedder', + ) + + result = plugin.initialize(ai=MagicMock(spec=Genkit)) + + assert result is not None diff --git a/py/samples/vertex-ai-vector-search-bigquery/src/sample.py b/py/samples/vertex-ai-vector-search-bigquery/src/sample.py index 39c994a261..5f04d979c9 100644 --- a/py/samples/vertex-ai-vector-search-bigquery/src/sample.py +++ b/py/samples/vertex-ai-vector-search-bigquery/src/sample.py @@ -71,12 +71,14 @@ class QueryFlowInputSchema(BaseModel): """Input schema.""" + query: str k: int class QueryFlowOutputSchema(BaseModel): """Output schema.""" + result: list[dict] length: int time: int @@ -128,7 +130,7 @@ async def query_flow(_input: QueryFlowInputSchema) -> QueryFlowOutputSchema: async def main() -> None: """Main function.""" query_input = QueryFlowInputSchema( - query="Content for doc", + query='Content for doc', k=3, ) diff --git a/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py b/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py index 7fc6710c9b..7a166f4719 100644 --- a/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py +++ b/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py @@ -14,6 +14,8 @@ # # SPDX-License-Identifier: Apache-2.0 +"""Example of using Genkit to fill VertexAI Index for Vector Search with BigQuery.""" + import json import os @@ -57,7 +59,7 @@ }, embedder=EMBEDDING_MODEL, embedder_options={'task': 'RETRIEVAL_DOCUMENT'}, - ) + ), ] ) @@ -71,24 +73,28 @@ async def generate_embeddings(): """ toy_documents = [ { - "id": "doc1", - "content": {"title": "Document 1", "body": "This is the content of document 1."}, - "metadata": {"author": "Alice", "date": "2024-01-15"}, + 'id': 'doc1', + 'content': {'title': 'Document 1', 'body': 'This is the content of document 1.'}, + 'metadata': {'author': 'Alice', 'date': '2024-01-15'}, }, { - "id": "doc2", - "content": {"title": "Document 2", "body": "This is the content of document 2."}, - "metadata": {"author": "Bob", "date": "2024-02-20"}, + 'id': 'doc2', + 'content': {'title': 'Document 2', 'body': 'This is the content of document 2.'}, + 'metadata': {'author': 'Bob', 'date': '2024-02-20'}, }, { - "id": "doc3", - "content": {"title": "Document 3", "body": "Content for doc 3"}, - "metadata": {"author": "Charlie", "date": "2024-03-01"}, + 'id': 'doc3', + 'content': {'title': 'Document 3', 'body': 'Content for doc 3'}, + 'metadata': {'author': 'Charlie', 'date': '2024-03-01'}, }, ] create_bigquery_dataset_and_table( - PROJECT_ID, LOCATION, BIGQUERY_DATASET_NAME, BIGQUERY_TABLE_NAME, toy_documents, + PROJECT_ID, + LOCATION, + BIGQUERY_DATASET_NAME, + BIGQUERY_TABLE_NAME, + toy_documents, ) results_dict = get_data_from_bigquery( @@ -98,10 +104,7 @@ async def generate_embeddings(): table_id=BIGQUERY_TABLE_NAME, ) - genkit_documents = [ - types.Document(content=[types.TextPart(text=text)]) - for text in results_dict.values() - ] + genkit_documents = [types.Document(content=[types.TextPart(text=text)]) for text in results_dict.values()] embed_response = await ai.embed( embedder=vertexai_name(EMBEDDING_MODEL), @@ -110,9 +113,8 @@ async def generate_embeddings(): ) embeddings = [emb.embedding for emb in embed_response.embeddings] - logger.debug(f'Generated {len(embeddings)} embeddings, dimension: {len(embeddings[0])}') - ids = list(results_dict.keys())[:len(embeddings)] + ids = list(results_dict.keys())[: len(embeddings)] data_embeddings = list(zip(ids, embeddings, strict=True)) upsert_data = [(str(id), embedding) for id, embedding in data_embeddings] @@ -144,41 +146,46 @@ def create_bigquery_dataset_and_table( try: dataset = client.create_dataset(dataset, exists_ok=True) - logger.debug(f"Dataset {client.project}.{dataset.dataset_id} created.") + logger.debug('Dataset %s.%s created.', client.project, dataset.dataset_id) except Exception as e: - logger.exception(f"Error creating dataset: {e}") + logger.exception('Error creating dataset: %s', e) raise e schema = [ - bigquery.SchemaField("id", "STRING", mode="REQUIRED"), - bigquery.SchemaField("content", "JSON"), - bigquery.SchemaField("metadata", "JSON"), + bigquery.SchemaField('id', 'STRING', mode='REQUIRED'), + bigquery.SchemaField('content', 'JSON'), + bigquery.SchemaField('metadata', 'JSON'), ] table_ref = dataset_ref.table(table_id) table = bigquery.Table(table_ref, schema=schema) try: table = client.create_table(table, exists_ok=True) - logger.debug(f"Table {table.project}.{table.dataset_id}.{table.table_id} created.") + logger.debug( + 'Table %s.%s.%s created.', + table.project, + table.dataset_id, + table.table_id, + ) except Exception as e: - logger.exception(f"Error creating table: {e}") + logger.exception('Error creating table: %s', e) raise e rows_to_insert = [ { - "id": doc["id"], - "content": json.dumps(doc["content"]), - "metadata": json.dumps(doc["metadata"]), + 'id': doc['id'], + 'content': json.dumps(doc['content']), + 'metadata': json.dumps(doc['metadata']), } for doc in documents ] errors = client.insert_rows_json(table, rows_to_insert) if errors: - logger.error(f"Errors inserting rows: {errors}") - raise Exception(f"Failed to insert rows: {errors}") + logger.error('Errors inserting rows: %s', errors) + raise Exception(f'Failed to insert rows: {errors}') else: - logger.debug(f"Inserted {len(rows_to_insert)} rows into BigQuery.") + logger.debug('Inserted %s rows into BigQuery.', len(rows_to_insert)) def get_data_from_bigquery( @@ -199,15 +206,13 @@ def get_data_from_bigquery( A dictionary where keys are document IDs and values are JSON strings representing the document content. """ - table_ref = bigquery.TableReference.from_string( - f"{project_id}.{dataset_id}.{table_id}" - ) - query = f"SELECT id, content FROM `{table_ref}`" + table_ref = bigquery.TableReference.from_string(f'{project_id}.{dataset_id}.{table_id}') + query = f'SELECT id, content FROM `{table_ref}`' query_job = bq_client.query(query) rows = query_job.result() results = {row['id']: json.dumps(row['content']) for row in rows} - logger.debug(f'Found {len(results)} rows with different ids into BigQuery.') + logger.debug('Found %s rows with different ids into BigQuery.', len(results)) return results @@ -230,26 +235,19 @@ def upsert_index( aiplatform.init(project=project_id, location=region) index_client = aiplatform_v1.IndexServiceClient( - client_options={"api_endpoint": f"{region}-aiplatform.googleapis.com"} + client_options={'api_endpoint': f'{region}-aiplatform.googleapis.com'} ) - index_path = index_client.index_path( - project=project_id, location=region, index=index_name - ) + index_path = index_client.index_path(project=project_id, location=region, index=index_name) - datapoints = [ - aiplatform_v1.IndexDatapoint(datapoint_id=id, feature_vector=embedding) - for id, embedding in data - ] + datapoints = [aiplatform_v1.IndexDatapoint(datapoint_id=id, feature_vector=embedding) for id, embedding in data] - logger.debug(f'Attempting to insert {len(datapoints)} rows into Index {index_path}') + logger.debug('Attempting to insert %s rows into Index %s', len(datapoints), index_path) - upsert_request = aiplatform_v1.UpsertDatapointsRequest( - index=index_path, datapoints=datapoints - ) + upsert_request = aiplatform_v1.UpsertDatapointsRequest(index=index_path, datapoints=datapoints) - response = index_client.upsert_datapoints(request=upsert_request) - logger.info(f"Upserted {len(datapoints)} datapoints. Response: {response}") + index_client.upsert_datapoints(request=upsert_request) + logger.info('Upserted %s datapoints.', len(datapoints)) async def main() -> None: diff --git a/py/samples/vertex-ai-vector-search-firestore/src/sample.py b/py/samples/vertex-ai-vector-search-firestore/src/sample.py index a8bd67f563..d5dd7fd05a 100644 --- a/py/samples/vertex-ai-vector-search-firestore/src/sample.py +++ b/py/samples/vertex-ai-vector-search-firestore/src/sample.py @@ -67,12 +67,14 @@ class QueryFlowInputSchema(BaseModel): """Input schema.""" + query: str k: int class QueryFlowOutputSchema(BaseModel): """Output schema.""" + result: list[dict] length: int time: int @@ -124,7 +126,7 @@ async def query_flow(_input: QueryFlowInputSchema) -> QueryFlowOutputSchema: async def main() -> None: """Main function.""" query_input = QueryFlowInputSchema( - query="Content for doc", + query='Content for doc', k=3, )