Skip to content

Commit ec0df6c

Browse files
declan-scaleclaude
andcommitted
feat(openai-temporal): render hosted/server-side tool calls in TemporalStreamingModel
web_search, file_search, code_interpreter, image generation, and server-side mcp calls run inside the Responses API: they never become function_call items and the SDK RunHooks never fire, so the streaming model dropped them entirely (no UI event). Surface each completed hosted-tool output item as a ToolRequestContent + ToolResponseContent pair, mirroring how function_call tools render. This upstreams the hosted-tool half of the golden agent's vendored RichToolStreamingModel so that agent can stop vendoring. The reasoning double-render fix it also carried is already handled upstream (a533598), so only the hosted-tool handling is ported here. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 0fadfd7 commit ec0df6c

2 files changed

Lines changed: 216 additions & 1 deletion

File tree

src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
from agentex.lib.core.tracing.tracer import AsyncTracer
6767
from agentex.types.task_message_delta import TextDelta, ToolRequestDelta, ReasoningContentDelta, ReasoningSummaryDelta
6868
from agentex.types.task_message_update import StreamTaskMessageFull, StreamTaskMessageDelta
69-
from agentex.types.task_message_content import TextContent, ReasoningContent, ToolRequestContent
69+
from agentex.types.task_message_content import TextContent, ReasoningContent, ToolRequestContent, ToolResponseContent
7070
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
7171
from agentex.lib.core.temporal.plugins.openai_agents.interceptors.context_interceptor import (
7272
streaming_task_id,
@@ -123,6 +123,90 @@ def _serialize_item(item: Any) -> dict[str, Any]:
123123
return item_dict
124124

125125

126+
# Responses-API output items for server-side / hosted tools. These execute inside
127+
# the Responses API, so they never become function_call items AND the SDK's
128+
# RunHooks (on_tool_start/on_tool_end) never fire for them. The streaming loop
129+
# must surface them explicitly, as a tool request + response pair, when the item
130+
# completes (by then it carries the full query/result).
131+
_HOSTED_TOOL_TYPES = frozenset(
132+
{
133+
"web_search_call",
134+
"file_search_call",
135+
"code_interpreter_call",
136+
"image_generation_call",
137+
"mcp_call",
138+
"computer_call",
139+
"local_shell_call",
140+
}
141+
)
142+
143+
# Cap on the rendered hosted-tool result string (UI / trace readability).
144+
_HOSTED_TOOL_RESULT_CAP = 2000
145+
146+
147+
def _coerce_args(raw: Any) -> dict[str, Any]:
148+
"""Best-effort coerce a hosted-tool's arguments to a dict for the UI."""
149+
if raw is None:
150+
return {}
151+
if isinstance(raw, dict):
152+
return raw
153+
if isinstance(raw, str):
154+
try:
155+
parsed = json.loads(raw)
156+
return parsed if isinstance(parsed, dict) else {"value": parsed}
157+
except (json.JSONDecodeError, ValueError):
158+
return {"raw": raw}
159+
serialized = _serialize_item(raw)
160+
return serialized if isinstance(serialized, dict) else {"value": str(raw)}
161+
162+
163+
def _hosted_tool_request(item: Any) -> tuple[str, str, dict[str, Any]]:
164+
"""Extract (call_id, display_name, arguments) from a hosted-tool item."""
165+
itype = getattr(item, "type", "") or ""
166+
call_id = (
167+
getattr(item, "id", "")
168+
or getattr(item, "call_id", "")
169+
or f"hosted_{uuid.uuid4().hex[:8]}"
170+
)
171+
name = itype[:-5] if itype.endswith("_call") else itype # web_search_call -> web_search
172+
args: dict[str, Any] = {}
173+
if itype == "web_search_call":
174+
action = getattr(item, "action", None)
175+
if action is not None:
176+
args = _coerce_args(action)
177+
elif itype == "file_search_call":
178+
args = {"queries": list(getattr(item, "queries", []) or [])}
179+
elif itype == "code_interpreter_call":
180+
args = {"code": getattr(item, "code", "") or ""}
181+
elif itype == "mcp_call":
182+
mcp_name = getattr(item, "name", None) or "mcp"
183+
server = getattr(item, "server_label", None)
184+
name = f"{server}.{mcp_name}" if server else mcp_name
185+
args = _coerce_args(getattr(item, "arguments", None))
186+
return call_id, name, args
187+
188+
189+
def _hosted_tool_result(item: Any) -> str:
190+
"""Extract a short result string from a completed hosted-tool item."""
191+
itype = getattr(item, "type", "") or ""
192+
if itype == "mcp_call":
193+
err = getattr(item, "error", None)
194+
if err:
195+
return f"error: {err}"
196+
out = getattr(item, "output", None)
197+
if out:
198+
return str(out)
199+
elif itype == "code_interpreter_call":
200+
outputs = getattr(item, "outputs", None)
201+
if outputs:
202+
return json.dumps([_serialize_item(o) for o in outputs])[:_HOSTED_TOOL_RESULT_CAP]
203+
elif itype == "file_search_call":
204+
results = getattr(item, "results", None)
205+
if results:
206+
return json.dumps([_serialize_item(r) for r in results])[:_HOSTED_TOOL_RESULT_CAP]
207+
return str(getattr(item, "status", "completed") or "completed")
208+
209+
126210
class TemporalStreamingModel(Model):
127211
"""Custom model implementation with streaming support."""
128212

@@ -481,6 +565,31 @@ def _convert_tool_choice(self, tool_choice: Any) -> Any:
481565
# Pass through as-is for other types
482566
return tool_choice
483567

568+
async def _post_tool_message(self, task_id: str, content: Any) -> None:
569+
"""Post a one-shot tool request/response message (no deltas).
570+
571+
Used for hosted/server-side tool calls (web_search, file_search,
572+
code_interpreter, image generation, server-side mcp, ...) that execute
573+
inside the Responses API and so never produce function_call items or fire
574+
RunHooks. Each completed hosted tool is surfaced as a ToolRequestContent +
575+
ToolResponseContent pair. Posting full (no deltas) means the coalescing
576+
path that the streamed reasoning/text contexts use does not apply here.
577+
"""
578+
try:
579+
async with adk.streaming.streaming_task_message_context(
580+
task_id=task_id,
581+
initial_content=content,
582+
) as ctx:
583+
await ctx.stream_update(
584+
StreamTaskMessageFull(
585+
parent_task_message=ctx.task_message,
586+
content=content,
587+
type="full",
588+
)
589+
)
590+
except Exception as e: # noqa: BLE001 - UI surfacing must never break a turn
591+
logger.warning(f"[TemporalStreamingModel] failed to post hosted-tool message: {e}")
592+
484593
@override
485594
async def get_response(
486595
self,
@@ -942,6 +1051,33 @@ async def get_response(
9421051
finally:
9431052
call_data['context'] = None
9441053

1054+
elif item and getattr(item, 'type', None) in _HOSTED_TOOL_TYPES:
1055+
# Hosted / server-side tool call (web_search, file_search,
1056+
# code_interpreter, image generation, server-side mcp, ...).
1057+
# These run inside the Responses API: no function_call item
1058+
# and no RunHooks fire, so surface the completed call as a
1059+
# tool request + response pair (it carries the full
1060+
# query/result by the time it's done).
1061+
call_id, name, args = _hosted_tool_request(item)
1062+
await self._post_tool_message(
1063+
task_id,
1064+
ToolRequestContent(
1065+
author="agent",
1066+
tool_call_id=call_id,
1067+
name=name,
1068+
arguments=args,
1069+
),
1070+
)
1071+
await self._post_tool_message(
1072+
task_id,
1073+
ToolResponseContent(
1074+
author="agent",
1075+
tool_call_id=call_id,
1076+
name=name,
1077+
content={"result": _hosted_tool_result(item)[:_HOSTED_TOOL_RESULT_CAP]},
1078+
),
1079+
)
1080+
9451081
elif isinstance(event, ResponseReasoningSummaryPartAddedEvent):
9461082
# New reasoning part/summary started - reset accumulator
9471083
part = getattr(event, 'part', None)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Unit tests for hosted/server-side tool rendering helpers.
2+
3+
These cover the pure extraction helpers used by TemporalStreamingModel to surface
4+
Responses-API hosted tools (web_search, file_search, code_interpreter, mcp, ...)
5+
as ToolRequest/ToolResponse pairs. They never become function_call items, so the
6+
streaming loop must render them explicitly.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
from types import SimpleNamespace
12+
13+
from agentex.lib.core.temporal.plugins.openai_agents.models.temporal_streaming_model import (
14+
_HOSTED_TOOL_TYPES,
15+
_coerce_args,
16+
_hosted_tool_result,
17+
_hosted_tool_request,
18+
)
19+
20+
21+
def test_hosted_tool_types_membership():
22+
for t in ("web_search_call", "file_search_call", "code_interpreter_call",
23+
"image_generation_call", "mcp_call"):
24+
assert t in _HOSTED_TOOL_TYPES
25+
assert "function_call" not in _HOSTED_TOOL_TYPES
26+
27+
28+
def test_coerce_args_variants():
29+
assert _coerce_args(None) == {}
30+
assert _coerce_args({"a": 1}) == {"a": 1}
31+
assert _coerce_args('{"a": 1}') == {"a": 1}
32+
assert _coerce_args("[1, 2]") == {"value": [1, 2]}
33+
assert _coerce_args("not json") == {"raw": "not json"}
34+
35+
36+
def test_hosted_tool_request_web_search():
37+
item = SimpleNamespace(type="web_search_call", id="ws_1",
38+
action={"query": "agentex"})
39+
call_id, name, args = _hosted_tool_request(item)
40+
assert call_id == "ws_1"
41+
assert name == "web_search" # "_call" stripped
42+
assert args == {"query": "agentex"}
43+
44+
45+
def test_hosted_tool_request_mcp_uses_server_label():
46+
item = SimpleNamespace(type="mcp_call", id="m_1", name="search",
47+
server_label="linear", arguments='{"q": "x"}')
48+
call_id, name, args = _hosted_tool_request(item)
49+
assert call_id == "m_1"
50+
assert name == "linear.search"
51+
assert args == {"q": "x"}
52+
53+
54+
def test_hosted_tool_request_file_search_queries():
55+
item = SimpleNamespace(type="file_search_call", id="fs_1",
56+
queries=["q1", "q2"])
57+
_, name, args = _hosted_tool_request(item)
58+
assert name == "file_search"
59+
assert args == {"queries": ["q1", "q2"]}
60+
61+
62+
def test_hosted_tool_request_falls_back_to_generated_id():
63+
item = SimpleNamespace(type="code_interpreter_call", code="print(1)")
64+
call_id, name, args = _hosted_tool_request(item)
65+
assert call_id.startswith("hosted_")
66+
assert name == "code_interpreter"
67+
assert args == {"code": "print(1)"}
68+
69+
70+
def test_hosted_tool_result_mcp_error_and_output():
71+
err_item = SimpleNamespace(type="mcp_call", error="boom")
72+
assert "boom" in _hosted_tool_result(err_item)
73+
ok_item = SimpleNamespace(type="mcp_call", error=None, output="done")
74+
assert _hosted_tool_result(ok_item) == "done"
75+
76+
77+
def test_hosted_tool_result_falls_back_to_status():
78+
item = SimpleNamespace(type="web_search_call", status="completed")
79+
assert _hosted_tool_result(item) == "completed"

0 commit comments

Comments
 (0)