Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 108 additions & 3 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,9 +473,18 @@ async def _iter_llm_responses(
if self.streaming:
stream = self.provider.text_chat_stream(**payload)
async for resp in stream: # type: ignore
if not resp.is_chunk:
self._sanitize_tool_call_entries(resp)
yield resp
else:
yield await self.provider.text_chat(**payload)
yield await self._text_chat_with_sanitized_tool_calls(**payload)

async def _text_chat_with_sanitized_tool_calls(
self, **kwargs: T.Any
) -> LLMResponse:
resp = await self.provider.text_chat(**kwargs)
self._sanitize_tool_call_entries(resp)
return resp

async def _iter_llm_responses_with_fallback(
self,
Expand Down Expand Up @@ -1269,6 +1278,102 @@ def _has_meaningful_assistant_reply(llm_resp: LLMResponse) -> bool:
text = (llm_resp.completion_text or "").strip()
return bool(text)

@staticmethod
def _sanitize_tool_call_entries(llm_resp: LLMResponse) -> None:
"""Drop malformed tool calls and keep names/args/ids aligned."""
raw_tool_names = llm_resp.tools_call_name
if not raw_tool_names:
return

if not isinstance(raw_tool_names, list):
logger.warning(
"Dropped malformed tool calls because tool names are invalid."
)
llm_resp.tools_call_name = []
llm_resp.tools_call_args = []
llm_resp.tools_call_ids = []
llm_resp.tools_call_extra_content = {}
return

tool_names: list[str] = []
tool_args: list[dict[str, T.Any]] = []
tool_ids: list[str] = []
tool_extra_content: dict[str, dict[str, T.Any]] = {}
dropped_count = 0
changed = False
raw_tool_args = llm_resp.tools_call_args
raw_tool_ids = llm_resp.tools_call_ids
tool_args_source = raw_tool_args if isinstance(raw_tool_args, list) else []
tool_ids_source = raw_tool_ids if isinstance(raw_tool_ids, list) else []
if tool_args_source is not raw_tool_args:
changed = True
if tool_ids_source is not raw_tool_ids:
changed = True
raw_extra_content = llm_resp.tools_call_extra_content
extra_content_map = (
raw_extra_content if isinstance(raw_extra_content, dict) else {}
)
if extra_content_map is not raw_extra_content:
changed = True

for idx, tool_name in enumerate(raw_tool_names):
if not isinstance(tool_name, str) or not tool_name:
dropped_count += 1
changed = True
continue

tool_names.append(tool_name)

if idx < len(tool_args_source):
raw_args = tool_args_source[idx]
if isinstance(raw_args, dict):
tool_args.append(raw_args)
else:
tool_args.append({})
changed = True
else:
tool_args.append({})
changed = True

if idx < len(tool_ids_source):
raw_tool_id = tool_ids_source[idx]
tool_id = (
raw_tool_id
if isinstance(raw_tool_id, str) and raw_tool_id
else f"call_{uuid.uuid4().hex[:8]}"
)
if tool_id != raw_tool_id:
changed = True
else:
tool_id = f"call_{uuid.uuid4().hex[:8]}"
changed = True
tool_ids.append(tool_id)

extra_content = extra_content_map.get(tool_id)
if isinstance(extra_content, dict) and extra_content:
tool_extra_content[tool_id] = extra_content

if len(tool_names) != len(raw_tool_names):
changed = True
if len(tool_args) != len(tool_args_source):
changed = True
if len(tool_ids) != len(tool_ids_source):
changed = True
if tool_extra_content != extra_content_map:
changed = True

if dropped_count:
logger.warning(
f"Dropped {dropped_count} malformed tool call(s) with "
"non-string or empty names."
)

if changed:
llm_resp.tools_call_name = tool_names
llm_resp.tools_call_args = tool_args
llm_resp.tools_call_ids = tool_ids
llm_resp.tools_call_extra_content = tool_extra_content
Comment on lines +1298 to +1375

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent potential TypeError exceptions, we should defensively handle cases where tools_call_args, tools_call_ids, or tools_call_extra_content are None at runtime. This can happen with certain custom or mock providers that do not fully populate these fields.

        tool_names: list[str] = []
        tool_args: list[dict[str, T.Any]] = []
        tool_ids: list[str] = []
        tool_extra_content: dict[str, dict[str, T.Any]] = {}
        dropped_count = 0
        changed = False

        raw_args_list = llm_resp.tools_call_args if llm_resp.tools_call_args is not None else []
        raw_ids_list = llm_resp.tools_call_ids if llm_resp.tools_call_ids is not None else []
        raw_extra_content = llm_resp.tools_call_extra_content if llm_resp.tools_call_extra_content is not None else {}

        if (
            llm_resp.tools_call_args is None
            or llm_resp.tools_call_ids is None
            or llm_resp.tools_call_extra_content is None
        ):
            changed = True

        for idx, tool_name in enumerate(llm_resp.tools_call_name):
            if not isinstance(tool_name, str) or not tool_name:
                dropped_count += 1
                changed = True
                continue

            tool_names.append(tool_name)

            if idx < len(raw_args_list):
                raw_args = raw_args_list[idx]
                if isinstance(raw_args, dict):
                    tool_args.append(raw_args)
                else:
                    tool_args.append({})
                    changed = True
            else:
                tool_args.append({})
                changed = True

            if idx < len(raw_ids_list):
                raw_tool_id = raw_ids_list[idx]
                tool_id = (
                    raw_tool_id
                    if isinstance(raw_tool_id, str) and raw_tool_id
                    else f"call_{uuid.uuid4().hex[:8]}"
                )
                if tool_id != raw_tool_id:
                    changed = True
            else:
                tool_id = f"call_{uuid.uuid4().hex[:8]}"
                changed = True
            tool_ids.append(tool_id)

            extra_content = raw_extra_content.get(tool_id)
            if extra_content:
                tool_extra_content[tool_id] = extra_content

        if len(tool_names) != len(llm_resp.tools_call_name):
            changed = True
        if len(tool_args) != len(raw_args_list):
            changed = True
        if len(tool_ids) != len(raw_ids_list):
            changed = True

        if dropped_count:
            logger.warning(
                f"Dropped {dropped_count} malformed tool call(s) with empty names."
            )

        if changed:
            llm_resp.tools_call_name = tool_names
            llm_resp.tools_call_args = tool_args
            llm_resp.tools_call_ids = tool_ids
            llm_resp.tools_call_extra_content = tool_extra_content


def _build_tool_subset(self, tool_set: ToolSet, tool_names: list[str]) -> ToolSet:
"""Build a subset of tools from the given tool set based on tool names."""
subset = ToolSet()
Expand Down Expand Up @@ -1300,7 +1405,7 @@ async def _resolve_tool_exec(
)
if param_subset.tools and tool_names:
contexts = self._build_tool_requery_context(tool_names)
requery_resp = await self.provider.text_chat(
requery_resp = await self._text_chat_with_sanitized_tool_calls(
contexts=self._sanitize_contexts_for_provider(contexts),
func_tool=param_subset,
model=self.req.model,
Expand All @@ -1327,7 +1432,7 @@ async def _resolve_tool_exec(
tool_names,
extra_instruction=self.SKILLS_LIKE_REQUERY_REPAIR_INSTRUCTION,
)
repair_resp = await self.provider.text_chat(
repair_resp = await self._text_chat_with_sanitized_tool_calls(
contexts=self._sanitize_contexts_for_provider(repair_contexts),
func_tool=param_subset,
model=self.req.model,
Expand Down
171 changes: 171 additions & 0 deletions tests/test_tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,177 @@ async def text_chat(self, **kwargs) -> LLMResponse:
assert parts[0].text == "<image_caption>一张猫的照片</image_caption>"


@pytest.mark.asyncio
async def test_skills_like_requery_drops_malformed_tool_names():
captured_contexts = []

class MalformedToolNameProvider(MockProvider):
async def text_chat(self, **kwargs) -> LLMResponse:
self.call_count += 1
if self.call_count == 1:
return LLMResponse(
role="assistant",
completion_text="选择工具",
tools_call_name=["test_tool", None],
tools_call_args=[{"query": "test"}, {"query": "bad"}],
tools_call_ids=["call_1", "call_bad"],
Comment on lines +1371 to +1380

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Consider adding a test that covers malformed tool IDs and non-dict args to exercise ID regeneration and arg fallback.

The sanitization now also regenerates invalid tools_call_ids and coerces non-dict tools_call_args to {}; these paths aren’t covered by the new tests. Please add a test where tools_call_ids includes None/"" and tools_call_args includes a non-dict (e.g., list/string) to confirm IDs are regenerated and stay aligned with names, and that invalid args are replaced with {} without affecting execution.

Suggested implementation:

@pytest.mark.asyncio
async def test_skills_like_requery_drops_malformed_tool_names():
    captured_contexts = []

To cover malformed tool IDs and non-dict args (ID regeneration and arg fallback), add the following new test in tests/test_tool_loop_agent_runner.py near the other skills_like_requery tests (e.g., immediately after test_skills_like_requery_drops_malformed_tool_names):

@pytest.mark.asyncio
async def test_skills_like_requery_sanitizes_tool_ids_and_args():
    captured_contexts = []

    class MalformedToolIdAndArgsProvider(MockProvider):
        async def text_chat(self, **kwargs) -> LLMResponse:
            self.call_count += 1
            if self.call_count == 1:
                # First response has:
                # - a valid tool name and a second tool with a malformed ID
                # - a non-dict tool arg (string) that should be coerced to {}
                return LLMResponse(
                    role="assistant",
                    completion_text="选择工具",
                    tools_call_name=["test_tool", "test_tool"],
                    tools_call_args=[
                        {"query": "good"},
                        "non-dict-arg",
                    ],
                    tools_call_ids=[
                        "call_1",
                        None,  # malformed ID, should be regenerated
                    ],
                    usage=TokenUsage(input_other=10, output=5),
                )

            if self.call_count == 2:
                # Second response should be the re-query after tool execution.
                # Capture contexts to assert they stayed aligned with names and ids.
                captured_contexts.extend(kwargs.get("contexts") or [])
                return LLMResponse(
                    role="assistant",
                    completion_text="调用工具",
                    tools_call_name=[],
                    tools_call_args=[],
                    tools_call_ids=[],
                    usage=TokenUsage(input_other=5, output=3),
                )

            return LLMResponse(
                role="assistant",
                completion_text="done",
                tools_call_name=[],
                tools_call_args=[],
                tools_call_ids=[],
                usage=TokenUsage(input_other=1, output=1),
            )

    provider = MalformedToolIdAndArgsProvider()
    provider.call_count = 0

    # Use the same runner / harness as the other skills-like-requery tests.
    # For example, if other tests do something like:
    #
    #   result = await run_skills_like_requery_agent(
    #       provider=provider,
    #       tools=[test_tool],
    #   )
    #
    # mirror that here. The exact call should match the existing tests:
    result = await run_skills_like_requery_agent(
        provider=provider,
        tools=[test_tool],
    )

    # Ensure the agent ran through the tool loop.
    assert provider.call_count >= 2

    # IDs should have been sanitized:
    # - same number of contexts as tool calls
    # - no empty/None IDs
    assert len(captured_contexts) == 2
    tool_call_ids = [ctx.tool_call_id for ctx in captured_contexts]
    assert all(tool_call_ids)
    assert len(set(tool_call_ids)) == 2

    # Args sanitization: the non-dict arg should have been coerced to {}.
    # We expect one context with the original "good" args, and one with {}.
    args_list = [ctx.tool_call_args for ctx in captured_contexts]
    assert {"query": "good"} in args_list
    assert {} in args_list

You will need to:

  1. Replace run_skills_like_requery_agent with the actual helper used by the existing test_skills_like_requery_* tests to drive the agent/tool loop.
  2. Replace test_tool with the actual tool object (or list of tools) used in those tests, so the new test uses the same setup and conventions.
  3. If your ToolContext or equivalent type uses different attribute names than tool_call_id and tool_call_args, adjust the assertions accordingly (mirror how contexts are inspected in existing tests).

usage=TokenUsage(input_other=10, output=5),
)
if self.call_count == 2:
captured_contexts.extend(kwargs.get("contexts") or [])
return LLMResponse(
role="assistant",
completion_text="调用工具",
tools_call_name=["test_tool"],
tools_call_args=[{"query": "actual"}],
tools_call_ids=["call_2"],
usage=TokenUsage(input_other=10, output=5),
)
return LLMResponse(
role="assistant",
completion_text="最终回复",
usage=TokenUsage(input_other=10, output=5),
)

provider = MalformedToolNameProvider()
tool = FunctionTool(
name="test_tool",
description="测试",
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
handler=AsyncMock(),
)
req = ProviderRequest(
prompt="调用工具",
func_tool=ToolSet(tools=[tool]),
contexts=[],
)
runner = ToolLoopAgentRunner()

await runner.reset(
provider=provider,
request=req,
run_context=ContextWrapper(context=None),
tool_executor=cast(Any, MockToolExecutor()),
agent_hooks=MockHooks(),
tool_schema_mode="skills_like",
)

async for _ in runner.step():
pass

assert provider.call_count == 2
assert captured_contexts[0]["content"].startswith(
"You have decided to call tool(s): test_tool."
)
tool_messages = [msg for msg in runner.run_context.messages if msg.role == "tool"]
assert len(tool_messages) == 1
assert tool_messages[0].tool_call_id == "call_2"


@pytest.mark.asyncio
async def test_tool_loop_drops_malformed_tool_names_before_execution():
class MalformedToolNameProvider(MockProvider):
async def text_chat(self, **kwargs) -> LLMResponse:
self.call_count += 1
return LLMResponse(
role="assistant",
completion_text="调用工具",
tools_call_name=["test_tool", None],
tools_call_args=[{"query": "actual"}, {"query": "bad"}],
tools_call_ids=["call_1", "call_bad"],
usage=TokenUsage(input_other=10, output=5),
)

provider = MalformedToolNameProvider()
tool = FunctionTool(
name="test_tool",
description="测试",
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
handler=AsyncMock(),
)
req = ProviderRequest(
prompt="调用工具",
func_tool=ToolSet(tools=[tool]),
contexts=[],
)
runner = ToolLoopAgentRunner()

await runner.reset(
provider=provider,
request=req,
run_context=ContextWrapper(context=None),
tool_executor=cast(Any, MockToolExecutor()),
agent_hooks=MockHooks(),
)

async for _ in runner.step():
pass

tool_messages = [msg for msg in runner.run_context.messages if msg.role == "tool"]
assert len(tool_messages) == 1
assert tool_messages[0].tool_call_id == "call_1"


@pytest.mark.asyncio
async def test_tool_loop_sanitizes_malformed_tool_ids_args_and_extra_content():
class MalformedToolFieldsProvider(MockProvider):
async def text_chat(self, **kwargs) -> LLMResponse:
self.call_count += 1
response = LLMResponse(
role="assistant",
completion_text="调用工具",
tools_call_name=["test_tool", "test_tool"],
tools_call_args=[{"query": "actual"}, "bad-args"],
tools_call_ids=["call_1", None],
usage=TokenUsage(input_other=10, output=5),
)
response.tools_call_extra_content = None
return response

class CapturingHooks(MockHooks):
def __init__(self):
super().__init__()
self.tool_args_list = []

async def on_tool_start(self, run_context, tool, tool_args):
await super().on_tool_start(run_context, tool, tool_args)
self.tool_args_list.append(tool_args)

provider = MalformedToolFieldsProvider()
hooks = CapturingHooks()
tool = FunctionTool(
name="test_tool",
description="测试",
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
handler=AsyncMock(),
)
req = ProviderRequest(
prompt="调用工具",
func_tool=ToolSet(tools=[tool]),
contexts=[],
)
runner = ToolLoopAgentRunner()

await runner.reset(
provider=provider,
request=req,
run_context=ContextWrapper(context=None),
tool_executor=cast(Any, MockToolExecutor()),
agent_hooks=hooks,
)

async for _ in runner.step():
pass

tool_messages = [msg for msg in runner.run_context.messages if msg.role == "tool"]
assert len(tool_messages) == 2
tool_call_ids = [msg.tool_call_id for msg in tool_messages]
assert tool_call_ids[0] == "call_1"
assert tool_call_ids[1] is not None
assert tool_call_ids[1].startswith("call_")
assert hooks.tool_args_list == [{"query": "actual"}, {}]


@pytest.mark.asyncio
async def test_follow_up_accepted_when_active_and_not_stopping(
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
Expand Down
Loading