Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib
import utils


def run_tau_bench_rollouts(
Expand Down
1 change: 1 addition & 0 deletions contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import experiment
import gepa_utils
from google.genai import types
import utils

_OUTPUT_DIR = flags.DEFINE_string(
'output_dir',
Expand Down
11 changes: 11 additions & 0 deletions src/google/adk/sessions/vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ..events.event import Event
from ..events.event_actions import EventActions
from ..events.event_actions import EventCompaction
from ..models.cache_metadata import CacheMetadata
from ..utils.vertex_ai_utils import get_express_mode_api_key
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
Expand Down Expand Up @@ -311,6 +312,10 @@ async def append_event(self, session: Session, event: Event) -> Event:
else None
),
}
if event.cache_metadata:
metadata_dict['cache_metadata'] = event.cache_metadata.model_dump(
exclude_none=True, mode='json'
)
if event.grounding_metadata:
metadata_dict['grounding_metadata'] = event.grounding_metadata.model_dump(
exclude_none=True, mode='json'
Expand Down Expand Up @@ -481,6 +486,10 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
getattr(event_metadata, 'grounding_metadata', None),
types.GroundingMetadata,
)
cache_metadata = _session_util.decode_model(
getattr(event_metadata, 'cache_metadata', None),
CacheMetadata,
)
else:
long_running_tool_ids = None
partial = None
Expand All @@ -491,6 +500,7 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
compaction_data = None
usage_metadata_data = None
grounding_metadata = None
cache_metadata = None

if actions:
actions_dict = actions.model_dump(exclude_none=True, mode='python')
Expand Down Expand Up @@ -539,6 +549,7 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
branch=branch,
custom_metadata=custom_metadata,
grounding_metadata=grounding_metadata,
cache_metadata=cache_metadata,
long_running_tool_ids=long_running_tool_ids,
usage_metadata=usage_metadata,
)
51 changes: 51 additions & 0 deletions tests/unittests/sessions/test_vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ def _convert_to_object(data):
'requested_auth_configs',
'rawEvent',
'raw_event',
'cache_metadata',
'usage_metadata',
]:
kwargs[key] = value
else:
Expand Down Expand Up @@ -1306,3 +1308,52 @@ class DummyModel(pydantic.BaseModel):

assert appended_event.actions.compaction is not None
assert appended_event.actions.compaction.start_timestamp == 1000.0


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_append_event_with_cache_and_usage_metadata():
"""cache_metadata and usage_metadata round-trip through append and get."""
session_service = mock_vertex_ai_session_service()
session = await session_service.get_session(
app_name='123', user_id='user', session_id='1'
)
assert session is not None

cache_meta = CacheMetadata(
cache_name='projects/123/locations/us-central1/cachedContents/456',
expire_time=9999999999.0,
fingerprint='abc123hash',
invocations_used=3,
contents_count=10,
created_at=1700000000.0,
)
usage_meta = genai_types.GenerateContentResponseUsageMetadata(
prompt_token_count=100,
candidates_token_count=50,
total_token_count=150,
cached_content_token_count=80,
)
event_to_append = Event(
invocation_id='cache_test_invocation',
author='model',
timestamp=1734005536.0,
content=genai_types.Content(
parts=[genai_types.Part(text='cached response')]
),
cache_metadata=cache_meta,
usage_metadata=usage_meta,
)

await session_service.append_event(session, event_to_append)

retrieved_session = await session_service.get_session(
app_name='123', user_id='user', session_id='1'
)
assert retrieved_session is not None

appended_event = retrieved_session.events[-1]
# cache_metadata is preserved
assert appended_event.cache_metadata == cache_meta
# usage_metadata is preserved
assert appended_event.usage_metadata == usage_meta