From 9c6b9a9bb96c4c8a6692b0b5c4a36c9cd40e27a5 Mon Sep 17 00:00:00 2001 From: lex00 <121451605+lex00@users.noreply.github.com> Date: Sun, 12 Apr 2026 18:54:19 -0600 Subject: [PATCH 1/5] Add temporal-tool-registry module with AgenticSession support Co-Authored-By: Claude Sonnet 4.6 --- gradle/linting.gradle | 2 +- settings.gradle | 1 + temporal-tool-registry/README.md | 202 ++++++++++++ temporal-tool-registry/build.gradle | 34 ++ .../temporal/toolregistry/AgenticSession.java | 172 +++++++++++ .../toolregistry/AnthropicConfig.java | 102 ++++++ .../toolregistry/AnthropicProvider.java | 158 ++++++++++ .../temporal/toolregistry/OpenAIConfig.java | 102 ++++++ .../temporal/toolregistry/OpenAIProvider.java | 193 ++++++++++++ .../io/temporal/toolregistry/Provider.java | 27 ++ .../toolregistry/SessionCheckpoint.java | 27 ++ .../io/temporal/toolregistry/SessionFn.java | 10 + .../temporal/toolregistry/ToolDefinition.java | 85 +++++ .../io/temporal/toolregistry/ToolHandler.java | 21 ++ .../temporal/toolregistry/ToolRegistry.java | 136 ++++++++ .../io/temporal/toolregistry/TurnResult.java | 30 ++ .../toolregistry/testing/CrashAfterTurns.java | 60 ++++ .../toolregistry/testing/DispatchCall.java | 30 ++ .../testing/FakeToolRegistry.java | 44 +++ .../testing/MockAgenticSession.java | 64 ++++ .../toolregistry/testing/MockProvider.java | 116 +++++++ .../toolregistry/testing/MockResponse.java | 85 +++++ .../toolregistry/AgenticSessionTest.java | 292 ++++++++++++++++++ .../toolregistry/ToolRegistryTest.java | 265 ++++++++++++++++ .../testing/TestingUtilitiesTest.java | 240 ++++++++++++++ 25 files changed, 2497 insertions(+), 1 deletion(-) create mode 100644 temporal-tool-registry/README.md create mode 100644 temporal-tool-registry/build.gradle create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/AgenticSession.java create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/AnthropicConfig.java create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/AnthropicProvider.java create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/OpenAIConfig.java create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/OpenAIProvider.java create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/Provider.java create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/SessionCheckpoint.java create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/SessionFn.java create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolDefinition.java create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolHandler.java create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolRegistry.java create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/TurnResult.java create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/CrashAfterTurns.java create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/DispatchCall.java create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/FakeToolRegistry.java create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/MockAgenticSession.java create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/MockProvider.java create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/MockResponse.java create mode 100644 temporal-tool-registry/src/test/java/io/temporal/toolregistry/AgenticSessionTest.java create mode 100644 temporal-tool-registry/src/test/java/io/temporal/toolregistry/ToolRegistryTest.java create mode 100644 temporal-tool-registry/src/test/java/io/temporal/toolregistry/testing/TestingUtilitiesTest.java diff --git a/gradle/linting.gradle b/gradle/linting.gradle index fbc410d359..e38c9ef792 100644 --- a/gradle/linting.gradle +++ b/gradle/linting.gradle @@ -9,7 +9,7 @@ subprojects { target 'src/*/java/**/*.java' targetExclude '**/generated/*' targetExclude '**/.idea/**' - googleJavaFormat('1.24.0') + googleJavaFormat('1.25.2') } kotlin { diff --git a/settings.gradle b/settings.gradle index 918ceaa28e..1fe7230e3a 100644 --- a/settings.gradle +++ b/settings.gradle @@ -5,6 +5,7 @@ include 'temporal-sdk' include 'temporal-testing' include 'temporal-test-server' include 'temporal-opentracing' +include 'temporal-tool-registry' include 'temporal-kotlin' include 'temporal-spring-boot-autoconfigure' include 'temporal-spring-boot-starter' diff --git a/temporal-tool-registry/README.md b/temporal-tool-registry/README.md new file mode 100644 index 0000000000..685d905601 --- /dev/null +++ b/temporal-tool-registry/README.md @@ -0,0 +1,202 @@ +# temporal-tool-registry + +LLM tool-calling primitives for Temporal activities — define tools once, use with +Anthropic or OpenAI. + +## Before you start + +A Temporal Activity is a function that Temporal monitors and retries automatically on failure. Temporal streams progress between retries via heartbeats — that's the mechanism `AgenticSession` uses to resume a crashed LLM conversation mid-turn. + +`ToolRegistry.runToolLoop` works standalone in any function — no Temporal server needed. Add `AgenticSession` only when you need crash-safe resume inside a Temporal activity. + +`AgenticSession` requires a running Temporal worker — it reads and writes heartbeat state from the active activity context. Use `ToolRegistry.runToolLoop` standalone for scripts, one-off jobs, or any code that runs outside a Temporal worker. + +New to Temporal? → https://docs.temporal.io/develop + +## Install + +Add to your `build.gradle`: + +```groovy +dependencies { + // Replace VERSION with the latest release from https://search.maven.org + implementation 'io.temporal:temporal-tool-registry:VERSION' + // Add only the LLM SDK(s) you use: + implementation 'com.anthropic:anthropic-java:VERSION' // Anthropic + implementation 'com.openai:openai-java:VERSION' // OpenAI +} +``` + +## Quickstart + +Tool definitions use [JSON Schema](https://json-schema.org/understanding-json-schema/) for `inputSchema`. The quickstart uses a single string field; for richer schemas refer to the JSON Schema docs. + +```java +import io.temporal.toolregistry.*; + +@ActivityMethod +public List analyze(String prompt) throws Exception { + List issues = new ArrayList<>(); + ToolRegistry registry = new ToolRegistry(); + registry.register( + ToolDefinition.builder() + .name("flag_issue") + .description("Flag a problem found in the analysis") + .inputSchema(Map.of( + "type", "object", + "properties", Map.of("description", Map.of("type", "string")), + "required", List.of("description"))) + .build(), + (Map input) -> { + issues.add((String) input.get("description")); + return "recorded"; // this string is sent back to the LLM as the tool result + }); + + AnthropicConfig cfg = AnthropicConfig.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + Provider provider = new AnthropicProvider(cfg, registry, + "You are a code reviewer. Call flag_issue for each problem you find."); + + ToolRegistry.runToolLoop(provider, registry, "" /* system prompt: "" defers to provider default */, prompt); + return issues; +} +``` + +### Selecting a model + +The default model is `"claude-sonnet-4-6"` (Anthropic) or `"gpt-4o"` (OpenAI). Override with the `model()` builder method: + +```java +AnthropicConfig cfg = AnthropicConfig.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .model("claude-3-5-sonnet-20241022") + .build(); +``` + +Model IDs are defined by the provider — see Anthropic or OpenAI docs for current names. + +### OpenAI + +```java +OpenAIConfig cfg = OpenAIConfig.builder() + .apiKey(System.getenv("OPENAI_API_KEY")) + .build(); +Provider provider = new OpenAIProvider(cfg, registry, "your system prompt"); +ToolRegistry.runToolLoop(provider, registry, "" /* system prompt: "" defers to provider default */, prompt); +``` + +## Crash-safe agentic sessions + +For multi-turn LLM conversations that must survive activity retries, use +`AgenticSession.runWithSession`. It saves conversation history via +`Activity.getExecutionContext().heartbeat()` on every turn and restores it on retry. + +```java +@ActivityMethod +public List longAnalysis(String prompt) throws Exception { + List issues = new ArrayList<>(); + + AgenticSession.runWithSession(session -> { + ToolRegistry registry = new ToolRegistry(); + registry.register( + ToolDefinition.builder().name("flag").description("...").inputSchema(Map.of("type", "object")).build(), + input -> { session.addIssue(input); return "ok"; /* sent back to LLM */ }); + + AnthropicConfig cfg = AnthropicConfig.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); + Provider provider = new AnthropicProvider(cfg, registry, "your system prompt"); + + session.runToolLoop(provider, registry, "your system prompt", prompt); + issues.addAll(session.getIssues()); // capture after loop completes + }); + + return issues; +} +``` + +## Testing without an API key + +```java +import io.temporal.toolregistry.testing.*; + +@Test +public void testAnalyze() throws Exception { + ToolRegistry registry = new ToolRegistry(); + registry.register( + ToolDefinition.builder().name("flag").description("d") + .inputSchema(Map.of("type", "object")).build(), + input -> "ok"); + + MockProvider provider = new MockProvider( + MockResponse.toolCall("flag", Map.of("description", "stale API")), + MockResponse.done("analysis complete")); + + List> msgs = + ToolRegistry.runToolLoop(provider, registry, "sys", "analyze"); + assertTrue(msgs.size() > 2); +} +``` + +## Integration testing with real providers + +To run the integration tests against live Anthropic and OpenAI APIs: + +```bash +RUN_INTEGRATION_TESTS=1 \ + ANTHROPIC_API_KEY=sk-ant-... \ + OPENAI_API_KEY=sk-proj-... \ + ./gradlew test --tests "*.ToolRegistryTest.testIntegration*" +``` + +Tests skip automatically when `RUN_INTEGRATION_TESTS` is unset. Real API calls +incur billing — expect a few cents per full test run. + +## Storing application results + +`session.getIssues()` accumulates application-level +results during the tool loop. Elements are serialized to JSON inside each heartbeat +checkpoint — they must be plain maps/dicts with JSON-serializable values. A non-serializable +value raises a non-retryable `ApplicationError` at heartbeat time rather than silently +losing data on the next retry. + +### Storing typed results + +Convert your domain type to a plain dict at the tool-call site and back after the session: + +```java +record Issue(String type, String file) {} + +// Inside tool handler: +session.addIssue(Map.of("type", "smell", "file", "Foo.java")); + +// After session (using Jackson for convenient mapping): +// requires jackson-databind in your build.gradle: +// implementation 'com.fasterxml.jackson.core:jackson-databind:VERSION' +ObjectMapper mapper = new ObjectMapper(); +List issues = session.getIssues().stream() + .map(m -> mapper.convertValue(m, Issue.class)) + .toList(); +``` + +## Per-turn LLM timeout + +Individual LLM calls inside the tool loop are unbounded by default. A hung HTTP +connection holds the activity open until Temporal's `ScheduleToCloseTimeout` +fires — potentially many minutes. Set a per-turn timeout on the provider client: + +```java +AnthropicConfig cfg = AnthropicConfig.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .timeout(Duration.ofSeconds(30)) + .build(); +Provider provider = new AnthropicProvider(cfg, registry, "your system prompt"); +// provider now enforces 30s per turn +``` + +Recommended timeouts: + +| Model type | Recommended | +|---|---| +| Standard (Claude 3.x, GPT-4o) | 30 s | +| Reasoning (o1, o3, extended thinking) | 300 s | diff --git a/temporal-tool-registry/build.gradle b/temporal-tool-registry/build.gradle new file mode 100644 index 0000000000..599b37a4f4 --- /dev/null +++ b/temporal-tool-registry/build.gradle @@ -0,0 +1,34 @@ +description = '''Temporal Java SDK Tool Registry - LLM tool-calling primitives for Temporal activities''' + +// Both Anthropic and OpenAI Java SDKs require Java 11+, so this module targets Java 11. +// The core SDK supports Java 8+, but this contrib module is explicitly Java 11+. +afterEvaluate { + compileJava.options.compilerArgs.removeAll { it == '8' } + compileJava.options.compilerArgs.removeAll { it == '--release' } + compileJava.options.compilerArgs.addAll(['--release', '11']) + + compileTestJava.options.compilerArgs.removeAll { it == '8' } + compileTestJava.options.compilerArgs.removeAll { it == '--release' } + compileTestJava.options.compilerArgs.addAll(['--release', '11']) +} + +ext { + anthropicVersion = '2.24.0' // com.anthropic:anthropic-java + openaiVersion = '4.31.0' // com.openai:openai-java +} + +dependencies { + // Not bundled — consumers provide temporal-sdk themselves, just like temporal-opentracing. + compileOnly project(':temporal-sdk') + + // LLM providers are optional compile-time deps; users add only what they use. + compileOnly "com.anthropic:anthropic-java:$anthropicVersion" + compileOnly "com.openai:openai-java:$openaiVersion" + + testImplementation project(':temporal-testing') + testImplementation "junit:junit:${junitVersion}" + testImplementation "org.mockito:mockito-core:${mockitoVersion}" + testImplementation "com.anthropic:anthropic-java:$anthropicVersion" + testImplementation "com.openai:openai-java:$openaiVersion" + testRuntimeOnly "ch.qos.logback:logback-classic:${logbackVersion}" +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/AgenticSession.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/AgenticSession.java new file mode 100644 index 0000000000..b82a6fee15 --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/AgenticSession.java @@ -0,0 +1,172 @@ +package io.temporal.toolregistry; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.temporal.activity.Activity; +import io.temporal.activity.ActivityExecutionContext; +import io.temporal.client.ActivityCompletionException; +import io.temporal.failure.ApplicationFailure; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Maintains conversation state (messages and issues) across multiple turns of a tool-calling loop, + * with heartbeat checkpointing for crash recovery. + * + *

Use {@link #runWithSession(SessionFn)} inside a Temporal activity to get automatic checkpoint + * restore-on-retry and heartbeat on each turn: + * + *

{@code
+ * AgenticSession.runWithSession(session -> {
+ *     session.runToolLoop(provider, registry, systemPrompt, userPrompt);
+ * });
+ * }
+ * + *

For simple non-resumable loops, use {@link ToolRegistry#runToolLoop} instead. + */ +public class AgenticSession { + + private static final Logger log = LoggerFactory.getLogger(AgenticSession.class); + + private final List> messages = new ArrayList<>(); + private final List> issues = new ArrayList<>(); + + /** Creates an empty session. */ + public AgenticSession() {} + + /** + * Runs the multi-turn LLM tool-calling loop, heartbeating before each turn so Temporal can track + * progress and recover from crashes. + * + *

Must be called from within a Temporal activity (requires an active {@link + * ActivityExecutionContext}). + * + *

If the activity is cancelled, the heartbeat call throws {@link ActivityCompletionException} + * — this propagates out of {@code runToolLoop} and up to the caller. No explicit cancellation + * check is needed. + * + * @param provider the LLM provider adapter + * @param registry the tool registry + * @param system the system prompt + * @param prompt the initial user prompt (ignored if restoring from a checkpoint that already has + * messages) + * @throws ActivityCompletionException if the activity is cancelled + * @throws Exception on API or dispatch errors + */ + public void runToolLoop(Provider provider, ToolRegistry registry, String system, String prompt) + throws Exception { + if (messages.isEmpty()) { + Map userMsg = new java.util.LinkedHashMap<>(); + userMsg.put("role", "user"); + userMsg.put("content", prompt); + messages.add(userMsg); + } + + while (true) { + // Heartbeat before each turn — throws ActivityCompletionException if cancelled. + checkpoint(); + + TurnResult result = provider.runTurn(messages, registry.definitions()); + messages.addAll(result.getNewMessages()); + if (result.isDone()) { + return; + } + } + } + + /** + * Heartbeats the current session state to Temporal. Called automatically by {@link #runToolLoop}, + * but can also be called manually between tool dispatches. + * + *

Throws {@link ActivityCompletionException} if the activity has been cancelled. + * + * @throws ApplicationFailure (non-retryable) if any issue is not JSON-serializable + */ + public void checkpoint() throws ActivityCompletionException { + ObjectMapper mapper = new ObjectMapper(); + for (int i = 0; i < issues.size(); i++) { + try { + mapper.writeValueAsString(issues.get(i)); + } catch (Exception e) { + throw ApplicationFailure.newNonRetryableFailure( + "AgenticSession: issues[" + + i + + "] is not JSON-serializable: " + + e.getMessage() + + ". Store only Map with JSON-serializable values.", + "InvalidArgument"); + } + } + SessionCheckpoint cp = new SessionCheckpoint(messages, issues); + Activity.getExecutionContext().heartbeat(cp); + } + + /** + * Runs {@code fn} inside an {@link AgenticSession}, restoring from a heartbeat checkpoint if one + * exists (i.e., on activity retry after crash). + * + *

Must be called from within a Temporal activity. + * + *

Example: + * + *

{@code
+   * AgenticSession.runWithSession(session -> {
+   *     session.runToolLoop(provider, registry, systemPrompt, userPrompt);
+   * });
+   * }
+ * + * @param fn the function to run with the session + * @throws Exception propagated from {@code fn} + */ + public static void runWithSession(SessionFn fn) throws Exception { + AgenticSession session = new AgenticSession(); + ActivityExecutionContext ctx = Activity.getExecutionContext(); + try { + ctx.getHeartbeatDetails(SessionCheckpoint.class) + .ifPresent( + cp -> { + if (cp.version != 0 && cp.version != 1) { + log.warn( + "AgenticSession: checkpoint version {}, expected 1 — starting fresh", + cp.version); + } else { + if (cp.version == 0) { + log.warn( + "AgenticSession: checkpoint has no version field" + + " — may be from an older release"); + } + session.restore(cp); + } + }); + } catch (Exception e) { + log.warn("AgenticSession: failed to decode checkpoint, starting fresh: {}", e.getMessage()); + } + fn.run(session); + } + + /** Returns an unmodifiable view of the conversation messages. */ + public List> getMessages() { + return Collections.unmodifiableList(messages); + } + + /** Returns an unmodifiable view of the issues collected during the session. */ + public List> getIssues() { + return Collections.unmodifiableList(issues); + } + + /** Appends an issue to the issue list. */ + public void addIssue(Map issue) { + issues.add(issue); + } + + /** Restores session state from a checkpoint. Called by {@link #runWithSession} on retry. */ + void restore(SessionCheckpoint checkpoint) { + messages.clear(); + messages.addAll(checkpoint.messages); + issues.clear(); + issues.addAll(checkpoint.issues); + } +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/AnthropicConfig.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/AnthropicConfig.java new file mode 100644 index 0000000000..c671d6ff27 --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/AnthropicConfig.java @@ -0,0 +1,102 @@ +package io.temporal.toolregistry; + +import com.anthropic.client.AnthropicClient; + +/** + * Configuration for {@link AnthropicProvider}. + * + *

Usage: + * + *

{@code
+ * AnthropicConfig cfg = AnthropicConfig.builder()
+ *     .apiKey(System.getenv("ANTHROPIC_API_KEY"))
+ *     .model("claude-sonnet-4-6")
+ *     .build();
+ * AnthropicProvider provider = new AnthropicProvider(cfg, registry, systemPrompt);
+ * }
+ */ +public final class AnthropicConfig { + + private final String apiKey; + private final String model; + private final String baseUrl; + private final AnthropicClient client; + + private AnthropicConfig(Builder builder) { + this.apiKey = builder.apiKey; + this.model = builder.model; + this.baseUrl = builder.baseUrl; + this.client = builder.client; + } + + /** Returns the API key, or {@code null} if {@link #getClient()} is set. */ + public String getApiKey() { + return apiKey; + } + + /** Returns the model name, or {@code null} to use the default ({@code "claude-sonnet-4-6"}). */ + public String getModel() { + return model; + } + + /** Returns the base URL override, or {@code null} to use the default Anthropic endpoint. */ + public String getBaseUrl() { + return baseUrl; + } + + /** + * Returns a pre-constructed client to use instead of building one from {@link #getApiKey()} and + * {@link #getBaseUrl()}. Useful for testing without real API calls. + */ + public AnthropicClient getClient() { + return client; + } + + /** Returns a new builder. */ + public static Builder builder() { + return new Builder(); + } + + /** Builder for {@link AnthropicConfig}. */ + public static final class Builder { + + private String apiKey; + private String model; + private String baseUrl; + private AnthropicClient client; + + private Builder() {} + + /** Sets the Anthropic API key. Required unless {@link #client(AnthropicClient)} is set. */ + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + /** Sets the model name. Defaults to {@code "claude-sonnet-4-6"}. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Overrides the Anthropic API base URL (e.g. for proxies). */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** + * Sets a pre-constructed client. When set, {@link #apiKey} and {@link #baseUrl} are ignored. + * Useful for testing. + */ + public Builder client(AnthropicClient client) { + this.client = client; + return this; + } + + /** Builds the config. */ + public AnthropicConfig build() { + return new AnthropicConfig(this); + } + } +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/AnthropicProvider.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/AnthropicProvider.java new file mode 100644 index 0000000000..fbddaae718 --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/AnthropicProvider.java @@ -0,0 +1,158 @@ +package io.temporal.toolregistry; + +import com.anthropic.client.AnthropicClient; +import com.anthropic.client.okhttp.AnthropicOkHttpClient; +import com.anthropic.models.messages.Message; +import com.anthropic.models.messages.MessageCreateParams; +import com.anthropic.models.messages.MessageParam; +import com.anthropic.models.messages.StopReason; +import com.anthropic.models.messages.Tool; +import com.anthropic.models.messages.ToolUnion; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * {@link Provider} implementation for the Anthropic Messages API. + * + *

Messages are stored as {@code List>} (checkpoint-safe) and converted to + * Anthropic SDK types via Jackson JSON round-trip before each API call. + * + *

Example: + * + *

{@code
+ * AnthropicConfig cfg = AnthropicConfig.builder()
+ *     .apiKey(System.getenv("ANTHROPIC_API_KEY"))
+ *     .build();
+ * Provider provider = new AnthropicProvider(cfg, registry, "You are a helpful assistant.");
+ * }
+ */ +public class AnthropicProvider implements Provider { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private static final String DEFAULT_MODEL = "claude-sonnet-4-6"; + + private final AnthropicClient client; + private final String model; + private final String system; + private final ToolRegistry registry; + + /** + * Creates an AnthropicProvider. + * + * @param cfg provider configuration (API key, model, etc.) + * @param registry tool registry used for dispatching tool calls + * @param system system prompt + */ + public AnthropicProvider(AnthropicConfig cfg, ToolRegistry registry, String system) { + this.model = cfg.getModel() != null ? cfg.getModel() : DEFAULT_MODEL; + this.system = system; + this.registry = registry; + if (cfg.getClient() != null) { + this.client = cfg.getClient(); + } else { + AnthropicOkHttpClient.Builder builder = + AnthropicOkHttpClient.builder().apiKey(cfg.getApiKey()); + if (cfg.getBaseUrl() != null) { + builder.baseUrl(cfg.getBaseUrl()); + } + this.client = builder.build(); + } + } + + /** + * Executes one turn of the conversation. + * + *

Converts the message history and tool definitions to Anthropic SDK types via JSON + * round-trip, calls the Messages API, dispatches any tool calls, and returns the new messages. + */ + @Override + public TurnResult runTurn(List> messages, List tools) + throws Exception { + // Convert message maps to Anthropic MessageParam via JSON round-trip. + List msgParams = mapsToMessageParams(messages); + + // Convert ToolDefinition list to Anthropic ToolUnion list via JSON round-trip. + List toolUnions = toolDefsToToolUnions(tools); + + Message response = + client + .messages() + .create( + MessageCreateParams.builder() + .model(model) + .maxTokens(4096L) + .system(system) + .messages(msgParams) + .tools(toolUnions) + .build()); + + // Convert response content blocks to plain maps for checkpoint-safe storage. + List> contentMaps = contentBlocksToMaps(response); + + List> newMessages = new ArrayList<>(); + Map assistantMsg = new LinkedHashMap<>(); + assistantMsg.put("role", "assistant"); + assistantMsg.put("content", contentMaps); + newMessages.add(assistantMsg); + + // Collect tool-use blocks. + List> toolCalls = + contentMaps.stream() + .filter(b -> "tool_use".equals(b.get("type"))) + .collect(Collectors.toList()); + + boolean endTurn = response.stopReason().map(r -> r == StopReason.END_TURN).orElse(false); + if (toolCalls.isEmpty() || endTurn) { + return new TurnResult(newMessages, true); + } + + // Dispatch each tool call and collect results. + List> toolResults = new ArrayList<>(toolCalls.size()); + for (Map call : toolCalls) { + String name = (String) call.get("name"); + String id = (String) call.get("id"); + @SuppressWarnings("unchecked") + Map input = (Map) call.get("input"); + String result; + try { + result = registry.dispatch(name, input); + } catch (Exception e) { + result = "error: " + e.getMessage(); + } + Map toolResult = new LinkedHashMap<>(); + toolResult.put("type", "tool_result"); + toolResult.put("tool_use_id", id); + toolResult.put("content", result); + toolResults.add(toolResult); + } + Map toolResultMsg = new LinkedHashMap<>(); + toolResultMsg.put("role", "user"); + toolResultMsg.put("content", toolResults); + newMessages.add(toolResultMsg); + return new TurnResult(newMessages, false); + } + + // ── JSON conversion helpers ────────────────────────────────────────────────── + + private static List mapsToMessageParams(List> messages) + throws Exception { + String json = MAPPER.writeValueAsString(messages); + return MAPPER.readValue(json, new TypeReference>() {}); + } + + private static List toolDefsToToolUnions(List defs) throws Exception { + String json = MAPPER.writeValueAsString(defs); + List tools = MAPPER.readValue(json, new TypeReference>() {}); + return tools.stream().map(ToolUnion::ofTool).collect(Collectors.toList()); + } + + private static List> contentBlocksToMaps(Message response) throws Exception { + String json = MAPPER.writeValueAsString(response.content()); + return MAPPER.readValue(json, new TypeReference>>() {}); + } +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/OpenAIConfig.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/OpenAIConfig.java new file mode 100644 index 0000000000..478c600684 --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/OpenAIConfig.java @@ -0,0 +1,102 @@ +package io.temporal.toolregistry; + +import com.openai.client.OpenAIClient; + +/** + * Configuration for {@link OpenAIProvider}. + * + *

Usage: + * + *

{@code
+ * OpenAIConfig cfg = OpenAIConfig.builder()
+ *     .apiKey(System.getenv("OPENAI_API_KEY"))
+ *     .model("gpt-4o")
+ *     .build();
+ * OpenAIProvider provider = new OpenAIProvider(cfg, registry, systemPrompt);
+ * }
+ */ +public final class OpenAIConfig { + + private final String apiKey; + private final String model; + private final String baseUrl; + private final OpenAIClient client; + + private OpenAIConfig(Builder builder) { + this.apiKey = builder.apiKey; + this.model = builder.model; + this.baseUrl = builder.baseUrl; + this.client = builder.client; + } + + /** Returns the API key, or {@code null} if {@link #getClient()} is set. */ + public String getApiKey() { + return apiKey; + } + + /** Returns the model name, or {@code null} to use the default ({@code "gpt-4o"}). */ + public String getModel() { + return model; + } + + /** Returns the base URL override, or {@code null} to use the default OpenAI endpoint. */ + public String getBaseUrl() { + return baseUrl; + } + + /** + * Returns a pre-constructed client to use instead of building one from {@link #getApiKey()} and + * {@link #getBaseUrl()}. Useful for testing without real API calls. + */ + public OpenAIClient getClient() { + return client; + } + + /** Returns a new builder. */ + public static Builder builder() { + return new Builder(); + } + + /** Builder for {@link OpenAIConfig}. */ + public static final class Builder { + + private String apiKey; + private String model; + private String baseUrl; + private OpenAIClient client; + + private Builder() {} + + /** Sets the OpenAI API key. Required unless {@link #client(OpenAIClient)} is set. */ + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + /** Sets the model name. Defaults to {@code "gpt-4o"}. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Overrides the OpenAI API base URL. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** + * Sets a pre-constructed client. When set, {@link #apiKey} and {@link #baseUrl} are ignored. + * Useful for testing. + */ + public Builder client(OpenAIClient client) { + this.client = client; + return this; + } + + /** Builds the config. */ + public OpenAIConfig build() { + return new OpenAIConfig(this); + } + } +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/OpenAIProvider.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/OpenAIProvider.java new file mode 100644 index 0000000000..48ccb0bf75 --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/OpenAIProvider.java @@ -0,0 +1,193 @@ +package io.temporal.toolregistry; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.chat.completions.ChatCompletion; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.ChatCompletionMessage; +import com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageToolCall; +import com.openai.models.chat.completions.ChatCompletionTool; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * {@link Provider} implementation for the OpenAI Chat Completions API. + * + *

Messages are stored as {@code List>} (checkpoint-safe) and converted to + * OpenAI SDK types via Jackson JSON round-trip before each API call. + * + *

Example: + * + *

{@code
+ * OpenAIConfig cfg = OpenAIConfig.builder()
+ *     .apiKey(System.getenv("OPENAI_API_KEY"))
+ *     .build();
+ * Provider provider = new OpenAIProvider(cfg, registry, "You are a helpful assistant.");
+ * }
+ */ +public class OpenAIProvider implements Provider { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private static final String DEFAULT_MODEL = "gpt-4o"; + + private final OpenAIClient client; + private final String model; + private final String system; + private final ToolRegistry registry; + + /** + * Creates an OpenAIProvider. + * + * @param cfg provider configuration (API key, model, etc.) + * @param registry tool registry used for dispatching tool calls + * @param system system prompt + */ + public OpenAIProvider(OpenAIConfig cfg, ToolRegistry registry, String system) { + this.model = cfg.getModel() != null ? cfg.getModel() : DEFAULT_MODEL; + this.system = system; + this.registry = registry; + if (cfg.getClient() != null) { + this.client = cfg.getClient(); + } else { + OpenAIOkHttpClient.Builder builder = OpenAIOkHttpClient.builder().apiKey(cfg.getApiKey()); + if (cfg.getBaseUrl() != null) { + builder.baseUrl(cfg.getBaseUrl()); + } + this.client = builder.build(); + } + } + + /** + * Executes one turn of the conversation. + * + *

Prepends a system message, converts the message history to OpenAI SDK types via JSON + * round-trip, calls the Chat Completions API, dispatches any tool calls, and returns new + * messages. + */ + @Override + public TurnResult runTurn(List> messages, List tools) + throws Exception { + // Build full message list: system prefix + conversation history. + List> full = new ArrayList<>(messages.size() + 1); + Map sysMsg = new LinkedHashMap<>(); + sysMsg.put("role", "system"); + sysMsg.put("content", system); + full.add(sysMsg); + full.addAll(messages); + + // Convert to OpenAI ChatCompletionMessageParam via JSON round-trip. + List chatMsgs = mapsToMessageParams(full); + + // Convert registry definitions to OpenAI ChatCompletionTool via JSON round-trip. + List oaiTools = registryToOpenAITools(); + + ChatCompletion response = + client + .chat() + .completions() + .create( + ChatCompletionCreateParams.builder() + .model(model) + .messages(chatMsgs) + .tools(oaiTools) + .build()); + + if (response.choices().isEmpty()) { + return new TurnResult(new ArrayList<>(), true); + } + + ChatCompletion.Choice choice = response.choices().get(0); + ChatCompletionMessage msg = choice.message(); + + // Build the assistant message map. + Map assistantMsg = new LinkedHashMap<>(); + assistantMsg.put("role", "assistant"); + assistantMsg.put("content", msg.content().orElse(null)); + + // Collect only function-type tool calls (ignore custom/other variants). + List toolCalls = new ArrayList<>(); + for (ChatCompletionMessageToolCall tc : msg.toolCalls().orElse(new ArrayList<>())) { + if (tc.isFunction()) { + toolCalls.add(tc.asFunction()); + } + } + + if (!toolCalls.isEmpty()) { + List> callMaps = new ArrayList<>(toolCalls.size()); + for (ChatCompletionMessageFunctionToolCall tc : toolCalls) { + Map callMap = new LinkedHashMap<>(); + callMap.put("id", tc.id()); + callMap.put("type", "function"); + Map funcMap = new LinkedHashMap<>(); + funcMap.put("name", tc.function().name()); + funcMap.put("arguments", tc.function().arguments()); + callMap.put("function", funcMap); + callMaps.add(callMap); + } + assistantMsg.put("tool_calls", callMaps); + } + + List> newMessages = new ArrayList<>(); + newMessages.add(assistantMsg); + + ChatCompletion.Choice.FinishReason finishReason = choice.finishReason(); + boolean done = + toolCalls.isEmpty() + || finishReason == ChatCompletion.Choice.FinishReason.STOP + || finishReason == ChatCompletion.Choice.FinishReason.LENGTH; + if (done) { + return new TurnResult(newMessages, true); + } + + // Dispatch each tool call. + for (ChatCompletionMessageFunctionToolCall tc : toolCalls) { + String name = tc.function().name(); + String argsJson = tc.function().arguments(); + @SuppressWarnings("unchecked") + Map input = + argsJson != null && !argsJson.isEmpty() + ? MAPPER.readValue(argsJson, Map.class) + : new LinkedHashMap<>(); + String result; + try { + result = registry.dispatch(name, input); + } catch (Exception e) { + result = "error: " + e.getMessage(); + } + Map toolResultMsg = new LinkedHashMap<>(); + toolResultMsg.put("role", "tool"); + toolResultMsg.put("tool_call_id", tc.id()); + toolResultMsg.put("content", result); + newMessages.add(toolResultMsg); + } + return new TurnResult(newMessages, false); + } + + // ── JSON conversion helpers ────────────────────────────────────────────────── + + private static List mapsToMessageParams( + List> messages) throws Exception { + String json = MAPPER.writeValueAsString(messages); + return MAPPER.readValue(json, new TypeReference>() {}); + } + + /** + * Converts the tool registry's definitions to OpenAI ChatCompletionTool via: + * + *

    + *
  1. ToolRegistry.toOpenAI() → List in OpenAI wire format + *
  2. Jackson round-trip → List + *
+ */ + private List registryToOpenAITools() throws Exception { + List> openAIFormat = registry.toOpenAI(); + String json = MAPPER.writeValueAsString(openAIFormat); + return MAPPER.readValue(json, new TypeReference>() {}); + } +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/Provider.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/Provider.java new file mode 100644 index 0000000000..9a461ea16e --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/Provider.java @@ -0,0 +1,27 @@ +package io.temporal.toolregistry; + +import java.util.List; +import java.util.Map; + +/** + * Adapter interface for LLM providers. One implementation per vendor; the loop logic lives in + * {@link ToolRegistry#runToolLoop} and {@link AgenticSession#runToolLoop}. + * + *

Adding a new provider means implementing this single method. + */ +public interface Provider { + + /** + * Executes one turn of the conversation. + * + *

Sends the full message history plus the tool definitions to the LLM, dispatches any tool + * calls using the registry, and returns the new messages to append plus a done flag. + * + * @param messages the current conversation history + * @param tools the tools available to the model + * @return new messages to append and a done flag + * @throws Exception on API or dispatch errors + */ + TurnResult runTurn(List> messages, List tools) + throws Exception; +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/SessionCheckpoint.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/SessionCheckpoint.java new file mode 100644 index 0000000000..79562be8bd --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/SessionCheckpoint.java @@ -0,0 +1,27 @@ +package io.temporal.toolregistry; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Data serialized to the Temporal heartbeat on each turn of {@link AgenticSession#runToolLoop}. + * + *

Package-private — callers interact only with {@link AgenticSession}. + */ +class SessionCheckpoint { + + /** Checkpoint schema version. Absent (0) in pre-versioned checkpoints. */ + public int version = 1; + + public List> messages = new ArrayList<>(); + public List> issues = new ArrayList<>(); + + /** No-arg constructor required for Jackson deserialization. */ + SessionCheckpoint() {} + + SessionCheckpoint(List> messages, List> issues) { + this.messages = new ArrayList<>(messages); + this.issues = new ArrayList<>(issues); + } +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/SessionFn.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/SessionFn.java new file mode 100644 index 0000000000..4a4b2e5cf5 --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/SessionFn.java @@ -0,0 +1,10 @@ +package io.temporal.toolregistry; + +/** + * Functional callback passed to {@link AgenticSession#runWithSession}. Receives a fully initialized + * (and optionally checkpoint-restored) session. + */ +@FunctionalInterface +public interface SessionFn { + void run(AgenticSession session) throws Exception; +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolDefinition.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolDefinition.java new file mode 100644 index 0000000000..71a0c59f7f --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolDefinition.java @@ -0,0 +1,85 @@ +package io.temporal.toolregistry; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; + +/** + * Defines an LLM tool in Anthropic's tool_use JSON format. + * + *

The same definition is used for both Anthropic and OpenAI; each {@link Provider} converts the + * schema to the wire format it needs. + * + *

Example: + * + *

{@code
+ * ToolDefinition def = ToolDefinition.builder()
+ *     .name("flag_issue")
+ *     .description("Flag a problem found during analysis")
+ *     .inputSchema(Map.of(
+ *         "type", "object",
+ *         "properties", Map.of("description", Map.of("type", "string")),
+ *         "required", List.of("description")))
+ *     .build();
+ * }
+ */ +public final class ToolDefinition { + + private final String name; + private final String description; + + @JsonProperty("input_schema") + private final Map inputSchema; + + private ToolDefinition(Builder builder) { + this.name = Objects.requireNonNull(builder.name, "name"); + this.description = Objects.requireNonNull(builder.description, "description"); + this.inputSchema = + Collections.unmodifiableMap(Objects.requireNonNull(builder.inputSchema, "inputSchema")); + } + + public String getName() { + return name; + } + + public String getDescription() { + return description; + } + + @JsonProperty("input_schema") + public Map getInputSchema() { + return inputSchema; + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + private String name; + private String description; + private Map inputSchema; + + private Builder() {} + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder inputSchema(Map inputSchema) { + this.inputSchema = inputSchema; + return this; + } + + public ToolDefinition build() { + return new ToolDefinition(this); + } + } +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolHandler.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolHandler.java new file mode 100644 index 0000000000..342e827bee --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolHandler.java @@ -0,0 +1,21 @@ +package io.temporal.toolregistry; + +import java.util.Map; + +/** + * Called when the LLM invokes a tool. Receives the parsed tool input and returns a string result or + * throws an exception. + * + *

This is a functional interface; implementations can be lambdas: + * + *

{@code
+ * ToolHandler handler = input -> {
+ *     issues.add(input.get("description"));
+ *     return "recorded";
+ * };
+ * }
+ */ +@FunctionalInterface +public interface ToolHandler { + String handle(Map input) throws Exception; +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolRegistry.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolRegistry.java new file mode 100644 index 0000000000..5cb9a1555a --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolRegistry.java @@ -0,0 +1,136 @@ +package io.temporal.toolregistry; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * Maps tool names to definitions and handlers. + * + *

Tools are registered in Anthropic's {@code tool_use} JSON format. The registry exports them + * for Anthropic or OpenAI and dispatches incoming tool calls to the appropriate handler. + * + *

A ToolRegistry is not safe for concurrent modification; build it before passing it to + * concurrent activities. + * + *

Example: + * + *

{@code
+ * ToolRegistry registry = new ToolRegistry();
+ * registry.register(
+ *     ToolDefinition.builder()
+ *         .name("flag_issue")
+ *         .description("Flag a problem")
+ *         .inputSchema(Map.of("type", "object", "properties",
+ *             Map.of("description", Map.of("type", "string")),
+ *             "required", List.of("description")))
+ *         .build(),
+ *     input -> {
+ *         issues.add((String) input.get("description"));
+ *         return "recorded";
+ *     });
+ * }
+ */ +public class ToolRegistry { + + private final List defs = new ArrayList<>(); + private final Map handlers = new HashMap<>(); + + /** Registers a tool definition and its handler. */ + public void register(ToolDefinition definition, ToolHandler handler) { + defs.add(definition); + handlers.put(definition.getName(), handler); + } + + /** + * Dispatches a tool call to the registered handler. + * + * @throws IllegalArgumentException if no handler is registered for {@code name} + */ + public String dispatch(String name, Map input) throws Exception { + ToolHandler handler = handlers.get(name); + if (handler == null) { + throw new IllegalArgumentException("toolregistry: unknown tool \"" + name + "\""); + } + return handler.handle(input); + } + + /** Returns a snapshot of the registered tool definitions. */ + public List definitions() { + return Collections.unmodifiableList(new ArrayList<>(defs)); + } + + /** + * Returns the tool definitions in Anthropic tool_use format. + * + *
{@code
+   * [{"name": "...", "description": "...", "input_schema": {...}}]
+   * }
+ */ + public List> toAnthropic() { + List> result = new ArrayList<>(defs.size()); + for (ToolDefinition def : defs) { + Map tool = new LinkedHashMap<>(); + tool.put("name", def.getName()); + tool.put("description", def.getDescription()); + tool.put("input_schema", def.getInputSchema()); + result.add(tool); + } + return result; + } + + /** + * Returns the tool definitions in OpenAI function-calling format. + * + *
{@code
+   * [{"type": "function", "function": {"name": "...", "description": "...", "parameters": {...}}}]
+   * }
+ */ + public List> toOpenAI() { + List> result = new ArrayList<>(defs.size()); + for (ToolDefinition def : defs) { + Map function = new LinkedHashMap<>(); + function.put("name", def.getName()); + function.put("description", def.getDescription()); + function.put("parameters", def.getInputSchema()); + + Map tool = new LinkedHashMap<>(); + tool.put("type", "function"); + tool.put("function", function); + result.add(tool); + } + return result; + } + + /** + * Runs a complete multi-turn LLM tool-calling loop to completion. + * + *

This is the primary entry point for simple, non-resumable loops. For crash-safe sessions + * with heartbeat checkpointing, use {@link AgenticSession#runWithSession}. + * + * @param provider the LLM provider adapter + * @param registry the tool registry (may be the same object, provided for clarity) + * @param system the system prompt + * @param prompt the initial user prompt + * @return the full message history on completion + */ + public static List> runToolLoop( + Provider provider, ToolRegistry registry, String system, String prompt) throws Exception { + List> messages = new ArrayList<>(); + Map userMsg = new LinkedHashMap<>(); + userMsg.put("role", "user"); + userMsg.put("content", prompt); + messages.add(userMsg); + + while (true) { + TurnResult result = provider.runTurn(messages, registry.definitions()); + messages.addAll(result.getNewMessages()); + if (result.isDone()) { + return Collections.unmodifiableList(messages); + } + } + } +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/TurnResult.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/TurnResult.java new file mode 100644 index 0000000000..c1a962ee21 --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/TurnResult.java @@ -0,0 +1,30 @@ +package io.temporal.toolregistry; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** The result of one turn of the LLM conversation loop. */ +public final class TurnResult { + + private final List> newMessages; + private final boolean done; + + public TurnResult(List> newMessages, boolean done) { + this.newMessages = Collections.unmodifiableList(newMessages); + this.done = done; + } + + /** + * The new messages produced during this turn (assistant response and any tool results). Append + * these to the conversation history before calling the next turn. + */ + public List> getNewMessages() { + return newMessages; + } + + /** Returns {@code true} when the loop should stop. */ + public boolean isDone() { + return done; + } +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/CrashAfterTurns.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/CrashAfterTurns.java new file mode 100644 index 0000000000..5c3a17d802 --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/CrashAfterTurns.java @@ -0,0 +1,60 @@ +package io.temporal.toolregistry.testing; + +import io.temporal.toolregistry.Provider; +import io.temporal.toolregistry.ToolDefinition; +import io.temporal.toolregistry.TurnResult; +import java.util.List; +import java.util.Map; + +/** + * A {@link Provider} that throws a {@link RuntimeException} after a fixed number of turns. + * + *

Use this to simulate a crash mid-loop for checkpoint recovery tests: + * + *

{@code
+ * // Crash after 2 turns; each "turn" delegates to an inner provider.
+ * Provider inner = new MockProvider(...);
+ * Provider crasher = new CrashAfterTurns(2, inner);
+ * }
+ */ +public class CrashAfterTurns implements Provider { + + private final int n; + private final Provider delegate; + private int count = 0; + + /** + * Creates a provider that throws after {@code n} successful turns. + * + * @param n number of turns to allow before crashing + * @param delegate the underlying provider to delegate to + */ + public CrashAfterTurns(int n, Provider delegate) { + this.n = n; + this.delegate = delegate; + } + + /** + * Creates a provider that throws after {@code n} successful turns, with no delegate (returns a + * done result immediately until the crash). + * + * @param n number of turns to allow before crashing + */ + public CrashAfterTurns(int n) { + this(n, null); + } + + @Override + public TurnResult runTurn(List> messages, List tools) + throws Exception { + count++; + if (count > n) { + throw new RuntimeException("CrashAfterTurns: crashed after " + n + " turn(s)"); + } + if (delegate != null) { + return delegate.runTurn(messages, tools); + } + // No delegate: return done immediately. + return new TurnResult(new java.util.ArrayList<>(), true); + } +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/DispatchCall.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/DispatchCall.java new file mode 100644 index 0000000000..bd75efebc8 --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/DispatchCall.java @@ -0,0 +1,30 @@ +package io.temporal.toolregistry.testing; + +import java.util.Map; + +/** Records a single tool dispatch call made through {@link FakeToolRegistry}. */ +public final class DispatchCall { + + private final String name; + private final Map input; + + DispatchCall(String name, Map input) { + this.name = name; + this.input = input; + } + + /** The name of the tool that was called. */ + public String getName() { + return name; + } + + /** The input map passed to the tool. */ + public Map getInput() { + return input; + } + + @Override + public String toString() { + return "DispatchCall{name=" + name + ", input=" + input + "}"; + } +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/FakeToolRegistry.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/FakeToolRegistry.java new file mode 100644 index 0000000000..57f2cdf19f --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/FakeToolRegistry.java @@ -0,0 +1,44 @@ +package io.temporal.toolregistry.testing; + +import io.temporal.toolregistry.ToolRegistry; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * A {@link ToolRegistry} subclass that records every dispatch call. + * + *

Use this in tests to verify which tools were called and with what inputs: + * + *

{@code
+ * FakeToolRegistry fake = new FakeToolRegistry();
+ * fake.register(myDef, input -> "ok");
+ *
+ * // ... run the loop ...
+ *
+ * assertEquals(1, fake.getCalls().size());
+ * assertEquals("my_tool", fake.getCalls().get(0).getName());
+ * }
+ */ +public class FakeToolRegistry extends ToolRegistry { + + private final List calls = new ArrayList<>(); + + /** Dispatches to the registered handler and records the call. */ + @Override + public String dispatch(String name, Map input) throws Exception { + calls.add(new DispatchCall(name, input)); + return super.dispatch(name, input); + } + + /** Returns all dispatch calls recorded so far, in order. */ + public List getCalls() { + return Collections.unmodifiableList(calls); + } + + /** Clears the recorded call history. */ + public void clearCalls() { + calls.clear(); + } +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/MockAgenticSession.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/MockAgenticSession.java new file mode 100644 index 0000000000..b0bda77c3c --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/MockAgenticSession.java @@ -0,0 +1,64 @@ +package io.temporal.toolregistry.testing; + +import io.temporal.toolregistry.AgenticSession; +import io.temporal.toolregistry.Provider; +import io.temporal.toolregistry.ToolRegistry; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * A no-op {@link AgenticSession} substitute for tests that need an activity to accept a session but + * don't want to run an actual tool loop. + * + *

Instead of calling an LLM, {@link #runToolLoop} records the prompt and optionally returns + * pre-canned issues. + * + *

Example: + * + *

{@code
+ * MockAgenticSession mock = new MockAgenticSession();
+ * mock.getIssues().add(Map.of("description", "pre-seeded issue"));
+ * }
+ */ +public class MockAgenticSession extends AgenticSession { + + private String capturedPrompt; + private final List> mutableIssues = new ArrayList<>(); + + /** + * Records the prompt and returns immediately without calling an LLM. + * + *

Issues can be pre-seeded via {@link #getMutableIssues()}. + */ + @Override + public void runToolLoop(Provider provider, ToolRegistry registry, String system, String prompt) { + this.capturedPrompt = prompt; + } + + /** Returns the prompt that was passed to {@link #runToolLoop}, or {@code null} if not called. */ + public String getCapturedPrompt() { + return capturedPrompt; + } + + /** + * Returns the mutable issues list. Add entries here before running the session to simulate + * pre-existing issues. + */ + public List> getMutableIssues() { + return mutableIssues; + } + + /** + * Returns the issues list (mutable issues + any added via {@link AgenticSession#addIssue}). + * + *

Note: the returned list is a merged snapshot. + */ + @Override + public List> getIssues() { + List> merged = new ArrayList<>(mutableIssues); + merged.addAll(super.getIssues()); + return Collections.unmodifiableList(merged); + } +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/MockProvider.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/MockProvider.java new file mode 100644 index 0000000000..5750f17af3 --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/MockProvider.java @@ -0,0 +1,116 @@ +package io.temporal.toolregistry.testing; + +import io.temporal.toolregistry.Provider; +import io.temporal.toolregistry.ToolDefinition; +import io.temporal.toolregistry.TurnResult; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * A scripted {@link Provider} that replays pre-configured {@link MockResponse} instances in order. + * + *

Useful for testing session logic without making real API calls. + * + *

Example: + * + *

{@code
+ * MockProvider provider = new MockProvider(
+ *     MockResponse.toolCall("flag_issue", Map.of("description", "bug")),
+ *     MockResponse.done("Done.")
+ * );
+ * }
+ * + *

If a {@link FakeToolRegistry} is set via {@link #withRegistry(FakeToolRegistry)}, tool calls + * are dispatched through it so call history is recorded. + */ +public class MockProvider implements Provider { + + private final List responses; + private int index = 0; + private FakeToolRegistry registry; + + /** Creates a MockProvider with the given scripted responses. */ + public MockProvider(MockResponse... responses) { + this.responses = new ArrayList<>(); + for (MockResponse r : responses) { + this.responses.add(r); + } + } + + /** Creates a MockProvider with the given scripted responses. */ + public MockProvider(List responses) { + this.responses = new ArrayList<>(responses); + } + + /** + * Wires up a {@link FakeToolRegistry} so that tool calls from scripted responses are dispatched + * and recorded. + * + * @return {@code this} for chaining + */ + public MockProvider withRegistry(FakeToolRegistry registry) { + this.registry = registry; + return this; + } + + @Override + public TurnResult runTurn(List> messages, List tools) + throws Exception { + if (index >= responses.size()) { + throw new IllegalStateException( + "MockProvider ran out of scripted responses after " + index + " turn(s)"); + } + MockResponse resp = responses.get(index++); + + List> newMessages = new ArrayList<>(); + + // Build the assistant message. + Map assistantMsg = new LinkedHashMap<>(); + assistantMsg.put("role", "assistant"); + assistantMsg.put("content", resp.getContent()); + newMessages.add(assistantMsg); + + if (resp.isStop()) { + return new TurnResult(newMessages, true); + } + + // Dispatch tool calls if a registry is wired up. + List> toolResults = new ArrayList<>(); + for (Map block : resp.getContent()) { + if ("tool_use".equals(block.get("type"))) { + String name = (String) block.get("name"); + String id = (String) block.get("id"); + @SuppressWarnings("unchecked") + Map input = (Map) block.get("input"); + + String result; + if (registry != null) { + try { + result = registry.dispatch(name, input); + } catch (Exception e) { + result = "error: " + e.getMessage(); + } + } else { + result = "ok"; + } + + Map toolResult = new LinkedHashMap<>(); + toolResult.put("type", "tool_result"); + toolResult.put("tool_use_id", id); + toolResult.put("content", result); + toolResults.add(toolResult); + } + } + + if (!toolResults.isEmpty()) { + Map toolResultMsg = new LinkedHashMap<>(); + toolResultMsg.put("role", "user"); + toolResultMsg.put("content", toolResults); + newMessages.add(toolResultMsg); + } + + return new TurnResult(newMessages, false); + } +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/MockResponse.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/MockResponse.java new file mode 100644 index 0000000000..3f24573bfe --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/testing/MockResponse.java @@ -0,0 +1,85 @@ +package io.temporal.toolregistry.testing; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +/** + * A scripted response for {@link MockProvider}. + * + *

Use the factory methods to create responses: + * + *

{@code
+ * // A turn that calls a tool
+ * MockResponse.toolCall("my_tool", Map.of("param", "value"))
+ *
+ * // A final turn that stops the loop
+ * MockResponse.done("All done.")
+ * }
+ */ +public final class MockResponse { + + private final boolean stop; + private final List> content; + + MockResponse(boolean stop, List> content) { + this.stop = stop; + this.content = Collections.unmodifiableList(new ArrayList<>(content)); + } + + /** Returns {@code true} if this response signals the loop to stop. */ + public boolean isStop() { + return stop; + } + + /** The content blocks for the assistant message. */ + public List> getContent() { + return content; + } + + /** + * Creates a response with a single tool-use block. + * + * @param toolName the tool to call + * @param toolInput the tool input + */ + public static MockResponse toolCall(String toolName, Map toolInput) { + return toolCall(toolName, toolInput, UUID.randomUUID().toString()); + } + + /** + * Creates a response with a single tool-use block and an explicit call ID. + * + * @param toolName the tool to call + * @param toolInput the tool input + * @param callId the tool_use id + */ + public static MockResponse toolCall( + String toolName, Map toolInput, String callId) { + Map block = new LinkedHashMap<>(); + block.put("type", "tool_use"); + block.put("id", callId); + block.put("name", toolName); + block.put("input", toolInput); + return new MockResponse(false, Collections.singletonList(block)); + } + + /** + * Creates a final response with optional text content. + * + * @param texts zero or more text strings (joined as separate text blocks) + */ + public static MockResponse done(String... texts) { + List> blocks = new ArrayList<>(); + for (String text : texts) { + Map block = new LinkedHashMap<>(); + block.put("type", "text"); + block.put("text", text); + blocks.add(block); + } + return new MockResponse(true, blocks); + } +} diff --git a/temporal-tool-registry/src/test/java/io/temporal/toolregistry/AgenticSessionTest.java b/temporal-tool-registry/src/test/java/io/temporal/toolregistry/AgenticSessionTest.java new file mode 100644 index 0000000000..a586eb02a6 --- /dev/null +++ b/temporal-tool-registry/src/test/java/io/temporal/toolregistry/AgenticSessionTest.java @@ -0,0 +1,292 @@ +package io.temporal.toolregistry; + +import static org.junit.Assert.*; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityMethod; +import io.temporal.testing.TestActivityEnvironment; +import io.temporal.toolregistry.testing.FakeToolRegistry; +import io.temporal.toolregistry.testing.MockProvider; +import io.temporal.toolregistry.testing.MockResponse; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import org.junit.Test; + +/** + * Tests for {@link AgenticSession}: runToolLoop, checkpoint restore via {@link + * AgenticSession#runWithSession}. + * + *

All session tests execute inside a {@link TestActivityEnvironment} so that {@link + * io.temporal.activity.Activity#getExecutionContext()} is available. + */ +public class AgenticSessionTest { + + // ── test harness ────────────────────────────────────────────────────────────── + + @ActivityInterface + public interface TestOp { + @ActivityMethod + void execute(); + } + + /** + * Runs {@code task} inside a real activity context provided by {@link TestActivityEnvironment}. + */ + private static void runInActivity(Runnable task) { + TestActivityEnvironment env = TestActivityEnvironment.newInstance(); + try { + env.registerActivitiesImplementations( + new TestOp() { + @Override + public void execute() { + task.run(); + } + }); + env.newActivityStub(TestOp.class).execute(); + } finally { + env.close(); + } + } + + // ── runToolLoop ─────────────────────────────────────────────────────────────── + + @Test + public void testFreshStart() { + MockProvider provider = new MockProvider(MockResponse.done("finished")); + ToolRegistry registry = new ToolRegistry(); + List> captured = new ArrayList<>(); + + runInActivity( + () -> { + AgenticSession session = new AgenticSession(); + try { + session.runToolLoop(provider, registry, "sys", "my prompt"); + } catch (Exception e) { + throw new RuntimeException(e); + } + captured.addAll(session.getMessages()); + }); + + assertEquals("user", captured.get(0).get("role")); + assertEquals("my prompt", captured.get(0).get("content")); + assertEquals("assistant", captured.get(1).get("role")); + } + + @Test + public void testResumesExistingMessages() { + // When messages is already populated (retry case), the prompt is NOT prepended again. + MockProvider provider = new MockProvider(MockResponse.done("ok")); + ToolRegistry registry = new ToolRegistry(); + List> captured = new ArrayList<>(); + + runInActivity( + () -> { + AgenticSession session = new AgenticSession(); + // Simulate a partially-restored session. + session.restore( + new SessionCheckpoint( + Arrays.asList( + Map.of("role", "user", "content", "original"), + Map.of( + "role", + "assistant", + "content", + Collections.singletonList(Map.of("type", "text", "text", "thinking")))), + new ArrayList<>())); + try { + session.runToolLoop(provider, registry, "sys", "ignored prompt"); + } catch (Exception e) { + throw new RuntimeException(e); + } + captured.addAll(session.getMessages()); + }); + + assertEquals("original", captured.get(0).get("content")); + assertEquals("assistant", captured.get(1).get("role")); + } + + @Test + public void testWithToolCalls() { + List collected = new ArrayList<>(); + FakeToolRegistry fakeRegistry = new FakeToolRegistry(); + fakeRegistry.register( + ToolDefinition.builder() + .name("collect") + .description("d") + .inputSchema(Collections.singletonMap("type", "object")) + .build(), + input -> { + collected.add((String) input.get("v")); + return "ok"; + }); + + MockProvider provider = + new MockProvider( + MockResponse.toolCall("collect", Collections.singletonMap("v", "first")), + MockResponse.toolCall("collect", Collections.singletonMap("v", "second")), + MockResponse.done("done")) + .withRegistry(fakeRegistry); + + List> captured = new ArrayList<>(); + runInActivity( + () -> { + AgenticSession session = new AgenticSession(); + try { + session.runToolLoop(provider, fakeRegistry, "sys", "go"); + } catch (Exception e) { + throw new RuntimeException(e); + } + captured.addAll(session.getMessages()); + }); + + assertEquals(Arrays.asList("first", "second"), collected); + // user + (assistant + tool_result_wrapper)*2 + final assistant + assertTrue(captured.size() > 4); + } + + @Test + public void testCheckpointOnEachTurn() { + // Verifies runToolLoop heartbeats inside an activity context without error. + ToolRegistry registry = new ToolRegistry(); + MockProvider provider = new MockProvider(MockResponse.done("turn1")); + List> captured = new ArrayList<>(); + + runInActivity( + () -> { + AgenticSession session = new AgenticSession(); + try { + session.runToolLoop(provider, registry, "sys", "prompt"); + } catch (Exception e) { + throw new RuntimeException(e); + } + captured.addAll(session.getMessages()); + }); + + assertFalse(captured.isEmpty()); + } + + // ── runWithSession ──────────────────────────────────────────────────────────── + + @Test + public void testRunWithSession_freshStart() { + MockProvider provider = new MockProvider(MockResponse.done("done")); + ToolRegistry registry = new ToolRegistry(); + List> capturedMessages = new ArrayList<>(); + + runInActivity( + () -> { + try { + AgenticSession.runWithSession( + session -> { + session.runToolLoop(provider, registry, "sys", "hello"); + capturedMessages.addAll(session.getMessages()); + }); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + assertFalse(capturedMessages.isEmpty()); + assertEquals("hello", capturedMessages.get(0).get("content")); + } + + @Test + public void testRunWithSession_restoreFromCheckpoint() { + // Pre-seed a checkpoint so the session is restored on first call. + MockProvider provider = new MockProvider(MockResponse.done("done")); + ToolRegistry registry = new ToolRegistry(); + List> capturedMessages = new ArrayList<>(); + + SessionCheckpoint checkpoint = new SessionCheckpoint(); + checkpoint.messages.add(Map.of("role", "user", "content", "restored prompt")); + + TestActivityEnvironment env = TestActivityEnvironment.newInstance(); + try { + env.setHeartbeatDetails(checkpoint); + env.registerActivitiesImplementations( + new TestOp() { + @Override + public void execute() { + try { + AgenticSession.runWithSession( + session -> { + session.runToolLoop(provider, registry, "sys", "ignored"); + capturedMessages.addAll(session.getMessages()); + }); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }); + env.newActivityStub(TestOp.class).execute(); + } finally { + env.close(); + } + + assertFalse(capturedMessages.isEmpty()); + // The restored message should be the first — not "ignored". + assertEquals("restored prompt", capturedMessages.get(0).get("content")); + } + + // ── Checkpoint round-trip test (T6) ────────────────────────────────────────── + + /** + * Verifies that a SessionCheckpoint with nested tool_calls survives a Jackson JSON + * serialize/deserialize cycle with all fields intact. Guards against the class of bug where + * nested maps lose their type after deserialization (cf. .NET List<object?> bug). + */ + @Test + public void testCheckpoint_RoundTrip() throws Exception { + ObjectMapper mapper = new ObjectMapper(); + + Map fnMap = new LinkedHashMap<>(); + fnMap.put("name", "my_tool"); + fnMap.put("arguments", "{\"x\":1}"); + + Map toolCall = new LinkedHashMap<>(); + toolCall.put("id", "call_abc"); + toolCall.put("type", "function"); + toolCall.put("function", fnMap); + + Map assistantMsg = new LinkedHashMap<>(); + assistantMsg.put("role", "assistant"); + assistantMsg.put("tool_calls", Collections.singletonList(toolCall)); + + Map issue = new LinkedHashMap<>(); + issue.put("type", "smell"); + issue.put("file", "Foo.java"); + + SessionCheckpoint original = new SessionCheckpoint( + Collections.singletonList(assistantMsg), + Collections.singletonList(issue)); + + // Simulate Temporal heartbeat round-trip via Jackson. + String json = mapper.writeValueAsString(original); + SessionCheckpoint restored = mapper.readValue(json, SessionCheckpoint.class); + + assertEquals(1, restored.messages.size()); + assertEquals("assistant", restored.messages.get(0).get("role")); + + // tool_calls must survive as a list of maps after round-trip. + @SuppressWarnings("unchecked") + List> toolCallsRestored = + (List>) restored.messages.get(0).get("tool_calls"); + assertNotNull(toolCallsRestored); + assertEquals(1, toolCallsRestored.size()); + assertEquals("call_abc", toolCallsRestored.get(0).get("id")); + + @SuppressWarnings("unchecked") + Map fnRestored = + (Map) toolCallsRestored.get(0).get("function"); + assertEquals("my_tool", fnRestored.get("name")); + + assertEquals(1, restored.issues.size()); + assertEquals("smell", restored.issues.get(0).get("type")); + assertEquals("Foo.java", restored.issues.get(0).get("file")); + } +} diff --git a/temporal-tool-registry/src/test/java/io/temporal/toolregistry/ToolRegistryTest.java b/temporal-tool-registry/src/test/java/io/temporal/toolregistry/ToolRegistryTest.java new file mode 100644 index 0000000000..73f90fec80 --- /dev/null +++ b/temporal-tool-registry/src/test/java/io/temporal/toolregistry/ToolRegistryTest.java @@ -0,0 +1,265 @@ +package io.temporal.toolregistry; + +import static org.junit.Assert.*; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.junit.Assume; +import org.junit.Test; + +/** Unit tests for {@link ToolRegistry} and {@link ToolRegistry#runToolLoop}. */ +public class ToolRegistryTest { + + // ── dispatch ────────────────────────────────────────────────────────────────── + + @Test + public void testDispatch_basicCall() throws Exception { + ToolRegistry registry = new ToolRegistry(); + registry.register( + ToolDefinition.builder() + .name("greet") + .description("greets a user") + .inputSchema(Collections.singletonMap("type", "object")) + .build(), + input -> "hello " + input.get("name")); + + String result = registry.dispatch("greet", Collections.singletonMap("name", "world")); + assertEquals("hello world", result); + } + + @Test + public void testDispatch_unknownTool() { + ToolRegistry registry = new ToolRegistry(); + try { + registry.dispatch("missing", Collections.emptyMap()); + fail("expected exception"); + } catch (Exception e) { + assertTrue(e.getMessage().contains("missing")); + } + } + + @Test + public void testDispatch_handlerException_propagates() { + ToolRegistry registry = new ToolRegistry(); + registry.register( + ToolDefinition.builder() + .name("boom") + .description("always fails") + .inputSchema(Collections.singletonMap("type", "object")) + .build(), + input -> { + throw new RuntimeException("kaboom"); + }); + + try { + registry.dispatch("boom", Collections.emptyMap()); + fail("expected exception"); + } catch (Exception e) { + assertTrue(e.getMessage().contains("kaboom")); + } + } + + // ── definitions ─────────────────────────────────────────────────────────────── + + @Test + public void testDefinitions_returnsCopy() { + ToolRegistry registry = new ToolRegistry(); + registry.register( + ToolDefinition.builder() + .name("a") + .description("d") + .inputSchema(Collections.singletonMap("type", "object")) + .build(), + input -> "ok"); + + List defs = registry.definitions(); + assertEquals(1, defs.size()); + } + + @Test + public void testDefinitions_multipleTools() { + ToolRegistry registry = new ToolRegistry(); + for (String name : Arrays.asList("alpha", "beta", "gamma")) { + registry.register( + ToolDefinition.builder() + .name(name) + .description("d") + .inputSchema(Collections.singletonMap("type", "object")) + .build(), + input -> "ok"); + } + assertEquals(3, registry.definitions().size()); + assertEquals("alpha", registry.definitions().get(0).getName()); + } + + // ── toAnthropic ─────────────────────────────────────────────────────────────── + + @Test + public void testToAnthropic_structure() { + ToolRegistry registry = new ToolRegistry(); + registry.register( + ToolDefinition.builder() + .name("my_tool") + .description("does something") + .inputSchema(Map.of("type", "object")) + .build(), + input -> "ok"); + + List> result = registry.toAnthropic(); + assertEquals(1, result.size()); + assertEquals("my_tool", result.get(0).get("name")); + assertEquals("does something", result.get(0).get("description")); + assertNotNull(result.get(0).get("input_schema")); + } + + // ── toOpenAI ────────────────────────────────────────────────────────────────── + + @Test + public void testToOpenAI_structure() { + ToolRegistry registry = new ToolRegistry(); + registry.register( + ToolDefinition.builder() + .name("my_tool") + .description("does something") + .inputSchema( + Map.of("type", "object", "properties", Map.of("x", Map.of("type", "string")))) + .build(), + input -> "ok"); + + List> result = registry.toOpenAI(); + assertEquals(1, result.size()); + assertEquals("function", result.get(0).get("type")); + @SuppressWarnings("unchecked") + Map fn = (Map) result.get(0).get("function"); + assertEquals("my_tool", fn.get("name")); + assertEquals("does something", fn.get("description")); + assertNotNull(fn.get("parameters")); + } + + // ── runToolLoop ─────────────────────────────────────────────────────────────── + + @Test + public void testRunToolLoop_singleDone() throws Exception { + ToolRegistry registry = new ToolRegistry(); + io.temporal.toolregistry.testing.MockProvider provider = + new io.temporal.toolregistry.testing.MockProvider( + io.temporal.toolregistry.testing.MockResponse.done("finished")); + + List> msgs = ToolRegistry.runToolLoop(provider, registry, "sys", "hello"); + + // user + assistant + assertEquals(2, msgs.size()); + assertEquals("user", msgs.get(0).get("role")); + assertEquals("hello", msgs.get(0).get("content")); + assertEquals("assistant", msgs.get(1).get("role")); + } + + @Test + public void testRunToolLoop_withToolCall() throws Exception { + java.util.List collected = new java.util.ArrayList<>(); + ToolRegistry registry = new ToolRegistry(); + registry.register( + ToolDefinition.builder() + .name("collect") + .description("d") + .inputSchema(Collections.singletonMap("type", "object")) + .build(), + input -> { + collected.add((String) input.get("v")); + return "ok"; + }); + + io.temporal.toolregistry.testing.MockProvider provider = + new io.temporal.toolregistry.testing.MockProvider( + io.temporal.toolregistry.testing.MockResponse.toolCall( + "collect", Collections.singletonMap("v", "first")), + io.temporal.toolregistry.testing.MockResponse.toolCall( + "collect", Collections.singletonMap("v", "second")), + io.temporal.toolregistry.testing.MockResponse.done("all done")); + provider.withRegistry( + new io.temporal.toolregistry.testing.FakeToolRegistry() { + { + register( + ToolDefinition.builder() + .name("collect") + .description("d") + .inputSchema(Collections.singletonMap("type", "object")) + .build(), + input -> { + collected.add((String) input.get("v")); + return "ok"; + }); + } + }); + + List> msgs = ToolRegistry.runToolLoop(provider, registry, "sys", "go"); + + assertEquals(Arrays.asList("first", "second"), collected); + // user + (assistant + tool_result_wrapper)*2 + final assistant + assertTrue(msgs.size() > 4); + } + + // ── Integration tests (skipped unless RUN_INTEGRATION_TESTS is set) ─────────── + + private static ToolRegistry makeRecordRegistry(List collected) throws Exception { + ToolRegistry registry = new ToolRegistry(); + registry.register( + ToolDefinition.builder() + .name("record") + .description("Record a value") + .inputSchema( + Map.of( + "type", "object", + "properties", Map.of("value", Map.of("type", "string")), + "required", List.of("value"))) + .build(), + input -> { + collected.add((String) input.get("value")); + return "recorded"; + }); + return registry; + } + + @Test + public void testIntegration_Anthropic() throws Exception { + Assume.assumeNotNull(System.getenv("RUN_INTEGRATION_TESTS")); + String apiKey = System.getenv("ANTHROPIC_API_KEY"); + Assume.assumeNotNull(apiKey); + + List collected = new ArrayList<>(); + ToolRegistry registry = makeRecordRegistry(collected); + Provider provider = + new AnthropicProvider( + AnthropicConfig.builder().apiKey(apiKey).build(), + registry, + "You must call record() exactly once with value='hello'."); + + ToolRegistry.runToolLoop( + provider, registry, "", "Please call the record tool with value='hello'."); + + assertTrue("expected 'hello' in collected", collected.contains("hello")); + } + + @Test + public void testIntegration_OpenAI() throws Exception { + Assume.assumeNotNull(System.getenv("RUN_INTEGRATION_TESTS")); + String apiKey = System.getenv("OPENAI_API_KEY"); + Assume.assumeNotNull(apiKey); + + List collected = new ArrayList<>(); + ToolRegistry registry = makeRecordRegistry(collected); + Provider provider = + new OpenAIProvider( + OpenAIConfig.builder().apiKey(apiKey).build(), + registry, + "You must call record() exactly once with value='hello'."); + + ToolRegistry.runToolLoop( + provider, registry, "", "Please call the record tool with value='hello'."); + + assertTrue("expected 'hello' in collected", collected.contains("hello")); + } +} diff --git a/temporal-tool-registry/src/test/java/io/temporal/toolregistry/testing/TestingUtilitiesTest.java b/temporal-tool-registry/src/test/java/io/temporal/toolregistry/testing/TestingUtilitiesTest.java new file mode 100644 index 0000000000..ac61ba963b --- /dev/null +++ b/temporal-tool-registry/src/test/java/io/temporal/toolregistry/testing/TestingUtilitiesTest.java @@ -0,0 +1,240 @@ +package io.temporal.toolregistry.testing; + +import static org.junit.Assert.*; + +import io.temporal.toolregistry.ToolDefinition; +import io.temporal.toolregistry.TurnResult; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.junit.Test; + +/** Unit tests for the testing utilities package (no Temporal server or API key needed). */ +public class TestingUtilitiesTest { + + // ── MockResponse ────────────────────────────────────────────────────────────── + + @Test + public void testMockResponse_done_isStop() { + MockResponse r = MockResponse.done("finished"); + assertTrue(r.isStop()); + assertEquals(1, r.getContent().size()); + assertEquals("text", r.getContent().get(0).get("type")); + assertEquals("finished", r.getContent().get(0).get("text")); + } + + @Test + public void testMockResponse_done_noText() { + MockResponse r = MockResponse.done(); + assertTrue(r.isStop()); + assertTrue(r.getContent().isEmpty()); + } + + @Test + public void testMockResponse_toolCall_structure() { + MockResponse r = MockResponse.toolCall("my_tool", Collections.singletonMap("x", "1"), "id42"); + assertFalse(r.isStop()); + assertEquals(1, r.getContent().size()); + Map block = r.getContent().get(0); + assertEquals("tool_use", block.get("type")); + assertEquals("my_tool", block.get("name")); + assertEquals("id42", block.get("id")); + assertEquals(Collections.singletonMap("x", "1"), block.get("input")); + } + + @Test + public void testMockResponse_toolCall_generatesId() { + MockResponse r1 = MockResponse.toolCall("t", Collections.emptyMap()); + MockResponse r2 = MockResponse.toolCall("t", Collections.emptyMap()); + String id1 = (String) r1.getContent().get(0).get("id"); + String id2 = (String) r2.getContent().get(0).get("id"); + assertNotNull(id1); + assertNotNull(id2); + assertNotEquals(id1, id2); + } + + // ── MockProvider ───────────────────────────────────────────────────────────── + + @Test + public void testMockProvider_returnsResponsesInOrder() throws Exception { + MockProvider provider = + new MockProvider(MockResponse.done("first"), MockResponse.done("second")); + + List> msgs = new ArrayList<>(); + msgs.add(Map.of("role", "user", "content", "hello")); + + TurnResult t1 = provider.runTurn(msgs, Collections.emptyList()); + assertTrue(t1.isDone()); + @SuppressWarnings("unchecked") + List> content1 = + (List>) t1.getNewMessages().get(0).get("content"); + assertEquals("first", content1.get(0).get("text")); + + msgs.addAll(t1.getNewMessages()); + TurnResult t2 = provider.runTurn(msgs, Collections.emptyList()); + assertTrue(t2.isDone()); + @SuppressWarnings("unchecked") + List> content2 = + (List>) t2.getNewMessages().get(0).get("content"); + assertEquals("second", content2.get(0).get("text")); + } + + @Test + public void testMockProvider_throwsWhenExhausted() { + MockProvider provider = new MockProvider(MockResponse.done("only one")); + List> msgs = + Collections.singletonList(Map.of("role", "user", "content", "x")); + try { + provider.runTurn(msgs, Collections.emptyList()); + provider.runTurn(msgs, Collections.emptyList()); // second call should throw + fail("expected exception"); + } catch (Exception e) { + assertTrue(e.getMessage().contains("ran out")); + } + } + + @Test + public void testMockProvider_toolCallDispatches() throws Exception { + FakeToolRegistry fake = new FakeToolRegistry(); + fake.register( + ToolDefinition.builder() + .name("my_tool") + .description("d") + .inputSchema(Collections.singletonMap("type", "object")) + .build(), + input -> "result-" + input.get("v")); + + MockProvider provider = + new MockProvider( + MockResponse.toolCall("my_tool", Collections.singletonMap("v", "42"), "call1"), + MockResponse.done("done")) + .withRegistry(fake); + + List> msgs = new ArrayList<>(); + msgs.add(Map.of("role", "user", "content", "go")); + + TurnResult t1 = provider.runTurn(msgs, Collections.emptyList()); + assertFalse(t1.isDone()); + // The turn should produce: assistant message + tool_result wrapper + assertEquals(2, t1.getNewMessages().size()); + + // Verify the tool was dispatched with the right input. + assertEquals(1, fake.getCalls().size()); + assertEquals("my_tool", fake.getCalls().get(0).getName()); + assertEquals(Collections.singletonMap("v", "42"), fake.getCalls().get(0).getInput()); + } + + // ── FakeToolRegistry ───────────────────────────────────────────────────────── + + @Test + public void testFakeToolRegistry_recordsCalls() throws Exception { + FakeToolRegistry fake = new FakeToolRegistry(); + fake.register( + ToolDefinition.builder() + .name("fn") + .description("d") + .inputSchema(Collections.singletonMap("type", "object")) + .build(), + input -> "ok"); + + fake.dispatch("fn", Collections.singletonMap("a", "b")); + fake.dispatch("fn", Collections.singletonMap("a", "c")); + + assertEquals(2, fake.getCalls().size()); + assertEquals("fn", fake.getCalls().get(0).getName()); + assertEquals(Collections.singletonMap("a", "b"), fake.getCalls().get(0).getInput()); + assertEquals(Collections.singletonMap("a", "c"), fake.getCalls().get(1).getInput()); + } + + @Test + public void testFakeToolRegistry_clearCalls() throws Exception { + FakeToolRegistry fake = new FakeToolRegistry(); + fake.register( + ToolDefinition.builder() + .name("fn") + .description("d") + .inputSchema(Collections.singletonMap("type", "object")) + .build(), + input -> "ok"); + fake.dispatch("fn", Collections.emptyMap()); + assertEquals(1, fake.getCalls().size()); + + fake.clearCalls(); + assertEquals(0, fake.getCalls().size()); + } + + // ── CrashAfterTurns ────────────────────────────────────────────────────────── + + @Test + public void testCrashAfterTurns_crashesAtRightTime() { + CrashAfterTurns crasher = new CrashAfterTurns(2); + List> msgs = + Collections.singletonList(Map.of("role", "user", "content", "x")); + + try { + crasher.runTurn(msgs, Collections.emptyList()); // turn 1 — OK + crasher.runTurn(msgs, Collections.emptyList()); // turn 2 — OK + crasher.runTurn(msgs, Collections.emptyList()); // turn 3 — crash + fail("expected exception"); + } catch (Exception e) { + assertTrue(e.getMessage().contains("crashed after 2")); + } + } + + @Test + public void testCrashAfterTurns_withDelegate() throws Exception { + MockProvider inner = new MockProvider(MockResponse.done("t1"), MockResponse.done("t2")); + CrashAfterTurns crasher = new CrashAfterTurns(1, inner); + List> msgs = + Collections.singletonList(Map.of("role", "user", "content", "x")); + + TurnResult result = crasher.runTurn(msgs, Collections.emptyList()); // delegates + assertTrue(result.isDone()); + + try { + crasher.runTurn(msgs, Collections.emptyList()); // crashes + fail("expected exception"); + } catch (Exception e) { + assertTrue(e.getMessage().contains("crashed")); + } + } + + // ── MockAgenticSession ──────────────────────────────────────────────────────── + + @Test + public void testMockAgenticSession_capturesPrompt() { + MockAgenticSession mock = new MockAgenticSession(); + mock.runToolLoop(null, null, "sys", "the prompt"); + assertEquals("the prompt", mock.getCapturedPrompt()); + } + + @Test + public void testMockAgenticSession_preSeedIssues() { + MockAgenticSession mock = new MockAgenticSession(); + mock.getMutableIssues().add(Collections.singletonMap("desc", "pre-existing")); + + List> issues = mock.getIssues(); + assertEquals(1, issues.size()); + assertEquals("pre-existing", issues.get(0).get("desc")); + } + + @Test + public void testMockAgenticSession_doesNotCallProvider() { + // runToolLoop should be a no-op — no exceptions from null provider. + MockAgenticSession mock = new MockAgenticSession(); + mock.runToolLoop(null, null, null, "prompt"); + assertEquals("prompt", mock.getCapturedPrompt()); + } + + // ── DispatchCall ───────────────────────────────────────────────────────────── + + @Test + public void testDispatchCall_getters() { + Map input = Collections.singletonMap("k", "v"); + DispatchCall call = new DispatchCall("tool_name", input); + assertEquals("tool_name", call.getName()); + assertEquals(input, call.getInput()); + assertTrue(call.toString().contains("tool_name")); + } +} From 8445d8602843d19af87ec2ae2bbe196b5901442a Mon Sep 17 00:00:00 2001 From: lex00 <121451605+lex00@users.noreply.github.com> Date: Sun, 12 Apr 2026 22:41:01 -0600 Subject: [PATCH 2/5] Add MCP tool-wrapping support to ToolRegistry Adds McpTool POJO and ToolRegistry.fromMcpTools static method that converts a list of MCP tool descriptors into a populated ToolRegistry. Handlers default to no-ops; callers override with register after construction. Null inputSchema is normalized to an empty object schema. Co-Authored-By: Claude Sonnet 4.6 --- temporal-tool-registry/README.md | 19 +++++++++ .../io/temporal/toolregistry/McpTool.java | 42 +++++++++++++++++++ .../temporal/toolregistry/ToolRegistry.java | 26 ++++++++++++ .../toolregistry/ToolRegistryTest.java | 26 ++++++++++++ 4 files changed, 113 insertions(+) create mode 100644 temporal-tool-registry/src/main/java/io/temporal/toolregistry/McpTool.java diff --git a/temporal-tool-registry/README.md b/temporal-tool-registry/README.md index 685d905601..9c64919bcf 100644 --- a/temporal-tool-registry/README.md +++ b/temporal-tool-registry/README.md @@ -200,3 +200,22 @@ Recommended timeouts: |---|---| | Standard (Claude 3.x, GPT-4o) | 30 s | | Reasoning (o1, o3, extended thinking) | 300 s | + +## MCP integration + +`ToolRegistry.fromMcpTools` converts a list of `McpTool` descriptors into a populated +registry. Handlers default to no-ops that return an empty string; override them with +`register` after construction. + +```java +// mcpTools is List — populate from your MCP client. +ToolRegistry registry = ToolRegistry.fromMcpTools(mcpTools); + +// Override specific handlers before running the loop. +registry.register( + ToolDefinition.builder().name("read_file") /* ... */ .build(), + input -> readFile((String) input.get("path"))); +``` + +`McpTool` mirrors the MCP protocol's `Tool` object: `name`, `description`, and +`inputSchema` (a `Map` containing a JSON Schema object). diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/McpTool.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/McpTool.java new file mode 100644 index 0000000000..b8d01d25ba --- /dev/null +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/McpTool.java @@ -0,0 +1,42 @@ +package io.temporal.toolregistry; + +import java.util.Map; +import java.util.Objects; + +/** + * MCP-compatible tool descriptor. + * + *

Any MCP {@code Tool} object can be adapted to this class by copying its {@code name}, + * {@code description}, and {@code inputSchema} fields. + */ +public final class McpTool { + + private final String name; + private final String description; + private final Map inputSchema; + + /** + * Creates an MCP tool descriptor. + * + * @param name tool name + * @param description human-readable description + * @param inputSchema JSON Schema for the tool's input object (may be {@code null}) + */ + public McpTool(String name, String description, Map inputSchema) { + this.name = Objects.requireNonNull(name, "name"); + this.description = description != null ? description : ""; + this.inputSchema = inputSchema; + } + + public String getName() { + return name; + } + + public String getDescription() { + return description; + } + + public Map getInputSchema() { + return inputSchema; + } +} diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolRegistry.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolRegistry.java index 5cb9a1555a..d2246c8e3c 100644 --- a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolRegistry.java +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolRegistry.java @@ -39,6 +39,32 @@ public class ToolRegistry { private final List defs = new ArrayList<>(); private final Map handlers = new HashMap<>(); + /** + * Creates a {@link ToolRegistry} from a list of MCP tool descriptors. + * + *

Each tool is registered with a no-op handler (returning an empty string). Override handlers + * by calling {@link #register} with the same name after construction. + * + * @param tools MCP tool descriptors + * @return a new registry populated from the MCP tool list + */ + public static ToolRegistry fromMcpTools(List tools) { + ToolRegistry registry = new ToolRegistry(); + Map emptySchema = Map.of("type", "object", "properties", Map.of()); + for (McpTool tool : tools) { + Map schema = + tool.getInputSchema() != null ? tool.getInputSchema() : emptySchema; + registry.register( + ToolDefinition.builder() + .name(tool.getName()) + .description(tool.getDescription()) + .inputSchema(schema) + .build(), + input -> ""); + } + return registry; + } + /** Registers a tool definition and its handler. */ public void register(ToolDefinition definition, ToolHandler handler) { defs.add(definition); diff --git a/temporal-tool-registry/src/test/java/io/temporal/toolregistry/ToolRegistryTest.java b/temporal-tool-registry/src/test/java/io/temporal/toolregistry/ToolRegistryTest.java index 73f90fec80..c31046fba2 100644 --- a/temporal-tool-registry/src/test/java/io/temporal/toolregistry/ToolRegistryTest.java +++ b/temporal-tool-registry/src/test/java/io/temporal/toolregistry/ToolRegistryTest.java @@ -202,6 +202,32 @@ public void testRunToolLoop_withToolCall() throws Exception { assertTrue(msgs.size() > 4); } + // ── fromMcpTools ────────────────────────────────────────────────────────────── + + @Test + public void testFromMcpTools_populatesRegistry() throws Exception { + McpTool t1 = + new McpTool( + "read_file", + "Read a file", + Map.of( + "type", "object", + "properties", Map.of("path", Map.of("type", "string")), + "required", List.of("path"))); + McpTool t2 = new McpTool("list_dir", null, null); // null schema → empty object schema + + ToolRegistry reg = ToolRegistry.fromMcpTools(Arrays.asList(t1, t2)); + + List defs = reg.definitions(); + assertEquals(2, defs.size()); + assertEquals("read_file", defs.get(0).getName()); + assertEquals("Read a file", defs.get(0).getDescription()); + assertEquals("list_dir", defs.get(1).getName()); + assertEquals("object", defs.get(1).getInputSchema().get("type")); + // no-op handler returns empty string + assertEquals("", reg.dispatch("read_file", Map.of("path", "/etc/hosts"))); + } + // ── Integration tests (skipped unless RUN_INTEGRATION_TESTS is set) ─────────── private static ToolRegistry makeRecordRegistry(List collected) throws Exception { From d943985db11eb8de20bda20617c6753515de5228 Mon Sep 17 00:00:00 2001 From: lex00 <121451605+lex00@users.noreply.github.com> Date: Sun, 12 Apr 2026 23:23:26 -0600 Subject: [PATCH 3/5] fix: add is_error handling to AnthropicProvider, add provider error test - Set is_error=Boolean.TRUE on Anthropic tool result maps when a handler throws, matching the Anthropic API spec; OpenAI has no equivalent field - Add testAnthropicProvider_HandlerError_SetsIsError using an in-process HTTP server to verify is_error propagation without a real API key - Update README to clarify positioning vs Python/TypeScript framework plugins Co-Authored-By: Claude Sonnet 4.6 --- temporal-tool-registry/README.md | 2 + .../toolregistry/AnthropicProvider.java | 5 ++ .../toolregistry/ToolRegistryTest.java | 71 +++++++++++++++++++ 3 files changed, 78 insertions(+) diff --git a/temporal-tool-registry/README.md b/temporal-tool-registry/README.md index 9c64919bcf..8effeba116 100644 --- a/temporal-tool-registry/README.md +++ b/temporal-tool-registry/README.md @@ -13,6 +13,8 @@ A Temporal Activity is a function that Temporal monitors and retries automatical New to Temporal? → https://docs.temporal.io/develop +**Python or TypeScript user?** Those SDKs also ship framework-level integrations (`openai_agents`, `google_adk_agents`, `langgraph`, `@temporalio/ai-sdk`) for teams already using a specific agent framework. ToolRegistry is the equivalent story for direct Anthropic/OpenAI calls, and shares the same API surface across all six Temporal SDKs. + ## Install Add to your `build.gradle`: diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/AnthropicProvider.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/AnthropicProvider.java index fbddaae718..402f3c9aa0 100644 --- a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/AnthropicProvider.java +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/AnthropicProvider.java @@ -119,15 +119,20 @@ public TurnResult runTurn(List> messages, List input = (Map) call.get("input"); String result; + boolean isError = false; try { result = registry.dispatch(name, input); } catch (Exception e) { result = "error: " + e.getMessage(); + isError = true; } Map toolResult = new LinkedHashMap<>(); toolResult.put("type", "tool_result"); toolResult.put("tool_use_id", id); toolResult.put("content", result); + if (isError) { + toolResult.put("is_error", Boolean.TRUE); + } toolResults.add(toolResult); } Map toolResultMsg = new LinkedHashMap<>(); diff --git a/temporal-tool-registry/src/test/java/io/temporal/toolregistry/ToolRegistryTest.java b/temporal-tool-registry/src/test/java/io/temporal/toolregistry/ToolRegistryTest.java index c31046fba2..753f9bd570 100644 --- a/temporal-tool-registry/src/test/java/io/temporal/toolregistry/ToolRegistryTest.java +++ b/temporal-tool-registry/src/test/java/io/temporal/toolregistry/ToolRegistryTest.java @@ -228,6 +228,77 @@ public void testFromMcpTools_populatesRegistry() throws Exception { assertEquals("", reg.dispatch("read_file", Map.of("path", "/etc/hosts"))); } + // ── AnthropicProvider is_error / handler error tests ───────────────────────── + + /** + * Verifies that when a tool handler throws, the Anthropic tool result carries is_error=true and + * the turn does not propagate the exception. + */ + @Test + public void testAnthropicProvider_HandlerError_SetsIsError() throws Exception { + // Start a minimal HTTP server to mock the Anthropic API. + com.sun.net.httpserver.HttpServer server = + com.sun.net.httpserver.HttpServer.create(new java.net.InetSocketAddress(0), 0); + + server.createContext( + "/", + exchange -> { + String body = + "{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\"," + + "\"content\":[{\"type\":\"tool_use\",\"id\":\"c1\"," + + "\"name\":\"boom\",\"input\":{}}]," + + "\"model\":\"claude-sonnet-4-6\",\"stop_reason\":\"tool_use\"," + + "\"usage\":{\"input_tokens\":10,\"output_tokens\":5}}"; + byte[] bytes = body.getBytes(java.nio.charset.StandardCharsets.UTF_8); + exchange.getResponseHeaders().set("Content-Type", "application/json"); + exchange.sendResponseHeaders(200, bytes.length); + try (java.io.OutputStream os = exchange.getResponseBody()) { + os.write(bytes); + } + }); + server.start(); + + int port = server.getAddress().getPort(); + String baseUrl = "http://localhost:" + port; + + ToolRegistry registry = new ToolRegistry(); + registry.register( + ToolDefinition.builder() + .name("boom") + .description("d") + .inputSchema(Collections.singletonMap("type", "object")) + .build(), + input -> { + throw new RuntimeException("intentional failure"); + }); + + Provider provider = + new AnthropicProvider( + AnthropicConfig.builder().apiKey("test-key").baseUrl(baseUrl).build(), + registry, + "sys"); + + List> messages = new ArrayList<>(); + messages.add(new java.util.LinkedHashMap<>(Map.of("role", "user", "content", "go"))); + + TurnResult result = provider.runTurn(messages, registry.definitions()); + server.stop(0); + + assertFalse(result.isDone()); + assertEquals(2, result.getNewMessages().size()); + + Map toolResultMsg = result.getNewMessages().get(1); + assertEquals("user", toolResultMsg.get("role")); + @SuppressWarnings("unchecked") + List> toolResults = + (List>) toolResultMsg.get("content"); + assertEquals(1, toolResults.size()); + assertEquals("tool_result", toolResults.get(0).get("type")); + assertEquals(Boolean.TRUE, toolResults.get(0).get("is_error")); + String content = (String) toolResults.get(0).get("content"); + assertTrue("error message should contain failure text", content.contains("intentional failure")); + } + // ── Integration tests (skipped unless RUN_INTEGRATION_TESTS is set) ─────────── private static ToolRegistry makeRecordRegistry(List collected) throws Exception { From f8345bbc686690cc69ad5e11ca46f8218e961c18 Mon Sep 17 00:00:00 2001 From: lex00 <121451605+lex00@users.noreply.github.com> Date: Mon, 13 Apr 2026 00:14:50 -0600 Subject: [PATCH 4/5] =?UTF-8?q?feat(tool-registry):=20rename=20issues?= =?UTF-8?q?=E2=86=92results,=20remove=20system=20param,=20timeout=20docs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename AgenticSession.issues → results, getIssues() → getResults(), addIssue() → addResult(); rename SessionCheckpoint JSON key 'issues' → 'results' - Remove unused system parameter from AgenticSession.runToolLoop and ToolRegistry.runToolLoop - Add ScheduleToCloseTimeout guidance to README - Update all test call sites Co-Authored-By: Claude Sonnet 4.6 --- temporal-tool-registry/README.md | 39 +++++++++++++------ .../temporal/toolregistry/AgenticSession.java | 39 +++++++++---------- .../toolregistry/SessionCheckpoint.java | 6 +-- .../temporal/toolregistry/ToolRegistry.java | 3 +- .../toolregistry/ToolRegistryTest.java | 8 ++-- 5 files changed, 54 insertions(+), 41 deletions(-) diff --git a/temporal-tool-registry/README.md b/temporal-tool-registry/README.md index 8effeba116..8afe531ad8 100644 --- a/temporal-tool-registry/README.md +++ b/temporal-tool-registry/README.md @@ -60,7 +60,7 @@ public List analyze(String prompt) throws Exception { Provider provider = new AnthropicProvider(cfg, registry, "You are a code reviewer. Call flag_issue for each problem you find."); - ToolRegistry.runToolLoop(provider, registry, "" /* system prompt: "" defers to provider default */, prompt); + ToolRegistry.runToolLoop(provider, registry, prompt); return issues; } ``` @@ -97,23 +97,23 @@ For multi-turn LLM conversations that must survive activity retries, use ```java @ActivityMethod public List longAnalysis(String prompt) throws Exception { - List issues = new ArrayList<>(); + List results = new ArrayList<>(); AgenticSession.runWithSession(session -> { ToolRegistry registry = new ToolRegistry(); registry.register( ToolDefinition.builder().name("flag").description("...").inputSchema(Map.of("type", "object")).build(), - input -> { session.addIssue(input); return "ok"; /* sent back to LLM */ }); + input -> { session.addResult(input); return "ok"; /* sent back to LLM */ }); AnthropicConfig cfg = AnthropicConfig.builder() .apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); Provider provider = new AnthropicProvider(cfg, registry, "your system prompt"); - session.runToolLoop(provider, registry, "your system prompt", prompt); - issues.addAll(session.getIssues()); // capture after loop completes + session.runToolLoop(provider, registry, prompt); + results.addAll(session.getResults()); // capture after loop completes }); - return issues; + return results; } ``` @@ -135,7 +135,7 @@ public void testAnalyze() throws Exception { MockResponse.done("analysis complete")); List> msgs = - ToolRegistry.runToolLoop(provider, registry, "sys", "analyze"); + ToolRegistry.runToolLoop(provider, registry, "analyze"); assertTrue(msgs.size() > 2); } ``` @@ -156,7 +156,7 @@ incur billing — expect a few cents per full test run. ## Storing application results -`session.getIssues()` accumulates application-level +`session.getResults()` accumulates application-level results during the tool loop. Elements are serialized to JSON inside each heartbeat checkpoint — they must be plain maps/dicts with JSON-serializable values. A non-serializable value raises a non-retryable `ApplicationError` at heartbeat time rather than silently @@ -167,17 +167,17 @@ losing data on the next retry. Convert your domain type to a plain dict at the tool-call site and back after the session: ```java -record Issue(String type, String file) {} +record Result(String type, String file) {} // Inside tool handler: -session.addIssue(Map.of("type", "smell", "file", "Foo.java")); +session.addResult(Map.of("type", "smell", "file", "Foo.java")); // After session (using Jackson for convenient mapping): // requires jackson-databind in your build.gradle: // implementation 'com.fasterxml.jackson.core:jackson-databind:VERSION' ObjectMapper mapper = new ObjectMapper(); -List issues = session.getIssues().stream() - .map(m -> mapper.convertValue(m, Issue.class)) +List results = session.getResults().stream() + .map(m -> mapper.convertValue(m, Result.class)) .toList(); ``` @@ -203,6 +203,21 @@ Recommended timeouts: | Standard (Claude 3.x, GPT-4o) | 30 s | | Reasoning (o1, o3, extended thinking) | 300 s | +### Activity-level timeout + +Set `setScheduleToCloseTimeout` on the activity stub options to bound the entire conversation: + +```java +ActivityOptions opts = ActivityOptions.newBuilder() + .setScheduleToCloseTimeout(Duration.ofMinutes(10)) + .build(); +MyActivities stub = Workflow.newActivityStub(MyActivities.class, opts); +``` + +The per-turn client timeout and `ScheduleToCloseTimeout` are complementary: +- Per-turn timeout fires if one LLM call hangs (protects against a single stuck turn) +- `ScheduleToCloseTimeout` bounds the entire conversation including all retries (protects against runaway multi-turn loops) + ## MCP integration `ToolRegistry.fromMcpTools` converts a list of `McpTool` descriptors into a populated diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/AgenticSession.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/AgenticSession.java index b82a6fee15..acd16912f0 100644 --- a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/AgenticSession.java +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/AgenticSession.java @@ -13,15 +13,15 @@ import org.slf4j.LoggerFactory; /** - * Maintains conversation state (messages and issues) across multiple turns of a tool-calling loop, - * with heartbeat checkpointing for crash recovery. + * Maintains conversation state (messages and results) across multiple turns of a tool-calling + * loop, with heartbeat checkpointing for crash recovery. * *

Use {@link #runWithSession(SessionFn)} inside a Temporal activity to get automatic checkpoint * restore-on-retry and heartbeat on each turn: * *

{@code
  * AgenticSession.runWithSession(session -> {
- *     session.runToolLoop(provider, registry, systemPrompt, userPrompt);
+ *     session.runToolLoop(provider, registry, userPrompt);
  * });
  * }
* @@ -32,7 +32,7 @@ public class AgenticSession { private static final Logger log = LoggerFactory.getLogger(AgenticSession.class); private final List> messages = new ArrayList<>(); - private final List> issues = new ArrayList<>(); + private final List> results = new ArrayList<>(); /** Creates an empty session. */ public AgenticSession() {} @@ -50,13 +50,12 @@ public AgenticSession() {} * * @param provider the LLM provider adapter * @param registry the tool registry - * @param system the system prompt * @param prompt the initial user prompt (ignored if restoring from a checkpoint that already has * messages) * @throws ActivityCompletionException if the activity is cancelled * @throws Exception on API or dispatch errors */ - public void runToolLoop(Provider provider, ToolRegistry registry, String system, String prompt) + public void runToolLoop(Provider provider, ToolRegistry registry, String prompt) throws Exception { if (messages.isEmpty()) { Map userMsg = new java.util.LinkedHashMap<>(); @@ -83,16 +82,16 @@ public void runToolLoop(Provider provider, ToolRegistry registry, String system, * *

Throws {@link ActivityCompletionException} if the activity has been cancelled. * - * @throws ApplicationFailure (non-retryable) if any issue is not JSON-serializable + * @throws ApplicationFailure (non-retryable) if any result is not JSON-serializable */ public void checkpoint() throws ActivityCompletionException { ObjectMapper mapper = new ObjectMapper(); - for (int i = 0; i < issues.size(); i++) { + for (int i = 0; i < results.size(); i++) { try { - mapper.writeValueAsString(issues.get(i)); + mapper.writeValueAsString(results.get(i)); } catch (Exception e) { throw ApplicationFailure.newNonRetryableFailure( - "AgenticSession: issues[" + "AgenticSession: results[" + i + "] is not JSON-serializable: " + e.getMessage() @@ -100,7 +99,7 @@ public void checkpoint() throws ActivityCompletionException { "InvalidArgument"); } } - SessionCheckpoint cp = new SessionCheckpoint(messages, issues); + SessionCheckpoint cp = new SessionCheckpoint(messages, results); Activity.getExecutionContext().heartbeat(cp); } @@ -114,7 +113,7 @@ public void checkpoint() throws ActivityCompletionException { * *

{@code
    * AgenticSession.runWithSession(session -> {
-   *     session.runToolLoop(provider, registry, systemPrompt, userPrompt);
+   *     session.runToolLoop(provider, registry, userPrompt);
    * });
    * }
* @@ -152,21 +151,21 @@ public List> getMessages() { return Collections.unmodifiableList(messages); } - /** Returns an unmodifiable view of the issues collected during the session. */ - public List> getIssues() { - return Collections.unmodifiableList(issues); + /** Returns an unmodifiable view of the results collected during the session. */ + public List> getResults() { + return Collections.unmodifiableList(results); } - /** Appends an issue to the issue list. */ - public void addIssue(Map issue) { - issues.add(issue); + /** Appends a result to the results list. */ + public void addResult(Map result) { + results.add(result); } /** Restores session state from a checkpoint. Called by {@link #runWithSession} on retry. */ void restore(SessionCheckpoint checkpoint) { messages.clear(); messages.addAll(checkpoint.messages); - issues.clear(); - issues.addAll(checkpoint.issues); + results.clear(); + results.addAll(checkpoint.results); } } diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/SessionCheckpoint.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/SessionCheckpoint.java index 79562be8bd..07f3010286 100644 --- a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/SessionCheckpoint.java +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/SessionCheckpoint.java @@ -15,13 +15,13 @@ class SessionCheckpoint { public int version = 1; public List> messages = new ArrayList<>(); - public List> issues = new ArrayList<>(); + public List> results = new ArrayList<>(); /** No-arg constructor required for Jackson deserialization. */ SessionCheckpoint() {} - SessionCheckpoint(List> messages, List> issues) { + SessionCheckpoint(List> messages, List> results) { this.messages = new ArrayList<>(messages); - this.issues = new ArrayList<>(issues); + this.results = new ArrayList<>(results); } } diff --git a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolRegistry.java b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolRegistry.java index d2246c8e3c..b9669e9a99 100644 --- a/temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolRegistry.java +++ b/temporal-tool-registry/src/main/java/io/temporal/toolregistry/ToolRegistry.java @@ -139,12 +139,11 @@ public List> toOpenAI() { * * @param provider the LLM provider adapter * @param registry the tool registry (may be the same object, provided for clarity) - * @param system the system prompt * @param prompt the initial user prompt * @return the full message history on completion */ public static List> runToolLoop( - Provider provider, ToolRegistry registry, String system, String prompt) throws Exception { + Provider provider, ToolRegistry registry, String prompt) throws Exception { List> messages = new ArrayList<>(); Map userMsg = new LinkedHashMap<>(); userMsg.put("role", "user"); diff --git a/temporal-tool-registry/src/test/java/io/temporal/toolregistry/ToolRegistryTest.java b/temporal-tool-registry/src/test/java/io/temporal/toolregistry/ToolRegistryTest.java index 753f9bd570..3821e6a7e5 100644 --- a/temporal-tool-registry/src/test/java/io/temporal/toolregistry/ToolRegistryTest.java +++ b/temporal-tool-registry/src/test/java/io/temporal/toolregistry/ToolRegistryTest.java @@ -148,7 +148,7 @@ public void testRunToolLoop_singleDone() throws Exception { new io.temporal.toolregistry.testing.MockProvider( io.temporal.toolregistry.testing.MockResponse.done("finished")); - List> msgs = ToolRegistry.runToolLoop(provider, registry, "sys", "hello"); + List> msgs = ToolRegistry.runToolLoop(provider, registry, "hello"); // user + assistant assertEquals(2, msgs.size()); @@ -195,7 +195,7 @@ public void testRunToolLoop_withToolCall() throws Exception { } }); - List> msgs = ToolRegistry.runToolLoop(provider, registry, "sys", "go"); + List> msgs = ToolRegistry.runToolLoop(provider, registry, "go"); assertEquals(Arrays.asList("first", "second"), collected); // user + (assistant + tool_result_wrapper)*2 + final assistant @@ -335,7 +335,7 @@ public void testIntegration_Anthropic() throws Exception { "You must call record() exactly once with value='hello'."); ToolRegistry.runToolLoop( - provider, registry, "", "Please call the record tool with value='hello'."); + provider, registry, "Please call the record tool with value='hello'."); assertTrue("expected 'hello' in collected", collected.contains("hello")); } @@ -355,7 +355,7 @@ public void testIntegration_OpenAI() throws Exception { "You must call record() exactly once with value='hello'."); ToolRegistry.runToolLoop( - provider, registry, "", "Please call the record tool with value='hello'."); + provider, registry, "Please call the record tool with value='hello'."); assertTrue("expected 'hello' in collected", collected.contains("hello")); } From 5de5391d40c89d61b136080883c94ceba58cc7cf Mon Sep 17 00:00:00 2001 From: lex00 <121451605+lex00@users.noreply.github.com> Date: Mon, 13 Apr 2026 00:22:33 -0600 Subject: [PATCH 5/5] =?UTF-8?q?fix=20stale=20local=20variable=20name=20in?= =?UTF-8?q?=20quickstart=20example=20(issues=20=E2=86=92=20results)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The quickstart example used `issues` as the local collector variable, left over from before the AgenticSession field was renamed to `getResults()` in round 2. Rename to `results` so the example is consistent with the other SDK READMEs. Co-Authored-By: Claude Sonnet 4.6 --- temporal-tool-registry/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/temporal-tool-registry/README.md b/temporal-tool-registry/README.md index 8afe531ad8..ab7bb637f8 100644 --- a/temporal-tool-registry/README.md +++ b/temporal-tool-registry/README.md @@ -38,7 +38,7 @@ import io.temporal.toolregistry.*; @ActivityMethod public List analyze(String prompt) throws Exception { - List issues = new ArrayList<>(); + List results = new ArrayList<>(); ToolRegistry registry = new ToolRegistry(); registry.register( ToolDefinition.builder() @@ -50,7 +50,7 @@ public List analyze(String prompt) throws Exception { "required", List.of("description"))) .build(), (Map input) -> { - issues.add((String) input.get("description")); + results.add((String) input.get("description")); return "recorded"; // this string is sent back to the LLM as the tool result }); @@ -61,7 +61,7 @@ public List analyze(String prompt) throws Exception { "You are a code reviewer. Call flag_issue for each problem you find."); ToolRegistry.runToolLoop(provider, registry, prompt); - return issues; + return results; } ```