diff --git a/pyproject.toml b/pyproject.toml index 6ea339047..c0c5f4a09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,8 @@ lambda-worker-otel = [ "opentelemetry-semantic-conventions>=0.40b0,<1", "opentelemetry-sdk-extension-aws>=2.0.0,<3", ] +tool-registry = ["anthropic>=0.40.0"] +tool-registry-openai = ["anthropic>=0.40.0", "openai>=1.0.0"] aioboto3 = [ "aioboto3>=10.4.0", "types-aioboto3[s3]>=10.4.0", diff --git a/temporalio/contrib/tool_registry/README.md b/temporalio/contrib/tool_registry/README.md new file mode 100644 index 000000000..5ed0e6ccb --- /dev/null +++ b/temporalio/contrib/tool_registry/README.md @@ -0,0 +1,373 @@ +# temporalio[tool-registry] + +LLM tool-calling primitives for Temporal activities — define tools once, use them with +Anthropic or OpenAI. + +## Before you start + +A Temporal Activity is a function that Temporal monitors and retries automatically on failure. Temporal streams progress between retries via heartbeats — that's the mechanism `agentic_session` uses to resume a crashed LLM conversation mid-turn. + +`run_tool_loop` works standalone in any async function — no Temporal server needed. Add `agentic_session` only when you need crash-safe resume inside a Temporal activity. + +`agentic_session` requires a running Temporal worker — it reads and writes heartbeat state from the active activity context. Use `run_tool_loop` standalone for scripts, one-off jobs, or any code that runs outside a Temporal worker. + +New to Temporal? → https://docs.temporal.io/develop + +## Install + +```bash +pip install "temporalio[tool-registry]" # Anthropic only +pip install "temporalio[tool-registry-openai]" # Anthropic + OpenAI +``` + +## Quickstart + +Tool definitions use [JSON Schema](https://json-schema.org/understanding-json-schema/) for `input_schema`. The quickstart uses a single string field; for richer schemas refer to the JSON Schema docs. + +```python +from temporalio import activity +from temporalio.contrib.tool_registry import ToolRegistry, run_tool_loop + +@activity.defn # Remove for standalone use — no worker needed +async def analyze(prompt: str) -> list[str]: + results: list[str] = [] + tools = ToolRegistry() + + @tools.handler({ + "name": "flag_issue", + "description": "Flag a problem found in the analysis", + "input_schema": { + "type": "object", + "properties": {"description": {"type": "string"}}, + "required": ["description"], + }, + }) + def handle_flag(inp: dict) -> str: + results.append(inp["description"]) + return "recorded" # this string is sent back to the LLM as the tool result + + await run_tool_loop( + provider="anthropic", # reads ANTHROPIC_API_KEY from environment; or use "openai" + system="You are a code reviewer. Call flag_issue for each problem you find.", + prompt=prompt, + tools=tools, + ) + return results +``` + +## Feature matrix + +| Feature | `tool_registry` | `openai_agents` | +|---|---|---| +| Anthropic (claude-*) | ✓ | ✗ | +| OpenAI (gpt-*) | ✓ | ✓ | +| MCP tool wrapping | ✓ | ✓ | +| Crash-safe heartbeat resume | ✓ (via `agentic_session`) | ✗ | +| Agent orchestration (handoffs, etc.) | ✗ | ✓ | + +Use `openai_agents`, `google_adk_agents`, or `langgraph` when you are already building with those frameworks and want each model call to be a separately observable, retryable Temporal activity. +Use `tool_registry` for direct Anthropic support, crash-safe sessions that survive server-side session expiry, or when you need the same implementation pattern across all six Temporal SDKs (Go, Java, Ruby, .NET have no framework-level integrations). + +## Sandbox passthrough + +You need this if you register both workflows and activities on the same `Worker` instance. If your activities run on a dedicated worker (no workflows registered), skip this section. + +The Temporal workflow sandbox blocks third-party imports. If your activity +worker runs alongside a sandboxed workflow worker, use `ToolRegistryPlugin`: + +```python +from temporalio.contrib.tool_registry import ToolRegistryPlugin +from temporalio.worker import Worker + +worker = Worker( + client, + task_queue="my-queue", + plugins=[ToolRegistryPlugin(provider="anthropic")], + workflows=[MyWorkflow], + activities=[analyze], +) +``` + +## MCP integration + +MCP tool wrapping is supported via `ToolRegistry.from_mcp_tools()`. See the MCP integration guide for a complete example including server setup. + +### Selecting a model + +The default model is `"claude-sonnet-4-6"` (Anthropic) or `"gpt-4o"` (OpenAI). Pass `model=` to `run_tool_loop`: + +```python +await run_tool_loop( + provider="anthropic", + model="claude-3-5-sonnet-20241022", + system="...", + prompt=prompt, + tools=tools, +) +``` + +Model IDs are defined by the provider — see Anthropic or OpenAI docs for current names. + +### OpenAI + +```python +await run_tool_loop( + provider="openai", # reads OPENAI_API_KEY from environment + system="...", + prompt=prompt, + tools=tools, +) +``` + +## Crash-safe agentic sessions + +For multi-turn LLM conversations that must survive activity retries, use +`agentic_session`. It saves conversation history via `activity.heartbeat()` +on every turn and restores it automatically on retry. + +```python +from temporalio.contrib.tool_registry import ToolRegistry, agentic_session + +@activity.defn +async def long_analysis(prompt: str) -> list[str]: + async with agentic_session() as session: + tools = ToolRegistry() + + @tools.handler({"name": "flag", "description": "...", "input_schema": {"type": "object"}}) + def handle_flag(inp: dict) -> str: + session.results.append(inp) + return "ok" # this string is sent back to the LLM as the tool result + + await session.run_tool_loop( + registry=tools, + provider="anthropic", + system="...", + prompt=prompt, + ) + return session.results +``` + +## Human-in-the-loop tool calls + +A tool handler can block waiting for a human decision before returning a result to the +LLM. Because conversation state is stored in the heartbeat — not in a provider-side +session — the activity can wait hours without losing context. Framework plugins that rely +on API session IDs cannot do this: those sessions expire. + +The pattern: the handler starts a Temporal workflow that notifies a reviewer, then blocks +on its result. The human signals the workflow to approve or reject. If the activity +crashes while waiting, the next retry re-attaches to the same workflow via a deterministic +ID — no duplicate notifications, no lost decisions. + +The rejection reason is returned to the LLM as the tool result. The model can read it and +revise its next proposal accordingly. + +```python +import asyncio +from datetime import timedelta +from temporalio import activity, workflow +from temporalio.client import Client, WorkflowIDConflictPolicy +from temporalio.contrib.tool_registry import ToolRegistry, agentic_session + + +# ── Approval workflow ────────────────────────────────────────────────────────── + +@workflow.defn +class FixApprovalWorkflow: + """Waits for a human to approve or reject a proposed code fix.""" + + def __init__(self) -> None: + self._decision: dict | None = None + + @workflow.run + async def run(self, fix: dict) -> dict: + # Notify reviewer here (Slack, email, etc.) using workflow.execute_activity + await workflow.wait_condition( + lambda: self._decision is not None, + timeout=timedelta(hours=24), + ) + return self._decision or {"approved": False, "reason": "timed out"} + + @workflow.signal + def decide(self, decision: dict) -> None: + self._decision = decision + + +# ── Activity ─────────────────────────────────────────────────────────────────── + +@activity.defn +async def review_and_fix(diff: str) -> list[dict]: + """Review a code diff; each proposed fix requires human sign-off.""" + + async with agentic_session() as session: + tools = ToolRegistry() + + @tools.handler({ + "name": "propose_fix", + "description": ( + "Propose a code fix requiring human approval. " + "Returns 'approved' or 'rejected: '. " + "If rejected, revise your approach using the stated reason." + ), + "input_schema": { + "type": "object", + "properties": { + "file": {"type": "string"}, + "description": {"type": "string"}, + "patch": {"type": "string", "description": "Unified diff to apply"}, + }, + "required": ["file", "description", "patch"], + }, + }) + async def handle_propose_fix(inp: dict) -> str: + client = await Client.connect("localhost:7233") + + # Deterministic ID: crash-retry re-attaches to the existing workflow + # rather than starting a duplicate. If the human already decided before + # the crash, handle.result() returns immediately. + wf_id = f"fix-{activity.info().activity_id}-{inp['file']}" + handle = await client.start_workflow( + FixApprovalWorkflow.run, + inp, + id=wf_id, + task_queue="approvals", + id_conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING, + ) + + # agentic_session heartbeats before each LLM turn, not during tool + # execution. Heartbeat manually here so the activity isn't timed out + # while waiting for a human reviewer. + async def _heartbeat() -> None: + while True: + activity.heartbeat() + await asyncio.sleep(10) + + hb = asyncio.create_task(_heartbeat()) + try: + decision = await handle.result() + finally: + hb.cancel() + + if decision["approved"]: + session.results.append(inp) + return "approved — fix applied" + return f"rejected: {decision.get('reason', 'no reason given')}" + + await session.run_tool_loop( + registry=tools, + provider="anthropic", + system=( + "You are a code reviewer. Propose fixes using propose_fix. " + "If a fix is rejected, revise your approach using the stated reason." + ), + prompt=f"Review this diff and propose fixes:\n\n{diff}", + ) + return session.results +``` + +Reviewer signals approval from any Temporal client: + +```python +handle = client.get_workflow_handle(wf_id) +await handle.signal(FixApprovalWorkflow.decide, {"approved": True}) +# or with a rejection reason the LLM will read: +await handle.signal(FixApprovalWorkflow.decide, {"approved": False, "reason": "scope too broad — fix one thing at a time"}) +``` + +## Testing without an API key + +```python +from temporalio.contrib.tool_registry import ToolRegistry +from temporalio.contrib.tool_registry.testing import MockProvider, ResponseBuilder + +tools = ToolRegistry() + +@tools.handler({"name": "flag", "description": "d", "input_schema": {"type": "object"}}) +def handle_flag(inp: dict) -> str: + return "ok" # this string is sent back to the LLM as the tool result + +provider = MockProvider([ + ResponseBuilder.tool_call("flag", {"description": "stale API"}), + ResponseBuilder.done("done"), +]) +messages = [{"role": "user", "content": "analyze"}] +provider.run_loop(messages, tools) # synchronous +assert len(messages) > 2 +``` + +## Integration testing with real providers + +To run the integration tests against live Anthropic and OpenAI APIs: + +```bash +RUN_INTEGRATION_TESTS=1 \ + ANTHROPIC_API_KEY=sk-ant-... \ + OPENAI_API_KEY=sk-proj-... \ + uv run pytest tests/contrib/tool_registry/ -v +``` + +Tests skip automatically when `RUN_INTEGRATION_TESTS` is unset. Real API calls +incur billing — expect a few cents per full test run. + +## Storing application results + +`session.results` accumulates application-level results during the tool loop. +Elements are serialized to JSON inside each heartbeat checkpoint — they must be +plain maps/dicts with JSON-serializable values. A non-serializable value raises +a non-retryable `ApplicationError` at heartbeat time rather than silently losing +data on the next retry. + +### Storing typed results + +Convert your domain type to a plain dict at the tool-call site and back after +the session: + +```python +import dataclasses + +@dataclasses.dataclass +class Finding: + type: str + file: str + +# Inside tool handler: +session.results.append(dataclasses.asdict(Finding(type="smell", file="foo.py"))) + +# After session: +findings = [Finding(**r) for r in session.results] +``` + +## Per-turn LLM timeout + +Individual LLM calls inside the tool loop are unbounded by default. A hung HTTP +connection holds the activity open until Temporal's `ScheduleToCloseTimeout` +fires — potentially many minutes. Set a per-turn timeout on the provider client: + +```python +import anthropic +client = anthropic.Anthropic(api_key=..., timeout=30.0) +await session.run_tool_loop(..., client=client) +``` + +Recommended timeouts: + +| Model type | Recommended | +|---|---| +| Standard (Claude 3.x, GPT-4o) | 30 s | +| Reasoning (o1, o3, extended thinking) | 300 s | + +### Activity-level timeout + +Set `schedule_to_close_timeout` on the activity options to bound the entire conversation: + +```python +await workflow.execute_activity( + long_analysis, + prompt, + schedule_to_close_timeout=timedelta(seconds=600), +) +``` + +The per-turn client timeout and `schedule_to_close_timeout` are complementary: +- Per-turn timeout fires if one LLM call hangs (protects against a single stuck turn) +- `schedule_to_close_timeout` bounds the entire conversation including all retries (protects against runaway multi-turn loops) diff --git a/temporalio/contrib/tool_registry/__init__.py b/temporalio/contrib/tool_registry/__init__.py new file mode 100644 index 000000000..803f1e34b --- /dev/null +++ b/temporalio/contrib/tool_registry/__init__.py @@ -0,0 +1,56 @@ +"""Support for LLM tool-calling within Temporal activities. + +This package provides :class:`ToolRegistry`, a unified interface for defining +LLM tools once and exporting provider-specific schemas for Anthropic or +OpenAI. It also provides :func:`run_tool_loop`, a convenience function for +running a complete multi-turn tool-calling conversation from within a Temporal +activity. + +For crash-safe multi-turn sessions with automatic heartbeat-based checkpoint +and resume, see :mod:`_session` (available after installing with the +``tool-registry`` extra). + +Install:: + + pip install "temporalio[tool-registry]" # Anthropic only + pip install "temporalio[tool-registry-openai]" # Anthropic + OpenAI + +Quickstart:: + + from temporalio import activity + from temporalio.contrib.tool_registry import ToolRegistry, run_tool_loop + + @activity.defn + async def my_llm_activity(prompt: str) -> str: + tools = ToolRegistry() + + @tools.handler({"name": "flag", "description": "Flag an issue", + "input_schema": {"type": "object", + "properties": {"msg": {"type": "string"}}}}) + def handle_flag(inp: dict) -> str: + results.append(inp["msg"]) + return "recorded" + + results: list[str] = [] + await run_tool_loop(provider="anthropic", system="You are ...", + prompt=prompt, tools=tools) + return ", ".join(results) + +See ``README.md`` in this package for full documentation. +""" + +from temporalio.contrib.tool_registry._plugin import ToolRegistryPlugin +from temporalio.contrib.tool_registry._providers import run_tool_loop +from temporalio.contrib.tool_registry._registry import ToolRegistry +from temporalio.contrib.tool_registry._session import AgenticSession, agentic_session + +from . import testing + +__all__ = [ + "AgenticSession", + "ToolRegistry", + "ToolRegistryPlugin", + "agentic_session", + "run_tool_loop", + "testing", +] diff --git a/temporalio/contrib/tool_registry/_plugin.py b/temporalio/contrib/tool_registry/_plugin.py new file mode 100644 index 000000000..a0de170e3 --- /dev/null +++ b/temporalio/contrib/tool_registry/_plugin.py @@ -0,0 +1,86 @@ +"""ToolRegistryPlugin — Temporal plugin for LLM tool-calling activities. + +Configures the worker's sandbox to pass through ``anthropic`` and ``openai`` +imports, which are otherwise blocked by the Temporal workflow sandbox. Apply +this plugin when using :func:`run_tool_loop` or :class:`AgenticSession` in an +activity that runs alongside a sandboxed workflow worker. + +Example:: + + from temporalio.worker import Worker + from temporalio.contrib.tool_registry import ToolRegistryPlugin + + worker = Worker( + client, + task_queue="my-queue", + plugins=[ToolRegistryPlugin(provider="anthropic")], + workflows=[MyWorkflow], + activities=[my_llm_activity], + ) +""" + +from __future__ import annotations + +import dataclasses + +from temporalio.plugin import SimplePlugin +from temporalio.worker import WorkflowRunner +from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner + + +class ToolRegistryPlugin(SimplePlugin): + """Temporal plugin that configures sandbox passthrough for LLM imports. + + The Temporal workflow sandbox blocks imports of third-party packages such + as ``anthropic`` and ``openai``. This plugin adds passthrough rules so + that activities using those libraries can be registered on the same worker + as sandboxed workflows without triggering import errors. + + Args: + provider: LLM provider to configure passthrough for. Either + ``"anthropic"``, ``"openai"``, or ``"both"`` (default: + ``"anthropic"``). + anthropic_model: Default Anthropic model name passed through to + :class:`_providers.AnthropicProvider` when not overridden. + openai_model: Default OpenAI model name passed through to + :class:`_providers.OpenAIProvider` when not overridden. + + Example:: + + plugin = ToolRegistryPlugin(provider="anthropic") + worker = Worker(client, task_queue="q", plugins=[plugin], ...) + """ + + def __init__( + self, + provider: str = "anthropic", + anthropic_model: str = "claude-sonnet-4-6", + openai_model: str = "gpt-4o", + ) -> None: + """Initialize ToolRegistryPlugin with sandbox passthrough rules.""" + self._provider = provider + self._anthropic_model = anthropic_model + self._openai_model = openai_model + + passthrough: list[str] = [] + if provider in ("anthropic", "both"): + passthrough.append("anthropic") + if provider in ("openai", "both"): + passthrough.append("openai") + + def _workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: + if runner is None: + raise ValueError("No WorkflowRunner provided to ToolRegistryPlugin.") + if isinstance(runner, SandboxedWorkflowRunner) and passthrough: + return dataclasses.replace( + runner, + restrictions=runner.restrictions.with_passthrough_modules( + *passthrough + ), + ) + return runner + + super().__init__( + name="ToolRegistryPlugin", + workflow_runner=_workflow_runner, + ) diff --git a/temporalio/contrib/tool_registry/_providers.py b/temporalio/contrib/tool_registry/_providers.py new file mode 100644 index 000000000..c24c77b1a --- /dev/null +++ b/temporalio/contrib/tool_registry/_providers.py @@ -0,0 +1,271 @@ +"""LLM provider implementations for ToolRegistry. + +Provides :class:`AnthropicProvider` and :class:`OpenAIProvider`, each +implementing a complete multi-turn tool-calling loop. The top-level +:func:`run_tool_loop` function constructs the appropriate provider and runs +the loop. + +Both providers follow the same protocol: + +1. Send messages + tool definitions to the model. +2. If the model returns ``tool_use`` / ``function`` blocks, dispatch each + tool call via :meth:`ToolRegistry.dispatch` and append the result. +3. Repeat until the model returns a ``stop_reason`` of ``"end_turn"`` / + ``"stop"`` with no tool calls, or raises :exc:`StopIteration`. + +The providers are *not* intended to be used directly; prefer :func:`run_tool_loop` +or :class:`_session.AgenticSession`. +""" + +from __future__ import annotations + +import json +import os +from typing import Any + +from temporalio.contrib.tool_registry._registry import ToolRegistry + + +def _blocks_to_dicts(content: Any) -> list[dict[str, Any]]: + """Convert an Anthropic response content list to plain JSON-serialisable dicts. + + Anthropic returns Pydantic ``ContentBlock`` objects. Before storing them + in heartbeat state they must be converted to plain ``dict`` instances. + """ + if isinstance(content, list): + result = [] + for item in content: + if isinstance(item, dict): + result.append(item) + elif hasattr(item, "model_dump"): + result.append(item.model_dump()) + elif hasattr(item, "__dict__"): + result.append(dict(vars(item))) + else: + result.append({"type": "text", "text": str(item)}) + return result + return [{"type": "text", "text": str(content)}] + + +class AnthropicProvider: + """Multi-turn Anthropic tool-calling loop. + + Args: + registry: Tool registry whose handlers are called on each tool-use block. + system: System prompt. + client: An ``anthropic.Anthropic`` client instance. If ``None``, one + is constructed from the ``ANTHROPIC_API_KEY`` environment variable. + model: Model name (default: ``claude-sonnet-4-6``). + """ + + def __init__( + self, + registry: ToolRegistry, + system: str, + client: Any = None, + model: str = "claude-sonnet-4-6", + ) -> None: + """Initialize AnthropicProvider.""" + self._registry = registry + self._system = system + self._model = model + if client is None: + import anthropic + + self._client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"]) + else: + self._client = client + + async def run_turn(self, messages: list[dict[str, Any]]) -> bool: + """Execute one turn of the conversation. + + Appends the assistant response (and any tool results) to *messages* + in-place. + + Returns: + ``True`` when the loop should stop (no more tool calls), ``False`` + to continue. + """ + response = self._client.messages.create( + model=self._model, + max_tokens=4096, + system=self._system, + tools=self._registry.to_anthropic(), # type: ignore[arg-type] + messages=messages, # type: ignore[arg-type] + ) + + assistant_content = _blocks_to_dicts(response.content) + messages.append({"role": "assistant", "content": assistant_content}) + + tool_calls = [b for b in assistant_content if b.get("type") == "tool_use"] + if not tool_calls or response.stop_reason == "end_turn": + return True # done + + tool_results = [] + for call in tool_calls: + is_error = False + try: + result = await self._registry.adispatch(call["name"], call.get("input", {})) + except Exception as e: + result = f"error: {e}" + is_error = True + entry: dict[str, Any] = { + "type": "tool_result", + "tool_use_id": call["id"], + "content": result, + } + if is_error: + entry["is_error"] = True + tool_results.append(entry) + messages.append({"role": "user", "content": tool_results}) + return False + + async def run_loop(self, messages: list[dict[str, Any]]) -> None: + """Run turns until the model stops using tools.""" + while not await self.run_turn(messages): + pass + + +class OpenAIProvider: + """Multi-turn OpenAI function-calling loop. + + Args: + registry: Tool registry whose handlers are called on each function call. + system: System prompt. + client: An ``openai.OpenAI`` client instance. If ``None``, one is + constructed from the ``OPENAI_API_KEY`` environment variable. + model: Model name (default: ``gpt-4o``). + """ + + def __init__( + self, + registry: ToolRegistry, + system: str, + client: Any = None, + model: str = "gpt-4o", + ) -> None: + """Initialize OpenAIProvider.""" + self._registry = registry + self._system = system + self._model = model + if client is None: + import openai + + self._client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) + else: + self._client = client + + async def run_turn(self, messages: list[dict[str, Any]]) -> bool: + """Execute one turn of the conversation. + + Appends the assistant response (and any tool results) to *messages* + in-place. + + Returns: + ``True`` when the loop should stop, ``False`` to continue. + """ + # Prepend system message for OpenAI format + full_messages = [{"role": "system", "content": self._system}] + messages + + response = self._client.chat.completions.create( + model=self._model, + tools=self._registry.to_openai(), # type: ignore[arg-type] + messages=full_messages, # type: ignore[arg-type] + ) + + choice = response.choices[0] + message = choice.message + + # Convert to plain dict for heartbeat-safe storage + msg_dict: dict[str, Any] = {"role": "assistant", "content": message.content} + if message.tool_calls: + msg_dict["tool_calls"] = [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in message.tool_calls + if tc.type == "function" + ] + messages.append(msg_dict) + + if not message.tool_calls or choice.finish_reason in ("stop", "length"): + return True + + for tc in message.tool_calls: + if tc.type != "function": + continue + args = json.loads(tc.function.arguments or "{}") + try: + result = await self._registry.adispatch(tc.function.name, args) + except Exception as e: + result = f"error: {e}" + messages.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": result, + } + ) + return False + + async def run_loop(self, messages: list[dict[str, Any]]) -> None: + """Run turns until the model stops calling functions.""" + while not await self.run_turn(messages): + pass + + +async def run_tool_loop( + *, + provider: str, + system: str, + prompt: str, + tools: ToolRegistry, + messages: list[dict[str, Any]] | None = None, + model: str | None = None, + client: Any = None, +) -> list[dict[str, Any]]: + """Run a complete multi-turn LLM tool-calling loop. + + This is the primary entry point for simple (non-resumable) tool loops. + For resumable agentic sessions with crash-safe heartbeating, use + :class:`_session.AgenticSession` via :func:`_session.agentic_session`. + + Args: + provider: LLM provider — ``"anthropic"`` or ``"openai"``. + system: System prompt. + prompt: Initial user message. + tools: Registered tool handlers. + messages: Existing message history to continue from. If ``None``, a + new conversation is started with ``prompt`` as the first message. + model: Model name override. If ``None``, the provider default is used. + client: Pre-constructed LLM client. Useful in tests. + + Returns: + The final ``messages`` list with the complete conversation history. + + Raises: + ValueError: If ``provider`` is not ``"anthropic"`` or ``"openai"``. + """ + if messages is None: + messages = [{"role": "user", "content": prompt}] + elif not messages: + messages = [{"role": "user", "content": prompt}] + + kwargs: dict[str, Any] = {} + if model is not None: + kwargs["model"] = model + if client is not None: + kwargs["client"] = client + + if provider == "anthropic": + await AnthropicProvider(tools, system, **kwargs).run_loop(messages) + elif provider == "openai": + await OpenAIProvider(tools, system, **kwargs).run_loop(messages) + else: + raise ValueError(f"Unknown provider {provider!r}. Use 'anthropic' or 'openai'.") + return messages diff --git a/temporalio/contrib/tool_registry/_registry.py b/temporalio/contrib/tool_registry/_registry.py new file mode 100644 index 000000000..c653cacab --- /dev/null +++ b/temporalio/contrib/tool_registry/_registry.py @@ -0,0 +1,200 @@ +"""ToolRegistry: define LLM tools once, export provider-specific schemas. + +A ``ToolRegistry`` stores a mapping from tool name to its definition (in +Anthropic's ``tool_use`` format) and a callable handler. The same registry +can be converted to Anthropic-format or OpenAI-format schemas for use with +either provider's client library, and dispatches incoming tool calls to the +registered handler. + +Example:: + + tools = ToolRegistry() + + @tools.handler({"name": "flag_issue", "description": "Flag a diagram issue", + "input_schema": {"type": "object", + "properties": {"description": {"type": "string"}}, + "required": ["description"]}}) + def handle_flag_issue(inp: dict) -> str: + results.append(inp["description"]) + return "recorded" + + # Use with Anthropic + response = client.messages.create(tools=tools.to_anthropic(), ...) + + # Use with OpenAI + response = client.chat.completions.create(tools=tools.to_openai(), ...) + + # Dispatch a tool call returned by the model + result = await tools.dispatch(tool_name, tool_input) +""" + +from __future__ import annotations + +import inspect +from collections.abc import Callable +from typing import Any, Awaitable, Union + + +class ToolRegistry: + """Registry mapping tool names to definitions and handlers. + + Tools are registered in Anthropic's ``tool_use`` JSON format (with + ``name``, ``description``, and ``input_schema`` keys). The registry + can then export the same tools for Anthropic or OpenAI providers, and + dispatch incoming tool calls to the appropriate handler. + + Handlers may be synchronous (``def``) or asynchronous (``async def``). + """ + + def __init__(self) -> None: + """Initialize ToolRegistry with empty definitions and handlers.""" + self._definitions: list[dict[str, Any]] = [] + self._handlers: dict[str, Callable[..., Any]] = {} + + # ── Registration ────────────────────────────────────────────────────────── + + def handler(self, definition: dict[str, Any]) -> Callable: + """Decorator that registers a handler for the given tool definition. + + Args: + definition: Tool definition in Anthropic ``tool_use`` format — + must contain ``name``, ``description``, and ``input_schema``. + + Returns: + A decorator that registers the wrapped callable and returns it + unchanged. + + Example:: + + @tools.handler({"name": "my_tool", "description": "...", + "input_schema": {...}}) + async def handle_my_tool(inp: dict) -> str: + return await some_async_call(inp) + """ + + def decorator(fn: Callable[..., Any]) -> Callable: + self._definitions.append(definition) + self._handlers[definition["name"]] = fn + return fn + + return decorator + + @classmethod + def from_mcp_tools(cls, tools: list[Any]) -> "ToolRegistry": + """Build a ``ToolRegistry`` from a list of MCP ``Tool`` objects. + + Each MCP tool exposes ``name``, ``description``, and + ``inputSchema`` attributes. The resulting registry has no-op + handlers (returning an empty string) — callers must replace them + via :meth:`handler` or :meth:`dispatch` overrides as needed. + + Args: + tools: List of ``mcp.Tool`` (or any object with ``name``, + ``description``, and ``inputSchema`` attributes). + + Returns: + A new :class:`ToolRegistry` with definitions populated from + the MCP tools. + """ + registry = cls() + for tool in tools: + defn = { + "name": tool.name, + "description": tool.description or "", + "input_schema": tool.inputSchema + or {"type": "object", "properties": {}}, + } + registry._definitions.append(defn) + registry._handlers[tool.name] = lambda _inp: "" + return registry + + # ── Schema export ───────────────────────────────────────────────────────── + + def to_anthropic(self) -> list[dict[str, Any]]: + """Return tool definitions in Anthropic ``tool_use`` format. + + The definitions are returned exactly as registered — no conversion + needed because the registry stores them in Anthropic format. + + Returns: + List of dicts with ``name``, ``description``, and + ``input_schema`` keys. + """ + return list(self._definitions) + + def to_openai(self) -> list[dict[str, Any]]: + """Return tool definitions in OpenAI function-calling format. + + Converts each Anthropic-format definition to the OpenAI + ``{"type": "function", "function": {...}}`` shape, mapping + ``input_schema`` to ``parameters``. + + Returns: + List of dicts in OpenAI tool format. + """ + result = [] + for defn in self._definitions: + result.append( + { + "type": "function", + "function": { + "name": defn["name"], + "description": defn["description"], + "parameters": defn["input_schema"], + }, + } + ) + return result + + # ── Dispatch ────────────────────────────────────────────────────────────── + + def dispatch(self, name: str, input_dict: dict[str, Any]) -> str: + """Call the handler registered for ``name`` with ``input_dict``. + + Synchronous version. Raises ``TypeError`` if the registered handler is + an ``async def`` — use :meth:`adispatch` for async handlers. + + Args: + name: Tool name as returned by the model. + input_dict: Parsed tool input as a plain ``dict``. + + Returns: + String result from the handler. + + Raises: + KeyError: If no handler is registered for ``name``. + TypeError: If the handler is async (use :meth:`adispatch` instead). + """ + handler = self._handlers.get(name) + if handler is None: + raise KeyError(f"Unknown tool: {name!r}") + if inspect.iscoroutinefunction(handler): + raise TypeError( + f"Handler for {name!r} is async — use `await registry.adispatch(...)` " + "from an async context." + ) + return handler(input_dict) + + async def adispatch(self, name: str, input_dict: dict[str, Any]) -> str: + """Call the handler registered for ``name`` with ``input_dict``. + + Supports both synchronous and asynchronous handlers. An + ``async def`` handler is awaited; a plain ``def`` handler is called + directly. This is the method called internally by all providers. + + Args: + name: Tool name as returned by the model. + input_dict: Parsed tool input as a plain ``dict``. + + Returns: + String result from the handler. + + Raises: + KeyError: If no handler is registered for ``name``. + """ + handler = self._handlers.get(name) + if handler is None: + raise KeyError(f"Unknown tool: {name!r}") + if inspect.iscoroutinefunction(handler): + return await handler(input_dict) + return handler(input_dict) diff --git a/temporalio/contrib/tool_registry/_session.py b/temporalio/contrib/tool_registry/_session.py new file mode 100644 index 000000000..0f4320ff5 --- /dev/null +++ b/temporalio/contrib/tool_registry/_session.py @@ -0,0 +1,271 @@ +"""agentic_session — durable multi-turn LLM activity with heartbeat checkpointing. + +Provides :func:`agentic_session`, an async context manager that saves +conversation state via :func:`temporalio.activity.heartbeat` after each +LLM turn. On activity retry, the session is automatically restored from +the last checkpoint so the conversation resumes mid-turn rather than +restarting from the beginning. + +This builds on standard Temporal APIs — :func:`activity.heartbeat` and +:attr:`activity.info().heartbeat_details` — and adds the conversation-specific +serialization and restore logic as a reusable primitive. + +Example:: + + from temporalio.contrib.tool_registry import ToolRegistry, agentic_session + + @activity.defn + async def analyze(prompt: str) -> list[dict]: + async with agentic_session() as session: + tools = ToolRegistry() + + @tools.handler({"name": "flag", "description": "...", + "input_schema": {"type": "object", + "properties": {"msg": {"type": "string"}}}}) + def handle_flag(inp: dict) -> str: + session.results.append(inp) + return "recorded" + + await session.run_tool_loop( + registry=tools, + provider="anthropic", + system="You are a code reviewer...", + prompt=prompt, + ) + return session.results +""" + +from __future__ import annotations + +import contextlib +import dataclasses +import json +import logging +import os +from collections.abc import AsyncGenerator +from typing import Any + +from temporalio import activity +from temporalio.exceptions import ApplicationError + +_logger = logging.getLogger(__name__) + +from temporalio.contrib.tool_registry._registry import ToolRegistry + + +@dataclasses.dataclass +class AgenticSession: + """Holds conversation state across a multi-turn LLM tool-use loop. + + Instances are created by :func:`agentic_session` and should not normally + be constructed directly. On activity retry, :func:`agentic_session` + deserializes the saved state from :attr:`activity.info().heartbeat_details` + and passes it to the constructor. + + Attributes: + messages: Full conversation history in provider-neutral format. + Appended to by :meth:`run_tool_loop` on each turn. + results: Application-level results accumulated during the session. + Serialized to JSON for checkpoint storage — elements must be + JSON-serializable (plain dicts or dataclass instances). + """ + + messages: list[dict[str, Any]] = dataclasses.field(default_factory=list) + results: list[Any] = dataclasses.field(default_factory=list) + + async def run_tool_loop( + self, + registry: ToolRegistry, + provider: str, + system: str, + prompt: str, + heartbeat_every: int = 1, + model: str | None = None, + client: Any = None, + ) -> None: + """Run the agentic tool-use loop to completion. + + If :attr:`messages` is empty (fresh start), adds ``prompt`` as the + first user message. Otherwise resumes from the existing conversation + state (retry case). + + Checkpoints via :meth:`_checkpoint` before each LLM call. If the + activity is cancelled due to a heartbeat timeout, the next attempt + will restore from the last checkpoint and continue from there. + + Args: + registry: Tool registry with handlers for all tools the model + may call. + provider: LLM provider — ``"anthropic"`` or ``"openai"``. + system: System prompt. + prompt: Initial user message. Ignored on resume (messages already set). + heartbeat_every: Heartbeat every N turns (default 1 = every turn). + Increase to reduce heartbeat overhead for cheap/fast models. + A crash between heartbeats replays from the previous checkpoint. + model: Model name override. If ``None``, the provider default + is used (``claude-sonnet-4-6`` / ``gpt-4o``). + client: Pre-constructed LLM client. Useful in tests to avoid + API key requirements. + + Raises: + ValueError: If ``provider`` is not ``"anthropic"`` or ``"openai"``. + """ + from temporalio.contrib.tool_registry._providers import ( + AnthropicProvider, + OpenAIProvider, + ) + + if not self.messages: + self.messages = [{"role": "user", "content": prompt}] + + kwargs: dict[str, Any] = {} + if model is not None: + kwargs["model"] = model + if client is not None: + kwargs["client"] = client + + if provider == "anthropic": + if client is None: + import anthropic as _anthropic + + kwargs["client"] = _anthropic.Anthropic( + api_key=os.environ["ANTHROPIC_API_KEY"] + ) + p = AnthropicProvider(registry, system, **kwargs) + turn = 0 + while True: + turn += 1 + if (turn - 1) % heartbeat_every == 0: + self._checkpoint() + if await p.run_turn(self.messages): + break + + elif provider == "openai": + if client is None: + import openai as _openai + + kwargs["client"] = _openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) + p_oa = OpenAIProvider(registry, system, **kwargs) + turn = 0 + while True: + turn += 1 + if (turn - 1) % heartbeat_every == 0: + self._checkpoint() + if await p_oa.run_turn(self.messages): + break + + else: + raise ValueError( + f"Unknown provider {provider!r}. Use 'anthropic' or 'openai'." + ) + + def _checkpoint(self) -> None: + """Heartbeat the serialized conversation state for crash-safe resume. + + :func:`activity.heartbeat` sends the payload to the Temporal server. + On retry, ``activity.info().heartbeat_details[0]`` contains this JSON. + :func:`agentic_session` reads it on entry and restores messages + + results. + + Results are serialized with :func:`dataclasses.asdict` if they are + dataclass instances, or left as-is if they are already plain dicts. + + Raises: + ApplicationError: (non-retryable) If any result is not JSON-serializable. + """ + + def _serialize_result(result: Any, idx: int) -> dict[str, Any]: + if dataclasses.is_dataclass(result) and not isinstance(result, type): + s = dataclasses.asdict(result) + else: + try: + s = dict(result) + except (TypeError, ValueError) as e: + raise ApplicationError( + f"AgenticSession: results[{idx}] cannot be converted to dict: {e}. " + "Store only plain dicts or dataclass instances.", + non_retryable=True, + ) from e + try: + json.dumps(s) + except (TypeError, ValueError) as e: + raise ApplicationError( + f"AgenticSession: results[{idx}] is not JSON-serializable: {e}. " + "Store only plain dicts or dataclasses with JSON-serializable fields.", + non_retryable=True, + ) from e + return s + + serialized_results = [_serialize_result(r, i) for i, r in enumerate(self.results)] + + activity.heartbeat( + json.dumps( + { + "version": 1, + "messages": self.messages, + "results": serialized_results, + } + ) + ) + + +@contextlib.asynccontextmanager +async def agentic_session() -> AsyncGenerator[AgenticSession, None]: + """Async context manager for a durable, checkpointed LLM tool-use session. + + On entry, restores conversation state (messages + results) from + :attr:`activity.info().heartbeat_details`, if present. This handles the + retry case — the session resumes mid-conversation instead of restarting + from the first turn. + + On exit, the session's final state is available via the yielded + :class:`AgenticSession` object. + + Usage:: + + async with agentic_session() as session: + tools = ToolRegistry() + # ... register handlers ... + await session.run_tool_loop( + registry=tools, + provider="anthropic", + system=SYSTEM, + prompt=prompt, + ) + # session.results contains all results accumulated during the run + + Retry behavior: + A 10-turn LLM analysis that crashes on turn 9 resumes from turn 9. + Token cost on retry is proportional to remaining turns, not total turns. + This matters for long analyses with many tool calls. + + Note: + Logs a warning if heartbeat details are present but cannot be decoded; + the session then starts fresh rather than raising. + """ + details = activity.info().heartbeat_details + saved: dict[str, Any] = {} + if details: + try: + saved = json.loads(details[0]) + v = saved.get("version") + if v is None: + _logger.warning( + "AgenticSession: checkpoint has no version field" + " — may be from an older release" + ) + elif v != 1: + _logger.warning( + "AgenticSession: checkpoint version %s, expected 1 — starting fresh", v + ) + saved = {} + except (json.JSONDecodeError, TypeError, IndexError) as e: + _logger.warning( + "AgenticSession: failed to decode checkpoint, starting fresh: %s", e + ) + + session = AgenticSession( + messages=saved.get("messages", []), + results=saved.get("results", []), + ) + yield session diff --git a/temporalio/contrib/tool_registry/testing.py b/temporalio/contrib/tool_registry/testing.py new file mode 100644 index 000000000..1adae7d7e --- /dev/null +++ b/temporalio/contrib/tool_registry/testing.py @@ -0,0 +1,269 @@ +"""Testing utilities for :mod:`temporalio.contrib.tool_registry`. + +Provides mock objects that allow unit tests to exercise ToolRegistry, +AgenticSession, and run_tool_loop without an API key or a running Temporal +server. Also provides :class:`MockAgenticSession` for testing code that +uses :func:`agentic_session` without any LLM calls. + +Example:: + + from temporalio.contrib.tool_registry.testing import MockProvider, ResponseBuilder + + # Script two turns: first turn uses a tool, second turn is done. + provider = MockProvider([ + ResponseBuilder.tool_call("flag_issue", {"description": "wrong"}), + ResponseBuilder.done("Analysis complete."), + ]) + + messages = [{"role": "user", "content": "analyze this"}] + provider.run_loop(messages) + assert len(messages) == 5 # user + assistant + tool_result + assistant + ... +""" + +from __future__ import annotations + +import uuid +from typing import Any + +from temporalio.contrib.tool_registry._registry import ToolRegistry + + +class ResponseBuilder: + """Factories for scripting mock LLM turn sequences.""" + + @staticmethod + def tool_call( + tool_name: str, + tool_input: dict[str, Any], + call_id: str | None = None, + ) -> dict[str, Any]: + """Create a mock assistant message that makes a single tool call. + + Args: + tool_name: Name of the tool to call. + tool_input: Input dict passed to the tool. + call_id: Optional tool-use ID. Auto-generated if not provided. + + Returns: + A dict in Anthropic-style assistant content format. + """ + cid = call_id or f"test_{uuid.uuid4().hex[:8]}" + return { + "_mock_stop": False, + "content": [ + {"type": "tool_use", "id": cid, "name": tool_name, "input": tool_input} + ], + } + + @staticmethod + def done(text: str = "Done.") -> dict[str, Any]: + """Create a mock assistant message that ends the loop. + + Args: + text: Assistant text response. + + Returns: + A dict indicating the loop should stop. + """ + return { + "_mock_stop": True, + "content": [{"type": "text", "text": text}], + } + + +class MockProvider: + """LLM provider that returns pre-scripted responses without API calls. + + Responses are consumed in order. Once exhausted, the loop stops. + + Args: + responses: Sequence of response dicts produced by :class:`ResponseBuilder`. + + Example:: + + provider = MockProvider([ + ResponseBuilder.tool_call("greet", {"name": "world"}), + ResponseBuilder.done("said hello"), + ]) + messages = [{"role": "user", "content": "greet world"}] + provider.run_loop(messages, registry) + """ + + def __init__(self, responses: list[dict[str, Any]]) -> None: + """Initialize MockProvider with a list of scripted responses.""" + self._responses = list(responses) + self._index = 0 + + def run_turn(self, messages: list[dict[str, Any]], registry: ToolRegistry) -> bool: + """Execute one scripted turn. + + Returns: + ``True`` when done, ``False`` to continue. + """ + if self._index >= len(self._responses): + return True + + response = self._responses[self._index] + self._index += 1 + stop = response.get("_mock_stop", True) + content = response.get("content", []) + + messages.append({"role": "assistant", "content": content}) + + if not stop: + tool_results = [] + for block in content: + if block.get("type") == "tool_use": + result = registry.dispatch(block["name"], block.get("input", {})) + tool_results.append( + { + "type": "tool_result", + "tool_use_id": block["id"], + "content": result, + } + ) + if tool_results: + messages.append({"role": "user", "content": tool_results}) + + return stop + + def run_loop( + self, + messages: list[dict[str, Any]], + registry: ToolRegistry | None = None, + ) -> None: + """Run all scripted turns until exhausted or a done response is reached.""" + if registry is None: + registry = ToolRegistry() + while not self.run_turn(messages, registry): + pass + + +class FakeToolRegistry(ToolRegistry): + """A :class:`ToolRegistry` that records all dispatch calls. + + Useful for asserting which tools were called and with what inputs during a + test run. + + Example:: + + fake = FakeToolRegistry() + fake.handler({"name": "greet", "description": "d", "input_schema": {}})( + lambda inp: "ok" + ) + fake.dispatch("greet", {"name": "world"}) + assert fake.calls == [("greet", {"name": "world"})] + """ + + def __init__(self) -> None: + """Initialize FakeToolRegistry with an empty calls list.""" + super().__init__() + self.calls: list[tuple[str, dict[str, Any]]] = [] + + def dispatch(self, name: str, input_dict: dict[str, Any]) -> str: + """Record the call then delegate to the registered handler.""" + self.calls.append((name, input_dict)) + return super().dispatch(name, input_dict) + + +class MockAgenticSession: + """A pre-canned :class:`AgenticSession` that returns fixed results without LLM calls. + + Use this to test code that calls :func:`agentic_session` and inspects + ``session.results`` without needing an API key or a running server. + + Args: + results: Pre-canned results to populate ``session.results`` immediately + on construction. + + Example:: + + from temporalio.contrib.tool_registry.testing import MockAgenticSession + from contextlib import asynccontextmanager + + async def run_with_mock(prompt: str) -> list: + session = MockAgenticSession([{"type": "missing", "symbol": "x", "description": "gone"}]) + # Simulate the agentic loop completing without LLM calls + return session.results + """ + + def __init__(self, results: list[Any] | None = None) -> None: + """Initialize MockAgenticSession with optional pre-seeded results.""" + self.messages: list[dict[str, Any]] = [] + self.results: list[Any] = list(results or []) + + async def run_tool_loop( + self, + registry: "ToolRegistry", + provider: str, + system: str, + prompt: str, + model: str | None = None, + client: Any = None, + ) -> None: + """No-op — does not call any LLM.""" + if not self.messages: + self.messages = [{"role": "user", "content": prompt}] + # No LLM calls; session.results already set by constructor. + + def _checkpoint(self) -> None: + """No-op — does not call activity.heartbeat() in tests.""" + + +class CrashAfterTurns: + """Simulates an activity crash after ``n`` turns. + + Use this in integration tests to verify that :class:`AgenticSession` + correctly resumes from a checkpoint after a crash. + + Args: + n: Number of turns to complete before raising :exc:`RuntimeError`. + + Example:: + + # First invocation crashes after turn 2. + # Second invocation (retry) should resume from turn 2's checkpoint. + crasher = CrashAfterTurns(2) + """ + + def __init__(self, n: int) -> None: + """Initialize CrashAfterTurns to simulate a crash after n turns.""" + self._n = n + self._count = 0 + + def run_turn(self, messages: list[dict[str, Any]], registry: ToolRegistry) -> bool: + """Execute one turn, raising RuntimeError after n turns are complete. + + Args: + messages: Conversation message history, mutated in-place. + registry: Tool registry (unused — crash-only provider). + + Returns: + ``True`` when the last allowed turn completes, never returns False. + """ + self._count += 1 + if self._count > self._n: + raise RuntimeError( + f"CrashAfterTurns: simulated crash after {self._n} turns" + ) + # Produce a no-op done response + messages.append( + {"role": "assistant", "content": [{"type": "text", "text": "..."}]} + ) + return self._count >= self._n + + def run_loop( + self, + messages: list[dict[str, Any]], + registry: ToolRegistry | None = None, + ) -> None: + """Run turns until n turns complete or a crash is triggered. + + Args: + messages: Conversation message history, mutated in-place. + registry: Optional tool registry; a default empty one is used if None. + """ + if registry is None: + registry = ToolRegistry() + while not self.run_turn(messages, registry): + pass diff --git a/tests/contrib/tool_registry/__init__.py b/tests/contrib/tool_registry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/contrib/tool_registry/test_agentic_session.py b/tests/contrib/tool_registry/test_agentic_session.py new file mode 100644 index 000000000..1ae4e2208 --- /dev/null +++ b/tests/contrib/tool_registry/test_agentic_session.py @@ -0,0 +1,430 @@ +"""Unit tests for AgenticSession and agentic_session. + +Tests run without an API key or Temporal server. activity.info() and +activity.heartbeat() are mocked to avoid needing a running worker. +""" + +from __future__ import annotations + +import dataclasses +import json +from unittest.mock import MagicMock, patch + +import pytest + +from temporalio.contrib.tool_registry import ( + AgenticSession, + ToolRegistry, + agentic_session, +) +from temporalio.contrib.tool_registry.testing import ( + MockAgenticSession, +) + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + + +def _make_activity_info(heartbeat_details=None): + info = MagicMock() + info.heartbeat_details = heartbeat_details or [] + return info + + +# ── agentic_session context manager tests ──────────────────────────────────── + + +async def test_fresh_start_empty_state(): + """No heartbeat details → session starts with empty messages and results.""" + with patch("temporalio.contrib.tool_registry._session.activity") as mock_activity: + mock_activity.info.return_value = _make_activity_info() + async with agentic_session() as session: + assert session.messages == [] + assert session.results == [] + + +async def test_restores_from_checkpoint(): + """Heartbeat details present → session restores messages and results on retry.""" + saved = { + "version": 1, + "messages": [{"role": "user", "content": "analyze this"}], + "results": [{"type": "missing", "symbol": "patched", "description": "removed"}], + } + with patch("temporalio.contrib.tool_registry._session.activity") as mock_activity: + mock_activity.info.return_value = _make_activity_info( + heartbeat_details=[json.dumps(saved)] + ) + async with agentic_session() as session: + assert session.messages == saved["messages"] + assert len(session.results) == 1 + assert session.results[0]["type"] == "missing" + assert session.results[0]["symbol"] == "patched" + + +async def test_ignores_invalid_checkpoint(): + """Corrupted heartbeat details → session starts fresh (no crash).""" + with patch("temporalio.contrib.tool_registry._session.activity") as mock_activity: + mock_activity.info.return_value = _make_activity_info( + heartbeat_details=["not valid json{{"] + ) + async with agentic_session() as session: + assert session.messages == [] + assert session.results == [] + + +# ── AgenticSession._checkpoint tests ───────────────────────────────────────── + + +def test_checkpoint_serializes_messages_and_results(): + """_checkpoint() heartbeats JSON with messages + results.""" + heartbeat_calls = [] + with patch("temporalio.contrib.tool_registry._session.activity") as mock_activity: + mock_activity.heartbeat.side_effect = lambda s: heartbeat_calls.append( + json.loads(s) + ) + + @dataclasses.dataclass + class Result: + type: str + symbol: str + description: str + + session = AgenticSession( + messages=[{"role": "user", "content": "hi"}], + results=[Result(type="missing", symbol="x", description="gone")], + ) + session._checkpoint() + + assert len(heartbeat_calls) == 1 + payload = heartbeat_calls[0] + assert payload["messages"] == [{"role": "user", "content": "hi"}] + assert len(payload["results"]) == 1 + assert payload["results"][0]["type"] == "missing" + + +def test_checkpoint_empty_state(): + """_checkpoint() with no messages/results produces valid empty JSON.""" + heartbeat_calls = [] + with patch("temporalio.contrib.tool_registry._session.activity") as mock_activity: + mock_activity.heartbeat.side_effect = lambda s: heartbeat_calls.append( + json.loads(s) + ) + AgenticSession()._checkpoint() + + assert heartbeat_calls[0] == {"version": 1, "messages": [], "results": []} + + +def test_checkpoint_plain_dict_results(): + """_checkpoint() handles plain dict results (not just dataclasses).""" + heartbeat_calls = [] + with patch("temporalio.contrib.tool_registry._session.activity") as mock_activity: + mock_activity.heartbeat.side_effect = lambda s: heartbeat_calls.append( + json.loads(s) + ) + session = AgenticSession( + results=[ + {"type": "deprecated", "symbol": "old_api", "description": "use new"} + ] + ) + session._checkpoint() + + assert heartbeat_calls[0]["results"][0]["symbol"] == "old_api" + + +# ── AgenticSession.run_tool_loop tests ──────────────────────────────────────── + +_ENV = {"ANTHROPIC_API_KEY": "sk-ant-test"} + + +def _make_mock_anthropic_client( + responses: list[bool], + tool_name: str = "noop", +) -> MagicMock: + """Build a mock Anthropic client. + + Args: + responses: list of bools — True = done (end_turn, no tools), + False = not done (returns a tool_use block to continue). + tool_name: Tool name for tool_use blocks when not done. + """ + client = MagicMock() + mock_responses = [] + for done in responses: + msg = MagicMock() + if done: + msg.content = [MagicMock(type="text", text="done")] + msg.stop_reason = "end_turn" + else: + # Return a tool_use block so run_turn continues + tool_block = MagicMock() + tool_block.type = "tool_use" + tool_block.id = "test_id" + tool_block.name = tool_name + tool_block.input = {} + # model_dump needed for _blocks_to_dicts + tool_block.model_dump.return_value = { + "type": "tool_use", + "id": "test_id", + "name": tool_name, + "input": {}, + } + msg.content = [tool_block] + msg.stop_reason = "tool_use" + mock_responses.append(msg) + client.messages.create.side_effect = mock_responses + return client + + +async def test_run_tool_loop_adds_prompt_on_fresh_start(): + """On fresh start, run_tool_loop adds the user prompt as the first message.""" + import os + + session = AgenticSession() + mock_client = _make_mock_anthropic_client([True]) # done on first turn + + with ( + patch("temporalio.contrib.tool_registry._session.activity") as mock_activity, + patch.dict(os.environ, _ENV), + ): + mock_activity.heartbeat = MagicMock() + await session.run_tool_loop( + registry=ToolRegistry(), + provider="anthropic", + system="system", + prompt="my prompt", + client=mock_client, + ) + + assert session.messages[0] == {"role": "user", "content": "my prompt"} + + +async def test_run_tool_loop_skips_prompt_on_resume(): + """On retry with existing messages, the prompt is not added again.""" + import os + + existing = [ + {"role": "user", "content": "original prompt"}, + {"role": "assistant", "content": [{"type": "text", "text": "ok"}]}, + ] + session = AgenticSession(messages=list(existing)) + mock_client = _make_mock_anthropic_client([True]) # done on first turn + + with ( + patch("temporalio.contrib.tool_registry._session.activity") as mock_activity, + patch.dict(os.environ, _ENV), + ): + mock_activity.heartbeat = MagicMock() + await session.run_tool_loop( + registry=ToolRegistry(), + provider="anthropic", + system="system", + prompt="new prompt that should be ignored", + client=mock_client, + ) + + # First two messages unchanged + assert session.messages[:2] == existing + + +async def test_run_tool_loop_checkpoints_each_turn(): + """_checkpoint is called once per turn before the LLM call.""" + import os + + session = AgenticSession(messages=[{"role": "user", "content": "go"}]) + checkpoint_count = [0] + # Script 3 turns: first 2 return not-done (tool_use), third returns done + registry = ToolRegistry() + registry.handler({"name": "noop", "description": "d", "input_schema": {}})( + lambda _: "ok" + ) + mock_client = _make_mock_anthropic_client([False, False, True], tool_name="noop") + + def counting_checkpoint(): + checkpoint_count[0] += 1 + + with ( + patch("temporalio.contrib.tool_registry._session.activity") as mock_activity, + patch.object(session, "_checkpoint", side_effect=counting_checkpoint), + patch.dict(os.environ, _ENV), + ): + mock_activity.heartbeat = MagicMock() + await session.run_tool_loop( + registry=registry, + provider="anthropic", + system="s", + prompt="ignored", + client=mock_client, + ) + + assert checkpoint_count[0] == 3 + + +async def test_run_tool_loop_unknown_provider_raises(): + """Unknown provider raises ValueError.""" + session = AgenticSession(messages=[{"role": "user", "content": "x"}]) + with patch("temporalio.contrib.tool_registry._session.activity") as mock_activity: + mock_activity.heartbeat = MagicMock() + with pytest.raises(ValueError, match="Unknown provider"): + await session.run_tool_loop( + registry=ToolRegistry(), + provider="gemini", + system="s", + prompt="p", + ) + + +# ── Checkpoint round-trip test (T6) ────────────────────────────────────────── + + +def test_checkpoint_round_trip_preserves_tool_calls(): + """Round-trip: checkpoint with tool_calls serializes/deserializes correctly. + + Catches the class of bug where nested dicts lose their type after a + JSON serialize→deserialize cycle (mirrors the .NET List bug). + """ + tool_calls_in_memory = [ + { + "id": "call_abc", + "type": "function", + "function": {"name": "my_tool", "arguments": '{"x": 1}'}, + } + ] + assistant_msg = {"role": "assistant", "tool_calls": tool_calls_in_memory} + result = {"type": "smell", "file": "foo.py"} + + session = AgenticSession(messages=[assistant_msg], results=[result]) + heartbeat_calls: list[str] = [] + + with patch("temporalio.contrib.tool_registry._session.activity") as mock_activity: + mock_activity.heartbeat.side_effect = lambda s: heartbeat_calls.append(s) + session._checkpoint() + + assert len(heartbeat_calls) == 1 + restored = json.loads(heartbeat_calls[0]) + + assert restored["messages"][0]["role"] == "assistant" + tool_calls_restored = restored["messages"][0]["tool_calls"] + assert isinstance(tool_calls_restored, list) + assert len(tool_calls_restored) == 1 + assert tool_calls_restored[0]["id"] == "call_abc" + assert tool_calls_restored[0]["function"]["name"] == "my_tool" + assert restored["results"][0]["type"] == "smell" + assert restored["results"][0]["file"] == "foo.py" + + +# ── heartbeat_every tests (T7) ──────────────────────────────────────────────── + + +async def test_heartbeat_every_default_checkpoints_each_turn(): + """heartbeat_every=1 (default) checkpoints before every turn.""" + import os + + session = AgenticSession(messages=[{"role": "user", "content": "go"}]) + registry = ToolRegistry() + registry.handler({"name": "noop", "description": "d", "input_schema": {}})( + lambda _: "ok" + ) + mock_client = _make_mock_anthropic_client([False, False, True], tool_name="noop") + checkpoint_count = [0] + + def counting_checkpoint(): + checkpoint_count[0] += 1 + + with ( + patch("temporalio.contrib.tool_registry._session.activity") as mock_activity, + patch.object(session, "_checkpoint", side_effect=counting_checkpoint), + patch.dict(os.environ, _ENV), + ): + mock_activity.heartbeat = MagicMock() + await session.run_tool_loop( + registry=registry, + provider="anthropic", + system="s", + prompt="ignored", + heartbeat_every=1, + client=mock_client, + ) + + assert checkpoint_count[0] == 3 # one checkpoint per turn + + +async def test_heartbeat_every_n_skips_turns(): + """heartbeat_every=3 checkpoints on turns 1, 4, 7, ...""" + import os + + session = AgenticSession(messages=[{"role": "user", "content": "go"}]) + registry = ToolRegistry() + registry.handler({"name": "noop", "description": "d", "input_schema": {}})( + lambda _: "ok" + ) + # 4 turns: [tool, tool, tool, done] + mock_client = _make_mock_anthropic_client( + [False, False, False, True], tool_name="noop" + ) + checkpoint_count = [0] + + def counting_checkpoint(): + checkpoint_count[0] += 1 + + with ( + patch("temporalio.contrib.tool_registry._session.activity") as mock_activity, + patch.object(session, "_checkpoint", side_effect=counting_checkpoint), + patch.dict(os.environ, _ENV), + ): + mock_activity.heartbeat = MagicMock() + await session.run_tool_loop( + registry=registry, + provider="anthropic", + system="s", + prompt="ignored", + heartbeat_every=3, + client=mock_client, + ) + + # 4 turns, heartbeat_every=3 → checkpoints on turns 1 and 4 → 2 checkpoints + assert checkpoint_count[0] == 2 + + +# ── MockAgenticSession tests ────────────────────────────────────────────────── + + +async def test_mock_agentic_session_returns_pre_canned_results(): + """MockAgenticSession returns fixed results without LLM calls.""" + session = MockAgenticSession( + results=[{"type": "deprecated", "symbol": "old_fn", "description": "removed"}] + ) + await session.run_tool_loop( + registry=ToolRegistry(), + provider="anthropic", + system="s", + prompt="p", + ) + assert len(session.results) == 1 + assert session.results[0]["symbol"] == "old_fn" + + +async def test_mock_agentic_session_empty_results(): + """MockAgenticSession with no results starts empty.""" + session = MockAgenticSession() + assert session.results == [] + + +# ── Integration test (skipped unless RUN_INTEGRATION_TESTS=true) ───────────── + + +@pytest.mark.skipif( + not __import__("os").environ.get("RUN_INTEGRATION_TESTS"), + reason="RUN_INTEGRATION_TESTS not set", +) +async def test_crash_resume(): + """Integration: activity crashes mid-loop; second attempt resumes from checkpoint. + + Uses WorkflowEnvironment to run a real Temporal worker. The first activity + attempt crashes after 2 turns; the second attempt should restore from the + turn-2 checkpoint and complete from there, not from turn 0. + """ + # This test requires a running Temporal server (temporal server start-dev) + # and would use WorkflowEnvironment.start_local() to spin up a test server. + # Omitted here to keep the test file self-contained; see the project README + # for instructions on running the full integration suite. + pytest.skip("Full integration test requires WorkflowEnvironment setup — see README") diff --git a/tests/contrib/tool_registry/test_tool_registry.py b/tests/contrib/tool_registry/test_tool_registry.py new file mode 100644 index 000000000..9be464c32 --- /dev/null +++ b/tests/contrib/tool_registry/test_tool_registry.py @@ -0,0 +1,453 @@ +"""Unit tests for ToolRegistry. + +Tests run without an API key or Temporal server. LLM calls are replaced by +:class:`MockProvider` from ``testing.py``. +""" + +from __future__ import annotations + +import asyncio + +import pytest + +from temporalio.contrib.tool_registry import ToolRegistry, run_tool_loop +from temporalio.contrib.tool_registry.testing import ( + CrashAfterTurns, + FakeToolRegistry, + MockProvider, + ResponseBuilder, +) + +# ── ToolRegistry unit tests ─────────────────────────────────────────────────── + + +def test_dispatch_calls_handler(): + registry = ToolRegistry() + + @registry.handler({"name": "greet", "description": "d", "input_schema": {}}) + def handle_greet(inp: dict) -> str: + return f"hello {inp.get('name')}" + + assert registry.dispatch("greet", {"name": "world"}) == "hello world" + + +def test_dispatch_unknown_raises(): + registry = ToolRegistry() + with pytest.raises(KeyError, match="unknown"): + registry.dispatch("unknown", {}) + + +def test_adispatch_sync_handler(): + """adispatch works with a plain def handler.""" + registry = ToolRegistry() + + @registry.handler({"name": "greet", "description": "d", "input_schema": {}}) + def handle_greet(inp: dict) -> str: + return f"hello {inp.get('name')}" + + result = asyncio.run(registry.adispatch("greet", {"name": "world"})) + assert result == "hello world" + + +def test_adispatch_async_handler(): + """adispatch awaits async handlers.""" + registry = ToolRegistry() + + @registry.handler({"name": "async_greet", "description": "d", "input_schema": {}}) + async def handle_async_greet(inp: dict) -> str: + return f"async hello {inp.get('name')}" + + result = asyncio.run(registry.adispatch("async_greet", {"name": "world"})) + assert result == "async hello world" + + +def test_dispatch_async_handler_raises_typeerror(): + """dispatch() on an async handler raises TypeError — use adispatch instead.""" + registry = ToolRegistry() + + @registry.handler({"name": "async_tool", "description": "d", "input_schema": {}}) + async def handle(inp: dict) -> str: + return "async result" + + with pytest.raises(TypeError, match="adispatch"): + registry.dispatch("async_tool", {}) + + +def test_to_openai_format(): + registry = ToolRegistry() + + @registry.handler( + { + "name": "my_tool", + "description": "Does something useful.", + "input_schema": { + "type": "object", + "properties": {"arg": {"type": "string"}}, + "required": ["arg"], + }, + } + ) + def handle(inp: dict) -> str: + return "ok" + + result = registry.to_openai() + assert len(result) == 1 + converted = result[0] + assert converted["type"] == "function" + assert converted["function"]["name"] == "my_tool" + assert converted["function"]["description"] == "Does something useful." + assert "arg" in converted["function"]["parameters"]["properties"] + + +def test_to_anthropic_returns_definitions_unchanged(): + defn = {"name": "t", "description": "d", "input_schema": {}} + registry = ToolRegistry() + registry.handler(defn)(lambda inp: "ok") + assert registry.to_anthropic() == [defn] + + +def test_multiple_tools(): + registry = ToolRegistry() + registry.handler({"name": "alpha", "description": "a", "input_schema": {}})( + lambda _: "a" + ) + registry.handler({"name": "beta", "description": "b", "input_schema": {}})( + lambda _: "b" + ) + result = registry.to_openai() + assert len(result) == 2 + assert result[0]["function"]["name"] == "alpha" + assert result[1]["function"]["name"] == "beta" + + +def test_from_mcp_tools(): + """from_mcp_tools wraps MCP-style objects into definitions.""" + + class FakeMCPTool: + def __init__(self, name: str, description: str, schema: dict): + self.name = name + self.description = description + self.inputSchema = schema + + mcp_tools = [ + FakeMCPTool("search", "Search files", {"type": "object", "properties": {}}), + FakeMCPTool("read", "Read a file", {"type": "object", "properties": {}}), + ] + + registry = ToolRegistry.from_mcp_tools(mcp_tools) + assert len(registry.to_anthropic()) == 2 + names = {d["name"] for d in registry.to_anthropic()} + assert names == {"search", "read"} + + +# ── MockProvider tests ──────────────────────────────────────────────────────── + + +def test_mock_provider_dispatches_tool_calls(): + """MockProvider dispatches tool calls and runs the loop to completion.""" + collected: list[str] = [] + registry = ToolRegistry() + + @registry.handler({"name": "collect", "description": "d", "input_schema": {}}) + def handle_collect(inp: dict) -> str: + collected.append(inp.get("value", "")) + return "ok" + + provider = MockProvider( + [ + ResponseBuilder.tool_call("collect", {"value": "first"}), + ResponseBuilder.tool_call("collect", {"value": "second"}), + ResponseBuilder.done("all done"), + ] + ) + messages: list[dict] = [{"role": "user", "content": "go"}] + provider.run_loop(messages, registry) + + assert collected == ["first", "second"] + + +def test_mock_provider_stops_on_done(): + provider = MockProvider([ResponseBuilder.done("finished")]) + messages: list[dict] = [{"role": "user", "content": "x"}] + provider.run_loop(messages) + # One user message + one assistant message + assert len(messages) == 2 + assert messages[-1]["role"] == "assistant" + + +def test_mock_provider_stops_when_exhausted(): + """If responses are exhausted, run_loop stops cleanly.""" + provider = MockProvider([]) + messages: list[dict] = [{"role": "user", "content": "x"}] + provider.run_loop(messages) + assert len(messages) == 1 # nothing added + + +# ── FakeToolRegistry tests ──────────────────────────────────────────────────── + + +def test_fake_registry_records_calls(): + fake = FakeToolRegistry() + + @fake.handler({"name": "greet", "description": "d", "input_schema": {}}) + def handle(inp: dict) -> str: + return "ok" + + fake.dispatch("greet", {"name": "world"}) + fake.dispatch("greet", {"name": "temporal"}) + + assert fake.calls == [("greet", {"name": "world"}), ("greet", {"name": "temporal"})] + + +# ── run_tool_loop tests ─────────────────────────────────────────────────────── + + +def test_run_tool_loop_unknown_provider_raises(): + async def _run(): + await run_tool_loop( + provider="gemini", + system="s", + prompt="p", + tools=ToolRegistry(), + ) + + with pytest.raises(ValueError, match="gemini"): + asyncio.run(_run()) + + +# ── CrashAfterTurns tests ───────────────────────────────────────────────────── + + +def test_crash_after_turns_raises(): + crasher = CrashAfterTurns(1) + messages: list[dict] = [{"role": "user", "content": "x"}] + # First turn: fine + crasher.run_turn(messages, ToolRegistry()) + # Second turn: crashes + with pytest.raises(RuntimeError, match="simulated crash"): + crasher.run_turn(messages, ToolRegistry()) + + +# ── is_error / handler error tests ─────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_anthropic_handler_error_sets_is_error_and_does_not_crash(): + """Handler exceptions are caught; the tool result carries is_error=True.""" + from temporalio.contrib.tool_registry._providers import AnthropicProvider + + registry = ToolRegistry() + + @registry.handler({"name": "boom", "description": "d", "input_schema": {}}) + def handle(inp: dict) -> str: + raise ValueError("intentional failure") + + # Minimal Anthropic mock: first call returns a tool_use, second returns end_turn. + calls: list[int] = [] + + class _MockMessages: + def create(self, **_kwargs): # type: ignore[override] + calls.append(1) + if len(calls) == 1: + return _FakeResponse( + content=[{"type": "tool_use", "id": "c1", "name": "boom", "input": {}}], + stop_reason="tool_use", + ) + return _FakeResponse(content=[{"type": "text", "text": "done"}], stop_reason="end_turn") + + class _FakeClient: + messages = _MockMessages() + + class _FakeResponse: + def __init__(self, content, stop_reason): + self.content = content + self.stop_reason = stop_reason + + provider = AnthropicProvider(registry, "sys", client=_FakeClient()) + messages: list[dict] = [{"role": "user", "content": "go"}] + await provider.run_turn(messages) + + # messages[1] is the assistant message; messages[2] is the tool result wrapper. + tool_result_msg = messages[2] + assert tool_result_msg["role"] == "user" + tool_result = tool_result_msg["content"][0] + assert tool_result["type"] == "tool_result" + assert tool_result["is_error"] is True + assert "intentional failure" in tool_result["content"] + + +@pytest.mark.asyncio +async def test_async_handler_invoked_via_adispatch(): + """Async handlers are awaited by providers via adispatch.""" + from temporalio.contrib.tool_registry._providers import AnthropicProvider + + registry = ToolRegistry() + invocations: list[str] = [] + + @registry.handler({"name": "async_tool", "description": "d", "input_schema": {}}) + async def handle(inp: dict) -> str: + invocations.append("called") + return "async result" + + calls: list[int] = [] + + class _MockMessages: + def create(self, **_kwargs): + calls.append(1) + if len(calls) == 1: + return _FakeResponse( + content=[{"type": "tool_use", "id": "c1", "name": "async_tool", "input": {}}], + stop_reason="tool_use", + ) + return _FakeResponse(content=[{"type": "text", "text": "done"}], stop_reason="end_turn") + + class _FakeClient: + messages = _MockMessages() + + class _FakeResponse: + def __init__(self, content, stop_reason): + self.content = content + self.stop_reason = stop_reason + + provider = AnthropicProvider(registry, "sys", client=_FakeClient()) + messages: list[dict] = [{"role": "user", "content": "go"}] + await provider.run_turn(messages) + + assert invocations == ["called"] + # messages[1] is the assistant message; messages[2] is the tool result wrapper. + tool_result_msg = messages[2] + tool_result = tool_result_msg["content"][0] + assert tool_result["content"] == "async result" + assert "is_error" not in tool_result + + +@pytest.mark.asyncio +async def test_openai_handler_error_does_not_crash(): + """Handler exceptions in OpenAI provider are caught and returned as error strings.""" + from temporalio.contrib.tool_registry._providers import OpenAIProvider + + registry = ToolRegistry() + + @registry.handler({"name": "boom", "description": "d", "input_schema": {}}) + def handle(inp: dict) -> str: + raise RuntimeError("openai error test") + + class _MockCompletions: + def create(self, **_kwargs): # type: ignore[override] + return _FakeResp() + + class _FakeClient: + class chat: + completions = _MockCompletions() + + class _FakeTc: + id = "tc1" + type = "function" + + class function: + name = "boom" + arguments = "{}" + + class _FakeMsg: + content = None + tool_calls = [_FakeTc()] + + class _FakeChoice: + message = _FakeMsg() + finish_reason = "tool_calls" + + class _FakeResp: + choices = [_FakeChoice()] + + provider = OpenAIProvider(registry, "sys", client=_FakeClient()) + messages: list[dict] = [{"role": "user", "content": "go"}] + # Should not raise even though the handler throws. + try: + await provider.run_turn(messages) + except Exception as e: + pytest.fail(f"run_turn raised unexpectedly: {e}") + + tool_msg = messages[-1] + assert tool_msg["role"] == "tool" + assert "openai error test" in tool_msg["content"] + + +# ── Integration test (skipped unless RUN_INTEGRATION_TESTS=true) ───────────── + + +@pytest.mark.skipif( + not __import__("os").environ.get("RUN_INTEGRATION_TESTS"), + reason="RUN_INTEGRATION_TESTS not set", +) +@pytest.mark.asyncio +async def test_integration_anthropic_real_call(): + """End-to-end: run_tool_loop with real Anthropic API call.""" + import os + + assert os.environ.get("ANTHROPIC_API_KEY"), "ANTHROPIC_API_KEY required" + + collected: list[str] = [] + tools = ToolRegistry() + + @tools.handler( + { + "name": "record", + "description": "Record a value", + "input_schema": { + "type": "object", + "properties": {"value": {"type": "string"}}, + "required": ["value"], + }, + } + ) + def handle_record(inp: dict) -> str: + collected.append(inp["value"]) + return "recorded" + + await run_tool_loop( + provider="anthropic", + system="You must call record() exactly once with value='hello'.", + prompt="Please call the record tool with value='hello'.", + tools=tools, + ) + + assert "hello" in collected + + +@pytest.mark.skipif( + not __import__("os").environ.get("RUN_INTEGRATION_TESTS"), + reason="RUN_INTEGRATION_TESTS not set", +) +@pytest.mark.asyncio +async def test_integration_openai_real_call(): + """End-to-end: run_tool_loop with real OpenAI API call.""" + import os + + assert os.environ.get("OPENAI_API_KEY"), "OPENAI_API_KEY required" + + collected: list[str] = [] + tools = ToolRegistry() + + @tools.handler( + { + "name": "record", + "description": "Record a value", + "input_schema": { + "type": "object", + "properties": {"value": {"type": "string"}}, + "required": ["value"], + }, + } + ) + def handle_record(inp: dict) -> str: + collected.append(inp["value"]) + return "recorded" + + await run_tool_loop( + provider="openai", + system="You must call record() exactly once with value='hello'.", + prompt="Please call the record tool with value='hello'.", + tools=tools, + ) + + assert "hello" in collected diff --git a/uv.lock b/uv.lock index 6d824cf92..ea336cfa2 100644 --- a/uv.lock +++ b/uv.lock @@ -246,6 +246,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] +[[package]] +name = "anthropic" +version = "0.94.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "docstring-parser" }, + { name = "httpx" }, + { name = "jiter" }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/60/d7/11a649b986da06aaeb81334632f1843d70e3797f54ca4a9c5d604b7987d0/anthropic-0.94.0.tar.gz", hash = "sha256:dde8c57de73538c5136c1bca9b16da92e75446b53a73562cc911574e4934435c", size = 654236, upload-time = "2026-04-10T22:27:59.853Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/ac/7185750e8688f6ff79b0e3d6a61372c88b81ba81fcda8798c70598e18aca/anthropic-0.94.0-py3-none-any.whl", hash = "sha256:42550b401eed8fcd7f6654234560f99c428306301bca726d32bca4bfb6feb748", size = 627519, upload-time = "2026-04-10T22:27:57.541Z" }, +] + [[package]] name = "antlr4-python3-runtime" version = "4.13.2" @@ -5049,6 +5068,13 @@ opentelemetry = [ pydantic = [ { name = "pydantic" }, ] +tool-registry = [ + { name = "anthropic" }, +] +tool-registry-openai = [ + { name = "anthropic" }, + { name = "openai" }, +] [package.dev-dependencies] dev = [ @@ -5090,11 +5116,14 @@ dev = [ [package.metadata] requires-dist = [ { name = "aioboto3", marker = "extra == 'aioboto3'", specifier = ">=10.4.0" }, + { name = "anthropic", marker = "extra == 'tool-registry'", specifier = ">=0.40.0" }, + { name = "anthropic", marker = "extra == 'tool-registry-openai'", specifier = ">=0.40.0" }, { name = "google-adk", marker = "extra == 'google-adk'", specifier = ">=1.27.0,<2" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.48.2,<2" }, { name = "langsmith", marker = "extra == 'langsmith'", specifier = ">=0.7.0,<0.8" }, { name = "mcp", marker = "extra == 'openai-agents'", specifier = ">=1.9.4,<2" }, { name = "nexus-rpc", specifier = "==1.4.0" }, + { name = "openai", marker = "extra == 'tool-registry-openai'", specifier = ">=1.0.0" }, { name = "openai-agents", marker = "extra == 'openai-agents'", specifier = ">=0.3,<0.7" }, { name = "opentelemetry-api", marker = "extra == 'lambda-worker-otel'", specifier = ">=1.11.1,<2" }, { name = "opentelemetry-api", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, @@ -5110,7 +5139,7 @@ requires-dist = [ { name = "types-protobuf", specifier = ">=3.20,<7.0.0" }, { name = "typing-extensions", specifier = ">=4.2.0,<5" }, ] -provides-extras = ["grpc", "opentelemetry", "pydantic", "openai-agents", "google-adk", "langsmith", "lambda-worker-otel", "aioboto3"] +provides-extras = ["grpc", "opentelemetry", "pydantic", "openai-agents", "google-adk", "langsmith", "lambda-worker-otel", "tool-registry", "tool-registry-openai", "aioboto3"] [package.metadata.requires-dev] dev = [