From 89ef6c38f7719c1c3b9bd5313274912030dd95b4 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 25 May 2026 13:54:23 +0000 Subject: [PATCH 1/3] Add a BackgroundAgentsProvider for python --- .../packages/core/agent_framework/__init__.py | 12 + .../_harness/_background_agents.py | 496 ++++++++++++++++ .../core/test_harness_background_agents.py | 530 ++++++++++++++++++ 3 files changed, 1038 insertions(+) create mode 100644 python/packages/core/agent_framework/_harness/_background_agents.py create mode 100644 python/packages/core/tests/core/test_harness_background_agents.py diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index 356051da3f..497afe37be 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -79,6 +79,13 @@ tool_calls_present, ) from ._feature_stage import ExperimentalFeature, ReleaseCandidateFeature +from ._harness._background_agents import ( + DEFAULT_BACKGROUND_AGENTS_RUNTIME_SOURCE_ID, + DEFAULT_BACKGROUND_AGENTS_SOURCE_ID, + BackgroundAgentsProvider, + BackgroundTaskInfo, + BackgroundTaskStatus, +) from ._harness._memory import ( DEFAULT_MEMORY_SOURCE_ID, MemoryContextProvider, @@ -297,6 +304,8 @@ "AGENT_FRAMEWORK_USER_AGENT", "APP_INFO", "COMPACTION_STATE_KEY", + "DEFAULT_BACKGROUND_AGENTS_RUNTIME_SOURCE_ID", + "DEFAULT_BACKGROUND_AGENTS_SOURCE_ID", "DEFAULT_MAX_ITERATIONS", "DEFAULT_MEMORY_SOURCE_ID", "DEFAULT_MODE_SOURCE_ID", @@ -332,6 +341,9 @@ "AgentSession", "AggregatingSkillsSource", "Annotation", + "BackgroundAgentsProvider", + "BackgroundTaskInfo", + "BackgroundTaskStatus", "BaseAgent", "BaseChatClient", "BaseEmbeddingClient", diff --git a/python/packages/core/agent_framework/_harness/_background_agents.py b/python/packages/core/agent_framework/_harness/_background_agents.py new file mode 100644 index 0000000000..d2f820fc06 --- /dev/null +++ b/python/packages/core/agent_framework/_harness/_background_agents.py @@ -0,0 +1,496 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""BackgroundAgentsProvider: enables an agent to delegate work to background sub-agents asynchronously. + +This module provides :class:`BackgroundAgentsProvider`, a context provider that allows +a parent agent to start background tasks on child agents, wait for their completion, +and retrieve results. Each background task runs in its own session concurrently. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Sequence +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from .._agents import SupportsAgentRun +from .._feature_stage import ExperimentalFeature, experimental +from .._sessions import AgentSession, ContextProvider, SessionContext +from .._tools import tool +from .._types import AgentResponse, Message + +DEFAULT_BACKGROUND_AGENTS_SOURCE_ID = "background_agents" +DEFAULT_BACKGROUND_AGENTS_RUNTIME_SOURCE_ID = "background_agents_runtime" + +DEFAULT_BACKGROUND_AGENTS_INSTRUCTIONS = """\ +## Background Agents + +You have access to background agents that can perform work on your behalf. + +- Use the `background_agents_*` tools to start tasks on background agents and check their results. +- Creating a background task does not block, and background tasks run concurrently. +- Important: Always wait for outstanding tasks to finish before you finish processing. +- Important: After retrieving results from a completed task, clear it with \ +background_agents_clear_completed_task to free memory, unless you plan to continue it with \ +background_agents_continue_task. + +{background_agents}""" + + +class BackgroundTaskStatus(str, Enum): + """Status of a background task.""" + + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + LOST = "lost" + + +@experimental(feature_id=ExperimentalFeature.HARNESS) +@dataclass +class BackgroundTaskInfo: + """Metadata for a single background task.""" + + id: int + agent_name: str + description: str + status: BackgroundTaskStatus = BackgroundTaskStatus.RUNNING + result_text: str | None = None + error_text: str | None = None + + def to_dict(self) -> dict[str, Any]: + """Serialize for session state persistence.""" + data: dict[str, Any] = { + "id": self.id, + "agent_name": self.agent_name, + "description": self.description, + "status": self.status.value, + } + if self.result_text is not None: + data["result_text"] = self.result_text + if self.error_text is not None: + data["error_text"] = self.error_text + return data + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> BackgroundTaskInfo: + """Deserialize from session state.""" + return cls( + id=data["id"], + agent_name=data["agent_name"], + description=data["description"], + status=BackgroundTaskStatus(data["status"]), + result_text=data.get("result_text"), + error_text=data.get("error_text"), + ) + + +@dataclass +class _RuntimeState: + """Non-serializable per-session runtime state for background tasks.""" + + in_flight_tasks: dict[int, asyncio.Task[AgentResponse[Any]]] = field(default_factory=dict) + background_sessions: dict[int, AgentSession] = field(default_factory=dict) + + +# --------------------------------------------------------------------------- +# Module-level helper functions (following ModeProvider pattern) +# --------------------------------------------------------------------------- + + +def _validate_and_build_agent_dict(agents: Sequence[SupportsAgentRun]) -> dict[str, SupportsAgentRun]: + """Validate agents and build a case-insensitive lookup dict. + + Raises: + ValueError: If agents is empty, an agent has no name, or names are not unique. + """ + if not agents: + raise ValueError("At least one background agent must be provided.") + + agent_dict: dict[str, SupportsAgentRun] = {} + for agent in agents: + name = agent.name + if not name or not name.strip(): + raise ValueError("All background agents must have a non-empty name.") + key = name.lower() + if key in agent_dict: + raise ValueError( + f"Duplicate background agent name: '{name}'. Agent names must be unique (case-insensitive)." + ) + agent_dict[key] = agent + return agent_dict + + +def _build_agent_list_text(agents: dict[str, SupportsAgentRun]) -> str: + """Build text listing available background agents.""" + lines = ["Available background agents:"] + for agent in agents.values(): + line = f"- {agent.name}" + if agent.description: + line += f": {agent.description}" + lines.append(line) + return "\n".join(lines) + + +def _get_provider_state(session: AgentSession, *, source_id: str) -> dict[str, Any]: + """Load or initialize serializable provider state from session.""" + state = session.state.get(source_id) + if state is None: + state = {"next_task_id": 1, "tasks": []} + session.state[source_id] = state + return state # type: ignore[return-value] + + +def _save_provider_state(session: AgentSession, state: dict[str, Any], *, source_id: str) -> None: + """Persist serializable state to session.""" + session.state[source_id] = state + + +def _get_tasks(state: dict[str, Any]) -> list[BackgroundTaskInfo]: + """Parse task list from state dict.""" + return [BackgroundTaskInfo.from_dict(t) for t in state.get("tasks", [])] + + +def _save_tasks(state: dict[str, Any], tasks: list[BackgroundTaskInfo]) -> None: + """Serialize task list back to state dict.""" + state["tasks"] = [t.to_dict() for t in tasks] + + +def _finalize_task( + task_info: BackgroundTaskInfo, + completed_task: asyncio.Task[AgentResponse[Any]], + runtime: _RuntimeState, +) -> None: + """Extract results from a completed asyncio task and update task info.""" + exception = completed_task.exception() + if exception is not None: + task_info.status = BackgroundTaskStatus.FAILED + task_info.error_text = str(exception) + elif completed_task.cancelled(): + task_info.status = BackgroundTaskStatus.FAILED + task_info.error_text = "Task was canceled." + else: + task_info.status = BackgroundTaskStatus.COMPLETED + task_info.result_text = completed_task.result().text + runtime.in_flight_tasks.pop(task_info.id, None) + + +def _refresh_task_state( + session: AgentSession, state: dict[str, Any], runtime: _RuntimeState, *, source_id: str +) -> list[BackgroundTaskInfo]: + """Refresh status of in-flight tasks and return updated task list.""" + tasks = _get_tasks(state) + changed = False + + for task_info in tasks: + if task_info.status != BackgroundTaskStatus.RUNNING: + continue + + in_flight = runtime.in_flight_tasks.get(task_info.id) + if in_flight is None: + task_info.status = BackgroundTaskStatus.LOST + changed = True + continue + + if in_flight.done(): + _finalize_task(task_info, in_flight, runtime) + changed = True + + if changed: + _save_tasks(state, tasks) + _save_provider_state(session, state, source_id=source_id) + + return tasks + + +# --------------------------------------------------------------------------- +# Provider class +# --------------------------------------------------------------------------- + + +@experimental(feature_id=ExperimentalFeature.HARNESS) +class BackgroundAgentsProvider(ContextProvider): + """Context provider that enables an agent to delegate work to background sub-agents. + + The ``BackgroundAgentsProvider`` allows a parent agent to start background tasks on child agents, + wait for their completion, and retrieve results. Each background task runs in its own session and + executes concurrently. + + This provider exposes the following tools to the agent: + + - ``background_agents_start_task`` — Start a background task on a named agent with text input. + - ``background_agents_wait_for_first_completion`` — Block until the first of the specified tasks completes. + - ``background_agents_get_task_results`` — Retrieve the text output of a completed background task. + - ``background_agents_get_all_tasks`` — List all background tasks with their IDs, statuses, and descriptions. + - ``background_agents_continue_task`` — Send follow-up input to a completed task's session to resume work. + - ``background_agents_clear_completed_task`` — Remove a completed task and release its session. + """ + + def __init__( + self, + agents: Sequence[SupportsAgentRun], + *, + source_id: str = DEFAULT_BACKGROUND_AGENTS_SOURCE_ID, + runtime_source_id: str | None = None, + instructions: str | None = None, + ) -> None: + """Initialize the background agents provider. + + Args: + agents: Collection of background agents available for delegation. + Each agent must have a non-empty, unique name (case-insensitive). + + Keyword Args: + source_id: Unique source ID for serializable task state in session. + runtime_source_id: Unique source ID for non-serializable runtime state + (in-flight asyncio tasks and background sessions). Defaults to + ``"{source_id}_runtime"``. + instructions: Optional instruction override. May include ``{background_agents}`` + placeholder which will be replaced with the agent listing. + + Raises: + ValueError: If agents is empty, an agent has no name, or names are not unique. + """ + super().__init__(source_id) + + self._agents = _validate_and_build_agent_dict(agents) + self._runtime_source_id = runtime_source_id or f"{source_id}_runtime" + + # Build instructions with agent listing. + base_instructions = instructions if instructions is not None else DEFAULT_BACKGROUND_AGENTS_INSTRUCTIONS + agent_list_text = _build_agent_list_text(self._agents) + self._instructions = base_instructions.replace("{background_agents}", agent_list_text) + + # Per-session runtime state (non-serializable), keyed by runtime_source_id. + self._runtime: dict[str, _RuntimeState] = {} + + @property + def runtime_source_id(self) -> str: + """The source ID used for non-serializable runtime state.""" + return self._runtime_source_id + + def _get_runtime(self, session: AgentSession) -> _RuntimeState: + """Get or create runtime state for a session.""" + session_id = session.session_id + if session_id not in self._runtime: + self._runtime[session_id] = _RuntimeState() + return self._runtime[session_id] + + async def before_run( + self, + *, + agent: Any, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Inject background agent tools and instructions before the model runs.""" + del agent, state + + provider_state = _get_provider_state(session, source_id=self.source_id) + runtime = self._get_runtime(session) + source_id = self.source_id + + @tool(name="background_agents_start_task", approval_mode="never_require") + def background_agents_start_task(agent_name: str, input: str, description: str) -> str: + """Start a background task on a named agent. Returns a confirmation with the task ID.""" + key = agent_name.lower() + if key not in self._agents: + available = ", ".join(a.name or "" for a in self._agents.values()) + return f"Error: No background agent found with name '{agent_name}'. Available agents: {available}" + + bg_agent = self._agents[key] + task_id = provider_state.get("next_task_id", 1) + provider_state["next_task_id"] = task_id + 1 + + task_info = BackgroundTaskInfo( + id=task_id, + agent_name=agent_name, + description=description, + ) + tasks = _get_tasks(provider_state) + tasks.append(task_info) + _save_tasks(provider_state, tasks) + + # Create a dedicated session for this background task. + sub_session = bg_agent.create_session() + + # Start the task concurrently. + async_task = asyncio.create_task(bg_agent.run(input, session=sub_session)) + runtime.in_flight_tasks[task_id] = async_task + runtime.background_sessions[task_id] = sub_session + + _save_provider_state(session, provider_state, source_id=source_id) + return f"Background task {task_id} started on agent '{agent_name}'." + + @tool(name="background_agents_wait_for_first_completion", approval_mode="never_require") + async def background_agents_wait_for_first_completion(task_ids: list[int]) -> str: + """Block until the first of the specified background tasks completes. Returns the completed task's ID.""" + if not task_ids: + return "Error: No task IDs provided." + + # Collect in-flight tasks matching the requested IDs. + waitable: list[tuple[int, asyncio.Task[AgentResponse[Any]]]] = [] + for tid in task_ids: + in_flight = runtime.in_flight_tasks.get(tid) + if in_flight is not None: + waitable.append((tid, in_flight)) + + if not waitable: + # Refresh state to catch any that completed. + tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id) + already_complete = next( + (t for t in tasks if t.id in task_ids and t.status != BackgroundTaskStatus.RUNNING), None + ) + if already_complete is not None: + return ( + f"Task {already_complete.id} is not running; current status: {already_complete.status.value}." + ) + return "Error: None of the specified task IDs correspond to running tasks." + + # Wait for the first one to complete. + done, _ = await asyncio.wait( + [t for _, t in waitable], + return_when=asyncio.FIRST_COMPLETED, + ) + + # Find which ID completed. + completed_id: int | None = None + for tid, task in waitable: + if task in done: + completed_id = tid + break + + # Finalize the completed task. + tasks = _get_tasks(provider_state) + task_info = next((t for t in tasks if t.id == completed_id), None) + if task_info is not None and completed_id is not None: + completed_task = runtime.in_flight_tasks.get(completed_id) + if completed_task is not None: + _finalize_task(task_info, completed_task, runtime) + _save_tasks(provider_state, tasks) + _save_provider_state(session, provider_state, source_id=source_id) + + status_str = task_info.status.value if task_info else "Unknown" + return f"Task {completed_id} finished with status: {status_str}." + + @tool(name="background_agents_get_task_results", approval_mode="never_require") + def background_agents_get_task_results(task_id: int) -> str: + """Get the text output of a background task by its ID.""" + tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id) + task_info = next((t for t in tasks if t.id == task_id), None) + + if task_info is None: + return f"Error: No task found with ID {task_id}." + + if task_info.status == BackgroundTaskStatus.COMPLETED: + return task_info.result_text or "(no output)" + if task_info.status == BackgroundTaskStatus.FAILED: + return f"Task failed: {task_info.error_text or 'Unknown error'}" + if task_info.status == BackgroundTaskStatus.LOST: + return "Task state was lost (reference unavailable)." + if task_info.status == BackgroundTaskStatus.RUNNING: + return f"Task {task_id} is still running." + return f"Task {task_id} has status: {task_info.status.value}." + + @tool(name="background_agents_get_all_tasks", approval_mode="never_require") + def background_agents_get_all_tasks() -> str: + """List all background tasks with their IDs, statuses, agent names, and descriptions.""" + tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id) + + if not tasks: + return "No tasks." + + lines = ["Tasks:"] + for t in tasks: + lines.append(f"- Task {t.id} [{t.status.value}] ({t.agent_name}): {t.description}") + return "\n".join(lines) + + @tool(name="background_agents_continue_task", approval_mode="never_require") + def background_agents_continue_task(task_id: int, text: str) -> str: + """Send follow-up input to a completed or failed task to resume its work.""" + tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id) + task_info = next((t for t in tasks if t.id == task_id), None) + + if task_info is None: + return f"Error: No task found with ID {task_id}." + + if task_info.status == BackgroundTaskStatus.LOST: + return ( + f"Error: Task {task_id} cannot be continued because its session was lost. Start a new task instead." + ) + + if task_info.status == BackgroundTaskStatus.RUNNING: + return f"Error: Task {task_id} is still running. Wait for it to complete before continuing." + + key = task_info.agent_name.lower() + if key not in self._agents: + return f"Error: Agent '{task_info.agent_name}' is no longer available." + + sub_session = runtime.background_sessions.get(task_id) + if sub_session is None: + return f"Error: Session for task {task_id} is no longer available." + + bg_agent = self._agents[key] + + # Reset task state and start a new run on the existing session. + task_info.status = BackgroundTaskStatus.RUNNING + task_info.result_text = None + task_info.error_text = None + _save_tasks(provider_state, tasks) + + async_task = asyncio.create_task(bg_agent.run(text, session=sub_session)) + runtime.in_flight_tasks[task_id] = async_task + + _save_provider_state(session, provider_state, source_id=source_id) + return f"Task {task_id} continued with new input." + + @tool(name="background_agents_clear_completed_task", approval_mode="never_require") + def background_agents_clear_completed_task(task_id: int) -> str: + """Remove a completed or failed task and release its session to free memory.""" + tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id) + task_info = next((t for t in tasks if t.id == task_id), None) + + if task_info is None: + return f"Error: No task found with ID {task_id}." + + if task_info.status == BackgroundTaskStatus.RUNNING: + return f"Error: Task {task_id} is still running. Wait for it to complete before clearing." + + # Remove the task from state. + tasks = [t for t in tasks if t.id != task_id] + _save_tasks(provider_state, tasks) + + # Clean up runtime references. + runtime.in_flight_tasks.pop(task_id, None) + runtime.background_sessions.pop(task_id, None) + + _save_provider_state(session, provider_state, source_id=source_id) + return f"Task {task_id} cleared." + + # Inject instructions and current task status. + context.extend_instructions(self.source_id, [self._instructions]) + context.extend_tools( + self.source_id, + [ + background_agents_start_task, + background_agents_wait_for_first_completion, + background_agents_get_task_results, + background_agents_get_all_tasks, + background_agents_continue_task, + background_agents_clear_completed_task, + ], + ) + + # Include current task status as context message if there are tasks. + tasks = _get_tasks(provider_state) + if tasks: + status_lines = ["### Current background tasks"] + for t in tasks: + status_lines.append(f"- Task {t.id} [{t.status.value}] ({t.agent_name}): {t.description}") + context.extend_messages( + self.source_id, + [Message(role="user", contents=["\n".join(status_lines)])], + ) diff --git a/python/packages/core/tests/core/test_harness_background_agents.py b/python/packages/core/tests/core/test_harness_background_agents.py new file mode 100644 index 0000000000..5e51180770 --- /dev/null +++ b/python/packages/core/tests/core/test_harness_background_agents.py @@ -0,0 +1,530 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import asyncio +from typing import Any + +import pytest + +from agent_framework import ( + AgentResponse, + AgentSession, + BackgroundAgentsProvider, + BackgroundTaskInfo, + BackgroundTaskStatus, + Message, +) +from agent_framework._sessions import SessionContext + +# --- Test Helpers --- + + +class _FakeAgent: + """Minimal agent stub for testing background agent delegation.""" + + def __init__( + self, + name: str, + description: str | None = None, + *, + response_text: str = "done", + delay: float = 0.0, + should_fail: bool = False, + ): + self.id = f"agent-{name}" + self.name = name + self.description = description + self._response_text = response_text + self._delay = delay + self._should_fail = should_fail + + def create_session(self, *, session_id: str | None = None) -> AgentSession: + return AgentSession(session_id=session_id) + + def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + return AgentSession(service_session_id=service_session_id, session_id=session_id) + + async def run( + self, messages: Any = None, *, stream: bool = False, session: Any = None, **kwargs: Any + ) -> AgentResponse[Any]: + if self._delay > 0: + await asyncio.sleep(self._delay) + if self._should_fail: + raise RuntimeError("Agent execution failed") + return AgentResponse(messages=[Message(role="assistant", contents=[self._response_text])]) + + +def _make_provider(*agents: _FakeAgent) -> BackgroundAgentsProvider: + """Create a provider with given agents.""" + return BackgroundAgentsProvider(agents) + + +def _make_session() -> AgentSession: + """Create a session for testing.""" + return AgentSession() + + +async def _get_tools(provider: BackgroundAgentsProvider, session: AgentSession) -> dict[str, Any]: + """Run before_run and return tools by name.""" + context = SessionContext(input_messages=[]) + await provider.before_run(agent=None, session=session, context=context, state={}) + tools_by_name: dict[str, Any] = {} + for t in context.tools: + tools_by_name[t.name if hasattr(t, "name") else str(t)] = t + return tools_by_name + + +async def _invoke_tool(tool_obj: Any, **kwargs: Any) -> str: + """Invoke a FunctionTool and return the raw result string.""" + return await tool_obj.invoke(arguments=kwargs, skip_parsing=True) + + +# --- Constructor Tests --- + + +def test_constructor_requires_at_least_one_agent() -> None: + """Should reject empty agent list.""" + with pytest.raises(ValueError, match="At least one background agent"): + BackgroundAgentsProvider([]) + + +def test_constructor_requires_agent_names() -> None: + """Should reject agents with no name.""" + agent = _FakeAgent("") + with pytest.raises(ValueError, match="non-empty name"): + BackgroundAgentsProvider([agent]) + + +def test_constructor_rejects_duplicate_names() -> None: + """Should reject duplicate agent names (case-insensitive).""" + agent1 = _FakeAgent("Research") + agent2 = _FakeAgent("research") + with pytest.raises(ValueError, match="Duplicate background agent name"): + BackgroundAgentsProvider([agent1, agent2]) + + +def test_constructor_valid_agents() -> None: + """Should succeed with valid unique agents.""" + provider = BackgroundAgentsProvider([_FakeAgent("Alpha"), _FakeAgent("Beta")]) + assert provider.source_id == "background_agents" + + +def test_constructor_custom_source_id() -> None: + """Should accept custom source_id.""" + provider = BackgroundAgentsProvider([_FakeAgent("Agent1")], source_id="custom_bg") + assert provider.source_id == "custom_bg" + assert provider.runtime_source_id == "custom_bg_runtime" + + +def test_constructor_custom_runtime_source_id() -> None: + """Should accept explicit runtime_source_id.""" + provider = BackgroundAgentsProvider([_FakeAgent("Agent1")], source_id="custom_bg", runtime_source_id="my_runtime") + assert provider.source_id == "custom_bg" + assert provider.runtime_source_id == "my_runtime" + + +def test_constructor_default_runtime_source_id() -> None: + """Default runtime_source_id should be derived from source_id.""" + provider = BackgroundAgentsProvider([_FakeAgent("Agent1")]) + assert provider.runtime_source_id == "background_agents_runtime" + + +# --- Tool Injection Tests --- + + +async def test_before_run_injects_six_tools() -> None: + """before_run should inject exactly 6 tools.""" + provider = _make_provider(_FakeAgent("Worker")) + tools = await _get_tools(provider, _make_session()) + assert len(tools) == 6 + expected_names = { + "background_agents_start_task", + "background_agents_wait_for_first_completion", + "background_agents_get_task_results", + "background_agents_get_all_tasks", + "background_agents_continue_task", + "background_agents_clear_completed_task", + } + assert set(tools.keys()) == expected_names + + +async def test_before_run_injects_instructions() -> None: + """before_run should inject instructions mentioning agent names.""" + provider = _make_provider(_FakeAgent("ResearchBot", "Does research")) + context = SessionContext(input_messages=[]) + session = _make_session() + await provider.before_run(agent=None, session=session, context=context, state={}) + all_instructions = " ".join(context.instructions) + assert "ResearchBot" in all_instructions + assert "Does research" in all_instructions + + +# --- Start Task Tests --- + + +async def test_start_task_success() -> None: + """Should start a task and return confirmation.""" + provider = _make_provider(_FakeAgent("Worker", response_text="result")) + session = _make_session() + tools = await _get_tools(provider, session) + + result = await _invoke_tool( + tools["background_agents_start_task"], + agent_name="Worker", + input="do something", + description="test task", + ) + assert "task 1 started" in result.lower() + assert "Worker" in result + + +async def test_start_task_unknown_agent() -> None: + """Should return error for unknown agent name.""" + provider = _make_provider(_FakeAgent("Worker")) + session = _make_session() + tools = await _get_tools(provider, session) + + result = await _invoke_tool( + tools["background_agents_start_task"], + agent_name="NonExistent", + input="do something", + description="test", + ) + assert "Error" in result + assert "NonExistent" in result + + +async def test_start_task_increments_ids() -> None: + """Task IDs should increment sequentially.""" + provider = _make_provider(_FakeAgent("Worker")) + session = _make_session() + tools = await _get_tools(provider, session) + + r1 = await _invoke_tool( + tools["background_agents_start_task"], + agent_name="Worker", + input="task 1", + description="first", + ) + r2 = await _invoke_tool( + tools["background_agents_start_task"], + agent_name="Worker", + input="task 2", + description="second", + ) + assert "task 1 started" in r1.lower() + assert "task 2 started" in r2.lower() + + +# --- Get All Tasks Tests --- + + +async def test_get_all_tasks_empty() -> None: + """Should return 'No tasks.' when no tasks exist.""" + provider = _make_provider(_FakeAgent("Worker")) + session = _make_session() + tools = await _get_tools(provider, session) + + result = await _invoke_tool(tools["background_agents_get_all_tasks"]) + assert "No tasks" in result + + +async def test_get_all_tasks_shows_tasks() -> None: + """Should list all tasks with status and description.""" + provider = _make_provider(_FakeAgent("Worker")) + session = _make_session() + tools = await _get_tools(provider, session) + + await _invoke_tool( + tools["background_agents_start_task"], + agent_name="Worker", + input="hello", + description="my task", + ) + result = await _invoke_tool(tools["background_agents_get_all_tasks"]) + assert "my task" in result + assert "Worker" in result + + +# --- Wait for Completion Tests --- + + +async def test_wait_for_first_completion() -> None: + """Should wait and return when a task completes.""" + provider = _make_provider(_FakeAgent("Fast", response_text="fast result", delay=0.01)) + session = _make_session() + tools = await _get_tools(provider, session) + + await _invoke_tool( + tools["background_agents_start_task"], + agent_name="Fast", + input="go", + description="fast task", + ) + result = await _invoke_tool( + tools["background_agents_wait_for_first_completion"], + task_ids=[1], + ) + assert "finished" in result.lower() + assert "completed" in result.lower() + + +async def test_wait_empty_task_ids() -> None: + """Should return error for empty task_ids.""" + provider = _make_provider(_FakeAgent("Worker")) + session = _make_session() + tools = await _get_tools(provider, session) + + result = await _invoke_tool( + tools["background_agents_wait_for_first_completion"], + task_ids=[], + ) + assert "Error" in result + + +async def test_wait_no_running_tasks() -> None: + """Should return error when no specified tasks are running.""" + provider = _make_provider(_FakeAgent("Worker")) + session = _make_session() + tools = await _get_tools(provider, session) + + result = await _invoke_tool( + tools["background_agents_wait_for_first_completion"], + task_ids=[999], + ) + assert "Error" in result or "not running" in result.lower() + + +# --- Get Task Results Tests --- + + +async def test_get_task_results_completed() -> None: + """Should return result text for completed task.""" + provider = _make_provider(_FakeAgent("Worker", response_text="the answer", delay=0.01)) + session = _make_session() + tools = await _get_tools(provider, session) + + await _invoke_tool( + tools["background_agents_start_task"], + agent_name="Worker", + input="query", + description="test", + ) + # Wait for completion. + await _invoke_tool( + tools["background_agents_wait_for_first_completion"], + task_ids=[1], + ) + result = await _invoke_tool( + tools["background_agents_get_task_results"], + task_id=1, + ) + assert result == "the answer" + + +async def test_get_task_results_running() -> None: + """Should indicate task is still running.""" + provider = _make_provider(_FakeAgent("Slow", delay=10.0)) + session = _make_session() + tools = await _get_tools(provider, session) + + await _invoke_tool( + tools["background_agents_start_task"], + agent_name="Slow", + input="query", + description="slow task", + ) + result = await _invoke_tool( + tools["background_agents_get_task_results"], + task_id=1, + ) + assert "still running" in result.lower() + + +async def test_get_task_results_failed() -> None: + """Should return error text for failed task.""" + provider = _make_provider(_FakeAgent("Broken", should_fail=True, delay=0.01)) + session = _make_session() + tools = await _get_tools(provider, session) + + await _invoke_tool( + tools["background_agents_start_task"], + agent_name="Broken", + input="query", + description="will fail", + ) + await _invoke_tool( + tools["background_agents_wait_for_first_completion"], + task_ids=[1], + ) + result = await _invoke_tool( + tools["background_agents_get_task_results"], + task_id=1, + ) + assert "failed" in result.lower() + + +async def test_get_task_results_not_found() -> None: + """Should return error for non-existent task.""" + provider = _make_provider(_FakeAgent("Worker")) + session = _make_session() + tools = await _get_tools(provider, session) + + result = await _invoke_tool( + tools["background_agents_get_task_results"], + task_id=999, + ) + assert "Error" in result + + +# --- Continue Task Tests --- + + +async def test_continue_task_after_completion() -> None: + """Should be able to continue a completed task.""" + provider = _make_provider(_FakeAgent("Worker", response_text="first result", delay=0.01)) + session = _make_session() + tools = await _get_tools(provider, session) + + await _invoke_tool( + tools["background_agents_start_task"], + agent_name="Worker", + input="first input", + description="continuable", + ) + await _invoke_tool( + tools["background_agents_wait_for_first_completion"], + task_ids=[1], + ) + result = await _invoke_tool( + tools["background_agents_continue_task"], + task_id=1, + text="follow up", + ) + assert "continued" in result.lower() + + +async def test_continue_task_still_running() -> None: + """Should return error if task is still running.""" + provider = _make_provider(_FakeAgent("Slow", delay=10.0)) + session = _make_session() + tools = await _get_tools(provider, session) + + await _invoke_tool( + tools["background_agents_start_task"], + agent_name="Slow", + input="input", + description="running", + ) + result = await _invoke_tool( + tools["background_agents_continue_task"], + task_id=1, + text="follow up", + ) + assert "still running" in result.lower() + + +async def test_continue_task_not_found() -> None: + """Should return error for non-existent task.""" + provider = _make_provider(_FakeAgent("Worker")) + session = _make_session() + tools = await _get_tools(provider, session) + + result = await _invoke_tool( + tools["background_agents_continue_task"], + task_id=999, + text="hello", + ) + assert "Error" in result + + +# --- Clear Task Tests --- + + +async def test_clear_completed_task() -> None: + """Should clear a completed task.""" + provider = _make_provider(_FakeAgent("Worker", response_text="done", delay=0.01)) + session = _make_session() + tools = await _get_tools(provider, session) + + await _invoke_tool( + tools["background_agents_start_task"], + agent_name="Worker", + input="task", + description="clearable", + ) + await _invoke_tool( + tools["background_agents_wait_for_first_completion"], + task_ids=[1], + ) + result = await _invoke_tool( + tools["background_agents_clear_completed_task"], + task_id=1, + ) + assert "cleared" in result.lower() + + # Verify task is gone. + all_tasks = await _invoke_tool(tools["background_agents_get_all_tasks"]) + assert "No tasks" in all_tasks + + +async def test_clear_running_task_error() -> None: + """Should return error when clearing a running task.""" + provider = _make_provider(_FakeAgent("Slow", delay=10.0)) + session = _make_session() + tools = await _get_tools(provider, session) + + await _invoke_tool( + tools["background_agents_start_task"], + agent_name="Slow", + input="task", + description="still going", + ) + result = await _invoke_tool( + tools["background_agents_clear_completed_task"], + task_id=1, + ) + assert "still running" in result.lower() + + +async def test_clear_not_found() -> None: + """Should return error for non-existent task.""" + provider = _make_provider(_FakeAgent("Worker")) + session = _make_session() + tools = await _get_tools(provider, session) + + result = await _invoke_tool( + tools["background_agents_clear_completed_task"], + task_id=999, + ) + assert "Error" in result + + +# --- BackgroundTaskInfo Tests --- + + +def test_task_info_serialization() -> None: + """BackgroundTaskInfo should round-trip through to_dict/from_dict.""" + info = BackgroundTaskInfo( + id=1, + agent_name="Worker", + description="test task", + status=BackgroundTaskStatus.COMPLETED, + result_text="hello", + ) + data = info.to_dict() + restored = BackgroundTaskInfo.from_dict(data) + assert restored.id == 1 + assert restored.agent_name == "Worker" + assert restored.status == BackgroundTaskStatus.COMPLETED + assert restored.result_text == "hello" + assert restored.error_text is None + + +def test_task_status_enum_values() -> None: + """BackgroundTaskStatus should have expected values.""" + assert BackgroundTaskStatus.RUNNING == "running" + assert BackgroundTaskStatus.COMPLETED == "completed" + assert BackgroundTaskStatus.FAILED == "failed" + assert BackgroundTaskStatus.LOST == "lost" From 5fb589a2cd10d486a194f29d2315bb05c7bcbd58 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 25 May 2026 14:16:57 +0000 Subject: [PATCH 2/3] Address PR comments and fix linting warnings --- .../packages/core/agent_framework/__init__.py | 2 - .../_harness/_background_agents.py | 62 +++++++++-------- .../core/test_harness_background_agents.py | 68 +++++++++++-------- 3 files changed, 71 insertions(+), 61 deletions(-) diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index 497afe37be..6890aae47b 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -80,7 +80,6 @@ ) from ._feature_stage import ExperimentalFeature, ReleaseCandidateFeature from ._harness._background_agents import ( - DEFAULT_BACKGROUND_AGENTS_RUNTIME_SOURCE_ID, DEFAULT_BACKGROUND_AGENTS_SOURCE_ID, BackgroundAgentsProvider, BackgroundTaskInfo, @@ -304,7 +303,6 @@ "AGENT_FRAMEWORK_USER_AGENT", "APP_INFO", "COMPACTION_STATE_KEY", - "DEFAULT_BACKGROUND_AGENTS_RUNTIME_SOURCE_ID", "DEFAULT_BACKGROUND_AGENTS_SOURCE_ID", "DEFAULT_MAX_ITERATIONS", "DEFAULT_MEMORY_SOURCE_ID", diff --git a/python/packages/core/agent_framework/_harness/_background_agents.py b/python/packages/core/agent_framework/_harness/_background_agents.py index d2f820fc06..7df474f100 100644 --- a/python/packages/core/agent_framework/_harness/_background_agents.py +++ b/python/packages/core/agent_framework/_harness/_background_agents.py @@ -10,10 +10,10 @@ from __future__ import annotations import asyncio -from collections.abc import Sequence +from collections.abc import Awaitable, Sequence from dataclasses import dataclass, field from enum import Enum -from typing import Any +from typing import Any, cast from .._agents import SupportsAgentRun from .._feature_stage import ExperimentalFeature, experimental @@ -22,7 +22,6 @@ from .._types import AgentResponse, Message DEFAULT_BACKGROUND_AGENTS_SOURCE_ID = "background_agents" -DEFAULT_BACKGROUND_AGENTS_RUNTIME_SOURCE_ID = "background_agents_runtime" DEFAULT_BACKGROUND_AGENTS_INSTRUCTIONS = """\ ## Background Agents @@ -91,8 +90,12 @@ def from_dict(cls, data: dict[str, Any]) -> BackgroundTaskInfo: class _RuntimeState: """Non-serializable per-session runtime state for background tasks.""" - in_flight_tasks: dict[int, asyncio.Task[AgentResponse[Any]]] = field(default_factory=dict) - background_sessions: dict[int, AgentSession] = field(default_factory=dict) + in_flight_tasks: dict[int, asyncio.Task[AgentResponse[Any]]] = field( + default_factory=lambda: {} # pyright: ignore[reportUnknownLambdaType] + ) + background_sessions: dict[int, AgentSession] = field( + default_factory=lambda: {} # pyright: ignore[reportUnknownLambdaType] + ) # --------------------------------------------------------------------------- @@ -100,6 +103,11 @@ class _RuntimeState: # --------------------------------------------------------------------------- +async def _run_agent(awaitable: Awaitable[AgentResponse[Any]]) -> AgentResponse[Any]: + """Wrap an Awaitable in a proper coroutine for use with asyncio.create_task.""" + return await awaitable + + def _validate_and_build_agent_dict(agents: Sequence[SupportsAgentRun]) -> dict[str, SupportsAgentRun]: """Validate agents and build a case-insensitive lookup dict. @@ -138,9 +146,10 @@ def _get_provider_state(session: AgentSession, *, source_id: str) -> dict[str, A """Load or initialize serializable provider state from session.""" state = session.state.get(source_id) if state is None: - state = {"next_task_id": 1, "tasks": []} - session.state[source_id] = state - return state # type: ignore[return-value] + initial: dict[str, Any] = {"next_task_id": 1, "tasks": []} + session.state[source_id] = initial + return initial + return cast(dict[str, Any], state) def _save_provider_state(session: AgentSession, state: dict[str, Any], *, source_id: str) -> None: @@ -164,16 +173,17 @@ def _finalize_task( runtime: _RuntimeState, ) -> None: """Extract results from a completed asyncio task and update task info.""" - exception = completed_task.exception() - if exception is not None: - task_info.status = BackgroundTaskStatus.FAILED - task_info.error_text = str(exception) - elif completed_task.cancelled(): + if completed_task.cancelled(): task_info.status = BackgroundTaskStatus.FAILED task_info.error_text = "Task was canceled." else: - task_info.status = BackgroundTaskStatus.COMPLETED - task_info.result_text = completed_task.result().text + exception = completed_task.exception() + if exception is not None: + task_info.status = BackgroundTaskStatus.FAILED + task_info.error_text = str(exception) + else: + task_info.status = BackgroundTaskStatus.COMPLETED + task_info.result_text = completed_task.result().text runtime.in_flight_tasks.pop(task_info.id, None) @@ -233,7 +243,6 @@ def __init__( agents: Sequence[SupportsAgentRun], *, source_id: str = DEFAULT_BACKGROUND_AGENTS_SOURCE_ID, - runtime_source_id: str | None = None, instructions: str | None = None, ) -> None: """Initialize the background agents provider. @@ -244,9 +253,6 @@ def __init__( Keyword Args: source_id: Unique source ID for serializable task state in session. - runtime_source_id: Unique source ID for non-serializable runtime state - (in-flight asyncio tasks and background sessions). Defaults to - ``"{source_id}_runtime"``. instructions: Optional instruction override. May include ``{background_agents}`` placeholder which will be replaced with the agent listing. @@ -256,21 +262,18 @@ def __init__( super().__init__(source_id) self._agents = _validate_and_build_agent_dict(agents) - self._runtime_source_id = runtime_source_id or f"{source_id}_runtime" # Build instructions with agent listing. base_instructions = instructions if instructions is not None else DEFAULT_BACKGROUND_AGENTS_INSTRUCTIONS agent_list_text = _build_agent_list_text(self._agents) self._instructions = base_instructions.replace("{background_agents}", agent_list_text) - # Per-session runtime state (non-serializable), keyed by runtime_source_id. + # Per-session runtime state (non-serializable), keyed by session_id. + # Note: Runtime state (in-flight asyncio.Task objects, child AgentSession handles) + # is inherently non-serializable and cannot survive process restarts. If the provider + # instance is lost, _refresh_task_state() marks orphaned tasks as LOST. self._runtime: dict[str, _RuntimeState] = {} - @property - def runtime_source_id(self) -> str: - """The source ID used for non-serializable runtime state.""" - return self._runtime_source_id - def _get_runtime(self, session: AgentSession) -> _RuntimeState: """Get or create runtime state for a session.""" session_id = session.session_id @@ -318,7 +321,7 @@ def background_agents_start_task(agent_name: str, input: str, description: str) sub_session = bg_agent.create_session() # Start the task concurrently. - async_task = asyncio.create_task(bg_agent.run(input, session=sub_session)) + async_task = asyncio.create_task(_run_agent(bg_agent.run(input, session=sub_session))) runtime.in_flight_tasks[task_id] = async_task runtime.background_sessions[task_id] = sub_session @@ -441,7 +444,7 @@ def background_agents_continue_task(task_id: int, text: str) -> str: task_info.error_text = None _save_tasks(provider_state, tasks) - async_task = asyncio.create_task(bg_agent.run(text, session=sub_session)) + async_task = asyncio.create_task(_run_agent(bg_agent.run(text, session=sub_session))) runtime.in_flight_tasks[task_id] = async_task _save_provider_state(session, provider_state, source_id=source_id) @@ -485,7 +488,8 @@ def background_agents_clear_completed_task(task_id: int) -> str: ) # Include current task status as context message if there are tasks. - tasks = _get_tasks(provider_state) + # Refresh first to get accurate statuses for any tasks that completed between turns. + tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id) if tasks: status_lines = ["### Current background tasks"] for t in tasks: diff --git a/python/packages/core/tests/core/test_harness_background_agents.py b/python/packages/core/tests/core/test_harness_background_agents.py index 5e51180770..34f0893df9 100644 --- a/python/packages/core/tests/core/test_harness_background_agents.py +++ b/python/packages/core/tests/core/test_harness_background_agents.py @@ -17,6 +17,10 @@ ) from agent_framework._sessions import SessionContext +# Suppress "coroutine was never awaited" warnings from task cancellation in tests. +# This occurs when cancelling tasks that wrap coroutines through _run_agent(). +pytestmark = pytest.mark.filterwarnings("ignore::RuntimeWarning:asyncio") + # --- Test Helpers --- @@ -114,20 +118,6 @@ def test_constructor_custom_source_id() -> None: """Should accept custom source_id.""" provider = BackgroundAgentsProvider([_FakeAgent("Agent1")], source_id="custom_bg") assert provider.source_id == "custom_bg" - assert provider.runtime_source_id == "custom_bg_runtime" - - -def test_constructor_custom_runtime_source_id() -> None: - """Should accept explicit runtime_source_id.""" - provider = BackgroundAgentsProvider([_FakeAgent("Agent1")], source_id="custom_bg", runtime_source_id="my_runtime") - assert provider.source_id == "custom_bg" - assert provider.runtime_source_id == "my_runtime" - - -def test_constructor_default_runtime_source_id() -> None: - """Default runtime_source_id should be derived from source_id.""" - provider = BackgroundAgentsProvider([_FakeAgent("Agent1")]) - assert provider.runtime_source_id == "background_agents_runtime" # --- Tool Injection Tests --- @@ -335,11 +325,17 @@ async def test_get_task_results_running() -> None: input="query", description="slow task", ) - result = await _invoke_tool( - tools["background_agents_get_task_results"], - task_id=1, - ) - assert "still running" in result.lower() + try: + result = await _invoke_tool( + tools["background_agents_get_task_results"], + task_id=1, + ) + assert "still running" in result.lower() + finally: + runtime = provider._get_runtime(session) + for task in list(runtime.in_flight_tasks.values()): + task.cancel() + await asyncio.gather(*runtime.in_flight_tasks.values(), return_exceptions=True) async def test_get_task_results_failed() -> None: @@ -417,12 +413,18 @@ async def test_continue_task_still_running() -> None: input="input", description="running", ) - result = await _invoke_tool( - tools["background_agents_continue_task"], - task_id=1, - text="follow up", - ) - assert "still running" in result.lower() + try: + result = await _invoke_tool( + tools["background_agents_continue_task"], + task_id=1, + text="follow up", + ) + assert "still running" in result.lower() + finally: + runtime = provider._get_runtime(session) + for task in list(runtime.in_flight_tasks.values()): + task.cancel() + await asyncio.gather(*runtime.in_flight_tasks.values(), return_exceptions=True) async def test_continue_task_not_found() -> None: @@ -481,11 +483,17 @@ async def test_clear_running_task_error() -> None: input="task", description="still going", ) - result = await _invoke_tool( - tools["background_agents_clear_completed_task"], - task_id=1, - ) - assert "still running" in result.lower() + try: + result = await _invoke_tool( + tools["background_agents_clear_completed_task"], + task_id=1, + ) + assert "still running" in result.lower() + finally: + runtime = provider._get_runtime(session) + for task in list(runtime.in_flight_tasks.values()): + task.cancel() + await asyncio.gather(*runtime.in_flight_tasks.values(), return_exceptions=True) async def test_clear_not_found() -> None: From 0ec75d7687b602e31e10e70f0b09dc425cc4ad13 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Tue, 26 May 2026 15:04:39 +0000 Subject: [PATCH 3/3] Address PR comment --- .../_harness/_background_agents.py | 43 ++++++++++++++----- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/python/packages/core/agent_framework/_harness/_background_agents.py b/python/packages/core/agent_framework/_harness/_background_agents.py index 7df474f100..c329af4aa9 100644 --- a/python/packages/core/agent_framework/_harness/_background_agents.py +++ b/python/packages/core/agent_framework/_harness/_background_agents.py @@ -10,13 +10,14 @@ from __future__ import annotations import asyncio -from collections.abc import Awaitable, Sequence +from collections.abc import Awaitable, MutableMapping, Sequence from dataclasses import dataclass, field from enum import Enum -from typing import Any, cast +from typing import Any, ClassVar, cast from .._agents import SupportsAgentRun from .._feature_stage import ExperimentalFeature, experimental +from .._serialization import SerializationMixin from .._sessions import AgentSession, ContextProvider, SessionContext from .._tools import tool from .._types import AgentResponse, Message @@ -48,33 +49,53 @@ class BackgroundTaskStatus(str, Enum): @experimental(feature_id=ExperimentalFeature.HARNESS) -@dataclass -class BackgroundTaskInfo: +class BackgroundTaskInfo(SerializationMixin): """Metadata for a single background task.""" + DEFAULT_EXCLUDE: ClassVar[set[str]] = set() + id: int agent_name: str description: str - status: BackgroundTaskStatus = BackgroundTaskStatus.RUNNING - result_text: str | None = None - error_text: str | None = None + status: BackgroundTaskStatus + result_text: str | None + error_text: str | None + __slots__ = ("agent_name", "description", "error_text", "id", "result_text", "status") - def to_dict(self) -> dict[str, Any]: + def __init__( + self, + id: int, + agent_name: str, + description: str, + status: BackgroundTaskStatus = BackgroundTaskStatus.RUNNING, + result_text: str | None = None, + error_text: str | None = None, + ) -> None: + """Initialize a background task info entry.""" + self.id = id + self.agent_name = agent_name + self.description = description + self.status = status + self.result_text = result_text + self.error_text = error_text + + def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: """Serialize for session state persistence.""" + del exclude data: dict[str, Any] = { "id": self.id, "agent_name": self.agent_name, "description": self.description, "status": self.status.value, } - if self.result_text is not None: + if not exclude_none or self.result_text is not None: data["result_text"] = self.result_text - if self.error_text is not None: + if not exclude_none or self.error_text is not None: data["error_text"] = self.error_text return data @classmethod - def from_dict(cls, data: dict[str, Any]) -> BackgroundTaskInfo: + def from_dict(cls, data: MutableMapping[str, Any], **kwargs: Any) -> BackgroundTaskInfo: """Deserialize from session state.""" return cls( id=data["id"],