-
Notifications
You must be signed in to change notification settings - Fork 88
LCORE-2309: Added Pydantic AI Bridge #1817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| """Helpers for running Pydantic AI agents against Llama Stack (Responses API compatibility).""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Any, Final, cast | ||
|
|
||
| from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient | ||
| from llama_stack_client import AsyncLlamaStackClient | ||
| from pydantic_ai import Agent | ||
| from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings | ||
|
|
||
| from models.common.responses.responses_api_params import ResponsesApiParams | ||
| from pydantic_ai_lightspeed.llamastack import LlamaStackProvider | ||
|
|
||
| _LLS_RESPONSES_EXTRA_FIELDS: Final[frozenset[str]] = frozenset( | ||
| { | ||
| "conversation", | ||
| "max_infer_iters", | ||
| "tools", | ||
| "tool_choice", | ||
| "include", | ||
| "text", | ||
| "reasoning", | ||
| "prompt", | ||
| "metadata", | ||
| "max_tool_calls", | ||
| "safety_identifier", | ||
| } | ||
| ) | ||
|
|
||
|
|
||
| def _llama_stack_provider_from_client( | ||
| client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient, | ||
| ) -> LlamaStackProvider: | ||
| """Construct a Pydantic AI Llama Stack provider backed by the same client as ``/query``.""" | ||
| if isinstance(client, AsyncLlamaStackAsLibraryClient): | ||
| return LlamaStackProvider(library_client=client) | ||
| api_key = client.api_key or "not-needed" | ||
| base = str(client.base_url).rstrip("/") | ||
| base_url = base if base.endswith("/v1") else f"{base}/v1" | ||
| return LlamaStackProvider( | ||
| base_url=base_url, | ||
| api_key=api_key, | ||
| http_client=client._client, # pylint: disable=protected-access | ||
| ) | ||
|
Comment on lines
+41
to
+45
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🌐 Web query:
💡 Result: In llama-stack-client 0.6.0, AsyncLlamaStackClient does not publicly expose (or promise) direct access to/reuse of its underlying httpx.AsyncClient; the underlying HTTP client is intended to be managed via the public http_client configuration/injection API, not by reaching into a private/internal attribute like _client. Evidence: - The official client Python README documents that for async usage you may configure the HTTP backend by passing an http_client when constructing AsyncLlamaStackClient (including using DefaultAioHttpClient) and shows using the client as an async context manager [1]. This is the supported “reuse/customize” mechanism for the underlying transport/client. - The same README also provides a documented mechanism to “directly override” the underlying HTTP client (via DefaultHttpxClient) and use with_options(http_client=...) for customizing on a per-request basis, indicating the library expects callers to inject/override the HTTP client rather than extracting it from internals [1]. So, _client is intentionally private/internal: callers should reuse the HTTP resources by passing the desired http_client (e.g., a DefaultHttpxClient with a custom httpx.AsyncClient/transport) into AsyncLlamaStackClient, or by using the async context manager lifecycle, rather than accessing a hidden Citations: Don’t depend on The helper passes a protected/internal attribute into return LlamaStackProvider(
base_url=base_url,
api_key=api_key,
http_client=client._client, # pylint: disable=protected-access
)
🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| def _model_settings_from_responses_params( | ||
| responses_params: ResponsesApiParams, | ||
| ) -> OpenAIResponsesModelSettings: | ||
| """Map ``ResponsesApiParams`` into Pydantic AI OpenAI Responses model settings.""" | ||
| payload = responses_params.model_dump(exclude_none=True) | ||
| extra_body = {k: v for k, v in payload.items() if k in _LLS_RESPONSES_EXTRA_FIELDS} | ||
| settings_dict: dict[str, Any] = {} | ||
| if extra_body: | ||
| settings_dict["extra_body"] = extra_body | ||
| if responses_params.max_output_tokens is not None: | ||
| settings_dict["max_tokens"] = responses_params.max_output_tokens | ||
| if responses_params.temperature is not None: | ||
| settings_dict["temperature"] = responses_params.temperature | ||
| if responses_params.parallel_tool_calls is not None: | ||
| settings_dict["parallel_tool_calls"] = responses_params.parallel_tool_calls | ||
| if responses_params.extra_headers: | ||
| settings_dict["extra_headers"] = dict(responses_params.extra_headers) | ||
| settings_dict["openai_store"] = responses_params.store | ||
| if responses_params.previous_response_id is not None: | ||
| settings_dict["openai_previous_response_id"] = ( | ||
| responses_params.previous_response_id | ||
| ) | ||
| return cast(OpenAIResponsesModelSettings, settings_dict) | ||
|
|
||
|
|
||
| def build_agent( | ||
| client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient, | ||
| responses_params: ResponsesApiParams, | ||
| ) -> Agent[None, str]: | ||
| """Build a Pydantic AI agent that mirrors ``responses_params`` on the Llama Stack backend. | ||
|
|
||
| Uses ``LlamaStackProvider`` with the same ``AsyncLlamaStackClient`` (or library client) | ||
| as the query endpoint, and ``OpenAIResponsesModel`` so requests follow the Responses API. | ||
| Llama-Stack-specific fields (conversation, tools, MCP headers, etc.) are passed via | ||
| ``model_settings['extra_body']`` so they merge into the OpenAI client request body. | ||
|
|
||
| Parameters: | ||
| client: Initialized Llama Stack client from ``AsyncLlamaStackClientHolder().get_client()``. | ||
| responses_params: Parameters produced by ``prepare_responses_params`` for this turn. | ||
|
|
||
| Returns: | ||
| ``Agent`` configured for ``await agent.run(...)`` (or streaming) against the same | ||
| stack configuration as ``client.responses.create(**responses_params.model_dump())``. | ||
| """ | ||
| provider = _llama_stack_provider_from_client(client) | ||
| settings = _model_settings_from_responses_params(responses_params) | ||
|
|
||
| model = OpenAIResponsesModel( | ||
| responses_params.model, | ||
| provider=provider, | ||
| settings=settings, | ||
| ) | ||
| return Agent( | ||
| model, | ||
| instructions=responses_params.instructions, | ||
| defer_model_check=True, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,270 @@ | ||
| """Unit tests for utils/pydantic_ai module.""" | ||
|
|
||
| # pylint: disable=protected-access | ||
|
|
||
| import httpx | ||
| import pytest | ||
| from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient | ||
| from pytest_mock import MockerFixture | ||
|
|
||
| from utils.pydantic_ai import ( | ||
| _LLS_RESPONSES_EXTRA_FIELDS, | ||
| _llama_stack_provider_from_client, | ||
| _model_settings_from_responses_params, | ||
| build_agent, | ||
| ) | ||
|
|
||
|
|
||
| class TestLlamaStackProviderFromClient: | ||
| """Tests for _llama_stack_provider_from_client factory.""" | ||
|
|
||
| def test_library_client(self, mocker: MockerFixture) -> None: | ||
| """Test that a library client creates a provider with library_client kwarg.""" | ||
| mock_lib_client = mocker.Mock(spec=AsyncLlamaStackAsLibraryClient) | ||
| mock_lib_client.provider_data = None | ||
|
|
||
| provider = _llama_stack_provider_from_client(mock_lib_client) | ||
|
|
||
| assert provider._library_client is mock_lib_client | ||
|
|
||
| def test_remote_client_with_api_key(self, mocker: MockerFixture) -> None: | ||
| """Test that a remote client uses its api_key.""" | ||
| mock_client = mocker.Mock() | ||
| mock_client.base_url = "http://my-server:8321" | ||
| mock_client.api_key = "my-secret" | ||
| mock_client._client = mocker.Mock(spec=httpx.AsyncClient) | ||
|
|
||
| provider = _llama_stack_provider_from_client(mock_client) | ||
|
|
||
| assert provider.client.api_key == "my-secret" | ||
| assert "my-server:8321" in provider.base_url | ||
|
|
||
| def test_remote_client_without_api_key(self, mocker: MockerFixture) -> None: | ||
| """Test that a remote client without api_key defaults to 'not-needed'.""" | ||
| mock_client = mocker.Mock() | ||
| mock_client.base_url = "http://my-server:8321" | ||
| mock_client.api_key = None | ||
| mock_client._client = mocker.Mock(spec=httpx.AsyncClient) | ||
|
|
||
| provider = _llama_stack_provider_from_client(mock_client) | ||
|
|
||
| assert provider.client.api_key == "not-needed" | ||
|
|
||
| def test_remote_client_passes_http_client(self, mocker: MockerFixture) -> None: | ||
| """Test that a remote client's internal http_client is forwarded.""" | ||
| mock_http_client = mocker.Mock(spec=httpx.AsyncClient) | ||
| mock_client = mocker.Mock() | ||
| mock_client.base_url = "http://my-server:8321" | ||
| mock_client.api_key = "key" | ||
| mock_client._client = mock_http_client | ||
|
|
||
| provider = _llama_stack_provider_from_client(mock_client) | ||
|
|
||
| assert provider._client._client is mock_http_client | ||
|
|
||
|
|
||
| class TestModelSettingsFromResponsesParams: | ||
| """Tests for _model_settings_from_responses_params mapping.""" | ||
|
|
||
| @pytest.fixture(name="minimal_params") | ||
| def minimal_params_fixture(self, mocker: MockerFixture) -> object: | ||
| """Create minimal ResponsesApiParams mock with required fields only.""" | ||
| params = mocker.Mock() | ||
| params.model_dump.return_value = {"model": "test/model", "input": "hello"} | ||
| params.max_output_tokens = None | ||
| params.temperature = None | ||
| params.parallel_tool_calls = None | ||
| params.extra_headers = None | ||
| params.store = False | ||
| params.previous_response_id = None | ||
| return params | ||
|
|
||
| def test_minimal_params_returns_store_false(self, minimal_params: object) -> None: | ||
| """Test that minimal params produce settings with openai_store=False.""" | ||
| settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] | ||
| assert settings["openai_store"] is False | ||
|
|
||
| def test_minimal_params_no_extra_body(self, minimal_params: object) -> None: | ||
| """Test that minimal params without extra fields omit extra_body.""" | ||
| settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] | ||
| assert "extra_body" not in settings | ||
|
|
||
| def test_max_output_tokens_mapped(self, minimal_params: object) -> None: | ||
| """Test that max_output_tokens is mapped to max_tokens.""" | ||
| minimal_params.max_output_tokens = 1024 # type: ignore[attr-defined] | ||
| settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] | ||
| assert settings["max_tokens"] == 1024 | ||
|
|
||
| def test_temperature_mapped(self, minimal_params: object) -> None: | ||
| """Test that temperature is passed through.""" | ||
| minimal_params.temperature = 0.7 # type: ignore[attr-defined] | ||
| settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] | ||
| assert settings["temperature"] == 0.7 | ||
|
|
||
| def test_parallel_tool_calls_mapped(self, minimal_params: object) -> None: | ||
| """Test that parallel_tool_calls is passed through.""" | ||
| minimal_params.parallel_tool_calls = True # type: ignore[attr-defined] | ||
| settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] | ||
| assert settings["parallel_tool_calls"] is True | ||
|
|
||
| def test_extra_headers_mapped(self, minimal_params: object) -> None: | ||
| """Test that extra_headers are converted to a dict.""" | ||
| minimal_params.extra_headers = {"x-custom": "value"} # type: ignore[attr-defined] | ||
| settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] | ||
| assert settings["extra_headers"] == {"x-custom": "value"} | ||
|
|
||
| def test_store_true_mapped(self, minimal_params: object) -> None: | ||
| """Test that store=True is passed as openai_store.""" | ||
| minimal_params.store = True # type: ignore[attr-defined] | ||
| settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] | ||
| assert settings["openai_store"] is True | ||
|
|
||
| def test_previous_response_id_mapped(self, minimal_params: object) -> None: | ||
| """Test that previous_response_id is passed as openai_previous_response_id.""" | ||
| minimal_params.previous_response_id = "resp_abc123" # type: ignore[attr-defined] | ||
| settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] | ||
| assert settings["openai_previous_response_id"] == "resp_abc123" | ||
|
|
||
| def test_extra_body_from_lls_fields(self, mocker: MockerFixture) -> None: | ||
| """Test that LLS-specific fields are placed into extra_body.""" | ||
| params = mocker.Mock() | ||
| params.model_dump.return_value = { | ||
| "model": "test/model", | ||
| "conversation": "conv-123", | ||
| "max_infer_iters": 5, | ||
| "tools": [{"type": "function"}], | ||
| "tool_choice": "auto", | ||
| } | ||
| params.max_output_tokens = None | ||
| params.temperature = None | ||
| params.parallel_tool_calls = None | ||
| params.extra_headers = None | ||
| params.store = False | ||
| params.previous_response_id = None | ||
|
|
||
| settings = _model_settings_from_responses_params(params) | ||
|
|
||
| assert "extra_body" in settings | ||
| assert settings["extra_body"]["conversation"] == "conv-123" | ||
| assert settings["extra_body"]["max_infer_iters"] == 5 | ||
| assert settings["extra_body"]["tools"] == [{"type": "function"}] | ||
| assert settings["extra_body"]["tool_choice"] == "auto" | ||
|
|
||
| def test_extra_body_only_includes_known_fields(self, mocker: MockerFixture) -> None: | ||
| """Test that extra_body only includes fields in _LLS_RESPONSES_EXTRA_FIELDS.""" | ||
| params = mocker.Mock() | ||
| params.model_dump.return_value = { | ||
| "model": "test/model", | ||
| "conversation": "conv-1", | ||
| "unknown_field": "should-not-appear", | ||
| } | ||
| params.max_output_tokens = None | ||
| params.temperature = None | ||
| params.parallel_tool_calls = None | ||
| params.extra_headers = None | ||
| params.store = False | ||
| params.previous_response_id = None | ||
|
|
||
| settings = _model_settings_from_responses_params(params) | ||
|
|
||
| assert "unknown_field" not in settings.get("extra_body", {}) | ||
| assert settings["extra_body"]["conversation"] == "conv-1" | ||
|
|
||
|
|
||
| class TestLlsResponsesExtraFields: | ||
| """Tests for the _LLS_RESPONSES_EXTRA_FIELDS constant.""" | ||
|
|
||
| def test_is_frozenset(self) -> None: | ||
| """Test that _LLS_RESPONSES_EXTRA_FIELDS is a frozenset.""" | ||
| assert isinstance(_LLS_RESPONSES_EXTRA_FIELDS, frozenset) | ||
|
|
||
| def test_contains_expected_fields(self) -> None: | ||
| """Test that key fields are present.""" | ||
| expected = { | ||
| "conversation", | ||
| "max_infer_iters", | ||
| "tools", | ||
| "tool_choice", | ||
| "include", | ||
| "text", | ||
| "reasoning", | ||
| "prompt", | ||
| "metadata", | ||
| "max_tool_calls", | ||
| "safety_identifier", | ||
| } | ||
| assert expected == _LLS_RESPONSES_EXTRA_FIELDS | ||
|
|
||
|
|
||
| class TestBuildAgent: | ||
| """Tests for the build_agent factory function.""" | ||
|
|
||
| def test_returns_agent_with_correct_model(self, mocker: MockerFixture) -> None: | ||
| """Test that build_agent returns an Agent with the specified model name.""" | ||
| mock_client = mocker.Mock() | ||
| mock_client.base_url = "http://localhost:8321" | ||
| mock_client.api_key = "test-key" | ||
| mock_client._client = mocker.Mock(spec=httpx.AsyncClient) | ||
|
|
||
| mock_params = mocker.Mock() | ||
| mock_params.model = "provider/my-model" | ||
| mock_params.instructions = "Be helpful." | ||
| mock_params.model_dump.return_value = { | ||
| "model": "provider/my-model", | ||
| "conversation": "conv-1", | ||
| } | ||
| mock_params.max_output_tokens = None | ||
| mock_params.temperature = None | ||
| mock_params.parallel_tool_calls = None | ||
| mock_params.extra_headers = None | ||
| mock_params.store = False | ||
| mock_params.previous_response_id = None | ||
|
|
||
| agent = build_agent(mock_client, mock_params) | ||
|
|
||
| assert agent is not None | ||
|
|
||
| def test_agent_has_instructions(self, mocker: MockerFixture) -> None: | ||
| """Test that build_agent passes instructions to the Agent.""" | ||
| mock_client = mocker.Mock() | ||
| mock_client.base_url = "http://localhost:8321" | ||
| mock_client.api_key = "test-key" | ||
| mock_client._client = mocker.Mock(spec=httpx.AsyncClient) | ||
|
|
||
| mock_params = mocker.Mock() | ||
| mock_params.model = "provider/my-model" | ||
| mock_params.instructions = "You are a helpful assistant." | ||
| mock_params.model_dump.return_value = {"model": "provider/my-model"} | ||
| mock_params.max_output_tokens = None | ||
| mock_params.temperature = None | ||
| mock_params.parallel_tool_calls = None | ||
| mock_params.extra_headers = None | ||
| mock_params.store = False | ||
| mock_params.previous_response_id = None | ||
|
|
||
| agent = build_agent(mock_client, mock_params) | ||
|
|
||
| assert "You are a helpful assistant." in agent._instructions | ||
|
|
||
| def test_agent_with_library_client(self, mocker: MockerFixture) -> None: | ||
| """Test that build_agent works with a library client.""" | ||
| mock_lib_client = mocker.Mock(spec=AsyncLlamaStackAsLibraryClient) | ||
| mock_lib_client.provider_data = None | ||
|
|
||
| mock_params = mocker.Mock() | ||
| mock_params.model = "provider/my-model" | ||
| mock_params.instructions = None | ||
| mock_params.model_dump.return_value = { | ||
| "model": "provider/my-model", | ||
| "conversation": "conv-1", | ||
| } | ||
| mock_params.max_output_tokens = None | ||
| mock_params.temperature = None | ||
| mock_params.parallel_tool_calls = None | ||
| mock_params.extra_headers = None | ||
| mock_params.store = True | ||
| mock_params.previous_response_id = None | ||
|
|
||
| agent = build_agent(mock_lib_client, mock_params) | ||
|
|
||
| assert agent is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧹 Nitpick | 🔵 Trivial | 💤 Low value
Expand helper docstrings to Google conventions.
_llama_stack_provider_from_client(and_model_settings_from_responses_paramsat Line 51) use single-line docstrings withoutParameters/Returns/Raisessections. As per coding guidelines, "Follow Google Python docstring conventions with required sections: Parameters, Returns, Raises".🤖 Prompt for AI Agents