From 5a479d547b0a61d23b73506afd3405338fc096dc Mon Sep 17 00:00:00 2001 From: Connor Shorten Date: Tue, 26 May 2026 16:23:10 -0400 Subject: [PATCH] first look --- test/query_agent/test_query_model.py | 75 ++++++++++++++++++++++++ weaviate_agents/query/classes/request.py | 1 + weaviate_agents/query/query_agent.py | 10 ++++ weaviate_agents/query/search.py | 4 ++ 4 files changed, 90 insertions(+) diff --git a/test/query_agent/test_query_model.py b/test/query_agent/test_query_model.py index 3bb0c76..7eabb18 100644 --- a/test/query_agent/test_query_model.py +++ b/test/query_agent/test_query_model.py @@ -797,6 +797,54 @@ def fake_post_with_capture(url, headers=None, json=None, timeout=None): assert captured["json"]["diversity_weight"] is None +def test_search_only_mode_with_explain_score(monkeypatch): + captured = {} + + def fake_post_with_capture(url, headers=None, json=None, timeout=None): + captured["json"] = json + return fake_post_search_only_success() + + monkeypatch.setattr(httpx, "post", fake_post_with_capture) + dummy_client = DummyClient() + agent = QueryAgent( + dummy_client, ["test_collection"], agents_host="http://dummy-agent" + ) + agent._connection = dummy_client + agent._headers = dummy_client.additional_headers + + # Test with explain_score set + results = agent.search("test query", limit=2, explain_score=True) + assert isinstance(results, SearchModeResponse) + assert captured["json"]["explain_score"] is True + + # Reset captured json, then paginate — explain_score should persist + captured = {} + results_2 = results.next(limit=2, offset=1) + assert isinstance(results_2, SearchModeResponse) + assert captured["json"]["explain_score"] is True + + +def test_search_only_mode_without_explain_score(monkeypatch): + captured = {} + + def fake_post_with_capture(url, headers=None, json=None, timeout=None): + captured["json"] = json + return fake_post_search_only_success() + + monkeypatch.setattr(httpx, "post", fake_post_with_capture) + dummy_client = DummyClient() + agent = QueryAgent( + dummy_client, ["test_collection"], agents_host="http://dummy-agent" + ) + agent._connection = dummy_client + agent._headers = dummy_client.additional_headers + + # Test without explain_score — should default to False + results = agent.search("test query", limit=2) + assert isinstance(results, SearchModeResponse) + assert captured["json"]["explain_score"] is False + + def test_search_only_mode_failure(monkeypatch): monkeypatch.setattr(httpx, "post", fake_post_failure) dummy_client = DummyClient() @@ -893,6 +941,33 @@ async def fake_post_with_capture(self, url, headers=None, json=None, timeout=Non assert captured["json"]["diversity_weight"] == 0.7 +async def test_async_search_only_mode_with_explain_score(monkeypatch): + captured = {} + + async def fake_post_with_capture(self, url, headers=None, json=None, timeout=None): + captured["json"] = json + return await fake_async_post_search_only_success() + + monkeypatch.setattr(httpx.AsyncClient, "post", fake_post_with_capture) + dummy_client = DummyClient() + agent = AsyncQueryAgent( + dummy_client, ["test_collection"], agents_host="http://dummy-agent" + ) + agent._connection = dummy_client + agent._headers = dummy_client.additional_headers + + # Test with explain_score set + results = await agent.search("test query", limit=2, explain_score=True) + assert isinstance(results, AsyncSearchModeResponse) + assert captured["json"]["explain_score"] is True + + # Reset captured json, then paginate — explain_score should persist + captured = {} + results_2 = await results.next(limit=2, offset=1) + assert isinstance(results_2, AsyncSearchModeResponse) + assert captured["json"]["explain_score"] is True + + async def test_async_search_only_mode_failure(monkeypatch): monkeypatch.setattr(httpx.AsyncClient, "post", fake_async_post_failure) dummy_client = DummyClient() diff --git a/weaviate_agents/query/classes/request.py b/weaviate_agents/query/classes/request.py index aee9ee7..539458d 100644 --- a/weaviate_agents/query/classes/request.py +++ b/weaviate_agents/query/classes/request.py @@ -18,6 +18,7 @@ class SearchModeRequestBase(BaseModel): limit: int offset: int diversity_weight: Optional[float] = None + explain_score: bool = False class SearchModeExecutionRequest(SearchModeRequestBase): diff --git a/weaviate_agents/query/query_agent.py b/weaviate_agents/query/query_agent.py index d3e9a64..cfd86d6 100644 --- a/weaviate_agents/query/query_agent.py +++ b/weaviate_agents/query/query_agent.py @@ -1014,6 +1014,7 @@ def search( limit: int = 20, collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, diversity_weight: Optional[float] = None, + explain_score: bool = False, ) -> SearchModeResponse: """Run the Query Agent search-only mode. @@ -1031,6 +1032,9 @@ def search( results with MMR reranking. Higher values push for more topical variety at the cost of relevance. Defaults to None (no diversity). + explain_score: If True, each result will include a short natural-language + explanation of why it is or isn't relevant to the query. + Defaults to False. Returns: An instance of :class:`~weaviate_agents.query.classes.response.SearchModeResponse` for the first page of results. Use @@ -1066,6 +1070,7 @@ def search( collections=collections, system_prompt=self._system_prompt, diversity_weight=diversity_weight, + explain_score=explain_score, ) return searcher.run(limit=limit) @@ -1653,6 +1658,7 @@ async def search( limit: int = 20, collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, diversity_weight: Optional[float] = None, + explain_score: bool = False, ) -> AsyncSearchModeResponse: """Run the Query Agent search-only mode. @@ -1671,6 +1677,9 @@ async def search( results with MMR reranking. Higher values push for more topical variety at the cost of relevance. Defaults to None (no diversity). + explain_score: If True, each result will include a short natural-language + explanation of why it is or isn't relevant to the query. + Defaults to False. Returns: An instance of :class:`~weaviate_agents.query.classes.response.AsyncSearchModeResponse` for the first page of results. Use @@ -1706,6 +1715,7 @@ async def search( collections=collections, system_prompt=self._system_prompt, diversity_weight=diversity_weight, + explain_score=explain_score, ) return await searcher.run(limit=limit) diff --git a/weaviate_agents/query/search.py b/weaviate_agents/query/search.py index 77c2365..0642148 100644 --- a/weaviate_agents/query/search.py +++ b/weaviate_agents/query/search.py @@ -36,6 +36,7 @@ def __init__( collections: list[Union[str, QueryAgentCollectionConfig]], system_prompt: Optional[str], diversity_weight: Optional[float] = None, + explain_score: bool = False, ): self.headers = headers self.connection_headers = connection_headers @@ -45,6 +46,7 @@ def __init__( self.collections = collections self.system_prompt = system_prompt self.diversity_weight = diversity_weight + self.explain_score = explain_score self._cached_searches: Optional[list[QueryResultWithCollectionNormalized]] = ( None ) @@ -64,6 +66,7 @@ def _get_request_body(self, limit: int, offset: int) -> dict[str, Any]: offset=offset, system_prompt=self.system_prompt, diversity_weight=self.diversity_weight, + explain_score=self.explain_score, ).model_dump(mode="json") else: return SearchModeExecutionRequest( @@ -74,6 +77,7 @@ def _get_request_body(self, limit: int, offset: int) -> dict[str, Any]: offset=offset, searches=self._cached_searches, diversity_weight=self.diversity_weight, + explain_score=self.explain_score, ).model_dump(mode="json")