diff --git a/weaviate_agents/query/classes/__init__.py b/weaviate_agents/query/classes/__init__.py index 1f9c3f1..1bc05ef 100644 --- a/weaviate_agents/query/classes/__init__.py +++ b/weaviate_agents/query/classes/__init__.py @@ -24,6 +24,7 @@ IsNullPropertyFilter, ModelUnitUsage, NumericMetrics, + ParsedAskModeResponse, ProgressDetails, ProgressMessage, QueryAgentResponse, @@ -79,6 +80,7 @@ "GeoPropertyFilter", "UnknownPropertyAggregation", "UnknownPropertyFilter", + "ParsedAskModeResponse", "ProgressDetails", "ProgressMessage", "QueryWithCollection", diff --git a/weaviate_agents/query/classes/response.py b/weaviate_agents/query/classes/response.py index 61c99b2..5291050 100644 --- a/weaviate_agents/query/classes/response.py +++ b/weaviate_agents/query/classes/response.py @@ -15,7 +15,7 @@ ) from uuid import UUID -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, field_validator from typing_extensions import TypedDict from weaviate.outputs.query import QueryReturn @@ -479,6 +479,13 @@ def display(self) -> None: return None +T = TypeVar("T", bound=Union[dict, BaseModel]) + + +class ParsedAskModeResponse(AskModeResponse, Generic[T]): + final_answer_parsed: SerializeAsAny[T] + + class ResearchModeResponse(BaseModel): output_type: Literal["final_state"] = "final_state" diff --git a/weaviate_agents/query/query_agent.py b/weaviate_agents/query/query_agent.py index 7f0756c..3ede129 100644 --- a/weaviate_agents/query/query_agent.py +++ b/weaviate_agents/query/query_agent.py @@ -1,5 +1,6 @@ +import warnings from abc import ABC, abstractmethod -from json import JSONDecodeError +from json import JSONDecodeError, loads from typing import ( Any, AsyncGenerator, @@ -8,18 +9,21 @@ Generic, Literal, Optional, + TypeVar, Union, overload, ) import httpx from httpx_sse import ServerSentEvent, aconnect_sse, connect_sse +from pydantic import BaseModel from typing_extensions import deprecated from weaviate.client import WeaviateAsyncClient, WeaviateClient from weaviate_agents.base import ClientType, _BaseAgent from weaviate_agents.query.classes import ( AskModeResponse, + ParsedAskModeResponse, ProgressMessage, QueryAgentCollectionConfig, QueryAgentResponse, @@ -36,6 +40,10 @@ SearchModeResponse, ) +# Bound to BaseModel so a `output_format=MyModel` call flows the concrete model +# type through to `ParsedAskModeResponse[MyModel].final_answer_parsed`. +M = TypeVar("M", bound=BaseModel) + class _BaseQueryAgent(Generic[ClientType], _BaseAgent[ClientType], ABC): def __init__( @@ -60,6 +68,7 @@ def _prepare_request_body( query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, context: Optional[QueryAgentResponse] = None, + output_format: Union[None, dict[str, Any], type[BaseModel]] = None, **kwargs, ) -> dict: """Prepare the request body for the query. @@ -68,6 +77,7 @@ def _prepare_request_body( query: The natural language query string for the agent. collections: The collections to query. Will override any collections if passed in the constructor. context: Optional previous response from the agent. + output_format: The format of the output to return. Either `None` (default, no structured output), a `BaseModel` subclass, or a `dict` (a Draft 2020-12 JSON Schema). **kwargs: Additional keyword arguments to pass to the request body. """ collections = collections or self._collections @@ -79,6 +89,14 @@ def _prepare_request_body( if isinstance(query, str) else ConversationContext(messages=query).model_dump(mode="json") ) + + if isinstance(output_format, type) and issubclass(output_format, BaseModel): + output_format_json = output_format.model_json_schema() + elif isinstance(output_format, dict): + output_format_json = output_format + else: + output_format_json = None + output = { "query": query_request, "collections": [ @@ -91,6 +109,7 @@ def _prepare_request_body( ], "headers": self._connection.additional_headers, "system_prompt": self._system_prompt, + "output_format": output_format_json, **kwargs, } if context is not None: @@ -167,13 +186,51 @@ def run( """ pass + @overload + def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], + ) -> Union[ + ParsedAskModeResponse[M], Coroutine[Any, Any, ParsedAskModeResponse[M]] + ]: ... + + @overload + def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> Union[ + ParsedAskModeResponse[dict], Coroutine[Any, Any, ParsedAskModeResponse[dict]] + ]: ... + + @overload + def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + output_format: None = None, + ) -> Union[AskModeResponse, Coroutine[Any, Any, AskModeResponse]]: ... + @abstractmethod def ask( self, query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, result_evaluation: Literal["llm", "none"] = "none", - ) -> Union[AskModeResponse, Coroutine[Any, Any, AskModeResponse]]: + output_format: Union[dict[str, Any], type[BaseModel], None] = None, + ) -> Union[ + AskModeResponse, + ParsedAskModeResponse[Any], + Coroutine[Any, Any, Union[AskModeResponse, ParsedAskModeResponse[Any]]], + ]: pass @overload @@ -277,6 +334,47 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], + ) -> Union[ + Generator[ + Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[M]], None, None + ], + AsyncGenerator[ + Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[M]], None + ], + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> Union[ + Generator[ + Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[dict]], + None, + None, + ], + AsyncGenerator[ + Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[dict]], None + ], + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + output_format: None = None, ) -> Union[ Generator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None, None], AsyncGenerator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None], @@ -290,11 +388,72 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], + ) -> Union[ + Generator[Union[ProgressMessage, StreamedTokens], None, None], + AsyncGenerator[Union[ProgressMessage, StreamedTokens], None], + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], ) -> Union[ Generator[Union[ProgressMessage, StreamedTokens], None, None], AsyncGenerator[Union[ProgressMessage, StreamedTokens], None], ]: ... + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + output_format: None = None, + ) -> Union[ + Generator[Union[ProgressMessage, StreamedTokens], None, None], + AsyncGenerator[Union[ProgressMessage, StreamedTokens], None], + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], + ) -> Union[ + Generator[Union[StreamedTokens, ParsedAskModeResponse[M]], None, None], + AsyncGenerator[Union[StreamedTokens, ParsedAskModeResponse[M]], None], + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> Union[ + Generator[Union[StreamedTokens, ParsedAskModeResponse[dict]], None, None], + AsyncGenerator[Union[StreamedTokens, ParsedAskModeResponse[dict]], None], + ]: ... + @overload def ask_stream( self, @@ -303,6 +462,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", + output_format: None = None, ) -> Union[ Generator[Union[StreamedTokens, AskModeResponse], None, None], AsyncGenerator[Union[StreamedTokens, AskModeResponse], None], @@ -316,6 +476,37 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], + ) -> Union[ + Generator[StreamedTokens, None, None], + AsyncGenerator[StreamedTokens, None], + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> Union[ + Generator[StreamedTokens, None, None], + AsyncGenerator[StreamedTokens, None], + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + output_format: None = None, ) -> Union[ Generator[StreamedTokens, None, None], AsyncGenerator[StreamedTokens, None], @@ -329,6 +520,7 @@ def ask_stream( include_progress: bool = True, include_final_state: bool = True, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[dict[str, Any], type[BaseModel], None] = None, ) -> Union[ Generator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None, None], AsyncGenerator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None], @@ -577,12 +769,42 @@ def run( return QueryAgentResponse(**response.json()) + @overload + def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], + ) -> ParsedAskModeResponse[M]: ... + + @overload + def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> ParsedAskModeResponse[dict]: ... + + @overload + def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + output_format: None = None, + ) -> AskModeResponse: ... + def ask( self, query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, result_evaluation: Literal["llm", "none"] = "none", - ) -> AskModeResponse: + output_format: Union[dict[str, Any], type[BaseModel], None] = None, + ) -> Union[AskModeResponse, ParsedAskModeResponse[Any]]: """Run the Query Agent ask mode. Perform an agentic search on the collections and return a natural language answer to the query. @@ -596,6 +818,12 @@ def ask( Also populates the fields `missing_information` and `is_partial_answer` of the response. If ``"none"``, the result will not be evaluated, and the sources will not be filtered. Defaults to ``"none"``. + output_format: The structured output format to return. Either `None` (default, no structured output), a `BaseModel` subclass, or a dictionary. + This enforces the output format of the final answer to be of this schema. + The LLM will conform to the output format specified. + The `final_answer` output field in the response will also be of the type specified. + When passing a `dict`, the dictionary must conform to the Draft 2020-12 JSON Schema specification. + Defaults to `str`. Returns: An instance of :class:`~weaviate_agents.query.classes.response.AskModeResponse` which contains the final answer, sources, @@ -608,9 +836,30 @@ def ask( ... collections=["FinancialContracts"], ... ) >>> agent.ask("What are the terms of the contract signed by John Smith in May 2025?") + + >>> from weaviate_agents import QueryAgent + >>> from pydantic import BaseModel, Field + >>> + >>> class CitedText(BaseModel): + ... text: str + ... sources: list[str] = Field(description="The sources that support this section of text. Can be empty.") + >>> + >>> class AnswerWithSources(BaseModel): + ... texts: list[CitedText] + >>> + >>> agent = QueryAgent( + ... client=client, + ... collections=["FinancialContracts"], + ... ) + >>> result = agent.ask("What contracts were signed by Jane Doe in 2024? What were they about?", output_format=AnswerWithSources) + >>> print(type(result.final_answer)) + """ request_body = self._prepare_request_body( - query=query, collections=collections, result_evaluation=result_evaluation + query=query, + collections=collections, + result_evaluation=result_evaluation, + output_format=output_format, ) response = httpx.post( @@ -623,7 +872,7 @@ def ask( if response.is_error: raise Exception(response.text) - return AskModeResponse(**response.json()) + return _parse_ask_result(response=response.json(), output_format=output_format) @overload def stream( @@ -718,6 +967,35 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], + ) -> Generator[ + Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[M]], None, None + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> Generator[ + Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[dict]], None, None + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + output_format: None = None, ) -> Generator[ Union[ProgressMessage, StreamedTokens, AskModeResponse], None, None ]: ... @@ -730,8 +1008,57 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], ) -> Generator[Union[ProgressMessage, StreamedTokens], None, None]: ... + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> Generator[Union[ProgressMessage, StreamedTokens], None, None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + output_format: None = None, + ) -> Generator[Union[ProgressMessage, StreamedTokens], None, None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], + ) -> Generator[Union[StreamedTokens, ParsedAskModeResponse[M]], None, None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> Generator[Union[StreamedTokens, ParsedAskModeResponse[dict]], None, None]: ... + @overload def ask_stream( self, @@ -740,6 +1067,7 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", + output_format: None = None, ) -> Generator[Union[StreamedTokens, AskModeResponse], None, None]: ... @overload @@ -750,6 +1078,31 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], + ) -> Generator[StreamedTokens, None, None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> Generator[StreamedTokens, None, None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + output_format: None = None, ) -> Generator[StreamedTokens, None, None]: ... def ask_stream( @@ -759,6 +1112,7 @@ def ask_stream( include_progress: bool = True, include_final_state: bool = True, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[dict[str, Any], type[BaseModel], None] = None, ): """Run the Query Agent ask mode and stream the response. @@ -773,6 +1127,16 @@ def ask_stream( Also populates the fields `missing_information` and `is_partial_answer` of the response. If ``"none"``, the result will not be evaluated, and the sources will not be filtered. Defaults to ``"none"``. + output_format: The structured output format to return. Either `None` (default, no structured output), a `BaseModel` subclass, or a dictionary. + This enforces the output format of the final answer to be of this schema. + The LLM will conform to the output format specified. + The `final_answer` output field in the response will also be of the type specified. + When passing a `dict`, the dictionary must conform to the Draft 2020-12 JSON Schema specification. + Whilst streaming, the :class:`~weaviate_agents.query.classes.response.StreamedTokens` will return delta text + tokens on the final answer as it is being constructed as raw string tokens, not a JSON object. + When the final answer is complete, the :class:`~weaviate_agents.query.classes.response.AskModeResponse` will be parsed + and the `final_answer` field will be of the type specified. + Defaults to `str`. Returns: A generator of the response stream. @@ -796,6 +1160,31 @@ def ask_stream( ... print(result.delta, end='', flush=True) ... elif isinstance(result, ProgressMessage): ... print(result.message) + + >>> from weaviate_agents import QueryAgent + >>> from pydantic import BaseModel, Field + >>> + >>> class CitedText(BaseModel): + ... text: str + ... sources: list[str] = Field(description="The sources that support this section of text. Can be empty.") + >>> + >>> class AnswerWithSources(BaseModel): + ... texts: list[CitedText] + >>> + >>> agent = QueryAgent( + ... client=client, + ... collections=["FinancialContracts"], + ... ) + >>> for result in agent.ask_stream( + ... "What contracts were signed by Jane Doe in 2024? What were they about?", + ... output_format=AnswerWithSources + ... ): + ... if isinstance(result, AskModeResponse): + ... result.display() + ... elif isinstance(result, StreamedTokens): + ... print(result.delta, end='', flush=True) + ... elif isinstance(result, ProgressMessage): + ... print(result.message) """ request_body = self._prepare_request_body( query=query, @@ -803,6 +1192,7 @@ def ask_stream( include_progress=include_progress, include_final_state=include_final_state, result_evaluation=result_evaluation, + output_format=output_format, ) with httpx.Client() as client: with connect_sse( @@ -818,7 +1208,7 @@ def ask_stream( raise Exception(events.response.text) for sse in events.iter_sse(): - output = _parse_sse(sse, mode="ask") + output = _parse_sse(sse, mode="ask", output_format=output_format) if isinstance(output, ProgressMessage): if include_progress: yield output @@ -1222,12 +1612,42 @@ async def run( return QueryAgentResponse(**response.json()) + @overload + async def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], + ) -> ParsedAskModeResponse[M]: ... + + @overload + async def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> ParsedAskModeResponse[dict]: ... + + @overload + async def ask( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + result_evaluation: Literal["llm", "none"] = "none", + output_format: None = None, + ) -> AskModeResponse: ... + async def ask( self, query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, result_evaluation: Literal["llm", "none"] = "none", - ) -> AskModeResponse: + output_format: Union[dict[str, Any], type[BaseModel], None] = None, + ) -> Union[AskModeResponse, ParsedAskModeResponse[Any]]: """Run the Query Agent ask mode. Perform an agentic search on the collections and return a natural language answer to the query. @@ -1240,6 +1660,12 @@ async def ask( Also populates the fields `missing_information` and `is_partial_answer` of the response. If ``"none"``, the result will not be evaluated, and the sources will not be filtered. Defaults to ``"none"``. + output_format: The structured output format to return. Either `None` (default, no structured output), a `BaseModel` subclass, or a dictionary. + This enforces the output format of the final answer to be of this schema. + The LLM will conform to the output format specified. + The `final_answer` output field in the response will also be of the type specified. + When passing a `dict`, the dictionary must conform to the Draft 2020-12 JSON Schema specification. + Defaults to `str`. Returns: An instance of :class:`~weaviate_agents.query.classes.response.AskModeResponse` which contains the final answer, sources, @@ -1252,9 +1678,30 @@ async def ask( ... collections=["FinancialContracts"], ... ) >>> await agent.ask("What are the terms of the contract signed by John Smith in May 2025?") + + >>> from weaviate_agents import AsyncQueryAgent + >>> from pydantic import BaseModel, Field + >>> + >>> class CitedText(BaseModel): + ... text: str + ... sources: list[str] = Field(description="The sources that support this section of text. Can be empty.") + >>> + >>> class AnswerWithSources(BaseModel): + ... texts: list[CitedText] + >>> + >>> agent = AsyncQueryAgent( + ... client=client, + ... collections=["FinancialContracts"], + ... ) + >>> result = await agent.ask("What contracts were signed by Jane Doe in 2024? What were they about?", output_format=AnswerWithSources) + >>> print(type(result.final_answer)) + """ request_body = self._prepare_request_body( - query=query, collections=collections, result_evaluation=result_evaluation + query=query, + collections=collections, + result_evaluation=result_evaluation, + output_format=output_format, ) async with httpx.AsyncClient() as client: @@ -1268,7 +1715,9 @@ async def ask( if response.is_error: raise Exception(response.text) - return AskModeResponse(**response.json()) + return _parse_ask_result( + response=response.json(), output_format=output_format + ) @overload def stream( @@ -1364,6 +1813,35 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], + ) -> AsyncGenerator[ + Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[M]], None + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> AsyncGenerator[ + Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[dict]], None + ]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + output_format: None = None, ) -> AsyncGenerator[ Union[ProgressMessage, StreamedTokens, AskModeResponse], None ]: ... @@ -1376,6 +1854,31 @@ def ask_stream( include_progress: Literal[True] = True, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], + ) -> AsyncGenerator[Union[ProgressMessage, StreamedTokens], None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> AsyncGenerator[Union[ProgressMessage, StreamedTokens], None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[True] = True, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + output_format: None = None, ) -> AsyncGenerator[Union[ProgressMessage, StreamedTokens], None]: ... @overload @@ -1386,6 +1889,31 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[True] = True, result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], + ) -> AsyncGenerator[Union[StreamedTokens, ParsedAskModeResponse[M]], None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> AsyncGenerator[Union[StreamedTokens, ParsedAskModeResponse[dict]], None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[True] = True, + result_evaluation: Literal["llm", "none"] = "none", + output_format: None = None, ) -> AsyncGenerator[Union[StreamedTokens, AskModeResponse], None]: ... @overload @@ -1396,6 +1924,31 @@ def ask_stream( include_progress: Literal[False] = False, include_final_state: Literal[False] = False, result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: type[M], + ) -> AsyncGenerator[StreamedTokens, None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + *, + output_format: dict[str, Any], + ) -> AsyncGenerator[StreamedTokens, None]: ... + + @overload + def ask_stream( + self, + query: Union[str, list[ChatMessage]], + collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, + include_progress: Literal[False] = False, + include_final_state: Literal[False] = False, + result_evaluation: Literal["llm", "none"] = "none", + output_format: None = None, ) -> AsyncGenerator[StreamedTokens, None]: ... async def ask_stream( @@ -1405,6 +1958,7 @@ async def ask_stream( include_progress: bool = True, include_final_state: bool = True, result_evaluation: Literal["llm", "none"] = "none", + output_format: Union[dict[str, Any], type[BaseModel], None] = None, ): """Run the Query Agent ask mode and stream the response. @@ -1419,6 +1973,16 @@ async def ask_stream( Also populates the fields `missing_information` and `is_partial_answer` of the response. If ``"none"``, the result will not be evaluated, and the sources will not be filtered. Defaults to ``"none"``. + output_format: The structured output format to return. Either `None` (default, no structured output), a `BaseModel` subclass, or a dictionary. + This enforces the output format of the final answer to be of this schema. + The LLM will conform to the output format specified. + The `final_answer` output field in the response will also be of the type specified. + When passing a `dict`, the dictionary must conform to the Draft 2020-12 JSON Schema specification. + Whilst streaming, the :class:`~weaviate_agents.query.classes.response.StreamedTokens` will return delta text + tokens on the final answer as it is being constructed as raw string tokens, not a JSON object. + When the final answer is complete, the :class:`~weaviate_agents.query.classes.response.AskModeResponse` will be parsed + and the `final_answer` field will be of the type specified. + Defaults to `str`. Returns: A generator of the response stream. @@ -1442,6 +2006,31 @@ async def ask_stream( ... print(result.delta, end='', flush=True) ... elif isinstance(result, ProgressMessage): ... print(result.message) + + >>> from weaviate_agents import QueryAgent + >>> from pydantic import BaseModel, Field + >>> + >>> class CitedText(BaseModel): + ... text: str + ... sources: list[str] = Field(description="The sources that support this section of text. Can be empty.") + >>> + >>> class AnswerWithSources(BaseModel): + ... texts: list[CitedText] + >>> + >>> agent = QueryAgent( + ... client=client, + ... collections=["FinancialContracts"], + ... ) + >>> async for result in agent.ask_stream( + ... "What contracts were signed by Jane Doe in 2024? What were they about?", + ... output_format=AnswerWithSources + ... ): + ... if isinstance(result, AskModeResponse): + ... result.display() + ... elif isinstance(result, StreamedTokens): + ... print(result.delta, end='', flush=True) + ... elif isinstance(result, ProgressMessage): + ... print(result.message) """ request_body = self._prepare_request_body( query=query, @@ -1449,6 +2038,7 @@ async def ask_stream( include_progress=include_progress, include_final_state=include_final_state, result_evaluation=result_evaluation, + output_format=output_format, ) async with httpx.AsyncClient() as client: async with aconnect_sse( @@ -1464,7 +2054,7 @@ async def ask_stream( raise Exception(events.response.text) async for sse in events.aiter_sse(): - output = _parse_sse(sse, mode="ask") + output = _parse_sse(sse, mode="ask", output_format=output_format) if isinstance(output, ProgressMessage): if include_progress: yield output @@ -1806,30 +2396,49 @@ async def suggest_queries( @overload def _parse_sse( - sse: ServerSentEvent, mode: Literal["query"] + sse: ServerSentEvent, + mode: Literal["query"], + output_format: Union[dict[str, Any], type[BaseModel], None] = None, ) -> Union[ProgressMessage, StreamedTokens, QueryAgentResponse]: ... @overload def _parse_sse( - sse: ServerSentEvent, mode: Literal["ask"] -) -> Union[ProgressMessage, StreamedTokens, AskModeResponse]: ... + sse: ServerSentEvent, + mode: Literal["research"], + output_format: Union[dict[str, Any], type[BaseModel], None] = None, +) -> Union[ProgressMessage, StreamedThoughts, StreamedTokens, ResearchModeResponse]: ... @overload def _parse_sse( - sse: ServerSentEvent, mode: Literal["research"] -) -> Union[ProgressMessage, StreamedThoughts, StreamedTokens, ResearchModeResponse]: ... + sse: ServerSentEvent, mode: Literal["ask"], output_format: dict[str, Any] +) -> Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[dict]]: ... + + +@overload +def _parse_sse( + sse: ServerSentEvent, mode: Literal["ask"], output_format: type[M] +) -> Union[ProgressMessage, StreamedTokens, ParsedAskModeResponse[M]]: ... + + +@overload +def _parse_sse( + sse: ServerSentEvent, mode: Literal["ask"], output_format: None +) -> Union[ProgressMessage, StreamedTokens, AskModeResponse]: ... def _parse_sse( - sse: ServerSentEvent, mode: Literal["query", "ask", "research"] + sse: ServerSentEvent, + mode: Literal["query", "ask", "research"], + output_format: Union[dict[str, Any], type[BaseModel], None] = None, ) -> Union[ ProgressMessage, StreamedThoughts, StreamedTokens, QueryAgentResponse, AskModeResponse, + ParsedAskModeResponse[Any], ResearchModeResponse, ]: try: @@ -1849,10 +2458,38 @@ def _parse_sse( if mode == "query": return QueryAgentResponse.model_validate(data) elif mode == "ask": - return AskModeResponse.model_validate(data) + return _parse_ask_result( + response=data, + output_format=output_format, + ) elif mode == "research": return ResearchModeResponse.model_validate(data) else: raise Exception( f"Unrecognised event type in response: {sse.event=}, {sse.data=}" ) + + +def _parse_ask_result( + response: dict[str, Any], + output_format: Union[dict[str, Any], type[BaseModel], None], +) -> Union[AskModeResponse, ParsedAskModeResponse[Any]]: + # Not overloaded: callers pass the broad union, and the precise return + # type (ParsedAskModeResponse[M]) is conveyed by the public ask overloads. + if isinstance(output_format, type) and issubclass(output_format, BaseModel): + response["final_answer_parsed"] = output_format.model_validate_json( + response["final_answer"] + ) + return ParsedAskModeResponse[BaseModel](**response) + + elif isinstance(output_format, dict): + try: + response["final_answer_parsed"] = loads(response["final_answer"]) + except JSONDecodeError: + warnings.warn( + "Unable to decode final answer as dictionary, returning as string" + ) + response["final_answer_parsed"] = response["final_answer"] + return ParsedAskModeResponse[dict](**response) + + return AskModeResponse(**response) diff --git a/weaviate_agents/utils.py b/weaviate_agents/utils.py index 51e760d..5c2f263 100644 --- a/weaviate_agents/utils.py +++ b/weaviate_agents/utils.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union from rich.console import Console from rich.panel import Panel @@ -6,7 +6,11 @@ from rich.table import Table if TYPE_CHECKING: - from weaviate_agents.query.classes import AskModeResponse, QueryAgentResponse + from weaviate_agents.query.classes import ( + AskModeResponse, + ParsedAskModeResponse, + QueryAgentResponse, + ) console = Console() @@ -87,13 +91,28 @@ def print_query_agent_response(response: "QueryAgentResponse"): ) -def print_ask_mode_response(response: "AskModeResponse"): +def print_ask_mode_response( + response: Union["AskModeResponse", "ParsedAskModeResponse"], +): """Prints a formatted response from the Ask Mode using rich.""" - console.print( - Panel( - response.final_answer, title="💬 Ask Mode Response", style="cyan", padding=1 + if hasattr(response, "final_answer_parsed"): + console.print( + Panel( + Pretty(response.final_answer_parsed), + title="💬 Ask Mode Response", + style="cyan", + padding=1, + ) + ) + else: + console.print( + Panel( + response.final_answer, + title="💬 Ask Mode Response", + style="cyan", + padding=1, + ) ) - ) for i, result in enumerate(response.searches): search_content = Pretty(result)