diff --git a/astrbot/dashboard/api/knowledge_bases.py b/astrbot/dashboard/api/knowledge_bases.py index 595f6ff911..ed0498098c 100644 --- a/astrbot/dashboard/api/knowledge_bases.py +++ b/astrbot/dashboard/api/knowledge_bases.py @@ -9,6 +9,7 @@ from astrbot.dashboard.async_utils import run_maybe_async from astrbot.dashboard.responses import error, ok from astrbot.dashboard.schemas import ( + KnowledgeBaseCreateRequest, KnowledgeBaseImportRequest, KnowledgeBaseRequest, KnowledgeBaseRetrieveRequest, @@ -53,14 +54,6 @@ def _to_int(value: Any, default: int) -> int: return default -def _model_dict(payload) -> dict[str, Any]: - if payload is None: - return {} - if hasattr(payload, "model_dump"): - return payload.model_dump(exclude_none=True) - return payload if isinstance(payload, dict) else {} - - async def _run(operation, *, prefix: str): try: result = await run_maybe_async(operation) @@ -102,12 +95,12 @@ async def list_knowledge_bases( @router.post("/knowledge-bases") async def create_knowledge_base( - payload: KnowledgeBaseRequest, + payload: KnowledgeBaseCreateRequest, _auth: AuthContext = Depends(require_kb_scope), service: KnowledgeBaseService = Depends(get_service), ): return await _run( - lambda: service.create_kb(_model_dict(payload)), + lambda: service.create_kb(payload.canonical_payload()), prefix="创建知识库失败", ) @@ -140,9 +133,8 @@ async def update_knowledge_base( _auth: AuthContext = Depends(require_kb_scope), service: KnowledgeBaseService = Depends(get_service), ): - body = _model_dict(payload) return await _run( - lambda: service.update_kb({"kb_id": kb_id, **body}), + lambda: service.update_kb({**payload.canonical_payload(), "kb_id": kb_id}), prefix="更新知识库失败", ) @@ -213,7 +205,7 @@ async def import_knowledge_base_documents( _auth: AuthContext = Depends(require_kb_scope), service: KnowledgeBaseService = Depends(get_service), ): - body = _model_dict(payload) + body = payload.model_dump(exclude_none=True) return await _run( lambda: service.import_documents({"kb_id": kb_id, **body}), prefix="导入文档失败", @@ -227,7 +219,7 @@ async def import_knowledge_base_document_url( _auth: AuthContext = Depends(require_kb_scope), service: KnowledgeBaseService = Depends(get_service), ): - body = _model_dict(payload) + body = payload.model_dump(exclude_none=True) return await _run( lambda: service.upload_document_from_url({"kb_id": kb_id, **body}), prefix="从URL上传文档失败", @@ -307,7 +299,7 @@ async def retrieve_knowledge_base( _auth: AuthContext = Depends(require_kb_scope), service: KnowledgeBaseService = Depends(get_service), ): - body = _model_dict(payload) + body = payload.model_dump(exclude_none=True) return await _run( lambda: service.retrieve({"kb_id": kb_id, **body}), prefix="检索失败", diff --git a/astrbot/dashboard/schemas.py b/astrbot/dashboard/schemas.py index f37449c532..c7b5d00ba8 100644 --- a/astrbot/dashboard/schemas.py +++ b/astrbot/dashboard/schemas.py @@ -205,13 +205,46 @@ class ImMessageRequest(OpenModel): class KnowledgeBaseRequest(OpenModel): - kb_id: str | None = None - name: str | None = None + kb_name: str | None = Field(None, alias="name") description: str | None = None + emoji: str | None = None embedding_provider_id: str | None = None rerank_provider_id: str | None = None chunk_size: int | None = None chunk_overlap: int | None = None + top_k_dense: int | None = None + top_k_sparse: int | None = None + top_m_final: int | None = None + + model_config = ConfigDict(populate_by_name=True, extra="allow") + + def canonical_payload(self) -> dict[str, Any]: + """Return the service-facing knowledge base payload. + + Returns: + Dictionary accepted by KnowledgeBaseService. + """ + return self.model_dump( + exclude_unset=True, + include={ + "kb_name", + "description", + "emoji", + "embedding_provider_id", + "rerank_provider_id", + "chunk_size", + "chunk_overlap", + "top_k_dense", + "top_k_sparse", + "top_m_final", + }, + by_alias=False, + ) + + +class KnowledgeBaseCreateRequest(KnowledgeBaseRequest): + kb_name: str = Field(..., alias="name") + embedding_provider_id: str class KnowledgeBaseImportRequest(OpenModel): diff --git a/astrbot/dashboard/services/knowledge_base_service.py b/astrbot/dashboard/services/knowledge_base_service.py index ec162aa299..76ad91359d 100644 --- a/astrbot/dashboard/services/knowledge_base_service.py +++ b/astrbot/dashboard/services/knowledge_base_service.py @@ -12,6 +12,7 @@ from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider from astrbot.core.utils.astrbot_path import get_astrbot_temp_path +from astrbot.dashboard.schemas import KnowledgeBaseRequest from astrbot.dashboard.utils import generate_tsne_visualization @@ -29,6 +30,19 @@ def __init__(self, core_lifecycle: AstrBotCoreLifecycle) -> None: def _payload(data: object) -> dict[str, Any]: return data if isinstance(data, dict) else {} + @staticmethod + def _canonical_kb_payload(data: object) -> dict[str, Any]: + """Normalize knowledge base create/update payloads. + + Uses KnowledgeBaseRequest to handle the legacy ``name`` → + ``kb_name`` migration while preserving operational fields + like ``kb_id``. + """ + raw = KnowledgeBaseService._payload(data) + canonical = KnowledgeBaseRequest(**raw).canonical_payload() + raw.update(canonical) + return raw + def get_kb_manager(self): return self.core_lifecycle.kb_manager @@ -293,7 +307,7 @@ async def list_kbs_from_dashboard_query(self, *, page, page_size) -> dict[str, A async def create_kb(self, data: object) -> tuple[dict[str, Any], str]: kb_manager = self.get_kb_manager() - payload = self._payload(data) + payload = self._canonical_kb_payload(data) kb_name = payload.get("kb_name") if not kb_name: raise KnowledgeBaseServiceError("知识库名称不能为空") @@ -363,7 +377,7 @@ async def get_kb_from_dashboard_query(self, kb_id: str | None) -> dict[str, Any] return await self.get_kb(kb_id) async def update_kb(self, data: object) -> tuple[dict[str, Any], str]: - payload = self._payload(data) + payload = self._canonical_kb_payload(data) kb_id = payload.get("kb_id") if not kb_id: raise KnowledgeBaseServiceError("缺少参数 kb_id") @@ -380,28 +394,20 @@ async def update_kb(self, data: object) -> tuple[dict[str, Any], str]: "top_k_sparse", "top_m_final", ] - if all(payload.get(key) is None for key in update_keys): + provided_updates = {key: payload[key] for key in update_keys if key in payload} + if not provided_updates: raise KnowledgeBaseServiceError("至少需要提供一个更新字段") current_kb = await self.get_kb_manager().get_kb(kb_id) - kb_name = payload.get("kb_name") - if kb_name is None: - if not current_kb: - raise KnowledgeBaseServiceError("知识库不存在") - kb_name = current_kb.kb.kb_name + if not current_kb: + raise KnowledgeBaseServiceError("知识库不存在") + current = current_kb.kb + update_data = {key: getattr(current, key, None) for key in update_keys} + update_data.update(provided_updates) kb_helper = await self.get_kb_manager().update_kb( kb_id=kb_id, - kb_name=kb_name, - description=payload.get("description"), - emoji=payload.get("emoji"), - embedding_provider_id=payload.get("embedding_provider_id"), - rerank_provider_id=payload.get("rerank_provider_id"), - chunk_size=payload.get("chunk_size"), - chunk_overlap=payload.get("chunk_overlap"), - top_k_dense=payload.get("top_k_dense"), - top_k_sparse=payload.get("top_k_sparse"), - top_m_final=payload.get("top_m_final"), + **update_data, ) if not kb_helper: raise KnowledgeBaseServiceError("知识库不存在") @@ -762,11 +768,11 @@ async def retrieve(self, data: object) -> dict[str, Any]: if not query: raise KnowledgeBaseServiceError("缺少参数 query") + kb_manager = self.get_kb_manager() if not kb_names or not isinstance(kb_names, list): raise KnowledgeBaseServiceError("缺少参数 kb_names 或格式错误") top_k = payload.get("top_k", 5) - kb_manager = self.get_kb_manager() results = await kb_manager.retrieve( query=query, kb_names=kb_names, diff --git a/dashboard/src/api/generated/openapi-v1/types.gen.ts b/dashboard/src/api/generated/openapi-v1/types.gen.ts index 49e94b2bb1..ffe7c5e394 100644 --- a/dashboard/src/api/generated/openapi-v1/types.gen.ts +++ b/dashboard/src/api/generated/openapi-v1/types.gen.ts @@ -255,13 +255,22 @@ export type JsonSchema = { [key: string]: unknown; }; +export type KnowledgeBaseCreateRequest = KnowledgeBaseRequest & { + kb_name: string; + embedding_provider_id: string; +}; + export type KnowledgeBaseRequest = { - name: string; + kb_name?: string; description?: string; - embedding_provider_id?: string; - rerank_provider_id?: string; - chunking?: DynamicConfig; - metadata?: DynamicConfig; + emoji?: string; + embedding_provider_id?: (string) | null; + rerank_provider_id?: (string) | null; + chunk_size?: number; + chunk_overlap?: number; + top_k_dense?: number; + top_k_sparse?: number; + top_m_final?: number; }; export type KnowledgeDocumentImportRequest = { @@ -271,7 +280,6 @@ export type KnowledgeDocumentImportRequest = { export type KnowledgeDocumentUploadRequest = { file: (Blob | File); - parser?: string; }; export type KnowledgeDocumentUrlImportRequest = { @@ -2606,7 +2614,7 @@ export type ListKnowledgeBasesResponse = (SuccessEnvelope); export type ListKnowledgeBasesError = unknown; export type CreateKnowledgeBaseData = { - body: KnowledgeBaseRequest; + body: KnowledgeBaseCreateRequest; }; export type CreateKnowledgeBaseResponse = (SuccessEnvelope); diff --git a/dashboard/src/api/v1.ts b/dashboard/src/api/v1.ts index 29c365caa5..d10f532c22 100644 --- a/dashboard/src/api/v1.ts +++ b/dashboard/src/api/v1.ts @@ -32,6 +32,8 @@ import { type DynamicConfig, type EnabledPatch, type GhproxyTestRequest, + type KnowledgeBaseCreateRequest, + type KnowledgeBaseRequest, type LoginRequest, type ListConversationsData, type McpServerConfig, @@ -1366,16 +1368,16 @@ export const knowledgeApi = { openApiV1.getKnowledgeBase({ path: { kb_id: kbId } }), ); }, - create(config: OpenConfig) { + create(config: KnowledgeBaseCreateRequest) { return typed( - openApiV1.createKnowledgeBase({ body: config as any }), + openApiV1.createKnowledgeBase({ body: config }), ); }, - update(kbId: string, config: OpenConfig) { + update(kbId: string, config: KnowledgeBaseRequest) { return typed( openApiV1.updateKnowledgeBase({ path: { kb_id: kbId }, - body: config as any, + body: config, }), ); }, diff --git a/dashboard/src/views/knowledge-base/KBList.vue b/dashboard/src/views/knowledge-base/KBList.vue index 4ca8b593d9..d9fd6772d9 100644 --- a/dashboard/src/views/knowledge-base/KBList.vue +++ b/dashboard/src/views/knowledge-base/KBList.vue @@ -161,7 +161,9 @@ + :label="t('create.embeddingModelLabel')" variant="outlined" class="mb-4" :disabled="editingKB !== null" + :rules="[v => editingKB !== null || !!v || t('create.embeddingModelRequired')]" required + hint="嵌入模型选择后无法修改,如需更换请创建新的知识库。" persistent-hint>