diff --git a/src/utils/pydantic_ai.py b/src/utils/pydantic_ai.py new file mode 100644 index 000000000..5df570dc9 --- /dev/null +++ b/src/utils/pydantic_ai.py @@ -0,0 +1,104 @@ +"""Helpers for running Pydantic AI agents against Llama Stack (Responses API compatibility).""" + +from __future__ import annotations + +from typing import Any, Final, cast + +from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient +from llama_stack_client import AsyncLlamaStackClient +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings + +from models.common.responses.responses_api_params import ResponsesApiParams +from pydantic_ai_lightspeed.llamastack import LlamaStackProvider + +_LLS_RESPONSES_EXTRA_FIELDS: Final[frozenset[str]] = frozenset( + { + "conversation", + "max_infer_iters", + "tools", + "tool_choice", + "include", + "text", + "reasoning", + "prompt", + "metadata", + "max_tool_calls", + "safety_identifier", + } +) + + +def _llama_stack_provider_from_client( + client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient, +) -> LlamaStackProvider: + """Construct a Pydantic AI Llama Stack provider backed by the same client as ``/query``.""" + if isinstance(client, AsyncLlamaStackAsLibraryClient): + return LlamaStackProvider(library_client=client) + api_key = client.api_key or "not-needed" + base = str(client.base_url).rstrip("/") + base_url = base if base.endswith("/v1") else f"{base}/v1" + return LlamaStackProvider( + base_url=base_url, + api_key=api_key, + http_client=client._client, # pylint: disable=protected-access + ) + + +def _model_settings_from_responses_params( + responses_params: ResponsesApiParams, +) -> OpenAIResponsesModelSettings: + """Map ``ResponsesApiParams`` into Pydantic AI OpenAI Responses model settings.""" + payload = responses_params.model_dump(exclude_none=True) + extra_body = {k: v for k, v in payload.items() if k in _LLS_RESPONSES_EXTRA_FIELDS} + settings_dict: dict[str, Any] = {} + if extra_body: + settings_dict["extra_body"] = extra_body + if responses_params.max_output_tokens is not None: + settings_dict["max_tokens"] = responses_params.max_output_tokens + if responses_params.temperature is not None: + settings_dict["temperature"] = responses_params.temperature + if responses_params.parallel_tool_calls is not None: + settings_dict["parallel_tool_calls"] = responses_params.parallel_tool_calls + if responses_params.extra_headers: + settings_dict["extra_headers"] = dict(responses_params.extra_headers) + settings_dict["openai_store"] = responses_params.store + if responses_params.previous_response_id is not None: + settings_dict["openai_previous_response_id"] = ( + responses_params.previous_response_id + ) + return cast(OpenAIResponsesModelSettings, settings_dict) + + +def build_agent( + client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient, + responses_params: ResponsesApiParams, +) -> Agent[None, str]: + """Build a Pydantic AI agent that mirrors ``responses_params`` on the Llama Stack backend. + + Uses ``LlamaStackProvider`` with the same ``AsyncLlamaStackClient`` (or library client) + as the query endpoint, and ``OpenAIResponsesModel`` so requests follow the Responses API. + Llama-Stack-specific fields (conversation, tools, MCP headers, etc.) are passed via + ``model_settings['extra_body']`` so they merge into the OpenAI client request body. + + Parameters: + client: Initialized Llama Stack client from ``AsyncLlamaStackClientHolder().get_client()``. + responses_params: Parameters produced by ``prepare_responses_params`` for this turn. + + Returns: + ``Agent`` configured for ``await agent.run(...)`` (or streaming) against the same + stack configuration as ``client.responses.create(**responses_params.model_dump())``. + """ + provider = _llama_stack_provider_from_client(client) + settings = _model_settings_from_responses_params(responses_params) + + model = OpenAIResponsesModel( + responses_params.model, + provider=provider, + settings=settings, + ) + return Agent( + model, + instructions=responses_params.instructions, + defer_model_check=True, + ) diff --git a/tests/unit/utils/test_pydantic_ai.py b/tests/unit/utils/test_pydantic_ai.py new file mode 100644 index 000000000..c7acc37e3 --- /dev/null +++ b/tests/unit/utils/test_pydantic_ai.py @@ -0,0 +1,270 @@ +"""Unit tests for utils/pydantic_ai module.""" + +# pylint: disable=protected-access + +import httpx +import pytest +from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient +from pytest_mock import MockerFixture + +from utils.pydantic_ai import ( + _LLS_RESPONSES_EXTRA_FIELDS, + _llama_stack_provider_from_client, + _model_settings_from_responses_params, + build_agent, +) + + +class TestLlamaStackProviderFromClient: + """Tests for _llama_stack_provider_from_client factory.""" + + def test_library_client(self, mocker: MockerFixture) -> None: + """Test that a library client creates a provider with library_client kwarg.""" + mock_lib_client = mocker.Mock(spec=AsyncLlamaStackAsLibraryClient) + mock_lib_client.provider_data = None + + provider = _llama_stack_provider_from_client(mock_lib_client) + + assert provider._library_client is mock_lib_client + + def test_remote_client_with_api_key(self, mocker: MockerFixture) -> None: + """Test that a remote client uses its api_key.""" + mock_client = mocker.Mock() + mock_client.base_url = "http://my-server:8321" + mock_client.api_key = "my-secret" + mock_client._client = mocker.Mock(spec=httpx.AsyncClient) + + provider = _llama_stack_provider_from_client(mock_client) + + assert provider.client.api_key == "my-secret" + assert "my-server:8321" in provider.base_url + + def test_remote_client_without_api_key(self, mocker: MockerFixture) -> None: + """Test that a remote client without api_key defaults to 'not-needed'.""" + mock_client = mocker.Mock() + mock_client.base_url = "http://my-server:8321" + mock_client.api_key = None + mock_client._client = mocker.Mock(spec=httpx.AsyncClient) + + provider = _llama_stack_provider_from_client(mock_client) + + assert provider.client.api_key == "not-needed" + + def test_remote_client_passes_http_client(self, mocker: MockerFixture) -> None: + """Test that a remote client's internal http_client is forwarded.""" + mock_http_client = mocker.Mock(spec=httpx.AsyncClient) + mock_client = mocker.Mock() + mock_client.base_url = "http://my-server:8321" + mock_client.api_key = "key" + mock_client._client = mock_http_client + + provider = _llama_stack_provider_from_client(mock_client) + + assert provider._client._client is mock_http_client + + +class TestModelSettingsFromResponsesParams: + """Tests for _model_settings_from_responses_params mapping.""" + + @pytest.fixture(name="minimal_params") + def minimal_params_fixture(self, mocker: MockerFixture) -> object: + """Create minimal ResponsesApiParams mock with required fields only.""" + params = mocker.Mock() + params.model_dump.return_value = {"model": "test/model", "input": "hello"} + params.max_output_tokens = None + params.temperature = None + params.parallel_tool_calls = None + params.extra_headers = None + params.store = False + params.previous_response_id = None + return params + + def test_minimal_params_returns_store_false(self, minimal_params: object) -> None: + """Test that minimal params produce settings with openai_store=False.""" + settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] + assert settings["openai_store"] is False + + def test_minimal_params_no_extra_body(self, minimal_params: object) -> None: + """Test that minimal params without extra fields omit extra_body.""" + settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] + assert "extra_body" not in settings + + def test_max_output_tokens_mapped(self, minimal_params: object) -> None: + """Test that max_output_tokens is mapped to max_tokens.""" + minimal_params.max_output_tokens = 1024 # type: ignore[attr-defined] + settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] + assert settings["max_tokens"] == 1024 + + def test_temperature_mapped(self, minimal_params: object) -> None: + """Test that temperature is passed through.""" + minimal_params.temperature = 0.7 # type: ignore[attr-defined] + settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] + assert settings["temperature"] == 0.7 + + def test_parallel_tool_calls_mapped(self, minimal_params: object) -> None: + """Test that parallel_tool_calls is passed through.""" + minimal_params.parallel_tool_calls = True # type: ignore[attr-defined] + settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] + assert settings["parallel_tool_calls"] is True + + def test_extra_headers_mapped(self, minimal_params: object) -> None: + """Test that extra_headers are converted to a dict.""" + minimal_params.extra_headers = {"x-custom": "value"} # type: ignore[attr-defined] + settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] + assert settings["extra_headers"] == {"x-custom": "value"} + + def test_store_true_mapped(self, minimal_params: object) -> None: + """Test that store=True is passed as openai_store.""" + minimal_params.store = True # type: ignore[attr-defined] + settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] + assert settings["openai_store"] is True + + def test_previous_response_id_mapped(self, minimal_params: object) -> None: + """Test that previous_response_id is passed as openai_previous_response_id.""" + minimal_params.previous_response_id = "resp_abc123" # type: ignore[attr-defined] + settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] + assert settings["openai_previous_response_id"] == "resp_abc123" + + def test_extra_body_from_lls_fields(self, mocker: MockerFixture) -> None: + """Test that LLS-specific fields are placed into extra_body.""" + params = mocker.Mock() + params.model_dump.return_value = { + "model": "test/model", + "conversation": "conv-123", + "max_infer_iters": 5, + "tools": [{"type": "function"}], + "tool_choice": "auto", + } + params.max_output_tokens = None + params.temperature = None + params.parallel_tool_calls = None + params.extra_headers = None + params.store = False + params.previous_response_id = None + + settings = _model_settings_from_responses_params(params) + + assert "extra_body" in settings + assert settings["extra_body"]["conversation"] == "conv-123" + assert settings["extra_body"]["max_infer_iters"] == 5 + assert settings["extra_body"]["tools"] == [{"type": "function"}] + assert settings["extra_body"]["tool_choice"] == "auto" + + def test_extra_body_only_includes_known_fields(self, mocker: MockerFixture) -> None: + """Test that extra_body only includes fields in _LLS_RESPONSES_EXTRA_FIELDS.""" + params = mocker.Mock() + params.model_dump.return_value = { + "model": "test/model", + "conversation": "conv-1", + "unknown_field": "should-not-appear", + } + params.max_output_tokens = None + params.temperature = None + params.parallel_tool_calls = None + params.extra_headers = None + params.store = False + params.previous_response_id = None + + settings = _model_settings_from_responses_params(params) + + assert "unknown_field" not in settings.get("extra_body", {}) + assert settings["extra_body"]["conversation"] == "conv-1" + + +class TestLlsResponsesExtraFields: + """Tests for the _LLS_RESPONSES_EXTRA_FIELDS constant.""" + + def test_is_frozenset(self) -> None: + """Test that _LLS_RESPONSES_EXTRA_FIELDS is a frozenset.""" + assert isinstance(_LLS_RESPONSES_EXTRA_FIELDS, frozenset) + + def test_contains_expected_fields(self) -> None: + """Test that key fields are present.""" + expected = { + "conversation", + "max_infer_iters", + "tools", + "tool_choice", + "include", + "text", + "reasoning", + "prompt", + "metadata", + "max_tool_calls", + "safety_identifier", + } + assert expected == _LLS_RESPONSES_EXTRA_FIELDS + + +class TestBuildAgent: + """Tests for the build_agent factory function.""" + + def test_returns_agent_with_correct_model(self, mocker: MockerFixture) -> None: + """Test that build_agent returns an Agent with the specified model name.""" + mock_client = mocker.Mock() + mock_client.base_url = "http://localhost:8321" + mock_client.api_key = "test-key" + mock_client._client = mocker.Mock(spec=httpx.AsyncClient) + + mock_params = mocker.Mock() + mock_params.model = "provider/my-model" + mock_params.instructions = "Be helpful." + mock_params.model_dump.return_value = { + "model": "provider/my-model", + "conversation": "conv-1", + } + mock_params.max_output_tokens = None + mock_params.temperature = None + mock_params.parallel_tool_calls = None + mock_params.extra_headers = None + mock_params.store = False + mock_params.previous_response_id = None + + agent = build_agent(mock_client, mock_params) + + assert agent is not None + + def test_agent_has_instructions(self, mocker: MockerFixture) -> None: + """Test that build_agent passes instructions to the Agent.""" + mock_client = mocker.Mock() + mock_client.base_url = "http://localhost:8321" + mock_client.api_key = "test-key" + mock_client._client = mocker.Mock(spec=httpx.AsyncClient) + + mock_params = mocker.Mock() + mock_params.model = "provider/my-model" + mock_params.instructions = "You are a helpful assistant." + mock_params.model_dump.return_value = {"model": "provider/my-model"} + mock_params.max_output_tokens = None + mock_params.temperature = None + mock_params.parallel_tool_calls = None + mock_params.extra_headers = None + mock_params.store = False + mock_params.previous_response_id = None + + agent = build_agent(mock_client, mock_params) + + assert "You are a helpful assistant." in agent._instructions + + def test_agent_with_library_client(self, mocker: MockerFixture) -> None: + """Test that build_agent works with a library client.""" + mock_lib_client = mocker.Mock(spec=AsyncLlamaStackAsLibraryClient) + mock_lib_client.provider_data = None + + mock_params = mocker.Mock() + mock_params.model = "provider/my-model" + mock_params.instructions = None + mock_params.model_dump.return_value = { + "model": "provider/my-model", + "conversation": "conv-1", + } + mock_params.max_output_tokens = None + mock_params.temperature = None + mock_params.parallel_tool_calls = None + mock_params.extra_headers = None + mock_params.store = True + mock_params.previous_response_id = None + + agent = build_agent(mock_lib_client, mock_params) + + assert agent is not None