diff --git a/pyproject.toml b/pyproject.toml index 3a2a587b5d..ba87a0c93a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,7 @@ test = [ "a2a-sdk>=0.3.0,<0.4.0", "anthropic>=0.43.0", # For anthropic model tests "crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool tests; chromadb/pypika fail on 3.12+ + "google-cloud-firestore>=2.11.0", "google-cloud-parametermanager>=0.4.0, <1.0.0", "kubernetes>=29.0.0", # For GkeCodeExecutor "langchain-community>=0.3.17", @@ -158,6 +159,7 @@ extensions = [ "beautifulsoup4>=3.2.2", # For load_web_page tool. "crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool; chromadb/pypika fail on 3.12+ "docker>=7.0.0", # For ContainerCodeExecutor + "google-cloud-firestore>=2.11.0", # For Firestore services "google-cloud-parametermanager>=0.4.0, <1.0.0", "kubernetes>=29.0.0", # For GkeCodeExecutor "k8s-agent-sandbox>=0.1.1.post3", # For GkeCodeExecutor sandbox mode diff --git a/src/google/adk/errors/already_exists_error.py b/src/google/adk/errors/already_exists_error.py index 8bd14f9ad6..bf8d357a81 100644 --- a/src/google/adk/errors/already_exists_error.py +++ b/src/google/adk/errors/already_exists_error.py @@ -18,7 +18,7 @@ class AlreadyExistsError(Exception): """Represents an error that occurs when an entity already exists.""" - def __init__(self, message="The resource already exists."): + def __init__(self, message: str = "The resource already exists."): """Initializes the AlreadyExistsError exception. Args: diff --git a/src/google/adk/integrations/firestore/__init__.py b/src/google/adk/integrations/firestore/__init__.py new file mode 100644 index 0000000000..7c76d28c93 --- /dev/null +++ b/src/google/adk/integrations/firestore/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +"""Firestore integrations for ADK.""" diff --git a/src/google/adk/integrations/firestore/_stop_words.py b/src/google/adk/integrations/firestore/_stop_words.py new file mode 100644 index 0000000000..b72cc5b6cc --- /dev/null +++ b/src/google/adk/integrations/firestore/_stop_words.py @@ -0,0 +1,151 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +DEFAULT_STOP_WORDS = { + "a", + "about", + "above", + "after", + "again", + "against", + "all", + "am", + "an", + "and", + "any", + "are", + "as", + "at", + "be", + "because", + "been", + "before", + "being", + "below", + "between", + "both", + "but", + "by", + "can", + "could", + "did", + "do", + "does", + "doing", + "don", + "down", + "during", + "each", + "else", + "few", + "for", + "from", + "further", + "had", + "has", + "have", + "having", + "he", + "her", + "here", + "hers", + "herself", + "him", + "himself", + "his", + "how", + "i", + "if", + "in", + "into", + "is", + "it", + "its", + "itself", + "just", + "may", + "me", + "might", + "more", + "most", + "must", + "my", + "myself", + "no", + "nor", + "not", + "now", + "of", + "off", + "on", + "once", + "only", + "or", + "other", + "our", + "ours", + "ourselves", + "out", + "over", + "own", + "s", + "same", + "shall", + "she", + "should", + "so", + "some", + "such", + "t", + "than", + "that", + "the", + "their", + "theirs", + "them", + "themselves", + "then", + "there", + "these", + "they", + "this", + "those", + "through", + "to", + "too", + "under", + "until", + "up", + "very", + "was", + "we", + "were", + "what", + "when", + "where", + "which", + "who", + "whom", + "why", + "will", + "with", + "would", + "you", + "your", + "yours", + "yourself", + "yourselves", +} diff --git a/src/google/adk/integrations/firestore/firestore_memory_service.py b/src/google/adk/integrations/firestore/firestore_memory_service.py new file mode 100644 index 0000000000..1d711c35cd --- /dev/null +++ b/src/google/adk/integrations/firestore/firestore_memory_service.py @@ -0,0 +1,195 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +import os +import re +from typing import Any +from typing import Optional +from typing import TYPE_CHECKING + +from google.cloud.firestore_v1.base_query import FieldFilter +from typing_extensions import override + +from ...events.event import Event +from ...memory import _utils +from ...memory.base_memory_service import BaseMemoryService +from ...memory.base_memory_service import SearchMemoryResponse +from ...memory.memory_entry import MemoryEntry +from ._stop_words import DEFAULT_STOP_WORDS + +if TYPE_CHECKING: + from google.cloud import firestore + + from ...sessions.session import Session + +logger = logging.getLogger("google_adk." + __name__) + +DEFAULT_EVENTS_COLLECTION = "events" +DEFAULT_MEMORIES_COLLECTION = "memories" + + +class FirestoreMemoryService(BaseMemoryService): # type: ignore[misc] + """Memory service that uses Google Cloud Firestore as the backend. + + It uses the existing session data to create memories in a top-level memory collection. + """ + + def __init__( + self, + client: Optional[firestore.AsyncClient] = None, + events_collection: Optional[str] = None, + stop_words: Optional[set[str]] = None, + memories_collection: Optional[str] = None, + ): + """Initializes the Firestore memory service. + + Args: + client: An optional Firestore AsyncClient. If not provided, a new one + will be created. + events_collection: The name of the events collection or collection group. + Defaults to 'events'. + stop_words: A set of words to ignore when extracting keywords. Defaults to + a standard English stop words list. + memories_collection: The name of the memories collection. Defaults to + 'memories'. + """ + if client is None: + from google.cloud import firestore + + self.client = firestore.AsyncClient() + else: + self.client = client + self.events_collection = events_collection or DEFAULT_EVENTS_COLLECTION + self.memories_collection = ( + memories_collection or DEFAULT_MEMORIES_COLLECTION + ) + self.stop_words = ( + stop_words if stop_words is not None else DEFAULT_STOP_WORDS + ) + + @override + async def add_session_to_memory(self, session: Session) -> None: + """Extracts keywords from session events and stores them in the memories collection.""" + batch = self.client.batch() + count = 0 + + for event in session.events: + if not event.content or not event.content.parts: + continue + + text = " ".join([part.text for part in event.content.parts if part.text]) + if not text: + continue + + keywords = self._extract_keywords(text) + if not keywords: + continue + + doc_ref = self.client.collection(self.memories_collection).document() + batch.set( + doc_ref, + { + "appName": session.app_name, + "userId": session.user_id, + "keywords": list(keywords), + "author": event.author, + "content": event.content.model_dump( + exclude_none=True, mode="json" + ), + "timestamp": event.timestamp, + }, + ) + count += 1 + if count >= 500: + await batch.commit() + batch = self.client.batch() + count = 0 + + if count > 0: + await batch.commit() + + def _extract_keywords(self, text: str) -> set[str]: + """Extracts keywords from text, ignoring stop words.""" + words = re.findall(r"[A-Za-z]+", text.lower()) + return {word for word in words if word not in self.stop_words} + + async def _search_by_keyword( + self, app_name: str, user_id: str, keyword: str + ) -> list[MemoryEntry]: + """Searches for events matching a single keyword.""" + query = ( + self.client.collection(self.memories_collection) + .where(filter=FieldFilter("appName", "==", app_name)) + .where(filter=FieldFilter("userId", "==", user_id)) + .where(filter=FieldFilter("keywords", "array_contains", keyword)) + ) + + docs = await query.get() + entries = [] + for doc in docs: + data = doc.to_dict() + if data and "content" in data: + try: + from google.genai import types + + content = types.Content.model_validate(data["content"]) + entries.append( + MemoryEntry( + content=content, + author=data.get("author", ""), + timestamp=_utils.format_timestamp(data.get("timestamp", 0.0)), + ) + ) + except Exception as e: + logger.warning(f"Failed to parse memory entry: {e}") + + return entries + + @override + async def search_memory( + self, *, app_name: str, user_id: str, query: str + ) -> SearchMemoryResponse: + """Searches memory for events matching the query.""" + keywords = self._extract_keywords(query) + if not keywords: + return SearchMemoryResponse() + + tasks = [ + self._search_by_keyword(app_name, user_id, keyword) + for keyword in keywords + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + + seen = set() + memories = [] + for result_list in results: + if isinstance(result_list, BaseException): + logger.warning(f"Memory keyword search partial failure: {result_list}") + continue + for entry in result_list: + content_text = "" + if entry.content and entry.content.parts: + content_text = " ".join( + [part.text for part in entry.content.parts if part.text] + ) + key = (entry.author, content_text, entry.timestamp) + if key not in seen: + seen.add(key) + memories.append(entry) + + return SearchMemoryResponse(memories=memories) diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py new file mode 100644 index 0000000000..83b97c33c2 --- /dev/null +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -0,0 +1,586 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from datetime import datetime +from datetime import timezone +import logging +import os +from typing import Any +from typing import AsyncIterator +from typing import cast +from typing import Optional +from typing import TYPE_CHECKING + +_SessionLockKey = tuple[str, str, str] + +if TYPE_CHECKING: + from google.cloud import firestore + +from pydantic import BaseModel + +from ...events.event import Event +from ...sessions import _session_util +from ...sessions.base_session_service import BaseSessionService +from ...sessions.base_session_service import GetSessionConfig +from ...sessions.base_session_service import ListSessionsResponse +from ...sessions.session import Session +from ...sessions.state import State + +logger = logging.getLogger("google_adk." + __name__) + +DEFAULT_ROOT_COLLECTION = "adk-session" +DEFAULT_SESSIONS_COLLECTION = "sessions" +DEFAULT_EVENTS_COLLECTION = "events" +DEFAULT_APP_STATE_COLLECTION = "app_states" +DEFAULT_USER_STATE_COLLECTION = "user_states" + + +class FirestoreSessionService(BaseSessionService): # type: ignore[misc] + """Session service that uses Google Cloud Firestore as the backend. + + Hierarchy for sessions: + adk-session + ↳ + ↳ users + ↳ + ↳ sessions + ↳ + ↳ events + ↳ + + Hierarchy for shared App/User state configurations: + app_states + ↳ + + user_states + ↳ + ↳ users + ↳ + """ + + def __init__( + self, + client: Optional[firestore.AsyncClient] = None, + root_collection: Optional[str] = None, + ): + """Initializes the Firestore session service. + + Args: + client: An optional Firestore AsyncClient. If not provided, a new one + will be created. + root_collection: The root collection name. Defaults to 'adk-session' or + the value of ADK_FIRESTORE_ROOT_COLLECTION env var. + """ + try: + from google.cloud import firestore + except ImportError as e: + raise ImportError( + "FirestoreSessionService requires google-cloud-firestore. " + "Install it with: pip install google-cloud-firestore" + ) from e + + self.client = client or firestore.AsyncClient() + self.root_collection = ( + root_collection + or os.environ.get("ADK_FIRESTORE_ROOT_COLLECTION") + or DEFAULT_ROOT_COLLECTION + ) + self.sessions_collection = DEFAULT_SESSIONS_COLLECTION + + # Per-session locks used to serialize append_event calls in this process. + self._session_locks: dict[_SessionLockKey, asyncio.Lock] = {} + self._session_lock_ref_count: dict[_SessionLockKey, int] = {} + self._session_locks_guard = asyncio.Lock() + self.events_collection = DEFAULT_EVENTS_COLLECTION + self.app_state_collection = DEFAULT_APP_STATE_COLLECTION + self.user_state_collection = DEFAULT_USER_STATE_COLLECTION + + @asynccontextmanager + async def _with_session_lock( + self, *, app_name: str, user_id: str, session_id: str + ) -> AsyncIterator[None]: + """Serializes event appends for the same session within this process.""" + lock_key = (app_name, user_id, session_id) + async with self._session_locks_guard: + lock = self._session_locks.get(lock_key, asyncio.Lock()) + self._session_locks[lock_key] = lock + self._session_lock_ref_count[lock_key] = ( + self._session_lock_ref_count.get(lock_key, 0) + 1 + ) + + try: + async with lock: + yield + finally: + async with self._session_locks_guard: + remaining = self._session_lock_ref_count.get(lock_key, 0) - 1 + if remaining <= 0 and not lock.locked(): + self._session_lock_ref_count.pop(lock_key, None) + self._session_locks.pop(lock_key, None) + else: + self._session_lock_ref_count[lock_key] = remaining + + @staticmethod + def _merge_state( + app_state: dict[str, Any], + user_state: dict[str, Any], + session_state: dict[str, Any], + ) -> dict[str, Any]: + """Merge app, user, and session states into a single state dictionary.""" + import copy + + merged_state = copy.deepcopy(session_state) + for key, value in app_state.items(): + merged_state[State.APP_PREFIX + key] = value + for key, value in user_state.items(): + merged_state[State.USER_PREFIX + key] = value + return merged_state + + def _get_sessions_ref( + self, app_name: str, user_id: str + ) -> firestore.AsyncCollectionReference: + return ( + self.client.collection(self.root_collection) + .document(app_name) + .collection("users") + .document(user_id) + .collection(self.sessions_collection) + ) + + async def create_session( + self, + *, + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + session_id: Optional[str] = None, + ) -> Session: + """Creates a new session in Firestore.""" + from google.cloud import firestore + + if not session_id: + from ...platform import uuid as platform_uuid + + session_id = platform_uuid.new_uuid() + + initial_state = state or {} + now = firestore.SERVER_TIMESTAMP + + session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) + + # Extract state deltas + state_deltas = _session_util.extract_state_delta(initial_state) + app_state_delta = state_deltas["app"] + user_state_delta = state_deltas["user"] + session_state = state_deltas["session"] + + app_ref = self.client.collection(self.app_state_collection).document( + app_name + ) + user_ref = ( + self.client.collection(self.user_state_collection) + .document(app_name) + .collection("users") + .document(user_id) + ) + + session_data = { + "id": session_id, + "appName": app_name, + "userId": user_id, + "state": session_state, + "createTime": now, + "updateTime": now, + "revision": 1, + } + + @firestore.async_transactional # type: ignore[untyped-decorator] + async def _create_txn(transaction: firestore.AsyncTransaction) -> None: + # 1. Reads + snap = await session_ref.get(transaction=transaction) + if snap.exists: + from ...errors.already_exists_error import AlreadyExistsError + + raise AlreadyExistsError(f"Session {session_id} already exists.") + + app_snap = ( + await app_ref.get(transaction=transaction) + if app_state_delta + else None + ) + user_snap = ( + await user_ref.get(transaction=transaction) + if user_state_delta + else None + ) + + # 2. Writes + if app_state_delta: + current_app = ( + app_snap.to_dict() if (app_snap and app_snap.exists) else {} + ) + current_app.update(app_state_delta) + transaction.set(app_ref, current_app, merge=True) + + if user_state_delta: + current_user = ( + user_snap.to_dict() if (user_snap and user_snap.exists) else {} + ) + current_user.update(user_state_delta) + transaction.set(user_ref, current_user, merge=True) + + transaction.set(session_ref, session_data) + + transaction_obj = self.client.transaction() + await _create_txn(transaction_obj) + + storage_app_doc = await app_ref.get() + storage_app_state = ( + storage_app_doc.to_dict() if storage_app_doc.exists else {} + ) + storage_user_doc = await user_ref.get() + storage_user_state = ( + storage_user_doc.to_dict() if storage_user_doc.exists else {} + ) + + merged_state = self._merge_state( + storage_app_state, storage_user_state, session_state + ) + + local_now = datetime.now(timezone.utc).timestamp() + + session = Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=merged_state, + events=[], + last_update_time=local_now, + ) + session._storage_update_marker = "1" + return session + + async def get_session( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: Optional[GetSessionConfig] = None, + ) -> Optional[Session]: + """Gets a session from Firestore.""" + session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) + doc = await session_ref.get() + + if not doc.exists: + return None + + data = doc.to_dict() + if not data: + return None + + # Fetch events + events_ref = session_ref.collection(self.events_collection) + query = events_ref.order_by("timestamp") + + if config: + if config.after_timestamp: + after_dt = datetime.fromtimestamp(config.after_timestamp) + query = query.where("timestamp", ">=", after_dt) + if config.num_recent_events: + query = query.limit_to_last(config.num_recent_events) + + events_docs = await query.get() + events = [] + for event_doc in events_docs: + event_data = event_doc.to_dict() + if event_data and "event_data" in event_data: + ed = event_data["event_data"] + events.append(Event.model_validate(ed)) + + # Let's continue getting session. + session_state = data.get("state", {}) + + # Fetch shared state + app_ref = self.client.collection(self.app_state_collection).document( + app_name + ) + user_ref = ( + self.client.collection(self.user_state_collection) + .document(app_name) + .collection("users") + .document(user_id) + ) + app_doc = await app_ref.get() + app_state = app_doc.to_dict() if app_doc.exists else {} + user_doc = await user_ref.get() + user_state = user_doc.to_dict() if user_doc.exists else {} + + merged_state = self._merge_state(app_state, user_state, session_state) + + # Convert timestamp + update_time = data.get("updateTime") + last_update_time = 0.0 + if update_time: + if isinstance(update_time, datetime): + last_update_time = update_time.timestamp() + else: + try: + last_update_time = float(update_time) + except (ValueError, TypeError): + pass + + current_revision = data.get("revision", 0) + session = Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=merged_state, + events=events, + last_update_time=last_update_time, + ) + session._storage_update_marker = ( + str(current_revision) if current_revision > 0 else None + ) + return session + + async def list_sessions( + self, *, app_name: str, user_id: Optional[str] = None + ) -> ListSessionsResponse: + """Lists sessions from Firestore.""" + if user_id: + query = self._get_sessions_ref(app_name, user_id).where( + "appName", "==", app_name + ) + docs = await query.get() + else: + query = self.client.collection_group(self.sessions_collection).where( + "appName", "==", app_name + ) + docs = await query.get() + + # Fetch shared state once + app_ref = self.client.collection(self.app_state_collection).document( + app_name + ) + app_doc = await app_ref.get() + app_state = app_doc.to_dict() if app_doc.exists else {} + + user_states_map = {} + if user_id: + user_ref = ( + self.client.collection(self.user_state_collection) + .document(app_name) + .collection("users") + .document(user_id) + ) + user_doc = await user_ref.get() + if user_doc.exists: + user_states_map[user_id] = user_doc.to_dict() + else: + users_ref = ( + self.client.collection(self.user_state_collection) + .document(app_name) + .collection("users") + ) + users_docs = await users_ref.get() + for u_doc in users_docs: + user_states_map[u_doc.id] = u_doc.to_dict() + + sessions = [] + for doc in docs: + data = doc.to_dict() + if data: + u_id = data["userId"] + s_state = data.get("state", {}) + u_state = user_states_map.get(u_id, {}) + merged = self._merge_state(app_state, u_state, s_state) + + sessions.append( + Session( + id=data["id"], + app_name=data["appName"], + user_id=data["userId"], + state=merged, + events=[], + last_update_time=0.0, + ) + ) + + return ListSessionsResponse(sessions=sessions) + + async def delete_session( + self, *, app_name: str, user_id: str, session_id: str + ) -> None: + """Deletes a session and its events from Firestore.""" + from google.cloud import firestore + + session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) + + @firestore.async_transactional # type: ignore[untyped-decorator] + async def _mark_deleting_txn( + transaction: firestore.AsyncTransaction, + ) -> None: + snap = await session_ref.get(transaction=transaction) + if snap.exists: + transaction.update(session_ref, {"status": "DELETING"}) + + try: + transaction_obj = self.client.transaction() + await _mark_deleting_txn(transaction_obj) + except Exception: + pass + + events_ref = session_ref.collection(self.events_collection) + + batch = self.client.batch() + count = 0 + async for event_doc in events_ref.stream(): + batch.delete(event_doc.reference) + count += 1 + if count >= 500: + await batch.commit() + batch = self.client.batch() + count = 0 + if count > 0: + await batch.commit() + + await session_ref.delete() + + async def append_event(self, session: Session, event: Event) -> Event: + """Appends an event to a session in Firestore.""" + from google.cloud import firestore + + if event.partial: + return event + + self._apply_temp_state(session, event) + event = self._trim_temp_delta_state(event) + + session_ref = self._get_sessions_ref( + session.app_name, session.user_id + ).document(session.id) + + state_delta = ( + event.actions.state_delta + if event.actions and event.actions.state_delta + else {} + ) + state_deltas = _session_util.extract_state_delta(state_delta) + app_updates = state_deltas["app"] + user_updates = state_deltas["user"] + session_updates = state_deltas["session"] + + app_ref = self.client.collection(self.app_state_collection).document( + session.app_name + ) + user_ref = ( + self.client.collection(self.user_state_collection) + .document(session.app_name) + .collection("users") + .document(session.user_id) + ) + + async with self._with_session_lock( + app_name=session.app_name, + user_id=session.user_id, + session_id=session.id, + ): + + @firestore.async_transactional # type: ignore[untyped-decorator] + async def _append_txn(transaction: firestore.AsyncTransaction) -> int: + # 1. Reads + session_snap = await session_ref.get(transaction=transaction) + if not session_snap.exists: + raise ValueError(f"Session {session.id} not found.") + + session_doc = session_snap.to_dict() or {} + if session_doc.get("status") == "DELETING": + raise ValueError(f"Session {session.id} is currently being deleted.") + + current_revision = session_doc.get("revision", 0) + + if session._storage_update_marker is not None: + if session._storage_update_marker != str(current_revision): + raise ValueError( + "The session has been modified in storage since it was loaded. " + "Please reload the session before appending more events." + ) + + app_snap = ( + await app_ref.get(transaction=transaction) if app_updates else None + ) + user_snap = ( + await user_ref.get(transaction=transaction) + if user_updates + else None + ) + + # 2. Writes + if app_updates and app_snap is not None: + current_app = app_snap.to_dict() if app_snap.exists else {} + current_app.update(app_updates) + transaction.set(app_ref, current_app, merge=True) + + if user_updates and user_snap is not None: + current_user = user_snap.to_dict() if user_snap.exists else {} + current_user.update(user_updates) + transaction.set(user_ref, current_user, merge=True) + + for k, v in session_updates.items(): + session.state[k] = v + + new_revision = current_revision + 1 + session_only_state = { + k: v + for k, v in session.state.items() + if not k.startswith(State.APP_PREFIX) + and not k.startswith(State.USER_PREFIX) + } + transaction.update( + session_ref, + { + "state": session_only_state, + "updateTime": firestore.SERVER_TIMESTAMP, + "revision": new_revision, + }, + ) + + event_id = event.id + event_ref = session_ref.collection(self.events_collection).document( + event_id + ) + event_data = event.model_dump(exclude_none=True, mode="json") + transaction.set( + event_ref, + { + "event_data": event_data, + "timestamp": firestore.SERVER_TIMESTAMP, + "appName": session.app_name, + "userId": session.user_id, + }, + ) + + return cast(int, new_revision) + + transaction_obj = self.client.transaction() + new_revision_count = await _append_txn(transaction_obj) + session._storage_update_marker = str(new_revision_count) + + await super().append_event(session, event) + return event diff --git a/tests/unittests/integrations/firestore/test_firestore_memory_service.py b/tests/unittests/integrations/firestore/test_firestore_memory_service.py new file mode 100644 index 0000000000..afa7f75cac --- /dev/null +++ b/tests/unittests/integrations/firestore/test_firestore_memory_service.py @@ -0,0 +1,388 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +from google.adk.events.event import Event +from google.adk.integrations.firestore.firestore_memory_service import FirestoreMemoryService +from google.cloud.firestore_v1.base_query import FieldFilter +from google.genai import types +import pytest + + +@pytest.fixture +def mock_firestore_client(): + client = mock.MagicMock() + collection_ref = mock.MagicMock() + client.collection.return_value = collection_ref + + collection_ref.where.return_value = collection_ref + + doc_snapshot = mock.MagicMock() + doc_snapshot.to_dict.return_value = {} + + collection_ref.get = mock.AsyncMock(return_value=[doc_snapshot]) + + return client + + +def test_extract_keywords(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + text = "The quick brown fox jumps over the lazy dog." + keywords = service._extract_keywords(text) + + assert "the" not in keywords + assert "over" not in keywords + assert "quick" in keywords + assert "brown" in keywords + assert "fox" in keywords + assert "jumps" in keywords + assert "lazy" in keywords + assert "dog" in keywords + + +@pytest.mark.asyncio +async def test_search_memory_empty_query(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + response = await service.search_memory( + app_name="test_app", user_id="test_user", query="" + ) + assert not response.memories + mock_firestore_client.collection.assert_not_called() + + +@pytest.mark.asyncio +async def test_search_memory_with_results(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + query = "quick fox" + + doc_snapshot = mock_firestore_client.collection.return_value.where.return_value.where.return_value.where.return_value.get.return_value[ + 0 + ] + + content = types.Content(parts=[types.Part.from_text(text="quick fox jumps")]) + + doc_snapshot.to_dict.return_value = { + "appName": app_name, + "userId": user_id, + "author": "user", + "content": content.model_dump(exclude_none=True, mode="json"), + "timestamp": 1234567890.0, + } + + response = await service.search_memory( + app_name=app_name, user_id=user_id, query=query + ) + + assert response.memories + assert len(response.memories) == 1 + assert response.memories[0].author == "user" + + mock_firestore_client.collection.assert_called_with("memories") + collection_ref = mock_firestore_client.collection.return_value + + assert collection_ref.where.call_count == 6 + calls = collection_ref.where.call_args_list + + app_name_calls = 0 + user_id_calls = 0 + keyword_calls = 0 + + for call in calls: + kwargs = call.kwargs + filt = kwargs.get("filter") + if filt: + if ( + filt.field_path == "appName" + and filt.op_string == "==" + and filt.value == app_name + ): + app_name_calls += 1 + elif ( + filt.field_path == "userId" + and filt.op_string == "==" + and filt.value == user_id + ): + user_id_calls += 1 + elif filt.field_path == "keywords" and filt.op_string == "array_contains": + + if filt.value in ["quick", "fox"]: + keyword_calls += 1 + + assert app_name_calls == 2 + assert user_id_calls == 2 + assert keyword_calls == 2 + + +@pytest.mark.asyncio +async def test_search_memory_deduplication(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + query = "quick fox" + + content = types.Content(parts=[types.Part.from_text(text="quick fox jumps")]) + + doc_snapshot1 = mock.MagicMock() + doc_snapshot1.to_dict.return_value = { + "appName": app_name, + "userId": user_id, + "author": "user", + "content": content.model_dump(exclude_none=True, mode="json"), + "timestamp": 1234567890.0, + } + + doc_snapshot2 = mock.MagicMock() + doc_snapshot2.to_dict.return_value = { + "appName": app_name, + "userId": user_id, + "author": "user", + "content": content.model_dump(exclude_none=True, mode="json"), + "timestamp": 1234567890.0, + } + + get_mock = mock.AsyncMock(side_effect=[[doc_snapshot1], [doc_snapshot2]]) + + mock_firestore_client.collection.return_value.where.return_value.where.return_value.where.return_value.get = ( + get_mock + ) + + response = await service.search_memory( + app_name=app_name, user_id=user_id, query=query + ) + + assert response.memories + assert len(response.memories) == 1 + assert response.memories[0].author == "user" + + +@pytest.mark.asyncio +async def test_search_memory_parsing_error(mock_firestore_client, caplog): + service = FirestoreMemoryService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + query = "quick" + + doc_snapshot = mock_firestore_client.collection.return_value.where.return_value.where.return_value.where.return_value.get.return_value[ + 0 + ] + doc_snapshot.to_dict.return_value = {"content": "invalid_data"} + + response = await service.search_memory( + app_name=app_name, user_id=user_id, query=query + ) + + assert not response.memories + assert "Failed to parse memory entry" in caplog.text + + +@pytest.mark.asyncio +async def test_search_memory_only_stop_words(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + response = await service.search_memory( + app_name="test_app", user_id="test_user", query="the and or" + ) + assert not response.memories + mock_firestore_client.collection.assert_not_called() + + +@pytest.mark.asyncio +async def test_search_memory_partial_failures(mock_firestore_client, caplog): + service = FirestoreMemoryService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + query = "fox quick" + + coll_ref = ( + mock_firestore_client.collection.return_value.where.return_value.where.return_value.where.return_value + ) + + doc_snapshot = mock.MagicMock() + doc_snapshot.to_dict.return_value = { + "content": {"parts": [{"text": "quick response"}]}, + "author": "user", + "timestamp": 1234567890.0, + } + + call_count = 0 + + async def mock_get(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ValueError("Mock generic network failure standalone") + return [doc_snapshot] + + coll_ref.get = mock.AsyncMock(side_effect=mock_get) + + response = await service.search_memory( + app_name=app_name, user_id=user_id, query=query + ) + + assert len(response.memories) == 1 + assert response.memories[0].author == "user" + assert "Memory keyword search partial failure" in caplog.text + + +def test_init_default_client(): + with mock.patch("google.cloud.firestore.AsyncClient") as mock_client_class: + mock_instance = mock.MagicMock() + mock_client_class.return_value = mock_instance + + service = FirestoreMemoryService() + + mock_client_class.assert_called_once() + assert service.client == mock_instance + + +@pytest.mark.asyncio +async def test_add_session_to_memory(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name="test_app", user_id="test_user") + + content = types.Content(parts=[types.Part.from_text(text="quick brown fox")]) + event = Event( + invocation_id="test_inv", + author="user", + content=content, + timestamp=1234567890.0, + ) + session.events.append(event) + + batch = mock.MagicMock() + mock_firestore_client.batch.return_value = batch + batch.commit = mock.AsyncMock() + + doc_ref = mock.MagicMock() + mock_firestore_client.collection.return_value.document.return_value = doc_ref + + await service.add_session_to_memory(session) + + mock_firestore_client.batch.assert_called_once() + mock_firestore_client.collection.assert_called_with("memories") + batch.set.assert_called_once() + batch.commit.assert_called_once() + + args, kwargs = batch.set.call_args + assert args[0] == doc_ref + data = args[1] + assert data["appName"] == "test_app" + assert data["userId"] == "test_user" + assert "quick" in data["keywords"] + assert data["author"] == "user" + assert data["timestamp"] == 1234567890.0 + + +@pytest.mark.asyncio +async def test_add_session_to_memory_no_events(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name="test_app", user_id="test_user") + + batch = mock.MagicMock() + mock_firestore_client.batch.return_value = batch + + await service.add_session_to_memory(session) + + mock_firestore_client.batch.assert_called_once() + batch.set.assert_not_called() + batch.commit.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_session_to_memory_no_keywords(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name="test_app", user_id="test_user") + + content = types.Content(parts=[types.Part.from_text(text="the and or")]) + event = Event(invocation_id="test_inv", author="user", content=content) + session.events.append(event) + + batch = mock.MagicMock() + mock_firestore_client.batch.return_value = batch + + await service.add_session_to_memory(session) + + mock_firestore_client.batch.assert_called_once() + batch.set.assert_not_called() + batch.commit.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_session_to_memory_commit_error(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name="test_app", user_id="test_user") + + content = types.Content(parts=[types.Part.from_text(text="quick brown fox")]) + event = Event(invocation_id="test_inv", author="user", content=content) + session.events.append(event) + + batch = mock.MagicMock() + mock_firestore_client.batch.return_value = batch + batch.commit = mock.AsyncMock( + side_effect=Exception("Firestore commit failed") + ) + + with pytest.raises(Exception, match="Firestore commit failed"): + await service.add_session_to_memory(session) + + +@pytest.mark.asyncio +async def test_add_session_to_memory_exceeds_batch_limit(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name="test_app", user_id="test_user") + + for i in range(501): + content = types.Content( + parts=[types.Part.from_text(text=f"event keyword {i}")] + ) + event = Event( + invocation_id=f"test_inv_{i}", + author="user", + content=content, + timestamp=1234567890.0 + i, + ) + session.events.append(event) + + batch1 = mock.MagicMock() + batch2 = mock.MagicMock() + batch1.commit = mock.AsyncMock() + batch2.commit = mock.AsyncMock() + mock_firestore_client.batch.side_effect = [batch1, batch2] + + await service.add_session_to_memory(session) + + assert mock_firestore_client.batch.call_count == 2 + assert batch1.set.call_count == 500 + batch1.commit.assert_called_once() + assert batch2.set.call_count == 1 + batch2.commit.assert_called_once() diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py new file mode 100644 index 0000000000..1445bfe0ef --- /dev/null +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -0,0 +1,757 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +from google.adk.events.event import Event +from google.adk.integrations.firestore.firestore_session_service import FirestoreSessionService +import pytest + + +@pytest.fixture +def mock_firestore_client(): + client = mock.MagicMock() + collection_ref = mock.MagicMock() + doc_ref = mock.MagicMock() + subcollection_ref = mock.MagicMock() + subdoc_ref = mock.MagicMock() + sessions_coll_ref = mock.MagicMock() + sessions_doc_ref = mock.MagicMock() + + client.collection.return_value = collection_ref + collection_ref.document.return_value = doc_ref + doc_ref.collection.return_value = subcollection_ref + subcollection_ref.document.return_value = subdoc_ref + subdoc_ref.collection.return_value = sessions_coll_ref + sessions_coll_ref.document.return_value = sessions_doc_ref + + doc_snapshot = mock.MagicMock() + doc_snapshot.exists = False + doc_snapshot.to_dict.return_value = {} + + subdoc_ref.get = mock.AsyncMock(return_value=doc_snapshot) + sessions_doc_ref.get = mock.AsyncMock(return_value=doc_snapshot) + doc_ref.get = mock.AsyncMock(return_value=doc_snapshot) + + sessions_doc_ref.set = mock.AsyncMock() + sessions_doc_ref.delete = mock.AsyncMock() + + events_collection_ref = mock.MagicMock() + sessions_doc_ref.collection.return_value = events_collection_ref + events_collection_ref.order_by.return_value = events_collection_ref + events_collection_ref.where.return_value = events_collection_ref + events_collection_ref.limit_to_last.return_value = events_collection_ref + events_collection_ref.get = mock.AsyncMock(return_value=[]) + + sessions_coll_ref.get = mock.AsyncMock(return_value=[]) + sessions_coll_ref.where.return_value = sessions_coll_ref + + client.collection_group.return_value = collection_ref + + batch = mock.MagicMock() + client.batch.return_value = batch + batch.commit = mock.AsyncMock() + + return client + + +def test_init_missing_dependency(): + import builtins + + original_import = builtins.__import__ + + def mock_import(name, globals=None, locals=None, fromlist=(), level=0): + if name == "google.cloud" and "firestore" in fromlist: + raise ImportError("Mocked import error") + return original_import(name, globals, locals, fromlist, level) + + with mock.patch("builtins.__import__", side_effect=mock_import): + with pytest.raises(ImportError, match="requires google-cloud-firestore"): + FirestoreSessionService() + + +@pytest.mark.asyncio +async def test_create_session(mock_firestore_client): + + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + + with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): + session = await service.create_session(app_name=app_name, user_id=user_id) + + assert session.app_name == app_name + assert session.user_id == user_id + assert session.id + + mock_firestore_client.collection.assert_any_call("adk-session") + mock_firestore_client.collection.assert_any_call("app_states") + mock_firestore_client.collection.assert_any_call("user_states") + + root_coll = mock_firestore_client.collection.return_value + app_ref = root_coll.document.return_value + users_coll = app_ref.collection.return_value + user_ref = users_coll.document.return_value + sessions_ref = user_ref.collection.return_value + session_doc_ref = sessions_ref.document.return_value + + from google.cloud import firestore + + transaction = mock_firestore_client.transaction.return_value + transaction.set.assert_called_once() + args, kwargs = transaction.set.call_args + assert args[0] == session_doc_ref + assert args[1]["id"] == session.id + assert args[1]["appName"] == app_name + assert args[1]["userId"] == user_id + assert args[1]["state"] == {} + assert args[1]["createTime"] == firestore.SERVER_TIMESTAMP + assert args[1]["updateTime"] == firestore.SERVER_TIMESTAMP + + +@pytest.mark.asyncio +async def test_get_session_not_found(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + session = await service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + assert session is None + + mock_firestore_client.collection.assert_called_with("adk-session") + root_coll = mock_firestore_client.collection.return_value + root_coll.document.assert_called_with(app_name) + app_ref = root_coll.document.return_value + app_ref.collection.assert_called_with("users") + users_coll = app_ref.collection.return_value + users_coll.document.assert_called_with(user_id) + user_ref = users_coll.document.return_value + user_ref.collection.assert_called_with("sessions") + sessions_ref = user_ref.collection.return_value + sessions_ref.document.assert_called_with(session_id) + + +@pytest.mark.asyncio +async def test_get_session_found(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + root_coll = mock_firestore_client.collection.return_value + app_ref = root_coll.document.return_value + users_coll = app_ref.collection.return_value + user_ref = users_coll.document.return_value + sessions_ref = user_ref.collection.return_value + sessions_doc_ref = sessions_ref.document.return_value + + session_snap = mock.MagicMock() + session_snap.exists = True + session_snap.to_dict.return_value = { + "id": session_id, + "appName": app_name, + "userId": user_id, + "state": {"key": "value"}, + "updateTime": 1234567890.0, + } + sessions_doc_ref.get.return_value = session_snap + + # Decouple app and user documents so they do not duplicate values + app_state_coll = mock_firestore_client.collection.return_value + app_doc_ref = app_state_coll.document.return_value + app_snap = mock.MagicMock() + app_snap.exists = False + app_snap.to_dict.return_value = {} + app_doc_ref.get.return_value = app_snap + + user_state_coll = mock_firestore_client.collection.return_value + user_doc_ref = user_state_coll.document.return_value + user_snap = mock.MagicMock() + user_snap.exists = False + user_snap.to_dict.return_value = {} + user_doc_ref.get.return_value = user_snap + + events_collection_ref = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + ) + event_doc = mock.MagicMock() + event_doc.to_dict.return_value = { + "event_data": {"invocation_id": "test_inv", "author": "user"} + } + events_collection_ref.get = mock.AsyncMock(return_value=[event_doc]) + + session = await service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + assert session is not None + assert session.id == session_id + assert session.state == {"key": "value"} + assert len(session.events) == 1 + assert session.events[0].invocation_id == "test_inv" + + +@pytest.mark.asyncio +async def test_delete_session(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + events_ref = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + ) + event_doc = mock.AsyncMock() + + async def to_async_iter(iterable): + for item in iterable: + yield item + + events_ref.stream.return_value = to_async_iter([event_doc]) + + await service.delete_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + events_ref.stream.assert_called_once() + mock_firestore_client.batch.assert_called_once() + batch = mock_firestore_client.batch.return_value + batch.delete.assert_called_once_with(event_doc.reference) + batch.commit.assert_called_once() + + session_doc_ref = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value + ) + session_doc_ref.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_append_event(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name=app_name, user_id=user_id) + event = Event(invocation_id="test_inv", author="user") + + session_doc_snapshot = mock.MagicMock() + session_doc_snapshot.exists = True + session_doc_snapshot.to_dict.return_value = {"revision": 0} + + root_coll = mock_firestore_client.collection.return_value + app_ref = root_coll.document.return_value + users_coll = app_ref.collection.return_value + user_ref = users_coll.document.return_value + sessions_ref = user_ref.collection.return_value + session_doc_ref = sessions_ref.document.return_value + session_doc_ref.get = mock.AsyncMock(return_value=session_doc_snapshot) + + with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): + await service.append_event(session, event) + + from google.cloud import firestore + + transaction = mock_firestore_client.transaction.return_value + transaction.set.assert_called() # Invoked for events appends + transaction.update.assert_called_once() # Invoked for session revisions + + args, kwargs = transaction.update.call_args + assert args[1]["revision"] == 1 + assert args[1]["updateTime"] == firestore.SERVER_TIMESTAMP + + +@pytest.mark.asyncio +async def test_append_event_with_state_delta(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name=app_name, user_id=user_id) + + event = mock.MagicMock() + event.partial = False + event.id = "test_event_id" + event.actions.state_delta = { + "_app_my_key": "app_val", + "_user_my_key": "user_val", + "session_key": "session_val", + } + event.model_dump.return_value = {"id": "test_event_id", "author": "user"} + + service._update_app_state_transactional = mock.AsyncMock() + service._update_user_state_transactional = mock.AsyncMock() + + session_doc_snapshot = mock.MagicMock() + session_doc_snapshot.exists = True + session_doc_snapshot.to_dict.return_value = {"revision": 0} + + root_coll = mock_firestore_client.collection.return_value + app_ref = root_coll.document.return_value + users_coll = app_ref.collection.return_value + user_ref = users_coll.document.return_value + sessions_ref = user_ref.collection.return_value + session_doc_ref = sessions_ref.document.return_value + session_doc_ref.get = mock.AsyncMock(return_value=session_doc_snapshot) + + with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): + await service.append_event(session, event) + + transaction = mock_firestore_client.transaction.return_value + transaction.set.assert_called() + + assert session.state["session_key"] == "session_val" + + from google.cloud import firestore + + transaction.update.assert_called_once() + args, kwargs = transaction.update.call_args + # In modular Firestore configurations alignments, updating variables mock assertions core setups + assert args[1]["state"] == session.state + assert args[1]["updateTime"] == firestore.SERVER_TIMESTAMP + + +@pytest.mark.asyncio +async def test_append_event_with_temp_state(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + from google.adk.events.event import Event + from google.adk.events.event import EventActions + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name=app_name, user_id=user_id) + + event = Event( + invocation_id="test_inv", + author="user", + actions=EventActions( + state_delta={"temp:k1": "v1", "session_key": "session_val"} + ), + ) + + session_doc_snapshot = mock.MagicMock() + session_doc_snapshot.exists = True + session_doc_snapshot.to_dict.return_value = {"revision": 0} + + root_coll = mock_firestore_client.collection.return_value + app_ref = root_coll.document.return_value + users_coll = app_ref.collection.return_value + user_ref = users_coll.document.return_value + sessions_ref = user_ref.collection.return_value + session_doc_ref = sessions_ref.document.return_value + session_doc_ref.get = mock.AsyncMock(return_value=session_doc_snapshot) + + with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): + await service.append_event(session, event) + + # 1. Verify it was applied in-memory + assert session.state["temp:k1"] == "v1" + assert session.state["session_key"] == "session_val" + + # 2. Verify it was trimmed before Firestore save + transaction = mock_firestore_client.transaction.return_value + transaction.set.assert_called() + + # Filter calls for the one that actually sets the event data + event_set_calls = [ + call + for call in transaction.set.call_args_list + if len(call[0]) > 1 + and isinstance(call[0][1], dict) + and "event_data" in call[0][1] + ] + assert len(event_set_calls) == 1 + event_data = event_set_calls[0][0][1]["event_data"] + + # Temporary keys should be deleted from delta before snapshot + assert "temp:k1" not in event_data["actions"]["state_delta"] + assert event_data["actions"]["state_delta"]["session_key"] == "session_val" + + +@pytest.mark.asyncio +async def test_list_sessions_with_user_id(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + + session_doc = mock.MagicMock() + session_doc.to_dict.return_value = { + "id": "session1", + "appName": app_name, + "userId": user_id, + "state": {"session_key": "session_val"}, + } + + app_state_coll = mock.MagicMock() + user_state_coll = mock.MagicMock() + sessions_coll = mock.MagicMock() + + def collection_side_effect(name): + if name == service.app_state_collection: + return app_state_coll + elif name == service.user_state_collection: + return user_state_coll + elif name == service.root_collection: + return sessions_coll + return mock.MagicMock() + + mock_firestore_client.collection.side_effect = collection_side_effect + + app_doc = mock.MagicMock() + app_doc.exists = True + app_doc.to_dict.return_value = {"app_key": "app_val"} + app_doc_ref = mock.MagicMock() + app_state_coll.document.return_value = app_doc_ref + app_doc_ref.get = mock.AsyncMock(return_value=app_doc) + + user_doc = mock.MagicMock() + user_doc.exists = True + user_doc.to_dict.return_value = {"user_key": "user_val"} + user_app_doc = mock.MagicMock() + user_state_coll.document.return_value = user_app_doc + users_coll = mock.MagicMock() + user_app_doc.collection.return_value = users_coll + user_doc_ref = mock.MagicMock() + users_coll.document.return_value = user_doc_ref + user_doc_ref.get = mock.AsyncMock(return_value=user_doc) + + app_doc_in_root = mock.MagicMock() + sessions_coll.document.return_value = app_doc_in_root + users_coll = mock.MagicMock() + app_doc_in_root.collection.return_value = users_coll + user_doc_in_users = mock.MagicMock() + users_coll.document.return_value = user_doc_in_users + sessions_subcoll = mock.MagicMock() + user_doc_in_users.collection.return_value = sessions_subcoll + sessions_query = mock.MagicMock() + sessions_subcoll.where.return_value = sessions_query + sessions_query.get = mock.AsyncMock(return_value=[session_doc]) + + response = await service.list_sessions(app_name=app_name, user_id=user_id) + + assert len(response.sessions) == 1 + session = response.sessions[0] + assert session.id == "session1" + assert session.state["session_key"] == "session_val" + assert session.state["app:app_key"] == "app_val" + assert session.state["user:user_key"] == "user_val" + + +@pytest.mark.asyncio +async def test_list_sessions_without_user_id(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + + session_doc = mock.MagicMock() + session_doc.to_dict.return_value = { + "id": "session1", + "appName": app_name, + "userId": "user1", + "state": {"session_key": "session_val"}, + } + + mock_firestore_client.collection_group.return_value.where.return_value.get = ( + mock.AsyncMock(return_value=[session_doc]) + ) + + app_state_coll = mock.MagicMock() + user_state_coll = mock.MagicMock() + + def collection_side_effect(name): + if name == service.app_state_collection: + return app_state_coll + elif name == service.user_state_collection: + return user_state_coll + return mock.MagicMock() + + mock_firestore_client.collection.side_effect = collection_side_effect + + app_doc = mock.MagicMock() + app_doc.exists = True + app_doc.to_dict.return_value = {"app_key": "app_val"} + app_doc_ref = mock.MagicMock() + app_state_coll.document.return_value = app_doc_ref + app_doc_ref.get = mock.AsyncMock(return_value=app_doc) + + user_doc = mock.MagicMock() + user_doc.id = "user1" + user_doc.to_dict.return_value = {"user_key": "user_val"} + user_app_doc = mock.MagicMock() + user_state_coll.document.return_value = user_app_doc + users_coll = mock.MagicMock() + user_app_doc.collection.return_value = users_coll + users_coll.get = mock.AsyncMock(return_value=[user_doc]) + + response = await service.list_sessions(app_name=app_name) + + assert len(response.sessions) == 1 + session = response.sessions[0] + assert session.id == "session1" + assert session.state["app:app_key"] == "app_val" + assert session.state["user:user_key"] == "user_val" + + mock_firestore_client.collection_group.assert_called_once_with("sessions") + mock_firestore_client.collection_group.return_value.where.assert_called_once_with( + "appName", "==", app_name + ) + + +@pytest.mark.asyncio +async def test_list_sessions_filters_other_apps(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + + session_doc = mock.MagicMock() + session_doc.to_dict.return_value = { + "id": "session1", + "appName": app_name, + "userId": "user1", + "state": {"session_key": "session_val"}, + } + + mock_firestore_client.collection_group.return_value.where.return_value.get = ( + mock.AsyncMock(return_value=[session_doc]) + ) + + app_state_coll = mock.MagicMock() + user_state_coll = mock.MagicMock() + + def collection_side_effect(name): + if name == service.app_state_collection: + return app_state_coll + elif name == service.user_state_collection: + return user_state_coll + return mock.MagicMock() + + mock_firestore_client.collection.side_effect = collection_side_effect + + app_doc = mock.MagicMock() + app_doc.exists = True + app_doc.to_dict.return_value = {"app_key": "app_val"} + app_doc_ref = mock.MagicMock() + app_state_coll.document.return_value = app_doc_ref + app_doc_ref.get = mock.AsyncMock(return_value=app_doc) + + user_doc = mock.MagicMock() + user_doc.id = "user1" + user_doc.to_dict.return_value = {"user_key": "user_val"} + user_app_doc = mock.MagicMock() + user_state_coll.document.return_value = user_app_doc + users_coll = mock.MagicMock() + user_app_doc.collection.return_value = users_coll + users_coll.get = mock.AsyncMock(return_value=[user_doc]) + + response = await service.list_sessions(app_name=app_name) + + assert len(response.sessions) == 1 + assert response.sessions[0].id == "session1" + assert response.sessions[0].app_name == app_name + + mock_firestore_client.collection_group.assert_called_once_with("sessions") + mock_firestore_client.collection_group.return_value.where.assert_called_once_with( + "appName", "==", app_name + ) + + +@pytest.mark.asyncio +async def test_create_session_already_exists(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + + doc_snapshot = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + ) + doc_snapshot.exists = True + + from google.adk.errors.already_exists_error import AlreadyExistsError + + with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): + with pytest.raises(AlreadyExistsError): + await service.create_session( + app_name=app_name, user_id=user_id, session_id="existing_id" + ) + + +@pytest.mark.asyncio +async def test_get_session_with_config(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + doc_snapshot = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + ) + doc_snapshot.exists = True + doc_snapshot.to_dict.return_value = { + "id": session_id, + "appName": app_name, + "userId": user_id, + } + + events_collection_ref = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + ) + + from google.adk.sessions.base_session_service import GetSessionConfig + + config = GetSessionConfig(after_timestamp=1234567890.0, num_recent_events=5) + + await service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id, config=config + ) + + events_collection_ref.where.assert_called_once() + events_collection_ref.limit_to_last.assert_called_once_with(5) + + +@pytest.mark.asyncio +async def test_delete_session_batching(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + events_ref = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + ) + + dummy_docs = [mock.MagicMock() for _ in range(501)] + + async def to_async_iter(iterable): + for item in iterable: + yield item + + events_ref.stream.return_value = to_async_iter(dummy_docs) + + batch = mock_firestore_client.batch.return_value + + await service.delete_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + assert batch.commit.call_count == 2 + assert batch.delete.call_count == 501 + + +@pytest.mark.asyncio +async def test_append_event_partial(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name="test_app", user_id="test_user") + + event = Event(invocation_id="test_inv", author="user", partial=True) + + result = await service.append_event(session, event) + + assert result == event + mock_firestore_client.batch.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.asyncio +async def test_get_session_empty_data(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + doc_snapshot = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + ) + doc_snapshot.exists = True + doc_snapshot.to_dict.return_value = {} + + session = await service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + assert session is None + + +@pytest.mark.asyncio +async def test_list_sessions_missing_states(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + + session_doc = mock.MagicMock() + session_doc.to_dict.return_value = { + "id": "session1", + "appName": app_name, + "userId": user_id, + "state": {"session_key": "session_val"}, + } + + app_state_coll = mock.MagicMock() + user_state_coll = mock.MagicMock() + sessions_coll = mock.MagicMock() + + def collection_side_effect(name): + if name == service.app_state_collection: + return app_state_coll + elif name == service.user_state_collection: + return user_state_coll + elif name == service.root_collection: + return sessions_coll + return mock.MagicMock() + + mock_firestore_client.collection.side_effect = collection_side_effect + + app_doc = mock.MagicMock() + app_doc.exists = False + app_doc_ref = mock.MagicMock() + app_state_coll.document.return_value = app_doc_ref + app_doc_ref.get = mock.AsyncMock(return_value=app_doc) + + user_doc = mock.MagicMock() + user_doc.exists = False + user_app_doc = mock.MagicMock() + user_state_coll.document.return_value = user_app_doc + users_coll = mock.MagicMock() + user_app_doc.collection.return_value = users_coll + user_doc_ref = mock.MagicMock() + users_coll.document.return_value = user_doc_ref + user_doc_ref.get = mock.AsyncMock(return_value=user_doc) + + app_doc_in_root = mock.MagicMock() + sessions_coll.document.return_value = app_doc_in_root + users_coll = mock.MagicMock() + app_doc_in_root.collection.return_value = users_coll + user_doc_in_users = mock.MagicMock() + users_coll.document.return_value = user_doc_in_users + sessions_subcoll = mock.MagicMock() + user_doc_in_users.collection.return_value = sessions_subcoll + sessions_query = mock.MagicMock() + sessions_subcoll.where.return_value = sessions_query + sessions_query.get = mock.AsyncMock(return_value=[session_doc]) + + response = await service.list_sessions(app_name=app_name, user_id=user_id) + + assert len(response.sessions) == 1 + session = response.sessions[0] + assert session.id == "session1" + assert session.state["session_key"] == "session_val" + assert "_app_app_key" not in session.state + assert "_user_user_key" not in session.state