From a75c22eaac6f8bb09e8c51caaf70f384650d603f Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Fri, 29 May 2026 03:22:18 +0800 Subject: [PATCH] fix: keep advanced session message writes atomic --- .../memory/advanced_sqlite_session.py | 20 +++--------- .../memory/test_advanced_sqlite_session.py | 32 +++++++++++++++++++ 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/src/agents/extensions/memory/advanced_sqlite_session.py b/src/agents/extensions/memory/advanced_sqlite_session.py index 83c289bdf8..66b5861bda 100644 --- a/src/agents/extensions/memory/advanced_sqlite_session.py +++ b/src/agents/extensions/memory/advanced_sqlite_session.py @@ -133,26 +133,14 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: def _add_items_sync(): """Synchronous helper to add items and structure metadata together.""" with self._locked_connection() as conn: - # Keep both writes in one critical section so message IDs and metadata stay aligned. - self._insert_items(conn, items) - conn.commit() try: + # Keep both writes in one transaction so message IDs and metadata stay aligned. + self._insert_items(conn, items) self._insert_structure_metadata(conn, items) conn.commit() - except Exception as e: + except Exception: conn.rollback() - self._logger.error( - f"Failed to add structure metadata for session {self.session_id}: {e}" - ) - try: - deleted_count = self._cleanup_orphaned_messages_sync(conn) - if deleted_count: - conn.commit() - else: - conn.rollback() - except Exception as cleanup_error: - conn.rollback() - self._logger.error(f"Failed to cleanup orphaned messages: {cleanup_error}") + raise await asyncio.to_thread(_add_items_sync) diff --git a/tests/extensions/memory/test_advanced_sqlite_session.py b/tests/extensions/memory/test_advanced_sqlite_session.py index ad4b5c4d86..bf1976e1bc 100644 --- a/tests/extensions/memory/test_advanced_sqlite_session.py +++ b/tests/extensions/memory/test_advanced_sqlite_session.py @@ -2,6 +2,7 @@ import asyncio import json +import sqlite3 import tempfile from pathlib import Path from typing import Any, cast @@ -101,6 +102,37 @@ async def test_advanced_session_basic_functionality(agent: Agent): session.close() +async def test_add_items_rolls_back_when_structure_metadata_fails(): + class BrokenMetadataSession(AdvancedSQLiteSession): + def _insert_structure_metadata( + self, + conn: sqlite3.Connection, + items: list[TResponseInputItem], + ) -> None: + raise RuntimeError("metadata write failed") + + session = BrokenMetadataSession(session_id="metadata_failure_test", create_tables=True) + + with pytest.raises(RuntimeError, match="metadata write failed"): + await session.add_items([{"role": "user", "content": "hello"}]) + + with session._locked_connection() as conn: + message_count = conn.execute( + f"SELECT COUNT(*) FROM {session.messages_table} WHERE session_id = ?", + (session.session_id,), + ).fetchone()[0] + structure_count = conn.execute( + "SELECT COUNT(*) FROM message_structure WHERE session_id = ?", + (session.session_id,), + ).fetchone()[0] + + assert message_count == 0 + assert structure_count == 0 + assert await session.get_items() == [] + + session.close() + + async def test_advanced_session_respects_custom_table_names(): """AdvancedSQLiteSession should consistently use configured table names.""" session = AdvancedSQLiteSession(