From 53633b8347264098a7abb622fbded914d99abf64 Mon Sep 17 00:00:00 2001 From: dongkseo Date: Sat, 13 Dec 2025 23:49:21 +0900 Subject: [PATCH 1/6] feat(runner): add metadata parameter to Runner.run_async() Add support for passing per-request metadata through the agent execution pipeline. This enables use cases like: - Passing user_id, trace_id, or session context to callbacks - Enabling memory injection in before_model_callback - Supporting request-specific context without using ContextVar workarounds Changes: - Add `metadata` field to LlmRequest model - Add `metadata` field to InvocationContext model - Add `metadata` parameter to Runner.run_async() and related methods - Propagate metadata from InvocationContext to LlmRequest in base_llm_flow - Add unit tests for metadata functionality Closes #2978 --- src/google/adk/agents/invocation_context.py | 9 ++ .../adk/flows/llm_flows/base_llm_flow.py | 4 +- src/google/adk/models/llm_request.py | 10 ++ src/google/adk/runners.py | 23 ++- tests/unittests/test_runners.py | 139 ++++++++++++++++++ 5 files changed, 182 insertions(+), 3 deletions(-) diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 7fdbaee89b..b0cb89730f 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -218,6 +218,15 @@ class InvocationContext(BaseModel): credential_by_key: dict[str, AuthCredential] = Field(default_factory=dict) """The resolved credentials for this invocation, keyed by credential_key.""" + metadata: Optional[dict[str, Any]] = None + """Per-request metadata passed from Runner.run_async(). + + This field allows passing arbitrary metadata that can be accessed during + the invocation lifecycle, particularly in callbacks like before_model_callback. + Common use cases include passing user_id, trace_id, memory context keys, or + other request-specific context that needs to be available during processing. + """ + _invocation_cost_manager: _InvocationCostManager = PrivateAttr( default_factory=_InvocationCostManager ) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 2e45708d9e..75db798dcf 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -478,7 +478,7 @@ async def run_live( """Runs the flow using live api.""" from google.genai import errors - llm_request = LlmRequest() + llm_request = LlmRequest(metadata=invocation_context.metadata) event_id = Event.new_id() # Preprocess before calling the LLM. @@ -819,7 +819,7 @@ async def _run_one_step_async( invocation_context: InvocationContext, ) -> AsyncGenerator[Event, None]: """One step means one LLM call.""" - llm_request = LlmRequest() + llm_request = LlmRequest(metadata=invocation_context.metadata) # Preprocess before calling the LLM. async with Aclosing( diff --git a/src/google/adk/models/llm_request.py b/src/google/adk/models/llm_request.py index 37f1852bd7..5e7621dfd2 100644 --- a/src/google/adk/models/llm_request.py +++ b/src/google/adk/models/llm_request.py @@ -15,6 +15,7 @@ from __future__ import annotations import logging +from typing import Any from typing import Optional from typing import Union @@ -100,6 +101,15 @@ class LlmRequest(BaseModel): the full history. """ + metadata: Optional[dict[str, Any]] = None + """Per-request metadata for callbacks and custom processing. + + This field allows passing arbitrary metadata from the Runner.run_async() + call to callbacks like before_model_callback. This is useful for passing + request-specific context such as user_id, trace_id, or memory context keys + that need to be available during model invocation. + """ + def append_instructions( self, instructions: Union[list[str], types.Content] ) -> list[types.Content]: diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 3a8e49c3f2..b32b4c65f7 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -509,6 +509,7 @@ async def run_async( new_message: Optional[types.Content] = None, state_delta: Optional[dict[str, Any]] = None, run_config: Optional[RunConfig] = None, + metadata: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[Event, None]: """Main entry method to run the agent in this runner. @@ -526,6 +527,9 @@ async def run_async( new_message: A new message to append to the session. state_delta: Optional state changes to apply to the session. run_config: The run config for the agent. + metadata: Optional per-request metadata that will be passed to callbacks. + This allows passing request-specific context such as user_id, trace_id, + or memory context keys to before_model_callback and other callbacks. Yields: The events generated by the agent. @@ -535,6 +539,8 @@ async def run_async( new_message are None. """ run_config = run_config or RunConfig() + # Create a shallow copy to isolate from caller's modifications + metadata = metadata.copy() if metadata else None if new_message and not new_message.role: new_message.role = 'user' @@ -542,6 +548,7 @@ async def run_async( async def _run_with_trace( new_message: Optional[types.Content] = None, invocation_id: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[Event, None]: with tracer.start_as_current_span('invocation'): session = await self._get_or_create_session( @@ -572,6 +579,7 @@ async def _run_with_trace( new_message=new_message, run_config=run_config, state_delta=state_delta, + metadata=metadata, ) else: invocation_id = self._resolve_invocation_id( @@ -583,6 +591,7 @@ async def _run_with_trace( new_message=new_message, run_config=run_config, state_delta=state_delta, + metadata=metadata, ) else: invocation_context = ( @@ -592,6 +601,7 @@ async def _run_with_trace( invocation_id=invocation_id, run_config=run_config, state_delta=state_delta, + metadata=metadata, ) ) if invocation_context.end_of_agents.get( @@ -628,7 +638,9 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: skip_token_compaction=invocation_context.token_compaction_checked, ) - async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen: + async with Aclosing( + _run_with_trace(new_message, invocation_id, metadata) + ) as agen: async for event in agen: yield event @@ -1323,6 +1335,7 @@ async def _setup_context_for_new_invocation( new_message: types.Content, run_config: RunConfig, state_delta: Optional[dict[str, Any]], + metadata: Optional[dict[str, Any]] = None, ) -> InvocationContext: """Sets up the context for a new invocation. @@ -1331,6 +1344,7 @@ async def _setup_context_for_new_invocation( new_message: The new message to process and append to the session. run_config: The run config of the agent. state_delta: Optional state changes to apply to the session. + metadata: Optional per-request metadata to pass to callbacks. Returns: The invocation context for the new invocation. @@ -1340,6 +1354,7 @@ async def _setup_context_for_new_invocation( session, new_message=new_message, run_config=run_config, + metadata=metadata, ) # Step 2: Handle new message, by running callbacks and appending to # session. @@ -1362,6 +1377,7 @@ async def _setup_context_for_resumed_invocation( invocation_id: Optional[str], run_config: RunConfig, state_delta: Optional[dict[str, Any]], + metadata: Optional[dict[str, Any]] = None, ) -> InvocationContext: """Sets up the context for a resumed invocation. @@ -1371,6 +1387,7 @@ async def _setup_context_for_resumed_invocation( invocation_id: The invocation id to resume. run_config: The run config of the agent. state_delta: Optional state changes to apply to the session. + metadata: Optional per-request metadata to pass to callbacks. Returns: The invocation context for the resumed invocation. @@ -1396,6 +1413,7 @@ async def _setup_context_for_resumed_invocation( new_message=user_message, run_config=run_config, invocation_id=invocation_id, + metadata=metadata, ) # Step 3: Maybe handle new message. if new_message: @@ -1444,6 +1462,7 @@ def _new_invocation_context( new_message: Optional[types.Content] = None, live_request_queue: Optional[LiveRequestQueue] = None, run_config: Optional[RunConfig] = None, + metadata: Optional[dict[str, Any]] = None, ) -> InvocationContext: """Creates a new invocation context. @@ -1453,6 +1472,7 @@ def _new_invocation_context( new_message: The new message for the context. live_request_queue: The live request queue for the context. run_config: The run config for the context. + metadata: Optional per-request metadata for the context. Returns: The new invocation context. @@ -1487,6 +1507,7 @@ def _new_invocation_context( live_request_queue=live_request_queue, run_config=run_config, resumability_config=self.resumability_config, + metadata=metadata, ) def _new_invocation_context_for_live( diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index e3bcaeb2a4..9bc1b72ca7 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -31,6 +31,8 @@ from google.adk.cli.utils.agent_loader import AgentLoader from google.adk.errors.session_not_found_error import SessionNotFoundError from google.adk.events.event import Event +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse from google.adk.plugins.base_plugin import BasePlugin from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService @@ -1512,5 +1514,142 @@ async def test_get_session_config_limits_events(): assert len(limited_session.events) == 3 +class TestRunnerMetadata: + """Tests for Runner metadata parameter functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.session_service = InMemorySessionService() + self.artifact_service = InMemoryArtifactService() + self.root_agent = MockLlmAgent("root_agent") + self.runner = Runner( + app_name="test_app", + agent=self.root_agent, + session_service=self.session_service, + artifact_service=self.artifact_service, + ) + + def test_new_invocation_context_with_metadata(self): + """Test that _new_invocation_context correctly passes metadata.""" + mock_session = Session( + id=TEST_SESSION_ID, + app_name=TEST_APP_ID, + user_id=TEST_USER_ID, + events=[], + ) + + test_metadata = {"user_id": "test123", "trace_id": "trace456"} + invocation_context = self.runner._new_invocation_context( + mock_session, metadata=test_metadata + ) + + assert invocation_context.metadata == test_metadata + assert invocation_context.metadata["user_id"] == "test123" + assert invocation_context.metadata["trace_id"] == "trace456" + + def test_new_invocation_context_without_metadata(self): + """Test that _new_invocation_context works without metadata.""" + mock_session = Session( + id=TEST_SESSION_ID, + app_name=TEST_APP_ID, + user_id=TEST_USER_ID, + events=[], + ) + + invocation_context = self.runner._new_invocation_context(mock_session) + + assert invocation_context.metadata is None + + @pytest.mark.asyncio + async def test_run_async_passes_metadata_to_invocation_context(self): + """Test that run_async correctly passes metadata to before_model_callback.""" + # Capture metadata received in callback + captured_metadata = None + + def before_model_callback(callback_context, llm_request): + nonlocal captured_metadata + captured_metadata = llm_request.metadata + # Return a response to skip actual LLM call + return LlmResponse( + content=types.Content( + role="model", parts=[types.Part(text="Test response")] + ) + ) + + # Create agent with before_model_callback + agent_with_callback = LlmAgent( + name="callback_agent", + model="gemini-2.0-flash", + before_model_callback=before_model_callback, + ) + + runner_with_callback = Runner( + app_name="test_app", + agent=agent_with_callback, + session_service=self.session_service, + artifact_service=self.artifact_service, + ) + + session = await self.session_service.create_session( + app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID + ) + + test_metadata = {"experiment_id": "exp-001", "variant": "B"} + + async for event in runner_with_callback.run_async( + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + new_message=types.Content( + role="user", parts=[types.Part(text="Hello")] + ), + metadata=test_metadata, + ): + pass + + # Verify metadata was passed to before_model_callback + assert captured_metadata is not None + assert captured_metadata == test_metadata + assert captured_metadata["experiment_id"] == "exp-001" + assert captured_metadata["variant"] == "B" + + def test_metadata_field_in_invocation_context(self): + """Test that InvocationContext model accepts metadata field.""" + mock_session = Session( + id=TEST_SESSION_ID, + app_name=TEST_APP_ID, + user_id=TEST_USER_ID, + events=[], + ) + + test_metadata = {"key1": "value1", "key2": 123} + + # This should not raise a validation error + invocation_context = InvocationContext( + session_service=self.session_service, + invocation_id="test_inv_id", + agent=self.root_agent, + session=mock_session, + metadata=test_metadata, + ) + + assert invocation_context.metadata == test_metadata + + def test_metadata_field_in_llm_request(self): + """Test that LlmRequest model accepts metadata field.""" + test_metadata = {"context_key": "ctx123", "user_info": {"name": "test"}} + + llm_request = LlmRequest(metadata=test_metadata) + + assert llm_request.metadata == test_metadata + assert llm_request.metadata["context_key"] == "ctx123" + assert llm_request.metadata["user_info"]["name"] == "test" + + def test_llm_request_without_metadata(self): + """Test that LlmRequest works without metadata.""" + llm_request = LlmRequest() + + assert llm_request.metadata is None + + if __name__ == "__main__": pytest.main([__file__]) From f09a5c0b863c325028d44a33b426e745471a4821 Mon Sep 17 00:00:00 2001 From: dongkseo Date: Sat, 20 Dec 2025 18:50:20 +0900 Subject: [PATCH 2/6] docs: clarify shallow copy behavior in docstring and add isolation test - Add note in docstring about shallow copy behavior for nested objects - Add test_metadata_shallow_copy_isolation to verify: - Top-level changes are isolated from original dict - Nested object modifications affect original (shallow copy) --- src/google/adk/runners.py | 4 +++ tests/unittests/test_runners.py | 61 +++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index b32b4c65f7..d416751c09 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -530,6 +530,10 @@ async def run_async( metadata: Optional per-request metadata that will be passed to callbacks. This allows passing request-specific context such as user_id, trace_id, or memory context keys to before_model_callback and other callbacks. + Note: A shallow copy is made of this dictionary, so top-level changes + within callbacks won't affect the original. However, modifications to + nested mutable objects (e.g., nested dicts or lists) will affect the + original. Yields: The events generated by the agent. diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index 9bc1b72ca7..0ac8186534 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -1650,6 +1650,67 @@ def test_llm_request_without_metadata(self): assert llm_request.metadata is None + @pytest.mark.asyncio + async def test_metadata_shallow_copy_isolation(self): + """Test that shallow copy isolates top-level changes but shares nested objects.""" + # Track modifications made in callback + callback_received_metadata = None + + def before_model_callback(callback_context, llm_request): + nonlocal callback_received_metadata + callback_received_metadata = llm_request.metadata + # Modify top-level key (should NOT affect original due to shallow copy) + llm_request.metadata["top_level_key"] = "modified_in_callback" + # Modify nested object (WILL affect original due to shallow copy) + llm_request.metadata["nested"]["inner_key"] = "modified_nested" + return LlmResponse( + content=types.Content( + role="model", parts=[types.Part(text="Test response")] + ) + ) + + agent_with_callback = LlmAgent( + name="callback_agent", + model="gemini-2.0-flash", + before_model_callback=before_model_callback, + ) + + runner_with_callback = Runner( + app_name="test_app", + agent=agent_with_callback, + session_service=self.session_service, + artifact_service=self.artifact_service, + ) + + await self.session_service.create_session( + app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID + ) + + # Original metadata with nested mutable object + original_metadata = { + "top_level_key": "original_value", + "nested": {"inner_key": "original_nested"}, + } + + async for event in runner_with_callback.run_async( + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + new_message=types.Content( + role="user", parts=[types.Part(text="Hello")] + ), + metadata=original_metadata, + ): + pass + + # Verify callback received metadata + assert callback_received_metadata is not None + + # Top-level changes in callback should NOT affect original (shallow copy) + assert original_metadata["top_level_key"] == "original_value" + + # Nested object changes in callback WILL affect original (shallow copy behavior) + assert original_metadata["nested"]["inner_key"] == "modified_nested" + if __name__ == "__main__": pytest.main([__file__]) From b7ded7b51612ca32c31455f067b39f015a58ac1b Mon Sep 17 00:00:00 2001 From: dongkseo Date: Sat, 20 Dec 2025 18:54:07 +0900 Subject: [PATCH 3/6] fix: preserve empty dict metadata instead of converting to None - Change `if metadata` to `if metadata is not None` for truthiness check - Empty dict {} was incorrectly converted to None due to falsy check - Add test_empty_metadata_dict_not_converted_to_none to prevent regression --- src/google/adk/runners.py | 2 +- tests/unittests/test_runners.py | 47 +++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index d416751c09..6ab1536a5d 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -544,7 +544,7 @@ async def run_async( """ run_config = run_config or RunConfig() # Create a shallow copy to isolate from caller's modifications - metadata = metadata.copy() if metadata else None + metadata = metadata.copy() if metadata is not None else None if new_message and not new_message.role: new_message.role = 'user' diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index 0ac8186534..6d9c9fca00 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -1650,6 +1650,53 @@ def test_llm_request_without_metadata(self): assert llm_request.metadata is None + @pytest.mark.asyncio + async def test_empty_metadata_dict_not_converted_to_none(self): + """Test that empty dict {} is preserved and not converted to None.""" + captured_metadata = None + + def before_model_callback(callback_context, llm_request): + nonlocal captured_metadata + captured_metadata = llm_request.metadata + return LlmResponse( + content=types.Content( + role="model", parts=[types.Part(text="Test response")] + ) + ) + + agent_with_callback = LlmAgent( + name="callback_agent", + model="gemini-2.0-flash", + before_model_callback=before_model_callback, + ) + + runner_with_callback = Runner( + app_name="test_app", + agent=agent_with_callback, + session_service=self.session_service, + artifact_service=self.artifact_service, + ) + + await self.session_service.create_session( + app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID + ) + + # Pass empty dict - should NOT become None + async for event in runner_with_callback.run_async( + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + new_message=types.Content( + role="user", parts=[types.Part(text="Hello")] + ), + metadata={}, + ): + pass + + # Empty dict should be preserved, not converted to None + assert captured_metadata is not None + assert captured_metadata == {} + assert isinstance(captured_metadata, dict) + @pytest.mark.asyncio async def test_metadata_shallow_copy_isolation(self): """Test that shallow copy isolates top-level changes but shares nested objects.""" From 1b38c7dd4d6d88f6b3c498368008e9fad23b5627 Mon Sep 17 00:00:00 2001 From: donggyun112 Date: Thu, 8 Jan 2026 13:29:01 +0900 Subject: [PATCH 4/6] feat(runner): add metadata parameter to run(), run_live(), run_debug() Add metadata support to all run methods for consistency: - run(): sync wrapper, passes metadata to run_async() - run_live(): live mode, passes metadata through invocation context - run_debug(): debug helper, passes metadata to run_async() Also update InvocationContext docstring to reflect all supported entry points. --- src/google/adk/agents/invocation_context.py | 4 +- src/google/adk/runners.py | 12 ++ tests/unittests/test_runners.py | 121 ++++++++++++++++++++ 3 files changed, 136 insertions(+), 1 deletion(-) diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index b0cb89730f..8ac24a93a9 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -219,12 +219,14 @@ class InvocationContext(BaseModel): """The resolved credentials for this invocation, keyed by credential_key.""" metadata: Optional[dict[str, Any]] = None - """Per-request metadata passed from Runner.run_async(). + """Per-request metadata passed from Runner entry points. This field allows passing arbitrary metadata that can be accessed during the invocation lifecycle, particularly in callbacks like before_model_callback. Common use cases include passing user_id, trace_id, memory context keys, or other request-specific context that needs to be available during processing. + + Supported entry points: run(), run_async(), run_live(), run_debug(). """ _invocation_cost_manager: _InvocationCostManager = PrivateAttr( diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 6ab1536a5d..ce9dc35acc 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -442,6 +442,7 @@ def run( session_id: str, new_message: types.Content, run_config: Optional[RunConfig] = None, + metadata: Optional[dict[str, Any]] = None, ) -> Generator[Event, None, None]: """Runs the agent. @@ -459,6 +460,7 @@ def run( session_id: The session ID of the session. new_message: A new message to append to the session. run_config: The run config for the agent. + metadata: Optional per-request metadata that will be passed to callbacks. Yields: The events generated by the agent. @@ -474,6 +476,7 @@ async def _invoke_run_async(): session_id=session_id, new_message=new_message, run_config=run_config, + metadata=metadata, ) ) as agen: async for event in agen: @@ -1047,6 +1050,7 @@ async def run_live( live_request_queue: LiveRequestQueue, run_config: Optional[RunConfig] = None, session: Optional[Session] = None, + metadata: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[Event, None]: """Runs the agent in live mode (experimental feature). @@ -1088,6 +1092,7 @@ async def run_live( run_config: The run config for the agent. session: The session to use. This parameter is deprecated, please use `user_id` and `session_id` instead. + metadata: Optional per-request metadata that will be passed to callbacks. Yields: AsyncGenerator[Event, None]: An asynchronous generator that yields @@ -1102,6 +1107,7 @@ async def run_live( Either `session` or both `user_id` and `session_id` must be provided. """ run_config = run_config or RunConfig() + metadata = metadata.copy() if metadata is not None else None # Some native audio models requires the modality to be set. So we set it to # AUDIO by default. if run_config.response_modalities is None: @@ -1129,6 +1135,7 @@ async def run_live( session, live_request_queue=live_request_queue, run_config=run_config, + metadata=metadata, ) root_agent = self.agent @@ -1235,6 +1242,7 @@ async def run_debug( run_config: RunConfig | None = None, quiet: bool = False, verbose: bool = False, + metadata: dict[str, Any] | None = None, ) -> list[Event]: """Debug helper for quick agent experimentation and testing. @@ -1258,6 +1266,7 @@ async def run_debug( shown). verbose: If True, shows detailed tool calls and responses. Defaults to False for cleaner output showing only final agent responses. + metadata: Optional per-request metadata that will be passed to callbacks. Returns: list[Event]: All events from all messages. @@ -1324,6 +1333,7 @@ async def run_debug( session_id=session.id, new_message=types.UserContent(parts=[types.Part(text=message)]), run_config=run_config, + metadata=metadata, ): if not quiet: print_event(event, verbose=verbose) @@ -1520,6 +1530,7 @@ def _new_invocation_context_for_live( *, live_request_queue: LiveRequestQueue, run_config: Optional[RunConfig] = None, + metadata: Optional[dict[str, Any]] = None, ) -> InvocationContext: """Creates a new invocation context for live multi-agent.""" run_config = run_config or RunConfig() @@ -1538,6 +1549,7 @@ def _new_invocation_context_for_live( session, live_request_queue=live_request_queue, run_config=run_config, + metadata=metadata, ) async def _handle_new_message( diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index 6d9c9fca00..106660f619 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -22,6 +22,7 @@ from google.adk.agents.base_agent import BaseAgent from google.adk.agents.context_cache_config import ContextCacheConfig +from google.adk.agents.live_request_queue import LiveRequestQueue from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.run_config import RunConfig @@ -36,6 +37,7 @@ from google.adk.plugins.base_plugin import BasePlugin from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService +from tests.unittests import testing_utils from google.adk.sessions.session import Session from google.genai import types import pytest @@ -1758,6 +1760,125 @@ def before_model_callback(callback_context, llm_request): # Nested object changes in callback WILL affect original (shallow copy behavior) assert original_metadata["nested"]["inner_key"] == "modified_nested" + def test_new_invocation_context_for_live_with_metadata(self): + """Test that _new_invocation_context_for_live correctly passes metadata.""" + mock_session = Session( + id=TEST_SESSION_ID, + app_name=TEST_APP_ID, + user_id=TEST_USER_ID, + events=[], + ) + + test_metadata = {"user_id": "live_user", "trace_id": "live_trace"} + invocation_context = self.runner._new_invocation_context_for_live( + mock_session, metadata=test_metadata + ) + + assert invocation_context.metadata == test_metadata + assert invocation_context.metadata["user_id"] == "live_user" + + @pytest.mark.asyncio + async def test_run_sync_passes_metadata(self): + """Test that sync run() correctly passes metadata to run_async().""" + captured_metadata = None + + def before_model_callback(callback_context, llm_request): + nonlocal captured_metadata + captured_metadata = llm_request.metadata + return LlmResponse( + content=types.Content( + role="model", parts=[types.Part(text="Test response")] + ) + ) + + agent_with_callback = LlmAgent( + name="callback_agent", + model="gemini-2.0-flash", + before_model_callback=before_model_callback, + ) + + runner_with_callback = Runner( + app_name="test_app", + agent=agent_with_callback, + session_service=self.session_service, + artifact_service=self.artifact_service, + ) + + await self.session_service.create_session( + app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID + ) + + test_metadata = {"sync_key": "sync_value"} + + for event in runner_with_callback.run( + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + new_message=types.Content( + role="user", parts=[types.Part(text="Hello")] + ), + metadata=test_metadata, + ): + pass + + assert captured_metadata is not None + assert captured_metadata["sync_key"] == "sync_value" + + @pytest.mark.asyncio + async def test_run_live_passes_metadata_to_llm_request(self): + """Test that run_live() passes metadata through live pipeline to LlmRequest.""" + import asyncio + + # Create MockModel to capture LlmRequest + mock_model = testing_utils.MockModel.create( + responses=[ + LlmResponse( + content=types.Content( + role="model", parts=[types.Part(text="Live response")] + ) + ) + ] + ) + + agent_with_mock = LlmAgent( + name="live_mock_agent", + model=mock_model, + ) + + runner_with_mock = Runner( + app_name="test_app", + agent=agent_with_mock, + session_service=self.session_service, + artifact_service=self.artifact_service, + ) + + await self.session_service.create_session( + app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID + ) + + test_metadata = {"live_key": "live_value", "trace_id": "live_trace_123"} + live_queue = LiveRequestQueue() + live_queue.close() # Close immediately to end the live session + + async def consume_events(): + async for event in runner_with_mock.run_live( + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + live_request_queue=live_queue, + metadata=test_metadata, + ): + pass + + try: + await asyncio.wait_for(consume_events(), timeout=2) + except asyncio.TimeoutError: + pass # Expected - live session may not terminate cleanly + + # Verify MockModel received LlmRequest with correct metadata + assert len(mock_model.requests) > 0 + assert mock_model.requests[0].metadata is not None + assert mock_model.requests[0].metadata["live_key"] == "live_value" + assert mock_model.requests[0].metadata["trace_id"] == "live_trace_123" + if __name__ == "__main__": pytest.main([__file__]) From b5c055e57f735334ab2741c78534dfec3bf4fef4 Mon Sep 17 00:00:00 2001 From: donggyun112 Date: Tue, 20 Jan 2026 14:02:16 +0900 Subject: [PATCH 5/6] fix: resolve isort duplicate import and add missing live_request_queue argument in test - Remove duplicate LiveRequestQueue import caused by merge conflict - Fix import ordering for tests.unittests module - Add required live_request_queue argument to test_new_invocation_context_for_live_with_metadata test --- tests/unittests/test_runners.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index 106660f619..64354dcf99 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -22,7 +22,6 @@ from google.adk.agents.base_agent import BaseAgent from google.adk.agents.context_cache_config import ContextCacheConfig -from google.adk.agents.live_request_queue import LiveRequestQueue from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.run_config import RunConfig @@ -37,11 +36,12 @@ from google.adk.plugins.base_plugin import BasePlugin from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService -from tests.unittests import testing_utils from google.adk.sessions.session import Session from google.genai import types import pytest +from tests.unittests import testing_utils + TEST_APP_ID = "test_app" TEST_USER_ID = "test_user" TEST_SESSION_ID = "test_session" @@ -1770,8 +1770,9 @@ def test_new_invocation_context_for_live_with_metadata(self): ) test_metadata = {"user_id": "live_user", "trace_id": "live_trace"} + live_queue = LiveRequestQueue() invocation_context = self.runner._new_invocation_context_for_live( - mock_session, metadata=test_metadata + mock_session, live_request_queue=live_queue, metadata=test_metadata ) assert invocation_context.metadata == test_metadata From 0f92dd676d17e7b8ae1786c07033ca394c6f143c Mon Sep 17 00:00:00 2001 From: donggyun112 Date: Sat, 7 Mar 2026 21:17:38 +0900 Subject: [PATCH 6/6] fix: move LiveRequestQueue to top-level import to avoid CI failures --- tests/unittests/test_runners.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index 64354dcf99..ff4946de82 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -23,6 +23,7 @@ from google.adk.agents.base_agent import BaseAgent from google.adk.agents.context_cache_config import ContextCacheConfig from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.live_request_queue import LiveRequestQueue from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.run_config import RunConfig from google.adk.apps.app import App @@ -342,8 +343,6 @@ async def test_run_live_auto_create_session(): ) # An empty LiveRequestQueue is sufficient for our mock agent. - from google.adk.agents.live_request_queue import LiveRequestQueue - live_queue = LiveRequestQueue() agen = runner.run_live(