Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/engram/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ToolCallCustomInput,
ToolCallFuncInput,
ToolCallInput,
Topic,
)
from .async_client import AsyncEngramClient
from .client import EngramClient
Expand Down Expand Up @@ -50,6 +51,7 @@
"ToolCallCustomInput",
"ToolCallFuncInput",
"ToolCallInput",
"Topic",
"ValidationError",
"__version__",
]
4 changes: 4 additions & 0 deletions src/engram/_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
ToolCallCustomInput,
ToolCallFuncInput,
ToolCallInput,
Topic,
TopicSelector,
)
from .run import CommittedOperation, CommittedOperations, Run, RunStatus

Expand All @@ -31,4 +33,6 @@
"ToolCallCustomInput",
"ToolCallFuncInput",
"ToolCallInput",
"Topic",
"TopicSelector",
]
17 changes: 16 additions & 1 deletion src/engram/_models/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,21 @@ class RetrievalConfig:
limit: int | None = None


@dataclass(slots=True)
class Topic:
"""A topic with an optional per-topic property filter.

Use ``None`` as a property value to clear an inherited global filter
for this topic only.
"""

name: str
properties: dict[str, str | None] | None = None


TopicSelector: TypeAlias = str | Topic


@dataclass(slots=True)
class Memory:
id: str
Expand All @@ -106,9 +121,9 @@ class Memory:
created_at: str
updated_at: str
user_id: str | None = None
conversation_id: str | None = None
tags: list[str] | None = None
score: float | None = None
properties: dict[str, str] | None = None


class SearchResults(Sequence[Memory]):
Expand Down
29 changes: 18 additions & 11 deletions src/engram/_resources/memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
from uuid import UUID

