-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrag_core.py
More file actions
126 lines (100 loc) · 3.94 KB
/
rag_core.py
File metadata and controls
126 lines (100 loc) · 3.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
from dotenv import load_dotenv
import chromadb
from chromadb.utils import embedding_functions
import google.generativeai as genai
load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if not GEMINI_API_KEY:
raise RuntimeError("Set GEMINI_API_KEY in your .env")
genai.configure(api_key=GEMINI_API_KEY)
class GeminiEmbeddingFunction(embedding_functions.EmbeddingFunction):
"""
Minimal wrapper so Chroma can call Gemini embeddings automatically.
"""
def __init__(self, model_name: str = "text-embedding-004"):
self.model_name = model_name
def __call__(self, texts):
print(f"Embedding via Gemini... {texts}")
if isinstance(texts, str):
texts = [texts]
vectors = []
for t in texts:
res = genai.embed_content(model=self.model_name, content=t)
emb = res.get("embedding")
if isinstance(emb, dict) and "values" in emb:
emb = emb["values"]
if not isinstance(emb, list):
raise RuntimeError("Unexpected embedding response from Gemini")
vectors.append(emb)
print(f"Embedded {len(texts)} texts via Gemini.")
return vectors
gemini_ef = GeminiEmbeddingFunction(model_name="text-embedding-004")
chroma_client = chromadb.PersistentClient(path="chroma_persistent_storage")
collection_name = "my_collection"
collection = chroma_client.get_or_create_collection(
name=collection_name,
embedding_function=gemini_ef
)
def _load_documents_from_directory(directory_path: str):
print("=== Loading documents from directory ===")
documents = []
for filename in os.listdir(directory_path):
if filename.endswith(".txt"):
with open(os.path.join(directory_path, filename), 'r', encoding='utf-8') as file:
documents.append({"id": filename, "text": file.read()})
return documents
def _split_text(text: str, chunk_size: int = 1000, overlap: int = 20):
chunks = []
start = 0
text_length = len(text)
while start < text_length:
end = min(start + chunk_size, text_length)
chunks.append(text[start:end])
start += chunk_size - overlap
return chunks
def ensure_index_built(data_dir: str = "./my_data") -> None:
"""Index local text files into Chroma if collection is empty."""
try:
current_count = collection.count()
except Exception:
current_count = 0
if current_count > 0:
print(f"Chroma already has {current_count} items; skipping re-index.")
return
documents = _load_documents_from_directory(data_dir)
print(f"Loaded {len(documents)} documents from {data_dir}")
chunked_documents = []
for doc in documents:
chunks = _split_text(doc["text"])
for i, chunk in enumerate(chunks):
chunked_documents.append({
"id": f"{doc['id']}_chunk_{i}",
"text": chunk
})
print(f"Split documents into {len(chunked_documents)} chunks")
if chunked_documents:
collection.add(
ids=[c["id"] for c in chunked_documents],
documents=[c["text"] for c in chunked_documents]
)
print("Indexed chunks into Chroma with Gemini embeddings.")
else:
print("No documents to index.")
def rag_answer(question: str, top_k: int = 5, chat_model: str = "gemini-1.5-flash") -> str:
# Retrieve
hits = collection.query(query_texts=[question], n_results=top_k)
docs = hits.get("documents", [[]])[0] if hits else []
context = "\n\n".join(docs)
# Generate with Gemini
prompt = (
"You are a concise RAG assistant. Use ONLY the provided context. "
"If the answer is not in the context, say you don't know.\n\n"
f"Context:\n{context}\n\n"
f"Question: {question}"
)
model = genai.GenerativeModel(chat_model)
resp = model.generate_content(prompt)
return resp.text
# Build index on import (idempotent)
ensure_index_built()