diff --git a/src/models/common/__init__.py b/src/models/common/__init__.py index 4d6a3b837..797894d1c 100644 --- a/src/models/common/__init__.py +++ b/src/models/common/__init__.py @@ -17,10 +17,12 @@ from models.common.query import Attachment, SolrVectorSearchRequest from models.common.transcripts import Transcript, TranscriptMetadata from models.common.turn_summary import ( + MCPListToolsSummary, RAGChunk, RAGContext, ReferencedDocument, ToolCallSummary, + ToolInfoSummary, ToolResultSummary, TurnSummary, ) @@ -48,4 +50,6 @@ "Transcript", "TranscriptMetadata", "TurnSummary", + "ToolInfoSummary", + "MCPListToolsSummary", ] diff --git a/src/models/common/agents/__init__.py b/src/models/common/agents/__init__.py new file mode 100644 index 000000000..a6ac96da2 --- /dev/null +++ b/src/models/common/agents/__init__.py @@ -0,0 +1,39 @@ +"""Streaming payload models and event type exports.""" + +from models.common.agents.stream_payloads import ( + EndEventData, + EndStreamPayload, + ErrorEventData, + ErrorStreamPayload, + InterruptedEventData, + InterruptedStreamPayload, + StartEventData, + StartStreamPayload, + StreamEventPayload, + StreamPayloadBase, + TokenChunkData, + TokenStreamPayload, + ToolCallStreamPayload, + ToolResultStreamPayload, + TurnCompleteStreamPayload, +) +from models.common.agents.turn_accumulator import AgentTurnAccumulator + +__all__ = [ + "StreamPayloadBase", + "ErrorEventData", + "StartEventData", + "InterruptedEventData", + "EndEventData", + "ErrorStreamPayload", + "StartStreamPayload", + "InterruptedStreamPayload", + "EndStreamPayload", + "TokenChunkData", + "TokenStreamPayload", + "TurnCompleteStreamPayload", + "ToolCallStreamPayload", + "ToolResultStreamPayload", + "StreamEventPayload", + "AgentTurnAccumulator", +] diff --git a/src/models/common/agents/stream_payloads.py b/src/models/common/agents/stream_payloads.py new file mode 100644 index 000000000..cc81993ee --- /dev/null +++ b/src/models/common/agents/stream_payloads.py @@ -0,0 +1,270 @@ +"""Typed JSON bodies for SSE streaming events.""" + +import json +from typing import Annotated, Literal, Optional, Self, TypeAlias + +from pydantic import BaseModel, ConfigDict, Field + +from models.api.responses.error import AbstractErrorResponse +from models.common import ReferencedDocument, ToolCallSummary, ToolResultSummary + + +class StreamPayloadBase(BaseModel): + """Base for streaming SSE JSON payloads.""" + + model_config = ConfigDict(extra="forbid") + + def serialize_json(self) -> str: + """Format this payload as an SSE ``data:`` line.""" + return f"data: {json.dumps(self.model_dump(mode='json'))}\n\n" + + def serialize_text(self) -> str: + """Format this payload as plain text for text media type clients.""" + return "" + + +class ErrorEventData(BaseModel): + """Payload for event: "error".""" + + status_code: int + response: str + cause: str + + +class StartEventData(BaseModel): + """Payload for event: "start".""" + + conversation_id: str + request_id: str + + +class InterruptedEventData(BaseModel): + """Payload for event: "interrupted".""" + + request_id: str + + +class EndEventData(BaseModel): + """Nested data for event: "end".""" + + referenced_documents: list[ReferencedDocument] + truncated: Optional[bool] + input_tokens: int + output_tokens: int + + +class ErrorStreamPayload(StreamPayloadBase): + """SSE error event body (event + typed data).""" + + event: Literal["error"] = "error" + data: ErrorEventData + + @classmethod + def create(cls, *, status_code: int, response: str, cause: str) -> Self: + """Create an error stream payload from HTTP error fields. + + Args: + status_code: HTTP status code for the error. + response: Short summary of the error. + cause: Detailed explanation of the error cause. + + Returns: + Error stream payload instance. + """ + return cls( + data=ErrorEventData(status_code=status_code, response=response, cause=cause) + ) + + @classmethod + def from_error_response(cls, error_response: AbstractErrorResponse) -> Self: + """Create an error stream payload from a structured API error response. + + Args: + error_response: Structured error response model. + + Returns: + Error stream payload instance. + """ + return cls.create( + status_code=error_response.status_code, + response=error_response.detail.response, + cause=error_response.detail.cause, + ) + + def serialize_text(self) -> str: + """Serialize error stream payload to plain text.""" + return f"Status: {self.data.status_code} - {self.data.response} - {self.data.cause}" + + +class StartStreamPayload(StreamPayloadBase): + """SSE stream start body.""" + + event: Literal["start"] = "start" + data: StartEventData + + @classmethod + def create(cls, *, conversation_id: str, request_id: str) -> Self: + """Create a stream-start payload. + + Args: + conversation_id: Conversation identifier for the stream. + request_id: Request identifier for the stream. + + Returns: + Start stream payload instance. + """ + return cls( + data=StartEventData(conversation_id=conversation_id, request_id=request_id) + ) + + +class InterruptedStreamPayload(StreamPayloadBase): + """SSE interrupted stream body.""" + + event: Literal["interrupted"] = "interrupted" + data: InterruptedEventData + + @classmethod + def create(cls, *, request_id: str) -> Self: + """Create an interrupted-stream payload. + + Args: + request_id: Request identifier for the interrupted stream. + + Returns: + Interrupted stream payload instance. + """ + return cls(data=InterruptedEventData(request_id=request_id)) + + +class EndStreamPayload(StreamPayloadBase): + """SSE end-of-stream body (includes available_quotas beside data).""" + + event: Literal["end"] = "end" + data: EndEventData + available_quotas: dict[str, int] + + @classmethod + def create( + cls, + *, + referenced_documents: list[ReferencedDocument], + input_tokens: int, + output_tokens: int, + available_quotas: dict[str, int], + ) -> Self: + """Create an end-of-stream payload. + + Args: + referenced_documents: Documents referenced during the turn. + input_tokens: Input token count for the turn. + output_tokens: Output token count for the turn. + available_quotas: Remaining quota limits by quota name. + + Returns: + End stream payload instance. + """ + return cls( + data=EndEventData( + referenced_documents=referenced_documents, + truncated=None, + input_tokens=input_tokens, + output_tokens=output_tokens, + ), + available_quotas=available_quotas, + ) + + def serialize_text(self) -> str: + """Serialize end stream payload to plain text.""" + ref_docs_string = "\n".join( + f"{doc.doc_title}: {doc.doc_url}" + for doc in self.data.referenced_documents + if doc.doc_url and doc.doc_title + ) + return f"\n\n---\n\n{ref_docs_string}" if ref_docs_string else "" + + +class TokenChunkData(BaseModel): + """Structured data for token and turn-complete stream lines.""" + + id: int + token: str + + +class TokenStreamPayload(StreamPayloadBase): + """SSE token delta (event: "token").""" + + event: Literal["token"] = "token" + data: TokenChunkData + + @classmethod + def create(cls, *, chunk_id: int, token: str) -> Self: + """Create a token stream payload. + + Args: + chunk_id: Monotonic chunk identifier for the token delta. + token: Token text for the delta. + + Returns: + Token stream payload instance. + """ + return cls(data=TokenChunkData(id=chunk_id, token=token)) + + def serialize_text(self) -> str: + """Serialize token stream payload to plain text.""" + return self.data.token + + +class TurnCompleteStreamPayload(StreamPayloadBase): + """SSE turn completion (same data shape as token).""" + + event: Literal["turn_complete"] = "turn_complete" + data: TokenChunkData + + @classmethod + def create(cls, *, chunk_id: int, token: str) -> Self: + """Create a turn-complete stream payload. + + Args: + chunk_id: Monotonic chunk identifier for the final text. + token: Full assistant text for the completed turn. + + Returns: + Turn-complete stream payload instance. + """ + return cls(data=TokenChunkData(id=chunk_id, token=token)) + + +class ToolCallStreamPayload(StreamPayloadBase): + """SSE tool call summary.""" + + event: Literal["tool_call"] = "tool_call" + data: ToolCallSummary + + def serialize_text(self) -> str: + """Serialize tool call stream payload to plain text.""" + return f"[Tool Call: {self.data.name}]\n" + + +class ToolResultStreamPayload(StreamPayloadBase): + """SSE tool result summary.""" + + event: Literal["tool_result"] = "tool_result" + data: ToolResultSummary + + def serialize_text(self) -> str: + """Serialize tool result stream payload to plain text.""" + return "[Tool Result]\n" + + +StreamEventPayload: TypeAlias = Annotated[ + TokenStreamPayload + | TurnCompleteStreamPayload + | ToolCallStreamPayload + | ToolResultStreamPayload + | EndStreamPayload + | ErrorStreamPayload + | InterruptedStreamPayload + | StartStreamPayload, + Field(discriminator="event"), +] diff --git a/src/models/common/agents/turn_accumulator.py b/src/models/common/agents/turn_accumulator.py new file mode 100644 index 000000000..7c3040e6f --- /dev/null +++ b/src/models/common/agents/turn_accumulator.py @@ -0,0 +1,46 @@ +"""Mutable per-turn state for agent response processing.""" + +from dataclasses import dataclass, field +from typing import Final + +from pydantic_ai import AgentRunResult + +from models.common.turn_summary import TurnSummary + + +@dataclass(slots=True) +class AgentTurnAccumulator: # pylint: disable=too-many-instance-attributes + """Information accumulator for a single interaction turn. + + Attributes: + vector_store_ids: Vector store IDs used to resolve RAG source labels. + rag_id_mapping: Maps vector store IDs to user-facing source names. + turn_summary: Aggregated turn output (text, tools, RAG, token usage). + run_result: Agent run result (streaming only). + chunk_id: Monotonic SSE chunk index (streaming only). + text_parts: Buffered text deltas before turn_complete (streaming only). + tool_round: Current tool-call round for summary labeling. + round_increment_pending: Whether to bump tool_round on the next step. + emitted_tool_call_ids: Tool call IDs already sent or recorded. + emitted_tool_result_ids: Tool result IDs already sent or recorded. + seen_docs: Referenced-document keys already added (deduplication). + """ + + vector_store_ids: Final[list[str]] + rag_id_mapping: Final[dict[str, str]] + turn_summary: TurnSummary + run_result: AgentRunResult[str] | None = None + chunk_id: int = 0 + text_parts: list[str] = field(default_factory=list) + tool_round: int = 1 + round_increment_pending: bool = False + emitted_tool_call_ids: set[str] = field(default_factory=set) + emitted_tool_result_ids: set[str] = field(default_factory=set) + seen_docs: set[tuple[str, str]] = field(default_factory=set) + + def increment_round_if_pending(self) -> None: + """Increment tool_round if round_increment_pending is True.""" + if not self.round_increment_pending: + return + self.tool_round += 1 + self.round_increment_pending = False diff --git a/src/models/common/turn_summary.py b/src/models/common/turn_summary.py index 9bf8cb04e..c3de9e572 100644 --- a/src/models/common/turn_summary.py +++ b/src/models/common/turn_summary.py @@ -108,3 +108,29 @@ class TurnSummary(BaseModel): rag_chunks: list[RAGChunk] = Field(default_factory=list) referenced_documents: list[ReferencedDocument] = Field(default_factory=list) token_usage: TokenCounter = Field(default_factory=TokenCounter) + + +class ToolInfoSummary(BaseModel): + """Model representing metadata for a single tool exposed by MCP list tools.""" + + name: str = Field(description="Tool name") + description: Optional[str] = Field( + default=None, + description="Human-readable tool description", + ) + input_schema: Optional[dict[str, Any]] = Field( + default=None, + description="JSON schema for the tool input", + ) + + +class MCPListToolsSummary(BaseModel): + """Model representing MCP list tools payload serialized into tool results.""" + + server_label: str = Field( + description="MCP server label associated with the tool list", + ) + tools: list[ToolInfoSummary] = Field( + default_factory=list, + description="Tools exposed by the MCP server", + )