Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions src/utils/pydantic_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Helpers for running Pydantic AI agents against Llama Stack (Responses API compatibility)."""

from __future__ import annotations

from typing import Any, Final, cast

from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
from llama_stack_client import AsyncLlamaStackClient
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings

from models.common.responses.responses_api_params import ResponsesApiParams
from pydantic_ai_lightspeed.llamastack import LlamaStackProvider

_LLS_RESPONSES_EXTRA_FIELDS: Final[frozenset[str]] = frozenset(
{
"conversation",
"max_infer_iters",
"tools",
"tool_choice",
"include",
"text",
"reasoning",
"prompt",
"metadata",
"max_tool_calls",
"safety_identifier",
}
)


def _llama_stack_provider_from_client(
client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient,
) -> LlamaStackProvider:
"""Construct a Pydantic AI Llama Stack provider backed by the same client as ``/query``."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial | 💤 Low value

Expand helper docstrings to Google conventions.

_llama_stack_provider_from_client (and _model_settings_from_responses_params at Line 51) use single-line docstrings without Parameters/Returns/Raises sections. As per coding guidelines, "Follow Google Python docstring conventions with required sections: Parameters, Returns, Raises".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/utils/pydantic_ai.py` at line 35, The single-line docstrings for
_llama_stack_provider_from_client and _model_settings_from_responses_params must
be expanded to follow Google Python docstring conventions: update each function
to include Parameters (documenting args such as client, responses_params, etc.
with types and purpose), Returns (describe the return type and what it
represents), and Raises (list any exceptions the function may raise); ensure the
wording matches existing parameter names and types used in those functions and
keep the descriptions concise and accurate so the new docstrings fully replace
the current one-line descriptions.

if isinstance(client, AsyncLlamaStackAsLibraryClient):
return LlamaStackProvider(library_client=client)
api_key = client.api_key or "not-needed"
base = str(client.base_url).rstrip("/")
base_url = base if base.endswith("/v1") else f"{base}/v1"
return LlamaStackProvider(
base_url=base_url,
api_key=api_key,
http_client=client._client, # pylint: disable=protected-access
)
Comment on lines +41 to +45
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🌐 Web query:

In llama-stack-client 0.6.0, does AsyncLlamaStackClient expose a public way to access or reuse its underlying httpx.AsyncClient, or is _client intentionally private/internal?

💡 Result:

In llama-stack-client 0.6.0, AsyncLlamaStackClient does not publicly expose (or promise) direct access to/reuse of its underlying httpx.AsyncClient; the underlying HTTP client is intended to be managed via the public http_client configuration/injection API, not by reaching into a private/internal attribute like _client. Evidence: - The official client Python README documents that for async usage you may configure the HTTP backend by passing an http_client when constructing AsyncLlamaStackClient (including using DefaultAioHttpClient) and shows using the client as an async context manager [1]. This is the supported “reuse/customize” mechanism for the underlying transport/client. - The same README also provides a documented mechanism to “directly override” the underlying HTTP client (via DefaultHttpxClient) and use with_options(http_client=...) for customizing on a per-request basis, indicating the library expects callers to inject/override the HTTP client rather than extracting it from internals [1]. So, _client is intentionally private/internal: callers should reuse the HTTP resources by passing the desired http_client (e.g., a DefaultHttpxClient with a custom httpx.AsyncClient/transport) into AsyncLlamaStackClient, or by using the async context manager lifecycle, rather than accessing a hidden _client attribute directly [1].

Citations:


Don’t depend on AsyncLlamaStackClient._client; use the public http_client injection path.

The helper passes a protected/internal attribute into LlamaStackProvider:

return LlamaStackProvider(
    base_url=base_url,
    api_key=api_key,
    http_client=client._client,  # pylint: disable=protected-access
)

llama-stack-client’s docs indicate the underlying HTTP client should be managed via the public http_client configuration/override API, so reaching into _client is an unstable dependency that can break at runtime even if mocks keep CI green. Refactor _llama_stack_provider_from_client to accept/pass through a publicly injected http_client (or construct the needed client using that same public injection) rather than extracting it from AsyncLlamaStackClient.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/utils/pydantic_ai.py` around lines 41 - 45, The helper
_llama_stack_provider_from_client currently reaches into
AsyncLlamaStackClient._client (protected) when building a LlamaStackProvider;
change it to rely on the public HTTP-injection path instead: update
_llama_stack_provider_from_client to accept a public http_client parameter (or
obtain the client's configured http_client via its public API), and pass that
http_client into LlamaStackProvider(base_url=..., api_key=..., http_client=...)
rather than using client._client; ensure callers of
_llama_stack_provider_from_client are updated to provide the public http_client
argument.



def _model_settings_from_responses_params(
responses_params: ResponsesApiParams,
) -> OpenAIResponsesModelSettings:
"""Map ``ResponsesApiParams`` into Pydantic AI OpenAI Responses model settings."""
payload = responses_params.model_dump(exclude_none=True)
extra_body = {k: v for k, v in payload.items() if k in _LLS_RESPONSES_EXTRA_FIELDS}
settings_dict: dict[str, Any] = {}
if extra_body:
settings_dict["extra_body"] = extra_body
if responses_params.max_output_tokens is not None:
settings_dict["max_tokens"] = responses_params.max_output_tokens
if responses_params.temperature is not None:
settings_dict["temperature"] = responses_params.temperature
if responses_params.parallel_tool_calls is not None:
settings_dict["parallel_tool_calls"] = responses_params.parallel_tool_calls
if responses_params.extra_headers:
settings_dict["extra_headers"] = dict(responses_params.extra_headers)
settings_dict["openai_store"] = responses_params.store
if responses_params.previous_response_id is not None:
settings_dict["openai_previous_response_id"] = (
responses_params.previous_response_id
)
return cast(OpenAIResponsesModelSettings, settings_dict)


def build_agent(
client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient,
responses_params: ResponsesApiParams,
) -> Agent[None, str]:
"""Build a Pydantic AI agent that mirrors ``responses_params`` on the Llama Stack backend.

Uses ``LlamaStackProvider`` with the same ``AsyncLlamaStackClient`` (or library client)
as the query endpoint, and ``OpenAIResponsesModel`` so requests follow the Responses API.
Llama-Stack-specific fields (conversation, tools, MCP headers, etc.) are passed via
``model_settings['extra_body']`` so they merge into the OpenAI client request body.

Parameters:
client: Initialized Llama Stack client from ``AsyncLlamaStackClientHolder().get_client()``.
responses_params: Parameters produced by ``prepare_responses_params`` for this turn.

Returns:
``Agent`` configured for ``await agent.run(...)`` (or streaming) against the same
stack configuration as ``client.responses.create(**responses_params.model_dump())``.
"""
provider = _llama_stack_provider_from_client(client)
settings = _model_settings_from_responses_params(responses_params)

model = OpenAIResponsesModel(
responses_params.model,
provider=provider,
settings=settings,
)
return Agent(
model,
instructions=responses_params.instructions,
defer_model_check=True,
)
270 changes: 270 additions & 0 deletions tests/unit/utils/test_pydantic_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
"""Unit tests for utils/pydantic_ai module."""

# pylint: disable=protected-access

import httpx
import pytest
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
from pytest_mock import MockerFixture

from utils.pydantic_ai import (
_LLS_RESPONSES_EXTRA_FIELDS,
_llama_stack_provider_from_client,
_model_settings_from_responses_params,
build_agent,
)


class TestLlamaStackProviderFromClient:
"""Tests for _llama_stack_provider_from_client factory."""

def test_library_client(self, mocker: MockerFixture) -> None:
"""Test that a library client creates a provider with library_client kwarg."""
mock_lib_client = mocker.Mock(spec=AsyncLlamaStackAsLibraryClient)
mock_lib_client.provider_data = None

provider = _llama_stack_provider_from_client(mock_lib_client)

assert provider._library_client is mock_lib_client

def test_remote_client_with_api_key(self, mocker: MockerFixture) -> None:
"""Test that a remote client uses its api_key."""
mock_client = mocker.Mock()
mock_client.base_url = "http://my-server:8321"
mock_client.api_key = "my-secret"
mock_client._client = mocker.Mock(spec=httpx.AsyncClient)

provider = _llama_stack_provider_from_client(mock_client)

assert provider.client.api_key == "my-secret"
assert "my-server:8321" in provider.base_url

def test_remote_client_without_api_key(self, mocker: MockerFixture) -> None:
"""Test that a remote client without api_key defaults to 'not-needed'."""
mock_client = mocker.Mock()
mock_client.base_url = "http://my-server:8321"
mock_client.api_key = None
mock_client._client = mocker.Mock(spec=httpx.AsyncClient)

provider = _llama_stack_provider_from_client(mock_client)

assert provider.client.api_key == "not-needed"

def test_remote_client_passes_http_client(self, mocker: MockerFixture) -> None:
"""Test that a remote client's internal http_client is forwarded."""
mock_http_client = mocker.Mock(spec=httpx.AsyncClient)
mock_client = mocker.Mock()
mock_client.base_url = "http://my-server:8321"
mock_client.api_key = "key"
mock_client._client = mock_http_client

provider = _llama_stack_provider_from_client(mock_client)

assert provider._client._client is mock_http_client


class TestModelSettingsFromResponsesParams:
"""Tests for _model_settings_from_responses_params mapping."""

@pytest.fixture(name="minimal_params")
def minimal_params_fixture(self, mocker: MockerFixture) -> object:
"""Create minimal ResponsesApiParams mock with required fields only."""
params = mocker.Mock()
params.model_dump.return_value = {"model": "test/model", "input": "hello"}
params.max_output_tokens = None
params.temperature = None
params.parallel_tool_calls = None
params.extra_headers = None
params.store = False
params.previous_response_id = None
return params

def test_minimal_params_returns_store_false(self, minimal_params: object) -> None:
"""Test that minimal params produce settings with openai_store=False."""
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
assert settings["openai_store"] is False

def test_minimal_params_no_extra_body(self, minimal_params: object) -> None:
"""Test that minimal params without extra fields omit extra_body."""
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
assert "extra_body" not in settings

def test_max_output_tokens_mapped(self, minimal_params: object) -> None:
"""Test that max_output_tokens is mapped to max_tokens."""
minimal_params.max_output_tokens = 1024 # type: ignore[attr-defined]
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
assert settings["max_tokens"] == 1024

def test_temperature_mapped(self, minimal_params: object) -> None:
"""Test that temperature is passed through."""
minimal_params.temperature = 0.7 # type: ignore[attr-defined]
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
assert settings["temperature"] == 0.7

def test_parallel_tool_calls_mapped(self, minimal_params: object) -> None:
"""Test that parallel_tool_calls is passed through."""
minimal_params.parallel_tool_calls = True # type: ignore[attr-defined]
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
assert settings["parallel_tool_calls"] is True

def test_extra_headers_mapped(self, minimal_params: object) -> None:
"""Test that extra_headers are converted to a dict."""
minimal_params.extra_headers = {"x-custom": "value"} # type: ignore[attr-defined]
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
assert settings["extra_headers"] == {"x-custom": "value"}

def test_store_true_mapped(self, minimal_params: object) -> None:
"""Test that store=True is passed as openai_store."""
minimal_params.store = True # type: ignore[attr-defined]
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
assert settings["openai_store"] is True

def test_previous_response_id_mapped(self, minimal_params: object) -> None:
"""Test that previous_response_id is passed as openai_previous_response_id."""
minimal_params.previous_response_id = "resp_abc123" # type: ignore[attr-defined]
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
assert settings["openai_previous_response_id"] == "resp_abc123"

def test_extra_body_from_lls_fields(self, mocker: MockerFixture) -> None:
"""Test that LLS-specific fields are placed into extra_body."""
params = mocker.Mock()
params.model_dump.return_value = {
"model": "test/model",
"conversation": "conv-123",
"max_infer_iters": 5,
"tools": [{"type": "function"}],
"tool_choice": "auto",
}
params.max_output_tokens = None
params.temperature = None
params.parallel_tool_calls = None
params.extra_headers = None
params.store = False
params.previous_response_id = None

settings = _model_settings_from_responses_params(params)

assert "extra_body" in settings
assert settings["extra_body"]["conversation"] == "conv-123"
assert settings["extra_body"]["max_infer_iters"] == 5
assert settings["extra_body"]["tools"] == [{"type": "function"}]
assert settings["extra_body"]["tool_choice"] == "auto"

def test_extra_body_only_includes_known_fields(self, mocker: MockerFixture) -> None:
"""Test that extra_body only includes fields in _LLS_RESPONSES_EXTRA_FIELDS."""
params = mocker.Mock()
params.model_dump.return_value = {
"model": "test/model",
"conversation": "conv-1",
"unknown_field": "should-not-appear",
}
params.max_output_tokens = None
params.temperature = None
params.parallel_tool_calls = None
params.extra_headers = None
params.store = False
params.previous_response_id = None

settings = _model_settings_from_responses_params(params)

assert "unknown_field" not in settings.get("extra_body", {})
assert settings["extra_body"]["conversation"] == "conv-1"


class TestLlsResponsesExtraFields:
"""Tests for the _LLS_RESPONSES_EXTRA_FIELDS constant."""

def test_is_frozenset(self) -> None:
"""Test that _LLS_RESPONSES_EXTRA_FIELDS is a frozenset."""
assert isinstance(_LLS_RESPONSES_EXTRA_FIELDS, frozenset)

def test_contains_expected_fields(self) -> None:
"""Test that key fields are present."""
expected = {
"conversation",
"max_infer_iters",
"tools",
"tool_choice",
"include",
"text",
"reasoning",
"prompt",
"metadata",
"max_tool_calls",
"safety_identifier",
}
assert expected == _LLS_RESPONSES_EXTRA_FIELDS


class TestBuildAgent:
"""Tests for the build_agent factory function."""

def test_returns_agent_with_correct_model(self, mocker: MockerFixture) -> None:
"""Test that build_agent returns an Agent with the specified model name."""
mock_client = mocker.Mock()
mock_client.base_url = "http://localhost:8321"
mock_client.api_key = "test-key"
mock_client._client = mocker.Mock(spec=httpx.AsyncClient)

mock_params = mocker.Mock()
mock_params.model = "provider/my-model"
mock_params.instructions = "Be helpful."
mock_params.model_dump.return_value = {
"model": "provider/my-model",
"conversation": "conv-1",
}
mock_params.max_output_tokens = None
mock_params.temperature = None
mock_params.parallel_tool_calls = None
mock_params.extra_headers = None
mock_params.store = False
mock_params.previous_response_id = None

agent = build_agent(mock_client, mock_params)

assert agent is not None

def test_agent_has_instructions(self, mocker: MockerFixture) -> None:
"""Test that build_agent passes instructions to the Agent."""
mock_client = mocker.Mock()
mock_client.base_url = "http://localhost:8321"
mock_client.api_key = "test-key"
mock_client._client = mocker.Mock(spec=httpx.AsyncClient)

mock_params = mocker.Mock()
mock_params.model = "provider/my-model"
mock_params.instructions = "You are a helpful assistant."
mock_params.model_dump.return_value = {"model": "provider/my-model"}
mock_params.max_output_tokens = None
mock_params.temperature = None
mock_params.parallel_tool_calls = None
mock_params.extra_headers = None
mock_params.store = False
mock_params.previous_response_id = None

agent = build_agent(mock_client, mock_params)

assert "You are a helpful assistant." in agent._instructions

def test_agent_with_library_client(self, mocker: MockerFixture) -> None:
"""Test that build_agent works with a library client."""
mock_lib_client = mocker.Mock(spec=AsyncLlamaStackAsLibraryClient)
mock_lib_client.provider_data = None

mock_params = mocker.Mock()
mock_params.model = "provider/my-model"
mock_params.instructions = None
mock_params.model_dump.return_value = {
"model": "provider/my-model",
"conversation": "conv-1",
}
mock_params.max_output_tokens = None
mock_params.temperature = None
mock_params.parallel_tool_calls = None
mock_params.extra_headers = None
mock_params.store = True
mock_params.previous_response_id = None

agent = build_agent(mock_lib_client, mock_params)

assert agent is not None
Loading