From f515a9f8e07cf882de21d18845709d8cfff04612 Mon Sep 17 00:00:00 2001 From: SivaPusthak Date: Sat, 11 Apr 2026 12:25:32 +0530 Subject: [PATCH] feat(sessions): introduce pluggable SessionDataTransformer hooks for masking in DatabaseSessionService --- contributing/samples/gepa/experiment.py | 1 - contributing/samples/gepa/run_experiment.py | 1 - src/google/adk/sessions/__init__.py | 4 + .../adk/sessions/database_session_service.py | 40 ++++++++- .../adk/sessions/session_data_transformer.py | 43 ++++++++++ .../sessions/test_session_service.py | 86 +++++++++++++++++++ 6 files changed, 169 insertions(+), 6 deletions(-) create mode 100644 src/google/adk/sessions/session_data_transformer.py diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index f3751206a8..2710c3894c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index d857da9635..e31db15788 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/sessions/__init__.py b/src/google/adk/sessions/__init__.py index 7505eda346..0e2a85f5cf 100644 --- a/src/google/adk/sessions/__init__.py +++ b/src/google/adk/sessions/__init__.py @@ -22,12 +22,16 @@ 'DatabaseSessionService', 'InMemorySessionService', 'Session', + 'SessionDataTransformer', 'State', 'VertexAiSessionService', ] def __getattr__(name: str): + if name == 'SessionDataTransformer': + from .session_data_transformer import SessionDataTransformer + return SessionDataTransformer if name == 'DatabaseSessionService': try: from .database_session_service import DatabaseSessionService diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index d033f1f234..7dec7e4b58 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -60,6 +60,7 @@ from .schemas.v1 import StorageSession as StorageSessionV1 from .schemas.v1 import StorageUserState as StorageUserStateV1 from .session import Session +from .session_data_transformer import SessionDataTransformer from .state import State logger = logging.getLogger("google_adk." + __name__) @@ -188,7 +189,13 @@ def __init__(self, version: str): class DatabaseSessionService(BaseSessionService): """A session service that uses a database for storage.""" - def __init__(self, db_url: str, **kwargs: Any): + def __init__( + self, + db_url: str, + *, + transformer: Optional[SessionDataTransformer] = None, + **kwargs: Any, + ): """Initializes the database session service with a database URL.""" # 1. Create DB engine for db connection # 2. Create all tables based on schema @@ -248,6 +255,7 @@ def __init__(self, db_url: str, **kwargs: Any): self._session_locks: dict[_SessionLockKey, asyncio.Lock] = {} self._session_lock_ref_count: dict[_SessionLockKey, int] = {} self._session_locks_guard = asyncio.Lock() + self.transformer = transformer def _get_schema_classes(self) -> _SchemaClasses: return _SchemaClasses(self._db_schema_version) @@ -446,7 +454,12 @@ async def create_session( ) # Extract state deltas - state_deltas = _session_util.extract_state_delta(state) + transformed_state = ( + self.transformer.before_persist_state(state) + if self.transformer and state is not None + else state + ) + state_deltas = _session_util.extract_state_delta(transformed_state) app_state_delta = state_deltas["app"] user_state_delta = state_deltas["user"] session_state = state_deltas["session"] @@ -479,6 +492,8 @@ async def create_session( merged_state = _merge_state( storage_app_state.state, storage_user_state.state, session_state ) + if self.transformer: + merged_state = self.transformer.after_load_state(merged_state) session = storage_session.to_session( state=merged_state, is_sqlite=is_sqlite ) @@ -540,9 +555,16 @@ async def get_session( # Merge states merged_state = _merge_state(app_state, user_state, session_state) + if self.transformer: + merged_state = self.transformer.after_load_state(merged_state) # Convert storage session to session - events = [e.to_event() for e in reversed(storage_events)] + events = [] + for e in reversed(storage_events): + evt = e.to_event() + if self.transformer: + evt = self.transformer.after_load_event(evt) + events.append(evt) is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT session = storage_session.to_session( state=merged_state, events=events, is_sqlite=is_sqlite @@ -596,6 +618,8 @@ async def list_sessions( session_state = storage_session.state user_state = user_states_map.get(storage_session.user_id, {}) merged_state = _merge_state(app_state, user_state, session_state) + if self.transformer: + merged_state = self.transformer.after_load_state(merged_state) sessions.append( storage_session.to_session(state=merged_state, is_sqlite=is_sqlite) ) @@ -640,6 +664,8 @@ async def append_event(self, session: Session, event: Event) -> Event: if event.actions and event.actions.state_delta else {} ) + if self.transformer: + state_delta = self.transformer.before_persist_state(state_delta) state_deltas = _session_util.extract_state_delta(state_delta) has_app_delta = bool(state_deltas["app"]) has_user_delta = bool(state_deltas["user"]) @@ -735,7 +761,13 @@ async def append_event(self, session: Session, event: Event) -> Event: else: update_time = datetime.fromtimestamp(event.timestamp) storage_session.update_time = update_time - sql_session.add(schema.StorageEvent.from_event(session, event)) + + transformed_event = ( + self.transformer.before_persist_event(event) + if self.transformer + else event + ) + sql_session.add(schema.StorageEvent.from_event(session, transformed_event)) await sql_session.commit() diff --git a/src/google/adk/sessions/session_data_transformer.py b/src/google/adk/sessions/session_data_transformer.py new file mode 100644 index 0000000000..f0031818c0 --- /dev/null +++ b/src/google/adk/sessions/session_data_transformer.py @@ -0,0 +1,43 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any +from typing import Mapping +from typing import Protocol + +from google.adk.events.event import Event + + +class SessionDataTransformer(Protocol): + """Hook protocol for selectively transforming DB session records before persist/load. + + This is useful for implementing field-level encryption, PII masking, or secret + scrubbing at the storage boundary without modifying the in-memory core structures, + as long as the transformation yields valid storage dictionaries and Events. + """ + + def before_persist_event(self, event: Event) -> Event: + """Invoked just before serializing and persisting an Event to the database.""" + ... + + def after_load_event(self, event: Event) -> Event: + """Invoked immediately after loading and deserializing an Event from the database.""" + ... + + def before_persist_state(self, state: Mapping[str, Any]) -> dict[str, Any]: + """Invoked before persisting state changes (can be full state or partial deltas).""" + ... + + def after_load_state(self, state: Mapping[str, Any]) -> dict[str, Any]: + """Invoked after loading a combined application/user/session state dict.""" + ... diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 2d7d89f15f..aefd6025ec 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -1626,3 +1626,89 @@ async def tracking_fn(**kwargs): finally: database_session_service._select_required_state = original_fn await service.close() + +import json + + +class MockPIIMaskerTransformer: + def before_persist_state(self, state): + return {k: f"{v}_masked" if isinstance(v, str) else v for k, v in state.items()} + + def after_load_state(self, state): + return {k: v.replace("_masked", "") if isinstance(v, str) and v.endswith("_masked") else v for k, v in state.items()} + + def before_persist_event(self, event: Event) -> Event: + new_event = event.model_copy() if hasattr(event, "model_copy") else event.copy() + if new_event.invocation_id: + new_event.invocation_id += "_masked" + return new_event + + def after_load_event(self, event: Event) -> Event: + new_event = event.model_copy() if hasattr(event, "model_copy") else event.copy() + if new_event.invocation_id and new_event.invocation_id.endswith("_masked"): + new_event.invocation_id = new_event.invocation_id.replace("_masked", "") + return new_event + +@pytest.mark.asyncio +async def test_session_data_transformer(): + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:', transformer=MockPIIMaskerTransformer()) + try: + session = await service.create_session( + app_name='app', user_id='user', session_id='s1', state={'app:secret': 'foo', 'user:pii': 'bar'} + ) + assert session.state == {'app:secret': 'foo', 'user:pii': 'bar'} + + # Verify persistence has been masked + async with service.db_engine.connect() as conn: + from sqlalchemy import text + result = await conn.execute(text("SELECT state FROM app_states WHERE app_name = 'app'")) + app_state_json = result.scalar() + assert "foo_masked" in json.dumps(app_state_json) + + event = Event(invocation_id='inv1', author='user', actions=EventActions(state_delta={'sk1': 'pass'})) + returned_event = await service.append_event(session, event) + + assert returned_event.invocation_id == 'inv1' + assert session.state.get('sk1') == 'pass' + + # Check event persistence + async with service.db_engine.connect() as conn: + result = await conn.execute(text("SELECT id, state FROM sessions WHERE id = 's1'")) + row = result.fetchone() + assert "pass_masked" in json.dumps(row[1]) + + result_evt = await conn.execute(text("SELECT event_data FROM events WHERE session_id = 's1' LIMIT 1")) + evt_payload = result_evt.scalar() + assert "inv1_masked" in json.dumps(evt_payload) + + # Check retrieval unmasks + loaded_session = await service.get_session(app_name='app', user_id='user', session_id='s1') + assert loaded_session.state == {'app:secret': 'foo', 'user:pii': 'bar', 'sk1': 'pass'} + assert len(loaded_session.events) == 1 + assert loaded_session.events[0].invocation_id == 'inv1' + finally: + await service.close() + +class ErrorMaskerTransformer: + def before_persist_state(self, state): + raise ValueError("Transformer exception test") + + def after_load_state(self, state): + return state + + def before_persist_event(self, event: Event) -> Event: + return event + + def after_load_event(self, event: Event) -> Event: + return event + +@pytest.mark.asyncio +async def test_session_data_transformer_handles_exception(): + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:', transformer=ErrorMaskerTransformer()) + try: + with pytest.raises(ValueError, match="Transformer exception test"): + await service.create_session( + app_name='app', user_id='user', session_id='s1', state={'app:secret': 'foo'} + ) + finally: + await service.close()