Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions astrbot/dashboard/api/knowledge_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,6 @@ def _to_int(value: Any, default: int) -> int:
return default


def _model_dict(payload) -> dict[str, Any]:
if payload is None:
return {}
if hasattr(payload, "model_dump"):
return payload.model_dump(exclude_none=True)
return payload if isinstance(payload, dict) else {}


async def _run(operation, *, prefix: str):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The helper function _model_dict was removed from this file. However, it is still referenced in multiple other endpoints within the same file:

  • import_knowledge_base_documents (line 210)
  • import_knowledge_base_document_url (line 224)
  • retrieve_knowledge_base (line 304)

Removing _model_dict will cause a NameError when any of these endpoints are called. Please restore _model_dict or refactor those endpoints to use payload.model_dump(exclude_none=True) directly.

def _model_dict(payload) -> dict[str, Any]:
    if payload is None:
        return {}
    if hasattr(payload, "model_dump"):
        return payload.model_dump(exclude_none=True)
    return payload if isinstance(payload, dict) else {}


async def _run(operation, *, prefix: str):

try:
result = await run_maybe_async(operation)
Expand Down Expand Up @@ -94,7 +86,11 @@ async def list_knowledge_bases(
return await _run(
lambda: service.list_kbs(
page=_to_int(request.query_params.get("page"), 1),
page_size=_to_int(request.query_params.get("page_size"), 20),
page_size=(
_to_int(request.query_params.get("page_size"), 20)
if "page" in request.query_params or "page_size" in request.query_params
else None
),
),
prefix="获取知识库列表失败",
)
Expand All @@ -107,7 +103,7 @@ async def create_knowledge_base(
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.create_kb(_model_dict(payload)),
lambda: service.create_kb(payload.canonical_payload()),
prefix="创建知识库失败",
)

Expand Down Expand Up @@ -140,9 +136,8 @@ async def update_knowledge_base(
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
body = _model_dict(payload)
return await _run(
lambda: service.update_kb({"kb_id": kb_id, **body}),
lambda: service.update_kb({**payload.canonical_payload(), "kb_id": kb_id}),
prefix="更新知识库失败",
)

Expand Down Expand Up @@ -212,7 +207,7 @@ async def import_knowledge_base_documents(
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
body = _model_dict(payload)
body = payload.model_dump(exclude_none=True)
return await _run(
lambda: service.import_documents({"kb_id": kb_id, **body}),
prefix="导入文档失败",
Expand All @@ -226,7 +221,7 @@ async def import_knowledge_base_document_url(
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
body = _model_dict(payload)
body = payload.model_dump(exclude_none=True)
return await _run(
lambda: service.upload_document_from_url({"kb_id": kb_id, **body}),
prefix="从URL上传文档失败",
Expand Down Expand Up @@ -306,7 +301,7 @@ async def retrieve_knowledge_base(
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
body = _model_dict(payload)
body = payload.model_dump(exclude_none=True)
return await _run(
lambda: service.retrieve({"kb_id": kb_id, **body}),
prefix="检索失败",
Expand All @@ -322,7 +317,11 @@ async def dashboard_list_kbs(
return await _run(
lambda: service.list_kbs(
page=_to_int(request.query_params.get("page"), 1),
page_size=_to_int(request.query_params.get("page_size"), 20),
page_size=(
_to_int(request.query_params.get("page_size"), 20)
if "page" in request.query_params or "page_size" in request.query_params
else None
),
),
prefix="获取知识库列表失败",
)
Expand Down
37 changes: 35 additions & 2 deletions astrbot/dashboard/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,46 @@ class ImMessageRequest(OpenModel):


class KnowledgeBaseRequest(OpenModel):
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
kb_id: str | None = None
name: str | None = None
kb_name: str | None = Field(None, alias="name")
description: str | None = None
emoji: str | None = None
embedding_provider_id: str | None = None
rerank_provider_id: str | None = None
chunk_size: int | None = None
chunk_overlap: int | None = None
top_k_dense: int | None = None
top_k_sparse: int | None = None
top_m_final: int | None = None

model_config = ConfigDict(populate_by_name=True, extra="allow")

def canonical_payload(self) -> dict[str, Any]:
"""Return the service-facing knowledge base payload.

Returns:
Dictionary accepted by KnowledgeBaseService.
"""
return self.model_dump(
exclude_unset=True,
include={
"kb_name",
"description",
"emoji",
"embedding_provider_id",
"rerank_provider_id",
"chunk_size",
"chunk_overlap",
"top_k_dense",
"top_k_sparse",
"top_m_final",
},
by_alias=False,
)


class KnowledgeBaseCreateRequest(KnowledgeBaseRequest):
kb_name: str = Field(..., alias="name")
embedding_provider_id: str


class KnowledgeBaseImportRequest(OpenModel):
Comment on lines 249 to 250

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The PR description mentions syncing KnowledgeBaseCreateRequest Pydantic models, and the OpenAPI spec/frontend types expect it. However, KnowledgeBaseCreateRequest is not defined in astrbot/dashboard/schemas.py.

Please define KnowledgeBaseCreateRequest inheriting from KnowledgeBaseRequest with kb_name and embedding_provider_id as required fields.

Suggested change
class KnowledgeBaseImportRequest(OpenModel):
class KnowledgeBaseCreateRequest(KnowledgeBaseRequest):
kb_name: str
embedding_provider_id: str
class KnowledgeBaseImportRequest(OpenModel):

Expand Down
61 changes: 39 additions & 22 deletions astrbot/dashboard/services/knowledge_base_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.dashboard.schemas import KnowledgeBaseRequest
from astrbot.dashboard.utils import generate_tsne_visualization


Expand All @@ -29,6 +30,19 @@ def __init__(self, core_lifecycle: AstrBotCoreLifecycle) -> None:
def _payload(data: object) -> dict[str, Any]:
return data if isinstance(data, dict) else {}

@staticmethod
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
def _canonical_kb_payload(data: object) -> dict[str, Any]:
"""Normalize knowledge base create/update payloads.

Uses KnowledgeBaseRequest to handle the legacy ``name`` →
``kb_name`` migration while preserving operational fields
like ``kb_id``.
"""
raw = KnowledgeBaseService._payload(data)
canonical = KnowledgeBaseRequest(**raw).canonical_payload()
raw.update(canonical)
return raw

def get_kb_manager(self):
return self.core_lifecycle.kb_manager

Expand Down Expand Up @@ -263,19 +277,30 @@ async def background_import_task(
logger.error(traceback.format_exc())
self.set_task_result(task_id, "failed", error=str(exc))

async def list_kbs(self, *, page: int, page_size: int) -> dict[str, Any]:
async def list_kbs(self, *, page: int, page_size: int | None) -> dict[str, Any]:
kb_manager = self.get_kb_manager()
kbs = await kb_manager.list_kbs()

kb_list = []
for kb in kbs:
selected_kbs = kbs
if page_size is not None:
start = max(page - 1, 0) * page_size
end = start + page_size
selected_kbs = kbs[start:end]

for kb in selected_kbs:
kb_dict = kb.model_dump()
kb_helper = await kb_manager.get_kb(kb.kb_id)
if kb_helper and kb_helper.init_error:
kb_dict["init_error"] = kb_helper.init_error
kb_list.append(kb_dict)

return {"items": kb_list, "page": page, "page_size": page_size}
return {
"items": kb_list,
"page": page,
"page_size": page_size if page_size is not None else len(kbs),
"total": len(kbs),
}

async def list_kbs_from_dashboard_query(self, *, page, page_size) -> dict[str, Any]:
return await self.list_kbs(
Expand All @@ -285,7 +310,7 @@ async def list_kbs_from_dashboard_query(self, *, page, page_size) -> dict[str, A

async def create_kb(self, data: object) -> tuple[dict[str, Any], str]:
kb_manager = self.get_kb_manager()
payload = self._payload(data)
payload = self._canonical_kb_payload(data)
kb_name = payload.get("kb_name")
if not kb_name:
raise KnowledgeBaseServiceError("知识库名称不能为空")
Expand Down Expand Up @@ -355,7 +380,7 @@ async def get_kb_from_dashboard_query(self, kb_id: str | None) -> dict[str, Any]
return await self.get_kb(kb_id)

async def update_kb(self, data: object) -> tuple[dict[str, Any], str]:
payload = self._payload(data)
payload = self._canonical_kb_payload(data)
kb_id = payload.get("kb_id")
if not kb_id:
raise KnowledgeBaseServiceError("缺少参数 kb_id")
Expand All @@ -372,28 +397,20 @@ async def update_kb(self, data: object) -> tuple[dict[str, Any], str]:
"top_k_sparse",
"top_m_final",
]
if all(payload.get(key) is None for key in update_keys):
provided_updates = {key: payload[key] for key in update_keys if key in payload}
if not provided_updates:
raise KnowledgeBaseServiceError("至少需要提供一个更新字段")

current_kb = await self.get_kb_manager().get_kb(kb_id)
kb_name = payload.get("kb_name")
if kb_name is None:
if not current_kb:
raise KnowledgeBaseServiceError("知识库不存在")
kb_name = current_kb.kb.kb_name
if not current_kb:
raise KnowledgeBaseServiceError("知识库不存在")
current = current_kb.kb
update_data = {key: getattr(current, key, None) for key in update_keys}
update_data.update(provided_updates)

kb_helper = await self.get_kb_manager().update_kb(
kb_id=kb_id,
kb_name=kb_name,
description=payload.get("description"),
emoji=payload.get("emoji"),
embedding_provider_id=payload.get("embedding_provider_id"),
rerank_provider_id=payload.get("rerank_provider_id"),
chunk_size=payload.get("chunk_size"),
chunk_overlap=payload.get("chunk_overlap"),
top_k_dense=payload.get("top_k_dense"),
top_k_sparse=payload.get("top_k_sparse"),
top_m_final=payload.get("top_m_final"),
**update_data,
)
if not kb_helper:
raise KnowledgeBaseServiceError("知识库不存在")
Expand Down Expand Up @@ -738,11 +755,11 @@ async def retrieve(self, data: object) -> dict[str, Any]:

if not query:
raise KnowledgeBaseServiceError("缺少参数 query")
kb_manager = self.get_kb_manager()
if not kb_names or not isinstance(kb_names, list):
raise KnowledgeBaseServiceError("缺少参数 kb_names 或格式错误")

top_k = payload.get("top_k", 5)
kb_manager = self.get_kb_manager()
results = await kb_manager.retrieve(
query=query,
kb_names=kb_names,
Expand Down
22 changes: 15 additions & 7 deletions dashboard/src/api/generated/openapi-v1/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,22 @@ export type JsonSchema = {
[key: string]: unknown;
};

export type KnowledgeBaseCreateRequest = KnowledgeBaseRequest & {
kb_name: string;
embedding_provider_id: string;
};

export type KnowledgeBaseRequest = {
name: string;
kb_name?: string;
description?: string;
embedding_provider_id?: string;
rerank_provider_id?: string;
chunking?: DynamicConfig;
metadata?: DynamicConfig;
emoji?: string;
embedding_provider_id?: (string) | null;
rerank_provider_id?: (string) | null;
chunk_size?: number;
chunk_overlap?: number;
top_k_dense?: number;
top_k_sparse?: number;
top_m_final?: number;
};

export type KnowledgeDocumentImportRequest = {
Expand All @@ -271,7 +280,6 @@ export type KnowledgeDocumentImportRequest = {

export type KnowledgeDocumentUploadRequest = {
file: (Blob | File);
parser?: string;
};

export type KnowledgeDocumentUrlImportRequest = {
Expand Down Expand Up @@ -2569,7 +2577,7 @@ export type ListKnowledgeBasesResponse = (SuccessEnvelope);
export type ListKnowledgeBasesError = unknown;

export type CreateKnowledgeBaseData = {
body: KnowledgeBaseRequest;
body: KnowledgeBaseCreateRequest;
};

export type CreateKnowledgeBaseResponse = (SuccessEnvelope);
Expand Down
10 changes: 6 additions & 4 deletions dashboard/src/api/v1.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import {
type DynamicConfig,
type EnabledPatch,
type GhproxyTestRequest,
type KnowledgeBaseCreateRequest,
type KnowledgeBaseRequest,
type LoginRequest,
type ListConversationsData,
type McpServerConfig,
Expand Down Expand Up @@ -1352,16 +1354,16 @@ export const knowledgeApi = {
openApiV1.getKnowledgeBase({ path: { kb_id: kbId } }),
);
},
create(config: OpenConfig) {
create(config: KnowledgeBaseCreateRequest) {
return typed<OpenConfig>(
openApiV1.createKnowledgeBase({ body: config as any }),
openApiV1.createKnowledgeBase({ body: config }),
);
},
update(kbId: string, config: OpenConfig) {
update(kbId: string, config: KnowledgeBaseRequest) {
return typed<OpenConfig>(
openApiV1.updateKnowledgeBase({
path: { kb_id: kbId },
body: config as any,
body: config,
}),
);
},
Expand Down
9 changes: 7 additions & 2 deletions dashboard/src/views/knowledge-base/KBList.vue
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@

<v-select v-model="formData.embedding_provider_id" :items="embeddingProviders"
:item-title="item => item.embedding_model || item.id" :item-value="'id'"
:label="t('create.embeddingModelLabel')" variant="outlined" class="mb-4" :disabled="editingKB !== null" hint="嵌入模型选择后无法修改,如需更换请创建新的知识库。" persistent-hint>
:label="t('create.embeddingModelLabel')" variant="outlined" class="mb-4" :disabled="editingKB !== null"
:rules="[v => editingKB !== null || !!v || t('create.embeddingModelRequired')]" required
hint="嵌入模型选择后无法修改,如需更换请创建新的知识库。" persistent-hint>
<template #item="{ props, item }">
<v-list-item v-bind="props">
<template #subtitle>
Expand Down Expand Up @@ -441,7 +443,10 @@ const submitForm = async () => {
if (editingKB.value) {
response = await knowledgeApi.update(editingKB.value.kb_id, payload)
} else {
response = await knowledgeApi.create(payload)
response = await knowledgeApi.create({
...payload,
embedding_provider_id: formData.value.embedding_provider_id!
})
}

if (response.data.status === 'ok') {
Expand Down
Loading
Loading