-
-
Notifications
You must be signed in to change notification settings - Fork 2.5k
fix: ignore malformed tool call names #8929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1364,6 +1364,177 @@ async def text_chat(self, **kwargs) -> LLMResponse: | |
| assert parts[0].text == "<image_caption>一张猫的照片</image_caption>" | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_skills_like_requery_drops_malformed_tool_names(): | ||
| captured_contexts = [] | ||
|
|
||
| class MalformedToolNameProvider(MockProvider): | ||
| async def text_chat(self, **kwargs) -> LLMResponse: | ||
| self.call_count += 1 | ||
| if self.call_count == 1: | ||
| return LLMResponse( | ||
| role="assistant", | ||
| completion_text="选择工具", | ||
| tools_call_name=["test_tool", None], | ||
| tools_call_args=[{"query": "test"}, {"query": "bad"}], | ||
| tools_call_ids=["call_1", "call_bad"], | ||
|
Comment on lines
+1371
to
+1380
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Consider adding a test that covers malformed tool IDs and non-dict args to exercise ID regeneration and arg fallback. The sanitization now also regenerates invalid Suggested implementation: @pytest.mark.asyncio
async def test_skills_like_requery_drops_malformed_tool_names():
captured_contexts = []To cover malformed tool IDs and non-dict args (ID regeneration and arg fallback), add the following new test in @pytest.mark.asyncio
async def test_skills_like_requery_sanitizes_tool_ids_and_args():
captured_contexts = []
class MalformedToolIdAndArgsProvider(MockProvider):
async def text_chat(self, **kwargs) -> LLMResponse:
self.call_count += 1
if self.call_count == 1:
# First response has:
# - a valid tool name and a second tool with a malformed ID
# - a non-dict tool arg (string) that should be coerced to {}
return LLMResponse(
role="assistant",
completion_text="选择工具",
tools_call_name=["test_tool", "test_tool"],
tools_call_args=[
{"query": "good"},
"non-dict-arg",
],
tools_call_ids=[
"call_1",
None, # malformed ID, should be regenerated
],
usage=TokenUsage(input_other=10, output=5),
)
if self.call_count == 2:
# Second response should be the re-query after tool execution.
# Capture contexts to assert they stayed aligned with names and ids.
captured_contexts.extend(kwargs.get("contexts") or [])
return LLMResponse(
role="assistant",
completion_text="调用工具",
tools_call_name=[],
tools_call_args=[],
tools_call_ids=[],
usage=TokenUsage(input_other=5, output=3),
)
return LLMResponse(
role="assistant",
completion_text="done",
tools_call_name=[],
tools_call_args=[],
tools_call_ids=[],
usage=TokenUsage(input_other=1, output=1),
)
provider = MalformedToolIdAndArgsProvider()
provider.call_count = 0
# Use the same runner / harness as the other skills-like-requery tests.
# For example, if other tests do something like:
#
# result = await run_skills_like_requery_agent(
# provider=provider,
# tools=[test_tool],
# )
#
# mirror that here. The exact call should match the existing tests:
result = await run_skills_like_requery_agent(
provider=provider,
tools=[test_tool],
)
# Ensure the agent ran through the tool loop.
assert provider.call_count >= 2
# IDs should have been sanitized:
# - same number of contexts as tool calls
# - no empty/None IDs
assert len(captured_contexts) == 2
tool_call_ids = [ctx.tool_call_id for ctx in captured_contexts]
assert all(tool_call_ids)
assert len(set(tool_call_ids)) == 2
# Args sanitization: the non-dict arg should have been coerced to {}.
# We expect one context with the original "good" args, and one with {}.
args_list = [ctx.tool_call_args for ctx in captured_contexts]
assert {"query": "good"} in args_list
assert {} in args_listYou will need to:
|
||
| usage=TokenUsage(input_other=10, output=5), | ||
| ) | ||
| if self.call_count == 2: | ||
| captured_contexts.extend(kwargs.get("contexts") or []) | ||
| return LLMResponse( | ||
| role="assistant", | ||
| completion_text="调用工具", | ||
| tools_call_name=["test_tool"], | ||
| tools_call_args=[{"query": "actual"}], | ||
| tools_call_ids=["call_2"], | ||
| usage=TokenUsage(input_other=10, output=5), | ||
| ) | ||
| return LLMResponse( | ||
| role="assistant", | ||
| completion_text="最终回复", | ||
| usage=TokenUsage(input_other=10, output=5), | ||
| ) | ||
|
|
||
| provider = MalformedToolNameProvider() | ||
| tool = FunctionTool( | ||
| name="test_tool", | ||
| description="测试", | ||
| parameters={"type": "object", "properties": {"query": {"type": "string"}}}, | ||
| handler=AsyncMock(), | ||
| ) | ||
| req = ProviderRequest( | ||
| prompt="调用工具", | ||
| func_tool=ToolSet(tools=[tool]), | ||
| contexts=[], | ||
| ) | ||
| runner = ToolLoopAgentRunner() | ||
|
|
||
| await runner.reset( | ||
| provider=provider, | ||
| request=req, | ||
| run_context=ContextWrapper(context=None), | ||
| tool_executor=cast(Any, MockToolExecutor()), | ||
| agent_hooks=MockHooks(), | ||
| tool_schema_mode="skills_like", | ||
| ) | ||
|
|
||
| async for _ in runner.step(): | ||
| pass | ||
|
|
||
| assert provider.call_count == 2 | ||
| assert captured_contexts[0]["content"].startswith( | ||
| "You have decided to call tool(s): test_tool." | ||
| ) | ||
| tool_messages = [msg for msg in runner.run_context.messages if msg.role == "tool"] | ||
| assert len(tool_messages) == 1 | ||
| assert tool_messages[0].tool_call_id == "call_2" | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_tool_loop_drops_malformed_tool_names_before_execution(): | ||
| class MalformedToolNameProvider(MockProvider): | ||
| async def text_chat(self, **kwargs) -> LLMResponse: | ||
| self.call_count += 1 | ||
| return LLMResponse( | ||
| role="assistant", | ||
| completion_text="调用工具", | ||
| tools_call_name=["test_tool", None], | ||
| tools_call_args=[{"query": "actual"}, {"query": "bad"}], | ||
| tools_call_ids=["call_1", "call_bad"], | ||
| usage=TokenUsage(input_other=10, output=5), | ||
| ) | ||
|
|
||
| provider = MalformedToolNameProvider() | ||
| tool = FunctionTool( | ||
| name="test_tool", | ||
| description="测试", | ||
| parameters={"type": "object", "properties": {"query": {"type": "string"}}}, | ||
| handler=AsyncMock(), | ||
| ) | ||
| req = ProviderRequest( | ||
| prompt="调用工具", | ||
| func_tool=ToolSet(tools=[tool]), | ||
| contexts=[], | ||
| ) | ||
| runner = ToolLoopAgentRunner() | ||
|
|
||
| await runner.reset( | ||
| provider=provider, | ||
| request=req, | ||
| run_context=ContextWrapper(context=None), | ||
| tool_executor=cast(Any, MockToolExecutor()), | ||
| agent_hooks=MockHooks(), | ||
| ) | ||
|
|
||
| async for _ in runner.step(): | ||
| pass | ||
|
|
||
| tool_messages = [msg for msg in runner.run_context.messages if msg.role == "tool"] | ||
| assert len(tool_messages) == 1 | ||
| assert tool_messages[0].tool_call_id == "call_1" | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_tool_loop_sanitizes_malformed_tool_ids_args_and_extra_content(): | ||
| class MalformedToolFieldsProvider(MockProvider): | ||
| async def text_chat(self, **kwargs) -> LLMResponse: | ||
| self.call_count += 1 | ||
| response = LLMResponse( | ||
| role="assistant", | ||
| completion_text="调用工具", | ||
| tools_call_name=["test_tool", "test_tool"], | ||
| tools_call_args=[{"query": "actual"}, "bad-args"], | ||
| tools_call_ids=["call_1", None], | ||
| usage=TokenUsage(input_other=10, output=5), | ||
| ) | ||
| response.tools_call_extra_content = None | ||
| return response | ||
|
|
||
| class CapturingHooks(MockHooks): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.tool_args_list = [] | ||
|
|
||
| async def on_tool_start(self, run_context, tool, tool_args): | ||
| await super().on_tool_start(run_context, tool, tool_args) | ||
| self.tool_args_list.append(tool_args) | ||
|
|
||
| provider = MalformedToolFieldsProvider() | ||
| hooks = CapturingHooks() | ||
| tool = FunctionTool( | ||
| name="test_tool", | ||
| description="测试", | ||
| parameters={"type": "object", "properties": {"query": {"type": "string"}}}, | ||
| handler=AsyncMock(), | ||
| ) | ||
| req = ProviderRequest( | ||
| prompt="调用工具", | ||
| func_tool=ToolSet(tools=[tool]), | ||
| contexts=[], | ||
| ) | ||
| runner = ToolLoopAgentRunner() | ||
|
|
||
| await runner.reset( | ||
| provider=provider, | ||
| request=req, | ||
| run_context=ContextWrapper(context=None), | ||
| tool_executor=cast(Any, MockToolExecutor()), | ||
| agent_hooks=hooks, | ||
| ) | ||
|
|
||
| async for _ in runner.step(): | ||
| pass | ||
|
|
||
| tool_messages = [msg for msg in runner.run_context.messages if msg.role == "tool"] | ||
| assert len(tool_messages) == 2 | ||
| tool_call_ids = [msg.tool_call_id for msg in tool_messages] | ||
| assert tool_call_ids[0] == "call_1" | ||
| assert tool_call_ids[1] is not None | ||
| assert tool_call_ids[1].startswith("call_") | ||
| assert hooks.tool_args_list == [{"query": "actual"}, {}] | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_follow_up_accepted_when_active_and_not_stopping( | ||
| runner, mock_provider, provider_request, mock_tool_executor, mock_hooks | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To prevent potential
TypeErrorexceptions, we should defensively handle cases wheretools_call_args,tools_call_ids, ortools_call_extra_contentareNoneat runtime. This can happen with certain custom or mock providers that do not fully populate these fields.