Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 98 additions & 14 deletions agentrun/tool/api/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@

import httpx

from agentrun.tool.model import ToolInfo, ToolSchema
from agentrun.tool.model import ToolInfo
from agentrun.utils.config import Config
from agentrun.utils.log import logger
from agentrun.utils.ram_signature import get_agentrun_signed_headers

_MCP_METADATA_TIMEOUT_SECONDS = 30.0


def _get_or_create_event_loop() -> asyncio.AbstractEventLoop:
Expand All @@ -30,9 +33,6 @@ def _get_or_create_event_loop() -> asyncio.AbstractEventLoop:
return loop


from agentrun.utils.ram_signature import get_agentrun_signed_headers


class _AgentrunRamAuth(httpx.Auth):
"""httpx Auth handler:为每次请求动态生成 RAM 签名。

Expand Down Expand Up @@ -144,6 +144,54 @@ def is_streamable(self) -> bool:
"""是否使用 Streamable HTTP 传输 / Whether to use Streamable HTTP transport"""
return self.session_affinity == "MCP_STREAMABLE"

def _metadata_timeout_seconds(self) -> float:
timeout = self.config.get_timeout()
if timeout and timeout > 0:
return min(float(timeout), _MCP_METADATA_TIMEOUT_SECONDS)
return _MCP_METADATA_TIMEOUT_SECONDS

def _invoke_timeout_seconds(self) -> float:
timeout = self.config.get_timeout()
if timeout and timeout > 0:
return float(timeout)
return 600.0

async def _wait_for_mcp_request(
self,
awaitable: Any,
operation: str,
timeout: float,
) -> Any:
try:
return await asyncio.wait_for(awaitable, timeout=timeout)
except asyncio.TimeoutError as exc:
raise TimeoutError(
f"MCP {operation} timed out after {timeout:g}s for endpoint"
f" {self.endpoint}"
) from exc

def _find_mcp_timeout_error(
self, exc: BaseException
) -> Optional[TimeoutError]:
if isinstance(exc, TimeoutError) and str(exc).startswith("MCP "):
return exc

nested_exceptions = getattr(exc, "exceptions", None)
if not nested_exceptions:
return None

for nested_exc in nested_exceptions:
timeout_error = self._find_mcp_timeout_error(nested_exc)
if timeout_error is not None:
return timeout_error

return None

def _raise_mcp_timeout_if_present(self, exc: BaseException) -> None:
timeout_error = self._find_mcp_timeout_error(exc)
if timeout_error is not None:
raise timeout_error

def _build_ram_auth(self, url: str) -> tuple:
"""当目标是 agentrun-data 域名时,改写 URL 并返回 httpx Auth handler。

