21
21
22
22
import structlog
23
23
from google .cloud import bigquery , firestore
24
- from google .cloud .aiplatform_v1 import FindNeighborsRequest , IndexDatapoint , Neighbor
24
+ from google .cloud .aiplatform_v1 import FindNeighborsRequest , FindNeighborsResponse , IndexDatapoint
25
25
from pydantic import BaseModel , Field , ValidationError
26
26
27
27
from genkit .ai import Genkit
@@ -49,6 +49,7 @@ class DocRetriever(ABC):
49
49
embedder: The name of the embedder to use for generating embeddings.
50
50
embedder_options: Options to pass to the embedder.
51
51
"""
52
+
52
53
def __init__ (
53
54
self ,
54
55
ai : Genkit ,
@@ -117,20 +118,25 @@ async def _get_closest_documents(
117
118
A list of Document objects representing the closest documents.
118
119
119
120
Raises:
120
- AttributeError: If the request does not contain the necessary
121
+ AttributeError: If the request does not contain the necessary
121
122
index endpoint path in its metadata.
122
123
"""
123
124
metadata = request .query .metadata
124
- if not metadata or 'index_endpoint_path' not in metadata or 'api_endpoint' not in metadata :
125
- raise AttributeError ('Request provides no data about index endpoint path' )
125
+
126
+ required_keys = ['index_endpoint_path' , 'api_endpoint' , 'deployed_index_id' ]
127
+
128
+ if not metadata :
129
+ raise AttributeError ('Request metadata provides no data about index' )
130
+
131
+ for rkey in required_keys :
132
+ if rkey not in metadata :
133
+ raise AttributeError (f'Request metadata provides no data for { rkey } ' )
126
134
127
135
api_endpoint = metadata ['api_endpoint' ]
128
136
index_endpoint_path = metadata ['index_endpoint_path' ]
129
137
deployed_index_id = metadata ['deployed_index_id' ]
130
138
131
- client_options = {
132
- "api_endpoint" : api_endpoint
133
- }
139
+ client_options = {'api_endpoint' : api_endpoint }
134
140
135
141
vector_search_client = self ._match_service_client_generator (
136
142
client_options = client_options ,
@@ -149,17 +155,17 @@ async def _get_closest_documents(
149
155
150
156
response = await vector_search_client .find_neighbors (request = nn_request )
151
157
152
- return await self ._retrieve_neighbours_data_from_db ( neighbours = response .nearest_neighbors [0 ].neighbors )
158
+ return await self ._retrieve_neighbors_data_from_db ( neighbors = response .nearest_neighbors [0 ].neighbors )
153
159
154
160
@abstractmethod
155
- async def _retrieve_neighbours_data_from_db (self , neighbours : list [Neighbor ]) -> list [Document ]:
161
+ async def _retrieve_neighbors_data_from_db (self , neighbors : list [FindNeighborsResponse . Neighbor ]) -> list [Document ]:
156
162
"""Retrieves document data from the database based on neighbor information.
157
163
158
164
This method must be implemented by subclasses to define how document
159
165
data is fetched from the database using the provided neighbor information.
160
166
161
167
Args:
162
- neighbours : A list of Neighbor objects representing the nearest neighbors
168
+ neighbors : A list of Neighbor objects representing the nearest neighbors
163
169
found in the vector search index.
164
170
165
171
Returns:
@@ -180,8 +186,14 @@ class BigQueryRetriever(DocRetriever):
180
186
dataset_id: The ID of the BigQuery dataset.
181
187
table_id: The ID of the BigQuery table.
182
188
"""
189
+
183
190
def __init__ (
184
- self , bq_client : bigquery .Client , dataset_id : str , table_id : str , * args , ** kwargs ,
191
+ self ,
192
+ bq_client : bigquery .Client ,
193
+ dataset_id : str ,
194
+ table_id : str ,
195
+ * args ,
196
+ ** kwargs ,
185
197
) -> None :
186
198
"""Initializes the BigQueryRetriever.
187
199
@@ -197,32 +209,26 @@ def __init__(
197
209
self .dataset_id = dataset_id
198
210
self .table_id = table_id
199
211
200
- async def _retrieve_neighbours_data_from_db (self , neighbours : list [Neighbor ]) -> list [Document ]:
212
+ async def _retrieve_neighbors_data_from_db (self , neighbors : list [FindNeighborsResponse . Neighbor ]) -> list [Document ]:
201
213
"""Retrieves document data from the BigQuery table for the given neighbors.
202
214
203
215
Constructs and executes a BigQuery query to fetch document data based on
204
- the IDs obtained. Handles potential errors during query execution and
216
+ the IDs obtained. Handles potential errors during query execution and
205
217
document parsing.
206
218
207
219
Args:
208
- neighbours : A list of Neighbor objects representing the nearest neighbors.
220
+ neighbors : A list of Neighbor objects representing the nearest neighbors.
209
221
Each neighbor should contain a datapoint with a datapoint_id.
210
222
211
223
Returns:
212
224
A list of Document objects containing the retrieved document data.
213
225
Returns an empty list if no IDs are found in the neighbors or if the
214
226
query fails.
215
227
"""
216
- ids = [
217
- n .datapoint .datapoint_id
218
- for n in neighbours
219
- if n .datapoint and n .datapoint .datapoint_id
220
- ]
228
+ ids = [n .datapoint .datapoint_id for n in neighbors if n .datapoint and n .datapoint .datapoint_id ]
221
229
222
230
distance_by_id = {
223
- n .datapoint .datapoint_id : n .distance
224
- for n in neighbours
225
- if n .datapoint and n .datapoint .datapoint_id
231
+ n .datapoint .datapoint_id : n .distance for n in neighbors if n .datapoint and n .datapoint .datapoint_id
226
232
}
227
233
228
234
if not ids :
@@ -260,7 +266,7 @@ async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) ->
260
266
documents .append (Document .from_text (content , metadata ))
261
267
except (ValidationError , json .JSONDecodeError , Exception ) as error :
262
268
doc_id = row .get ('id' , '<unknown>' )
263
- await logger .awarning (f 'Failed to parse document data for document with ID { doc_id } : { error } ' )
269
+ await logger .awarning ('Failed to parse document data for document with ID %s: %s' , doc_id , error )
264
270
265
271
return documents
266
272
@@ -276,8 +282,13 @@ class FirestoreRetriever(DocRetriever):
276
282
db: The Firestore client.
277
283
collection_name: The name of the Firestore collection.
278
284
"""
285
+
279
286
def __init__ (
280
- self , firestore_client : firestore .AsyncClient , collection_name : str , * args , ** kwargs ,
287
+ self ,
288
+ firestore_client : firestore .AsyncClient ,
289
+ collection_name : str ,
290
+ * args ,
291
+ ** kwargs ,
281
292
) -> None :
282
293
"""Initializes the FirestoreRetriever.
283
294
@@ -291,14 +302,14 @@ def __init__(
291
302
self .db = firestore_client
292
303
self .collection_name = collection_name
293
304
294
- async def _retrieve_neighbours_data_from_db (self , neighbours : list [Neighbor ]) -> list [Document ]:
305
+ async def _retrieve_neighbors_data_from_db (self , neighbors : list [FindNeighborsResponse . Neighbor ]) -> list [Document ]:
295
306
"""Retrieves document data from the Firestore collection for the given neighbors.
296
307
297
308
Fetches document data from Firestore based on the IDs of the nearest neighbors.
298
309
Handles potential errors during document retrieval and data parsing.
299
310
300
311
Args:
301
- neighbours : A list of Neighbor objects representing the nearest neighbors.
312
+ neighbors : A list of Neighbor objects representing the nearest neighbors.
302
313
Each neighbor should contain a datapoint with a datapoint_id.
303
314
304
315
Returns:
@@ -307,14 +318,14 @@ async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) ->
307
318
"""
308
319
documents : list [Document ] = []
309
320
310
- for neighbor in neighbours :
321
+ for neighbor in neighbors :
311
322
doc_ref = self .db .collection (self .collection_name ).document (document_id = neighbor .datapoint .datapoint_id )
312
323
doc_snapshot = doc_ref .get ()
313
324
314
325
if doc_snapshot .exists :
315
326
doc_data = doc_snapshot .to_dict () or {}
316
327
317
- content = doc_data .get ('content' )
328
+ content = doc_data .get ('content' , '' )
318
329
content = json .dumps (content ) if isinstance (content , dict ) else str (content )
319
330
320
331
metadata = doc_data .get ('metadata' , {})
@@ -330,7 +341,9 @@ async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) ->
330
341
)
331
342
except ValidationError as e :
332
343
await logger .awarning (
333
- f'Failed to parse document data for ID { neighbor .datapoint .datapoint_id } : { e } '
344
+ 'Failed to parse document data for ID %s: %s' ,
345
+ neighbor .datapoint .datapoint_id ,
346
+ e ,
334
347
)
335
348
336
349
return documents
@@ -342,4 +355,5 @@ class RetrieverOptionsSchema(BaseModel):
342
355
Attributes:
343
356
limit: Number of documents to retrieve.
344
357
"""
358
+
345
359
limit : int | None = Field (title = 'Number of documents to retrieve' , default = None )
0 commit comments