Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
import os
import tempfile
import threading
from collections.abc import AsyncIterable, AsyncIterator, Generator, Sequence
from contextlib import suppress
from collections.abc import AsyncIterable, AsyncIterator, Generator, Mapping, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack, suppress
from dataclasses import asdict, is_dataclass
from pathlib import Path
from contextlib import AbstractAsyncContextManager, AsyncExitStack, suppress
from typing import Protocol, cast

from agent_framework import (
Expand Down Expand Up @@ -465,14 +464,12 @@ async def _handle_inner_agent(
# Run the agent in non-streaming mode
response = await self._agent.run(stream=False, **run_kwargs) # type: ignore[reportUnknownMemberType]

for message in response.messages:
for content in message.contents:
async for item in _to_outputs(
response_event_stream,
content,
approval_storage=self._approval_storage,
):
yield item
async for item in _to_outputs_for_messages(
response_event_stream,
response.messages,
approval_storage=self._approval_storage,
):
yield item
yield response_event_stream.emit_completed()
else:
if tracker is None: # pragma: no cover - defensive, set above
Expand Down Expand Up @@ -613,10 +610,8 @@ async def _handle_inner_workflow(
checkpoint_storage=write_storage,
)

for message in response.messages:
for content in message.contents:
async for item in _to_outputs(response_event_stream, content):
yield item
Comment thread
Hameedkunkanoor marked this conversation as resolved.
async for item in _to_outputs_for_messages(response_event_stream, response.messages):
yield item

await self._delete_not_latest_checkpoints(write_storage, self._agent.workflow.name)
yield response_event_stream.emit_completed()
Expand Down Expand Up @@ -722,7 +717,7 @@ def handle(self, content: Content) -> Generator[ResponseStreamEvent]:
yield self._fc_builder.emit_arguments_delta(args_str)

elif content.type == "mcp_server_tool_call" and content.tool_name:
key = f"{content.server_name or 'default'}::{content.tool_name}"
key = content.call_id or f"{content.server_name or 'default'}::{content.tool_name}"
if self._active_type != "mcp_server_tool_call" or self._active_id != key:
yield from self._close()
yield from self._open_mcp_call(content)
Expand All @@ -731,6 +726,24 @@ def handle(self, content: Content) -> Generator[ResponseStreamEvent]:
if self._mcp_builder is not None:
yield self._mcp_builder.emit_arguments_delta(args_str)

elif (
content.type == "mcp_server_tool_result"
and self._active_type == "mcp_server_tool_call"
and self._mcp_builder is not None
and content.call_id is not None
and content.call_id == self._mcp_builder.item_id
):
accumulated = "".join(self._accumulated)
yield self._mcp_builder.emit_arguments_done(accumulated)
yield self._mcp_builder.emit_completed()
yield self._mcp_builder.emit_done(output=_stringify_mcp_output(content.output))
self._mcp_builder = None
self._active_type = None
self._active_id = None
self._accumulated.clear()
self.needs_async = False
return

else:
yield from self._close()
self.needs_async = True
Expand Down Expand Up @@ -770,9 +783,10 @@ def _open_mcp_call(self, content: Content) -> Generator[ResponseStreamEvent]:
self._mcp_builder = self._stream.add_output_item_mcp_call(
server_label=content.server_name or "default",
name=content.tool_name or "",
item_id=content.call_id,
)
self._active_type = "mcp_server_tool_call"
self._active_id = f"{content.server_name or 'default'}::{content.tool_name}"
self._active_id = content.call_id or f"{content.server_name or 'default'}::{content.tool_name}"
yield self._mcp_builder.emit_added()

def _close(self) -> Generator[ResponseStreamEvent]:
Expand Down Expand Up @@ -920,16 +934,19 @@ async def _item_to_message(item: Item, *, approval_storage: ApprovalStorage | No

if item.type == "mcp_call":
mcp = cast(ItemMcpToolCall, item)
contents = [
Content.from_mcp_server_tool_call(
mcp.id,
mcp.name,
server_name=mcp.server_label,
arguments=mcp.arguments,
)
]
if getattr(mcp, "output", None) is not None:
contents.append(Content.from_mcp_server_tool_result(call_id=mcp.id, output=mcp.output))
return Message(
role="assistant",
contents=[
Content.from_mcp_server_tool_call(
mcp.id,
mcp.name,
server_name=mcp.server_label,
arguments=mcp.arguments,
)
],
contents=contents,
)

if item.type == "mcp_approval_request":
Expand Down Expand Up @@ -1190,16 +1207,19 @@ async def _output_item_to_message(item: OutputItem, *, approval_storage: Approva

if item.type == "mcp_call":
mcp = cast(OutputItemMcpToolCall, item)
contents = [
Content.from_mcp_server_tool_call(
mcp.id,
mcp.name,
server_name=mcp.server_label,
arguments=mcp.arguments,
)
]
if getattr(mcp, "output", None) is not None:
contents.append(Content.from_mcp_server_tool_result(call_id=mcp.id, output=mcp.output))
return Message(
role="assistant",
contents=[
Content.from_mcp_server_tool_call(
mcp.id,
mcp.name,
server_name=mcp.server_label,
arguments=mcp.arguments,
)
],
contents=contents,
)

if item.type == "mcp_approval_request":
Expand Down Expand Up @@ -1576,6 +1596,7 @@ async def _to_outputs(
mcp_call = stream.add_output_item_mcp_call(
server_label=content.server_name or "default",
name=content.tool_name or "",
item_id=content.call_id,
)
yield mcp_call.emit_added()
async for event in mcp_call.aarguments(_arguments_to_str(content.arguments)):
Expand Down Expand Up @@ -1650,4 +1671,91 @@ async def _to_outputs(
logger.warning(f"Content type '{content.type}' is not supported yet. This is usually safe to ignore.")


def _stringify_mcp_output(output: Any) -> str:
"""Convert hosted MCP output payloads into the string shape expected by mcp_call.output."""
if output is None:
return ""
if isinstance(output, str):
return output
if isinstance(output, Mapping):
text = cast(Any, output).get("text")
if isinstance(text, str):
return text
return json.dumps(output, default=str)
if isinstance(output, Sequence) and not isinstance(output, (str, bytes, bytearray)):
parts: list[str] = []
entries = cast(Sequence[object], output)
for entry in entries:
if isinstance(entry, Content) and entry.type == "text":
parts.append(entry.text or "")
continue
parts.append(_stringify_mcp_output(entry))
return "".join(parts)
return str(output)


def _emit_completed_mcp_call(
stream: ResponseEventStream,
call_content: Content,
*,
arguments: str,
output: str,
) -> Generator[ResponseStreamEvent]:
"""Emit a single completed MCP call item carrying both arguments and output."""
mcp_call = stream.add_output_item_mcp_call(
server_label=call_content.server_name or "default",
name=call_content.tool_name or "",
item_id=call_content.call_id,
)
yield mcp_call.emit_added()
yield mcp_call.emit_arguments_done(arguments)
yield mcp_call.emit_completed()
yield mcp_call.emit_done(output=output)


async def _to_outputs_for_messages(
stream: ResponseEventStream,
messages: Sequence[Message],
*,
approval_storage: ApprovalStorage | None = None,
) -> AsyncIterator[ResponseStreamEvent]:
"""Convert messages to output events with hosted-MCP call/result coalescing.

Parse once in message/content order and emit either:
- a single canonical completed ``mcp_call`` when adjacent hosted MCP
call/result content are encountered, or
- standard output items for all other content types.
"""
pending_mcp_call: Content | None = None

for message in messages:
for content in message.contents:
if pending_mcp_call is not None:
if content.type == "mcp_server_tool_result" and content.call_id == pending_mcp_call.call_id:
for event in _emit_completed_mcp_call(
stream,
pending_mcp_call,
arguments=_arguments_to_str(pending_mcp_call.arguments),
output=_stringify_mcp_output(content.output),
):
yield event
pending_mcp_call = None
continue

async for event in _to_outputs(stream, pending_mcp_call, approval_storage=approval_storage):
yield event
pending_mcp_call = None

if content.type == "mcp_server_tool_call" and content.call_id:
pending_mcp_call = content
continue

async for event in _to_outputs(stream, content, approval_storage=approval_storage):
yield event

if pending_mcp_call is not None:
async for event in _to_outputs(stream, pending_mcp_call, approval_storage=approval_storage):
yield event


# endregion
2 changes: 1 addition & 1 deletion python/packages/foundry_hosting/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ classifiers = [
dependencies = [
"agent-framework-core>=1.6.0,<2",
"azure-ai-agentserver-core>=2.0.0b3,<3",
"azure-ai-agentserver-responses>=1.0.0b5,<2",
"azure-ai-agentserver-responses>=1.0.0b7,<2",
"azure-ai-agentserver-invocations>=1.0.0b3,<2",
]

Expand Down
Loading