Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 4 additions & 16 deletions src/agents/extensions/memory/advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,26 +133,14 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
def _add_items_sync():
"""Synchronous helper to add items and structure metadata together."""
with self._locked_connection() as conn:
# Keep both writes in one critical section so message IDs and metadata stay aligned.
self._insert_items(conn, items)
conn.commit()
try:
# Keep both writes in one transaction so message IDs and metadata stay aligned.
self._insert_items(conn, items)
self._insert_structure_metadata(conn, items)
conn.commit()
except Exception as e:
except Exception:
conn.rollback()
self._logger.error(
f"Failed to add structure metadata for session {self.session_id}: {e}"
)
try:
deleted_count = self._cleanup_orphaned_messages_sync(conn)
if deleted_count:
conn.commit()
else:
conn.rollback()
except Exception as cleanup_error:
conn.rollback()
self._logger.error(f"Failed to cleanup orphaned messages: {cleanup_error}")
raise

await asyncio.to_thread(_add_items_sync)

Expand Down
32 changes: 32 additions & 0 deletions tests/extensions/memory/test_advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import json
import sqlite3
import tempfile
from pathlib import Path
from typing import Any, cast
Expand Down Expand Up @@ -101,6 +102,37 @@ async def test_advanced_session_basic_functionality(agent: Agent):
session.close()


async def test_add_items_rolls_back_when_structure_metadata_fails():
class BrokenMetadataSession(AdvancedSQLiteSession):
def _insert_structure_metadata(
self,
conn: sqlite3.Connection,
items: list[TResponseInputItem],
) -> None:
raise RuntimeError("metadata write failed")

session = BrokenMetadataSession(session_id="metadata_failure_test", create_tables=True)

with pytest.raises(RuntimeError, match="metadata write failed"):
await session.add_items([{"role": "user", "content": "hello"}])

with session._locked_connection() as conn:
message_count = conn.execute(
f"SELECT COUNT(*) FROM {session.messages_table} WHERE session_id = ?",
(session.session_id,),
).fetchone()[0]
structure_count = conn.execute(
"SELECT COUNT(*) FROM message_structure WHERE session_id = ?",
(session.session_id,),
).fetchone()[0]

assert message_count == 0
assert structure_count == 0
assert await session.get_items() == []

session.close()


async def test_advanced_session_respects_custom_table_names():
"""AdvancedSQLiteSession should consistently use configured table names."""
session = AdvancedSQLiteSession(
Expand Down