diff --git a/src/mcp/server/validation.py b/src/mcp/server/validation.py index 08f5754f1..4076464b6 100644 --- a/src/mcp/server/validation.py +++ b/src/mcp/server/validation.py @@ -4,7 +4,7 @@ """ from mcp.shared.exceptions import MCPError -from mcp.types import INVALID_PARAMS, ClientCapabilities, SamplingMessage, Tool, ToolChoice +from mcp.types import INVALID_PARAMS, ClientCapabilities, SamplingMessage, SamplingMessageContentBlock, Tool, ToolChoice def check_sampling_tools_capability(client_caps: ClientCapabilities | None) -> bool: @@ -52,6 +52,7 @@ def validate_tool_use_result_messages(messages: list[SamplingMessage]) -> None: 1. Messages with tool_result content contain ONLY tool_result content 2. tool_result messages are preceded by a message with tool_use 3. tool_result IDs match the tool_use IDs from the previous message + 4. Every tool_use message in the history is followed by matching tool_result content See: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1577 @@ -64,24 +65,26 @@ def validate_tool_use_result_messages(messages: list[SamplingMessage]) -> None: if not messages: return - last_content = messages[-1].content_as_list - has_tool_results = any(c.type == "tool_result" for c in last_content) - - previous_content = messages[-2].content_as_list if len(messages) >= 2 else None - has_previous_tool_use = previous_content and any(c.type == "tool_use" for c in previous_content) - - if has_tool_results: - # Per spec: "SamplingMessage with tool result content blocks - # MUST NOT contain other content types." - if any(c.type != "tool_result" for c in last_content): - raise ValueError("The last message must contain only tool_result content if any is present") - if previous_content is None: - raise ValueError("tool_result requires a previous message containing tool_use") - if not has_previous_tool_use: - raise ValueError("tool_result blocks do not match any tool_use in the previous message") - - if has_previous_tool_use and previous_content: - tool_use_ids = {c.id for c in previous_content if c.type == "tool_use"} - tool_result_ids = {c.tool_use_id for c in last_content if c.type == "tool_result"} - if tool_use_ids != tool_result_ids: - raise ValueError("ids of tool_result blocks and tool_use blocks from previous message do not match") + previous_content: list[SamplingMessageContentBlock] | None = None + for content in (message.content_as_list for message in messages): + has_tool_results = any(c.type == "tool_result" for c in content) + previous_tool_use_ids: set[str] = set() + if previous_content is not None: + previous_tool_use_ids = {c.id for c in previous_content if c.type == "tool_use"} + + if has_tool_results: + # Per spec: "SamplingMessage with tool result content blocks + # MUST NOT contain other content types." + if any(c.type != "tool_result" for c in content): + raise ValueError("A message must contain only tool_result content if any is present") + if previous_content is None: + raise ValueError("tool_result requires a previous message containing tool_use") + if not previous_tool_use_ids: + raise ValueError("tool_result blocks do not match any tool_use in the previous message") + + if previous_tool_use_ids: + tool_result_ids = {c.tool_use_id for c in content if c.type == "tool_result"} + if previous_tool_use_ids != tool_result_ids: + raise ValueError("ids of tool_result blocks and tool_use blocks from previous message do not match") + + previous_content = content diff --git a/tests/interaction/lowlevel/test_sampling.py b/tests/interaction/lowlevel/test_sampling.py index 260e56419..1693942fd 100644 --- a/tests/interaction/lowlevel/test_sampling.py +++ b/tests/interaction/lowlevel/test_sampling.py @@ -395,9 +395,7 @@ async def sampling_callback( assert result == snapshot( CallToolResult( - content=[ - TextContent(text="ValueError: The last message must contain only tool_result content if any is present") - ] + content=[TextContent(text="ValueError: A message must contain only tool_result content if any is present")] ) ) diff --git a/tests/server/test_session.py b/tests/server/test_session.py index a713a79b6..15eaf125f 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -125,6 +125,116 @@ async def test_send_request_without_back_channel_or_related_id_fails_fast(): assert dispatcher.requests[0][3] == 3 +@pytest.mark.anyio +async def test_create_message_tool_result_validation(): + """Test tool_use/tool_result validation in create_message.""" + dispatcher = StubDispatcher(result={"role": "assistant", "content": [{"type": "text", "text": "ok"}], "model": "m"}) + session = _make_session( + dispatcher, capabilities=ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())) + ) + tool = types.Tool(name="test_tool", input_schema={"type": "object"}) + text = types.TextContent(type="text", text="hello") + tool_use = types.ToolUseContent(type="tool_use", id="call_1", name="test_tool", input={}) + tool_result = types.ToolResultContent(type="tool_result", tool_use_id="call_1", content=[]) + + # Case 1: tool_result mixed with other content + with pytest.raises(ValueError, match="only tool_result content"): + await session.create_message( + messages=[ + types.SamplingMessage(role="user", content=text), + types.SamplingMessage(role="assistant", content=tool_use), + types.SamplingMessage(role="user", content=[tool_result, text]), + ], + max_tokens=100, + tools=[tool], + ) + + # Case 2: tool_result without previous message + with pytest.raises(ValueError, match="requires a previous message"): + await session.create_message( + messages=[types.SamplingMessage(role="user", content=tool_result)], + max_tokens=100, + tools=[tool], + ) + + # Case 3: tool_result without previous tool_use + with pytest.raises(ValueError, match="do not match any tool_use"): + await session.create_message( + messages=[ + types.SamplingMessage(role="user", content=text), + types.SamplingMessage(role="user", content=tool_result), + ], + max_tokens=100, + tools=[tool], + ) + + # Case 4: mismatched tool IDs + with pytest.raises(ValueError, match="ids of tool_result blocks and tool_use blocks"): + await session.create_message( + messages=[ + types.SamplingMessage(role="user", content=text), + types.SamplingMessage(role="assistant", content=tool_use), + types.SamplingMessage( + role="user", + content=types.ToolResultContent(type="tool_result", tool_use_id="wrong_id", content=[]), + ), + ], + max_tokens=100, + tools=[tool], + ) + + # Case 4b: earlier mismatched tool result with a later plain message + with pytest.raises(ValueError, match="ids of tool_result blocks and tool_use blocks"): + await session.create_message( + messages=[ + types.SamplingMessage(role="assistant", content=tool_use), + types.SamplingMessage( + role="user", + content=types.ToolResultContent(type="tool_result", tool_use_id="wrong_id", content=[]), + ), + types.SamplingMessage(role="assistant", content=text), + ], + max_tokens=100, + tools=[tool], + ) + + # Case 5: text-only message with tools (no tool_results) - passes validation + await session.create_message( + messages=[types.SamplingMessage(role="user", content=text)], + max_tokens=100, + tools=[tool], + ) + + # Case 6: valid matching tool_result/tool_use IDs - passes validation + await session.create_message( + messages=[ + types.SamplingMessage(role="user", content=text), + types.SamplingMessage(role="assistant", content=tool_use), + types.SamplingMessage(role="user", content=tool_result), + ], + max_tokens=100, + tools=[tool], + ) + + # Case 7: validation runs even without `tools` parameter + # (tool loop continuation may omit tools while containing tool_result) + with pytest.raises(ValueError, match="do not match any tool_use"): + await session.create_message( + messages=[ + types.SamplingMessage(role="user", content=text), + types.SamplingMessage(role="user", content=tool_result), + ], + max_tokens=100, + ) + + # Case 8: empty messages list - skips validation entirely + no_tools_session = _make_session( + StubDispatcher(result={"role": "assistant", "content": {"type": "text", "text": "ok"}, "model": "m"}), + capabilities=ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())), + ) + await no_tools_session.create_message(messages=[], max_tokens=100) + + @pytest.mark.anyio async def test_send_request_validates_result_alias_only(): """Peer results validate alias-only; a snake_case key from the wire is diff --git a/tests/server/test_validation.py b/tests/server/test_validation.py index 19f4eb108..b6d2a8e6f 100644 --- a/tests/server/test_validation.py +++ b/tests/server/test_validation.py @@ -108,6 +108,27 @@ def test_validate_tool_use_result_messages_raises_when_tool_result_mixed_with_ot validate_tool_use_result_messages(messages) +def test_validate_tool_use_result_messages_raises_for_earlier_mixed_tool_result() -> None: + """Raises when an earlier message mixes tool_result with other content.""" + messages = [ + SamplingMessage( + role="assistant", + content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}), + ), + SamplingMessage( + role="user", + content=[ + ToolResultContent(type="tool_result", tool_use_id="tool-1"), + TextContent(type="text", text="also this"), + ], + ), + SamplingMessage(role="assistant", content=TextContent(type="text", text="done")), + ] + + with pytest.raises(ValueError, match="only tool_result content"): + validate_tool_use_result_messages(messages) + + def test_validate_tool_use_result_messages_raises_when_tool_result_without_previous_tool_use() -> None: """Raises when tool_result appears without preceding tool_use.""" messages = [ @@ -146,6 +167,39 @@ def test_validate_tool_use_result_messages_raises_when_tool_result_ids_dont_matc validate_tool_use_result_messages(messages) +def test_validate_tool_use_result_messages_raises_when_earlier_tool_result_ids_dont_match_tool_use() -> None: + """Raises when an earlier tool_result does not match the previous tool_use.""" + messages = [ + SamplingMessage( + role="assistant", + content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}), + ), + SamplingMessage( + role="user", + content=ToolResultContent(type="tool_result", tool_use_id="tool-2"), + ), + SamplingMessage(role="assistant", content=TextContent(type="text", text="done")), + ] + + with pytest.raises(ValueError, match="do not match"): + validate_tool_use_result_messages(messages) + + +def test_validate_tool_use_result_messages_raises_when_tool_use_is_not_answered() -> None: + """Raises when a tool_use is followed by a non-tool_result message.""" + messages = [ + SamplingMessage( + role="assistant", + content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}), + ), + SamplingMessage(role="user", content=TextContent(type="text", text="not a result")), + SamplingMessage(role="assistant", content=TextContent(type="text", text="done")), + ] + + with pytest.raises(ValueError, match="do not match"): + validate_tool_use_result_messages(messages) + + def test_validate_tool_use_result_messages_no_error_when_tool_result_matches_tool_use() -> None: """No error when tool_result IDs match tool_use IDs.""" messages = [ @@ -159,3 +213,34 @@ def test_validate_tool_use_result_messages_no_error_when_tool_result_matches_too ), ] validate_tool_use_result_messages(messages) # Should not raise + + +def test_validate_tool_use_result_messages_no_error_for_multiple_tool_pairs() -> None: + """No error when every tool_use in the history has a matching tool_result.""" + messages = [ + SamplingMessage(role="user", content=TextContent(type="text", text="first")), + SamplingMessage( + role="assistant", + content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}), + ), + SamplingMessage( + role="user", + content=ToolResultContent(type="tool_result", tool_use_id="tool-1"), + ), + SamplingMessage( + role="assistant", + content=[ + ToolUseContent(type="tool_use", id="tool-2", name="test", input={}), + ToolUseContent(type="tool_use", id="tool-3", name="test", input={}), + ], + ), + SamplingMessage( + role="user", + content=[ + ToolResultContent(type="tool_result", tool_use_id="tool-3"), + ToolResultContent(type="tool_result", tool_use_id="tool-2"), + ], + ), + ] + + validate_tool_use_result_messages(messages)