From 74596ad5a5cb3f2c4390e7b18cec84f5cf2c6716 Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Mon, 8 Jun 2026 12:59:38 -0700 Subject: [PATCH] feat: Add Agent Platform MCP support to async generate_content_stream PiperOrigin-RevId: 928724878 --- google/genai/models.py | 347 ++++++++++-------- .../models/test_generate_content_tools.py | 131 +++++-- 2 files changed, 306 insertions(+), 172 deletions(-) diff --git a/google/genai/models.py b/google/genai/models.py index 65e0cf1fd..5afebf537 100644 --- a/google/genai/models.py +++ b/google/genai/models.py @@ -8781,171 +8781,232 @@ async def generate_content_stream( # The image shows a flat lay arrangement of freshly baked blueberry # scones. """ - if getattr( - self._api_client, 'vertexai', False - ) and _extra_utils.has_agent_platform_mcp_servers(config): - raise NotImplementedError( - 'MCP servers are not yet supported for streaming in the Agent' - ' Platform API.' - ) + if not config: + parsed_config = None + elif isinstance(config, dict): + parsed_config = types.GenerateContentConfig(**config) + else: + parsed_config = config.model_copy(deep=True) - # Retrieve and cache any MCP sessions if provided. incompatible_tools_indexes = ( _extra_utils.find_afc_incompatible_tool_indexes( - config, + parsed_config, is_agent_platform=getattr(self._api_client, 'vertexai', False), ) ) - # Retrieve and cache any MCP sessions if provided. - parsed_config, mcp_to_genai_tool_adapters = ( - await _extra_utils.parse_config_for_mcp_sessions( - config, - is_agent_platform=getattr(self._api_client, 'vertexai', False), - ) - ) - if _extra_utils.should_disable_afc(parsed_config): - response = await self._generate_content_stream( - model=model, contents=contents, config=parsed_config - ) - - async def base_async_generator(model, contents, config): # type: ignore[no-untyped-def] - async for chunk in response: # type: ignore[attr-defined] - yield chunk - return base_async_generator(model, contents, parsed_config) # type: ignore[no-untyped-call, no-any-return] + async def stream_generator(): # type: ignore[no-untyped-def] + # Use AsyncExitStack to keep MCP connections alive across the entire stream + async with contextlib.AsyncExitStack() as stack: + current_config = parsed_config - if incompatible_tools_indexes: - original_tools_length = 0 - if isinstance(config, types.GenerateContentConfig): - if config.tools: - original_tools_length = len(config.tools) - elif isinstance(config, dict): - tools = config.get('tools', []) - if tools: - original_tools_length = len(tools) - if len(incompatible_tools_indexes) != original_tools_length: - indices_str = ', '.join(map(str, incompatible_tools_indexes)) - logger.warning( - 'Tools at indices [%s] are not compatible with automatic function ' - 'calling (AFC). AFC is disabled. If AFC is intended, please ' - 'include python callables in the tool list, and do not include ' - 'function declaration and MCP server in the tool list.', - indices_str, + if ( + getattr(self._api_client, 'vertexai', False) + and _extra_utils.has_agent_platform_mcp_servers(current_config) + and current_config is not None + ): + new_tools: list[Any] = [] + if current_config.tools: + for tool in current_config.tools: + if isinstance(tool, types.Tool) and tool.mcp_servers: + if ( + tool.function_declarations + or tool.google_search + or tool.retrieval + or tool.google_search_retrieval + or tool.code_execution + ): + tool_copy = tool.model_copy(update={'mcp_servers': None}) + new_tools.append(tool_copy) + + for server in tool.mcp_servers: + if ( + getattr(server, 'streamable_http_transport', None) + is not None + ): + raise ValueError( + "The 'streamable_http_transport' parameter is only" + ' supported in Gemini Developer API mode, not in Gemini' + ' Enterprise Agent Platform mode.' + ) + + if server.name is not None: + session = await stack.enter_async_context( + _mcp_utils._connect_agent_platform_mcp( + self._api_client, server.name + ) + ) + new_tools.append(session) + else: + raise ValueError( + "Agent Platform MCP servers require a 'name' field." + ) + else: + new_tools.append(tool) + current_config.tools = new_tools + + final_parsed_config, mcp_to_genai_tool_adapters = ( + await _extra_utils.parse_config_for_mcp_sessions( + current_config, + is_agent_platform=getattr(self._api_client, 'vertexai', False), + ) ) - response = await self._generate_content_stream( - model=model, contents=contents, config=parsed_config - ) - async def base_async_generator(model, contents, config): # type: ignore[no-untyped-def] - async for chunk in response: # type: ignore[attr-defined] - yield chunk - - return base_async_generator(model, contents, parsed_config) # type: ignore[no-untyped-call, no-any-return] + if _extra_utils.should_disable_afc(final_parsed_config): + response = await self._generate_content_stream( + model=model, contents=contents, config=final_parsed_config + ) + async for chunk in response: # type: ignore[attr-defined] + yield chunk + return + + if incompatible_tools_indexes: + original_tools_length = 0 + if isinstance(config, types.GenerateContentConfig): + if config.tools: + original_tools_length = len(config.tools) + elif isinstance(config, dict): + tools = config.get('tools', []) + if tools: + original_tools_length = len(tools) + if len(incompatible_tools_indexes) != original_tools_length: + indices_str = ', '.join(map(str, incompatible_tools_indexes)) + logger.warning( + 'Tools at indices [%s] are not compatible with automatic' + ' function calling (AFC). AFC is disabled. If AFC is intended,' + ' please include python callables in the tool list, and do not' + ' include function declaration and MCP server in the tool' + ' list.', + indices_str, + ) + response = await self._generate_content_stream( + model=model, contents=contents, config=final_parsed_config + ) + async for chunk in response: # type: ignore[attr-defined] + yield chunk + return - # With tool compatibility confirmed, validate that the configuration are - # compatible with each other and raise an error if invalid. - _extra_utils.raise_error_for_afc_incompatible_config(parsed_config) + _extra_utils.raise_error_for_afc_incompatible_config( + final_parsed_config + ) - async def async_generator(model, contents, config): # type: ignore[no-untyped-def] - remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(config) - logger.info( - f'AFC is enabled with max remote calls: {remaining_remote_calls_afc}.' - ) - automatic_function_calling_history: list[types.Content] = [] - func_response_parts = None - chunk = None - i = 0 - while remaining_remote_calls_afc > 0: - function_map = _extra_utils.get_function_map( - config, mcp_to_genai_tool_adapters, is_caller_method_async=True + remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc( + final_parsed_config ) - config_to_call = config.model_copy(deep=True) if config else None - if function_map: - config_to_call = _extra_utils.get_usage_header(config_to_call) - i += 1 - response = await self._generate_content_stream( - model=model, contents=contents, config=config_to_call + logger.info( + 'AFC is enabled with max remote calls:' + f' {remaining_remote_calls_afc}.' ) - # TODO: b/453739108 - make AFC logic more robust like the other 3 methods. - if i > 1: - logger.info(f'AFC remote call {i} is done.') - remaining_remote_calls_afc -= 1 - if i > 1 and remaining_remote_calls_afc == 0: - logger.info( - 'Reached max remote calls for automatic function calling.' + automatic_function_calling_history: list[types.Content] = [] + func_response_parts = None + chunk = None + i = 0 + loop_contents = contents + + while remaining_remote_calls_afc > 0: + + function_map = _extra_utils.get_function_map( + final_parsed_config, + mcp_to_genai_tool_adapters, + is_caller_method_async=True, ) - if i == 1: - # First request gets a function call. - # Then get function response parts. - # Yield chunks only if there's no function response parts. - async for chunk in response: # type: ignore[attr-defined] - if not function_map: - contents = _extra_utils.append_chunk_contents(contents, chunk) - yield chunk - else: - if ( - not chunk.candidates - or not chunk.candidates[0].content - or not chunk.candidates[0].content.parts - ): - break - func_response_parts = ( - await _extra_utils.get_function_response_parts_async( - chunk, function_map - ) - ) - if not func_response_parts: - contents = _extra_utils.append_chunk_contents(contents, chunk) - yield chunk + final_parsed_config_to_call = ( + final_parsed_config.model_copy(deep=True) + if final_parsed_config + else None + ) + if function_map: + final_parsed_config_to_call = _extra_utils.get_usage_header( + final_parsed_config_to_call + ) - else: - # Second request and beyond, yield chunks. - async for chunk in response: # type: ignore[attr-defined] + i += 1 - if _extra_utils.should_append_afc_history(config): - chunk.automatic_function_calling_history = ( - automatic_function_calling_history - ) - contents = _extra_utils.append_chunk_contents(contents, chunk) - yield chunk - if ( - chunk is None - or not chunk.candidates - or not chunk.candidates[0].content - or not chunk.candidates[0].content.parts - ): - break - func_response_parts = ( - await _extra_utils.get_function_response_parts_async( - chunk, function_map - ) + response = await self._generate_content_stream( + model=model, + contents=loop_contents, + config=final_parsed_config_to_call, ) - if not function_map: - break - if not func_response_parts: - break + if i > 1: + logger.info(f'AFC remote call {i} is done.') + remaining_remote_calls_afc -= 1 + if i > 1 and remaining_remote_calls_afc == 0: + logger.info( + 'Reached max remote calls for automatic function calling.' + ) - if chunk is None: - continue - # Append function response parts to contents for the next request. - func_call_content = chunk.candidates[0].content - func_response_content = types.Content( - role='user', - parts=func_response_parts, - ) - contents = t.t_contents(contents) - if not automatic_function_calling_history: - automatic_function_calling_history.extend(contents) - if isinstance(contents, list) and func_call_content is not None: - contents.append(func_call_content) - contents.append(func_response_content) - if func_call_content is not None: - automatic_function_calling_history.append(func_call_content) - automatic_function_calling_history.append(func_response_content) + if i == 1: + async for chunk in response: # type: ignore[attr-defined] + if not function_map: + loop_contents = _extra_utils.append_chunk_contents( + loop_contents, chunk + ) + yield chunk + else: + if ( + not chunk.candidates + or not chunk.candidates[0].content + or not chunk.candidates[0].content.parts + ): + break + func_response_parts = ( + await _extra_utils.get_function_response_parts_async( + chunk, function_map + ) + ) + if not func_response_parts: + loop_contents = _extra_utils.append_chunk_contents( + loop_contents, chunk + ) + yield chunk + else: + async for chunk in response: # type: ignore[attr-defined] + if _extra_utils.should_append_afc_history(final_parsed_config): + chunk.automatic_function_calling_history = ( + automatic_function_calling_history + ) + loop_contents = _extra_utils.append_chunk_contents( + loop_contents, chunk + ) + yield chunk + if ( + chunk is None + or not chunk.candidates + or not chunk.candidates[0].content + or not chunk.candidates[0].content.parts + ): + break + func_response_parts = ( + await _extra_utils.get_function_response_parts_async( + chunk, function_map + ) + ) - return async_generator(model, contents, parsed_config) # type: ignore[no-untyped-call, no-any-return] + if not function_map or not func_response_parts: + break + + if chunk is None: + continue + + # Append function response parts to contents for the next request. + func_call_content = chunk.candidates[0].content + func_response_content = types.Content( + role='user', + parts=func_response_parts, + ) + loop_contents = t.t_contents(loop_contents) # type: ignore[assignment] + if not automatic_function_calling_history: + automatic_function_calling_history.extend(loop_contents) # type: ignore[arg-type] + if isinstance(loop_contents, list) and func_call_content is not None: + loop_contents.append(func_call_content) # type: ignore[arg-type] + loop_contents.append(func_response_content) # type: ignore[arg-type] + if func_call_content is not None: + automatic_function_calling_history.append(func_call_content) + automatic_function_calling_history.append(func_response_content) + + return stream_generator() # type: ignore[no-untyped-call, no-any-return] async def edit_image( self, diff --git a/google/genai/tests/models/test_generate_content_tools.py b/google/genai/tests/models/test_generate_content_tools.py index 4c48c6c73..56f65ebbf 100644 --- a/google/genai/tests/models/test_generate_content_tools.py +++ b/google/genai/tests/models/test_generate_content_tools.py @@ -40,6 +40,8 @@ mcp_types = None ClientSession = None +from ...models import AsyncModels + GOOGLE_HOMEPAGE_FILE_PATH = os.path.abspath( os.path.join(os.path.dirname(__file__), '../data/google_homepage.png') ) @@ -2199,35 +2201,6 @@ async def mock_connect(*args, **kwargs): assert mock_session.call_tool.called -@pytest.mark.asyncio -async def test_client_side_mcp_stream_async_raises(client): - """Test that streaming with Agent Platform MCP raises an error.""" - - if not client._api_client.vertexai: - pytest.skip('Vertex MCP test is not applicable to MLDev.') - - with pytest.raises( - NotImplementedError, - match=( - 'MCP servers are not yet supported for streaming in the Agent' - ' Platform API.' - ) - ): - response = await client.aio.models.generate_content_stream( - model='gemini-2.5-flash', - contents='List my endpoints.', - config={ - 'tools': [ - types.Tool( - mcp_servers=[types.McpServer(name='endpoints')] - ) - ] - } - ) - async for _ in response: - pass - - @pytest.mark.asyncio async def test_client_side_mcp_missing_name_raises(client): """Test that an MCP server without a name raises an error.""" @@ -2250,3 +2223,103 @@ async def test_client_side_mcp_missing_name_raises(client): ] } ) + + +@pytest.mark.asyncio +async def test_agent_platform_mcp_stream_async_unit(client): + """Unit tests the Agent Platform MCP integration for streaming without the replay framework.""" + if not client._api_client.vertexai: + return + + if ClientSession is None: + pytest.skip('MCP library is not installed.') + + class MockAgentPlatformSession(ClientSession): + def __init__(self): + self._read_stream = None + self._write_stream = None + + async def list_tools(self): + return mcp_types.ListToolsResult( + tools=[ + mcp_types.Tool( + name='list_endpoints', + description='Lists all serving Endpoints', + inputSchema={ + 'type': 'object', + 'properties': {'parent': {'type': 'string'}}, + }, + ) + ] + ) + + async def call_tool(self, name: str, arguments: dict[str, typing.Any]): + if name == 'list_endpoints': + return mcp_types.CallToolResult( + content=[mcp_types.TextContent(type='text', text='["endpoint-1", "endpoint-2"]')] + ) + + @contextlib.asynccontextmanager + async def mock_mcp_context(*args, **kwargs): + yield MockAgentPlatformSession() + + turn_1_chunk = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role='model', + parts=[ + types.Part( + function_call=types.FunctionCall( + name='list_endpoints', + args={'parent': 'projects/vertex-sdk-dev/locations/us-central1'} + ) + ) + ] + ) + ) + ] + ) + + turn_2_chunk = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role='model', + parts=[types.Part(text='You have 2 endpoints.')] + ) + ) + ] + ) + + async def mock_stream_1(*args, **kwargs): + yield turn_1_chunk + + async def mock_stream_2(*args, **kwargs): + yield turn_2_chunk + + with mock.patch.object(_mcp_utils, '_connect_agent_platform_mcp', side_effect=mock_mcp_context) as mock_connect_mcp: + with mock.patch.object(AsyncModels, '_generate_content_stream', side_effect=[mock_stream_1(), mock_stream_2()]) as mock_generate_stream: + + response_stream = await client.aio.models.generate_content_stream( + model='gemini-2.5-flash', + contents='List my endpoints.', + config=types.GenerateContentConfig( + tools=[ + types.Tool( + mcp_servers=[ + types.McpServer(name='endpoints') + ] + ) + ] + ) + ) + + final_text = '' + async for chunk in response_stream: + if chunk.text: + final_text += chunk.text + + assert '2 endpoints' in final_text + mock_connect_mcp.assert_called_once_with(client._api_client, 'endpoints') + assert mock_generate_stream.call_count == 2