From 050f191dbd22694acbd18a67c455598bf955301d Mon Sep 17 00:00:00 2001 From: EterUltimate <1831303476@qq.com> Date: Sun, 21 Jun 2026 09:24:33 +0800 Subject: [PATCH 1/2] fix: ignore malformed tool call names --- .../agent/runners/tool_loop_agent_runner.py | 73 ++++++++++++ tests/test_tool_loop_agent_runner.py | 111 ++++++++++++++++++ 2 files changed, 184 insertions(+) diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index b56d7e62fb..26c926adfe 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -790,6 +790,8 @@ async def step(self): ) return + self._sanitize_tool_call_entries(llm_resp) + if not llm_resp.tools_call_name: await self._complete_with_assistant_response(llm_resp) @@ -1269,6 +1271,74 @@ def _has_meaningful_assistant_reply(llm_resp: LLMResponse) -> bool: text = (llm_resp.completion_text or "").strip() return bool(text) + @staticmethod + def _sanitize_tool_call_entries(llm_resp: LLMResponse) -> None: + """Drop malformed tool calls and keep names/args/ids aligned.""" + if not llm_resp.tools_call_name: + return + + tool_names: list[str] = [] + tool_args: list[dict[str, T.Any]] = [] + tool_ids: list[str] = [] + tool_extra_content: dict[str, dict[str, T.Any]] = {} + dropped_count = 0 + changed = False + + for idx, tool_name in enumerate(llm_resp.tools_call_name): + if not isinstance(tool_name, str) or not tool_name: + dropped_count += 1 + changed = True + continue + + tool_names.append(tool_name) + + if idx < len(llm_resp.tools_call_args): + raw_args = llm_resp.tools_call_args[idx] + if isinstance(raw_args, dict): + tool_args.append(raw_args) + else: + tool_args.append({}) + changed = True + else: + tool_args.append({}) + changed = True + + if idx < len(llm_resp.tools_call_ids): + raw_tool_id = llm_resp.tools_call_ids[idx] + tool_id = ( + raw_tool_id + if isinstance(raw_tool_id, str) and raw_tool_id + else f"call_{uuid.uuid4().hex[:8]}" + ) + if tool_id != raw_tool_id: + changed = True + else: + tool_id = f"call_{uuid.uuid4().hex[:8]}" + changed = True + tool_ids.append(tool_id) + + extra_content = llm_resp.tools_call_extra_content.get(tool_id) + if extra_content: + tool_extra_content[tool_id] = extra_content + + if len(tool_names) != len(llm_resp.tools_call_name): + changed = True + if len(tool_args) != len(llm_resp.tools_call_args): + changed = True + if len(tool_ids) != len(llm_resp.tools_call_ids): + changed = True + + if dropped_count: + logger.warning( + f"Dropped {dropped_count} malformed tool call(s) with empty names." + ) + + if changed: + llm_resp.tools_call_name = tool_names + llm_resp.tools_call_args = tool_args + llm_resp.tools_call_ids = tool_ids + llm_resp.tools_call_extra_content = tool_extra_content + def _build_tool_subset(self, tool_set: ToolSet, tool_names: list[str]) -> ToolSet: """Build a subset of tools from the given tool set based on tool names.""" subset = ToolSet() @@ -1283,6 +1353,7 @@ async def _resolve_tool_exec( llm_resp: LLMResponse, ) -> tuple[LLMResponse, ToolSet | None]: """Used in 'skills_like' tool schema mode to re-query LLM with param-only tool schemas.""" + self._sanitize_tool_call_entries(llm_resp) tool_names = llm_resp.tools_call_name if not tool_names: return llm_resp, self.req.func_tool @@ -1311,6 +1382,7 @@ async def _resolve_tool_exec( request_max_retries=self.request_max_retries, ) if requery_resp: + self._sanitize_tool_call_entries(requery_resp) llm_resp = requery_resp # If the re-query still returns no tool calls, and also does not have a meaningful assistant reply, @@ -1338,6 +1410,7 @@ async def _resolve_tool_exec( request_max_retries=self.request_max_retries, ) if repair_resp: + self._sanitize_tool_call_entries(repair_resp) llm_resp = repair_resp return llm_resp, subset diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index b4464680fb..24a972285c 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -1364,6 +1364,117 @@ async def text_chat(self, **kwargs) -> LLMResponse: assert parts[0].text == "一张猫的照片" +@pytest.mark.asyncio +async def test_skills_like_requery_drops_malformed_tool_names(): + captured_contexts = [] + + class MalformedToolNameProvider(MockProvider): + async def text_chat(self, **kwargs) -> LLMResponse: + self.call_count += 1 + if self.call_count == 1: + return LLMResponse( + role="assistant", + completion_text="选择工具", + tools_call_name=["test_tool", None], + tools_call_args=[{"query": "test"}, {"query": "bad"}], + tools_call_ids=["call_1", "call_bad"], + usage=TokenUsage(input_other=10, output=5), + ) + if self.call_count == 2: + captured_contexts.extend(kwargs.get("contexts") or []) + return LLMResponse( + role="assistant", + completion_text="调用工具", + tools_call_name=["test_tool"], + tools_call_args=[{"query": "actual"}], + tools_call_ids=["call_2"], + usage=TokenUsage(input_other=10, output=5), + ) + return LLMResponse( + role="assistant", + completion_text="最终回复", + usage=TokenUsage(input_other=10, output=5), + ) + + provider = MalformedToolNameProvider() + tool = FunctionTool( + name="test_tool", + description="测试", + parameters={"type": "object", "properties": {"query": {"type": "string"}}}, + handler=AsyncMock(), + ) + req = ProviderRequest( + prompt="调用工具", + func_tool=ToolSet(tools=[tool]), + contexts=[], + ) + runner = ToolLoopAgentRunner() + + await runner.reset( + provider=provider, + request=req, + run_context=ContextWrapper(context=None), + tool_executor=cast(Any, MockToolExecutor()), + agent_hooks=MockHooks(), + tool_schema_mode="skills_like", + ) + + async for _ in runner.step(): + pass + + assert provider.call_count == 2 + assert captured_contexts[0]["content"].startswith( + "You have decided to call tool(s): test_tool." + ) + tool_messages = [msg for msg in runner.run_context.messages if msg.role == "tool"] + assert len(tool_messages) == 1 + assert tool_messages[0].tool_call_id == "call_2" + + +@pytest.mark.asyncio +async def test_tool_loop_drops_malformed_tool_names_before_execution(): + class MalformedToolNameProvider(MockProvider): + async def text_chat(self, **kwargs) -> LLMResponse: + self.call_count += 1 + return LLMResponse( + role="assistant", + completion_text="调用工具", + tools_call_name=["test_tool", None], + tools_call_args=[{"query": "actual"}, {"query": "bad"}], + tools_call_ids=["call_1", "call_bad"], + usage=TokenUsage(input_other=10, output=5), + ) + + provider = MalformedToolNameProvider() + tool = FunctionTool( + name="test_tool", + description="测试", + parameters={"type": "object", "properties": {"query": {"type": "string"}}}, + handler=AsyncMock(), + ) + req = ProviderRequest( + prompt="调用工具", + func_tool=ToolSet(tools=[tool]), + contexts=[], + ) + runner = ToolLoopAgentRunner() + + await runner.reset( + provider=provider, + request=req, + run_context=ContextWrapper(context=None), + tool_executor=cast(Any, MockToolExecutor()), + agent_hooks=MockHooks(), + ) + + async for _ in runner.step(): + pass + + tool_messages = [msg for msg in runner.run_context.messages if msg.role == "tool"] + assert len(tool_messages) == 1 + assert tool_messages[0].tool_call_id == "call_1" + + @pytest.mark.asyncio async def test_follow_up_accepted_when_active_and_not_stopping( runner, mock_provider, provider_request, mock_tool_executor, mock_hooks From cdeb1525ce6f0f08c2b67b57f321efd42648bc8a Mon Sep 17 00:00:00 2001 From: EterUltimate <1831303476@qq.com> Date: Sun, 21 Jun 2026 09:42:17 +0800 Subject: [PATCH 2/2] fix: harden tool call sanitization --- .../agent/runners/tool_loop_agent_runner.py | 72 +++++++++++++------ tests/test_tool_loop_agent_runner.py | 60 ++++++++++++++++ 2 files changed, 112 insertions(+), 20 deletions(-) diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 26c926adfe..89849810f8 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -473,9 +473,18 @@ async def _iter_llm_responses( if self.streaming: stream = self.provider.text_chat_stream(**payload) async for resp in stream: # type: ignore + if not resp.is_chunk: + self._sanitize_tool_call_entries(resp) yield resp else: - yield await self.provider.text_chat(**payload) + yield await self._text_chat_with_sanitized_tool_calls(**payload) + + async def _text_chat_with_sanitized_tool_calls( + self, **kwargs: T.Any + ) -> LLMResponse: + resp = await self.provider.text_chat(**kwargs) + self._sanitize_tool_call_entries(resp) + return resp async def _iter_llm_responses_with_fallback( self, @@ -790,8 +799,6 @@ async def step(self): ) return - self._sanitize_tool_call_entries(llm_resp) - if not llm_resp.tools_call_name: await self._complete_with_assistant_response(llm_resp) @@ -1274,7 +1281,18 @@ def _has_meaningful_assistant_reply(llm_resp: LLMResponse) -> bool: @staticmethod def _sanitize_tool_call_entries(llm_resp: LLMResponse) -> None: """Drop malformed tool calls and keep names/args/ids aligned.""" - if not llm_resp.tools_call_name: + raw_tool_names = llm_resp.tools_call_name + if not raw_tool_names: + return + + if not isinstance(raw_tool_names, list): + logger.warning( + "Dropped malformed tool calls because tool names are invalid." + ) + llm_resp.tools_call_name = [] + llm_resp.tools_call_args = [] + llm_resp.tools_call_ids = [] + llm_resp.tools_call_extra_content = {} return tool_names: list[str] = [] @@ -1283,8 +1301,22 @@ def _sanitize_tool_call_entries(llm_resp: LLMResponse) -> None: tool_extra_content: dict[str, dict[str, T.Any]] = {} dropped_count = 0 changed = False + raw_tool_args = llm_resp.tools_call_args + raw_tool_ids = llm_resp.tools_call_ids + tool_args_source = raw_tool_args if isinstance(raw_tool_args, list) else [] + tool_ids_source = raw_tool_ids if isinstance(raw_tool_ids, list) else [] + if tool_args_source is not raw_tool_args: + changed = True + if tool_ids_source is not raw_tool_ids: + changed = True + raw_extra_content = llm_resp.tools_call_extra_content + extra_content_map = ( + raw_extra_content if isinstance(raw_extra_content, dict) else {} + ) + if extra_content_map is not raw_extra_content: + changed = True - for idx, tool_name in enumerate(llm_resp.tools_call_name): + for idx, tool_name in enumerate(raw_tool_names): if not isinstance(tool_name, str) or not tool_name: dropped_count += 1 changed = True @@ -1292,8 +1324,8 @@ def _sanitize_tool_call_entries(llm_resp: LLMResponse) -> None: tool_names.append(tool_name) - if idx < len(llm_resp.tools_call_args): - raw_args = llm_resp.tools_call_args[idx] + if idx < len(tool_args_source): + raw_args = tool_args_source[idx] if isinstance(raw_args, dict): tool_args.append(raw_args) else: @@ -1303,8 +1335,8 @@ def _sanitize_tool_call_entries(llm_resp: LLMResponse) -> None: tool_args.append({}) changed = True - if idx < len(llm_resp.tools_call_ids): - raw_tool_id = llm_resp.tools_call_ids[idx] + if idx < len(tool_ids_source): + raw_tool_id = tool_ids_source[idx] tool_id = ( raw_tool_id if isinstance(raw_tool_id, str) and raw_tool_id @@ -1317,20 +1349,23 @@ def _sanitize_tool_call_entries(llm_resp: LLMResponse) -> None: changed = True tool_ids.append(tool_id) - extra_content = llm_resp.tools_call_extra_content.get(tool_id) - if extra_content: + extra_content = extra_content_map.get(tool_id) + if isinstance(extra_content, dict) and extra_content: tool_extra_content[tool_id] = extra_content - if len(tool_names) != len(llm_resp.tools_call_name): + if len(tool_names) != len(raw_tool_names): + changed = True + if len(tool_args) != len(tool_args_source): changed = True - if len(tool_args) != len(llm_resp.tools_call_args): + if len(tool_ids) != len(tool_ids_source): changed = True - if len(tool_ids) != len(llm_resp.tools_call_ids): + if tool_extra_content != extra_content_map: changed = True if dropped_count: logger.warning( - f"Dropped {dropped_count} malformed tool call(s) with empty names." + f"Dropped {dropped_count} malformed tool call(s) with " + "non-string or empty names." ) if changed: @@ -1353,7 +1388,6 @@ async def _resolve_tool_exec( llm_resp: LLMResponse, ) -> tuple[LLMResponse, ToolSet | None]: """Used in 'skills_like' tool schema mode to re-query LLM with param-only tool schemas.""" - self._sanitize_tool_call_entries(llm_resp) tool_names = llm_resp.tools_call_name if not tool_names: return llm_resp, self.req.func_tool @@ -1371,7 +1405,7 @@ async def _resolve_tool_exec( ) if param_subset.tools and tool_names: contexts = self._build_tool_requery_context(tool_names) - requery_resp = await self.provider.text_chat( + requery_resp = await self._text_chat_with_sanitized_tool_calls( contexts=self._sanitize_contexts_for_provider(contexts), func_tool=param_subset, model=self.req.model, @@ -1382,7 +1416,6 @@ async def _resolve_tool_exec( request_max_retries=self.request_max_retries, ) if requery_resp: - self._sanitize_tool_call_entries(requery_resp) llm_resp = requery_resp # If the re-query still returns no tool calls, and also does not have a meaningful assistant reply, @@ -1399,7 +1432,7 @@ async def _resolve_tool_exec( tool_names, extra_instruction=self.SKILLS_LIKE_REQUERY_REPAIR_INSTRUCTION, ) - repair_resp = await self.provider.text_chat( + repair_resp = await self._text_chat_with_sanitized_tool_calls( contexts=self._sanitize_contexts_for_provider(repair_contexts), func_tool=param_subset, model=self.req.model, @@ -1410,7 +1443,6 @@ async def _resolve_tool_exec( request_max_retries=self.request_max_retries, ) if repair_resp: - self._sanitize_tool_call_entries(repair_resp) llm_resp = repair_resp return llm_resp, subset diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index 24a972285c..5e07b10514 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -1475,6 +1475,66 @@ async def text_chat(self, **kwargs) -> LLMResponse: assert tool_messages[0].tool_call_id == "call_1" +@pytest.mark.asyncio +async def test_tool_loop_sanitizes_malformed_tool_ids_args_and_extra_content(): + class MalformedToolFieldsProvider(MockProvider): + async def text_chat(self, **kwargs) -> LLMResponse: + self.call_count += 1 + response = LLMResponse( + role="assistant", + completion_text="调用工具", + tools_call_name=["test_tool", "test_tool"], + tools_call_args=[{"query": "actual"}, "bad-args"], + tools_call_ids=["call_1", None], + usage=TokenUsage(input_other=10, output=5), + ) + response.tools_call_extra_content = None + return response + + class CapturingHooks(MockHooks): + def __init__(self): + super().__init__() + self.tool_args_list = [] + + async def on_tool_start(self, run_context, tool, tool_args): + await super().on_tool_start(run_context, tool, tool_args) + self.tool_args_list.append(tool_args) + + provider = MalformedToolFieldsProvider() + hooks = CapturingHooks() + tool = FunctionTool( + name="test_tool", + description="测试", + parameters={"type": "object", "properties": {"query": {"type": "string"}}}, + handler=AsyncMock(), + ) + req = ProviderRequest( + prompt="调用工具", + func_tool=ToolSet(tools=[tool]), + contexts=[], + ) + runner = ToolLoopAgentRunner() + + await runner.reset( + provider=provider, + request=req, + run_context=ContextWrapper(context=None), + tool_executor=cast(Any, MockToolExecutor()), + agent_hooks=hooks, + ) + + async for _ in runner.step(): + pass + + tool_messages = [msg for msg in runner.run_context.messages if msg.role == "tool"] + assert len(tool_messages) == 2 + tool_call_ids = [msg.tool_call_id for msg in tool_messages] + assert tool_call_ids[0] == "call_1" + assert tool_call_ids[1] is not None + assert tool_call_ids[1].startswith("call_") + assert hooks.tool_args_list == [{"query": "actual"}, {}] + + @pytest.mark.asyncio async def test_follow_up_accepted_when_active_and_not_stopping( runner, mock_provider, provider_request, mock_tool_executor, mock_hooks