From d83398af9ff02e77665d3c6934b3daf294bcb0a5 Mon Sep 17 00:00:00 2001 From: dannyjameswilliams Date: Wed, 17 Jun 2026 16:19:33 +0100 Subject: [PATCH 1/5] feat: added parsing of structured outputs --- weaviate_agents/query/classes/response.py | 2 +- weaviate_agents/query/query_agent.py | 121 ++++++++++++++++++++-- weaviate_agents/utils.py | 22 +++- 3 files changed, 134 insertions(+), 11 deletions(-) diff --git a/weaviate_agents/query/classes/response.py b/weaviate_agents/query/classes/response.py index 61c99b2..2e6c035 100644 --- a/weaviate_agents/query/classes/response.py +++ b/weaviate_agents/query/classes/response.py @@ -470,7 +470,7 @@ class AskModeResponse(BaseModel): total_time: float is_partial_answer: Union[bool, None] missing_information: Union[list[str], None] - final_answer: str + final_answer: Union[str, BaseModel] sources: Union[list[Source], None] def display(self) -> None: diff --git a/weaviate_agents/query/query_agent.py b/weaviate_agents/query/query_agent.py index 7f0756c..82daecb 100644 --- a/weaviate_agents/query/query_agent.py +++ b/weaviate_agents/query/query_agent.py @@ -14,6 +14,7 @@ import httpx from httpx_sse import ServerSentEvent, aconnect_sse, connect_sse +from pydantic import BaseModel from typing_extensions import deprecated from weaviate.client import WeaviateAsyncClient, WeaviateClient @@ -60,6 +61,7 @@ def _prepare_request_body( query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, context: Optional[QueryAgentResponse] = None, + output_format: Union[type[str], type[BaseModel]] = str, **kwargs, ) -> dict: """Prepare the request body for the query. @@ -68,6 +70,7 @@ def _prepare_request_body( query: The natural language query string for the agent. collections: The collections to query. Will override any collections if passed in the constructor. context: Optional previous response from the agent. + output_format: The format of the output to return. Either a `str` (default) or a `BaseModel` subclass. **kwargs: Additional keyword arguments to pass to the request body. """ collections = collections or self._collections @@ -79,6 +82,11 @@ def _prepare_request_body( if isinstance(query, str) else ConversationContext(messages=query).model_dump(mode="json") ) + if isinstance(output_format, type) and issubclass(output_format, BaseModel): + output_format_json = output_format.model_json_schema() + else: + output_format_json = None + output = { "query": query_request, "collections": [ @@ -91,6 +99,7 @@ def _prepare_request_body( ], "headers": self._connection.additional_headers, "system_prompt": self._system_prompt, + "output_format": output_format_json, **kwargs, } if context is not None: @@ -145,6 +154,15 @@ def _prepare_research_mode_request_body( return output + def _parse_ask_result( + self, response: dict[str, Any], output_format: Union[type[str], type[BaseModel]] + ) -> AskModeResponse: + if isinstance(output_format, type) and issubclass(output_format, BaseModel): + response["final_answer"] = output_format.model_validate_json( + response["final_answer"] + ) + return AskModeResponse(**response) + @deprecated( "QueryAgent.run() is deprecated and will be removed in a future release. " "Use QueryAgent.ask() instead." @@ -277,6 +295,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], type[BaseModel]] = str, ) -> Union[ Generator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None, None], AsyncGenerator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None], @@ -290,6 +309,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], type[BaseModel]] = str, ) -> Union[ Generator[Union[ProgressMessage, StreamedTokens], None, None], AsyncGenerator[Union[ProgressMessage, StreamedTokens], None], @@ -303,6 +323,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], type[BaseModel]] = str, ) -> Union[ Generator[Union[StreamedTokens, AskModeResponse], None, None], AsyncGenerator[Union[StreamedTokens, AskModeResponse], None], @@ -316,6 +337,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], type[BaseModel]] = str, ) -> Union[ Generator[StreamedTokens, None, None], AsyncGenerator[StreamedTokens, None], @@ -329,6 +351,7 @@ def ask_stream( include_progress: bool = True, include_final_state: bool = True, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], type[BaseModel]] = str, ) -> Union[ Generator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None, None], AsyncGenerator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None], @@ -582,6 +605,7 @@ def ask( query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], type[BaseModel]] = str, ) -> AskModeResponse: """Run the Query Agent ask mode. @@ -596,6 +620,12 @@ def ask( Also populates the fields `missing_information` and `is_partial_answer` of the response. If ``"none"``, the result will not be evaluated, and the sources will not be filtered. Defaults to ``"none"``. + output_format: The structured output format to return. Either a `str` (default) or a `BaseModel` subclass. + This defines the type of the `final_answer` field in the response. + The LLM will conform to the output format specified. + If a `BaseModel` subclass is provided, it will be serialized to a JSON object and returned in the response. + If a `str` is provided, it will be returned as a string in the response. + Defaults to `str`. Returns: An instance of :class:`~weaviate_agents.query.classes.response.AskModeResponse` which contains the final answer, sources, @@ -608,9 +638,28 @@ def ask( ... collections=["FinancialContracts"], ... ) >>> agent.ask("What are the terms of the contract signed by John Smith in May 2025?") + + >>> from weaviate_agents import QueryAgent + >>> from pydantic import BaseModel, Field + >>> + >>> class CitedText(BaseModel): + ... text: str + ... sources: list[str] = Field(description="The sources that support this section of text. Can be empty.") + >>> + >>> class AnswerWithSources(BaseModel): + ... texts: list[CitedText] + >>> + >>> agent = QueryAgent( + ... client=client, + ... collections=["FinancialContracts"], + ... ) + >>> agent.ask("What contracts were signed by Jane Doe in 2024? What were they about?", output_format=AnswerWithSources) """ request_body = self._prepare_request_body( - query=query, collections=collections, result_evaluation=result_evaluation + query=query, + collections=collections, + result_evaluation=result_evaluation, + output_format=output_format, ) response = httpx.post( @@ -623,7 +672,9 @@ def ask( if response.is_error: raise Exception(response.text) - return AskModeResponse(**response.json()) + return self._parse_ask_result( + response=response.json(), output_format=output_format + ) @overload def stream( @@ -718,6 +769,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], type[BaseModel]] = str, ) -> Generator[ Union[ProgressMessage, StreamedTokens, AskModeResponse], None, None ]: ... @@ -730,6 +782,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], type[BaseModel]] = str, ) -> Generator[Union[ProgressMessage, StreamedTokens], None, None]: ... @overload @@ -740,6 +793,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], type[BaseModel]] = str, ) -> Generator[Union[StreamedTokens, AskModeResponse], None, None]: ... @overload @@ -750,6 +804,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], type[BaseModel]] = str, ) -> Generator[StreamedTokens, None, None]: ... def ask_stream( @@ -759,6 +814,7 @@ def ask_stream( include_progress: bool = True, include_final_state: bool = True, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], type[BaseModel]] = str, ): """Run the Query Agent ask mode and stream the response. @@ -773,6 +829,12 @@ def ask_stream( Also populates the fields `missing_information` and `is_partial_answer` of the response. If ``"none"``, the result will not be evaluated, and the sources will not be filtered. Defaults to ``"none"``. + output_format: The structured output format to return. Either a `str` (default) or a `BaseModel` subclass. + This defines the type of the `final_answer` field in the response. + The LLM will conform to the output format specified. + If a `BaseModel` subclass is provided, it will be serialized to a JSON object and returned in the response. + If a `str` is provided, it will be returned as a string in the response. + Defaults to `str`. Returns: A generator of the response stream. @@ -803,6 +865,7 @@ def ask_stream( include_progress=include_progress, include_final_state=include_final_state, result_evaluation=result_evaluation, + output_format=output_format, ) with httpx.Client() as client: with connect_sse( @@ -824,7 +887,10 @@ def ask_stream( yield output elif isinstance(output, AskModeResponse): if include_final_state: - yield output + yield self._parse_ask_result( + response=output.model_dump(mode="json"), + output_format=output_format, + ) else: yield output @@ -1227,6 +1293,7 @@ async def ask( query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], type[BaseModel]] = str, ) -> AskModeResponse: """Run the Query Agent ask mode. @@ -1240,6 +1307,12 @@ async def ask( Also populates the fields `missing_information` and `is_partial_answer` of the response. If ``"none"``, the result will not be evaluated, and the sources will not be filtered. Defaults to ``"none"``. + output_format: The structured output format to return. Either a `str` (default) or a `BaseModel` subclass. + This defines the type of the `final_answer` field in the response. + The LLM will conform to the output format specified. + If a `BaseModel` subclass is provided, it will be serialized to a JSON object and returned in the response. + If a `str` is provided, it will be returned as a string in the response. + Defaults to `str`. Returns: An instance of :class:`~weaviate_agents.query.classes.response.AskModeResponse` which contains the final answer, sources, @@ -1252,9 +1325,28 @@ async def ask( ... collections=["FinancialContracts"], ... ) >>> await agent.ask("What are the terms of the contract signed by John Smith in May 2025?") + + >>> from weaviate_agents import QueryAgent + >>> from pydantic import BaseModel, Field + >>> + >>> class CitedText(BaseModel): + ... text: str + ... sources: list[str] = Field(description="The sources that support this section of text. Can be empty.") + >>> + >>> class AnswerWithSources(BaseModel): + ... texts: list[CitedText] + >>> + >>> agent = QueryAgent( + ... client=client, + ... collections=["FinancialContracts"], + ... ) + >>> await agent.ask("What contracts were signed by Jane Doe in 2024? What were they about?", output_format=AnswerWithSources) """ request_body = self._prepare_request_body( - query=query, collections=collections, result_evaluation=result_evaluation + query=query, + collections=collections, + result_evaluation=result_evaluation, + output_format=output_format, ) async with httpx.AsyncClient() as client: @@ -1268,7 +1360,9 @@ async def ask( if response.is_error: raise Exception(response.text) - return AskModeResponse(**response.json()) + return self._parse_ask_result( + response=response.json(), output_format=output_format + ) @overload def stream( @@ -1364,6 +1458,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], type[BaseModel]] = str, ) -> AsyncGenerator[ Union[ProgressMessage, StreamedTokens, AskModeResponse], None ]: ... @@ -1376,6 +1471,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], type[BaseModel]] = str, ) -> AsyncGenerator[Union[ProgressMessage, StreamedTokens], None]: ... @overload @@ -1386,6 +1482,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], type[BaseModel]] = str, ) -> AsyncGenerator[Union[StreamedTokens, AskModeResponse], None]: ... @overload @@ -1396,6 +1493,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], type[BaseModel]] = str, ) -> AsyncGenerator[StreamedTokens, None]: ... async def ask_stream( @@ -1405,6 +1503,7 @@ async def ask_stream( include_progress: bool = True, include_final_state: bool = True, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], type[BaseModel]] = str, ): """Run the Query Agent ask mode and stream the response. @@ -1419,6 +1518,12 @@ async def ask_stream( Also populates the fields `missing_information` and `is_partial_answer` of the response. If ``"none"``, the result will not be evaluated, and the sources will not be filtered. Defaults to ``"none"``. + output_format: The structured output format to return. Either a `str` (default) or a `BaseModel` subclass. + This defines the type of the `final_answer` field in the response. + The LLM will conform to the output format specified. + If a `BaseModel` subclass is provided, it will be serialized to a JSON object and returned in the response. + If a `str` is provided, it will be returned as a string in the response. + Defaults to `str`. Returns: A generator of the response stream. @@ -1449,6 +1554,7 @@ async def ask_stream( include_progress=include_progress, include_final_state=include_final_state, result_evaluation=result_evaluation, + output_format=output_format, ) async with httpx.AsyncClient() as client: async with aconnect_sse( @@ -1470,7 +1576,10 @@ async def ask_stream( yield output elif isinstance(output, AskModeResponse): if include_final_state: - yield output + yield self._parse_ask_result( + response=output.model_dump(mode="json"), + output_format=output_format, + ) else: yield output diff --git a/weaviate_agents/utils.py b/weaviate_agents/utils.py index 51e760d..5e6eaad 100644 --- a/weaviate_agents/utils.py +++ b/weaviate_agents/utils.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING +from pydantic import BaseModel from rich.console import Console from rich.panel import Panel from rich.pretty import Pretty @@ -89,11 +90,24 @@ def print_query_agent_response(response: "QueryAgentResponse"): def print_ask_mode_response(response: "AskModeResponse"): """Prints a formatted response from the Ask Mode using rich.""" - console.print( - Panel( - response.final_answer, title="💬 Ask Mode Response", style="cyan", padding=1 + if isinstance(response.final_answer, BaseModel): + console.print( + Panel( + Pretty(response.final_answer), + title="💬 Ask Mode Response", + style="cyan", + padding=1, + ) + ) + else: + console.print( + Panel( + response.final_answer, + title="💬 Ask Mode Response", + style="cyan", + padding=1, + ) ) - ) for i, result in enumerate(response.searches): search_content = Pretty(result) From 97d467afd2aa4fdd6722bb5474a73e980d0409ae Mon Sep 17 00:00:00 2001 From: dannyjameswilliams Date: Thu, 18 Jun 2026 10:31:28 +0100 Subject: [PATCH 2/5] output_format can be dict, updated parse_ask_result, typing of final answer --- weaviate_agents/query/classes/response.py | 2 +- weaviate_agents/query/query_agent.py | 157 ++++++++++++++++------ weaviate_agents/utils.py | 4 +- 3 files changed, 120 insertions(+), 43 deletions(-) diff --git a/weaviate_agents/query/classes/response.py b/weaviate_agents/query/classes/response.py index 2e6c035..8ffd1f6 100644 --- a/weaviate_agents/query/classes/response.py +++ b/weaviate_agents/query/classes/response.py @@ -470,7 +470,7 @@ class AskModeResponse(BaseModel): total_time: float is_partial_answer: Union[bool, None] missing_information: Union[list[str], None] - final_answer: Union[str, BaseModel] + final_answer: Union[str, dict, BaseModel] sources: Union[list[Source], None] def display(self) -> None: diff --git a/weaviate_agents/query/query_agent.py b/weaviate_agents/query/query_agent.py index 82daecb..5e66a30 100644 --- a/weaviate_agents/query/query_agent.py +++ b/weaviate_agents/query/query_agent.py @@ -1,5 +1,6 @@ +import warnings from abc import ABC, abstractmethod -from json import JSONDecodeError +from json import JSONDecodeError, loads from typing import ( Any, AsyncGenerator, @@ -61,7 +62,7 @@ def _prepare_request_body( query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, context: Optional[QueryAgentResponse] = None, - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, **kwargs, ) -> dict: """Prepare the request body for the query. @@ -70,7 +71,7 @@ def _prepare_request_body( query: The natural language query string for the agent. collections: The collections to query. Will override any collections if passed in the constructor. context: Optional previous response from the agent. - output_format: The format of the output to return. Either a `str` (default) or a `BaseModel` subclass. + output_format: The format of the output to return. Either a `str` (default), a `BaseModel` subclass, or a `dict` (a Draft 2020-12 JSON Schema). **kwargs: Additional keyword arguments to pass to the request body. """ collections = collections or self._collections @@ -84,6 +85,8 @@ def _prepare_request_body( ) if isinstance(output_format, type) and issubclass(output_format, BaseModel): output_format_json = output_format.model_json_schema() + elif isinstance(output_format, dict): + output_format_json = output_format else: output_format_json = None @@ -155,12 +158,21 @@ def _prepare_research_mode_request_body( return output def _parse_ask_result( - self, response: dict[str, Any], output_format: Union[type[str], type[BaseModel]] + self, + response: dict[str, Any], + output_format: Union[type[str], dict[str, Any], type[BaseModel]], ) -> AskModeResponse: if isinstance(output_format, type) and issubclass(output_format, BaseModel): response["final_answer"] = output_format.model_validate_json( response["final_answer"] ) + elif isinstance(output_format, dict): + try: + response["final_answer"] = loads(response["final_answer"]) + except JSONDecodeError: + warnings.warn( + "Unable to decode final answer as dictionary, returning as string" + ) return AskModeResponse(**response) @deprecated( @@ -191,6 +203,7 @@ def ask( query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ) -> Union[AskModeResponse, Coroutine[Any, Any, AskModeResponse]]: pass @@ -295,7 +308,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ) -> Union[ Generator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None, None], AsyncGenerator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None], @@ -309,7 +322,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ) -> Union[ Generator[Union[ProgressMessage, StreamedTokens], None, None], AsyncGenerator[Union[ProgressMessage, StreamedTokens], None], @@ -323,7 +336,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ) -> Union[ Generator[Union[StreamedTokens, AskModeResponse], None, None], AsyncGenerator[Union[StreamedTokens, AskModeResponse], None], @@ -337,7 +350,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ) -> Union[ Generator[StreamedTokens, None, None], AsyncGenerator[StreamedTokens, None], @@ -351,7 +364,7 @@ def ask_stream( include_progress: bool = True, include_final_state: bool = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ) -> Union[ Generator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None, None], AsyncGenerator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None], @@ -605,7 +618,7 @@ def ask( query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ) -> AskModeResponse: """Run the Query Agent ask mode. @@ -620,11 +633,11 @@ def ask( Also populates the fields `missing_information` and `is_partial_answer` of the response. If ``"none"``, the result will not be evaluated, and the sources will not be filtered. Defaults to ``"none"``. - output_format: The structured output format to return. Either a `str` (default) or a `BaseModel` subclass. - This defines the type of the `final_answer` field in the response. + output_format: The structured output format to return. Either a `str` (default), a `BaseModel` subclass, or a dictionary. + This enforces the output format of the final answer to be of this schema. The LLM will conform to the output format specified. - If a `BaseModel` subclass is provided, it will be serialized to a JSON object and returned in the response. - If a `str` is provided, it will be returned as a string in the response. + The `final_answer` output field in the response will also be of the type specified. + When passing a `dict`, the dictionary must conform to the Draft 2020-12 JSON Schema specification. Defaults to `str`. Returns: @@ -653,7 +666,9 @@ def ask( ... client=client, ... collections=["FinancialContracts"], ... ) - >>> agent.ask("What contracts were signed by Jane Doe in 2024? What were they about?", output_format=AnswerWithSources) + >>> result = agent.ask("What contracts were signed by Jane Doe in 2024? What were they about?", output_format=AnswerWithSources) + >>> print(type(result.final_answer)) + """ request_body = self._prepare_request_body( query=query, @@ -769,7 +784,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ) -> Generator[ Union[ProgressMessage, StreamedTokens, AskModeResponse], None, None ]: ... @@ -782,7 +797,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ) -> Generator[Union[ProgressMessage, StreamedTokens], None, None]: ... @overload @@ -793,7 +808,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ) -> Generator[Union[StreamedTokens, AskModeResponse], None, None]: ... @overload @@ -804,7 +819,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ) -> Generator[StreamedTokens, None, None]: ... def ask_stream( @@ -814,7 +829,7 @@ def ask_stream( include_progress: bool = True, include_final_state: bool = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ): """Run the Query Agent ask mode and stream the response. @@ -829,11 +844,15 @@ def ask_stream( Also populates the fields `missing_information` and `is_partial_answer` of the response. If ``"none"``, the result will not be evaluated, and the sources will not be filtered. Defaults to ``"none"``. - output_format: The structured output format to return. Either a `str` (default) or a `BaseModel` subclass. - This defines the type of the `final_answer` field in the response. + output_format: The structured output format to return. Either a `str` (default), a `BaseModel` subclass, or a dictionary. + This enforces the output format of the final answer to be of this schema. The LLM will conform to the output format specified. - If a `BaseModel` subclass is provided, it will be serialized to a JSON object and returned in the response. - If a `str` is provided, it will be returned as a string in the response. + The `final_answer` output field in the response will also be of the type specified. + When passing a `dict`, the dictionary must conform to the Draft 2020-12 JSON Schema specification. + Whilst streaming, the :class:`~weaviate_agents.query.classes.response.StreamedTokens` will return delta text + tokens on the final answer as it is being constructed as raw string tokens, not a JSON object. + When the final answer is complete, the :class:`~weaviate_agents.query.classes.response.AskModeResponse` will be parsed + and the `final_answer` field will be of the type specified. Defaults to `str`. Returns: @@ -858,6 +877,31 @@ def ask_stream( ... print(result.delta, end='', flush=True) ... elif isinstance(result, ProgressMessage): ... print(result.message) + + >>> from weaviate_agents import QueryAgent + >>> from pydantic import BaseModel, Field + >>> + >>> class CitedText(BaseModel): + ... text: str + ... sources: list[str] = Field(description="The sources that support this section of text. Can be empty.") + >>> + >>> class AnswerWithSources(BaseModel): + ... texts: list[CitedText] + >>> + >>> agent = QueryAgent( + ... client=client, + ... collections=["FinancialContracts"], + ... ) + >>> for result in agent.ask_stream( + ... "What contracts were signed by Jane Doe in 2024? What were they about?", + ... output_format=AnswerWithSources + ... ): + ... if isinstance(result, AskModeResponse): + ... result.display() + ... elif isinstance(result, StreamedTokens): + ... print(result.delta, end='', flush=True) + ... elif isinstance(result, ProgressMessage): + ... print(result.message) """ request_body = self._prepare_request_body( query=query, @@ -1293,7 +1337,7 @@ async def ask( query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ) -> AskModeResponse: """Run the Query Agent ask mode. @@ -1307,11 +1351,11 @@ async def ask( Also populates the fields `missing_information` and `is_partial_answer` of the response. If ``"none"``, the result will not be evaluated, and the sources will not be filtered. Defaults to ``"none"``. - output_format: The structured output format to return. Either a `str` (default) or a `BaseModel` subclass. - This defines the type of the `final_answer` field in the response. + output_format: The structured output format to return. Either a `str` (default), a `BaseModel` subclass, or a dictionary. + This enforces the output format of the final answer to be of this schema. The LLM will conform to the output format specified. - If a `BaseModel` subclass is provided, it will be serialized to a JSON object and returned in the response. - If a `str` is provided, it will be returned as a string in the response. + The `final_answer` output field in the response will also be of the type specified. + When passing a `dict`, the dictionary must conform to the Draft 2020-12 JSON Schema specification. Defaults to `str`. Returns: @@ -1326,7 +1370,7 @@ async def ask( ... ) >>> await agent.ask("What are the terms of the contract signed by John Smith in May 2025?") - >>> from weaviate_agents import QueryAgent + >>> from weaviate_agents import AsyncQueryAgent >>> from pydantic import BaseModel, Field >>> >>> class CitedText(BaseModel): @@ -1336,11 +1380,13 @@ async def ask( >>> class AnswerWithSources(BaseModel): ... texts: list[CitedText] >>> - >>> agent = QueryAgent( + >>> agent = AsyncQueryAgent( ... client=client, ... collections=["FinancialContracts"], ... ) - >>> await agent.ask("What contracts were signed by Jane Doe in 2024? What were they about?", output_format=AnswerWithSources) + >>> result = await agent.ask("What contracts were signed by Jane Doe in 2024? What were they about?", output_format=AnswerWithSources) + >>> print(type(result.final_answer)) + """ request_body = self._prepare_request_body( query=query, @@ -1458,7 +1504,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ) -> AsyncGenerator[ Union[ProgressMessage, StreamedTokens, AskModeResponse], None ]: ... @@ -1471,7 +1517,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ) -> AsyncGenerator[Union[ProgressMessage, StreamedTokens], None]: ... @overload @@ -1482,7 +1528,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ) -> AsyncGenerator[Union[StreamedTokens, AskModeResponse], None]: ... @overload @@ -1493,7 +1539,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ) -> AsyncGenerator[StreamedTokens, None]: ... async def ask_stream( @@ -1503,7 +1549,7 @@ async def ask_stream( include_progress: bool = True, include_final_state: bool = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], type[BaseModel]] = str, + output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, ): """Run the Query Agent ask mode and stream the response. @@ -1518,11 +1564,15 @@ async def ask_stream( Also populates the fields `missing_information` and `is_partial_answer` of the response. If ``"none"``, the result will not be evaluated, and the sources will not be filtered. Defaults to ``"none"``. - output_format: The structured output format to return. Either a `str` (default) or a `BaseModel` subclass. - This defines the type of the `final_answer` field in the response. + output_format: The structured output format to return. Either a `str` (default), a `BaseModel` subclass, or a dictionary. + This enforces the output format of the final answer to be of this schema. The LLM will conform to the output format specified. - If a `BaseModel` subclass is provided, it will be serialized to a JSON object and returned in the response. - If a `str` is provided, it will be returned as a string in the response. + The `final_answer` output field in the response will also be of the type specified. + When passing a `dict`, the dictionary must conform to the Draft 2020-12 JSON Schema specification. + Whilst streaming, the :class:`~weaviate_agents.query.classes.response.StreamedTokens` will return delta text + tokens on the final answer as it is being constructed as raw string tokens, not a JSON object. + When the final answer is complete, the :class:`~weaviate_agents.query.classes.response.AskModeResponse` will be parsed + and the `final_answer` field will be of the type specified. Defaults to `str`. Returns: @@ -1547,6 +1597,31 @@ async def ask_stream( ... print(result.delta, end='', flush=True) ... elif isinstance(result, ProgressMessage): ... print(result.message) + + >>> from weaviate_agents import QueryAgent + >>> from pydantic import BaseModel, Field + >>> + >>> class CitedText(BaseModel): + ... text: str + ... sources: list[str] = Field(description="The sources that support this section of text. Can be empty.") + >>> + >>> class AnswerWithSources(BaseModel): + ... texts: list[CitedText] + >>> + >>> agent = QueryAgent( + ... client=client, + ... collections=["FinancialContracts"], + ... ) + >>> async for result in agent.ask_stream( + ... "What contracts were signed by Jane Doe in 2024? What were they about?", + ... output_format=AnswerWithSources + ... ): + ... if isinstance(result, AskModeResponse): + ... result.display() + ... elif isinstance(result, StreamedTokens): + ... print(result.delta, end='', flush=True) + ... elif isinstance(result, ProgressMessage): + ... print(result.message) """ request_body = self._prepare_request_body( query=query, diff --git a/weaviate_agents/utils.py b/weaviate_agents/utils.py index 5e6eaad..cee7429 100644 --- a/weaviate_agents/utils.py +++ b/weaviate_agents/utils.py @@ -90,7 +90,9 @@ def print_query_agent_response(response: "QueryAgentResponse"): def print_ask_mode_response(response: "AskModeResponse"): """Prints a formatted response from the Ask Mode using rich.""" - if isinstance(response.final_answer, BaseModel): + if isinstance(response.final_answer, BaseModel) or isinstance( + response.final_answer, dict + ): console.print( Panel( Pretty(response.final_answer), From 38269e8ad02db7285dcf129b49fe3cc10eed5f99 Mon Sep 17 00:00:00 2001 From: dannyjameswilliams Date: Fri, 19 Jun 2026 12:14:26 +0100 Subject: [PATCH 3/5] overload city --- weaviate_agents/query/classes/__init__.py | 2 + weaviate_agents/query/classes/response.py | 11 +- weaviate_agents/query/query_agent.py | 474 +++++++++++++++++++++- weaviate_agents/utils.py | 19 +- 4 files changed, 478 insertions(+), 28 deletions(-) diff --git a/weaviate_agents/query/classes/__init__.py b/weaviate_agents/query/classes/__init__.py index 1f9c3f1..1bc05ef 100644 --- a/weaviate_agents/query/classes/__init__.py +++ b/weaviate_agents/query/classes/__init__.py @@ -24,6 +24,7 @@ IsNullPropertyFilter, ModelUnitUsage, NumericMetrics, + ParsedAskModeResponse, ProgressDetails, ProgressMessage, QueryAgentResponse, @@ -79,6 +80,7 @@ "GeoPropertyFilter", "UnknownPropertyAggregation", "UnknownPropertyFilter", + "ParsedAskModeResponse", "ProgressDetails", "ProgressMessage", "QueryWithCollection", diff --git a/weaviate_agents/query/classes/response.py b/weaviate_agents/query/classes/response.py index 8ffd1f6..5291050 100644 --- a/weaviate_agents/query/classes/response.py +++ b/weaviate_agents/query/classes/response.py @@ -15,7 +15,7 @@ ) from uuid import UUID -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, field_validator from typing_extensions import TypedDict from weaviate.outputs.query import QueryReturn @@ -470,7 +470,7 @@ class AskModeResponse(BaseModel): total_time: float is_partial_answer: Union[bool, None] missing_information: Union[list[str], None] - final_answer: Union[str, dict, BaseModel] + final_answer: str sources: Union[list[Source], None] def display(self) -> None: @@ -479,6 +479,13 @@ def display(self) -> None: return None +T = TypeVar("T", bound=Union[dict, BaseModel]) + + +class ParsedAskModeResponse(AskModeResponse, Generic[T]): + final_answer_parsed: SerializeAsAny[T] + + class ResearchModeResponse(BaseModel): output_type: Literal["final_state"] = "final_state" diff --git a/weaviate_agents/query/query_agent.py b/weaviate_agents/query/query_agent.py index 5e66a30..3f0ba3e 100644 --- a/weaviate_agents/query/query_agent.py +++ b/weaviate_agents/query/query_agent.py @@ -9,6 +9,7 @@ Generic, Literal, Optional, + TypeVar, Union, overload, ) @@ -22,6 +23,7 @@ from weaviate_agents.base import ClientType, _BaseAgent from weaviate_agents.query.classes import ( AskModeResponse, + ParsedAskModeResponse, ProgressMessage, QueryAgentCollectionConfig, QueryAgentResponse, @@ -38,6 +40,10 @@ SearchModeResponse, ) +# Bound to BaseModel so a `output_format=MyModel` call flows the concrete model +# type through to `ParsedAskModeResponse[MyModel].final_answer_parsed`. +M = TypeVar("M", bound=BaseModel) + class _BaseQueryAgent(Generic[ClientType], _BaseAgent[ClientType], ABC): def __init__( @@ -161,18 +167,25 @@ def _parse_ask_result( self, response: dict[str, Any], output_format: Union[type[str], dict[str, Any], type[BaseModel]], - ) -> AskModeResponse: + ) -> Union[AskModeResponse, ParsedAskModeResponse[Any]]: + # Not overloaded: callers pass the broad union, and the precise return + # type (ParsedAskModeResponse[M]) is conveyed by the public ask overloads. if isinstance(output_format, type) and issubclass(output_format, BaseModel): - response["final_answer"] = output_format.model_validate_json( + response["final_answer_parsed"] = output_format.model_validate_json( response["final_answer"] ) + return ParsedAskModeResponse[BaseModel](**response) + elif isinstance(output_format, dict): try: - response["final_answer"] = loads(response["final_answer"]) + response["final_answer_parsed"] = loads(response["final_answer"]) except JSONDecodeError: warnings.warn( "Unable to decode final answer as dictionary, returning as string" ) + response["final_answer_parsed"] = response["final_answer"] + return ParsedAskModeResponse[dict](**response) + return AskModeResponse(**response) @deprecated( @@ -197,6 +210,39 @@ def run( """ pass + @overload + def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], + ) -> Union[ + ParsedAskModeResponse[M], Coroutine[Any, Any, ParsedAskModeResponse[M]] + ]: ... + + @overload + def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> Union[ + ParsedAskModeResponse[dict], Coroutine[Any, Any, ParsedAskModeResponse[dict]] + ]: ... + + @overload + def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + output_format: type[str] = str, + ) -> Union[AskModeResponse, Coroutine[Any, Any, AskModeResponse]]: ... + @abstractmethod def ask( self, @@ -204,7 +250,11 @@ def ask( collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, result_evaluation: Literal["llm", "none"] = "none", output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, - ) -> Union[AskModeResponse, Coroutine[Any, Any, AskModeResponse]]: + ) -> Union[ + AskModeResponse, + ParsedAskModeResponse[Any], + Coroutine[Any, Any, Union[AskModeResponse, ParsedAskModeResponse[Any]]], + ]: pass @overload @@ -308,7 +358,47 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + *, + output_format: type[M], + ) -> Union[ + Generator[ + Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[M]], None, None + ], + AsyncGenerator[ + Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[M]], None + ], + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> Union[ + Generator[ + Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[dict]], + None, + None, + ], + AsyncGenerator[ + Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[dict]], None + ], + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + output_format: type[str] = str, ) -> Union[ Generator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None, None], AsyncGenerator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None], @@ -322,7 +412,37 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + *, + output_format: type[M], + ) -> Union[ + Generator[Union[ProgressMessage, StreamedTokens], None, None], + AsyncGenerator[Union[ProgressMessage, StreamedTokens], None], + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> Union[ + Generator[Union[ProgressMessage, StreamedTokens], None, None], + AsyncGenerator[Union[ProgressMessage, StreamedTokens], None], + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + output_format: type[str] = str, ) -> Union[ Generator[Union[ProgressMessage, StreamedTokens], None, None], AsyncGenerator[Union[ProgressMessage, StreamedTokens], None], @@ -336,7 +456,37 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + *, + output_format: type[M], + ) -> Union[ + Generator[Union[StreamedTokens, ParsedAskModeResponse[M]], None, None], + AsyncGenerator[Union[StreamedTokens, ParsedAskModeResponse[M]], None], + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> Union[ + Generator[Union[StreamedTokens, ParsedAskModeResponse[dict]], None, None], + AsyncGenerator[Union[StreamedTokens, ParsedAskModeResponse[dict]], None], + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + output_format: type[str] = str, ) -> Union[ Generator[Union[StreamedTokens, AskModeResponse], None, None], AsyncGenerator[Union[StreamedTokens, AskModeResponse], None], @@ -350,7 +500,37 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + *, + output_format: type[M], + ) -> Union[ + Generator[StreamedTokens, None, None], + AsyncGenerator[StreamedTokens, None], + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> Union[ + Generator[StreamedTokens, None, None], + AsyncGenerator[StreamedTokens, None], + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + output_format: type[str] = str, ) -> Union[ Generator[StreamedTokens, None, None], AsyncGenerator[StreamedTokens, None], @@ -613,13 +793,42 @@ def run( return QueryAgentResponse(**response.json()) + @overload + def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], + ) -> ParsedAskModeResponse[M]: ... + + @overload + def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> ParsedAskModeResponse[dict]: ... + + @overload + def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + output_format: type[str] = str, + ) -> AskModeResponse: ... + def ask( self, query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, result_evaluation: Literal["llm", "none"] = "none", output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, - ) -> AskModeResponse: + ) -> Union[AskModeResponse, ParsedAskModeResponse[Any]]: """Run the Query Agent ask mode. Perform an agentic search on the collections and return a natural language answer to the query. @@ -784,7 +993,35 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + *, + output_format: type[M], + ) -> Generator[ + Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[M]], None, None + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> Generator[ + Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[dict]], None, None + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + output_format: type[str] = str, ) -> Generator[ Union[ProgressMessage, StreamedTokens, AskModeResponse], None, None ]: ... @@ -797,7 +1034,31 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + *, + output_format: type[M], + ) -> Generator[Union[ProgressMessage, StreamedTokens], None, None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> Generator[Union[ProgressMessage, StreamedTokens], None, None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + output_format: type[str] = str, ) -> Generator[Union[ProgressMessage, StreamedTokens], None, None]: ... @overload @@ -808,7 +1069,31 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + *, + output_format: type[M], + ) -> Generator[Union[StreamedTokens, ParsedAskModeResponse[M]], None, None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> Generator[Union[StreamedTokens, ParsedAskModeResponse[dict]], None, None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + output_format: type[str] = str, ) -> Generator[Union[StreamedTokens, AskModeResponse], None, None]: ... @overload @@ -819,7 +1104,31 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + *, + output_format: type[M], + ) -> Generator[StreamedTokens, None, None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> Generator[StreamedTokens, None, None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + output_format: type[str] = str, ) -> Generator[StreamedTokens, None, None]: ... def ask_stream( @@ -1332,13 +1641,42 @@ async def run( return QueryAgentResponse(**response.json()) + @overload + async def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], + ) -> ParsedAskModeResponse[M]: ... + + @overload + async def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> ParsedAskModeResponse[dict]: ... + + @overload + async def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + output_format: type[str] = str, + ) -> AskModeResponse: ... + async def ask( self, query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, result_evaluation: Literal["llm", "none"] = "none", output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, - ) -> AskModeResponse: + ) -> Union[AskModeResponse, ParsedAskModeResponse[Any]]: """Run the Query Agent ask mode. Perform an agentic search on the collections and return a natural language answer to the query. @@ -1504,7 +1842,35 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + *, + output_format: type[M], + ) -> AsyncGenerator[ + Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[M]], None + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> AsyncGenerator[ + Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[dict]], None + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + output_format: type[str] = str, ) -> AsyncGenerator[ Union[ProgressMessage, StreamedTokens, AskModeResponse], None ]: ... @@ -1517,7 +1883,31 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + *, + output_format: type[M], + ) -> AsyncGenerator[Union[ProgressMessage, StreamedTokens], None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> AsyncGenerator[Union[ProgressMessage, StreamedTokens], None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + output_format: type[str] = str, ) -> AsyncGenerator[Union[ProgressMessage, StreamedTokens], None]: ... @overload @@ -1528,7 +1918,31 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + *, + output_format: type[M], + ) -> AsyncGenerator[Union[StreamedTokens, ParsedAskModeResponse[M]], None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> AsyncGenerator[Union[StreamedTokens, ParsedAskModeResponse[dict]], None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + output_format: type[str] = str, ) -> AsyncGenerator[Union[StreamedTokens, AskModeResponse], None]: ... @overload @@ -1539,7 +1953,31 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + *, + output_format: type[M], + ) -> AsyncGenerator[StreamedTokens, None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> AsyncGenerator[StreamedTokens, None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + output_format: type[str] = str, ) -> AsyncGenerator[StreamedTokens, None]: ... async def ask_stream( diff --git a/weaviate_agents/utils.py b/weaviate_agents/utils.py index cee7429..5c2f263 100644 --- a/weaviate_agents/utils.py +++ b/weaviate_agents/utils.py @@ -1,13 +1,16 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union -from pydantic import BaseModel from rich.console import Console from rich.panel import Panel from rich.pretty import Pretty from rich.table import Table if TYPE_CHECKING: - from weaviate_agents.query.classes import AskModeResponse, QueryAgentResponse + from weaviate_agents.query.classes import ( + AskModeResponse, + ParsedAskModeResponse, + QueryAgentResponse, + ) console = Console() @@ -88,14 +91,14 @@ def print_query_agent_response(response: "QueryAgentResponse"): ) -def print_ask_mode_response(response: "AskModeResponse"): +def print_ask_mode_response( + response: Union["AskModeResponse", "ParsedAskModeResponse"], +): """Prints a formatted response from the Ask Mode using rich.""" - if isinstance(response.final_answer, BaseModel) or isinstance( - response.final_answer, dict - ): + if hasattr(response, "final_answer_parsed"): console.print( Panel( - Pretty(response.final_answer), + Pretty(response.final_answer_parsed), title="💬 Ask Mode Response", style="cyan", padding=1, From 7c4de2cafb87b7ec71f94fe9be8c8cf95588b94e Mon Sep 17 00:00:00 2001 From: dannyjameswilliams Date: Fri, 19 Jun 2026 12:27:53 +0100 Subject: [PATCH 4/5] change default from str to None --- weaviate_agents/query/query_agent.py | 57 ++++++++++++++-------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/weaviate_agents/query/query_agent.py b/weaviate_agents/query/query_agent.py index 3f0ba3e..226d03c 100644 --- a/weaviate_agents/query/query_agent.py +++ b/weaviate_agents/query/query_agent.py @@ -68,7 +68,7 @@ def _prepare_request_body( query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, context: Optional[QueryAgentResponse] = None, - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + output_format: Union[None, dict[str, Any], type[BaseModel]] = None, **kwargs, ) -> dict: """Prepare the request body for the query. @@ -77,7 +77,7 @@ def _prepare_request_body( query: The natural language query string for the agent. collections: The collections to query. Will override any collections if passed in the constructor. context: Optional previous response from the agent. - output_format: The format of the output to return. Either a `str` (default), a `BaseModel` subclass, or a `dict` (a Draft 2020-12 JSON Schema). + output_format: The format of the output to return. Either `None` (default, no structured output), a `BaseModel` subclass, or a `dict` (a Draft 2020-12 JSON Schema). **kwargs: Additional keyword arguments to pass to the request body. """ collections = collections or self._collections @@ -89,6 +89,7 @@ def _prepare_request_body( if isinstance(query, str) else ConversationContext(messages=query).model_dump(mode="json") ) + if isinstance(output_format, type) and issubclass(output_format, BaseModel): output_format_json = output_format.model_json_schema() elif isinstance(output_format, dict): @@ -166,7 +167,7 @@ def _prepare_research_mode_request_body( def _parse_ask_result( self, response: dict[str, Any], - output_format: Union[type[str], dict[str, Any], type[BaseModel]], + output_format: Union[dict[str, Any], type[BaseModel], None], ) -> Union[AskModeResponse, ParsedAskModeResponse[Any]]: # Not overloaded: callers pass the broad union, and the precise return # type (ParsedAskModeResponse[M]) is conveyed by the public ask overloads. @@ -240,7 +241,7 @@ def ask( query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, result_evaluation: Literal["llm", "none"] = "none", - output_format: type[str] = str, + output_format: None = None, ) -> Union[AskModeResponse, Coroutine[Any, Any, AskModeResponse]]: ... @abstractmethod @@ -249,7 +250,7 @@ def ask( query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + output_format: Union[dict[str, Any], type[BaseModel], None] = None, ) -> Union[ AskModeResponse, ParsedAskModeResponse[Any], @@ -398,7 +399,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: type[str] = str, + output_format: None = None, ) -> Union[ Generator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None, None], AsyncGenerator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None], @@ -442,7 +443,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: type[str] = str, + output_format: None = None, ) -> Union[ Generator[Union[ProgressMessage, StreamedTokens], None, None], AsyncGenerator[Union[ProgressMessage, StreamedTokens], None], @@ -486,7 +487,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: type[str] = str, + output_format: None = None, ) -> Union[ Generator[Union[StreamedTokens, AskModeResponse], None, None], AsyncGenerator[Union[StreamedTokens, AskModeResponse], None], @@ -530,7 +531,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: type[str] = str, + output_format: None = None, ) -> Union[ Generator[StreamedTokens, None, None], AsyncGenerator[StreamedTokens, None], @@ -544,7 +545,7 @@ def ask_stream( include_progress: bool = True, include_final_state: bool = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + output_format: Union[dict[str, Any], type[BaseModel], None] = None, ) -> Union[ Generator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None, None], AsyncGenerator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None], @@ -819,7 +820,7 @@ def ask( query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, result_evaluation: Literal["llm", "none"] = "none", - output_format: type[str] = str, + output_format: None = None, ) -> AskModeResponse: ... def ask( @@ -827,7 +828,7 @@ def ask( query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + output_format: Union[dict[str, Any], type[BaseModel], None] = None, ) -> Union[AskModeResponse, ParsedAskModeResponse[Any]]: """Run the Query Agent ask mode. @@ -842,7 +843,7 @@ def ask( Also populates the fields `missing_information` and `is_partial_answer` of the response. If ``"none"``, the result will not be evaluated, and the sources will not be filtered. Defaults to ``"none"``. - output_format: The structured output format to return. Either a `str` (default), a `BaseModel` subclass, or a dictionary. + output_format: The structured output format to return. Either `None` (default, no structured output), a `BaseModel` subclass, or a dictionary. This enforces the output format of the final answer to be of this schema. The LLM will conform to the output format specified. The `final_answer` output field in the response will also be of the type specified. @@ -1021,7 +1022,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: type[str] = str, + output_format: None = None, ) -> Generator[ Union[ProgressMessage, StreamedTokens, AskModeResponse], None, None ]: ... @@ -1058,7 +1059,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: type[str] = str, + output_format: None = None, ) -> Generator[Union[ProgressMessage, StreamedTokens], None, None]: ... @overload @@ -1093,7 +1094,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: type[str] = str, + output_format: None = None, ) -> Generator[Union[StreamedTokens, AskModeResponse], None, None]: ... @overload @@ -1128,7 +1129,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: type[str] = str, + output_format: None = None, ) -> Generator[StreamedTokens, None, None]: ... def ask_stream( @@ -1138,7 +1139,7 @@ def ask_stream( include_progress: bool = True, include_final_state: bool = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + output_format: Union[dict[str, Any], type[BaseModel], None] = None, ): """Run the Query Agent ask mode and stream the response. @@ -1153,7 +1154,7 @@ def ask_stream( Also populates the fields `missing_information` and `is_partial_answer` of the response. If ``"none"``, the result will not be evaluated, and the sources will not be filtered. Defaults to ``"none"``. - output_format: The structured output format to return. Either a `str` (default), a `BaseModel` subclass, or a dictionary. + output_format: The structured output format to return. Either `None` (default, no structured output), a `BaseModel` subclass, or a dictionary. This enforces the output format of the final answer to be of this schema. The LLM will conform to the output format specified. The `final_answer` output field in the response will also be of the type specified. @@ -1667,7 +1668,7 @@ async def ask( query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, result_evaluation: Literal["llm", "none"] = "none", - output_format: type[str] = str, + output_format: None = None, ) -> AskModeResponse: ... async def ask( @@ -1675,7 +1676,7 @@ async def ask( query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + output_format: Union[dict[str, Any], type[BaseModel], None] = None, ) -> Union[AskModeResponse, ParsedAskModeResponse[Any]]: """Run the Query Agent ask mode. @@ -1689,7 +1690,7 @@ async def ask( Also populates the fields `missing_information` and `is_partial_answer` of the response. If ``"none"``, the result will not be evaluated, and the sources will not be filtered. Defaults to ``"none"``. - output_format: The structured output format to return. Either a `str` (default), a `BaseModel` subclass, or a dictionary. + output_format: The structured output format to return. Either `None` (default, no structured output), a `BaseModel` subclass, or a dictionary. This enforces the output format of the final answer to be of this schema. The LLM will conform to the output format specified. The `final_answer` output field in the response will also be of the type specified. @@ -1870,7 +1871,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: type[str] = str, + output_format: None = None, ) -> AsyncGenerator[ Union[ProgressMessage, StreamedTokens, AskModeResponse], None ]: ... @@ -1907,7 +1908,7 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: type[str] = str, + output_format: None = None, ) -> AsyncGenerator[Union[ProgressMessage, StreamedTokens], None]: ... @overload @@ -1942,7 +1943,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: type[str] = str, + output_format: None = None, ) -> AsyncGenerator[Union[StreamedTokens, AskModeResponse], None]: ... @overload @@ -1977,7 +1978,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", - output_format: type[str] = str, + output_format: None = None, ) -> AsyncGenerator[StreamedTokens, None]: ... async def ask_stream( @@ -1987,7 +1988,7 @@ async def ask_stream( include_progress: bool = True, include_final_state: bool = True, result_evaluation: Literal["llm", "none"] = "none", - output_format: Union[type[str], dict[str, Any], type[BaseModel]] = str, + output_format: Union[dict[str, Any], type[BaseModel], None] = None, ): """Run the Query Agent ask mode and stream the response. @@ -2002,7 +2003,7 @@ async def ask_stream( Also populates the fields `missing_information` and `is_partial_answer` of the response. If ``"none"``, the result will not be evaluated, and the sources will not be filtered. Defaults to ``"none"``. - output_format: The structured output format to return. Either a `str` (default), a `BaseModel` subclass, or a dictionary. + output_format: The structured output format to return. Either `None` (default, no structured output), a `BaseModel` subclass, or a dictionary. This enforces the output format of the final answer to be of this schema. The LLM will conform to the output format specified. The `final_answer` output field in the response will also be of the type specified. From 20be91ee26d2ab2b0e5fa416c70a2ec495535d4d Mon Sep 17 00:00:00 2001 From: dannyjameswilliams Date: Tue, 23 Jun 2026 17:18:02 +0100 Subject: [PATCH 5/5] refactor: move parse_ask_response to inside parse_sse --- weaviate_agents/query/query_agent.py | 106 +++++++++++++++------------ 1 file changed, 60 insertions(+), 46 deletions(-) diff --git a/weaviate_agents/query/query_agent.py b/weaviate_agents/query/query_agent.py index 226d03c..3ede129 100644 --- a/weaviate_agents/query/query_agent.py +++ b/weaviate_agents/query/query_agent.py @@ -164,31 +164,6 @@ def _prepare_research_mode_request_body( return output - def _parse_ask_result( - self, - response: dict[str, Any], - output_format: Union[dict[str, Any], type[BaseModel], None], - ) -> Union[AskModeResponse, ParsedAskModeResponse[Any]]: - # Not overloaded: callers pass the broad union, and the precise return - # type (ParsedAskModeResponse[M]) is conveyed by the public ask overloads. - if isinstance(output_format, type) and issubclass(output_format, BaseModel): - response["final_answer_parsed"] = output_format.model_validate_json( - response["final_answer"] - ) - return ParsedAskModeResponse[BaseModel](**response) - - elif isinstance(output_format, dict): - try: - response["final_answer_parsed"] = loads(response["final_answer"]) - except JSONDecodeError: - warnings.warn( - "Unable to decode final answer as dictionary, returning as string" - ) - response["final_answer_parsed"] = response["final_answer"] - return ParsedAskModeResponse[dict](**response) - - return AskModeResponse(**response) - @deprecated( "QueryAgent.run() is deprecated and will be removed in a future release. " "Use QueryAgent.ask() instead." @@ -897,9 +872,7 @@ def ask( if response.is_error: raise Exception(response.text) - return self._parse_ask_result( - response=response.json(), output_format=output_format - ) + return _parse_ask_result(response=response.json(), output_format=output_format) @overload def stream( @@ -1235,16 +1208,13 @@ def ask_stream( raise Exception(events.response.text) for sse in events.iter_sse(): - output = _parse_sse(sse, mode="ask") + output = _parse_sse(sse, mode="ask", output_format=output_format) if isinstance(output, ProgressMessage): if include_progress: yield output elif isinstance(output, AskModeResponse): if include_final_state: - yield self._parse_ask_result( - response=output.model_dump(mode="json"), - output_format=output_format, - ) + yield output else: yield output @@ -1745,7 +1715,7 @@ async def ask( if response.is_error: raise Exception(response.text) - return self._parse_ask_result( + return _parse_ask_result( response=response.json(), output_format=output_format ) @@ -2084,16 +2054,13 @@ async def ask_stream( raise Exception(events.response.text) async for sse in events.aiter_sse(): - output = _parse_sse(sse, mode="ask") + output = _parse_sse(sse, mode="ask", output_format=output_format) if isinstance(output, ProgressMessage): if include_progress: yield output elif isinstance(output, AskModeResponse): if include_final_state: - yield self._parse_ask_result( - response=output.model_dump(mode="json"), - output_format=output_format, - ) + yield output else: yield output @@ -2429,30 +2396,49 @@ async def suggest_queries( @overload def _parse_sse( - sse: ServerSentEvent, mode: Literal["query"] + sse: ServerSentEvent, + mode: Literal["query"], + output_format: Union[dict[str, Any], type[BaseModel], None] = None, ) -> Union[ProgressMessage, StreamedTokens, QueryAgentResponse]: ... @overload def _parse_sse( - sse: ServerSentEvent, mode: Literal["ask"] -) -> Union[ProgressMessage, StreamedTokens, AskModeResponse]: ... + sse: ServerSentEvent, + mode: Literal["research"], + output_format: Union[dict[str, Any], type[BaseModel], None] = None, +) -> Union[ProgressMessage, StreamedThoughts, StreamedTokens, ResearchModeResponse]: ... @overload def _parse_sse( - sse: ServerSentEvent, mode: Literal["research"] -) -> Union[ProgressMessage, StreamedThoughts, StreamedTokens, ResearchModeResponse]: ... + sse: ServerSentEvent, mode: Literal["ask"], output_format: dict[str, Any] +) -> Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[dict]]: ... +@overload +def _parse_sse( + sse: ServerSentEvent, mode: Literal["ask"], output_format: type[M] +) -> Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[M]]: ... + + +@overload def _parse_sse( - sse: ServerSentEvent, mode: Literal["query", "ask", "research"] + sse: ServerSentEvent, mode: Literal["ask"], output_format: None +) -> Union[ProgressMessage, StreamedTokens, AskModeResponse]: ... + + +def _parse_sse( + sse: ServerSentEvent, + mode: Literal["query", "ask", "research"], + output_format: Union[dict[str, Any], type[BaseModel], None] = None, ) -> Union[ ProgressMessage, StreamedThoughts, StreamedTokens, QueryAgentResponse, AskModeResponse, + ParsedAskModeResponse[Any], ResearchModeResponse, ]: try: @@ -2472,10 +2458,38 @@ def _parse_sse( if mode == "query": return QueryAgentResponse.model_validate(data) elif mode == "ask": - return AskModeResponse.model_validate(data) + return _parse_ask_result( + response=data, + output_format=output_format, + ) elif mode == "research": return ResearchModeResponse.model_validate(data) else: raise Exception( f"Unrecognised event type in response: {sse.event=}, {sse.data=}" ) + + +def _parse_ask_result( + response: dict[str, Any], + output_format: Union[dict[str, Any], type[BaseModel], None], +) -> Union[AskModeResponse, ParsedAskModeResponse[Any]]: + # Not overloaded: callers pass the broad union, and the precise return + # type (ParsedAskModeResponse[M]) is conveyed by the public ask overloads. + if isinstance(output_format, type) and issubclass(output_format, BaseModel): + response["final_answer_parsed"] = output_format.model_validate_json( + response["final_answer"] + ) + return ParsedAskModeResponse[BaseModel](**response) + + elif isinstance(output_format, dict): + try: + response["final_answer_parsed"] = loads(response["final_answer"]) + except JSONDecodeError: + warnings.warn( + "Unable to decode final answer as dictionary, returning as string" + ) + response["final_answer_parsed"] = response["final_answer"] + return ParsedAskModeResponse[dict](**response) + + return AskModeResponse(**response)