diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py index 75dc0f053..38e3503d7 100644 --- a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py @@ -66,7 +66,7 @@ from agentex.lib.core.tracing.tracer import AsyncTracer from agentex.types.task_message_delta import TextDelta, ToolRequestDelta, ReasoningContentDelta, ReasoningSummaryDelta from agentex.types.task_message_update import StreamTaskMessageFull, StreamTaskMessageDelta -from agentex.types.task_message_content import TextContent, ReasoningContent, ToolRequestContent +from agentex.types.task_message_content import TextContent, ReasoningContent, ToolRequestContent, ToolResponseContent from agentex.lib.adk.utils._modules.client import create_async_agentex_client from agentex.lib.core.temporal.plugins.openai_agents.interceptors.context_interceptor import ( streaming_task_id, @@ -123,6 +123,103 @@ def _serialize_item(item: Any) -> dict[str, Any]: return item_dict +# Responses-API output items for server-side / hosted tools. These execute inside +# the Responses API, so they never become function_call items AND the SDK's +# RunHooks (on_tool_start/on_tool_end) never fire for them. The streaming loop +# must surface them explicitly, as a tool request + response pair, when the item +# completes (by then it carries the full query/result). +_HOSTED_TOOL_TYPES = frozenset( + { + "web_search_call", + "file_search_call", + "code_interpreter_call", + "image_generation_call", + "mcp_call", + "computer_call", + "local_shell_call", + } +) + +# Cap on the rendered hosted-tool result string (UI / trace readability). +_HOSTED_TOOL_RESULT_CAP = 2000 + + +def _coerce_args(raw: Any) -> dict[str, Any]: + """Best-effort coerce a hosted-tool's arguments to a dict for the UI.""" + if raw is None: + return {} + if isinstance(raw, dict): + return raw + if isinstance(raw, str): + try: + parsed = json.loads(raw) + return parsed if isinstance(parsed, dict) else {"value": parsed} + except (json.JSONDecodeError, ValueError): + return {"raw": raw} + serialized = _serialize_item(raw) + return serialized if isinstance(serialized, dict) else {"value": str(raw)} + + +def _hosted_tool_request(item: Any) -> tuple[str, str, dict[str, Any]]: + """Extract (call_id, display_name, arguments) from a hosted-tool item.""" + itype = getattr(item, "type", "") or "" + call_id = ( + getattr(item, "id", "") + or getattr(item, "call_id", "") + or f"hosted_{uuid.uuid4().hex[:8]}" + ) + name = itype[:-5] if itype.endswith("_call") else itype # web_search_call -> web_search + args: dict[str, Any] = {} + if itype == "web_search_call": + action = getattr(item, "action", None) + if action is not None: + args = _coerce_args(action) + elif itype == "file_search_call": + args = {"queries": list(getattr(item, "queries", []) or [])} + elif itype == "code_interpreter_call": + args = {"code": getattr(item, "code", "") or ""} + elif itype in ("computer_call", "local_shell_call"): + # Both carry an `action` object: a ComputerAction (click/scroll/type/...) + # or a LocalShellCallAction (command/env/cwd). Surface it as the args so + # the trace shows what the tool actually did, not just its status. + action = getattr(item, "action", None) + if action is not None: + args = _coerce_args(action) + elif itype == "mcp_call": + mcp_name = getattr(item, "name", None) or "mcp" + server = getattr(item, "server_label", None) + name = f"{server}.{mcp_name}" if server else mcp_name + args = _coerce_args(getattr(item, "arguments", None)) + return call_id, name, args + + +def _hosted_tool_result(item: Any) -> str: + """Extract a short result string from a completed hosted-tool item.""" + itype = getattr(item, "type", "") or "" + if itype == "mcp_call": + err = getattr(item, "error", None) + if err: + return f"error: {err}" + out = getattr(item, "output", None) + if out: + return str(out) + elif itype == "code_interpreter_call": + outputs = getattr(item, "outputs", None) + if outputs: + return json.dumps([_serialize_item(o) for o in outputs])[:_HOSTED_TOOL_RESULT_CAP] + elif itype == "file_search_call": + results = getattr(item, "results", None) + if results: + return json.dumps([_serialize_item(r) for r in results])[:_HOSTED_TOOL_RESULT_CAP] + elif itype == "image_generation_call": + # `result` is base64 image data; surface a compact reference instead of + # dumping the (large) payload into the trace. + result = getattr(item, "result", None) + if result: + return f"" + return str(getattr(item, "status", "completed") or "completed") + + class TemporalStreamingModel(Model): """Custom model implementation with streaming support.""" @@ -481,6 +578,31 @@ def _convert_tool_choice(self, tool_choice: Any) -> Any: # Pass through as-is for other types return tool_choice + async def _post_tool_message(self, task_id: str, content: Any) -> None: + """Post a one-shot tool request/response message (no deltas). + + Used for hosted/server-side tool calls (web_search, file_search, + code_interpreter, image generation, server-side mcp, ...) that execute + inside the Responses API and so never produce function_call items or fire + RunHooks. Each completed hosted tool is surfaced as a ToolRequestContent + + ToolResponseContent pair. Posting full (no deltas) means the coalescing + path that the streamed reasoning/text contexts use does not apply here. + """ + try: + async with adk.streaming.streaming_task_message_context( + task_id=task_id, + initial_content=content, + ) as ctx: + await ctx.stream_update( + StreamTaskMessageFull( + parent_task_message=ctx.task_message, + content=content, + type="full", + ) + ) + except Exception as e: # noqa: BLE001 - UI surfacing must never break a turn + logger.warning(f"[TemporalStreamingModel] failed to post hosted-tool message: {e}") + @override async def get_response( self, @@ -942,6 +1064,33 @@ async def get_response( finally: call_data['context'] = None + elif item and getattr(item, 'type', None) in _HOSTED_TOOL_TYPES: + # Hosted / server-side tool call (web_search, file_search, + # code_interpreter, image generation, server-side mcp, ...). + # These run inside the Responses API: no function_call item + # and no RunHooks fire, so surface the completed call as a + # tool request + response pair (it carries the full + # query/result by the time it's done). + call_id, name, args = _hosted_tool_request(item) + await self._post_tool_message( + task_id, + ToolRequestContent( + author="agent", + tool_call_id=call_id, + name=name, + arguments=args, + ), + ) + await self._post_tool_message( + task_id, + ToolResponseContent( + author="agent", + tool_call_id=call_id, + name=name, + content={"result": _hosted_tool_result(item)[:_HOSTED_TOOL_RESULT_CAP]}, + ), + ) + elif isinstance(event, ResponseReasoningSummaryPartAddedEvent): # New reasoning part/summary started - reset accumulator part = getattr(event, 'part', None) diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_hosted_tools.py b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_hosted_tools.py new file mode 100644 index 000000000..066d6f2ed --- /dev/null +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_hosted_tools.py @@ -0,0 +1,135 @@ +"""Unit tests for hosted/server-side tool rendering helpers. + +These cover the pure extraction helpers used by TemporalStreamingModel to surface +Responses-API hosted tools (web_search, file_search, code_interpreter, mcp, ...) +as ToolRequest/ToolResponse pairs. They never become function_call items, so the +streaming loop must render them explicitly. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +from openai.types.responses.response_output_item import ( + LocalShellCall, + ImageGenerationCall, + LocalShellCallAction, +) +from openai.types.responses.response_computer_tool_call import ActionClick, ResponseComputerToolCall +from openai.types.responses.response_function_web_search import ActionSearch, ResponseFunctionWebSearch + +from agentex.lib.core.temporal.plugins.openai_agents.models.temporal_streaming_model import ( + _HOSTED_TOOL_TYPES, + _coerce_args, + _hosted_tool_result, + _hosted_tool_request, +) + + +def test_hosted_tool_types_membership(): + for t in ("web_search_call", "file_search_call", "code_interpreter_call", + "image_generation_call", "mcp_call", "computer_call", "local_shell_call"): + assert t in _HOSTED_TOOL_TYPES + assert "function_call" not in _HOSTED_TOOL_TYPES + + +def test_coerce_args_variants(): + assert _coerce_args(None) == {} + assert _coerce_args({"a": 1}) == {"a": 1} + assert _coerce_args('{"a": 1}') == {"a": 1} + assert _coerce_args("[1, 2]") == {"value": [1, 2]} + assert _coerce_args("not json") == {"raw": "not json"} + + +def test_hosted_tool_request_web_search(): + # Use the real Responses-API type to prove `action` is a genuine SDK field + # (it is on ResponseFunctionWebSearch), not a hand-crafted stand-in. + item = ResponseFunctionWebSearch( + id="ws_1", + status="completed", + type="web_search_call", + action=ActionSearch(type="search", query="agentex"), + ) + call_id, name, args = _hosted_tool_request(item) + assert call_id == "ws_1" + assert name == "web_search" # "_call" stripped + assert args["query"] == "agentex" + assert args["type"] == "search" + + +def test_hosted_tool_request_computer_call(): + item = ResponseComputerToolCall( + id="cc_1", + call_id="ccall_1", + type="computer_call", + status="completed", + pending_safety_checks=[], + action=ActionClick(type="click", button="left", x=10, y=20), + ) + call_id, name, args = _hosted_tool_request(item) + assert call_id == "cc_1" + assert name == "computer" + assert args["type"] == "click" + assert args["button"] == "left" + assert args["x"] == 10 and args["y"] == 20 + + +def test_hosted_tool_request_local_shell_call(): + item = LocalShellCall( + id="ls_1", + call_id="lscall_1", + type="local_shell_call", + status="completed", + action=LocalShellCallAction(type="exec", command=["ls", "-la"], env={}), + ) + call_id, name, args = _hosted_tool_request(item) + assert call_id == "ls_1" + assert name == "local_shell" + assert args["command"] == ["ls", "-la"] + + +def test_hosted_tool_request_mcp_uses_server_label(): + item = SimpleNamespace(type="mcp_call", id="m_1", name="search", + server_label="linear", arguments='{"q": "x"}') + call_id, name, args = _hosted_tool_request(item) + assert call_id == "m_1" + assert name == "linear.search" + assert args == {"q": "x"} + + +def test_hosted_tool_request_file_search_queries(): + item = SimpleNamespace(type="file_search_call", id="fs_1", + queries=["q1", "q2"]) + _, name, args = _hosted_tool_request(item) + assert name == "file_search" + assert args == {"queries": ["q1", "q2"]} + + +def test_hosted_tool_request_falls_back_to_generated_id(): + item = SimpleNamespace(type="code_interpreter_call", code="print(1)") + call_id, name, args = _hosted_tool_request(item) + assert call_id.startswith("hosted_") + assert name == "code_interpreter" + assert args == {"code": "print(1)"} + + +def test_hosted_tool_result_mcp_error_and_output(): + err_item = SimpleNamespace(type="mcp_call", error="boom") + assert "boom" in _hosted_tool_result(err_item) + ok_item = SimpleNamespace(type="mcp_call", error=None, output="done") + assert _hosted_tool_result(ok_item) == "done" + + +def test_hosted_tool_result_image_generation(): + item = ImageGenerationCall( + id="ig_1", + type="image_generation_call", + status="completed", + result="QUJD", # 4 chars of (fake) base64 + ) + assert _hosted_tool_result(item) == "" + + +def test_hosted_tool_result_falls_back_to_status(): + item = SimpleNamespace(type="web_search_call", status="completed") + assert _hosted_tool_result(item) == "completed"