Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
347 changes: 204 additions & 143 deletions google/genai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8781,171 +8781,232 @@ async def generate_content_stream(
# The image shows a flat lay arrangement of freshly baked blueberry
# scones.
"""
if getattr(
self._api_client, 'vertexai', False
) and _extra_utils.has_agent_platform_mcp_servers(config):
raise NotImplementedError(
'MCP servers are not yet supported for streaming in the Agent'
' Platform API.'
)
if not config:
parsed_config = None
elif isinstance(config, dict):
parsed_config = types.GenerateContentConfig(**config)
else:
parsed_config = config.model_copy(deep=True)

# Retrieve and cache any MCP sessions if provided.
incompatible_tools_indexes = (
_extra_utils.find_afc_incompatible_tool_indexes(
config,
parsed_config,
is_agent_platform=getattr(self._api_client, 'vertexai', False),
)
)
# Retrieve and cache any MCP sessions if provided.
parsed_config, mcp_to_genai_tool_adapters = (
await _extra_utils.parse_config_for_mcp_sessions(
config,
is_agent_platform=getattr(self._api_client, 'vertexai', False),
)
)
if _extra_utils.should_disable_afc(parsed_config):
response = await self._generate_content_stream(
model=model, contents=contents, config=parsed_config
)

async def base_async_generator(model, contents, config): # type: ignore[no-untyped-def]
async for chunk in response: # type: ignore[attr-defined]
yield chunk

return base_async_generator(model, contents, parsed_config) # type: ignore[no-untyped-call, no-any-return]
async def stream_generator(): # type: ignore[no-untyped-def]
# Use AsyncExitStack to keep MCP connections alive across the entire stream
async with contextlib.AsyncExitStack() as stack:
current_config = parsed_config

if incompatible_tools_indexes:
original_tools_length = 0
if isinstance(config, types.GenerateContentConfig):
if config.tools:
original_tools_length = len(config.tools)
elif isinstance(config, dict):
tools = config.get('tools', [])
if tools:
original_tools_length = len(tools)
if len(incompatible_tools_indexes) != original_tools_length:
indices_str = ', '.join(map(str, incompatible_tools_indexes))
logger.warning(
'Tools at indices [%s] are not compatible with automatic function '
'calling (AFC). AFC is disabled. If AFC is intended, please '
'include python callables in the tool list, and do not include '
'function declaration and MCP server in the tool list.',
indices_str,
if (
getattr(self._api_client, 'vertexai', False)
and _extra_utils.has_agent_platform_mcp_servers(current_config)
and current_config is not None
):
new_tools: list[Any] = []
if current_config.tools:
for tool in current_config.tools:
if isinstance(tool, types.Tool) and tool.mcp_servers:
if (
tool.function_declarations
or tool.google_search
or tool.retrieval
or tool.google_search_retrieval
or tool.code_execution
):
tool_copy = tool.model_copy(update={'mcp_servers': None})
new_tools.append(tool_copy)

for server in tool.mcp_servers:
if (
getattr(server, 'streamable_http_transport', None)
is not None
):
raise ValueError(
"The 'streamable_http_transport' parameter is only"
' supported in Gemini Developer API mode, not in Gemini'
' Enterprise Agent Platform mode.'
)

if server.name is not None:
session = await stack.enter_async_context(
_mcp_utils._connect_agent_platform_mcp(
self._api_client, server.name
)
)
new_tools.append(session)
else:
raise ValueError(
"Agent Platform MCP servers require a 'name' field."
)
else:
new_tools.append(tool)
current_config.tools = new_tools

final_parsed_config, mcp_to_genai_tool_adapters = (
await _extra_utils.parse_config_for_mcp_sessions(
current_config,
is_agent_platform=getattr(self._api_client, 'vertexai', False),
)
)
response = await self._generate_content_stream(
model=model, contents=contents, config=parsed_config
)

async def base_async_generator(model, contents, config): # type: ignore[no-untyped-def]
async for chunk in response: # type: ignore[attr-defined]
yield chunk

return base_async_generator(model, contents, parsed_config) # type: ignore[no-untyped-call, no-any-return]
if _extra_utils.should_disable_afc(final_parsed_config):
response = await self._generate_content_stream(
model=model, contents=contents, config=final_parsed_config
)
async for chunk in response: # type: ignore[attr-defined]
yield chunk
return

if incompatible_tools_indexes:
original_tools_length = 0
if isinstance(config, types.GenerateContentConfig):
if config.tools:
original_tools_length = len(config.tools)
elif isinstance(config, dict):
tools = config.get('tools', [])
if tools:
original_tools_length = len(tools)
if len(incompatible_tools_indexes) != original_tools_length:
indices_str = ', '.join(map(str, incompatible_tools_indexes))
logger.warning(
'Tools at indices [%s] are not compatible with automatic'
' function calling (AFC). AFC is disabled. If AFC is intended,'
' please include python callables in the tool list, and do not'
' include function declaration and MCP server in the tool'
' list.',
indices_str,
)
response = await self._generate_content_stream(
model=model, contents=contents, config=final_parsed_config
)
async for chunk in response: # type: ignore[attr-defined]
yield chunk
return

# With tool compatibility confirmed, validate that the configuration are
# compatible with each other and raise an error if invalid.
_extra_utils.raise_error_for_afc_incompatible_config(parsed_config)
_extra_utils.raise_error_for_afc_incompatible_config(
final_parsed_config
)

async def async_generator(model, contents, config): # type: ignore[no-untyped-def]
remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(config)
logger.info(
f'AFC is enabled with max remote calls: {remaining_remote_calls_afc}.'
)
automatic_function_calling_history: list[types.Content] = []
func_response_parts = None
chunk = None
i = 0
while remaining_remote_calls_afc > 0:
function_map = _extra_utils.get_function_map(
config, mcp_to_genai_tool_adapters, is_caller_method_async=True
remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(
final_parsed_config
)
config_to_call = config.model_copy(deep=True) if config else None
if function_map:
config_to_call = _extra_utils.get_usage_header(config_to_call)
i += 1
response = await self._generate_content_stream(
model=model, contents=contents, config=config_to_call
logger.info(
'AFC is enabled with max remote calls:'
f' {remaining_remote_calls_afc}.'
)
# TODO: b/453739108 - make AFC logic more robust like the other 3 methods.
if i > 1:
logger.info(f'AFC remote call {i} is done.')
remaining_remote_calls_afc -= 1
if i > 1 and remaining_remote_calls_afc == 0:
logger.info(
'Reached max remote calls for automatic function calling.'
automatic_function_calling_history: list[types.Content] = []
func_response_parts = None
chunk = None
i = 0
loop_contents = contents

while remaining_remote_calls_afc > 0:

function_map = _extra_utils.get_function_map(
final_parsed_config,
mcp_to_genai_tool_adapters,
is_caller_method_async=True,
)

if i == 1:
# First request gets a function call.
# Then get function response parts.
# Yield chunks only if there's no function response parts.
async for chunk in response: # type: ignore[attr-defined]
if not function_map:
contents = _extra_utils.append_chunk_contents(contents, chunk)
yield chunk
else:
if (
not chunk.candidates
or not chunk.candidates[0].content
or not chunk.candidates[0].content.parts
):
break
func_response_parts = (
await _extra_utils.get_function_response_parts_async(
chunk, function_map
)
)
if not func_response_parts:
contents = _extra_utils.append_chunk_contents(contents, chunk)
yield chunk
final_parsed_config_to_call = (
final_parsed_config.model_copy(deep=True)
if final_parsed_config
else None
)
if function_map:
final_parsed_config_to_call = _extra_utils.get_usage_header(
final_parsed_config_to_call
)

else:
# Second request and beyond, yield chunks.
async for chunk in response: # type: ignore[attr-defined]
i += 1

if _extra_utils.should_append_afc_history(config):
chunk.automatic_function_calling_history = (
automatic_function_calling_history
)
contents = _extra_utils.append_chunk_contents(contents, chunk)
yield chunk
if (
chunk is None
or not chunk.candidates
or not chunk.candidates[0].content
or not chunk.candidates[0].content.parts
):
break
func_response_parts = (
await _extra_utils.get_function_response_parts_async(
chunk, function_map
)
response = await self._generate_content_stream(
model=model,
contents=loop_contents,
config=final_parsed_config_to_call,
)
if not function_map:
break

if not func_response_parts:
break
if i > 1:
logger.info(f'AFC remote call {i} is done.')
remaining_remote_calls_afc -= 1
if i > 1 and remaining_remote_calls_afc == 0:
logger.info(
'Reached max remote calls for automatic function calling.'
)

if chunk is None:
continue
# Append function response parts to contents for the next request.
func_call_content = chunk.candidates[0].content
func_response_content = types.Content(
role='user',
parts=func_response_parts,
)
contents = t.t_contents(contents)
if not automatic_function_calling_history:
automatic_function_calling_history.extend(contents)
if isinstance(contents, list) and func_call_content is not None:
contents.append(func_call_content)
contents.append(func_response_content)
if func_call_content is not None:
automatic_function_calling_history.append(func_call_content)
automatic_function_calling_history.append(func_response_content)
if i == 1:
async for chunk in response: # type: ignore[attr-defined]
if not function_map:
loop_contents = _extra_utils.append_chunk_contents(
loop_contents, chunk
)
yield chunk
else:
if (
not chunk.candidates
or not chunk.candidates[0].content
or not chunk.candidates[0].content.parts
):
break
func_response_parts = (
await _extra_utils.get_function_response_parts_async(
chunk, function_map
)
)
if not func_response_parts:
loop_contents = _extra_utils.append_chunk_contents(
loop_contents, chunk
)
yield chunk
else:
async for chunk in response: # type: ignore[attr-defined]
if _extra_utils.should_append_afc_history(final_parsed_config):
chunk.automatic_function_calling_history = (
automatic_function_calling_history
)
loop_contents = _extra_utils.append_chunk_contents(
loop_contents, chunk
)
yield chunk
if (
chunk is None
or not chunk.candidates
or not chunk.candidates[0].content
or not chunk.candidates[0].content.parts
):
break
func_response_parts = (
await _extra_utils.get_function_response_parts_async(
chunk, function_map
)
)

return async_generator(model, contents, parsed_config) # type: ignore[no-untyped-call, no-any-return]
if not function_map or not func_response_parts:
break

if chunk is None:
continue

# Append function response parts to contents for the next request.
func_call_content = chunk.candidates[0].content
func_response_content = types.Content(
role='user',
parts=func_response_parts,
)
loop_contents = t.t_contents(loop_contents) # type: ignore[assignment]
if not automatic_function_calling_history:
automatic_function_calling_history.extend(loop_contents) # type: ignore[arg-type]
if isinstance(loop_contents, list) and func_call_content is not None:
loop_contents.append(func_call_content) # type: ignore[arg-type]
loop_contents.append(func_response_content) # type: ignore[arg-type]
if func_call_content is not None:
automatic_function_calling_history.append(func_call_content)
automatic_function_calling_history.append(func_response_content)

return stream_generator() # type: ignore[no-untyped-call, no-any-return]

async def edit_image(
self,
Expand Down
Loading
Loading