Skip to content
Open
2 changes: 1 addition & 1 deletion examples/avatar_agents/tavus/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ This example demonstrates how to create a animated avatar using [Tavus](https://
```bash
# Tavus Config
export TAVUS_API_KEY="..."
export TAVUS_REPLICA_ID="..."
export TAVUS_FACE_ID="..."

# OpenAI config (or other models, tts, stt)
export OPENAI_API_KEY="..."
Expand Down
6 changes: 3 additions & 3 deletions examples/avatar_agents/tavus/agent_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ async def entrypoint(ctx: JobContext):
resume_false_interruption=False,
)

persona_id = os.getenv("TAVUS_PERSONA_ID")
replica_id = os.getenv("TAVUS_REPLICA_ID")
tavus_avatar = tavus.AvatarSession(persona_id=persona_id, replica_id=replica_id)
pal_id = os.getenv("TAVUS_PAL_ID")
face_id = os.getenv("TAVUS_FACE_ID")
tavus_avatar = tavus.AvatarSession(pal_id=pal_id, face_id=face_id)
await tavus_avatar.start(session, room=ctx.room)

# start the agent, it will join the room and wait for the avatar to join
Expand Down
105 changes: 92 additions & 13 deletions livekit-plugins/livekit-plugins-tavus/livekit/plugins/tavus/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import os
import warnings
from typing import Any

import aiohttp
Expand All @@ -22,6 +23,37 @@ class TavusException(Exception):


DEFAULT_API_URL = "https://tavusapi.com/v2"
# Stock face used when the caller provides neither a face nor a pal.
DEFAULT_FACE_ID = "r72f7f7f7c8b"


def _resolve_renamed_arg(
new_value: NotGivenOr[str],
deprecated_value: NotGivenOr[str],
*,
deprecated_name: str,
new_name: str,
) -> NotGivenOr[str]:
# Prefer the new arg; fall back to the deprecated alias and warn only when it's used.
if deprecated_value and not new_value:
warnings.warn(
f"`{deprecated_name}` is deprecated, use `{new_name}` instead",
DeprecationWarning,
stacklevel=3,
)
return new_value or deprecated_value


def _deprecated_env(deprecated_name: str, new_name: str) -> str | None:
# Read a deprecated env var, warning if it's set so callers migrate to `new_name`.
value = os.getenv(deprecated_name)
if value:
warnings.warn(
f"`{deprecated_name}` is deprecated, use `{new_name}` instead",
DeprecationWarning,
stacklevel=3,
)
return value


class TavusAPI:
Expand All @@ -45,26 +77,43 @@ def __init__(
async def create_conversation(
self,
*,
face_id: NotGivenOr[str] = NOT_GIVEN,
pal_id: NotGivenOr[str] = NOT_GIVEN,
replica_id: NotGivenOr[str] = NOT_GIVEN,
persona_id: NotGivenOr[str] = NOT_GIVEN,
properties: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
extra_payload: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
) -> str:
replica_id = replica_id or (os.getenv("TAVUS_REPLICA_ID") or NOT_GIVEN)
if not replica_id:
raise TavusException("TAVUS_REPLICA_ID must be set")

persona_id = persona_id or (os.getenv("TAVUS_PERSONA_ID") or NOT_GIVEN)
if not persona_id:
# create a persona if not provided
persona_id = await self.create_persona()
# `replica_id`/`persona_id` are deprecated aliases for `face_id`/`pal_id`.
face_id = _resolve_renamed_arg(
face_id, replica_id, deprecated_name="replica_id", new_name="face_id"
)
pal_id = _resolve_renamed_arg(
pal_id, persona_id, deprecated_name="persona_id", new_name="pal_id"
)

face_id = (
face_id
or os.getenv("TAVUS_FACE_ID")
or _deprecated_env("TAVUS_REPLICA_ID", "TAVUS_FACE_ID")
or NOT_GIVEN
)
pal_id = (
pal_id
or os.getenv("TAVUS_PAL_ID")
or _deprecated_env("TAVUS_PERSONA_ID", "TAVUS_PAL_ID")
or NOT_GIVEN
)

if not pal_id:
# no pal to reuse, so create one — falling back to the default face
pal_id = await self.create_pal(default_face_id=face_id or DEFAULT_FACE_ID)
Comment on lines +108 to +110

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚩 Behavioral change: missing face/pal no longer raises an exception

Previously, if no replica_id was provided (via argument or env var), create_conversation raised TavusException("TAVUS_REPLICA_ID must be set"). Now, if neither face_id nor pal_id is provided, the code silently auto-creates a pal using DEFAULT_FACE_ID = "r72f7f7f7c8b" (livekit-plugins/livekit-plugins-tavus/livekit/plugins/tavus/api.py:27). This is a deliberate UX improvement (zero-config start), but existing users who relied on the error to catch missing configuration will no longer get that safety net. The default face ID is hardcoded; if it becomes invalid on the Tavus side, the API call will fail with a less clear error from the Tavus API rather than a local validation error.

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.


properties = properties or {}
payload = {
"replica_id": replica_id,
"persona_id": persona_id,
"properties": properties,
}
payload: dict[str, Any] = {"pal_id": pal_id, "properties": properties}
# send face_id only when given; otherwise the pal's default_face_id is used
if face_id:
payload["face_id"] = face_id
if utils.is_given(extra_payload):
payload.update(extra_payload)

Expand All @@ -74,12 +123,42 @@ async def create_conversation(
response_data = await self._post("conversations", payload)
return response_data["conversation_id"] # type: ignore

async def create_pal(
self,
name: NotGivenOr[str] = NOT_GIVEN,
*,
default_face_id: str,
extra_payload: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
) -> str:
name = name or utils.shortuuid("lk_pal_")

payload = {
"pal_name": name,
"default_face_id": default_face_id,
"pipeline_mode": "echo",
"layers": {
"transport": {"transport_type": "livekit"},
},
}

if utils.is_given(extra_payload):
payload.update(extra_payload)

response_data = await self._post("pals", payload)
return response_data["pal_id"] # type: ignore

async def create_persona(
self,
name: NotGivenOr[str] = NOT_GIVEN,
*,
extra_payload: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
) -> str:
# Deprecated: use create_pal(). Kept on the legacy /v2/personas endpoint.
warnings.warn(
"`create_persona` is deprecated, use `create_pal` instead",
DeprecationWarning,
stacklevel=2,
)
name = name or utils.shortuuid("lk_persona_")

payload = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from livekit.agents.voice.avatar import AvatarSession as BaseAvatarSession, DataStreamAudioOutput
from livekit.agents.voice.room_io import ATTRIBUTE_PUBLISH_ON_BEHALF

from .api import TavusAPI, TavusException
from .api import TavusAPI, TavusException, _resolve_renamed_arg
from .log import logger

SAMPLE_RATE = 24000
Expand All @@ -31,6 +31,8 @@ class AvatarSession(BaseAvatarSession):
def __init__(
self,
*,
face_id: NotGivenOr[str] = NOT_GIVEN,
pal_id: NotGivenOr[str] = NOT_GIVEN,
replica_id: NotGivenOr[str] = NOT_GIVEN,
persona_id: NotGivenOr[str] = NOT_GIVEN,
api_url: NotGivenOr[str] = NOT_GIVEN,
Expand All @@ -43,8 +45,13 @@ def __init__(
self._http_session: aiohttp.ClientSession | None = None
self._conn_options = conn_options
self.conversation_id: str | None = None
self._persona_id = persona_id
self._replica_id = replica_id
# `replica_id`/`persona_id` are deprecated aliases for `face_id`/`pal_id`.
self._pal_id = _resolve_renamed_arg(
pal_id, persona_id, deprecated_name="persona_id", new_name="pal_id"
)
self._face_id = _resolve_renamed_arg(
face_id, replica_id, deprecated_name="replica_id", new_name="face_id"
)
self._api = TavusAPI(
api_url=api_url,
api_key=api_key,
Expand Down Expand Up @@ -104,8 +111,8 @@ async def start(

logger.debug("starting avatar session")
self.conversation_id = await self._api.create_conversation(
persona_id=self._persona_id,
replica_id=self._replica_id,
pal_id=self._pal_id,
face_id=self._face_id,
properties={"livekit_ws_url": livekit_url, "livekit_room_token": livekit_token},
)

Expand Down
144 changes: 144 additions & 0 deletions livekit-plugins/livekit-plugins-tavus/tests/test_tavus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import warnings
Comment thread
tinalenguyen marked this conversation as resolved.
from unittest.mock import AsyncMock, patch

import pytest

from livekit.agents.utils import http_context
from livekit.plugins.tavus.api import DEFAULT_FACE_ID, TavusAPI
from livekit.plugins.tavus.avatar import AvatarSession

pytestmark = pytest.mark.unit


@pytest.fixture(autouse=True)
def _env(monkeypatch):
for v in ("TAVUS_FACE_ID", "TAVUS_PAL_ID", "TAVUS_REPLICA_ID", "TAVUS_PERSONA_ID"):
monkeypatch.delenv(v, raising=False)
monkeypatch.setenv("TAVUS_API_KEY", "test-key")


def _api() -> TavusAPI:
# session is unused because _post is always mocked in these tests
return TavusAPI(session=object()) # type: ignore[arg-type]


def _mock_post() -> AsyncMock:
return AsyncMock(return_value={"conversation_id": "conv1", "persona_id": "pal_auto"})


def _no_deprecation(rec: list[warnings.WarningMessage]) -> bool:
return not [w for w in rec if issubclass(w.category, DeprecationWarning)]


async def test_new_args_map_to_unchanged_wire_keys():
api = _api()
with patch.object(api, "_post", new=_mock_post()) as m:
with warnings.catch_warnings(record=True) as rec:
warnings.simplefilter("always")
cid = await api.create_conversation(face_id="f1", pal_id="p1")
assert cid == "conv1"
payload = m.call_args.args[1]
assert payload["face_id"] == "f1"
assert payload["pal_id"] == "p1"
assert _no_deprecation(rec)


async def test_deprecated_args_still_work_and_warn():
api = _api()
with patch.object(api, "_post", new=_mock_post()) as m:
with pytest.warns(DeprecationWarning) as rec:
await api.create_conversation(replica_id="r1", persona_id="x1")
payload = m.call_args.args[1]
assert payload["face_id"] == "r1"
assert payload["pal_id"] == "x1"
msgs = [str(w.message) for w in rec]
assert any("replica_id" in s and "face_id" in s for s in msgs)
assert any("persona_id" in s and "pal_id" in s for s in msgs)


async def test_no_warning_when_new_and_deprecated_both_given():
api = _api()
with patch.object(api, "_post", new=_mock_post()) as m:
with warnings.catch_warnings(record=True) as rec:
warnings.simplefilter("always")
await api.create_conversation(
face_id="f1", replica_id="r1", pal_id="p1", persona_id="x1"
)
# the new values win, so the deprecated aliases are unused -> no warning
assert _no_deprecation(rec)
payload = m.call_args.args[1]
assert payload["face_id"] == "f1"
assert payload["pal_id"] == "p1"


async def test_new_env_vars_fallback(monkeypatch):
monkeypatch.setenv("TAVUS_FACE_ID", "envf")
monkeypatch.setenv("TAVUS_PAL_ID", "envp")
api = _api()
with patch.object(api, "_post", new=_mock_post()) as m:
with warnings.catch_warnings(record=True) as rec:
warnings.simplefilter("always")
await api.create_conversation()
payload = m.call_args.args[1]
assert payload["face_id"] == "envf"
assert payload["pal_id"] == "envp"
assert _no_deprecation(rec)


async def test_deprecated_env_vars_still_work_and_warn(monkeypatch):
monkeypatch.setenv("TAVUS_REPLICA_ID", "oldf")
monkeypatch.setenv("TAVUS_PERSONA_ID", "oldp")
api = _api()
with patch.object(api, "_post", new=_mock_post()) as m:
with pytest.warns(DeprecationWarning):
await api.create_conversation()
payload = m.call_args.args[1]
assert payload["face_id"] == "oldf"
assert payload["pal_id"] == "oldp"


async def test_auto_creates_pal_when_none_given():
api = _api()
post = AsyncMock(side_effect=[{"pal_id": "pal_new"}, {"conversation_id": "conv1"}])
with patch.object(api, "_post", new=post):
await api.create_conversation(face_id="f1")
pal_endpoint, pal_payload = post.call_args_list[0].args
conv_endpoint, conv_payload = post.call_args_list[1].args
assert pal_endpoint == "pals"
assert pal_payload["default_face_id"] == "f1"
assert conv_endpoint == "conversations"
assert conv_payload["face_id"] == "f1"
assert conv_payload["pal_id"] == "pal_new"


async def test_pal_id_only_skips_pal_creation_and_omits_face():
api = _api()
with patch.object(api, "_post", new=_mock_post()) as m:
await api.create_conversation(pal_id="p1")
endpoints = [c.args[0] for c in m.call_args_list]
assert "pals" not in endpoints # an existing pal carries its own default face
payload = m.call_args.args[1]
assert payload["pal_id"] == "p1"
assert "face_id" not in payload


async def test_defaults_face_when_neither_given():
api = _api()
post = AsyncMock(side_effect=[{"pal_id": "pal_new"}, {"conversation_id": "conv1"}])
with patch.object(api, "_post", new=post):
await api.create_conversation()
pal_endpoint, pal_payload = post.call_args_list[0].args
assert pal_endpoint == "pals"
assert pal_payload["default_face_id"] == DEFAULT_FACE_ID


async def test_avatar_session_resolves_new_and_deprecated_args():
async with http_context.open():
with pytest.warns(DeprecationWarning):
deprecated = AvatarSession(replica_id="r9", persona_id="x9")
assert deprecated._face_id == "r9"
assert deprecated._pal_id == "x9"

renamed = AvatarSession(face_id="f9", pal_id="p9")
assert renamed._face_id == "f9"
assert renamed._pal_id == "p9"
Loading