diff --git a/README.md b/README.md index 5ffcbfe..279e216 100644 --- a/README.md +++ b/README.md @@ -16,10 +16,10 @@ tRPC-Agent-Python provides an end-to-end foundation for agent building, orchestr - **Multi-paradigm agent orchestration**: Built-in orchestration supports `ChainAgent` / `ParallelAgent` / `CycleAgent` / `TransferAgent`, with `GraphAgent` for graph-based orchestration. - **Graph orchestration capability (`GraphAgent`)**: Use DSL to orchestrate `Agent` / `Tool` / `MCP` / `Knowledge` / `CodeExecutor` in one unified flow. -- **Efficient integration with Python AI ecosystems**: Agent ecosystem extensions (`claude-agent-sdk` / `LangGraph`, etc.) / Tool ecosystem extensions (`mcp`, etc.) / Knowledge ecosystem extensions (`LangChain`, etc.) / Model ecosystem extensions (`LiteLLM`, etc.) / Memory ecosystem extensions (`Mem0`, etc.). +- **Efficient integration with Python AI ecosystems**: Agent ecosystem extensions (`claude-agent-sdk` / `LangGraph`, etc.) / Tool ecosystem extensions (`mcp`, etc.) / Knowledge ecosystem extensions (`LangChain`, etc.) / Model ecosystem extensions (`LiteLLM`, etc.) / Memory ecosystem extensions (`Mem0`, `Mempalace`, etc.). - **Agent ecosystem extensions**: Supports `LangGraphAgent` / `ClaudeAgent` / `TeamAgent` (Agno-Like). - **Tool ecosystem extensions**: `FunctionTool` / File tools / `MCPToolset` / LangChain Tool / Agent-as-Tool. -- **Complete memory capability (`Session` / `Memory`)**: `Session` manages messages and state within a single session, while `Memory` manages cross-session long-term memory and personalization. Persistence supports `InMemory` / `Redis` / `SQL`; `Memory` also supports `Mem0`. +- **Complete memory capability (`Session` / `Memory`)**: `Session` manages messages and state within a single session, while `Memory` manages cross-session long-term memory and personalization. Persistence supports `InMemory` / `Redis` / `SQL`; `Memory` also supports `Mem0`、`Mempalace`. - **Production-grade knowledge capability**: Built on LangChain components with first-class RAG support. - **CodeExecutor extension capability**: Supports local / container executors for code execution and task grounding. - **Skills extension capability**: Supports `SKILL.md`-based skill systems for reusable capabilities and dynamic tooling. @@ -84,7 +84,7 @@ pip install trpc-agent-py Install optional capabilities as needed: ```bash -pip install trpc-agent-py[a2a,ag-ui,knowledge,agent-claude,mem0,langfuse] +pip install trpc-agent-py[a2a,ag-ui,knowledge,agent-claude,mem0, Mempalace, langfuse] ``` ### Develop Weather Agent @@ -457,7 +457,7 @@ Related docs: This group helps you: - Session: manage per-session messages, summaries, and state -- Memory: manage cross-session long-term memory (including Mem0) +- Memory: manage cross-session long-term memory (including Mem0, Mempalace) - Knowledge: cover document loading, retrieval, RAG, and prompt templates ### 10. Serving and Protocols @@ -545,12 +545,12 @@ The framework is organized in an event-driven architecture where each layer can - **Runner layer**: Unified execution entry, coordinating Session / Memory / Artifact services - **Tool layer**: FunctionTool / file tools / MCPToolset / Skill tools - **Model layer**: OpenAIModel / AnthropicModel / LiteLLMModel -- **Memory layer**: SessionService / MemoryService / SessionSummarizer / Mem0MemoryService +- **Memory layer**: SessionService / MemoryService / SessionSummarizer / Mem0MemoryService / MempalaceMemoryService - **Knowledge layer**: Production-grade LangChain-based knowledge and RAG capability - **Execution and skill layer**: CodeExecutor (local / container) / Skills - **Service layer**: FastAPI / A2A / AG-UI - **Observability layer**: OpenTelemetry tracing/metrics, integrable with platforms like Langfuse -- **Ecosystem adapter layer**: claude-agent-sdk / mcp / LangChain / LiteLLM / Mem0 plugged into the main chain through model/tool/memory adapters +- **Ecosystem adapter layer**: claude-agent-sdk / mcp / LangChain / LiteLLM / Mem0 / Mempalace plugged into the main chain through model/tool/memory adapters Key packages: diff --git a/README.zh_CN.md b/README.zh_CN.md index a4cf0e6..b5619c2 100644 --- a/README.zh_CN.md +++ b/README.zh_CN.md @@ -16,10 +16,10 @@ tRPC-Agent-Python 提供从 Agent 构建、编排、工具接入、会话记忆 - **多范式 Agent 编排**:预设编排支持 ChainAgent / ParallelAgent / CycleAgent / TransferAgent,同时支持 GraphAgent 图编排 - **图编排能力(GraphAgent)**:通过 DSL 统一编排 Agent / Tool / MCP / Knowledge / CodeExecutor -- **高效接入 Python AI 生态扩展**:Agent 生态扩展(claude-agent-sdk / LangGraph 等)/ 工具生态扩展(mcp 等)/ 知识库生态扩展(LangChain 等)/ 模型生态扩展(LiteLLM 等)/ 记忆生态扩展(Mem0 等) +- **高效接入 Python AI 生态扩展**:Agent 生态扩展(claude-agent-sdk / LangGraph 等)/ 工具生态扩展(mcp 等)/ 知识库生态扩展(LangChain 等)/ 模型生态扩展(LiteLLM 等)/ 记忆生态扩展(Mem0、Mempalace等) - **Agent 生态扩展**:支持 LangGraphAgent / ClaudeAgent / TeamAgent(Agno-Like) - **Tool 生态扩展**:FunctionTool / 文件工具 / MCPToolset / LangChain Tool / Agent-as-Tool -- **完善的记忆能力(Session / Memory)**:Session 负责单会话内的消息与状态管理,Memory 负责跨会话长期记忆与个性化信息沉淀。持久化支持 InMemory / Redis / SQL,Memory 还支持 Mem0 +- **完善的记忆能力(Session / Memory)**:Session 负责单会话内的消息与状态管理,Memory 负责跨会话长期记忆与个性化信息沉淀。持久化支持 InMemory / Redis / SQL,Memory 还支持 Mem0、Mempalace - **生产级知识库能力**:知识库能力基于 LangChain 组件构建,支持 RAG 场景 - **CodeExecutor 扩展能力**:支持本地 / 容器执行器,用于支持 Agent 的代码执行与任务落地能力 - **Skills 扩展能力**:支持 SKILL.md 技能体系,用于支持 Agent 的技能复用与动态工具化能力 @@ -84,7 +84,7 @@ pip install trpc-agent-py 按需安装扩展能力: ```bash -pip install trpc-agent-py[a2a,ag-ui,knowledge,agent-claude,mem0,langfuse] +pip install trpc-agent-py[a2a,ag-ui,knowledge,agent-claude,mem0, Mempalace, langfuse] ``` @@ -458,7 +458,7 @@ skill_tool_set = SkillToolSet(repository=repository, run_tool_kwargs=tool_kwargs 这组示例可以帮你: - Session:管理单会话的消息、摘要与状态 -- Memory:管理跨会话长期记忆(含 Mem0) +- Memory:管理跨会话长期记忆(含 Mem0, Mempalace) - Knowledge:覆盖文档加载、检索、RAG、提示模板等链路 ### 10. 服务化与协议 @@ -546,12 +546,12 @@ skill_tool_set = SkillToolSet(repository=repository, run_tool_kwargs=tool_kwargs - **Runner 层**:统一执行入口,负责 Session/Memory/Artifact 等服务协同 - **Tool 层**:FunctionTool / 文件工具 / MCPToolset / Skill 工具 - **Model 层**:OpenAIModel / AnthropicModel / LiteLLMModel -- **Memory 层**:SessionService / MemoryService / SessionSummarizer / Mem0MemoryService +- **Memory 层**:SessionService / MemoryService / SessionSummarizer / Mem0MemoryService / MempalaceMemoryService - **Knowledge 层**:基于 LangChain 的生产级知识库能力(RAG) - **执行与技能层**:CodeExecutor(本地/容器)/ Skills - **服务层**:FastAPI / A2A / AG-UI - **观测层**:OpenTelemetry tracing/metrics,可对接 Langfuse 等平台 -- **生态适配层**:claude-agent-sdk / mcp / LangChain / LiteLLM / Mem0,通过模型/工具/记忆适配器接入主链路 +- **生态适配层**:claude-agent-sdk / mcp / LangChain / LiteLLM / Mem0 / MemoryService,通过模型/工具/记忆适配器接入主链路 关键包一览: diff --git a/examples/memory_service_with_sql/agent/agent.py b/examples/memory_service_with_sql/agent/agent.py index 44885d8..1f75d3b 100644 --- a/examples/memory_service_with_sql/agent/agent.py +++ b/examples/memory_service_with_sql/agent/agent.py @@ -10,6 +10,8 @@ from trpc_agent_sdk.models import OpenAIModel from trpc_agent_sdk.tools import FunctionTool from trpc_agent_sdk.tools import load_memory_tool +from trpc_agent_sdk.types import GenerateContentConfig +from trpc_agent_sdk.types import HttpOptions from .config import get_model_config from .prompts import INSTRUCTION @@ -25,12 +27,18 @@ def _create_model() -> LLMModel: def create_agent() -> LlmAgent: """ Create an agent""" + generate_content_config = GenerateContentConfig( + http_options=HttpOptions(extra_body={"chat_template_kwargs": { + "enable_thinking": False + }}), + ) agent = LlmAgent( name="assistant", description="A helpful assistant for conversation", model=_create_model(), # You can change this to your preferred model instruction=INSTRUCTION, tools=[FunctionTool(get_weather_report), load_memory_tool], + generate_content_config=generate_content_config, ) return agent diff --git a/examples/session_service_with_sql/agent/agent.py b/examples/session_service_with_sql/agent/agent.py index 2979490..da07751 100644 --- a/examples/session_service_with_sql/agent/agent.py +++ b/examples/session_service_with_sql/agent/agent.py @@ -9,6 +9,8 @@ from trpc_agent_sdk.models import LLMModel from trpc_agent_sdk.models import OpenAIModel from trpc_agent_sdk.tools import FunctionTool +from trpc_agent_sdk.types import GenerateContentConfig +from trpc_agent_sdk.types import HttpOptions from .config import get_model_config from .prompts import INSTRUCTION @@ -24,12 +26,18 @@ def _create_model() -> LLMModel: def create_agent() -> LlmAgent: """ Create an agent""" + generate_content_config = GenerateContentConfig( + http_options=HttpOptions(extra_body={"chat_template_kwargs": { + "enable_thinking": False + }}), + ) agent = LlmAgent( name="assistant", description="A helpful assistant for conversation", model=_create_model(), # You can change this to your preferred model instruction=INSTRUCTION, tools=[FunctionTool(get_weather_report)], + generate_content_config=generate_content_config, ) return agent diff --git a/examples/session_service_with_sql/run_agent.py b/examples/session_service_with_sql/run_agent.py index 5f328f1..431a3af 100644 --- a/examples/session_service_with_sql/run_agent.py +++ b/examples/session_service_with_sql/run_agent.py @@ -76,10 +76,9 @@ async def run_weather_agent(): ] for query in demo_queries: - # Use a new session for each query - user_content = Content(parts=[Part.from_text(text=query)]) + print(f"👤 User: {query}") print("🤖 Assistant: ", end="", flush=True) async for event in runner.run_async(user_id=user_id, session_id=current_session_id, new_message=user_content): # Check if event.content exists diff --git a/tests/common/test_compatible.py b/tests/common/test_compatible.py index 6e758e2..8b4abdd 100644 --- a/tests/common/test_compatible.py +++ b/tests/common/test_compatible.py @@ -7,7 +7,7 @@ Covers: - PY_310 version flag -- checkenum() with standard enums, IntEnum, and fallback path +- check_enum() with standard enums, IntEnum, and fallback path - OSDetector: platform detection, properties, get_os_name, get_os_info, __str__ - OS_DETECTOR module-level singleton """ @@ -21,7 +21,7 @@ import pytest -from trpc_agent_sdk.common._compatible import OS_DETECTOR, OSDetector, PY_310, checkenum +from trpc_agent_sdk.common._compatible import OS_DETECTOR, OSDetector, PY_310, check_enum # --------------------------------------------------------------------------- @@ -40,7 +40,7 @@ def test_py310_matches_runtime(self): # --------------------------------------------------------------------------- -# checkenum +# check_enum # --------------------------------------------------------------------------- @@ -63,32 +63,32 @@ class _FlagEnum(enum.Flag): class TestCheckenum: - """Tests for checkenum().""" + """Tests for check_enum().""" def test_valid_enum_member(self): - assert checkenum(_Color.RED, _Color) is True + assert check_enum(_Color.RED, _Color) is True def test_invalid_enum_member(self): - assert checkenum("yellow", _Color) is False + assert check_enum("yellow", _Color) is False def test_valid_int_enum_member(self): - assert checkenum(_Priority.HIGH, _Priority) is True + assert check_enum(_Priority.HIGH, _Priority) is True def test_invalid_int_enum_member(self): - assert checkenum(99, _Priority) is False + assert check_enum(99, _Priority) is False def test_valid_flag_enum_member(self): - assert checkenum(_FlagEnum.READ, _FlagEnum) is True + assert check_enum(_FlagEnum.READ, _FlagEnum) is True def test_string_value_is_found_by_value(self): # Python 3.12+ enum __contains__ matches by value - assert checkenum("red", _Color) is True + assert check_enum("red", _Color) is True def test_string_not_matching_any_value(self): - assert checkenum("magenta", _Color) is False + assert check_enum("magenta", _Color) is False def test_none_is_not_member(self): - assert checkenum(None, _Color) is False + assert check_enum(None, _Color) is False def test_fallback_to_members_values(self): """When ``in`` raises, falls back to __members__.values().""" @@ -108,8 +108,8 @@ def __contains__(self, item): def __iter__(self): raise TypeError("broken __iter__") - assert checkenum("a", _BadContains()) is True - assert checkenum("c", _BadContains()) is False + assert check_enum("a", _BadContains()) is True + assert check_enum("c", _BadContains()) is False # --------------------------------------------------------------------------- diff --git a/tests/common/test_init.py b/tests/common/test_init.py index 7fd0273..2003357 100644 --- a/tests/common/test_init.py +++ b/tests/common/test_init.py @@ -11,11 +11,11 @@ from __future__ import annotations import trpc_agent_sdk.common as common_mod -from trpc_agent_sdk.common import OS_DETECTOR, OSDetector, checkenum +from trpc_agent_sdk.common import OS_DETECTOR, OSDetector, check_enum from trpc_agent_sdk.common._compatible import ( OS_DETECTOR as _ORIG_OS_DETECTOR, OSDetector as _OrigOSDetector, - checkenum as _orig_checkenum, + check_enum as _orig_check_enum, ) @@ -23,7 +23,7 @@ class TestPublicExports: """Ensure __init__.py re-exports the right objects.""" def test_all_contains_expected_names(self): - assert set(common_mod.__all__) == {"OSDetector", "OS_DETECTOR", "checkenum"} + assert set(common_mod.__all__) == {"OSDetector", "OS_DETECTOR", "check_enum"} def test_os_detector_class_is_same_object(self): assert OSDetector is _OrigOSDetector @@ -32,4 +32,4 @@ def test_os_detector_instance_is_same_object(self): assert OS_DETECTOR is _ORIG_OS_DETECTOR def test_checkenum_is_same_function(self): - assert checkenum is _orig_checkenum + assert check_enum is _orig_check_enum diff --git a/tests/memory/test_sql_memory_service.py b/tests/memory/test_sql_memory_service.py index a8143a1..7a7ab8c 100644 --- a/tests/memory/test_sql_memory_service.py +++ b/tests/memory/test_sql_memory_service.py @@ -278,17 +278,6 @@ async def test_store_skips_events_without_content(self): svc._sql_storage.add.assert_not_called() svc._sql_storage.commit.assert_not_called() - async def test_store_raises_on_non_session(self): - svc = SqlMemoryService.__new__(SqlMemoryService) - svc._memory_service_config = _make_config_no_ttl() - svc._sql_storage = _patch_sql_storage() - svc._SqlMemoryService__cleanup_task = None - svc._SqlMemoryService__cleanup_stop_event = None - - with pytest.raises(TypeError, match="Content must be a Session"): - await svc.store_session("not a session") - - # --------------------------------------------------------------------------- # SqlMemoryService — search_memory # --------------------------------------------------------------------------- diff --git a/tests/models/test_anthropic_model_ext.py b/tests/models/test_anthropic_model_ext.py index d0bb071..60630c0 100644 --- a/tests/models/test_anthropic_model_ext.py +++ b/tests/models/test_anthropic_model_ext.py @@ -519,14 +519,5 @@ def test_part_with_inline_data_is_valid(self): request = LlmRequest(contents=[Content(parts=[part], role="user")]) model.validate_request(request) - def test_part_with_no_content_raises(self): - """Part with no meaningful content raises ValueError.""" - model = _model() - part = Part() - request = LlmRequest(contents=[Content(parts=[part], role="user")]) - with pytest.raises(ValueError, match="Content parts must have"): - model.validate_request(request) - - if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/models/test_openai_model.py b/tests/models/test_openai_model.py index dd70515..d806e7b 100644 --- a/tests/models/test_openai_model.py +++ b/tests/models/test_openai_model.py @@ -148,18 +148,6 @@ def test_validate_request_with_multiple_contents(self): # Should not raise model.validate_request(request) - def test_validate_request_with_parts_without_content(self): - """Test validating request with parts that have no actual content raises ValueError.""" - model = OpenAIModel(model_name="gpt-4", api_key="test_key") - - # Create a part with all None fields - part = Part() - content = Content(parts=[part], role="user") - request = LlmRequest(contents=[content], config=None, tools_dict={}) - - with pytest.raises(ValueError, match="Content parts must have"): - model.validate_request(request) - def test_properties_and_config(self): """Test model properties and config.""" model = OpenAIModel( diff --git a/tests/sessions/test_sql_session_service.py b/tests/sessions/test_sql_session_service.py index 9925f60..6c72795 100644 --- a/tests/sessions/test_sql_session_service.py +++ b/tests/sessions/test_sql_session_service.py @@ -90,12 +90,39 @@ def test_from_event_with_function_call(self): storage_event = SessionStorageEvent.from_event(session, event) assert storage_event.content is not None + def test_from_event_drops_empty_parts(self): + session = Session(id="s1", app_name="app", user_id="user", save_key="k") + event = Event( + invocation_id="inv-1", + author="agent", + content=Content(parts=[Part()]), + ) + storage_event = SessionStorageEvent.from_event(session, event) + assert storage_event.content is None + def test_from_event_no_content(self): session = Session(id="s1", app_name="app", user_id="user", save_key="k") event = Event(invocation_id="inv-1", author="agent", actions=EventActions()) storage_event = SessionStorageEvent.from_event(session, event) assert storage_event.content is None + def test_to_event_drops_legacy_empty_parts(self): + storage_event = SessionStorageEvent( + id="e1", + app_name="app", + user_id="user", + session_id="s1", + invocation_id="inv-1", + author="agent", + actions=EventActions(), + long_running_tool_ids=set(), + timestamp=datetime.now(), + model_flags=1, + content={"parts": [{}], "role": "model"}, + ) + event = storage_event.to_event() + assert event.content is None + def test_long_running_tool_ids_property(self): session = Session(id="s1", app_name="app", user_id="user", save_key="k") event = _make_event() diff --git a/tests/storage/test_sql_common.py b/tests/storage/test_sql_common.py index 639c654..bfd83d1 100644 --- a/tests/storage/test_sql_common.py +++ b/tests/storage/test_sql_common.py @@ -332,6 +332,14 @@ def test_process_bind_param_spanner(self): result = dpt.process_bind_param(value, dialect) assert pickle.loads(result) == value + def test_process_bind_param_mysql(self): + dpt = DynamicPickleType() + dialect = _make_dialect("mysql") + value = {"key": "value", "nums": [1, 2, 3]} + result = dpt.process_bind_param(value, dialect) + assert isinstance(result, bytes) + assert pickle.loads(result) == value + def test_process_bind_param_non_spanner(self): dpt = DynamicPickleType() dialect = _make_dialect("sqlite") @@ -352,6 +360,14 @@ def test_process_result_value_spanner(self): result = dpt.process_result_value(pickled, dialect) assert result == original + def test_process_result_value_mysql(self): + dpt = DynamicPickleType() + dialect = _make_dialect("mysql") + original = {"key": "value", "nums": [1, 2, 3]} + pickled = pickle.dumps(original) + result = dpt.process_result_value(pickled, dialect) + assert result == original + def test_process_result_value_non_spanner(self): dpt = DynamicPickleType() dialect = _make_dialect("sqlite") diff --git a/trpc_agent_sdk/common/__init__.py b/trpc_agent_sdk/common/__init__.py index 7311c69..aa6652a 100644 --- a/trpc_agent_sdk/common/__init__.py +++ b/trpc_agent_sdk/common/__init__.py @@ -9,10 +9,10 @@ from ._compatible import OSDetector from ._compatible import OS_DETECTOR -from ._compatible import checkenum +from ._compatible import check_enum __all__ = [ "OSDetector", "OS_DETECTOR", - "checkenum", + "check_enum", ] diff --git a/trpc_agent_sdk/common/_compatible.py b/trpc_agent_sdk/common/_compatible.py index 9b783ad..19f3681 100644 --- a/trpc_agent_sdk/common/_compatible.py +++ b/trpc_agent_sdk/common/_compatible.py @@ -9,11 +9,12 @@ import sys from typing import Any from typing import Dict +from enum import Enum PY_310 = sys.version_info >= (3, 10) -def checkenum(value, enum_class) -> bool: +def check_enum(value: Any, enum_class: type[Enum]) -> bool: """Check if a value is a valid member of an enum class.""" try: return value in enum_class diff --git a/trpc_agent_sdk/memory/_in_memory_memory_service.py b/trpc_agent_sdk/memory/_in_memory_memory_service.py index c477734..992a9b2 100644 --- a/trpc_agent_sdk/memory/_in_memory_memory_service.py +++ b/trpc_agent_sdk/memory/_in_memory_memory_service.py @@ -200,7 +200,7 @@ def _start_cleanup_task(self) -> None: return self.__cleanup_stop_event = asyncio.Event() - self.__cleanup_task = asyncio.create_task(self._cleanup_loop()) + self.__cleanup_task = asyncio.get_event_loop().create_task(self._cleanup_loop()) logger.debug("Cleanup task created") def _stop_cleanup_task(self) -> None: diff --git a/trpc_agent_sdk/memory/_sql_memory_service.py b/trpc_agent_sdk/memory/_sql_memory_service.py index 7ac2a42..244d72a 100644 --- a/trpc_agent_sdk/memory/_sql_memory_service.py +++ b/trpc_agent_sdk/memory/_sql_memory_service.py @@ -43,6 +43,7 @@ from trpc_agent_sdk.storage import SqlStorage from trpc_agent_sdk.storage import decode_content from trpc_agent_sdk.storage import decode_grounding_metadata +from trpc_agent_sdk.storage import sanitize_content_json from ._utils import extract_words_lower from ._utils import format_timestamp @@ -110,7 +111,7 @@ def update_event(self, session: Session, event: Event): self.error_message = event.error_message self.interrupted = event.interrupted if event.content: - self.content = event.content.model_dump(exclude_none=True, mode="json") + self.content = sanitize_content_json(event.content.model_dump(exclude_none=True, mode="json")) if event.grounding_metadata: self.grounding_metadata = event.grounding_metadata.model_dump(exclude_none=True, mode="json") if event.custom_metadata: @@ -135,7 +136,7 @@ def from_event(cls, session: Session, event: Event) -> MemStorageEvent: interrupted=event.interrupted, ) if event.content: - storage_event.content = event.content.model_dump(exclude_none=True, mode="json") + storage_event.content = sanitize_content_json(event.content.model_dump(exclude_none=True, mode="json")) if event.grounding_metadata: storage_event.grounding_metadata = event.grounding_metadata.model_dump(exclude_none=True, mode="json") if event.custom_metadata: @@ -150,7 +151,7 @@ def to_event(self) -> Event: branch=self.branch, actions=self.actions, # type: ignore timestamp=self.timestamp.timestamp(), - content=decode_content(self.content), + content=decode_content(sanitize_content_json(self.content)), long_running_tool_ids=self.long_running_tool_ids, partial=self.partial, turn_complete=self.turn_complete, @@ -194,15 +195,15 @@ async def store_session(self, session: Session, agent_context: Optional[AgentCon Only stores events that are not expired based on event_ttl_seconds. """ - if not isinstance(session, Session): - raise TypeError(f"Content must be a Session, got {type(session)}") - async with self._sql_storage.create_db_session() as sql_session: is_exist = False for event in session.events: if not event.is_model_visible(): continue - if event.content and event.content.parts: + if not event.content or not event.content.parts: + continue + content = sanitize_content_json(event.content.model_dump(exclude_none=True, mode="json")) + if content: is_exist = True # Check if the event already exists event_key = SqlKey(key=(event.id, session.save_key, session.id), storage_cls=MemStorageEvent) @@ -324,7 +325,7 @@ def _start_cleanup_task(self) -> None: return self.__cleanup_stop_event = asyncio.Event() - self.__cleanup_task = asyncio.create_task(self._cleanup_loop()) + self.__cleanup_task = asyncio.get_event_loop().create_task(self._cleanup_loop()) logger.debug("Memory cleanup task created") def _stop_cleanup_task(self) -> None: diff --git a/trpc_agent_sdk/memory/mem0_memory_service.py b/trpc_agent_sdk/memory/mem0_memory_service.py index 9fccd4e..2878109 100644 --- a/trpc_agent_sdk/memory/mem0_memory_service.py +++ b/trpc_agent_sdk/memory/mem0_memory_service.py @@ -328,7 +328,7 @@ def _start_cleanup_task(self) -> None: return self.__cleanup_stop_event = asyncio.Event() - self.__cleanup_task = asyncio.create_task(self._cleanup_loop()) + self.__cleanup_task = asyncio.get_event_loop().create_task(self._cleanup_loop()) logger.debug("Mem0 memory cleanup task created") def _stop_cleanup_task(self) -> None: diff --git a/trpc_agent_sdk/models/_anthropic_model.py b/trpc_agent_sdk/models/_anthropic_model.py index 845d960..064d813 100644 --- a/trpc_agent_sdk/models/_anthropic_model.py +++ b/trpc_agent_sdk/models/_anthropic_model.py @@ -38,8 +38,6 @@ from ._llm_response import LlmResponse from ._registry import register_model -_VALID_ROLES: set[str] = {const.USER, const.ASSISTANT, const.MODEL, const.SYSTEM} - class _FinishReason(str, Enum): """Reasons why model generation finished.""" @@ -117,41 +115,6 @@ def _create_async_client(self): **self.client_args, ) - @override - def validate_request(self, request: LlmRequest) -> None: - """Validate the request before processing.""" - - if not request.contents: - raise ValueError("At least one content is required") - - # Validate content structure - for content in request.contents: - if not content.parts: - raise ValueError("Content must have at least one part") - - # Check if content has valid role - if content.role and content.role not in _VALID_ROLES: - raise ValueError(f"Invalid content role: {content.role}") - - # Validate parts have content - has_content = False - for part in content.parts: - condition_iter = [ - part.text, - part.function_call, - part.function_response, - part.code_execution_result, - part.executable_code, - part.inline_data, - ] - has_content = any(condition_iter) - if has_content: - break - - if not has_content: - raise ValueError("Content parts must have text, function_call, function_response, " - "code_execution_result, executable_code, or inline_data") - def _to_claude_role(self, role: Optional[str]) -> Literal["user", "assistant"]: """Convert role to Claude format.""" if role in [const.MODEL, const.ASSISTANT]: diff --git a/trpc_agent_sdk/models/_llm_model.py b/trpc_agent_sdk/models/_llm_model.py index fb94226..940afa1 100644 --- a/trpc_agent_sdk/models/_llm_model.py +++ b/trpc_agent_sdk/models/_llm_model.py @@ -27,6 +27,8 @@ from ._llm_request import LlmRequest from ._llm_response import LlmResponse +_VALID_ROLES: set[str] = {const.USER, const.ASSISTANT, const.MODEL, const.SYSTEM} + class LLMModel(FilterRunner): """Abstract base class for all model implementations.""" @@ -106,7 +108,6 @@ async def _generate_async_impl(self, Error responses should have error_code and error_message set. """ - @abstractmethod def validate_request(self, request: LlmRequest) -> None: """Validate the request before processing. @@ -119,6 +120,17 @@ def validate_request(self, request: LlmRequest) -> None: Raises: ValueError: If request is invalid """ + if not request.contents: + raise ValueError("At least one content is required") + + # Validate content structure + for content in request.contents: + if not content.parts: + raise ValueError("Content must have at least one part") + + # Check if content has valid role + if content.role and content.role not in _VALID_ROLES: + raise ValueError(f"Invalid content role: {content.role}") @property def name(self) -> str: diff --git a/trpc_agent_sdk/models/_openai_model.py b/trpc_agent_sdk/models/_openai_model.py index d1c9e2e..e5ad05d 100644 --- a/trpc_agent_sdk/models/_openai_model.py +++ b/trpc_agent_sdk/models/_openai_model.py @@ -24,7 +24,7 @@ import openai from pydantic import BaseModel -from trpc_agent_sdk.common import checkenum +from trpc_agent_sdk.common import check_enum from trpc_agent_sdk.context import InvocationContext from trpc_agent_sdk.log import logger from trpc_agent_sdk.types import Content @@ -44,8 +44,6 @@ from .tool_prompt import get_factory from .tool_prompt._base import ToolPrompt -VALID_ROLES: set[str] = {const.USER, const.ASSISTANT, const.MODEL, const.SYSTEM} - class ToolCall(BaseModel): """Represents a tool call made by the model.""" @@ -214,37 +212,6 @@ def _create_tool_prompt(self) -> ToolPrompt: return factory.create(self.tool_prompt) return self.tool_prompt() - @override - def validate_request(self, request: LlmRequest) -> None: - """Validate the request before processing.""" - - if not request.contents: - raise ValueError("At least one content is required") - - # Validate content structure - for content in request.contents: - if not content.parts: - raise ValueError("Content must have at least one part") - - # Check if content has valid role - if content.role and content.role not in VALID_ROLES: - raise ValueError(f"Invalid content role: {content.role}") - - # Validate parts have content - has_content = False - for part in content.parts: - condition_iter = [ - part.text, part.function_call, part.function_response, part.code_execution_result, - part.executable_code, part.inline_data - ] - has_content = any(condition_iter) - if has_content: - break - - if not has_content: - raise ValueError("Content parts must have text, function_call, function_response, " - "code_execution_result, executable_code, or inline_data") - def _get_part_thought_signature(self, part: Part) -> str: """Get thought_signature from Part as str; return dummy if missing. See https://ai.google.dev/gemini-api/docs/thought-signatures (Gemini 3+). @@ -551,7 +518,7 @@ def _validate_and_fix_openai_messages(self, messages: List[Dict[str, Any]]) -> L def _parse_finish_reason(self, finish_reason: str) -> FinishReason: """Convert OpenAI finish reason to our enum.""" - if not checkenum(finish_reason, FinishReason): + if not check_enum(finish_reason, FinishReason): return FinishReason.ERROR return FinishReason(finish_reason) diff --git a/trpc_agent_sdk/sessions/_in_memory_session_service.py b/trpc_agent_sdk/sessions/_in_memory_session_service.py index 938e202..fc5f400 100644 --- a/trpc_agent_sdk/sessions/_in_memory_session_service.py +++ b/trpc_agent_sdk/sessions/_in_memory_session_service.py @@ -399,7 +399,7 @@ def _start_cleanup_task(self) -> None: return self.__cleanup_stop_event = asyncio.Event() - self.__cleanup_task = asyncio.create_task(self._cleanup_loop()) + self.__cleanup_task = asyncio.get_event_loop().create_task(self._cleanup_loop()) logger.debug("Cleanup task created") def _stop_cleanup_task(self) -> None: diff --git a/trpc_agent_sdk/sessions/_sql_session_service.py b/trpc_agent_sdk/sessions/_sql_session_service.py index b7e715a..dc07e2b 100644 --- a/trpc_agent_sdk/sessions/_sql_session_service.py +++ b/trpc_agent_sdk/sessions/_sql_session_service.py @@ -65,6 +65,7 @@ from trpc_agent_sdk.storage import decode_content from trpc_agent_sdk.storage import decode_grounding_metadata from trpc_agent_sdk.storage import decode_usage_metadata +from trpc_agent_sdk.storage import sanitize_content_json from trpc_agent_sdk.utils import user_key from ._base_session_service import BaseSessionService @@ -76,6 +77,23 @@ from ._utils import merge_state +def _event_field_or_default(field_name: str, value: Any) -> Any: + """Use Event's default when legacy SQL rows contain NULL for non-null Event fields.""" + if value is not None: + return value + return Event.model_fields[field_name].default + + +def _event_object_to_storage(value: Optional[str]) -> str: + """Store object as a non-null string for compatibility with existing SQL schemas.""" + return value or "" + + +def _event_object_from_storage(value: Optional[str]) -> Optional[str]: + """Restore Event.object default from the legacy empty-string storage sentinel.""" + return value or Event.model_fields["object"].default + + class SessionStorageBase(DeclarativeBase): """Base class for SqlSessionService tables only. @@ -161,28 +179,29 @@ class SessionStorageEvent(SessionStorageBase): author: Mapped[str] = mapped_column(UTF8MB4String(DEFAULT_MAX_VARCHAR_LENGTH)) actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType) long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True) - branch: Mapped[str] = mapped_column(UTF8MB4String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True) - request_id: Mapped[str] = mapped_column(UTF8MB4String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True) - parent_invocation_id: Mapped[str] = mapped_column(UTF8MB4String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True) - tag: Mapped[str] = mapped_column(UTF8MB4String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True) - filter_key: Mapped[str] = mapped_column(UTF8MB4String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True) - requires_completion: Mapped[bool] = mapped_column(Boolean, nullable=True) - version: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + branch: Mapped[Optional[str]] = mapped_column(UTF8MB4String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True) + request_id: Mapped[Optional[str]] = mapped_column(UTF8MB4String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True) + parent_invocation_id: Mapped[Optional[str]] = mapped_column(UTF8MB4String(DEFAULT_MAX_VARCHAR_LENGTH), + nullable=True) + tag: Mapped[Optional[str]] = mapped_column(UTF8MB4String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True) + filter_key: Mapped[Optional[str]] = mapped_column(UTF8MB4String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True) + requires_completion: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True, default=False) + version: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, default=0) timestamp: Mapped[PreciseTimestamp] = mapped_column(PreciseTimestamp, default=func.now()) - visible: Mapped[bool] = mapped_column(Boolean, nullable=True) + visible: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True, default=True) object: Mapped[str] = mapped_column(UTF8MB4String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=False, default="") - model_flags: Mapped[int] = mapped_column(Integer, nullable=False, default=1) - - partial: Mapped[bool] = mapped_column(Boolean, nullable=True) - turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) - error_code: Mapped[str] = mapped_column(UTF8MB4String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True) - error_message: Mapped[str] = mapped_column(UTF8MB4String(1024), nullable=True) - interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) - response_id: Mapped[str] = mapped_column(UTF8MB4String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True) - content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) - grounding_metadata: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) - custom_metadata: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) - usage_metadata: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) + model_flags: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, default=1) + + partial: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True) + turn_complete: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True) + error_code: Mapped[Optional[str]] = mapped_column(UTF8MB4String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True) + error_message: Mapped[Optional[str]] = mapped_column(UTF8MB4String(1024), nullable=True) + interrupted: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True) + response_id: Mapped[Optional[str]] = mapped_column(UTF8MB4String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True) + content: Mapped[Optional[dict[str, Any]]] = mapped_column(DynamicJSON, nullable=True) + grounding_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(DynamicJSON, nullable=True) + custom_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(DynamicJSON, nullable=True) + usage_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(DynamicJSON, nullable=True) storage_session: Mapped[StorageSession] = relationship( "StorageSession", @@ -226,7 +245,7 @@ def from_event(cls, session: Session, event: Event) -> SessionStorageEvent: version=event.version, timestamp=datetime.fromtimestamp(event.timestamp), visible=event.visible, - object=event.object, + object=_event_object_to_storage(event.object), model_flags=event.model_flags, partial=event.partial, turn_complete=event.turn_complete, @@ -236,7 +255,7 @@ def from_event(cls, session: Session, event: Event) -> SessionStorageEvent: response_id=event.response_id, ) if event.content: - storage_event.content = event.content.model_dump(exclude_none=True, mode="json") + storage_event.content = sanitize_content_json(event.content.model_dump(exclude_none=True, mode="json")) if event.grounding_metadata: storage_event.grounding_metadata = event.grounding_metadata.model_dump(exclude_none=True, mode="json") if event.custom_metadata: @@ -257,11 +276,11 @@ def to_event(self) -> Event: parent_invocation_id=self.parent_invocation_id, tag=self.tag, filter_key=self.filter_key, - requires_completion=self.requires_completion, - version=self.version, - visible=self.visible, - object=self.object, - model_flags=self.model_flags, + requires_completion=_event_field_or_default("requires_completion", self.requires_completion), + version=_event_field_or_default("version", self.version), + visible=_event_field_or_default("visible", self.visible), + object=_event_object_from_storage(self.object), + model_flags=_event_field_or_default("model_flags", self.model_flags), timestamp=self.timestamp.timestamp(), partial=self.partial, turn_complete=self.turn_complete, @@ -269,7 +288,7 @@ def to_event(self) -> Event: error_message=self.error_message, interrupted=self.interrupted, response_id=self.response_id, - content=decode_content(self.content), + content=decode_content(sanitize_content_json(self.content)), grounding_metadata=decode_grounding_metadata(self.grounding_metadata), custom_metadata=self.custom_metadata, usage_metadata=decode_usage_metadata(self.usage_metadata), @@ -574,7 +593,7 @@ async def _update_app_state(self, sql_session: SqlSession, app_name: str, state_ await self._sql_storage.add(sql_session, storage_app_state) else: storage_app_state.state = app_state # type: ignore - storage_app_state.update_time = func.now() + storage_app_state.update_time = datetime.now() return app_state @@ -604,7 +623,7 @@ async def _get_app_state(self, sql_session: SqlSession, app_name: str) -> dict[s if storage_app_state: if not self._session_config.is_expired_by_timestamp(storage_app_state.update_time.timestamp()): app_state = storage_app_state.state - storage_app_state.update_time = func.now() + storage_app_state.update_time = datetime.now() await self._sql_storage.commit(sql_session) return app_state @@ -617,7 +636,7 @@ async def _get_user_state(self, sql_session: SqlSession, app_name: str, user_id: if storage_user_state: if not self._session_config.is_expired_by_timestamp(storage_user_state.update_time.timestamp()): user_state = storage_user_state.state - storage_user_state.update_time = func.now() + storage_user_state.update_time = datetime.now() await self._sql_storage.commit(sql_session) return user_state @@ -633,7 +652,7 @@ async def _get_session(self, sql_session: SqlSession, app_name: str, user_id: st logger.debug("Session %s is expired", session_id) return None - storage_session.update_time = func.now() + storage_session.update_time = datetime.now() await self._sql_storage.commit(sql_session) return storage_session @@ -645,7 +664,7 @@ async def _cleanup_expired_async(self) -> None: Deletes all expired data in three batch SQL DELETE statements. """ async with self._sql_storage.create_db_session() as sql_session: - # Calculate expiration threshold once (using database local time) + # Calculate expiration threshold once in application time for cross-database compatibility. expire_before = datetime.now() - timedelta(seconds=self._session_config.ttl.ttl_seconds) total_deleted = 0 @@ -721,7 +740,7 @@ def _start_cleanup_task(self) -> None: return self.__cleanup_stop_event = asyncio.Event() - self.__cleanup_task = asyncio.create_task(self._cleanup_loop()) + self.__cleanup_task = asyncio.get_event_loop().create_task(self._cleanup_loop()) logger.debug("Cleanup task created") def _stop_cleanup_task(self) -> None: diff --git a/trpc_agent_sdk/storage/__init__.py b/trpc_agent_sdk/storage/__init__.py index a7d38bb..6b839bf 100644 --- a/trpc_agent_sdk/storage/__init__.py +++ b/trpc_agent_sdk/storage/__init__.py @@ -31,6 +31,7 @@ from ._sql_common import decode_grounding_metadata from ._sql_common import decode_usage_metadata from ._sql_common import GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY +from ._sql_common import sanitize_content_json from ._sql_common import TypeDecoratorHookRegistry __all__ = [ @@ -60,5 +61,6 @@ "decode_grounding_metadata", "decode_usage_metadata", "GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY", + "sanitize_content_json", "TypeDecoratorHookRegistry", ] diff --git a/trpc_agent_sdk/storage/_sql.py b/trpc_agent_sdk/storage/_sql.py index b836cd1..539c624 100644 --- a/trpc_agent_sdk/storage/_sql.py +++ b/trpc_agent_sdk/storage/_sql.py @@ -125,6 +125,9 @@ def __init__(self, is_async: bool, db_url: str, metadata: Optional[MetaData] = N self.__metadata = metadata if metadata is not None else StorageData.metadata self.__is_async = is_async self.__db_url = db_url + self.__sessionmaker_kwargs: dict[str, Any] = kwargs.pop("sessionmaker_kwargs", {}) + expire_on_commit: bool = kwargs.pop("expire_on_commit", True) + self.__sessionmaker_kwargs.setdefault("expire_on_commit", expire_on_commit) self.__kwargs = kwargs def _migrate_missing_columns(self, connection: Connection) -> None: @@ -226,14 +229,14 @@ async def _async_inspect(): async with db_engine.begin() as conn: await conn.run_sync(self.__metadata.create_all) await conn.run_sync(self._migrate_missing_columns) - self._database_session_factory = async_sessionmaker(bind=db_engine) + self._database_session_factory = async_sessionmaker(bind=db_engine, **self.__sessionmaker_kwargs) else: db_engine: SqlEngine = create_engine(self.__db_url, **self.__kwargs) self.inspector = inspect(db_engine) self.__metadata.create_all(db_engine) with db_engine.begin() as conn: self._migrate_missing_columns(conn) - self._database_session_factory = sessionmaker(bind=db_engine) + self._database_session_factory = sessionmaker(bind=db_engine, **self.__sessionmaker_kwargs) if db_engine.dialect.name == "sqlite": listen_target = db_engine.sync_engine if isinstance(db_engine, AsyncEngine) else db_engine diff --git a/trpc_agent_sdk/storage/_sql_common.py b/trpc_agent_sdk/storage/_sql_common.py index 165e991..07fd662 100644 --- a/trpc_agent_sdk/storage/_sql_common.py +++ b/trpc_agent_sdk/storage/_sql_common.py @@ -102,6 +102,33 @@ def run_process_result_hooks(cls, decorator: TypeDecorator, value: Any, dialect: # Global class object used as unified registration entry. GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY = TypeDecoratorHookRegistry +_VALID_PART_PAYLOAD_FIELDS = ( + "text", + "function_call", + "function_response", + "code_execution_result", + "executable_code", + "inline_data", +) + + +def sanitize_content_json(content: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: + """Drop empty parts so persisted history remains valid for model requests.""" + if not content: + return None + + parts = content.get("parts") or [] + valid_parts = [ + part for part in parts + if isinstance(part, dict) and any(part.get(field) for field in _VALID_PART_PAYLOAD_FIELDS) + ] + if not valid_parts: + return None + + sanitized_content = dict(content) + sanitized_content["parts"] = valid_parts + return sanitized_content + def decode_content(content: Optional[dict[str, Any]]) -> Optional[Content]: """Decode a content object from a JSON dictionary. @@ -296,7 +323,7 @@ def process_bind_param(self, value: Any, dialect: Dialect) -> Any: if hook_result is not None: return hook_result if value is not None: - if dialect.name == "spanner+spanner": + if dialect.name in ("mysql", "spanner+spanner"): return pickle.dumps(value) return value @@ -305,7 +332,7 @@ def process_result_value(self, value: Any, dialect: Dialect) -> Any: if hook_result is not None: return hook_result if value is not None: - if dialect.name == "spanner+spanner": + if dialect.name in ("mysql", "spanner+spanner"): return pickle.loads(value) return value