diff --git a/test/query_agent/test_query_model.py b/test/query_agent/test_query_model.py index 3bb0c76..65b13a9 100644 --- a/test/query_agent/test_query_model.py +++ b/test/query_agent/test_query_model.py @@ -2119,3 +2119,97 @@ async def test_async_suggest_queries_failure(monkeypatch): str(exc_info.value) == "{'error': {'message': 'Test error message', 'code': 'test_error_code', 'details': {'info': 'test detail'}}}" ) + + +def test_suggest_queries_with_conversation(monkeypatch): + captured = {} + + def fake_post_with_capture(url, headers=None, json=None, timeout=None): + captured["json"] = json + return FakeResponse(200, FAKE_SUGGEST_QUERIES_SUCCESS_JSON) + + monkeypatch.setattr(httpx, "post", fake_post_with_capture) + dummy_client = DummyClient() + agent = QueryAgent( + dummy_client, ["test_collection"], agents_host="http://dummy-agent" + ) + agent._connection = dummy_client + agent._headers = dummy_client.additional_headers + + chat_messages: list[ChatMessage] = [ + {"role": "user", "content": "What topics are covered?"}, + {"role": "assistant", "content": "The collection covers ML and economics."}, + ] + + result = agent.suggest_queries(["test_collection"], conversation=chat_messages) + + assert isinstance(result, SuggestQueryResponse) + assert captured["json"]["conversation_context"] == {"messages": chat_messages} + + +def test_suggest_queries_without_conversation(monkeypatch): + captured = {} + + def fake_post_with_capture(url, headers=None, json=None, timeout=None): + captured["json"] = json + return FakeResponse(200, FAKE_SUGGEST_QUERIES_SUCCESS_JSON) + + monkeypatch.setattr(httpx, "post", fake_post_with_capture) + dummy_client = DummyClient() + agent = QueryAgent( + dummy_client, ["test_collection"], agents_host="http://dummy-agent" + ) + agent._connection = dummy_client + agent._headers = dummy_client.additional_headers + + agent.suggest_queries(["test_collection"]) + + assert "conversation_context" not in captured["json"] + + +async def test_async_suggest_queries_with_conversation(monkeypatch): + captured = {} + + async def fake_async_post_with_capture(*args, **kwargs): + captured["json"] = kwargs.get("json") + return FakeResponse(200, FAKE_SUGGEST_QUERIES_SUCCESS_JSON) + + monkeypatch.setattr(httpx.AsyncClient, "post", fake_async_post_with_capture) + dummy_client = DummyClient() + agent = AsyncQueryAgent( + dummy_client, ["test_collection"], agents_host="http://dummy-agent" + ) + agent._connection = dummy_client + agent._headers = dummy_client.additional_headers + + chat_messages: list[ChatMessage] = [ + {"role": "user", "content": "What topics are covered?"}, + {"role": "assistant", "content": "The collection covers ML and economics."}, + ] + + result = await agent.suggest_queries( + ["test_collection"], conversation=chat_messages + ) + + assert isinstance(result, SuggestQueryResponse) + assert captured["json"]["conversation_context"] == {"messages": chat_messages} + + +async def test_async_suggest_queries_without_conversation(monkeypatch): + captured = {} + + async def fake_async_post_with_capture(*args, **kwargs): + captured["json"] = kwargs.get("json") + return FakeResponse(200, FAKE_SUGGEST_QUERIES_SUCCESS_JSON) + + monkeypatch.setattr(httpx.AsyncClient, "post", fake_async_post_with_capture) + dummy_client = DummyClient() + agent = AsyncQueryAgent( + dummy_client, ["test_collection"], agents_host="http://dummy-agent" + ) + agent._connection = dummy_client + agent._headers = dummy_client.additional_headers + + await agent.suggest_queries(["test_collection"]) + + assert "conversation_context" not in captured["json"] diff --git a/weaviate_agents/query/query_agent.py b/weaviate_agents/query/query_agent.py index d3e9a64..5797c24 100644 --- a/weaviate_agents/query/query_agent.py +++ b/weaviate_agents/query/query_agent.py @@ -508,6 +508,7 @@ def suggest_queries( collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, num_queries: int = 3, instructions: Optional[str] = None, + conversation: Optional[list[ChatMessage]] = None, ) -> Union[SuggestQueryResponse, Coroutine[Any, Any, SuggestQueryResponse]]: pass @@ -1074,6 +1075,7 @@ def suggest_queries( collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, num_queries: int = 3, instructions: Optional[str] = None, + conversation: Optional[list[ChatMessage]] = None, ) -> SuggestQueryResponse: """Suggest queries for the data in your collections. @@ -1090,6 +1092,10 @@ def suggest_queries( instructions: Optional natural language guidance for the style, topic, or language of the suggested queries. This is supplied in addition to the agent's system instructions. + conversation: + Optional list of chat messages representing a prior conversation. + When provided, the suggested queries will be contextualised as + follow-up questions to the conversation. Returns: An instance of :class:`~weaviate_agents.query.classes.response.SuggestQueryResponse` which @@ -1136,6 +1142,10 @@ def suggest_queries( } if instructions is not None: request_body["instructions"] = instructions + if conversation is not None: + request_body["conversation_context"] = ConversationContext( + messages=conversation + ).model_dump(mode="json") response = httpx.post( self.query_url + "/suggest_queries", @@ -1714,6 +1724,7 @@ async def suggest_queries( collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, num_queries: int = 3, instructions: Optional[str] = None, + conversation: Optional[list[ChatMessage]] = None, ) -> SuggestQueryResponse: """Suggest queries for the data in your collections. @@ -1730,6 +1741,10 @@ async def suggest_queries( instructions: Optional natural language guidance for the style, topic, or language of the suggested queries. This is supplied in addition to the agent's system instructions. + conversation: + Optional list of chat messages representing a prior conversation. + When provided, the suggested queries will be contextualised as + follow-up questions to the conversation. Returns: An instance of :class:`~weaviate_agents.query.classes.response.SuggestQueryResponse` which @@ -1776,6 +1791,10 @@ async def suggest_queries( } if instructions is not None: request_body["instructions"] = instructions + if conversation is not None: + request_body["conversation_context"] = ConversationContext( + messages=conversation + ).model_dump(mode="json") async with httpx.AsyncClient() as client: response = await client.post(