Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 62 additions & 51 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import TYPE_CHECKING

from google.adk.platform import time as platform_time
from google.genai import errors as genai_errors
from google.genai import types
from opentelemetry import trace
from websockets.exceptions import ConnectionClosed
Expand Down Expand Up @@ -738,61 +739,71 @@ def get_author_for_event(llm_response):
else:
return invocation_context.agent.name

while True:
async with Aclosing(llm_connection.receive()) as agen:
async for llm_response in agen:
if llm_response.live_session_resumption_update:
logger.info(
'Update session resumption handle:'
f' {llm_response.live_session_resumption_update}.'
)
invocation_context.live_session_resumption_handle = (
llm_response.live_session_resumption_update.new_handle
try:
while True:
async with Aclosing(llm_connection.receive()) as agen:
async for llm_response in agen:
if llm_response.live_session_resumption_update:
logger.info(
'Update session resumption handle:'
f' {llm_response.live_session_resumption_update}.'
)
invocation_context.live_session_resumption_handle = (
llm_response.live_session_resumption_update.new_handle
)
if llm_response.go_away:
logger.info(f'Received go away signal: {llm_response.go_away}')
# The server signals that it will close the connection soon.
# We proactively raise ConnectionClosed to trigger the reconnection
# logic in run_live, which will use the latest session handle.
raise ConnectionClosed(None, None)

model_response_event = Event(
id=Event.new_id(),
invocation_id=invocation_context.invocation_id,
author=get_author_for_event(llm_response),
)
if llm_response.go_away:
logger.info(f'Received go away signal: {llm_response.go_away}')
# The server signals that it will close the connection soon.
# We proactively raise ConnectionClosed to trigger the reconnection
# logic in run_live, which will use the latest session handle.
raise ConnectionClosed(None, None)

model_response_event = Event(
id=Event.new_id(),
invocation_id=invocation_context.invocation_id,
author=get_author_for_event(llm_response),
)

async with Aclosing(
self._postprocess_live(
invocation_context,
llm_request,
llm_response,
model_response_event,
)
) as agen:
async for event in agen:
# Cache output audio chunks from model responses
# TODO: support video data
if (
invocation_context.run_config.save_live_blob
and event.content
and event.content.parts
and event.content.parts[0].inline_data
and event.content.parts[0].inline_data.mime_type.startswith(
'audio/'
)
):
audio_blob = types.Blob(
data=event.content.parts[0].inline_data.data,
mime_type=event.content.parts[0].inline_data.mime_type,
)
self.audio_cache_manager.cache_audio(
invocation_context, audio_blob, cache_type='output'
async with Aclosing(
self._postprocess_live(
invocation_context,
llm_request,
llm_response,
model_response_event,
)
) as agen:
async for event in agen:
# Cache output audio chunks from model responses
# TODO: support video data
if (
invocation_context.run_config.save_live_blob
and event.content
and event.content.parts
and event.content.parts[0].inline_data
and event.content.parts[0].inline_data.mime_type.startswith(
'audio/'
)
):
audio_blob = types.Blob(
data=event.content.parts[0].inline_data.data,
mime_type=event.content.parts[0].inline_data.mime_type,
)
self.audio_cache_manager.cache_audio(
invocation_context, audio_blob, cache_type='output'
)

yield event
# Give opportunity for other tasks to run.
await asyncio.sleep(0)
yield event
# Give opportunity for other tasks to run.
await asyncio.sleep(0)
except ConnectionClosedOK:
pass
except genai_errors.APIError as e:
# google-genai >= 1.62.0 converts ConnectionClosedOK into APIError with
# WebSocket close code 1000 (RFC 6455 Normal Closure). Treat it the same
# as ConnectionClosedOK so that a clean session close does not propagate
# as an unexpected error.
if e.code != 1000:
raise

async def run_async(
self, invocation_context: InvocationContext
Expand Down
147 changes: 147 additions & 0 deletions tests/unittests/flows/llm_flows/test_base_llm_flow_connection_close.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for _receive_from_model connection-close handling in BaseLlmFlow.

google-genai >= 1.62.0 converts websockets.ConnectionClosedOK into
google.genai.errors.APIError with WebSocket close code 1000 (RFC 6455 Normal
Closure). These tests verify that both ConnectionClosedOK and APIError(1000)
are treated as a clean session end rather than an unexpected error, while
APIError with any other code is still re-raised.
"""

from unittest import mock

from google.adk.agents.llm_agent import Agent
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
from google.adk.models.llm_request import LlmRequest
from google.genai import errors as genai_errors
import pytest
from websockets.exceptions import ConnectionClosedOK

from ... import testing_utils


class _TestFlow(BaseLlmFlow):
"""Minimal concrete subclass of BaseLlmFlow for testing."""

pass


async def _collect(agen):
"""Drain an async generator and return all yielded items."""
items = []
async for item in agen:
items.append(item)
return items


def _make_raising_connection(exc):
"""Return a mock LLM connection whose receive() raises *exc* immediately."""

async def _raise():
raise exc
yield # make it an async generator

connection = mock.MagicMock()
connection.receive = _raise
return connection


@pytest.fixture
def flow():
return _TestFlow()


@pytest.fixture
async def invocation_context():
agent = Agent(name='test_agent', model='mock')
return await testing_utils.create_invocation_context(
agent=agent, user_content=''
)


@pytest.fixture
def llm_request():
return LlmRequest()


# ---------------------------------------------------------------------------
# ConnectionClosedOK — pre-existing behaviour, must remain silent
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_receive_from_model_connection_closed_ok_is_silent(
flow, invocation_context, llm_request
):
"""ConnectionClosedOK must be swallowed so the live session ends cleanly."""
connection = _make_raising_connection(ConnectionClosedOK(None, None))

events = await _collect(
flow._receive_from_model(
connection, 'evt-1', invocation_context, llm_request
)
)

assert events == []


# ---------------------------------------------------------------------------
# APIError(code=1000) — new behaviour for google-genai >= 1.62.0
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_receive_from_model_api_error_1000_is_silent(
flow, invocation_context, llm_request
):
"""APIError with code 1000 (Normal Closure) must be swallowed.

google-genai >= 1.62.0 wraps ConnectionClosedOK as APIError(1000).
This should be treated identically to ConnectionClosedOK.
"""
error = genai_errors.APIError(1000, {}, None)
connection = _make_raising_connection(error)

events = await _collect(
flow._receive_from_model(
connection, 'evt-2', invocation_context, llm_request
)
)

assert events == []


# ---------------------------------------------------------------------------
# APIError with a non-1000 code — must still propagate as a real error
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_receive_from_model_api_error_non_1000_is_raised(
flow, invocation_context, llm_request
):
"""APIError with a code other than 1000 must propagate unchanged."""
error = genai_errors.APIError(500, {}, None)
connection = _make_raising_connection(error)

with pytest.raises(genai_errors.APIError) as exc_info:
await _collect(
flow._receive_from_model(
connection, 'evt-3', invocation_context, llm_request
)
)

assert exc_info.value.code == 500