Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/models/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
from models.common.query import Attachment, SolrVectorSearchRequest
from models.common.transcripts import Transcript, TranscriptMetadata
from models.common.turn_summary import (
MCPListToolsSummary,
RAGChunk,
RAGContext,
ReferencedDocument,
ToolCallSummary,
ToolInfoSummary,
ToolResultSummary,
TurnSummary,
)
Expand Down Expand Up @@ -48,4 +50,6 @@
"Transcript",
"TranscriptMetadata",
"TurnSummary",
"ToolInfoSummary",
"MCPListToolsSummary",
]
39 changes: 39 additions & 0 deletions src/models/common/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Streaming payload models and event type exports."""

from models.common.agents.stream_payloads import (
EndEventData,
EndStreamPayload,
ErrorEventData,
ErrorStreamPayload,
InterruptedEventData,
InterruptedStreamPayload,
StartEventData,
StartStreamPayload,
StreamEventPayload,
StreamPayloadBase,
TokenChunkData,
TokenStreamPayload,
ToolCallStreamPayload,
ToolResultStreamPayload,
TurnCompleteStreamPayload,
)
from models.common.agents.turn_accumulator import AgentTurnAccumulator

__all__ = [
"StreamPayloadBase",
"ErrorEventData",
"StartEventData",
"InterruptedEventData",
"EndEventData",
"ErrorStreamPayload",
"StartStreamPayload",
"InterruptedStreamPayload",
"EndStreamPayload",
"TokenChunkData",
"TokenStreamPayload",
"TurnCompleteStreamPayload",
"ToolCallStreamPayload",
"ToolResultStreamPayload",
"StreamEventPayload",
"AgentTurnAccumulator",
]
270 changes: 270 additions & 0 deletions src/models/common/agents/stream_payloads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
"""Typed JSON bodies for SSE streaming events."""

import json
from typing import Annotated, Literal, Optional, Self, TypeAlias

from pydantic import BaseModel, ConfigDict, Field

from models.api.responses.error import AbstractErrorResponse
from models.common import ReferencedDocument, ToolCallSummary, ToolResultSummary


class StreamPayloadBase(BaseModel):
"""Base for streaming SSE JSON payloads."""

model_config = ConfigDict(extra="forbid")

def serialize_json(self) -> str:
"""Format this payload as an SSE ``data:`` line."""
return f"data: {json.dumps(self.model_dump(mode='json'))}\n\n"

def serialize_text(self) -> str:
"""Format this payload as plain text for text media type clients."""
return ""


class ErrorEventData(BaseModel):
"""Payload for event: "error"."""

status_code: int
response: str
cause: str
Comment thread
asimurka marked this conversation as resolved.


class StartEventData(BaseModel):
"""Payload for event: "start"."""

conversation_id: str
request_id: str


class InterruptedEventData(BaseModel):
"""Payload for event: "interrupted"."""

request_id: str


class EndEventData(BaseModel):
"""Nested data for event: "end"."""

referenced_documents: list[ReferencedDocument]
truncated: Optional[bool]
input_tokens: int
output_tokens: int


class ErrorStreamPayload(StreamPayloadBase):
"""SSE error event body (event + typed data)."""

event: Literal["error"] = "error"
data: ErrorEventData

@classmethod
def create(cls, *, status_code: int, response: str, cause: str) -> Self:
"""Create an error stream payload from HTTP error fields.

Args:
status_code: HTTP status code for the error.
response: Short summary of the error.
cause: Detailed explanation of the error cause.

Returns:
Error stream payload instance.
"""
return cls(
data=ErrorEventData(status_code=status_code, response=response, cause=cause)
)

@classmethod
def from_error_response(cls, error_response: AbstractErrorResponse) -> Self:
"""Create an error stream payload from a structured API error response.

Args:
error_response: Structured error response model.

Returns:
Error stream payload instance.
"""
return cls.create(
status_code=error_response.status_code,
response=error_response.detail.response,
cause=error_response.detail.cause,
)

def serialize_text(self) -> str:
"""Serialize error stream payload to plain text."""
return f"Status: {self.data.status_code} - {self.data.response} - {self.data.cause}"


class StartStreamPayload(StreamPayloadBase):
"""SSE stream start body."""

event: Literal["start"] = "start"
data: StartEventData

@classmethod
def create(cls, *, conversation_id: str, request_id: str) -> Self:
"""Create a stream-start payload.

Args:
conversation_id: Conversation identifier for the stream.
request_id: Request identifier for the stream.

Returns:
Start stream payload instance.
"""
return cls(
data=StartEventData(conversation_id=conversation_id, request_id=request_id)
)


class InterruptedStreamPayload(StreamPayloadBase):
"""SSE interrupted stream body."""

event: Literal["interrupted"] = "interrupted"
data: InterruptedEventData

@classmethod
def create(cls, *, request_id: str) -> Self:
"""Create an interrupted-stream payload.

Args:
request_id: Request identifier for the interrupted stream.

Returns:
Interrupted stream payload instance.
"""
return cls(data=InterruptedEventData(request_id=request_id))


class EndStreamPayload(StreamPayloadBase):
"""SSE end-of-stream body (includes available_quotas beside data)."""

event: Literal["end"] = "end"
data: EndEventData
available_quotas: dict[str, int]

@classmethod
def create(
cls,
*,
referenced_documents: list[ReferencedDocument],
input_tokens: int,
output_tokens: int,
available_quotas: dict[str, int],
) -> Self:
"""Create an end-of-stream payload.

