diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index b56d7e62fb..89849810f8 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -473,9 +473,18 @@ async def _iter_llm_responses( if self.streaming: stream = self.provider.text_chat_stream(**payload) async for resp in stream: # type: ignore + if not resp.is_chunk: + self._sanitize_tool_call_entries(resp) yield resp else: - yield await self.provider.text_chat(**payload) + yield await self._text_chat_with_sanitized_tool_calls(**payload) + + async def _text_chat_with_sanitized_tool_calls( + self, **kwargs: T.Any + ) -> LLMResponse: + resp = await self.provider.text_chat(**kwargs) + self._sanitize_tool_call_entries(resp) + return resp async def _iter_llm_responses_with_fallback( self, @@ -1269,6 +1278,102 @@ def _has_meaningful_assistant_reply(llm_resp: LLMResponse) -> bool: text = (llm_resp.completion_text or "").strip() return bool(text) + @staticmethod + def _sanitize_tool_call_entries(llm_resp: LLMResponse) -> None: + """Drop malformed tool calls and keep names/args/ids aligned.""" + raw_tool_names = llm_resp.tools_call_name + if not raw_tool_names: + return + + if not isinstance(raw_tool_names, list): + logger.warning( + "Dropped malformed tool calls because tool names are invalid." + ) + llm_resp.tools_call_name = [] + llm_resp.tools_call_args = [] + llm_resp.tools_call_ids = [] + llm_resp.tools_call_extra_content = {} + return + + tool_names: list[str] = [] + tool_args: list[dict[str, T.Any]] = [] + tool_ids: list[str] = [] + tool_extra_content: dict[str, dict[str, T.Any]] = {} + dropped_count = 0 + changed = False + raw_tool_args = llm_resp.tools_call_args + raw_tool_ids = llm_resp.tools_call_ids + tool_args_source = raw_tool_args if isinstance(raw_tool_args, list) else [] + tool_ids_source = raw_tool_ids if isinstance(raw_tool_ids, list) else [] + if tool_args_source is not raw_tool_args: + changed = True + if tool_ids_source is not raw_tool_ids: + changed = True + raw_extra_content = llm_resp.tools_call_extra_content + extra_content_map = ( + raw_extra_content if isinstance(raw_extra_content, dict) else {} + ) + if extra_content_map is not raw_extra_content: + changed = True + + for idx, tool_name in enumerate(raw_tool_names): + if not isinstance(tool_name, str) or not tool_name: + dropped_count += 1 + changed = True + continue + + tool_names.append(tool_name) + + if idx < len(tool_args_source): + raw_args = tool_args_source[idx] + if isinstance(raw_args, dict): + tool_args.append(raw_args) + else: + tool_args.append({}) + changed = True + else: + tool_args.append({}) + changed = True + + if idx < len(tool_ids_source): + raw_tool_id = tool_ids_source[idx] + tool_id = ( + raw_tool_id + if isinstance(raw_tool_id, str) and raw_tool_id + else f"call_{uuid.uuid4().hex[:8]}" + ) + if tool_id != raw_tool_id: + changed = True + else: + tool_id = f"call_{uuid.uuid4().hex[:8]}" + changed = True + tool_ids.append(tool_id) + + extra_content = extra_content_map.get(tool_id) + if isinstance(extra_content, dict) and extra_content: + tool_extra_content[tool_id] = extra_content + + if len(tool_names) != len(raw_tool_names): + changed = True + if len(tool_args) != len(tool_args_source): + changed = True + if len(tool_ids) != len(tool_ids_source): + changed = True + if tool_extra_content != extra_content_map: + changed = True + + if dropped_count: + logger.warning( + f"Dropped {dropped_count} malformed tool call(s) with " + "non-string or empty names." + ) + + if changed: + llm_resp.tools_call_name = tool_names + llm_resp.tools_call_args = tool_args + llm_resp.tools_call_ids = tool_ids + llm_resp.tools_call_extra_content = tool_extra_content + def _build_tool_subset(self, tool_set: ToolSet, tool_names: list[str]) -> ToolSet: """Build a subset of tools from the given tool set based on tool names.""" subset = ToolSet() @@ -1300,7 +1405,7 @@ async def _resolve_tool_exec( ) if param_subset.tools and tool_names: contexts = self._build_tool_requery_context(tool_names) - requery_resp = await self.provider.text_chat( + requery_resp = await self._text_chat_with_sanitized_tool_calls( contexts=self._sanitize_contexts_for_provider(contexts), func_tool=param_subset, model=self.req.model, @@ -1327,7 +1432,7 @@ async def _resolve_tool_exec( tool_names, extra_instruction=self.SKILLS_LIKE_REQUERY_REPAIR_INSTRUCTION, ) - repair_resp = await self.provider.text_chat( + repair_resp = await self._text_chat_with_sanitized_tool_calls( contexts=self._sanitize_contexts_for_provider(repair_contexts), func_tool=param_subset, model=self.req.model, diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index b4464680fb..5e07b10514 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -1364,6 +1364,177 @@ async def text_chat(self, **kwargs) -> LLMResponse: assert parts[0].text == "一张猫的照片" +@pytest.mark.asyncio +async def test_skills_like_requery_drops_malformed_tool_names(): + captured_contexts = [] + + class MalformedToolNameProvider(MockProvider): + async def text_chat(self, **kwargs) -> LLMResponse: + self.call_count += 1 + if self.call_count == 1: + return LLMResponse( + role="assistant", + completion_text="选择工具", + tools_call_name=["test_tool", None], + tools_call_args=[{"query": "test"}, {"query": "bad"}], + tools_call_ids=["call_1", "call_bad"], + usage=TokenUsage(input_other=10, output=5), + ) + if self.call_count == 2: + captured_contexts.extend(kwargs.get("contexts") or []) + return LLMResponse( + role="assistant", + completion_text="调用工具", + tools_call_name=["test_tool"], + tools_call_args=[{"query": "actual"}], + tools_call_ids=["call_2"], + usage=TokenUsage(input_other=10, output=5), + ) + return LLMResponse( + role="assistant", + completion_text="最终回复", + usage=TokenUsage(input_other=10, output=5), + ) + + provider = MalformedToolNameProvider() + tool = FunctionTool( + name="test_tool", + description="测试", + parameters={"type": "object", "properties": {"query": {"type": "string"}}}, + handler=AsyncMock(), + ) + req = ProviderRequest( + prompt="调用工具", + func_tool=ToolSet(tools=[tool]), + contexts=[], + ) + runner = ToolLoopAgentRunner() + + await runner.reset( + provider=provider, + request=req, + run_context=ContextWrapper(context=None), + tool_executor=cast(Any, MockToolExecutor()), + agent_hooks=MockHooks(), + tool_schema_mode="skills_like", + ) + + async for _ in runner.step(): + pass + + assert provider.call_count == 2 + assert captured_contexts[0]["content"].startswith( + "You have decided to call tool(s): test_tool." + ) + tool_messages = [msg for msg in runner.run_context.messages if msg.role == "tool"] + assert len(tool_messages) == 1 + assert tool_messages[0].tool_call_id == "call_2" + + +@pytest.mark.asyncio +async def test_tool_loop_drops_malformed_tool_names_before_execution(): + class MalformedToolNameProvider(MockProvider): + async def text_chat(self, **kwargs) -> LLMResponse: + self.call_count += 1 + return LLMResponse( + role="assistant", + completion_text="调用工具", + tools_call_name=["test_tool", None], + tools_call_args=[{"query": "actual"}, {"query": "bad"}], + tools_call_ids=["call_1", "call_bad"], + usage=TokenUsage(input_other=10, output=5), + ) + + provider = MalformedToolNameProvider() + tool = FunctionTool( + name="test_tool", + description="测试", + parameters={"type": "object", "properties": {"query": {"type": "string"}}}, + handler=AsyncMock(), + ) + req = ProviderRequest( + prompt="调用工具", + func_tool=ToolSet(tools=[tool]), + contexts=[], + ) + runner = ToolLoopAgentRunner() + + await runner.reset( + provider=provider, + request=req, + run_context=ContextWrapper(context=None), + tool_executor=cast(Any, MockToolExecutor()), + agent_hooks=MockHooks(), + ) + + async for _ in runner.step(): + pass + + tool_messages = [msg for msg in runner.run_context.messages if msg.role == "tool"] + assert len(tool_messages) == 1 + assert tool_messages[0].tool_call_id == "call_1" + + +@pytest.mark.asyncio +async def test_tool_loop_sanitizes_malformed_tool_ids_args_and_extra_content(): + class MalformedToolFieldsProvider(MockProvider): + async def text_chat(self, **kwargs) -> LLMResponse: + self.call_count += 1 + response = LLMResponse( + role="assistant", + completion_text="调用工具", + tools_call_name=["test_tool", "test_tool"], + tools_call_args=[{"query": "actual"}, "bad-args"], + tools_call_ids=["call_1", None], + usage=TokenUsage(input_other=10, output=5), + ) + response.tools_call_extra_content = None + return response + + class CapturingHooks(MockHooks): + def __init__(self): + super().__init__() + self.tool_args_list = [] + + async def on_tool_start(self, run_context, tool, tool_args): + await super().on_tool_start(run_context, tool, tool_args) + self.tool_args_list.append(tool_args) + + provider = MalformedToolFieldsProvider() + hooks = CapturingHooks() + tool = FunctionTool( + name="test_tool", + description="测试", + parameters={"type": "object", "properties": {"query": {"type": "string"}}}, + handler=AsyncMock(), + ) + req = ProviderRequest( + prompt="调用工具", + func_tool=ToolSet(tools=[tool]), + contexts=[], + ) + runner = ToolLoopAgentRunner() + + await runner.reset( + provider=provider, + request=req, + run_context=ContextWrapper(context=None), + tool_executor=cast(Any, MockToolExecutor()), + agent_hooks=hooks, + ) + + async for _ in runner.step(): + pass + + tool_messages = [msg for msg in runner.run_context.messages if msg.role == "tool"] + assert len(tool_messages) == 2 + tool_call_ids = [msg.tool_call_id for msg in tool_messages] + assert tool_call_ids[0] == "call_1" + assert tool_call_ids[1] is not None + assert tool_call_ids[1].startswith("call_") + assert hooks.tool_args_list == [{"query": "actual"}, {}] + + @pytest.mark.asyncio async def test_follow_up_accepted_when_active_and_not_stopping( runner, mock_provider, provider_request, mock_tool_executor, mock_hooks