Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions test/query_agent/test_query_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,54 @@ def fake_post_with_capture(url, headers=None, json=None, timeout=None):
assert captured["json"]["diversity_weight"] is None


def test_search_only_mode_with_explain_score(monkeypatch):
captured = {}

def fake_post_with_capture(url, headers=None, json=None, timeout=None):
captured["json"] = json
return fake_post_search_only_success()

monkeypatch.setattr(httpx, "post", fake_post_with_capture)
dummy_client = DummyClient()
agent = QueryAgent(
dummy_client, ["test_collection"], agents_host="http://dummy-agent"
)
agent._connection = dummy_client
agent._headers = dummy_client.additional_headers

# Test with explain_score set
results = agent.search("test query", limit=2, explain_score=True)
assert isinstance(results, SearchModeResponse)
assert captured["json"]["explain_score"] is True

# Reset captured json, then paginate — explain_score should persist
captured = {}
results_2 = results.next(limit=2, offset=1)
assert isinstance(results_2, SearchModeResponse)
assert captured["json"]["explain_score"] is True


def test_search_only_mode_without_explain_score(monkeypatch):
captured = {}

def fake_post_with_capture(url, headers=None, json=None, timeout=None):
captured["json"] = json
return fake_post_search_only_success()

monkeypatch.setattr(httpx, "post", fake_post_with_capture)
dummy_client = DummyClient()
agent = QueryAgent(
dummy_client, ["test_collection"], agents_host="http://dummy-agent"
)
agent._connection = dummy_client
agent._headers = dummy_client.additional_headers

# Test without explain_score — should default to False
results = agent.search("test query", limit=2)
assert isinstance(results, SearchModeResponse)
assert captured["json"]["explain_score"] is False


def test_search_only_mode_failure(monkeypatch):
monkeypatch.setattr(httpx, "post", fake_post_failure)
dummy_client = DummyClient()
Expand Down Expand Up @@ -893,6 +941,33 @@ async def fake_post_with_capture(self, url, headers=None, json=None, timeout=Non
assert captured["json"]["diversity_weight"] == 0.7


async def test_async_search_only_mode_with_explain_score(monkeypatch):
captured = {}

async def fake_post_with_capture(self, url, headers=None, json=None, timeout=None):
captured["json"] = json
return await fake_async_post_search_only_success()

monkeypatch.setattr(httpx.AsyncClient, "post", fake_post_with_capture)
dummy_client = DummyClient()
agent = AsyncQueryAgent(
dummy_client, ["test_collection"], agents_host="http://dummy-agent"
)
agent._connection = dummy_client
agent._headers = dummy_client.additional_headers

# Test with explain_score set
results = await agent.search("test query", limit=2, explain_score=True)
assert isinstance(results, AsyncSearchModeResponse)
assert captured["json"]["explain_score"] is True

# Reset captured json, then paginate — explain_score should persist
captured = {}
results_2 = await results.next(limit=2, offset=1)
assert isinstance(results_2, AsyncSearchModeResponse)
assert captured["json"]["explain_score"] is True


async def test_async_search_only_mode_failure(monkeypatch):
monkeypatch.setattr(httpx.AsyncClient, "post", fake_async_post_failure)
dummy_client = DummyClient()
Expand Down
1 change: 1 addition & 0 deletions weaviate_agents/query/classes/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class SearchModeRequestBase(BaseModel):
limit: int
offset: int
diversity_weight: Optional[float] = None
explain_score: bool = False


class SearchModeExecutionRequest(SearchModeRequestBase):
Expand Down
10 changes: 10 additions & 0 deletions weaviate_agents/query/query_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,7 @@ def search(
limit: int = 20,
collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None,
diversity_weight: Optional[float] = None,
explain_score: bool = False,
) -> SearchModeResponse:
"""Run the Query Agent search-only mode.

Expand All @@ -1031,6 +1032,9 @@ def search(
results with MMR reranking.
Higher values push for more topical variety at the cost of relevance.
Defaults to None (no diversity).
explain_score: If True, each result will include a short natural-language
explanation of why it is or isn't relevant to the query.
Defaults to False.

Returns:
An instance of :class:`~weaviate_agents.query.classes.response.SearchModeResponse` for the first page of results. Use
Expand Down Expand Up @@ -1066,6 +1070,7 @@ def search(
collections=collections,
system_prompt=self._system_prompt,
diversity_weight=diversity_weight,
explain_score=explain_score,
)
return searcher.run(limit=limit)

Expand Down Expand Up @@ -1653,6 +1658,7 @@ async def search(
limit: int = 20,
collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None,
diversity_weight: Optional[float] = None,
explain_score: bool = False,
) -> AsyncSearchModeResponse:
"""Run the Query Agent search-only mode.

Expand All @@ -1671,6 +1677,9 @@ async def search(
results with MMR reranking.
Higher values push for more topical variety at the cost of relevance.
Defaults to None (no diversity).
explain_score: If True, each result will include a short natural-language
explanation of why it is or isn't relevant to the query.
Defaults to False.

Returns:
An instance of :class:`~weaviate_agents.query.classes.response.AsyncSearchModeResponse` for the first page of results. Use
Expand Down Expand Up @@ -1706,6 +1715,7 @@ async def search(
collections=collections,
system_prompt=self._system_prompt,
diversity_weight=diversity_weight,
explain_score=explain_score,
)
return await searcher.run(limit=limit)

Expand Down
4 changes: 4 additions & 0 deletions weaviate_agents/query/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
collections: list[Union[str, QueryAgentCollectionConfig]],
system_prompt: Optional[str],
diversity_weight: Optional[float] = None,
explain_score: bool = False,
):
self.headers = headers
self.connection_headers = connection_headers
Expand All @@ -45,6 +46,7 @@ def __init__(
self.collections = collections
self.system_prompt = system_prompt
self.diversity_weight = diversity_weight
self.explain_score = explain_score
self._cached_searches: Optional[list[QueryResultWithCollectionNormalized]] = (
None
)
Expand All @@ -64,6 +66,7 @@ def _get_request_body(self, limit: int, offset: int) -> dict[str, Any]:
offset=offset,
system_prompt=self.system_prompt,
diversity_weight=self.diversity_weight,
explain_score=self.explain_score,
).model_dump(mode="json")
else:
return SearchModeExecutionRequest(
Expand All @@ -74,6 +77,7 @@ def _get_request_body(self, limit: int, offset: int) -> dict[str, Any]:
offset=offset,
searches=self._cached_searches,
diversity_weight=self.diversity_weight,
explain_score=self.explain_score,
).model_dump(mode="json")


Expand Down
Loading