diff --git a/src/engram/__init__.py b/src/engram/__init__.py index 9fbb4b5..6e06efb 100644 --- a/src/engram/__init__.py +++ b/src/engram/__init__.py @@ -14,6 +14,7 @@ ToolCallCustomInput, ToolCallFuncInput, ToolCallInput, + Topic, ) from .async_client import AsyncEngramClient from .client import EngramClient @@ -50,6 +51,7 @@ "ToolCallCustomInput", "ToolCallFuncInput", "ToolCallInput", + "Topic", "ValidationError", "__version__", ] diff --git a/src/engram/_models/__init__.py b/src/engram/_models/__init__.py index b9906e2..04cbfc1 100644 --- a/src/engram/_models/__init__.py +++ b/src/engram/_models/__init__.py @@ -11,6 +11,8 @@ ToolCallCustomInput, ToolCallFuncInput, ToolCallInput, + Topic, + TopicSelector, ) from .run import CommittedOperation, CommittedOperations, Run, RunStatus @@ -31,4 +33,6 @@ "ToolCallCustomInput", "ToolCallFuncInput", "ToolCallInput", + "Topic", + "TopicSelector", ] diff --git a/src/engram/_models/memory.py b/src/engram/_models/memory.py index 17ed194..6cd79d8 100644 --- a/src/engram/_models/memory.py +++ b/src/engram/_models/memory.py @@ -96,6 +96,21 @@ class RetrievalConfig: limit: int | None = None +@dataclass(slots=True) +class Topic: + """A topic with an optional per-topic property filter. + + Use ``None`` as a property value to clear an inherited global filter + for this topic only. + """ + + name: str + properties: dict[str, str | None] | None = None + + +TopicSelector: TypeAlias = str | Topic + + @dataclass(slots=True) class Memory: id: str @@ -106,9 +121,9 @@ class Memory: created_at: str updated_at: str user_id: str | None = None - conversation_id: str | None = None tags: list[str] | None = None score: float | None = None + properties: dict[str, str] | None = None class SearchResults(Sequence[Memory]): diff --git a/src/engram/_resources/memories.py b/src/engram/_resources/memories.py index 1b22627..2db72d5 100644 --- a/src/engram/_resources/memories.py +++ b/src/engram/_resources/memories.py @@ -3,7 +3,14 @@ from uuid import UUID from .._http import AsyncHttpTransport, HttpTransport -from .._models import AddInput, Memory, RetrievalConfig, Run, SearchResults +from .._models import ( + AddInput, + Memory, + RetrievalConfig, + Run, + SearchResults, + TopicSelector, +) from .._serialization import ( build_add_body, build_memory_params, @@ -32,14 +39,14 @@ def add( input_data: AddInput, *, user_id: str | None = None, - conversation_id: str | None = None, group: str | None = None, + properties: dict[str, str] | None = None, ) -> Run: body = build_add_body( input_data, user_id=user_id, - conversation_id=conversation_id, group=group, + properties=properties, ) data = self._transport.request("POST", _MEMORIES_PATH, json=body) return parse_run(data) @@ -75,19 +82,19 @@ def search( self, *, query: str, - topics: list[str] | None = None, + topics: list[TopicSelector] | None = None, user_id: str | None = None, - conversation_id: str | None = None, group: str | None = None, retrieval_config: RetrievalConfig | None = None, + properties: dict[str, str] | None = None, ) -> SearchResults: body = build_search_body( query=query, topics=topics, user_id=user_id, - conversation_id=conversation_id, group=group, retrieval_config=retrieval_config, + properties=properties, ) data = self._transport.request("POST", _MEMORIES_SEARCH_PATH, json=body) return parse_search_results(data) @@ -104,14 +111,14 @@ async def add( input_data: AddInput, *, user_id: str | None = None, - conversation_id: str | None = None, group: str | None = None, + properties: dict[str, str] | None = None, ) -> Run: body = build_add_body( input_data, user_id=user_id, - conversation_id=conversation_id, group=group, + properties=properties, ) data = await self._transport.request("POST", _MEMORIES_PATH, json=body) return parse_run(data) @@ -147,19 +154,19 @@ async def search( self, *, query: str, - topics: list[str] | None = None, + topics: list[TopicSelector] | None = None, user_id: str | None = None, - conversation_id: str | None = None, group: str | None = None, retrieval_config: RetrievalConfig | None = None, + properties: dict[str, str] | None = None, ) -> SearchResults: body = build_search_body( query=query, topics=topics, user_id=user_id, - conversation_id=conversation_id, group=group, retrieval_config=retrieval_config, + properties=properties, ) data = await self._transport.request("POST", _MEMORIES_SEARCH_PATH, json=body) return parse_search_results(data) diff --git a/src/engram/_serialization/_builders.py b/src/engram/_serialization/_builders.py index 7a975db..ef28334 100644 --- a/src/engram/_serialization/_builders.py +++ b/src/engram/_serialization/_builders.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, TypeAlias from .._models import ( AddInput, @@ -9,6 +9,8 @@ RetrievalConfig, StringInput, ToolCallInput, + Topic, + TopicSelector, ) @@ -65,20 +67,40 @@ def _serialize_conversation_content(content: ConversationInput) -> dict[str, Any return {"conversation": conversation} +_SerializedTopic: TypeAlias = str | dict[str, str | dict[str, str | None]] + + +def _serialize_topic(topic: TopicSelector) -> _SerializedTopic: + if isinstance(topic, str): + return topic + if isinstance(topic, Topic): + out: dict[str, str | dict[str, str | None]] = {"name": topic.name} + if topic.properties is not None: + out["properties"] = dict(topic.properties) + return out + raise TypeError(f"Unsupported topic type: {type(topic)}") # pragma: no cover + + +def _serialize_topics(topics: list[TopicSelector] | None) -> list[_SerializedTopic] | None: + if topics is None: + return None + return [_serialize_topic(t) for t in topics] + + def build_add_body( input_data: AddInput, *, user_id: str | None, - conversation_id: str | None, group: str | None, + properties: dict[str, str] | None = None, ) -> dict[str, Any]: body: dict[str, Any] = {"input": _serialize_input(input_data)} if user_id is not None: body["user_id"] = user_id - if conversation_id is not None: - body["conversation_id"] = conversation_id if group is not None: body["group"] = group + if properties is not None: + body["properties"] = dict(properties) return body @@ -98,11 +120,11 @@ def build_memory_params( def build_search_body( *, query: str, - topics: list[str] | None, + topics: list[TopicSelector] | None, user_id: str | None, - conversation_id: str | None, group: str | None, retrieval_config: RetrievalConfig | None, + properties: dict[str, str] | None = None, ) -> dict[str, Any]: body: dict[str, Any] = {"query": query} if retrieval_config is not None: @@ -110,12 +132,13 @@ def build_search_body( "retrieval_type": retrieval_config.retrieval_type, "limit": retrieval_config.limit, } - if topics is not None: - body["topics"] = topics + serialized_topics = _serialize_topics(topics) + if serialized_topics is not None: + body["topics"] = serialized_topics if user_id is not None: body["user_id"] = user_id - if conversation_id is not None: - body["conversation_id"] = conversation_id if group is not None: body["group"] = group + if properties is not None: + body["properties"] = dict(properties) return body diff --git a/src/engram/_serialization/_parsers.py b/src/engram/_serialization/_parsers.py index 92a01fc..662493d 100644 --- a/src/engram/_serialization/_parsers.py +++ b/src/engram/_serialization/_parsers.py @@ -30,9 +30,9 @@ def parse_memory(data: dict[str, Any]) -> Memory: created_at=data["created_at"], updated_at=data["updated_at"], user_id=data.get("user_id"), - conversation_id=data.get("conversation_id"), tags=data.get("tags"), score=data.get("score"), + properties=data.get("properties"), ) diff --git a/tests/test_client_async.py b/tests/test_client_async.py index 64d35f8..e77e1b3 100644 --- a/tests/test_client_async.py +++ b/tests/test_client_async.py @@ -14,6 +14,7 @@ StringInput, ToolCallFuncInput, ToolCallInput, + Topic, ) from engram.async_client import DEFAULT_BASE_URL, AsyncEngramClient from engram.errors import APIError, AuthenticationError, ValidationError @@ -130,7 +131,7 @@ async def test_add_conversation() -> None: result = await client.memories.add( [{"role": "user", "content": "hi"}], user_id="u1", - conversation_id="c1", + properties={"conversation_id": "c1"}, ) assert result.run_id == "r3" @@ -233,7 +234,7 @@ async def test_add_conversation_content() -> None: result = await client.memories.add( ConversationInput(messages=[MessageInput(role="user", content="hi")]), user_id="u1", - conversation_id="c1", + properties={"conversation_id": "c1"}, ) assert result.run_id == "r5" @@ -262,7 +263,7 @@ def handler(request: httpx.Request) -> httpx.Response: ], metadata={"session_id": "s1"}, ), - conversation_id="c1", + properties={"conversation_id": "c1"}, ) body = json.loads(captured[0].content) conv = body["input"]["conversation"] @@ -270,7 +271,7 @@ def handler(request: httpx.Request) -> httpx.Response: assert conv["messages"][1]["tool_calls"] == [ {"id": "tc1", "type": "function", "function": {"name": "search", "arguments": "{}"}} ] - assert body["conversation_id"] == "c1" + assert body["properties"] == {"conversation_id": "c1"} # ── memories.get ──────────────────────────────────────────────────────── @@ -371,6 +372,45 @@ def handler(request: httpx.Request) -> httpx.Response: assert body["retrieval_config"]["limit"] == 5 +# ── properties / list ─────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_add_sends_properties() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"run_id": "r1", "status": "pending"}) + + client = _make_client_with_handler(handler) + await client.memories.add( + "hello", + properties={"region": "eu"}, + ) + body = json.loads(captured[0].content) + assert body["properties"] == {"region": "eu"} + + +@pytest.mark.asyncio +async def test_search_sends_topic_filters() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"memories": [], "total": 0}) + + client = _make_client_with_handler(handler) + await client.memories.search( + query="q", + topics=[Topic(name="t1", properties={"region": "eu"})], + properties={"tier": "pro"}, + ) + body = json.loads(captured[0].content) + assert body["topics"] == [{"name": "t1", "properties": {"region": "eu"}}] + assert body["properties"] == {"tier": "pro"} + + # ── runs.get ──────────────────────────────────────────────────────────── diff --git a/tests/test_client_sync.py b/tests/test_client_sync.py index 5e666c9..682823d 100644 --- a/tests/test_client_sync.py +++ b/tests/test_client_sync.py @@ -14,6 +14,7 @@ StringInput, ToolCallFuncInput, ToolCallInput, + Topic, ) from engram.client import DEFAULT_BASE_URL, EngramClient from engram.errors import APIError, AuthenticationError, ValidationError @@ -130,7 +131,7 @@ def test_add_conversation() -> None: result = client.memories.add( [{"role": "user", "content": "hi"}], user_id="u1", - conversation_id="c1", + properties={"conversation_id": "c1"}, ) assert result.run_id == "r3" @@ -161,11 +162,11 @@ def handler(request: httpx.Request) -> httpx.Response: client = _make_client_with_handler(handler) messages = [{"role": "user", "content": "hi"}] - client.memories.add(messages, conversation_id="c1") + client.memories.add(messages, properties={"conversation_id": "c1"}) body = json.loads(captured[0].content) assert body == { "input": {"conversation": {"messages": messages}}, - "conversation_id": "c1", + "properties": {"conversation_id": "c1"}, } @@ -244,7 +245,7 @@ def test_add_conversation_content() -> None: result = client.memories.add( ConversationInput(messages=[MessageInput(role="user", content="hi")]), user_id="u1", - conversation_id="c1", + properties={"conversation_id": "c1"}, ) assert result.run_id == "r5" @@ -272,7 +273,7 @@ def handler(request: httpx.Request) -> httpx.Response: ], metadata={"session_id": "s1"}, ), - conversation_id="c1", + properties={"conversation_id": "c1"}, ) body = json.loads(captured[0].content) conv = body["input"]["conversation"] @@ -280,7 +281,7 @@ def handler(request: httpx.Request) -> httpx.Response: assert conv["messages"][1]["tool_calls"] == [ {"id": "tc1", "type": "function", "function": {"name": "search", "arguments": "{}"}} ] - assert body["conversation_id"] == "c1" + assert body["properties"] == {"conversation_id": "c1"} # ── memories.get ──────────────────────────────────────────────────────── @@ -388,6 +389,62 @@ def handler(request: httpx.Request) -> httpx.Response: assert "retrieval_config" not in body +# ── properties support ────────────────────────────────────────────────── + + +def test_add_sends_properties() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"run_id": "r1", "status": "pending"}) + + client = _make_client_with_handler(handler) + client.memories.add( + "hello", + user_id="u1", + properties={"region": "eu", "tier": "pro"}, + ) + body = json.loads(captured[0].content) + assert body["properties"] == {"region": "eu", "tier": "pro"} + + +def test_search_sends_properties_and_topic_filters() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"memories": [], "total": 0}) + + client = _make_client_with_handler(handler) + client.memories.search( + query="q", + topics=[ + "plain", + Topic(name="scoped", properties={"region": "eu"}), + Topic(name="cleared", properties={"region": None}), + ], + properties={"tier": "pro"}, + ) + body = json.loads(captured[0].content) + assert body["properties"] == {"tier": "pro"} + assert body["topics"] == [ + "plain", + {"name": "scoped", "properties": {"region": "eu"}}, + {"name": "cleared", "properties": {"region": None}}, + ] + + +def test_get_memory_returns_properties() -> None: + response_body = { + **SAMPLE_MEMORY_RESPONSE, + "properties": {"region": "eu", "tier": "pro"}, + } + client = _make_client(body=response_body) + mem = client.memories.get("m1") + assert mem.properties == {"region": "eu", "tier": "pro"} + + # ── runs.get ──────────────────────────────────────────────────────────── diff --git a/tests/test_imports.py b/tests/test_imports.py index fb29e1f..3453f89 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -23,6 +23,7 @@ def test_public_imports() -> None: ToolCallCustomInput, ToolCallFuncInput, ToolCallInput, + Topic, ValidationError, ) @@ -48,6 +49,7 @@ def test_public_imports() -> None: assert isinstance(ToolCallCustomInput, type) assert isinstance(ToolCallFuncInput, type) assert isinstance(ToolCallInput, type) + assert isinstance(Topic, type) expected_exports = { "APIError", @@ -72,6 +74,7 @@ def test_public_imports() -> None: "ToolCallCustomInput", "ToolCallFuncInput", "ToolCallInput", + "Topic", "ValidationError", "__version__", } diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 6328d81..98283cb 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -8,6 +8,7 @@ ToolCallCustomInput, ToolCallFuncInput, ToolCallInput, + Topic, ) from engram._serialization import ( build_add_body, @@ -26,7 +27,6 @@ def test_build_add_body_str() -> None: body = build_add_body( "hello world", user_id=None, - conversation_id=None, group=None, ) assert body == {"input": {"string": {"content": ["hello world"]}}} @@ -36,13 +36,11 @@ def test_build_add_body_str_with_options() -> None: body = build_add_body( "hello", user_id="u1", - conversation_id="c1", group="g1", ) assert body == { "input": {"string": {"content": ["hello"]}}, "user_id": "u1", - "conversation_id": "c1", "group": "g1", } @@ -51,7 +49,6 @@ def test_build_add_body_pre_extracted() -> None: body = build_add_body( PreExtractedInput(items=[PreExtractedItem(content="fact", topic="topic")]), user_id=None, - conversation_id=None, group=None, ) assert body == { @@ -67,13 +64,11 @@ def test_build_add_body_conversation() -> None: body = build_add_body( messages, user_id="u1", - conversation_id="c1", group=None, ) assert body == { "input": {"conversation": {"messages": messages}}, "user_id": "u1", - "conversation_id": "c1", } @@ -81,7 +76,6 @@ def test_build_add_body_string_content() -> None: body = build_add_body( StringInput(content="hello world"), user_id=None, - conversation_id=None, group=None, ) assert body == {"input": {"string": {"content": ["hello world"]}}} @@ -91,13 +85,11 @@ def test_build_add_body_string_content_with_options() -> None: body = build_add_body( StringInput(content="hello"), user_id="u1", - conversation_id="c1", group="g1", ) assert body == { "input": {"string": {"content": ["hello"]}}, "user_id": "u1", - "conversation_id": "c1", "group": "g1", } @@ -110,7 +102,6 @@ def test_build_add_body_conversation_content() -> None: body = build_add_body( ConversationInput(messages=messages), user_id="u1", - conversation_id="c1", group=None, ) assert body == { @@ -123,7 +114,6 @@ def test_build_add_body_conversation_content() -> None: }, }, "user_id": "u1", - "conversation_id": "c1", } @@ -137,7 +127,6 @@ def test_build_add_body_conversation_content_with_metadata() -> None: updated_at="2024-01-02T00:00:00Z", ), user_id=None, - conversation_id=None, group=None, ) conv = body["input"]["conversation"] @@ -151,7 +140,6 @@ def test_build_add_body_conversation_content_with_message_timestamps() -> None: body = build_add_body( ConversationInput(messages=messages), user_id=None, - conversation_id=None, group=None, ) msg = body["input"]["conversation"]["messages"][0] @@ -173,7 +161,6 @@ def test_build_add_body_conversation_content_with_tool_calls() -> None: body = build_add_body( ConversationInput(messages=messages), user_id=None, - conversation_id=None, group=None, ) msg = body["input"]["conversation"]["messages"][0] @@ -198,7 +185,6 @@ def test_build_add_body_conversation_content_with_custom_tool_calls() -> None: body = build_add_body( ConversationInput(messages=messages), user_id=None, - conversation_id=None, group=None, ) msg = body["input"]["conversation"]["messages"][0] @@ -212,7 +198,6 @@ def test_build_add_body_conversation_content_with_tool_role() -> None: body = build_add_body( ConversationInput(messages=messages), user_id=None, - conversation_id=None, group=None, ) msg = body["input"]["conversation"]["messages"][0] @@ -227,7 +212,6 @@ def test_build_add_body_conversation_content_with_developer_role() -> None: body = build_add_body( ConversationInput(messages=messages), user_id=None, - conversation_id=None, group=None, ) msg = body["input"]["conversation"]["messages"][0] @@ -259,7 +243,6 @@ def test_build_search_body_defaults() -> None: query="test", topics=None, user_id=None, - conversation_id=None, group=None, retrieval_config=None, ) @@ -271,7 +254,6 @@ def test_build_search_body_full() -> None: query="test", topics=["a", "b"], user_id="u1", - conversation_id="c1", group="g1", retrieval_config=RetrievalConfig(retrieval_type="vector", limit=5), ) @@ -281,6 +263,77 @@ def test_build_search_body_full() -> None: assert body["retrieval_config"]["limit"] == 5 +# ── properties on add ─────────────────────────────────────────────────── + + +def test_build_add_body_with_properties() -> None: + body = build_add_body( + "hello", + user_id=None, + group=None, + properties={"region": "eu", "tier": "pro"}, + ) + assert body == { + "input": {"string": {"content": ["hello"]}}, + "properties": {"region": "eu", "tier": "pro"}, + } + + +def test_build_add_body_properties_none_omitted() -> None: + body = build_add_body( + "hello", + user_id=None, + group=None, + properties=None, + ) + assert "properties" not in body + + +# ── properties + topic filters on search ──────────────────────────────── + + +def test_build_search_body_with_properties() -> None: + body = build_search_body( + query="q", + topics=None, + user_id=None, + group=None, + retrieval_config=None, + properties={"region": "eu"}, + ) + assert body == {"query": "q", "properties": {"region": "eu"}} + + +def test_build_search_body_with_topic_filter() -> None: + body = build_search_body( + query="q", + topics=[ + "plain", + Topic(name="scoped", properties={"region": "eu"}), + Topic(name="cleared", properties={"region": None}), + ], + user_id=None, + group=None, + retrieval_config=None, + ) + assert body["topics"] == [ + "plain", + {"name": "scoped", "properties": {"region": "eu"}}, + {"name": "cleared", "properties": {"region": None}}, + ] + + +def test_build_search_body_topic_filter_without_properties() -> None: + body = build_search_body( + query="q", + topics=[Topic(name="t1")], + user_id=None, + group=None, + retrieval_config=None, + ) + assert body["topics"] == [{"name": "t1"}] + + # ── parse_run ─────────────────────────────────────────────────────────── @@ -322,14 +375,15 @@ def test_parse_memory_with_optional_fields() -> None: data = { **SAMPLE_MEMORY, "user_id": "u1", - "conversation_id": "c1", "tags": ["x"], "score": 0.95, + "properties": {"region": "eu", "conversation_id": "c1"}, } mem = parse_memory(data) assert mem.user_id == "u1" assert mem.tags == ["x"] assert mem.score == 0.95 + assert mem.properties == {"region": "eu", "conversation_id": "c1"} # ── parse_search_results ────────────────────────────────────────────────