diff --git a/src/agents/run_state.py b/src/agents/run_state.py index c5bb8c9faf..782238a4ad 100644 --- a/src/agents/run_state.py +++ b/src/agents/run_state.py @@ -790,7 +790,10 @@ def _serialize_processed_response( A dictionary representation of the ProcessedResponse. """ - action_groups = _serialize_tool_action_groups(processed_response) + action_groups = _serialize_tool_action_groups( + processed_response, + agent_identity_keys_by_id=agent_identity_keys_by_id, + ) _serialize_pending_nested_agent_tool_runs( parent_state=self, function_entries=action_groups.get("functions", []), @@ -1289,16 +1292,28 @@ def _serialize_tool_actions( return serialized_actions -def _serialize_handoffs(handoffs: Sequence[Any]) -> list[dict[str, Any]]: +def _serialize_handoffs( + handoffs: Sequence[Any], + *, + agent_identity_keys_by_id: Mapping[int, str] | None = None, +) -> list[dict[str, Any]]: """Serialize handoff tool calls.""" serialized_handoffs = [] for handoff in handoffs: handoff_target = handoff.handoff handoff_name = _get_attr(handoff_target, "tool_name") or _get_attr(handoff_target, "name") + handoff_data: dict[str, Any] = {"tool_name": handoff_name} + if isinstance(handoff_target, Handoff): + target_agent = _get_handoff_target_agent(handoff_target) + if target_agent is not None: + handoff_data["target_agent"] = _serialize_agent_reference( + target_agent, + agent_identity_keys_by_id=agent_identity_keys_by_id, + ) serialized_handoffs.append( { "tool_call": _serialize_tool_call_data(handoff.tool_call), - "handoff": {"tool_name": handoff_name}, + "handoff": handoff_data, } ) return serialized_handoffs @@ -1383,6 +1398,8 @@ def _serialize_tool_approval_interruption( def _serialize_tool_action_groups( processed_response: ProcessedResponse, + *, + agent_identity_keys_by_id: Mapping[int, str] | None = None, ) -> dict[str, list[dict[str, Any]]]: """Serialize tool-related action groups using a shared spec.""" action_specs: list[ @@ -1455,7 +1472,10 @@ def _serialize_tool_action_groups( include_params_schema, ) in action_specs } - serialized["handoffs"] = _serialize_handoffs(processed_response.handoffs) + serialized["handoffs"] = _serialize_handoffs( + processed_response.handoffs, + agent_identity_keys_by_id=agent_identity_keys_by_id, + ) serialized["mcp_approval_requests"] = _serialize_mcp_approval_requests( processed_response.mcp_approval_requests ) @@ -1644,6 +1664,17 @@ def _build_handoffs_map(current_agent: Agent[Any]) -> dict[str, Handoff[Any, Age return handoffs_map +def _get_handoff_target_agent(handoff: Handoff[Any, Any]) -> Agent[Any] | None: + """Resolve the target agent captured by a handoff, if available.""" + handoff_ref = getattr(handoff, "_agent_ref", None) + target_agent = handoff_ref() if callable(handoff_ref) else None + if target_agent is None: + target_agent = getattr(handoff, "agent", None) + if target_agent is not None and hasattr(target_agent, "handoffs"): + return cast(Agent[Any], target_agent) + return None + + async def _restore_pending_nested_agent_tool_runs( *, current_agent: Agent[Any], @@ -1809,14 +1840,53 @@ def _parse_apply_patch_call(data: dict[str, Any]) -> Any: return data def _deserialize_action_groups() -> dict[str, list[Any]]: - def _resolve_handoff_tool_name(data: Mapping[str, Any]) -> NamedToolLookupKey | None: + def _resolve_handoff_from_data(data: Mapping[str, Any]) -> Handoff[Any, Agent[Any]] | None: handoff_data = data.get("handoff", {}) if not isinstance(handoff_data, Mapping): return None tool_name = handoff_data.get("tool_name") - return cast( - NamedToolLookupKey | None, tool_name if isinstance(tool_name, str) else None - ) + if not isinstance(tool_name, str) or not tool_name: + return None + + target_agent_data = handoff_data.get("target_agent") + if target_agent_data is not None: + target_agent = _resolve_agent_from_data( + target_agent_data, + agent_map, + agent_identity_map, + ) + if target_agent is not None: + for handoff in getattr(current_agent, "handoffs", ()): + if not isinstance(handoff, Handoff): + continue + handoff_name = getattr(handoff, "tool_name", None) or getattr( + handoff, "name", None + ) + if handoff_name != tool_name: + continue + if _get_handoff_target_agent(handoff) is target_agent: + return handoff + + return handoffs_map.get(tool_name) + + def _deserialize_handoff_actions(entries: list[dict[str, Any]]) -> list[Any]: + deserialized: list[Any] = [] + for entry in entries or []: + if not isinstance(entry, Mapping): + continue + handoff = _resolve_handoff_from_data(entry) + if handoff is None: + continue + tool_call_data_raw = entry.get("tool_call", {}) + tool_call_data = ( + dict(tool_call_data_raw) if isinstance(tool_call_data_raw, Mapping) else {} + ) + try: + tool_call = ResponseFunctionToolCall(**tool_call_data) + except Exception: + continue + deserialized.append(ToolRunHandoff(tool_call=tool_call, handoff=handoff)) + return deserialized def _resolve_function_tool_name(data: Mapping[str, Any]) -> FunctionToolLookupKey | None: tool_data = data.get("tool", {}) @@ -1851,14 +1921,6 @@ def _resolve_function_tool_name(data: Mapping[str, Any]) -> FunctionToolLookupKe Callable[[Mapping[str, Any]], NamedToolLookupKey | None] | None, ] ] = [ - ( - "handoffs", - "handoff", - handoffs_map, - lambda data: ResponseFunctionToolCall(**data), - lambda tool_call, handoff: ToolRunHandoff(tool_call=tool_call, handoff=handoff), - _resolve_handoff_tool_name, - ), ( "functions", "tool", @@ -1938,6 +2000,9 @@ def _resolve_function_tool_name(data: Mapping[str, Any]) -> FunctionToolLookupKe action_factory=action_factory, name_resolver=name_resolver, ) + action_groups["handoffs"] = _deserialize_handoff_actions( + processed_response_data.get("handoffs", []) + ) return action_groups action_groups = _deserialize_action_groups() diff --git a/tests/test_run_state.py b/tests/test_run_state.py index 7b2de6b859..bfc2036843 100644 --- a/tests/test_run_state.py +++ b/tests/test_run_state.py @@ -2520,6 +2520,41 @@ async def test_serialization_uses_duplicate_identities_for_handoff_and_output_gu assert restored_item.target_agent is third assert restored._output_guardrail_results[0].agent is third + @pytest.mark.asyncio + async def test_last_processed_handoff_restores_duplicate_target_by_identity(self): + """Duplicate-name handoff actions should restore to the serialized target agent.""" + first = Agent(name="duplicate", instructions="safe") + second = Agent(name="duplicate", instructions="danger") + first_handoff = handoff(first) + second_handoff = handoff(second) + root = Agent(name="router", handoffs=[first_handoff, second_handoff]) + first.handoffs = [root] + second.handoffs = [root] + + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + state = make_state(root, context=context, original_input="route", max_turns=2) + tool_call = cast(ResponseFunctionToolCall, get_handoff_tool_call(first)) + state._last_processed_response = make_processed_response( + handoffs=[ + ToolRunHandoff( + tool_call=tool_call, + handoff=first_handoff, + ) + ] + ) + + json_data = state.to_json() + handoff_data = json_data["last_processed_response"]["handoffs"][0]["handoff"] + assert handoff_data["tool_name"] == "transfer_to_duplicate" + assert handoff_data["target_agent"] == {"name": "duplicate", "identity": "duplicate#2"} + + restored = await RunState.from_json(root, json_data) + + assert restored._last_processed_response is not None + restored_handoff = restored._last_processed_response.handoffs[0].handoff + assert restored_handoff is first_handoff + assert restored_handoff is not second_handoff + async def test_model_response_serialization_roundtrip(self): """Test that model responses serialize and deserialize correctly."""