Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 81 additions & 16 deletions src/agents/run_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,10 @@ def _serialize_processed_response(
A dictionary representation of the ProcessedResponse.
"""

action_groups = _serialize_tool_action_groups(processed_response)
action_groups = _serialize_tool_action_groups(
processed_response,
agent_identity_keys_by_id=agent_identity_keys_by_id,
)
_serialize_pending_nested_agent_tool_runs(
parent_state=self,
function_entries=action_groups.get("functions", []),
Expand Down Expand Up @@ -1289,16 +1292,28 @@ def _serialize_tool_actions(
return serialized_actions


def _serialize_handoffs(handoffs: Sequence[Any]) -> list[dict[str, Any]]:
def _serialize_handoffs(
handoffs: Sequence[Any],
*,
agent_identity_keys_by_id: Mapping[int, str] | None = None,
) -> list[dict[str, Any]]:
"""Serialize handoff tool calls."""
serialized_handoffs = []
for handoff in handoffs:
handoff_target = handoff.handoff
handoff_name = _get_attr(handoff_target, "tool_name") or _get_attr(handoff_target, "name")
handoff_data: dict[str, Any] = {"tool_name": handoff_name}
if isinstance(handoff_target, Handoff):
target_agent = _get_handoff_target_agent(handoff_target)
if target_agent is not None:
handoff_data["target_agent"] = _serialize_agent_reference(
target_agent,
agent_identity_keys_by_id=agent_identity_keys_by_id,
)
Comment on lines +1309 to +1312
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Bump schema for target-bound handoff snapshots

Adding target_agent changes the serialized RunState shape for pending handoff actions, but snapshots still advertise $schemaVersion 1.10, whose summary is unrelated to this field. Per the repo AGENTS.md RunState schema rule, shape changes need a CURRENT_SCHEMA_VERSION/SCHEMA_VERSION_SUMMARIES update; otherwise an older SDK that already accepts 1.10 can load these new snapshots, ignore target_agent, and resume duplicate-name handoffs with the old ambiguous tool-name behavior instead of failing fast.

Useful? React with 👍 / 👎.

serialized_handoffs.append(
{
"tool_call": _serialize_tool_call_data(handoff.tool_call),
"handoff": {"tool_name": handoff_name},
"handoff": handoff_data,
}
)
return serialized_handoffs
Expand Down Expand Up @@ -1383,6 +1398,8 @@ def _serialize_tool_approval_interruption(

def _serialize_tool_action_groups(
processed_response: ProcessedResponse,
*,
agent_identity_keys_by_id: Mapping[int, str] | None = None,
) -> dict[str, list[dict[str, Any]]]:
"""Serialize tool-related action groups using a shared spec."""
action_specs: list[
Expand Down Expand Up @@ -1455,7 +1472,10 @@ def _serialize_tool_action_groups(
include_params_schema,
) in action_specs
}
serialized["handoffs"] = _serialize_handoffs(processed_response.handoffs)
serialized["handoffs"] = _serialize_handoffs(
processed_response.handoffs,
agent_identity_keys_by_id=agent_identity_keys_by_id,
)
serialized["mcp_approval_requests"] = _serialize_mcp_approval_requests(
processed_response.mcp_approval_requests
)
Expand Down Expand Up @@ -1644,6 +1664,17 @@ def _build_handoffs_map(current_agent: Agent[Any]) -> dict[str, Handoff[Any, Age
return handoffs_map


def _get_handoff_target_agent(handoff: Handoff[Any, Any]) -> Agent[Any] | None:
"""Resolve the target agent captured by a handoff, if available."""
handoff_ref = getattr(handoff, "_agent_ref", None)
target_agent = handoff_ref() if callable(handoff_ref) else None
if target_agent is None:
target_agent = getattr(handoff, "agent", None)
if target_agent is not None and hasattr(target_agent, "handoffs"):
return cast(Agent[Any], target_agent)
return None


async def _restore_pending_nested_agent_tool_runs(
*,
current_agent: Agent[Any],
Expand Down Expand Up @@ -1809,14 +1840,53 @@ def _parse_apply_patch_call(data: dict[str, Any]) -> Any:
return data

def _deserialize_action_groups() -> dict[str, list[Any]]:
def _resolve_handoff_tool_name(data: Mapping[str, Any]) -> NamedToolLookupKey | None:
def _resolve_handoff_from_data(data: Mapping[str, Any]) -> Handoff[Any, Agent[Any]] | None:
handoff_data = data.get("handoff", {})
if not isinstance(handoff_data, Mapping):
return None
tool_name = handoff_data.get("tool_name")
return cast(
NamedToolLookupKey | None, tool_name if isinstance(tool_name, str) else None
)
if not isinstance(tool_name, str) or not tool_name:
return None

target_agent_data = handoff_data.get("target_agent")
if target_agent_data is not None:
target_agent = _resolve_agent_from_data(
target_agent_data,
agent_map,
agent_identity_map,
)
if target_agent is not None:
for handoff in getattr(current_agent, "handoffs", ()):
if not isinstance(handoff, Handoff):
continue
handoff_name = getattr(handoff, "tool_name", None) or getattr(
handoff, "name", None
)
if handoff_name != tool_name:
continue
if _get_handoff_target_agent(handoff) is target_agent:
return handoff

return handoffs_map.get(tool_name)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep target-bound handoffs from falling back by name

When a serialized handoff includes target_agent but the restored current_agent.handoffs has no handoff with that exact target, this falls back to the bare tool-name map. In deployments where the graph changed or duplicate/overridden handoff tool names remain, a target-bound snapshot can still resume against a different agent, which defeats the new binding; reserve the tool-name fallback for legacy entries that do not have target_agent, and skip or raise on unmatched target-bound data.

Useful? React with 👍 / 👎.


def _deserialize_handoff_actions(entries: list[dict[str, Any]]) -> list[Any]:
deserialized: list[Any] = []
for entry in entries or []:
if not isinstance(entry, Mapping):
continue
handoff = _resolve_handoff_from_data(entry)
if handoff is None:
continue
tool_call_data_raw = entry.get("tool_call", {})
tool_call_data = (
dict(tool_call_data_raw) if isinstance(tool_call_data_raw, Mapping) else {}
)
try:
tool_call = ResponseFunctionToolCall(**tool_call_data)
except Exception:
continue
deserialized.append(ToolRunHandoff(tool_call=tool_call, handoff=handoff))
return deserialized

def _resolve_function_tool_name(data: Mapping[str, Any]) -> FunctionToolLookupKey | None:
tool_data = data.get("tool", {})
Expand Down Expand Up @@ -1851,14 +1921,6 @@ def _resolve_function_tool_name(data: Mapping[str, Any]) -> FunctionToolLookupKe
Callable[[Mapping[str, Any]], NamedToolLookupKey | None] | None,
]
] = [
(
"handoffs",
"handoff",
handoffs_map,
lambda data: ResponseFunctionToolCall(**data),
lambda tool_call, handoff: ToolRunHandoff(tool_call=tool_call, handoff=handoff),
_resolve_handoff_tool_name,
),
(
"functions",
"tool",
Expand Down Expand Up @@ -1938,6 +2000,9 @@ def _resolve_function_tool_name(data: Mapping[str, Any]) -> FunctionToolLookupKe
action_factory=action_factory,
name_resolver=name_resolver,
)
action_groups["handoffs"] = _deserialize_handoff_actions(
processed_response_data.get("handoffs", [])
)
return action_groups

action_groups = _deserialize_action_groups()
Expand Down
35 changes: 35 additions & 0 deletions tests/test_run_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2520,6 +2520,41 @@ async def test_serialization_uses_duplicate_identities_for_handoff_and_output_gu
assert restored_item.target_agent is third
assert restored._output_guardrail_results[0].agent is third

@pytest.mark.asyncio
async def test_last_processed_handoff_restores_duplicate_target_by_identity(self):
"""Duplicate-name handoff actions should restore to the serialized target agent."""
first = Agent(name="duplicate", instructions="safe")
second = Agent(name="duplicate", instructions="danger")
first_handoff = handoff(first)
second_handoff = handoff(second)
root = Agent(name="router", handoffs=[first_handoff, second_handoff])
first.handoffs = [root]
second.handoffs = [root]

context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={})
state = make_state(root, context=context, original_input="route", max_turns=2)
tool_call = cast(ResponseFunctionToolCall, get_handoff_tool_call(first))
state._last_processed_response = make_processed_response(
handoffs=[
ToolRunHandoff(
tool_call=tool_call,
handoff=first_handoff,
)
]
)

json_data = state.to_json()
handoff_data = json_data["last_processed_response"]["handoffs"][0]["handoff"]
assert handoff_data["tool_name"] == "transfer_to_duplicate"
assert handoff_data["target_agent"] == {"name": "duplicate", "identity": "duplicate#2"}

restored = await RunState.from_json(root, json_data)

assert restored._last_processed_response is not None
restored_handoff = restored._last_processed_response.handoffs[0].handoff
assert restored_handoff is first_handoff
assert restored_handoff is not second_handoff

async def test_model_response_serialization_roundtrip(self):
"""Test that model responses serialize and deserialize correctly."""

Expand Down