Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
from ..runners import Runner
from ..sessions.base_session_service import BaseSessionService
from ..sessions.session import Session
from ..sessions.state import State
from ..utils.agent_info import AgentInfo
from ..utils.agent_info import get_agents_dict
from ..utils.context_utils import Aclosing
Expand Down Expand Up @@ -1915,6 +1916,21 @@ async def event_generator():
events_to_stream = [content_event, artifact_event]

for event_to_stream in events_to_stream:
# Filter temp-scoped state keys before SSE serialization.
# Temp state (prefix "temp:") can contain non-serializable
# objects such as FunctionTool instances stored by
# _call_llm_node. _trim_temp_delta_state() handles this
# for persistence in append_event(), but SSE events are
# serialized before reaching that path.
if (
event_to_stream.actions
and event_to_stream.actions.state_delta
):
event_to_stream.actions.state_delta = {
k: v
for k, v in event_to_stream.actions.state_delta.items()
if not k.startswith(State.TEMP_PREFIX)
}
sse_event = event_to_stream.model_dump_json(
exclude_none=True,
by_alias=True,
Expand Down
79 changes: 79 additions & 0 deletions tests/unittests/cli/test_fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,6 +1391,85 @@ async def run_async_raises(self, **kwargs):
assert sse_events == [{"error": "boom"}]


def test_agent_run_sse_filters_temp_state_keys(
test_app, create_test_session, monkeypatch
):
"""Test /run_sse strips temp-scoped state keys before serialization.

Temp state (e.g. ``temp:tools_dict``) can hold non-serializable objects
such as ``FunctionTool`` instances. ``_trim_temp_delta_state`` filters
these for persistence, but SSE events are serialized earlier. This test
verifies that the SSE generator applies the same filtering so that
``model_dump_json`` does not crash.

Regression test for https://github.com/google/adk-python/issues/5051
"""
info = create_test_session

# An object that is intentionally not JSON-serializable, mimicking
# the FunctionTool instances that _call_llm_node stores in temp state.
class _NotSerializable:
pass

async def run_async_with_temp_state(
self,
*,
user_id: str,
session_id: str,
invocation_id: Optional[str] = None,
new_message: Optional[types.Content] = None,
state_delta: Optional[dict[str, Any]] = None,
run_config: Optional[RunConfig] = None,
):
del user_id, session_id, invocation_id, new_message, state_delta, run_config
yield Event(
author="dummy agent",
invocation_id="invocation_id",
content=types.Content(
role="model", parts=[types.Part(text="hello")]
),
actions=EventActions(
state_delta={
"user_request": "hi",
"temp:tools_dict": {"greet": _NotSerializable()},
"temp:other": _NotSerializable(),
}
),
)

monkeypatch.setattr(Runner, "run_async", run_async_with_temp_state)

payload = {
"app_name": info["app_name"],
"user_id": info["user_id"],
"session_id": info["session_id"],
"new_message": {"role": "user", "parts": [{"text": "Hello agent"}]},
"streaming": True,
}

response = test_app.post("/run_sse", json=payload)
assert response.status_code == 200

sse_events = [
json.loads(line.removeprefix("data: "))
for line in response.text.splitlines()
if line.startswith("data: ")
]

assert len(sse_events) == 1
event_data = sse_events[0]

# Content should be intact.
assert event_data["content"]["parts"][0]["text"] == "hello"

# Non-temp state key should survive.
assert event_data["actions"]["stateDelta"]["user_request"] == "hi"

# Temp-scoped keys must be stripped.
assert "temp:tools_dict" not in event_data["actions"]["stateDelta"]
assert "temp:other" not in event_data["actions"]["stateDelta"]


def test_list_artifact_names(test_app, create_test_session):
"""Test listing artifact names for a session."""
info = create_test_session
Expand Down