Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions examples/google_tracer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# Add lib directory to load path
$LOAD_PATH.unshift(File.expand_path("../lib", __dir__))

require "securerandom"
require "openlayer"
require "openlayer/integrations/google_conversational_search_tracer"
require "google/cloud/discovery_engine/v1"
Expand All @@ -19,19 +20,25 @@
api_key: ENV["OPENLAYER_API_KEY"]
)

# Enable tracing - this patches the client to send all queries to Openlayer
# Enable tracing - this patches the client to send all queries to Openlayer.
# additional_columns here is a static default applied to every trace sent
# through this client.
Openlayer::Integrations::GoogleConversationalSearchTracer.trace_client(
google_client,
openlayer_client: openlayer,
inference_pipeline_id: ENV["OPENLAYER_INFERENCE_PIPELINE_ID"]
inference_pipeline_id: ENV["OPENLAYER_INFERENCE_PIPELINE_ID"],
additional_columns: {environment: "production"}
)

# Use the client normally - all answer_query calls are now automatically traced!
# additional_columns here is per-call; it takes precedence over the static
# default above on a key conflict.
response = google_client.answer_query(
serving_config: ENV["GOOGLE_SERVING_CONFIG"],
query: Google::Cloud::DiscoveryEngine::V1::Query.new(
text: "What is the meaning of life?"
)
),
additional_columns: {trace_id: SecureRandom.uuid}
)

puts "Answer: #{response.answer.answer_text}"
Expand Down
81 changes: 75 additions & 6 deletions lib/openlayer/integrations/google_conversational_search_tracer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,25 @@ module Integrations
# Openlayer::Integrations::GoogleConversationalSearchTracer.trace_client(
# google_client,
# openlayer_client: openlayer,
# inference_pipeline_id: 'your-pipeline-id'
# inference_pipeline_id: 'your-pipeline-id',
# additional_columns: { environment: 'production' }
# )
#
# # Now all answer_query calls are automatically traced
# # Now all answer_query calls are automatically traced! Pass
# # additional_columns on an individual call to attach data (like your
# # own trace ID) to just that row; it takes precedence over the
# # static defaults above on a key conflict.
# response = google_client.answer_query(
# serving_config: "projects/.../servingConfigs/default",
# query: { text: "What is the meaning of life?" }
# query: { text: "What is the meaning of life?" },
# additional_columns: { trace_id: "abc-123" }
# )
class GoogleConversationalSearchTracer
# Row keys computed by this tracer. Any key in a caller-supplied
# additional_columns hash matching one of these is dropped, so custom
# data can never overwrite core trace fields.
RESERVED_ROW_KEYS = [:query, :answer, :latency_ms, :timestamp, :metadata, :steps, :context, :session_id, :user_id].freeze

# Enable tracing on a Google ConversationalSearchService client
#
# @param client [Google::Cloud::DiscoveryEngine::V1::ConversationalSearchService::Client]
Expand All @@ -42,8 +52,12 @@ class GoogleConversationalSearchTracer
# Optional session ID to use for all traces. Takes precedence over auto-extracted sessions.
# @param user_id [String, nil]
# Optional user ID to use for all traces.
# @param additional_columns [Hash, nil]
# Optional static column values merged into every trace sent through this client (e.g. `{ environment: 'production' }`).
# A value passed to an individual answer_query call takes precedence over these on a key conflict. Keys colliding
# with a reserved row column (query, answer, latency_ms, timestamp, metadata, steps, context, session_id, user_id) are dropped.
# @return [void]
def self.trace_client(client, openlayer_client:, inference_pipeline_id:, session_id: nil, user_id: nil)
def self.trace_client(client, openlayer_client:, inference_pipeline_id:, session_id: nil, user_id: nil, additional_columns: {})
# Store original method reference
original_answer_query = client.method(:answer_query)

