diff --git a/examples/google_tracer.rb b/examples/google_tracer.rb index 0f378eb..3c6ac3d 100755 --- a/examples/google_tracer.rb +++ b/examples/google_tracer.rb @@ -7,6 +7,7 @@ # Add lib directory to load path $LOAD_PATH.unshift(File.expand_path("../lib", __dir__)) +require "securerandom" require "openlayer" require "openlayer/integrations/google_conversational_search_tracer" require "google/cloud/discovery_engine/v1" @@ -19,19 +20,25 @@ api_key: ENV["OPENLAYER_API_KEY"] ) -# Enable tracing - this patches the client to send all queries to Openlayer +# Enable tracing - this patches the client to send all queries to Openlayer. +# additional_columns here is a static default applied to every trace sent +# through this client. Openlayer::Integrations::GoogleConversationalSearchTracer.trace_client( google_client, openlayer_client: openlayer, - inference_pipeline_id: ENV["OPENLAYER_INFERENCE_PIPELINE_ID"] + inference_pipeline_id: ENV["OPENLAYER_INFERENCE_PIPELINE_ID"], + additional_columns: {environment: "production"} ) # Use the client normally - all answer_query calls are now automatically traced! +# additional_columns here is per-call; it takes precedence over the static +# default above on a key conflict. response = google_client.answer_query( serving_config: ENV["GOOGLE_SERVING_CONFIG"], query: Google::Cloud::DiscoveryEngine::V1::Query.new( text: "What is the meaning of life?" - ) + ), + additional_columns: {trace_id: SecureRandom.uuid} ) puts "Answer: #{response.answer.answer_text}" diff --git a/lib/openlayer/integrations/google_conversational_search_tracer.rb b/lib/openlayer/integrations/google_conversational_search_tracer.rb index aef47a6..23ed398 100644 --- a/lib/openlayer/integrations/google_conversational_search_tracer.rb +++ b/lib/openlayer/integrations/google_conversational_search_tracer.rb @@ -21,15 +21,25 @@ module Integrations # Openlayer::Integrations::GoogleConversationalSearchTracer.trace_client( # google_client, # openlayer_client: openlayer, - # inference_pipeline_id: 'your-pipeline-id' + # inference_pipeline_id: 'your-pipeline-id', + # additional_columns: { environment: 'production' } # ) # - # # Now all answer_query calls are automatically traced + # # Now all answer_query calls are automatically traced! Pass + # # additional_columns on an individual call to attach data (like your + # # own trace ID) to just that row; it takes precedence over the + # # static defaults above on a key conflict. # response = google_client.answer_query( # serving_config: "projects/.../servingConfigs/default", - # query: { text: "What is the meaning of life?" } + # query: { text: "What is the meaning of life?" }, + # additional_columns: { trace_id: "abc-123" } # ) class GoogleConversationalSearchTracer + # Row keys computed by this tracer. Any key in a caller-supplied + # additional_columns hash matching one of these is dropped, so custom + # data can never overwrite core trace fields. + RESERVED_ROW_KEYS = [:query, :answer, :latency_ms, :timestamp, :metadata, :steps, :context, :session_id, :user_id].freeze + # Enable tracing on a Google ConversationalSearchService client # # @param client [Google::Cloud::DiscoveryEngine::V1::ConversationalSearchService::Client] @@ -42,8 +52,12 @@ class GoogleConversationalSearchTracer # Optional session ID to use for all traces. Takes precedence over auto-extracted sessions. # @param user_id [String, nil] # Optional user ID to use for all traces. + # @param additional_columns [Hash, nil] + # Optional static column values merged into every trace sent through this client (e.g. `{ environment: 'production' }`). + # A value passed to an individual answer_query call takes precedence over these on a key conflict. Keys colliding + # with a reserved row column (query, answer, latency_ms, timestamp, metadata, steps, context, session_id, user_id) are dropped. # @return [void] - def self.trace_client(client, openlayer_client:, inference_pipeline_id:, session_id: nil, user_id: nil) + def self.trace_client(client, openlayer_client:, inference_pipeline_id:, session_id: nil, user_id: nil, additional_columns: {}) # Store original method reference original_answer_query = client.method(:answer_query) @@ -52,6 +66,10 @@ def self.trace_client(client, openlayer_client:, inference_pipeline_id:, session # Capture start time start_time = Time.now + # Extract per-call additional columns before forwarding to the + # real client; Google's client never sees this key + call_additional_columns = kwargs.delete(:additional_columns) + # Execute the original method response = original_answer_query.call(*args, **kwargs, &block) @@ -69,7 +87,9 @@ def self.trace_client(client, openlayer_client:, inference_pipeline_id:, session openlayer_client: openlayer_client, inference_pipeline_id: inference_pipeline_id, session_id: session_id, - user_id: user_id + user_id: user_id, + additional_columns: additional_columns, + call_additional_columns: call_additional_columns ) rescue StandardError => e # Never break the user's application due to tracing errors @@ -95,8 +115,10 @@ def self.trace_client(client, openlayer_client:, inference_pipeline_id:, session # @param inference_pipeline_id [String] Pipeline ID # @param session_id [String, nil] Optional session ID (takes precedence over auto-extracted) # @param user_id [String, nil] Optional user ID + # @param additional_columns [Hash, nil] Optional static column values (see {.trace_client}) + # @param call_additional_columns [Hash, nil] Optional per-call column values; takes precedence over additional_columns # @return [void] - def self.send_trace(args:, kwargs:, response:, start_time:, end_time:, openlayer_client:, inference_pipeline_id:, session_id: nil, user_id: nil) + def self.send_trace(args:, kwargs:, response:, start_time:, end_time:, openlayer_client:, inference_pipeline_id:, session_id: nil, user_id: nil, additional_columns: {}, call_additional_columns: {}) # Calculate latency latency_ms = ((end_time - start_time) * 1000).round(2) @@ -199,6 +221,12 @@ def self.send_trace(args:, kwargs:, response:, start_time:, end_time:, openlayer trace_data[:config][:userIdColumnName] = "user_id" end + # Merge additional columns (per-call values take precedence over + # static defaults; keys colliding with reserved row columns are + # dropped so custom data can never corrupt core trace fields) + extra_columns = resolve_additional_columns(additional_columns, call_additional_columns) + trace_data[:rows][0].merge!(extra_columns) unless extra_columns.empty? + # Send to Openlayer openlayer_client .inference_pipelines @@ -594,6 +622,45 @@ def self.extract_query_understanding_info(answer) nil end + # Merge static and per-call additional columns into a single Hash of + # extra row columns. Call-level values take precedence over static + # ones on key conflict, and any key colliding with a reserved row + # column is dropped. + # + # @param static_columns [Object] Value passed to trace_client (expected Hash) + # @param call_columns [Object] Value passed to an individual answer_query call (expected Hash) + # @return [Hash] Extra columns safe to merge onto a trace row + def self.resolve_additional_columns(static_columns, call_columns) + merged = normalize_additional_columns(static_columns).merge(normalize_additional_columns(call_columns)) + + merged.each_with_object({}) do |(key, value), result| + if RESERVED_ROW_KEYS.include?(key) + warn_if_debug("[Openlayer] additional_columns key :#{key} collides with a reserved column and was ignored") + else + result[key] = value + end + end + end + + # Normalize an additional_columns value into a Hash with Symbol keys. + # Non-Hash input (or a key that can't be a Symbol) is dropped rather + # than raising, so a caller mistake can never break tracing. + # + # @param columns [Object] Expected to be a Hash of column name => value + # @return [Hash] + def self.normalize_additional_columns(columns) + return {} unless columns.is_a?(Hash) + + columns.each_with_object({}) do |(key, value), result| + next unless key.respond_to?(:to_sym) + + result[key.to_sym] = value + end + rescue StandardError => e + warn_if_debug("[Openlayer] Failed to normalize additional columns: #{e.message}") + {} + end + # Safely extract a field from an object # # @param obj [Object] Object to extract from @@ -659,6 +726,8 @@ def self.warn_if_debug(message) :extract_session, :extract_user_pseudo_id, :extract_query_understanding_info, + :resolve_additional_columns, + :normalize_additional_columns, :safe_extract, :safe_count, :extract_timestamp diff --git a/rbi/openlayer/integrations.rbi b/rbi/openlayer/integrations.rbi index 421bce5..2cf90e8 100644 --- a/rbi/openlayer/integrations.rbi +++ b/rbi/openlayer/integrations.rbi @@ -10,7 +10,8 @@ module Openlayer openlayer_client: Openlayer::Client, inference_pipeline_id: String, session_id: T.nilable(String), - user_id: T.nilable(String) + user_id: T.nilable(String), + additional_columns: T::Hash[Symbol, T.untyped] ).void end def self.trace_client( @@ -18,7 +19,8 @@ module Openlayer openlayer_client:, inference_pipeline_id:, session_id: nil, - user_id: nil + user_id: nil, + additional_columns: {} ) end end diff --git a/test/openlayer/integrations/google_conversational_search_tracer_test.rb b/test/openlayer/integrations/google_conversational_search_tracer_test.rb new file mode 100644 index 0000000..05620c0 --- /dev/null +++ b/test/openlayer/integrations/google_conversational_search_tracer_test.rb @@ -0,0 +1,144 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "../../../lib/openlayer/integrations/google_conversational_search_tracer" + +module Openlayer + module Test + module Integrations + end + end +end + +class Openlayer::Test::Integrations::GoogleConversationalSearchTracerTest < Minitest::Test + Tracer = Openlayer::Integrations::GoogleConversationalSearchTracer + + class FakeAnswer + attr_reader :answer_text + + def initialize(answer_text) + @answer_text = answer_text + end + end + + class FakeResponse + attr_reader :answer + + def initialize(answer_text) + @answer = FakeAnswer.new(answer_text) + end + end + + class FakeGoogleClient + def answer_query(serving_config:, query:) # rubocop:disable Lint/UnusedMethodArgument + FakeResponse.new("hi") + end + end + + class FakeDataResource + attr_reader :calls + + def initialize + @calls = [] + end + + def stream(inference_pipeline_id, **trace_data) + @calls << {inference_pipeline_id: inference_pipeline_id}.merge(trace_data) + end + end + + class FakeInferencePipelines + attr_reader :data + + def initialize(data) + @data = data + end + end + + class FakeOpenlayerClient + attr_reader :inference_pipelines + + def initialize + @data = FakeDataResource.new + @inference_pipelines = FakeInferencePipelines.new(@data) + end + + def last_row + @data.calls.last[:rows][0] + end + end + + def setup + @openlayer_client = FakeOpenlayerClient.new + @start_time = Time.now + @end_time = @start_time + 1 + end + + def trace_row(**overrides) + defaults = { + args: [], + kwargs: {query: "hello"}, + response: FakeResponse.new("hi"), + start_time: @start_time, + end_time: @end_time, + openlayer_client: @openlayer_client, + inference_pipeline_id: "pipeline-id" + } + + Tracer.send_trace(**defaults, **overrides) + @openlayer_client.last_row + end + + def test_static_additional_columns_appear_on_the_row + row = trace_row(additional_columns: {environment: "production"}) + + assert_equal("production", row[:environment]) + end + + def test_per_call_additional_columns_appear_on_the_row + row = trace_row(call_additional_columns: {trace_id: "abc-123"}) + + assert_equal("abc-123", row[:trace_id]) + end + + def test_per_call_additional_columns_override_static_on_conflict + row = trace_row( + additional_columns: {trace_id: "static-value"}, + call_additional_columns: {trace_id: "call-value"} + ) + + assert_equal("call-value", row[:trace_id]) + end + + def test_reserved_keys_are_dropped_even_as_string_keys + row = trace_row(additional_columns: {"answer" => "hijacked", trace_id: "abc-123"}) + + assert_equal("hi", row[:answer]) + assert_equal("abc-123", row[:trace_id]) + end + + def test_non_hash_additional_columns_does_not_raise + row = trace_row(additional_columns: "not-a-hash", call_additional_columns: nil) + + assert_equal("hi", row[:answer]) + end + + def test_trace_client_strips_additional_columns_before_forwarding_to_google_client + google_client = FakeGoogleClient.new + + Tracer.trace_client( + google_client, + openlayer_client: @openlayer_client, + inference_pipeline_id: "pipeline-id" + ) + + response = google_client.answer_query( + serving_config: "config", + query: "hello", + additional_columns: {trace_id: "abc-123"} + ) + + assert_equal("hi", response.answer.answer_text) + assert_equal("abc-123", @openlayer_client.last_row[:trace_id]) + end +end