Args:
referenced_documents: Documents referenced during the turn.
input_tokens: Input token count for the turn.
output_tokens: Output token count for the turn.
available_quotas: Remaining quota limits by quota name.

Returns:
End stream payload instance.
"""
return cls(
data=EndEventData(
referenced_documents=referenced_documents,
truncated=None,
input_tokens=input_tokens,
output_tokens=output_tokens,
),
available_quotas=available_quotas,
)

def serialize_text(self) -> str:
"""Serialize end stream payload to plain text."""
ref_docs_string = "\n".join(
f"{doc.doc_title}: {doc.doc_url}"
for doc in self.data.referenced_documents
if doc.doc_url and doc.doc_title
)
return f"\n\n---\n\n{ref_docs_string}" if ref_docs_string else ""


class TokenChunkData(BaseModel):
"""Structured data for token and turn-complete stream lines."""

id: int
token: str


class TokenStreamPayload(StreamPayloadBase):
"""SSE token delta (event: "token")."""

event: Literal["token"] = "token"
data: TokenChunkData

@classmethod
def create(cls, *, chunk_id: int, token: str) -> Self:
"""Create a token stream payload.

Args:
chunk_id: Monotonic chunk identifier for the token delta.
token: Token text for the delta.

Returns:
Token stream payload instance.
"""
return cls(data=TokenChunkData(id=chunk_id, token=token))

def serialize_text(self) -> str:
"""Serialize token stream payload to plain text."""
return self.data.token


class TurnCompleteStreamPayload(StreamPayloadBase):
"""SSE turn completion (same data shape as token)."""

event: Literal["turn_complete"] = "turn_complete"
data: TokenChunkData

@classmethod
def create(cls, *, chunk_id: int, token: str) -> Self:
"""Create a turn-complete stream payload.

Args:
chunk_id: Monotonic chunk identifier for the final text.
token: Full assistant text for the completed turn.

Returns:
Turn-complete stream payload instance.
"""
return cls(data=TokenChunkData(id=chunk_id, token=token))
Comment thread
asimurka marked this conversation as resolved.


class ToolCallStreamPayload(StreamPayloadBase):
"""SSE tool call summary."""

event: Literal["tool_call"] = "tool_call"
data: ToolCallSummary

def serialize_text(self) -> str:
"""Serialize tool call stream payload to plain text."""
return f"[Tool Call: {self.data.name}]\n"
Comment thread
asimurka marked this conversation as resolved.


class ToolResultStreamPayload(StreamPayloadBase):
"""SSE tool result summary."""

event: Literal["tool_result"] = "tool_result"
data: ToolResultSummary

def serialize_text(self) -> str:
"""Serialize tool result stream payload to plain text."""
return "[Tool Result]\n"


StreamEventPayload: TypeAlias = Annotated[
TokenStreamPayload
| TurnCompleteStreamPayload
| ToolCallStreamPayload
| ToolResultStreamPayload
| EndStreamPayload
| ErrorStreamPayload
| InterruptedStreamPayload
| StartStreamPayload,
Field(discriminator="event"),
]
46 changes: 46 additions & 0 deletions src/models/common/agents/turn_accumulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Mutable per-turn state for agent response processing."""

from dataclasses import dataclass, field
from typing import Final

from pydantic_ai import AgentRunResult

from models.common.turn_summary import TurnSummary


@dataclass(slots=True)
class AgentTurnAccumulator: # pylint: disable=too-many-instance-attributes
Comment thread
asimurka marked this conversation as resolved.
"""Information accumulator for a single interaction turn.

Attributes:
vector_store_ids: Vector store IDs used to resolve RAG source labels.
rag_id_mapping: Maps vector store IDs to user-facing source names.
turn_summary: Aggregated turn output (text, tools, RAG, token usage).
run_result: Agent run result (streaming only).
chunk_id: Monotonic SSE chunk index (streaming only).
text_parts: Buffered text deltas before turn_complete (streaming only).
tool_round: Current tool-call round for summary labeling.
round_increment_pending: Whether to bump tool_round on the next step.
emitted_tool_call_ids: Tool call IDs already sent or recorded.
emitted_tool_result_ids: Tool result IDs already sent or recorded.
seen_docs: Referenced-document keys already added (deduplication).
"""

vector_store_ids: Final[list[str]]
rag_id_mapping: Final[dict[str, str]]
turn_summary: TurnSummary
run_result: AgentRunResult[str] | None = None
chunk_id: int = 0
text_parts: list[str] = field(default_factory=list)
tool_round: int = 1
round_increment_pending: bool = False
emitted_tool_call_ids: set[str] = field(default_factory=set)
emitted_tool_result_ids: set[str] = field(default_factory=set)
seen_docs: set[tuple[str, str]] = field(default_factory=set)

def increment_round_if_pending(self) -> None:
"""Increment tool_round if round_increment_pending is True."""
Comment thread
asimurka marked this conversation as resolved.
if not self.round_increment_pending:
return
self.tool_round += 1
self.round_increment_pending = False
Loading
Loading