From ee70de669d3a3940816d8c8d033ed20aa5e35028 Mon Sep 17 00:00:00 2001 From: dengjingren Date: Mon, 22 Jun 2026 22:32:44 +0800 Subject: [PATCH 1/4] Add Text2SQL natural language query support - Add text2sql engine: guardrails, schema metadata, readonly DB access - Add dataset/table/example CRUD, schema, service layer - Wire text2sql capability and v1 router endpoint - Add web search (Exa/Tavily) and Text2SQL settings in plugin.toml - Add guardrails tests --- .env.example | 15 + .gitignore | 2 + api/router.py | 2 + api/v1/text2sql.py | 353 +++++++++++++++++++++++ capabilities/text2sql.py | 106 +++++++ chat/pipeline.py | 2 + crud/crud_text2sql_dataset.py | 109 +++++++ crud/crud_text2sql_example.py | 102 +++++++ crud/crud_text2sql_table.py | 139 +++++++++ model/__init__.py | 4 + model/text2sql.py | 62 ++++ plugin.toml | 17 +- schema/chat.py | 1 + schema/text2sql.py | 151 ++++++++++ service/text2sql_service.py | 412 +++++++++++++++++++++++++++ sql/mysql/destroy.sql | 11 +- sql/mysql/destroy_snowflake.sql | 11 +- sql/mysql/init.sql | 10 +- sql/mysql/init_snowflake.sql | 8 + sql/postgresql/destroy.sql | 11 +- sql/postgresql/destroy_snowflake.sql | 11 +- sql/postgresql/init.sql | 10 +- sql/postgresql/init_snowflake.sql | 8 + tests/__init__.py | 1 + tests/test_guardrails.py | 110 +++++++ text2sql/__init__.py | 1 + text2sql/engine.py | 307 ++++++++++++++++++++ text2sql/exceptions.py | 24 ++ text2sql/guardrails.py | 88 ++++++ text2sql/readonly_db.py | 90 ++++++ text2sql/schema_meta.py | 120 ++++++++ 31 files changed, 2291 insertions(+), 7 deletions(-) create mode 100644 api/v1/text2sql.py create mode 100644 capabilities/text2sql.py create mode 100644 crud/crud_text2sql_dataset.py create mode 100644 crud/crud_text2sql_example.py create mode 100644 crud/crud_text2sql_table.py create mode 100644 model/text2sql.py create mode 100644 schema/text2sql.py create mode 100644 service/text2sql_service.py create mode 100644 tests/__init__.py create mode 100644 tests/test_guardrails.py create mode 100644 text2sql/__init__.py create mode 100644 text2sql/engine.py create mode 100644 text2sql/exceptions.py create mode 100644 text2sql/guardrails.py create mode 100644 text2sql/readonly_db.py create mode 100644 text2sql/schema_meta.py diff --git a/.env.example b/.env.example index e9fdbea..fee537e 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,18 @@ # [ Plugin ] ai AI_EXA_API_KEY='' AI_TAVILY_API_KEY='' + +# [ Plugin ] ai · Text2SQL +AI_TEXT2SQL_ENABLED=false +AI_TEXT2SQL_SCHEMA=fba +AI_TEXT2SQL_MAX_ROWS=200 +AI_TEXT2SQL_TIMEOUT=15 +AI_TEXT2SQL_MAX_RETRIES=2 +# 默认模型(providers 表 id + 模型 id;留空则取首个启用的 OpenAI 兼容供应商+模型) +AI_TEXT2SQL_PROVIDER_ID=0 +AI_TEXT2SQL_MODEL_ID= +# 只读数据库账号(强烈建议配置仅 SELECT 权限账号;留空则回退主库并强制护栏) +AI_TEXT2SQL_READONLY_HOST= +AI_TEXT2SQL_READONLY_PORT=0 +AI_TEXT2SQL_READONLY_USER= +AI_TEXT2SQL_READONLY_PASSWORD= diff --git a/.gitignore b/.gitignore index d2c0607..57da058 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ __pycache__/ .cursor/ .claude/ .DS_Store + +.omc/ diff --git a/api/router.py b/api/router.py index 00480f9..e89b11d 100644 --- a/api/router.py +++ b/api/router.py @@ -8,6 +8,7 @@ from backend.plugin.ai.api.v1.model import router as model_router from backend.plugin.ai.api.v1.provider import router as provider_router from backend.plugin.ai.api.v1.quick_phrase import router as quick_phrase_router +from backend.plugin.ai.api.v1.text2sql import router as text2sql_router v1 = APIRouter(prefix=settings.FASTAPI_API_V1_PATH) @@ -18,3 +19,4 @@ v1.include_router(model_router, prefix='/models', tags=['AI 模型管理']) v1.include_router(provider_router, prefix='/providers', tags=['AI 供应商管理']) v1.include_router(mcp_router, prefix='/mcps', tags=['AI MCP 管理']) +v1.include_router(text2sql_router, prefix='/text2sql', tags=['AI Text2SQL']) diff --git a/api/v1/text2sql.py b/api/v1/text2sql.py new file mode 100644 index 0000000..5ca229e --- /dev/null +++ b/api/v1/text2sql.py @@ -0,0 +1,353 @@ +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, Path, Query, Request + +from backend.common.pagination import DependsPagination, PageData +from backend.common.response.response_schema import ResponseModel, ResponseSchemaModel, response_base +from backend.common.security.jwt import DependsJwtAuth +from backend.common.security.permission import RequestPermission +from backend.common.security.rbac import DependsRBAC +from backend.database.db import CurrentSession, CurrentSessionTransaction +from backend.plugin.ai.schema.text2sql import ( + CreateText2SqlDatasetParam, + CreateText2SqlExampleParam, + CreateText2SqlTableParam, + GetText2SqlDatasetDetail, + GetText2SqlExampleDetail, + GetText2SqlTableDetail, + Text2SqlDatasetEnabled, + Text2SqlQueryParam, + Text2SqlQueryResult, + Text2SqlTableSelectable, + UpdateText2SqlDatasetParam, + UpdateText2SqlExampleParam, + UpdateText2SqlTableParam, +) +from backend.plugin.ai.service.text2sql_service import text2sql_service +from backend.plugin.ai.text2sql.engine import run_query + +router = APIRouter() + + +# ---------------- 数据集 ---------------- + + +@router.get('/datasets/enabled', summary='获取启用的数据集(chat 选择器)', dependencies=[DependsJwtAuth]) +async def get_enabled_datasets(db: CurrentSession) -> ResponseSchemaModel[list[Text2SqlDatasetEnabled]]: + data = await text2sql_service.get_enabled_datasets(db=db) + return response_base.success(data=data) + + +@router.get('/datasets/all', summary='获取全部数据集', dependencies=[DependsJwtAuth]) +async def get_all_datasets(db: CurrentSession) -> ResponseSchemaModel[list[GetText2SqlDatasetDetail]]: + data = await text2sql_service.get_all_datasets(db=db) + return response_base.success(data=data) + + +@router.get('/datasets/{pk}', summary='获取数据集详情', dependencies=[DependsJwtAuth]) +async def get_dataset( + db: CurrentSession, pk: Annotated[int, Path(description='数据集 ID')] +) -> ResponseSchemaModel[GetText2SqlDatasetDetail]: + data = await text2sql_service.get_dataset(db=db, pk=pk) + return response_base.success(data=data) + + +@router.get( + '/datasets', + summary='分页获取数据集', + dependencies=[ + DependsJwtAuth, + DependsPagination, + ], +) +async def get_datasets( + db: CurrentSession, + name: Annotated[str | None, Query(description='数据集名称(模糊)')] = None, + enabled: Annotated[int | None, Query(description='是否启用(0停用 1启用)')] = None, +) -> ResponseSchemaModel[PageData[GetText2SqlDatasetDetail]]: + page_data = await text2sql_service.get_dataset_list(db=db, name=name, enabled=enabled) + return response_base.success(data=page_data) + + +@router.post( + '/datasets', + summary='新增数据集', + dependencies=[ + Depends(RequestPermission('ai:text2sql:dataset:add')), + DependsRBAC, + ], +) +async def create_dataset(db: CurrentSessionTransaction, obj: CreateText2SqlDatasetParam) -> ResponseModel: + await text2sql_service.create_dataset(db=db, obj=obj) + return response_base.success() + + +@router.put( + '/datasets/{pk}', + summary='更新数据集', + dependencies=[ + Depends(RequestPermission('ai:text2sql:dataset:edit')), + DependsRBAC, + ], +) +async def update_dataset( + db: CurrentSessionTransaction, + pk: Annotated[int, Path(description='数据集 ID')], + obj: UpdateText2SqlDatasetParam, +) -> ResponseModel: + count = await text2sql_service.update_dataset(db=db, pk=pk, obj=obj) + if count > 0: + return response_base.success() + return response_base.fail() + + +@router.delete( + '/datasets/{pk}', + summary='删除数据集', + dependencies=[ + Depends(RequestPermission('ai:text2sql:dataset:del')), + DependsRBAC, + ], +) +async def delete_dataset( + db: CurrentSessionTransaction, pk: Annotated[int, Path(description='数据集 ID')] +) -> ResponseModel: + count = await text2sql_service.delete_dataset(db=db, pk=pk) + if count > 0: + return response_base.success() + return response_base.fail() + + +# ---------------- 自然语言查询 ---------------- + + +@router.post( + '/queries', + summary='自然语言查询(Text2SQL)', + dependencies=[ + Depends(RequestPermission('ai:text2sql:query')), + DependsRBAC, + ], +) +async def text2sql_query( + request: Request, + db: CurrentSession, + obj: Text2SqlQueryParam, +) -> ResponseSchemaModel[Text2SqlQueryResult]: + tables = await text2sql_service.get_enabled(db=db, dataset_id=obj.dataset_id) + examples = await text2sql_service.get_examples_for( + db=db, + tables={table.table_name for table in tables}, + dataset_id=obj.dataset_id, + ) + data = await run_query( + db=db, + question=obj.question, + user_id=request.user.id, + selected_tables=tables, + examples=examples, + ) + return response_base.success(data=data) + + +# ---------------- 数据源管理(已选表) ---------------- + + +@router.get('/tables', summary='获取可挑选的数据库表', dependencies=[DependsJwtAuth]) +async def get_selectable_tables( + db: CurrentSession, + dataset_id: Annotated[int, Query(description='所属数据集 ID')], + table_schema: Annotated[str | None, Query(description='库名/schema,缺省取 AI_TEXT2SQL_SCHEMA')] = None, +) -> ResponseSchemaModel[list[Text2SqlTableSelectable]]: + data = await text2sql_service.list_selectable_tables(db=db, dataset_id=dataset_id, table_schema=table_schema) + return response_base.success(data=data) + + +@router.get('/tables/{table_name}/columns', summary='获取表列信息', dependencies=[DependsJwtAuth]) +async def get_table_columns( + db: CurrentSession, + table_name: Annotated[str, Path(description='表名')], + table_schema: Annotated[str | None, Query(description='库名/schema')] = None, +) -> ResponseSchemaModel[list[dict[str, Any]]]: + data = await text2sql_service.get_table_columns(db=db, table_name=table_name, table_schema=table_schema) + return response_base.success(data=data) + + +@router.get('/selected-tables/all', summary='获取全部已选表', dependencies=[DependsJwtAuth]) +async def get_all_selected_tables(db: CurrentSession) -> ResponseSchemaModel[list[GetText2SqlTableDetail]]: + data = await text2sql_service.get_all_selected(db=db) + return response_base.success(data=data) + + +@router.get('/selected-tables/{pk}', summary='获取已选表详情', dependencies=[DependsJwtAuth]) +async def get_selected_table( + db: CurrentSession, pk: Annotated[int, Path(description='已选表 ID')] +) -> ResponseSchemaModel[GetText2SqlTableDetail]: + data = await text2sql_service.get_selected(db=db, pk=pk) + return response_base.success(data=data) + + +@router.get( + '/selected-tables', + summary='分页获取已选表', + dependencies=[ + DependsJwtAuth, + DependsPagination, + ], +) +async def get_selected_tables( + db: CurrentSession, + dataset_id: Annotated[int | None, Query(description='所属数据集 ID')] = None, + schema_name: Annotated[str | None, Query(description='库名/schema')] = None, + table_name: Annotated[str | None, Query(description='表名')] = None, + enabled: Annotated[int | None, Query(description='是否启用(0停用 1启用)')] = None, +) -> ResponseSchemaModel[PageData[GetText2SqlTableDetail]]: + page_data = await text2sql_service.get_selected_list( + db=db, + dataset_id=dataset_id, + schema_name=schema_name, + table_name=table_name, + enabled=enabled, + ) + return response_base.success(data=page_data) + + +@router.post( + '/selected-tables', + summary='挑选表', + dependencies=[ + Depends(RequestPermission('ai:text2sql:table:add')), + DependsRBAC, + ], +) +async def select_table(db: CurrentSessionTransaction, obj: CreateText2SqlTableParam) -> ResponseModel: + await text2sql_service.select_table(db=db, obj=obj) + return response_base.success() + + +@router.put( + '/selected-tables/{pk}', + summary='更新已选表', + dependencies=[ + Depends(RequestPermission('ai:text2sql:table:edit')), + DependsRBAC, + ], +) +async def update_selected_table( + db: CurrentSessionTransaction, + pk: Annotated[int, Path(description='已选表 ID')], + obj: UpdateText2SqlTableParam, +) -> ResponseModel: + count = await text2sql_service.update_selected(db=db, pk=pk, obj=obj) + if count > 0: + return response_base.success() + return response_base.fail() + + +@router.delete( + '/selected-tables/{pk}', + summary='取消挑选', + dependencies=[ + Depends(RequestPermission('ai:text2sql:table:del')), + DependsRBAC, + ], +) +async def unselect_table( + db: CurrentSessionTransaction, pk: Annotated[int, Path(description='已选表 ID')] +) -> ResponseModel: + count = await text2sql_service.unselect_table(db=db, pk=pk) + if count > 0: + return response_base.success() + return response_base.fail() + + +# ---------------- Few-shot 样例 ---------------- + + +@router.get('/examples/all', summary='获取全部启用样例', dependencies=[DependsJwtAuth]) +async def get_all_examples( + db: CurrentSession, + dataset_id: Annotated[int | None, Query(description='所属数据集 ID')] = None, +) -> ResponseSchemaModel[list[GetText2SqlExampleDetail]]: + data = await text2sql_service.get_all_examples(db=db, dataset_id=dataset_id) + return response_base.success(data=data) + + +@router.get('/examples/{pk}', summary='获取样例详情', dependencies=[DependsJwtAuth]) +async def get_example( + db: CurrentSession, pk: Annotated[int, Path(description='样例 ID')] +) -> ResponseSchemaModel[GetText2SqlExampleDetail]: + data = await text2sql_service.get_example(db=db, pk=pk) + return response_base.success(data=data) + + +@router.get( + '/examples', + summary='分页获取样例', + dependencies=[ + DependsJwtAuth, + DependsPagination, + ], +) +async def get_examples( + db: CurrentSession, + dataset_id: Annotated[int | None, Query(description='所属数据集 ID')] = None, + question: Annotated[str | None, Query(description='自然语言问题(模糊)')] = None, + enabled: Annotated[int | None, Query(description='是否启用(0停用 1启用)')] = None, +) -> ResponseSchemaModel[PageData[GetText2SqlExampleDetail]]: + page_data = await text2sql_service.get_example_list( + db=db, + dataset_id=dataset_id, + question=question, + enabled=enabled, + ) + return response_base.success(data=page_data) + + +@router.post( + '/examples', + summary='新增样例', + dependencies=[ + Depends(RequestPermission('ai:text2sql:example:add')), + DependsRBAC, + ], +) +async def create_example(db: CurrentSessionTransaction, obj: CreateText2SqlExampleParam) -> ResponseModel: + await text2sql_service.create_example(db=db, obj=obj) + return response_base.success() + + +@router.put( + '/examples/{pk}', + summary='更新样例', + dependencies=[ + Depends(RequestPermission('ai:text2sql:example:edit')), + DependsRBAC, + ], +) +async def update_example( + db: CurrentSessionTransaction, + pk: Annotated[int, Path(description='样例 ID')], + obj: UpdateText2SqlExampleParam, +) -> ResponseModel: + count = await text2sql_service.update_example(db=db, pk=pk, obj=obj) + if count > 0: + return response_base.success() + return response_base.fail() + + +@router.delete( + '/examples/{pk}', + summary='删除样例', + dependencies=[ + Depends(RequestPermission('ai:text2sql:example:del')), + DependsRBAC, + ], +) +async def delete_example( + db: CurrentSessionTransaction, pk: Annotated[int, Path(description='样例 ID')] +) -> ResponseModel: + count = await text2sql_service.delete_example(db=db, pk=pk) + if count > 0: + return response_base.success() + return response_base.fail() diff --git a/capabilities/text2sql.py b/capabilities/text2sql.py new file mode 100644 index 0000000..6c36eb8 --- /dev/null +++ b/capabilities/text2sql.py @@ -0,0 +1,106 @@ +import json + +from pydantic_ai.capabilities import AbstractCapability, Toolset +from pydantic_ai.tools import RunContext +from pydantic_ai.toolsets import FunctionToolset + +from backend.core.conf import settings +from backend.database.db import async_db_session +from backend.plugin.ai.capabilities.base import function_tools_allowed +from backend.plugin.ai.dataclasses import CapabilityContext, CapabilityResult, ChatAgentDeps +from backend.plugin.ai.enums import AIChatGenerationType +from backend.plugin.ai.service.text2sql_service import text2sql_service + + +async def build_text2sql_capability(ctx: CapabilityContext) -> CapabilityResult: # noqa: RUF029 + """ + 构建 Text2SQL 能力:让聊天 Agent 按所选数据集自助取数 + + 当 text2sql_dataset_id 指定了一个数据集、生成类型为文本、且当前组合允许函数工具时, + 注入 text2sql_query 工具,聊天即可回答统计/明细类数据问题。 + 复用会话当前选择的供应商与模型,工具仅可见该数据集下启用的表与样例。 + + :param ctx: 能力构建上下文 + :return: + """ + dataset_id = ctx.forwarded_props.text2sql_dataset_id + if not dataset_id: + return CapabilityResult(capability=None) + if ctx.forwarded_props.generation_type != AIChatGenerationType.text: + return CapabilityResult(capability=None) + if not function_tools_allowed( + adapter=ctx.adapter, + supports_tools=ctx.supports_tools, + has_builtin_tools=ctx.has_builtin_tools, + ): + return CapabilityResult(capability=None) + return CapabilityResult( + capability=_build_text2sql_toolset( + provider_id=ctx.forwarded_props.provider_id, + model_id=ctx.forwarded_props.model_id, + dataset_id=dataset_id, + ), + introduces_function_tool_source=True, + ) + + +def _build_text2sql_toolset(*, provider_id: int, model_id: str, dataset_id: int) -> AbstractCapability[ChatAgentDeps]: + """ + 装配 text2sql_query 函数工具(复用 run_query,全程只读、过护栏) + + 闭包捕获会话当前选择的 provider_id / model_id 与所选数据集 dataset_id; + 取数时仅按 dataset_id 圈定可见表与召回样例。 + run_query 采用函数内懒加载,避免与 chat.session/pipeline 形成模块加载期循环导入。 + + :param provider_id: 会话当前选择的供应商 ID + :param model_id: 会话当前选择的模型 ID + :param dataset_id: 所选数据集 ID + :return: + """ + toolset = FunctionToolset[ChatAgentDeps]() + + @toolset.tool + async def text2sql_query(ctx: RunContext[ChatAgentDeps], question: str) -> str: + """ + 用自然语言查询当前所选数据集中的 FBA 业务数据(Text2SQL)。 + + 仅对该数据集中启用的表执行只读查询,全程受安全护栏保护。 + 适合回答统计、聚合、明细类数据问题,例如: + 「最近 7 天订单总金额是多少」「每个供应商有多少条记录」「按金额倒序的前 10 笔订单」。 + + :param question: 用户的自然语言数据问题 + :return: + """ + # 懒加载:engine.run_query → chat.session → chat.pipeline,避免模块加载期循环导入 + from backend.plugin.ai.text2sql.engine import run_query + + async with async_db_session() as db: + tables = await text2sql_service.get_enabled(db=db, dataset_id=dataset_id) + if not tables: + return '当前数据集尚未挑选任何数据表,请先在「数据集」管理中挑选可查询的表。' + examples = await text2sql_service.get_examples_for( + db=db, + tables={table.table_name for table in tables}, + dataset_id=dataset_id, + ) + result = await run_query( + db=db, + question=question, + user_id=ctx.deps.user_id, + selected_tables=tables, + examples=examples, + provider_id=provider_id, + model_id=model_id, + ) + + max_rows = int(settings.AI_TEXT2SQL_MAX_ROWS) + preview = (result.get('rows') or [])[:max_rows] + return ( + f"SQL:\n{result.get('sql') or ''}\n\n" + f"摘要:{result.get('summary') or ''}\n" + f"命中行数:{result.get('row_count', 0)}\n" + f"列:{result.get('columns') or []}\n" + f"数据预览:\n{json.dumps(preview, ensure_ascii=False, default=str)}" + ) + + return Toolset(toolset) diff --git a/chat/pipeline.py b/chat/pipeline.py index 1c46ca3..e230e43 100644 --- a/chat/pipeline.py +++ b/chat/pipeline.py @@ -13,6 +13,7 @@ from backend.plugin.ai.capabilities.image import build_image_generation_capability from backend.plugin.ai.capabilities.mcp import build_mcp_capability from backend.plugin.ai.capabilities.search import build_search_capabilities +from backend.plugin.ai.capabilities.text2sql import build_text2sql_capability from backend.plugin.ai.capabilities.thinking import build_thinking_capability from backend.plugin.ai.dataclasses import CapabilityContext from backend.plugin.ai.enums import AIProviderType @@ -23,6 +24,7 @@ build_thinking_capability, build_mcp_capability, build_search_capabilities, + build_text2sql_capability, build_code_execution_capability, build_image_generation_capability, build_builtin_toolset_capability, diff --git a/crud/crud_text2sql_dataset.py b/crud/crud_text2sql_dataset.py new file mode 100644 index 0000000..f04fed3 --- /dev/null +++ b/crud/crud_text2sql_dataset.py @@ -0,0 +1,109 @@ +from collections.abc import Sequence + +from sqlalchemy import Select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy_crud_plus import CRUDPlus + +from backend.plugin.ai.model import AIText2SqlDataset +from backend.plugin.ai.schema.text2sql import CreateText2SqlDatasetParam, UpdateText2SqlDatasetParam +from backend.utils.timezone import timezone + + +class CRUDText2SqlDataset(CRUDPlus[AIText2SqlDataset]): + """Text2SQL 数据集数据库操作类""" + + async def get(self, db: AsyncSession, pk: int) -> AIText2SqlDataset | None: + """ + 获取数据集 + + :param db: 数据库会话 + :param pk: 记录 ID + :return: + """ + return await self.select_model(db, pk, deleted=0) + + async def get_all(self, db: AsyncSession) -> Sequence[AIText2SqlDataset]: + """ + 获取全部数据集 + + :param db: 数据库会话 + :return: + """ + return await self.select_models(db, deleted=0) + + async def get_enabled(self, db: AsyncSession) -> Sequence[AIText2SqlDataset]: + """ + 获取全部启用的数据集(供 chat 选择器) + + :param db: 数据库会话 + :return: + """ + return await self.select_models(db, enabled=1, deleted=0) + + async def get_by_name(self, db: AsyncSession, name: str) -> AIText2SqlDataset | None: + """ + 通过名称获取数据集 + + :param db: 数据库会话 + :param name: 数据集名称 + :return: + """ + return await self.select_model_by_column(db, name=name, deleted=0) + + async def get_select(self, name: str | None, enabled: int | None) -> Select: + """ + 获取数据集分页查询 + + :param name: 数据集名称(模糊) + :param enabled: 是否启用 + :return: + """ + filters = {'deleted': 0} + if name is not None: + filters.update(name__like=f'%{name}%') + if enabled is not None: + filters.update(enabled=enabled) + return await self.select_order('sort', 'asc', **filters) + + async def create(self, db: AsyncSession, obj: CreateText2SqlDatasetParam) -> None: + """ + 创建数据集 + + :param db: 数据库会话 + :param obj: 创建参数 + :return: + """ + await self.create_model(db, obj) + + async def update(self, db: AsyncSession, pk: int, obj: UpdateText2SqlDatasetParam) -> int: + """ + 更新数据集 + + :param db: 数据库会话 + :param pk: 记录 ID + :param obj: 更新参数 + :return: + """ + return await self.update_model_by_column(db, obj, id=pk, deleted=0) + + async def delete(self, db: AsyncSession, pk: int) -> int: + """ + 删除数据集(逻辑删除) + + :param db: 数据库会话 + :param pk: 记录 ID + :return: + """ + return await self.delete_model_by_column( + db, + logical_deletion=True, + deleted_flag_column='deleted', + deleted_flag_value=self.model.id, + deleted_at_column='deleted_time', + deleted_at_factory=timezone.now(), + id=pk, + deleted=0, + ) + + +text2sql_dataset_dao: CRUDText2SqlDataset = CRUDText2SqlDataset(AIText2SqlDataset) diff --git a/crud/crud_text2sql_example.py b/crud/crud_text2sql_example.py new file mode 100644 index 0000000..a5bc27c --- /dev/null +++ b/crud/crud_text2sql_example.py @@ -0,0 +1,102 @@ +from collections.abc import Sequence + +from sqlalchemy import Select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy_crud_plus import CRUDPlus + +from backend.plugin.ai.model import AIText2SqlExample +from backend.plugin.ai.schema.text2sql import CreateText2SqlExampleParam, UpdateText2SqlExampleParam +from backend.utils.timezone import timezone + + +class CRUDText2SqlExample(CRUDPlus[AIText2SqlExample]): + """Text2SQL Few-shot 样例数据库操作类""" + + async def get(self, db: AsyncSession, pk: int) -> AIText2SqlExample | None: + """ + 获取样例 + + :param db: 数据库会话 + :param pk: 记录 ID + :return: + """ + return await self.select_model(db, pk, deleted=0) + + async def get_all_enabled(self, db: AsyncSession, dataset_id: int | None = None) -> Sequence[AIText2SqlExample]: + """ + 获取全部启用的样例(供 Agent 召回) + + :param db: 数据库会话 + :param dataset_id: 数据集 ID;传入则仅返回该数据集的样例,不传则返回全部 + :return: + """ + filters: dict = {'enabled': 1, 'deleted': 0} + if dataset_id is not None: + filters['dataset_id'] = dataset_id + return await self.select_models(db, **filters) + + async def get_select( + self, + question: str | None, + enabled: int | None, + dataset_id: int | None = None, + ) -> Select: + """ + 获取样例分页查询 + + :param question: 自然语言问题(模糊) + :param enabled: 是否启用 + :param dataset_id: 所属数据集 ID + :return: + """ + filters = {'deleted': 0} + if dataset_id is not None: + filters.update(dataset_id=dataset_id) + if question is not None: + filters.update(question__like=f'%{question}%') + if enabled is not None: + filters.update(enabled=enabled) + return await self.select_order('sort', 'asc', **filters) + + async def create(self, db: AsyncSession, obj: CreateText2SqlExampleParam) -> None: + """ + 创建样例 + + :param db: 数据库会话 + :param obj: 创建参数 + :return: + """ + await self.create_model(db, obj) + + async def update(self, db: AsyncSession, pk: int, obj: UpdateText2SqlExampleParam) -> int: + """ + 更新样例 + + :param db: 数据库会话 + :param pk: 记录 ID + :param obj: 更新参数 + :return: + """ + return await self.update_model_by_column(db, obj, id=pk, deleted=0) + + async def delete(self, db: AsyncSession, pk: int) -> int: + """ + 删除样例(逻辑删除) + + :param db: 数据库会话 + :param pk: 记录 ID + :return: + """ + return await self.delete_model_by_column( + db, + logical_deletion=True, + deleted_flag_column='deleted', + deleted_flag_value=self.model.id, + deleted_at_column='deleted_time', + deleted_at_factory=timezone.now(), + id=pk, + deleted=0, + ) + + +text2sql_example_dao: CRUDText2SqlExample = CRUDText2SqlExample(AIText2SqlExample) diff --git a/crud/crud_text2sql_table.py b/crud/crud_text2sql_table.py new file mode 100644 index 0000000..37361ed --- /dev/null +++ b/crud/crud_text2sql_table.py @@ -0,0 +1,139 @@ +from collections.abc import Sequence + +from sqlalchemy import Select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy_crud_plus import CRUDPlus + +from backend.plugin.ai.model import AIText2SqlTable +from backend.plugin.ai.schema.text2sql import CreateText2SqlTableParam, UpdateText2SqlTableParam +from backend.utils.timezone import timezone + + +class CRUDText2SqlTable(CRUDPlus[AIText2SqlTable]): + """Text2SQL 已选数据表数据库操作类""" + + async def get(self, db: AsyncSession, pk: int) -> AIText2SqlTable | None: + """ + 获取已选表 + + :param db: 数据库会话 + :param pk: 记录 ID + :return: + """ + return await self.select_model(db, pk, deleted=0) + + async def get_by_name( + self, + db: AsyncSession, + schema_name: str, + table_name: str, + dataset_id: int, + ) -> AIText2SqlTable | None: + """ + 通过数据集+库名+表名获取已选表(同一数据集内表名唯一) + + :param db: 数据库会话 + :param schema_name: 库名/schema + :param table_name: 表名 + :param dataset_id: 所属数据集 ID + :return: + """ + return await self.select_model_by_column( + db, + schema_name=schema_name, + table_name=table_name, + dataset_id=dataset_id, + deleted=0, + ) + + async def get_enabled(self, db: AsyncSession, dataset_id: int | None = None) -> Sequence[AIText2SqlTable]: + """ + 获取全部启用的已选表(供 Agent 作为可见表集合) + + :param db: 数据库会话 + :param dataset_id: 数据集 ID;传入则仅返回该数据集的表,不传则返回全部 + :return: + """ + filters: dict = {'enabled': 1, 'deleted': 0} + if dataset_id is not None: + filters['dataset_id'] = dataset_id + return await self.select_models(db, **filters) + + async def get_all(self, db: AsyncSession) -> Sequence[AIText2SqlTable]: + """ + 获取全部已选表 + + :param db: 数据库会话 + :return: + """ + return await self.select_models(db, deleted=0) + + async def get_select( + self, + schema_name: str | None, + table_name: str | None, + enabled: int | None, + dataset_id: int | None = None, + ) -> Select: + """ + 获取已选表分页查询 + + :param schema_name: 库名/schema + :param table_name: 表名 + :param enabled: 是否启用 + :param dataset_id: 所属数据集 ID + :return: + """ + filters = {'deleted': 0} + if dataset_id is not None: + filters.update(dataset_id=dataset_id) + if schema_name is not None: + filters.update(schema_name__like=f'%{schema_name}%') + if table_name is not None: + filters.update(table_name__like=f'%{table_name}%') + if enabled is not None: + filters.update(enabled=enabled) + return await self.select_order('sort', 'asc', **filters) + + async def create(self, db: AsyncSession, obj: CreateText2SqlTableParam) -> None: + """ + 创建已选表 + + :param db: 数据库会话 + :param obj: 创建参数 + :return: + """ + await self.create_model(db, obj) + + async def update(self, db: AsyncSession, pk: int, obj: UpdateText2SqlTableParam) -> int: + """ + 更新已选表 + + :param db: 数据库会话 + :param pk: 记录 ID + :param obj: 更新参数 + :return: + """ + return await self.update_model_by_column(db, obj, id=pk, deleted=0) + + async def delete(self, db: AsyncSession, pk: int) -> int: + """ + 删除已选表(逻辑删除) + + :param db: 数据库会话 + :param pk: 记录 ID + :return: + """ + return await self.delete_model_by_column( + db, + logical_deletion=True, + deleted_flag_column='deleted', + deleted_flag_value=self.model.id, + deleted_at_column='deleted_time', + deleted_at_factory=timezone.now(), + id=pk, + deleted=0, + ) + + +text2sql_table_dao: CRUDText2SqlTable = CRUDText2SqlTable(AIText2SqlTable) diff --git a/model/__init__.py b/model/__init__.py index a91bd7d..37e07fb 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -4,3 +4,7 @@ from backend.plugin.ai.model.model import AIModel as AIModel from backend.plugin.ai.model.provider import AIProvider as AIProvider from backend.plugin.ai.model.quick_phrase import AIQuickPhrase as AIQuickPhrase +from backend.plugin.ai.model.text2sql import AIText2SqlDataset as AIText2SqlDataset +from backend.plugin.ai.model.text2sql import AIText2SqlExample as AIText2SqlExample +from backend.plugin.ai.model.text2sql import AIText2SqlHistory as AIText2SqlHistory +from backend.plugin.ai.model.text2sql import AIText2SqlTable as AIText2SqlTable diff --git a/model/text2sql.py b/model/text2sql.py new file mode 100644 index 0000000..3740421 --- /dev/null +++ b/model/text2sql.py @@ -0,0 +1,62 @@ +import sqlalchemy as sa + +from sqlalchemy.orm import Mapped, mapped_column + +from backend.common.model import Base, UniversalText, id_key + + +class AIText2SqlDataset(Base): + """AI Text2SQL 数据集(表与样例的容器,chat 按数据集圈定可见范围)""" + + __tablename__ = 'ai_text2sql_dataset' + + id: Mapped[id_key] = mapped_column(init=False) + name: Mapped[str] = mapped_column(sa.String(128), comment='数据集名称') + description: Mapped[str | None] = mapped_column(UniversalText, default=None, comment='描述') + enabled: Mapped[int] = mapped_column(default=1, comment='是否启用(0停用 1启用)') + sort: Mapped[int] = mapped_column(default=0, comment='排序') + + +class AIText2SqlTable(Base): + """AI Text2SQL 已选数据表(数据源管理)""" + + __tablename__ = 'ai_text2sql_table' + + id: Mapped[id_key] = mapped_column(init=False) + dataset_id: Mapped[int] = mapped_column(sa.BigInteger, index=True, comment='所属数据集 ID') + table_name: Mapped[str] = mapped_column(sa.String(128), comment='表名') + schema_name: Mapped[str] = mapped_column(sa.String(64), default='fba', comment='库名/schema') + table_comment: Mapped[str | None] = mapped_column(sa.String(256), default=None, comment='表注释') + custom_desc: Mapped[str | None] = mapped_column(UniversalText, default=None, comment='自定义语义描述') + enabled: Mapped[int] = mapped_column(default=1, comment='是否启用(0停用 1启用)') + sort: Mapped[int] = mapped_column(default=0, comment='排序') + + +class AIText2SqlExample(Base): + """AI Text2SQL Few-shot 样例""" + + __tablename__ = 'ai_text2sql_example' + + id: Mapped[id_key] = mapped_column(init=False) + dataset_id: Mapped[int] = mapped_column(sa.BigInteger, index=True, comment='所属数据集 ID') + question: Mapped[str] = mapped_column(UniversalText, comment='自然语言问题') + sql: Mapped[str] = mapped_column(UniversalText, comment='示范 SQL') + related_tables: Mapped[str | None] = mapped_column(sa.String(512), default=None, comment='相关表(逗号分隔)') + note: Mapped[str | None] = mapped_column(UniversalText, default=None, comment='备注') + enabled: Mapped[int] = mapped_column(default=1, comment='是否启用(0停用 1启用)') + sort: Mapped[int] = mapped_column(default=0, comment='排序') + + +class AIText2SqlHistory(Base): + """AI Text2SQL 查询历史""" + + __tablename__ = 'ai_text2sql_history' + + id: Mapped[id_key] = mapped_column(init=False) + user_id: Mapped[int] = mapped_column(sa.BigInteger, comment='用户 ID') + question: Mapped[str] = mapped_column(UniversalText, comment='自然语言问题') + sql: Mapped[str | None] = mapped_column(UniversalText, default=None, comment='生成 SQL') + executed: Mapped[int] = mapped_column(default=0, comment='是否已执行(0否 1是)') + row_count: Mapped[int] = mapped_column(default=0, comment='结果行数') + error: Mapped[str | None] = mapped_column(UniversalText, default=None, comment='错误信息') + duration_ms: Mapped[int] = mapped_column(default=0, comment='耗时(毫秒)') diff --git a/plugin.toml b/plugin.toml index 8e8a829..445c0c8 100644 --- a/plugin.toml +++ b/plugin.toml @@ -1,7 +1,7 @@ [plugin] summary = "AI" version = "0.2.0" -description = "为系统提供 AI 赋能" +description = "为系统提供 AI 赋能(含 Text2SQL 自然语言查库)" author = "wu-clan" tags = ["agent", "ai", "mcp"] database = ["mysql", "postgresql"] @@ -13,3 +13,18 @@ router = ["v1"] AI_CODE_MODE_TOOLS = [] AI_HTTP_MAX_RETRIES = 5 AI_MCP_MAX_RETRIES = 1 +# [ Web Search ] +AI_EXA_API_KEY = "" +AI_TAVILY_API_KEY = "" +# [ Text2SQL ] +AI_TEXT2SQL_ENABLED = false +AI_TEXT2SQL_SCHEMA = "fba" +AI_TEXT2SQL_MAX_ROWS = 200 +AI_TEXT2SQL_TIMEOUT = 15 +AI_TEXT2SQL_MAX_RETRIES = 2 +AI_TEXT2SQL_PROVIDER_ID = 0 +AI_TEXT2SQL_MODEL_ID = "" +AI_TEXT2SQL_READONLY_HOST = "" +AI_TEXT2SQL_READONLY_PORT = 0 +AI_TEXT2SQL_READONLY_USER = "" +AI_TEXT2SQL_READONLY_PASSWORD = "" diff --git a/schema/chat.py b/schema/chat.py index 23aee7a..73a1c8f 100644 --- a/schema/chat.py +++ b/schema/chat.py @@ -26,6 +26,7 @@ class AIChatRuntimeParam(AIChatSchemaBase): enable_builtin_tools: bool = Field(default=True, description='是否启用项目内置工具') mcp_ids: list[int] | None = Field(default=None, description='启用的 MCP ID 列表') web_search: AIWebSearchType = Field(default=AIWebSearchType.off, description='网络搜索模式') + text2sql_dataset_id: int | None = Field(default=None, description='Text2SQL 数据集 ID;None 关闭,有值则按该数据集取数') class AIChatModelSettingsParam(AIChatSchemaBase): diff --git a/schema/text2sql.py b/schema/text2sql.py new file mode 100644 index 0000000..83aa808 --- /dev/null +++ b/schema/text2sql.py @@ -0,0 +1,151 @@ +from datetime import datetime +from typing import Any + +from pydantic import ConfigDict, Field + +from backend.common.schema import SchemaBase + + +# ---------------- 数据集 ---------------- + + +class Text2SqlDatasetSchemaBase(SchemaBase): + """Text2SQL 数据集基础模型""" + + name: str = Field(description='数据集名称') + description: str | None = Field(None, description='描述') + enabled: int = Field(1, description='是否启用(0停用 1启用)') + sort: int = Field(0, description='排序') + + +class CreateText2SqlDatasetParam(Text2SqlDatasetSchemaBase): + """新增数据集""" + + +class UpdateText2SqlDatasetParam(SchemaBase): + """更新数据集(部分更新)""" + + name: str | None = Field(None, description='数据集名称') + description: str | None = Field(None, description='描述') + enabled: int | None = Field(None, description='是否启用(0停用 1启用)') + sort: int | None = Field(None, description='排序') + + +class GetText2SqlDatasetDetail(Text2SqlDatasetSchemaBase): + """数据集详情""" + + model_config = ConfigDict(from_attributes=True) + + id: int = Field(description='ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(None, description='更新时间') + + +class Text2SqlDatasetEnabled(SchemaBase): + """启用的数据集(chat 选择器用)""" + + id: int = Field(description='数据集 ID') + name: str = Field(description='数据集名称') + description: str | None = Field(None, description='描述') + + +# ---------------- 已选数据表 ---------------- + + +class Text2SqlTableSchemaBase(SchemaBase): + """Text2SQL 已选数据表基础模型""" + + dataset_id: int = Field(description='所属数据集 ID') + table_name: str = Field(description='表名') + schema_name: str = Field('fba', description='库名/schema') + table_comment: str | None = Field(None, description='表注释') + custom_desc: str | None = Field(None, description='自定义语义描述(喂给 Agent 提升精度)') + enabled: int = Field(1, description='是否启用(0停用 1启用)') + sort: int = Field(0, description='排序') + + +class CreateText2SqlTableParam(Text2SqlTableSchemaBase): + """新增已选表""" + + +class UpdateText2SqlTableParam(SchemaBase): + """更新已选表(部分更新)""" + + table_comment: str | None = Field(None, description='表注释') + custom_desc: str | None = Field(None, description='自定义语义描述') + enabled: int | None = Field(None, description='是否启用(0停用 1启用)') + sort: int | None = Field(None, description='排序') + + +class GetText2SqlTableDetail(Text2SqlTableSchemaBase): + """已选表详情""" + + model_config = ConfigDict(from_attributes=True) + + id: int = Field(description='ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(None, description='更新时间') + + +class Text2SqlTableSelectable(SchemaBase): + """可挑选的数据库表(来自反查,不入库)""" + + table_name: str = Field(description='表名') + table_comment: str | None = Field(None, description='表注释') + selected: bool = Field(False, description='是否已挑选') + + +class Text2SqlExampleSchemaBase(SchemaBase): + """Text2SQL Few-shot 样例基础模型""" + + dataset_id: int = Field(description='所属数据集 ID') + question: str = Field(description='自然语言问题') + sql: str = Field(description='示范 SQL(只读 SELECT)') + related_tables: str | None = Field(None, description='相关表(逗号分隔,用于召回)') + note: str | None = Field(None, description='备注') + enabled: int = Field(1, description='是否启用(0停用 1启用)') + sort: int = Field(0, description='排序') + + +class CreateText2SqlExampleParam(Text2SqlExampleSchemaBase): + """新增 Few-shot 样例""" + + +class UpdateText2SqlExampleParam(SchemaBase): + """更新 Few-shot 样例(部分更新)""" + + question: str | None = Field(None, description='自然语言问题') + sql: str | None = Field(None, description='示范 SQL') + related_tables: str | None = Field(None, description='相关表(逗号分隔)') + note: str | None = Field(None, description='备注') + enabled: int | None = Field(None, description='是否启用(0停用 1启用)') + sort: int | None = Field(None, description='排序') + + +class GetText2SqlExampleDetail(Text2SqlExampleSchemaBase): + """Few-shot 样例详情""" + + model_config = ConfigDict(from_attributes=True) + + id: int = Field(description='ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(None, description='更新时间') + + +class Text2SqlQueryParam(SchemaBase): + """自然语言查询参数""" + + question: str = Field(description='自然语言问题') + dataset_id: int | None = Field(None, description='数据集 ID;不传则使用全部已启用表') + + +class Text2SqlQueryResult(SchemaBase): + """自然语言查询结果""" + + sql: str = Field(description='生成的只读 SQL') + summary: str = Field(description='结果摘要') + columns: list[str] = Field(default_factory=list, description='结果列') + rows: list[dict[str, Any]] = Field(default_factory=list, description='结果行(最多 max_rows)') + row_count: int = Field(0, description='命中总行数') + duration_ms: int = Field(0, description='耗时(毫秒)') + history_id: int | None = Field(None, description='历史记录 ID') diff --git a/service/text2sql_service.py b/service/text2sql_service.py new file mode 100644 index 0000000..aa589da --- /dev/null +++ b/service/text2sql_service.py @@ -0,0 +1,412 @@ +from collections.abc import Sequence +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from backend.common.exception import errors +from backend.common.pagination import paging_data +from backend.core.conf import settings +from backend.plugin.ai.crud.crud_text2sql_dataset import text2sql_dataset_dao +from backend.plugin.ai.crud.crud_text2sql_example import text2sql_example_dao +from backend.plugin.ai.crud.crud_text2sql_table import text2sql_table_dao +from backend.plugin.ai.model import AIText2SqlDataset, AIText2SqlExample, AIText2SqlTable +from backend.plugin.ai.schema.text2sql import ( + CreateText2SqlDatasetParam, + CreateText2SqlExampleParam, + CreateText2SqlTableParam, + Text2SqlDatasetEnabled, + Text2SqlTableSelectable, + UpdateText2SqlDatasetParam, + UpdateText2SqlExampleParam, + UpdateText2SqlTableParam, +) +from backend.plugin.ai.text2sql.schema_meta import get_columns, get_tables + + +class Text2SqlService: + """Text2SQL 服务类""" + + # ---------------- 数据集 ---------------- + + @staticmethod + async def get_dataset(*, db: AsyncSession, pk: int) -> AIText2SqlDataset: + """ + 获取数据集详情 + + :param db: 数据库会话 + :param pk: 记录 ID + :return: + """ + row = await text2sql_dataset_dao.get(db, pk) + if not row: + raise errors.NotFoundError(msg='数据集不存在') + return row + + @staticmethod + async def get_all_datasets(*, db: AsyncSession) -> Sequence[AIText2SqlDataset]: + """ + 获取全部数据集 + + :param db: 数据库会话 + :return: + """ + return await text2sql_dataset_dao.get_all(db) + + @staticmethod + async def get_enabled_datasets(*, db: AsyncSession) -> Sequence[Text2SqlDatasetEnabled]: + """ + 获取全部启用的数据集(供 chat 选择器) + + :param db: 数据库会话 + :return: + """ + rows = await text2sql_dataset_dao.get_enabled(db) + return [ + Text2SqlDatasetEnabled(id=row.id, name=row.name, description=row.description) + for row in rows + ] + + @staticmethod + async def get_dataset_list( + *, + db: AsyncSession, + name: str | None, + enabled: int | None, + ) -> dict[str, Any]: + """ + 分页获取数据集 + + :param db: 数据库会话 + :param name: 数据集名称(模糊) + :param enabled: 是否启用 + :return: + """ + sel = await text2sql_dataset_dao.get_select(name=name, enabled=enabled) + return await paging_data(db, sel) + + @staticmethod + async def create_dataset(*, db: AsyncSession, obj: CreateText2SqlDatasetParam) -> None: + """ + 创建数据集 + + :param db: 数据库会话 + :param obj: 创建参数 + :return: + """ + existing = await text2sql_dataset_dao.get_by_name(db, obj.name) + if existing: + raise errors.ForbiddenError(msg='数据集名称已存在') + await text2sql_dataset_dao.create(db, obj) + + @staticmethod + async def update_dataset(*, db: AsyncSession, pk: int, obj: UpdateText2SqlDatasetParam) -> int: + """ + 更新数据集 + + :param db: 数据库会话 + :param pk: 记录 ID + :param obj: 更新参数 + :return: + """ + row = await text2sql_dataset_dao.get(db, pk) + if not row: + raise errors.NotFoundError(msg='数据集不存在') + return await text2sql_dataset_dao.update(db, pk, obj) + + @staticmethod + async def delete_dataset(*, db: AsyncSession, pk: int) -> int: + """ + 删除数据集(逻辑删除;其下表与样例保留但不再被该数据集聚合) + + :param db: 数据库会话 + :param pk: 记录 ID + :return: + """ + row = await text2sql_dataset_dao.get(db, pk) + if not row: + raise errors.NotFoundError(msg='数据集不存在') + return await text2sql_dataset_dao.delete(db, pk) + + # ---------------- 数据源管理(已选表) ---------------- + + @staticmethod + async def list_selectable_tables( + *, + db: AsyncSession, + dataset_id: int, + table_schema: str | None = None, + ) -> Sequence[Text2SqlTableSelectable]: + """ + 列出可挑选的数据库表(反查 information_schema),并标记在该数据集内是否已挑选 + + :param db: 数据库会话 + :param dataset_id: 所属数据集 ID + :param table_schema: 库名/schema,缺省取 AI_TEXT2SQL_SCHEMA + :return: + """ + schema = table_schema or settings.AI_TEXT2SQL_SCHEMA + rows = await get_tables(db, schema) + selected = await text2sql_table_dao.select_models(db, dataset_id=dataset_id, deleted=0) + selected_names = {row.table_name for row in selected} + return [ + Text2SqlTableSelectable( + table_name=row['table_name'], + table_comment=row['table_comment'], + selected=row['table_name'] in selected_names, + ) + for row in rows + ] + + @staticmethod + async def get_table_columns( + *, + db: AsyncSession, + table_name: str, + table_schema: str | None = None, + ) -> list[dict[str, Any]]: + """ + 获取表列信息(供列预览与拼装 DDL 上下文) + + :param db: 数据库会话 + :param table_name: 表名 + :param table_schema: 库名/schema + :return: + """ + schema = table_schema or settings.AI_TEXT2SQL_SCHEMA + rows = await get_columns(db, schema, table_name) + return [dict(row) for row in rows] + + @staticmethod + async def get_selected(*, db: AsyncSession, pk: int) -> AIText2SqlTable: + """ + 获取已选表详情 + + :param db: 数据库会话 + :param pk: 记录 ID + :return: + """ + row = await text2sql_table_dao.get(db, pk) + if not row: + raise errors.NotFoundError(msg='已选表记录不存在') + return row + + @staticmethod + async def get_all_selected(*, db: AsyncSession) -> Sequence[AIText2SqlTable]: + """ + 获取全部已选表 + + :param db: 数据库会话 + :return: + """ + return await text2sql_table_dao.get_all(db) + + @staticmethod + async def get_enabled(*, db: AsyncSession, dataset_id: int | None = None) -> Sequence[AIText2SqlTable]: + """ + 获取全部启用的已选表(Agent 可见表集合) + + :param db: 数据库会话 + :param dataset_id: 数据集 ID;传入则仅返回该数据集的表 + :return: + """ + return await text2sql_table_dao.get_enabled(db, dataset_id=dataset_id) + + @staticmethod + async def get_selected_list( + *, + db: AsyncSession, + dataset_id: int | None, + schema_name: str | None, + table_name: str | None, + enabled: int | None, + ) -> dict[str, Any]: + """ + 分页获取已选表 + + :param db: 数据库会话 + :param dataset_id: 所属数据集 ID + :param schema_name: 库名/schema + :param table_name: 表名 + :param enabled: 是否启用 + :return: + """ + sel = await text2sql_table_dao.get_select( + schema_name=schema_name, + table_name=table_name, + enabled=enabled, + dataset_id=dataset_id, + ) + return await paging_data(db, sel) + + @staticmethod + async def select_table(*, db: AsyncSession, obj: CreateText2SqlTableParam) -> None: + """ + 挑选表(新增到已选表,同一数据集内表名唯一) + + :param db: 数据库会话 + :param obj: 创建参数 + :return: + """ + existing = await text2sql_table_dao.get_by_name( + db, + schema_name=obj.schema_name, + table_name=obj.table_name, + dataset_id=obj.dataset_id, + ) + if existing: + raise errors.ForbiddenError(msg='该数据集内此表已挑选') + cols = await get_columns(db, obj.schema_name, obj.table_name) + if not cols: + raise errors.NotFoundError(msg='数据库表不存在') + await text2sql_table_dao.create(db, obj) + + @staticmethod + async def update_selected(*, db: AsyncSession, pk: int, obj: UpdateText2SqlTableParam) -> int: + """ + 更新已选表 + + :param db: 数据库会话 + :param pk: 记录 ID + :param obj: 更新参数 + :return: + """ + row = await text2sql_table_dao.get(db, pk) + if not row: + raise errors.NotFoundError(msg='已选表记录不存在') + return await text2sql_table_dao.update(db, pk, obj) + + @staticmethod + async def unselect_table(*, db: AsyncSession, pk: int) -> int: + """ + 取消挑选(逻辑删除已选表) + + :param db: 数据库会话 + :param pk: 记录 ID + :return: + """ + row = await text2sql_table_dao.get(db, pk) + if not row: + raise errors.NotFoundError(msg='已选表记录不存在') + return await text2sql_table_dao.delete(db, pk) + + # ---------------- Few-shot 样例 ---------------- + + @staticmethod + async def get_example(*, db: AsyncSession, pk: int) -> AIText2SqlExample: + """ + 获取样例详情 + + :param db: 数据库会话 + :param pk: 记录 ID + :return: + """ + row = await text2sql_example_dao.get(db, pk) + if not row: + raise errors.NotFoundError(msg='样例不存在') + return row + + @staticmethod + async def get_all_examples(*, db: AsyncSession, dataset_id: int | None = None) -> Sequence[AIText2SqlExample]: + """ + 获取全部样例 + + :param db: 数据库会话 + :param dataset_id: 数据集 ID;传入则仅返回该数据集的样例 + :return: + """ + return await text2sql_example_dao.get_all_enabled(db, dataset_id=dataset_id) + + @staticmethod + async def get_example_list( + *, + db: AsyncSession, + dataset_id: int | None, + question: str | None, + enabled: int | None, + ) -> dict[str, Any]: + """ + 分页获取样例 + + :param db: 数据库会话 + :param dataset_id: 所属数据集 ID + :param question: 自然语言问题(模糊) + :param enabled: 是否启用 + :return: + """ + sel = await text2sql_example_dao.get_select( + question=question, + enabled=enabled, + dataset_id=dataset_id, + ) + return await paging_data(db, sel) + + @staticmethod + async def create_example(*, db: AsyncSession, obj: CreateText2SqlExampleParam) -> None: + """ + 创建样例 + + :param db: 数据库会话 + :param obj: 创建参数 + :return: + """ + await text2sql_example_dao.create(db, obj) + + @staticmethod + async def update_example(*, db: AsyncSession, pk: int, obj: UpdateText2SqlExampleParam) -> int: + """ + 更新样例 + + :param db: 数据库会话 + :param pk: 记录 ID + :param obj: 更新参数 + :return: + """ + row = await text2sql_example_dao.get(db, pk) + if not row: + raise errors.NotFoundError(msg='样例不存在') + return await text2sql_example_dao.update(db, pk, obj) + + @staticmethod + async def delete_example(*, db: AsyncSession, pk: int) -> int: + """ + 删除样例 + + :param db: 数据库会话 + :param pk: 记录 ID + :return: + """ + row = await text2sql_example_dao.get(db, pk) + if not row: + raise errors.NotFoundError(msg='样例不存在') + return await text2sql_example_dao.delete(db, pk) + + @staticmethod + async def get_examples_for( + *, + db: AsyncSession, + tables: set[str], + dataset_id: int | None = None, + limit: int = 5, + ) -> list[dict[str, str]]: + """ + 按命中表召回 Few-shot 样例(供 Agent 提升精度) + + 无 related_tables 的样例视为通用样例,始终召回。 + + :param db: 数据库会话 + :param tables: 本次涉及的表名集合 + :param dataset_id: 数据集 ID;传入则仅在该数据集样例中召回 + :param limit: 最多召回条数 + :return: + """ + examples = await text2sql_example_dao.get_all_enabled(db, dataset_id=dataset_id) + table_set = {name.lower() for name in tables} + matched: list[AIText2SqlExample] = [] + for example in examples: + related = {name.strip().lower() for name in (example.related_tables or '').split(',') if name.strip()} + if not related or (related & table_set): + matched.append(example) + return [{'question': e.question, 'sql': e.sql} for e in matched[:limit]] + + +text2sql_service: Text2SqlService = Text2SqlService() diff --git a/sql/mysql/destroy.sql b/sql/mysql/destroy.sql index 2a423d3..e760304 100644 --- a/sql/mysql/destroy.sql +++ b/sql/mysql/destroy.sql @@ -15,7 +15,11 @@ where name in ( 'AIMcpManage', 'AddAIMcp', 'EditAIMcp', - 'DeleteAIMcp' + 'DeleteAIMcp', + 'AIText2SqlDataset', + 'AddAIText2SqlDataset', + 'EditAIText2SqlDataset', + 'DeleteAIText2SqlDataset' ); delete from sys_menu where name = 'PluginAI'; @@ -26,3 +30,8 @@ drop table if exists ai_quick_phrase; drop table if exists ai_model; drop table if exists ai_provider; drop table if exists ai_mcp; + +drop table if exists ai_text2sql_history; +drop table if exists ai_text2sql_example; +drop table if exists ai_text2sql_table; +drop table if exists ai_text2sql_dataset; diff --git a/sql/mysql/destroy_snowflake.sql b/sql/mysql/destroy_snowflake.sql index 4b94e2c..aa95a3d 100644 --- a/sql/mysql/destroy_snowflake.sql +++ b/sql/mysql/destroy_snowflake.sql @@ -12,7 +12,11 @@ where name in ( 'AIMcpManage', 'AddAIMcp', 'EditAIMcp', - 'DeleteAIMcp' + 'DeleteAIMcp', + 'AIText2SqlDataset', + 'AddAIText2SqlDataset', + 'EditAIText2SqlDataset', + 'DeleteAIText2SqlDataset' ); delete from sys_menu where name = 'PluginAI'; @@ -23,3 +27,8 @@ drop table if exists ai_quick_phrase; drop table if exists ai_model; drop table if exists ai_provider; drop table if exists ai_mcp; + +drop table if exists ai_text2sql_history; +drop table if exists ai_text2sql_example; +drop table if exists ai_text2sql_table; +drop table if exists ai_text2sql_dataset; diff --git a/sql/mysql/init.sql b/sql/mysql/init.sql index b03a81b..2380dac 100644 --- a/sql/mysql/init.sql +++ b/sql/mysql/init.sql @@ -21,6 +21,11 @@ values ('MCP 管理', 'AIMcpManage', '/plugins/ai/mcp', 4, 'simple-icons:modelco set @ai_mcp_menu_id = LAST_INSERT_ID(); +insert into sys_menu (title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time) +values ('数据集', 'AIText2SqlDataset', '/plugins/ai/text2sql', 5, 'mdi:database-search-outline', 1, '/plugins/ai/views/text2sql/index', null, 1, 1, 1, '', null, @ai_menu_id, now(), null); + +set @ai_text2sql_menu_id = LAST_INSERT_ID(); + insert into sys_menu (title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time) values ('新增供应商', 'AddAIProvider', null, 0, null, 2, null, 'ai:provider:add', 1, 0, 1, '', null, @ai_model_service_menu_id, now(), null), @@ -34,4 +39,7 @@ values ('删除快捷短语', 'DeleteAIQuickPhrase', null, 0, null, 2, null, 'ai:quick-phrase:del', 1, 0, 1, '', null, @ai_quick_phrase_menu_id, now(), null), ('新增MCP', 'AddAIMcp', null, 0, null, 2, null, 'ai:mcp:add', 1, 0, 1, '', null, @ai_mcp_menu_id, now(), null), ('修改MCP', 'EditAIMcp', null, 0, null, 2, null, 'ai:mcp:edit', 1, 0, 1, '', null, @ai_mcp_menu_id, now(), null), -('删除MCP', 'DeleteAIMcp', null, 0, null, 2, null, 'ai:mcp:del', 1, 0, 1, '', null, @ai_mcp_menu_id, now(), null); +('删除MCP', 'DeleteAIMcp', null, 0, null, 2, null, 'ai:mcp:del', 1, 0, 1, '', null, @ai_mcp_menu_id, now(), null), +('新增数据集', 'AddAIText2SqlDataset', null, 0, null, 2, null, 'ai:text2sql:dataset:add', 1, 0, 1, '', null, @ai_text2sql_menu_id, now(), null), +('修改数据集', 'EditAIText2SqlDataset', null, 0, null, 2, null, 'ai:text2sql:dataset:edit', 1, 0, 1, '', null, @ai_text2sql_menu_id, now(), null), +('删除数据集', 'DeleteAIText2SqlDataset', null, 0, null, 2, null, 'ai:text2sql:dataset:del', 1, 0, 1, '', null, @ai_text2sql_menu_id, now(), null); diff --git a/sql/mysql/init_snowflake.sql b/sql/mysql/init_snowflake.sql index 0d6b9cc..be0b621 100644 --- a/sql/mysql/init_snowflake.sql +++ b/sql/mysql/init_snowflake.sql @@ -16,3 +16,11 @@ values (2147098509659213836, '新增MCP', 'AddAIMcp', null, 0, null, 2, null, 'ai:mcp:add', 1, 0, 1, '', null, 2147098509659213835, now(), null), (2147098509659213837, '修改MCP', 'EditAIMcp', null, 0, null, 2, null, 'ai:mcp:edit', 1, 0, 1, '', null, 2147098509659213835, now(), null), (2147098509659213838, '删除MCP', 'DeleteAIMcp', null, 0, null, 2, null, 'ai:mcp:del', 1, 0, 1, '', null, 2147098509659213835, now(), null); + +-- Text2SQL 数据集(与 init.sql 对齐) +insert into sys_menu (id, title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time) +values +(2147098509659213840, '数据集', 'AIText2SqlDataset', '/plugins/ai/text2sql', 5, 'mdi:database-search-outline', 1, '/plugins/ai/views/text2sql/index', null, 1, 1, 1, '', null, 2147098509659213824, now(), null), +(2147098509659213841, '新增数据集', 'AddAIText2SqlDataset', null, 0, null, 2, null, 'ai:text2sql:dataset:add', 1, 0, 1, '', null, 2147098509659213840, now(), null), +(2147098509659213842, '修改数据集', 'EditAIText2SqlDataset', null, 0, null, 2, null, 'ai:text2sql:dataset:edit', 1, 0, 1, '', null, 2147098509659213840, now(), null), +(2147098509659213843, '删除数据集', 'DeleteAIText2SqlDataset', null, 0, null, 2, null, 'ai:text2sql:dataset:del', 1, 0, 1, '', null, 2147098509659213840, now(), null); diff --git a/sql/postgresql/destroy.sql b/sql/postgresql/destroy.sql index b5e3328..4151e9a 100644 --- a/sql/postgresql/destroy.sql +++ b/sql/postgresql/destroy.sql @@ -15,7 +15,11 @@ where name in ( 'AIMcpManage', 'AddAIMcp', 'EditAIMcp', - 'DeleteAIMcp' + 'DeleteAIMcp', + 'AIText2SqlDataset', + 'AddAIText2SqlDataset', + 'EditAIText2SqlDataset', + 'DeleteAIText2SqlDataset' ); delete from sys_menu where name = 'PluginAI'; @@ -27,4 +31,9 @@ drop table if exists ai_model; drop table if exists ai_provider; drop table if exists ai_mcp; +drop table if exists ai_text2sql_history; +drop table if exists ai_text2sql_example; +drop table if exists ai_text2sql_table; +drop table if exists ai_text2sql_dataset; + select setval(pg_get_serial_sequence('sys_menu', 'id'), coalesce(max(id), 0) + 1, true) from sys_menu; diff --git a/sql/postgresql/destroy_snowflake.sql b/sql/postgresql/destroy_snowflake.sql index 4b94e2c..aa95a3d 100644 --- a/sql/postgresql/destroy_snowflake.sql +++ b/sql/postgresql/destroy_snowflake.sql @@ -12,7 +12,11 @@ where name in ( 'AIMcpManage', 'AddAIMcp', 'EditAIMcp', - 'DeleteAIMcp' + 'DeleteAIMcp', + 'AIText2SqlDataset', + 'AddAIText2SqlDataset', + 'EditAIText2SqlDataset', + 'DeleteAIText2SqlDataset' ); delete from sys_menu where name = 'PluginAI'; @@ -23,3 +27,8 @@ drop table if exists ai_quick_phrase; drop table if exists ai_model; drop table if exists ai_provider; drop table if exists ai_mcp; + +drop table if exists ai_text2sql_history; +drop table if exists ai_text2sql_example; +drop table if exists ai_text2sql_table; +drop table if exists ai_text2sql_dataset; diff --git a/sql/postgresql/init.sql b/sql/postgresql/init.sql index d242988..b9f2309 100644 --- a/sql/postgresql/init.sql +++ b/sql/postgresql/init.sql @@ -4,6 +4,7 @@ declare ai_model_service_menu_id bigint; ai_quick_phrase_menu_id bigint; ai_mcp_menu_id bigint; + ai_text2sql_menu_id bigint; begin insert into sys_menu (title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time) values ('AI', 'PluginAI', '/plugins/ai', 11, 'tabler:robot', 0, null, null, 1, 1, 1, '', null, null, now(), null) @@ -24,6 +25,10 @@ begin values ('MCP 管理', 'AIMcpManage', '/plugins/ai/mcp', 4, 'simple-icons:modelcontextprotocol', 1, '/plugins/ai/views/mcp/index', null, 1, 1, 1, '', null, ai_menu_id, now(), null) returning id into ai_mcp_menu_id; + insert into sys_menu (title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time) + values ('数据集', 'AIText2SqlDataset', '/plugins/ai/text2sql', 5, 'mdi:database-search-outline', 1, '/plugins/ai/views/text2sql/index', null, 1, 1, 1, '', null, ai_menu_id, now(), null) + returning id into ai_text2sql_menu_id; + insert into sys_menu (title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time) values ('新增供应商', 'AddAIProvider', null, 0, null, 2, null, 'ai:provider:add', 1, 0, 1, '', null, ai_model_service_menu_id, now(), null), @@ -37,7 +42,10 @@ begin ('删除快捷短语', 'DeleteAIQuickPhrase', null, 0, null, 2, null, 'ai:quick-phrase:del', 1, 0, 1, '', null, ai_quick_phrase_menu_id, now(), null), ('新增MCP', 'AddAIMcp', null, 0, null, 2, null, 'ai:mcp:add', 1, 0, 1, '', null, ai_mcp_menu_id, now(), null), ('修改MCP', 'EditAIMcp', null, 0, null, 2, null, 'ai:mcp:edit', 1, 0, 1, '', null, ai_mcp_menu_id, now(), null), - ('删除MCP', 'DeleteAIMcp', null, 0, null, 2, null, 'ai:mcp:del', 1, 0, 1, '', null, ai_mcp_menu_id, now(), null); + ('删除MCP', 'DeleteAIMcp', null, 0, null, 2, null, 'ai:mcp:del', 1, 0, 1, '', null, ai_mcp_menu_id, now(), null), + ('新增数据集', 'AddAIText2SqlDataset', null, 0, null, 2, null, 'ai:text2sql:dataset:add', 1, 0, 1, '', null, ai_text2sql_menu_id, now(), null), + ('修改数据集', 'EditAIText2SqlDataset', null, 0, null, 2, null, 'ai:text2sql:dataset:edit', 1, 0, 1, '', null, ai_text2sql_menu_id, now(), null), + ('删除数据集', 'DeleteAIText2SqlDataset', null, 0, null, 2, null, 'ai:text2sql:dataset:del', 1, 0, 1, '', null, ai_text2sql_menu_id, now(), null); end $$; select setval(pg_get_serial_sequence('sys_menu', 'id'), coalesce(max(id), 0) + 1, true) from sys_menu; diff --git a/sql/postgresql/init_snowflake.sql b/sql/postgresql/init_snowflake.sql index 0d6b9cc..be0b621 100644 --- a/sql/postgresql/init_snowflake.sql +++ b/sql/postgresql/init_snowflake.sql @@ -16,3 +16,11 @@ values (2147098509659213836, '新增MCP', 'AddAIMcp', null, 0, null, 2, null, 'ai:mcp:add', 1, 0, 1, '', null, 2147098509659213835, now(), null), (2147098509659213837, '修改MCP', 'EditAIMcp', null, 0, null, 2, null, 'ai:mcp:edit', 1, 0, 1, '', null, 2147098509659213835, now(), null), (2147098509659213838, '删除MCP', 'DeleteAIMcp', null, 0, null, 2, null, 'ai:mcp:del', 1, 0, 1, '', null, 2147098509659213835, now(), null); + +-- Text2SQL 数据集(与 init.sql 对齐) +insert into sys_menu (id, title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time) +values +(2147098509659213840, '数据集', 'AIText2SqlDataset', '/plugins/ai/text2sql', 5, 'mdi:database-search-outline', 1, '/plugins/ai/views/text2sql/index', null, 1, 1, 1, '', null, 2147098509659213824, now(), null), +(2147098509659213841, '新增数据集', 'AddAIText2SqlDataset', null, 0, null, 2, null, 'ai:text2sql:dataset:add', 1, 0, 1, '', null, 2147098509659213840, now(), null), +(2147098509659213842, '修改数据集', 'EditAIText2SqlDataset', null, 0, null, 2, null, 'ai:text2sql:dataset:edit', 1, 0, 1, '', null, 2147098509659213840, now(), null), +(2147098509659213843, '删除数据集', 'DeleteAIText2SqlDataset', null, 0, null, 2, null, 'ai:text2sql:dataset:del', 1, 0, 1, '', null, 2147098509659213840, now(), null); diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..9a49ea4 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""AI 插件测试包""" diff --git a/tests/test_guardrails.py b/tests/test_guardrails.py new file mode 100644 index 0000000..5d2e0e4 --- /dev/null +++ b/tests/test_guardrails.py @@ -0,0 +1,110 @@ +"""Text2SQL 安全护栏单测(纯逻辑,无 DB) + +覆盖:放行普通 SELECT/JOIN/CTE/子查询/已有 LIMIT;拦截非 SELECT、多语句、INTO、越表(含 CTE 体/UNION 内越表)。 +""" + +import pytest + +from backend.plugin.ai.text2sql.exceptions import TableNotAllowedError, UnsafeSqlError +from backend.plugin.ai.text2sql.guardrails import validate_and_normalize + +ALLOWLIST = {'users', 'orders', 'products'} + + +def _guard(sql: str, **kwargs: object) -> str: + allowlist = kwargs.get('allowlist', ALLOWLIST) # type: ignore[arg-type] + max_rows = kwargs.get('max_rows', 100) # type: ignore[arg-type] + return validate_and_normalize(sql, allowlist=allowlist, max_rows=max_rows) + + +def test_allows_plain_select_injects_limit() -> None: + sql = _guard('SELECT * FROM users') + assert 'LIMIT 100' in sql.upper() + + +def test_allows_join_and_where() -> None: + sql = _guard('SELECT u.id FROM users u JOIN orders o ON o.user_id = u.id WHERE u.id > 0') + assert 'LIMIT 100' in sql.upper() + + +def test_allows_left_join() -> None: + sql = _guard('SELECT u.id FROM users u LEFT JOIN orders o ON o.user_id = u.id') + assert 'LIMIT 100' in sql.upper() + + +def test_allows_cte() -> None: + sql = _guard('WITH active AS (SELECT id FROM users WHERE deleted = 0) SELECT * FROM active') + assert 'LIMIT 100' in sql.upper() + + +def test_allows_subquery_in_from() -> None: + sql = _guard('SELECT * FROM (SELECT id FROM users) sub') + assert 'LIMIT 100' in sql.upper() + + +def test_preserves_existing_limit() -> None: + sql = _guard('SELECT * FROM users LIMIT 5') + assert sql.upper().count('LIMIT') == 1 + + +def test_accepts_lowercase_select() -> None: + sql = _guard('select id from users') + assert 'LIMIT 100' in sql.upper() + + +def test_accepts_multiline_and_comment() -> None: + sql = _guard('/* get users */\nSELECT\n id\nFROM\n users\nWHERE 1 = 1') + assert 'LIMIT 100' in sql.upper() + + +def test_uses_custom_max_rows() -> None: + sql = _guard('SELECT * FROM users', max_rows=42) + assert 'LIMIT 42' in sql.upper() + + +@pytest.mark.parametrize( + ['sql'], + [ + ["INSERT INTO users (id) VALUES (1)"], + ["UPDATE users SET name = 'x'"], + ['DELETE FROM users'], + ['DROP TABLE users'], + ['ALTER TABLE users ADD COLUMN x INT'], + ['TRUNCATE TABLE users'], + ['CREATE TABLE x (id INT)'], + ['GRANT SELECT ON users TO bob'], + ], +) +def test_blocks_non_select(sql: str) -> None: + with pytest.raises(UnsafeSqlError): + _guard(sql) + + +def test_blocks_multi_statement() -> None: + with pytest.raises(UnsafeSqlError): + _guard('SELECT * FROM users; DROP TABLE users') + + +def test_blocks_select_into() -> None: + with pytest.raises(UnsafeSqlError): + _guard('SELECT * INTO new_tbl FROM users') + + +def test_blocks_union_exfil_from_non_allowlisted() -> None: + with pytest.raises(TableNotAllowedError): + _guard('SELECT id FROM users UNION SELECT password FROM sys_user') + + +def test_blocks_non_allowlisted_table_inside_cte() -> None: + with pytest.raises(TableNotAllowedError): + _guard('WITH active AS (SELECT * FROM secrets) SELECT * FROM active') + + +def test_blocks_non_allowlisted_table() -> None: + with pytest.raises(TableNotAllowedError): + _guard('SELECT * FROM secrets') + + +def test_blocks_empty_sql() -> None: + with pytest.raises(UnsafeSqlError): + _guard(' ') diff --git a/text2sql/__init__.py b/text2sql/__init__.py new file mode 100644 index 0000000..d0e592d --- /dev/null +++ b/text2sql/__init__.py @@ -0,0 +1 @@ +"""AI Text2SQL 子系统:基于原生 Pydantic AI 的自然语言查库(只读安全)。""" diff --git a/text2sql/engine.py b/text2sql/engine.py new file mode 100644 index 0000000..8fb7530 --- /dev/null +++ b/text2sql/engine.py @@ -0,0 +1,307 @@ +"""Text2SQL 核心引擎(原生 Pydantic AI)。 + +流程:解析已挑选表 + few-shot → 复用供应商适配器建 pydantic-ai 模型 → 构建 Agent +(system_prompt + list_tables/describe_table/execute_sql 工具,execute_sql 强制过护栏) +→ 运行得到 {sql, summary} → 对最终 SQL 再次过护栏并在只读引擎执行取数 → 落历史。 + +注意:pydantic-ai 为 2.x beta,本模块用稳定的装饰器 API(@agent.system_prompt / @agent.tool), +规避构造参数命名差异;上线前务必用真实模型跑通一次(见 README/计划)。 +""" + +import asyncio +import json +import time +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any + +from pydantic import BaseModel, Field +from pydantic_ai import Agent, RunContext +from sqlalchemy import text as sa_text +from sqlalchemy.ext.asyncio import AsyncSession + +from backend.common.exception import errors +from backend.common.log import log +from backend.core.conf import settings +from backend.database.db import async_db_session +from backend.plugin.ai.chat.session import AgentSession +from backend.plugin.ai.crud.crud_model import ai_model_dao +from backend.plugin.ai.crud.crud_provider import ai_provider_dao +from backend.plugin.ai.model import AIText2SqlHistory, AIText2SqlTable +from backend.plugin.ai.providers.registry import get_provider_adapter +from backend.plugin.ai.text2sql.exceptions import TableNotAllowedError, UnsafeSqlError +from backend.plugin.ai.text2sql.guardrails import validate_and_normalize +from backend.plugin.ai.text2sql.readonly_db import get_readonly_session +from backend.plugin.ai.text2sql.schema_meta import get_columns + + +@dataclass +class Text2SqlDeps: + """Agent 运行依赖""" + + readonly_session: AsyncSession + schema: str + tables: list[tuple[str, str | None]] # 启用的已选表 (name, 描述) + allowlist: set[str] + examples: list[dict[str, str]] + max_rows: int + timeout: int + + +class Text2SqlResult(BaseModel): + """Agent 结构化输出""" + + sql: str = Field(description='最终只读 SELECT SQL') + summary: str = Field(description='对查询结果的中文摘要') + + +def _build_prompt(deps: Text2SqlDeps) -> str: + tables = '\n'.join(f'- {name}:{comment or "(无注释)"}' for name, comment in deps.tables) or '- (无)' + examples = '' + if deps.examples: + blocks = '\n'.join(f'问题:{e["question"]}\nSQL:{e["sql"]}' for e in deps.examples) + examples = f'\n\n参考示例:\n{blocks}' + return ( + '你是 FBA Text2SQL 助手,将用户的自然语言问题转为只读 SQL,并给出结果摘要。\n' + f'数据库 schema:{deps.schema}\n' + f'仅可查询以下表(严禁查询其他表):\n{tables}{examples}\n\n' + '规则:\n' + '1. 只生成只读 SELECT(禁止 INSERT/UPDATE/DELETE/DDL/多语句/SELECT INTO)。\n' + '2. 只查询上面列出的表。\n' + '3. 需要列信息时调用 describe_table;验证查询时调用 execute_sql。\n' + '4. 若 execute_sql 报错,请根据错误修正 SQL 后重试。\n' + '5. 最终输出包含字段:sql(最终 SQL)与 summary(对结果的中文摘要)。' + ) + + +def _build_agent(model: Any) -> Agent: # type: ignore[type-arg] + agent = Agent( + model=model, + deps_type=Text2SqlDeps, + output_type=Text2SqlResult, + retries=int(settings.AI_TEXT2SQL_MAX_RETRIES), + ) + + @agent.system_prompt + def _system_prompt(ctx: RunContext[Text2SqlDeps]) -> str: + return _build_prompt(ctx.deps) + + @agent.tool + async def list_tables(ctx: RunContext[Text2SqlDeps]) -> str: + items = '\n'.join(f'- {name}:{comment or "(无注释)"}' for name, comment in ctx.deps.tables) + return f'可查询表:\n{items}' + + @agent.tool + async def describe_table(ctx: RunContext[Text2SqlDeps], table_name: str) -> str: + if table_name.lower() not in ctx.deps.allowlist: + return f'表 {table_name} 不在可查询范围内' + rows = await get_columns(ctx.deps.readonly_session, ctx.deps.schema, table_name) + lines = [ + f"- {row['column_name']} {row['column_type']}{' [PK]' if row['is_pk'] else ''}:{row['column_comment'] or ''}" + for row in rows + ] + return f'表 {table_name} 列:\n' + '\n'.join(lines) + + @agent.tool + async def execute_sql(ctx: RunContext[Text2SqlDeps], sql: str) -> str: + try: + safe = validate_and_normalize(sql, allowlist=ctx.deps.allowlist, max_rows=ctx.deps.max_rows) + except (UnsafeSqlError, TableNotAllowedError) as exc: + return f'校验失败,请改写:{exc}' + try: + result = await asyncio.wait_for( + ctx.deps.readonly_session.execute(sa_text(safe)), + timeout=ctx.deps.timeout, + ) + except TimeoutError: + return '执行超时,请简化查询' + except Exception as exc: # noqa: BLE001 + return f'执行失败,请检查 SQL:{exc}' + rows = result.mappings().all() + columns = list(rows[0].keys()) if rows else [] + preview = [dict(row) for row in rows[:20]] + return ( + f'列:{columns}\n命中行数:{len(rows)}\n' + f'预览(前 {len(preview)} 行):\n{json.dumps(preview, ensure_ascii=False, default=str)}' + ) + + return agent + + +async def _resolve_model( + *, + db: AsyncSession, + provider_id: int | None = None, + model_id: str | None = None, +) -> tuple[Any, Any]: + """ + 解析 text2sql 使用的供应商与模型 + + 优先使用调用方传入的 provider_id / model_id(例如聊天会话当前选择的模型); + 均未传入时回退到 .env 配置的 AI_TEXT2SQL_PROVIDER_ID / AI_TEXT2SQL_MODEL_ID。 + + :param db: 数据库会话 + :param provider_id: 供应商 ID,缺省回退配置 + :param model_id: 模型 ID,缺省回退配置 + :return: (provider, model) + """ + if not provider_id or not model_id: + provider_id = settings.AI_TEXT2SQL_PROVIDER_ID + model_id = settings.AI_TEXT2SQL_MODEL_ID + if not provider_id or not model_id: + raise errors.RequestError(msg='未配置 Text2SQL 模型,请在 .env 设置 AI_TEXT2SQL_PROVIDER_ID 与 AI_TEXT2SQL_MODEL_ID') + provider = await ai_provider_dao.get(db, provider_id) + if not provider or not provider.status: + raise errors.RequestError(msg='Text2SQL 供应商不可用,请检查供应商状态') + model = await ai_model_dao.get_by_model_and_provider(db, model_id, provider_id) + if not model or not model.status: + raise errors.RequestError(msg='Text2SQL 模型不可用,请检查模型状态') + return provider, model + + +async def _execute_final( + *, + readonly_session: AsyncSession, + sql: str, + allowlist: set[str], + max_rows: int, +) -> tuple[list[str], list[dict[str, Any]], int]: + """ + 对最终 SQL 再次过护栏并在只读引擎执行,返回列、数据、总行数 + + :raises UnsafeSqlError: 未通过护栏 + :raises TableNotAllowedError: 越表 + """ + safe = validate_and_normalize(sql, allowlist=allowlist, max_rows=max_rows) + result = await readonly_session.execute(sa_text(safe)) + rows = result.mappings().all() + columns = list(rows[0].keys()) if rows else [] + data = [dict(row) for row in rows[:max_rows]] + return columns, data, len(rows) + + +async def run_query( + *, + db: AsyncSession, + question: str, + user_id: int, + selected_tables: Sequence[AIText2SqlTable], + examples: list[dict[str, str]], + provider_id: int | None = None, + model_id: str | None = None, +) -> dict[str, Any]: + """ + 执行一次 Text2SQL 查询(端到端) + + :param db: 主库会话(用于解析模型,只读) + :param question: 自然语言问题 + :param user_id: 用户 ID + :param selected_tables: 启用的已选表 + :param examples: 召回的 few-shot 样例 + :param provider_id: 指定供应商 ID(缺省回退配置);聊天 capability 传入会话当前模型 + :param model_id: 指定模型 ID(缺省回退配置);聊天 capability 传入会话当前模型 + :return: {sql, summary, columns, rows, row_count, duration_ms, history_id} + """ + if not selected_tables: + raise errors.RequestError(msg='尚未挑选任何数据表,请先在数据源管理中挑选') + + schema = settings.AI_TEXT2SQL_SCHEMA + allowlist = {table.table_name.lower() for table in selected_tables} + tables_info = [(table.table_name, table.custom_desc or table.table_comment) for table in selected_tables] + max_rows = int(settings.AI_TEXT2SQL_MAX_ROWS) + timeout = int(settings.AI_TEXT2SQL_TIMEOUT) + + provider, model = await _resolve_model(db=db, provider_id=provider_id, model_id=model_id) + adapter = get_provider_adapter(provider.type) + adapter.validate_model_id(model.model_id) + + started = time.perf_counter() + session = await AgentSession.open( + adapter=adapter, + model_name=model.model_id, + api_key=provider.api_key, + base_url=provider.api_host, + ) + final_sql: str | None = None + summary = '' + columns: list[str] = [] + rows: list[dict[str, Any]] = [] + row_count = 0 + error_msg: str | None = None + try: + agent = _build_agent(session.model) + async with get_readonly_session() as readonly_session: + deps = Text2SqlDeps( + readonly_session=readonly_session, + schema=schema, + tables=tables_info, + allowlist=allowlist, + examples=examples, + max_rows=max_rows, + timeout=timeout, + ) + run_result = await agent.run(question, deps=deps) + final_sql = run_result.output.sql + summary = run_result.output.summary + columns, rows, row_count = await _execute_final( + readonly_session=readonly_session, + sql=final_sql, + allowlist=allowlist, + max_rows=max_rows, + ) + except Exception as exc: # noqa: BLE001 + error_msg = str(exc) + log.error(f'Text2SQL 查询失败: {exc}') + raise + finally: + duration_ms = int((time.perf_counter() - started) * 1000) + await session.aclose() + + # 历史记录(best-effort,不影响主流程) + history_id = await _write_history( + user_id=user_id, + question=question, + sql=final_sql, + row_count=row_count, + duration_ms=duration_ms, + error=error_msg, + ) + + return { + 'sql': final_sql or '', + 'summary': summary, + 'columns': columns, + 'rows': rows, + 'row_count': row_count, + 'duration_ms': duration_ms, + 'history_id': history_id, + } + + +async def _write_history( + *, + user_id: int, + question: str, + sql: str | None, + row_count: int, + duration_ms: int, + error: str | None, +) -> int | None: + """best-effort 写入查询历史,失败仅告警""" + try: + async with async_db_session.begin() as hist_db: + record = AIText2SqlHistory( + user_id=user_id, + question=question, + sql=sql, + executed=0 if error else 1, + row_count=row_count, + error=error, + duration_ms=duration_ms, + ) + hist_db.add(record) + await hist_db.flush() + return record.id + except Exception as exc: # noqa: BLE001 + log.warning(f'Text2SQL 历史写入失败: {exc}') + return None diff --git a/text2sql/exceptions.py b/text2sql/exceptions.py new file mode 100644 index 0000000..8349dda --- /dev/null +++ b/text2sql/exceptions.py @@ -0,0 +1,24 @@ +"""Text2SQL 异常定义。 + + +服务层负责将这些异常映射为对应的 HTTP 错误(见 backend/common/exception/errors): +- UnsafeSqlError -> 400(SQL 未通过安全护栏) +- TableNotAllowedError -> 400(引用了未挑选/未授权的表) +- Text2SqlTimeoutError -> 408(执行超时) +""" + + +class Text2SqlError(Exception): + """Text2SQL 基础异常""" + + +class UnsafeSqlError(Text2SqlError): + """SQL 未通过安全护栏(非只读 / 多语句 / 危险关键字 等)""" + + +class TableNotAllowedError(UnsafeSqlError): + """SQL 引用了未挑选或未授权的表""" + + +class Text2SqlTimeoutError(Text2SqlError): + """SQL 执行超时""" diff --git a/text2sql/guardrails.py b/text2sql/guardrails.py new file mode 100644 index 0000000..1a867de --- /dev/null +++ b/text2sql/guardrails.py @@ -0,0 +1,88 @@ +"""Text2SQL 安全护栏。 + +设计原则(fail-closed): +1. 只允许**单条只读 SELECT**(基于 sqlparse.get_type,权威判定,避免关键字误伤列名)。 +2. 拒绝多语句(防 `;` 注入第二条写语句)。 +3. 拒绝 `SELECT ... INTO`(PG 建表 / MySQL 写文件·变量等副作用)。 +4. **表白名单**:用正则全文扫描 FROM/JOIN 引用表,覆盖子查询体、CTE 体、UNION 等 + 任意嵌套位置——管理员控制 AI 可见表的边界,防越表取数(如 sys_user 密码)。 + CTE 名视为查询内合法“虚拟表”,不要求在白名单中。 +5. 缺 LIMIT 时注入 `LIMIT `;已有 LIMIT 则保留(行数上限在执行层兜底)。 + +本模块保持纯净(仅依赖 sqlparse + re + 自身 exceptions),便于无 DB 单测。 +真正的只读权限与超时由执行层(只读账号 + readonly_db)兜底,护栏是应用层第一道关。 + +已知 v1 局限(由只读账号兜底,且为 fail-closed 偏向拒绝): +- 注释/字符串字面量里出现 `from/join <词>` 可能被误当作表名而拒绝(极少见)。 +- 列名恰好为 `from`(如 `SELECT from FROM t`)等保留字边界场景。 +""" + +import re + +import sqlparse + +from backend.plugin.ai.text2sql.exceptions import TableNotAllowedError, UnsafeSqlError + +_INTO_RE = re.compile(r'\binto\b', re.IGNORECASE) +_HAS_LIMIT_RE = re.compile(r'\blimit\b', re.IGNORECASE) +# FROM/JOIN 后的表名(支持 schema.table,取末段去引号) +_TABLE_REF_RE = re.compile(r'\b(?:from|join)\s+([A-Za-z_]\w*(?:\.[A-Za-z_]\w*)?)', re.IGNORECASE) +# CTE 名称:WITH [RECURSIVE] name AS ( 或 , name AS ( +_CTE_NAME_RE = re.compile(r'(?:\bwith\b\s+(?:recursive\s+)?|,)\s*([A-Za-z_]\w*)\s+as\s*\(', re.IGNORECASE) + + +def _extract_table_names(text: str) -> set[str]: + """ + 全文提取 FROM/JOIN 引用的表名(覆盖子查询体 / CTE 体 / UNION 等嵌套位置) + + :param text: SQL 文本 + :return: 小写裸表名集合 + """ + names: set[str] = set() + for match in _TABLE_REF_RE.findall(text): + names.add(match.split('.')[-1].strip().lower()) + return names + + +def _extract_cte_names(text: str) -> set[str]: + """提取 CTE 名称(视为查询内合法的“虚拟表”,不要求在 DB 白名单中)""" + return {name.lower() for name in _CTE_NAME_RE.findall(text)} + + +def validate_and_normalize(sql: str, *, allowlist: set[str], max_rows: int) -> str: + """ + 校验并归一化 SQL:仅放行单条只读 SELECT,强制表白名单与 LIMIT + + :param sql: 待校验 SQL + :param allowlist: 允许查询的表名集合(小写裸名) + :param max_rows: 缺省 LIMIT 行数上限 + :return: 归一化后的安全 SQL(缺 LIMIT 时已注入) + :raises UnsafeSqlError: 空 / 多语句 / 非 SELECT / 含 INTO 等 + :raises TableNotAllowedError: 引用了不在白名单内的表 + """ + if not sql or not sql.strip(): + raise UnsafeSqlError('SQL 为空') + + statements = [s for s in sqlparse.parse(sql) if s.token_first(skip_ws=True, skip_cm=True) is not None] + if len(statements) != 1: + raise UnsafeSqlError('仅允许单条 SQL 语句') + + statement = statements[0] + if statement.get_type() != 'SELECT': + raise UnsafeSqlError('仅允许 SELECT 查询') + + text = str(statement) + if _INTO_RE.search(text): + raise UnsafeSqlError('禁止 SELECT ... INTO 等带副作用写操作') + + referenced = _extract_table_names(text) + allowed = {name.lower() for name in allowlist} | _extract_cte_names(text) + disallowed = referenced - allowed + if disallowed: + raise TableNotAllowedError(f'引用了未授权的表: {", ".join(sorted(disallowed))}') + + normalized = statement.value.strip().rstrip(';').strip() + if not _HAS_LIMIT_RE.search(normalized): + normalized = f'{normalized}\nLIMIT {max(1, int(max_rows))}' + + return normalized diff --git a/text2sql/readonly_db.py b/text2sql/readonly_db.py new file mode 100644 index 0000000..c5ddc3f --- /dev/null +++ b/text2sql/readonly_db.py @@ -0,0 +1,90 @@ +from typing import Any + +from sqlalchemy import URL +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine + +from backend.common.enums import DataBaseType +from backend.common.log import log +from backend.core.conf import settings + +# 只读引擎与会话工厂(懒加载) +_readonly_engine: AsyncEngine | None = None +_readonly_session_maker: async_sessionmaker[AsyncSession | Any] | None = None + + +def _readonly_url_or_none() -> URL | None: + """ + 根据只读账号配置生成连接 URL + + :return: 配置完整时返回只读连接 URL,否则返回 None(调用方回退主库) + """ + host = settings.AI_TEXT2SQL_READONLY_HOST + user = settings.AI_TEXT2SQL_READONLY_USER + password = settings.AI_TEXT2SQL_READONLY_PASSWORD + if not host or not user: + return None + + driver = 'mysql+asyncmy' if DataBaseType.mysql == settings.DATABASE_TYPE else 'postgresql+asyncpg' + url = URL.create( + drivername=driver, + username=user, + password=password, + host=host, + port=settings.AI_TEXT2SQL_READONLY_PORT or settings.DATABASE_PORT, + database=settings.DATABASE_SCHEMA, + ) + if DataBaseType.mysql == settings.DATABASE_TYPE: + url = url.update_query_dict({'charset': settings.DATABASE_CHARSET}) + return url + + +def get_readonly_engine() -> AsyncEngine: + """ + 获取只读引擎 + + 未配置只读账号时回退主库引擎并告警(此时更依赖服务端 sqlparse 护栏兜底)。 + + :return: 只读异步引擎 + """ + global _readonly_engine, _readonly_session_maker + if _readonly_engine is not None: + return _readonly_engine + + url = _readonly_url_or_none() + if url is None: + log.warning('AI Text2SQL 未配置只读账号(AI_TEXT2SQL_READONLY_*),回退主库引擎并强制护栏') + from backend.database.db import async_engine as main_engine # noqa: PLC0415 + + _readonly_engine = main_engine + else: + _readonly_engine = create_async_engine( + url, + future=True, + pool_size=5, + max_overflow=10, + pool_timeout=30, + pool_recycle=3600, + pool_pre_ping=True, + ) + + _readonly_session_maker = async_sessionmaker( + bind=_readonly_engine, + class_=AsyncSession, + autoflush=False, + expire_on_commit=False, + ) + return _readonly_engine + + +def get_readonly_session() -> AsyncSession: + """ + 获取只读会话(上下文管理器) + + 用法:`async with get_readonly_session() as session: ...` + + :return: 只读异步会话 + """ + if _readonly_session_maker is None: + get_readonly_engine() + assert _readonly_session_maker is not None + return _readonly_session_maker() diff --git a/text2sql/schema_meta.py b/text2sql/schema_meta.py new file mode 100644 index 0000000..eb425da --- /dev/null +++ b/text2sql/schema_meta.py @@ -0,0 +1,120 @@ +from collections.abc import Sequence + +from sqlalchemy import RowMapping, text +from sqlalchemy.ext.asyncio import AsyncSession + +from backend.common.enums import DataBaseType +from backend.core.conf import settings + + +async def get_tables(db: AsyncSession, table_schema: str) -> Sequence[RowMapping]: + """ + 获取指定 schema 下可被 Text2SQL 挑选的业务表 + + 复用 code_generator 的反查思路,排除插件自身表(ai_*)与代码生成表(gen_*)。 + + :param db: 数据库会话 + :param table_schema: 数据库 schema 名称 + :return: + """ + if DataBaseType.mysql == settings.DATABASE_TYPE: + sql = """ + select + table_name as table_name, + table_comment as table_comment + from + information_schema.tables + where + table_schema = :table_schema + and table_name not like 'ai_%' + and table_name not like 'gen_%' + order by + table_name; + """ + else: + sql = """ + select + c.relname as table_name, + obj_description (c.oid) as table_comment + from + pg_class c + left join pg_namespace n on n.oid = c.relnamespace + where + c.relkind = 'r' + and n.nspname = :table_schema + and c.relname not like 'ai_%' + and c.relname not like 'gen_%' + order by + c.relname; + """ + result = await db.execute(text(sql).bindparams(table_schema=table_schema)) + return result.mappings().all() + + +async def get_columns(db: AsyncSession, table_schema: str, table_name: str) -> Sequence[RowMapping]: + """ + 获取表的列信息(含主键、可空、注释、类型),用于拼装 DDL 上下文与列预览 + + :param db: 数据库会话 + :param table_schema: 数据库 schema 名称 + :param table_name: 表名 + :return: + """ + if DataBaseType.mysql == settings.DATABASE_TYPE: + sql = """ + select + column_name as column_name, + case + when column_key = 'pri' then 1 + else 0 + end as is_pk, + case + when is_nullable = 'no' or column_key = 'pri' then 0 + else 1 + end as is_nullable, + ordinal_position as sort, + column_comment as column_comment, + column_type as column_type + from + information_schema.columns + where + table_name = :table_name + and table_schema = :table_schema + order by + sort; + """ + else: + sql = """ + select + a.attname as column_name, + case + when exists ( + select 1 from pg_constraint c + where c.conrelid = t.oid and c.contype = 'p' and a.attnum = any (c.conkey) + ) then 1 + else 0 + end as is_pk, + case + when a.attnotnull or exists ( + select 1 from pg_constraint c + where c.conrelid = t.oid and c.contype = 'p' and a.attnum = any (c.conkey) + ) then 0 + else 1 + end as is_nullable, + a.attnum as sort, + col_description (t.oid, a.attnum) as column_comment, + pg_catalog.format_type (a.atttypid, a.atttypmod) as column_type + from + pg_attribute a + join pg_class t on a.attrelid = t.oid + join pg_namespace n on n.oid = t.relnamespace + where + a.attnum > 0 + and not a.attisdropped + and t.relname = :table_name + and n.nspname = :table_schema + order by + sort; + """ + result = await db.execute(text(sql).bindparams(table_schema=table_schema, table_name=table_name)) + return result.mappings().all() From 09c56a6b478cc5f27edffc13f1349eb59045bc28 Mon Sep 17 00:00:00 2001 From: dengjingren Date: Thu, 25 Jun 2026 23:03:15 +0800 Subject: [PATCH 2/4] fix(text2sql): security hardening + default-model consolidation Address upstream review blockers for the Text2SQL feature. Security (fail-closed): - Rewrite guardrail with sqlglot AST: full schema.table allowlist (fixes mysql.user->user namespace-collision bypass), reject tableless recon (@@hostname/USER()/SLEEP/LOAD_FILE...), deny dangerous funcs/vars, scan for write/DDL nodes (DELETE...RETURNING in subquery), LIMIT clamp. - Make readonly DB mandatory (no main-DB fallback). - Wire AI_TEXT2SQL_ENABLED (capability builder + run_query). - _execute_final: add asyncio.wait_for timeout; remove dead Text2SqlTimeoutError. Model consolidation (M4): - Add AIDefaultModelScene.text2sql; resolve via ai_default_model_service. - Remove AI_TEXT2SQL_PROVIDER_ID / AI_TEXT2SQL_MODEL_ID. Tests: 41 guardrail cases incl. namespace-collision, tableless recon, INTO OUTFILE, DELETE...RETURNING, large-LIMIT clamp. --- .env.example | 3 - capabilities/text2sql.py | 3 + enums.py | 1 + plugin.toml | 2 - requirements.txt | 1 + tests/test_guardrails.py | 120 ++++++++++++++++++-- text2sql/engine.py | 66 ++++++++--- text2sql/exceptions.py | 11 +- text2sql/guardrails.py | 232 +++++++++++++++++++++++++++------------ text2sql/readonly_db.py | 31 +++--- 10 files changed, 352 insertions(+), 118 deletions(-) diff --git a/.env.example b/.env.example index fee537e..5ba7aee 100644 --- a/.env.example +++ b/.env.example @@ -8,9 +8,6 @@ AI_TEXT2SQL_SCHEMA=fba AI_TEXT2SQL_MAX_ROWS=200 AI_TEXT2SQL_TIMEOUT=15 AI_TEXT2SQL_MAX_RETRIES=2 -# 默认模型(providers 表 id + 模型 id;留空则取首个启用的 OpenAI 兼容供应商+模型) -AI_TEXT2SQL_PROVIDER_ID=0 -AI_TEXT2SQL_MODEL_ID= # 只读数据库账号(强烈建议配置仅 SELECT 权限账号;留空则回退主库并强制护栏) AI_TEXT2SQL_READONLY_HOST= AI_TEXT2SQL_READONLY_PORT=0 diff --git a/capabilities/text2sql.py b/capabilities/text2sql.py index 6c36eb8..5096a0e 100644 --- a/capabilities/text2sql.py +++ b/capabilities/text2sql.py @@ -23,6 +23,9 @@ async def build_text2sql_capability(ctx: CapabilityContext) -> CapabilityResult: :param ctx: 能力构建上下文 :return: """ + if not settings.AI_TEXT2SQL_ENABLED: + # 总开关关闭时不向 chat 注入工具(/queries 也在 run_query 处拒绝) + return CapabilityResult(capability=None) dataset_id = ctx.forwarded_props.text2sql_dataset_id if not dataset_id: return CapabilityResult(capability=None) diff --git a/enums.py b/enums.py index 148f087..daf1ec9 100644 --- a/enums.py +++ b/enums.py @@ -5,6 +5,7 @@ class AIDefaultModelScene(StrEnum): """AI 默认模型场景""" assistant = 'assistant' + text2sql = 'text2sql' class AIProviderType(IntEnum): diff --git a/plugin.toml b/plugin.toml index 0b80bba..082360a 100644 --- a/plugin.toml +++ b/plugin.toml @@ -25,8 +25,6 @@ AI_TEXT2SQL_SCHEMA = "fba" AI_TEXT2SQL_MAX_ROWS = 200 AI_TEXT2SQL_TIMEOUT = 15 AI_TEXT2SQL_MAX_RETRIES = 2 -AI_TEXT2SQL_PROVIDER_ID = 0 -AI_TEXT2SQL_MODEL_ID = "" AI_TEXT2SQL_READONLY_HOST = "" AI_TEXT2SQL_READONLY_PORT = 0 AI_TEXT2SQL_READONLY_USER = "" diff --git a/requirements.txt b/requirements.txt index 92c8910..8e83f0a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ pydantic-ai-slim[openai,google,anthropic,xai,openrouter,mcp,retries,tavily,duckduckgo,ag-ui,exa]==2.0.0 pydantic-ai-harness[code-mode]>=0.4.0 +sqlglot>=25.0.0 diff --git a/tests/test_guardrails.py b/tests/test_guardrails.py index 5d2e0e4..bf75ecb 100644 --- a/tests/test_guardrails.py +++ b/tests/test_guardrails.py @@ -1,6 +1,8 @@ -"""Text2SQL 安全护栏单测(纯逻辑,无 DB) +"""Text2SQL 安全护栏单测(纯逻辑,无 DB)。 -覆盖:放行普通 SELECT/JOIN/CTE/子查询/已有 LIMIT;拦截非 SELECT、多语句、INTO、越表(含 CTE 体/UNION 内越表)。 +覆盖:放行普通 SELECT/JOIN/CTE/子查询/UNION/已有 LIMIT;拦截非 SELECT、多语句、 +INTO OUTFILE、越表(含命名空间碰撞 mysql.user→user、CTE 体/UNION 内越表)、 +无表侦察(@@hostname/USER()/SLEEP/LOAD_FILE 等)、大 LIMIT 钳制。 """ import pytest @@ -8,13 +10,24 @@ from backend.plugin.ai.text2sql.exceptions import TableNotAllowedError, UnsafeSqlError from backend.plugin.ai.text2sql.guardrails import validate_and_normalize -ALLOWLIST = {'users', 'orders', 'products'} +# 白名单按完整 schema.table 给出(小写);default_schema 用于 SQL 未显式指定 schema 时 +ALLOWLIST = {'public.users', 'public.orders', 'public.products', 'public.user'} +DEFAULT_SCHEMA = 'public' def _guard(sql: str, **kwargs: object) -> str: allowlist = kwargs.get('allowlist', ALLOWLIST) # type: ignore[arg-type] max_rows = kwargs.get('max_rows', 100) # type: ignore[arg-type] - return validate_and_normalize(sql, allowlist=allowlist, max_rows=max_rows) + default_schema = kwargs.get('default_schema', DEFAULT_SCHEMA) # type: ignore[arg-type] + return validate_and_normalize( + sql, + allowlist=allowlist, + max_rows=max_rows, + default_schema=default_schema, + ) + + +# ---------------- 放行 ---------------- def test_allows_plain_select_injects_limit() -> None: @@ -42,9 +55,16 @@ def test_allows_subquery_in_from() -> None: assert 'LIMIT 100' in sql.upper() -def test_preserves_existing_limit() -> None: +def test_allows_union_of_allowlisted_tables() -> None: + sql = _guard('SELECT id FROM users UNION SELECT id FROM orders') + assert 'UNION' in sql.upper() + assert 'LIMIT 100' in sql.upper() + + +def test_preserves_small_existing_limit() -> None: sql = _guard('SELECT * FROM users LIMIT 5') assert sql.upper().count('LIMIT') == 1 + assert 'LIMIT 5' in sql.upper().replace(' ', ' ') def test_accepts_lowercase_select() -> None: @@ -62,6 +82,9 @@ def test_uses_custom_max_rows() -> None: assert 'LIMIT 42' in sql.upper() +# ---------------- 拦截:非只读 / 多语句 ---------------- + + @pytest.mark.parametrize( ['sql'], [ @@ -85,9 +108,35 @@ def test_blocks_multi_statement() -> None: _guard('SELECT * FROM users; DROP TABLE users') -def test_blocks_select_into() -> None: +def test_blocks_select_into_outfile() -> None: + # sqlglot 不解析 INTO OUTFILE → fail-closed with pytest.raises(UnsafeSqlError): - _guard('SELECT * INTO new_tbl FROM users') + _guard("SELECT * FROM users INTO OUTFILE '/tmp/x'") + + +def test_blocks_delete_returning_in_subquery() -> None: + # PG: SELECT * FROM (DELETE ... RETURNING ...) x —— 子查询内含写操作 + with pytest.raises(UnsafeSqlError): + _guard('SELECT * FROM (DELETE FROM users RETURNING id) x') + + +# ---------------- 拦截:表白名单 / 命名空间碰撞 ---------------- + + +def test_blocks_system_table_via_namespace_collision() -> None: + # 关键回归:mysql.user 经命名空间剥离不得与白名单内的 public.user 碰撞放行 + with pytest.raises(TableNotAllowedError): + _guard('SELECT authentication_string FROM mysql.user') + + +def test_blocks_information_schema() -> None: + with pytest.raises(TableNotAllowedError): + _guard('SELECT * FROM information_schema.tables') + + +def test_blocks_non_allowlisted_table() -> None: + with pytest.raises(TableNotAllowedError): + _guard('SELECT * FROM secrets') def test_blocks_union_exfil_from_non_allowlisted() -> None: @@ -100,9 +149,62 @@ def test_blocks_non_allowlisted_table_inside_cte() -> None: _guard('WITH active AS (SELECT * FROM secrets) SELECT * FROM active') -def test_blocks_non_allowlisted_table() -> None: +def test_explicit_schema_must_match() -> None: + # 即使表名在白名单,schema 不是 public 也应拒(default_schema=public) with pytest.raises(TableNotAllowedError): - _guard('SELECT * FROM secrets') + _guard('SELECT * FROM other.users') + + +# ---------------- 拦截:无表侦察 / 危险函数 ---------------- + + +@pytest.mark.parametrize( + ['sql'], + [ + ['SELECT @@hostname'], + ['SELECT @@version'], + ['SELECT @@datadir'], + ['SELECT USER()'], + ['SELECT CURRENT_USER()'], + ['SELECT VERSION()'], + ['SELECT CONNECTION_ID()'], + ['SELECT 1'], + ['SELECT 1 + 1'], + ], +) +def test_blocks_tableless_recon(sql: str) -> None: + with pytest.raises(UnsafeSqlError): + _guard(sql) + + +def test_blocks_dangerous_func_even_with_table() -> None: + # 即便引用了白名单表,危险函数也要拒 + with pytest.raises(UnsafeSqlError): + _guard('SELECT SLEEP(5) FROM users') + with pytest.raises(UnsafeSqlError): + _guard("SELECT LOAD_FILE('/etc/passwd') FROM users") + + +def test_blocks_user_variable() -> None: + with pytest.raises(UnsafeSqlError): + _guard('SELECT @x := 1 FROM users') + + +# ---------------- LIMIT 钳制 ---------------- + + +def test_clamps_large_limit() -> None: + sql = _guard('SELECT * FROM users LIMIT 1000000') + assert 'LIMIT 100' in sql.upper() + assert '1000000' not in sql + + +def test_keeps_limit_below_cap() -> None: + sql = _guard('SELECT * FROM users LIMIT 10') + assert 'LIMIT 10' in sql.upper() + + +# ---------------- 其它 ---------------- def test_blocks_empty_sql() -> None: diff --git a/text2sql/engine.py b/text2sql/engine.py index 8fb7530..76dbe1d 100644 --- a/text2sql/engine.py +++ b/text2sql/engine.py @@ -20,6 +20,7 @@ from sqlalchemy import text as sa_text from sqlalchemy.ext.asyncio import AsyncSession +from backend.common.enums import DataBaseType from backend.common.exception import errors from backend.common.log import log from backend.core.conf import settings @@ -27,8 +28,10 @@ from backend.plugin.ai.chat.session import AgentSession from backend.plugin.ai.crud.crud_model import ai_model_dao from backend.plugin.ai.crud.crud_provider import ai_provider_dao +from backend.plugin.ai.enums import AIDefaultModelScene from backend.plugin.ai.model import AIText2SqlHistory, AIText2SqlTable from backend.plugin.ai.providers.registry import get_provider_adapter +from backend.plugin.ai.service.default_model_service import ai_default_model_service from backend.plugin.ai.text2sql.exceptions import TableNotAllowedError, UnsafeSqlError from backend.plugin.ai.text2sql.guardrails import validate_and_normalize from backend.plugin.ai.text2sql.readonly_db import get_readonly_session @@ -42,10 +45,11 @@ class Text2SqlDeps: readonly_session: AsyncSession schema: str tables: list[tuple[str, str | None]] # 启用的已选表 (name, 描述) - allowlist: set[str] + allowlist: set[str] # 授权的 'schema.table'(小写) examples: list[dict[str, str]] max_rows: int timeout: int + dialect: str | None # sqlglot 方言(mysql/postgres) class Text2SqlResult(BaseModel): @@ -93,7 +97,8 @@ async def list_tables(ctx: RunContext[Text2SqlDeps]) -> str: @agent.tool async def describe_table(ctx: RunContext[Text2SqlDeps], table_name: str) -> str: - if table_name.lower() not in ctx.deps.allowlist: + ref = f'{ctx.deps.schema.strip().lower()}.{table_name.strip().lower()}' + if ref not in ctx.deps.allowlist: return f'表 {table_name} 不在可查询范围内' rows = await get_columns(ctx.deps.readonly_session, ctx.deps.schema, table_name) lines = [ @@ -105,7 +110,13 @@ async def describe_table(ctx: RunContext[Text2SqlDeps], table_name: str) -> str: @agent.tool async def execute_sql(ctx: RunContext[Text2SqlDeps], sql: str) -> str: try: - safe = validate_and_normalize(sql, allowlist=ctx.deps.allowlist, max_rows=ctx.deps.max_rows) + safe = validate_and_normalize( + sql, + allowlist=ctx.deps.allowlist, + max_rows=ctx.deps.max_rows, + default_schema=ctx.deps.schema, + dialect=ctx.deps.dialect, + ) except (UnsafeSqlError, TableNotAllowedError) as exc: return f'校验失败,请改写:{exc}' try: @@ -138,18 +149,17 @@ async def _resolve_model( 解析 text2sql 使用的供应商与模型 优先使用调用方传入的 provider_id / model_id(例如聊天会话当前选择的模型); - 均未传入时回退到 .env 配置的 AI_TEXT2SQL_PROVIDER_ID / AI_TEXT2SQL_MODEL_ID。 + 均未传入时回退到默认模型配置(AIDefaultModelScene.text2sql 场景)。 :param db: 数据库会话 - :param provider_id: 供应商 ID,缺省回退配置 - :param model_id: 模型 ID,缺省回退配置 + :param provider_id: 供应商 ID,缺省回退默认模型配置 + :param model_id: 模型 ID,缺省回退默认模型配置 :return: (provider, model) """ if not provider_id or not model_id: - provider_id = settings.AI_TEXT2SQL_PROVIDER_ID - model_id = settings.AI_TEXT2SQL_MODEL_ID - if not provider_id or not model_id: - raise errors.RequestError(msg='未配置 Text2SQL 模型,请在 .env 设置 AI_TEXT2SQL_PROVIDER_ID 与 AI_TEXT2SQL_MODEL_ID') + default_model = await ai_default_model_service.get(db=db, scene=AIDefaultModelScene.text2sql) + provider_id = default_model.provider_id + model_id = default_model.model_id provider = await ai_provider_dao.get(db, provider_id) if not provider or not provider.status: raise errors.RequestError(msg='Text2SQL 供应商不可用,请检查供应商状态') @@ -165,15 +175,31 @@ async def _execute_final( sql: str, allowlist: set[str], max_rows: int, + default_schema: str, + dialect: str | None, + timeout: int, ) -> tuple[list[str], list[dict[str, Any]], int]: """ 对最终 SQL 再次过护栏并在只读引擎执行,返回列、数据、总行数 :raises UnsafeSqlError: 未通过护栏 :raises TableNotAllowedError: 越表 + :raises RequestError: 执行超时(LIMIT 已由护栏钳制,结果集有界) """ - safe = validate_and_normalize(sql, allowlist=allowlist, max_rows=max_rows) - result = await readonly_session.execute(sa_text(safe)) + safe = validate_and_normalize( + sql, + allowlist=allowlist, + max_rows=max_rows, + default_schema=default_schema, + dialect=dialect, + ) + try: + result = await asyncio.wait_for( + readonly_session.execute(sa_text(safe)), + timeout=timeout, + ) + except TimeoutError as exc: # noqa: PERF203 + raise errors.RequestError(msg=f'Text2SQL 查询执行超时({timeout}s),请简化查询') from exc rows = result.mappings().all() columns = list(rows[0].keys()) if rows else [] data = [dict(row) for row in rows[:max_rows]] @@ -198,18 +224,24 @@ async def run_query( :param user_id: 用户 ID :param selected_tables: 启用的已选表 :param examples: 召回的 few-shot 样例 - :param provider_id: 指定供应商 ID(缺省回退配置);聊天 capability 传入会话当前模型 - :param model_id: 指定模型 ID(缺省回退配置);聊天 capability 传入会话当前模型 + :param provider_id: 指定供应商 ID(缺省回退默认模型配置);聊天 capability 传入会话当前模型 + :param model_id: 指定模型 ID(缺省回退默认模型配置);聊天 capability 传入会话当前模型 :return: {sql, summary, columns, rows, row_count, duration_ms, history_id} """ + if not settings.AI_TEXT2SQL_ENABLED: + raise errors.ForbiddenError(msg='Text2SQL 未启用(AI_TEXT2SQL_ENABLED=false)') if not selected_tables: raise errors.RequestError(msg='尚未挑选任何数据表,请先在数据源管理中挑选') schema = settings.AI_TEXT2SQL_SCHEMA - allowlist = {table.table_name.lower() for table in selected_tables} + allowlist = { + f'{table.schema_name.strip().lower()}.{table.table_name.strip().lower()}' + for table in selected_tables + } tables_info = [(table.table_name, table.custom_desc or table.table_comment) for table in selected_tables] max_rows = int(settings.AI_TEXT2SQL_MAX_ROWS) timeout = int(settings.AI_TEXT2SQL_TIMEOUT) + dialect = 'mysql' if settings.DATABASE_TYPE == DataBaseType.mysql else 'postgres' provider, model = await _resolve_model(db=db, provider_id=provider_id, model_id=model_id) adapter = get_provider_adapter(provider.type) @@ -239,6 +271,7 @@ async def run_query( examples=examples, max_rows=max_rows, timeout=timeout, + dialect=dialect, ) run_result = await agent.run(question, deps=deps) final_sql = run_result.output.sql @@ -248,6 +281,9 @@ async def run_query( sql=final_sql, allowlist=allowlist, max_rows=max_rows, + default_schema=schema, + dialect=dialect, + timeout=timeout, ) except Exception as exc: # noqa: BLE001 error_msg = str(exc) diff --git a/text2sql/exceptions.py b/text2sql/exceptions.py index 8349dda..97e5a6b 100644 --- a/text2sql/exceptions.py +++ b/text2sql/exceptions.py @@ -1,10 +1,11 @@ -"""Text2SQL 异常定义。 +"""Text2SQL 异常定义. 服务层负责将这些异常映射为对应的 HTTP 错误(见 backend/common/exception/errors): - UnsafeSqlError -> 400(SQL 未通过安全护栏) - TableNotAllowedError -> 400(引用了未挑选/未授权的表) -- Text2SqlTimeoutError -> 408(执行超时) + +执行超时由 engine 直接抛 errors.RequestError(-> 400),不在此单独建类。 """ @@ -13,12 +14,8 @@ class Text2SqlError(Exception): class UnsafeSqlError(Text2SqlError): - """SQL 未通过安全护栏(非只读 / 多语句 / 危险关键字 等)""" + """SQL 未通过安全护栏(非只读 / 多语句 / 危险函数或变量 等)""" class TableNotAllowedError(UnsafeSqlError): """SQL 引用了未挑选或未授权的表""" - - -class Text2SqlTimeoutError(Text2SqlError): - """SQL 执行超时""" diff --git a/text2sql/guardrails.py b/text2sql/guardrails.py index 1a867de..362d70d 100644 --- a/text2sql/guardrails.py +++ b/text2sql/guardrails.py @@ -1,88 +1,184 @@ -"""Text2SQL 安全护栏。 - -设计原则(fail-closed): -1. 只允许**单条只读 SELECT**(基于 sqlparse.get_type,权威判定,避免关键字误伤列名)。 -2. 拒绝多语句(防 `;` 注入第二条写语句)。 -3. 拒绝 `SELECT ... INTO`(PG 建表 / MySQL 写文件·变量等副作用)。 -4. **表白名单**:用正则全文扫描 FROM/JOIN 引用表,覆盖子查询体、CTE 体、UNION 等 - 任意嵌套位置——管理员控制 AI 可见表的边界,防越表取数(如 sys_user 密码)。 - CTE 名视为查询内合法“虚拟表”,不要求在白名单中。 -5. 缺 LIMIT 时注入 `LIMIT `;已有 LIMIT 则保留(行数上限在执行层兜底)。 - -本模块保持纯净(仅依赖 sqlparse + re + 自身 exceptions),便于无 DB 单测。 -真正的只读权限与超时由执行层(只读账号 + readonly_db)兜底,护栏是应用层第一道关。 - -已知 v1 局限(由只读账号兜底,且为 fail-closed 偏向拒绝): -- 注释/字符串字面量里出现 `from/join <词>` 可能被误当作表名而拒绝(极少见)。 -- 列名恰好为 `from`(如 `SELECT from FROM t`)等保留字边界场景。 +"""Text2SQL 安全护栏(sqlglot AST,fail-closed)。 + +相比 v1(sqlparse.get_type + 正则提表名)的关键修正: +1. **表白名单按完整 ``schema.table`` 校验**,不再剥离命名空间——堵死 + ``mysql.user`` / ``information_schema.*`` / ``sys.*`` 因表名碰撞(如业务表恰叫 + ``user``)而越权读取的路径。 +2. **强制至少引用一张白名单内的真实表**——阻断无 ``FROM`` 的侦察语句 + (``@@hostname`` / ``USER()`` / ``VERSION()`` / ``SLEEP`` / ``LOAD_FILE`` …)。 +3. **拒绝危险函数与系统/会话变量**,即便语句引用了白名单表(如 ``SELECT USER() FROM t``)。 +4. 解析失败 / 多语句 / 非 SELECT 一律拒绝(fail-closed);``DATABASE()`` / ``SCHEMA()`` + / ``SELECT ... INTO OUTFILE`` 等 sqlglot 解析不了的形态自动落入选拒绝分支。 + 另对整棵树做写/DDL 节点扫描,防 ``DELETE ... RETURNING`` 嵌入子查询等形态。 +5. **LIMIT 服务端钳制到 ≤ max_rows**,防 LLM 生成大 LIMIT 配合 ``.all()`` 造成 OOM。 +6. CTE 名视为查询内合法"虚拟表",不要求落在白名单中。 + +只读账号(readonly_db)仍是真正的权限边界,本护栏是应用层第一道关。 """ -import re +from __future__ import annotations -import sqlparse +import sqlglot +from sqlglot import exp from backend.plugin.ai.text2sql.exceptions import TableNotAllowedError, UnsafeSqlError -_INTO_RE = re.compile(r'\binto\b', re.IGNORECASE) -_HAS_LIMIT_RE = re.compile(r'\blimit\b', re.IGNORECASE) -# FROM/JOIN 后的表名(支持 schema.table,取末段去引号) -_TABLE_REF_RE = re.compile(r'\b(?:from|join)\s+([A-Za-z_]\w*(?:\.[A-Za-z_]\w*)?)', re.IGNORECASE) -# CTE 名称:WITH [RECURSIVE] name AS ( 或 , name AS ( -_CTE_NAME_RE = re.compile(r'(?:\bwith\b\s+(?:recursive\s+)?|,)\s*([A-Za-z_]\w*)\s+as\s*\(', re.IGNORECASE) - - -def _extract_table_names(text: str) -> set[str]: - """ - 全文提取 FROM/JOIN 引用的表名(覆盖子查询体 / CTE 体 / UNION 等嵌套位置) - - :param text: SQL 文本 - :return: 小写裸表名集合 +# 允许的只读根节点类名:SELECT 及集合操作(UNION / INTERSECT / EXCEPT) +_READ_ONLY_ROOT_NAMES = frozenset({'Select', 'Union', 'Intersect', 'Except'}) + +# 任意位置出现即拒绝的写/DDL 节点类名(用类名匹配,兼容不同 sqlglot 版本的节点命名) +_WRITE_NODE_NAMES = frozenset({ + 'Insert', 'Update', 'Delete', 'Drop', 'Alter', 'Create', + 'TruncateTable', 'AlterTable', 'Merge', 'Command', + 'Grant', 'Revoke', 'AddConstraint', 'DropConstraint', +}) + +# 即便引用了白名单表也一律禁止的函数调用(侦察 / DoS / 文件读 / 锁) +# 注:只匹配 exp.Anonymous(函数调用),列名/别名同名不受影响。 +_DANGEROUS_FUNCS = frozenset({ + 'load_file', + 'sleep', + 'benchmark', + 'get_lock', + 'release_lock', + 'user', + 'current_user', + 'system_user', + 'session_user', + 'database', + 'schema', + 'connection_id', + 'version', + 'current_database', + 'current_setting', + 'pg_sleep', + 'pg_read_file', +}) + + +def _table_ref(table: exp.Table, default_schema: str) -> str: + """Table 节点 -> 规范化 ``schema.table``(小写);未显式指定 schema 时用 default_schema。""" + schema = (table.db or default_schema).strip().lower() + return f'{schema}.{table.name.strip().lower()}' + + +def _cte_names(root: exp.Expression) -> set[str]: + """查询内定义的 CTE 名(小写),视为合法"虚拟表"。""" + return {cte.alias.strip().lower() for cte in root.find_all(exp.CTE)} + + +def _find_dangerous(root: exp.Expression) -> str | None: + """返回命中的危险描述(用于报错),未命中返回 None。""" + for node in root.walk(): + if isinstance(node, exp.Parameter): + # @var / @@var —— 系统/会话变量,侦察用,一律拒 + return f'系统/会话变量不被允许: {node.sql()}' + if isinstance(node, exp.Anonymous): + name = (node.name or '').strip().lower() + if name in _DANGEROUS_FUNCS: + return f'危险函数不被允许: {node.name.upper()}()' + return None + + +def _classify_tables( + root: exp.Expression, + *, + allowlist: set[str], + ctes: set[str], + default_schema: str, +) -> tuple[list[str], list[str]]: + """分类所有 Table 引用。 + + :return: (授权的真实表引用, 未授权的表引用);CTE 引用不计入未授权。 """ - names: set[str] = set() - for match in _TABLE_REF_RE.findall(text): - names.add(match.split('.')[-1].strip().lower()) - return names - - -def _extract_cte_names(text: str) -> set[str]: - """提取 CTE 名称(视为查询内合法的“虚拟表”,不要求在 DB 白名单中)""" - return {name.lower() for name in _CTE_NAME_RE.findall(text)} - - -def validate_and_normalize(sql: str, *, allowlist: set[str], max_rows: int) -> str: - """ - 校验并归一化 SQL:仅放行单条只读 SELECT,强制表白名单与 LIMIT + real: list[str] = [] + disallowed: list[str] = [] + seen: set[str] = set() + for table in root.find_all(exp.Table): + ref = _table_ref(table, default_schema) + if ref in seen: + continue + seen.add(ref) + # CTE 引用(虚拟表):跳过;其定义体内的真实表由另一个 Table 节点捕获 + if table.name.strip().lower() in ctes: + continue + if ref in allowlist: + real.append(ref) + else: + disallowed.append(ref) + return real, disallowed + + +def _clamp_limit(query: exp.Expression, max_rows: int) -> None: + """将最外层 LIMIT 钳制到 ≤ max_rows(原地改 AST);缺失则注入。""" + cap = max(1, int(max_rows)) + limit = query.args.get('limit') + if limit is None: + query.set('limit', exp.Limit(expression=exp.Literal.number(cap))) + return + expr = limit.expression # type: ignore[attr-defined] + if isinstance(expr, exp.Literal) and expr.is_int: + if int(expr.name) > cap: + limit.set('expression', exp.Literal.number(cap)) # type: ignore[attr-defined] + return + # 非字面量(表达式/占位符等):保守替换为 cap,不冒险评估 + limit.set('expression', exp.Literal.number(cap)) # type: ignore[attr-defined] + + +def validate_and_normalize( + sql: str, + *, + allowlist: set[str], + max_rows: int, + default_schema: str, + dialect: str | None = None, +) -> str: + """校验并归一化 SQL:仅放行单条只读 SELECT,强制表白名单、禁侦察、LIMIT 钳制。 :param sql: 待校验 SQL - :param allowlist: 允许查询的表名集合(小写裸名) - :param max_rows: 缺省 LIMIT 行数上限 - :return: 归一化后的安全 SQL(缺 LIMIT 时已注入) - :raises UnsafeSqlError: 空 / 多语句 / 非 SELECT / 含 INTO 等 + :param allowlist: 允许查询的 ``schema.table`` 集合(小写) + :param max_rows: 缺省/上限 LIMIT 行数 + :param default_schema: SQL 未显式指定 schema 时使用的默认 schema(即业务库) + :param dialect: sqlglot 方言(``mysql`` / ``postgres`` 等),默认通用 + :return: 归一化后的安全 SQL(已注入或钳制 LIMIT) + :raises UnsafeSqlError: 空 / 解析失败 / 多语句 / 非只读 / 含写节点 / 含危险函数或变量 / 无白名单表 :raises TableNotAllowedError: 引用了不在白名单内的表 """ if not sql or not sql.strip(): raise UnsafeSqlError('SQL 为空') - statements = [s for s in sqlparse.parse(sql) if s.token_first(skip_ws=True, skip_cm=True) is not None] + try: + parsed = sqlglot.parse(sql, read=dialect) + except Exception as exc: # noqa: BLE001 — 解析失败一律 fail-closed + raise UnsafeSqlError(f'SQL 解析失败: {exc}') from exc + + statements = [s for s in parsed if s is not None] if len(statements) != 1: raise UnsafeSqlError('仅允许单条 SQL 语句') - statement = statements[0] - if statement.get_type() != 'SELECT': - raise UnsafeSqlError('仅允许 SELECT 查询') - - text = str(statement) - if _INTO_RE.search(text): - raise UnsafeSqlError('禁止 SELECT ... INTO 等带副作用写操作') - - referenced = _extract_table_names(text) - allowed = {name.lower() for name in allowlist} | _extract_cte_names(text) - disallowed = referenced - allowed + stmt = statements[0] + if type(stmt).__name__ not in _READ_ONLY_ROOT_NAMES: + raise UnsafeSqlError('仅允许只读 SELECT(含 UNION/INTERSECT/EXCEPT)查询') + for node in stmt.walk(): + if type(node).__name__ in _WRITE_NODE_NAMES: + raise UnsafeSqlError(f'仅允许只读查询: 检测到 {type(node).__name__}') + + danger = _find_dangerous(stmt) + if danger: + raise UnsafeSqlError(danger) + + allowlist = {name.strip().lower() for name in allowlist} + ctes = _cte_names(stmt) + real_refs, disallowed = _classify_tables( + stmt, + allowlist=allowlist, + ctes=ctes, + default_schema=default_schema, + ) if disallowed: raise TableNotAllowedError(f'引用了未授权的表: {", ".join(sorted(disallowed))}') + if not real_refs: + raise UnsafeSqlError('查询必须引用至少一张已授权的数据表(禁止无表侦察语句)') - normalized = statement.value.strip().rstrip(';').strip() - if not _HAS_LIMIT_RE.search(normalized): - normalized = f'{normalized}\nLIMIT {max(1, int(max_rows))}' - - return normalized + _clamp_limit(stmt, max_rows) + return stmt.sql(dialect=dialect) diff --git a/text2sql/readonly_db.py b/text2sql/readonly_db.py index c5ddc3f..4de88e4 100644 --- a/text2sql/readonly_db.py +++ b/text2sql/readonly_db.py @@ -4,6 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine from backend.common.enums import DataBaseType +from backend.common.exception import errors from backend.common.log import log from backend.core.conf import settings @@ -42,9 +43,11 @@ def get_readonly_engine() -> AsyncEngine: """ 获取只读引擎 - 未配置只读账号时回退主库引擎并告警(此时更依赖服务端 sqlparse 护栏兜底)。 + fail-closed:未配置只读账号时**拒绝回退主库**——执行 LLM 生成的 SQL 必须落在 + 显式配置、仅 SELECT 权限的只读账号上。主库可写,绝不可作为 Text2SQL 的执行目标。 :return: 只读异步引擎 + :raises RequestError: 未配置只读账号 """ global _readonly_engine, _readonly_session_maker if _readonly_engine is not None: @@ -52,21 +55,21 @@ def get_readonly_engine() -> AsyncEngine: url = _readonly_url_or_none() if url is None: - log.warning('AI Text2SQL 未配置只读账号(AI_TEXT2SQL_READONLY_*),回退主库引擎并强制护栏') - from backend.database.db import async_engine as main_engine # noqa: PLC0415 - - _readonly_engine = main_engine - else: - _readonly_engine = create_async_engine( - url, - future=True, - pool_size=5, - max_overflow=10, - pool_timeout=30, - pool_recycle=3600, - pool_pre_ping=True, + raise errors.RequestError( + msg='AI Text2SQL 未配置只读账号(AI_TEXT2SQL_READONLY_HOST/USER/PASSWORD),' + '拒绝执行:执行 LLM 生成的 SQL 必须使用仅 SELECT 权限的只读账号', ) + _readonly_engine = create_async_engine( + url, + future=True, + pool_size=5, + max_overflow=10, + pool_timeout=30, + pool_recycle=3600, + pool_pre_ping=True, + ) + _readonly_session_maker = async_sessionmaker( bind=_readonly_engine, class_=AsyncSession, From c2880b5e33d36c21fbe08c867826ed00ad36fbfa Mon Sep 17 00:00:00 2001 From: dengjingren Date: Thu, 25 Jun 2026 23:08:29 +0800 Subject: [PATCH 3/4] refactor(text2sql): single-shot agent, drop bespoke tool loop (M1) Replace the bespoke 3-tool pydantic-ai Agent (list_tables/describe_table/ execute_sql) with a single structured-output call (output_type=Text2SqlResult, no tools). Table/column context is pre-fetched server-side and inlined into the system prompt. Why: the tool loop duplicated the plugin's existing capability/builtin_toolset pipeline (reviewer 'major reject reason'). The execute_sql self-correction loop was redundant -- the final SQL is re-guarded + re-executed by _execute_final regardless. Net effect: smaller attack surface (the Agent can no longer execute SQL at all), one LLM round-trip instead of N. run_query return contract unchanged. _resolve_model / _execute_final / _write_history untouched. Drops now-unused imports (json, guardrail exceptions). --- text2sql/engine.py | 114 +++++++++++++++------------------------------ 1 file changed, 38 insertions(+), 76 deletions(-) diff --git a/text2sql/engine.py b/text2sql/engine.py index 76dbe1d..a7dbf4e 100644 --- a/text2sql/engine.py +++ b/text2sql/engine.py @@ -1,15 +1,15 @@ -"""Text2SQL 核心引擎(原生 Pydantic AI)。 +"""Text2SQL 核心引擎(单次结构化输出,无工具循环)。 -流程:解析已挑选表 + few-shot → 复用供应商适配器建 pydantic-ai 模型 → 构建 Agent -(system_prompt + list_tables/describe_table/execute_sql 工具,execute_sql 强制过护栏) -→ 运行得到 {sql, summary} → 对最终 SQL 再次过护栏并在只读引擎执行取数 → 落历史。 +流程:预取已选表列信息 + few-shot → 内联进 system_prompt → 复用供应商适配器建 +pydantic-ai 模型 → 单次 Agent.run 得到 {sql, summary}(Agent 不执行任何 SQL)→ +对最终 SQL 过 sqlglot 护栏并在只读引擎执行取数 → 落历史。 -注意:pydantic-ai 为 2.x beta,本模块用稳定的装饰器 API(@agent.system_prompt / @agent.tool), -规避构造参数命名差异;上线前务必用真实模型跑通一次(见 README/计划)。 +安全边界全部在执行层:Agent 产出的 SQL 不被信任,由 _execute_final 重新过护栏 + +只读账号执行。如此去掉了原先 bespoke 的 list_tables/describe_table/execute_sql +工具循环(与插件既有 capability/tool 体系重复),并收紧攻击面(Agent 再也碰不到执行)。 """ import asyncio -import json import time from collections.abc import Sequence from dataclasses import dataclass @@ -32,7 +32,6 @@ from backend.plugin.ai.model import AIText2SqlHistory, AIText2SqlTable from backend.plugin.ai.providers.registry import get_provider_adapter from backend.plugin.ai.service.default_model_service import ai_default_model_service -from backend.plugin.ai.text2sql.exceptions import TableNotAllowedError, UnsafeSqlError from backend.plugin.ai.text2sql.guardrails import validate_and_normalize from backend.plugin.ai.text2sql.readonly_db import get_readonly_session from backend.plugin.ai.text2sql.schema_meta import get_columns @@ -40,16 +39,12 @@ @dataclass class Text2SqlDeps: - """Agent 运行依赖""" + """Agent 运行依赖(执行层的只读会话/白名单/超时由 _execute_final 持有,不入 Agent)""" - readonly_session: AsyncSession schema: str - tables: list[tuple[str, str | None]] # 启用的已选表 (name, 描述) - allowlist: set[str] # 授权的 'schema.table'(小写) + tables_info: list[tuple[str, str | None]] # 启用的已选表 (name, 描述) + columns_map: dict[str, Any] # table_name -> 列信息(预取,内联进 system_prompt) examples: list[dict[str, str]] - max_rows: int - timeout: int - dialect: str | None # sqlglot 方言(mysql/postgres) class Text2SqlResult(BaseModel): @@ -60,25 +55,37 @@ class Text2SqlResult(BaseModel): def _build_prompt(deps: Text2SqlDeps) -> str: - tables = '\n'.join(f'- {name}:{comment or "(无注释)"}' for name, comment in deps.tables) or '- (无)' + blocks = [] + for name, comment in deps.tables_info: + cols = deps.columns_map.get(name) or [] + col_lines = '\n'.join( + f" - {row['column_name']} {row['column_type']}{' [PK]' if row['is_pk'] else ''}" + f":{row['column_comment'] or ''}" + for row in cols + ) or ' - (无列信息)' + blocks.append(f'- {name}({comment or "无注释"}):\n{col_lines}') + tables = '\n'.join(blocks) or '- (无)' examples = '' if deps.examples: - blocks = '\n'.join(f'问题:{e["question"]}\nSQL:{e["sql"]}' for e in deps.examples) - examples = f'\n\n参考示例:\n{blocks}' + ex_blocks = '\n'.join(f'问题:{e["question"]}\nSQL:{e["sql"]}' for e in deps.examples) + examples = f'\n\n参考示例:\n{ex_blocks}' return ( '你是 FBA Text2SQL 助手,将用户的自然语言问题转为只读 SQL,并给出结果摘要。\n' f'数据库 schema:{deps.schema}\n' - f'仅可查询以下表(严禁查询其他表):\n{tables}{examples}\n\n' + f'仅可查询以下表及其列(严禁查询其他表、系统表、信息架构):\n{tables}{examples}\n\n' '规则:\n' - '1. 只生成只读 SELECT(禁止 INSERT/UPDATE/DELETE/DDL/多语句/SELECT INTO)。\n' - '2. 只查询上面列出的表。\n' - '3. 需要列信息时调用 describe_table;验证查询时调用 execute_sql。\n' - '4. 若 execute_sql 报错,请根据错误修正 SQL 后重试。\n' - '5. 最终输出包含字段:sql(最终 SQL)与 summary(对结果的中文摘要)。' + '1. 只生成单条只读 SELECT(禁止 INSERT/UPDATE/DELETE/DDL/多语句/SELECT INTO)。\n' + '2. 只查询上面列出的表,按列名与类型构造 SQL。\n' + '3. 最终输出包含字段:sql(最终 SQL)与 summary(对结果的中文摘要)。' ) def _build_agent(model: Any) -> Agent: # type: ignore[type-arg] + """单次结构化输出 Agent:无工具,列信息已内联进 system_prompt。 + + Agent 不执行任何 SQL——其产出的 sql 字段会被 _execute_final 重新过护栏 + 只读执行, + 因此无需(也不应)在此再提供 execute_sql 工具或自纠正循环。 + """ agent = Agent( model=model, deps_type=Text2SqlDeps, @@ -90,52 +97,6 @@ def _build_agent(model: Any) -> Agent: # type: ignore[type-arg] def _system_prompt(ctx: RunContext[Text2SqlDeps]) -> str: return _build_prompt(ctx.deps) - @agent.tool - async def list_tables(ctx: RunContext[Text2SqlDeps]) -> str: - items = '\n'.join(f'- {name}:{comment or "(无注释)"}' for name, comment in ctx.deps.tables) - return f'可查询表:\n{items}' - - @agent.tool - async def describe_table(ctx: RunContext[Text2SqlDeps], table_name: str) -> str: - ref = f'{ctx.deps.schema.strip().lower()}.{table_name.strip().lower()}' - if ref not in ctx.deps.allowlist: - return f'表 {table_name} 不在可查询范围内' - rows = await get_columns(ctx.deps.readonly_session, ctx.deps.schema, table_name) - lines = [ - f"- {row['column_name']} {row['column_type']}{' [PK]' if row['is_pk'] else ''}:{row['column_comment'] or ''}" - for row in rows - ] - return f'表 {table_name} 列:\n' + '\n'.join(lines) - - @agent.tool - async def execute_sql(ctx: RunContext[Text2SqlDeps], sql: str) -> str: - try: - safe = validate_and_normalize( - sql, - allowlist=ctx.deps.allowlist, - max_rows=ctx.deps.max_rows, - default_schema=ctx.deps.schema, - dialect=ctx.deps.dialect, - ) - except (UnsafeSqlError, TableNotAllowedError) as exc: - return f'校验失败,请改写:{exc}' - try: - result = await asyncio.wait_for( - ctx.deps.readonly_session.execute(sa_text(safe)), - timeout=ctx.deps.timeout, - ) - except TimeoutError: - return '执行超时,请简化查询' - except Exception as exc: # noqa: BLE001 - return f'执行失败,请检查 SQL:{exc}' - rows = result.mappings().all() - columns = list(rows[0].keys()) if rows else [] - preview = [dict(row) for row in rows[:20]] - return ( - f'列:{columns}\n命中行数:{len(rows)}\n' - f'预览(前 {len(preview)} 行):\n{json.dumps(preview, ensure_ascii=False, default=str)}' - ) - return agent @@ -263,15 +224,16 @@ async def run_query( try: agent = _build_agent(session.model) async with get_readonly_session() as readonly_session: + # 预取列信息内联进 prompt(替代原 describe_table 工具);Agent 全程不执行任何 SQL + columns_map = { + table.table_name: await get_columns(readonly_session, schema, table.table_name) + for table in selected_tables + } deps = Text2SqlDeps( - readonly_session=readonly_session, schema=schema, - tables=tables_info, - allowlist=allowlist, + tables_info=tables_info, + columns_map=columns_map, examples=examples, - max_rows=max_rows, - timeout=timeout, - dialect=dialect, ) run_result = await agent.run(question, deps=deps) final_sql = run_result.output.sql From 89feedcd3a000b8eebbe929d581d1cb47f4a0dc4 Mon Sep 17 00:00:00 2001 From: dengjingren Date: Thu, 25 Jun 2026 23:49:12 +0800 Subject: [PATCH 4/4] fix(text2sql): broaden tool description to cue log/count/stat queries MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The chat capability tool text2sql_query was described as 'FBA 业务数据' with order/supplier examples, so the model did not invoke it for log/count questions (e.g. 'how many operation logs today'). Broaden the description to explicitly cue logs/counts/stats and state that any database-data question should prefer this tool. --- capabilities/text2sql.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/capabilities/text2sql.py b/capabilities/text2sql.py index 5096a0e..3e2bf33 100644 --- a/capabilities/text2sql.py +++ b/capabilities/text2sql.py @@ -65,11 +65,16 @@ def _build_text2sql_toolset(*, provider_id: int, model_id: str, dataset_id: int) @toolset.tool async def text2sql_query(ctx: RunContext[ChatAgentDeps], question: str) -> str: """ - 用自然语言查询当前所选数据集中的 FBA 业务数据(Text2SQL)。 + 查询系统数据库中已授权的业务/运营数据(Text2SQL:自然语言转 SQL 只读查询)。 - 仅对该数据集中启用的表执行只读查询,全程受安全护栏保护。 - 适合回答统计、聚合、明细类数据问题,例如: - 「最近 7 天订单总金额是多少」「每个供应商有多少条记录」「按金额倒序的前 10 笔订单」。 + 只要用户的问题涉及"数据库里的数据"——日志、记录、订单、用户、统计、计数、 + 排名、趋势、明细、聚合等——都应优先调用本工具,而不是让用户去接外部系统。 + 仅查询当前所选数据集中已启用的表,全程只读、受安全护栏保护。 + + 典型问题(均应调用本工具): + - 「今天有多少条操作日志 / 登录日志」「最近的日志记录」 + - 「最近 7 天订单总金额是多少」「每个供应商有多少条记录」 + - 「按金额倒序的前 10 笔订单」「今天/本周/某段时间内 XXX 的数量/计数/汇总」 :param question: 用户的自然语言数据问题 :return: