From 560de5cab5f897ed878a3c77c4257b0b8cdeff11 Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Sun, 17 May 2026 00:37:39 +0200 Subject: [PATCH 1/2] Validate full sampling tool result history --- src/mcp/server/validation.py | 47 +++++++------- tests/server/test_session.py | 112 ++++++++++++++++++++++++++++++++ tests/server/test_validation.py | 85 ++++++++++++++++++++++++ 3 files changed, 222 insertions(+), 22 deletions(-) diff --git a/src/mcp/server/validation.py b/src/mcp/server/validation.py index 08f5754f1..4076464b6 100644 --- a/src/mcp/server/validation.py +++ b/src/mcp/server/validation.py @@ -4,7 +4,7 @@ """ from mcp.shared.exceptions import MCPError -from mcp.types import INVALID_PARAMS, ClientCapabilities, SamplingMessage, Tool, ToolChoice +from mcp.types import INVALID_PARAMS, ClientCapabilities, SamplingMessage, SamplingMessageContentBlock, Tool, ToolChoice def check_sampling_tools_capability(client_caps: ClientCapabilities | None) -> bool: @@ -52,6 +52,7 @@ def validate_tool_use_result_messages(messages: list[SamplingMessage]) -> None: 1. Messages with tool_result content contain ONLY tool_result content 2. tool_result messages are preceded by a message with tool_use 3. tool_result IDs match the tool_use IDs from the previous message + 4. Every tool_use message in the history is followed by matching tool_result content See: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1577 @@ -64,24 +65,26 @@ def validate_tool_use_result_messages(messages: list[SamplingMessage]) -> None: if not messages: return - last_content = messages[-1].content_as_list - has_tool_results = any(c.type == "tool_result" for c in last_content) - - previous_content = messages[-2].content_as_list if len(messages) >= 2 else None - has_previous_tool_use = previous_content and any(c.type == "tool_use" for c in previous_content) - - if has_tool_results: - # Per spec: "SamplingMessage with tool result content blocks - # MUST NOT contain other content types." - if any(c.type != "tool_result" for c in last_content): - raise ValueError("The last message must contain only tool_result content if any is present") - if previous_content is None: - raise ValueError("tool_result requires a previous message containing tool_use") - if not has_previous_tool_use: - raise ValueError("tool_result blocks do not match any tool_use in the previous message") - - if has_previous_tool_use and previous_content: - tool_use_ids = {c.id for c in previous_content if c.type == "tool_use"} - tool_result_ids = {c.tool_use_id for c in last_content if c.type == "tool_result"} - if tool_use_ids != tool_result_ids: - raise ValueError("ids of tool_result blocks and tool_use blocks from previous message do not match") + previous_content: list[SamplingMessageContentBlock] | None = None + for content in (message.content_as_list for message in messages): + has_tool_results = any(c.type == "tool_result" for c in content) + previous_tool_use_ids: set[str] = set() + if previous_content is not None: + previous_tool_use_ids = {c.id for c in previous_content if c.type == "tool_use"} + + if has_tool_results: + # Per spec: "SamplingMessage with tool result content blocks + # MUST NOT contain other content types." + if any(c.type != "tool_result" for c in content): + raise ValueError("A message must contain only tool_result content if any is present") + if previous_content is None: + raise ValueError("tool_result requires a previous message containing tool_use") + if not previous_tool_use_ids: + raise ValueError("tool_result blocks do not match any tool_use in the previous message") + + if previous_tool_use_ids: + tool_result_ids = {c.tool_use_id for c in content if c.type == "tool_result"} + if previous_tool_use_ids != tool_result_ids: + raise ValueError("ids of tool_result blocks and tool_use blocks from previous message do not match") + + previous_content = content diff --git a/tests/server/test_session.py b/tests/server/test_session.py index a713a79b6..7e66e6e38 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -125,6 +125,118 @@ async def test_send_request_without_back_channel_or_related_id_fails_fast(): assert dispatcher.requests[0][3] == 3 +@pytest.mark.anyio +async def test_create_message_tool_result_validation(): + """Test tool_use/tool_result validation in create_message.""" + dispatcher = StubDispatcher( + result={"role": "assistant", "content": [{"type": "text", "text": "ok"}], "model": "m"} + ) + session = _make_session( + dispatcher, capabilities=ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())) + ) + tool = types.Tool(name="test_tool", input_schema={"type": "object"}) + text = types.TextContent(type="text", text="hello") + tool_use = types.ToolUseContent(type="tool_use", id="call_1", name="test_tool", input={}) + tool_result = types.ToolResultContent(type="tool_result", tool_use_id="call_1", content=[]) + + # Case 1: tool_result mixed with other content + with pytest.raises(ValueError, match="only tool_result content"): + await session.create_message( + messages=[ + types.SamplingMessage(role="user", content=text), + types.SamplingMessage(role="assistant", content=tool_use), + types.SamplingMessage(role="user", content=[tool_result, text]), + ], + max_tokens=100, + tools=[tool], + ) + + # Case 2: tool_result without previous message + with pytest.raises(ValueError, match="requires a previous message"): + await session.create_message( + messages=[types.SamplingMessage(role="user", content=tool_result)], + max_tokens=100, + tools=[tool], + ) + + # Case 3: tool_result without previous tool_use + with pytest.raises(ValueError, match="do not match any tool_use"): + await session.create_message( + messages=[ + types.SamplingMessage(role="user", content=text), + types.SamplingMessage(role="user", content=tool_result), + ], + max_tokens=100, + tools=[tool], + ) + + # Case 4: mismatched tool IDs + with pytest.raises(ValueError, match="ids of tool_result blocks and tool_use blocks"): + await session.create_message( + messages=[ + types.SamplingMessage(role="user", content=text), + types.SamplingMessage(role="assistant", content=tool_use), + types.SamplingMessage( + role="user", + content=types.ToolResultContent(type="tool_result", tool_use_id="wrong_id", content=[]), + ), + ], + max_tokens=100, + tools=[tool], + ) + + # Case 4b: earlier mismatched tool result with a later plain message + with pytest.raises(ValueError, match="ids of tool_result blocks and tool_use blocks"): + await session.create_message( + messages=[ + types.SamplingMessage(role="assistant", content=tool_use), + types.SamplingMessage( + role="user", + content=types.ToolResultContent(type="tool_result", tool_use_id="wrong_id", content=[]), + ), + types.SamplingMessage(role="assistant", content=text), + ], + max_tokens=100, + tools=[tool], + ) + + # Case 5: text-only message with tools (no tool_results) - passes validation + await session.create_message( + messages=[types.SamplingMessage(role="user", content=text)], + max_tokens=100, + tools=[tool], + ) + + # Case 6: valid matching tool_result/tool_use IDs - passes validation + await session.create_message( + messages=[ + types.SamplingMessage(role="user", content=text), + types.SamplingMessage(role="assistant", content=tool_use), + types.SamplingMessage(role="user", content=tool_result), + ], + max_tokens=100, + tools=[tool], + ) + + # Case 7: validation runs even without `tools` parameter + # (tool loop continuation may omit tools while containing tool_result) + with pytest.raises(ValueError, match="do not match any tool_use"): + await session.create_message( + messages=[ + types.SamplingMessage(role="user", content=text), + types.SamplingMessage(role="user", content=tool_result), + ], + max_tokens=100, + ) + + # Case 8: empty messages list - skips validation entirely + no_tools_session = _make_session( + StubDispatcher(result={"role": "assistant", "content": {"type": "text", "text": "ok"}, "model": "m"}), + capabilities=ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())), + ) + await no_tools_session.create_message(messages=[], max_tokens=100) + + @pytest.mark.anyio async def test_send_request_validates_result_alias_only(): """Peer results validate alias-only; a snake_case key from the wire is diff --git a/tests/server/test_validation.py b/tests/server/test_validation.py index 19f4eb108..b6d2a8e6f 100644 --- a/tests/server/test_validation.py +++ b/tests/server/test_validation.py @@ -108,6 +108,27 @@ def test_validate_tool_use_result_messages_raises_when_tool_result_mixed_with_ot validate_tool_use_result_messages(messages) +def test_validate_tool_use_result_messages_raises_for_earlier_mixed_tool_result() -> None: + """Raises when an earlier message mixes tool_result with other content.""" + messages = [ + SamplingMessage( + role="assistant", + content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}), + ), + SamplingMessage( + role="user", + content=[ + ToolResultContent(type="tool_result", tool_use_id="tool-1"), + TextContent(type="text", text="also this"), + ], + ), + SamplingMessage(role="assistant", content=TextContent(type="text", text="done")), + ] + + with pytest.raises(ValueError, match="only tool_result content"): + validate_tool_use_result_messages(messages) + + def test_validate_tool_use_result_messages_raises_when_tool_result_without_previous_tool_use() -> None: """Raises when tool_result appears without preceding tool_use.""" messages = [ @@ -146,6 +167,39 @@ def test_validate_tool_use_result_messages_raises_when_tool_result_ids_dont_matc validate_tool_use_result_messages(messages) +def test_validate_tool_use_result_messages_raises_when_earlier_tool_result_ids_dont_match_tool_use() -> None: + """Raises when an earlier tool_result does not match the previous tool_use.""" + messages = [ + SamplingMessage( + role="assistant", + content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}), + ), + SamplingMessage( + role="user", + content=ToolResultContent(type="tool_result", tool_use_id="tool-2"), + ), + SamplingMessage(role="assistant", content=TextContent(type="text", text="done")), + ] + + with pytest.raises(ValueError, match="do not match"): + validate_tool_use_result_messages(messages) + + +def test_validate_tool_use_result_messages_raises_when_tool_use_is_not_answered() -> None: + """Raises when a tool_use is followed by a non-tool_result message.""" + messages = [ + SamplingMessage( + role="assistant", + content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}), + ), + SamplingMessage(role="user", content=TextContent(type="text", text="not a result")), + SamplingMessage(role="assistant", content=TextContent(type="text", text="done")), + ] + + with pytest.raises(ValueError, match="do not match"): + validate_tool_use_result_messages(messages) + + def test_validate_tool_use_result_messages_no_error_when_tool_result_matches_tool_use() -> None: """No error when tool_result IDs match tool_use IDs.""" messages = [ @@ -159,3 +213,34 @@ def test_validate_tool_use_result_messages_no_error_when_tool_result_matches_too ), ] validate_tool_use_result_messages(messages) # Should not raise + + +def test_validate_tool_use_result_messages_no_error_for_multiple_tool_pairs() -> None: + """No error when every tool_use in the history has a matching tool_result.""" + messages = [ + SamplingMessage(role="user", content=TextContent(type="text", text="first")), + SamplingMessage( + role="assistant", + content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}), + ), + SamplingMessage( + role="user", + content=ToolResultContent(type="tool_result", tool_use_id="tool-1"), + ), + SamplingMessage( + role="assistant", + content=[ + ToolUseContent(type="tool_use", id="tool-2", name="test", input={}), + ToolUseContent(type="tool_use", id="tool-3", name="test", input={}), + ], + ), + SamplingMessage( + role="user", + content=[ + ToolResultContent(type="tool_result", tool_use_id="tool-3"), + ToolResultContent(type="tool_result", tool_use_id="tool-2"), + ], + ), + ] + + validate_tool_use_result_messages(messages) From 5cbbff53b18e4e65fb21a0511e1d1c1ad6fe3a2b Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Tue, 9 Jun 2026 14:57:57 +0200 Subject: [PATCH 2/2] Fix sampling validation test expectations --- tests/interaction/lowlevel/test_sampling.py | 4 +--- tests/server/test_session.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/interaction/lowlevel/test_sampling.py b/tests/interaction/lowlevel/test_sampling.py index 260e56419..1693942fd 100644 --- a/tests/interaction/lowlevel/test_sampling.py +++ b/tests/interaction/lowlevel/test_sampling.py @@ -395,9 +395,7 @@ async def sampling_callback( assert result == snapshot( CallToolResult( - content=[ - TextContent(text="ValueError: The last message must contain only tool_result content if any is present") - ] + content=[TextContent(text="ValueError: A message must contain only tool_result content if any is present")] ) ) diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 7e66e6e38..15eaf125f 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -128,9 +128,7 @@ async def test_send_request_without_back_channel_or_related_id_fails_fast(): @pytest.mark.anyio async def test_create_message_tool_result_validation(): """Test tool_use/tool_result validation in create_message.""" - dispatcher = StubDispatcher( - result={"role": "assistant", "content": [{"type": "text", "text": "ok"}], "model": "m"} - ) + dispatcher = StubDispatcher(result={"role": "assistant", "content": [{"type": "text", "text": "ok"}], "model": "m"}) session = _make_session( dispatcher, capabilities=ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())) )