Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
import tempfile
import threading
from collections.abc import AsyncIterable, AsyncIterator, Generator, Sequence
from contextlib import suppress
from contextlib import AbstractAsyncContextManager, AsyncExitStack, suppress
from dataclasses import asdict, is_dataclass
from pathlib import Path
from contextlib import AbstractAsyncContextManager, AsyncExitStack, suppress
from typing import Protocol, cast

from agent_framework import (
Expand Down Expand Up @@ -73,6 +72,7 @@
MessageContentOutputTextContent,
MessageContentReasoningTextContent,
MessageContentRefusalContent,
MessageRole,
OAuthConsentRequestOutputItem,
OutputItem,
OutputItemApplyPatchToolCall,
Expand Down Expand Up @@ -117,6 +117,8 @@

logger = logging.getLogger(__name__)

_AZURE_RESPONSES_MESSAGE_ROLE_TYPE = f"{MessageRole.__module__}:{MessageRole.__qualname__}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth pinning the resolved string in a test? At runtime this is azure.ai.agentserver.responses.models._generated.sdk.models.models._enums:MessageRole. If the azure SDK ever re-homes the enum, __module__ would silently change and the allowlist key would still match the new pickle GLOBAL, so this is robust today, but a one-line assert _AZURE_RESPONSES_MESSAGE_ROLE_TYPE == "..." would catch the inverse case (re-export path stays the same but the canonical module moves and pickled module:qualname shifts).



# region Approval Storage
class ApprovalStorage(Protocol):
Expand Down Expand Up @@ -250,7 +252,12 @@ def _checkpoint_storage_for_context(root: str, context_id: str) -> FileCheckpoin
storage_path = (root_path / context_id).resolve()
if not storage_path.is_relative_to(root_path):
raise RuntimeError(f"Invalid checkpoint context id: {context_id!r}")
return FileCheckpointStorage(storage_path)
return FileCheckpointStorage(
storage_path,
# Keep this provider-specific allowlist narrow. Hosted workflow
# checkpoints can persist Azure's role enum inside Message objects.
allowed_checkpoint_types=[_AZURE_RESPONSES_MESSAGE_ROLE_TYPE],
)
Comment thread
karimbaidar marked this conversation as resolved.


# endregion Approval Storage
Expand Down
81 changes: 81 additions & 0 deletions python/packages/foundry_hosting/tests/test_responses.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The four new tests exercise _checkpoint_storage_for_context directly. Should we also drive one of these through ResponsesHostServer so the bug-motivating restore path (multi-turn previous_response_id) is actually covered? Every existing HTTP-level multi-turn test uses a MagicMock(spec=RawAgent) that ignores checkpoint_storage, so a regression on the host-side wiring (someone passing the wrong factory at _responses.py:553-556 or :651-658) would not be caught.

Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
Message,
RawAgent,
ResponseStream,
WorkflowCheckpoint,
WorkflowCheckpointException,
WorkflowMessage,
)
from azure.ai.agentserver.responses import InMemoryResponseProvider
from mcp import McpError
Expand Down Expand Up @@ -2712,6 +2715,23 @@ def _helper() -> Callable[[str, str], FileCheckpointStorage]:

return _checkpoint_storage_for_context

@staticmethod
def _checkpoint_with_azure_message_role() -> WorkflowCheckpoint:
from azure.ai.agentserver.responses.models import MessageRole

return WorkflowCheckpoint(
workflow_name="wf",
graph_signature_hash="hash",
messages={
"executor": [
WorkflowMessage(
data=Message(role=MessageRole.USER, contents=[Content.from_text("hello")]),
source_id="source",
)
]
},
)

def test_valid_segment_creates_storage_under_root(self, tmp_path: Any) -> None:
helper = self._helper()
root = tmp_path / "root"
Expand All @@ -2720,6 +2740,65 @@ def test_valid_segment_creates_storage_under_root(self, tmp_path: Any) -> None:
assert storage.storage_path.is_dir()
assert storage.storage_path.parent == root.resolve()

async def test_storage_allows_azure_message_role_checkpoint_restore(self, tmp_path: Any) -> None:
from azure.ai.agentserver.responses.models import MessageRole

helper = self._helper()
root = tmp_path / "root"
root.mkdir()
storage = helper(str(root), "resp_abc123")
checkpoint = self._checkpoint_with_azure_message_role()

await storage.save(checkpoint)
loaded = await storage.load(checkpoint.checkpoint_id)

loaded_message = loaded.messages["executor"][0].data
assert isinstance(loaded_message, Message)
assert type(loaded_message.role) is MessageRole
assert loaded_message.role == MessageRole.USER
assert loaded_message.text == "hello"

async def test_plain_storage_blocks_azure_message_role_checkpoint_restore(self, tmp_path: Any) -> None:
storage = FileCheckpointStorage(tmp_path / "plain")
checkpoint = self._checkpoint_with_azure_message_role()

await storage.save(checkpoint)
with pytest.raises(WorkflowCheckpointException, match="MessageRole"):
await storage.load(checkpoint.checkpoint_id)

async def test_get_latest_restores_azure_message_role(self, tmp_path: Any) -> None:
from azure.ai.agentserver.responses.models import MessageRole

helper = self._helper()
root = tmp_path / "root"
root.mkdir()
storage = helper(str(root), "resp_abc123")
checkpoint = self._checkpoint_with_azure_message_role()

await storage.save(checkpoint)
latest = await storage.get_latest(workflow_name="wf")

assert latest is not None
assert latest.checkpoint_id == checkpoint.checkpoint_id
latest_message = latest.messages["executor"][0].data
assert isinstance(latest_message, Message)
assert type(latest_message.role) is MessageRole

async def test_get_latest_silently_skips_without_allowlist(
self, tmp_path: Any, caplog: pytest.LogCaptureFixture
) -> None:
import logging

storage = FileCheckpointStorage(tmp_path / "plain")
checkpoint = self._checkpoint_with_azure_message_role()

await storage.save(checkpoint)
with caplog.at_level(logging.WARNING, logger="agent_framework"):
latest = await storage.get_latest(workflow_name="wf")

assert latest is None
assert any("MessageRole" in message for message in caplog.messages)

@pytest.mark.parametrize(
"bad_id",
[
Expand Down Expand Up @@ -2923,6 +3002,8 @@ async def test_malicious_context_id_rejected_e2e(self, tmp_path: Any, context_fi
f"before={before} after={after}"
)
assert list(root.iterdir()) == [], f"Checkpoint directory created inside root for {context_field}={bad_id!r}"


# region Agent lifecycle (lazy entry & OAuth consent surfacing)


Expand Down