diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index fffeab698..fdda5219d 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -461,14 +461,31 @@ public Flowable run(InvocationContext invocationContext) { private Flowable run( Context spanContext, InvocationContext invocationContext, int stepsCompleted) { - Flowable currentStepEvents = runOneStep(spanContext, invocationContext).cache(); + Flowable currentStepEvents = runOneStep(spanContext, invocationContext); + + Flowable processedEvents = + currentStepEvents + .concatMap( + event -> + invocationContext + .sessionService() + .appendEvent(invocationContext.session(), event) + .flatMap( + registeredEvent -> + invocationContext + .pluginManager() + .onEventCallback(invocationContext, registeredEvent) + .defaultIfEmpty(registeredEvent)) + .toFlowable()) + .cache(); + if (stepsCompleted + 1 >= maxSteps) { logger.debug("Ending flow execution because max steps reached."); - return currentStepEvents; + return processedEvents; } - return currentStepEvents.concatWith( - currentStepEvents + return processedEvents.concatWith( + processedEvents .toList() .flatMapPublisher( eventList -> { diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 44a281f72..0005c6ecb 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -570,19 +570,27 @@ private Flowable runAgentWithUpdatedSession( .agent() .runAsync(contextWithUpdatedSession) .concatMap( - agentEvent -> - this.sessionService - .appendEvent(updatedSession, agentEvent) - .flatMap( - registeredEvent -> { - // TODO: remove this hack after deprecating runAsync with Session. - copySessionStates(updatedSession, initialContext.session()); - return contextWithUpdatedSession - .pluginManager() - .onEventCallback(contextWithUpdatedSession, registeredEvent) - .defaultIfEmpty(registeredEvent); - }) - .toFlowable()); + agentEvent -> { + if (agentEvent.id() != null + && updatedSession.events().stream() + .anyMatch(e -> agentEvent.id().equals(e.id()))) { + // Already appended (e.g. by BaseLlmFlow). Still apply the hack. + copySessionStates(updatedSession, initialContext.session()); + return Flowable.just(agentEvent); + } + return this.sessionService + .appendEvent(updatedSession, agentEvent) + .flatMap( + registeredEvent -> { + // TODO: remove this hack after deprecating runAsync with Session. + copySessionStates(updatedSession, initialContext.session()); + return contextWithUpdatedSession + .pluginManager() + .onEventCallback(contextWithUpdatedSession, registeredEvent) + .defaultIfEmpty(registeredEvent); + }) + .toFlowable(); + }); // If beforeRunCallback returns content, emit it and skip agent Context capturedContext = Context.current(); diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index ff75c97b0..0fd32252e 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -46,6 +46,7 @@ import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.Functions; +import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; import com.google.adk.sessions.BaseSessionService; @@ -588,12 +589,22 @@ public void onToolErrorCallback_error() { @Test public void onEventCallback_success() { when(plugin.onEventCallback(any(), any())) - .thenReturn(Maybe.just(TestUtils.createEvent("form plugin"))); + .thenAnswer( + invocation -> { + Event event = invocation.getArgument(1); + return Maybe.just( + Event.builder() + .id(event.id()) + .invocationId(event.invocationId()) + .author("model") + .content(createContent("from plugin")) + .build()); + }); List events = runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); - assertThat(simplifyEvents(events)).containsExactly("author: content for event form plugin"); + assertThat(simplifyEvents(events)).containsExactly("model: from plugin"); verify(plugin).onEventCallback(any(), any()); } @@ -1686,4 +1697,67 @@ public void runner_executesSaveArtifactFlow() { // agent was run assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm"); } + + @Test + public void runAsync_ensuresSequentialConsistencyForTools() { + // Arrange + TestLlm testLlm = + createTestLlm( + createFunctionCallLlmResponse("call_1", "tool1", ImmutableMap.of("arg", "value1")), + createTextLlmResponse("Final response")); + + LlmAgent agent = + createTestAgentBuilder(testLlm) + .tools( + ImmutableList.of( + FunctionTool.create(RaceConditionTools.class, "tool1"), + FunctionTool.create(RaceConditionTools.class, "tool2"))) + .build(); + + Runner runner = + Runner.builder().app(App.builder().name("test").rootAgent(agent).build()).build(); + Session session = runner.sessionService().createSession("test", "user").blockingGet(); + + // Act + var unused = + runner + .runAsync("user", session.id(), Content.fromParts(Part.fromText("start"))) + .toList() + .blockingGet(); + + // Assert + ImmutableList requests = ImmutableList.copyOf(testLlm.getRequests()); + assertThat(requests).hasSize(2); + + // Second request should contain the result of tool1 + LlmRequest secondRequest = requests.get(1); + List history = secondRequest.contents(); + + boolean foundToolResponse = false; + for (Content content : history) { + for (Part part : content.parts().get()) { + if (part.functionResponse().isPresent() + && part.functionResponse().get().name().isPresent() + && part.functionResponse().get().name().get().equals("tool1")) { + foundToolResponse = true; + assertThat(part.functionResponse().get().response().isPresent()).isTrue(); + assertThat(part.functionResponse().get().response().get()) + .isEqualTo(ImmutableMap.of("result", "result_value1")); + } + } + } + assertThat(foundToolResponse).isTrue(); + } + + public static class RaceConditionTools { + private RaceConditionTools() {} + + public static ImmutableMap tool1(String arg) { + return ImmutableMap.of("result", "result_" + arg); + } + + public static ImmutableMap tool2(String input) { + return ImmutableMap.of("status", "received_" + input); + } + } }