Expand All @@ -52,6 +66,10 @@ def self.trace_client(client, openlayer_client:, inference_pipeline_id:, session
# Capture start time
start_time = Time.now

# Extract per-call additional columns before forwarding to the
# real client; Google's client never sees this key
call_additional_columns = kwargs.delete(:additional_columns)

# Execute the original method
response = original_answer_query.call(*args, **kwargs, &block)

Expand All @@ -69,7 +87,9 @@ def self.trace_client(client, openlayer_client:, inference_pipeline_id:, session
openlayer_client: openlayer_client,
inference_pipeline_id: inference_pipeline_id,
session_id: session_id,
user_id: user_id
user_id: user_id,
additional_columns: additional_columns,
call_additional_columns: call_additional_columns
)
rescue StandardError => e
# Never break the user's application due to tracing errors
Expand All @@ -95,8 +115,10 @@ def self.trace_client(client, openlayer_client:, inference_pipeline_id:, session
# @param inference_pipeline_id [String] Pipeline ID
# @param session_id [String, nil] Optional session ID (takes precedence over auto-extracted)
# @param user_id [String, nil] Optional user ID
# @param additional_columns [Hash, nil] Optional static column values (see {.trace_client})
# @param call_additional_columns [Hash, nil] Optional per-call column values; takes precedence over additional_columns
# @return [void]
def self.send_trace(args:, kwargs:, response:, start_time:, end_time:, openlayer_client:, inference_pipeline_id:, session_id: nil, user_id: nil)
def self.send_trace(args:, kwargs:, response:, start_time:, end_time:, openlayer_client:, inference_pipeline_id:, session_id: nil, user_id: nil, additional_columns: {}, call_additional_columns: {})
# Calculate latency
latency_ms = ((end_time - start_time) * 1000).round(2)

Expand Down Expand Up @@ -199,6 +221,12 @@ def self.send_trace(args:, kwargs:, response:, start_time:, end_time:, openlayer
trace_data[:config][:userIdColumnName] = "user_id"
end

# Merge additional columns (per-call values take precedence over
# static defaults; keys colliding with reserved row columns are
# dropped so custom data can never corrupt core trace fields)
extra_columns = resolve_additional_columns(additional_columns, call_additional_columns)
trace_data[:rows][0].merge!(extra_columns) unless extra_columns.empty?

# Send to Openlayer
openlayer_client
.inference_pipelines
Expand Down Expand Up @@ -594,6 +622,45 @@ def self.extract_query_understanding_info(answer)
nil
end

# Merge static and per-call additional columns into a single Hash of
# extra row columns. Call-level values take precedence over static
# ones on key conflict, and any key colliding with a reserved row
# column is dropped.
#
# @param static_columns [Object] Value passed to trace_client (expected Hash)
# @param call_columns [Object] Value passed to an individual answer_query call (expected Hash)
# @return [Hash] Extra columns safe to merge onto a trace row
def self.resolve_additional_columns(static_columns, call_columns)
merged = normalize_additional_columns(static_columns).merge(normalize_additional_columns(call_columns))

merged.each_with_object({}) do |(key, value), result|
if RESERVED_ROW_KEYS.include?(key)
warn_if_debug("[Openlayer] additional_columns key :#{key} collides with a reserved column and was ignored")
else
result[key] = value
end
end
end

# Normalize an additional_columns value into a Hash with Symbol keys.
# Non-Hash input (or a key that can't be a Symbol) is dropped rather
# than raising, so a caller mistake can never break tracing.
#
# @param columns [Object] Expected to be a Hash of column name => value
# @return [Hash]
def self.normalize_additional_columns(columns)
return {} unless columns.is_a?(Hash)

columns.each_with_object({}) do |(key, value), result|
next unless key.respond_to?(:to_sym)

result[key.to_sym] = value
end
rescue StandardError => e
warn_if_debug("[Openlayer] Failed to normalize additional columns: #{e.message}")
{}
end

# Safely extract a field from an object
#
# @param obj [Object] Object to extract from
Expand Down Expand Up @@ -659,6 +726,8 @@ def self.warn_if_debug(message)
:extract_session,
:extract_user_pseudo_id,
:extract_query_understanding_info,
:resolve_additional_columns,
:normalize_additional_columns,
:safe_extract,
:safe_count,
:extract_timestamp
Expand Down
6 changes: 4 additions & 2 deletions rbi/openlayer/integrations.rbi
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@ module Openlayer
openlayer_client: Openlayer::Client,
inference_pipeline_id: String,
session_id: T.nilable(String),
user_id: T.nilable(String)
user_id: T.nilable(String),
additional_columns: T::Hash[Symbol, T.untyped]
).void
end
def self.trace_client(
client,
openlayer_client:,
inference_pipeline_id:,
session_id: nil,
user_id: nil
user_id: nil,
additional_columns: {}
)
end
end
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# frozen_string_literal: true

require_relative "../test_helper"
require_relative "../../../lib/openlayer/integrations/google_conversational_search_tracer"

module Openlayer
module Test
module Integrations
end
end
end

class Openlayer::Test::Integrations::GoogleConversationalSearchTracerTest < Minitest::Test
Tracer = Openlayer::Integrations::GoogleConversationalSearchTracer

class FakeAnswer
attr_reader :answer_text

def initialize(answer_text)
@answer_text = answer_text
end
end

class FakeResponse
attr_reader :answer

def initialize(answer_text)
@answer = FakeAnswer.new(answer_text)
end
end

class FakeGoogleClient
def answer_query(serving_config:, query:) # rubocop:disable Lint/UnusedMethodArgument
FakeResponse.new("hi")
end
end

class FakeDataResource
attr_reader :calls

def initialize
@calls = []
end

def stream(inference_pipeline_id, **trace_data)
@calls << {inference_pipeline_id: inference_pipeline_id}.merge(trace_data)
end
end

class FakeInferencePipelines
attr_reader :data

def initialize(data)
@data = data
end
end

class FakeOpenlayerClient
attr_reader :inference_pipelines

def initialize
@data = FakeDataResource.new
@inference_pipelines = FakeInferencePipelines.new(@data)
end

def last_row
@data.calls.last[:rows][0]
end
end

def setup
@openlayer_client = FakeOpenlayerClient.new
@start_time = Time.now
@end_time = @start_time + 1
end

def trace_row(**overrides)
defaults = {
args: [],
kwargs: {query: "hello"},
response: FakeResponse.new("hi"),
start_time: @start_time,
end_time: @end_time,
openlayer_client: @openlayer_client,
inference_pipeline_id: "pipeline-id"
}

Tracer.send_trace(**defaults, **overrides)
@openlayer_client.last_row
end

def test_static_additional_columns_appear_on_the_row
row = trace_row(additional_columns: {environment: "production"})

assert_equal("production", row[:environment])
end

def test_per_call_additional_columns_appear_on_the_row
row = trace_row(call_additional_columns: {trace_id: "abc-123"})

assert_equal("abc-123", row[:trace_id])
end

def test_per_call_additional_columns_override_static_on_conflict
row = trace_row(
additional_columns: {trace_id: "static-value"},
call_additional_columns: {trace_id: "call-value"}
)

assert_equal("call-value", row[:trace_id])
end

def test_reserved_keys_are_dropped_even_as_string_keys
row = trace_row(additional_columns: {"answer" => "hijacked", trace_id: "abc-123"})

assert_equal("hi", row[:answer])
assert_equal("abc-123", row[:trace_id])
end

def test_non_hash_additional_columns_does_not_raise
row = trace_row(additional_columns: "not-a-hash", call_additional_columns: nil)

assert_equal("hi", row[:answer])
end

def test_trace_client_strips_additional_columns_before_forwarding_to_google_client
google_client = FakeGoogleClient.new

Tracer.trace_client(
google_client,
openlayer_client: @openlayer_client,
inference_pipeline_id: "pipeline-id"
)

response = google_client.answer_query(
serving_config: "config",
query: "hello",
additional_columns: {trace_id: "abc-123"}
)

assert_equal("hi", response.answer.answer_text)
assert_equal("abc-123", @openlayer_client.last_row[:trace_id])
end
end