From a1323d98d2f53cbb268eedb782a466e5798c43ce Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Wed, 24 Jun 2026 23:46:06 +0800 Subject: [PATCH 1/2] fix(kb): align CRUD API contract end-to-end Unify knowledge base create/update/list with kb_name, canonical_payload, optional list pagination with total, and matching OpenAPI plus dashboard types. --- astrbot/dashboard/api/knowledge_bases.py | 25 +- astrbot/dashboard/schemas.py | 34 ++- .../services/knowledge_base_service.py | 63 +++-- .../src/api/generated/openapi-v1/types.gen.ts | 22 +- dashboard/src/api/v1.ts | 10 +- dashboard/src/views/knowledge-base/KBList.vue | 9 +- openspec/openapi-v1.yaml | 39 ++- .../test_knowledge_base_service_contract.py | 266 ++++++++++++++++++ 8 files changed, 406 insertions(+), 62 deletions(-) create mode 100644 tests/unit/test_knowledge_base_service_contract.py diff --git a/astrbot/dashboard/api/knowledge_bases.py b/astrbot/dashboard/api/knowledge_bases.py index c6f62235dd..6d8084a19b 100644 --- a/astrbot/dashboard/api/knowledge_bases.py +++ b/astrbot/dashboard/api/knowledge_bases.py @@ -53,14 +53,6 @@ def _to_int(value: Any, default: int) -> int: return default -def _model_dict(payload) -> dict[str, Any]: - if payload is None: - return {} - if hasattr(payload, "model_dump"): - return payload.model_dump(exclude_none=True) - return payload if isinstance(payload, dict) else {} - - async def _run(operation, *, prefix: str): try: result = await run_maybe_async(operation) @@ -94,7 +86,11 @@ async def list_knowledge_bases( return await _run( lambda: service.list_kbs( page=_to_int(request.query_params.get("page"), 1), - page_size=_to_int(request.query_params.get("page_size"), 20), + page_size=( + _to_int(request.query_params.get("page_size"), 20) + if "page" in request.query_params or "page_size" in request.query_params + else None + ), ), prefix="获取知识库列表失败", ) @@ -107,7 +103,7 @@ async def create_knowledge_base( service: KnowledgeBaseService = Depends(get_service), ): return await _run( - lambda: service.create_kb(_model_dict(payload)), + lambda: service.create_kb(payload.canonical_payload()), prefix="创建知识库失败", ) @@ -140,9 +136,8 @@ async def update_knowledge_base( _auth: AuthContext = Depends(require_kb_scope), service: KnowledgeBaseService = Depends(get_service), ): - body = _model_dict(payload) return await _run( - lambda: service.update_kb({"kb_id": kb_id, **body}), + lambda: service.update_kb({**payload.canonical_payload(), "kb_id": kb_id}), prefix="更新知识库失败", ) @@ -322,7 +317,11 @@ async def dashboard_list_kbs( return await _run( lambda: service.list_kbs( page=_to_int(request.query_params.get("page"), 1), - page_size=_to_int(request.query_params.get("page_size"), 20), + page_size=( + _to_int(request.query_params.get("page_size"), 20) + if "page" in request.query_params or "page_size" in request.query_params + else None + ), ), prefix="获取知识库列表失败", ) diff --git a/astrbot/dashboard/schemas.py b/astrbot/dashboard/schemas.py index 773ed6cd87..f3d5f20b88 100644 --- a/astrbot/dashboard/schemas.py +++ b/astrbot/dashboard/schemas.py @@ -205,14 +205,42 @@ class ImMessageRequest(OpenModel): class KnowledgeBaseRequest(OpenModel): - kb_id: str | None = None - name: str | None = None + kb_name: str | None = None description: str | None = None + emoji: str | None = None embedding_provider_id: str | None = None rerank_provider_id: str | None = None chunk_size: int | None = None chunk_overlap: int | None = None - + top_k_dense: int | None = None + top_k_sparse: int | None = None + top_m_final: int | None = None + + def canonical_payload(self) -> dict[str, Any]: + """Return the service-facing knowledge base payload. + + Returns: + Dictionary accepted by KnowledgeBaseService. + """ + data = self.model_dump( + exclude_unset=True, + include={ + "kb_name", + "description", + "emoji", + "embedding_provider_id", + "rerank_provider_id", + "chunk_size", + "chunk_overlap", + "top_k_dense", + "top_k_sparse", + "top_m_final", + }, + ) + legacy_name = getattr(self, "name", None) + if data.get("kb_name") is None and legacy_name is not None: + data["kb_name"] = legacy_name + return data class KnowledgeBaseImportRequest(OpenModel): documents: list[dict[str, Any]] | None = None diff --git a/astrbot/dashboard/services/knowledge_base_service.py b/astrbot/dashboard/services/knowledge_base_service.py index c7f9546418..5558391913 100644 --- a/astrbot/dashboard/services/knowledge_base_service.py +++ b/astrbot/dashboard/services/knowledge_base_service.py @@ -29,6 +29,22 @@ def __init__(self, core_lifecycle: AstrBotCoreLifecycle) -> None: def _payload(data: object) -> dict[str, Any]: return data if isinstance(data, dict) else {} + @staticmethod + def _canonical_kb_payload(data: object) -> dict[str, Any]: + """Normalize knowledge base create/update payloads. + + Args: + data: Request payload from v1 or legacy Dashboard routes. + + Returns: + Payload using the service's canonical field names. + """ + payload = KnowledgeBaseService._payload(data).copy() + if payload.get("kb_name") is None and payload.get("name") is not None: + payload["kb_name"] = payload["name"] + payload.pop("name", None) + return payload + def get_kb_manager(self): return self.core_lifecycle.kb_manager @@ -263,19 +279,30 @@ async def background_import_task( logger.error(traceback.format_exc()) self.set_task_result(task_id, "failed", error=str(exc)) - async def list_kbs(self, *, page: int, page_size: int) -> dict[str, Any]: + async def list_kbs(self, *, page: int, page_size: int | None) -> dict[str, Any]: kb_manager = self.get_kb_manager() kbs = await kb_manager.list_kbs() kb_list = [] - for kb in kbs: + selected_kbs = kbs + if page_size is not None: + start = max(page - 1, 0) * page_size + end = start + page_size + selected_kbs = kbs[start:end] + + for kb in selected_kbs: kb_dict = kb.model_dump() kb_helper = await kb_manager.get_kb(kb.kb_id) if kb_helper and kb_helper.init_error: kb_dict["init_error"] = kb_helper.init_error kb_list.append(kb_dict) - return {"items": kb_list, "page": page, "page_size": page_size} + return { + "items": kb_list, + "page": page, + "page_size": page_size if page_size is not None else len(kbs), + "total": len(kbs), + } async def list_kbs_from_dashboard_query(self, *, page, page_size) -> dict[str, Any]: return await self.list_kbs( @@ -285,7 +312,7 @@ async def list_kbs_from_dashboard_query(self, *, page, page_size) -> dict[str, A async def create_kb(self, data: object) -> tuple[dict[str, Any], str]: kb_manager = self.get_kb_manager() - payload = self._payload(data) + payload = self._canonical_kb_payload(data) kb_name = payload.get("kb_name") if not kb_name: raise KnowledgeBaseServiceError("知识库名称不能为空") @@ -355,7 +382,7 @@ async def get_kb_from_dashboard_query(self, kb_id: str | None) -> dict[str, Any] return await self.get_kb(kb_id) async def update_kb(self, data: object) -> tuple[dict[str, Any], str]: - payload = self._payload(data) + payload = self._canonical_kb_payload(data) kb_id = payload.get("kb_id") if not kb_id: raise KnowledgeBaseServiceError("缺少参数 kb_id") @@ -372,28 +399,20 @@ async def update_kb(self, data: object) -> tuple[dict[str, Any], str]: "top_k_sparse", "top_m_final", ] - if all(payload.get(key) is None for key in update_keys): + provided_updates = {key: payload[key] for key in update_keys if key in payload} + if not provided_updates: raise KnowledgeBaseServiceError("至少需要提供一个更新字段") current_kb = await self.get_kb_manager().get_kb(kb_id) - kb_name = payload.get("kb_name") - if kb_name is None: - if not current_kb: - raise KnowledgeBaseServiceError("知识库不存在") - kb_name = current_kb.kb.kb_name + if not current_kb: + raise KnowledgeBaseServiceError("知识库不存在") + current = current_kb.kb + update_data = {key: getattr(current, key) for key in update_keys} + update_data.update(provided_updates) kb_helper = await self.get_kb_manager().update_kb( kb_id=kb_id, - kb_name=kb_name, - description=payload.get("description"), - emoji=payload.get("emoji"), - embedding_provider_id=payload.get("embedding_provider_id"), - rerank_provider_id=payload.get("rerank_provider_id"), - chunk_size=payload.get("chunk_size"), - chunk_overlap=payload.get("chunk_overlap"), - top_k_dense=payload.get("top_k_dense"), - top_k_sparse=payload.get("top_k_sparse"), - top_m_final=payload.get("top_m_final"), + **update_data, ) if not kb_helper: raise KnowledgeBaseServiceError("知识库不存在") @@ -738,11 +757,11 @@ async def retrieve(self, data: object) -> dict[str, Any]: if not query: raise KnowledgeBaseServiceError("缺少参数 query") + kb_manager = self.get_kb_manager() if not kb_names or not isinstance(kb_names, list): raise KnowledgeBaseServiceError("缺少参数 kb_names 或格式错误") top_k = payload.get("top_k", 5) - kb_manager = self.get_kb_manager() results = await kb_manager.retrieve( query=query, kb_names=kb_names, diff --git a/dashboard/src/api/generated/openapi-v1/types.gen.ts b/dashboard/src/api/generated/openapi-v1/types.gen.ts index b82f7f58d1..ad82c84ad4 100644 --- a/dashboard/src/api/generated/openapi-v1/types.gen.ts +++ b/dashboard/src/api/generated/openapi-v1/types.gen.ts @@ -255,13 +255,22 @@ export type JsonSchema = { [key: string]: unknown; }; +export type KnowledgeBaseCreateRequest = KnowledgeBaseRequest & { + kb_name: string; + embedding_provider_id: string; +}; + export type KnowledgeBaseRequest = { - name: string; + kb_name?: string; description?: string; - embedding_provider_id?: string; - rerank_provider_id?: string; - chunking?: DynamicConfig; - metadata?: DynamicConfig; + emoji?: string; + embedding_provider_id?: (string) | null; + rerank_provider_id?: (string) | null; + chunk_size?: number; + chunk_overlap?: number; + top_k_dense?: number; + top_k_sparse?: number; + top_m_final?: number; }; export type KnowledgeDocumentImportRequest = { @@ -271,7 +280,6 @@ export type KnowledgeDocumentImportRequest = { export type KnowledgeDocumentUploadRequest = { file: (Blob | File); - parser?: string; }; export type KnowledgeDocumentUrlImportRequest = { @@ -2569,7 +2577,7 @@ export type ListKnowledgeBasesResponse = (SuccessEnvelope); export type ListKnowledgeBasesError = unknown; export type CreateKnowledgeBaseData = { - body: KnowledgeBaseRequest; + body: KnowledgeBaseCreateRequest; }; export type CreateKnowledgeBaseResponse = (SuccessEnvelope); diff --git a/dashboard/src/api/v1.ts b/dashboard/src/api/v1.ts index 7641c6a1fe..e953d3e0d6 100644 --- a/dashboard/src/api/v1.ts +++ b/dashboard/src/api/v1.ts @@ -32,6 +32,8 @@ import { type DynamicConfig, type EnabledPatch, type GhproxyTestRequest, + type KnowledgeBaseCreateRequest, + type KnowledgeBaseRequest, type LoginRequest, type ListConversationsData, type McpServerConfig, @@ -1352,16 +1354,16 @@ export const knowledgeApi = { openApiV1.getKnowledgeBase({ path: { kb_id: kbId } }), ); }, - create(config: OpenConfig) { + create(config: KnowledgeBaseCreateRequest) { return typed( - openApiV1.createKnowledgeBase({ body: config as any }), + openApiV1.createKnowledgeBase({ body: config }), ); }, - update(kbId: string, config: OpenConfig) { + update(kbId: string, config: KnowledgeBaseRequest) { return typed( openApiV1.updateKnowledgeBase({ path: { kb_id: kbId }, - body: config as any, + body: config, }), ); }, diff --git a/dashboard/src/views/knowledge-base/KBList.vue b/dashboard/src/views/knowledge-base/KBList.vue index d25b8e458a..71055bb376 100644 --- a/dashboard/src/views/knowledge-base/KBList.vue +++ b/dashboard/src/views/knowledge-base/KBList.vue @@ -152,7 +152,9 @@ + :label="t('create.embeddingModelLabel')" variant="outlined" class="mb-4" :disabled="editingKB !== null" + :rules="[v => editingKB !== null || !!v || t('create.embeddingModelRequired')]" required + hint="嵌入模型选择后无法修改,如需更换请创建新的知识库。" persistent-hint>