Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 68 additions & 7 deletions astrbot/core/db/vec_db/faiss_impl/document_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,31 @@ def stopwords(self) -> set[str]:
self._stopwords = load_stopwords(stopwords_path)
return self._stopwords

@staticmethod
def _get_sparse_index_text(
content: str,
metadata: dict | str | None,
) -> str:
"""Resolve the text used by FTS and BM25 sparse retrieval.

Args:
content: Stored chunk text returned to retrieval callers.
metadata: Chunk metadata as a mapping or serialized JSON string.

Returns:
The table index text when present, otherwise the stored content.
"""
if isinstance(metadata, str):
try:
metadata = json.loads(metadata)
except json.JSONDecodeError:
return content
if isinstance(metadata, dict):
index_text = metadata.get("index_text")
if isinstance(index_text, str) and index_text.strip():
return index_text
return content

async def get_documents(
self,
metadata_filters: dict,
Expand Down Expand Up @@ -301,7 +326,11 @@ async def insert_document(self, doc_id: str, text: str, metadata: dict) -> int:
session.add(document)
await session.flush() # Flush to get the ID
if document.id is not None:
await self._insert_fts_row(session, int(document.id), text)
await self._insert_fts_row(
session,
int(document.id),
self._get_sparse_index_text(text, metadata),
)
return document.id # type: ignore

async def insert_documents_batch(
Expand Down Expand Up @@ -339,7 +368,14 @@ async def insert_documents_batch(
session.add(document)

await session.flush() # Flush to get all IDs
await self._insert_fts_rows_batch(session, documents, texts)
await self._insert_fts_rows_batch(
session,
documents,
[
self._get_sparse_index_text(content, metadata)
for content, metadata in zip(texts, metadatas)
],
)
return [doc.id for doc in documents] # type: ignore

async def delete_document_by_doc_id(self, doc_id: str) -> None:
Expand All @@ -358,7 +394,14 @@ async def delete_document_by_doc_id(self, doc_id: str) -> None:

if document:
if document.id is not None:
await self._delete_fts_row(session, int(document.id), document.text)
await self._delete_fts_row(
session,
int(document.id),
self._get_sparse_index_text(
document.text,
document.metadata_,
),
)
await session.delete(document)

async def get_document_by_doc_id(self, doc_id: str):
Expand Down Expand Up @@ -399,12 +442,24 @@ async def update_document_by_doc_id(self, doc_id: str, new_text: str) -> None:

if document:
if document.id is not None:
await self._delete_fts_row(session, int(document.id), document.text)
sparse_index_text = self._get_sparse_index_text(
document.text,
document.metadata_,
)
await self._delete_fts_row(
session,
int(document.id),
sparse_index_text,
)
document.text = new_text
document.updated_at = datetime.now()
session.add(document)
if document.id is not None:
await self._insert_fts_row(session, int(document.id), new_text)
await self._insert_fts_row(
session,
int(document.id),
self._get_sparse_index_text(new_text, document.metadata_),
)

async def delete_documents(self, metadata_filters: dict) -> None:
"""Delete documents by their metadata filters.
Expand Down Expand Up @@ -513,7 +568,10 @@ async def rebuild_fts_index(self) -> None:
await self._insert_fts_rows_batch(
session,
documents,
[doc.text for doc in documents],
[
self._get_sparse_index_text(doc.text, doc.metadata_)
for doc in documents
],
)
last_id = int(documents[-1].id or last_id)

Expand Down Expand Up @@ -700,7 +758,10 @@ async def _delete_fts_rows_batch(
fts_params = [
{
"rowid": int(doc.id),
"search_text": to_fts5_search_text(doc.text, self.stopwords),
"search_text": to_fts5_search_text(
self._get_sparse_index_text(doc.text, doc.metadata_),
self.stopwords,
),
}
for doc in docs_with_ids
if doc.id is not None and int(doc.id) in existing_rowids
Expand Down
20 changes: 19 additions & 1 deletion astrbot/core/db/vec_db/faiss_impl/vec_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,15 @@ async def insert_batch(
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
embedding_texts: list[str] | None = None,
) -> list[int]:
"""批量插入文本和其对应向量,自动生成 ID 并保持一致性。

Args:
progress_callback: 进度回调函数,接收参数 (current, total)
embedding_texts: 可选的向量化文本,用于将"用于语义匹配的文本"与
"用于存储/检索返回的文本(contents)"解耦。表格知识库使用索引列
文本进行向量化,但存储并返回整行文本。缺省时回退为 contents。

"""
metadatas = metadatas or [{} for _ in contents]
Expand All @@ -81,6 +85,20 @@ async def insert_batch(
)
return []

texts_to_embed = embedding_texts if embedding_texts is not None else contents
if len(texts_to_embed) != len(contents):
raise KnowledgeBaseUploadError(
stage="embedding",
user_message=(
f"向量化失败:用于向量化的文本数量与文本分块数量不一致"
f"(期望 {len(contents)},实际 {len(texts_to_embed)})。"
),
details={
"expected_contents": len(contents),
"actual_embedding_texts": len(texts_to_embed),
},
)

content_count = len(contents)
if len(metadatas) != content_count:
raise KnowledgeBaseUploadError(
Expand Down Expand Up @@ -110,7 +128,7 @@ async def insert_batch(
start = time.time()
logger.debug(f"Generating embeddings for {len(contents)} contents...")
vectors = await self.embedding_provider.get_embeddings_batch(
contents,
texts_to_embed,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
Expand Down
40 changes: 40 additions & 0 deletions astrbot/core/knowledge_base/kb_db_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,46 @@ async def migrate_to_v1(self) -> None:

await session.commit()

async def migrate_to_v2(self) -> None:
"""Run knowledge base database v2 migration.

Adds the table knowledge base columns to existing databases that were
created before the table feature. SQLite does not support
``ADD COLUMN IF NOT EXISTS``, so existing columns are checked via
``PRAGMA table_info`` before issuing ``ALTER TABLE`` statements.
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
kb_columns = {
row[1]
for row in (
await session.execute(
text("PRAGMA table_info(knowledge_bases)")
)
).fetchall()
}
if "kb_type" not in kb_columns:
await session.execute(
text(
"ALTER TABLE knowledge_bases "
"ADD COLUMN kb_type VARCHAR(20) NOT NULL DEFAULT 'text'",
),
)

doc_columns = {
row[1]
for row in (
await session.execute(text("PRAGMA table_info(kb_documents)"))
).fetchall()
}
if "table_schema" not in doc_columns:
await session.execute(
text("ALTER TABLE kb_documents ADD COLUMN table_schema TEXT"),
)

await session.commit()

async def close(self) -> None:
"""关闭数据库连接"""
await self.engine.dispose()
Expand Down
Loading