diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 4c253014a9..2ec362bd87 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -23,6 +23,7 @@ from typing import TYPE_CHECKING from google.adk.platform import time as platform_time +from google.genai import errors as genai_errors from google.genai import types from opentelemetry import trace from websockets.exceptions import ConnectionClosed @@ -738,61 +739,71 @@ def get_author_for_event(llm_response): else: return invocation_context.agent.name - while True: - async with Aclosing(llm_connection.receive()) as agen: - async for llm_response in agen: - if llm_response.live_session_resumption_update: - logger.info( - 'Update session resumption handle:' - f' {llm_response.live_session_resumption_update}.' - ) - invocation_context.live_session_resumption_handle = ( - llm_response.live_session_resumption_update.new_handle + try: + while True: + async with Aclosing(llm_connection.receive()) as agen: + async for llm_response in agen: + if llm_response.live_session_resumption_update: + logger.info( + 'Update session resumption handle:' + f' {llm_response.live_session_resumption_update}.' + ) + invocation_context.live_session_resumption_handle = ( + llm_response.live_session_resumption_update.new_handle + ) + if llm_response.go_away: + logger.info(f'Received go away signal: {llm_response.go_away}') + # The server signals that it will close the connection soon. + # We proactively raise ConnectionClosed to trigger the reconnection + # logic in run_live, which will use the latest session handle. + raise ConnectionClosed(None, None) + + model_response_event = Event( + id=Event.new_id(), + invocation_id=invocation_context.invocation_id, + author=get_author_for_event(llm_response), ) - if llm_response.go_away: - logger.info(f'Received go away signal: {llm_response.go_away}') - # The server signals that it will close the connection soon. - # We proactively raise ConnectionClosed to trigger the reconnection - # logic in run_live, which will use the latest session handle. - raise ConnectionClosed(None, None) - - model_response_event = Event( - id=Event.new_id(), - invocation_id=invocation_context.invocation_id, - author=get_author_for_event(llm_response), - ) - async with Aclosing( - self._postprocess_live( - invocation_context, - llm_request, - llm_response, - model_response_event, - ) - ) as agen: - async for event in agen: - # Cache output audio chunks from model responses - # TODO: support video data - if ( - invocation_context.run_config.save_live_blob - and event.content - and event.content.parts - and event.content.parts[0].inline_data - and event.content.parts[0].inline_data.mime_type.startswith( - 'audio/' - ) - ): - audio_blob = types.Blob( - data=event.content.parts[0].inline_data.data, - mime_type=event.content.parts[0].inline_data.mime_type, - ) - self.audio_cache_manager.cache_audio( - invocation_context, audio_blob, cache_type='output' + async with Aclosing( + self._postprocess_live( + invocation_context, + llm_request, + llm_response, + model_response_event, ) + ) as agen: + async for event in agen: + # Cache output audio chunks from model responses + # TODO: support video data + if ( + invocation_context.run_config.save_live_blob + and event.content + and event.content.parts + and event.content.parts[0].inline_data + and event.content.parts[0].inline_data.mime_type.startswith( + 'audio/' + ) + ): + audio_blob = types.Blob( + data=event.content.parts[0].inline_data.data, + mime_type=event.content.parts[0].inline_data.mime_type, + ) + self.audio_cache_manager.cache_audio( + invocation_context, audio_blob, cache_type='output' + ) - yield event - # Give opportunity for other tasks to run. - await asyncio.sleep(0) + yield event + # Give opportunity for other tasks to run. + await asyncio.sleep(0) + except ConnectionClosedOK: + pass + except genai_errors.APIError as e: + # google-genai >= 1.62.0 converts ConnectionClosedOK into APIError with + # WebSocket close code 1000 (RFC 6455 Normal Closure). Treat it the same + # as ConnectionClosedOK so that a clean session close does not propagate + # as an unexpected error. + if e.code != 1000: + raise async def run_async( self, invocation_context: InvocationContext diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow_connection_close.py b/tests/unittests/flows/llm_flows/test_base_llm_flow_connection_close.py new file mode 100644 index 0000000000..817c6b2182 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow_connection_close.py @@ -0,0 +1,147 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for _receive_from_model connection-close handling in BaseLlmFlow. + +google-genai >= 1.62.0 converts websockets.ConnectionClosedOK into +google.genai.errors.APIError with WebSocket close code 1000 (RFC 6455 Normal +Closure). These tests verify that both ConnectionClosedOK and APIError(1000) +are treated as a clean session end rather than an unexpected error, while +APIError with any other code is still re-raised. +""" + +from unittest import mock + +from google.adk.agents.llm_agent import Agent +from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow +from google.adk.models.llm_request import LlmRequest +from google.genai import errors as genai_errors +import pytest +from websockets.exceptions import ConnectionClosedOK + +from ... import testing_utils + + +class _TestFlow(BaseLlmFlow): + """Minimal concrete subclass of BaseLlmFlow for testing.""" + + pass + + +async def _collect(agen): + """Drain an async generator and return all yielded items.""" + items = [] + async for item in agen: + items.append(item) + return items + + +def _make_raising_connection(exc): + """Return a mock LLM connection whose receive() raises *exc* immediately.""" + + async def _raise(): + raise exc + yield # make it an async generator + + connection = mock.MagicMock() + connection.receive = _raise + return connection + + +@pytest.fixture +def flow(): + return _TestFlow() + + +@pytest.fixture +async def invocation_context(): + agent = Agent(name='test_agent', model='mock') + return await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + + +@pytest.fixture +def llm_request(): + return LlmRequest() + + +# --------------------------------------------------------------------------- +# ConnectionClosedOK — pre-existing behaviour, must remain silent +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_receive_from_model_connection_closed_ok_is_silent( + flow, invocation_context, llm_request +): + """ConnectionClosedOK must be swallowed so the live session ends cleanly.""" + connection = _make_raising_connection(ConnectionClosedOK(None, None)) + + events = await _collect( + flow._receive_from_model( + connection, 'evt-1', invocation_context, llm_request + ) + ) + + assert events == [] + + +# --------------------------------------------------------------------------- +# APIError(code=1000) — new behaviour for google-genai >= 1.62.0 +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_receive_from_model_api_error_1000_is_silent( + flow, invocation_context, llm_request +): + """APIError with code 1000 (Normal Closure) must be swallowed. + + google-genai >= 1.62.0 wraps ConnectionClosedOK as APIError(1000). + This should be treated identically to ConnectionClosedOK. + """ + error = genai_errors.APIError(1000, {}, None) + connection = _make_raising_connection(error) + + events = await _collect( + flow._receive_from_model( + connection, 'evt-2', invocation_context, llm_request + ) + ) + + assert events == [] + + +# --------------------------------------------------------------------------- +# APIError with a non-1000 code — must still propagate as a real error +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_receive_from_model_api_error_non_1000_is_raised( + flow, invocation_context, llm_request +): + """APIError with a code other than 1000 must propagate unchanged.""" + error = genai_errors.APIError(500, {}, None) + connection = _make_raising_connection(error) + + with pytest.raises(genai_errors.APIError) as exc_info: + await _collect( + flow._receive_from_model( + connection, 'evt-3', invocation_context, llm_request + ) + ) + + assert exc_info.value.code == 500