from .._http import AsyncHttpTransport, HttpTransport
from .._models import AddInput, Memory, RetrievalConfig, Run, SearchResults
from .._models import (
AddInput,
Memory,
RetrievalConfig,
Run,
SearchResults,
TopicSelector,
)
from .._serialization import (
build_add_body,
build_memory_params,
Expand Down Expand Up @@ -32,14 +39,14 @@ def add(
input_data: AddInput,
*,
user_id: str | None = None,
conversation_id: str | None = None,
group: str | None = None,
properties: dict[str, str] | None = None,
) -> Run:
body = build_add_body(
input_data,
user_id=user_id,
conversation_id=conversation_id,
group=group,
properties=properties,
)
data = self._transport.request("POST", _MEMORIES_PATH, json=body)
return parse_run(data)
Expand Down Expand Up @@ -75,19 +82,19 @@ def search(
self,
*,
query: str,
topics: list[str] | None = None,
topics: list[TopicSelector] | None = None,
user_id: str | None = None,
conversation_id: str | None = None,
group: str | None = None,
retrieval_config: RetrievalConfig | None = None,
properties: dict[str, str] | None = None,
) -> SearchResults:
body = build_search_body(
query=query,
topics=topics,
user_id=user_id,
conversation_id=conversation_id,
group=group,
retrieval_config=retrieval_config,
properties=properties,
)
data = self._transport.request("POST", _MEMORIES_SEARCH_PATH, json=body)
return parse_search_results(data)
Expand All @@ -104,14 +111,14 @@ async def add(
input_data: AddInput,
*,
user_id: str | None = None,
conversation_id: str | None = None,
group: str | None = None,
properties: dict[str, str] | None = None,
) -> Run:
body = build_add_body(
input_data,
user_id=user_id,
conversation_id=conversation_id,
group=group,
properties=properties,
)
data = await self._transport.request("POST", _MEMORIES_PATH, json=body)
return parse_run(data)
Expand Down Expand Up @@ -147,19 +154,19 @@ async def search(
self,
*,
query: str,
topics: list[str] | None = None,
topics: list[TopicSelector] | None = None,
user_id: str | None = None,
conversation_id: str | None = None,
group: str | None = None,
retrieval_config: RetrievalConfig | None = None,
properties: dict[str, str] | None = None,
) -> SearchResults:
body = build_search_body(
query=query,
topics=topics,
user_id=user_id,
conversation_id=conversation_id,
group=group,
retrieval_config=retrieval_config,
properties=properties,
)
data = await self._transport.request("POST", _MEMORIES_SEARCH_PATH, json=body)
return parse_search_results(data)
43 changes: 33 additions & 10 deletions src/engram/_serialization/_builders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any
from typing import Any, TypeAlias

from .._models import (
AddInput,
Expand All @@ -9,6 +9,8 @@
RetrievalConfig,
StringInput,
ToolCallInput,
Topic,
TopicSelector,
)


Expand Down Expand Up @@ -65,20 +67,40 @@ def _serialize_conversation_content(content: ConversationInput) -> dict[str, Any
return {"conversation": conversation}


_SerializedTopic: TypeAlias = str | dict[str, str | dict[str, str | None]]


def _serialize_topic(topic: TopicSelector) -> _SerializedTopic:
if isinstance(topic, str):
return topic
if isinstance(topic, Topic):
out: dict[str, str | dict[str, str | None]] = {"name": topic.name}
if topic.properties is not None:
out["properties"] = dict(topic.properties)
return out
raise TypeError(f"Unsupported topic type: {type(topic)}") # pragma: no cover


def _serialize_topics(topics: list[TopicSelector] | None) -> list[_SerializedTopic] | None:
if topics is None:
return None
return [_serialize_topic(t) for t in topics]


def build_add_body(
input_data: AddInput,
*,
user_id: str | None,
conversation_id: str | None,
group: str | None,
properties: dict[str, str] | None = None,
) -> dict[str, Any]:
body: dict[str, Any] = {"input": _serialize_input(input_data)}
if user_id is not None:
body["user_id"] = user_id
if conversation_id is not None:
body["conversation_id"] = conversation_id
if group is not None:
body["group"] = group
if properties is not None:
body["properties"] = dict(properties)
return body


Expand All @@ -98,24 +120,25 @@ def build_memory_params(
def build_search_body(
*,
query: str,
topics: list[str] | None,
topics: list[TopicSelector] | None,
user_id: str | None,
conversation_id: str | None,
group: str | None,
retrieval_config: RetrievalConfig | None,
properties: dict[str, str] | None = None,
) -> dict[str, Any]:
body: dict[str, Any] = {"query": query}
if retrieval_config is not None:
body["retrieval_config"] = {
"retrieval_type": retrieval_config.retrieval_type,
"limit": retrieval_config.limit,
}
if topics is not None:
body["topics"] = topics
serialized_topics = _serialize_topics(topics)
if serialized_topics is not None:
body["topics"] = serialized_topics
if user_id is not None:
body["user_id"] = user_id
if conversation_id is not None:
body["conversation_id"] = conversation_id
if group is not None:
body["group"] = group
if properties is not None:
body["properties"] = dict(properties)
return body
2 changes: 1 addition & 1 deletion src/engram/_serialization/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def parse_memory(data: dict[str, Any]) -> Memory:
created_at=data["created_at"],
updated_at=data["updated_at"],
user_id=data.get("user_id"),
conversation_id=data.get("conversation_id"),
tags=data.get("tags"),
score=data.get("score"),
properties=data.get("properties"),
)


Expand Down
48 changes: 44 additions & 4 deletions tests/test_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
StringInput,
ToolCallFuncInput,
ToolCallInput,
Topic,
)
from engram.async_client import DEFAULT_BASE_URL, AsyncEngramClient
from engram.errors import APIError, AuthenticationError, ValidationError
Expand Down Expand Up @@ -130,7 +131,7 @@ async def test_add_conversation() -> None:
result = await client.memories.add(
[{"role": "user", "content": "hi"}],
user_id="u1",
conversation_id="c1",
properties={"conversation_id": "c1"},
)
assert result.run_id == "r3"

Expand Down Expand Up @@ -233,7 +234,7 @@ async def test_add_conversation_content() -> None:
result = await client.memories.add(
ConversationInput(messages=[MessageInput(role="user", content="hi")]),
user_id="u1",
conversation_id="c1",
properties={"conversation_id": "c1"},
)
assert result.run_id == "r5"

Expand Down Expand Up @@ -262,15 +263,15 @@ def handler(request: httpx.Request) -> httpx.Response:
],
metadata={"session_id": "s1"},
),
conversation_id="c1",
properties={"conversation_id": "c1"},
)
body = json.loads(captured[0].content)
conv = body["input"]["conversation"]
assert conv["metadata"] == {"session_id": "s1"}
assert conv["messages"][1]["tool_calls"] == [
{"id": "tc1", "type": "function", "function": {"name": "search", "arguments": "{}"}}
]
assert body["conversation_id"] == "c1"
assert body["properties"] == {"conversation_id": "c1"}


# ── memories.get ────────────────────────────────────────────────────────
Expand Down Expand Up @@ -371,6 +372,45 @@ def handler(request: httpx.Request) -> httpx.Response:
assert body["retrieval_config"]["limit"] == 5


# ── properties / list ───────────────────────────────────────────────────


@pytest.mark.asyncio
async def test_add_sends_properties() -> None:
captured: list[httpx.Request] = []

def handler(request: httpx.Request) -> httpx.Response:
captured.append(request)
return httpx.Response(200, json={"run_id": "r1", "status": "pending"})

client = _make_client_with_handler(handler)
await client.memories.add(
"hello",
properties={"region": "eu"},
)
body = json.loads(captured[0].content)
assert body["properties"] == {"region": "eu"}


@pytest.mark.asyncio
async def test_search_sends_topic_filters() -> None:
captured: list[httpx.Request] = []

def handler(request: httpx.Request) -> httpx.Response:
captured.append(request)
return httpx.Response(200, json={"memories": [], "total": 0})

client = _make_client_with_handler(handler)
await client.memories.search(
query="q",
topics=[Topic(name="t1", properties={"region": "eu"})],
properties={"tier": "pro"},
)
body = json.loads(captured[0].content)
assert body["topics"] == [{"name": "t1", "properties": {"region": "eu"}}]
assert body["properties"] == {"tier": "pro"}


# ── runs.get ────────────────────────────────────────────────────────────


Expand Down
Loading
Loading