Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
4 changes: 4 additions & 0 deletions src/google/adk/sessions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@
'DatabaseSessionService',
'InMemorySessionService',
'Session',
'SessionDataTransformer',
'State',
'VertexAiSessionService',
]


def __getattr__(name: str):
if name == 'SessionDataTransformer':
from .session_data_transformer import SessionDataTransformer
return SessionDataTransformer
if name == 'DatabaseSessionService':
try:
from .database_session_service import DatabaseSessionService
Expand Down
40 changes: 36 additions & 4 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from .schemas.v1 import StorageSession as StorageSessionV1
from .schemas.v1 import StorageUserState as StorageUserStateV1
from .session import Session
from .session_data_transformer import SessionDataTransformer
from .state import State

logger = logging.getLogger("google_adk." + __name__)
Expand Down Expand Up @@ -188,7 +189,13 @@ def __init__(self, version: str):
class DatabaseSessionService(BaseSessionService):
"""A session service that uses a database for storage."""

def __init__(self, db_url: str, **kwargs: Any):
def __init__(
self,
db_url: str,
*,
transformer: Optional[SessionDataTransformer] = None,
**kwargs: Any,
):
"""Initializes the database session service with a database URL."""
# 1. Create DB engine for db connection
# 2. Create all tables based on schema
Expand Down Expand Up @@ -248,6 +255,7 @@ def __init__(self, db_url: str, **kwargs: Any):
self._session_locks: dict[_SessionLockKey, asyncio.Lock] = {}
self._session_lock_ref_count: dict[_SessionLockKey, int] = {}
self._session_locks_guard = asyncio.Lock()
self.transformer = transformer

def _get_schema_classes(self) -> _SchemaClasses:
return _SchemaClasses(self._db_schema_version)
Expand Down Expand Up @@ -446,7 +454,12 @@ async def create_session(
)

