Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import static com.google.adk.plugins.agentanalytics.JsonFormatter.convertToJsonNode;
import static com.google.adk.plugins.agentanalytics.JsonFormatter.smartTruncate;
import static com.google.adk.plugins.agentanalytics.JsonFormatter.toJavaObject;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

import com.google.adk.agents.BaseAgent;
import com.google.adk.agents.CallbackContext;
Expand All @@ -41,8 +40,6 @@
import com.google.adk.tools.ToolContext;
import com.google.adk.tools.mcp.AbstractMcpTool;
import com.google.adk.utils.AgentEnums.AgentOrigin;
import com.google.api.gax.core.FixedCredentialsProvider;
import com.google.api.gax.retrying.RetrySettings;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.cloud.bigquery.BigQuery;
import com.google.cloud.bigquery.BigQueryException;
Expand All @@ -53,11 +50,7 @@
import com.google.cloud.bigquery.Table;
import com.google.cloud.bigquery.TableId;
import com.google.cloud.bigquery.TableInfo;
import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient;
import com.google.cloud.bigquery.storage.v1.BigQueryWriteSettings;
import com.google.cloud.bigquery.storage.v1.StreamWriter;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.genai.types.Content;
Expand All @@ -70,9 +63,6 @@
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Level;
import java.util.logging.Logger;
Expand Down Expand Up @@ -100,11 +90,8 @@ public class BigQueryAgentAnalyticsPlugin extends BasePlugin {

private final BigQueryLoggerConfig config;
private final BigQuery bigQuery;
private final BigQueryWriteClient writeClient;
private final ScheduledExecutorService executor;
private final Object tableEnsuredLock = new Object();
@VisibleForTesting final BatchProcessor batchProcessor;
@VisibleForTesting final TraceManager traceManager;
private final PluginState state;
private volatile boolean tableEnsured = false;

public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOException {
Expand All @@ -113,28 +100,14 @@ public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOExcept

public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQuery)
throws IOException {
this(config, bigQuery, new PluginState(config));
}

BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQuery, PluginState state) {
super("bigquery_agent_analytics");
this.config = config;
this.bigQuery = bigQuery;
ThreadFactory threadFactory =
r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement());
this.executor = Executors.newScheduledThreadPool(1, threadFactory);
this.writeClient = createWriteClient(config);
this.traceManager = createTraceManager();

if (config.enabled()) {
StreamWriter writer = createWriter(config);
this.batchProcessor =
new BatchProcessor(
writer,
config.batchSize(),
config.batchFlushInterval(),
config.queueMaxSize(),
executor);
this.batchProcessor.start();
} else {
this.batchProcessor = null;
}
this.state = state;
}

