diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java index 3c673b140..f3f004b55 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java @@ -21,7 +21,6 @@ import static com.google.adk.plugins.agentanalytics.JsonFormatter.convertToJsonNode; import static com.google.adk.plugins.agentanalytics.JsonFormatter.smartTruncate; import static com.google.adk.plugins.agentanalytics.JsonFormatter.toJavaObject; -import static java.util.concurrent.TimeUnit.MILLISECONDS; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.CallbackContext; @@ -41,8 +40,6 @@ import com.google.adk.tools.ToolContext; import com.google.adk.tools.mcp.AbstractMcpTool; import com.google.adk.utils.AgentEnums.AgentOrigin; -import com.google.api.gax.core.FixedCredentialsProvider; -import com.google.api.gax.retrying.RetrySettings; import com.google.auth.oauth2.GoogleCredentials; import com.google.cloud.bigquery.BigQuery; import com.google.cloud.bigquery.BigQueryException; @@ -53,11 +50,7 @@ import com.google.cloud.bigquery.Table; import com.google.cloud.bigquery.TableId; import com.google.cloud.bigquery.TableInfo; -import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; -import com.google.cloud.bigquery.storage.v1.BigQueryWriteSettings; -import com.google.cloud.bigquery.storage.v1.StreamWriter; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; @@ -70,9 +63,6 @@ import java.util.HashMap; import java.util.Map; import java.util.Optional; -import java.util.concurrent.Executors; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; @@ -100,11 +90,8 @@ public class BigQueryAgentAnalyticsPlugin extends BasePlugin { private final BigQueryLoggerConfig config; private final BigQuery bigQuery; - private final BigQueryWriteClient writeClient; - private final ScheduledExecutorService executor; private final Object tableEnsuredLock = new Object(); - @VisibleForTesting final BatchProcessor batchProcessor; - @VisibleForTesting final TraceManager traceManager; + private final PluginState state; private volatile boolean tableEnsured = false; public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOException { @@ -113,28 +100,14 @@ public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOExcept public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQuery) throws IOException { + this(config, bigQuery, new PluginState(config)); + } + + BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQuery, PluginState state) { super("bigquery_agent_analytics"); this.config = config; this.bigQuery = bigQuery; - ThreadFactory threadFactory = - r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement()); - this.executor = Executors.newScheduledThreadPool(1, threadFactory); - this.writeClient = createWriteClient(config); - this.traceManager = createTraceManager(); - - if (config.enabled()) { - StreamWriter writer = createWriter(config); - this.batchProcessor = - new BatchProcessor( - writer, - config.batchSize(), - config.batchFlushInterval(), - config.queueMaxSize(), - executor); - this.batchProcessor.start(); - } else { - this.batchProcessor = null; - } + this.state = state; } private static BigQuery createBigQuery(BigQueryLoggerConfig config) throws IOException { @@ -194,7 +167,7 @@ private void ensureTableExists(BigQuery bigQuery, BigQueryLoggerConfig config) { try { if (config.createViews()) { - var unused = executor.submit(() -> createAnalyticsViews(bigQuery, config)); + var unused = state.getExecutor().submit(() -> createAnalyticsViews(bigQuery, config)); } } catch (RuntimeException e) { logger.log(Level.WARNING, "Failed to create/update BigQuery views for table: " + tableId, e); @@ -209,48 +182,6 @@ private void processBigQueryException(BigQueryException e, String logMessage) { } } - protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) throws IOException { - if (config.credentials() != null) { - return BigQueryWriteClient.create( - BigQueryWriteSettings.newBuilder() - .setCredentialsProvider(FixedCredentialsProvider.create(config.credentials())) - .build()); - } - return BigQueryWriteClient.create(); - } - - protected String getStreamName(BigQueryLoggerConfig config) { - return String.format( - "projects/%s/datasets/%s/tables/%s/streams/_default", - config.projectId(), config.datasetId(), config.tableName()); - } - - protected StreamWriter createWriter(BigQueryLoggerConfig config) { - BigQueryLoggerConfig.RetryConfig retryConfig = config.retryConfig(); - RetrySettings retrySettings = - RetrySettings.newBuilder() - .setMaxAttempts(retryConfig.maxRetries()) - .setInitialRetryDelay( - org.threeten.bp.Duration.ofMillis(retryConfig.initialDelay().toMillis())) - .setRetryDelayMultiplier(retryConfig.multiplier()) - .setMaxRetryDelay(org.threeten.bp.Duration.ofMillis(retryConfig.maxDelay().toMillis())) - .build(); - - String streamName = getStreamName(config); - try { - return StreamWriter.newBuilder(streamName, writeClient) - .setRetrySettings(retrySettings) - .setWriterSchema(BigQuerySchema.getArrowSchema()) - .build(); - } catch (Exception e) { - throw new VerifyException("Failed to create StreamWriter for " + streamName, e); - } - } - - protected TraceManager createTraceManager() { - return new TraceManager(); - } - private void logEvent( String eventType, InvocationContext invocationContext, @@ -265,7 +196,7 @@ private void logEvent( Object content, boolean isContentTruncated, Optional eventData) { - if (!config.enabled() || batchProcessor == null) { + if (!config.enabled()) { return; } if (!config.eventAllowlist().isEmpty() && !config.eventAllowlist().contains(eventType)) { @@ -274,6 +205,8 @@ private void logEvent( if (config.eventDenylist().contains(eventType)) { return; } + String invocationId = invocationContext.invocationId(); + BatchProcessor processor = state.getBatchProcessor(invocationId); // Ensure table exists before logging. ensureTableExistsOnce(); // Log common fields @@ -301,11 +234,12 @@ private void logEvent( row.put("attributes", convertToJsonNode(getAttributes(data, invocationContext))); addTraceDetails(row, invocationContext, eventData); - batchProcessor.append(row); + processor.append(row); } private void addTraceDetails( Map row, InvocationContext invocationContext, Optional eventData) { + TraceManager traceManager = state.getTraceManager(invocationContext.invocationId()); String traceId = eventData .flatMap(EventData::traceIdOverride) @@ -336,7 +270,7 @@ private void addTraceDetails( private Map getAttributes( EventData eventData, InvocationContext invocationContext) { Map attributes = new HashMap<>(eventData.extraAttributes()); - + TraceManager traceManager = state.getTraceManager(invocationContext.invocationId()); attributes.put("root_agent_name", traceManager.getRootAgentName()); eventData.model().ifPresent(m -> attributes.put("model", m)); eventData.modelVersion().ifPresent(mv -> attributes.put("model_version", mv)); @@ -375,25 +309,17 @@ private Map getAttributes( @Override public Completable close() { - if (batchProcessor != null) { - batchProcessor.close(); - } - if (writeClient != null) { - writeClient.close(); - } - try { - executor.shutdown(); - if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) { - executor.shutdownNow(); - } - } catch (InterruptedException e) { - executor.shutdownNow(); - Thread.currentThread().interrupt(); - } + state.close(); return Completable.complete(); } + @VisibleForTesting + PluginState getState() { + return state; + } + private Optional getCompletedEventData(InvocationContext invocationContext) { + TraceManager traceManager = state.getTraceManager(invocationContext.invocationId()); String traceId = traceManager.getTraceId(invocationContext); // Pop the invocation span from the trace manager. Optional popped = traceManager.popSpan(); @@ -426,7 +352,9 @@ public Maybe onUserMessageCallback( InvocationContext invocationContext, Content userMessage) { return Maybe.fromAction( () -> { - traceManager.ensureInvocationSpan(invocationContext); + state + .getTraceManager(invocationContext.invocationId()) + .ensureInvocationSpan(invocationContext); logEvent("USER_MESSAGE_RECEIVED", invocationContext, userMessage, Optional.empty()); if (userMessage.parts().isPresent()) { for (Part part : userMessage.parts().get()) { @@ -510,7 +438,7 @@ public Maybe onEventCallback(InvocationContext invocationContext, Event e @Override public Maybe beforeRunCallback(InvocationContext invocationContext) { - traceManager.ensureInvocationSpan(invocationContext); + state.getTraceManager(invocationContext.invocationId()).ensureInvocationSpan(invocationContext); return Maybe.fromAction( () -> logEvent("INVOCATION_STARTING", invocationContext, null, Optional.empty())); } @@ -524,8 +452,15 @@ public Completable afterRunCallback(InvocationContext invocationContext) { invocationContext, null, getCompletedEventData(invocationContext)); - batchProcessor.flush(); - traceManager.clearStack(); + BatchProcessor processor = state.removeProcessor(invocationContext.invocationId()); + if (processor != null) { + processor.flush(); + processor.close(); + } + TraceManager traceManager = state.removeTraceManager(invocationContext.invocationId()); + if (traceManager != null) { + traceManager.clearStack(); + } }); } @@ -533,7 +468,9 @@ public Completable afterRunCallback(InvocationContext invocationContext) { public Maybe beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { return Maybe.fromAction( () -> { - traceManager.pushSpan("agent:" + agent.name()); + state + .getTraceManager(callbackContext.invocationContext().invocationId()) + .pushSpan("agent:" + agent.name()); logEvent("AGENT_STARTING", callbackContext.invocationContext(), null, Optional.empty()); }); } @@ -622,7 +559,9 @@ public Maybe beforeModelCallback( .setModel(req.model().orElse("")) .setExtraAttributes(attributes) .build(); - traceManager.pushSpan("llm_request"); + state + .getTraceManager(callbackContext.invocationContext().invocationId()) + .pushSpan("llm_request"); logEvent("LLM_REQUEST", callbackContext.invocationContext(), req, Optional.of(eventData)); }); } @@ -632,6 +571,8 @@ public Maybe afterModelCallback( CallbackContext callbackContext, LlmResponse llmResponse) { return Maybe.fromAction( () -> { + TraceManager traceManager = + state.getTraceManager(callbackContext.invocationContext().invocationId()); // TODO(b/495809488): Add formatting of the content ParsedContent parsedContent = JsonFormatter.parse(llmResponse.content().orElse(null), config.maxContentLength()); @@ -728,6 +669,8 @@ public Maybe onModelErrorCallback( CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { return Maybe.fromAction( () -> { + TraceManager traceManager = + state.getTraceManager(callbackContext.invocationContext().invocationId()); InvocationContext invocationContext = callbackContext.invocationContext(); Optional popped = traceManager.popSpan(); String spanId = popped.map(RecordData::spanId).orElse(null); @@ -762,7 +705,7 @@ public Maybe> beforeToolCallback( ImmutableMap contentMap = ImmutableMap.of( "tool_origin", getToolOrigin(tool), "tool", tool.name(), "args", res.node()); - traceManager.pushSpan("tool"); + state.getTraceManager(toolContext.invocationContext().invocationId()).pushSpan("tool"); logEvent("TOOL_STARTING", toolContext.invocationContext(), contentMap, Optional.empty()); }); } @@ -775,6 +718,8 @@ public Maybe> afterToolCallback( Map result) { return Maybe.fromAction( () -> { + TraceManager traceManager = + state.getTraceManager(toolContext.invocationContext().invocationId()); Optional popped = traceManager.popSpan(); TruncationResult truncationResult = smartTruncate(result, config.maxContentLength()); ImmutableMap contentMap = @@ -812,6 +757,8 @@ public Maybe> onToolErrorCallback( BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { return Maybe.fromAction( () -> { + TraceManager traceManager = + state.getTraceManager(toolContext.invocationContext().invocationId()); Optional popped = traceManager.popSpan(); TruncationResult truncationResult = smartTruncate(toolArgs, config.maxContentLength()); String toolOrigin = getToolOrigin(tool); diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java new file mode 100644 index 000000000..faec47cf1 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java @@ -0,0 +1,158 @@ +package com.google.adk.plugins.agentanalytics; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +import com.google.api.gax.core.FixedCredentialsProvider; +import com.google.api.gax.retrying.RetrySettings; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteSettings; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.VerifyException; +import java.io.IOException; +import java.util.Collection; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicLong; + +/** Manages state for the BigQueryAgentAnalyticsPlugin. */ +class PluginState { + private final BigQueryLoggerConfig config; + private final ScheduledExecutorService executor; + private final BigQueryWriteClient writeClient; + private static final AtomicLong threadCounter = new AtomicLong(0); + // Map of invocation ID to BatchProcessor. + private final ConcurrentHashMap batchProcessors = + new ConcurrentHashMap<>(); + // Map of invocation ID to TraceManager. + private final ConcurrentHashMap traceManagers = new ConcurrentHashMap<>(); + + PluginState(BigQueryLoggerConfig config) throws IOException { + this.config = config; + ThreadFactory threadFactory = + r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement()); + this.executor = Executors.newScheduledThreadPool(1, threadFactory); + // One write client per plugin instance, shared by all invocations. + this.writeClient = createWriteClient(config); + } + + ScheduledExecutorService getExecutor() { + return executor; + } + + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) throws IOException { + if (config.credentials() != null) { + return BigQueryWriteClient.create( + BigQueryWriteSettings.newBuilder() + .setCredentialsProvider(FixedCredentialsProvider.create(config.credentials())) + .build()); + } + return BigQueryWriteClient.create(); + } + + protected StreamWriter createWriter() { + BigQueryLoggerConfig.RetryConfig retryConfig = config.retryConfig(); + RetrySettings retrySettings = + RetrySettings.newBuilder() + .setMaxAttempts(retryConfig.maxRetries()) + .setInitialRetryDelay( + org.threeten.bp.Duration.ofMillis(retryConfig.initialDelay().toMillis())) + .setRetryDelayMultiplier(retryConfig.multiplier()) + .setMaxRetryDelay(org.threeten.bp.Duration.ofMillis(retryConfig.maxDelay().toMillis())) + .build(); + + String streamName = getStreamName(config); + try { + return StreamWriter.newBuilder(streamName, writeClient) + .setRetrySettings(retrySettings) + .setWriterSchema(BigQuerySchema.getArrowSchema()) + .build(); + } catch (Exception e) { + throw new VerifyException("Failed to create StreamWriter for " + streamName, e); + } + } + + @VisibleForTesting + String getStreamName(BigQueryLoggerConfig config) { + return String.format( + "projects/%s/datasets/%s/tables/%s/streams/_default", + config.projectId(), config.datasetId(), config.tableName()); + } + + @VisibleForTesting + TraceManager getTraceManager(String invocationId) { + return traceManagers.computeIfAbsent(invocationId, id -> new TraceManager()); + } + + @VisibleForTesting + BatchProcessor getBatchProcessor(String invocationId) { + return batchProcessors.computeIfAbsent( + invocationId, + id -> { + StreamWriter writer = createWriter(); + BatchProcessor p = + new BatchProcessor( + writer, + config.batchSize(), + config.batchFlushInterval(), + config.queueMaxSize(), + executor); + p.start(); + return p; + }); + } + + @VisibleForTesting + Collection getTraceManagers() { + return traceManagers.values(); + } + + @VisibleForTesting + Collection getBatchProcessors() { + return batchProcessors.values(); + } + + @VisibleForTesting + TraceManager removeTraceManager(String invocationId) { + return traceManagers.remove(invocationId); + } + + @VisibleForTesting + protected BatchProcessor removeProcessor(String invocationId) { + return batchProcessors.remove(invocationId); + } + + void clearTraceManagers() { + traceManagers.clear(); + } + + void clearBatchProcessors() { + batchProcessors.clear(); + } + + void close() { + for (BatchProcessor processor : getBatchProcessors()) { + processor.close(); + } + for (TraceManager traceManager : getTraceManagers()) { + traceManager.clearStack(); + } + clearBatchProcessors(); + clearTraceManagers(); + + if (writeClient != null) { + writeClient.close(); + } + try { + executor.shutdown(); + if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) { + executor.shutdownNow(); + } + } catch (InterruptedException e) { + executor.shutdownNow(); + Thread.currentThread().interrupt(); + } + } +} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java index 53faf3329..ef721e432 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java @@ -63,6 +63,7 @@ public final class BigQueryAgentAnalyticsPluginE2ETest { private StreamWriter mockWriter; private BigQueryWriteClient mockWriteClient; private BigQueryLoggerConfig config; + private PluginState state; private BigQueryAgentAnalyticsPlugin plugin; private Runner runner; private BaseAgent fakeAgent; @@ -92,26 +93,34 @@ public void setUp() throws Exception { when(mockWriter.append(any(ArrowRecordBatch.class))) .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); - plugin = - new BigQueryAgentAnalyticsPlugin(config, mockBigQuery) { + state = + new PluginState(config) { @Override protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { return mockWriteClient; } @Override - protected StreamWriter createWriter(BigQueryLoggerConfig config) { + protected StreamWriter createWriter() { return mockWriter; } + + @Override + protected BatchProcessor removeProcessor(String invocationId) { + return null; + } }; + plugin = new BigQueryAgentAnalyticsPlugin(config, mockBigQuery, state); + when(mockWriter.append(any(ArrowRecordBatch.class))) .thenAnswer( invocation -> { ArrowRecordBatch recordedBatch = invocation.getArgument(0); + BatchProcessor batchProcessor = state.getBatchProcessors().iterator().next(); try (VectorSchemaRoot root = VectorSchemaRoot.create( - BigQuerySchema.getArrowSchema(), plugin.batchProcessor.allocator)) { + BigQuerySchema.getArrowSchema(), batchProcessor.allocator)) { VectorLoader loader = new VectorLoader(root); loader.load(recordedBatch); for (int i = 0; i < root.getRowCount(); i++) { @@ -150,8 +159,9 @@ public void runAgent_logsAgentStartingAndCompleted() throws Exception { // Ensure everything is flushed. The BatchProcessor flushes asynchronously sometimes, // but the direct flush() call should help. We wait up to 2 seconds for all 5 expected events. + BatchProcessor batchProcessor = state.getBatchProcessors().iterator().next(); for (int i = 0; i < 20 && capturedRows.size() < 5; i++) { - plugin.batchProcessor.flush(); + batchProcessor.flush(); if (capturedRows.size() < 5) { Thread.sleep(100); } diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java index c7e35e3d6..fed1d81f1 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java @@ -62,7 +62,6 @@ import com.google.genai.types.GenerateContentResponse; import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; -import io.opentelemetry.api.GlobalOpenTelemetry; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.SpanContext; import io.opentelemetry.api.trace.Tracer; @@ -75,7 +74,11 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.logging.Handler; import java.util.logging.Level; import java.util.logging.LogRecord; @@ -113,6 +116,7 @@ public class BigQueryAgentAnalyticsPluginTest { private BaseAgent fakeAgent; private BigQueryLoggerConfig config; + private PluginState state; private BigQueryAgentAnalyticsPlugin plugin; private Handler mockHandler; private Tracer tracer; @@ -140,24 +144,21 @@ public void setUp() throws Exception { when(mockWriter.append(any(ArrowRecordBatch.class))) .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); - plugin = - new BigQueryAgentAnalyticsPlugin(config, mockBigQuery) { + state = + new PluginState(config) { @Override protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { return mockWriteClient; } @Override - protected StreamWriter createWriter(BigQueryLoggerConfig config) { + protected StreamWriter createWriter() { return mockWriter; } - - @Override - protected TraceManager createTraceManager() { - return new TraceManager(tracer); - } }; + plugin = new BigQueryAgentAnalyticsPlugin(config, mockBigQuery, state); + Session session = Session.builder("session_id").appName("test_app").userId("test_user").build(); when(mockInvocationContext.session()).thenReturn(session); when(mockInvocationContext.invocationId()).thenReturn("invocation_id"); @@ -183,7 +184,7 @@ public void onUserMessageCallback_appendsToWriter() throws Exception { Content content = Content.builder().build(); plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); } @@ -191,15 +192,15 @@ public void onUserMessageCallback_appendsToWriter() throws Exception { @Test public void beforeRunCallback_appendsToWriter() throws Exception { plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); } @Test public void afterRunCallback_flushesAndAppends() throws Exception { + plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); plugin.afterRunCallback(mockInvocationContext).blockingSubscribe(); - plugin.batchProcessor.flush(); verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); } @@ -213,7 +214,7 @@ public void getStreamName_returnsCorrectFormat() { .tableName("test-table") .build(); - String streamName = plugin.getStreamName(config); + String streamName = state.getStreamName(config); assertEquals( "projects/test-project/datasets/test-dataset/tables/test-table/streams/_default", @@ -253,7 +254,7 @@ public void onUserMessageCallback_handlesTableCreationFailure() throws Exception // Should not throw exception plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); verify(mockHandler, atLeastOnce()).publish(captor.capture()); @@ -280,7 +281,7 @@ public void onUserMessageCallback_handlesAppendFailure() throws Exception { plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); // Flush should handle the failed future from writer.append() - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); @@ -350,7 +351,8 @@ public void logEvent_populatesCommonFields() throws Exception { ArrowRecordBatch recordedBatch = invocation.getArgument(0); Schema schema = BigQuerySchema.getArrowSchema(); try (VectorSchemaRoot root = - VectorSchemaRoot.create(schema, plugin.batchProcessor.allocator)) { + VectorSchemaRoot.create( + schema, state.getBatchProcessor("invocation_id").allocator)) { VectorLoader loader = new VectorLoader(root); loader.load(recordedBatch); @@ -411,7 +413,7 @@ public void logEvent_populatesCommonFields() throws Exception { Content content = Content.fromParts(Part.fromText("test message")); plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); assertTrue(failureMessage[0], checksPassed[0]); } @@ -429,12 +431,12 @@ public void logEvent_populatesTraceDetails() throws Exception { Span mockSpan = Span.wrap(mockSpanContext); try (Scope scope = mockSpan.makeCurrent()) { - plugin.traceManager.attachCurrentSpan(); + state.getTraceManager("invocation_id").attachCurrentSpan(); Content content = Content.builder().build(); plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull("Row not found in queue", row); assertEquals(traceId, row.get("trace_id")); assertEquals(spanId, row.get("span_id")); @@ -447,7 +449,7 @@ public void complexType_appendsToWriter() throws Exception { Content content = Content.fromParts(part); plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); } @@ -462,7 +464,7 @@ public void onEventCallback_populatesCorrectFields() throws Exception { plugin.onEventCallback(mockInvocationContext, event).blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull("Row not found in queue", row); assertEquals("STATE_DELTA", row.get("event_type")); assertEquals("agent_name", row.get("agent")); @@ -479,12 +481,12 @@ public void onModelErrorCallback_populatesCorrectFields() throws Exception { LlmRequest.Builder mockLlmRequestBuilder = mock(LlmRequest.Builder.class); Throwable error = new RuntimeException("model error message"); - plugin.traceManager.pushSpan("llm_request"); + state.getTraceManager("invocation_id").pushSpan("llm_request"); plugin .onModelErrorCallback(mockCallbackContext, mockLlmRequestBuilder, error) .blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = plugin.getState().getBatchProcessor("invocation_id").queue.poll(); assertNotNull("Row not found in queue", row); assertEquals("LLM_ERROR", row.get("event_type")); assertEquals("agent_name", row.get("agent")); @@ -524,13 +526,13 @@ public void afterModelCallback_populatesCorrectFields() throws Exception { tracer.spanBuilder("ambient").setParent(Context.current().with(parentSpan)).startSpan(); // Set valid ambient span context try (Scope scope = ambientSpan.makeCurrent()) { - plugin.traceManager.pushSpan("parent_request"); - plugin.traceManager.pushSpan("llm_request"); + state.getTraceManager("invocation_id").pushSpan("parent_request"); + state.getTraceManager("invocation_id").pushSpan("llm_request"); plugin.afterModelCallback(mockCallbackContext, adkResponse).blockingSubscribe(); } finally { ambientSpan.end(); } - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull("Row not found in queue", row); assertEquals("LLM_RESPONSE", row.get("event_type")); ObjectNode contentMap = (ObjectNode) row.get("content"); @@ -562,10 +564,10 @@ public void afterToolCallback_populatesCorrectFields() throws Exception { ImmutableMap toolArgs = ImmutableMap.of("arg1", "value1"); ImmutableMap result = ImmutableMap.of("res1", "value2"); - plugin.traceManager.pushSpan("tool_request"); + state.getTraceManager("invocation_id").pushSpan("tool_request"); plugin.afterToolCallback(mockTool, toolArgs, mockToolContext, result).blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull("Row not found in queue", row); assertEquals("TOOL_COMPLETED", row.get("event_type")); assertEquals("agent_name", row.get("agent")); @@ -592,12 +594,12 @@ public AgentOrigin toolOrigin() { AgentTool a2aTool = AgentTool.create(a2aAgent); - plugin.traceManager.pushSpan("tool_request"); + state.getTraceManager("invocation_id").pushSpan("tool_request"); plugin .afterToolCallback(a2aTool, ImmutableMap.of(), mockToolContext, ImmutableMap.of()) .blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull(row); ObjectNode contentMap = (ObjectNode) row.get("content"); assertEquals("A2A", contentMap.get("tool_origin").asText()); @@ -609,7 +611,7 @@ public void logEvent_includesSessionMetadata_whenEnabled() throws Exception { Content content = Content.fromParts(Part.fromText("test message")); plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull(row); ObjectNode attributes = (ObjectNode) row.get("attributes"); assertTrue("attributes should contain session_metadata", attributes.has("session_metadata")); @@ -622,28 +624,25 @@ public void logEvent_includesSessionMetadata_whenEnabled() throws Exception { @Test public void logEvent_excludesSessionMetadata_whenDisabled() throws Exception { BigQueryLoggerConfig disabledConfig = config.toBuilder().logSessionMetadata(false).build(); - BigQueryAgentAnalyticsPlugin disabledPlugin = - new BigQueryAgentAnalyticsPlugin(disabledConfig, mockBigQuery) { + PluginState disabledState = + new PluginState(disabledConfig) { @Override protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { return mockWriteClient; } @Override - protected StreamWriter createWriter(BigQueryLoggerConfig config) { + protected StreamWriter createWriter() { return mockWriter; } - - @Override - protected TraceManager createTraceManager() { - return new TraceManager(GlobalOpenTelemetry.getTracer("test-plugin-disabled")); - } }; + BigQueryAgentAnalyticsPlugin disabledPlugin = + new BigQueryAgentAnalyticsPlugin(disabledConfig, mockBigQuery, disabledState); Content content = Content.fromParts(Part.fromText("test message")); disabledPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - Map row = disabledPlugin.batchProcessor.queue.poll(); + Map row = disabledState.getBatchProcessor("invocation_id").queue.poll(); assertNotNull(row); ObjectNode attributes = (ObjectNode) row.get("attributes"); assertFalse( @@ -767,6 +766,100 @@ public void createAnalyticsViews_executesQueries() throws Exception { .anyMatch(q -> q.contains("CREATE OR REPLACE VIEW `project.dataset.v_llm_response`"))); } + @Test + public void multipleInvocations_logsCorrectly() throws Exception { + BigQueryLoggerConfig testConfig = config.toBuilder().batchSize(10).build(); + PluginState testState = + new PluginState(testConfig) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter() { + return mockWriter; + } + }; + BigQueryAgentAnalyticsPlugin testPlugin = + new BigQueryAgentAnalyticsPlugin(testConfig, mockBigQuery, testState); + + InvocationContext context1 = mock(InvocationContext.class); + when(context1.invocationId()).thenReturn("inv-1"); + when(context1.agent()).thenReturn(fakeAgent); + when(context1.session()).thenReturn(Session.builder("s1").build()); + + InvocationContext context2 = mock(InvocationContext.class); + when(context2.invocationId()).thenReturn("inv-2"); + when(context2.agent()).thenReturn(fakeAgent); + when(context2.session()).thenReturn(Session.builder("s2").build()); + + var unused1 = testPlugin.beforeRunCallback(context1).blockingGet(); + var unused2 = + testPlugin + .onUserMessageCallback(context1, Content.fromParts(Part.fromText("msg1"))) + .blockingGet(); + + var unused3 = testPlugin.beforeRunCallback(context2).blockingGet(); + var unused4 = + testPlugin + .onUserMessageCallback(context2, Content.fromParts(Part.fromText("msg2"))) + .blockingGet(); + + // Verify processors are created and have correct data in their queues + BatchProcessor p1 = testState.getBatchProcessor("inv-1"); + BatchProcessor p2 = testState.getBatchProcessor("inv-2"); + + assertNotNull("Processor for inv-1 should exist", p1); + assertNotNull("Processor for inv-2 should exist", p2); + assertFalse("Queue for inv-1 should not be empty", p1.queue.isEmpty()); + assertFalse("Queue for inv-2 should not be empty", p2.queue.isEmpty()); + + assertTrue( + "All logs for inv-1 should have correct invocation_id", + p1.queue.stream().allMatch(row -> row.get("invocation_id").equals("inv-1"))); + assertTrue( + "All logs for inv-2 should have correct invocation_id", + p2.queue.stream().allMatch(row -> row.get("invocation_id").equals("inv-2"))); + + // Now flush and verify writer was called + testPlugin.afterRunCallback(context1).blockingAwait(); + testPlugin.afterRunCallback(context2).blockingAwait(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void logEvent_createsUniqueProcessorPerInvocation() throws Exception { + int numInvocations = 5; + ExecutorService testExecutor = Executors.newFixedThreadPool(numInvocations); + Set processors = ConcurrentHashMap.newKeySet(); + CountDownLatch latch = new CountDownLatch(numInvocations); + + for (int i = 0; i < numInvocations; i++) { + final String invocationId = "inv-" + i; + testExecutor.execute( + () -> { + try { + InvocationContext context = mock(InvocationContext.class); + when(context.invocationId()).thenReturn(invocationId); + when(context.agent()).thenReturn(fakeAgent); + Session session = Session.builder("s").build(); + when(context.session()).thenReturn(session); + + plugin.beforeRunCallback(context).blockingSubscribe(); + processors.add(state.getBatchProcessor(invocationId)); + } finally { + latch.countDown(); + } + }); + } + + latch.await(); + assertEquals(numInvocations, processors.size()); + testExecutor.shutdown(); + } + private static class FakeAgent extends BaseAgent { FakeAgent(String name) { super(name, "description", null, null, null);