Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 73 additions & 54 deletions astrbot/core/provider/func_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,35 +666,69 @@ async def _start_mcp_server(
if shutdown_event is None:
shutdown_event = asyncio.Event()

mcp_client: MCPClient | None = None
try:
mcp_client = await asyncio.wait_for(
self._init_mcp_client(name, cfg),
timeout=timeout,
mcp_client = MCPClient()
mcp_client.name = name

connect_done = asyncio.Event()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider replacing the connect_done/connect_error pattern with a single connect Future and centralized cleanup to simplify the MCP client startup lifecycle and error handling.

The main added complexity comes from connect_done + connect_error plus duplicated cleanup inside connect_and_lifecycle. You can keep the “single task owns anyio contexts and cleanup” behavior while simplifying the handshake and reusing the existing cleanup helper.

1. Replace connect_done / connect_error with a single Future

Use a single Future to signal completion/failure of the connect phase, instead of an Event + shared exception variable:

mcp_client = MCPClient()
mcp_client.name = name

loop = asyncio.get_running_loop()
connect_future: asyncio.Future[None] = loop.create_future()

async def connect_and_lifecycle() -> None:
    try:
        await mcp_client.connect_to_server(cfg, name)
        await mcp_client.list_tools_and_save()
    except asyncio.CancelledError as e:
        # Ensure connect_future is completed
        if not connect_future.done():
            connect_future.set_exception(e)
        try:
            await self._cleanup_mcp_client_safely(mcp_client, name)
        except BaseException:
            pass
        raise
    except Exception as e:
        if not connect_future.done():
            connect_future.set_exception(e)
        try:
            await self._cleanup_mcp_client_safely(mcp_client, name)
        except BaseException:
            pass
        return

    # Register tools (same as current behavior)
    self.func_list = [
        f
        for f in self.func_list
        if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
    ]
    for tool in mcp_client.tools:
        func_tool = MCPTool(
            mcp_tool=tool,
            mcp_client=mcp_client,
            mcp_server_name=name,
        )
        self.func_list.append(func_tool)

    logger.info(
        f"Connected to MCP server {name}, "
        f"Tools: {[t.name for t in mcp_client.tools]}"
    )

    # Signal successful connect
    if not connect_future.done():
        connect_future.set_result(None)

    try:
        await shutdown_event.wait()
        logger.info(f"Received shutdown signal for MCP client {name}")
    except asyncio.CancelledError:
        logger.debug(f"MCP client {name} task was cancelled")
        raise
    finally:
        await asyncio.shield(self._terminate_mcp_client(name))

Then the outer timeout / error handling becomes simpler and more idiomatic:

lifecycle_task = asyncio.create_task(
    connect_and_lifecycle(), name=f"mcp-client:{name}"
)

async with self._runtime_lock:
    self._mcp_server_runtime[name] = _MCPServerRuntime(
        name=name,
        client=mcp_client,
        shutdown_event=shutdown_event,
        lifecycle_task=lifecycle_task,
    )
    self._mcp_starting.discard(name)

try:
    await asyncio.wait_for(connect_future, timeout=timeout)
except asyncio.TimeoutError as exc:
    lifecycle_task.cancel()
    await asyncio.gather(lifecycle_task, return_exceptions=True)
    async with self._runtime_lock:
        self._mcp_starting.discard(name)
        self._mcp_server_runtime.pop(name, None)
    raise MCPInitTimeoutError(
        f"Connected to MCP server {name} timeout ({timeout:g} seconds)"
    ) from exc
except asyncio.CancelledError:
    lifecycle_task.cancel()
    await asyncio.gather(lifecycle_task, return_exceptions=True)
    async with self._runtime_lock:
        self._mcp_starting.discard(name)
        self._mcp_server_runtime.pop(name, None)
    raise
except Exception:
    lifecycle_task.cancel()
    await asyncio.gather(lifecycle_task, return_exceptions=True)
    async with self._runtime_lock:
        self._mcp_starting.discard(name)
        self._mcp_server_runtime.pop(name, None)
    raise

This keeps:

  • A single task owning connection, tool registration, shutdown wait, and _terminate_mcp_client.
  • Precise timeout on the connect phase.
  • Centralized cleanup via _cleanup_mcp_client_safely and _terminate_mcp_client.

But removes the manual connect_done signaling and connect_error shared state, making the control flow and error propagation easier to reason about.

connect_error: BaseException | None = None

async def connect_and_lifecycle() -> None:
# Single task that handles connect, lifecycle, and cleanup.

nonlocal connect_error
try:
await mcp_client.connect_to_server(cfg, name)
await mcp_client.list_tools_and_save()
except asyncio.CancelledError:
# cleanup on cancellation
try:
Comment on lines +675 to +684

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Tool registration and logging errors won’t propagate via connect_error, leaving the caller to see only a timeout.

Because the try ends at await mcp_client.list_tools_and_save(), any exception during tool registration or logging won’t set connect_error or connect_done. The caller will then hit the timeout and raise MCPInitTimeoutError instead of the real failure. Please extend the try to include the registration/logging block, or add a dedicated try/except there that sets connect_error and signals connect_done on error.

await mcp_client.cleanup()
except BaseException:
pass
raise
except Exception as e:
connect_error = e
try:
await mcp_client.cleanup()
except Exception:
pass
connect_done.set()
return

# Register tools
self.func_list = [
f
for f in self.func_list
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
for tool in mcp_client.tools:
func_tool = MCPTool(
mcp_tool=tool,
mcp_client=mcp_client,
mcp_server_name=name,
)
self.func_list.append(func_tool)

logger.info(
f"Connected to MCP server {name}, "
f"Tools: {[t.name for t in mcp_client.tools]}"
)
except asyncio.TimeoutError as exc:
raise MCPInitTimeoutError(
f"Connected to MCP server {name} timeout ({timeout:g} seconds)"
) from exc
except Exception:
logger.error(f"Failed to initialize MCP client {name}", exc_info=True)
raise
finally:
if mcp_client is None:
async with self._runtime_lock:
self._mcp_starting.discard(name)

async def lifecycle() -> None:
connect_done.set()

try:
await shutdown_event.wait()
logger.info(f"Received shutdown signal for MCP client {name}")
except asyncio.CancelledError:
logger.debug(f"MCP client {name} task was cancelled")
raise
finally:
await self._terminate_mcp_client(name)
# Cleanup in the same task that entered the anyio contexts
await asyncio.shield(self._terminate_mcp_client(name))

lifecycle_task = asyncio.create_task(lifecycle(), name=f"mcp-client:{name}")
lifecycle_task = asyncio.create_task(
connect_and_lifecycle(), name=f"mcp-client:{name}"
)
async with self._runtime_lock:
self._mcp_server_runtime[name] = _MCPServerRuntime(
name=name,
Expand All @@ -704,6 +738,26 @@ async def lifecycle() -> None:
)
self._mcp_starting.discard(name)

try:
await asyncio.wait_for(connect_done.wait(), timeout=timeout)
except (asyncio.TimeoutError, asyncio.CancelledError) as e:
lifecycle_task.cancel()
await asyncio.gather(lifecycle_task, return_exceptions=True)
async with self._runtime_lock:
self._mcp_starting.discard(name)
self._mcp_server_runtime.pop(name, None)
if isinstance(e, asyncio.TimeoutError):
raise MCPInitTimeoutError(
f"Connected to MCP server {name} timeout ({timeout:g} seconds)"
) from e
raise

if connect_error is not None:
async with self._runtime_lock:
self._mcp_starting.discard(name)
self._mcp_server_runtime.pop(name, None)
raise connect_error

async def _shutdown_runtimes(
self,
runtimes: list[_MCPServerRuntime],
Expand Down Expand Up @@ -768,41 +822,6 @@ async def _cleanup_mcp_client_safely(
f"Failed to cleanup MCP client resources {name}: {cleanup_exc}"
)

async def _init_mcp_client(self, name: str, config: dict) -> MCPClient:
"""初始化单个MCP客户端"""
mcp_client = MCPClient()
mcp_client.name = name
try:
await mcp_client.connect_to_server(config, name)
tools_res = await mcp_client.list_tools_and_save()
except asyncio.CancelledError:
await self._cleanup_mcp_client_safely(mcp_client, name)
raise
except Exception:
await self._cleanup_mcp_client_safely(mcp_client, name)
raise
logger.debug(f"MCP server {name} list tools response: {tools_res}")
tool_names = [tool.name for tool in tools_res.tools]

# 移除该MCP服务之前的工具(如有)
self.func_list = [
f
for f in self.func_list
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]

# 将 MCP 工具转换为 FuncTool 并添加到 func_list
for tool in mcp_client.tools:
func_tool = MCPTool(
mcp_tool=tool,
mcp_client=mcp_client,
mcp_server_name=name,
)
self.func_list.append(func_tool)

logger.info(f"Connected to MCP server {name}, Tools: {tool_names}")
return mcp_client

async def _terminate_mcp_client(self, name: str) -> None:
"""关闭并清理MCP客户端"""
async with self._runtime_lock:
Expand Down