From 7486cb57f7fe315e2b64e2c6dcaa06130274dde0 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 9 Apr 2026 16:50:40 -0700 Subject: [PATCH] feat: Make BigQueryAgentAnalyticsPlugin state per-invocation This change introduces per-invocation instances of BatchProcessor and TraceManager, managed by ConcurrentHashMaps keyed by invocation ID. This ensures that analytics and tracing data are isolated for each concurrent invocation. BatchProcessors and TraceManagers are created lazily on the first event for a given invocation and are cleaned up when the invocation completes. PiperOrigin-RevId: 897370846 --- .../BigQueryAgentAnalyticsPlugin.java | 149 +++++---------- .../plugins/agentanalytics/PluginState.java | 158 ++++++++++++++++ .../BigQueryAgentAnalyticsPluginE2ETest.java | 20 +- .../BigQueryAgentAnalyticsPluginTest.java | 173 ++++++++++++++---- 4 files changed, 354 insertions(+), 146 deletions(-) create mode 100644 core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java index 3c673b140..f3f004b55 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java @@ -21,7 +21,6 @@ import static com.google.adk.plugins.agentanalytics.JsonFormatter.convertToJsonNode; import static com.google.adk.plugins.agentanalytics.JsonFormatter.smartTruncate; import static com.google.adk.plugins.agentanalytics.JsonFormatter.toJavaObject; -import static java.util.concurrent.TimeUnit.MILLISECONDS; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.CallbackContext; @@ -41,8 +40,6 @@ import com.google.adk.tools.ToolContext; import com.google.adk.tools.mcp.AbstractMcpTool; import com.google.adk.utils.AgentEnums.AgentOrigin; -import com.google.api.gax.core.FixedCredentialsProvider; -import com.google.api.gax.retrying.RetrySettings; import com.google.auth.oauth2.GoogleCredentials; import com.google.cloud.bigquery.BigQuery; import com.google.cloud.bigquery.BigQueryException; @@ -53,11 +50,7 @@ import com.google.cloud.bigquery.Table; import com.google.cloud.bigquery.TableId; import com.google.cloud.bigquery.TableInfo; -import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; -import com.google.cloud.bigquery.storage.v1.BigQueryWriteSettings; -import com.google.cloud.bigquery.storage.v1.StreamWriter; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; @@ -70,9 +63,6 @@ import java.util.HashMap; import java.util.Map; import java.util.Optional; -import java.util.concurrent.Executors; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; @@ -100,11 +90,8 @@ public class BigQueryAgentAnalyticsPlugin extends BasePlugin { private final BigQueryLoggerConfig config; private final BigQuery bigQuery; - private final BigQueryWriteClient writeClient; - private final ScheduledExecutorService executor; private final Object tableEnsuredLock = new Object(); - @VisibleForTesting final BatchProcessor batchProcessor; - @VisibleForTesting final TraceManager traceManager; + private final PluginState state; private volatile boolean tableEnsured = false; public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOException { @@ -113,28 +100,14 @@ public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOExcept public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQuery) throws IOException { + this(config, bigQuery, new PluginState(config)); + } + + BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQuery, PluginState state) { super("bigquery_agent_analytics"); this.config = config; this.bigQuery = bigQuery; - ThreadFactory threadFactory = - r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement()); - this.executor = Executors.newScheduledThreadPool(1, threadFactory); - this.writeClient = createWriteClient(config); - this.traceManager = createTraceManager(); - - if (config.enabled()) { - StreamWriter writer = createWriter(config); - this.batchProcessor = - new BatchProcessor( - writer, - config.batchSize(), - config.batchFlushInterval(), - config.queueMaxSize(), - executor); - this.batchProcessor.start(); - } else { - this.batchProcessor = null; - } + this.state = state; } private static BigQuery createBigQuery(BigQueryLoggerConfig config) throws IOException { @@ -194,7 +167,7 @@ private void ensureTableExists(BigQuery bigQuery, BigQueryLoggerConfig config) { try { if (config.createViews()) { - var unused = executor.submit(() -> createAnalyticsViews(bigQuery, config)); + var unused = state.getExecutor().submit(() -> createAnalyticsViews(bigQuery, config)); } } catch (RuntimeException e) { logger.log(Level.WARNING, "Failed to create/update BigQuery views for table: " + tableId, e); @@ -209,48 +182,6 @@ private void processBigQueryException(BigQueryException e, String logMessage) { } } - protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) throws IOException { - if (config.credentials() != null) { - return BigQueryWriteClient.create( - BigQueryWriteSettings.newBuilder() - .setCredentialsProvider(FixedCredentialsProvider.create(config.credentials())) - .build()); - } - return BigQueryWriteClient.create(); - } - - protected String getStreamName(BigQueryLoggerConfig config) { - return String.format( - "projects/%s/datasets/%s/tables/%s/streams/_default", - config.projectId(), config.datasetId(), config.tableName()); - } - - protected StreamWriter createWriter(BigQueryLoggerConfig config) { - BigQueryLoggerConfig.RetryConfig retryConfig = config.retryConfig(); - RetrySettings retrySettings = - RetrySettings.newBuilder() - .setMaxAttempts(retryConfig.maxRetries()) - .setInitialRetryDelay( - org.threeten.bp.Duration.ofMillis(retryConfig.initialDelay().toMillis())) - .setRetryDelayMultiplier(retryConfig.multiplier()) - .setMaxRetryDelay(org.threeten.bp.Duration.ofMillis(retryConfig.maxDelay().toMillis())) - .build(); - - String streamName = getStreamName(config); - try { - return StreamWriter.newBuilder(streamName, writeClient) - .setRetrySettings(retrySettings) - .setWriterSchema(BigQuerySchema.getArrowSchema()) - .build(); - } catch (Exception e) { - throw new VerifyException("Failed to create StreamWriter for " + streamName, e); - } - } - - protected TraceManager createTraceManager() { - return new TraceManager(); - } - private void logEvent( String eventType, InvocationContext invocationContext, @@ -265,7 +196,7 @@ private void logEvent( Object content, boolean isContentTruncated, Optional eventData) { - if (!config.enabled() || batchProcessor == null) { + if (!config.enabled()) { return; } if (!config.eventAllowlist().isEmpty() && !config.eventAllowlist().contains(eventType)) { @@ -274,6 +205,8 @@ private void logEvent( if (config.eventDenylist().contains(eventType)) { return; } + String invocationId = invocationContext.invocationId(); + BatchProcessor processor = state.getBatchProcessor(invocationId); // Ensure table exists before logging. ensureTableExistsOnce(); // Log common fields @@ -301,11 +234,12 @@ private void logEvent( row.put("attributes", convertToJsonNode(getAttributes(data, invocationContext))); addTraceDetails(row, invocationContext, eventData); - batchProcessor.append(row); + processor.append(row); } private void addTraceDetails( Map row, InvocationContext invocationContext, Optional eventData) { + TraceManager traceManager = state.getTraceManager(invocationContext.invocationId()); String traceId = eventData .flatMap(EventData::traceIdOverride) @@ -336,7 +270,7 @@ private void addTraceDetails( private Map getAttributes( EventData eventData, InvocationContext invocationContext) { Map attributes = new HashMap<>(eventData.extraAttributes()); - + TraceManager traceManager = state.getTraceManager(invocationContext.invocationId()); attributes.put("root_agent_name", traceManager.getRootAgentName()); eventData.model().ifPresent(m -> attributes.put("model", m)); eventData.modelVersion().ifPresent(mv -> attributes.put("model_version", mv)); @@ -375,25 +309,17 @@ private Map getAttributes( @Override public Completable close() { - if (batchProcessor != null) { - batchProcessor.close(); - } - if (writeClient != null) { - writeClient.close(); - } - try { - executor.shutdown(); - if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) { - executor.shutdownNow(); - } - } catch (InterruptedException e) { - executor.shutdownNow(); - Thread.currentThread().interrupt(); - } + state.close(); return Completable.complete(); } + @VisibleForTesting + PluginState getState() { + return state; + } + private Optional getCompletedEventData(InvocationContext invocationContext) { + TraceManager traceManager = state.getTraceManager(invocationContext.invocationId()); String traceId = traceManager.getTraceId(invocationContext); // Pop the invocation span from the trace manager. Optional popped = traceManager.popSpan(); @@ -426,7 +352,9 @@ public Maybe onUserMessageCallback( InvocationContext invocationContext, Content userMessage) { return Maybe.fromAction( () -> { - traceManager.ensureInvocationSpan(invocationContext); + state + .getTraceManager(invocationContext.invocationId()) + .ensureInvocationSpan(invocationContext); logEvent("USER_MESSAGE_RECEIVED", invocationContext, userMessage, Optional.empty()); if (userMessage.parts().isPresent()) { for (Part part : userMessage.parts().get()) { @@ -510,7 +438,7 @@ public Maybe onEventCallback(InvocationContext invocationContext, Event e @Override public Maybe beforeRunCallback(InvocationContext invocationContext) { - traceManager.ensureInvocationSpan(invocationContext); + state.getTraceManager(invocationContext.invocationId()).ensureInvocationSpan(invocationContext); return Maybe.fromAction( () -> logEvent("INVOCATION_STARTING", invocationContext, null, Optional.empty())); } @@ -524,8 +452,15 @@ public Completable afterRunCallback(InvocationContext invocationContext) { invocationContext, null, getCompletedEventData(invocationContext)); - batchProcessor.flush(); - traceManager.clearStack(); + BatchProcessor processor = state.removeProcessor(invocationContext.invocationId()); + if (processor != null) { + processor.flush(); + processor.close(); + } + TraceManager traceManager = state.removeTraceManager(invocationContext.invocationId()); + if (traceManager != null) { + traceManager.clearStack(); + } }); } @@ -533,7 +468,9 @@ public Completable afterRunCallback(InvocationContext invocationContext) { public Maybe beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { return Maybe.fromAction( () -> { - traceManager.pushSpan("agent:" + agent.name()); + state + .getTraceManager(callbackContext.invocationContext().invocationId()) + .pushSpan("agent:" + agent.name()); logEvent("AGENT_STARTING", callbackContext.invocationContext(), null, Optional.empty()); }); } @@ -622,7 +559,9 @@ public Maybe beforeModelCallback( .setModel(req.model().orElse("")) .setExtraAttributes(attributes) .build(); - traceManager.pushSpan("llm_request"); + state + .getTraceManager(callbackContext.invocationContext().invocationId()) + .pushSpan("llm_request"); logEvent("LLM_REQUEST", callbackContext.invocationContext(), req, Optional.of(eventData)); }); } @@ -632,6 +571,8 @@ public Maybe afterModelCallback( CallbackContext callbackContext, LlmResponse llmResponse) { return Maybe.fromAction( () -> { + TraceManager traceManager = + state.getTraceManager(callbackContext.invocationContext().invocationId()); // TODO(b/495809488): Add formatting of the content ParsedContent parsedContent = JsonFormatter.parse(llmResponse.content().orElse(null), config.maxContentLength()); @@ -728,6 +669,8 @@ public Maybe onModelErrorCallback( CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { return Maybe.fromAction( () -> { + TraceManager traceManager = + state.getTraceManager(callbackContext.invocationContext().invocationId()); InvocationContext invocationContext = callbackContext.invocationContext(); Optional popped = traceManager.popSpan(); String spanId = popped.map(RecordData::spanId).orElse(null); @@ -762,7 +705,7 @@ public Maybe> beforeToolCallback( ImmutableMap contentMap = ImmutableMap.of( "tool_origin", getToolOrigin(tool), "tool", tool.name(), "args", res.node()); - traceManager.pushSpan("tool"); + state.getTraceManager(toolContext.invocationContext().invocationId()).pushSpan("tool"); logEvent("TOOL_STARTING", toolContext.invocationContext(), contentMap, Optional.empty()); }); } @@ -775,6 +718,8 @@ public Maybe> afterToolCallback( Map result) { return Maybe.fromAction( () -> { + TraceManager traceManager = + state.getTraceManager(toolContext.invocationContext().invocationId()); Optional popped = traceManager.popSpan(); TruncationResult truncationResult = smartTruncate(result, config.maxContentLength()); ImmutableMap contentMap = @@ -812,6 +757,8 @@ public Maybe> onToolErrorCallback( BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { return Maybe.fromAction( () -> { + TraceManager traceManager = + state.getTraceManager(toolContext.invocationContext().invocationId()); Optional popped = traceManager.popSpan(); TruncationResult truncationResult = smartTruncate(toolArgs, config.maxContentLength()); String toolOrigin = getToolOrigin(tool); diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java new file mode 100644 index 000000000..faec47cf1 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java @@ -0,0 +1,158 @@ +package com.google.adk.plugins.agentanalytics; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +import com.google.api.gax.core.FixedCredentialsProvider; +import com.google.api.gax.retrying.RetrySettings; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteSettings; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.VerifyException; +import java.io.IOException; +import java.util.Collection; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicLong; + +/** Manages state for the BigQueryAgentAnalyticsPlugin. */ +class PluginState { + private final BigQueryLoggerConfig config; + private final ScheduledExecutorService executor; + private final BigQueryWriteClient writeClient; + private static final AtomicLong threadCounter = new AtomicLong(0); + // Map of invocation ID to BatchProcessor. + private final ConcurrentHashMap batchProcessors = + new ConcurrentHashMap<>(); + // Map of invocation ID to TraceManager. + private final ConcurrentHashMap traceManagers = new ConcurrentHashMap<>(); + + PluginState(BigQueryLoggerConfig config) throws IOException { + this.config = config; + ThreadFactory threadFactory = + r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement()); + this.executor = Executors.newScheduledThreadPool(1, threadFactory); + // One write client per plugin instance, shared by all invocations. + this.writeClient = createWriteClient(config); + } + + ScheduledExecutorService getExecutor() { + return executor; + } + + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) throws IOException { + if (config.credentials() != null) { + return BigQueryWriteClient.create( + BigQueryWriteSettings.newBuilder() + .setCredentialsProvider(FixedCredentialsProvider.create(config.credentials())) + .build()); + } + return BigQueryWriteClient.create(); + } + + protected StreamWriter createWriter() { + BigQueryLoggerConfig.RetryConfig retryConfig = config.retryConfig(); + RetrySettings retrySettings = + RetrySettings.newBuilder() + .setMaxAttempts(retryConfig.maxRetries()) + .setInitialRetryDelay( + org.threeten.bp.Duration.ofMillis(retryConfig.initialDelay().toMillis())) + .setRetryDelayMultiplier(retryConfig.multiplier()) + .setMaxRetryDelay(org.threeten.bp.Duration.ofMillis(retryConfig.maxDelay().toMillis())) + .build(); + + String streamName = getStreamName(config); + try { + return StreamWriter.newBuilder(streamName, writeClient) + .setRetrySettings(retrySettings) + .setWriterSchema(BigQuerySchema.getArrowSchema()) + .build(); + } catch (Exception e) { + throw new VerifyException("Failed to create StreamWriter for " + streamName, e); + } + } + + @VisibleForTesting + String getStreamName(BigQueryLoggerConfig config) { + return String.format( + "projects/%s/datasets/%s/tables/%s/streams/_default", + config.projectId(), config.datasetId(), config.tableName()); + } + + @VisibleForTesting + TraceManager getTraceManager(String invocationId) { + return traceManagers.computeIfAbsent(invocationId, id -> new TraceManager()); + } + + @VisibleForTesting + BatchProcessor getBatchProcessor(String invocationId) { + return batchProcessors.computeIfAbsent( + invocationId, + id -> { + StreamWriter writer = createWriter(); + BatchProcessor p = + new BatchProcessor( + writer, + config.batchSize(), + config.batchFlushInterval(), + config.queueMaxSize(), + executor); + p.start(); + return p; + }); + } + + @VisibleForTesting + Collection getTraceManagers() { + return traceManagers.values(); + } + + @VisibleForTesting + Collection getBatchProcessors() { + return batchProcessors.values(); + } + + @VisibleForTesting + TraceManager removeTraceManager(String invocationId) { + return traceManagers.remove(invocationId); + } + + @VisibleForTesting + protected BatchProcessor removeProcessor(String invocationId) { + return batchProcessors.remove(invocationId); + } + + void clearTraceManagers() { + traceManagers.clear(); + } + + void clearBatchProcessors() { + batchProcessors.clear(); + } + + void close() { + for (BatchProcessor processor : getBatchProcessors()) { + processor.close(); + } + for (TraceManager traceManager : getTraceManagers()) { + traceManager.clearStack(); + } + clearBatchProcessors(); + clearTraceManagers(); + + if (writeClient != null) { + writeClient.close(); + } + try { + executor.shutdown(); + if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) { + executor.shutdownNow(); + } + } catch (InterruptedException e) { + executor.shutdownNow(); + Thread.currentThread().interrupt(); + } + } +} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java index 53faf3329..ef721e432 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java @@ -63,6 +63,7 @@ public final class BigQueryAgentAnalyticsPluginE2ETest { private StreamWriter mockWriter; private BigQueryWriteClient mockWriteClient; private BigQueryLoggerConfig config; + private PluginState state; private BigQueryAgentAnalyticsPlugin plugin; private Runner runner; private BaseAgent fakeAgent; @@ -92,26 +93,34 @@ public void setUp() throws Exception { when(mockWriter.append(any(ArrowRecordBatch.class))) .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); - plugin = - new BigQueryAgentAnalyticsPlugin(config, mockBigQuery) { + state = + new PluginState(config) { @Override protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { return mockWriteClient; } @Override - protected StreamWriter createWriter(BigQueryLoggerConfig config) { + protected StreamWriter createWriter() { return mockWriter; } + + @Override + protected BatchProcessor removeProcessor(String invocationId) { + return null; + } }; + plugin = new BigQueryAgentAnalyticsPlugin(config, mockBigQuery, state); + when(mockWriter.append(any(ArrowRecordBatch.class))) .thenAnswer( invocation -> { ArrowRecordBatch recordedBatch = invocation.getArgument(0); + BatchProcessor batchProcessor = state.getBatchProcessors().iterator().next(); try (VectorSchemaRoot root = VectorSchemaRoot.create( - BigQuerySchema.getArrowSchema(), plugin.batchProcessor.allocator)) { + BigQuerySchema.getArrowSchema(), batchProcessor.allocator)) { VectorLoader loader = new VectorLoader(root); loader.load(recordedBatch); for (int i = 0; i < root.getRowCount(); i++) { @@ -150,8 +159,9 @@ public void runAgent_logsAgentStartingAndCompleted() throws Exception { // Ensure everything is flushed. The BatchProcessor flushes asynchronously sometimes, // but the direct flush() call should help. We wait up to 2 seconds for all 5 expected events. + BatchProcessor batchProcessor = state.getBatchProcessors().iterator().next(); for (int i = 0; i < 20 && capturedRows.size() < 5; i++) { - plugin.batchProcessor.flush(); + batchProcessor.flush(); if (capturedRows.size() < 5) { Thread.sleep(100); } diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java index c7e35e3d6..fed1d81f1 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java @@ -62,7 +62,6 @@ import com.google.genai.types.GenerateContentResponse; import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; -import io.opentelemetry.api.GlobalOpenTelemetry; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.SpanContext; import io.opentelemetry.api.trace.Tracer; @@ -75,7 +74,11 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.logging.Handler; import java.util.logging.Level; import java.util.logging.LogRecord; @@ -113,6 +116,7 @@ public class BigQueryAgentAnalyticsPluginTest { private BaseAgent fakeAgent; private BigQueryLoggerConfig config; + private PluginState state; private BigQueryAgentAnalyticsPlugin plugin; private Handler mockHandler; private Tracer tracer; @@ -140,24 +144,21 @@ public void setUp() throws Exception { when(mockWriter.append(any(ArrowRecordBatch.class))) .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); - plugin = - new BigQueryAgentAnalyticsPlugin(config, mockBigQuery) { + state = + new PluginState(config) { @Override protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { return mockWriteClient; } @Override - protected StreamWriter createWriter(BigQueryLoggerConfig config) { + protected StreamWriter createWriter() { return mockWriter; } - - @Override - protected TraceManager createTraceManager() { - return new TraceManager(tracer); - } }; + plugin = new BigQueryAgentAnalyticsPlugin(config, mockBigQuery, state); + Session session = Session.builder("session_id").appName("test_app").userId("test_user").build(); when(mockInvocationContext.session()).thenReturn(session); when(mockInvocationContext.invocationId()).thenReturn("invocation_id"); @@ -183,7 +184,7 @@ public void onUserMessageCallback_appendsToWriter() throws Exception { Content content = Content.builder().build(); plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); } @@ -191,15 +192,15 @@ public void onUserMessageCallback_appendsToWriter() throws Exception { @Test public void beforeRunCallback_appendsToWriter() throws Exception { plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); } @Test public void afterRunCallback_flushesAndAppends() throws Exception { + plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); plugin.afterRunCallback(mockInvocationContext).blockingSubscribe(); - plugin.batchProcessor.flush(); verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); } @@ -213,7 +214,7 @@ public void getStreamName_returnsCorrectFormat() { .tableName("test-table") .build(); - String streamName = plugin.getStreamName(config); + String streamName = state.getStreamName(config); assertEquals( "projects/test-project/datasets/test-dataset/tables/test-table/streams/_default", @@ -253,7 +254,7 @@ public void onUserMessageCallback_handlesTableCreationFailure() throws Exception // Should not throw exception plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); verify(mockHandler, atLeastOnce()).publish(captor.capture()); @@ -280,7 +281,7 @@ public void onUserMessageCallback_handlesAppendFailure() throws Exception { plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); // Flush should handle the failed future from writer.append() - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); @@ -350,7 +351,8 @@ public void logEvent_populatesCommonFields() throws Exception { ArrowRecordBatch recordedBatch = invocation.getArgument(0); Schema schema = BigQuerySchema.getArrowSchema(); try (VectorSchemaRoot root = - VectorSchemaRoot.create(schema, plugin.batchProcessor.allocator)) { + VectorSchemaRoot.create( + schema, state.getBatchProcessor("invocation_id").allocator)) { VectorLoader loader = new VectorLoader(root); loader.load(recordedBatch); @@ -411,7 +413,7 @@ public void logEvent_populatesCommonFields() throws Exception { Content content = Content.fromParts(Part.fromText("test message")); plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); assertTrue(failureMessage[0], checksPassed[0]); } @@ -429,12 +431,12 @@ public void logEvent_populatesTraceDetails() throws Exception { Span mockSpan = Span.wrap(mockSpanContext); try (Scope scope = mockSpan.makeCurrent()) { - plugin.traceManager.attachCurrentSpan(); + state.getTraceManager("invocation_id").attachCurrentSpan(); Content content = Content.builder().build(); plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull("Row not found in queue", row); assertEquals(traceId, row.get("trace_id")); assertEquals(spanId, row.get("span_id")); @@ -447,7 +449,7 @@ public void complexType_appendsToWriter() throws Exception { Content content = Content.fromParts(part); plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); } @@ -462,7 +464,7 @@ public void onEventCallback_populatesCorrectFields() throws Exception { plugin.onEventCallback(mockInvocationContext, event).blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull("Row not found in queue", row); assertEquals("STATE_DELTA", row.get("event_type")); assertEquals("agent_name", row.get("agent")); @@ -479,12 +481,12 @@ public void onModelErrorCallback_populatesCorrectFields() throws Exception { LlmRequest.Builder mockLlmRequestBuilder = mock(LlmRequest.Builder.class); Throwable error = new RuntimeException("model error message"); - plugin.traceManager.pushSpan("llm_request"); + state.getTraceManager("invocation_id").pushSpan("llm_request"); plugin .onModelErrorCallback(mockCallbackContext, mockLlmRequestBuilder, error) .blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = plugin.getState().getBatchProcessor("invocation_id").queue.poll(); assertNotNull("Row not found in queue", row); assertEquals("LLM_ERROR", row.get("event_type")); assertEquals("agent_name", row.get("agent")); @@ -524,13 +526,13 @@ public void afterModelCallback_populatesCorrectFields() throws Exception { tracer.spanBuilder("ambient").setParent(Context.current().with(parentSpan)).startSpan(); // Set valid ambient span context try (Scope scope = ambientSpan.makeCurrent()) { - plugin.traceManager.pushSpan("parent_request"); - plugin.traceManager.pushSpan("llm_request"); + state.getTraceManager("invocation_id").pushSpan("parent_request"); + state.getTraceManager("invocation_id").pushSpan("llm_request"); plugin.afterModelCallback(mockCallbackContext, adkResponse).blockingSubscribe(); } finally { ambientSpan.end(); } - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull("Row not found in queue", row); assertEquals("LLM_RESPONSE", row.get("event_type")); ObjectNode contentMap = (ObjectNode) row.get("content"); @@ -562,10 +564,10 @@ public void afterToolCallback_populatesCorrectFields() throws Exception { ImmutableMap toolArgs = ImmutableMap.of("arg1", "value1"); ImmutableMap result = ImmutableMap.of("res1", "value2"); - plugin.traceManager.pushSpan("tool_request"); + state.getTraceManager("invocation_id").pushSpan("tool_request"); plugin.afterToolCallback(mockTool, toolArgs, mockToolContext, result).blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull("Row not found in queue", row); assertEquals("TOOL_COMPLETED", row.get("event_type")); assertEquals("agent_name", row.get("agent")); @@ -592,12 +594,12 @@ public AgentOrigin toolOrigin() { AgentTool a2aTool = AgentTool.create(a2aAgent); - plugin.traceManager.pushSpan("tool_request"); + state.getTraceManager("invocation_id").pushSpan("tool_request"); plugin .afterToolCallback(a2aTool, ImmutableMap.of(), mockToolContext, ImmutableMap.of()) .blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull(row); ObjectNode contentMap = (ObjectNode) row.get("content"); assertEquals("A2A", contentMap.get("tool_origin").asText()); @@ -609,7 +611,7 @@ public void logEvent_includesSessionMetadata_whenEnabled() throws Exception { Content content = Content.fromParts(Part.fromText("test message")); plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull(row); ObjectNode attributes = (ObjectNode) row.get("attributes"); assertTrue("attributes should contain session_metadata", attributes.has("session_metadata")); @@ -622,28 +624,25 @@ public void logEvent_includesSessionMetadata_whenEnabled() throws Exception { @Test public void logEvent_excludesSessionMetadata_whenDisabled() throws Exception { BigQueryLoggerConfig disabledConfig = config.toBuilder().logSessionMetadata(false).build(); - BigQueryAgentAnalyticsPlugin disabledPlugin = - new BigQueryAgentAnalyticsPlugin(disabledConfig, mockBigQuery) { + PluginState disabledState = + new PluginState(disabledConfig) { @Override protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { return mockWriteClient; } @Override - protected StreamWriter createWriter(BigQueryLoggerConfig config) { + protected StreamWriter createWriter() { return mockWriter; } - - @Override - protected TraceManager createTraceManager() { - return new TraceManager(GlobalOpenTelemetry.getTracer("test-plugin-disabled")); - } }; + BigQueryAgentAnalyticsPlugin disabledPlugin = + new BigQueryAgentAnalyticsPlugin(disabledConfig, mockBigQuery, disabledState); Content content = Content.fromParts(Part.fromText("test message")); disabledPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - Map row = disabledPlugin.batchProcessor.queue.poll(); + Map row = disabledState.getBatchProcessor("invocation_id").queue.poll(); assertNotNull(row); ObjectNode attributes = (ObjectNode) row.get("attributes"); assertFalse( @@ -767,6 +766,100 @@ public void createAnalyticsViews_executesQueries() throws Exception { .anyMatch(q -> q.contains("CREATE OR REPLACE VIEW `project.dataset.v_llm_response`"))); } + @Test + public void multipleInvocations_logsCorrectly() throws Exception { + BigQueryLoggerConfig testConfig = config.toBuilder().batchSize(10).build(); + PluginState testState = + new PluginState(testConfig) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter() { + return mockWriter; + } + }; + BigQueryAgentAnalyticsPlugin testPlugin = + new BigQueryAgentAnalyticsPlugin(testConfig, mockBigQuery, testState); + + InvocationContext context1 = mock(InvocationContext.class); + when(context1.invocationId()).thenReturn("inv-1"); + when(context1.agent()).thenReturn(fakeAgent); + when(context1.session()).thenReturn(Session.builder("s1").build()); + + InvocationContext context2 = mock(InvocationContext.class); + when(context2.invocationId()).thenReturn("inv-2"); + when(context2.agent()).thenReturn(fakeAgent); + when(context2.session()).thenReturn(Session.builder("s2").build()); + + var unused1 = testPlugin.beforeRunCallback(context1).blockingGet(); + var unused2 = + testPlugin + .onUserMessageCallback(context1, Content.fromParts(Part.fromText("msg1"))) + .blockingGet(); + + var unused3 = testPlugin.beforeRunCallback(context2).blockingGet(); + var unused4 = + testPlugin + .onUserMessageCallback(context2, Content.fromParts(Part.fromText("msg2"))) + .blockingGet(); + + // Verify processors are created and have correct data in their queues + BatchProcessor p1 = testState.getBatchProcessor("inv-1"); + BatchProcessor p2 = testState.getBatchProcessor("inv-2"); + + assertNotNull("Processor for inv-1 should exist", p1); + assertNotNull("Processor for inv-2 should exist", p2); + assertFalse("Queue for inv-1 should not be empty", p1.queue.isEmpty()); + assertFalse("Queue for inv-2 should not be empty", p2.queue.isEmpty()); + + assertTrue( + "All logs for inv-1 should have correct invocation_id", + p1.queue.stream().allMatch(row -> row.get("invocation_id").equals("inv-1"))); + assertTrue( + "All logs for inv-2 should have correct invocation_id", + p2.queue.stream().allMatch(row -> row.get("invocation_id").equals("inv-2"))); + + // Now flush and verify writer was called + testPlugin.afterRunCallback(context1).blockingAwait(); + testPlugin.afterRunCallback(context2).blockingAwait(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void logEvent_createsUniqueProcessorPerInvocation() throws Exception { + int numInvocations = 5; + ExecutorService testExecutor = Executors.newFixedThreadPool(numInvocations); + Set processors = ConcurrentHashMap.newKeySet(); + CountDownLatch latch = new CountDownLatch(numInvocations); + + for (int i = 0; i < numInvocations; i++) { + final String invocationId = "inv-" + i; + testExecutor.execute( + () -> { + try { + InvocationContext context = mock(InvocationContext.class); + when(context.invocationId()).thenReturn(invocationId); + when(context.agent()).thenReturn(fakeAgent); + Session session = Session.builder("s").build(); + when(context.session()).thenReturn(session); + + plugin.beforeRunCallback(context).blockingSubscribe(); + processors.add(state.getBatchProcessor(invocationId)); + } finally { + latch.countDown(); + } + }); + } + + latch.await(); + assertEquals(numInvocations, processors.size()); + testExecutor.shutdown(); + } + private static class FakeAgent extends BaseAgent { FakeAgent(String name) { super(name, "description", null, null, null);