Expand Down Expand Up @@ -199,8 +247,17 @@ async def list_tools_async(self) -> List[ToolInfo]:
async with ClientSession(
read_stream, write_stream
) as session:
await session.initialize()
result = await session.list_tools()
metadata_timeout = self._metadata_timeout_seconds()
await self._wait_for_mcp_request(
session.initialize(),
"initialize",
metadata_timeout,
)
result = await self._wait_for_mcp_request(
session.list_tools(),
"list_tools",
metadata_timeout,
)
return [
ToolInfo.from_mcp_tool(tool)
for tool in result.tools
Expand All @@ -215,8 +272,17 @@ async def list_tools_async(self) -> List[ToolInfo]:
async with ClientSession(
read_stream, write_stream
) as session:
await session.initialize()
result = await session.list_tools()
metadata_timeout = self._metadata_timeout_seconds()
await self._wait_for_mcp_request(
session.initialize(),
"initialize",
metadata_timeout,
)
result = await self._wait_for_mcp_request(
session.list_tools(),
"list_tools",
metadata_timeout,
)
return [
ToolInfo.from_mcp_tool(tool)
for tool in result.tools
Expand All @@ -226,6 +292,9 @@ async def list_tools_async(self) -> List[ToolInfo]:
"mcp package is not installed. Install it with: pip install mcp"
)
return []
except Exception as exc:
self._raise_mcp_timeout_if_present(exc)
raise

def list_tools(self) -> List[ToolInfo]:
"""同步获取工具列表 / Get tool list synchronously
Expand Down Expand Up @@ -266,9 +335,15 @@ async def call_tool_async(
async with ClientSession(
read_stream, write_stream
) as session:
await session.initialize()
result = await session.call_tool(
name, arguments=arguments or {}
await self._wait_for_mcp_request(
session.initialize(),
"initialize",
self._metadata_timeout_seconds(),
)
result = await self._wait_for_mcp_request(
session.call_tool(name, arguments=arguments or {}),
f"call_tool {name}",
self._invoke_timeout_seconds(),
)
return result
else:
Expand All @@ -281,16 +356,25 @@ async def call_tool_async(
async with ClientSession(
read_stream, write_stream
) as session:
await session.initialize()
result = await session.call_tool(
name, arguments=arguments or {}
await self._wait_for_mcp_request(
session.initialize(),
"initialize",
self._metadata_timeout_seconds(),
)
result = await self._wait_for_mcp_request(
session.call_tool(name, arguments=arguments or {}),
f"call_tool {name}",
self._invoke_timeout_seconds(),
)
return result
except ImportError:
raise ImportError(
"mcp package is required for MCP tool calls. "
"Install it with: pip install mcp"
)
except Exception as exc:
self._raise_mcp_timeout_if_present(exc)
raise

def call_tool(
self,
Expand Down
87 changes: 87 additions & 0 deletions tests/e2e/test_mcp_malformed_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""E2E regression tests for malformed MCP streamable-http responses."""

import asyncio
import socket
import threading
import time

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import httpx
import pytest
import uvicorn

from agentrun.tool.api.mcp import ToolMCPSession
from agentrun.utils.config import Config


def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()[1]


def _build_malformed_mcp_app() -> FastAPI:
app = FastAPI()

@app.get("/health")
async def health():
return {"ok": True}

@app.post("/mcp")
async def mcp_endpoint(request: Request):
payload = await request.json()
return JSONResponse(
{
"jsonrpc": "2.0",
"id": payload.get("id"),
"error": {
"code": -32000,
"message": None,
},
}
)

return app


@pytest.fixture
def malformed_mcp_server():
app = _build_malformed_mcp_app()
port = _find_free_port()
config = uvicorn.Config(
app, host="127.0.0.1", port=port, log_level="warning"
)
server = uvicorn.Server(config)

thread = threading.Thread(target=server.run, daemon=True)
thread.start()

base_url = f"http://127.0.0.1:{port}"
for _ in range(50):
try:
httpx.get(f"{base_url}/health", timeout=0.2)
break
except Exception:
time.sleep(0.1)
else:
raise RuntimeError("malformed MCP server did not start")

yield f"{base_url}/mcp"

server.should_exit = True
thread.join(timeout=5)


@pytest.mark.asyncio
async def test_streamable_mcp_malformed_initialize_error_fails_fast(
malformed_mcp_server,
):
session = ToolMCPSession(
endpoint=malformed_mcp_server,
session_affinity="MCP_STREAMABLE",
config=Config(timeout=0.05),
)

with pytest.raises(TimeoutError, match="MCP initialize timed out"):
await asyncio.wait_for(session.list_tools_async(), timeout=1)
59 changes: 58 additions & 1 deletion tests/unittests/tool/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
Tests MCP protocol interaction functionality of ToolMCPSession.
"""

import asyncio
import sys
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from agentrun.tool.api.mcp import ToolMCPSession
from agentrun.tool.model import ToolInfo
from agentrun.utils.config import Config


class TestToolMCPSessionInit:
Expand Down Expand Up @@ -186,6 +188,36 @@ def mock_import(name, *args, **kwargs):
sys.modules.update(saved_modules)
assert tools == []

@pytest.mark.asyncio
async def test_list_tools_async_initialize_timeout(self):
"""测试 initialize 无响应时不会无限等待"""

async def never_return():
await asyncio.Event().wait()

mock_session = AsyncMock()
mock_session.initialize = never_return
mock_session.list_tools = AsyncMock()

mock_modules = _setup_mock_mcp_modules(mock_session)

with patch.dict(sys.modules, mock_modules):
with patch(
"agentrun.tool.api.mcp._MCP_METADATA_TIMEOUT_SECONDS",
0.01,
):
session = ToolMCPSession(
endpoint="http://example.com/mcp",
session_affinity="MCP_STREAMABLE",
)

with pytest.raises(
TimeoutError, match="MCP initialize timed out"
):
await session.list_tools_async()

mock_session.list_tools.assert_not_called()


class TestToolMCPSessionListTools:
"""测试 list_tools 同步方法"""
Expand Down Expand Up @@ -258,6 +290,31 @@ async def test_call_tool_async_sse_mode(self):

assert result == mock_call_result

@pytest.mark.asyncio
async def test_call_tool_async_timeout(self):
"""测试工具调用无响应时会按 Config.timeout 退出"""

async def never_return(*args, **kwargs):
await asyncio.Event().wait()

mock_session = AsyncMock()
mock_session.initialize = AsyncMock()
mock_session.call_tool = never_return

mock_modules = _setup_mock_mcp_modules(mock_session)

with patch.dict(sys.modules, mock_modules):
session = ToolMCPSession(
endpoint="http://example.com/mcp",
session_affinity="MCP_STREAMABLE",
config=Config(timeout=0.01),
)

with pytest.raises(
TimeoutError, match="MCP call_tool test_tool timed out"
):
await session.call_tool_async("test_tool", {"key": "val"})

@pytest.mark.asyncio
async def test_call_tool_async_import_error(self):
"""测试 mcp 未安装时抛出 ImportError"""
Expand Down
Loading