Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions fastapi_startkit/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ ai = [
"google-generativeai>=0.8.0",
]

langchain = [
"langchain>=1.0.0",
"langchain-core>=1.0.0",
]

[dependency-groups]
dev = [
"dumpdie>=1.5.0",
Expand All @@ -68,6 +73,8 @@ dev = [
"sqlalchemy[asyncio]>=2.0.38",
"fastapi[standard]>=0.124.4",
"faker>=40.13.0",
"langchain>=1.0.0",
"langchain-core>=1.0.0",
]


Expand Down
2 changes: 2 additions & 0 deletions fastapi_startkit/src/fastapi_startkit/ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .config import AIConfig, AnthropicConfig, GoogleConfig, OpenAIConfig
from .decorators import max_steps, max_tokens, memory, model, provider, timeout, top_p
from .document import Document
from .fakes import fake_chat_model
from .image import Image, ImageResponse
from .image_factory import ImageFactory
from .providers.ai_provider import AIProvider
Expand All @@ -38,6 +39,7 @@
"AudioResponse",
"AudioFactory",
"Document",
"fake_chat_model",
"GoogleConfig",
"Image",
"ImageFactory",
Expand Down
64 changes: 51 additions & 13 deletions fastapi_startkit/src/fastapi_startkit/ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
from __future__ import annotations

import fnmatch
from typing import Any, Callable, Iterator, Optional, Type
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Optional, Type

from .document import Document
from .response import AgentResponse, AgentSnapshot

if TYPE_CHECKING:
from langchain_core.messages import AIMessage


class Agent:
"""
Expand Down Expand Up @@ -39,7 +42,7 @@ class Agent:
}

def __init__(self):
self._fakes: dict[str, AgentResponse | AgentSnapshot] = {}
self._fakes: dict[str, AgentResponse | AgentSnapshot | "AIMessage"] = {}
self._call_log: list[dict] = []

# ── Lifecycle — override in subclasses ──────────────────────────────────
Expand Down Expand Up @@ -97,10 +100,7 @@ def prompt(

match = self._match_fake(message)
if match is not None:
if isinstance(match, AgentSnapshot):
response = match.resolve(self, message, **_run_kwargs)
else:
response = match
response = self._resolve_fake(match, message, _run_kwargs)
self._log_call("prompt", message)
return self.after(response)

Expand All @@ -124,20 +124,45 @@ def stream(
self._log_call("stream", message)
fake = self._match_fake(message)
if fake is not None:
if isinstance(fake, AgentSnapshot):
response = fake.resolve(self, message)
else:
response = fake
response = self._resolve_fake(fake, message)
yield response.content
return
yield from self._stream(message, system=system, model=model, provider_options=provider_options)

def fake(self, patterns: dict[str, AgentResponse | AgentSnapshot]) -> "Agent":
"""Register fake responses for testing. Keys are glob patterns."""
def fake(self, patterns: dict[str, "AgentResponse | AgentSnapshot | AIMessage"]) -> "Agent":
"""Register fake responses for testing. Keys are glob patterns.

A value may be an :class:`AgentResponse`, an :class:`AgentSnapshot`
(record/replay), or a LangChain ``AIMessage`` (converted on match,
preserving ``content`` and ``tool_calls``). When a prompt matches a
pattern the agent returns the fake without ever calling the provider.
"""
for pattern, value in patterns.items():
self._fakes[pattern] = value
return self

def record(self, path: str, *, pattern: str = "*") -> "Agent":
"""Record-and-replay (VCR style) for prompts matching ``pattern``.

On the first run the real provider is called and the response is saved to
``path``; every later run replays that recording without hitting the API.
Returns ``self`` for chaining.
"""
self._fakes[pattern] = AgentSnapshot(path=path)
return self

@staticmethod
def fake_model(turns: Iterable[Any]):
"""Return a LangChain fake chat model that replays ``turns`` in order.

Pass it to ``langchain.agents.create_agent(model=..., tools=[...])`` to run
a full agent loop — tool calls included — offline. Requires the
``langchain`` extra. See :func:`fastapi_startkit.ai.fakes.fake_chat_model`.
"""
from .fakes import fake_chat_model

return fake_chat_model(turns)

def assert_prompted(self, times: int | None = None) -> None:
"""Assert that prompt() or stream() was called."""
calls = [c for c in self._call_log if c["method"] in ("prompt", "stream")]
Expand All @@ -158,12 +183,25 @@ def reset(self) -> "Agent":

# ── Internal helpers ────────────────────────────────────────────────────

def _match_fake(self, message: str) -> Optional[AgentResponse | AgentSnapshot]:
def _match_fake(self, message: str) -> Optional[Any]:
for pattern, value in self._fakes.items():
if fnmatch.fnmatch(message.lower(), pattern.lower()):
return value
return None

def _resolve_fake(self, match: Any, message: str, run_kwargs: dict | None = None) -> AgentResponse:
"""Turn a registered fake into an AgentResponse without calling the provider."""
if isinstance(match, AgentSnapshot):
return match.resolve(self, message, **(run_kwargs or {}))
if isinstance(match, AgentResponse):
return match

from .fakes import is_ai_message, to_agent_response

if is_ai_message(match):
return to_agent_response(match)
raise TypeError(f"Unsupported fake value: {type(match)!r}")

def _log_call(self, method: str, message: str) -> None:
self._call_log.append({"method": method, "message": message})

Expand Down
94 changes: 94 additions & 0 deletions fastapi_startkit/src/fastapi_startkit/ai/fakes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""LangChain test-harness helpers — run agents offline, without a real provider.

These utilities let tests exercise a *real* LangChain agent (``create_agent``)
without calling a model API, by replaying a scripted sequence of assistant turns
through a fake chat model. They require the ``langchain`` extra::

pip install "fastapi-startkit[langchain]"

Example — drive a full tool-calling loop with no network::

from langchain_core.messages import AIMessage, ToolCall
from langchain.agents import create_agent
from fastapi_startkit.ai import fake_chat_model

model = fake_chat_model([
AIMessage(content="", tool_calls=[
ToolCall(name="search_jobs", args={"query": "python"}, id="c1", type="tool_call"),
]),
AIMessage(content="Here is a Python Developer role at Shopify."),
])
agent = create_agent(model=model, tools=[search_jobs])
result = agent.invoke({"messages": [{"role": "user", "content": "Find me a python job"}]})
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Iterable

from .response import AgentResponse

if TYPE_CHECKING:
from langchain_core.messages import AIMessage


def _require_langchain():
try:
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage
except ImportError as exc: # pragma: no cover - exercised only without the extra
raise ImportError(
"LangChain is required for the agent test harness. "
"Install it with: pip install \"fastapi-startkit[langchain]\""
) from exc
return GenericFakeChatModel, AIMessage


def fake_chat_model(turns: Iterable[Any]):
"""Return a fake chat model that replays ``turns`` in order.

Each turn is an ``AIMessage`` (which may carry ``tool_calls``) or a ``str``
(shorthand for ``AIMessage(content=...)``). Pass the result straight to
``langchain.agents.create_agent(model=..., tools=[...])`` to run a complete
agent loop — tool calls included — without hitting a real provider.
"""
generic_model, ai_message = _require_langchain()

class _FakeChatModel(generic_model):
# create_agent calls bind_tools() when tools are present; GenericFakeChatModel
# leaves it unimplemented. The scripted turns already encode the model's
# decisions, so the bound tool schemas are irrelevant — binding is a no-op.
def bind_tools(self, tools, **kwargs):
return self

normalized = [t if isinstance(t, ai_message) else ai_message(content=str(t)) for t in turns]
return _FakeChatModel(messages=iter(normalized))


def is_ai_message(value: Any) -> bool:
"""Return True if ``value`` is a LangChain ``AIMessage`` (False if langchain is absent)."""
try:
from langchain_core.messages import AIMessage
except ImportError:
return False
return isinstance(value, AIMessage)


def to_agent_response(value: "AIMessage") -> AgentResponse:
"""Convert a LangChain ``AIMessage`` into an :class:`AgentResponse`."""
_, ai_message = _require_langchain()
if not isinstance(value, ai_message):
raise TypeError(f"Expected an AIMessage, got {type(value)!r}")

content = value.content if isinstance(value.content, str) else str(value.content)
tool_calls = list(getattr(value, "tool_calls", None) or [])

usage: dict = {}
meta = getattr(value, "usage_metadata", None)
if meta:
usage = {
"input": meta.get("input_tokens", 0),
"output": meta.get("output_tokens", 0),
}

return AgentResponse(content=content, tool_calls=tool_calls, usage=usage, raw=value)
97 changes: 97 additions & 0 deletions fastapi_startkit/tests/ai/test_agent_langchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Tests for the LangChain test harness — fake AIMessage fakes and fake_chat_model.

These exercise a real LangChain agent loop entirely offline. They are skipped
automatically when the optional ``langchain`` extra is not installed.
"""

import pytest

pytest.importorskip("langchain")
pytest.importorskip("langchain_core")

from langchain.agents import create_agent # noqa: E402
from langchain_core.messages import AIMessage, ToolCall # noqa: E402
from langchain_core.tools import tool # noqa: E402

from fastapi_startkit.ai import fake_chat_model # noqa: E402
from fastapi_startkit.ai.agent import Agent # noqa: E402


class SimpleAgent(Agent):
pass


@tool
def search_jobs(query: str):
"""Search for jobs matching the query."""
mock_db = [
{"title": "Python Developer", "company": "Shopify"},
{"title": "Data Engineer", "company": "Google"},
]
return [job for job in mock_db if query.lower() in job["title"].lower()] or mock_db


# ─── fake() accepts AIMessage values ─────────────────────────────────────────


def test_fake_with_aimessage_value_returns_content():
agent = SimpleAgent()
agent.fake({"*hi*": AIMessage(content="hello there")})

result = agent.prompt("hi friend")
assert result.content == "hello there"
agent.assert_prompted(times=1)


def test_fake_with_aimessage_preserves_tool_calls():
agent = SimpleAgent()
call = ToolCall(name="search_jobs", args={"query": "python"}, id="c1", type="tool_call")
agent.fake({"*job*": AIMessage(content="", tool_calls=[call])})

result = agent.prompt("find me a job")
assert result.content == ""
assert result.tool_calls and result.tool_calls[0]["name"] == "search_jobs"


def test_fake_aimessage_does_not_call_run():
agent = SimpleAgent()
agent.fake({"*": AIMessage(content="faked")})

called = []
agent._run = lambda *a, **kw: called.append(True) # type: ignore[method-assign]

agent.prompt("anything")
assert called == []


# ─── fake_chat_model drives a full create_agent loop offline ──────────────────


def test_fake_chat_model_runs_tool_calling_agent_offline():
model = fake_chat_model(
[
AIMessage(
content="",
tool_calls=[ToolCall(name="search_jobs", args={"query": "python"}, id="c1", type="tool_call")],
),
AIMessage(content="Here is a Python Developer role at Shopify."),
]
)
agent = create_agent(model=model, tools=[search_jobs])

result = agent.invoke({"messages": [{"role": "user", "content": "Find me a python job"}]})
messages = result["messages"]

assert messages[-1].content == "Here is a Python Developer role at Shopify."
assert any(getattr(m, "type", None) == "tool" for m in messages), "the tool should have run"


def test_fake_chat_model_accepts_plain_strings():
model = fake_chat_model(["just a string answer"])
result = model.invoke("hello")
assert result.content == "just a string answer"


def test_agent_fake_model_staticmethod_builds_usable_model():
model = Agent.fake_model([AIMessage(content="from staticmethod")])
assert model.invoke("hi").content == "from staticmethod"
Loading
Loading