diff --git a/packages/uipath/src/uipath/_cli/_chat/_bridge.py b/packages/uipath/src/uipath/_cli/_chat/_bridge.py index 24c1be024..3bc33c2ab 100644 --- a/packages/uipath/src/uipath/_cli/_chat/_bridge.py +++ b/packages/uipath/src/uipath/_cli/_chat/_bridge.py @@ -478,6 +478,48 @@ async def _cleanup_client(self) -> None: self._client = None +def build_cas_url(conversation_id: str) -> tuple[str, str]: + """Return (url, socketio_path) for a CAS socket.io connection. + + Checks CAS_WEBSOCKET_HOST first (local dev override), then falls back to UIPATH_URL. + + Raises: + RuntimeError: If UIPATH_URL is not set or has no netloc (and CAS_WEBSOCKET_HOST + is not set). + """ + cas_ws_host = os.environ.get("CAS_WEBSOCKET_HOST") + if cas_ws_host: + logger.warning( + "CAS_WEBSOCKET_HOST is set. Using local CAS at '%s'.", cas_ws_host + ) + return f"ws://{cas_ws_host}?conversationId={conversation_id}", "/socket.io" + + base_url = os.environ.get("UIPATH_URL") + if not base_url: + raise RuntimeError( + "UIPATH_URL environment variable required for conversational mode" + ) + parsed = urlparse(base_url) + if not parsed.netloc: + raise RuntimeError(f"Invalid UIPATH_URL format: {base_url}") + return ( + f"wss://{parsed.netloc}?conversationId={conversation_id}", + "autopilotforeveryone_/websocket_/socket.io", + ) + + +def build_cas_auth_headers(ctx: UiPathRuntimeContext) -> dict[str, str]: + """Build authentication headers for a CAS socket.io connection.""" + return { + "Authorization": f"Bearer {os.environ.get('UIPATH_ACCESS_TOKEN', '')}", + "X-UiPath-Internal-TenantId": ctx.tenant_id + or os.environ.get("UIPATH_TENANT_ID", ""), + "X-UiPath-Internal-AccountId": ctx.org_id + or os.environ.get("UIPATH_ORGANIZATION_ID", ""), + "X-UiPath-ConversationId": ctx.conversation_id or "", + } + + def get_chat_bridge( context: UiPathRuntimeContext, ) -> UiPathChatProtocol: @@ -505,39 +547,8 @@ def get_chat_bridge( assert context.conversation_id is not None, "conversation_id must be set in context" assert context.exchange_id is not None, "exchange_id must be set in context" - # Extract host from UIPATH_URL - base_url = os.environ.get("UIPATH_URL") - if not base_url: - raise RuntimeError( - "UIPATH_URL environment variable required for conversational mode" - ) - - parsed = urlparse(base_url) - if not parsed.netloc: - raise RuntimeError(f"Invalid UIPATH_URL format: {base_url}") - - host = parsed.netloc - - # Construct WebSocket URL for CAS - websocket_url = f"wss://{host}?conversationId={context.conversation_id}" - websocket_path = "autopilotforeveryone_/websocket_/socket.io" - - if os.environ.get("CAS_WEBSOCKET_HOST"): - websocket_url = f"ws://{os.environ.get('CAS_WEBSOCKET_HOST')}?conversationId={context.conversation_id}" - websocket_path = "/socket.io" - logger.warning( - f"CAS_WEBSOCKET_HOST is set. Using websocket_url '{websocket_url}{websocket_path}'." - ) - - # Build headers from context - headers = { - "Authorization": f"Bearer {os.environ.get('UIPATH_ACCESS_TOKEN', '')}", - "X-UiPath-Internal-TenantId": f"{context.tenant_id}" - or os.environ.get("UIPATH_TENANT_ID", ""), - "X-UiPath-Internal-AccountId": f"{context.org_id}" - or os.environ.get("UIPATH_ORGANIZATION_ID", ""), - "X-UiPath-ConversationId": context.conversation_id, - } + websocket_url, websocket_path = build_cas_url(context.conversation_id) + headers = build_cas_auth_headers(context) return SocketIOChatBridge( websocket_url=websocket_url, @@ -548,4 +559,4 @@ def get_chat_bridge( ) -__all__ = ["get_chat_bridge"] +__all__ = ["build_cas_auth_headers", "build_cas_url", "get_chat_bridge"] diff --git a/packages/uipath/src/uipath/_cli/_chat/_voice_session.py b/packages/uipath/src/uipath/_cli/_chat/_voice_session.py new file mode 100644 index 000000000..e34d6b73b --- /dev/null +++ b/packages/uipath/src/uipath/_cli/_chat/_voice_session.py @@ -0,0 +1,200 @@ +"""Voice tool-call session — persistent socket.io connection to CAS.""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +from collections.abc import Awaitable, Callable +from typing import Any + +logger = logging.getLogger(__name__) + +# socket.io connect handshake timeout +_CONNECT_TIMEOUT_SECONDS = 15.0 +# After session end, give in-flight tool tasks this long to finish before +# disconnecting. The phone call may have already ended (the agent's +# end-session tool runs client-side), but tool work in flight should still +# complete its side effects. +_SHUTDOWN_DRAIN_SECONDS = 30.0 + + +ToolHandler = Callable[[dict[str, Any]], Awaitable[tuple[str, bool]]] +"""Callback invoked once per tool call. + +Receives a single call dict (`{"callId", "toolName", "args"}`) and returns +`(result_text, is_error)` — the values sent back as the +`voice_tool_result` response. The session unpacks batched +`voice_tool_call` envelopes and invokes this handler once per call, +dispatching them in parallel. +""" + + +class VoiceToolCallSessionError(RuntimeError): + """Raised when the voice tool-call session cannot connect to CAS.""" + + +class VoiceToolCallSession: + """Persistent socket.io connection to CAS for the voice tool-call loop. + + CAS fetches the agent config from Orchestrator directly — this session + only carries tool-call traffic: + + - Receives `voice_tool_call` events from CAS. Each event is a batch + (`{"calls": [...]}`); each entry dispatches as an independent task. + - Emits one `voice_tool_result` per `callId` in any order. + - Exits on `voice_end_session` or socket disconnect. + + Parallel dispatch: Gemini Live can request multiple tool calls in a + single batch. Spawning a task per call keeps the socket reader free for + `voice_end_session` and avoids serializing multi-call turns. + + Graceful shutdown: when the session ends, in-flight tool tasks have + `_SHUTDOWN_DRAIN_SECONDS` to finish before the socket is disconnected. + + Example: + ```python + async def handle(call: dict[str, Any]) -> tuple[str, bool]: + return await run_my_tool(call["toolName"], call["args"]), False + + url, path = build_cas_url(conversation_id) + url += "&voiceUrt=true" + session = VoiceToolCallSession( + url=url, + socketio_path=path, + headers=build_cas_auth_headers(ctx), + tool_handler=handle, + ) + await session.run() + ``` + """ + + def __init__( + self, + url: str, + socketio_path: str, + headers: dict[str, str], + tool_handler: ToolHandler, + ) -> None: + self._url = url + self._socketio_path = socketio_path + self._headers = headers + self._tool_handler = tool_handler + self._client: Any = None + self._done = asyncio.Event() + self._in_flight: set[asyncio.Task[None]] = set() + + async def run(self) -> None: + """Connect, dispatch tool calls until session ends, then disconnect. + + Raises: + VoiceToolCallSessionError: If connecting to CAS fails. + """ + from socketio import AsyncClient # type: ignore[import-untyped] + + self._client = AsyncClient(logger=logger, engineio_logger=logger) + self._client.on("connect", self._handle_connect) + self._client.on("disconnect", self._handle_disconnect) + self._client.on("connect_error", self._handle_connect_error) + self._client.on("voice_tool_call", self._handle_tool_call) + self._client.on("voice_end_session", self._handle_session_end) + + try: + await asyncio.wait_for( + self._client.connect( + url=self._url, + socketio_path=self._socketio_path, + headers=self._headers, + transports=["websocket"], + ), + timeout=_CONNECT_TIMEOUT_SECONDS, + ) + except Exception as exc: + with contextlib.suppress(Exception): + await self._client.disconnect() + raise VoiceToolCallSessionError( + f"Failed to connect to CAS voice endpoint: {exc}" + ) from exc + + try: + await self._done.wait() + await self._drain_in_flight() + finally: + with contextlib.suppress(Exception): + await self._client.disconnect() + + async def _drain_in_flight(self) -> None: + """Wait for in-flight tool tasks to finish, capped by the drain timeout.""" + if not self._in_flight: + return + logger.info( + "[Voice] Session ended with %d in-flight tool task(s); draining (max %.0fs)", + len(self._in_flight), + _SHUTDOWN_DRAIN_SECONDS, + ) + try: + await asyncio.wait_for( + asyncio.gather(*self._in_flight, return_exceptions=True), + timeout=_SHUTDOWN_DRAIN_SECONDS, + ) + except asyncio.TimeoutError: + unfinished = sum(1 for t in self._in_flight if not t.done()) + logger.warning( + "[Voice] %d tool task(s) did not complete within %.0fs of session end", + unfinished, + _SHUTDOWN_DRAIN_SECONDS, + ) + + async def _handle_connect(self) -> None: + logger.info("[Voice] Socket.io connected to CAS") + + async def _handle_disconnect(self) -> None: + logger.info("[Voice] Socket.io disconnected from CAS") + self._done.set() + + async def _handle_connect_error(self, data: Any) -> None: + # sio.connect() also raises; this handler exists for logging only. + logger.error("[Voice] Socket.io connection error: %s", data) + + async def _handle_tool_call(self, data: dict[str, Any]) -> None: + """Spawn one task per call in the batch and return immediately. + + Returning fast keeps the socket.io reader free to receive + `voice_end_session` while tools are still running. + """ + if self._done.is_set(): + # Session is shutting down; ignore late events to avoid emitting + # results into a dead socket. + return + + calls = data.get("calls") + if not isinstance(calls, list) or not calls: + logger.warning("[Voice] voice_tool_call missing/empty 'calls': %r", data) + return + + for call in calls: + task = asyncio.create_task(self._execute_one(call)) + self._in_flight.add(task) + task.add_done_callback(self._in_flight.discard) + + async def _execute_one(self, call: dict[str, Any]) -> None: + """Run one tool call and emit its `voice_tool_result`.""" + call_id = call.get("callId", "") + tool_name = call.get("toolName", "?") + logger.info("[Voice] voice_tool_call dispatched: %s (%s)", tool_name, call_id) + try: + result_text, is_error = await self._tool_handler(call) + except Exception as exc: + logger.exception("[Voice] Tool call execution failed: %s", tool_name) + result_text, is_error = str(exc), True + + with contextlib.suppress(Exception): + await self._client.emit( + "voice_tool_result", + {"callId": call_id, "result": result_text, "isError": is_error}, + ) + logger.info("[Voice] voice_tool_result sent: %s (isError=%s)", call_id, is_error) + + async def _handle_session_end(self, data: Any) -> None: + logger.info("[Voice] voice_end_session received, shutting down") + self._done.set()