diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 8a2565c41d..7a6608303b 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -666,25 +666,56 @@ async def _start_mcp_server( if shutdown_event is None: shutdown_event = asyncio.Event() - mcp_client: MCPClient | None = None - try: - mcp_client = await asyncio.wait_for( - self._init_mcp_client(name, cfg), - timeout=timeout, + mcp_client = MCPClient() + mcp_client.name = name + + connect_done = asyncio.Event() + connect_error: BaseException | None = None + + async def connect_and_lifecycle() -> None: + # Single task that handles connect, lifecycle, and cleanup. + + nonlocal connect_error + try: + await mcp_client.connect_to_server(cfg, name) + await mcp_client.list_tools_and_save() + except asyncio.CancelledError: + # cleanup on cancellation + try: + await mcp_client.cleanup() + except BaseException: + pass + raise + except Exception as e: + connect_error = e + try: + await mcp_client.cleanup() + except Exception: + pass + connect_done.set() + return + + # Register tools + self.func_list = [ + f + for f in self.func_list + if not (isinstance(f, MCPTool) and f.mcp_server_name == name) + ] + for tool in mcp_client.tools: + func_tool = MCPTool( + mcp_tool=tool, + mcp_client=mcp_client, + mcp_server_name=name, + ) + self.func_list.append(func_tool) + + logger.info( + f"Connected to MCP server {name}, " + f"Tools: {[t.name for t in mcp_client.tools]}" ) - except asyncio.TimeoutError as exc: - raise MCPInitTimeoutError( - f"Connected to MCP server {name} timeout ({timeout:g} seconds)" - ) from exc - except Exception: - logger.error(f"Failed to initialize MCP client {name}", exc_info=True) - raise - finally: - if mcp_client is None: - async with self._runtime_lock: - self._mcp_starting.discard(name) - async def lifecycle() -> None: + connect_done.set() + try: await shutdown_event.wait() logger.info(f"Received shutdown signal for MCP client {name}") @@ -692,9 +723,12 @@ async def lifecycle() -> None: logger.debug(f"MCP client {name} task was cancelled") raise finally: - await self._terminate_mcp_client(name) + # Cleanup in the same task that entered the anyio contexts + await asyncio.shield(self._terminate_mcp_client(name)) - lifecycle_task = asyncio.create_task(lifecycle(), name=f"mcp-client:{name}") + lifecycle_task = asyncio.create_task( + connect_and_lifecycle(), name=f"mcp-client:{name}" + ) async with self._runtime_lock: self._mcp_server_runtime[name] = _MCPServerRuntime( name=name, @@ -704,6 +738,26 @@ async def lifecycle() -> None: ) self._mcp_starting.discard(name) + try: + await asyncio.wait_for(connect_done.wait(), timeout=timeout) + except (asyncio.TimeoutError, asyncio.CancelledError) as e: + lifecycle_task.cancel() + await asyncio.gather(lifecycle_task, return_exceptions=True) + async with self._runtime_lock: + self._mcp_starting.discard(name) + self._mcp_server_runtime.pop(name, None) + if isinstance(e, asyncio.TimeoutError): + raise MCPInitTimeoutError( + f"Connected to MCP server {name} timeout ({timeout:g} seconds)" + ) from e + raise + + if connect_error is not None: + async with self._runtime_lock: + self._mcp_starting.discard(name) + self._mcp_server_runtime.pop(name, None) + raise connect_error + async def _shutdown_runtimes( self, runtimes: list[_MCPServerRuntime], @@ -768,41 +822,6 @@ async def _cleanup_mcp_client_safely( f"Failed to cleanup MCP client resources {name}: {cleanup_exc}" ) - async def _init_mcp_client(self, name: str, config: dict) -> MCPClient: - """初始化单个MCP客户端""" - mcp_client = MCPClient() - mcp_client.name = name - try: - await mcp_client.connect_to_server(config, name) - tools_res = await mcp_client.list_tools_and_save() - except asyncio.CancelledError: - await self._cleanup_mcp_client_safely(mcp_client, name) - raise - except Exception: - await self._cleanup_mcp_client_safely(mcp_client, name) - raise - logger.debug(f"MCP server {name} list tools response: {tools_res}") - tool_names = [tool.name for tool in tools_res.tools] - - # 移除该MCP服务之前的工具(如有) - self.func_list = [ - f - for f in self.func_list - if not (isinstance(f, MCPTool) and f.mcp_server_name == name) - ] - - # 将 MCP 工具转换为 FuncTool 并添加到 func_list - for tool in mcp_client.tools: - func_tool = MCPTool( - mcp_tool=tool, - mcp_client=mcp_client, - mcp_server_name=name, - ) - self.func_list.append(func_tool) - - logger.info(f"Connected to MCP server {name}, Tools: {tool_names}") - return mcp_client - async def _terminate_mcp_client(self, name: str) -> None: """关闭并清理MCP客户端""" async with self._runtime_lock: