Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from agentex.lib.core.tracing.tracer import AsyncTracer
from agentex.types.task_message_delta import TextDelta, ToolRequestDelta, ReasoningContentDelta, ReasoningSummaryDelta
from agentex.types.task_message_update import StreamTaskMessageFull, StreamTaskMessageDelta
from agentex.types.task_message_content import TextContent, ReasoningContent, ToolRequestContent
from agentex.types.task_message_content import TextContent, ReasoningContent, ToolRequestContent, ToolResponseContent
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from agentex.lib.core.temporal.plugins.openai_agents.interceptors.context_interceptor import (
streaming_task_id,
Expand Down Expand Up @@ -123,6 +123,103 @@ def _serialize_item(item: Any) -> dict[str, Any]:
return item_dict


# Responses-API output items for server-side / hosted tools. These execute inside
# the Responses API, so they never become function_call items AND the SDK's
# RunHooks (on_tool_start/on_tool_end) never fire for them. The streaming loop
# must surface them explicitly, as a tool request + response pair, when the item
# completes (by then it carries the full query/result).
_HOSTED_TOOL_TYPES = frozenset(
{
"web_search_call",
"file_search_call",
"code_interpreter_call",
"image_generation_call",
"mcp_call",
"computer_call",
"local_shell_call",
}
)

# Cap on the rendered hosted-tool result string (UI / trace readability).
_HOSTED_TOOL_RESULT_CAP = 2000


def _coerce_args(raw: Any) -> dict[str, Any]:
"""Best-effort coerce a hosted-tool's arguments to a dict for the UI."""
if raw is None:
return {}
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else {"value": parsed}
except (json.JSONDecodeError, ValueError):
return {"raw": raw}
serialized = _serialize_item(raw)
return serialized if isinstance(serialized, dict) else {"value": str(raw)}


def _hosted_tool_request(item: Any) -> tuple[str, str, dict[str, Any]]:
"""Extract (call_id, display_name, arguments) from a hosted-tool item."""
itype = getattr(item, "type", "") or ""
call_id = (
getattr(item, "id", "")
or getattr(item, "call_id", "")
or f"hosted_{uuid.uuid4().hex[:8]}"
)
name = itype[:-5] if itype.endswith("_call") else itype # web_search_call -> web_search
args: dict[str, Any] = {}
if itype == "web_search_call":
action = getattr(item, "action", None)
if action is not None:
args = _coerce_args(action)
elif itype == "file_search_call":
args = {"queries": list(getattr(item, "queries", []) or [])}
elif itype == "code_interpreter_call":
args = {"code": getattr(item, "code", "") or ""}
elif itype in ("computer_call", "local_shell_call"):
# Both carry an `action` object: a ComputerAction (click/scroll/type/...)
# or a LocalShellCallAction (command/env/cwd). Surface it as the args so
# the trace shows what the tool actually did, not just its status.
action = getattr(item, "action", None)
if action is not None:
args = _coerce_args(action)
elif itype == "mcp_call":
mcp_name = getattr(item, "name", None) or "mcp"
server = getattr(item, "server_label", None)
name = f"{server}.{mcp_name}" if server else mcp_name
args = _coerce_args(getattr(item, "arguments", None))
return call_id, name, args
Comment thread
declan-scale marked this conversation as resolved.
Comment thread
declan-scale marked this conversation as resolved.


def _hosted_tool_result(item: Any) -> str:
"""Extract a short result string from a completed hosted-tool item."""
itype = getattr(item, "type", "") or ""
if itype == "mcp_call":
err = getattr(item, "error", None)
if err:
return f"error: {err}"
out = getattr(item, "output", None)
if out:
return str(out)
elif itype == "code_interpreter_call":
outputs = getattr(item, "outputs", None)
if outputs:
return json.dumps([_serialize_item(o) for o in outputs])[:_HOSTED_TOOL_RESULT_CAP]
elif itype == "file_search_call":
results = getattr(item, "results", None)
if results:
return json.dumps([_serialize_item(r) for r in results])[:_HOSTED_TOOL_RESULT_CAP]
elif itype == "image_generation_call":
# `result` is base64 image data; surface a compact reference instead of
# dumping the (large) payload into the trace.
result = getattr(item, "result", None)
if result:
return f"<image generated: {len(result)} bytes>"
return str(getattr(item, "status", "completed") or "completed")


class TemporalStreamingModel(Model):
"""Custom model implementation with streaming support."""

Expand Down Expand Up @@ -481,6 +578,31 @@ def _convert_tool_choice(self, tool_choice: Any) -> Any:
# Pass through as-is for other types
return tool_choice

async def _post_tool_message(self, task_id: str, content: Any) -> None:
"""Post a one-shot tool request/response message (no deltas).

Used for hosted/server-side tool calls (web_search, file_search,
code_interpreter, image generation, server-side mcp, ...) that execute
inside the Responses API and so never produce function_call items or fire
RunHooks. Each completed hosted tool is surfaced as a ToolRequestContent +
ToolResponseContent pair. Posting full (no deltas) means the coalescing
path that the streamed reasoning/text contexts use does not apply here.
"""
try:
async with adk.streaming.streaming_task_message_context(
task_id=task_id,
initial_content=content,
) as ctx:
await ctx.stream_update(
StreamTaskMessageFull(
parent_task_message=ctx.task_message,
content=content,
type="full",
)
)
except Exception as e: # noqa: BLE001 - UI surfacing must never break a turn
logger.warning(f"[TemporalStreamingModel] failed to post hosted-tool message: {e}")

@override
async def get_response(
self,
Expand Down Expand Up @@ -942,6 +1064,33 @@ async def get_response(
finally:
call_data['context'] = None

elif item and getattr(item, 'type', None) in _HOSTED_TOOL_TYPES:
# Hosted / server-side tool call (web_search, file_search,
# code_interpreter, image generation, server-side mcp, ...).
# These run inside the Responses API: no function_call item
# and no RunHooks fire, so surface the completed call as a
# tool request + response pair (it carries the full
# query/result by the time it's done).
call_id, name, args = _hosted_tool_request(item)
await self._post_tool_message(
task_id,
ToolRequestContent(
author="agent",
tool_call_id=call_id,
name=name,
arguments=args,
),
)
await self._post_tool_message(
task_id,
ToolResponseContent(
author="agent",
tool_call_id=call_id,
name=name,
content={"result": _hosted_tool_result(item)[:_HOSTED_TOOL_RESULT_CAP]},
),
)

elif isinstance(event, ResponseReasoningSummaryPartAddedEvent):
# New reasoning part/summary started - reset accumulator
part = getattr(event, 'part', None)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Unit tests for hosted/server-side tool rendering helpers.

These cover the pure extraction helpers used by TemporalStreamingModel to surface
Responses-API hosted tools (web_search, file_search, code_interpreter, mcp, ...)
as ToolRequest/ToolResponse pairs. They never become function_call items, so the
streaming loop must render them explicitly.
"""

from __future__ import annotations

from types import SimpleNamespace

from openai.types.responses.response_output_item import (
LocalShellCall,
ImageGenerationCall,
LocalShellCallAction,
)
from openai.types.responses.response_computer_tool_call import ActionClick, ResponseComputerToolCall
from openai.types.responses.response_function_web_search import ActionSearch, ResponseFunctionWebSearch

from agentex.lib.core.temporal.plugins.openai_agents.models.temporal_streaming_model import (
_HOSTED_TOOL_TYPES,
_coerce_args,
_hosted_tool_result,
_hosted_tool_request,
)


def test_hosted_tool_types_membership():
for t in ("web_search_call", "file_search_call", "code_interpreter_call",
"image_generation_call", "mcp_call", "computer_call", "local_shell_call"):
assert t in _HOSTED_TOOL_TYPES
assert "function_call" not in _HOSTED_TOOL_TYPES


def test_coerce_args_variants():
assert _coerce_args(None) == {}
assert _coerce_args({"a": 1}) == {"a": 1}
assert _coerce_args('{"a": 1}') == {"a": 1}
assert _coerce_args("[1, 2]") == {"value": [1, 2]}
assert _coerce_args("not json") == {"raw": "not json"}


def test_hosted_tool_request_web_search():
# Use the real Responses-API type to prove `action` is a genuine SDK field
# (it is on ResponseFunctionWebSearch), not a hand-crafted stand-in.
item = ResponseFunctionWebSearch(
id="ws_1",
status="completed",
type="web_search_call",
action=ActionSearch(type="search", query="agentex"),
)
call_id, name, args = _hosted_tool_request(item)
assert call_id == "ws_1"
assert name == "web_search" # "_call" stripped
assert args["query"] == "agentex"
assert args["type"] == "search"


def test_hosted_tool_request_computer_call():
item = ResponseComputerToolCall(
id="cc_1",
call_id="ccall_1",
type="computer_call",
status="completed",
pending_safety_checks=[],
action=ActionClick(type="click", button="left", x=10, y=20),
)
call_id, name, args = _hosted_tool_request(item)
assert call_id == "cc_1"
assert name == "computer"
assert args["type"] == "click"
assert args["button"] == "left"
assert args["x"] == 10 and args["y"] == 20


def test_hosted_tool_request_local_shell_call():
item = LocalShellCall(
id="ls_1",
call_id="lscall_1",
type="local_shell_call",
status="completed",
action=LocalShellCallAction(type="exec", command=["ls", "-la"], env={}),
)
call_id, name, args = _hosted_tool_request(item)
assert call_id == "ls_1"
assert name == "local_shell"
assert args["command"] == ["ls", "-la"]


def test_hosted_tool_request_mcp_uses_server_label():
item = SimpleNamespace(type="mcp_call", id="m_1", name="search",
server_label="linear", arguments='{"q": "x"}')
call_id, name, args = _hosted_tool_request(item)
assert call_id == "m_1"
assert name == "linear.search"
assert args == {"q": "x"}


def test_hosted_tool_request_file_search_queries():
item = SimpleNamespace(type="file_search_call", id="fs_1",
queries=["q1", "q2"])
_, name, args = _hosted_tool_request(item)
assert name == "file_search"
assert args == {"queries": ["q1", "q2"]}


def test_hosted_tool_request_falls_back_to_generated_id():
item = SimpleNamespace(type="code_interpreter_call", code="print(1)")
call_id, name, args = _hosted_tool_request(item)
assert call_id.startswith("hosted_")
assert name == "code_interpreter"
assert args == {"code": "print(1)"}


def test_hosted_tool_result_mcp_error_and_output():
err_item = SimpleNamespace(type="mcp_call", error="boom")
assert "boom" in _hosted_tool_result(err_item)
ok_item = SimpleNamespace(type="mcp_call", error=None, output="done")
assert _hosted_tool_result(ok_item) == "done"


def test_hosted_tool_result_image_generation():
item = ImageGenerationCall(
id="ig_1",
type="image_generation_call",
status="completed",
result="QUJD", # 4 chars of (fake) base64
)
assert _hosted_tool_result(item) == "<image generated: 4 bytes>"


def test_hosted_tool_result_falls_back_to_status():
item = SimpleNamespace(type="web_search_call", status="completed")
assert _hosted_tool_result(item) == "completed"
Loading