Skip to content

Commit 8b28c48

Browse files
AbeJLazaroAbraham Lazaro Martinez
and
Abraham Lazaro Martinez
authored
fix(py): linters and adding test for vertex-ai vector search plugin (#2889)
Co-authored-by: Abraham Lazaro Martinez <lazaromartinez@google.com>
1 parent 23a9565 commit 8b28c48

File tree

7 files changed

+604
-83
lines changed

7 files changed

+604
-83
lines changed

py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py

+43-29
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import structlog
2323
from google.cloud import bigquery, firestore
24-
from google.cloud.aiplatform_v1 import FindNeighborsRequest, IndexDatapoint, Neighbor
24+
from google.cloud.aiplatform_v1 import FindNeighborsRequest, FindNeighborsResponse, IndexDatapoint
2525
from pydantic import BaseModel, Field, ValidationError
2626

2727
from genkit.ai import Genkit
@@ -49,6 +49,7 @@ class DocRetriever(ABC):
4949
embedder: The name of the embedder to use for generating embeddings.
5050
embedder_options: Options to pass to the embedder.
5151
"""
52+
5253
def __init__(
5354
self,
5455
ai: Genkit,
@@ -117,20 +118,25 @@ async def _get_closest_documents(
117118
A list of Document objects representing the closest documents.
118119
119120
Raises:
120-
AttributeError: If the request does not contain the necessary
121+
AttributeError: If the request does not contain the necessary
121122
index endpoint path in its metadata.
122123
"""
123124
metadata = request.query.metadata
124-
if not metadata or 'index_endpoint_path' not in metadata or 'api_endpoint' not in metadata:
125-
raise AttributeError('Request provides no data about index endpoint path')
125+
126+
required_keys = ['index_endpoint_path', 'api_endpoint', 'deployed_index_id']
127+
128+
if not metadata:
129+
raise AttributeError('Request metadata provides no data about index')
130+
131+
for rkey in required_keys:
132+
if rkey not in metadata:
133+
raise AttributeError(f'Request metadata provides no data for {rkey}')
126134

127135
api_endpoint = metadata['api_endpoint']
128136
index_endpoint_path = metadata['index_endpoint_path']
129137
deployed_index_id = metadata['deployed_index_id']
130138

131-
client_options = {
132-
"api_endpoint": api_endpoint
133-
}
139+
client_options = {'api_endpoint': api_endpoint}
134140

135141
vector_search_client = self._match_service_client_generator(
136142
client_options=client_options,
@@ -149,17 +155,17 @@ async def _get_closest_documents(
149155

150156
response = await vector_search_client.find_neighbors(request=nn_request)
151157

152-
return await self._retrieve_neighbours_data_from_db(neighbours=response.nearest_neighbors[0].neighbors)
158+
return await self._retrieve_neighbors_data_from_db(neighbors=response.nearest_neighbors[0].neighbors)
153159

154160
@abstractmethod
155-
async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> list[Document]:
161+
async def _retrieve_neighbors_data_from_db(self, neighbors: list[FindNeighborsResponse.Neighbor]) -> list[Document]:
156162
"""Retrieves document data from the database based on neighbor information.
157163
158164
This method must be implemented by subclasses to define how document
159165
data is fetched from the database using the provided neighbor information.
160166
161167
Args:
162-
neighbours: A list of Neighbor objects representing the nearest neighbors
168+
neighbors: A list of Neighbor objects representing the nearest neighbors
163169
found in the vector search index.
164170
165171
Returns:
@@ -180,8 +186,14 @@ class BigQueryRetriever(DocRetriever):
180186
dataset_id: The ID of the BigQuery dataset.
181187
table_id: The ID of the BigQuery table.
182188
"""
189+
183190
def __init__(
184-
self, bq_client: bigquery.Client, dataset_id: str, table_id: str, *args, **kwargs,
191+
self,
192+
bq_client: bigquery.Client,
193+
dataset_id: str,
194+
table_id: str,
195+
*args,
196+
**kwargs,
185197
) -> None:
186198
"""Initializes the BigQueryRetriever.
187199
@@ -197,32 +209,26 @@ def __init__(
197209
self.dataset_id = dataset_id
198210
self.table_id = table_id
199211

200-
async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> list[Document]:
212+
async def _retrieve_neighbors_data_from_db(self, neighbors: list[FindNeighborsResponse.Neighbor]) -> list[Document]:
201213
"""Retrieves document data from the BigQuery table for the given neighbors.
202214
203215
Constructs and executes a BigQuery query to fetch document data based on
204-
the IDs obtained. Handles potential errors during query execution and
216+
the IDs obtained. Handles potential errors during query execution and
205217
document parsing.
206218
207219
Args:
208-
neighbours: A list of Neighbor objects representing the nearest neighbors.
220+
neighbors: A list of Neighbor objects representing the nearest neighbors.
209221
Each neighbor should contain a datapoint with a datapoint_id.
210222
211223
Returns:
212224
A list of Document objects containing the retrieved document data.
213225
Returns an empty list if no IDs are found in the neighbors or if the
214226
query fails.
215227
"""
216-
ids = [
217-
n.datapoint.datapoint_id
218-
for n in neighbours
219-
if n.datapoint and n.datapoint.datapoint_id
220-
]
228+
ids = [n.datapoint.datapoint_id for n in neighbors if n.datapoint and n.datapoint.datapoint_id]
221229

222230
distance_by_id = {
223-
n.datapoint.datapoint_id: n.distance
224-
for n in neighbours
225-
if n.datapoint and n.datapoint.datapoint_id
231+
n.datapoint.datapoint_id: n.distance for n in neighbors if n.datapoint and n.datapoint.datapoint_id
226232
}
227233

228234
if not ids:
@@ -260,7 +266,7 @@ async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) ->
260266
documents.append(Document.from_text(content, metadata))
261267
except (ValidationError, json.JSONDecodeError, Exception) as error:
262268
doc_id = row.get('id', '<unknown>')
263-
await logger.awarning(f'Failed to parse document data for document with ID {doc_id}: {error}')
269+
await logger.awarning('Failed to parse document data for document with ID %s: %s', doc_id, error)
264270

265271
return documents
266272

@@ -276,8 +282,13 @@ class FirestoreRetriever(DocRetriever):
276282
db: The Firestore client.
277283
collection_name: The name of the Firestore collection.
278284
"""
285+
279286
def __init__(
280-
self, firestore_client: firestore.AsyncClient, collection_name: str, *args, **kwargs,
287+
self,
288+
firestore_client: firestore.AsyncClient,
289+
collection_name: str,
290+
*args,
291+
**kwargs,
281292
) -> None:
282293
"""Initializes the FirestoreRetriever.
283294
@@ -291,14 +302,14 @@ def __init__(
291302
self.db = firestore_client
292303
self.collection_name = collection_name
293304

294-
async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> list[Document]:
305+
async def _retrieve_neighbors_data_from_db(self, neighbors: list[FindNeighborsResponse.Neighbor]) -> list[Document]:
295306
"""Retrieves document data from the Firestore collection for the given neighbors.
296307
297308
Fetches document data from Firestore based on the IDs of the nearest neighbors.
298309
Handles potential errors during document retrieval and data parsing.
299310
300311
Args:
301-
neighbours: A list of Neighbor objects representing the nearest neighbors.
312+
neighbors: A list of Neighbor objects representing the nearest neighbors.
302313
Each neighbor should contain a datapoint with a datapoint_id.
303314
304315
Returns:
@@ -307,14 +318,14 @@ async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) ->
307318
"""
308319
documents: list[Document] = []
309320

310-
for neighbor in neighbours:
321+
for neighbor in neighbors:
311322
doc_ref = self.db.collection(self.collection_name).document(document_id=neighbor.datapoint.datapoint_id)
312323
doc_snapshot = doc_ref.get()
313324

314325
if doc_snapshot.exists:
315326
doc_data = doc_snapshot.to_dict() or {}
316327

317-
content = doc_data.get('content')
328+
content = doc_data.get('content', '')
318329
content = json.dumps(content) if isinstance(content, dict) else str(content)
319330

320331
metadata = doc_data.get('metadata', {})
@@ -330,7 +341,9 @@ async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) ->
330341
)
331342
except ValidationError as e:
332343
await logger.awarning(
333-
f'Failed to parse document data for ID {neighbor.datapoint.datapoint_id}: {e}'
344+
'Failed to parse document data for ID %s: %s',
345+
neighbor.datapoint.datapoint_id,
346+
e,
334347
)
335348

336349
return documents
@@ -342,4 +355,5 @@ class RetrieverOptionsSchema(BaseModel):
342355
Attributes:
343356
limit: Number of documents to retrieve.
344357
"""
358+
345359
limit: int | None = Field(title='Number of documents to retrieve', default=None)

py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py

-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from functools import partial
1818
from typing import Any
1919

20-
import structlog
2120
from google.auth.credentials import Credentials
2221
from google.cloud import aiplatform_v1
2322

@@ -28,8 +27,6 @@
2827
RetrieverOptionsSchema,
2928
)
3029

31-
logger = structlog.get_logger(__name__)
32-
3330

3431
class VertexAIVectorSearch(Plugin):
3532
"""A plugin for integrating VertexAI Vector Search.

0 commit comments

Comments
 (0)