diff --git a/docs/running_agents.md b/docs/running_agents.md index ba6d395c99..f1433c8126 100644 --- a/docs/running_agents.md +++ b/docs/running_agents.md @@ -406,7 +406,7 @@ settings so the resumed turn continues in the same server-managed conversation. Use `call_model_input_filter` to edit the model input right before the model call. The hook receives the current agent, context, and the combined input items (including session history when present) and returns a new `ModelInputData`. -The return value must be a [`ModelInputData`][agents.run.ModelInputData] object. Its `input` field is required and must be a list of input items. Returning any other shape raises a `UserError`. +The return value must be a [`ModelInputData`][agents.run.ModelInputData] object. Its `input` field is required and must be a list of input items. Returning any other shape raises a `UserError`. You may also set `output_schema` on the returned object to replace the response format for that model call — the agent's own `output_type` is used when `output_schema` is `None` or omitted. ```python from agents import Agent, Runner, RunConfig diff --git a/src/agents/extensions/tool_output_trimmer.py b/src/agents/extensions/tool_output_trimmer.py index 26b307f14f..f9d2b055ce 100644 --- a/src/agents/extensions/tool_output_trimmer.py +++ b/src/agents/extensions/tool_output_trimmer.py @@ -152,7 +152,11 @@ def __call__(self, data: CallModelData[Any]) -> ModelInputData: f"saved ~{chars_saved} chars" ) - return _ModelInputData(input=new_items, instructions=model_data.instructions) + return _ModelInputData( + input=new_items, + instructions=model_data.instructions, + output_schema=model_data.output_schema, + ) def _find_recent_boundary(self, items: list[Any]) -> int: """Find the index separating 'old' items from 'recent' items. diff --git a/src/agents/run_config.py b/src/agents/run_config.py index fcc9b01315..905a32afed 100644 --- a/src/agents/run_config.py +++ b/src/agents/run_config.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: from .agent import Agent + from .agent_output import AgentOutputSchemaBase from .run_context import RunContextWrapper from .sandbox.manifest import Manifest from .sandbox.session.base_sandbox_session import BaseSandboxSession @@ -50,6 +51,10 @@ class ModelInputData: input: list[TResponseInputItem] instructions: str | None + output_schema: AgentOutputSchemaBase | None = None + """Output schema override. When set by a ``call_model_input_filter``, replaces the schema + derived from ``agent.output_type`` for this model call. When ``None``, the agent's schema + is used unchanged.""" @dataclass diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 45f09c0fa0..f106b097e3 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -1373,7 +1373,9 @@ def _tool_search_fingerprint(raw_item: Any) -> str: context_wrapper=context_wrapper, input_items=input, system_instructions=system_prompt, + output_schema=output_schema, ) + output_schema = filtered.output_schema if filtered.output_schema is not None else output_schema if isinstance(filtered.input, list): filtered.input = deduplicate_input_items_preferring_latest(filtered.input) hosted_mcp_tool_metadata = collect_mcp_list_tools_metadata(streamed_result._model_input_items) @@ -1760,7 +1762,7 @@ async def run_single_turn( else: input = _prepare_turn_input_items(original_input, generated_items, reasoning_item_id_policy) - new_response = await get_new_response( + new_response, output_schema = await get_new_response( bindings, system_prompt, input, @@ -1811,8 +1813,8 @@ async def get_new_response( session: Session | None = None, session_items_to_rewind: list[TResponseInputItem] | None = None, prompt_cache_key_resolver: PromptCacheKeyResolver | None = None, -) -> ModelResponse: - """Call the model and return the raw response, handling retries and hooks.""" +) -> tuple[ModelResponse, AgentOutputSchemaBase | None]: + """Call the model and return the raw response and effective output schema after filtering.""" public_agent = bindings.public_agent execution_agent = bindings.execution_agent filtered = await maybe_filter_model_input( @@ -1821,7 +1823,9 @@ async def get_new_response( context_wrapper=context_wrapper, input_items=input, system_instructions=system_prompt, + output_schema=output_schema, ) + output_schema = filtered.output_schema if filtered.output_schema is not None else output_schema if isinstance(filtered.input, list): filtered.input = deduplicate_input_items_preferring_latest(filtered.input) @@ -1917,4 +1921,4 @@ async def rewind_model_request() -> None: hooks.on_llm_end(context_wrapper, public_agent, new_response), ) - return new_response + return new_response, output_schema diff --git a/src/agents/run_internal/turn_preparation.py b/src/agents/run_internal/turn_preparation.py index 0a79ebd813..a2078ba5f4 100644 --- a/src/agents/run_internal/turn_preparation.py +++ b/src/agents/run_internal/turn_preparation.py @@ -55,18 +55,24 @@ async def maybe_filter_model_input( context_wrapper: RunContextWrapper[TContext], input_items: list[TResponseInputItem], system_instructions: str | None, + output_schema: AgentOutputSchemaBase | None = None, ) -> ModelInputData: """Apply optional call_model_input_filter to modify model input.""" effective_instructions = system_instructions effective_input: list[TResponseInputItem] = input_items if run_config.call_model_input_filter is None: - return ModelInputData(input=effective_input, instructions=effective_instructions) + return ModelInputData( + input=effective_input, + instructions=effective_instructions, + output_schema=output_schema, + ) try: model_input = ModelInputData( input=effective_input.copy(), instructions=effective_instructions, + output_schema=output_schema, ) filter_payload: CallModelData[TContext] = CallModelData( model_data=model_input, diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index eb22c70f14..944ebe97c2 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -2449,7 +2449,8 @@ async def test_conversation_lock_rewind_skips_when_no_snapshot() -> None: session_items_to_rewind=[], ) - assert isinstance(result, ModelResponse) + response, _ = result + assert isinstance(response, ModelResponse) assert session.pop_calls == 0 @@ -2494,8 +2495,9 @@ async def test_get_new_response_uses_agent_retry_settings() -> None: session_items_to_rewind=[], ) - assert isinstance(result, ModelResponse) - assert result.usage.requests == 2 + response, _ = result + assert isinstance(response, ModelResponse) + assert response.usage.requests == 2 @pytest.mark.asyncio diff --git a/tests/test_call_model_input_filter.py b/tests/test_call_model_input_filter.py index f0239089c6..3e44f7dc22 100644 --- a/tests/test_call_model_input_filter.py +++ b/tests/test_call_model_input_filter.py @@ -3,8 +3,10 @@ from typing import Any, cast import pytest +from pydantic import BaseModel from agents import Agent, RunConfig, Runner, TResponseInputItem, UserError +from agents.agent_output import AgentOutputSchema from agents.run import CallModelData, ModelInputData from .fake_model import FakeModel @@ -167,3 +169,113 @@ async def filter_fn(data: CallModelData[Any]) -> ModelInputData: ] assert len(outputs) == 1 assert outputs[0]["output"] == "new-value" + + +class _Reply(BaseModel): + answer: str + + +@pytest.mark.asyncio +async def test_filter_can_override_output_schema_non_streamed() -> None: + """Regression test for #3563: filter can replace output_schema on non-streamed run. + + Verifies both that the model call receives the override schema and that the + response is parsed against it (not discarded after get_new_response returns). + """ + model = FakeModel() + agent = Agent(name="test", model=model) + model.set_next_output([get_text_message('{"answer": "hi"}')]) + + override_schema = AgentOutputSchema(_Reply) + + def filter_fn(data: CallModelData[Any]) -> ModelInputData: + return ModelInputData( + input=data.model_data.input, + instructions=data.model_data.instructions, + output_schema=override_schema, + ) + + result = await Runner.run( + agent, + input="start", + run_config=RunConfig(call_model_input_filter=filter_fn), + ) + + assert model.last_turn_args["output_schema"] is override_schema + assert isinstance(result.final_output, _Reply) + assert result.final_output.answer == "hi" + + +@pytest.mark.asyncio +async def test_filter_can_override_output_schema_streamed() -> None: + """Regression test for #3563: filter can replace output_schema on streamed run.""" + model = FakeModel() + agent = Agent(name="test", model=model) + model.set_next_output([get_text_message('{"answer": "hi"}')]) + + override_schema = AgentOutputSchema(_Reply) + + async def filter_fn(data: CallModelData[Any]) -> ModelInputData: + return ModelInputData( + input=data.model_data.input, + instructions=data.model_data.instructions, + output_schema=override_schema, + ) + + result = Runner.run_streamed( + agent, + input="start", + run_config=RunConfig(call_model_input_filter=filter_fn), + ) + async for _ in result.stream_events(): + pass + + assert model.last_turn_args["output_schema"] is override_schema + + +@pytest.mark.asyncio +async def test_filter_receives_agent_output_schema() -> None: + """Filter should see the agent's output_schema in model_data so it can inspect or forward it.""" + model = FakeModel() + agent = Agent(name="test", model=model, output_type=_Reply) + model.set_next_output([get_text_message('{"answer": "hi"}')]) + + observed: list[Any] = [] + + def filter_fn(data: CallModelData[Any]) -> ModelInputData: + observed.append(data.model_data.output_schema) + return data.model_data + + await Runner.run( + agent, + input="start", + run_config=RunConfig(call_model_input_filter=filter_fn), + ) + + assert len(observed) == 1 + assert observed[0] is not None + assert observed[0].name() == "_Reply" + + +@pytest.mark.asyncio +async def test_filter_not_setting_output_schema_preserves_agent_schema() -> None: + """A filter omitting output_schema must not clear the agent's schema.""" + model = FakeModel() + agent = Agent(name="test", model=model, output_type=_Reply) + model.set_next_output([get_text_message('{"answer": "hi"}')]) + + def filter_fn(data: CallModelData[Any]) -> ModelInputData: + # Intentionally omit output_schema to confirm the agent schema is preserved. + return ModelInputData( + input=data.model_data.input, + instructions=data.model_data.instructions, + ) + + await Runner.run( + agent, + input="start", + run_config=RunConfig(call_model_input_filter=filter_fn), + ) + + assert model.last_turn_args["output_schema"] is not None + assert model.last_turn_args["output_schema"].name() == "_Reply"