# Extract state deltas
state_deltas = _session_util.extract_state_delta(state)
transformed_state = (
self.transformer.before_persist_state(state)
if self.transformer and state is not None
else state
)
state_deltas = _session_util.extract_state_delta(transformed_state)
app_state_delta = state_deltas["app"]
user_state_delta = state_deltas["user"]
session_state = state_deltas["session"]
Expand Down Expand Up @@ -479,6 +492,8 @@ async def create_session(
merged_state = _merge_state(
storage_app_state.state, storage_user_state.state, session_state
)
if self.transformer:
merged_state = self.transformer.after_load_state(merged_state)
session = storage_session.to_session(
state=merged_state, is_sqlite=is_sqlite
)
Expand Down Expand Up @@ -540,9 +555,16 @@ async def get_session(

# Merge states
merged_state = _merge_state(app_state, user_state, session_state)
if self.transformer:
merged_state = self.transformer.after_load_state(merged_state)

# Convert storage session to session
events = [e.to_event() for e in reversed(storage_events)]
events = []
for e in reversed(storage_events):
evt = e.to_event()
if self.transformer:
evt = self.transformer.after_load_event(evt)
events.append(evt)
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
session = storage_session.to_session(
state=merged_state, events=events, is_sqlite=is_sqlite
Expand Down Expand Up @@ -596,6 +618,8 @@ async def list_sessions(
session_state = storage_session.state
user_state = user_states_map.get(storage_session.user_id, {})
merged_state = _merge_state(app_state, user_state, session_state)
if self.transformer:
merged_state = self.transformer.after_load_state(merged_state)
sessions.append(
storage_session.to_session(state=merged_state, is_sqlite=is_sqlite)
)
Expand Down Expand Up @@ -640,6 +664,8 @@ async def append_event(self, session: Session, event: Event) -> Event:
if event.actions and event.actions.state_delta
else {}
)
if self.transformer:
state_delta = self.transformer.before_persist_state(state_delta)
state_deltas = _session_util.extract_state_delta(state_delta)
has_app_delta = bool(state_deltas["app"])
has_user_delta = bool(state_deltas["user"])
Expand Down Expand Up @@ -735,7 +761,13 @@ async def append_event(self, session: Session, event: Event) -> Event:
else:
update_time = datetime.fromtimestamp(event.timestamp)
storage_session.update_time = update_time
sql_session.add(schema.StorageEvent.from_event(session, event))

transformed_event = (
self.transformer.before_persist_event(event)
if self.transformer
else event
)
sql_session.add(schema.StorageEvent.from_event(session, transformed_event))

await sql_session.commit()

Expand Down
43 changes: 43 additions & 0 deletions src/google/adk/sessions/session_data_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Mapping
from typing import Protocol

from google.adk.events.event import Event


class SessionDataTransformer(Protocol):
"""Hook protocol for selectively transforming DB session records before persist/load.
This is useful for implementing field-level encryption, PII masking, or secret
scrubbing at the storage boundary without modifying the in-memory core structures,
as long as the transformation yields valid storage dictionaries and Events.
"""

def before_persist_event(self, event: Event) -> Event:
"""Invoked just before serializing and persisting an Event to the database."""
...

def after_load_event(self, event: Event) -> Event:
"""Invoked immediately after loading and deserializing an Event from the database."""
...

def before_persist_state(self, state: Mapping[str, Any]) -> dict[str, Any]:
"""Invoked before persisting state changes (can be full state or partial deltas)."""
...

def after_load_state(self, state: Mapping[str, Any]) -> dict[str, Any]:
"""Invoked after loading a combined application/user/session state dict."""
...
86 changes: 86 additions & 0 deletions tests/unittests/sessions/test_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1626,3 +1626,89 @@ async def tracking_fn(**kwargs):
finally:
database_session_service._select_required_state = original_fn
await service.close()

import json


class MockPIIMaskerTransformer:
def before_persist_state(self, state):
return {k: f"{v}_masked" if isinstance(v, str) else v for k, v in state.items()}

def after_load_state(self, state):
return {k: v.replace("_masked", "") if isinstance(v, str) and v.endswith("_masked") else v for k, v in state.items()}

def before_persist_event(self, event: Event) -> Event:
new_event = event.model_copy() if hasattr(event, "model_copy") else event.copy()
if new_event.invocation_id:
new_event.invocation_id += "_masked"
return new_event

def after_load_event(self, event: Event) -> Event:
new_event = event.model_copy() if hasattr(event, "model_copy") else event.copy()
if new_event.invocation_id and new_event.invocation_id.endswith("_masked"):
new_event.invocation_id = new_event.invocation_id.replace("_masked", "")
return new_event

@pytest.mark.asyncio
async def test_session_data_transformer():
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:', transformer=MockPIIMaskerTransformer())
try:
session = await service.create_session(
app_name='app', user_id='user', session_id='s1', state={'app:secret': 'foo', 'user:pii': 'bar'}
)
assert session.state == {'app:secret': 'foo', 'user:pii': 'bar'}

# Verify persistence has been masked
async with service.db_engine.connect() as conn:
from sqlalchemy import text
result = await conn.execute(text("SELECT state FROM app_states WHERE app_name = 'app'"))
app_state_json = result.scalar()
assert "foo_masked" in json.dumps(app_state_json)

event = Event(invocation_id='inv1', author='user', actions=EventActions(state_delta={'sk1': 'pass'}))
returned_event = await service.append_event(session, event)

assert returned_event.invocation_id == 'inv1'
assert session.state.get('sk1') == 'pass'

# Check event persistence
async with service.db_engine.connect() as conn:
result = await conn.execute(text("SELECT id, state FROM sessions WHERE id = 's1'"))
row = result.fetchone()
assert "pass_masked" in json.dumps(row[1])

result_evt = await conn.execute(text("SELECT event_data FROM events WHERE session_id = 's1' LIMIT 1"))
evt_payload = result_evt.scalar()
assert "inv1_masked" in json.dumps(evt_payload)

# Check retrieval unmasks
loaded_session = await service.get_session(app_name='app', user_id='user', session_id='s1')
assert loaded_session.state == {'app:secret': 'foo', 'user:pii': 'bar', 'sk1': 'pass'}
assert len(loaded_session.events) == 1
assert loaded_session.events[0].invocation_id == 'inv1'
finally:
await service.close()

class ErrorMaskerTransformer:
def before_persist_state(self, state):
raise ValueError("Transformer exception test")

def after_load_state(self, state):
return state

def before_persist_event(self, event: Event) -> Event:
return event

def after_load_event(self, event: Event) -> Event:
return event

@pytest.mark.asyncio
async def test_session_data_transformer_handles_exception():
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:', transformer=ErrorMaskerTransformer())
try:
with pytest.raises(ValueError, match="Transformer exception test"):
await service.create_session(
app_name='app', user_id='user', session_id='s1', state={'app:secret': 'foo'}
)
finally:
await service.close()