private static BigQuery createBigQuery(BigQueryLoggerConfig config) throws IOException {
Expand Down Expand Up @@ -194,7 +167,7 @@ private void ensureTableExists(BigQuery bigQuery, BigQueryLoggerConfig config) {

try {
if (config.createViews()) {
var unused = executor.submit(() -> createAnalyticsViews(bigQuery, config));
var unused = state.getExecutor().submit(() -> createAnalyticsViews(bigQuery, config));
}
} catch (RuntimeException e) {
logger.log(Level.WARNING, "Failed to create/update BigQuery views for table: " + tableId, e);
Expand All @@ -209,48 +182,6 @@ private void processBigQueryException(BigQueryException e, String logMessage) {
}
}

protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) throws IOException {
if (config.credentials() != null) {
return BigQueryWriteClient.create(
BigQueryWriteSettings.newBuilder()
.setCredentialsProvider(FixedCredentialsProvider.create(config.credentials()))
.build());
}
return BigQueryWriteClient.create();
}

protected String getStreamName(BigQueryLoggerConfig config) {
return String.format(
"projects/%s/datasets/%s/tables/%s/streams/_default",
config.projectId(), config.datasetId(), config.tableName());
}

protected StreamWriter createWriter(BigQueryLoggerConfig config) {
BigQueryLoggerConfig.RetryConfig retryConfig = config.retryConfig();
RetrySettings retrySettings =
RetrySettings.newBuilder()
.setMaxAttempts(retryConfig.maxRetries())
.setInitialRetryDelay(
org.threeten.bp.Duration.ofMillis(retryConfig.initialDelay().toMillis()))
.setRetryDelayMultiplier(retryConfig.multiplier())
.setMaxRetryDelay(org.threeten.bp.Duration.ofMillis(retryConfig.maxDelay().toMillis()))
.build();

String streamName = getStreamName(config);
try {
return StreamWriter.newBuilder(streamName, writeClient)
.setRetrySettings(retrySettings)
.setWriterSchema(BigQuerySchema.getArrowSchema())
.build();
} catch (Exception e) {
throw new VerifyException("Failed to create StreamWriter for " + streamName, e);
}
}

protected TraceManager createTraceManager() {
return new TraceManager();
}

private void logEvent(
String eventType,
InvocationContext invocationContext,
Expand All @@ -265,7 +196,7 @@ private void logEvent(
Object content,
boolean isContentTruncated,
Optional<EventData> eventData) {
if (!config.enabled() || batchProcessor == null) {
if (!config.enabled()) {
return;
}
if (!config.eventAllowlist().isEmpty() && !config.eventAllowlist().contains(eventType)) {
Expand All @@ -274,6 +205,8 @@ private void logEvent(
if (config.eventDenylist().contains(eventType)) {
return;
}
String invocationId = invocationContext.invocationId();
BatchProcessor processor = state.getBatchProcessor(invocationId);
// Ensure table exists before logging.
ensureTableExistsOnce();
// Log common fields
Expand Down Expand Up @@ -301,11 +234,12 @@ private void logEvent(
row.put("attributes", convertToJsonNode(getAttributes(data, invocationContext)));

addTraceDetails(row, invocationContext, eventData);
batchProcessor.append(row);
processor.append(row);
}

private void addTraceDetails(
Map<String, Object> row, InvocationContext invocationContext, Optional<EventData> eventData) {
TraceManager traceManager = state.getTraceManager(invocationContext.invocationId());
String traceId =
eventData
.flatMap(EventData::traceIdOverride)
Expand Down Expand Up @@ -336,7 +270,7 @@ private void addTraceDetails(
private Map<String, Object> getAttributes(
EventData eventData, InvocationContext invocationContext) {
Map<String, Object> attributes = new HashMap<>(eventData.extraAttributes());

TraceManager traceManager = state.getTraceManager(invocationContext.invocationId());
attributes.put("root_agent_name", traceManager.getRootAgentName());
eventData.model().ifPresent(m -> attributes.put("model", m));
eventData.modelVersion().ifPresent(mv -> attributes.put("model_version", mv));
Expand Down Expand Up @@ -375,25 +309,17 @@ private Map<String, Object> getAttributes(

@Override
public Completable close() {
if (batchProcessor != null) {
batchProcessor.close();
}
if (writeClient != null) {
writeClient.close();
}
try {
executor.shutdown();
if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) {
executor.shutdownNow();
}
} catch (InterruptedException e) {
executor.shutdownNow();
Thread.currentThread().interrupt();
}
state.close();
return Completable.complete();
}

@VisibleForTesting
PluginState getState() {
return state;
}

private Optional<EventData> getCompletedEventData(InvocationContext invocationContext) {
TraceManager traceManager = state.getTraceManager(invocationContext.invocationId());
String traceId = traceManager.getTraceId(invocationContext);
// Pop the invocation span from the trace manager.
Optional<RecordData> popped = traceManager.popSpan();
Expand Down Expand Up @@ -426,7 +352,9 @@ public Maybe<Content> onUserMessageCallback(
InvocationContext invocationContext, Content userMessage) {
return Maybe.fromAction(
() -> {
traceManager.ensureInvocationSpan(invocationContext);
state
.getTraceManager(invocationContext.invocationId())
.ensureInvocationSpan(invocationContext);
logEvent("USER_MESSAGE_RECEIVED", invocationContext, userMessage, Optional.empty());
if (userMessage.parts().isPresent()) {
for (Part part : userMessage.parts().get()) {
Expand Down Expand Up @@ -510,7 +438,7 @@ public Maybe<Event> onEventCallback(InvocationContext invocationContext, Event e

@Override
public Maybe<Content> beforeRunCallback(InvocationContext invocationContext) {
traceManager.ensureInvocationSpan(invocationContext);
state.getTraceManager(invocationContext.invocationId()).ensureInvocationSpan(invocationContext);
return Maybe.fromAction(
() -> logEvent("INVOCATION_STARTING", invocationContext, null, Optional.empty()));
}
Expand All @@ -524,16 +452,25 @@ public Completable afterRunCallback(InvocationContext invocationContext) {
invocationContext,
null,
getCompletedEventData(invocationContext));
batchProcessor.flush();
traceManager.clearStack();
BatchProcessor processor = state.removeProcessor(invocationContext.invocationId());
if (processor != null) {
processor.flush();
processor.close();
}
TraceManager traceManager = state.removeTraceManager(invocationContext.invocationId());
if (traceManager != null) {
traceManager.clearStack();
}
});
}

@Override
public Maybe<Content> beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) {
return Maybe.fromAction(
() -> {
traceManager.pushSpan("agent:" + agent.name());
state
.getTraceManager(callbackContext.invocationContext().invocationId())
.pushSpan("agent:" + agent.name());
logEvent("AGENT_STARTING", callbackContext.invocationContext(), null, Optional.empty());
});
}
Expand Down Expand Up @@ -622,7 +559,9 @@ public Maybe<LlmResponse> beforeModelCallback(
.setModel(req.model().orElse(""))
.setExtraAttributes(attributes)
.build();
traceManager.pushSpan("llm_request");
state
.getTraceManager(callbackContext.invocationContext().invocationId())
.pushSpan("llm_request");
logEvent("LLM_REQUEST", callbackContext.invocationContext(), req, Optional.of(eventData));
});
}
Expand All @@ -632,6 +571,8 @@ public Maybe<LlmResponse> afterModelCallback(
CallbackContext callbackContext, LlmResponse llmResponse) {
return Maybe.fromAction(
() -> {
TraceManager traceManager =
state.getTraceManager(callbackContext.invocationContext().invocationId());
// TODO(b/495809488): Add formatting of the content
ParsedContent parsedContent =
JsonFormatter.parse(llmResponse.content().orElse(null), config.maxContentLength());
Expand Down Expand Up @@ -728,6 +669,8 @@ public Maybe<LlmResponse> onModelErrorCallback(
CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) {
return Maybe.fromAction(
() -> {
TraceManager traceManager =
state.getTraceManager(callbackContext.invocationContext().invocationId());
InvocationContext invocationContext = callbackContext.invocationContext();
Optional<RecordData> popped = traceManager.popSpan();
String spanId = popped.map(RecordData::spanId).orElse(null);
Expand Down Expand Up @@ -762,7 +705,7 @@ public Maybe<Map<String, Object>> beforeToolCallback(
ImmutableMap<String, Object> contentMap =
ImmutableMap.of(
"tool_origin", getToolOrigin(tool), "tool", tool.name(), "args", res.node());
traceManager.pushSpan("tool");
state.getTraceManager(toolContext.invocationContext().invocationId()).pushSpan("tool");
logEvent("TOOL_STARTING", toolContext.invocationContext(), contentMap, Optional.empty());
});
}
Expand All @@ -775,6 +718,8 @@ public Maybe<Map<String, Object>> afterToolCallback(
Map<String, Object> result) {
return Maybe.fromAction(
() -> {
TraceManager traceManager =
state.getTraceManager(toolContext.invocationContext().invocationId());
Optional<RecordData> popped = traceManager.popSpan();
TruncationResult truncationResult = smartTruncate(result, config.maxContentLength());
ImmutableMap<String, Object> contentMap =
Expand Down Expand Up @@ -812,6 +757,8 @@ public Maybe<Map<String, Object>> onToolErrorCallback(
BaseTool tool, Map<String, Object> toolArgs, ToolContext toolContext, Throwable error) {
return Maybe.fromAction(
() -> {
TraceManager traceManager =
state.getTraceManager(toolContext.invocationContext().invocationId());
Optional<RecordData> popped = traceManager.popSpan();
TruncationResult truncationResult = smartTruncate(toolArgs, config.maxContentLength());
String toolOrigin = getToolOrigin(tool);
Expand Down
Loading